🐛 fix(sql-editor): 修复存储过程与返回结果写语句的结果识别

- 补齐 SQL 分类逻辑,识别 SQL Server 裸存储过程调用、RETURNING/OUTPUT、SELECT INTO 及消息块场景
- 调整多语句执行与批量写入分支,避免返回行或服务端消息被 Exec 路径吞掉
- 为 PostgreSQL、OpenGauss、Kingbase、HighGo 补充 notice 回传能力并增加回归测试
This commit is contained in:
Syngnat
2026-06-14 21:37:02 +08:00
parent d7632e29a6
commit f2ffeeaf45
9 changed files with 668 additions and 15 deletions

View File

@@ -932,15 +932,23 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
if !allReadOnly {
allWrite := true
containsPLSQLBlock := false
containsQueryFirstWrite := false
for _, stmt := range statements {
if strings.TrimSpace(stmt) != "" && !isBatchableWriteSQLStatement(runConfig.Type, stmt) {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
if !isBatchableWriteSQLStatement(runConfig.Type, stmt) {
allWrite = false
}
if shouldTryQueryResultFirst(runConfig.Type, stmt) {
containsQueryFirstWrite = true
}
if isPLSQLBlockStatement(stmt) {
containsPLSQLBlock = true
}
}
if allWrite && !containsPLSQLBlock && len(statements) == 1 {
if allWrite && !containsPLSQLBlock && !containsQueryFirstWrite && len(statements) == 1 {
batcher := sessionBatchTarget
if batcher == nil {
if fallbackBatcher, ok := dbInst.(db.BatchWriteExecer); ok {
@@ -1125,8 +1133,8 @@ func shouldUseNativeMultiResultBatch(dbType string, statements []string, allRead
}
func shouldTryQueryResultFirst(dbType string, query string) bool {
isSQLServer := strings.EqualFold(strings.TrimSpace(dbType), "sqlserver")
if keyword, withHasWrite := sqlDataOperationInfo(query); withHasWrite && keyword == "select" {
isSQLServer := isSQLServerDBType(dbType)
if sqlWriteStatementReturnsRows(dbType, query) {
return true
}
keyword := leadingSQLKeyword(query)
@@ -1137,14 +1145,66 @@ func shouldTryQueryResultFirst(dbType string, query string) bool {
return isSQLServer
case "dbcc":
return isSQLServer
case "do":
return isPostgresNoticeCapableDBType(dbType) && strings.Contains(strings.ToLower(query), "raise")
default:
if isSQLServer {
return strings.HasPrefix(keyword, "sp_") || strings.HasPrefix(keyword, "xp_")
if strings.HasPrefix(keyword, "sp_") || strings.HasPrefix(keyword, "xp_") {
return true
}
if sqlServerControlFlowMayReturnMessages(query) {
return true
}
return looksLikeSQLServerProcedureInvocation(query)
}
return false
}
}
func looksLikeSQLServerProcedureInvocation(query string) bool {
switch leadingSQLKeyword(query) {
case "select", "with", "insert", "update", "delete", "merge", "replace", "upsert",
"if", "begin", "declare", "while", "create", "alter", "drop", "truncate", "grant", "revoke",
"use", "set", "print", "dbcc", "commit", "rollback", "save", "return", "throw", "raiserror",
"waitfor", "open", "fetch", "close", "deallocate":
return false
}
pos := skipSQLTrivia(query, 0)
if pos >= len(query) {
return false
}
next, ok := skipSQLIdentifierToken(query, pos)
if !ok || next <= pos {
return false
}
pos = skipSQLTrivia(query, next)
for pos < len(query) && query[pos] == '.' {
pos = skipSQLTrivia(query, pos+1)
next, ok = skipSQLIdentifierToken(query, pos)
if !ok || next <= pos {
return false
}
pos = skipSQLTrivia(query, next)
}
if pos >= len(query) {
return true
}
switch ch := query[pos]; {
case ch == ';' || ch == ',' || ch == '@' || ch == '\'' || ch == '"' || ch == '[' || ch == '(':
return true
case ch == '+' || ch == '-':
return true
case ch >= '0' && ch <= '9':
return true
default:
keyword, _ := nextSQLKeyword(query, pos)
return keyword != ""
}
}
func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string, query string) connection.QueryResult {
runConfig := normalizeRunConfig(config, dbName)

View File

@@ -1446,3 +1446,231 @@ func TestDBQueryMultiKeepsAllResultSetsFromSingleSQLServerStatement(t *testing.T
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls)
}
}
func TestDBQueryMultiTreatsBareSQLServerProcedureCallAsQueryFirst(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
query := `p_get_select c_dyscript,'projectid = 1',1`
fakeDB := &fakeBatchWriteDB{
messageMap: map[string][]string{
query: {`INSERT c_dyscript(id,name) values (1,"demo")`},
},
queryErr: map[string]error{},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
result := app.DBQueryMulti(config, "master", query, "sqlserver-bare-proc-query-first-test")
if !result.Success {
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
}
if fakeDB.session == nil {
t.Fatal("expected bare SQL Server procedure call to use a pinned query session")
}
if fakeDB.session.queryCalls != 1 {
t.Fatalf("expected one session query call, got %d", fakeDB.session.queryCalls)
}
if fakeDB.session.execCalls != 0 {
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.session.execCalls)
}
resultSets, ok := result.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
}
if len(resultSets) != 1 {
t.Fatalf("expected one result set, got %#v", resultSets)
}
if len(resultSets[0].Rows) != 0 || len(resultSets[0].Columns) != 0 {
t.Fatalf("expected message-only result set, got %#v", resultSets[0])
}
if len(resultSets[0].Messages) != 1 || !strings.Contains(resultSets[0].Messages[0], "INSERT c_dyscript") {
t.Fatalf("expected procedure output message to be preserved, got %#v", resultSets[0].Messages)
}
}
func TestDBQueryMultiTreatsReturningWriteAsQueryFirst(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
query := "INSERT INTO audit_logs(id) VALUES (1) RETURNING id"
fakeDB := &fakeBatchWriteDB{
queryMap: map[string][]map[string]interface{}{
query: {
{"id": 1},
},
},
fieldMap: map[string][]string{
query: {"id"},
},
queryErr: map[string]error{},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "postgres", Host: "127.0.0.1", Port: 5432, User: "postgres"}
result := app.DBQueryMulti(config, "main", query, "postgres-returning-query-first-test")
if !result.Success {
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
}
if fakeDB.batchCalls != 0 {
t.Fatalf("expected RETURNING write to skip batch exec path, got batchCalls=%d", fakeDB.batchCalls)
}
if fakeDB.session == nil || fakeDB.session.queryCalls != 1 {
t.Fatalf("expected RETURNING write to query through pinned session, got session=%#v", fakeDB.session)
}
if fakeDB.session.execCalls != 0 {
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.session.execCalls)
}
resultSets, ok := result.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
}
if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 || resultSets[0].Rows[0]["id"] != 1 {
t.Fatalf("expected RETURNING rows to be preserved, got %#v", resultSets)
}
}
func TestDBQueryMultiTreatsSQLServerOutputWriteAsQueryFirst(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
query := "UPDATE users SET name = 'next' OUTPUT inserted.id WHERE id = 1"
fakeDB := &fakeBatchWriteDB{
queryMap: map[string][]map[string]interface{}{
query: {
{"id": 1},
},
},
fieldMap: map[string][]string{
query: {"id"},
},
queryErr: map[string]error{},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
result := app.DBQueryMulti(config, "master", query, "sqlserver-output-query-first-test")
if !result.Success {
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
}
if fakeDB.batchCalls != 0 {
t.Fatalf("expected OUTPUT write to skip batch exec path, got batchCalls=%d", fakeDB.batchCalls)
}
if fakeDB.session == nil || fakeDB.session.queryCalls != 1 {
t.Fatalf("expected OUTPUT write to query through pinned session, got session=%#v", fakeDB.session)
}
if fakeDB.session.execCalls != 0 {
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.session.execCalls)
}
resultSets, ok := result.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
}
if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 || resultSets[0].Rows[0]["id"] != 1 {
t.Fatalf("expected OUTPUT rows to be preserved, got %#v", resultSets)
}
}
func TestDBQueryMultiTreatsWrappedMessageBlocksAsQueryFirst(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
sqlServerQuery := "IF 1 = 1 PRINT 'done'"
postgresQuery := "DO $$ BEGIN RAISE NOTICE 'done'; END $$"
fakeDB := &fakeBatchWriteDB{
queryMap: map[string][]map[string]interface{}{
sqlServerQuery: {},
postgresQuery: {},
},
fieldMap: map[string][]string{
sqlServerQuery: {},
postgresQuery: {},
},
messageMap: map[string][]string{
sqlServerQuery: {"done"},
postgresQuery: {"done"},
},
queryErr: map[string]error{},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
sqlServerResult := app.DBQueryMulti(connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}, "master", sqlServerQuery, "sqlserver-print-block-test")
if !sqlServerResult.Success {
t.Fatalf("expected SQL Server block success, got failure: %s", sqlServerResult.Message)
}
postgresResult := app.DBQueryMulti(connection.ConnectionConfig{Type: "postgres", Host: "127.0.0.1", Port: 5432, User: "postgres"}, "main", postgresQuery, "postgres-notice-block-test")
if !postgresResult.Success {
t.Fatalf("expected PostgreSQL notice block success, got failure: %s", postgresResult.Message)
}
if fakeDB.batchCalls != 0 {
t.Fatalf("expected message blocks to skip batch exec path, got batchCalls=%d", fakeDB.batchCalls)
}
if fakeDB.execCalls != 0 {
t.Fatalf("expected message blocks to avoid shared exec path, got execCalls=%d", fakeDB.execCalls)
}
}
func TestDBQueryMultiTransactionalTreatsSelectIntoAsManagedWrite(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
stmt := "SELECT * INTO archived_users FROM users"
fakeDB := &fakeBatchWriteDB{
execAffected: map[string]int64{
stmt: 12,
},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "postgres", Host: "127.0.0.1", Port: 5432, User: "postgres"}
result := app.DBQueryMultiTransactional(config, "main", stmt, "select-into-managed-tx-test")
if !result.Success {
t.Fatalf("expected managed SELECT INTO success, got failure: %s", result.Message)
}
if result.TransactionID == "" || !result.TransactionPending {
t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
}
if fakeDB.session == nil || fakeDB.session.closed {
t.Fatal("expected managed SELECT INTO transaction session to stay open")
}
wantExecs := []string{"BEGIN", stmt}
if len(fakeDB.execQueries) != len(wantExecs) {
t.Fatalf("expected exec queries %#v, got %#v", wantExecs, fakeDB.execQueries)
}
for i, want := range wantExecs {
if fakeDB.execQueries[i] != want {
t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i])
}
}
}

