diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx index cfc3b71..b6cb4e3 100644 --- a/frontend/src/components/ConnectionModal.tsx +++ b/frontend/src/components/ConnectionModal.tsx @@ -1061,6 +1061,7 @@ const ConnectionModal: React.FC<{ }); } else if (type !== 'custom') { form.setFieldsValue({ + database: '', port: defaultPort, mysqlTopology: 'single', mongoTopology: 'single', @@ -1199,6 +1200,7 @@ const ConnectionModal: React.FC<{ type: 'mysql', host: 'localhost', port: 3306, + database: '', user: 'root', useSSH: false, sshPort: 22, @@ -1338,6 +1340,16 @@ const ConnectionModal: React.FC<{ )} + {(dbType === 'postgres' || dbType === 'kingbase' || dbType === 'highgo' || dbType === 'vastbase') && ( + + + + )} + {(dbType === 'mysql' || dbType === 'mariadb' || dbType === 'diros' || dbType === 'sphinx') && ( <> diff --git a/internal/app/app.go b/internal/app/app.go index 124b9d2..9d8f081 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -82,10 +82,6 @@ func getCacheKey(config connection.ConnectionConfig) string { if !config.UseProxy { config.Proxy = connection.ProxyConfig{} } - // 保持与驱动默认一致,避免同一连接被重复缓存 - if config.Type == "postgres" && config.Database == "" { - config.Database = "postgres" - } b, _ := json.Marshal(config) sum := sha256.Sum256(b) diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index 47587f6..e31b9e8 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -190,9 +190,6 @@ func (a *App) RenameDatabase(config connection.ConnectionConfig, oldName string, return connection.QueryResult{Success: false, Message: "当前连接正在使用目标数据库,请先连接到其他数据库后再重命名"} } runConfig := config - if strings.TrimSpace(runConfig.Database) == "" { - runConfig.Database = "postgres" - } dbInst, err := a.getDatabase(runConfig) if err != nil { return connection.QueryResult{Success: false, Message: err.Error()} @@ -228,9 +225,6 @@ func (a *App) DropDatabase(config connection.ConnectionConfig, dbName string) co return connection.QueryResult{Success: false, Message: "当前连接正在使用目标数据库,请先连接到其他数据库后再删除"} } runConfig = config - if strings.TrimSpace(runConfig.Database) == "" { - runConfig.Database = "postgres" - } sql = fmt.Sprintf("DROP DATABASE %s", quoteIdentByType(dbType, dbName)) default: return connection.QueryResult{Success: false, Message: fmt.Sprintf("当前数据源(%s)暂不支持删除数据库", dbType)} diff --git a/internal/db/postgres_connect_test.go b/internal/db/postgres_connect_test.go new file mode 100644 index 0000000..d8707c5 --- /dev/null +++ b/internal/db/postgres_connect_test.go @@ -0,0 +1,48 @@ +package db + +import ( + "reflect" + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestResolvePostgresConnectDatabases_ExplicitDatabase(t *testing.T) { + cfg := connection.ConnectionConfig{ + Type: "postgres", + Database: "analytics", + User: "app_user", + } + + got := resolvePostgresConnectDatabases(cfg) + want := []string{"analytics"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected databases, got=%v want=%v", got, want) + } +} + +func TestResolvePostgresConnectDatabases_FallbackOrder(t *testing.T) { + cfg := connection.ConnectionConfig{ + Type: "postgres", + User: "app_user", + } + + got := resolvePostgresConnectDatabases(cfg) + want := []string{"postgres", "template1", "app_user"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected databases, got=%v want=%v", got, want) + } +} + +func TestResolvePostgresConnectDatabases_DeduplicateUserDefault(t *testing.T) { + cfg := connection.ConnectionConfig{ + Type: "postgres", + User: "postgres", + } + + got := resolvePostgresConnectDatabases(cfg) + want := []string{"postgres", "template1"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected databases, got=%v want=%v", got, want) + } +} diff --git a/internal/db/postgres_impl.go b/internal/db/postgres_impl.go index a179e91..7727773 100644 --- a/internal/db/postgres_impl.go +++ b/internal/db/postgres_impl.go @@ -24,6 +24,30 @@ type PostgresDB struct { forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder } +func resolvePostgresConnectDatabases(config connection.ConnectionConfig) []string { + explicit := strings.TrimSpace(config.Database) + if explicit != "" { + return []string{explicit} + } + + candidates := []string{"postgres", "template1", strings.TrimSpace(config.User)} + seen := make(map[string]struct{}, len(candidates)) + result := make([]string, 0, len(candidates)) + for _, name := range candidates { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + continue + } + normalized := strings.ToLower(trimmed) + if _, exists := seen[normalized]; exists { + continue + } + seen[normalized] = struct{}{} + result = append(result, trimmed) + } + return result +} + func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string { // postgres://user:password@host:port/dbname?sslmode=disable dbname := config.Database @@ -53,8 +77,23 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error { return fmt.Errorf("%s", reason) } - var dsn string - var err error + runConfig := config + p.pingTimeout = getConnectTimeout(config) + + cleanupOnFailure := true + defer func() { + if !cleanupOnFailure { + return + } + if p.conn != nil { + _ = p.conn.Close() + p.conn = nil + } + if p.forwarder != nil { + _ = p.forwarder.Close() + p.forwarder = nil + } + }() if config.UseSSH { // Create SSH tunnel with local port forwarding @@ -83,24 +122,44 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error { localConfig.Port = port localConfig.UseSSH = false // Disable SSH flag for DSN generation - dsn = p.getDSN(localConfig) + runConfig = localConfig logger.Infof("PostgreSQL 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) - } else { - dsn = p.getDSN(config) } - db, err := sql.Open("postgres", dsn) - if err != nil { - return fmt.Errorf("打开数据库连接失败:%w", err) - } - p.conn = db - p.pingTimeout = getConnectTimeout(config) + attemptDBs := resolvePostgresConnectDatabases(runConfig) + var failures []string + for _, dbName := range attemptDBs { + attemptConfig := runConfig + attemptConfig.Database = dbName + dsn := p.getDSN(attemptConfig) - // Force verification - if err := p.Ping(); err != nil { - return fmt.Errorf("连接建立后验证失败:%w", err) + dbConn, err := sql.Open("postgres", dsn) + if err != nil { + failures = append(failures, fmt.Sprintf("数据库=%s 打开连接失败: %v", dbName, err)) + continue + } + p.conn = dbConn + + // Force verification + if err := p.Ping(); err != nil { + failures = append(failures, fmt.Sprintf("数据库=%s 验证失败: %v", dbName, err)) + _ = dbConn.Close() + p.conn = nil + continue + } + + if strings.TrimSpace(config.Database) == "" && !strings.EqualFold(dbName, "postgres") { + logger.Infof("PostgreSQL 自动选择连接数据库:%s", dbName) + } + + cleanupOnFailure = false + return nil } - return nil + + if len(failures) == 0 { + return fmt.Errorf("连接建立后验证失败:未找到可用的连接数据库") + } + return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";")) } func (p *PostgresDB) Close() error {