Files
MyGoNavi/internal/db/mysql_impl.go
Syngnat 7fd6d78c83 feat(driver): 新增 OceanBase 与 OpenGauss Agent 数据源
- 数据源支持:新增 OceanBase 与 OpenGauss optional driver-agent 实现
- 连接适配:复用 MySQL/PostgreSQL 兼容链路并补齐查询、DDL、同步能力
- 前端入口:补充连接表单、侧边栏、图标、SQL 方言和危险操作识别
- 驱动管理:更新 driver manifest、安装提示和 revision 自动生成链路
- 构建发布:支持多平台 driver-agent 打包并优化 release 构建失败提示
2026-04-30 13:13:01 +08:00

1200 lines
30 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package db
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"math"
"net/url"
"sort"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
_ "github.com/go-sql-driver/mysql"
)
type MySQLDB struct {
conn *sql.DB
pingTimeout time.Duration
}
const defaultMySQLPort = 3306
func parseMySQLCompatibleURI(raw string, allowedSchemes ...string) (*url.URL, bool) {
return parseConnectionURI(raw, allowedSchemes...)
}
func mysqlConnectionParamsFromText(raw string) url.Values {
return connectionParamsFromText(raw)
}
func parseMySQLBoolParam(raw string) (bool, bool) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "1", "true", "yes", "on":
return true, true
case "0", "false", "no", "off":
return false, true
default:
return false, false
}
}
func normalizeMySQLDurationParam(raw string, unit time.Duration) string {
text := strings.TrimSpace(raw)
if text == "" {
return text
}
if n, err := strconv.Atoi(text); err == nil && n >= 0 {
return (time.Duration(n) * unit).String()
}
return text
}
func normalizeMySQLCharsetParam(raw string) string {
text := strings.TrimSpace(raw)
if text == "" {
return ""
}
lower := strings.ToLower(text)
switch lower {
case "utf-8", "utf_8", "unicode":
return "utf8mb4"
case "utf8", "utf8mb4", "latin1", "gbk", "gb2312", "gb18030", "big5", "sjis", "cp932":
return lower
case "iso-8859-1", "iso8859-1", "iso88591":
return "latin1"
default:
return text
}
}
func normalizeMySQLServerTimezoneParam(raw string) (string, bool) {
text := strings.TrimSpace(raw)
if text == "" {
return "", false
}
compact := strings.ToUpper(strings.ReplaceAll(text, " ", ""))
switch compact {
case "LOCAL":
return "Local", true
case "UTC", "Z", "GMT", "GMT+0", "GMT-0", "GMT+00", "GMT-00", "GMT+00:00", "GMT-00:00",
"UTC+0", "UTC-0", "UTC+00", "UTC-00", "UTC+00:00", "UTC-00:00":
return "UTC", true
case "GMT+8", "GMT+08", "GMT+08:00", "UTC+8", "UTC+08", "UTC+08:00",
"ASIA/SHANGHAI", "PRC", "CTT":
return "Asia/Shanghai", true
}
if strings.Contains(text, "/") {
if _, err := time.LoadLocation(text); err == nil {
return text, true
}
}
return "", false
}
func mergeMySQLConnectionParam(params url.Values, key string, value string) {
name := strings.TrimSpace(key)
if name == "" {
return
}
lowerName := strings.ToLower(name)
switch lowerName {
case "topology":
return
case "useunicode", "autoreconnect", "useoldaliasmetadatabehavior":
return
case "charset":
if charset := normalizeMySQLCharsetParam(value); charset != "" {
params.Set("charset", charset)
}
return
case "characterencoding":
if charset := normalizeMySQLCharsetParam(value); charset != "" {
params.Set("charset", charset)
}
return
case "servertimezone":
if loc, ok := normalizeMySQLServerTimezoneParam(value); ok {
params.Set("loc", loc)
}
return
case "usessl":
if enabled, ok := parseMySQLBoolParam(value); ok {
if enabled {
params.Set("tls", "true")
} else {
params.Set("tls", "false")
}
}
return
case "verifyservercertificate":
if verified, ok := parseMySQLBoolParam(value); ok && !verified && params.Get("tls") != "false" {
params.Set("tls", "skip-verify")
}
return
case "trustservercertificate":
if trusted, ok := parseMySQLBoolParam(value); ok && trusted && params.Get("tls") != "false" {
params.Set("tls", "skip-verify")
}
return
case "connecttimeout":
params.Set("timeout", normalizeMySQLDurationParam(value, time.Millisecond))
return
case "sockettimeout":
params.Set("readTimeout", normalizeMySQLDurationParam(value, time.Millisecond))
return
case "timeout", "readtimeout", "writetimeout":
params.Set(name, normalizeMySQLDurationParam(value, time.Second))
return
default:
params.Set(name, value)
}
}
func mergeMySQLConnectionParams(params url.Values, values url.Values) {
keys := make([]string, 0, len(values))
for key := range values {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
lowerName := strings.ToLower(strings.TrimSpace(key))
if lowerName == "verifyservercertificate" || lowerName == "trustservercertificate" {
continue
}
for _, value := range values[key] {
mergeMySQLConnectionParam(params, key, value)
}
}
for _, key := range keys {
lowerName := strings.ToLower(strings.TrimSpace(key))
if lowerName != "verifyservercertificate" && lowerName != "trustservercertificate" {
continue
}
for _, value := range values[key] {
mergeMySQLConnectionParam(params, key, value)
}
}
}
func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) string {
timeout := getConnectTimeoutSeconds(config)
tlsMode := resolveMySQLTLSMode(config)
params := url.Values{}
params.Set("charset", "utf8mb4")
params.Set("parseTime", "True")
params.Set("loc", "Local")
params.Set("timeout", fmt.Sprintf("%ds", timeout))
params.Set("tls", tlsMode)
params.Set("multiStatements", "true")
if parsed, ok := parseMySQLCompatibleURI(config.URI, "mysql", "doris", "diros", "oceanbase"); ok {
mergeMySQLConnectionParams(params, parsed.Query())
}
mergeMySQLConnectionParams(params, mysqlConnectionParamsFromText(config.ConnectionParams))
return fmt.Sprintf(
"%s:%s@%s(%s)/%s?%s",
config.User, config.Password, protocol, address, database, params.Encode(),
)
}
func parseHostPortWithDefault(raw string, defaultPort int) (string, int, bool) {
text := strings.TrimSpace(raw)
if text == "" {
return "", 0, false
}
if strings.HasPrefix(text, "[") {
end := strings.Index(text, "]")
if end < 0 {
return text, defaultPort, true
}
host := text[1:end]
portText := strings.TrimSpace(text[end+1:])
if strings.HasPrefix(portText, ":") {
if p, err := strconv.Atoi(strings.TrimSpace(strings.TrimPrefix(portText, ":"))); err == nil && p > 0 {
return host, p, true
}
}
return host, defaultPort, true
}
lastColon := strings.LastIndex(text, ":")
if lastColon > 0 && strings.Count(text, ":") == 1 {
host := strings.TrimSpace(text[:lastColon])
portText := strings.TrimSpace(text[lastColon+1:])
if host != "" {
if p, err := strconv.Atoi(portText); err == nil && p > 0 {
return host, p, true
}
return host, defaultPort, true
}
}
return text, defaultPort, true
}
func normalizeMySQLAddress(host string, port int) string {
h := strings.TrimSpace(host)
if h == "" {
h = "localhost"
}
p := port
if p <= 0 {
p = defaultMySQLPort
}
return fmt.Sprintf("%s:%d", h, p)
}
var mysqlDatabaseQueries = []string{
"SHOW DATABASES",
"SELECT DATABASE() AS `Database`",
}
func collectMySQLDatabaseNames(queryFn func(string) ([]map[string]interface{}, []string, error)) ([]string, error) {
if queryFn == nil {
return nil, fmt.Errorf("查询函数为空")
}
names := make([]string, 0, 8)
seen := make(map[string]struct{}, 8)
var lastErr error
appendNames := func(rows []map[string]interface{}) {
for _, row := range rows {
for _, key := range []string{"Database", "database"} {
val, ok := row[key]
if !ok || val == nil {
continue
}
name := strings.TrimSpace(fmt.Sprintf("%v", val))
if name == "" || strings.EqualFold(name, "<nil>") {
continue
}
if _, exists := seen[name]; exists {
continue
}
seen[name] = struct{}{}
names = append(names, name)
break
}
}
}
for _, sqlText := range mysqlDatabaseQueries {
rows, _, err := queryFn(sqlText)
if err != nil {
lastErr = err
continue
}
appendNames(rows)
if len(names) > 0 {
return names, nil
}
}
if len(names) > 0 {
return names, nil
}
if lastErr != nil {
return nil, lastErr
}
return nil, fmt.Errorf("未获取到可用数据库")
}
func applyMySQLURI(config connection.ConnectionConfig) connection.ConnectionConfig {
uriText := strings.TrimSpace(config.URI)
if uriText == "" {
return config
}
parsed, ok := parseMySQLCompatibleURI(uriText, "mysql")
if !ok {
return config
}
if parsed.User != nil {
if config.User == "" {
config.User = parsed.User.Username()
}
if pass, ok := parsed.User.Password(); ok && config.Password == "" {
config.Password = pass
}
}
if dbName := strings.TrimPrefix(parsed.Path, "/"); dbName != "" && config.Database == "" {
config.Database = dbName
}
defaultPort := config.Port
if defaultPort <= 0 {
defaultPort = defaultMySQLPort
}
hostsFromURI := make([]string, 0, 4)
hostText := strings.TrimSpace(parsed.Host)
if hostText != "" {
for _, entry := range strings.Split(hostText, ",") {
host, port, ok := parseHostPortWithDefault(entry, defaultPort)
if !ok {
continue
}
hostsFromURI = append(hostsFromURI, normalizeMySQLAddress(host, port))
}
}
if len(config.Hosts) == 0 && len(hostsFromURI) > 0 {
config.Hosts = hostsFromURI
}
if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 {
host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort)
if ok {
config.Host = host
config.Port = port
}
}
if config.Topology == "" {
topology := strings.TrimSpace(parsed.Query().Get("topology"))
if topology != "" {
config.Topology = strings.ToLower(topology)
}
}
return config
}
func collectMySQLAddresses(config connection.ConnectionConfig) []string {
defaultPort := config.Port
if defaultPort <= 0 {
defaultPort = defaultMySQLPort
}
candidates := make([]string, 0, len(config.Hosts)+1)
if len(config.Hosts) > 0 {
candidates = append(candidates, config.Hosts...)
} else {
candidates = append(candidates, normalizeMySQLAddress(config.Host, defaultPort))
}
result := make([]string, 0, len(candidates))
seen := make(map[string]struct{}, len(candidates))
for _, entry := range candidates {
host, port, ok := parseHostPortWithDefault(entry, defaultPort)
if !ok {
continue
}
normalized := normalizeMySQLAddress(host, port)
if _, exists := seen[normalized]; exists {
continue
}
seen[normalized] = struct{}{}
result = append(result, normalized)
}
return result
}
func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) {
database := config.Database
protocol := "tcp"
address := normalizeMySQLAddress(config.Host, config.Port)
if config.UseSSH {
netName, err := ssh.RegisterSSHNetwork(config.SSH)
if err != nil {
return "", fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
protocol = netName
}
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
}
func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
primaryUser := strings.TrimSpace(config.User)
primaryPassword := config.Password
replicaUser := strings.TrimSpace(config.MySQLReplicaUser)
replicaPassword := config.MySQLReplicaPassword
if addressIndex > 0 && replicaUser != "" {
return replicaUser, replicaPassword
}
if primaryUser == "" && replicaUser != "" {
return replicaUser, replicaPassword
}
return config.User, primaryPassword
}
func (m *MySQLDB) Connect(config connection.ConnectionConfig) error {
runConfig := applyMySQLURI(config)
addresses := collectMySQLAddresses(runConfig)
if len(addresses) == 0 {
return fmt.Errorf("连接建立后验证失败:未找到可用的 MySQL 地址")
}
var errorDetails []string
for index, address := range addresses {
candidateConfig := runConfig
host, port, ok := parseHostPortWithDefault(address, defaultMySQLPort)
if !ok {
continue
}
candidateConfig.Host = host
candidateConfig.Port = port
candidateConfig.User, candidateConfig.Password = resolveMySQLCredential(runConfig, index)
dsn, err := m.getDSN(candidateConfig)
if err != nil {
errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err))
continue
}
db, err := sql.Open("mysql", dsn)
if err != nil {
errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err))
continue
}
timeout := getConnectTimeout(candidateConfig)
ctx, cancel := utils.ContextWithTimeout(timeout)
pingErr := db.PingContext(ctx)
cancel()
if pingErr != nil {
_ = db.Close()
errorDetails = append(errorDetails, fmt.Sprintf("%s 验证失败: %v", address, pingErr))
continue
}
m.conn = db
m.pingTimeout = timeout
return nil
}
if len(errorDetails) == 0 {
return fmt.Errorf("连接建立后验证失败:未找到可用的 MySQL 地址")
}
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(errorDetails, ""))
}
func (m *MySQLDB) Close() error {
if m.conn != nil {
return m.conn.Close()
}
return nil
}
func (m *MySQLDB) Ping() error {
if m.conn == nil {
return fmt.Errorf("连接未打开")
}
timeout := m.pingTimeout
if timeout <= 0 {
timeout = 5 * time.Second
}
ctx, cancel := utils.ContextWithTimeout(timeout)
defer cancel()
return m.conn.PingContext(ctx)
}
func (m *MySQLDB) QueryMulti(query string) ([]connection.ResultSetData, error) {
if m.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
rows, err := m.conn.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanMultiRows(rows)
}
func (m *MySQLDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
if m.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
rows, err := m.conn.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanMultiRows(rows)
}
func (m *MySQLDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if m.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := m.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error) {
if m.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := m.conn.Query(query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (m *MySQLDB) ExecBatchContext(ctx context.Context, query string) (int64, error) {
if m.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := m.conn.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (m *MySQLDB) ExecContext(ctx context.Context, query string) (int64, error) {
if m.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := m.conn.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (m *MySQLDB) Exec(query string) (int64, error) {
if m.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := m.conn.Exec(query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (m *MySQLDB) GetDatabases() ([]string, error) {
return collectMySQLDatabaseNames(m.Query)
}
func (m *MySQLDB) GetTables(dbName string) ([]string, error) {
query := "SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE' ORDER BY TABLE_NAME"
if dbName != "" {
query = fmt.Sprintf(
"SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = '%s' AND TABLE_TYPE = 'BASE TABLE' ORDER BY TABLE_NAME",
strings.ReplaceAll(dbName, "'", "''"),
)
}
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var tables []string
for _, row := range data {
for _, v := range row {
tables = append(tables, fmt.Sprintf("%v", v))
break
}
}
return tables, nil
}
func (m *MySQLDB) GetCreateStatement(dbName, tableName string) (string, error) {
query := fmt.Sprintf("SHOW CREATE TABLE `%s`.`%s`", dbName, tableName)
if dbName == "" {
query = fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName)
}
data, _, err := m.Query(query)
if err != nil {
return "", err
}
if len(data) > 0 {
if val, ok := data[0]["Create Table"]; ok {
return fmt.Sprintf("%v", val), nil
}
}
return "", fmt.Errorf("未找到建表语句")
}
func (m *MySQLDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
query := fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`.`%s`", dbName, tableName)
if dbName == "" {
query = fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`", tableName)
}
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var columns []connection.ColumnDefinition
for _, row := range data {
col := connection.ColumnDefinition{
Name: fmt.Sprintf("%v", row["Field"]),
Type: fmt.Sprintf("%v", row["Type"]),
Nullable: fmt.Sprintf("%v", row["Null"]),
Key: fmt.Sprintf("%v", row["Key"]),
Extra: fmt.Sprintf("%v", row["Extra"]),
Comment: fmt.Sprintf("%v", row["Comment"]),
}
if row["Default"] != nil {
d := fmt.Sprintf("%v", row["Default"])
col.Default = &d
}
columns = append(columns, col)
}
return columns, nil
}
func (m *MySQLDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
query := fmt.Sprintf("SHOW INDEX FROM `%s`.`%s`", dbName, tableName)
if dbName == "" {
query = fmt.Sprintf("SHOW INDEX FROM `%s`", tableName)
}
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var indexes []connection.IndexDefinition
for _, row := range data {
nonUnique := 0
if val, ok := row["Non_unique"]; ok {
if f, ok := val.(float64); ok {
nonUnique = int(f)
} else if i, ok := val.(int64); ok {
nonUnique = int(i)
}
}
seq := 0
if val, ok := row["Seq_in_index"]; ok {
if f, ok := val.(float64); ok {
seq = int(f)
} else if i, ok := val.(int64); ok {
seq = int(i)
}
}
subPart := 0
if val, ok := row["Sub_part"]; ok && val != nil {
if f, ok := val.(float64); ok {
subPart = int(f)
} else if i, ok := val.(int64); ok {
subPart = int(i)
}
}
idx := connection.IndexDefinition{
Name: fmt.Sprintf("%v", row["Key_name"]),
ColumnName: fmt.Sprintf("%v", row["Column_name"]),
NonUnique: nonUnique,
SeqInIndex: seq,
IndexType: fmt.Sprintf("%v", row["Index_type"]),
SubPart: subPart,
}
indexes = append(indexes, idx)
}
return indexes, nil
}
func (m *MySQLDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
query := fmt.Sprintf(`SELECT CONSTRAINT_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME
FROM information_schema.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = '%s' AND TABLE_NAME = '%s' AND REFERENCED_TABLE_NAME IS NOT NULL`, dbName, tableName)
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var fks []connection.ForeignKeyDefinition
for _, row := range data {
fk := connection.ForeignKeyDefinition{
Name: fmt.Sprintf("%v", row["CONSTRAINT_NAME"]),
ColumnName: fmt.Sprintf("%v", row["COLUMN_NAME"]),
RefTableName: fmt.Sprintf("%v", row["REFERENCED_TABLE_NAME"]),
RefColumnName: fmt.Sprintf("%v", row["REFERENCED_COLUMN_NAME"]),
ConstraintName: fmt.Sprintf("%v", row["CONSTRAINT_NAME"]),
}
fks = append(fks, fk)
}
return fks, nil
}
func (m *MySQLDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
query := fmt.Sprintf("SHOW TRIGGERS FROM `%s` WHERE `Table` = '%s'", dbName, tableName)
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var triggers []connection.TriggerDefinition
for _, row := range data {
trig := connection.TriggerDefinition{
Name: fmt.Sprintf("%v", row["Trigger"]),
Timing: fmt.Sprintf("%v", row["Timing"]),
Event: fmt.Sprintf("%v", row["Event"]),
Statement: fmt.Sprintf("%v", row["Statement"]),
}
triggers = append(triggers, trig)
}
return triggers, nil
}
func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
if m.conn == nil {
return fmt.Errorf("连接未打开")
}
columnTypeMap := m.loadColumnTypeMap(tableName)
tx, err := m.conn.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// 1. Deletes
for _, pk := range changes.Deletes {
var wheres []string
var args []interface{}
for k, v := range pk {
wheres = append(wheres, fmt.Sprintf("`%s` = ?", k))
args = append(args, normalizeMySQLValueForWrite(k, v, columnTypeMap))
}
if len(wheres) == 0 {
continue
}
query := fmt.Sprintf("DELETE FROM `%s` WHERE %s", tableName, strings.Join(wheres, " AND "))
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("删除失败:%v", err)
}
if err := requireSingleRowAffected(res, "删除"); err != nil {
return err
}
}
// 2. Updates
for _, update := range changes.Updates {
var sets []string
var args []interface{}
for k, v := range update.Values {
sets = append(sets, fmt.Sprintf("`%s` = ?", k))
args = append(args, normalizeMySQLValueForWrite(k, v, columnTypeMap))
}
if len(sets) == 0 {
continue
}
var wheres []string
for k, v := range update.Keys {
wheres = append(wheres, fmt.Sprintf("`%s` = ?", k))
args = append(args, normalizeMySQLValueForWrite(k, v, columnTypeMap))
}
if len(wheres) == 0 {
return fmt.Errorf("更新操作需要主键条件")
}
query := fmt.Sprintf("UPDATE `%s` SET %s WHERE %s", tableName, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("更新失败:%v", err)
}
if err := requireSingleRowAffected(res, "更新"); err != nil {
return err
}
}
// 3. Inserts
for _, row := range changes.Inserts {
var cols []string
var placeholders []string
var args []interface{}
for k, v := range row {
normalizedValue, omit := normalizeMySQLValueForInsert(k, v, columnTypeMap)
if omit {
continue
}
cols = append(cols, fmt.Sprintf("`%s`", k))
placeholders = append(placeholders, "?")
args = append(args, normalizedValue)
}
if len(cols) == 0 {
query := fmt.Sprintf("INSERT INTO `%s` () VALUES ()", tableName)
res, err := tx.Exec(query)
if err != nil {
return fmt.Errorf("插入失败:%v", err)
}
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
return fmt.Errorf("插入未生效:未影响任何行")
}
continue
}
query := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", tableName, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
res, err := tx.Exec(query, args...)
if err != nil {
return fmt.Errorf("插入失败:%v", err)
}
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
return fmt.Errorf("插入未生效:未影响任何行")
}
}
return tx.Commit()
}
func normalizeMySQLComplexValue(value interface{}) interface{} {
switch v := value.(type) {
case map[string]interface{}, []interface{}:
if data, err := json.Marshal(v); err == nil {
return string(data)
}
return fmt.Sprintf("%v", value)
default:
return value
}
}
func normalizeMySQLDateTimeValue(value interface{}) interface{} {
text, ok := value.(string)
if !ok {
return value
}
raw := strings.TrimSpace(text)
if raw == "" {
return value
}
cleaned := strings.ReplaceAll(raw, "+ ", "+")
cleaned = strings.ReplaceAll(cleaned, "- ", "-")
if len(cleaned) >= 19 && cleaned[10] == 'T' {
if strings.HasSuffix(cleaned, "Z") || hasTimezoneOffset(cleaned) {
if t, err := time.Parse(time.RFC3339Nano, cleaned); err == nil {
return formatMySQLDateTime(t)
}
if t, err := time.Parse(time.RFC3339, cleaned); err == nil {
return formatMySQLDateTime(t)
}
}
return strings.Replace(cleaned, "T", " ", 1)
}
if strings.Contains(cleaned, " ") && (strings.HasSuffix(cleaned, "Z") || hasTimezoneOffset(cleaned)) {
candidate := strings.Replace(cleaned, " ", "T", 1)
if t, err := time.Parse(time.RFC3339Nano, candidate); err == nil {
return formatMySQLDateTime(t)
}
if t, err := time.Parse(time.RFC3339, candidate); err == nil {
return formatMySQLDateTime(t)
}
}
return value
}
func (m *MySQLDB) loadColumnTypeMap(tableName string) map[string]string {
result := map[string]string{}
table := strings.TrimSpace(tableName)
if table == "" {
return result
}
columns, err := m.GetColumns("", table)
if err != nil {
logger.Warnf("加载列元数据失败(不影响提交):表=%s err=%v", table, err)
return result
}
for _, col := range columns {
name := strings.ToLower(strings.TrimSpace(col.Name))
if name == "" {
continue
}
result[name] = strings.TrimSpace(col.Type)
}
return result
}
func normalizeMySQLValueForInsert(columnName string, value interface{}, columnTypeMap map[string]string) (interface{}, bool) {
columnType := strings.ToLower(strings.TrimSpace(columnTypeMap[strings.ToLower(strings.TrimSpace(columnName))]))
if isMySQLBitColumnType(columnType) {
return normalizeMySQLBitValue(value), false
}
if !isMySQLTemporalColumnType(columnType) {
return normalizeMySQLComplexValue(value), false
}
text, ok := value.(string)
if ok && strings.TrimSpace(text) == "" {
// INSERT 空时间字段不写入,交给 DB 默认值处理(如 CURRENT_TIMESTAMP
return nil, true
}
return normalizeMySQLDateTimeValue(value), false
}
func normalizeMySQLValueForWrite(columnName string, value interface{}, columnTypeMap map[string]string) interface{} {
columnType := strings.ToLower(strings.TrimSpace(columnTypeMap[strings.ToLower(strings.TrimSpace(columnName))]))
if isMySQLBitColumnType(columnType) {
return normalizeMySQLBitValue(value)
}
if !isMySQLTemporalColumnType(columnType) {
return value
}
text, ok := value.(string)
if ok && strings.TrimSpace(text) == "" {
return nil
}
return normalizeMySQLDateTimeValue(value)
}
func isMySQLTemporalColumnType(columnType string) bool {
raw := strings.ToLower(strings.TrimSpace(columnType))
if raw == "" {
return false
}
if strings.Contains(raw, "datetime") || strings.Contains(raw, "timestamp") {
return true
}
base := raw
if idx := strings.IndexAny(base, "( "); idx >= 0 {
base = base[:idx]
}
return base == "date" || base == "time" || base == "year"
}
func isMySQLBitColumnType(columnType string) bool {
raw := strings.ToLower(strings.TrimSpace(columnType))
if raw == "" {
return false
}
base := raw
if idx := strings.IndexAny(base, "( "); idx >= 0 {
base = base[:idx]
}
return base == "bit"
}
func normalizeMySQLBitValue(value interface{}) interface{} {
switch v := value.(type) {
case nil:
return nil
case []byte:
return v
case bool:
if v {
return []byte{1}
}
return []byte{0}
case string:
if bitValue, ok := parseMySQLBitString(v); ok {
return bitValue
}
return value
case int:
if v >= 0 {
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
}
case int8:
if v >= 0 {
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
}
case int16:
if v >= 0 {
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
}
case int32:
if v >= 0 {
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
}
case int64:
if v >= 0 {
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
}
case uint:
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
case uint8:
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
case uint16:
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
case uint32:
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
case uint64:
if bitValue, ok := mysqlBitBytesFromUint64(v); ok {
return bitValue
}
case float32:
if v >= 0 && math.Trunc(float64(v)) == float64(v) {
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
}
case float64:
if v >= 0 && math.Trunc(v) == v {
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
return bitValue
}
}
}
return value
}
func parseMySQLBitString(text string) ([]byte, bool) {
raw := strings.TrimSpace(text)
if raw == "" {
return nil, false
}
switch strings.ToLower(raw) {
case "true":
return []byte{1}, true
case "false":
return []byte{0}, true
}
if len(raw) > 3 && (raw[0] == 'b' || raw[0] == 'B') && raw[1] == '\'' && raw[len(raw)-1] == '\'' {
value, err := strconv.ParseUint(raw[2:len(raw)-1], 2, 64)
if err == nil {
return mysqlBitBytesFromUint64OrZero(value), true
}
return nil, false
}
if len(raw) > 2 && (strings.HasPrefix(raw, "0b") || strings.HasPrefix(raw, "0B")) {
value, err := strconv.ParseUint(raw[2:], 2, 64)
if err == nil {
return mysqlBitBytesFromUint64OrZero(value), true
}
return nil, false
}
value, err := strconv.ParseUint(raw, 10, 64)
if err != nil {
return nil, false
}
return mysqlBitBytesFromUint64OrZero(value), true
}
func mysqlBitBytesFromUint64(value uint64) ([]byte, bool) {
return mysqlBitBytesFromUint64OrZero(value), true
}
func mysqlBitBytesFromUint64OrZero(value uint64) []byte {
if value == 0 {
return []byte{0}
}
var buf [8]byte
index := len(buf)
for value > 0 {
index--
buf[index] = byte(value)
value >>= 8
}
return append([]byte(nil), buf[index:]...)
}
func hasTimezoneOffset(text string) bool {
pos := strings.LastIndexAny(text, "+-")
if pos < 0 || pos < 10 || pos+1 >= len(text) {
return false
}
offset := text[pos+1:]
if len(offset) == 5 && offset[2] == ':' {
return isAllDigits(offset[:2]) && isAllDigits(offset[3:])
}
if len(offset) == 4 {
return isAllDigits(offset)
}
return false
}
func isAllDigits(text string) bool {
if text == "" {
return false
}
for _, r := range text {
if r < '0' || r > '9' {
return false
}
}
return true
}
func formatMySQLDateTime(t time.Time) string {
base := t.Format("2006-01-02 15:04:05")
nanos := t.Nanosecond()
if nanos == 0 {
return base
}
micro := nanos / 1000
return fmt.Sprintf("%s.%06d", base, micro)
}
func (m *MySQLDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
query := fmt.Sprintf("SELECT TABLE_NAME, COLUMN_NAME, COLUMN_TYPE FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = '%s'", dbName)
if dbName == "" {
return nil, fmt.Errorf("获取全部列信息需要指定数据库名称")
}
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var cols []connection.ColumnDefinitionWithTable
for _, row := range data {
col := connection.ColumnDefinitionWithTable{
TableName: fmt.Sprintf("%v", row["TABLE_NAME"]),
Name: fmt.Sprintf("%v", row["COLUMN_NAME"]),
Type: fmt.Sprintf("%v", row["COLUMN_TYPE"]),
}
cols = append(cols, col)
}
return cols, nil
}