View File

@@ -806,12 +806,10 @@ func isSQLFileBatchableWriteStatement(dbType string, stmt string) bool {
if isPLSQLBlockStatement(stmt) {
return false
}
switch leadingSQLKeyword(stmt) {
case "insert", "update", "delete", "replace", "merge", "upsert":
return true
default:
if shouldTryQueryResultFirst(dbType, stmt) {
return false
}
return isBatchableWriteSQLStatement(dbType, stmt)
}
func sqlFileBatchTransactionSQL(dbType string) (beginSQL string, commitSQL string, rollbackSQL string, ok bool) {

View File

@@ -336,6 +336,105 @@ func isSQLKeywordByte(ch byte) bool {
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_'
}
func normalizeSQLClassifierDBType(dbType string) string {
normalized := strings.ToLower(strings.TrimSpace(dbType))
switch normalized {
case "postgresql":
return "postgres"
case "mssql", "sql_server", "sql-server":
return "sqlserver"
case "open_gauss", "open-gauss":
return "opengauss"
case "gauss_db", "gauss-db":
return "gaussdb"
case "kingbase8", "kingbasees", "kingbasev8":
return "kingbase"
default:
return normalized
}
}
func isSQLServerDBType(dbType string) bool {
return normalizeSQLClassifierDBType(dbType) == "sqlserver"
}
func isOracleLikeDBType(dbType string) bool {
switch normalizeSQLClassifierDBType(dbType) {
case "oracle", "dameng":
return true
default:
return false
}
}
func isPostgresNoticeCapableDBType(dbType string) bool {
switch normalizeSQLClassifierDBType(dbType) {
case "postgres", "opengauss":
return true
default:
return false
}
}
func isSQLSelectIntoStatement(query string) bool {
keyword, _ := sqlDataOperationInfo(query)
return keyword == "select" && findTopLevelSQLKeyword(query, 0, "into") >= 0
}
func sqlContainsKeyword(text string, want string) bool {
for pos := 0; pos < len(text); {
if next, ok := skipSQLQuotedOrComment(text, pos); ok {
pos = next
continue
}
if isSQLKeywordByte(text[pos]) {
end := pos + 1
for end < len(text) && isSQLKeywordByte(text[end]) {
end++
}
if strings.EqualFold(text[pos:end], want) {
return true
}
pos = end
continue
}
pos++
}
return false
}
func sqlWriteStatementReturnsRows(dbType string, query string) bool {
keyword, withHasWrite := sqlDataOperationInfo(query)
if withHasWrite && keyword == "select" {
return true
}
if !isSQLDataWriteKeyword(keyword) {
return false
}
if !isOracleLikeDBType(dbType) && findTopLevelSQLKeyword(query, 0, "returning") >= 0 {
return true
}
if isSQLServerDBType(dbType) && findTopLevelSQLKeyword(query, 0, "output") >= 0 {
return true
}
return false
}
func sqlServerControlFlowMayReturnMessages(query string) bool {
switch leadingSQLKeyword(query) {
case "if", "begin", "while", "declare", "try", "catch":
return sqlContainsKeyword(query, "exec") ||
sqlContainsKeyword(query, "execute") ||
sqlContainsKeyword(query, "print") ||
sqlContainsKeyword(query, "raiserror") ||
sqlContainsKeyword(query, "throw") ||
sqlContainsKeyword(query, "dbcc") ||
sqlContainsKeyword(query, "set")
default:
return false
}
}
func isReadOnlySQLQuery(dbType string, query string) bool {
if strings.ToLower(strings.TrimSpace(dbType)) == "mongodb" && strings.HasPrefix(strings.TrimSpace(query), "{") {
return true
@@ -345,6 +444,9 @@ func isReadOnlySQLQuery(dbType string, query string) bool {
if withHasWrite {
return false
}
if keyword == "select" && isSQLSelectIntoStatement(query) {
return false
}
switch keyword {
case "select", "with", "show", "describe", "desc", "explain", "pragma", "values", "consume":
return true
@@ -362,6 +464,9 @@ func isBatchableWriteSQLStatement(dbType string, query string) bool {
if withHasWrite {
return true
}
if keyword == "select" && isSQLSelectIntoStatement(query) {
return true
}
return isSQLDataWriteKeyword(keyword)
}

View File

@@ -86,6 +86,13 @@ func TestIsReadOnlySQLQuery_ClassifiesWithByTopLevelOperation(t *testing.T) {
}
}
func TestIsReadOnlySQLQuery_TreatsSelectIntoAsWrite(t *testing.T) {
query := "SELECT * INTO archived_users FROM users"
if isReadOnlySQLQuery("postgres", query) {
t.Fatal("SELECT INTO should not be treated as read-only")
}
}
func TestIsReadOnlySQLQuery_TreatsKafkaConsumeAsReadOnly(t *testing.T) {
if !isReadOnlySQLQuery("kafka", `CONSUME GROUP "analytics" FROM "orders.events" LIMIT 20`) {
t.Fatal("Kafka CONSUME should be treated as read-only")
@@ -102,6 +109,9 @@ func TestIsBatchableWriteSQLStatement_OnlyMatchesRealWriteStatements(t *testing.
if !isBatchableWriteSQLStatement("postgres", "WITH moved AS (DELETE FROM audit_logs WHERE created_at < NOW() RETURNING id) SELECT * FROM moved") {
t.Fatal("expected data-changing CTE to be treated as batchable write")
}
if !isBatchableWriteSQLStatement("postgres", "SELECT * INTO archived_users FROM users") {
t.Fatal("expected SELECT INTO to be treated as batchable write")
}
if isBatchableWriteSQLStatement("sqlserver", "EXEC sp_who2") {
t.Fatal("EXEC should not be treated as batchable write")
}
@@ -131,6 +141,45 @@ func TestShouldTryQueryResultFirst_TreatsSQLServerSystemCommandsAsQueryFirst(t *
}
}
func TestShouldTryQueryResultFirst_TreatsSQLServerBareProcedureCallsAsQueryFirst(t *testing.T) {
if !shouldTryQueryResultFirst("sqlserver", `p_get_select c_dyscript,'projectid = 1',1`) {
t.Fatal("expected bare SQL Server procedure call to try query-first")
}
if !shouldTryQueryResultFirst("sqlserver", `dbo.p_get_select c_dyscript,'projectid = 1',1`) {
t.Fatal("expected schema-qualified SQL Server procedure call to try query-first")
}
if !shouldTryQueryResultFirst("sqlserver", `[dbo].[p_get_select] c_dyscript,'projectid = 1',1`) {
t.Fatal("expected bracket-qualified SQL Server procedure call to try query-first")
}
}
func TestShouldTryQueryResultFirst_TreatsReturningAndOutputWritesAsQueryFirst(t *testing.T) {
if !shouldTryQueryResultFirst("postgres", "INSERT INTO audit_logs(id) VALUES (1) RETURNING id") {
t.Fatal("expected INSERT ... RETURNING to try query-first")
}
if !shouldTryQueryResultFirst("sqlserver", "UPDATE users SET name = 'next' OUTPUT inserted.id WHERE id = 1") {
t.Fatal("expected SQL Server OUTPUT DML to try query-first")
}
}
func TestShouldTryQueryResultFirst_TreatsWrappedMessageBlocksAsQueryFirst(t *testing.T) {
if !shouldTryQueryResultFirst("sqlserver", "IF 1 = 1 EXEC dbo.p_get_select @name = 'demo'") {
t.Fatal("expected control-flow wrapped SQL Server procedure call to try query-first")
}
if !shouldTryQueryResultFirst("sqlserver", "BEGIN PRINT 'done'; END") {
t.Fatal("expected SQL Server BEGIN/PRINT block to try query-first")
}
if !shouldTryQueryResultFirst("postgres", "DO $$ BEGIN RAISE NOTICE 'done'; END $$") {
t.Fatal("expected PostgreSQL DO/RAISE NOTICE block to try query-first")
}
}
func TestShouldTryQueryResultFirst_DoesNotMisclassifyPlainSQLServerDML(t *testing.T) {
if shouldTryQueryResultFirst("sqlserver", "UPDATE users SET name = 'next' WHERE id = 1") {
t.Fatal("plain SQL Server UPDATE should not try query-first")
}
}
func TestShouldTryQueryResultFirst_TreatsDataChangingCTESelectAsQueryFirst(t *testing.T) {
query := "WITH moved AS (DELETE FROM audit_logs WHERE created_at < NOW() RETURNING id) SELECT * FROM moved"
if !shouldTryQueryResultFirst("postgres", query) {

View File

@@ -5,6 +5,7 @@ package db
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"net"
"net/url"
@@ -17,7 +18,7 @@ import (
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
_ "github.com/highgo/pq-sm3" // HighGo uses dedicated SM3-capable driver
highgopq "github.com/highgo/pq-sm3" // HighGo uses dedicated SM3-capable driver
)
// HighGoDB implements Database interface for HighGo (瀚高) database
@@ -28,6 +29,13 @@ type HighGoDB struct {
forwarder *ssh.LocalForwarder
}
type highgoSessionExecer struct {
*sqlConnStatementExecer
}
var _ QueryMessageExecer = (*HighGoDB)(nil)
var _ StatementQueryMessageExecer = (*highgoSessionExecer)(nil)
func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string {
// postgres://user:password@host:port/dbname?sslmode=disable
dbname := config.Database
@@ -152,6 +160,20 @@ func (h *HighGoDB) QueryContext(ctx context.Context, query string) ([]map[string
return scanRows(rows)
}
func (h *HighGoDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
if h.conn == nil {
return nil, nil, nil, fmt.Errorf("连接未打开")
}
conn, err := h.conn.Conn(ctx)
if err != nil {
return nil, nil, nil, err
}
defer conn.Close()
return queryHighGoConnWithMessages(ctx, conn, query)
}
func (h *HighGoDB) Query(query string) ([]map[string]interface{}, []string, error) {
if h.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
@@ -165,6 +187,10 @@ func (h *HighGoDB) Query(query string) ([]map[string]interface{}, []string, erro
return scanRows(rows)
}
func (h *HighGoDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
return h.QueryContextWithMessages(context.Background(), query)
}
func (h *HighGoDB) ExecContext(ctx context.Context, query string) (int64, error) {
if h.conn == nil {
return 0, fmt.Errorf("连接未打开")
@@ -195,7 +221,7 @@ func (h *HighGoDB) OpenSessionExecer(ctx context.Context) (StatementExecer, erro
if err != nil {
return nil, err
}
return NewSQLConnStatementExecer(conn), nil
return &highgoSessionExecer{sqlConnStatementExecer: &sqlConnStatementExecer{conn: conn}}, nil
}
func (h *HighGoDB) Exec(query string) (int64, error) {
@@ -209,6 +235,31 @@ func (h *HighGoDB) Exec(query string) (int64, error) {
return res.RowsAffected()
}
func (e *highgoSessionExecer) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
return e.QueryContextWithMessages(context.Background(), query)
}
func (e *highgoSessionExecer) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
if e == nil || e.conn == nil {
return nil, nil, nil, fmt.Errorf("连接未打开")
}
return queryHighGoConnWithMessages(ctx, e.conn, query)
}
func queryHighGoConnWithMessages(ctx context.Context, conn *sql.Conn, query string) ([]map[string]interface{}, []string, []string, error) {
return querySQLConnWithTextNotices(ctx, conn, query, func(driverConn driver.Conn, addNotice func(string)) {
if addNotice == nil {
highgopq.SetNoticeHandler(driverConn, nil)
return
}
highgopq.SetNoticeHandler(driverConn, func(notice *highgopq.Error) {
if notice != nil {
addNotice(notice.Message)
}
})
})
}
func (h *HighGoDB) GetDatabases() ([]string, error) {
data, _, err := h.Query("SELECT datname FROM pg_database WHERE datistemplate = false")
if err != nil {

View File

@@ -5,6 +5,7 @@ package db
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"net"
"net/url"
@@ -18,7 +19,7 @@ import (
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
_ "gitea.com/kingbase/gokb" // Registers "kingbase" driver
gokb "gitea.com/kingbase/gokb" // Registers "kingbase" driver
)
type KingbaseDB struct {
@@ -27,6 +28,13 @@ type KingbaseDB struct {
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
type kingbaseSessionExecer struct {
*sqlConnStatementExecer
}
var _ QueryMessageExecer = (*KingbaseDB)(nil)
var _ StatementQueryMessageExecer = (*kingbaseSessionExecer)(nil)
func quoteConnValue(v string) string {
if v == "" {
return "''"
@@ -280,6 +288,20 @@ func (k *KingbaseDB) QueryContext(ctx context.Context, query string) ([]map[stri
return scanRows(rows)
}
func (k *KingbaseDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
if k.conn == nil {
return nil, nil, nil, fmt.Errorf("连接未打开")
}
conn, err := k.conn.Conn(ctx)
if err != nil {
return nil, nil, nil, err
}
defer conn.Close()
return queryKingbaseConnWithMessages(ctx, conn, query)
}
func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, error) {
if k.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
@@ -293,6 +315,10 @@ func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, er
return scanRows(rows)
}
func (k *KingbaseDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
return k.QueryContextWithMessages(context.Background(), query)
}
func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, error) {
if k.conn == nil {
return 0, fmt.Errorf("连接未打开")
@@ -323,7 +349,7 @@ func (k *KingbaseDB) OpenSessionExecer(ctx context.Context) (StatementExecer, er
if err != nil {
return nil, err
}
return NewSQLConnStatementExecer(conn), nil
return &kingbaseSessionExecer{sqlConnStatementExecer: &sqlConnStatementExecer{conn: conn}}, nil
}
func (k *KingbaseDB) Exec(query string) (int64, error) {
@@ -337,6 +363,31 @@ func (k *KingbaseDB) Exec(query string) (int64, error) {
return res.RowsAffected()
}
func (e *kingbaseSessionExecer) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
return e.QueryContextWithMessages(context.Background(), query)
}
func (e *kingbaseSessionExecer) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
if e == nil || e.conn == nil {
return nil, nil, nil, fmt.Errorf("连接未打开")
}
return queryKingbaseConnWithMessages(ctx, e.conn, query)
}
func queryKingbaseConnWithMessages(ctx context.Context, conn *sql.Conn, query string) ([]map[string]interface{}, []string, []string, error) {
return querySQLConnWithTextNotices(ctx, conn, query, func(driverConn driver.Conn, addNotice func(string)) {
if addNotice == nil {
gokb.SetNoticeHandler(driverConn, nil)
return
}
gokb.SetNoticeHandler(driverConn, func(notice *gokb.Error) {
if notice != nil {
addNotice(notice.Message)
}
})
})
}
func (k *KingbaseDB) GetDatabases() ([]string, error) {
data, _, err := k.Query("SELECT datname FROM pg_database WHERE datistemplate = false")
if err == nil {

View File

@@ -0,0 +1,60 @@
package db
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"strings"
)
type sqlTextNoticeHandlerSetter func(driver.Conn, func(string))
func querySQLConnWithTextNotices(ctx context.Context, conn *sql.Conn, query string, setHandler sqlTextNoticeHandlerSetter) ([]map[string]interface{}, []string, []string, error) {
if conn == nil {
return nil, nil, nil, fmt.Errorf("连接未打开")
}
if setHandler == nil {
return nil, nil, nil, fmt.Errorf("未配置消息捕获处理器")
}
if ctx == nil {
ctx = context.Background()
}
notices := make([]string, 0, 2)
addNotice := func(text string) {
text = strings.TrimSpace(text)
if text != "" {
notices = append(notices, text)
}
}
if err := conn.Raw(func(rawConn interface{}) error {
driverConn, ok := rawConn.(driver.Conn)
if !ok {
return fmt.Errorf("底层连接类型不支持消息捕获")
}
setHandler(driverConn, addNotice)
return nil
}); err != nil {
return nil, nil, nil, err
}
defer func() {
_ = conn.Raw(func(rawConn interface{}) error {
driverConn, ok := rawConn.(driver.Conn)
if ok {
setHandler(driverConn, nil)
}
return nil
})
}()
rows, err := conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, append([]string(nil), notices...), err
}
defer rows.Close()
data, columns, err := scanRows(rows)
return data, columns, append([]string(nil), notices...), err
}

View File

@@ -3,6 +3,7 @@ package db
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"net"
"net/url"
@@ -15,7 +16,7 @@ import (
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
_ "github.com/lib/pq"
"github.com/lib/pq"
)
type PostgresDB struct {
@@ -24,6 +25,13 @@ type PostgresDB struct {
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
}
type postgresSessionExecer struct {
*sqlConnStatementExecer
}
var _ QueryMessageExecer = (*PostgresDB)(nil)
var _ StatementQueryMessageExecer = (*postgresSessionExecer)(nil)
func resolvePostgresConnectDatabases(config connection.ConnectionConfig) []string {
explicit := strings.TrimSpace(config.Database)
if explicit != "" {
@@ -225,6 +233,20 @@ func (p *PostgresDB) QueryContext(ctx context.Context, query string) ([]map[stri
return scanRows(rows)
}
func (p *PostgresDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
if p.conn == nil {
return nil, nil, nil, fmt.Errorf("连接未打开")
}
conn, err := p.conn.Conn(ctx)
if err != nil {
return nil, nil, nil, err
}
defer conn.Close()
return queryPostgresConnWithMessages(ctx, conn, query)
}
func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, error) {
if p.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
@@ -238,6 +260,10 @@ func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, er
return scanRows(rows)
}
func (p *PostgresDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
return p.QueryContextWithMessages(context.Background(), query)
}
func (p *PostgresDB) ExecBatchContext(ctx context.Context, query string) (int64, error) {
if p.conn == nil {
return 0, fmt.Errorf("连接未打开")
@@ -257,7 +283,7 @@ func (p *PostgresDB) OpenSessionExecer(ctx context.Context) (StatementExecer, er
if err != nil {
return nil, err
}
return NewSQLConnStatementExecer(conn), nil
return &postgresSessionExecer{sqlConnStatementExecer: &sqlConnStatementExecer{conn: conn}}, nil
}
func (p *PostgresDB) ExecContext(ctx context.Context, query string) (int64, error) {
@@ -282,6 +308,31 @@ func (p *PostgresDB) Exec(query string) (int64, error) {
return res.RowsAffected()
}
func (e *postgresSessionExecer) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
return e.QueryContextWithMessages(context.Background(), query)
}
func (e *postgresSessionExecer) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
if e == nil || e.conn == nil {
return nil, nil, nil, fmt.Errorf("连接未打开")
}
return queryPostgresConnWithMessages(ctx, e.conn, query)
}
func queryPostgresConnWithMessages(ctx context.Context, conn *sql.Conn, query string) ([]map[string]interface{}, []string, []string, error) {
return querySQLConnWithTextNotices(ctx, conn, query, func(driverConn driver.Conn, addNotice func(string)) {
if addNotice == nil {
pq.SetNoticeHandler(driverConn, nil)
return
}
pq.SetNoticeHandler(driverConn, func(notice *pq.Error) {
if notice != nil {
addNotice(notice.Message)
}
})
})
}
func (p *PostgresDB) GetDatabases() ([]string, error) {
data, _, err := p.Query("SELECT datname FROM pg_database WHERE datistemplate = false")
if err != nil {