mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-07-06 02:21:33 +08:00
🐛 fix(sql-editor): 修复结果消息展示与数据目录迁移稳定性
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@@ -17,12 +18,8 @@ import (
|
||||
"github.com/wailsapp/wails/v2/pkg/runtime"
|
||||
)
|
||||
|
||||
var migratableDataRootEntries = []string{
|
||||
"connections.json",
|
||||
"global_proxy.json",
|
||||
"ai_config.json",
|
||||
"sessions",
|
||||
"drivers",
|
||||
var dataRootMigrationExcludedEntries = map[string]struct{}{
|
||||
"storage_root.json": {},
|
||||
}
|
||||
|
||||
func (a *App) GetDataRootDirectoryInfo() connection.QueryResult {
|
||||
@@ -122,22 +119,43 @@ func migrateDataRootContents(sourceRoot string, targetRoot string) error {
|
||||
if sourceRoot == "" || targetRoot == "" {
|
||||
return fmt.Errorf("数据目录不能为空")
|
||||
}
|
||||
if filepath.Clean(sourceRoot) == filepath.Clean(targetRoot) {
|
||||
sourceAbs, err := filepath.Abs(sourceRoot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析源数据目录失败:%w", err)
|
||||
}
|
||||
targetAbs, err := filepath.Abs(targetRoot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析目标数据目录失败:%w", err)
|
||||
}
|
||||
if filepath.Clean(sourceAbs) == filepath.Clean(targetAbs) {
|
||||
return nil
|
||||
}
|
||||
if rel, err := filepath.Rel(sourceAbs, targetAbs); err == nil && rel != "." && rel != "" && !strings.HasPrefix(rel, "..") && !filepath.IsAbs(rel) {
|
||||
return fmt.Errorf("目标数据目录不能位于源目录内部")
|
||||
}
|
||||
sourceRoot = sourceAbs
|
||||
targetRoot = targetAbs
|
||||
if err := os.MkdirAll(targetRoot, 0o755); err != nil {
|
||||
return fmt.Errorf("创建目标数据目录失败:%w", err)
|
||||
}
|
||||
for _, name := range migratableDataRootEntries {
|
||||
entries, err := os.ReadDir(sourceRoot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取源数据目录失败:%w", err)
|
||||
}
|
||||
for _, entry := range entries {
|
||||
name := strings.TrimSpace(entry.Name())
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, excluded := dataRootMigrationExcludedEntries[name]; excluded {
|
||||
continue
|
||||
}
|
||||
sourcePath := filepath.Join(sourceRoot, name)
|
||||
info, err := os.Stat(sourcePath)
|
||||
targetPath := filepath.Join(targetRoot, name)
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("读取源数据失败(%s):%w", name, err)
|
||||
}
|
||||
targetPath := filepath.Join(targetRoot, name)
|
||||
if info.IsDir() {
|
||||
if err := copyDir(sourcePath, targetPath); err != nil {
|
||||
return fmt.Errorf("迁移目录失败(%s):%w", name, err)
|
||||
@@ -148,6 +166,75 @@ func migrateDataRootContents(sourceRoot string, targetRoot string) error {
|
||||
return fmt.Errorf("迁移文件失败(%s):%w", name, err)
|
||||
}
|
||||
}
|
||||
if err := rewriteMigratedDataRootState(targetRoot); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func rewriteMigratedDataRootState(targetRoot string) error {
|
||||
if err := rewriteSecurityUpdateBackupPaths(targetRoot); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func rewriteSecurityUpdateBackupPaths(targetRoot string) error {
|
||||
repo := newSecurityUpdateStateRepository(targetRoot)
|
||||
marker, err := repo.readMarker()
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("读取迁移后的安全更新状态失败:%w", err)
|
||||
}
|
||||
|
||||
migrationID := strings.TrimSpace(marker.MigrationID)
|
||||
if migrationID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
targetBackupPath := repo.backupPath(migrationID)
|
||||
marker.BackupPath = targetBackupPath
|
||||
if err := repo.writeMarker(marker); err != nil {
|
||||
return fmt.Errorf("写入迁移后的安全更新状态失败:%w", err)
|
||||
}
|
||||
|
||||
manifestPath := repo.manifestPath(migrationID)
|
||||
manifestData, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("读取迁移后的安全更新备份清单失败:%w", err)
|
||||
}
|
||||
} else {
|
||||
var manifest securityUpdateBackupManifest
|
||||
if err := json.Unmarshal(manifestData, &manifest); err != nil {
|
||||
return fmt.Errorf("解析迁移后的安全更新备份清单失败:%w", err)
|
||||
}
|
||||
manifest.BackupPath = targetBackupPath
|
||||
if err := securityUpdateWriteJSONFile(manifestPath, manifest); err != nil {
|
||||
return fmt.Errorf("写入迁移后的安全更新备份清单失败:%w", err)
|
||||
}
|
||||
}
|
||||
|
||||
resultPath := repo.resultPath(migrationID)
|
||||
resultData, err := os.ReadFile(resultPath)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("读取迁移后的安全更新结果失败:%w", err)
|
||||
}
|
||||
} else {
|
||||
var result SecurityUpdateStatus
|
||||
if err := json.Unmarshal(resultData, &result); err != nil {
|
||||
return fmt.Errorf("解析迁移后的安全更新结果失败:%w", err)
|
||||
}
|
||||
result.BackupPath = targetBackupPath
|
||||
result.BackupAvailable = strings.TrimSpace(targetBackupPath) != ""
|
||||
if err := securityUpdateWriteJSONFile(resultPath, result); err != nil {
|
||||
return fmt.Errorf("写入迁移后的安全更新结果失败:%w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -105,6 +105,36 @@ func TestMigrateDataRootContentsCopiesSecurityUpdateStateAndRewritesBackupPaths(
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateDataRootContentsToleratesMissingSecurityUpdateArtifacts(t *testing.T) {
|
||||
sourceRoot := t.TempDir()
|
||||
targetRoot := filepath.Join(t.TempDir(), "gonavi-data")
|
||||
sourceRepo := newSecurityUpdateStateRepository(sourceRoot)
|
||||
started, err := sourceRepo.StartRound(StartSecurityUpdateRequest{SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig})
|
||||
if err != nil {
|
||||
t.Fatalf("start security update round failed: %v", err)
|
||||
}
|
||||
if err := os.Remove(sourceRepo.manifestPath(started.MigrationID)); err != nil {
|
||||
t.Fatalf("remove source manifest failed: %v", err)
|
||||
}
|
||||
if err := os.Remove(sourceRepo.resultPath(started.MigrationID)); err != nil {
|
||||
t.Fatalf("remove source result failed: %v", err)
|
||||
}
|
||||
|
||||
if err := migrateDataRootContents(sourceRoot, targetRoot); err != nil {
|
||||
t.Fatalf("migrateDataRootContents should tolerate missing security update artifacts, got: %v", err)
|
||||
}
|
||||
|
||||
targetRepo := newSecurityUpdateStateRepository(targetRoot)
|
||||
targetStatus, err := targetRepo.LoadMarker()
|
||||
if err != nil {
|
||||
t.Fatalf("load migrated marker failed: %v", err)
|
||||
}
|
||||
expectedBackupPath := filepath.Join(targetRoot, securityUpdateBackupRootDirName, started.MigrationID)
|
||||
if targetStatus.BackupPath != expectedBackupPath {
|
||||
t.Fatalf("expected migrated marker backupPath %q, got %q", expectedBackupPath, targetStatus.BackupPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateDataRootContentsCopiesDailySecretsForSavedConnections(t *testing.T) {
|
||||
sourceRoot := t.TempDir()
|
||||
targetRoot := filepath.Join(t.TempDir(), "gonavi-data")
|
||||
|
||||
@@ -623,6 +623,7 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
|
||||
}()
|
||||
|
||||
isReadQuery := isReadOnlySQLQuery(runConfig.Type, query)
|
||||
tryQueryFirst := shouldTryQueryResultFirst(runConfig.Type, query)
|
||||
|
||||
runReadQuery := func(inst db.Database) ([]map[string]interface{}, []string, error) {
|
||||
if q, ok := inst.(interface {
|
||||
@@ -633,6 +634,14 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
|
||||
return inst.Query(query)
|
||||
}
|
||||
|
||||
runReadQueryWithMessages := func(inst db.Database) ([]map[string]interface{}, []string, []string, error) {
|
||||
if q, ok := inst.(db.QueryMessageExecer); ok {
|
||||
return q.QueryContextWithMessages(ctx, query)
|
||||
}
|
||||
data, columns, err := runReadQuery(inst)
|
||||
return data, columns, nil, err
|
||||
}
|
||||
|
||||
runExecQuery := func(inst db.Database) (int64, error) {
|
||||
if e, ok := inst.(interface {
|
||||
ExecContext(context.Context, string) (int64, error)
|
||||
@@ -642,8 +651,8 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
|
||||
return inst.Exec(query)
|
||||
}
|
||||
|
||||
if isReadQuery {
|
||||
data, columns, err := runReadQuery(dbInst)
|
||||
if isReadQuery || tryQueryFirst {
|
||||
data, columns, messages, err := runReadQueryWithMessages(dbInst)
|
||||
if err != nil && shouldRefreshCachedConnection(err) {
|
||||
if a.invalidateCachedDatabase(runConfig, err) {
|
||||
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
|
||||
@@ -651,32 +660,34 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
|
||||
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: retryErr.Error()}
|
||||
}
|
||||
data, columns, err = runReadQuery(retryInst)
|
||||
data, columns, messages, err = runReadQueryWithMessages(retryInst)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == nil {
|
||||
return connection.QueryResult{Success: true, Data: data, Fields: columns, Messages: messages, QueryID: queryID}
|
||||
}
|
||||
if isReadQuery {
|
||||
logger.Error(err, "DBQuery 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: data, Fields: columns, QueryID: queryID}
|
||||
} else {
|
||||
affected, err := runExecQuery(dbInst)
|
||||
if err != nil && shouldRefreshCachedConnection(err) {
|
||||
if a.invalidateCachedDatabase(runConfig, err) {
|
||||
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
|
||||
if retryErr != nil {
|
||||
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: retryErr.Error()}
|
||||
}
|
||||
affected, err = runExecQuery(retryInst)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: map[string]int64{"affectedRows": affected}, QueryID: queryID}
|
||||
}
|
||||
|
||||
affected, err := runExecQuery(dbInst)
|
||||
if err != nil && shouldRefreshCachedConnection(err) {
|
||||
if a.invalidateCachedDatabase(runConfig, err) {
|
||||
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
|
||||
if retryErr != nil {
|
||||
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: retryErr.Error()}
|
||||
}
|
||||
affected, err = runExecQuery(retryInst)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: map[string]int64{"affectedRows": affected}, QueryID: queryID}
|
||||
}
|
||||
|
||||
// DBQueryMulti 执行可能包含多条 SQL 语句的查询,返回多个结果集。
|
||||
@@ -727,20 +738,25 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
}
|
||||
}
|
||||
|
||||
runMultiQuery := func(inst db.Database) ([]connection.ResultSetData, error) {
|
||||
runMultiQuery := func(inst db.Database) ([]connection.ResultSetData, []string, error) {
|
||||
if !allReadOnly {
|
||||
return nil, nil // 包含写操作,走逐条执行路径
|
||||
return nil, nil, nil // 包含写操作,走逐条执行路径
|
||||
}
|
||||
if q, ok := inst.(db.MultiResultQueryMessageExecer); ok {
|
||||
return q.QueryMultiContextWithMessages(ctx, query)
|
||||
}
|
||||
if q, ok := inst.(db.MultiResultQuerierContext); ok {
|
||||
return q.QueryMultiContext(ctx, query)
|
||||
results, err := q.QueryMultiContext(ctx, query)
|
||||
return results, nil, err
|
||||
}
|
||||
if q, ok := inst.(db.MultiResultQuerier); ok {
|
||||
return q.QueryMulti(query)
|
||||
results, err := q.QueryMulti(query)
|
||||
return results, nil, err
|
||||
}
|
||||
return nil, nil // 返回 nil 表示不支持
|
||||
return nil, nil, nil // 返回 nil 表示不支持
|
||||
}
|
||||
|
||||
results, err := runMultiQuery(dbInst)
|
||||
results, resultMessages, err := runMultiQuery(dbInst)
|
||||
if err != nil && shouldRefreshCachedConnection(err) {
|
||||
if a.invalidateCachedDatabase(runConfig, err) {
|
||||
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
|
||||
@@ -748,7 +764,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
logger.Error(retryErr, "DBQueryMulti 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: retryErr.Error(), QueryID: queryID}
|
||||
}
|
||||
results, err = runMultiQuery(retryInst)
|
||||
results, resultMessages, err = runMultiQuery(retryInst)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
@@ -758,7 +774,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
|
||||
// 驱动支持多结果集,直接返回
|
||||
if results != nil {
|
||||
return connection.QueryResult{Success: true, Data: results, QueryID: queryID}
|
||||
return connection.QueryResult{Success: true, Data: results, Messages: resultMessages, QueryID: queryID}
|
||||
}
|
||||
|
||||
// 驱动不支持多结果集,回退到逐条执行
|
||||
@@ -771,13 +787,50 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
}
|
||||
}
|
||||
|
||||
var sessionQueryTarget db.StatementQueryExecer
|
||||
var sessionQueryMessageTarget db.StatementQueryMessageExecer
|
||||
var sessionMultiQueryTarget db.StatementMultiResultQueryExecer
|
||||
var sessionMultiQueryMessageTarget db.StatementMultiResultQueryMessageExecer
|
||||
var sessionExecTarget db.StatementExecer
|
||||
var sessionBatchTarget db.BatchWriteExecer
|
||||
closeExecTarget := func() {}
|
||||
if provider, ok := dbInst.(db.SessionExecerProvider); ok {
|
||||
sessionExecer, sessionErr := provider.OpenSessionExecer(ctx)
|
||||
if sessionErr != nil {
|
||||
logger.Warnf("DBQueryMulti 打开会话级执行器失败,将回退共享连接:%s SQL片段=%q err=%v", formatConnSummary(runConfig), sqlSnippet(query), sessionErr)
|
||||
} else {
|
||||
if statementQueryExecer, ok := sessionExecer.(db.StatementQueryExecer); ok {
|
||||
sessionQueryTarget = statementQueryExecer
|
||||
}
|
||||
if statementQueryMessageExecer, ok := sessionExecer.(db.StatementQueryMessageExecer); ok {
|
||||
sessionQueryMessageTarget = statementQueryMessageExecer
|
||||
}
|
||||
if statementMultiResultQueryExecer, ok := sessionExecer.(db.StatementMultiResultQueryExecer); ok {
|
||||
sessionMultiQueryTarget = statementMultiResultQueryExecer
|
||||
}
|
||||
if statementMultiResultQueryMessageExecer, ok := sessionExecer.(db.StatementMultiResultQueryMessageExecer); ok {
|
||||
sessionMultiQueryMessageTarget = statementMultiResultQueryMessageExecer
|
||||
}
|
||||
sessionExecTarget = sessionExecer
|
||||
if batcher, ok := sessionExecer.(db.BatchWriteExecer); ok {
|
||||
sessionBatchTarget = batcher
|
||||
}
|
||||
closeExecTarget = func() {
|
||||
if err := sessionExecer.Close(); err != nil {
|
||||
logger.Warnf("DBQueryMulti 关闭会话级执行器失败:%v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
defer closeExecTarget()
|
||||
|
||||
// 全部为写操作且驱动支持批量 Exec → 一次性发送,大幅减少网络往返
|
||||
// 适用于 MySQL/MariaDB/Doris/PostgreSQL/SQLite/DuckDB 等支持多语句 Exec 的驱动
|
||||
if !allReadOnly {
|
||||
allWrite := true
|
||||
containsPLSQLBlock := false
|
||||
for _, stmt := range statements {
|
||||
if strings.TrimSpace(stmt) != "" && isReadOnlySQLQuery(runConfig.Type, stmt) {
|
||||
if strings.TrimSpace(stmt) != "" && !isBatchableWriteSQLStatement(runConfig.Type, stmt) {
|
||||
allWrite = false
|
||||
}
|
||||
if isPLSQLBlockStatement(stmt) {
|
||||
@@ -785,7 +838,13 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
}
|
||||
}
|
||||
if allWrite && !containsPLSQLBlock {
|
||||
if batcher, ok := dbInst.(db.BatchWriteExecer); ok {
|
||||
batcher := sessionBatchTarget
|
||||
if batcher == nil {
|
||||
if fallbackBatcher, ok := dbInst.(db.BatchWriteExecer); ok {
|
||||
batcher = fallbackBatcher
|
||||
}
|
||||
}
|
||||
if batcher != nil {
|
||||
affected, batchErr := batcher.ExecBatchContext(ctx, query)
|
||||
if batchErr != nil && shouldRefreshCachedConnection(batchErr) {
|
||||
if a.invalidateCachedDatabase(runConfig, batchErr) {
|
||||
@@ -823,17 +882,80 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
continue
|
||||
}
|
||||
|
||||
if isReadOnlySQLQuery(runConfig.Type, stmt) {
|
||||
var data []map[string]interface{}
|
||||
var columns []string
|
||||
if q, ok := dbInst.(interface {
|
||||
isReadStmt := isReadOnlySQLQuery(runConfig.Type, stmt)
|
||||
tryQueryStmtFirst := shouldTryQueryResultFirst(runConfig.Type, stmt)
|
||||
if isReadStmt || tryQueryStmtFirst {
|
||||
var (
|
||||
data []map[string]interface{}
|
||||
columns []string
|
||||
messages []string
|
||||
statementResults []connection.ResultSetData
|
||||
usedMultiResult bool
|
||||
)
|
||||
if sessionMultiQueryMessageTarget != nil {
|
||||
statementResults, messages, err = sessionMultiQueryMessageTarget.QueryMultiContextWithMessages(ctx, stmt)
|
||||
usedMultiResult = true
|
||||
} else if sessionMultiQueryTarget != nil {
|
||||
statementResults, err = sessionMultiQueryTarget.QueryMultiContext(ctx, stmt)
|
||||
usedMultiResult = true
|
||||
} else if q, ok := dbInst.(db.MultiResultQueryMessageExecer); ok {
|
||||
statementResults, messages, err = q.QueryMultiContextWithMessages(ctx, stmt)
|
||||
usedMultiResult = true
|
||||
} else if q, ok := dbInst.(db.MultiResultQuerierContext); ok {
|
||||
statementResults, err = q.QueryMultiContext(ctx, stmt)
|
||||
usedMultiResult = true
|
||||
} else if q, ok := dbInst.(db.MultiResultQuerier); ok {
|
||||
statementResults, err = q.QueryMulti(stmt)
|
||||
usedMultiResult = true
|
||||
} else if sessionQueryMessageTarget != nil {
|
||||
data, columns, messages, err = sessionQueryMessageTarget.QueryContextWithMessages(ctx, stmt)
|
||||
} else if sessionQueryTarget != nil {
|
||||
data, columns, err = sessionQueryTarget.QueryContext(ctx, stmt)
|
||||
} else if q, ok := dbInst.(db.QueryMessageExecer); ok {
|
||||
data, columns, messages, err = q.QueryContextWithMessages(ctx, stmt)
|
||||
} else if q, ok := dbInst.(interface {
|
||||
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
|
||||
}); ok {
|
||||
data, columns, err = q.QueryContext(ctx, stmt)
|
||||
} else {
|
||||
data, columns, err = dbInst.Query(stmt)
|
||||
}
|
||||
if err != nil {
|
||||
if err == nil {
|
||||
if usedMultiResult {
|
||||
if len(statementResults) == 0 && len(messages) > 0 {
|
||||
statementResults = []connection.ResultSetData{{
|
||||
Rows: []map[string]interface{}{},
|
||||
Columns: []string{},
|
||||
Messages: append([]string(nil), messages...),
|
||||
}}
|
||||
}
|
||||
for _, statementResult := range statementResults {
|
||||
if statementResult.Rows == nil {
|
||||
statementResult.Rows = []map[string]interface{}{}
|
||||
}
|
||||
if statementResult.Columns == nil {
|
||||
statementResult.Columns = []string{}
|
||||
}
|
||||
statementResult.StatementIndex = idx + 1
|
||||
resultSets = append(resultSets, statementResult)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if data == nil {
|
||||
data = make([]map[string]interface{}, 0)
|
||||
}
|
||||
if columns == nil {
|
||||
columns = []string{}
|
||||
}
|
||||
resultSets = append(resultSets, connection.ResultSetData{
|
||||
Rows: data,
|
||||
Columns: columns,
|
||||
Messages: messages,
|
||||
StatementIndex: idx + 1,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if isReadStmt {
|
||||
logger.Error(err, "DBQueryMulti 逐条查询失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt))
|
||||
errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err)
|
||||
if len(resultSets) > 0 {
|
||||
@@ -841,35 +963,30 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
}
|
||||
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
|
||||
}
|
||||
if data == nil {
|
||||
data = make([]map[string]interface{}, 0)
|
||||
}
|
||||
if columns == nil {
|
||||
columns = []string{}
|
||||
}
|
||||
resultSets = append(resultSets, connection.ResultSetData{Rows: data, Columns: columns})
|
||||
} else {
|
||||
var affected int64
|
||||
if e, ok := dbInst.(interface {
|
||||
ExecContext(context.Context, string) (int64, error)
|
||||
}); ok {
|
||||
affected, err = e.ExecContext(ctx, stmt)
|
||||
} else {
|
||||
affected, err = dbInst.Exec(stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQueryMulti 逐条执行失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt))
|
||||
errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err)
|
||||
if len(resultSets) > 0 {
|
||||
errMsg += fmt.Sprintf("(前 %d 条已执行成功)", len(resultSets))
|
||||
}
|
||||
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
|
||||
}
|
||||
resultSets = append(resultSets, connection.ResultSetData{
|
||||
Rows: []map[string]interface{}{{"affectedRows": affected}},
|
||||
Columns: []string{"affectedRows"},
|
||||
})
|
||||
}
|
||||
|
||||
var affected int64
|
||||
if sessionExecTarget != nil {
|
||||
affected, err = sessionExecTarget.ExecContext(ctx, stmt)
|
||||
} else if e, ok := dbInst.(interface {
|
||||
ExecContext(context.Context, string) (int64, error)
|
||||
}); ok {
|
||||
affected, err = e.ExecContext(ctx, stmt)
|
||||
} else {
|
||||
affected, err = dbInst.Exec(stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQueryMulti 逐条执行失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt))
|
||||
errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err)
|
||||
if len(resultSets) > 0 {
|
||||
errMsg += fmt.Sprintf("(前 %d 条已执行成功)", len(resultSets))
|
||||
}
|
||||
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
|
||||
}
|
||||
resultSets = append(resultSets, connection.ResultSetData{
|
||||
Rows: []map[string]interface{}{{"affectedRows": affected}},
|
||||
Columns: []string{"affectedRows"},
|
||||
})
|
||||
}
|
||||
|
||||
if resultSets == nil {
|
||||
@@ -883,6 +1000,24 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
return connection.QueryResult{Success: true, Data: resultSets, QueryID: queryID, Message: fallbackMsg}
|
||||
}
|
||||
|
||||
func shouldTryQueryResultFirst(dbType string, query string) bool {
|
||||
isSQLServer := strings.EqualFold(strings.TrimSpace(dbType), "sqlserver")
|
||||
keyword := leadingSQLKeyword(query)
|
||||
switch keyword {
|
||||
case "exec", "execute", "call":
|
||||
return true
|
||||
case "set", "print":
|
||||
return isSQLServer
|
||||
case "dbcc":
|
||||
return isSQLServer
|
||||
default:
|
||||
if isSQLServer {
|
||||
return strings.HasPrefix(keyword, "sp_") || strings.HasPrefix(keyword, "xp_")
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string, query string) connection.QueryResult {
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
|
||||
@@ -906,22 +1041,30 @@ func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string,
|
||||
defer cancel()
|
||||
|
||||
isReadQuery := isReadOnlySQLQuery(runConfig.Type, query)
|
||||
tryQueryFirst := shouldTryQueryResultFirst(runConfig.Type, query)
|
||||
|
||||
if isReadQuery {
|
||||
var data []map[string]interface{}
|
||||
var columns []string
|
||||
if q, ok := dbInst.(interface {
|
||||
if isReadQuery || tryQueryFirst {
|
||||
var (
|
||||
data []map[string]interface{}
|
||||
columns []string
|
||||
messages []string
|
||||
)
|
||||
if q, ok := dbInst.(db.QueryMessageExecer); ok {
|
||||
data, columns, messages, err = q.QueryContextWithMessages(ctx, query)
|
||||
} else if q, ok := dbInst.(interface {
|
||||
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
|
||||
}); ok {
|
||||
data, columns, err = q.QueryContext(ctx, query)
|
||||
} else {
|
||||
data, columns, err = dbInst.Query(query)
|
||||
}
|
||||
if err != nil {
|
||||
if err == nil {
|
||||
return connection.QueryResult{Success: true, Data: data, Fields: columns, Messages: messages}
|
||||
}
|
||||
if isReadQuery {
|
||||
logger.Error(err, "DBQueryIsolated 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: data, Fields: columns}
|
||||
}
|
||||
|
||||
var affected int64
|
||||
|
||||
@@ -10,10 +10,17 @@ import (
|
||||
)
|
||||
|
||||
type fakeBatchWriteDB struct {
|
||||
batchCalls int
|
||||
execCalls int
|
||||
batchCalls int
|
||||
execCalls int
|
||||
execQueries []string
|
||||
lastQuery string
|
||||
lastQuery string
|
||||
queryCalls int
|
||||
queryMap map[string][]map[string]interface{}
|
||||
fieldMap map[string][]string
|
||||
messageMap map[string][]string
|
||||
multiResult map[string][]connection.ResultSetData
|
||||
queryErr map[string]error
|
||||
session *fakeBatchWriteSession
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) Connect(config connection.ConnectionConfig) error {
|
||||
@@ -29,7 +36,16 @@ func (f *fakeBatchWriteDB) Ping() error {
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
return nil, nil, nil
|
||||
f.queryCalls++
|
||||
if err := f.queryErr[query]; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return f.queryMap[query], f.fieldMap[query], nil
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
rows, fields, err := f.Query(query)
|
||||
return rows, fields, f.messageMap[query], err
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) Exec(query string) (int64, error) {
|
||||
@@ -76,12 +92,143 @@ func (f *fakeBatchWriteDB) ExecContext(ctx context.Context, query string) (int64
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
f.queryCalls++
|
||||
if err := f.queryErr[query]; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return f.queryMap[query], f.fieldMap[query], nil
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
rows, fields, err := f.QueryContext(ctx, query)
|
||||
return rows, fields, f.messageMap[query], err
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) ExecBatchContext(ctx context.Context, query string) (int64, error) {
|
||||
f.batchCalls++
|
||||
f.lastQuery = query
|
||||
return 500, nil
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) OpenSessionExecer(ctx context.Context) (db.StatementExecer, error) {
|
||||
f.session = &fakeBatchWriteSession{parent: f}
|
||||
return f.session, nil
|
||||
}
|
||||
|
||||
type fakeBatchWriteSession struct {
|
||||
parent *fakeBatchWriteDB
|
||||
queryCalls int
|
||||
execCalls int
|
||||
batchCalls int
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
return s.QueryContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
s.queryCalls++
|
||||
return s.parent.QueryContext(ctx, query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
return s.QueryContextWithMessages(context.Background(), query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
s.queryCalls++
|
||||
return s.parent.QueryContextWithMessages(ctx, query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) QueryMulti(query string) ([]connection.ResultSetData, error) {
|
||||
return s.QueryMultiContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
|
||||
if multi := s.parent.multiResult[query]; len(multi) > 0 {
|
||||
s.queryCalls++
|
||||
return cloneResultSets(multi), nil
|
||||
}
|
||||
rows, columns, err := s.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []connection.ResultSetData{{Rows: rows, Columns: columns}}, nil
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
|
||||
return s.QueryMultiContextWithMessages(context.Background(), query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
|
||||
if err := s.parent.queryErr[query]; err != nil {
|
||||
s.queryCalls++
|
||||
return nil, nil, err
|
||||
}
|
||||
if multi := s.parent.multiResult[query]; len(multi) > 0 {
|
||||
s.queryCalls++
|
||||
return cloneResultSets(multi), append([]string(nil), s.parent.messageMap[query]...), nil
|
||||
}
|
||||
rows, columns, messages, err := s.QueryContextWithMessages(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return []connection.ResultSetData{{
|
||||
Rows: rows,
|
||||
Columns: columns,
|
||||
Messages: append([]string(nil), messages...),
|
||||
}}, append([]string(nil), messages...), nil
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) Exec(query string) (int64, error) {
|
||||
return s.ExecContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
s.execCalls++
|
||||
return s.parent.ExecContext(ctx, query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) ExecBatchContext(ctx context.Context, query string) (int64, error) {
|
||||
s.batchCalls++
|
||||
return s.parent.ExecBatchContext(ctx, query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) Close() error {
|
||||
s.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func cloneResultSets(input []connection.ResultSetData) []connection.ResultSetData {
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
}
|
||||
cloned := make([]connection.ResultSetData, 0, len(input))
|
||||
for _, item := range input {
|
||||
rows := make([]map[string]interface{}, 0, len(item.Rows))
|
||||
for _, row := range item.Rows {
|
||||
if row == nil {
|
||||
rows = append(rows, nil)
|
||||
continue
|
||||
}
|
||||
rowCopy := make(map[string]interface{}, len(row))
|
||||
for key, value := range row {
|
||||
rowCopy[key] = value
|
||||
}
|
||||
rows = append(rows, rowCopy)
|
||||
}
|
||||
cloned = append(cloned, connection.ResultSetData{
|
||||
Rows: rows,
|
||||
Columns: append([]string(nil), item.Columns...),
|
||||
Messages: append([]string(nil), item.Messages...),
|
||||
StatementIndex: item.StatementIndex,
|
||||
})
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func TestDBQueryMultiKeepsOracleAnonymousBlockAsSingleStatement(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
@@ -122,6 +269,85 @@ END;`
|
||||
}
|
||||
|
||||
var _ db.BatchWriteExecer = (*fakeBatchWriteDB)(nil)
|
||||
var _ db.SessionExecerProvider = (*fakeBatchWriteDB)(nil)
|
||||
var _ db.QueryMessageExecer = (*fakeBatchWriteDB)(nil)
|
||||
var _ db.StatementQueryMessageExecer = (*fakeBatchWriteSession)(nil)
|
||||
|
||||
func TestDBQueryWithCancelReturnsResultSetForExecStoredProcedure(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
query := "EXEC sp_who2"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
queryMap: map[string][]map[string]interface{}{
|
||||
query: {
|
||||
{"SPID": 52, "STATUS": "RUNNABLE"},
|
||||
},
|
||||
},
|
||||
fieldMap: map[string][]string{
|
||||
query: {"SPID", "STATUS"},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
|
||||
|
||||
result := app.DBQueryWithCancel(config, "master", query, "sp-who2-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryWithCancel success, got failure: %s", result.Message)
|
||||
}
|
||||
rows, ok := result.Data.([]map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected []map[string]interface{}, got %T", result.Data)
|
||||
}
|
||||
if len(rows) != 1 || rows[0]["SPID"] != 52 {
|
||||
t.Fatalf("unexpected rows: %#v", rows)
|
||||
}
|
||||
if fakeDB.execCalls != 0 {
|
||||
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryWithCancelReturnsMessagesForSQLServerQuery(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
query := "SET STATISTICS IO ON"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
queryMap: map[string][]map[string]interface{}{
|
||||
query: {},
|
||||
},
|
||||
fieldMap: map[string][]string{
|
||||
query: {},
|
||||
},
|
||||
messageMap: map[string][]string{
|
||||
query: {"Table 'users'. Scan count 1, logical reads 3."},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
|
||||
|
||||
result := app.DBQueryWithCancel(config, "master", query, "statistics-io-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryWithCancel success, got failure: %s", result.Message)
|
||||
}
|
||||
if len(result.Messages) != 1 || result.Messages[0] == "" {
|
||||
t.Fatalf("expected SQL Server messages to be returned, got %#v", result.Messages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
@@ -168,3 +394,206 @@ func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) {
|
||||
t.Fatalf("expected affectedRows=500, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiPrefersResultSetForExecStoredProcedure(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
query := "EXEC sp_who2"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
queryMap: map[string][]map[string]interface{}{
|
||||
query: {
|
||||
{"SPID": 77, "STATUS": "SUSPENDED"},
|
||||
},
|
||||
},
|
||||
fieldMap: map[string][]string{
|
||||
query: {"SPID", "STATUS"},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
|
||||
|
||||
result := app.DBQueryMulti(config, "master", query, "sp-who2-multi-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
|
||||
}
|
||||
resultSets, ok := result.Data.([]connection.ResultSetData)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
|
||||
}
|
||||
if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 {
|
||||
t.Fatalf("unexpected result sets: %#v", resultSets)
|
||||
}
|
||||
if got := resultSets[0].Rows[0]["SPID"]; got != 77 {
|
||||
t.Fatalf("expected SPID=77, got %#v", got)
|
||||
}
|
||||
if fakeDB.execCalls != 0 {
|
||||
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiDoesNotBatchExecStoredProcedureAsWriteStatement(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
query := "EXEC sp_who2"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
queryMap: map[string][]map[string]interface{}{
|
||||
query: {
|
||||
{"SPID": 88, "STATUS": "RUNNING"},
|
||||
},
|
||||
},
|
||||
fieldMap: map[string][]string{
|
||||
query: {"SPID", "STATUS"},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
|
||||
|
||||
result := app.DBQueryMulti(config, "master", query, "sp-who2-batch-guard-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
|
||||
}
|
||||
if fakeDB.batchCalls != 0 {
|
||||
t.Fatalf("expected stored procedure to skip batch write path, got batchCalls=%d", fakeDB.batchCalls)
|
||||
}
|
||||
resultSets, ok := result.Data.([]connection.ResultSetData)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
|
||||
}
|
||||
if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 {
|
||||
t.Fatalf("unexpected result sets: %#v", resultSets)
|
||||
}
|
||||
if got := resultSets[0].Rows[0]["SPID"]; got != 88 {
|
||||
t.Fatalf("expected SPID=88, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiUsesPinnedSessionForSequentialFallback(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
queryMap: map[string][]map[string]interface{}{
|
||||
"SELECT 1 AS value": {
|
||||
{"value": 1},
|
||||
},
|
||||
},
|
||||
fieldMap: map[string][]string{
|
||||
"SELECT 1 AS value": {"value"},
|
||||
},
|
||||
messageMap: map[string][]string{
|
||||
"SET NOCOUNT ON": {"NOCOUNT 已开启"},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
|
||||
|
||||
result := app.DBQueryMulti(config, "master", "SET NOCOUNT ON;\nSELECT 1 AS value;", "session-fallback-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
|
||||
}
|
||||
if fakeDB.session == nil {
|
||||
t.Fatal("expected DBQueryMulti to open a pinned session for sequential fallback")
|
||||
}
|
||||
if !fakeDB.session.closed {
|
||||
t.Fatal("expected DBQueryMulti to close the pinned session")
|
||||
}
|
||||
if fakeDB.session.execCalls != 0 {
|
||||
t.Fatalf("expected SQL Server SET statement to avoid exec-only path, got execCalls=%d", fakeDB.session.execCalls)
|
||||
}
|
||||
if fakeDB.session.queryCalls != 2 {
|
||||
t.Fatalf("expected both statements to query through pinned session, got queryCalls=%d", fakeDB.session.queryCalls)
|
||||
}
|
||||
if fakeDB.queryCalls != 2 {
|
||||
t.Fatalf("expected exactly two underlying query calls, got %d", fakeDB.queryCalls)
|
||||
}
|
||||
resultSets, ok := result.Data.([]connection.ResultSetData)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
|
||||
}
|
||||
if len(resultSets) != 2 {
|
||||
t.Fatalf("expected two result sets, got %#v", resultSets)
|
||||
}
|
||||
if len(resultSets[0].Messages) != 1 || resultSets[0].Messages[0] != "NOCOUNT 已开启" {
|
||||
t.Fatalf("expected first result set to keep session message, got %#v", resultSets[0].Messages)
|
||||
}
|
||||
if got := resultSets[1].Rows[0]["value"]; got != 1 {
|
||||
t.Fatalf("expected second result set value=1, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiKeepsAllResultSetsFromSingleSQLServerStatement(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
query := "EXEC sp_helpdb"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
multiResult: map[string][]connection.ResultSetData{
|
||||
query: {
|
||||
{
|
||||
Rows: []map[string]interface{}{{"name": "master"}},
|
||||
Columns: []string{"name"},
|
||||
},
|
||||
{
|
||||
Rows: []map[string]interface{}{{"owner": "sa"}},
|
||||
Columns: []string{"owner"},
|
||||
},
|
||||
},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
|
||||
|
||||
result := app.DBQueryMulti(config, "master", query, "sp-helpdb-multi-result-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
|
||||
}
|
||||
resultSets, ok := result.Data.([]connection.ResultSetData)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
|
||||
}
|
||||
if len(resultSets) != 2 {
|
||||
t.Fatalf("expected two result sets, got %#v", resultSets)
|
||||
}
|
||||
if got := resultSets[0].Rows[0]["name"]; got != "master" {
|
||||
t.Fatalf("expected first result set to keep master row, got %#v", got)
|
||||
}
|
||||
if got := resultSets[1].Rows[0]["owner"]; got != "sa" {
|
||||
t.Fatalf("expected second result set to keep owner row, got %#v", got)
|
||||
}
|
||||
if resultSets[0].StatementIndex != 1 || resultSets[1].StatementIndex != 1 {
|
||||
t.Fatalf("expected both result sets to map to the first statement, got %#v", resultSets)
|
||||
}
|
||||
if fakeDB.execCalls != 0 {
|
||||
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +65,19 @@ func isReadOnlySQLQuery(dbType string, query string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func isBatchableWriteSQLStatement(dbType string, query string) bool {
|
||||
if isReadOnlySQLQuery(dbType, query) {
|
||||
return false
|
||||
}
|
||||
|
||||
switch leadingSQLKeyword(query) {
|
||||
case "insert", "update", "delete", "replace", "merge", "upsert":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeSQLForPgLike(dbType string, query string) string {
|
||||
normalizedType := strings.ToLower(strings.TrimSpace(dbType))
|
||||
switch normalizedType {
|
||||
|
||||
@@ -62,3 +62,42 @@ func TestSanitizeSQLForPgLike_DoesNotModifyOtherDBTypes(t *testing.T) {
|
||||
t.Fatalf("non-PG-like db should not be sanitized:\nIN: %s\nOUT: %s", in, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsReadOnlySQLQuery_DoesNotTreatExecAsReadOnly(t *testing.T) {
|
||||
if isReadOnlySQLQuery("sqlserver", "EXEC sp_who2") {
|
||||
t.Fatal("EXEC should not be treated as read-only SQL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBatchableWriteSQLStatement_OnlyMatchesRealWriteStatements(t *testing.T) {
|
||||
if !isBatchableWriteSQLStatement("mysql", "INSERT INTO demo(id) VALUES (1)") {
|
||||
t.Fatal("expected INSERT to be treated as batchable write")
|
||||
}
|
||||
if isBatchableWriteSQLStatement("sqlserver", "EXEC sp_who2") {
|
||||
t.Fatal("EXEC should not be treated as batchable write")
|
||||
}
|
||||
if isBatchableWriteSQLStatement("sqlserver", "SET STATISTICS IO ON") {
|
||||
t.Fatal("SET STATISTICS should not be treated as batchable write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldTryQueryResultFirst_TreatsSQLServerSetAsQueryFirst(t *testing.T) {
|
||||
if !shouldTryQueryResultFirst("sqlserver", "SET STATISTICS IO ON") {
|
||||
t.Fatal("expected SQL Server SET STATISTICS to try query-first for notice capture")
|
||||
}
|
||||
if shouldTryQueryResultFirst("mysql", "SET sql_mode = ''") {
|
||||
t.Fatal("non-SQLServer SET should not force query-first")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldTryQueryResultFirst_TreatsSQLServerSystemCommandsAsQueryFirst(t *testing.T) {
|
||||
if !shouldTryQueryResultFirst("sqlserver", "sp_who2") {
|
||||
t.Fatal("expected bare SQL Server system procedure to try query-first")
|
||||
}
|
||||
if !shouldTryQueryResultFirst("sqlserver", "DBCC INPUTBUFFER(52)") {
|
||||
t.Fatal("expected SQL Server DBCC command to try query-first")
|
||||
}
|
||||
if shouldTryQueryResultFirst("mysql", "sp_who2") {
|
||||
t.Fatal("non-SQLServer system procedure name should not force query-first")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,8 @@ func (p *panicChromium) PutZoomFactor(float64) {
|
||||
}
|
||||
|
||||
type panicFrontend struct {
|
||||
chromium *panicChromium
|
||||
chromium *panicChromium
|
||||
mainWindow *fakeWindow
|
||||
}
|
||||
|
||||
// 测试必须用 wails 一致的 string key "frontend" 作为 context.WithValue 的 key,
|
||||
@@ -122,7 +123,10 @@ func TestResetWebViewZoomFactorErrorsWhenFrontendMissing(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResetWebViewZoomFactorRecoversFromPutZoomFactorPanic(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), stringContextKey("frontend"), &panicFrontend{chromium: &panicChromium{}})
|
||||
ctx := context.WithValue(context.Background(), stringContextKey("frontend"), &panicFrontend{
|
||||
chromium: &panicChromium{},
|
||||
mainWindow: &fakeWindow{},
|
||||
})
|
||||
|
||||
err := resetWebViewZoomFactor(ctx, 1.0)
|
||||
if err == nil {
|
||||
|
||||
@@ -122,17 +122,20 @@ type ConnectionConfig struct {
|
||||
|
||||
// ResultSetData 表示一个查询结果集(行 + 列名),用于多结果集场景。
|
||||
type ResultSetData struct {
|
||||
Rows []map[string]interface{} `json:"rows"`
|
||||
Columns []string `json:"columns"`
|
||||
Rows []map[string]interface{} `json:"rows"`
|
||||
Columns []string `json:"columns"`
|
||||
Messages []string `json:"messages,omitempty"`
|
||||
StatementIndex int `json:"statementIndex,omitempty"`
|
||||
}
|
||||
|
||||
// QueryResult 是 Wails 绑定方法的统一响应格式,前端通过此结构体接收后端结果。
|
||||
type QueryResult struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data"`
|
||||
Fields []string `json:"fields,omitempty"`
|
||||
QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data"`
|
||||
Fields []string `json:"fields,omitempty"`
|
||||
Messages []string `json:"messages,omitempty"`
|
||||
QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation
|
||||
}
|
||||
|
||||
// ColumnDefinition 描述表的一个列定义。
|
||||
|
||||
@@ -67,6 +67,51 @@ type StatementExecer interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
// StatementQueryExecer can run queries on a pinned session/connection.
|
||||
// Drivers that return sqlConnStatementExecer automatically satisfy it.
|
||||
type StatementQueryExecer interface {
|
||||
StatementExecer
|
||||
Query(query string) ([]map[string]interface{}, []string, error)
|
||||
QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error)
|
||||
}
|
||||
|
||||
// StatementQueryMessageExecer can run queries on a pinned session and return
|
||||
// extra server messages/notices alongside rows.
|
||||
type StatementQueryMessageExecer interface {
|
||||
StatementQueryExecer
|
||||
QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error)
|
||||
QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error)
|
||||
}
|
||||
|
||||
// StatementMultiResultQueryExecer can run multi-result queries on a pinned session/connection.
|
||||
type StatementMultiResultQueryExecer interface {
|
||||
StatementExecer
|
||||
QueryMulti(query string) ([]connection.ResultSetData, error)
|
||||
QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error)
|
||||
}
|
||||
|
||||
// StatementMultiResultQueryMessageExecer can run multi-result queries on a
|
||||
// pinned session/connection and return server messages/notices.
|
||||
type StatementMultiResultQueryMessageExecer interface {
|
||||
StatementMultiResultQueryExecer
|
||||
QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error)
|
||||
QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error)
|
||||
}
|
||||
|
||||
// QueryMessageExecer is an optional database-level interface for returning
|
||||
// informational server messages alongside one result set.
|
||||
type QueryMessageExecer interface {
|
||||
QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error)
|
||||
QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error)
|
||||
}
|
||||
|
||||
// MultiResultQueryMessageExecer is an optional database-level interface for
|
||||
// returning informational server messages alongside multi-result queries.
|
||||
type MultiResultQueryMessageExecer interface {
|
||||
QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error)
|
||||
QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error)
|
||||
}
|
||||
|
||||
// SessionExecerProvider is implemented by database/sql based drivers that can
|
||||
// pin a long-running job to one physical connection.
|
||||
type SessionExecerProvider interface {
|
||||
@@ -96,6 +141,38 @@ func (e *sqlConnStatementExecer) Exec(query string) (int64, error) {
|
||||
return e.ExecContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlConnStatementExecer) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if e == nil || e.conn == nil {
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := e.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (e *sqlConnStatementExecer) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
return e.QueryContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlConnStatementExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
|
||||
if e == nil || e.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := e.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMultiRows(rows)
|
||||
}
|
||||
|
||||
func (e *sqlConnStatementExecer) QueryMulti(query string) ([]connection.ResultSetData, error) {
|
||||
return e.QueryMultiContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlConnStatementExecer) ExecBatchContext(ctx context.Context, query string) (int64, error) {
|
||||
return e.ExecContext(ctx, query)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
"github.com/golang-sql/sqlexp"
|
||||
_ "github.com/microsoft/go-mssqldb"
|
||||
)
|
||||
|
||||
@@ -26,6 +27,85 @@ type SqlServerDB struct {
|
||||
forwarder *ssh.LocalForwarder
|
||||
}
|
||||
|
||||
type sqlServerSessionExecer struct {
|
||||
conn *sql.Conn
|
||||
}
|
||||
|
||||
func scanSQLServerRowsWithMessages(ctx context.Context, rows *sql.Rows, retmsg *sqlexp.ReturnMessage) ([]connection.ResultSetData, []string, error) {
|
||||
if rows == nil {
|
||||
return []connection.ResultSetData{{Rows: []map[string]interface{}{}, Columns: []string{}}}, nil, nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
var (
|
||||
resultSets []connection.ResultSetData
|
||||
messages []string
|
||||
allMessages []string
|
||||
)
|
||||
active := true
|
||||
for active {
|
||||
raw := retmsg.Message(ctx)
|
||||
switch msg := raw.(type) {
|
||||
case sqlexp.MsgNotice:
|
||||
text := strings.TrimSpace(fmt.Sprint(msg.Message))
|
||||
if text != "" {
|
||||
messages = append(messages, text)
|
||||
allMessages = append(allMessages, text)
|
||||
}
|
||||
case sqlexp.MsgNext:
|
||||
data, cols, err := scanRows(rows)
|
||||
if err != nil {
|
||||
return resultSets, messages, err
|
||||
}
|
||||
if data == nil {
|
||||
data = []map[string]interface{}{}
|
||||
}
|
||||
if cols == nil {
|
||||
cols = []string{}
|
||||
}
|
||||
resultSets = append(resultSets, connection.ResultSetData{
|
||||
Rows: data,
|
||||
Columns: cols,
|
||||
Messages: append([]string(nil), messages...),
|
||||
})
|
||||
messages = nil
|
||||
case sqlexp.MsgRowsAffected:
|
||||
resultSets = append(resultSets, connection.ResultSetData{
|
||||
Rows: []map[string]interface{}{{"affectedRows": msg.Count}},
|
||||
Columns: []string{"affectedRows"},
|
||||
Messages: append([]string(nil), messages...),
|
||||
})
|
||||
messages = nil
|
||||
case sqlexp.MsgNextResultSet:
|
||||
active = rows.NextResultSet()
|
||||
case sqlexp.MsgError:
|
||||
return resultSets, messages, msg.Error
|
||||
default:
|
||||
active = false
|
||||
}
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
resultSets = append(resultSets, connection.ResultSetData{
|
||||
Rows: []map[string]interface{}{},
|
||||
Columns: []string{},
|
||||
Messages: append([]string(nil), messages...),
|
||||
})
|
||||
}
|
||||
if len(resultSets) == 0 {
|
||||
resultSets = []connection.ResultSetData{{
|
||||
Rows: []map[string]interface{}{},
|
||||
Columns: []string{},
|
||||
}}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return resultSets, allMessages, err
|
||||
}
|
||||
return resultSets, allMessages, nil
|
||||
}
|
||||
|
||||
// quoteBracket escapes ] in identifiers for safe use in SQL Server [bracket] notation
|
||||
func quoteBracket(name string) string {
|
||||
return strings.ReplaceAll(name, "]", "]]")
|
||||
@@ -133,54 +213,76 @@ func (s *SqlServerDB) Ping() error {
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) QueryMulti(query string) ([]connection.ResultSetData, error) {
|
||||
results, _, err := s.QueryMultiWithMessages(query)
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := s.conn.Query(query)
|
||||
ctx := context.Background()
|
||||
retmsg := &sqlexp.ReturnMessage{}
|
||||
rows, err := s.conn.QueryContext(ctx, query, retmsg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMultiRows(rows)
|
||||
return scanSQLServerRowsWithMessages(ctx, rows, retmsg)
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
|
||||
results, _, err := s.QueryMultiContextWithMessages(ctx, query)
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := s.conn.QueryContext(ctx, query)
|
||||
retmsg := &sqlexp.ReturnMessage{}
|
||||
rows, err := s.conn.QueryContext(ctx, query, retmsg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMultiRows(rows)
|
||||
return scanSQLServerRowsWithMessages(ctx, rows, retmsg)
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
rows, columns, _, err := s.QueryContextWithMessages(ctx, query)
|
||||
return rows, columns, err
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
return nil, nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := s.conn.QueryContext(ctx, query)
|
||||
resultSets, messages, err := s.QueryMultiContextWithMessages(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
if len(resultSets) == 0 {
|
||||
return []map[string]interface{}{}, []string{}, messages, nil
|
||||
}
|
||||
first := resultSets[0]
|
||||
if first.Rows == nil {
|
||||
first.Rows = []map[string]interface{}{}
|
||||
}
|
||||
if first.Columns == nil {
|
||||
first.Columns = []string{}
|
||||
}
|
||||
return first.Rows, first.Columns, messages, nil
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, columns, _, err := s.QueryWithMessages(query)
|
||||
return rows, columns, err
|
||||
}
|
||||
|
||||
rows, err := s.conn.Query(query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
func (s *SqlServerDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
return s.QueryContextWithMessages(context.Background(), query)
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
@@ -213,7 +315,7 @@ func (s *SqlServerDB) OpenSessionExecer(ctx context.Context) (StatementExecer, e
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewSQLConnStatementExecer(conn), nil
|
||||
return &sqlServerSessionExecer{conn: conn}, nil
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) Exec(query string) (int64, error) {
|
||||
@@ -227,6 +329,87 @@ func (s *SqlServerDB) Exec(query string) (int64, error) {
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) Exec(query string) (int64, error) {
|
||||
return e.ExecContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if e == nil || e.conn == nil {
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := e.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
rows, columns, _, err := e.QueryWithMessages(query)
|
||||
return rows, columns, err
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
rows, columns, _, err := e.QueryContextWithMessages(ctx, query)
|
||||
return rows, columns, err
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
return e.QueryContextWithMessages(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
results, messages, err := e.QueryMultiContextWithMessages(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return []map[string]interface{}{}, []string{}, messages, nil
|
||||
}
|
||||
first := results[0]
|
||||
if first.Rows == nil {
|
||||
first.Rows = []map[string]interface{}{}
|
||||
}
|
||||
if first.Columns == nil {
|
||||
first.Columns = []string{}
|
||||
}
|
||||
return first.Rows, first.Columns, messages, nil
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) QueryMulti(query string) ([]connection.ResultSetData, error) {
|
||||
results, _, err := e.QueryMultiWithMessages(query)
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
|
||||
results, _, err := e.QueryMultiContextWithMessages(ctx, query)
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
|
||||
return e.QueryMultiContextWithMessages(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
|
||||
if e == nil || e.conn == nil {
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
retmsg := &sqlexp.ReturnMessage{}
|
||||
rows, err := e.conn.QueryContext(ctx, query, retmsg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanSQLServerRowsWithMessages(ctx, rows, retmsg)
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) Close() error {
|
||||
if e == nil || e.conn == nil {
|
||||
return nil
|
||||
}
|
||||
return e.conn.Close()
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) GetDatabases() ([]string, error) {
|
||||
query := "SELECT name FROM sys.databases WHERE state_desc = 'ONLINE' ORDER BY name"
|
||||
data, _, err := s.Query(query)
|
||||
|
||||
Reference in New Issue
Block a user