mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-06 20:03:05 +08:00
- 数据源支持:新增 OceanBase 与 OpenGauss optional driver-agent 实现 - 连接适配:复用 MySQL/PostgreSQL 兼容链路并补齐查询、DDL、同步能力 - 前端入口:补充连接表单、侧边栏、图标、SQL 方言和危险操作识别 - 驱动管理:更新 driver manifest、安装提示和 revision 自动生成链路 - 构建发布:支持多平台 driver-agent 打包并优化 release 构建失败提示
1200 lines
30 KiB
Go
1200 lines
30 KiB
Go
package db
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"encoding/json"
|
||
"fmt"
|
||
"math"
|
||
"net/url"
|
||
"sort"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"GoNavi-Wails/internal/connection"
|
||
"GoNavi-Wails/internal/logger"
|
||
"GoNavi-Wails/internal/ssh"
|
||
"GoNavi-Wails/internal/utils"
|
||
|
||
_ "github.com/go-sql-driver/mysql"
|
||
)
|
||
|
||
type MySQLDB struct {
|
||
conn *sql.DB
|
||
pingTimeout time.Duration
|
||
}
|
||
|
||
const defaultMySQLPort = 3306
|
||
|
||
func parseMySQLCompatibleURI(raw string, allowedSchemes ...string) (*url.URL, bool) {
|
||
return parseConnectionURI(raw, allowedSchemes...)
|
||
}
|
||
|
||
func mysqlConnectionParamsFromText(raw string) url.Values {
|
||
return connectionParamsFromText(raw)
|
||
}
|
||
|
||
func parseMySQLBoolParam(raw string) (bool, bool) {
|
||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||
case "1", "true", "yes", "on":
|
||
return true, true
|
||
case "0", "false", "no", "off":
|
||
return false, true
|
||
default:
|
||
return false, false
|
||
}
|
||
}
|
||
|
||
func normalizeMySQLDurationParam(raw string, unit time.Duration) string {
|
||
text := strings.TrimSpace(raw)
|
||
if text == "" {
|
||
return text
|
||
}
|
||
if n, err := strconv.Atoi(text); err == nil && n >= 0 {
|
||
return (time.Duration(n) * unit).String()
|
||
}
|
||
return text
|
||
}
|
||
|
||
func normalizeMySQLCharsetParam(raw string) string {
|
||
text := strings.TrimSpace(raw)
|
||
if text == "" {
|
||
return ""
|
||
}
|
||
lower := strings.ToLower(text)
|
||
switch lower {
|
||
case "utf-8", "utf_8", "unicode":
|
||
return "utf8mb4"
|
||
case "utf8", "utf8mb4", "latin1", "gbk", "gb2312", "gb18030", "big5", "sjis", "cp932":
|
||
return lower
|
||
case "iso-8859-1", "iso8859-1", "iso88591":
|
||
return "latin1"
|
||
default:
|
||
return text
|
||
}
|
||
}
|
||
|
||
func normalizeMySQLServerTimezoneParam(raw string) (string, bool) {
|
||
text := strings.TrimSpace(raw)
|
||
if text == "" {
|
||
return "", false
|
||
}
|
||
compact := strings.ToUpper(strings.ReplaceAll(text, " ", ""))
|
||
switch compact {
|
||
case "LOCAL":
|
||
return "Local", true
|
||
case "UTC", "Z", "GMT", "GMT+0", "GMT-0", "GMT+00", "GMT-00", "GMT+00:00", "GMT-00:00",
|
||
"UTC+0", "UTC-0", "UTC+00", "UTC-00", "UTC+00:00", "UTC-00:00":
|
||
return "UTC", true
|
||
case "GMT+8", "GMT+08", "GMT+08:00", "UTC+8", "UTC+08", "UTC+08:00",
|
||
"ASIA/SHANGHAI", "PRC", "CTT":
|
||
return "Asia/Shanghai", true
|
||
}
|
||
if strings.Contains(text, "/") {
|
||
if _, err := time.LoadLocation(text); err == nil {
|
||
return text, true
|
||
}
|
||
}
|
||
return "", false
|
||
}
|
||
|
||
func mergeMySQLConnectionParam(params url.Values, key string, value string) {
|
||
name := strings.TrimSpace(key)
|
||
if name == "" {
|
||
return
|
||
}
|
||
lowerName := strings.ToLower(name)
|
||
switch lowerName {
|
||
case "topology":
|
||
return
|
||
case "useunicode", "autoreconnect", "useoldaliasmetadatabehavior":
|
||
return
|
||
case "charset":
|
||
if charset := normalizeMySQLCharsetParam(value); charset != "" {
|
||
params.Set("charset", charset)
|
||
}
|
||
return
|
||
case "characterencoding":
|
||
if charset := normalizeMySQLCharsetParam(value); charset != "" {
|
||
params.Set("charset", charset)
|
||
}
|
||
return
|
||
case "servertimezone":
|
||
if loc, ok := normalizeMySQLServerTimezoneParam(value); ok {
|
||
params.Set("loc", loc)
|
||
}
|
||
return
|
||
case "usessl":
|
||
if enabled, ok := parseMySQLBoolParam(value); ok {
|
||
if enabled {
|
||
params.Set("tls", "true")
|
||
} else {
|
||
params.Set("tls", "false")
|
||
}
|
||
}
|
||
return
|
||
case "verifyservercertificate":
|
||
if verified, ok := parseMySQLBoolParam(value); ok && !verified && params.Get("tls") != "false" {
|
||
params.Set("tls", "skip-verify")
|
||
}
|
||
return
|
||
case "trustservercertificate":
|
||
if trusted, ok := parseMySQLBoolParam(value); ok && trusted && params.Get("tls") != "false" {
|
||
params.Set("tls", "skip-verify")
|
||
}
|
||
return
|
||
case "connecttimeout":
|
||
params.Set("timeout", normalizeMySQLDurationParam(value, time.Millisecond))
|
||
return
|
||
case "sockettimeout":
|
||
params.Set("readTimeout", normalizeMySQLDurationParam(value, time.Millisecond))
|
||
return
|
||
case "timeout", "readtimeout", "writetimeout":
|
||
params.Set(name, normalizeMySQLDurationParam(value, time.Second))
|
||
return
|
||
default:
|
||
params.Set(name, value)
|
||
}
|
||
}
|
||
|
||
func mergeMySQLConnectionParams(params url.Values, values url.Values) {
|
||
keys := make([]string, 0, len(values))
|
||
for key := range values {
|
||
keys = append(keys, key)
|
||
}
|
||
sort.Strings(keys)
|
||
for _, key := range keys {
|
||
lowerName := strings.ToLower(strings.TrimSpace(key))
|
||
if lowerName == "verifyservercertificate" || lowerName == "trustservercertificate" {
|
||
continue
|
||
}
|
||
for _, value := range values[key] {
|
||
mergeMySQLConnectionParam(params, key, value)
|
||
}
|
||
}
|
||
for _, key := range keys {
|
||
lowerName := strings.ToLower(strings.TrimSpace(key))
|
||
if lowerName != "verifyservercertificate" && lowerName != "trustservercertificate" {
|
||
continue
|
||
}
|
||
for _, value := range values[key] {
|
||
mergeMySQLConnectionParam(params, key, value)
|
||
}
|
||
}
|
||
}
|
||
|
||
func buildMySQLCompatibleDSN(config connection.ConnectionConfig, protocol, address, database string) string {
|
||
timeout := getConnectTimeoutSeconds(config)
|
||
tlsMode := resolveMySQLTLSMode(config)
|
||
params := url.Values{}
|
||
params.Set("charset", "utf8mb4")
|
||
params.Set("parseTime", "True")
|
||
params.Set("loc", "Local")
|
||
params.Set("timeout", fmt.Sprintf("%ds", timeout))
|
||
params.Set("tls", tlsMode)
|
||
params.Set("multiStatements", "true")
|
||
if parsed, ok := parseMySQLCompatibleURI(config.URI, "mysql", "doris", "diros", "oceanbase"); ok {
|
||
mergeMySQLConnectionParams(params, parsed.Query())
|
||
}
|
||
mergeMySQLConnectionParams(params, mysqlConnectionParamsFromText(config.ConnectionParams))
|
||
return fmt.Sprintf(
|
||
"%s:%s@%s(%s)/%s?%s",
|
||
config.User, config.Password, protocol, address, database, params.Encode(),
|
||
)
|
||
}
|
||
|
||
func parseHostPortWithDefault(raw string, defaultPort int) (string, int, bool) {
|
||
text := strings.TrimSpace(raw)
|
||
if text == "" {
|
||
return "", 0, false
|
||
}
|
||
|
||
if strings.HasPrefix(text, "[") {
|
||
end := strings.Index(text, "]")
|
||
if end < 0 {
|
||
return text, defaultPort, true
|
||
}
|
||
host := text[1:end]
|
||
portText := strings.TrimSpace(text[end+1:])
|
||
if strings.HasPrefix(portText, ":") {
|
||
if p, err := strconv.Atoi(strings.TrimSpace(strings.TrimPrefix(portText, ":"))); err == nil && p > 0 {
|
||
return host, p, true
|
||
}
|
||
}
|
||
return host, defaultPort, true
|
||
}
|
||
|
||
lastColon := strings.LastIndex(text, ":")
|
||
if lastColon > 0 && strings.Count(text, ":") == 1 {
|
||
host := strings.TrimSpace(text[:lastColon])
|
||
portText := strings.TrimSpace(text[lastColon+1:])
|
||
if host != "" {
|
||
if p, err := strconv.Atoi(portText); err == nil && p > 0 {
|
||
return host, p, true
|
||
}
|
||
return host, defaultPort, true
|
||
}
|
||
}
|
||
|
||
return text, defaultPort, true
|
||
}
|
||
|
||
func normalizeMySQLAddress(host string, port int) string {
|
||
h := strings.TrimSpace(host)
|
||
if h == "" {
|
||
h = "localhost"
|
||
}
|
||
p := port
|
||
if p <= 0 {
|
||
p = defaultMySQLPort
|
||
}
|
||
return fmt.Sprintf("%s:%d", h, p)
|
||
}
|
||
|
||
var mysqlDatabaseQueries = []string{
|
||
"SHOW DATABASES",
|
||
"SELECT DATABASE() AS `Database`",
|
||
}
|
||
|
||
func collectMySQLDatabaseNames(queryFn func(string) ([]map[string]interface{}, []string, error)) ([]string, error) {
|
||
if queryFn == nil {
|
||
return nil, fmt.Errorf("查询函数为空")
|
||
}
|
||
|
||
names := make([]string, 0, 8)
|
||
seen := make(map[string]struct{}, 8)
|
||
var lastErr error
|
||
|
||
appendNames := func(rows []map[string]interface{}) {
|
||
for _, row := range rows {
|
||
for _, key := range []string{"Database", "database"} {
|
||
val, ok := row[key]
|
||
if !ok || val == nil {
|
||
continue
|
||
}
|
||
name := strings.TrimSpace(fmt.Sprintf("%v", val))
|
||
if name == "" || strings.EqualFold(name, "<nil>") {
|
||
continue
|
||
}
|
||
if _, exists := seen[name]; exists {
|
||
continue
|
||
}
|
||
seen[name] = struct{}{}
|
||
names = append(names, name)
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
for _, sqlText := range mysqlDatabaseQueries {
|
||
rows, _, err := queryFn(sqlText)
|
||
if err != nil {
|
||
lastErr = err
|
||
continue
|
||
}
|
||
appendNames(rows)
|
||
if len(names) > 0 {
|
||
return names, nil
|
||
}
|
||
}
|
||
|
||
if len(names) > 0 {
|
||
return names, nil
|
||
}
|
||
if lastErr != nil {
|
||
return nil, lastErr
|
||
}
|
||
return nil, fmt.Errorf("未获取到可用数据库")
|
||
}
|
||
|
||
func applyMySQLURI(config connection.ConnectionConfig) connection.ConnectionConfig {
|
||
uriText := strings.TrimSpace(config.URI)
|
||
if uriText == "" {
|
||
return config
|
||
}
|
||
parsed, ok := parseMySQLCompatibleURI(uriText, "mysql")
|
||
if !ok {
|
||
return config
|
||
}
|
||
|
||
if parsed.User != nil {
|
||
if config.User == "" {
|
||
config.User = parsed.User.Username()
|
||
}
|
||
if pass, ok := parsed.User.Password(); ok && config.Password == "" {
|
||
config.Password = pass
|
||
}
|
||
}
|
||
|
||
if dbName := strings.TrimPrefix(parsed.Path, "/"); dbName != "" && config.Database == "" {
|
||
config.Database = dbName
|
||
}
|
||
|
||
defaultPort := config.Port
|
||
if defaultPort <= 0 {
|
||
defaultPort = defaultMySQLPort
|
||
}
|
||
|
||
hostsFromURI := make([]string, 0, 4)
|
||
hostText := strings.TrimSpace(parsed.Host)
|
||
if hostText != "" {
|
||
for _, entry := range strings.Split(hostText, ",") {
|
||
host, port, ok := parseHostPortWithDefault(entry, defaultPort)
|
||
if !ok {
|
||
continue
|
||
}
|
||
hostsFromURI = append(hostsFromURI, normalizeMySQLAddress(host, port))
|
||
}
|
||
}
|
||
|
||
if len(config.Hosts) == 0 && len(hostsFromURI) > 0 {
|
||
config.Hosts = hostsFromURI
|
||
}
|
||
if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 {
|
||
host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort)
|
||
if ok {
|
||
config.Host = host
|
||
config.Port = port
|
||
}
|
||
}
|
||
|
||
if config.Topology == "" {
|
||
topology := strings.TrimSpace(parsed.Query().Get("topology"))
|
||
if topology != "" {
|
||
config.Topology = strings.ToLower(topology)
|
||
}
|
||
}
|
||
|
||
return config
|
||
}
|
||
|
||
func collectMySQLAddresses(config connection.ConnectionConfig) []string {
|
||
defaultPort := config.Port
|
||
if defaultPort <= 0 {
|
||
defaultPort = defaultMySQLPort
|
||
}
|
||
|
||
candidates := make([]string, 0, len(config.Hosts)+1)
|
||
if len(config.Hosts) > 0 {
|
||
candidates = append(candidates, config.Hosts...)
|
||
} else {
|
||
candidates = append(candidates, normalizeMySQLAddress(config.Host, defaultPort))
|
||
}
|
||
|
||
result := make([]string, 0, len(candidates))
|
||
seen := make(map[string]struct{}, len(candidates))
|
||
for _, entry := range candidates {
|
||
host, port, ok := parseHostPortWithDefault(entry, defaultPort)
|
||
if !ok {
|
||
continue
|
||
}
|
||
normalized := normalizeMySQLAddress(host, port)
|
||
if _, exists := seen[normalized]; exists {
|
||
continue
|
||
}
|
||
seen[normalized] = struct{}{}
|
||
result = append(result, normalized)
|
||
}
|
||
return result
|
||
}
|
||
|
||
func (m *MySQLDB) getDSN(config connection.ConnectionConfig) (string, error) {
|
||
database := config.Database
|
||
protocol := "tcp"
|
||
address := normalizeMySQLAddress(config.Host, config.Port)
|
||
|
||
if config.UseSSH {
|
||
netName, err := ssh.RegisterSSHNetwork(config.SSH)
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建 SSH 隧道失败:%w", err)
|
||
}
|
||
protocol = netName
|
||
}
|
||
|
||
return buildMySQLCompatibleDSN(config, protocol, address, database), nil
|
||
}
|
||
|
||
func resolveMySQLCredential(config connection.ConnectionConfig, addressIndex int) (string, string) {
|
||
primaryUser := strings.TrimSpace(config.User)
|
||
primaryPassword := config.Password
|
||
replicaUser := strings.TrimSpace(config.MySQLReplicaUser)
|
||
replicaPassword := config.MySQLReplicaPassword
|
||
|
||
if addressIndex > 0 && replicaUser != "" {
|
||
return replicaUser, replicaPassword
|
||
}
|
||
|
||
if primaryUser == "" && replicaUser != "" {
|
||
return replicaUser, replicaPassword
|
||
}
|
||
|
||
return config.User, primaryPassword
|
||
}
|
||
|
||
func (m *MySQLDB) Connect(config connection.ConnectionConfig) error {
|
||
runConfig := applyMySQLURI(config)
|
||
addresses := collectMySQLAddresses(runConfig)
|
||
if len(addresses) == 0 {
|
||
return fmt.Errorf("连接建立后验证失败:未找到可用的 MySQL 地址")
|
||
}
|
||
|
||
var errorDetails []string
|
||
for index, address := range addresses {
|
||
candidateConfig := runConfig
|
||
host, port, ok := parseHostPortWithDefault(address, defaultMySQLPort)
|
||
if !ok {
|
||
continue
|
||
}
|
||
candidateConfig.Host = host
|
||
candidateConfig.Port = port
|
||
candidateConfig.User, candidateConfig.Password = resolveMySQLCredential(runConfig, index)
|
||
|
||
dsn, err := m.getDSN(candidateConfig)
|
||
if err != nil {
|
||
errorDetails = append(errorDetails, fmt.Sprintf("%s 生成连接串失败: %v", address, err))
|
||
continue
|
||
}
|
||
db, err := sql.Open("mysql", dsn)
|
||
if err != nil {
|
||
errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err))
|
||
continue
|
||
}
|
||
|
||
timeout := getConnectTimeout(candidateConfig)
|
||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||
pingErr := db.PingContext(ctx)
|
||
cancel()
|
||
if pingErr != nil {
|
||
_ = db.Close()
|
||
errorDetails = append(errorDetails, fmt.Sprintf("%s 验证失败: %v", address, pingErr))
|
||
continue
|
||
}
|
||
|
||
m.conn = db
|
||
m.pingTimeout = timeout
|
||
return nil
|
||
}
|
||
|
||
if len(errorDetails) == 0 {
|
||
return fmt.Errorf("连接建立后验证失败:未找到可用的 MySQL 地址")
|
||
}
|
||
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(errorDetails, ";"))
|
||
}
|
||
|
||
func (m *MySQLDB) Close() error {
|
||
if m.conn != nil {
|
||
return m.conn.Close()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (m *MySQLDB) Ping() error {
|
||
if m.conn == nil {
|
||
return fmt.Errorf("连接未打开")
|
||
}
|
||
timeout := m.pingTimeout
|
||
if timeout <= 0 {
|
||
timeout = 5 * time.Second
|
||
}
|
||
ctx, cancel := utils.ContextWithTimeout(timeout)
|
||
defer cancel()
|
||
return m.conn.PingContext(ctx)
|
||
}
|
||
|
||
func (m *MySQLDB) QueryMulti(query string) ([]connection.ResultSetData, error) {
|
||
if m.conn == nil {
|
||
return nil, fmt.Errorf("连接未打开")
|
||
}
|
||
rows, err := m.conn.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
return scanMultiRows(rows)
|
||
}
|
||
|
||
func (m *MySQLDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
|
||
if m.conn == nil {
|
||
return nil, fmt.Errorf("连接未打开")
|
||
}
|
||
rows, err := m.conn.QueryContext(ctx, query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
return scanMultiRows(rows)
|
||
}
|
||
|
||
func (m *MySQLDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||
if m.conn == nil {
|
||
return nil, nil, fmt.Errorf("连接未打开")
|
||
}
|
||
|
||
rows, err := m.conn.QueryContext(ctx, query)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
return scanRows(rows)
|
||
}
|
||
|
||
func (m *MySQLDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||
if m.conn == nil {
|
||
return nil, nil, fmt.Errorf("连接未打开")
|
||
}
|
||
|
||
rows, err := m.conn.Query(query)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
defer rows.Close()
|
||
return scanRows(rows)
|
||
}
|
||
|
||
func (m *MySQLDB) ExecBatchContext(ctx context.Context, query string) (int64, error) {
|
||
if m.conn == nil {
|
||
return 0, fmt.Errorf("连接未打开")
|
||
}
|
||
res, err := m.conn.ExecContext(ctx, query)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
func (m *MySQLDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||
if m.conn == nil {
|
||
return 0, fmt.Errorf("连接未打开")
|
||
}
|
||
res, err := m.conn.ExecContext(ctx, query)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
func (m *MySQLDB) Exec(query string) (int64, error) {
|
||
if m.conn == nil {
|
||
return 0, fmt.Errorf("连接未打开")
|
||
}
|
||
res, err := m.conn.Exec(query)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return res.RowsAffected()
|
||
}
|
||
|
||
func (m *MySQLDB) GetDatabases() ([]string, error) {
|
||
return collectMySQLDatabaseNames(m.Query)
|
||
}
|
||
|
||
func (m *MySQLDB) GetTables(dbName string) ([]string, error) {
|
||
query := "SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE' ORDER BY TABLE_NAME"
|
||
if dbName != "" {
|
||
query = fmt.Sprintf(
|
||
"SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = '%s' AND TABLE_TYPE = 'BASE TABLE' ORDER BY TABLE_NAME",
|
||
strings.ReplaceAll(dbName, "'", "''"),
|
||
)
|
||
}
|
||
|
||
data, _, err := m.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var tables []string
|
||
for _, row := range data {
|
||
for _, v := range row {
|
||
tables = append(tables, fmt.Sprintf("%v", v))
|
||
break
|
||
}
|
||
}
|
||
return tables, nil
|
||
}
|
||
|
||
func (m *MySQLDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||
query := fmt.Sprintf("SHOW CREATE TABLE `%s`.`%s`", dbName, tableName)
|
||
if dbName == "" {
|
||
query = fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName)
|
||
}
|
||
|
||
data, _, err := m.Query(query)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
if len(data) > 0 {
|
||
if val, ok := data[0]["Create Table"]; ok {
|
||
return fmt.Sprintf("%v", val), nil
|
||
}
|
||
}
|
||
return "", fmt.Errorf("未找到建表语句")
|
||
}
|
||
|
||
func (m *MySQLDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||
query := fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`.`%s`", dbName, tableName)
|
||
if dbName == "" {
|
||
query = fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`", tableName)
|
||
}
|
||
|
||
data, _, err := m.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var columns []connection.ColumnDefinition
|
||
for _, row := range data {
|
||
col := connection.ColumnDefinition{
|
||
Name: fmt.Sprintf("%v", row["Field"]),
|
||
Type: fmt.Sprintf("%v", row["Type"]),
|
||
Nullable: fmt.Sprintf("%v", row["Null"]),
|
||
Key: fmt.Sprintf("%v", row["Key"]),
|
||
Extra: fmt.Sprintf("%v", row["Extra"]),
|
||
Comment: fmt.Sprintf("%v", row["Comment"]),
|
||
}
|
||
|
||
if row["Default"] != nil {
|
||
d := fmt.Sprintf("%v", row["Default"])
|
||
col.Default = &d
|
||
}
|
||
|
||
columns = append(columns, col)
|
||
}
|
||
return columns, nil
|
||
}
|
||
|
||
func (m *MySQLDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||
query := fmt.Sprintf("SHOW INDEX FROM `%s`.`%s`", dbName, tableName)
|
||
if dbName == "" {
|
||
query = fmt.Sprintf("SHOW INDEX FROM `%s`", tableName)
|
||
}
|
||
|
||
data, _, err := m.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var indexes []connection.IndexDefinition
|
||
for _, row := range data {
|
||
nonUnique := 0
|
||
if val, ok := row["Non_unique"]; ok {
|
||
if f, ok := val.(float64); ok {
|
||
nonUnique = int(f)
|
||
} else if i, ok := val.(int64); ok {
|
||
nonUnique = int(i)
|
||
}
|
||
}
|
||
|
||
seq := 0
|
||
if val, ok := row["Seq_in_index"]; ok {
|
||
if f, ok := val.(float64); ok {
|
||
seq = int(f)
|
||
} else if i, ok := val.(int64); ok {
|
||
seq = int(i)
|
||
}
|
||
}
|
||
|
||
subPart := 0
|
||
if val, ok := row["Sub_part"]; ok && val != nil {
|
||
if f, ok := val.(float64); ok {
|
||
subPart = int(f)
|
||
} else if i, ok := val.(int64); ok {
|
||
subPart = int(i)
|
||
}
|
||
}
|
||
|
||
idx := connection.IndexDefinition{
|
||
Name: fmt.Sprintf("%v", row["Key_name"]),
|
||
ColumnName: fmt.Sprintf("%v", row["Column_name"]),
|
||
NonUnique: nonUnique,
|
||
SeqInIndex: seq,
|
||
IndexType: fmt.Sprintf("%v", row["Index_type"]),
|
||
SubPart: subPart,
|
||
}
|
||
indexes = append(indexes, idx)
|
||
}
|
||
return indexes, nil
|
||
}
|
||
|
||
func (m *MySQLDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||
query := fmt.Sprintf(`SELECT CONSTRAINT_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME
|
||
FROM information_schema.KEY_COLUMN_USAGE
|
||
WHERE TABLE_SCHEMA = '%s' AND TABLE_NAME = '%s' AND REFERENCED_TABLE_NAME IS NOT NULL`, dbName, tableName)
|
||
|
||
data, _, err := m.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var fks []connection.ForeignKeyDefinition
|
||
for _, row := range data {
|
||
fk := connection.ForeignKeyDefinition{
|
||
Name: fmt.Sprintf("%v", row["CONSTRAINT_NAME"]),
|
||
ColumnName: fmt.Sprintf("%v", row["COLUMN_NAME"]),
|
||
RefTableName: fmt.Sprintf("%v", row["REFERENCED_TABLE_NAME"]),
|
||
RefColumnName: fmt.Sprintf("%v", row["REFERENCED_COLUMN_NAME"]),
|
||
ConstraintName: fmt.Sprintf("%v", row["CONSTRAINT_NAME"]),
|
||
}
|
||
fks = append(fks, fk)
|
||
}
|
||
return fks, nil
|
||
}
|
||
|
||
func (m *MySQLDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||
query := fmt.Sprintf("SHOW TRIGGERS FROM `%s` WHERE `Table` = '%s'", dbName, tableName)
|
||
data, _, err := m.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var triggers []connection.TriggerDefinition
|
||
for _, row := range data {
|
||
trig := connection.TriggerDefinition{
|
||
Name: fmt.Sprintf("%v", row["Trigger"]),
|
||
Timing: fmt.Sprintf("%v", row["Timing"]),
|
||
Event: fmt.Sprintf("%v", row["Event"]),
|
||
Statement: fmt.Sprintf("%v", row["Statement"]),
|
||
}
|
||
triggers = append(triggers, trig)
|
||
}
|
||
return triggers, nil
|
||
}
|
||
|
||
func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
|
||
if m.conn == nil {
|
||
return fmt.Errorf("连接未打开")
|
||
}
|
||
|
||
columnTypeMap := m.loadColumnTypeMap(tableName)
|
||
|
||
tx, err := m.conn.Begin()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer tx.Rollback()
|
||
|
||
// 1. Deletes
|
||
for _, pk := range changes.Deletes {
|
||
var wheres []string
|
||
var args []interface{}
|
||
for k, v := range pk {
|
||
wheres = append(wheres, fmt.Sprintf("`%s` = ?", k))
|
||
args = append(args, normalizeMySQLValueForWrite(k, v, columnTypeMap))
|
||
}
|
||
if len(wheres) == 0 {
|
||
continue
|
||
}
|
||
query := fmt.Sprintf("DELETE FROM `%s` WHERE %s", tableName, strings.Join(wheres, " AND "))
|
||
res, err := tx.Exec(query, args...)
|
||
if err != nil {
|
||
return fmt.Errorf("删除失败:%v", err)
|
||
}
|
||
if err := requireSingleRowAffected(res, "删除"); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
// 2. Updates
|
||
for _, update := range changes.Updates {
|
||
var sets []string
|
||
var args []interface{}
|
||
|
||
for k, v := range update.Values {
|
||
sets = append(sets, fmt.Sprintf("`%s` = ?", k))
|
||
args = append(args, normalizeMySQLValueForWrite(k, v, columnTypeMap))
|
||
}
|
||
|
||
if len(sets) == 0 {
|
||
continue
|
||
}
|
||
|
||
var wheres []string
|
||
for k, v := range update.Keys {
|
||
wheres = append(wheres, fmt.Sprintf("`%s` = ?", k))
|
||
args = append(args, normalizeMySQLValueForWrite(k, v, columnTypeMap))
|
||
}
|
||
|
||
if len(wheres) == 0 {
|
||
return fmt.Errorf("更新操作需要主键条件")
|
||
}
|
||
|
||
query := fmt.Sprintf("UPDATE `%s` SET %s WHERE %s", tableName, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
|
||
res, err := tx.Exec(query, args...)
|
||
if err != nil {
|
||
return fmt.Errorf("更新失败:%v", err)
|
||
}
|
||
if err := requireSingleRowAffected(res, "更新"); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
// 3. Inserts
|
||
for _, row := range changes.Inserts {
|
||
var cols []string
|
||
var placeholders []string
|
||
var args []interface{}
|
||
|
||
for k, v := range row {
|
||
normalizedValue, omit := normalizeMySQLValueForInsert(k, v, columnTypeMap)
|
||
if omit {
|
||
continue
|
||
}
|
||
cols = append(cols, fmt.Sprintf("`%s`", k))
|
||
placeholders = append(placeholders, "?")
|
||
args = append(args, normalizedValue)
|
||
}
|
||
|
||
if len(cols) == 0 {
|
||
query := fmt.Sprintf("INSERT INTO `%s` () VALUES ()", tableName)
|
||
res, err := tx.Exec(query)
|
||
if err != nil {
|
||
return fmt.Errorf("插入失败:%v", err)
|
||
}
|
||
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
|
||
return fmt.Errorf("插入未生效:未影响任何行")
|
||
}
|
||
continue
|
||
}
|
||
|
||
query := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", tableName, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
||
res, err := tx.Exec(query, args...)
|
||
if err != nil {
|
||
return fmt.Errorf("插入失败:%v", err)
|
||
}
|
||
if affected, err := res.RowsAffected(); err == nil && affected == 0 {
|
||
return fmt.Errorf("插入未生效:未影响任何行")
|
||
}
|
||
}
|
||
|
||
return tx.Commit()
|
||
}
|
||
|
||
func normalizeMySQLComplexValue(value interface{}) interface{} {
|
||
switch v := value.(type) {
|
||
case map[string]interface{}, []interface{}:
|
||
if data, err := json.Marshal(v); err == nil {
|
||
return string(data)
|
||
}
|
||
return fmt.Sprintf("%v", value)
|
||
default:
|
||
return value
|
||
}
|
||
}
|
||
|
||
func normalizeMySQLDateTimeValue(value interface{}) interface{} {
|
||
text, ok := value.(string)
|
||
if !ok {
|
||
return value
|
||
}
|
||
raw := strings.TrimSpace(text)
|
||
if raw == "" {
|
||
return value
|
||
}
|
||
|
||
cleaned := strings.ReplaceAll(raw, "+ ", "+")
|
||
cleaned = strings.ReplaceAll(cleaned, "- ", "-")
|
||
|
||
if len(cleaned) >= 19 && cleaned[10] == 'T' {
|
||
if strings.HasSuffix(cleaned, "Z") || hasTimezoneOffset(cleaned) {
|
||
if t, err := time.Parse(time.RFC3339Nano, cleaned); err == nil {
|
||
return formatMySQLDateTime(t)
|
||
}
|
||
if t, err := time.Parse(time.RFC3339, cleaned); err == nil {
|
||
return formatMySQLDateTime(t)
|
||
}
|
||
}
|
||
return strings.Replace(cleaned, "T", " ", 1)
|
||
}
|
||
|
||
if strings.Contains(cleaned, " ") && (strings.HasSuffix(cleaned, "Z") || hasTimezoneOffset(cleaned)) {
|
||
candidate := strings.Replace(cleaned, " ", "T", 1)
|
||
if t, err := time.Parse(time.RFC3339Nano, candidate); err == nil {
|
||
return formatMySQLDateTime(t)
|
||
}
|
||
if t, err := time.Parse(time.RFC3339, candidate); err == nil {
|
||
return formatMySQLDateTime(t)
|
||
}
|
||
}
|
||
|
||
return value
|
||
}
|
||
|
||
func (m *MySQLDB) loadColumnTypeMap(tableName string) map[string]string {
|
||
result := map[string]string{}
|
||
table := strings.TrimSpace(tableName)
|
||
if table == "" {
|
||
return result
|
||
}
|
||
|
||
columns, err := m.GetColumns("", table)
|
||
if err != nil {
|
||
logger.Warnf("加载列元数据失败(不影响提交):表=%s err=%v", table, err)
|
||
return result
|
||
}
|
||
|
||
for _, col := range columns {
|
||
name := strings.ToLower(strings.TrimSpace(col.Name))
|
||
if name == "" {
|
||
continue
|
||
}
|
||
result[name] = strings.TrimSpace(col.Type)
|
||
}
|
||
return result
|
||
}
|
||
|
||
func normalizeMySQLValueForInsert(columnName string, value interface{}, columnTypeMap map[string]string) (interface{}, bool) {
|
||
columnType := strings.ToLower(strings.TrimSpace(columnTypeMap[strings.ToLower(strings.TrimSpace(columnName))]))
|
||
if isMySQLBitColumnType(columnType) {
|
||
return normalizeMySQLBitValue(value), false
|
||
}
|
||
if !isMySQLTemporalColumnType(columnType) {
|
||
return normalizeMySQLComplexValue(value), false
|
||
}
|
||
text, ok := value.(string)
|
||
if ok && strings.TrimSpace(text) == "" {
|
||
// INSERT 空时间字段不写入,交给 DB 默认值处理(如 CURRENT_TIMESTAMP)。
|
||
return nil, true
|
||
}
|
||
return normalizeMySQLDateTimeValue(value), false
|
||
}
|
||
|
||
func normalizeMySQLValueForWrite(columnName string, value interface{}, columnTypeMap map[string]string) interface{} {
|
||
columnType := strings.ToLower(strings.TrimSpace(columnTypeMap[strings.ToLower(strings.TrimSpace(columnName))]))
|
||
if isMySQLBitColumnType(columnType) {
|
||
return normalizeMySQLBitValue(value)
|
||
}
|
||
if !isMySQLTemporalColumnType(columnType) {
|
||
return value
|
||
}
|
||
text, ok := value.(string)
|
||
if ok && strings.TrimSpace(text) == "" {
|
||
return nil
|
||
}
|
||
return normalizeMySQLDateTimeValue(value)
|
||
}
|
||
|
||
func isMySQLTemporalColumnType(columnType string) bool {
|
||
raw := strings.ToLower(strings.TrimSpace(columnType))
|
||
if raw == "" {
|
||
return false
|
||
}
|
||
if strings.Contains(raw, "datetime") || strings.Contains(raw, "timestamp") {
|
||
return true
|
||
}
|
||
base := raw
|
||
if idx := strings.IndexAny(base, "( "); idx >= 0 {
|
||
base = base[:idx]
|
||
}
|
||
return base == "date" || base == "time" || base == "year"
|
||
}
|
||
|
||
func isMySQLBitColumnType(columnType string) bool {
|
||
raw := strings.ToLower(strings.TrimSpace(columnType))
|
||
if raw == "" {
|
||
return false
|
||
}
|
||
base := raw
|
||
if idx := strings.IndexAny(base, "( "); idx >= 0 {
|
||
base = base[:idx]
|
||
}
|
||
return base == "bit"
|
||
}
|
||
|
||
func normalizeMySQLBitValue(value interface{}) interface{} {
|
||
switch v := value.(type) {
|
||
case nil:
|
||
return nil
|
||
case []byte:
|
||
return v
|
||
case bool:
|
||
if v {
|
||
return []byte{1}
|
||
}
|
||
return []byte{0}
|
||
case string:
|
||
if bitValue, ok := parseMySQLBitString(v); ok {
|
||
return bitValue
|
||
}
|
||
return value
|
||
case int:
|
||
if v >= 0 {
|
||
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
|
||
return bitValue
|
||
}
|
||
}
|
||
case int8:
|
||
if v >= 0 {
|
||
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
|
||
return bitValue
|
||
}
|
||
}
|
||
case int16:
|
||
if v >= 0 {
|
||
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
|
||
return bitValue
|
||
}
|
||
}
|
||
case int32:
|
||
if v >= 0 {
|
||
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
|
||
return bitValue
|
||
}
|
||
}
|
||
case int64:
|
||
if v >= 0 {
|
||
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
|
||
return bitValue
|
||
}
|
||
}
|
||
case uint:
|
||
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
|
||
return bitValue
|
||
}
|
||
case uint8:
|
||
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
|
||
return bitValue
|
||
}
|
||
case uint16:
|
||
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
|
||
return bitValue
|
||
}
|
||
case uint32:
|
||
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
|
||
return bitValue
|
||
}
|
||
case uint64:
|
||
if bitValue, ok := mysqlBitBytesFromUint64(v); ok {
|
||
return bitValue
|
||
}
|
||
case float32:
|
||
if v >= 0 && math.Trunc(float64(v)) == float64(v) {
|
||
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
|
||
return bitValue
|
||
}
|
||
}
|
||
case float64:
|
||
if v >= 0 && math.Trunc(v) == v {
|
||
if bitValue, ok := mysqlBitBytesFromUint64(uint64(v)); ok {
|
||
return bitValue
|
||
}
|
||
}
|
||
}
|
||
return value
|
||
}
|
||
|
||
func parseMySQLBitString(text string) ([]byte, bool) {
|
||
raw := strings.TrimSpace(text)
|
||
if raw == "" {
|
||
return nil, false
|
||
}
|
||
|
||
switch strings.ToLower(raw) {
|
||
case "true":
|
||
return []byte{1}, true
|
||
case "false":
|
||
return []byte{0}, true
|
||
}
|
||
|
||
if len(raw) > 3 && (raw[0] == 'b' || raw[0] == 'B') && raw[1] == '\'' && raw[len(raw)-1] == '\'' {
|
||
value, err := strconv.ParseUint(raw[2:len(raw)-1], 2, 64)
|
||
if err == nil {
|
||
return mysqlBitBytesFromUint64OrZero(value), true
|
||
}
|
||
return nil, false
|
||
}
|
||
|
||
if len(raw) > 2 && (strings.HasPrefix(raw, "0b") || strings.HasPrefix(raw, "0B")) {
|
||
value, err := strconv.ParseUint(raw[2:], 2, 64)
|
||
if err == nil {
|
||
return mysqlBitBytesFromUint64OrZero(value), true
|
||
}
|
||
return nil, false
|
||
}
|
||
|
||
value, err := strconv.ParseUint(raw, 10, 64)
|
||
if err != nil {
|
||
return nil, false
|
||
}
|
||
return mysqlBitBytesFromUint64OrZero(value), true
|
||
}
|
||
|
||
func mysqlBitBytesFromUint64(value uint64) ([]byte, bool) {
|
||
return mysqlBitBytesFromUint64OrZero(value), true
|
||
}
|
||
|
||
func mysqlBitBytesFromUint64OrZero(value uint64) []byte {
|
||
if value == 0 {
|
||
return []byte{0}
|
||
}
|
||
var buf [8]byte
|
||
index := len(buf)
|
||
for value > 0 {
|
||
index--
|
||
buf[index] = byte(value)
|
||
value >>= 8
|
||
}
|
||
return append([]byte(nil), buf[index:]...)
|
||
}
|
||
|
||
func hasTimezoneOffset(text string) bool {
|
||
pos := strings.LastIndexAny(text, "+-")
|
||
if pos < 0 || pos < 10 || pos+1 >= len(text) {
|
||
return false
|
||
}
|
||
offset := text[pos+1:]
|
||
if len(offset) == 5 && offset[2] == ':' {
|
||
return isAllDigits(offset[:2]) && isAllDigits(offset[3:])
|
||
}
|
||
if len(offset) == 4 {
|
||
return isAllDigits(offset)
|
||
}
|
||
return false
|
||
}
|
||
|
||
func isAllDigits(text string) bool {
|
||
if text == "" {
|
||
return false
|
||
}
|
||
for _, r := range text {
|
||
if r < '0' || r > '9' {
|
||
return false
|
||
}
|
||
}
|
||
return true
|
||
}
|
||
|
||
func formatMySQLDateTime(t time.Time) string {
|
||
base := t.Format("2006-01-02 15:04:05")
|
||
nanos := t.Nanosecond()
|
||
if nanos == 0 {
|
||
return base
|
||
}
|
||
micro := nanos / 1000
|
||
return fmt.Sprintf("%s.%06d", base, micro)
|
||
}
|
||
|
||
func (m *MySQLDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||
query := fmt.Sprintf("SELECT TABLE_NAME, COLUMN_NAME, COLUMN_TYPE FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = '%s'", dbName)
|
||
if dbName == "" {
|
||
return nil, fmt.Errorf("获取全部列信息需要指定数据库名称")
|
||
}
|
||
|
||
data, _, err := m.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var cols []connection.ColumnDefinitionWithTable
|
||
for _, row := range data {
|
||
col := connection.ColumnDefinitionWithTable{
|
||
TableName: fmt.Sprintf("%v", row["TABLE_NAME"]),
|
||
Name: fmt.Sprintf("%v", row["COLUMN_NAME"]),
|
||
Type: fmt.Sprintf("%v", row["COLUMN_TYPE"]),
|
||
}
|
||
cols = append(cols, col)
|
||
}
|
||
return cols, nil
|
||
}
|