mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-06 20:03:05 +08:00
🐛 fix(postgres-connection): 修复无postgres库时连接失败并支持默认连接库配置
- PostgreSQL 空 database 时按 postgres、template1、用户名同名库回退连接 - 移除后端对 database=postgres 的硬编码写死逻辑 - 连接弹窗新增 PostgreSQL 默认连接数据库(可选)配置项 - refs #120
This commit is contained in:
@@ -1061,6 +1061,7 @@ const ConnectionModal: React.FC<{
|
||||
});
|
||||
} else if (type !== 'custom') {
|
||||
form.setFieldsValue({
|
||||
database: '',
|
||||
port: defaultPort,
|
||||
mysqlTopology: 'single',
|
||||
mongoTopology: 'single',
|
||||
@@ -1199,6 +1200,7 @@ const ConnectionModal: React.FC<{
|
||||
type: 'mysql',
|
||||
host: 'localhost',
|
||||
port: 3306,
|
||||
database: '',
|
||||
user: 'root',
|
||||
useSSH: false,
|
||||
sshPort: 22,
|
||||
@@ -1338,6 +1340,16 @@ const ConnectionModal: React.FC<{
|
||||
)}
|
||||
</div>
|
||||
|
||||
{(dbType === 'postgres' || dbType === 'kingbase' || dbType === 'highgo' || dbType === 'vastbase') && (
|
||||
<Form.Item
|
||||
name="database"
|
||||
label="默认连接数据库(可选)"
|
||||
help="留空会自动尝试 postgres、template1、与当前用户名同名数据库"
|
||||
>
|
||||
<Input placeholder="例如:appdb" />
|
||||
</Form.Item>
|
||||
)}
|
||||
|
||||
{(dbType === 'mysql' || dbType === 'mariadb' || dbType === 'diros' || dbType === 'sphinx') && (
|
||||
<>
|
||||
<Form.Item name="mysqlTopology" label="连接模式">
|
||||
|
||||
@@ -82,10 +82,6 @@ func getCacheKey(config connection.ConnectionConfig) string {
|
||||
if !config.UseProxy {
|
||||
config.Proxy = connection.ProxyConfig{}
|
||||
}
|
||||
// 保持与驱动默认一致,避免同一连接被重复缓存
|
||||
if config.Type == "postgres" && config.Database == "" {
|
||||
config.Database = "postgres"
|
||||
}
|
||||
|
||||
b, _ := json.Marshal(config)
|
||||
sum := sha256.Sum256(b)
|
||||
|
||||
@@ -190,9 +190,6 @@ func (a *App) RenameDatabase(config connection.ConnectionConfig, oldName string,
|
||||
return connection.QueryResult{Success: false, Message: "当前连接正在使用目标数据库,请先连接到其他数据库后再重命名"}
|
||||
}
|
||||
runConfig := config
|
||||
if strings.TrimSpace(runConfig.Database) == "" {
|
||||
runConfig.Database = "postgres"
|
||||
}
|
||||
dbInst, err := a.getDatabase(runConfig)
|
||||
if err != nil {
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
@@ -228,9 +225,6 @@ func (a *App) DropDatabase(config connection.ConnectionConfig, dbName string) co
|
||||
return connection.QueryResult{Success: false, Message: "当前连接正在使用目标数据库,请先连接到其他数据库后再删除"}
|
||||
}
|
||||
runConfig = config
|
||||
if strings.TrimSpace(runConfig.Database) == "" {
|
||||
runConfig.Database = "postgres"
|
||||
}
|
||||
sql = fmt.Sprintf("DROP DATABASE %s", quoteIdentByType(dbType, dbName))
|
||||
default:
|
||||
return connection.QueryResult{Success: false, Message: fmt.Sprintf("当前数据源(%s)暂不支持删除数据库", dbType)}
|
||||
|
||||
48
internal/db/postgres_connect_test.go
Normal file
48
internal/db/postgres_connect_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/connection"
|
||||
)
|
||||
|
||||
func TestResolvePostgresConnectDatabases_ExplicitDatabase(t *testing.T) {
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "postgres",
|
||||
Database: "analytics",
|
||||
User: "app_user",
|
||||
}
|
||||
|
||||
got := resolvePostgresConnectDatabases(cfg)
|
||||
want := []string{"analytics"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("unexpected databases, got=%v want=%v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePostgresConnectDatabases_FallbackOrder(t *testing.T) {
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "postgres",
|
||||
User: "app_user",
|
||||
}
|
||||
|
||||
got := resolvePostgresConnectDatabases(cfg)
|
||||
want := []string{"postgres", "template1", "app_user"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("unexpected databases, got=%v want=%v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePostgresConnectDatabases_DeduplicateUserDefault(t *testing.T) {
|
||||
cfg := connection.ConnectionConfig{
|
||||
Type: "postgres",
|
||||
User: "postgres",
|
||||
}
|
||||
|
||||
got := resolvePostgresConnectDatabases(cfg)
|
||||
want := []string{"postgres", "template1"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("unexpected databases, got=%v want=%v", got, want)
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,30 @@ type PostgresDB struct {
|
||||
forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder
|
||||
}
|
||||
|
||||
func resolvePostgresConnectDatabases(config connection.ConnectionConfig) []string {
|
||||
explicit := strings.TrimSpace(config.Database)
|
||||
if explicit != "" {
|
||||
return []string{explicit}
|
||||
}
|
||||
|
||||
candidates := []string{"postgres", "template1", strings.TrimSpace(config.User)}
|
||||
seen := make(map[string]struct{}, len(candidates))
|
||||
result := make([]string, 0, len(candidates))
|
||||
for _, name := range candidates {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
normalized := strings.ToLower(trimmed)
|
||||
if _, exists := seen[normalized]; exists {
|
||||
continue
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (p *PostgresDB) getDSN(config connection.ConnectionConfig) string {
|
||||
// postgres://user:password@host:port/dbname?sslmode=disable
|
||||
dbname := config.Database
|
||||
@@ -53,8 +77,23 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
|
||||
return fmt.Errorf("%s", reason)
|
||||
}
|
||||
|
||||
var dsn string
|
||||
var err error
|
||||
runConfig := config
|
||||
p.pingTimeout = getConnectTimeout(config)
|
||||
|
||||
cleanupOnFailure := true
|
||||
defer func() {
|
||||
if !cleanupOnFailure {
|
||||
return
|
||||
}
|
||||
if p.conn != nil {
|
||||
_ = p.conn.Close()
|
||||
p.conn = nil
|
||||
}
|
||||
if p.forwarder != nil {
|
||||
_ = p.forwarder.Close()
|
||||
p.forwarder = nil
|
||||
}
|
||||
}()
|
||||
|
||||
if config.UseSSH {
|
||||
// Create SSH tunnel with local port forwarding
|
||||
@@ -83,24 +122,44 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error {
|
||||
localConfig.Port = port
|
||||
localConfig.UseSSH = false // Disable SSH flag for DSN generation
|
||||
|
||||
dsn = p.getDSN(localConfig)
|
||||
runConfig = localConfig
|
||||
logger.Infof("PostgreSQL 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
|
||||
} else {
|
||||
dsn = p.getDSN(config)
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开数据库连接失败:%w", err)
|
||||
}
|
||||
p.conn = db
|
||||
p.pingTimeout = getConnectTimeout(config)
|
||||
attemptDBs := resolvePostgresConnectDatabases(runConfig)
|
||||
var failures []string
|
||||
for _, dbName := range attemptDBs {
|
||||
attemptConfig := runConfig
|
||||
attemptConfig.Database = dbName
|
||||
dsn := p.getDSN(attemptConfig)
|
||||
|
||||
// Force verification
|
||||
if err := p.Ping(); err != nil {
|
||||
return fmt.Errorf("连接建立后验证失败:%w", err)
|
||||
dbConn, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
failures = append(failures, fmt.Sprintf("数据库=%s 打开连接失败: %v", dbName, err))
|
||||
continue
|
||||
}
|
||||
p.conn = dbConn
|
||||
|
||||
// Force verification
|
||||
if err := p.Ping(); err != nil {
|
||||
failures = append(failures, fmt.Sprintf("数据库=%s 验证失败: %v", dbName, err))
|
||||
_ = dbConn.Close()
|
||||
p.conn = nil
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.TrimSpace(config.Database) == "" && !strings.EqualFold(dbName, "postgres") {
|
||||
logger.Infof("PostgreSQL 自动选择连接数据库:%s", dbName)
|
||||
}
|
||||
|
||||
cleanupOnFailure = false
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
|
||||
if len(failures) == 0 {
|
||||
return fmt.Errorf("连接建立后验证失败:未找到可用的连接数据库")
|
||||
}
|
||||
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";"))
|
||||
}
|
||||
|
||||
func (p *PostgresDB) Close() error {
|
||||
|
||||
Reference in New Issue
Block a user