From f2ffeeaf453209cf95f4c1f1f799331f5537878f Mon Sep 17 00:00:00 2001 From: Syngnat Date: Sun, 14 Jun 2026 21:37:02 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(sql-editor):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E5=AD=98=E5=82=A8=E8=BF=87=E7=A8=8B=E4=B8=8E=E8=BF=94?= =?UTF-8?q?=E5=9B=9E=E7=BB=93=E6=9E=9C=E5=86=99=E8=AF=AD=E5=8F=A5=E7=9A=84?= =?UTF-8?q?=E7=BB=93=E6=9E=9C=E8=AF=86=E5=88=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 补齐 SQL 分类逻辑,识别 SQL Server 裸存储过程调用、RETURNING/OUTPUT、SELECT INTO 及消息块场景 - 调整多语句执行与批量写入分支,避免返回行或服务端消息被 Exec 路径吞掉 - 为 PostgreSQL、OpenGauss、Kingbase、HighGo 补充 notice 回传能力并增加回归测试 --- internal/app/methods_db.go | 70 +++++++- internal/app/methods_db_multi_test.go | 228 ++++++++++++++++++++++++++ internal/app/methods_file.go | 6 +- internal/app/sql_sanitize.go | 105 ++++++++++++ internal/app/sql_sanitize_test.go | 49 ++++++ internal/db/highgo_impl.go | 55 ++++++- internal/db/kingbase_impl.go | 55 ++++++- internal/db/notice_query.go | 60 +++++++ internal/db/postgres_impl.go | 55 ++++++- 9 files changed, 668 insertions(+), 15 deletions(-) create mode 100644 internal/db/notice_query.go diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index 86b4444..2018134 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -932,15 +932,23 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu if !allReadOnly { allWrite := true containsPLSQLBlock := false + containsQueryFirstWrite := false for _, stmt := range statements { - if strings.TrimSpace(stmt) != "" && !isBatchableWriteSQLStatement(runConfig.Type, stmt) { + stmt = strings.TrimSpace(stmt) + if stmt == "" { + continue + } + if !isBatchableWriteSQLStatement(runConfig.Type, stmt) { allWrite = false } + if shouldTryQueryResultFirst(runConfig.Type, stmt) { + containsQueryFirstWrite = true + } if isPLSQLBlockStatement(stmt) { containsPLSQLBlock = true } } - if allWrite && !containsPLSQLBlock && len(statements) == 1 { + if allWrite && !containsPLSQLBlock && !containsQueryFirstWrite && len(statements) == 1 { batcher := sessionBatchTarget if batcher == nil { if fallbackBatcher, ok := dbInst.(db.BatchWriteExecer); ok { @@ -1125,8 +1133,8 @@ func shouldUseNativeMultiResultBatch(dbType string, statements []string, allRead } func shouldTryQueryResultFirst(dbType string, query string) bool { - isSQLServer := strings.EqualFold(strings.TrimSpace(dbType), "sqlserver") - if keyword, withHasWrite := sqlDataOperationInfo(query); withHasWrite && keyword == "select" { + isSQLServer := isSQLServerDBType(dbType) + if sqlWriteStatementReturnsRows(dbType, query) { return true } keyword := leadingSQLKeyword(query) @@ -1137,14 +1145,66 @@ func shouldTryQueryResultFirst(dbType string, query string) bool { return isSQLServer case "dbcc": return isSQLServer + case "do": + return isPostgresNoticeCapableDBType(dbType) && strings.Contains(strings.ToLower(query), "raise") default: if isSQLServer { - return strings.HasPrefix(keyword, "sp_") || strings.HasPrefix(keyword, "xp_") + if strings.HasPrefix(keyword, "sp_") || strings.HasPrefix(keyword, "xp_") { + return true + } + if sqlServerControlFlowMayReturnMessages(query) { + return true + } + return looksLikeSQLServerProcedureInvocation(query) } return false } } +func looksLikeSQLServerProcedureInvocation(query string) bool { + switch leadingSQLKeyword(query) { + case "select", "with", "insert", "update", "delete", "merge", "replace", "upsert", + "if", "begin", "declare", "while", "create", "alter", "drop", "truncate", "grant", "revoke", + "use", "set", "print", "dbcc", "commit", "rollback", "save", "return", "throw", "raiserror", + "waitfor", "open", "fetch", "close", "deallocate": + return false + } + + pos := skipSQLTrivia(query, 0) + if pos >= len(query) { + return false + } + + next, ok := skipSQLIdentifierToken(query, pos) + if !ok || next <= pos { + return false + } + pos = skipSQLTrivia(query, next) + for pos < len(query) && query[pos] == '.' { + pos = skipSQLTrivia(query, pos+1) + next, ok = skipSQLIdentifierToken(query, pos) + if !ok || next <= pos { + return false + } + pos = skipSQLTrivia(query, next) + } + + if pos >= len(query) { + return true + } + switch ch := query[pos]; { + case ch == ';' || ch == ',' || ch == '@' || ch == '\'' || ch == '"' || ch == '[' || ch == '(': + return true + case ch == '+' || ch == '-': + return true + case ch >= '0' && ch <= '9': + return true + default: + keyword, _ := nextSQLKeyword(query, pos) + return keyword != "" + } +} + func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string, query string) connection.QueryResult { runConfig := normalizeRunConfig(config, dbName) diff --git a/internal/app/methods_db_multi_test.go b/internal/app/methods_db_multi_test.go index db6261a..7bc13a5 100644 --- a/internal/app/methods_db_multi_test.go +++ b/internal/app/methods_db_multi_test.go @@ -1446,3 +1446,231 @@ func TestDBQueryMultiKeepsAllResultSetsFromSingleSQLServerStatement(t *testing.T t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls) } } + +func TestDBQueryMultiTreatsBareSQLServerProcedureCallAsQueryFirst(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + query := `p_get_select c_dyscript,'projectid = 1',1` + fakeDB := &fakeBatchWriteDB{ + messageMap: map[string][]string{ + query: {`INSERT c_dyscript(id,name) values (1,"demo")`}, + }, + queryErr: map[string]error{}, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"} + + result := app.DBQueryMulti(config, "master", query, "sqlserver-bare-proc-query-first-test") + if !result.Success { + t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message) + } + if fakeDB.session == nil { + t.Fatal("expected bare SQL Server procedure call to use a pinned query session") + } + if fakeDB.session.queryCalls != 1 { + t.Fatalf("expected one session query call, got %d", fakeDB.session.queryCalls) + } + if fakeDB.session.execCalls != 0 { + t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.session.execCalls) + } + resultSets, ok := result.Data.([]connection.ResultSetData) + if !ok { + t.Fatalf("expected []connection.ResultSetData, got %T", result.Data) + } + if len(resultSets) != 1 { + t.Fatalf("expected one result set, got %#v", resultSets) + } + if len(resultSets[0].Rows) != 0 || len(resultSets[0].Columns) != 0 { + t.Fatalf("expected message-only result set, got %#v", resultSets[0]) + } + if len(resultSets[0].Messages) != 1 || !strings.Contains(resultSets[0].Messages[0], "INSERT c_dyscript") { + t.Fatalf("expected procedure output message to be preserved, got %#v", resultSets[0].Messages) + } +} + +func TestDBQueryMultiTreatsReturningWriteAsQueryFirst(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + query := "INSERT INTO audit_logs(id) VALUES (1) RETURNING id" + fakeDB := &fakeBatchWriteDB{ + queryMap: map[string][]map[string]interface{}{ + query: { + {"id": 1}, + }, + }, + fieldMap: map[string][]string{ + query: {"id"}, + }, + queryErr: map[string]error{}, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "postgres", Host: "127.0.0.1", Port: 5432, User: "postgres"} + + result := app.DBQueryMulti(config, "main", query, "postgres-returning-query-first-test") + if !result.Success { + t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message) + } + if fakeDB.batchCalls != 0 { + t.Fatalf("expected RETURNING write to skip batch exec path, got batchCalls=%d", fakeDB.batchCalls) + } + if fakeDB.session == nil || fakeDB.session.queryCalls != 1 { + t.Fatalf("expected RETURNING write to query through pinned session, got session=%#v", fakeDB.session) + } + if fakeDB.session.execCalls != 0 { + t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.session.execCalls) + } + resultSets, ok := result.Data.([]connection.ResultSetData) + if !ok { + t.Fatalf("expected []connection.ResultSetData, got %T", result.Data) + } + if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 || resultSets[0].Rows[0]["id"] != 1 { + t.Fatalf("expected RETURNING rows to be preserved, got %#v", resultSets) + } +} + +func TestDBQueryMultiTreatsSQLServerOutputWriteAsQueryFirst(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + query := "UPDATE users SET name = 'next' OUTPUT inserted.id WHERE id = 1" + fakeDB := &fakeBatchWriteDB{ + queryMap: map[string][]map[string]interface{}{ + query: { + {"id": 1}, + }, + }, + fieldMap: map[string][]string{ + query: {"id"}, + }, + queryErr: map[string]error{}, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"} + + result := app.DBQueryMulti(config, "master", query, "sqlserver-output-query-first-test") + if !result.Success { + t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message) + } + if fakeDB.batchCalls != 0 { + t.Fatalf("expected OUTPUT write to skip batch exec path, got batchCalls=%d", fakeDB.batchCalls) + } + if fakeDB.session == nil || fakeDB.session.queryCalls != 1 { + t.Fatalf("expected OUTPUT write to query through pinned session, got session=%#v", fakeDB.session) + } + if fakeDB.session.execCalls != 0 { + t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.session.execCalls) + } + resultSets, ok := result.Data.([]connection.ResultSetData) + if !ok { + t.Fatalf("expected []connection.ResultSetData, got %T", result.Data) + } + if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 || resultSets[0].Rows[0]["id"] != 1 { + t.Fatalf("expected OUTPUT rows to be preserved, got %#v", resultSets) + } +} + +func TestDBQueryMultiTreatsWrappedMessageBlocksAsQueryFirst(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + sqlServerQuery := "IF 1 = 1 PRINT 'done'" + postgresQuery := "DO $$ BEGIN RAISE NOTICE 'done'; END $$" + fakeDB := &fakeBatchWriteDB{ + queryMap: map[string][]map[string]interface{}{ + sqlServerQuery: {}, + postgresQuery: {}, + }, + fieldMap: map[string][]string{ + sqlServerQuery: {}, + postgresQuery: {}, + }, + messageMap: map[string][]string{ + sqlServerQuery: {"done"}, + postgresQuery: {"done"}, + }, + queryErr: map[string]error{}, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + + sqlServerResult := app.DBQueryMulti(connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}, "master", sqlServerQuery, "sqlserver-print-block-test") + if !sqlServerResult.Success { + t.Fatalf("expected SQL Server block success, got failure: %s", sqlServerResult.Message) + } + + postgresResult := app.DBQueryMulti(connection.ConnectionConfig{Type: "postgres", Host: "127.0.0.1", Port: 5432, User: "postgres"}, "main", postgresQuery, "postgres-notice-block-test") + if !postgresResult.Success { + t.Fatalf("expected PostgreSQL notice block success, got failure: %s", postgresResult.Message) + } + if fakeDB.batchCalls != 0 { + t.Fatalf("expected message blocks to skip batch exec path, got batchCalls=%d", fakeDB.batchCalls) + } + if fakeDB.execCalls != 0 { + t.Fatalf("expected message blocks to avoid shared exec path, got execCalls=%d", fakeDB.execCalls) + } +} + +func TestDBQueryMultiTransactionalTreatsSelectIntoAsManagedWrite(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + stmt := "SELECT * INTO archived_users FROM users" + fakeDB := &fakeBatchWriteDB{ + execAffected: map[string]int64{ + stmt: 12, + }, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "postgres", Host: "127.0.0.1", Port: 5432, User: "postgres"} + + result := app.DBQueryMultiTransactional(config, "main", stmt, "select-into-managed-tx-test") + if !result.Success { + t.Fatalf("expected managed SELECT INTO success, got failure: %s", result.Message) + } + if result.TransactionID == "" || !result.TransactionPending { + t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending) + } + if fakeDB.session == nil || fakeDB.session.closed { + t.Fatal("expected managed SELECT INTO transaction session to stay open") + } + wantExecs := []string{"BEGIN", stmt} + if len(fakeDB.execQueries) != len(wantExecs) { + t.Fatalf("expected exec queries %#v, got %#v", wantExecs, fakeDB.execQueries) + } + for i, want := range wantExecs { + if fakeDB.execQueries[i] != want { + t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i]) + } + } +} diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go index 8773c09..4fc0d3e 100644 --- a/internal/app/methods_file.go +++ b/internal/app/methods_file.go @@ -806,12 +806,10 @@ func isSQLFileBatchableWriteStatement(dbType string, stmt string) bool { if isPLSQLBlockStatement(stmt) { return false } - switch leadingSQLKeyword(stmt) { - case "insert", "update", "delete", "replace", "merge", "upsert": - return true - default: + if shouldTryQueryResultFirst(dbType, stmt) { return false } + return isBatchableWriteSQLStatement(dbType, stmt) } func sqlFileBatchTransactionSQL(dbType string) (beginSQL string, commitSQL string, rollbackSQL string, ok bool) { diff --git a/internal/app/sql_sanitize.go b/internal/app/sql_sanitize.go index 78c7d3d..b2dfe5d 100644 --- a/internal/app/sql_sanitize.go +++ b/internal/app/sql_sanitize.go @@ -336,6 +336,105 @@ func isSQLKeywordByte(ch byte) bool { return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' } +func normalizeSQLClassifierDBType(dbType string) string { + normalized := strings.ToLower(strings.TrimSpace(dbType)) + switch normalized { + case "postgresql": + return "postgres" + case "mssql", "sql_server", "sql-server": + return "sqlserver" + case "open_gauss", "open-gauss": + return "opengauss" + case "gauss_db", "gauss-db": + return "gaussdb" + case "kingbase8", "kingbasees", "kingbasev8": + return "kingbase" + default: + return normalized + } +} + +func isSQLServerDBType(dbType string) bool { + return normalizeSQLClassifierDBType(dbType) == "sqlserver" +} + +func isOracleLikeDBType(dbType string) bool { + switch normalizeSQLClassifierDBType(dbType) { + case "oracle", "dameng": + return true + default: + return false + } +} + +func isPostgresNoticeCapableDBType(dbType string) bool { + switch normalizeSQLClassifierDBType(dbType) { + case "postgres", "opengauss": + return true + default: + return false + } +} + +func isSQLSelectIntoStatement(query string) bool { + keyword, _ := sqlDataOperationInfo(query) + return keyword == "select" && findTopLevelSQLKeyword(query, 0, "into") >= 0 +} + +func sqlContainsKeyword(text string, want string) bool { + for pos := 0; pos < len(text); { + if next, ok := skipSQLQuotedOrComment(text, pos); ok { + pos = next + continue + } + if isSQLKeywordByte(text[pos]) { + end := pos + 1 + for end < len(text) && isSQLKeywordByte(text[end]) { + end++ + } + if strings.EqualFold(text[pos:end], want) { + return true + } + pos = end + continue + } + pos++ + } + return false +} + +func sqlWriteStatementReturnsRows(dbType string, query string) bool { + keyword, withHasWrite := sqlDataOperationInfo(query) + if withHasWrite && keyword == "select" { + return true + } + if !isSQLDataWriteKeyword(keyword) { + return false + } + if !isOracleLikeDBType(dbType) && findTopLevelSQLKeyword(query, 0, "returning") >= 0 { + return true + } + if isSQLServerDBType(dbType) && findTopLevelSQLKeyword(query, 0, "output") >= 0 { + return true + } + return false +} + +func sqlServerControlFlowMayReturnMessages(query string) bool { + switch leadingSQLKeyword(query) { + case "if", "begin", "while", "declare", "try", "catch": + return sqlContainsKeyword(query, "exec") || + sqlContainsKeyword(query, "execute") || + sqlContainsKeyword(query, "print") || + sqlContainsKeyword(query, "raiserror") || + sqlContainsKeyword(query, "throw") || + sqlContainsKeyword(query, "dbcc") || + sqlContainsKeyword(query, "set") + default: + return false + } +} + func isReadOnlySQLQuery(dbType string, query string) bool { if strings.ToLower(strings.TrimSpace(dbType)) == "mongodb" && strings.HasPrefix(strings.TrimSpace(query), "{") { return true @@ -345,6 +444,9 @@ func isReadOnlySQLQuery(dbType string, query string) bool { if withHasWrite { return false } + if keyword == "select" && isSQLSelectIntoStatement(query) { + return false + } switch keyword { case "select", "with", "show", "describe", "desc", "explain", "pragma", "values", "consume": return true @@ -362,6 +464,9 @@ func isBatchableWriteSQLStatement(dbType string, query string) bool { if withHasWrite { return true } + if keyword == "select" && isSQLSelectIntoStatement(query) { + return true + } return isSQLDataWriteKeyword(keyword) } diff --git a/internal/app/sql_sanitize_test.go b/internal/app/sql_sanitize_test.go index 248e205..870f08a 100644 --- a/internal/app/sql_sanitize_test.go +++ b/internal/app/sql_sanitize_test.go @@ -86,6 +86,13 @@ func TestIsReadOnlySQLQuery_ClassifiesWithByTopLevelOperation(t *testing.T) { } } +func TestIsReadOnlySQLQuery_TreatsSelectIntoAsWrite(t *testing.T) { + query := "SELECT * INTO archived_users FROM users" + if isReadOnlySQLQuery("postgres", query) { + t.Fatal("SELECT INTO should not be treated as read-only") + } +} + func TestIsReadOnlySQLQuery_TreatsKafkaConsumeAsReadOnly(t *testing.T) { if !isReadOnlySQLQuery("kafka", `CONSUME GROUP "analytics" FROM "orders.events" LIMIT 20`) { t.Fatal("Kafka CONSUME should be treated as read-only") @@ -102,6 +109,9 @@ func TestIsBatchableWriteSQLStatement_OnlyMatchesRealWriteStatements(t *testing. if !isBatchableWriteSQLStatement("postgres", "WITH moved AS (DELETE FROM audit_logs WHERE created_at < NOW() RETURNING id) SELECT * FROM moved") { t.Fatal("expected data-changing CTE to be treated as batchable write") } + if !isBatchableWriteSQLStatement("postgres", "SELECT * INTO archived_users FROM users") { + t.Fatal("expected SELECT INTO to be treated as batchable write") + } if isBatchableWriteSQLStatement("sqlserver", "EXEC sp_who2") { t.Fatal("EXEC should not be treated as batchable write") } @@ -131,6 +141,45 @@ func TestShouldTryQueryResultFirst_TreatsSQLServerSystemCommandsAsQueryFirst(t * } } +func TestShouldTryQueryResultFirst_TreatsSQLServerBareProcedureCallsAsQueryFirst(t *testing.T) { + if !shouldTryQueryResultFirst("sqlserver", `p_get_select c_dyscript,'projectid = 1',1`) { + t.Fatal("expected bare SQL Server procedure call to try query-first") + } + if !shouldTryQueryResultFirst("sqlserver", `dbo.p_get_select c_dyscript,'projectid = 1',1`) { + t.Fatal("expected schema-qualified SQL Server procedure call to try query-first") + } + if !shouldTryQueryResultFirst("sqlserver", `[dbo].[p_get_select] c_dyscript,'projectid = 1',1`) { + t.Fatal("expected bracket-qualified SQL Server procedure call to try query-first") + } +} + +func TestShouldTryQueryResultFirst_TreatsReturningAndOutputWritesAsQueryFirst(t *testing.T) { + if !shouldTryQueryResultFirst("postgres", "INSERT INTO audit_logs(id) VALUES (1) RETURNING id") { + t.Fatal("expected INSERT ... RETURNING to try query-first") + } + if !shouldTryQueryResultFirst("sqlserver", "UPDATE users SET name = 'next' OUTPUT inserted.id WHERE id = 1") { + t.Fatal("expected SQL Server OUTPUT DML to try query-first") + } +} + +func TestShouldTryQueryResultFirst_TreatsWrappedMessageBlocksAsQueryFirst(t *testing.T) { + if !shouldTryQueryResultFirst("sqlserver", "IF 1 = 1 EXEC dbo.p_get_select @name = 'demo'") { + t.Fatal("expected control-flow wrapped SQL Server procedure call to try query-first") + } + if !shouldTryQueryResultFirst("sqlserver", "BEGIN PRINT 'done'; END") { + t.Fatal("expected SQL Server BEGIN/PRINT block to try query-first") + } + if !shouldTryQueryResultFirst("postgres", "DO $$ BEGIN RAISE NOTICE 'done'; END $$") { + t.Fatal("expected PostgreSQL DO/RAISE NOTICE block to try query-first") + } +} + +func TestShouldTryQueryResultFirst_DoesNotMisclassifyPlainSQLServerDML(t *testing.T) { + if shouldTryQueryResultFirst("sqlserver", "UPDATE users SET name = 'next' WHERE id = 1") { + t.Fatal("plain SQL Server UPDATE should not try query-first") + } +} + func TestShouldTryQueryResultFirst_TreatsDataChangingCTESelectAsQueryFirst(t *testing.T) { query := "WITH moved AS (DELETE FROM audit_logs WHERE created_at < NOW() RETURNING id) SELECT * FROM moved" if !shouldTryQueryResultFirst("postgres", query) { diff --git a/internal/db/highgo_impl.go b/internal/db/highgo_impl.go index bfe68b9..19fb06c 100644 --- a/internal/db/highgo_impl.go +++ b/internal/db/highgo_impl.go @@ -5,6 +5,7 @@ package db import ( "context" "database/sql" + "database/sql/driver" "fmt" "net" "net/url" @@ -17,7 +18,7 @@ import ( "GoNavi-Wails/internal/ssh" "GoNavi-Wails/internal/utils" - _ "github.com/highgo/pq-sm3" // HighGo uses dedicated SM3-capable driver + highgopq "github.com/highgo/pq-sm3" // HighGo uses dedicated SM3-capable driver ) // HighGoDB implements Database interface for HighGo (瀚高) database @@ -28,6 +29,13 @@ type HighGoDB struct { forwarder *ssh.LocalForwarder } +type highgoSessionExecer struct { + *sqlConnStatementExecer +} + +var _ QueryMessageExecer = (*HighGoDB)(nil) +var _ StatementQueryMessageExecer = (*highgoSessionExecer)(nil) + func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string { // postgres://user:password@host:port/dbname?sslmode=disable dbname := config.Database @@ -152,6 +160,20 @@ func (h *HighGoDB) QueryContext(ctx context.Context, query string) ([]map[string return scanRows(rows) } +func (h *HighGoDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { + if h.conn == nil { + return nil, nil, nil, fmt.Errorf("连接未打开") + } + + conn, err := h.conn.Conn(ctx) + if err != nil { + return nil, nil, nil, err + } + defer conn.Close() + + return queryHighGoConnWithMessages(ctx, conn, query) +} + func (h *HighGoDB) Query(query string) ([]map[string]interface{}, []string, error) { if h.conn == nil { return nil, nil, fmt.Errorf("连接未打开") @@ -165,6 +187,10 @@ func (h *HighGoDB) Query(query string) ([]map[string]interface{}, []string, erro return scanRows(rows) } +func (h *HighGoDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + return h.QueryContextWithMessages(context.Background(), query) +} + func (h *HighGoDB) ExecContext(ctx context.Context, query string) (int64, error) { if h.conn == nil { return 0, fmt.Errorf("连接未打开") @@ -195,7 +221,7 @@ func (h *HighGoDB) OpenSessionExecer(ctx context.Context) (StatementExecer, erro if err != nil { return nil, err } - return NewSQLConnStatementExecer(conn), nil + return &highgoSessionExecer{sqlConnStatementExecer: &sqlConnStatementExecer{conn: conn}}, nil } func (h *HighGoDB) Exec(query string) (int64, error) { @@ -209,6 +235,31 @@ func (h *HighGoDB) Exec(query string) (int64, error) { return res.RowsAffected() } +func (e *highgoSessionExecer) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + return e.QueryContextWithMessages(context.Background(), query) +} + +func (e *highgoSessionExecer) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { + if e == nil || e.conn == nil { + return nil, nil, nil, fmt.Errorf("连接未打开") + } + return queryHighGoConnWithMessages(ctx, e.conn, query) +} + +func queryHighGoConnWithMessages(ctx context.Context, conn *sql.Conn, query string) ([]map[string]interface{}, []string, []string, error) { + return querySQLConnWithTextNotices(ctx, conn, query, func(driverConn driver.Conn, addNotice func(string)) { + if addNotice == nil { + highgopq.SetNoticeHandler(driverConn, nil) + return + } + highgopq.SetNoticeHandler(driverConn, func(notice *highgopq.Error) { + if notice != nil { + addNotice(notice.Message) + } + }) + }) +} + func (h *HighGoDB) GetDatabases() ([]string, error) { data, _, err := h.Query("SELECT datname FROM pg_database WHERE datistemplate = false") if err != nil { diff --git a/internal/db/kingbase_impl.go b/internal/db/kingbase_impl.go index 161a222..1daa776 100644 --- a/internal/db/kingbase_impl.go +++ b/internal/db/kingbase_impl.go @@ -5,6 +5,7 @@ package db import ( "context" "database/sql" + "database/sql/driver" "fmt" "net" "net/url" @@ -18,7 +19,7 @@ import ( "GoNavi-Wails/internal/ssh" "GoNavi-Wails/internal/utils" - _ "gitea.com/kingbase/gokb" // Registers "kingbase" driver + gokb "gitea.com/kingbase/gokb" // Registers "kingbase" driver ) type KingbaseDB struct { @@ -27,6 +28,13 @@ type KingbaseDB struct { forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder } +type kingbaseSessionExecer struct { + *sqlConnStatementExecer +} + +var _ QueryMessageExecer = (*KingbaseDB)(nil) +var _ StatementQueryMessageExecer = (*kingbaseSessionExecer)(nil) + func quoteConnValue(v string) string { if v == "" { return "''" @@ -280,6 +288,20 @@ func (k *KingbaseDB) QueryContext(ctx context.Context, query string) ([]map[stri return scanRows(rows) } +func (k *KingbaseDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { + if k.conn == nil { + return nil, nil, nil, fmt.Errorf("连接未打开") + } + + conn, err := k.conn.Conn(ctx) + if err != nil { + return nil, nil, nil, err + } + defer conn.Close() + + return queryKingbaseConnWithMessages(ctx, conn, query) +} + func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, error) { if k.conn == nil { return nil, nil, fmt.Errorf("连接未打开") @@ -293,6 +315,10 @@ func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, er return scanRows(rows) } +func (k *KingbaseDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + return k.QueryContextWithMessages(context.Background(), query) +} + func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, error) { if k.conn == nil { return 0, fmt.Errorf("连接未打开") @@ -323,7 +349,7 @@ func (k *KingbaseDB) OpenSessionExecer(ctx context.Context) (StatementExecer, er if err != nil { return nil, err } - return NewSQLConnStatementExecer(conn), nil + return &kingbaseSessionExecer{sqlConnStatementExecer: &sqlConnStatementExecer{conn: conn}}, nil } func (k *KingbaseDB) Exec(query string) (int64, error) { @@ -337,6 +363,31 @@ func (k *KingbaseDB) Exec(query string) (int64, error) { return res.RowsAffected() } +func (e *kingbaseSessionExecer) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + return e.QueryContextWithMessages(context.Background(), query) +} + +func (e *kingbaseSessionExecer) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { + if e == nil || e.conn == nil { + return nil, nil, nil, fmt.Errorf("连接未打开") + } + return queryKingbaseConnWithMessages(ctx, e.conn, query) +} + +func queryKingbaseConnWithMessages(ctx context.Context, conn *sql.Conn, query string) ([]map[string]interface{}, []string, []string, error) { + return querySQLConnWithTextNotices(ctx, conn, query, func(driverConn driver.Conn, addNotice func(string)) { + if addNotice == nil { + gokb.SetNoticeHandler(driverConn, nil) + return + } + gokb.SetNoticeHandler(driverConn, func(notice *gokb.Error) { + if notice != nil { + addNotice(notice.Message) + } + }) + }) +} + func (k *KingbaseDB) GetDatabases() ([]string, error) { data, _, err := k.Query("SELECT datname FROM pg_database WHERE datistemplate = false") if err == nil { diff --git a/internal/db/notice_query.go b/internal/db/notice_query.go new file mode 100644 index 0000000..57f1d97 --- /dev/null +++ b/internal/db/notice_query.go @@ -0,0 +1,60 @@ +package db + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "strings" +) + +type sqlTextNoticeHandlerSetter func(driver.Conn, func(string)) + +func querySQLConnWithTextNotices(ctx context.Context, conn *sql.Conn, query string, setHandler sqlTextNoticeHandlerSetter) ([]map[string]interface{}, []string, []string, error) { + if conn == nil { + return nil, nil, nil, fmt.Errorf("连接未打开") + } + if setHandler == nil { + return nil, nil, nil, fmt.Errorf("未配置消息捕获处理器") + } + if ctx == nil { + ctx = context.Background() + } + + notices := make([]string, 0, 2) + addNotice := func(text string) { + text = strings.TrimSpace(text) + if text != "" { + notices = append(notices, text) + } + } + + if err := conn.Raw(func(rawConn interface{}) error { + driverConn, ok := rawConn.(driver.Conn) + if !ok { + return fmt.Errorf("底层连接类型不支持消息捕获") + } + setHandler(driverConn, addNotice) + return nil + }); err != nil { + return nil, nil, nil, err + } + defer func() { + _ = conn.Raw(func(rawConn interface{}) error { + driverConn, ok := rawConn.(driver.Conn) + if ok { + setHandler(driverConn, nil) + } + return nil + }) + }() + + rows, err := conn.QueryContext(ctx, query) + if err != nil { + return nil, nil, append([]string(nil), notices...), err + } + defer rows.Close() + + data, columns, err := scanRows(rows) + return data, columns, append([]string(nil), notices...), err +} diff --git a/internal/db/postgres_impl.go b/internal/db/postgres_impl.go index 7753203..30f5fac 100644 --- a/internal/db/postgres_impl.go +++ b/internal/db/postgres_impl.go @@ -3,6 +3,7 @@ package db import ( "context" "database/sql" + "database/sql/driver" "fmt" "net" "net/url" @@ -15,7 +16,7 @@ import ( "GoNavi-Wails/internal/ssh" "GoNavi-Wails/internal/utils" - _ "github.com/lib/pq" + "github.com/lib/pq" ) type PostgresDB struct { @@ -24,6 +25,13 @@ type PostgresDB struct { forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder } +type postgresSessionExecer struct { + *sqlConnStatementExecer +} + +var _ QueryMessageExecer = (*PostgresDB)(nil) +var _ StatementQueryMessageExecer = (*postgresSessionExecer)(nil) + func resolvePostgresConnectDatabases(config connection.ConnectionConfig) []string { explicit := strings.TrimSpace(config.Database) if explicit != "" { @@ -225,6 +233,20 @@ func (p *PostgresDB) QueryContext(ctx context.Context, query string) ([]map[stri return scanRows(rows) } +func (p *PostgresDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { + if p.conn == nil { + return nil, nil, nil, fmt.Errorf("连接未打开") + } + + conn, err := p.conn.Conn(ctx) + if err != nil { + return nil, nil, nil, err + } + defer conn.Close() + + return queryPostgresConnWithMessages(ctx, conn, query) +} + func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, error) { if p.conn == nil { return nil, nil, fmt.Errorf("连接未打开") @@ -238,6 +260,10 @@ func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, er return scanRows(rows) } +func (p *PostgresDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + return p.QueryContextWithMessages(context.Background(), query) +} + func (p *PostgresDB) ExecBatchContext(ctx context.Context, query string) (int64, error) { if p.conn == nil { return 0, fmt.Errorf("连接未打开") @@ -257,7 +283,7 @@ func (p *PostgresDB) OpenSessionExecer(ctx context.Context) (StatementExecer, er if err != nil { return nil, err } - return NewSQLConnStatementExecer(conn), nil + return &postgresSessionExecer{sqlConnStatementExecer: &sqlConnStatementExecer{conn: conn}}, nil } func (p *PostgresDB) ExecContext(ctx context.Context, query string) (int64, error) { @@ -282,6 +308,31 @@ func (p *PostgresDB) Exec(query string) (int64, error) { return res.RowsAffected() } +func (e *postgresSessionExecer) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + return e.QueryContextWithMessages(context.Background(), query) +} + +func (e *postgresSessionExecer) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { + if e == nil || e.conn == nil { + return nil, nil, nil, fmt.Errorf("连接未打开") + } + return queryPostgresConnWithMessages(ctx, e.conn, query) +} + +func queryPostgresConnWithMessages(ctx context.Context, conn *sql.Conn, query string) ([]map[string]interface{}, []string, []string, error) { + return querySQLConnWithTextNotices(ctx, conn, query, func(driverConn driver.Conn, addNotice func(string)) { + if addNotice == nil { + pq.SetNoticeHandler(driverConn, nil) + return + } + pq.SetNoticeHandler(driverConn, func(notice *pq.Error) { + if notice != nil { + addNotice(notice.Message) + } + }) + }) +} + func (p *PostgresDB) GetDatabases() ([]string, error) { data, _, err := p.Query("SELECT datname FROM pg_database WHERE datistemplate = false") if err != nil {