mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-12 07:09:40 +08:00
🐛 fix(sqlserver): 修复新建数据库语法兼容问题
Refs #438 - SQL Server 创建数据库改用方言标识符 - 补齐 mssql/sql_server 别名归一 - 增加回归测试
This commit is contained in:
@@ -106,6 +106,11 @@ func (a *App) MongoDiscoverMembers(config connection.ConnectionConfig) connectio
|
||||
}
|
||||
|
||||
func (a *App) CreateDatabase(config connection.ConnectionConfig, dbName string) connection.QueryResult {
|
||||
dbName = strings.TrimSpace(dbName)
|
||||
if dbName == "" {
|
||||
return connection.QueryResult{Success: false, Message: "数据库名称不能为空"}
|
||||
}
|
||||
|
||||
runConfig := config
|
||||
runConfig.Database = ""
|
||||
|
||||
@@ -120,6 +125,8 @@ func (a *App) CreateDatabase(config connection.ConnectionConfig, dbName string)
|
||||
if dbType == "postgres" || dbType == "kingbase" || dbType == "highgo" || dbType == "vastbase" || dbType == "opengauss" {
|
||||
escapedDbName = strings.ReplaceAll(dbName, `"`, `""`)
|
||||
query = fmt.Sprintf("CREATE DATABASE \"%s\"", escapedDbName)
|
||||
} else if dbType == "sqlserver" {
|
||||
query = fmt.Sprintf("CREATE DATABASE %s", quoteIdentByType(dbType, dbName))
|
||||
} else if dbType == "tdengine" {
|
||||
query = fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", quoteIdentByType(dbType, dbName))
|
||||
} else if dbType == "clickhouse" {
|
||||
@@ -145,6 +152,9 @@ func resolveDDLDBType(config connection.ConnectionConfig) string {
|
||||
if dbType == "doris" {
|
||||
return "diros"
|
||||
}
|
||||
if dbType == "mssql" || dbType == "sql_server" || dbType == "sql-server" {
|
||||
return "sqlserver"
|
||||
}
|
||||
if dbType == "oceanbase" && isOceanBaseOracleProtocol(config) {
|
||||
return "oracle"
|
||||
}
|
||||
@@ -164,6 +174,8 @@ func resolveDDLDBType(config connection.ConnectionConfig) string {
|
||||
return "sqlite"
|
||||
case "sphinxql":
|
||||
return "sphinx"
|
||||
case "mssql", "sqlserver", "sql_server", "sql-server":
|
||||
return "sqlserver"
|
||||
case "diros", "doris":
|
||||
return "diros"
|
||||
case "kingbase", "kingbase8", "kingbasees", "kingbasev8":
|
||||
@@ -191,6 +203,8 @@ func resolveDDLDBType(config connection.ConnectionConfig) string {
|
||||
return "sqlite"
|
||||
case strings.Contains(driver, "sphinx"):
|
||||
return "sphinx"
|
||||
case strings.Contains(driver, "sqlserver"), strings.Contains(driver, "sql_server"), strings.Contains(driver, "sql-server"), strings.Contains(driver, "mssql"):
|
||||
return "sqlserver"
|
||||
case strings.Contains(driver, "diros"), strings.Contains(driver, "doris"):
|
||||
return "diros"
|
||||
case strings.Contains(driver, "oceanbase"):
|
||||
@@ -257,7 +271,7 @@ func buildRunConfigForDDL(config connection.ConnectionConfig, dbType string, dbN
|
||||
if strings.EqualFold(strings.TrimSpace(config.Type), "custom") {
|
||||
// custom 连接的 dbName 语义依赖 driver,尽量在常见驱动上对齐内置类型行为。
|
||||
switch dbType {
|
||||
case "mysql", "mariadb", "oceanbase", "diros", "sphinx", "postgres", "kingbase", "highgo", "vastbase", "opengauss", "dameng", "clickhouse":
|
||||
case "mysql", "mariadb", "oceanbase", "diros", "sphinx", "postgres", "kingbase", "highgo", "vastbase", "opengauss", "dameng", "sqlserver", "clickhouse":
|
||||
if strings.TrimSpace(dbName) != "" {
|
||||
runConfig.Database = strings.TrimSpace(dbName)
|
||||
}
|
||||
|
||||
115
internal/app/methods_db_create_test.go
Normal file
115
internal/app/methods_db_create_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
"GoNavi-Wails/internal/db"
|
||||
"GoNavi-Wails/internal/secretstore"
|
||||
)
|
||||
|
||||
type fakeCreateDatabaseDB struct {
|
||||
connectConfig connection.ConnectionConfig
|
||||
execQueries []string
|
||||
}
|
||||
|
||||
func (f *fakeCreateDatabaseDB) Connect(config connection.ConnectionConfig) error {
|
||||
f.connectConfig = config
|
||||
return nil
|
||||
}
|
||||
func (f *fakeCreateDatabaseDB) Close() error { return nil }
|
||||
func (f *fakeCreateDatabaseDB) Ping() error { return nil }
|
||||
func (f *fakeCreateDatabaseDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (f *fakeCreateDatabaseDB) Exec(query string) (int64, error) {
|
||||
f.execQueries = append(f.execQueries, query)
|
||||
return 0, nil
|
||||
}
|
||||
func (f *fakeCreateDatabaseDB) GetDatabases() ([]string, error) { return nil, nil }
|
||||
func (f *fakeCreateDatabaseDB) GetTables(dbName string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeCreateDatabaseDB) GetCreateStatement(dbName, tableName string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
func (f *fakeCreateDatabaseDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeCreateDatabaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeCreateDatabaseDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeCreateDatabaseDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeCreateDatabaseDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var _ db.Database = (*fakeCreateDatabaseDB)(nil)
|
||||
|
||||
func TestResolveDDLDBType_SQLServerAliases(t *testing.T) {
|
||||
tests := []connection.ConnectionConfig{
|
||||
{Type: "mssql"},
|
||||
{Type: "sql_server"},
|
||||
{Type: "custom", Driver: "mssql"},
|
||||
{Type: "custom", Driver: "sql-server"},
|
||||
}
|
||||
|
||||
for _, cfg := range tests {
|
||||
if got := resolveDDLDBType(cfg); got != "sqlserver" {
|
||||
t.Fatalf("resolveDDLDBType(%+v) = %q, want sqlserver", cfg, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRunConfigForDDL_CustomSQLServerUsesDatabase(t *testing.T) {
|
||||
got := buildRunConfigForDDL(connection.ConnectionConfig{
|
||||
Type: "custom",
|
||||
Driver: "mssql",
|
||||
Database: "master",
|
||||
}, "sqlserver", "target_db")
|
||||
if got.Database != "target_db" {
|
||||
t.Fatalf("expected custom SQL Server DDL database target_db, got %q", got.Database)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDatabase_SQLServerUsesBracketIdentifiers(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc
|
||||
})
|
||||
|
||||
fakeDB := &fakeCreateDatabaseDB{}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) {
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
result := app.CreateDatabase(connection.ConnectionConfig{
|
||||
Type: "sqlserver",
|
||||
Database: "master",
|
||||
}, "lg")
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("expected SQL Server create database success, got failure: %s", result.Message)
|
||||
}
|
||||
if fakeDB.connectConfig.Database != "" {
|
||||
t.Fatalf("expected create database connection to clear database and use default master, got %q", fakeDB.connectConfig.Database)
|
||||
}
|
||||
if len(fakeDB.execQueries) != 1 {
|
||||
t.Fatalf("expected one create database statement, got %d: %#v", len(fakeDB.execQueries), fakeDB.execQueries)
|
||||
}
|
||||
const want = "CREATE DATABASE [lg]"
|
||||
if fakeDB.execQueries[0] != want {
|
||||
t.Fatalf("unexpected SQL Server create database SQL, want %q got %q", want, fakeDB.execQueries[0])
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user