Files
MyGoNavi/internal/db/duckdb_impl.go
Syngnat 2254b76232 🐛 fix(duckdb): 修复无主键结果无法安全编辑
- 为 DuckDB 查询结果和表预览补充隐藏 rowid 定位列,允许无主键表安全提交修改
- DataGrid 提交变更时仅将 rowid 用作定位条件,避免把隐藏定位列写回业务字段
- DuckDB ApplyChanges 对 duckdb-rowid 改用未加引号的 rowid 条件,修复更新和删除失效
- 补充前后端回归测试,覆盖 QueryEditor、DataViewer、rowLocator 与 ApplyChanges 链路
2026-06-05 14:05:18 +08:00

491 lines
14 KiB
Go

//go:build gonavi_full_drivers || gonavi_duckdb_driver
package db
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/utils"
)
type DuckDB struct {
conn *sql.DB
pingTimeout time.Duration
}
func (d *DuckDB) Connect(config connection.ConnectionConfig) error {
if supported, reason := duckDBBuildSupportStatus(); !supported {
return fmt.Errorf("DuckDB 驱动不可用:%s", reason)
}
dsn := strings.TrimSpace(config.Host)
if dsn == "" {
dsn = strings.TrimSpace(config.Database)
}
if dsn == "" {
dsn = ":memory:"
}
db, err := sql.Open("duckdb", dsn)
if err != nil {
return fmt.Errorf("打开数据库连接失败:%w", err)
}
d.conn = db
d.pingTimeout = getConnectTimeout(config)
if err := d.Ping(); err != nil {
_ = db.Close()
d.conn = nil
return fmt.Errorf("连接建立后验证失败:%w", err)
}
return nil
}
func (d *DuckDB) Close() error {
if d.conn != nil {
return d.conn.Close()
}
return nil
}
func (d *DuckDB) Ping() error {
if d.conn == nil {
return fmt.Errorf("连接未打开")
}
timeout := d.pingTimeout
if timeout <= 0 {
timeout = 5 * time.Second
}
ctx, cancel := utils.ContextWithTimeout(timeout)
defer cancel()
return d.conn.PingContext(ctx)
}
func (d *DuckDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if d.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := d.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (d *DuckDB) Query(query string) ([]map[string]interface{}, []string, error) {
if d.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := d.conn.Query(query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (d *DuckDB) ExecBatchContext(ctx context.Context, query string) (int64, error) {
if d.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := d.conn.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (d *DuckDB) OpenSessionExecer(ctx context.Context) (StatementExecer, error) {
if d.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
conn, err := d.conn.Conn(ctx)
if err != nil {
return nil, err
}
return NewSQLConnStatementExecer(conn), nil
}
func (d *DuckDB) ExecContext(ctx context.Context, query string) (int64, error) {
if d.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := d.conn.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (d *DuckDB) Exec(query string) (int64, error) {
if d.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := d.conn.Exec(query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (d *DuckDB) GetDatabases() ([]string, error) {
data, _, err := d.Query("PRAGMA database_list")
if err != nil {
return []string{"main"}, nil
}
seen := map[string]struct{}{}
var names []string
for _, row := range data {
name := strings.TrimSpace(duckDBRowString(row, "name", "database_name", "database"))
if name == "" {
continue
}
if _, exists := seen[name]; exists {
continue
}
seen[name] = struct{}{}
names = append(names, name)
}
if len(names) == 0 {
return []string{"main"}, nil
}
return names, nil
}
func (d *DuckDB) GetTables(dbName string) ([]string, error) {
path := normalizeDuckDBObjectPath(dbName, "")
query := `
SELECT table_catalog, table_schema, table_name
FROM information_schema.tables
WHERE table_type = 'BASE TABLE'
AND table_schema NOT IN ('information_schema', 'pg_catalog')
ORDER BY table_catalog, table_schema, table_name`
if path.Catalog != "" {
query = fmt.Sprintf(`
SELECT table_catalog, table_schema, table_name
FROM information_schema.tables
WHERE table_type = 'BASE TABLE'
AND table_schema NOT IN ('information_schema', 'pg_catalog')
AND table_catalog = '%s'
ORDER BY table_catalog, table_schema, table_name`, escapeDuckDBLiteral(path.Catalog))
}
data, _, err := d.Query(query)
if err != nil {
return nil, err
}
seen := map[string]struct{}{}
var tables []string
for _, row := range data {
catalog := strings.TrimSpace(duckDBRowString(row, "table_catalog", "database_name"))
schema := strings.TrimSpace(duckDBRowString(row, "table_schema"))
name := strings.TrimSpace(duckDBRowString(row, "table_name"))
if name == "" {
continue
}
qualified := name
if schema != "" {
qualified = schema + "." + name
}
if catalog != "" && !strings.EqualFold(catalog, "memory") && !strings.EqualFold(catalog, "main") {
qualified = catalog + "." + qualified
}
if _, exists := seen[qualified]; exists {
continue
}
seen[qualified] = struct{}{}
tables = append(tables, qualified)
}
return tables, nil
}
func (d *DuckDB) GetCreateStatement(dbName, tableName string) (string, error) {
path := normalizeDuckDBObjectPath(dbName, tableName)
if path.Object == "" {
return "", fmt.Errorf("表名不能为空")
}
escapedTable := escapeDuckDBLiteral(path.Object)
escapedSchema := escapeDuckDBLiteral(path.Schema)
escapedCatalog := escapeDuckDBLiteral(path.Catalog)
queryCandidates := make([]string, 0, 4)
if path.Catalog != "" {
queryCandidates = append(queryCandidates, fmt.Sprintf("SELECT sql FROM duckdb_tables() WHERE table_name = '%s' AND schema_name = '%s' AND database_name = '%s' LIMIT 1", escapedTable, escapedSchema, escapedCatalog))
}
queryCandidates = append(queryCandidates,
fmt.Sprintf("SELECT sql FROM duckdb_tables() WHERE table_name = '%s' AND schema_name = '%s' LIMIT 1", escapedTable, escapedSchema),
fmt.Sprintf("SELECT sql FROM duckdb_tables() WHERE table_name = '%s' LIMIT 1", escapedTable),
fmt.Sprintf("SHOW CREATE TABLE %s", quoteDuckDBQualifiedTable(path.Schema, path.Object)),
)
if path.Catalog != "" {
queryCandidates = append([]string{
fmt.Sprintf("SHOW CREATE TABLE %s.%s", quoteDuckDBIdentifier(path.Catalog), quoteDuckDBQualifiedTable(path.Schema, path.Object)),
}, queryCandidates...)
}
for _, query := range queryCandidates {
data, _, err := d.Query(query)
if err != nil || len(data) == 0 {
continue
}
createSQL := strings.TrimSpace(duckDBRowString(data[0], "sql", "create_table", "Create Table", "create_statement"))
if createSQL != "" {
return createSQL, nil
}
for _, value := range data[0] {
text := strings.TrimSpace(fmt.Sprintf("%v", value))
if text != "" && text != "<nil>" {
return text, nil
}
}
}
return "", fmt.Errorf("未找到建表语句")
}
func (d *DuckDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
path := normalizeDuckDBObjectPath(dbName, tableName)
if path.Object == "" {
return nil, fmt.Errorf("表名不能为空")
}
query := fmt.Sprintf(`
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_name = '%s' AND table_schema = '%s'
ORDER BY ordinal_position`, escapeDuckDBLiteral(path.Object), escapeDuckDBLiteral(path.Schema))
if path.Catalog != "" {
query = fmt.Sprintf(`
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_name = '%s' AND table_schema = '%s' AND table_catalog = '%s'
ORDER BY ordinal_position`, escapeDuckDBLiteral(path.Object), escapeDuckDBLiteral(path.Schema), escapeDuckDBLiteral(path.Catalog))
}
data, _, err := d.Query(query)
if err != nil {
return nil, err
}
if len(data) == 0 && path.Schema != "main" {
fallbackQuery := fmt.Sprintf(`
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_name = '%s'
ORDER BY ordinal_position`, escapeDuckDBLiteral(path.Object))
data, _, err = d.Query(fallbackQuery)
if err != nil {
return nil, err
}
}
constraintQuery := buildDuckDBConstraintMetadataQuery(path, true)
constraintRows, _, constraintErr := d.Query(constraintQuery)
if constraintErr != nil {
return nil, constraintErr
}
if len(constraintRows) == 0 && path.Schema != "main" {
fallbackConstraintQuery := buildDuckDBConstraintMetadataQuery(path, false)
constraintRows, _, constraintErr = d.Query(fallbackConstraintQuery)
if constraintErr != nil {
return nil, constraintErr
}
}
return buildDuckDBColumnDefinitions(data, constraintRows), nil
}
func (d *DuckDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
path := normalizeDuckDBObjectPath(dbName, "")
query := `
SELECT table_catalog, table_schema, table_name, column_name, data_type
FROM information_schema.columns
WHERE table_schema NOT IN ('information_schema', 'pg_catalog')
ORDER BY table_catalog, table_schema, table_name, ordinal_position`
if path.Catalog != "" {
query = fmt.Sprintf(`
SELECT table_catalog, table_schema, table_name, column_name, data_type
FROM information_schema.columns
WHERE table_schema NOT IN ('information_schema', 'pg_catalog')
AND table_catalog = '%s'
ORDER BY table_catalog, table_schema, table_name, ordinal_position`, escapeDuckDBLiteral(path.Catalog))
}
data, _, err := d.Query(query)
if err != nil {
return nil, err
}
columns := make([]connection.ColumnDefinitionWithTable, 0, len(data))
for _, row := range data {
catalog := strings.TrimSpace(duckDBRowString(row, "table_catalog", "database_name"))
schema := strings.TrimSpace(duckDBRowString(row, "table_schema"))
tableName := strings.TrimSpace(duckDBRowString(row, "table_name"))
if tableName == "" {
continue
}
if schema != "" {
tableName = schema + "." + tableName
}
if catalog != "" && !strings.EqualFold(catalog, "memory") && !strings.EqualFold(catalog, "main") {
tableName = catalog + "." + tableName
}
columns = append(columns, connection.ColumnDefinitionWithTable{
TableName: tableName,
Name: duckDBRowString(row, "column_name"),
Type: duckDBRowString(row, "data_type"),
})
}
return columns, nil
}
func (d *DuckDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
path := normalizeDuckDBObjectPath(dbName, tableName)
if path.Object == "" {
return nil, fmt.Errorf("表名不能为空")
}
constraintQuery := buildDuckDBConstraintMetadataQuery(path, true)
constraintRows, _, err := d.Query(constraintQuery)
if err != nil {
return nil, err
}
if len(constraintRows) == 0 && path.Schema != "main" {
fallbackQuery := buildDuckDBConstraintMetadataQuery(path, false)
constraintRows, _, err = d.Query(fallbackQuery)
if err != nil {
return nil, err
}
}
indexQuery := buildDuckDBIndexMetadataQuery(path, true)
indexRows, _, indexErr := d.Query(indexQuery)
if indexErr != nil {
return nil, indexErr
}
if len(indexRows) == 0 && path.Schema != "main" {
fallbackIndexQuery := buildDuckDBIndexMetadataQuery(path, false)
indexRows, _, indexErr = d.Query(fallbackIndexQuery)
if indexErr != nil {
return nil, indexErr
}
}
return buildDuckDBIndexDefinitions(constraintRows, indexRows), nil
}
func (d *DuckDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
return []connection.ForeignKeyDefinition{}, nil
}
func (d *DuckDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return []connection.TriggerDefinition{}, nil
}
func (d *DuckDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
if d.conn == nil {
return fmt.Errorf("连接未打开")
}
tx, err := d.conn.Begin()
if err != nil {
return err
}
defer tx.Rollback()
quoteIdent := func(name string) string {
n := strings.TrimSpace(name)
n = strings.Trim(n, "\"")
n = strings.ReplaceAll(n, "\"", "\"\"")
if n == "" {
return "\"\""
}
return `"` + n + `"`
}
path := normalizeDuckDBObjectPath("", tableName)
schema := path.Schema
table := path.Object
qualifiedTable := quoteIdent(table)
if schema != "" {
qualifiedTable = fmt.Sprintf("%s.%s", quoteIdent(schema), quoteIdent(table))
}
isDuckDBRowIDLocator := strings.EqualFold(strings.TrimSpace(changes.LocatorStrategy), "duckdb-rowid")
buildWhere := func(keys map[string]interface{}) ([]string, []interface{}) {
var wheres []string
var args []interface{}
for k, v := range keys {
if isDuckDBRowIDLocator && strings.EqualFold(strings.TrimSpace(k), "rowid") {
wheres = append(wheres, "rowid = ?")
args = append(args, v)
continue
}
wheres = append(wheres, fmt.Sprintf("%s = ?", quoteIdent(k)))
args = append(args, v)
}
return wheres, args
}
for _, pk := range changes.Deletes {
wheres, args := buildWhere(pk)
if len(wheres) == 0 {
continue
}
query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("删除失败:%v", err)
}
}
for _, update := range changes.Updates {
var sets []string
var args []interface{}
for k, v := range update.Values {
sets = append(sets, fmt.Sprintf("%s = ?", quoteIdent(k)))
args = append(args, v)
}
if len(sets) == 0 {
continue
}
wheres, whereArgs := buildWhere(update.Keys)
args = append(args, whereArgs...)
if len(wheres) == 0 {
return fmt.Errorf("更新操作需要主键条件")
}
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND "))
if _, err := tx.Exec(query, args...); err != nil {
return fmt.Errorf("更新失败:%v", err)
}
}
if err := execParameterizedInsertBatches(parameterizedInsertConfig{
Table: qualifiedTable,
Rows: changes.Inserts,
QuoteColumn: quoteIdent,
Placeholder: func(int) string { return "?" },
Exec: func(query string, args ...interface{}) (sql.Result, error) {
return tx.Exec(query, args...)
},
MaxArgs: sqliteBatchInsertArgs,
}); err != nil {
return err
}
return tx.Commit()
}