Files
MyGoNavi/internal/db/mariadb_impl.go
Syngnat 0daf702d25 feat(data-sync): 扩展跨库迁移链路并优化数据同步交互
- 统一同库同步与跨库迁移入口,补充模式区分与风险提示
- 扩展 ClickHouse 与 PG-like 双向迁移,并新增 PG-like、ClickHouse、TDengine 到 MongoDB 的迁移路由
- 完善 TDengine 目标端建表规划、回归测试与需求追踪文档
- refs #51
2026-03-09 17:22:26 +08:00

424 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//go:build gonavi_full_drivers || gonavi_mariadb_driver
package db
import (
"context"
"database/sql"
"fmt"
"net/url"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
_ "github.com/go-sql-driver/mysql"
)
// MariaDB implements Database interface for MariaDB
// MariaDB is MySQL-compatible, so we reuse the MySQL driver
type MariaDB struct {
conn *sql.DB
pingTimeout time.Duration
}
func (m *MariaDB) getDSN(config connection.ConnectionConfig) string {
database := config.Database
protocol := "tcp"
address := fmt.Sprintf("%s:%d", config.Host, config.Port)
if config.UseSSH {
netName, err := ssh.RegisterSSHNetwork(config.SSH)
if err == nil {
protocol = netName
address = fmt.Sprintf("%s:%d", config.Host, config.Port)
} else {
logger.Warnf("注册 SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s原因%v", config.Host, config.Port, config.User, err)
}
}
timeout := getConnectTimeoutSeconds(config)
tlsMode := resolveMySQLTLSMode(config)
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%ds&tls=%s",
config.User, config.Password, protocol, address, database, timeout, url.QueryEscape(tlsMode))
}
func (m *MariaDB) Connect(config connection.ConnectionConfig) error {
dsn := m.getDSN(config)
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
}
m.conn = db
m.pingTimeout = getConnectTimeout(config)
if err := m.Ping(); err != nil {
return fmt.Errorf("连接建立后验证失败:%w", err)
}
return nil
}
func (m *MariaDB) Close() error {
if m.conn != nil {
return m.conn.Close()
}
return nil
}
func (m *MariaDB) Ping() error {
if m.conn == nil {
return fmt.Errorf("connection not open")
}
timeout := m.pingTimeout
if timeout <= 0 {
timeout = 5 * time.Second
}
ctx, cancel := utils.ContextWithTimeout(timeout)
defer cancel()
return m.conn.PingContext(ctx)
}
func (m *MariaDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if m.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := m.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (m *MariaDB) Query(query string) ([]map[string]interface{}, []string, error) {
if m.conn == nil {
return nil, nil, fmt.Errorf("connection not open")
}
rows, err := m.conn.Query(query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (m *MariaDB) ExecContext(ctx context.Context, query string) (int64, error) {
if m.conn == nil {
return 0, fmt.Errorf("connection not open")
}
res, err := m.conn.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (m *MariaDB) Exec(query string) (int64, error) {
if m.conn == nil {
return 0, fmt.Errorf("connection not open")
}
res, err := m.conn.Exec(query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (m *MariaDB) GetDatabases() ([]string, error) {
data, _, err := m.Query("SHOW DATABASES")
if err != nil {
return nil, err
}
var dbs []string
for _, row := range data {
if val, ok := row["Database"]; ok {
dbs = append(dbs, fmt.Sprintf("%v", val))
} else if val, ok := row["database"]; ok {
dbs = append(dbs, fmt.Sprintf("%v", val))
}
}
return dbs, nil
}
func (m *MariaDB) GetTables(dbName string) ([]string, error) {
query := "SHOW TABLES"
if dbName != "" {
query = fmt.Sprintf("SHOW TABLES FROM `%s`", dbName)
}
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var tables []string
for _, row := range data {
for _, v := range row {
tables = append(tables, fmt.Sprintf("%v", v))
break
}
}
return tables, nil
}
func (m *MariaDB) GetCreateStatement(dbName, tableName string) (string, error) {
query := fmt.Sprintf("SHOW CREATE TABLE `%s`.`%s`", dbName, tableName)
if dbName == "" {
query = fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName)
}
data, _, err := m.Query(query)
if err != nil {
return "", err
}
if len(data) > 0 {
if val, ok := data[0]["Create Table"]; ok {
return fmt.Sprintf("%v", val), nil
}
}
return "", fmt.Errorf("create statement not found")
}
func (m *MariaDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
query := fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`.`%s`", dbName, tableName)
if dbName == "" {
query = fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`", tableName)
}
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var columns []connection.ColumnDefinition
for _, row := range data {
col := connection.ColumnDefinition{
Name: fmt.Sprintf("%v", row["Field"]),
Type: fmt.Sprintf("%v", row["Type"]),
Nullable: fmt.Sprintf("%v", row["Null"]),
Key: fmt.Sprintf("%v", row["Key"]),
Extra: fmt.Sprintf("%v", row["Extra"]),
Comment: fmt.Sprintf("%v", row["Comment"]),
}
if row["Default"] != nil {
d := fmt.Sprintf("%v", row["Default"])
col.Default = &d
}
columns = append(columns, col)
}
return columns, nil
}
func (m *MariaDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
query := fmt.Sprintf("SHOW INDEX FROM `%s`.`%s`", dbName, tableName)
if dbName == "" {
query = fmt.Sprintf("SHOW INDEX FROM `%s`", tableName)
}
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var indexes []connection.IndexDefinition
for _, row := range data {
nonUnique := 0
if val, ok := row["Non_unique"]; ok {
if f, ok := val.(float64); ok {
nonUnique = int(f)
} else if i, ok := val.(int64); ok {
nonUnique = int(i)
}
}
seq := 0
if val, ok := row["Seq_in_index"]; ok {
if f, ok := val.(float64); ok {
seq = int(f)
} else if i, ok := val.(int64); ok {
seq = int(i)
}
}
subPart := 0
if val, ok := row["Sub_part"]; ok && val != nil {
if f, ok := val.(float64); ok {
subPart = int(f)
} else if i, ok := val.(int64); ok {
subPart = int(i)
}
}
idx := connection.IndexDefinition{
Name: fmt.Sprintf("%v", row["Key_name"]),
ColumnName: fmt.Sprintf("%v", row["Column_name"]),
NonUnique: nonUnique,
SeqInIndex: seq,
IndexType: fmt.Sprintf("%v", row["Index_type"]),
SubPart: subPart,
}
indexes = append(indexes, idx)
}
return indexes, nil
}
func (m *MariaDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
query := fmt.Sprintf(`SELECT CONSTRAINT_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME
FROM information_schema.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = '%s' AND TABLE_NAME = '%s' AND REFERENCED_TABLE_NAME IS NOT NULL`, dbName, tableName)
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var fks []connection.ForeignKeyDefinition
for _, row := range data {
fk := connection.ForeignKeyDefinition{
Name: fmt.Sprintf("%v", row["CONSTRAINT_NAME"]),
ColumnName: fmt.Sprintf("%v", row["COLUMN_NAME"]),
RefTableName: fmt.Sprintf("%v", row["REFERENCED_TABLE_NAME"]),
RefColumnName: fmt.Sprintf("%v", row["REFERENCED_COLUMN_NAME"]),
ConstraintName: fmt.Sprintf("%v", row["CONSTRAINT_NAME"]),
}
fks = append(fks, fk)
}
return fks, nil
}
func (m *MariaDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
query := fmt.Sprintf("SHOW TRIGGERS FROM `%s` WHERE `Table` = '%s'", dbName, tableName)
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var triggers []connection.TriggerDefinition
for _, row := range data {
trig := connection.TriggerDefinition{
Name: fmt.Sprintf("%v", row["Trigger"]),
Timing: fmt.Sprintf("%v", row["Timing"]),
Event: fmt.Sprintf("%v", row["Event"]),
Statement: fmt.Sprintf("%v", row["Statement"]),
}
triggers = append(triggers, trig)
}
return triggers, nil
}
func (m *MariaDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
if m.conn == nil {
return fmt.Errorf("connection not open")
}
tx, err := m.conn.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// 1. Deletes
for _, pk := range changes.Deletes {
var wheres []string
var args []interface{}
for k, v := range pk {
wheres = append(wheres, fmt.Sprintf("`%s` = ?", k))
args = append(args, normalizeMySQLComplexValue(normalizeMySQLDateTimeValue(v)))
}
if len(wheres) == 0 {
continue
}
query := fmt.Sprintf("DELETE FROM `%s` WHERE %s", tableName, strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("delete error: %v", err)
}
}
// 2. Updates
for _, update := range changes.Updates {
var sets []string
var args []interface{}
for k, v := range update.Values {
sets = append(sets, fmt.Sprintf("`%s` = ?", k))
args = append(args, normalizeMySQLComplexValue(normalizeMySQLDateTimeValue(v)))
}
if len(sets) == 0 {
continue
}
var wheres []string
for k, v := range update.Keys {
wheres = append(wheres, fmt.Sprintf("`%s` = ?", k))
args = append(args, normalizeMySQLComplexValue(normalizeMySQLDateTimeValue(v)))
}
if len(wheres) == 0 {
return fmt.Errorf("update requires keys")
}
query := fmt.Sprintf("UPDATE `%s` SET %s WHERE %s", tableName, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("update error: %v", err)
}
}
// 3. Inserts
for _, row := range changes.Inserts {
var cols []string
var placeholders []string
var args []interface{}
for k, v := range row {
cols = append(cols, fmt.Sprintf("`%s`", k))
placeholders = append(placeholders, "?")
args = append(args, normalizeMySQLComplexValue(normalizeMySQLDateTimeValue(v)))
}
if len(cols) == 0 {
continue
}
query := fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", tableName, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("insert error: %v", err)
}
}
return tx.Commit()
}
func (m *MariaDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
query := fmt.Sprintf("SELECT TABLE_NAME, COLUMN_NAME, COLUMN_TYPE FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = '%s'", dbName)
if dbName == "" {
return nil, fmt.Errorf("database name required for GetAllColumns")
}
data, _, err := m.Query(query)
if err != nil {
return nil, err
}
var cols []connection.ColumnDefinitionWithTable
for _, row := range data {
col := connection.ColumnDefinitionWithTable{
TableName: fmt.Sprintf("%v", row["TABLE_NAME"]),
Name: fmt.Sprintf("%v", row["COLUMN_NAME"]),
Type: fmt.Sprintf("%v", row["COLUMN_TYPE"]),
}
cols = append(cols, col)
}
return cols, nil
}