package db import ( "context" "database/sql" "fmt" "strings" "time" "GoNavi-Wails/internal/connection" "GoNavi-Wails/internal/utils" ) type CustomDB struct { conn *sql.DB driver string pingTimeout time.Duration } func (c *CustomDB) Connect(config connection.ConnectionConfig) error { driver := strings.TrimSpace(config.Driver) dsn := strings.TrimSpace(config.DSN) if driver == "" || dsn == "" { return fmt.Errorf("driver and dsn are required for custom connection") } // Verify driver is registered (implicit check by sql.Open) // We might not need explicit check, sql.Open will fail or Ping will fail if driver not found. db, err := sql.Open(driver, dsn) if err != nil { return formatCustomDriverOpenError(driver, err) } c.conn = db c.driver = driver c.pingTimeout = getConnectTimeout(config) if err := c.Ping(); err != nil { return fmt.Errorf("连接建立后验证失败:%w", err) } return nil } func formatCustomDriverOpenError(driver string, err error) error { if err == nil { return nil } if strings.Contains(strings.ToLower(err.Error()), "unknown driver") { if isLikelySystemODBCDriverName(driver) { return fmt.Errorf("打开数据库连接失败:自定义连接不支持直接填写系统 ODBC/JDBC 驱动名 %q;请填写 GoNavi 已注册的 Go database/sql 驱动名。当前版本未注册通用 ODBC 驱动,因此暂不支持通过 %q 连接 InterSystems IRIS:%w", driver, driver, err) } return fmt.Errorf("打开数据库连接失败:自定义连接驱动 %q 未在 GoNavi 中注册;请填写已注册的 Go database/sql 驱动名,不能填写系统 ODBC/JDBC 驱动名:%w", driver, err) } return fmt.Errorf("打开数据库连接失败:%w", err) } func isLikelySystemODBCDriverName(driver string) bool { normalized := strings.ToLower(strings.TrimSpace(driver)) return strings.Contains(normalized, "odbc") || strings.Contains(normalized, "jdbc") || strings.Contains(normalized, "intersystems") || strings.Contains(normalized, "iris") } func (c *CustomDB) Close() error { if c.conn != nil { return c.conn.Close() } return nil } func (c *CustomDB) Ping() error { if c.conn == nil { return fmt.Errorf("连接未打开") } timeout := c.pingTimeout if timeout <= 0 { timeout = 5 * time.Second } ctx, cancel := utils.ContextWithTimeout(timeout) defer cancel() return c.conn.PingContext(ctx) } func (c *CustomDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { if c.conn == nil { return nil, nil, fmt.Errorf("连接未打开") } rows, err := c.conn.QueryContext(ctx, query) if err != nil { return nil, nil, err } defer rows.Close() return scanRows(rows) } func (c *CustomDB) Query(query string) ([]map[string]interface{}, []string, error) { if c.conn == nil { return nil, nil, fmt.Errorf("连接未打开") } rows, err := c.conn.Query(query) if err != nil { return nil, nil, err } defer rows.Close() return scanRows(rows) } func (c *CustomDB) ExecContext(ctx context.Context, query string) (int64, error) { if c.conn == nil { return 0, fmt.Errorf("连接未打开") } res, err := c.conn.ExecContext(ctx, query) if err != nil { return 0, err } return res.RowsAffected() } func (c *CustomDB) Exec(query string) (int64, error) { if c.conn == nil { return 0, fmt.Errorf("连接未打开") } res, err := c.conn.Exec(query) if err != nil { return 0, err } return res.RowsAffected() } func (c *CustomDB) GetDatabases() ([]string, error) { // Try standard information_schema or some known patterns if we can't guess // For "custom", we can't easily know. // But many DBs support SHOW DATABASES or SELECT datname FROM pg_database // We'll try a generic query or return empty. // Users using custom might know their DB context is single. if c.driver == "mysql" { data, _, err := c.Query("SHOW DATABASES") if err == nil { var dbs []string for _, row := range data { for _, v := range row { name := strings.TrimSpace(fmt.Sprintf("%v", v)) if name != "" { dbs = append(dbs, name) } break } } if len(dbs) > 0 { return dbs, nil } } // Fallback for restricted accounts: at least expose current database. data, _, fallbackErr := c.Query("SELECT DATABASE() AS database_name") if fallbackErr == nil { for _, row := range data { for _, v := range row { name := strings.TrimSpace(fmt.Sprintf("%v", v)) if name != "" && !strings.EqualFold(name, "") && !strings.EqualFold(name, "null") { return []string{name}, nil } } } } } // Best effort: return []string{}, nil } func (c *CustomDB) GetTables(dbName string) ([]string, error) { // ANSI Standard query := "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'" // If mysql-like if c.driver == "mysql" { query = "SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE' ORDER BY TABLE_NAME" if dbName != "" { query = fmt.Sprintf( "SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_SCHEMA = '%s' AND TABLE_TYPE = 'BASE TABLE' ORDER BY TABLE_NAME", strings.ReplaceAll(dbName, "'", "''"), ) } } else if c.driver == "postgres" || c.driver == "kingbase" { query = ` SELECT table_schema AS schemaname, table_name AS tablename FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')` if dbName != "" { query += fmt.Sprintf(" AND table_schema = '%s'", dbName) } query += " ORDER BY table_schema, table_name" } else if c.driver == "sqlite" { query = "SELECT name FROM sqlite_master WHERE type='table'" } else if c.driver == "oracle" || c.driver == "dm" { query = "SELECT table_name FROM user_tables" if dbName != "" { query = fmt.Sprintf("SELECT owner, table_name FROM all_tables WHERE owner = '%s' ORDER BY table_name", strings.ToUpper(dbName)) } } // Fallback generic execution data, _, err := c.Query(query) if err != nil { return nil, fmt.Errorf("failed to get tables for custom driver %s: %v", c.driver, err) } var tables []string for _, row := range data { if schema, okSchema := row["schemaname"]; okSchema { if name, okName := row["tablename"]; okName { tables = append(tables, fmt.Sprintf("%v.%v", schema, name)) continue } } if owner, okOwner := row["OWNER"]; okOwner { if name, okName := row["TABLE_NAME"]; okName { tables = append(tables, fmt.Sprintf("%v.%v", owner, name)) continue } } // iterate keys to find likely column for k, v := range row { if strings.Contains(strings.ToLower(k), "name") || strings.Contains(strings.ToLower(k), "table") { tables = append(tables, fmt.Sprintf("%v", v)) break } } } return tables, nil } func (c *CustomDB) GetCreateStatement(dbName, tableName string) (string, error) { return "Not supported for custom connections yet", nil } func (c *CustomDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { // ANSI Standard // SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_name = '...' schema := "public" if dbName != "" { schema = dbName } query := fmt.Sprintf(`SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_name = '%s'`, tableName) // Adjust for schema if likely supported if c.driver == "postgres" || c.driver == "kingbase" { query += fmt.Sprintf(" AND table_schema = '%s'", schema) } else if c.driver == "mysql" { query = fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`", tableName) if dbName != "" { query = fmt.Sprintf("SHOW FULL COLUMNS FROM `%s`.`%s`", dbName, tableName) } } data, _, err := c.Query(query) if err != nil { return nil, err } var columns []connection.ColumnDefinition for _, row := range data { col := connection.ColumnDefinition{} // flexible mapping for k, v := range row { kl := strings.ToLower(k) val := fmt.Sprintf("%v", v) if strings.Contains(kl, "field") || strings.Contains(kl, "column_name") { col.Name = val } else if strings.Contains(kl, "type") { col.Type = val } else if strings.Contains(kl, "null") || strings.Contains(kl, "nullable") { col.Nullable = val } else if strings.Contains(kl, "default") { col.Default = &val } else if strings.Contains(kl, "key") { col.Key = val } else if strings.Contains(kl, "comment") { col.Comment = val } } columns = append(columns, col) } return columns, nil } func (c *CustomDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { return nil, fmt.Errorf("not implemented for custom") } func (c *CustomDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { return nil, fmt.Errorf("not implemented for custom") } func (c *CustomDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { return nil, fmt.Errorf("not implemented for custom") } func (c *CustomDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { if c.conn == nil { return fmt.Errorf("连接未打开") } tx, err := c.conn.Begin() if err != nil { return err } defer tx.Rollback() driver := strings.ToLower(strings.TrimSpace(c.driver)) isMySQL := strings.Contains(driver, "mysql") isKingbase := strings.Contains(driver, "kingbase") isPostgres := strings.Contains(driver, "postgres") || isKingbase || strings.Contains(driver, "pg") isOracle := strings.Contains(driver, "oracle") || strings.Contains(driver, "ora") || strings.Contains(driver, "dm") || strings.Contains(driver, "dameng") quoteIdent := func(name string) string { n := strings.TrimSpace(name) if isKingbase { return QuoteKingbaseIdentifier(n) } if isMySQL { n = strings.Trim(n, "`") n = strings.ReplaceAll(n, "`", "``") if n == "" { return "``" } return "`" + n + "`" } n = strings.Trim(n, "\"") n = strings.ReplaceAll(n, "\"", "\"\"") if n == "" { return "\"\"" } return `"` + n + `"` } placeholder := func(idx int) string { if isPostgres { return fmt.Sprintf("$%d", idx) } if isOracle { return fmt.Sprintf(":%d", idx) } // MySQL / SQLite / default return "?" } schema := "" table := strings.TrimSpace(tableName) if isKingbase { schema, table = SplitKingbaseQualifiedName(table) } else if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { schema = strings.TrimSpace(parts[0]) table = strings.TrimSpace(parts[1]) } qualifiedTable := "" if schema != "" { qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table)) } else { qualifiedTable = quoteIdent(table) } // 1. Deletes for _, pk := range changes.Deletes { var wheres []string var args []interface{} idx := 0 for k, v := range pk { idx++ wheres = append(wheres, fmt.Sprintf("%s = %s", quoteIdent(k), placeholder(idx))) args = append(args, v) } if len(wheres) == 0 { continue } query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND ")) if _, err := tx.Exec(query, args...); err != nil { return fmt.Errorf("删除失败:%v", err) } } // 2. Updates for _, update := range changes.Updates { var sets []string var args []interface{} idx := 0 for k, v := range update.Values { idx++ sets = append(sets, fmt.Sprintf("%s = %s", quoteIdent(k), placeholder(idx))) args = append(args, v) } if len(sets) == 0 { continue } var wheres []string for k, v := range update.Keys { idx++ wheres = append(wheres, fmt.Sprintf("%s = %s", quoteIdent(k), placeholder(idx))) args = append(args, v) } if len(wheres) == 0 { return fmt.Errorf("更新操作需要主键条件") } query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND ")) if _, err := tx.Exec(query, args...); err != nil { return fmt.Errorf("更新失败:%v", err) } } // 3. Inserts for _, row := range changes.Inserts { var cols []string var placeholders []string var args []interface{} idx := 0 for k, v := range row { idx++ cols = append(cols, quoteIdent(k)) placeholders = append(placeholders, placeholder(idx)) args = append(args, v) } if len(cols) == 0 { continue } query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) if _, err := tx.Exec(query, args...); err != nil { return fmt.Errorf("插入失败:%v", err) } } return tx.Commit() } func (c *CustomDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { return nil, fmt.Errorf("not implemented for custom") }