From ff0661d285fdacc6b5123306b85dbb5fc4edcf83 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Fri, 8 May 2026 21:41:01 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(sqlserver):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E6=96=B0=E5=BB=BA=E6=95=B0=E6=8D=AE=E5=BA=93=E8=AF=AD?= =?UTF-8?q?=E6=B3=95=E5=85=BC=E5=AE=B9=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refs #438 - SQL Server 创建数据库改用方言标识符 - 补齐 mssql/sql_server 别名归一 - 增加回归测试 --- internal/app/methods_db.go | 16 +++- internal/app/methods_db_create_test.go | 115 +++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 internal/app/methods_db_create_test.go diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index c469fb2..62b9885 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -106,6 +106,11 @@ func (a *App) MongoDiscoverMembers(config connection.ConnectionConfig) connectio } func (a *App) CreateDatabase(config connection.ConnectionConfig, dbName string) connection.QueryResult { + dbName = strings.TrimSpace(dbName) + if dbName == "" { + return connection.QueryResult{Success: false, Message: "数据库名称不能为空"} + } + runConfig := config runConfig.Database = "" @@ -120,6 +125,8 @@ func (a *App) CreateDatabase(config connection.ConnectionConfig, dbName string) if dbType == "postgres" || dbType == "kingbase" || dbType == "highgo" || dbType == "vastbase" || dbType == "opengauss" { escapedDbName = strings.ReplaceAll(dbName, `"`, `""`) query = fmt.Sprintf("CREATE DATABASE \"%s\"", escapedDbName) + } else if dbType == "sqlserver" { + query = fmt.Sprintf("CREATE DATABASE %s", quoteIdentByType(dbType, dbName)) } else if dbType == "tdengine" { query = fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", quoteIdentByType(dbType, dbName)) } else if dbType == "clickhouse" { @@ -145,6 +152,9 @@ func resolveDDLDBType(config connection.ConnectionConfig) string { if dbType == "doris" { return "diros" } + if dbType == "mssql" || dbType == "sql_server" || dbType == "sql-server" { + return "sqlserver" + } if dbType == "oceanbase" && isOceanBaseOracleProtocol(config) { return "oracle" } @@ -164,6 +174,8 @@ func resolveDDLDBType(config connection.ConnectionConfig) string { return "sqlite" case "sphinxql": return "sphinx" + case "mssql", "sqlserver", "sql_server", "sql-server": + return "sqlserver" case "diros", "doris": return "diros" case "kingbase", "kingbase8", "kingbasees", "kingbasev8": @@ -191,6 +203,8 @@ func resolveDDLDBType(config connection.ConnectionConfig) string { return "sqlite" case strings.Contains(driver, "sphinx"): return "sphinx" + case strings.Contains(driver, "sqlserver"), strings.Contains(driver, "sql_server"), strings.Contains(driver, "sql-server"), strings.Contains(driver, "mssql"): + return "sqlserver" case strings.Contains(driver, "diros"), strings.Contains(driver, "doris"): return "diros" case strings.Contains(driver, "oceanbase"): @@ -257,7 +271,7 @@ func buildRunConfigForDDL(config connection.ConnectionConfig, dbType string, dbN if strings.EqualFold(strings.TrimSpace(config.Type), "custom") { // custom 连接的 dbName 语义依赖 driver,尽量在常见驱动上对齐内置类型行为。 switch dbType { - case "mysql", "mariadb", "oceanbase", "diros", "sphinx", "postgres", "kingbase", "highgo", "vastbase", "opengauss", "dameng", "clickhouse": + case "mysql", "mariadb", "oceanbase", "diros", "sphinx", "postgres", "kingbase", "highgo", "vastbase", "opengauss", "dameng", "sqlserver", "clickhouse": if strings.TrimSpace(dbName) != "" { runConfig.Database = strings.TrimSpace(dbName) } diff --git a/internal/app/methods_db_create_test.go b/internal/app/methods_db_create_test.go new file mode 100644 index 0000000..cad9ab1 --- /dev/null +++ b/internal/app/methods_db_create_test.go @@ -0,0 +1,115 @@ +package app + +import ( + "testing" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/db" + "GoNavi-Wails/internal/secretstore" +) + +type fakeCreateDatabaseDB struct { + connectConfig connection.ConnectionConfig + execQueries []string +} + +func (f *fakeCreateDatabaseDB) Connect(config connection.ConnectionConfig) error { + f.connectConfig = config + return nil +} +func (f *fakeCreateDatabaseDB) Close() error { return nil } +func (f *fakeCreateDatabaseDB) Ping() error { return nil } +func (f *fakeCreateDatabaseDB) Query(query string) ([]map[string]interface{}, []string, error) { + return nil, nil, nil +} +func (f *fakeCreateDatabaseDB) Exec(query string) (int64, error) { + f.execQueries = append(f.execQueries, query) + return 0, nil +} +func (f *fakeCreateDatabaseDB) GetDatabases() ([]string, error) { return nil, nil } +func (f *fakeCreateDatabaseDB) GetTables(dbName string) ([]string, error) { + return nil, nil +} +func (f *fakeCreateDatabaseDB) GetCreateStatement(dbName, tableName string) (string, error) { + return "", nil +} +func (f *fakeCreateDatabaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { + return nil, nil +} +func (f *fakeCreateDatabaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { + return nil, nil +} +func (f *fakeCreateDatabaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { + return nil, nil +} +func (f *fakeCreateDatabaseDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { + return nil, nil +} +func (f *fakeCreateDatabaseDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { + return nil, nil +} + +var _ db.Database = (*fakeCreateDatabaseDB)(nil) + +func TestResolveDDLDBType_SQLServerAliases(t *testing.T) { + tests := []connection.ConnectionConfig{ + {Type: "mssql"}, + {Type: "sql_server"}, + {Type: "custom", Driver: "mssql"}, + {Type: "custom", Driver: "sql-server"}, + } + + for _, cfg := range tests { + if got := resolveDDLDBType(cfg); got != "sqlserver" { + t.Fatalf("resolveDDLDBType(%+v) = %q, want sqlserver", cfg, got) + } + } +} + +func TestBuildRunConfigForDDL_CustomSQLServerUsesDatabase(t *testing.T) { + got := buildRunConfigForDDL(connection.ConnectionConfig{ + Type: "custom", + Driver: "mssql", + Database: "master", + }, "sqlserver", "target_db") + if got.Database != "target_db" { + t.Fatalf("expected custom SQL Server DDL database target_db, got %q", got.Database) + } +} + +func TestCreateDatabase_SQLServerUsesBracketIdentifiers(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc + }) + + fakeDB := &fakeCreateDatabaseDB{} + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) { + return raw, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + result := app.CreateDatabase(connection.ConnectionConfig{ + Type: "sqlserver", + Database: "master", + }, "lg") + + if !result.Success { + t.Fatalf("expected SQL Server create database success, got failure: %s", result.Message) + } + if fakeDB.connectConfig.Database != "" { + t.Fatalf("expected create database connection to clear database and use default master, got %q", fakeDB.connectConfig.Database) + } + if len(fakeDB.execQueries) != 1 { + t.Fatalf("expected one create database statement, got %d: %#v", len(fakeDB.execQueries), fakeDB.execQueries) + } + const want = "CREATE DATABASE [lg]" + if fakeDB.execQueries[0] != want { + t.Fatalf("unexpected SQL Server create database SQL, want %q got %q", want, fakeDB.execQueries[0]) + } +}