🐛 fix(sql-editor): 修复结果消息展示与数据目录迁移稳定性

This commit is contained in:
Syngnat
2026-06-04 07:09:42 +08:00
parent 23ac30086f
commit f5166ac3fc
21 changed files with 1608 additions and 153 deletions

View File

@@ -1,6 +1,7 @@
package app
import (
"encoding/json"
"fmt"
"io"
"os"
@@ -17,12 +18,8 @@ import (
"github.com/wailsapp/wails/v2/pkg/runtime"
)
var migratableDataRootEntries = []string{
"connections.json",
"global_proxy.json",
"ai_config.json",
"sessions",
"drivers",
var dataRootMigrationExcludedEntries = map[string]struct{}{
"storage_root.json": {},
}
func (a *App) GetDataRootDirectoryInfo() connection.QueryResult {
@@ -122,22 +119,43 @@ func migrateDataRootContents(sourceRoot string, targetRoot string) error {
if sourceRoot == "" || targetRoot == "" {
return fmt.Errorf("数据目录不能为空")
}
if filepath.Clean(sourceRoot) == filepath.Clean(targetRoot) {
sourceAbs, err := filepath.Abs(sourceRoot)
if err != nil {
return fmt.Errorf("解析源数据目录失败:%w", err)
}
targetAbs, err := filepath.Abs(targetRoot)
if err != nil {
return fmt.Errorf("解析目标数据目录失败:%w", err)
}
if filepath.Clean(sourceAbs) == filepath.Clean(targetAbs) {
return nil
}
if rel, err := filepath.Rel(sourceAbs, targetAbs); err == nil && rel != "." && rel != "" && !strings.HasPrefix(rel, "..") && !filepath.IsAbs(rel) {
return fmt.Errorf("目标数据目录不能位于源目录内部")
}
sourceRoot = sourceAbs
targetRoot = targetAbs
if err := os.MkdirAll(targetRoot, 0o755); err != nil {
return fmt.Errorf("创建目标数据目录失败:%w", err)
}
for _, name := range migratableDataRootEntries {
entries, err := os.ReadDir(sourceRoot)
if err != nil {
return fmt.Errorf("读取源数据目录失败:%w", err)
}
for _, entry := range entries {
name := strings.TrimSpace(entry.Name())
if name == "" {
continue
}
if _, excluded := dataRootMigrationExcludedEntries[name]; excluded {
continue
}
sourcePath := filepath.Join(sourceRoot, name)
info, err := os.Stat(sourcePath)
targetPath := filepath.Join(targetRoot, name)
info, err := entry.Info()
if err != nil {
if os.IsNotExist(err) {
continue
}
return fmt.Errorf("读取源数据失败(%s%w", name, err)
}
targetPath := filepath.Join(targetRoot, name)
if info.IsDir() {
if err := copyDir(sourcePath, targetPath); err != nil {
return fmt.Errorf("迁移目录失败(%s%w", name, err)
@@ -148,6 +166,75 @@ func migrateDataRootContents(sourceRoot string, targetRoot string) error {
return fmt.Errorf("迁移文件失败(%s%w", name, err)
}
}
if err := rewriteMigratedDataRootState(targetRoot); err != nil {
return err
}
return nil
}
func rewriteMigratedDataRootState(targetRoot string) error {
if err := rewriteSecurityUpdateBackupPaths(targetRoot); err != nil {
return err
}
return nil
}
func rewriteSecurityUpdateBackupPaths(targetRoot string) error {
repo := newSecurityUpdateStateRepository(targetRoot)
marker, err := repo.readMarker()
if err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("读取迁移后的安全更新状态失败:%w", err)
}
migrationID := strings.TrimSpace(marker.MigrationID)
if migrationID == "" {
return nil
}
targetBackupPath := repo.backupPath(migrationID)
marker.BackupPath = targetBackupPath
if err := repo.writeMarker(marker); err != nil {
return fmt.Errorf("写入迁移后的安全更新状态失败:%w", err)
}
manifestPath := repo.manifestPath(migrationID)
manifestData, err := os.ReadFile(manifestPath)
if err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("读取迁移后的安全更新备份清单失败:%w", err)
}
} else {
var manifest securityUpdateBackupManifest
if err := json.Unmarshal(manifestData, &manifest); err != nil {
return fmt.Errorf("解析迁移后的安全更新备份清单失败:%w", err)
}
manifest.BackupPath = targetBackupPath
if err := securityUpdateWriteJSONFile(manifestPath, manifest); err != nil {
return fmt.Errorf("写入迁移后的安全更新备份清单失败:%w", err)
}
}
resultPath := repo.resultPath(migrationID)
resultData, err := os.ReadFile(resultPath)
if err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("读取迁移后的安全更新结果失败:%w", err)
}
} else {
var result SecurityUpdateStatus
if err := json.Unmarshal(resultData, &result); err != nil {
return fmt.Errorf("解析迁移后的安全更新结果失败:%w", err)
}
result.BackupPath = targetBackupPath
result.BackupAvailable = strings.TrimSpace(targetBackupPath) != ""
if err := securityUpdateWriteJSONFile(resultPath, result); err != nil {
return fmt.Errorf("写入迁移后的安全更新结果失败:%w", err)
}
}
return nil
}

View File

@@ -105,6 +105,36 @@ func TestMigrateDataRootContentsCopiesSecurityUpdateStateAndRewritesBackupPaths(
}
}
func TestMigrateDataRootContentsToleratesMissingSecurityUpdateArtifacts(t *testing.T) {
sourceRoot := t.TempDir()
targetRoot := filepath.Join(t.TempDir(), "gonavi-data")
sourceRepo := newSecurityUpdateStateRepository(sourceRoot)
started, err := sourceRepo.StartRound(StartSecurityUpdateRequest{SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig})
if err != nil {
t.Fatalf("start security update round failed: %v", err)
}
if err := os.Remove(sourceRepo.manifestPath(started.MigrationID)); err != nil {
t.Fatalf("remove source manifest failed: %v", err)
}
if err := os.Remove(sourceRepo.resultPath(started.MigrationID)); err != nil {
t.Fatalf("remove source result failed: %v", err)
}
if err := migrateDataRootContents(sourceRoot, targetRoot); err != nil {
t.Fatalf("migrateDataRootContents should tolerate missing security update artifacts, got: %v", err)
}
targetRepo := newSecurityUpdateStateRepository(targetRoot)
targetStatus, err := targetRepo.LoadMarker()
if err != nil {
t.Fatalf("load migrated marker failed: %v", err)
}
expectedBackupPath := filepath.Join(targetRoot, securityUpdateBackupRootDirName, started.MigrationID)
if targetStatus.BackupPath != expectedBackupPath {
t.Fatalf("expected migrated marker backupPath %q, got %q", expectedBackupPath, targetStatus.BackupPath)
}
}
func TestMigrateDataRootContentsCopiesDailySecretsForSavedConnections(t *testing.T) {
sourceRoot := t.TempDir()
targetRoot := filepath.Join(t.TempDir(), "gonavi-data")

View File

@@ -623,6 +623,7 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
}()
isReadQuery := isReadOnlySQLQuery(runConfig.Type, query)
tryQueryFirst := shouldTryQueryResultFirst(runConfig.Type, query)
runReadQuery := func(inst db.Database) ([]map[string]interface{}, []string, error) {
if q, ok := inst.(interface {
@@ -633,6 +634,14 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
return inst.Query(query)
}
runReadQueryWithMessages := func(inst db.Database) ([]map[string]interface{}, []string, []string, error) {
if q, ok := inst.(db.QueryMessageExecer); ok {
return q.QueryContextWithMessages(ctx, query)
}
data, columns, err := runReadQuery(inst)
return data, columns, nil, err
}
runExecQuery := func(inst db.Database) (int64, error) {
if e, ok := inst.(interface {
ExecContext(context.Context, string) (int64, error)
@@ -642,8 +651,8 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
return inst.Exec(query)
}
if isReadQuery {
data, columns, err := runReadQuery(dbInst)
if isReadQuery || tryQueryFirst {
data, columns, messages, err := runReadQueryWithMessages(dbInst)
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
@@ -651,32 +660,34 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: retryErr.Error()}
}
data, columns, err = runReadQuery(retryInst)
data, columns, messages, err = runReadQueryWithMessages(retryInst)
}
}
if err != nil {
if err == nil {
return connection.QueryResult{Success: true, Data: data, Fields: columns, Messages: messages, QueryID: queryID}
}
if isReadQuery {
logger.Error(err, "DBQuery 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
return connection.QueryResult{Success: true, Data: data, Fields: columns, QueryID: queryID}
} else {
affected, err := runExecQuery(dbInst)
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
if retryErr != nil {
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: retryErr.Error()}
}
affected, err = runExecQuery(retryInst)
}
}
if err != nil {
logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
return connection.QueryResult{Success: true, Data: map[string]int64{"affectedRows": affected}, QueryID: queryID}
}
affected, err := runExecQuery(dbInst)
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
if retryErr != nil {
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: retryErr.Error()}
}
affected, err = runExecQuery(retryInst)
}
}
if err != nil {
logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
return connection.QueryResult{Success: true, Data: map[string]int64{"affectedRows": affected}, QueryID: queryID}
}
// DBQueryMulti 执行可能包含多条 SQL 语句的查询,返回多个结果集。
@@ -727,20 +738,25 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
}
}
runMultiQuery := func(inst db.Database) ([]connection.ResultSetData, error) {
runMultiQuery := func(inst db.Database) ([]connection.ResultSetData, []string, error) {
if !allReadOnly {
return nil, nil // 包含写操作,走逐条执行路径
return nil, nil, nil // 包含写操作,走逐条执行路径
}
if q, ok := inst.(db.MultiResultQueryMessageExecer); ok {
return q.QueryMultiContextWithMessages(ctx, query)
}
if q, ok := inst.(db.MultiResultQuerierContext); ok {
return q.QueryMultiContext(ctx, query)
results, err := q.QueryMultiContext(ctx, query)
return results, nil, err
}
if q, ok := inst.(db.MultiResultQuerier); ok {
return q.QueryMulti(query)
results, err := q.QueryMulti(query)
return results, nil, err
}
return nil, nil // 返回 nil 表示不支持
return nil, nil, nil // 返回 nil 表示不支持
}
results, err := runMultiQuery(dbInst)
results, resultMessages, err := runMultiQuery(dbInst)
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
@@ -748,7 +764,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
logger.Error(retryErr, "DBQueryMulti 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: retryErr.Error(), QueryID: queryID}
}
results, err = runMultiQuery(retryInst)
results, resultMessages, err = runMultiQuery(retryInst)
}
}
if err != nil {
@@ -758,7 +774,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
// 驱动支持多结果集,直接返回
if results != nil {
return connection.QueryResult{Success: true, Data: results, QueryID: queryID}
return connection.QueryResult{Success: true, Data: results, Messages: resultMessages, QueryID: queryID}
}
// 驱动不支持多结果集,回退到逐条执行
@@ -771,13 +787,50 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
}
}
var sessionQueryTarget db.StatementQueryExecer
var sessionQueryMessageTarget db.StatementQueryMessageExecer
var sessionMultiQueryTarget db.StatementMultiResultQueryExecer
var sessionMultiQueryMessageTarget db.StatementMultiResultQueryMessageExecer
var sessionExecTarget db.StatementExecer
var sessionBatchTarget db.BatchWriteExecer
closeExecTarget := func() {}
if provider, ok := dbInst.(db.SessionExecerProvider); ok {
sessionExecer, sessionErr := provider.OpenSessionExecer(ctx)
if sessionErr != nil {
logger.Warnf("DBQueryMulti 打开会话级执行器失败,将回退共享连接:%s SQL片段=%q err=%v", formatConnSummary(runConfig), sqlSnippet(query), sessionErr)
} else {
if statementQueryExecer, ok := sessionExecer.(db.StatementQueryExecer); ok {
sessionQueryTarget = statementQueryExecer
}
if statementQueryMessageExecer, ok := sessionExecer.(db.StatementQueryMessageExecer); ok {
sessionQueryMessageTarget = statementQueryMessageExecer
}
if statementMultiResultQueryExecer, ok := sessionExecer.(db.StatementMultiResultQueryExecer); ok {
sessionMultiQueryTarget = statementMultiResultQueryExecer
}
if statementMultiResultQueryMessageExecer, ok := sessionExecer.(db.StatementMultiResultQueryMessageExecer); ok {
sessionMultiQueryMessageTarget = statementMultiResultQueryMessageExecer
}
sessionExecTarget = sessionExecer
if batcher, ok := sessionExecer.(db.BatchWriteExecer); ok {
sessionBatchTarget = batcher
}
closeExecTarget = func() {
if err := sessionExecer.Close(); err != nil {
logger.Warnf("DBQueryMulti 关闭会话级执行器失败:%v", err)
}
}
}
}
defer closeExecTarget()
// 全部为写操作且驱动支持批量 Exec → 一次性发送,大幅减少网络往返
// 适用于 MySQL/MariaDB/Doris/PostgreSQL/SQLite/DuckDB 等支持多语句 Exec 的驱动
if !allReadOnly {
allWrite := true
containsPLSQLBlock := false
for _, stmt := range statements {
if strings.TrimSpace(stmt) != "" && isReadOnlySQLQuery(runConfig.Type, stmt) {
if strings.TrimSpace(stmt) != "" && !isBatchableWriteSQLStatement(runConfig.Type, stmt) {
allWrite = false
}
if isPLSQLBlockStatement(stmt) {
@@ -785,7 +838,13 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
}
}
if allWrite && !containsPLSQLBlock {
if batcher, ok := dbInst.(db.BatchWriteExecer); ok {
batcher := sessionBatchTarget
if batcher == nil {
if fallbackBatcher, ok := dbInst.(db.BatchWriteExecer); ok {
batcher = fallbackBatcher
}
}
if batcher != nil {
affected, batchErr := batcher.ExecBatchContext(ctx, query)
if batchErr != nil && shouldRefreshCachedConnection(batchErr) {
if a.invalidateCachedDatabase(runConfig, batchErr) {
@@ -823,17 +882,80 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
continue
}
if isReadOnlySQLQuery(runConfig.Type, stmt) {
var data []map[string]interface{}
var columns []string
if q, ok := dbInst.(interface {
isReadStmt := isReadOnlySQLQuery(runConfig.Type, stmt)
tryQueryStmtFirst := shouldTryQueryResultFirst(runConfig.Type, stmt)
if isReadStmt || tryQueryStmtFirst {
var (
data []map[string]interface{}
columns []string
messages []string
statementResults []connection.ResultSetData
usedMultiResult bool
)
if sessionMultiQueryMessageTarget != nil {
statementResults, messages, err = sessionMultiQueryMessageTarget.QueryMultiContextWithMessages(ctx, stmt)
usedMultiResult = true
} else if sessionMultiQueryTarget != nil {
statementResults, err = sessionMultiQueryTarget.QueryMultiContext(ctx, stmt)
usedMultiResult = true
} else if q, ok := dbInst.(db.MultiResultQueryMessageExecer); ok {
statementResults, messages, err = q.QueryMultiContextWithMessages(ctx, stmt)
usedMultiResult = true
} else if q, ok := dbInst.(db.MultiResultQuerierContext); ok {
statementResults, err = q.QueryMultiContext(ctx, stmt)
usedMultiResult = true
} else if q, ok := dbInst.(db.MultiResultQuerier); ok {
statementResults, err = q.QueryMulti(stmt)
usedMultiResult = true
} else if sessionQueryMessageTarget != nil {
data, columns, messages, err = sessionQueryMessageTarget.QueryContextWithMessages(ctx, stmt)
} else if sessionQueryTarget != nil {
data, columns, err = sessionQueryTarget.QueryContext(ctx, stmt)
} else if q, ok := dbInst.(db.QueryMessageExecer); ok {
data, columns, messages, err = q.QueryContextWithMessages(ctx, stmt)
} else if q, ok := dbInst.(interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}); ok {
data, columns, err = q.QueryContext(ctx, stmt)
} else {
data, columns, err = dbInst.Query(stmt)
}
if err != nil {
if err == nil {
if usedMultiResult {
if len(statementResults) == 0 && len(messages) > 0 {
statementResults = []connection.ResultSetData{{
Rows: []map[string]interface{}{},
Columns: []string{},
Messages: append([]string(nil), messages...),
}}
}
for _, statementResult := range statementResults {
if statementResult.Rows == nil {
statementResult.Rows = []map[string]interface{}{}
}
if statementResult.Columns == nil {
statementResult.Columns = []string{}
}
statementResult.StatementIndex = idx + 1
resultSets = append(resultSets, statementResult)
}
continue
}
if data == nil {
data = make([]map[string]interface{}, 0)
}
if columns == nil {
columns = []string{}
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: data,
Columns: columns,
Messages: messages,
StatementIndex: idx + 1,
})
continue
}
if isReadStmt {
logger.Error(err, "DBQueryMulti 逐条查询失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt))
errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err)
if len(resultSets) > 0 {
@@ -841,35 +963,30 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
}
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
}
if data == nil {
data = make([]map[string]interface{}, 0)
}
if columns == nil {
columns = []string{}
}
resultSets = append(resultSets, connection.ResultSetData{Rows: data, Columns: columns})
} else {
var affected int64
if e, ok := dbInst.(interface {
ExecContext(context.Context, string) (int64, error)
}); ok {
affected, err = e.ExecContext(ctx, stmt)
} else {
affected, err = dbInst.Exec(stmt)
}
if err != nil {
logger.Error(err, "DBQueryMulti 逐条执行失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt))
errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err)
if len(resultSets) > 0 {
errMsg += fmt.Sprintf("(前 %d 条已执行成功)", len(resultSets))
}
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: []map[string]interface{}{{"affectedRows": affected}},
Columns: []string{"affectedRows"},
})
}
var affected int64
if sessionExecTarget != nil {
affected, err = sessionExecTarget.ExecContext(ctx, stmt)
} else if e, ok := dbInst.(interface {
ExecContext(context.Context, string) (int64, error)
}); ok {
affected, err = e.ExecContext(ctx, stmt)
} else {
affected, err = dbInst.Exec(stmt)
}
if err != nil {
logger.Error(err, "DBQueryMulti 逐条执行失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt))
errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err)
if len(resultSets) > 0 {
errMsg += fmt.Sprintf("(前 %d 条已执行成功)", len(resultSets))
}
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: []map[string]interface{}{{"affectedRows": affected}},
Columns: []string{"affectedRows"},
})
}
if resultSets == nil {
@@ -883,6 +1000,24 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
return connection.QueryResult{Success: true, Data: resultSets, QueryID: queryID, Message: fallbackMsg}
}
func shouldTryQueryResultFirst(dbType string, query string) bool {
isSQLServer := strings.EqualFold(strings.TrimSpace(dbType), "sqlserver")
keyword := leadingSQLKeyword(query)
switch keyword {
case "exec", "execute", "call":
return true
case "set", "print":
return isSQLServer
case "dbcc":
return isSQLServer
default:
if isSQLServer {
return strings.HasPrefix(keyword, "sp_") || strings.HasPrefix(keyword, "xp_")
}
return false
}
}
func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string, query string) connection.QueryResult {
runConfig := normalizeRunConfig(config, dbName)
@@ -906,22 +1041,30 @@ func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string,
defer cancel()
isReadQuery := isReadOnlySQLQuery(runConfig.Type, query)
tryQueryFirst := shouldTryQueryResultFirst(runConfig.Type, query)
if isReadQuery {
var data []map[string]interface{}
var columns []string
if q, ok := dbInst.(interface {
if isReadQuery || tryQueryFirst {
var (
data []map[string]interface{}
columns []string
messages []string
)
if q, ok := dbInst.(db.QueryMessageExecer); ok {
data, columns, messages, err = q.QueryContextWithMessages(ctx, query)
} else if q, ok := dbInst.(interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}); ok {
data, columns, err = q.QueryContext(ctx, query)
} else {
data, columns, err = dbInst.Query(query)
}
if err != nil {
if err == nil {
return connection.QueryResult{Success: true, Data: data, Fields: columns, Messages: messages}
}
if isReadQuery {
logger.Error(err, "DBQueryIsolated 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Data: data, Fields: columns}
}
var affected int64

View File

@@ -10,10 +10,17 @@ import (
)
type fakeBatchWriteDB struct {
batchCalls int
execCalls int
batchCalls int
execCalls int
execQueries []string
lastQuery string
lastQuery string
queryCalls int
queryMap map[string][]map[string]interface{}
fieldMap map[string][]string
messageMap map[string][]string
multiResult map[string][]connection.ResultSetData
queryErr map[string]error
session *fakeBatchWriteSession
}
func (f *fakeBatchWriteDB) Connect(config connection.ConnectionConfig) error {
@@ -29,7 +36,16 @@ func (f *fakeBatchWriteDB) Ping() error {
}
func (f *fakeBatchWriteDB) Query(query string) ([]map[string]interface{}, []string, error) {
return nil, nil, nil
f.queryCalls++
if err := f.queryErr[query]; err != nil {
return nil, nil, err
}
return f.queryMap[query], f.fieldMap[query], nil
}
func (f *fakeBatchWriteDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
rows, fields, err := f.Query(query)
return rows, fields, f.messageMap[query], err
}
func (f *fakeBatchWriteDB) Exec(query string) (int64, error) {
@@ -76,12 +92,143 @@ func (f *fakeBatchWriteDB) ExecContext(ctx context.Context, query string) (int64
return 1, nil
}
func (f *fakeBatchWriteDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
f.queryCalls++
if err := f.queryErr[query]; err != nil {
return nil, nil, err
}
return f.queryMap[query], f.fieldMap[query], nil
}
func (f *fakeBatchWriteDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
rows, fields, err := f.QueryContext(ctx, query)
return rows, fields, f.messageMap[query], err
}
func (f *fakeBatchWriteDB) ExecBatchContext(ctx context.Context, query string) (int64, error) {
f.batchCalls++
f.lastQuery = query
return 500, nil
}
func (f *fakeBatchWriteDB) OpenSessionExecer(ctx context.Context) (db.StatementExecer, error) {
f.session = &fakeBatchWriteSession{parent: f}
return f.session, nil
}
type fakeBatchWriteSession struct {
parent *fakeBatchWriteDB
queryCalls int
execCalls int
batchCalls int
closed bool
}
func (s *fakeBatchWriteSession) Query(query string) ([]map[string]interface{}, []string, error) {
return s.QueryContext(context.Background(), query)
}
func (s *fakeBatchWriteSession) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
s.queryCalls++
return s.parent.QueryContext(ctx, query)
}
func (s *fakeBatchWriteSession) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
return s.QueryContextWithMessages(context.Background(), query)
}
func (s *fakeBatchWriteSession) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
s.queryCalls++
return s.parent.QueryContextWithMessages(ctx, query)
}
func (s *fakeBatchWriteSession) QueryMulti(query string) ([]connection.ResultSetData, error) {
return s.QueryMultiContext(context.Background(), query)
}
func (s *fakeBatchWriteSession) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
if multi := s.parent.multiResult[query]; len(multi) > 0 {
s.queryCalls++
return cloneResultSets(multi), nil
}
rows, columns, err := s.QueryContext(ctx, query)
if err != nil {
return nil, err
}
return []connection.ResultSetData{{Rows: rows, Columns: columns}}, nil
}
func (s *fakeBatchWriteSession) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
return s.QueryMultiContextWithMessages(context.Background(), query)
}
func (s *fakeBatchWriteSession) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
if err := s.parent.queryErr[query]; err != nil {
s.queryCalls++
return nil, nil, err
}
if multi := s.parent.multiResult[query]; len(multi) > 0 {
s.queryCalls++
return cloneResultSets(multi), append([]string(nil), s.parent.messageMap[query]...), nil
}
rows, columns, messages, err := s.QueryContextWithMessages(ctx, query)
if err != nil {
return nil, nil, err
}
return []connection.ResultSetData{{
Rows: rows,
Columns: columns,
Messages: append([]string(nil), messages...),
}}, append([]string(nil), messages...), nil
}
func (s *fakeBatchWriteSession) Exec(query string) (int64, error) {
return s.ExecContext(context.Background(), query)
}
func (s *fakeBatchWriteSession) ExecContext(ctx context.Context, query string) (int64, error) {
s.execCalls++
return s.parent.ExecContext(ctx, query)
}
func (s *fakeBatchWriteSession) ExecBatchContext(ctx context.Context, query string) (int64, error) {
s.batchCalls++
return s.parent.ExecBatchContext(ctx, query)
}
func (s *fakeBatchWriteSession) Close() error {
s.closed = true
return nil
}
func cloneResultSets(input []connection.ResultSetData) []connection.ResultSetData {
if len(input) == 0 {
return nil
}
cloned := make([]connection.ResultSetData, 0, len(input))
for _, item := range input {
rows := make([]map[string]interface{}, 0, len(item.Rows))
for _, row := range item.Rows {
if row == nil {
rows = append(rows, nil)
continue
}
rowCopy := make(map[string]interface{}, len(row))
for key, value := range row {
rowCopy[key] = value
}
rows = append(rows, rowCopy)
}
cloned = append(cloned, connection.ResultSetData{
Rows: rows,
Columns: append([]string(nil), item.Columns...),
Messages: append([]string(nil), item.Messages...),
StatementIndex: item.StatementIndex,
})
}
return cloned
}
func TestDBQueryMultiKeepsOracleAnonymousBlockAsSingleStatement(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
@@ -122,6 +269,85 @@ END;`
}
var _ db.BatchWriteExecer = (*fakeBatchWriteDB)(nil)
var _ db.SessionExecerProvider = (*fakeBatchWriteDB)(nil)
var _ db.QueryMessageExecer = (*fakeBatchWriteDB)(nil)
var _ db.StatementQueryMessageExecer = (*fakeBatchWriteSession)(nil)
func TestDBQueryWithCancelReturnsResultSetForExecStoredProcedure(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
query := "EXEC sp_who2"
fakeDB := &fakeBatchWriteDB{
queryMap: map[string][]map[string]interface{}{
query: {
{"SPID": 52, "STATUS": "RUNNABLE"},
},
},
fieldMap: map[string][]string{
query: {"SPID", "STATUS"},
},
queryErr: map[string]error{},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
result := app.DBQueryWithCancel(config, "master", query, "sp-who2-test")
if !result.Success {
t.Fatalf("expected DBQueryWithCancel success, got failure: %s", result.Message)
}
rows, ok := result.Data.([]map[string]interface{})
if !ok {
t.Fatalf("expected []map[string]interface{}, got %T", result.Data)
}
if len(rows) != 1 || rows[0]["SPID"] != 52 {
t.Fatalf("unexpected rows: %#v", rows)
}
if fakeDB.execCalls != 0 {
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls)
}
}
func TestDBQueryWithCancelReturnsMessagesForSQLServerQuery(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
query := "SET STATISTICS IO ON"
fakeDB := &fakeBatchWriteDB{
queryMap: map[string][]map[string]interface{}{
query: {},
},
fieldMap: map[string][]string{
query: {},
},
messageMap: map[string][]string{
query: {"Table 'users'. Scan count 1, logical reads 3."},
},
queryErr: map[string]error{},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
result := app.DBQueryWithCancel(config, "master", query, "statistics-io-test")
if !result.Success {
t.Fatalf("expected DBQueryWithCancel success, got failure: %s", result.Message)
}
if len(result.Messages) != 1 || result.Messages[0] == "" {
t.Fatalf("expected SQL Server messages to be returned, got %#v", result.Messages)
}
}
func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
@@ -168,3 +394,206 @@ func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) {
t.Fatalf("expected affectedRows=500, got %#v", got)
}
}
func TestDBQueryMultiPrefersResultSetForExecStoredProcedure(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
query := "EXEC sp_who2"
fakeDB := &fakeBatchWriteDB{
queryMap: map[string][]map[string]interface{}{
query: {
{"SPID": 77, "STATUS": "SUSPENDED"},
},
},
fieldMap: map[string][]string{
query: {"SPID", "STATUS"},
},
queryErr: map[string]error{},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
result := app.DBQueryMulti(config, "master", query, "sp-who2-multi-test")
if !result.Success {
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
}
resultSets, ok := result.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
}
if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 {
t.Fatalf("unexpected result sets: %#v", resultSets)
}
if got := resultSets[0].Rows[0]["SPID"]; got != 77 {
t.Fatalf("expected SPID=77, got %#v", got)
}
if fakeDB.execCalls != 0 {
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls)
}
}
func TestDBQueryMultiDoesNotBatchExecStoredProcedureAsWriteStatement(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
query := "EXEC sp_who2"
fakeDB := &fakeBatchWriteDB{
queryMap: map[string][]map[string]interface{}{
query: {
{"SPID": 88, "STATUS": "RUNNING"},
},
},
fieldMap: map[string][]string{
query: {"SPID", "STATUS"},
},
queryErr: map[string]error{},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
result := app.DBQueryMulti(config, "master", query, "sp-who2-batch-guard-test")
if !result.Success {
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
}
if fakeDB.batchCalls != 0 {
t.Fatalf("expected stored procedure to skip batch write path, got batchCalls=%d", fakeDB.batchCalls)
}
resultSets, ok := result.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
}
if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 {
t.Fatalf("unexpected result sets: %#v", resultSets)
}
if got := resultSets[0].Rows[0]["SPID"]; got != 88 {
t.Fatalf("expected SPID=88, got %#v", got)
}
}
func TestDBQueryMultiUsesPinnedSessionForSequentialFallback(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
fakeDB := &fakeBatchWriteDB{
queryMap: map[string][]map[string]interface{}{
"SELECT 1 AS value": {
{"value": 1},
},
},
fieldMap: map[string][]string{
"SELECT 1 AS value": {"value"},
},
messageMap: map[string][]string{
"SET NOCOUNT ON": {"NOCOUNT 已开启"},
},
queryErr: map[string]error{},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
result := app.DBQueryMulti(config, "master", "SET NOCOUNT ON;\nSELECT 1 AS value;", "session-fallback-test")
if !result.Success {
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
}
if fakeDB.session == nil {
t.Fatal("expected DBQueryMulti to open a pinned session for sequential fallback")
}
if !fakeDB.session.closed {
t.Fatal("expected DBQueryMulti to close the pinned session")
}
if fakeDB.session.execCalls != 0 {
t.Fatalf("expected SQL Server SET statement to avoid exec-only path, got execCalls=%d", fakeDB.session.execCalls)
}
if fakeDB.session.queryCalls != 2 {
t.Fatalf("expected both statements to query through pinned session, got queryCalls=%d", fakeDB.session.queryCalls)
}
if fakeDB.queryCalls != 2 {
t.Fatalf("expected exactly two underlying query calls, got %d", fakeDB.queryCalls)
}
resultSets, ok := result.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
}
if len(resultSets) != 2 {
t.Fatalf("expected two result sets, got %#v", resultSets)
}
if len(resultSets[0].Messages) != 1 || resultSets[0].Messages[0] != "NOCOUNT 已开启" {
t.Fatalf("expected first result set to keep session message, got %#v", resultSets[0].Messages)
}
if got := resultSets[1].Rows[0]["value"]; got != 1 {
t.Fatalf("expected second result set value=1, got %#v", got)
}
}
func TestDBQueryMultiKeepsAllResultSetsFromSingleSQLServerStatement(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
query := "EXEC sp_helpdb"
fakeDB := &fakeBatchWriteDB{
multiResult: map[string][]connection.ResultSetData{
query: {
{
Rows: []map[string]interface{}{{"name": "master"}},
Columns: []string{"name"},
},
{
Rows: []map[string]interface{}{{"owner": "sa"}},
Columns: []string{"owner"},
},
},
},
queryErr: map[string]error{},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
result := app.DBQueryMulti(config, "master", query, "sp-helpdb-multi-result-test")
if !result.Success {
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
}
resultSets, ok := result.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
}
if len(resultSets) != 2 {
t.Fatalf("expected two result sets, got %#v", resultSets)
}
if got := resultSets[0].Rows[0]["name"]; got != "master" {
t.Fatalf("expected first result set to keep master row, got %#v", got)
}
if got := resultSets[1].Rows[0]["owner"]; got != "sa" {
t.Fatalf("expected second result set to keep owner row, got %#v", got)
}
if resultSets[0].StatementIndex != 1 || resultSets[1].StatementIndex != 1 {
t.Fatalf("expected both result sets to map to the first statement, got %#v", resultSets)
}
if fakeDB.execCalls != 0 {
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls)
}
}

View File

@@ -65,6 +65,19 @@ func isReadOnlySQLQuery(dbType string, query string) bool {
}
}
func isBatchableWriteSQLStatement(dbType string, query string) bool {
if isReadOnlySQLQuery(dbType, query) {
return false
}
switch leadingSQLKeyword(query) {
case "insert", "update", "delete", "replace", "merge", "upsert":
return true
default:
return false
}
}
func sanitizeSQLForPgLike(dbType string, query string) string {
normalizedType := strings.ToLower(strings.TrimSpace(dbType))
switch normalizedType {

View File

@@ -62,3 +62,42 @@ func TestSanitizeSQLForPgLike_DoesNotModifyOtherDBTypes(t *testing.T) {
t.Fatalf("non-PG-like db should not be sanitized:\nIN: %s\nOUT: %s", in, out)
}
}
func TestIsReadOnlySQLQuery_DoesNotTreatExecAsReadOnly(t *testing.T) {
if isReadOnlySQLQuery("sqlserver", "EXEC sp_who2") {
t.Fatal("EXEC should not be treated as read-only SQL")
}
}
func TestIsBatchableWriteSQLStatement_OnlyMatchesRealWriteStatements(t *testing.T) {
if !isBatchableWriteSQLStatement("mysql", "INSERT INTO demo(id) VALUES (1)") {
t.Fatal("expected INSERT to be treated as batchable write")
}
if isBatchableWriteSQLStatement("sqlserver", "EXEC sp_who2") {
t.Fatal("EXEC should not be treated as batchable write")
}
if isBatchableWriteSQLStatement("sqlserver", "SET STATISTICS IO ON") {
t.Fatal("SET STATISTICS should not be treated as batchable write")
}
}
func TestShouldTryQueryResultFirst_TreatsSQLServerSetAsQueryFirst(t *testing.T) {
if !shouldTryQueryResultFirst("sqlserver", "SET STATISTICS IO ON") {
t.Fatal("expected SQL Server SET STATISTICS to try query-first for notice capture")
}
if shouldTryQueryResultFirst("mysql", "SET sql_mode = ''") {
t.Fatal("non-SQLServer SET should not force query-first")
}
}
func TestShouldTryQueryResultFirst_TreatsSQLServerSystemCommandsAsQueryFirst(t *testing.T) {
if !shouldTryQueryResultFirst("sqlserver", "sp_who2") {
t.Fatal("expected bare SQL Server system procedure to try query-first")
}
if !shouldTryQueryResultFirst("sqlserver", "DBCC INPUTBUFFER(52)") {
t.Fatal("expected SQL Server DBCC command to try query-first")
}
if shouldTryQueryResultFirst("mysql", "sp_who2") {
t.Fatal("non-SQLServer system procedure name should not force query-first")
}
}

View File

@@ -45,7 +45,8 @@ func (p *panicChromium) PutZoomFactor(float64) {
}
type panicFrontend struct {
chromium *panicChromium
chromium *panicChromium
mainWindow *fakeWindow
}
// 测试必须用 wails 一致的 string key "frontend" 作为 context.WithValue 的 key
@@ -122,7 +123,10 @@ func TestResetWebViewZoomFactorErrorsWhenFrontendMissing(t *testing.T) {
}
func TestResetWebViewZoomFactorRecoversFromPutZoomFactorPanic(t *testing.T) {
ctx := context.WithValue(context.Background(), stringContextKey("frontend"), &panicFrontend{chromium: &panicChromium{}})
ctx := context.WithValue(context.Background(), stringContextKey("frontend"), &panicFrontend{
chromium: &panicChromium{},
mainWindow: &fakeWindow{},
})
err := resetWebViewZoomFactor(ctx, 1.0)
if err == nil {

View File

@@ -122,17 +122,20 @@ type ConnectionConfig struct {
// ResultSetData 表示一个查询结果集(行 + 列名),用于多结果集场景。
type ResultSetData struct {
Rows []map[string]interface{} `json:"rows"`
Columns []string `json:"columns"`
Rows []map[string]interface{} `json:"rows"`
Columns []string `json:"columns"`
Messages []string `json:"messages,omitempty"`
StatementIndex int `json:"statementIndex,omitempty"`
}
// QueryResult 是 Wails 绑定方法的统一响应格式,前端通过此结构体接收后端结果。
type QueryResult struct {
Success bool `json:"success"`
Message string `json:"message"`
Data interface{} `json:"data"`
Fields []string `json:"fields,omitempty"`
QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation
Success bool `json:"success"`
Message string `json:"message"`
Data interface{} `json:"data"`
Fields []string `json:"fields,omitempty"`
Messages []string `json:"messages,omitempty"`
QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation
}
// ColumnDefinition 描述表的一个列定义。

View File

@@ -67,6 +67,51 @@ type StatementExecer interface {
Close() error
}
// StatementQueryExecer can run queries on a pinned session/connection.
// Drivers that return sqlConnStatementExecer automatically satisfy it.
type StatementQueryExecer interface {
StatementExecer
Query(query string) ([]map[string]interface{}, []string, error)
QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error)
}
// StatementQueryMessageExecer can run queries on a pinned session and return
// extra server messages/notices alongside rows.
type StatementQueryMessageExecer interface {
StatementQueryExecer
QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error)
QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error)
}
// StatementMultiResultQueryExecer can run multi-result queries on a pinned session/connection.
type StatementMultiResultQueryExecer interface {
StatementExecer
QueryMulti(query string) ([]connection.ResultSetData, error)
QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error)
}
// StatementMultiResultQueryMessageExecer can run multi-result queries on a
// pinned session/connection and return server messages/notices.
type StatementMultiResultQueryMessageExecer interface {
StatementMultiResultQueryExecer
QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error)
QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error)
}
// QueryMessageExecer is an optional database-level interface for returning
// informational server messages alongside one result set.
type QueryMessageExecer interface {
QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error)
QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error)
}
// MultiResultQueryMessageExecer is an optional database-level interface for
// returning informational server messages alongside multi-result queries.
type MultiResultQueryMessageExecer interface {
QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error)
QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error)
}
// SessionExecerProvider is implemented by database/sql based drivers that can
// pin a long-running job to one physical connection.
type SessionExecerProvider interface {
@@ -96,6 +141,38 @@ func (e *sqlConnStatementExecer) Exec(query string) (int64, error) {
return e.ExecContext(context.Background(), query)
}
func (e *sqlConnStatementExecer) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if e == nil || e.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := e.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (e *sqlConnStatementExecer) Query(query string) ([]map[string]interface{}, []string, error) {
return e.QueryContext(context.Background(), query)
}
func (e *sqlConnStatementExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
if e == nil || e.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
rows, err := e.conn.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanMultiRows(rows)
}
func (e *sqlConnStatementExecer) QueryMulti(query string) ([]connection.ResultSetData, error) {
return e.QueryMultiContext(context.Background(), query)
}
func (e *sqlConnStatementExecer) ExecBatchContext(ctx context.Context, query string) (int64, error) {
return e.ExecContext(ctx, query)
}

View File

@@ -17,6 +17,7 @@ import (
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
"github.com/golang-sql/sqlexp"
_ "github.com/microsoft/go-mssqldb"
)
@@ -26,6 +27,85 @@ type SqlServerDB struct {
forwarder *ssh.LocalForwarder
}
type sqlServerSessionExecer struct {
conn *sql.Conn
}
func scanSQLServerRowsWithMessages(ctx context.Context, rows *sql.Rows, retmsg *sqlexp.ReturnMessage) ([]connection.ResultSetData, []string, error) {
if rows == nil {
return []connection.ResultSetData{{Rows: []map[string]interface{}{}, Columns: []string{}}}, nil, nil
}
if ctx == nil {
ctx = context.Background()
}
var (
resultSets []connection.ResultSetData
messages []string
allMessages []string
)
active := true
for active {
raw := retmsg.Message(ctx)
switch msg := raw.(type) {
case sqlexp.MsgNotice:
text := strings.TrimSpace(fmt.Sprint(msg.Message))
if text != "" {
messages = append(messages, text)
allMessages = append(allMessages, text)
}
case sqlexp.MsgNext:
data, cols, err := scanRows(rows)
if err != nil {
return resultSets, messages, err
}
if data == nil {
data = []map[string]interface{}{}
}
if cols == nil {
cols = []string{}
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: data,
Columns: cols,
Messages: append([]string(nil), messages...),
})
messages = nil
case sqlexp.MsgRowsAffected:
resultSets = append(resultSets, connection.ResultSetData{
Rows: []map[string]interface{}{{"affectedRows": msg.Count}},
Columns: []string{"affectedRows"},
Messages: append([]string(nil), messages...),
})
messages = nil
case sqlexp.MsgNextResultSet:
active = rows.NextResultSet()
case sqlexp.MsgError:
return resultSets, messages, msg.Error
default:
active = false
}
}
if len(messages) > 0 {
resultSets = append(resultSets, connection.ResultSetData{
Rows: []map[string]interface{}{},
Columns: []string{},
Messages: append([]string(nil), messages...),
})
}
if len(resultSets) == 0 {
resultSets = []connection.ResultSetData{{
Rows: []map[string]interface{}{},
Columns: []string{},
}}
}
if err := rows.Err(); err != nil {
return resultSets, allMessages, err
}
return resultSets, allMessages, nil
}
// quoteBracket escapes ] in identifiers for safe use in SQL Server [bracket] notation
func quoteBracket(name string) string {
return strings.ReplaceAll(name, "]", "]]")
@@ -133,54 +213,76 @@ func (s *SqlServerDB) Ping() error {
}
func (s *SqlServerDB) QueryMulti(query string) ([]connection.ResultSetData, error) {
results, _, err := s.QueryMultiWithMessages(query)
return results, err
}
func (s *SqlServerDB) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
if s.conn == nil {
return nil, fmt.Errorf("连接未打开")
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := s.conn.Query(query)
ctx := context.Background()
retmsg := &sqlexp.ReturnMessage{}
rows, err := s.conn.QueryContext(ctx, query, retmsg)
if err != nil {
return nil, err
return nil, nil, err
}
defer rows.Close()
return scanMultiRows(rows)
return scanSQLServerRowsWithMessages(ctx, rows, retmsg)
}
func (s *SqlServerDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
results, _, err := s.QueryMultiContextWithMessages(ctx, query)
return results, err
}
func (s *SqlServerDB) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
if s.conn == nil {
return nil, fmt.Errorf("连接未打开")
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := s.conn.QueryContext(ctx, query)
retmsg := &sqlexp.ReturnMessage{}
rows, err := s.conn.QueryContext(ctx, query, retmsg)
if err != nil {
return nil, err
return nil, nil, err
}
defer rows.Close()
return scanMultiRows(rows)
return scanSQLServerRowsWithMessages(ctx, rows, retmsg)
}
func (s *SqlServerDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
rows, columns, _, err := s.QueryContextWithMessages(ctx, query)
return rows, columns, err
}
func (s *SqlServerDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
if s.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
return nil, nil, nil, fmt.Errorf("连接未打开")
}
rows, err := s.conn.QueryContext(ctx, query)
resultSets, messages, err := s.QueryMultiContextWithMessages(ctx, query)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
defer rows.Close()
return scanRows(rows)
if len(resultSets) == 0 {
return []map[string]interface{}{}, []string{}, messages, nil
}
first := resultSets[0]
if first.Rows == nil {
first.Rows = []map[string]interface{}{}
}
if first.Columns == nil {
first.Columns = []string{}
}
return first.Rows, first.Columns, messages, nil
}
func (s *SqlServerDB) Query(query string) ([]map[string]interface{}, []string, error) {
if s.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, columns, _, err := s.QueryWithMessages(query)
return rows, columns, err
}
rows, err := s.conn.Query(query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
func (s *SqlServerDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
return s.QueryContextWithMessages(context.Background(), query)
}
func (s *SqlServerDB) ExecContext(ctx context.Context, query string) (int64, error) {
@@ -213,7 +315,7 @@ func (s *SqlServerDB) OpenSessionExecer(ctx context.Context) (StatementExecer, e
if err != nil {
return nil, err
}
return NewSQLConnStatementExecer(conn), nil
return &sqlServerSessionExecer{conn: conn}, nil
}
func (s *SqlServerDB) Exec(query string) (int64, error) {
@@ -227,6 +329,87 @@ func (s *SqlServerDB) Exec(query string) (int64, error) {
return res.RowsAffected()
}
func (e *sqlServerSessionExecer) Exec(query string) (int64, error) {
return e.ExecContext(context.Background(), query)
}
func (e *sqlServerSessionExecer) ExecContext(ctx context.Context, query string) (int64, error) {
if e == nil || e.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := e.conn.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (e *sqlServerSessionExecer) Query(query string) ([]map[string]interface{}, []string, error) {
rows, columns, _, err := e.QueryWithMessages(query)
return rows, columns, err
}
func (e *sqlServerSessionExecer) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
rows, columns, _, err := e.QueryContextWithMessages(ctx, query)
return rows, columns, err
}
func (e *sqlServerSessionExecer) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
return e.QueryContextWithMessages(context.Background(), query)
}
func (e *sqlServerSessionExecer) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
results, messages, err := e.QueryMultiContextWithMessages(ctx, query)
if err != nil {
return nil, nil, nil, err
}
if len(results) == 0 {
return []map[string]interface{}{}, []string{}, messages, nil
}
first := results[0]
if first.Rows == nil {
first.Rows = []map[string]interface{}{}
}
if first.Columns == nil {
first.Columns = []string{}
}
return first.Rows, first.Columns, messages, nil
}
func (e *sqlServerSessionExecer) QueryMulti(query string) ([]connection.ResultSetData, error) {
results, _, err := e.QueryMultiWithMessages(query)
return results, err
}
func (e *sqlServerSessionExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
results, _, err := e.QueryMultiContextWithMessages(ctx, query)
return results, err
}
func (e *sqlServerSessionExecer) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
return e.QueryMultiContextWithMessages(context.Background(), query)
}
func (e *sqlServerSessionExecer) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
if e == nil || e.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
retmsg := &sqlexp.ReturnMessage{}
rows, err := e.conn.QueryContext(ctx, query, retmsg)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanSQLServerRowsWithMessages(ctx, rows, retmsg)
}
func (e *sqlServerSessionExecer) Close() error {
if e == nil || e.conn == nil {
return nil
}
return e.conn.Close()
}
func (s *SqlServerDB) GetDatabases() ([]string, error) {
query := "SELECT name FROM sys.databases WHERE state_desc = 'ONLINE' ORDER BY name"
data, _, err := s.Query(query)