mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-16 19:49:51 +08:00
🐛 fix(sql-editor): 修复存储过程与返回结果写语句的结果识别
- 补齐 SQL 分类逻辑,识别 SQL Server 裸存储过程调用、RETURNING/OUTPUT、SELECT INTO 及消息块场景 - 调整多语句执行与批量写入分支,避免返回行或服务端消息被 Exec 路径吞掉 - 为 PostgreSQL、OpenGauss、Kingbase、HighGo 补充 notice 回传能力并增加回归测试
This commit is contained in:
@@ -932,15 +932,23 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
if !allReadOnly {
|
||||
allWrite := true
|
||||
containsPLSQLBlock := false
|
||||
containsQueryFirstWrite := false
|
||||
for _, stmt := range statements {
|
||||
if strings.TrimSpace(stmt) != "" && !isBatchableWriteSQLStatement(runConfig.Type, stmt) {
|
||||
stmt = strings.TrimSpace(stmt)
|
||||
if stmt == "" {
|
||||
continue
|
||||
}
|
||||
if !isBatchableWriteSQLStatement(runConfig.Type, stmt) {
|
||||
allWrite = false
|
||||
}
|
||||
if shouldTryQueryResultFirst(runConfig.Type, stmt) {
|
||||
containsQueryFirstWrite = true
|
||||
}
|
||||
if isPLSQLBlockStatement(stmt) {
|
||||
containsPLSQLBlock = true
|
||||
}
|
||||
}
|
||||
if allWrite && !containsPLSQLBlock && len(statements) == 1 {
|
||||
if allWrite && !containsPLSQLBlock && !containsQueryFirstWrite && len(statements) == 1 {
|
||||
batcher := sessionBatchTarget
|
||||
if batcher == nil {
|
||||
if fallbackBatcher, ok := dbInst.(db.BatchWriteExecer); ok {
|
||||
@@ -1125,8 +1133,8 @@ func shouldUseNativeMultiResultBatch(dbType string, statements []string, allRead
|
||||
}
|
||||
|
||||
func shouldTryQueryResultFirst(dbType string, query string) bool {
|
||||
isSQLServer := strings.EqualFold(strings.TrimSpace(dbType), "sqlserver")
|
||||
if keyword, withHasWrite := sqlDataOperationInfo(query); withHasWrite && keyword == "select" {
|
||||
isSQLServer := isSQLServerDBType(dbType)
|
||||
if sqlWriteStatementReturnsRows(dbType, query) {
|
||||
return true
|
||||
}
|
||||
keyword := leadingSQLKeyword(query)
|
||||
@@ -1137,14 +1145,66 @@ func shouldTryQueryResultFirst(dbType string, query string) bool {
|
||||
return isSQLServer
|
||||
case "dbcc":
|
||||
return isSQLServer
|
||||
case "do":
|
||||
return isPostgresNoticeCapableDBType(dbType) && strings.Contains(strings.ToLower(query), "raise")
|
||||
default:
|
||||
if isSQLServer {
|
||||
return strings.HasPrefix(keyword, "sp_") || strings.HasPrefix(keyword, "xp_")
|
||||
if strings.HasPrefix(keyword, "sp_") || strings.HasPrefix(keyword, "xp_") {
|
||||
return true
|
||||
}
|
||||
if sqlServerControlFlowMayReturnMessages(query) {
|
||||
return true
|
||||
}
|
||||
return looksLikeSQLServerProcedureInvocation(query)
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func looksLikeSQLServerProcedureInvocation(query string) bool {
|
||||
switch leadingSQLKeyword(query) {
|
||||
case "select", "with", "insert", "update", "delete", "merge", "replace", "upsert",
|
||||
"if", "begin", "declare", "while", "create", "alter", "drop", "truncate", "grant", "revoke",
|
||||
"use", "set", "print", "dbcc", "commit", "rollback", "save", "return", "throw", "raiserror",
|
||||
"waitfor", "open", "fetch", "close", "deallocate":
|
||||
return false
|
||||
}
|
||||
|
||||
pos := skipSQLTrivia(query, 0)
|
||||
if pos >= len(query) {
|
||||
return false
|
||||
}
|
||||
|
||||
next, ok := skipSQLIdentifierToken(query, pos)
|
||||
if !ok || next <= pos {
|
||||
return false
|
||||
}
|
||||
pos = skipSQLTrivia(query, next)
|
||||
for pos < len(query) && query[pos] == '.' {
|
||||
pos = skipSQLTrivia(query, pos+1)
|
||||
next, ok = skipSQLIdentifierToken(query, pos)
|
||||
if !ok || next <= pos {
|
||||
return false
|
||||
}
|
||||
pos = skipSQLTrivia(query, next)
|
||||
}
|
||||
|
||||
if pos >= len(query) {
|
||||
return true
|
||||
}
|
||||
switch ch := query[pos]; {
|
||||
case ch == ';' || ch == ',' || ch == '@' || ch == '\'' || ch == '"' || ch == '[' || ch == '(':
|
||||
return true
|
||||
case ch == '+' || ch == '-':
|
||||
return true
|
||||
case ch >= '0' && ch <= '9':
|
||||
return true
|
||||
default:
|
||||
keyword, _ := nextSQLKeyword(query, pos)
|
||||
return keyword != ""
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string, query string) connection.QueryResult {
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
|
||||
|
||||
@@ -1446,3 +1446,231 @@ func TestDBQueryMultiKeepsAllResultSetsFromSingleSQLServerStatement(t *testing.T
|
||||
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTreatsBareSQLServerProcedureCallAsQueryFirst(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
query := `p_get_select c_dyscript,'projectid = 1',1`
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
messageMap: map[string][]string{
|
||||
query: {`INSERT c_dyscript(id,name) values (1,"demo")`},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
|
||||
|
||||
result := app.DBQueryMulti(config, "master", query, "sqlserver-bare-proc-query-first-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
|
||||
}
|
||||
if fakeDB.session == nil {
|
||||
t.Fatal("expected bare SQL Server procedure call to use a pinned query session")
|
||||
}
|
||||
if fakeDB.session.queryCalls != 1 {
|
||||
t.Fatalf("expected one session query call, got %d", fakeDB.session.queryCalls)
|
||||
}
|
||||
if fakeDB.session.execCalls != 0 {
|
||||
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.session.execCalls)
|
||||
}
|
||||
resultSets, ok := result.Data.([]connection.ResultSetData)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
|
||||
}
|
||||
if len(resultSets) != 1 {
|
||||
t.Fatalf("expected one result set, got %#v", resultSets)
|
||||
}
|
||||
if len(resultSets[0].Rows) != 0 || len(resultSets[0].Columns) != 0 {
|
||||
t.Fatalf("expected message-only result set, got %#v", resultSets[0])
|
||||
}
|
||||
if len(resultSets[0].Messages) != 1 || !strings.Contains(resultSets[0].Messages[0], "INSERT c_dyscript") {
|
||||
t.Fatalf("expected procedure output message to be preserved, got %#v", resultSets[0].Messages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTreatsReturningWriteAsQueryFirst(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
query := "INSERT INTO audit_logs(id) VALUES (1) RETURNING id"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
queryMap: map[string][]map[string]interface{}{
|
||||
query: {
|
||||
{"id": 1},
|
||||
},
|
||||
},
|
||||
fieldMap: map[string][]string{
|
||||
query: {"id"},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "postgres", Host: "127.0.0.1", Port: 5432, User: "postgres"}
|
||||
|
||||
result := app.DBQueryMulti(config, "main", query, "postgres-returning-query-first-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
|
||||
}
|
||||
if fakeDB.batchCalls != 0 {
|
||||
t.Fatalf("expected RETURNING write to skip batch exec path, got batchCalls=%d", fakeDB.batchCalls)
|
||||
}
|
||||
if fakeDB.session == nil || fakeDB.session.queryCalls != 1 {
|
||||
t.Fatalf("expected RETURNING write to query through pinned session, got session=%#v", fakeDB.session)
|
||||
}
|
||||
if fakeDB.session.execCalls != 0 {
|
||||
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.session.execCalls)
|
||||
}
|
||||
resultSets, ok := result.Data.([]connection.ResultSetData)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
|
||||
}
|
||||
if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 || resultSets[0].Rows[0]["id"] != 1 {
|
||||
t.Fatalf("expected RETURNING rows to be preserved, got %#v", resultSets)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTreatsSQLServerOutputWriteAsQueryFirst(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
query := "UPDATE users SET name = 'next' OUTPUT inserted.id WHERE id = 1"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
queryMap: map[string][]map[string]interface{}{
|
||||
query: {
|
||||
{"id": 1},
|
||||
},
|
||||
},
|
||||
fieldMap: map[string][]string{
|
||||
query: {"id"},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
|
||||
|
||||
result := app.DBQueryMulti(config, "master", query, "sqlserver-output-query-first-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
|
||||
}
|
||||
if fakeDB.batchCalls != 0 {
|
||||
t.Fatalf("expected OUTPUT write to skip batch exec path, got batchCalls=%d", fakeDB.batchCalls)
|
||||
}
|
||||
if fakeDB.session == nil || fakeDB.session.queryCalls != 1 {
|
||||
t.Fatalf("expected OUTPUT write to query through pinned session, got session=%#v", fakeDB.session)
|
||||
}
|
||||
if fakeDB.session.execCalls != 0 {
|
||||
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.session.execCalls)
|
||||
}
|
||||
resultSets, ok := result.Data.([]connection.ResultSetData)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
|
||||
}
|
||||
if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 || resultSets[0].Rows[0]["id"] != 1 {
|
||||
t.Fatalf("expected OUTPUT rows to be preserved, got %#v", resultSets)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTreatsWrappedMessageBlocksAsQueryFirst(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
sqlServerQuery := "IF 1 = 1 PRINT 'done'"
|
||||
postgresQuery := "DO $$ BEGIN RAISE NOTICE 'done'; END $$"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
queryMap: map[string][]map[string]interface{}{
|
||||
sqlServerQuery: {},
|
||||
postgresQuery: {},
|
||||
},
|
||||
fieldMap: map[string][]string{
|
||||
sqlServerQuery: {},
|
||||
postgresQuery: {},
|
||||
},
|
||||
messageMap: map[string][]string{
|
||||
sqlServerQuery: {"done"},
|
||||
postgresQuery: {"done"},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
|
||||
sqlServerResult := app.DBQueryMulti(connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}, "master", sqlServerQuery, "sqlserver-print-block-test")
|
||||
if !sqlServerResult.Success {
|
||||
t.Fatalf("expected SQL Server block success, got failure: %s", sqlServerResult.Message)
|
||||
}
|
||||
|
||||
postgresResult := app.DBQueryMulti(connection.ConnectionConfig{Type: "postgres", Host: "127.0.0.1", Port: 5432, User: "postgres"}, "main", postgresQuery, "postgres-notice-block-test")
|
||||
if !postgresResult.Success {
|
||||
t.Fatalf("expected PostgreSQL notice block success, got failure: %s", postgresResult.Message)
|
||||
}
|
||||
if fakeDB.batchCalls != 0 {
|
||||
t.Fatalf("expected message blocks to skip batch exec path, got batchCalls=%d", fakeDB.batchCalls)
|
||||
}
|
||||
if fakeDB.execCalls != 0 {
|
||||
t.Fatalf("expected message blocks to avoid shared exec path, got execCalls=%d", fakeDB.execCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiTransactionalTreatsSelectIntoAsManagedWrite(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
stmt := "SELECT * INTO archived_users FROM users"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
execAffected: map[string]int64{
|
||||
stmt: 12,
|
||||
},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "postgres", Host: "127.0.0.1", Port: 5432, User: "postgres"}
|
||||
|
||||
result := app.DBQueryMultiTransactional(config, "main", stmt, "select-into-managed-tx-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected managed SELECT INTO success, got failure: %s", result.Message)
|
||||
}
|
||||
if result.TransactionID == "" || !result.TransactionPending {
|
||||
t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending)
|
||||
}
|
||||
if fakeDB.session == nil || fakeDB.session.closed {
|
||||
t.Fatal("expected managed SELECT INTO transaction session to stay open")
|
||||
}
|
||||
wantExecs := []string{"BEGIN", stmt}
|
||||
if len(fakeDB.execQueries) != len(wantExecs) {
|
||||
t.Fatalf("expected exec queries %#v, got %#v", wantExecs, fakeDB.execQueries)
|
||||
}
|
||||
for i, want := range wantExecs {
|
||||
if fakeDB.execQueries[i] != want {
|
||||
t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -806,12 +806,10 @@ func isSQLFileBatchableWriteStatement(dbType string, stmt string) bool {
|
||||
if isPLSQLBlockStatement(stmt) {
|
||||
return false
|
||||
}
|
||||
switch leadingSQLKeyword(stmt) {
|
||||
case "insert", "update", "delete", "replace", "merge", "upsert":
|
||||
return true
|
||||
default:
|
||||
if shouldTryQueryResultFirst(dbType, stmt) {
|
||||
return false
|
||||
}
|
||||
return isBatchableWriteSQLStatement(dbType, stmt)
|
||||
}
|
||||
|
||||
func sqlFileBatchTransactionSQL(dbType string) (beginSQL string, commitSQL string, rollbackSQL string, ok bool) {
|
||||
|
||||
@@ -336,6 +336,105 @@ func isSQLKeywordByte(ch byte) bool {
|
||||
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_'
|
||||
}
|
||||
|
||||
func normalizeSQLClassifierDBType(dbType string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(dbType))
|
||||
switch normalized {
|
||||
case "postgresql":
|
||||
return "postgres"
|
||||
case "mssql", "sql_server", "sql-server":
|
||||
return "sqlserver"
|
||||
case "open_gauss", "open-gauss":
|
||||
return "opengauss"
|
||||
case "gauss_db", "gauss-db":
|
||||
return "gaussdb"
|
||||
case "kingbase8", "kingbasees", "kingbasev8":
|
||||
return "kingbase"
|
||||
default:
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
|
||||
func isSQLServerDBType(dbType string) bool {
|
||||
return normalizeSQLClassifierDBType(dbType) == "sqlserver"
|
||||
}
|
||||
|
||||
func isOracleLikeDBType(dbType string) bool {
|
||||
switch normalizeSQLClassifierDBType(dbType) {
|
||||
case "oracle", "dameng":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isPostgresNoticeCapableDBType(dbType string) bool {
|
||||
switch normalizeSQLClassifierDBType(dbType) {
|
||||
case "postgres", "opengauss":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isSQLSelectIntoStatement(query string) bool {
|
||||
keyword, _ := sqlDataOperationInfo(query)
|
||||
return keyword == "select" && findTopLevelSQLKeyword(query, 0, "into") >= 0
|
||||
}
|
||||
|
||||
func sqlContainsKeyword(text string, want string) bool {
|
||||
for pos := 0; pos < len(text); {
|
||||
if next, ok := skipSQLQuotedOrComment(text, pos); ok {
|
||||
pos = next
|
||||
continue
|
||||
}
|
||||
if isSQLKeywordByte(text[pos]) {
|
||||
end := pos + 1
|
||||
for end < len(text) && isSQLKeywordByte(text[end]) {
|
||||
end++
|
||||
}
|
||||
if strings.EqualFold(text[pos:end], want) {
|
||||
return true
|
||||
}
|
||||
pos = end
|
||||
continue
|
||||
}
|
||||
pos++
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func sqlWriteStatementReturnsRows(dbType string, query string) bool {
|
||||
keyword, withHasWrite := sqlDataOperationInfo(query)
|
||||
if withHasWrite && keyword == "select" {
|
||||
return true
|
||||
}
|
||||
if !isSQLDataWriteKeyword(keyword) {
|
||||
return false
|
||||
}
|
||||
if !isOracleLikeDBType(dbType) && findTopLevelSQLKeyword(query, 0, "returning") >= 0 {
|
||||
return true
|
||||
}
|
||||
if isSQLServerDBType(dbType) && findTopLevelSQLKeyword(query, 0, "output") >= 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func sqlServerControlFlowMayReturnMessages(query string) bool {
|
||||
switch leadingSQLKeyword(query) {
|
||||
case "if", "begin", "while", "declare", "try", "catch":
|
||||
return sqlContainsKeyword(query, "exec") ||
|
||||
sqlContainsKeyword(query, "execute") ||
|
||||
sqlContainsKeyword(query, "print") ||
|
||||
sqlContainsKeyword(query, "raiserror") ||
|
||||
sqlContainsKeyword(query, "throw") ||
|
||||
sqlContainsKeyword(query, "dbcc") ||
|
||||
sqlContainsKeyword(query, "set")
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isReadOnlySQLQuery(dbType string, query string) bool {
|
||||
if strings.ToLower(strings.TrimSpace(dbType)) == "mongodb" && strings.HasPrefix(strings.TrimSpace(query), "{") {
|
||||
return true
|
||||
@@ -345,6 +444,9 @@ func isReadOnlySQLQuery(dbType string, query string) bool {
|
||||
if withHasWrite {
|
||||
return false
|
||||
}
|
||||
if keyword == "select" && isSQLSelectIntoStatement(query) {
|
||||
return false
|
||||
}
|
||||
switch keyword {
|
||||
case "select", "with", "show", "describe", "desc", "explain", "pragma", "values", "consume":
|
||||
return true
|
||||
@@ -362,6 +464,9 @@ func isBatchableWriteSQLStatement(dbType string, query string) bool {
|
||||
if withHasWrite {
|
||||
return true
|
||||
}
|
||||
if keyword == "select" && isSQLSelectIntoStatement(query) {
|
||||
return true
|
||||
}
|
||||
return isSQLDataWriteKeyword(keyword)
|
||||
}
|
||||
|
||||
|
||||
@@ -86,6 +86,13 @@ func TestIsReadOnlySQLQuery_ClassifiesWithByTopLevelOperation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsReadOnlySQLQuery_TreatsSelectIntoAsWrite(t *testing.T) {
|
||||
query := "SELECT * INTO archived_users FROM users"
|
||||
if isReadOnlySQLQuery("postgres", query) {
|
||||
t.Fatal("SELECT INTO should not be treated as read-only")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsReadOnlySQLQuery_TreatsKafkaConsumeAsReadOnly(t *testing.T) {
|
||||
if !isReadOnlySQLQuery("kafka", `CONSUME GROUP "analytics" FROM "orders.events" LIMIT 20`) {
|
||||
t.Fatal("Kafka CONSUME should be treated as read-only")
|
||||
@@ -102,6 +109,9 @@ func TestIsBatchableWriteSQLStatement_OnlyMatchesRealWriteStatements(t *testing.
|
||||
if !isBatchableWriteSQLStatement("postgres", "WITH moved AS (DELETE FROM audit_logs WHERE created_at < NOW() RETURNING id) SELECT * FROM moved") {
|
||||
t.Fatal("expected data-changing CTE to be treated as batchable write")
|
||||
}
|
||||
if !isBatchableWriteSQLStatement("postgres", "SELECT * INTO archived_users FROM users") {
|
||||
t.Fatal("expected SELECT INTO to be treated as batchable write")
|
||||
}
|
||||
if isBatchableWriteSQLStatement("sqlserver", "EXEC sp_who2") {
|
||||
t.Fatal("EXEC should not be treated as batchable write")
|
||||
}
|
||||
@@ -131,6 +141,45 @@ func TestShouldTryQueryResultFirst_TreatsSQLServerSystemCommandsAsQueryFirst(t *
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldTryQueryResultFirst_TreatsSQLServerBareProcedureCallsAsQueryFirst(t *testing.T) {
|
||||
if !shouldTryQueryResultFirst("sqlserver", `p_get_select c_dyscript,'projectid = 1',1`) {
|
||||
t.Fatal("expected bare SQL Server procedure call to try query-first")
|
||||
}
|
||||
if !shouldTryQueryResultFirst("sqlserver", `dbo.p_get_select c_dyscript,'projectid = 1',1`) {
|
||||
t.Fatal("expected schema-qualified SQL Server procedure call to try query-first")
|
||||
}
|
||||
if !shouldTryQueryResultFirst("sqlserver", `[dbo].[p_get_select] c_dyscript,'projectid = 1',1`) {
|
||||
t.Fatal("expected bracket-qualified SQL Server procedure call to try query-first")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldTryQueryResultFirst_TreatsReturningAndOutputWritesAsQueryFirst(t *testing.T) {
|
||||
if !shouldTryQueryResultFirst("postgres", "INSERT INTO audit_logs(id) VALUES (1) RETURNING id") {
|
||||
t.Fatal("expected INSERT ... RETURNING to try query-first")
|
||||
}
|
||||
if !shouldTryQueryResultFirst("sqlserver", "UPDATE users SET name = 'next' OUTPUT inserted.id WHERE id = 1") {
|
||||
t.Fatal("expected SQL Server OUTPUT DML to try query-first")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldTryQueryResultFirst_TreatsWrappedMessageBlocksAsQueryFirst(t *testing.T) {
|
||||
if !shouldTryQueryResultFirst("sqlserver", "IF 1 = 1 EXEC dbo.p_get_select @name = 'demo'") {
|
||||
t.Fatal("expected control-flow wrapped SQL Server procedure call to try query-first")
|
||||
}
|
||||
if !shouldTryQueryResultFirst("sqlserver", "BEGIN PRINT 'done'; END") {
|
||||
t.Fatal("expected SQL Server BEGIN/PRINT block to try query-first")
|
||||
}
|
||||
if !shouldTryQueryResultFirst("postgres", "DO $$ BEGIN RAISE NOTICE 'done'; END $$") {
|
||||
t.Fatal("expected PostgreSQL DO/RAISE NOTICE block to try query-first")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldTryQueryResultFirst_DoesNotMisclassifyPlainSQLServerDML(t *testing.T) {
|
||||
if shouldTryQueryResultFirst("sqlserver", "UPDATE users SET name = 'next' WHERE id = 1") {
|
||||
t.Fatal("plain SQL Server UPDATE should not try query-first")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldTryQueryResultFirst_TreatsDataChangingCTESelectAsQueryFirst(t *testing.T) {
|
||||
query := "WITH moved AS (DELETE FROM audit_logs WHERE created_at < NOW() RETURNING id) SELECT * FROM moved"
|
||||
if !shouldTryQueryResultFirst("postgres", query) {
|
||||
|
||||
@@ -5,6 +5,7 @@ package db
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
@@ -17,7 +18,7 @@ import (
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
_ "github.com/highgo/pq-sm3" // HighGo uses dedicated SM3-capable driver
|
||||
highgopq "github.com/highgo/pq-sm3" // HighGo uses dedicated SM3-capable driver
|
||||
)
|
||||
|
||||
// HighGoDB implements Database interface for HighGo (瀚高) database
|
||||
@@ -28,6 +29,13 @@ type HighGoDB struct {
|
||||
forwarder *ssh.LocalForwarder
|
||||
}
|
||||
|
||||
type highgoSessionExecer struct {
|
||||
*sqlConnStatementExecer
|
||||
}
|
||||
|
||||
var _ QueryMessageExecer = (*HighGoDB)(nil)
|
||||
var _ StatementQueryMessageExecer = (*highgoSessionExecer)(nil)
|
||||
|
||||
func (h *HighGoDB) getDSN(config connection.ConnectionConfig) string {
|
||||
// postgres://user:password@host:port/dbname?sslmode=disable
|
||||
dbname := config.Database
|
||||
@@ -152,6 +160,20 @@ func (h *HighGoDB) QueryContext(ctx context.Context, query string) ([]map[string
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (h *HighGoDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
if h.conn == nil {
|
||||
return nil, nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
conn, err := h.conn.Conn(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return queryHighGoConnWithMessages(ctx, conn, query)
|
||||
}
|
||||
|
||||
func (h *HighGoDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if h.conn == nil {
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
@@ -165,6 +187,10 @@ func (h *HighGoDB) Query(query string) ([]map[string]interface{}, []string, erro
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (h *HighGoDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
return h.QueryContextWithMessages(context.Background(), query)
|
||||
}
|
||||
|
||||
func (h *HighGoDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if h.conn == nil {
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
@@ -195,7 +221,7 @@ func (h *HighGoDB) OpenSessionExecer(ctx context.Context) (StatementExecer, erro
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewSQLConnStatementExecer(conn), nil
|
||||
return &highgoSessionExecer{sqlConnStatementExecer: &sqlConnStatementExecer{conn: conn}}, nil
|
||||
}
|
||||
|
||||
func (h *HighGoDB) Exec(query string) (int64, error) {
|
||||
@@ -209,6 +235,31 @@ func (h *HighGoDB) Exec(query string) (int64, error) {
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (e *highgoSessionExecer) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
return e.QueryContextWithMessages(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *highgoSessionExecer) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
if e == nil || e.conn == nil {
|
||||
return nil, nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
return queryHighGoConnWithMessages(ctx, e.conn, query)
|
||||
}
|
||||
|
||||
func queryHighGoConnWithMessages(ctx context.Context, conn *sql.Conn, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
return querySQLConnWithTextNotices(ctx, conn, query, func(driverConn driver.Conn, addNotice func(string)) {
|
||||
if addNotice == nil {
|
||||
highgopq.SetNoticeHandler(driverConn, nil)
|
||||
return
|
||||
}
|
||||
highgopq.SetNoticeHandler(driverConn, func(notice *highgopq.Error) {
|
||||
if notice != nil {
|
||||
addNotice(notice.Message)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (h *HighGoDB) GetDatabases() ([]string, error) {
|
||||
data, _, err := h.Query("SELECT datname FROM pg_database WHERE datistemplate = false")
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ package db
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
@@ -18,7 +19,7 @@ import (
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
_ "gitea.com/kingbase/gokb" // Registers "kingbase" driver
|
||||
gokb "gitea.com/kingbase/gokb" // Registers "kingbase" driver
|
||||
)
|
||||
|
||||
type KingbaseDB struct {
|
||||
@@ -27,6 +28,13 @@ type KingbaseDB struct {
|
||||
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
|
||||
}
|
||||
|
||||
type kingbaseSessionExecer struct {
|
||||
*sqlConnStatementExecer
|
||||
}
|
||||
|
||||
var _ QueryMessageExecer = (*KingbaseDB)(nil)
|
||||
var _ StatementQueryMessageExecer = (*kingbaseSessionExecer)(nil)
|
||||
|
||||
func quoteConnValue(v string) string {
|
||||
if v == "" {
|
||||
return "''"
|
||||
@@ -280,6 +288,20 @@ func (k *KingbaseDB) QueryContext(ctx context.Context, query string) ([]map[stri
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
if k.conn == nil {
|
||||
return nil, nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
conn, err := k.conn.Conn(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return queryKingbaseConnWithMessages(ctx, conn, query)
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if k.conn == nil {
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
@@ -293,6 +315,10 @@ func (k *KingbaseDB) Query(query string) ([]map[string]interface{}, []string, er
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
return k.QueryContextWithMessages(context.Background(), query)
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if k.conn == nil {
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
@@ -323,7 +349,7 @@ func (k *KingbaseDB) OpenSessionExecer(ctx context.Context) (StatementExecer, er
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewSQLConnStatementExecer(conn), nil
|
||||
return &kingbaseSessionExecer{sqlConnStatementExecer: &sqlConnStatementExecer{conn: conn}}, nil
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) Exec(query string) (int64, error) {
|
||||
@@ -337,6 +363,31 @@ func (k *KingbaseDB) Exec(query string) (int64, error) {
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (e *kingbaseSessionExecer) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
return e.QueryContextWithMessages(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *kingbaseSessionExecer) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
if e == nil || e.conn == nil {
|
||||
return nil, nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
return queryKingbaseConnWithMessages(ctx, e.conn, query)
|
||||
}
|
||||
|
||||
func queryKingbaseConnWithMessages(ctx context.Context, conn *sql.Conn, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
return querySQLConnWithTextNotices(ctx, conn, query, func(driverConn driver.Conn, addNotice func(string)) {
|
||||
if addNotice == nil {
|
||||
gokb.SetNoticeHandler(driverConn, nil)
|
||||
return
|
||||
}
|
||||
gokb.SetNoticeHandler(driverConn, func(notice *gokb.Error) {
|
||||
if notice != nil {
|
||||
addNotice(notice.Message)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (k *KingbaseDB) GetDatabases() ([]string, error) {
|
||||
data, _, err := k.Query("SELECT datname FROM pg_database WHERE datistemplate = false")
|
||||
if err == nil {
|
||||
|
||||
60
internal/db/notice_query.go
Normal file
60
internal/db/notice_query.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type sqlTextNoticeHandlerSetter func(driver.Conn, func(string))
|
||||
|
||||
func querySQLConnWithTextNotices(ctx context.Context, conn *sql.Conn, query string, setHandler sqlTextNoticeHandlerSetter) ([]map[string]interface{}, []string, []string, error) {
|
||||
if conn == nil {
|
||||
return nil, nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
if setHandler == nil {
|
||||
return nil, nil, nil, fmt.Errorf("未配置消息捕获处理器")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
notices := make([]string, 0, 2)
|
||||
addNotice := func(text string) {
|
||||
text = strings.TrimSpace(text)
|
||||
if text != "" {
|
||||
notices = append(notices, text)
|
||||
}
|
||||
}
|
||||
|
||||
if err := conn.Raw(func(rawConn interface{}) error {
|
||||
driverConn, ok := rawConn.(driver.Conn)
|
||||
if !ok {
|
||||
return fmt.Errorf("底层连接类型不支持消息捕获")
|
||||
}
|
||||
setHandler(driverConn, addNotice)
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Raw(func(rawConn interface{}) error {
|
||||
driverConn, ok := rawConn.(driver.Conn)
|
||||
if ok {
|
||||
setHandler(driverConn, nil)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
rows, err := conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, append([]string(nil), notices...), err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
data, columns, err := scanRows(rows)
|
||||
return data, columns, append([]string(nil), notices...), err
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package db
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
@@ -15,7 +16,7 @@ import (
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type PostgresDB struct {
|
||||
@@ -24,6 +25,13 @@ type PostgresDB struct {
|
||||
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
|
||||
}
|
||||
|
||||
type postgresSessionExecer struct {
|
||||
*sqlConnStatementExecer
|
||||
}
|
||||
|
||||
var _ QueryMessageExecer = (*PostgresDB)(nil)
|
||||
var _ StatementQueryMessageExecer = (*postgresSessionExecer)(nil)
|
||||
|
||||
func resolvePostgresConnectDatabases(config connection.ConnectionConfig) []string {
|
||||
explicit := strings.TrimSpace(config.Database)
|
||||
if explicit != "" {
|
||||
@@ -225,6 +233,20 @@ func (p *PostgresDB) QueryContext(ctx context.Context, query string) ([]map[stri
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (p *PostgresDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
if p.conn == nil {
|
||||
return nil, nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
conn, err := p.conn.Conn(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return queryPostgresConnWithMessages(ctx, conn, query)
|
||||
}
|
||||
|
||||
func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if p.conn == nil {
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
@@ -238,6 +260,10 @@ func (p *PostgresDB) Query(query string) ([]map[string]interface{}, []string, er
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (p *PostgresDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
return p.QueryContextWithMessages(context.Background(), query)
|
||||
}
|
||||
|
||||
func (p *PostgresDB) ExecBatchContext(ctx context.Context, query string) (int64, error) {
|
||||
if p.conn == nil {
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
@@ -257,7 +283,7 @@ func (p *PostgresDB) OpenSessionExecer(ctx context.Context) (StatementExecer, er
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewSQLConnStatementExecer(conn), nil
|
||||
return &postgresSessionExecer{sqlConnStatementExecer: &sqlConnStatementExecer{conn: conn}}, nil
|
||||
}
|
||||
|
||||
func (p *PostgresDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
@@ -282,6 +308,31 @@ func (p *PostgresDB) Exec(query string) (int64, error) {
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (e *postgresSessionExecer) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
return e.QueryContextWithMessages(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *postgresSessionExecer) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
if e == nil || e.conn == nil {
|
||||
return nil, nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
return queryPostgresConnWithMessages(ctx, e.conn, query)
|
||||
}
|
||||
|
||||
func queryPostgresConnWithMessages(ctx context.Context, conn *sql.Conn, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
return querySQLConnWithTextNotices(ctx, conn, query, func(driverConn driver.Conn, addNotice func(string)) {
|
||||
if addNotice == nil {
|
||||
pq.SetNoticeHandler(driverConn, nil)
|
||||
return
|
||||
}
|
||||
pq.SetNoticeHandler(driverConn, func(notice *pq.Error) {
|
||||
if notice != nil {
|
||||
addNotice(notice.Message)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (p *PostgresDB) GetDatabases() ([]string, error) {
|
||||
data, _, err := p.Query("SELECT datname FROM pg_database WHERE datistemplate = false")
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user