package db import ( "context" "database/sql" "encoding/json" "fmt" "math" "net/url" "sort" "strconv" "strings" "time" "GoNavi-Wails/internal/connection" "GoNavi-Wails/internal/logger" "GoNavi-Wails/internal/ssh" "GoNavi-Wails/internal/utils" _ "github.com/go-sql-driver/mysql" ) type MySQLDB struct { conn *sql.DB pingTimeout time.Duration } const defaultMySQLPort = 3306 func parseMySQLCompatibleURI(raw string, allowedSchemes ...string) (*url.URL, bool) { return parseConnectionURI(raw, allowedSchemes...) } func mysqlConnectionParamsFromText(raw string) url.Values { return connectionParamsFromText(raw) } func parseMySQLBoolParam(raw string) (bool, bool) { switch strings.ToLower(strings.TrimSpace(raw)) { case "1", "true", "yes", "on": return true, true case "0", "false", "no", "off": return false, true default: return false, false } } func normalizeMySQLDurationParam(raw string, unit time.Duration) string { text := strings.TrimSpace(raw) if text == "" { return text } if n, err := strconv.Atoi(text); err == nil && n >= 0 { return (time.Duration(n) * unit).String() } return text } func normalizeMySQLCharsetParam(raw string) string { text := strings.TrimSpace(raw) if text == "" { return "" } lower := strings.ToLower(text) switch lower { case "utf-8", "utf_8", "unicode": return "utf8mb4" case "utf8", "utf8mb4", "latin1", "gbk", "gb2312", "gb18030", "big5", "sjis", "cp932": return lower case "iso-8859-1", "iso8859-1", "iso88591": return "latin1" default: return text } } func normalizeMySQLServerTimezoneParam(raw string) (string, bool) { text := strings.TrimSpace(raw) if text == "" { return "", false } compact := strings.ToUpper(strings.ReplaceAll(text, " ", "")) switch compact { case "LOCAL": return "Local", true case "UTC", "Z", "GMT", "GMT+0", "GMT-0", "GMT+00", "GMT-00", "GMT+00:00", "GMT-00:00", "UTC+0", "UTC-0", "UTC+00", "UTC-00", "UTC+00:00", "UTC-00:00": return "UTC", true case "GMT+8", "GMT+08", "GMT+08:00", "UTC+8", "UTC+08", "UTC+08:00", "ASIA/SHANGHAI", "PRC", "CTT": return "Asia/Shanghai", true } if strings.Contains(text, "/") { if _, err := time.LoadLocation(text); err == nil { return text, true } } return "", false } func mergeMySQLConnectionParam(params url.Values, key string, value string) { name := strings.TrimSpace(key) if name == "" { return } lowerName := strings.ToLower(name) switch lowerName { case "topology": return case "useunicode", "autoreconnect", "useoldaliasmetadatabehavior": return case "charset": if charset := normalizeMySQLCharsetParam(value); charset != "" { params.Set("charset", charset) } return case "characterencoding": if charset := normalizeMySQLCharsetParam(value); charset != "" { params.Set("charset", charset) } return case "servertimezone": if loc, ok := normalizeMySQLServerTimezoneParam(value); ok { params.Set("loc", loc) } return case "usessl": if enabled, ok := parseMySQLBoolParam(value); ok { if enabled { params.Set("tls", "true") } else { params.Set("tls", "false") } } return case "verifyservercertificate": if verified, ok := parseMySQLBoolParam(value); ok && !verified && params.Get("tls") != "false" { params.Set("tls", "skip-verify") } return case "trustservercertificate": if trusted, ok := parseMySQLBoolParam(value); ok && trusted && params.Get("tls") != "false" { params.Set("tls", "skip-verify") } return case "connecttimeout": params.Set("timeout", normalizeMySQLDurationParam(value, time.Millisecond)) return case "sockettimeout": params.Set("readTimeout", normalizeMySQLDurationParam(value, time.Millisecond)) return case "timeout", "readtimeout", "writetimeout": params.Set(name, normalizeMySQLDurationParam(value, time.Second)) return default: params.Set(name, value) } } func mergeMySQLConnectionParams(params url.Values, values url.Values) { keys := make([]string, 0, len(values)) for key := range values { keys = append(keys, key) } sort.Strings(keys) for _, key := range keys { lowerName := strings.ToLower(strings.TrimSpace(key)) if lowerName == "verifyservercertificate" || lowerName == "trustservercertificate" { continue } for _, value := range values[key] { mergeMySQLConnectionParam(params, key, value) } } for _, key := range keys { lowerName := strings.ToLower(strings.TrimSpace(key)) if lowerName != "verifyservercertificate" && lowerName != "trustservercertificate" { continue } for _, value := range values[key] { mergeMySQLConnectionParam(params, key, value) } } } func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) string { timeout := getConnectTimeoutSeconds(config) tlsMode := resolveMySQLTLSMode(config) params := url.Values{} params.Set("charset", "utf8mb4") params.Set("parseTime", "True") params.Set("loc", "Local") params.Set("timeout", fmt.Sprintf("%ds", timeout)) params.Set("tls", tlsMode) params.Set("multiStatements", "true") if parsed, ok := parseMySQLCompatibleURI(config.URI, "mysql", "doris", "diros", "oceanbase"); ok { mergeMySQLConnectionParams(params, parsed.Query()) } mergeMySQLConnectionParams(params, mysqlConnectionParamsFromText(config.ConnectionParams)) return fmt.Sprintf( "%s:%s@%s(%s)/%s?%s", config.User, config.Password, protocol, address, database, params.Encode(), ) } func parseHostPortWithDefault(raw string, defaultPort int) (string, int, bool) { text := strings.TrimSpace(raw) if text == "" { return "", 0, false } if strings.HasPrefix(text, "[") { end := strings.Index(text, "]") if end < 0 { return text, defaultPort, true } host := text[1:end] portText := strings.TrimSpace(text[end+1:]) if strings.HasPrefix(portText, ":") { if p, err := strconv.Atoi(strings.TrimSpace(strings.TrimPrefix(portText, ":"))); err == nil && p > 0 { return host, p, true } } return host, defaultPort, true } lastColon := strings.LastIndex(text, ":") if lastColon > 0 && strings.Count(text, ":") == 1 { host := strings.TrimSpace(text[:lastColon]) portText := strings.TrimSpace(text[lastColon+1:]) if host != "" { if p, err := strconv.Atoi(portText); err == nil && p > 0 { return host, p, true } return host, defaultPort, true } } return text, defaultPort, true } func normalizeMySQLAddress(host string, port int) string { h := strings.TrimSpace(host) if h == "" { h = "localhost" } p := port if p <= 0 { p = defaultMySQLPort } return fmt.Sprintf("%s:%d", h, p) } var mysqlDatabaseQueries = []string{ "SHOW DATABASES", "SELECT DATABASE() AS `Database`", } func collectMySQLDatabaseNames(queryFn func(string) ([]map[string]interface{}, []string, error)) ([]string, error) { if queryFn == nil { return nil, fmt.Errorf("查询函数为空") } names := make([]string, 0, 8) seen := make(map[string]struct{}, 8) var lastErr error appendNames := func(rows []map[string]interface{}) { for _, row := range rows { for _, key := range []string{"Database", "database"} { val, ok := row[key] if !ok || val == nil { continue } name := strings.TrimSpace(fmt.Sprintf("%v", val)) if name == "" || strings.EqualFold(name, "") { continue } if _, exists := seen[name]; exists { continue } seen[name] = struct{}{} names = append(names, name) break } } } for _, sqlText := range mysqlDatabaseQueries { rows, _, err := queryFn(sqlText) if err != nil { lastErr = err continue } appendNames(rows) if len(names) > 0 { return names, nil } } if len(names) > 0 { return names, nil } if lastErr != nil { return nil, lastErr } return nil, fmt.Errorf("未获取到可用数据库") } func applyMySQLURI(config connection.ConnectionConfig) connection.ConnectionConfig { uriText := strings.TrimSpace(config.URI) if uriText == "" { return config } parsed, ok := parseMySQLCompatibleURI(uriText, "mysql") if !ok { return config } if parsed.User != nil { if config.User == "" { config.User = parsed.User.Username() } if pass, ok := parsed.User.Password(); ok && config.Password == "" { config.Password = pass } } if dbName := strings.TrimPrefix(parsed.Path, "/"); dbName != "" && config.Database == "" { config.Database = dbName } defaultPort := config.Port if defaultPort <= 0 { defaultPort = defaultMySQLPort } hostsFromURI := make([]string, 0, 4) hostText := strings.TrimSpace(parsed.Host) if hostText != "" { for _, entry := range strings.Split(hostText, ",") { host, port, ok := parseHostPortWithDefault(entry, defaultPort) if !ok { continue } hostsFromURI = append(hostsFromURI, normalizeMySQLAddress(host, port)) } } if len(config.Hosts) == 0 && len(hostsFromURI) > 0 { config.Hosts = hostsFromURI } if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 { host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort) if ok { config.Host = host config.Port = port } } if config.Topology == "" { topology := strings.TrimSpace(parsed.Query().Get("topology")) if topology != "" { config.Topology = strings.ToLower(topology) } } return config } func collectMySQLAddresses(config connection.ConnectionConfig) []string { defaultPort := config.Port if defaultPort <= 0 { defaultPort = defaultMySQLPort } candidates := make([]string, 0, len(config.Hosts)+1) if len(config.Hosts) > 0 { candidates = append(candidates, config.Hosts...) } else { candidates = append(candidates, normalizeMySQLAddress(config.Host, defaultPort)) } result := make([]string, 0, len(candidates)) seen := make(map[string]struct{}, len(candidates)) for _, entry := range candidates { host, port, ok := parseHostPortWithDefault(entry, defaultPort) if !ok { continue } normalized := normalizeMySQLAddress(host, port) if _, exists := seen[normalized]; exists { continue } seen[normalized] = struct{}{} result = append(result, normalized) } return result } func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) { database := config.Database protocol := "tcp" address := normalizeMySQLAddress(config.Host, config.Port) if config.UseSSH { netName, err := ssh.RegisterSSHNetwork(config.SSH) if err != nil { return "", fmt.Errorf("创建 SSH 隧道失败:%w", err) } protocol = netName } return buildMySQLCompatibleDSN(config, protocol, address, database), nil } func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) { primaryUser := strings.TrimSpace(config.User) primaryPassword := config.Password replicaUser := strings.TrimSpace(config.MySQLReplicaUser) replicaPassword := config.MySQLReplicaPassword if addressIndex > 0 && replicaUser != "" { return replicaUser, replicaPassword } if primaryUser == "" && replicaUser != "" { return replicaUser, replicaPassword } return config.User, primaryPassword } func (m *MySQLDB) Connect(config connection.ConnectionConfig) error { runConfig := applyMySQLURI(config) addresses := collectMySQLAddresses(runConfig) if len(addresses) == 0 { return fmt.Errorf("连接建立后验证失败:未找到可用的 MySQL 地址") } var errorDetails []string for index, address := range addresses { candidateConfig := runConfig host, port, ok := parseHostPortWithDefault(address, defaultMySQLPort) if !ok { continue } candidateConfig.Host = host candidateConfig.Port = port candidateConfig.User, candidateConfig.Password = resolveMySQLCredential(runConfig, index) dsn, err := m.getDSN(candidateConfig) if err != nil { errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err)) continue } db, err := sql.Open("mysql", dsn) if err != nil { errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err)) continue } timeout := getConnectTimeout(candidateConfig) ctx, cancel := utils.ContextWithTimeout(timeout) pingErr := db.PingContext(ctx) cancel() if pingErr != nil { _ = db.Close() errorDetails = append(errorDetails, fmt.Sprintf("%s 验证失败: %v", address, pingErr)) continue } m.conn = db m.pingTimeout = timeout return nil } if len(errorDetails) == 0 { return fmt.Errorf("连接建立后验证失败:未找到可用的 MySQL 地址") } return fmt.Errorf("连接建立后验证失败:%s", strings.Join(errorDetails, ";")) } func (m *MySQLDB) Close() error { if m.conn != nil { return m.conn.Close() } return nil } func (m *MySQLDB) Ping() error { if m.conn == nil { return fmt.Errorf("连接未打开") } timeout := m.pingTimeout if timeout <= 0 { timeout = 5 * time.Second } ctx, cancel := utils.ContextWithTimeout(timeout) defer cancel() return m.conn.PingContext(ctx) } func (m *MySQLDB) QueryMulti(query string) ([]connection.ResultSetData, error) { if m.conn == nil { return nil, fmt.Errorf("连接未打开") } rows, err := m.conn.Query(query) if err != nil { return nil, err } defer rows.Close() return scanMultiRows(rows) } func (m *MySQLDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) { if m.conn == nil { return nil, fmt.Errorf("连接未打开") } rows, err := m.conn.QueryContext(ctx, query) if err != nil { return nil, err } defer rows.Close() return scanMultiRows(rows) } func (m *MySQLDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { if m.conn == nil { return nil, nil, fmt.Errorf("连接未打开") } rows, err := m.conn.QueryContext(ctx, query) if err != nil { return nil, nil, err } defer rows.Close() return scanRows(rows) } func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error) { if m.conn == nil { return nil, nil, fmt.Errorf("连接未打开") } rows, err := m.conn.Query(query) if err != nil { return nil, nil, err } defer rows.Close() return scanRows(rows) } func (m *MySQLDB) ExecBatchContext(ctx context.Context, query string) (int64, error) { if m.conn == nil { return 0, fmt.Errorf("连接未打开") } res, err := m.conn.ExecContext(ctx, query) if err != nil { return 0, err } return res.RowsAffected() } func (m *MySQLDB) ExecContext(ctx context.Context, query string) (int64, error) { if m.conn == nil { return 0, fmt.Errorf("连接未打开") } res, err := m.conn.ExecContext(ctx, query) if err != nil { return 0, err } return res.RowsAffected() } func (m *MySQLDB) Exec(query string) (int64, error) { if m.conn == nil { return 0, fmt.Errorf("连接未打开") } res, err := m.conn.Exec(query) if err != nil { return 0, err } return res.RowsAffected() } func (m *MySQLDB) GetDatabases() ([]string, error) { return collectMySQLDatabaseNames(m.Query) } func (m *MySQLDB) GetTables(dbName string) ([]string, error) { query := "SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE' ORDER BY TABLE_NAME" if dbName != "" { query = fmt.Sprintf( "SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = '%s' AND TABLE_TYPE = 'BASE TABLE' ORDER BY TABLE_NAME", strings.ReplaceAll(dbName, "'", "''"), ) } data, _, err := m.Query(query) if err != nil { return nil, err } var tables []string for _, row := range data { for _, v := range row { tables = append(tables, fmt.Sprintf("%v", v)) break } } return tables, nil } func (m *MySQLDB) GetCreateStatement(dbName, tableName string) (string, error) { query := fmt.Sprintf("SHOW CREATE TABLE `%s`.`%s`", dbName, tableName) if dbName == "" { query = fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName) } data, _, err := m.Query(query) if err != nil { return "", err } if len(data) > 0 { if val, ok := data[0]["Create Table"]; ok { return fmt.Sprintf("%v", val), nil } } return "", fmt.Errorf("未找到建表语句") } func (m *MySQLDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { query := fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`.`%s`", dbName, tableName) if dbName == "" { query = fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`", tableName) } data, _, err := m.Query(query) if err != nil { return nil, err } var columns []connection.ColumnDefinition for _, row := range data { col := connection.ColumnDefinition{ Name: fmt.Sprintf("%v", row["Field"]), Type: fmt.Sprintf("%v", row["Type"]), Nullable: fmt.Sprintf("%v", row["Null"]), Key: fmt.Sprintf("%v", row["Key"]), Extra: fmt.Sprintf("%v", row["Extra"]), Comment: fmt.Sprintf("%v", row["Comment"]), } if row["Default"] != nil { d := fmt.Sprintf("%v", row["Default"]) col.Default = &d } columns = append(columns, col) } return columns, nil } func (m *MySQLDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { query := fmt.Sprintf("SHOW INDEX FROM `%s`.`%s`", dbName, tableName) if dbName == "" { query = fmt.Sprintf("SHOW INDEX FROM `%s`", tableName) } data, _, err := m.Query(query) if err != nil { return nil, err } var indexes []connection.IndexDefinition for _, row := range data { nonUnique := 0 if val, ok := row["Non_unique"]; ok { if f, ok := val.(float64); ok { nonUnique = int(f) } else if i, ok := val.(int64); ok { nonUnique = int(i) } } seq := 0 if val, ok := row["Seq_in_index"]; ok { if f, ok := val.(float64); ok { seq = int(f) } else if i, ok := val.(int64); ok { seq = int(i) } } subPart := 0 if val, ok := row["Sub_part"]; ok && val != nil { if f, ok := val.(float64); ok { subPart = int(f) } else if i, ok := val.(int64); ok { subPart = int(i) } } idx := connection.IndexDefinition{ Name: fmt.Sprintf("%v", row["Key_name"]), ColumnName: fmt.Sprintf("%v", row["Column_name"]), NonUnique: nonUnique, SeqInIndex: seq, IndexType: fmt.Sprintf("%v", row["Index_type"]), SubPart: subPart, } indexes = append(indexes, idx) } return indexes, nil } func (m *MySQLDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { query := fmt.Sprintf(`SELECT CONSTRAINT_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME FROM information_schema.KEY_COLUMN_USAGE WHERE TABLE_SCHEMA = '%s' AND TABLE_NAME = '%s' AND REFERENCED_TABLE_NAME IS NOT NULL`, dbName, tableName) data, _, err := m.Query(query) if err != nil { return nil, err } var fks []connection.ForeignKeyDefinition for _, row := range data { fk := connection.ForeignKeyDefinition{ Name: fmt.Sprintf("%v", row["CONSTRAINT_NAME"]), ColumnName: fmt.Sprintf("%v", row["COLUMN_NAME"]), RefTableName: fmt.Sprintf("%v", row["REFERENCED_TABLE_NAME"]), RefColumnName: fmt.Sprintf("%v", row["REFERENCED_COLUMN_NAME"]), ConstraintName: fmt.Sprintf("%v", row["CONSTRAINT_NAME"]), } fks = append(fks, fk) } return fks, nil } func (m *MySQLDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { query := fmt.Sprintf("SHOW TRIGGERS FROM `%s` WHERE `Table` = '%s'", dbName, tableName) data, _, err := m.Query(query) if err != nil { return nil, err } var triggers []connection.TriggerDefinition for _, row := range data { trig := connection.TriggerDefinition{ Name: fmt.Sprintf("%v", row["Trigger"]), Timing: fmt.Sprintf("%v", row["Timing"]), Event: fmt.Sprintf("%v", row["Event"]), Statement: fmt.Sprintf("%v", row["Statement"]), } triggers = append(triggers, trig) } return triggers, nil } func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { if m.conn == nil { return fmt.Errorf("连接未打开") } columnTypeMap := m.loadColumnTypeMap(tableName) tx, err := m.conn.Begin() if err != nil { return err } defer tx.Rollback() // 1. Deletes for _, pk := range changes.Deletes { var wheres []string var args []interface{} for k, v := range pk { wheres = append(wheres, fmt.Sprintf("`%s` = ?", k)) args = append(args, normalizeMySQLValueForWrite(k, v, columnTypeMap)) } if len(wheres) == 0 { continue } query := fmt.Sprintf("DELETE FROM `%s` WHERE %s", tableName, strings.Join(wheres, " AND ")) res, err := tx.Exec(query, args...) if err != nil { return fmt.Errorf("删除失败:%v", err) } if err := requireSingleRowAffected(res, "删除"); err != nil { return err } } // 2. Updates for _, update := range changes.Updates { var sets []string var args []interface{} for k, v := range update.Values { sets = append(sets, fmt.Sprintf("`%s` = ?", k)) args = append(args, normalizeMySQLValueForWrite(k, v, columnTypeMap)) } if len(sets) == 0 { continue } var wheres []string for k, v := range update.Keys { wheres = append(wheres, fmt.Sprintf("`%s` = ?", k)) args = append(args, normalizeMySQLValueForWrite(k, v, columnTypeMap)) } if len(wheres) == 0 { return fmt.Errorf("更新操作需要主键条件") } query := fmt.Sprintf("UPDATE `%s` SET %s WHERE %s", tableName, strings.Join(sets, ", "), strings.Join(wheres, " AND ")) res, err := tx.Exec(query, args...) if err != nil { return fmt.Errorf("更新失败:%v", err) } if err := requireSingleRowAffected(res, "更新"); err != nil { return err } } // 3. Inserts for _, row := range changes.Inserts { var cols []string var placeholders []string var args []interface{} for k, v := range row { normalizedValue, omit := normalizeMySQLValueForInsert(k, v, columnTypeMap) if omit { continue } cols = append(cols, fmt.Sprintf("`%s`", k)) placeholders = append(placeholders, "?") args = append(args, normalizedValue) } if len(cols) == 0 { query := fmt.Sprintf("INSERT INTO `%s` () VALUES ()", tableName) res, err := tx.Exec(query) if err != nil { return fmt.Errorf("插入失败:%v", err) } if affected, err := res.RowsAffected(); err == nil && affected == 0 { return fmt.Errorf("插入未生效:未影响任何行") } continue } query := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", tableName, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) res, err := tx.Exec(query, args...) if err != nil { return fmt.Errorf("插入失败:%v", err) } if affected, err := res.RowsAffected(); err == nil && affected == 0 { return fmt.Errorf("插入未生效:未影响任何行") } } return tx.Commit() } func normalizeMySQLComplexValue(value interface{}) interface{} { switch v := value.(type) { case map[string]interface{}, []interface{}: if data, err := json.Marshal(v); err == nil { return string(data) } return fmt.Sprintf("%v", value) default: return value } } func normalizeMySQLDateTimeValue(value interface{}) interface{} { text, ok := value.(string) if !ok { return value } raw := strings.TrimSpace(text) if raw == "" { return value } cleaned := strings.ReplaceAll(raw, "+ ", "+") cleaned = strings.ReplaceAll(cleaned, "- ", "-") if len(cleaned) >= 19 && cleaned[10] == 'T' { if strings.HasSuffix(cleaned, "Z") || hasTimezoneOffset(cleaned) { if t, err := time.Parse(time.RFC3339Nano, cleaned); err == nil { return formatMySQLDateTime(t) } if t, err := time.Parse(time.RFC3339, cleaned); err == nil { return formatMySQLDateTime(t) } } return strings.Replace(cleaned, "T", " ", 1) } if strings.Contains(cleaned, " ") && (strings.HasSuffix(cleaned, "Z") || hasTimezoneOffset(cleaned)) { candidate := strings.Replace(cleaned, " ", "T", 1) if t, err := time.Parse(time.RFC3339Nano, candidate); err == nil { return formatMySQLDateTime(t) } if t, err := time.Parse(time.RFC3339, candidate); err == nil { return formatMySQLDateTime(t) } } return value } func (m *MySQLDB) loadColumnTypeMap(tableName string) map[string]string { result := map[string]string{} table := strings.TrimSpace(tableName) if table == "" { return result } columns, err := m.GetColumns("", table) if err != nil { logger.Warnf("加载列元数据失败(不影响提交):表=%s err=%v", table, err) return result } for _, col := range columns { name := strings.ToLower(strings.TrimSpace(col.Name)) if name == "" { continue } result[name] = strings.TrimSpace(col.Type) } return result } func normalizeMySQLValueForInsert(columnName string, value interface{}, columnTypeMap map[string]string) (interface{}, bool) { columnType := strings.ToLower(strings.TrimSpace(columnTypeMap[strings.ToLower(strings.TrimSpace(columnName))])) if isMySQLBitColumnType(columnType) { return normalizeMySQLBitValue(value), false } if !isMySQLTemporalColumnType(columnType) { return normalizeMySQLComplexValue(value), false } text, ok := value.(string) if ok && strings.TrimSpace(text) == "" { // INSERT 空时间字段不写入,交给 DB 默认值处理(如 CURRENT_TIMESTAMP)。 return nil, true } return normalizeMySQLDateTimeValue(value), false } func normalizeMySQLValueForWrite(columnName string, value interface{}, columnTypeMap map[string]string) interface{} { columnType := strings.ToLower(strings.TrimSpace(columnTypeMap[strings.ToLower(strings.TrimSpace(columnName))])) if isMySQLBitColumnType(columnType) { return normalizeMySQLBitValue(value) } if !isMySQLTemporalColumnType(columnType) { return value } text, ok := value.(string) if ok && strings.TrimSpace(text) == "" { return nil } return normalizeMySQLDateTimeValue(value) } func isMySQLTemporalColumnType(columnType string) bool { raw := strings.ToLower(strings.TrimSpace(columnType)) if raw == "" { return false } if strings.Contains(raw, "datetime") || strings.Contains(raw, "timestamp") { return true } base := raw if idx := strings.IndexAny(base, "( "); idx >= 0 { base = base[:idx] } return base == "date" || base == "time" || base == "year" } func isMySQLBitColumnType(columnType string) bool { raw := strings.ToLower(strings.TrimSpace(columnType)) if raw == "" { return false } base := raw if idx := strings.IndexAny(base, "( "); idx >= 0 { base = base[:idx] } return base == "bit" } func normalizeMySQLBitValue(value interface{}) interface{} { switch v := value.(type) { case nil: return nil case []byte: return v case bool: if v { return []byte{1} } return []byte{0} case string: if bitValue, ok := parseMySQLBitString(v); ok { return bitValue } return value case int: if v >= 0 { if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok { return bitValue } } case int8: if v >= 0 { if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok { return bitValue } } case int16: if v >= 0 { if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok { return bitValue } } case int32: if v >= 0 { if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok { return bitValue } } case int64: if v >= 0 { if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok { return bitValue } } case uint: if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok { return bitValue } case uint8: if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok { return bitValue } case uint16: if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok { return bitValue } case uint32: if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok { return bitValue } case uint64: if bitValue, ok := mysqlBitBytesFromUint64(v); ok { return bitValue } case float32: if v >= 0 && math.Trunc(float64(v)) == float64(v) { if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok { return bitValue } } case float64: if v >= 0 && math.Trunc(v) == v { if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok { return bitValue } } } return value } func parseMySQLBitString(text string) ([]byte, bool) { raw := strings.TrimSpace(text) if raw == "" { return nil, false } switch strings.ToLower(raw) { case "true": return []byte{1}, true case "false": return []byte{0}, true } if len(raw) > 3 && (raw[0] == 'b' || raw[0] == 'B') && raw[1] == '\'' && raw[len(raw)-1] == '\'' { value, err := strconv.ParseUint(raw[2:len(raw)-1], 2, 64) if err == nil { return mysqlBitBytesFromUint64OrZero(value), true } return nil, false } if len(raw) > 2 && (strings.HasPrefix(raw, "0b") || strings.HasPrefix(raw, "0B")) { value, err := strconv.ParseUint(raw[2:], 2, 64) if err == nil { return mysqlBitBytesFromUint64OrZero(value), true } return nil, false } value, err := strconv.ParseUint(raw, 10, 64) if err != nil { return nil, false } return mysqlBitBytesFromUint64OrZero(value), true } func mysqlBitBytesFromUint64(value uint64) ([]byte, bool) { return mysqlBitBytesFromUint64OrZero(value), true } func mysqlBitBytesFromUint64OrZero(value uint64) []byte { if value == 0 { return []byte{0} } var buf [8]byte index := len(buf) for value > 0 { index-- buf[index] = byte(value) value >>= 8 } return append([]byte(nil), buf[index:]...) } func hasTimezoneOffset(text string) bool { pos := strings.LastIndexAny(text, "+-") if pos < 0 || pos < 10 || pos+1 >= len(text) { return false } offset := text[pos+1:] if len(offset) == 5 && offset[2] == ':' { return isAllDigits(offset[:2]) && isAllDigits(offset[3:]) } if len(offset) == 4 { return isAllDigits(offset) } return false } func isAllDigits(text string) bool { if text == "" { return false } for _, r := range text { if r < '0' || r > '9' { return false } } return true } func formatMySQLDateTime(t time.Time) string { base := t.Format("2006-01-02 15:04:05") nanos := t.Nanosecond() if nanos == 0 { return base } micro := nanos / 1000 return fmt.Sprintf("%s.%06d", base, micro) } func (m *MySQLDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { query := fmt.Sprintf("SELECT TABLE_NAME, COLUMN_NAME, COLUMN_TYPE FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = '%s'", dbName) if dbName == "" { return nil, fmt.Errorf("获取全部列信息需要指定数据库名称") } data, _, err := m.Query(query) if err != nil { return nil, err } var cols []connection.ColumnDefinitionWithTable for _, row := range data { col := connection.ColumnDefinitionWithTable{ TableName: fmt.Sprintf("%v", row["TABLE_NAME"]), Name: fmt.Sprintf("%v", row["COLUMN_NAME"]), Type: fmt.Sprintf("%v", row["COLUMN_TYPE"]), } cols = append(cols, col) } return cols, nil }