From 21c427bc39d7ecb85495f60835dff77d2a64da7a Mon Sep 17 00:00:00 2001 From: Syngnat Date: Thu, 18 Jun 2026 20:29:19 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(connection):=20=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E5=A4=9A=E6=95=B0=E6=8D=AE=E6=BA=90=E8=BF=9E=E6=8E=A5?= =?UTF-8?q?=E6=95=B0=E5=8D=A0=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 测试连接改为隔离连接,成功后立即关闭并避免写入全局缓存 - 新增通用 SQL 连接池配置,限制网络型数据源空闲连接长期占用 - Redis 测试连接改为临时客户端并立即释放 - MySQL 连接数超限时释放同实例缓存连接并重试 - 补充连接释放、缓存重试和连接池参数回归测试 --- internal/app/app.go | 113 +++++++++++++++++- internal/app/methods_db.go | 58 +++++---- internal/app/methods_db_conn_test.go | 123 +++++++++++++++++++- internal/app/methods_redis.go | 40 ++++++- internal/app/methods_redis_test.go | 49 +++++++- internal/db/clickhouse_impl.go | 1 + internal/db/custom_impl.go | 3 + internal/db/dameng_impl.go | 1 + internal/db/diros_impl.go | 1 + internal/db/gaussdb_impl.go | 2 + internal/db/highgo_impl.go | 1 + internal/db/iris_impl.go | 3 + internal/db/kingbase_impl.go | 4 +- internal/db/mariadb_impl.go | 3 + internal/db/mysql_connection_params_test.go | 15 +++ internal/db/mysql_impl.go | 1 + internal/db/oceanbase_impl.go | 2 + internal/db/oracle_impl.go | 1 + internal/db/postgres_impl.go | 2 + internal/db/sql_pool.go | 27 +++++ internal/db/sqlserver_impl.go | 3 + internal/db/starrocks_impl.go | 1 + internal/db/tdengine_impl.go | 1 + internal/db/vastbase_impl.go | 1 + 24 files changed, 423 insertions(+), 33 deletions(-) create mode 100644 internal/db/sql_pool.go diff --git a/internal/app/app.go b/internal/app/app.go index 57603b4..3a7dc5e 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -348,6 +348,7 @@ func normalizeConnectionReleaseMatchConfig(config connection.ConnectionConfig) c normalized := normalizeCacheKeyConfig(config) normalized.Database = "" normalized.RedisDB = 0 + normalized.ConnectionParams = "" return normalized } @@ -358,6 +359,72 @@ func getConnectionReleaseMatchKey(config connection.ConnectionConfig) string { return hex.EncodeToString(sum[:]) } +type cachedDatabaseCloseTarget struct { + key string + inst db.Database +} + +func (a *App) releaseCachedDatabaseConnectionsForConfig(config connection.ConnectionConfig) int { + if a == nil { + return 0 + } + return a.releaseCachedDatabaseConnectionsByMatchKey(getConnectionReleaseMatchKey(config)) +} + +func (a *App) releaseCachedDatabaseConnectionsByMatchKey(targetKey string) int { + if a == nil || strings.TrimSpace(targetKey) == "" { + return 0 + } + + targets := make([]cachedDatabaseCloseTarget, 0) + a.mu.Lock() + for key, entry := range a.dbCache { + entryConfig := entry.config + if strings.TrimSpace(entryConfig.Type) == "" { + continue + } + if getConnectionReleaseMatchKey(entryConfig) != targetKey { + continue + } + targets = append(targets, cachedDatabaseCloseTarget{key: key, inst: entry.inst}) + delete(a.dbCache, key) + } + a.mu.Unlock() + + for _, target := range targets { + if target.inst == nil { + continue + } + if closeErr := target.inst.Close(); closeErr != nil { + logger.Error(closeErr, "关闭缓存连接失败:缓存Key=%s", shortCacheKey(target.key)) + } + } + + return len(targets) +} + +func isMySQLMaxUserConnectionsError(err error) bool { + if err == nil { + return false + } + message := strings.ToLower(normalizeErrorMessage(err)) + return strings.Contains(message, "max_user_connections") || + (strings.Contains(message, "error 1226") && strings.Contains(message, "has exceeded")) +} + +func withMySQLMaxUserConnectionsHint(err error, released int) error { + if err == nil { + return nil + } + if !isMySQLMaxUserConnectionsError(err) { + return err + } + if released > 0 { + return fmt.Errorf("%w;数据库账号连接数已达上限(max_user_connections),GoNavi 已释放同一连接实例的 %d 个缓存连接并重试;若仍失败,请关闭 Navicat/其他客户端连接或提高数据库用户 max_user_connections", err, released) + } + return fmt.Errorf("%w;数据库账号连接数已达上限(max_user_connections),GoNavi 未找到可释放的同实例缓存连接;请关闭 Navicat/其他客户端连接或提高数据库用户 max_user_connections", err) +} + func shortCacheKey(cacheKey string) string { shortKey := cacheKey if len(shortKey) > 12 { @@ -638,11 +705,10 @@ func (a *App) getDatabase(config connection.ConnectionConfig) (db.Database, erro } func (a *App) openDatabaseIsolated(config connection.ConnectionConfig) (db.Database, error) { - resolvedConfig, err := a.resolveConnectionSecrets(config) + effectiveConfig, err := a.resolveEffectiveConnectionConfig(config) if err != nil { - return nil, wrapConnectError(config, err) + return nil, err } - effectiveConfig := applyGlobalProxyToConnection(resolvedConfig) if supported, reason := driverRuntimeSupportStatusFunc(effectiveConfig.Type); !supported { if strings.TrimSpace(reason) == "" { reason = fmt.Sprintf("%s 驱动未启用,请先在驱动管理中安装启用", strings.TrimSpace(effectiveConfig.Type)) @@ -670,6 +736,14 @@ func (a *App) openDatabaseIsolated(config connection.ConnectionConfig) (db.Datab return dbInst, nil } +func (a *App) resolveEffectiveConnectionConfig(config connection.ConnectionConfig) (connection.ConnectionConfig, error) { + resolvedConfig, err := a.resolveConnectionSecrets(config) + if err != nil { + return config, wrapConnectError(config, err) + } + return applyGlobalProxyToConnection(resolvedConfig), nil +} + func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing bool) (db.Database, error) { resolvedConfig, err := a.resolveConnectionSecrets(config) if err != nil { @@ -771,9 +845,14 @@ func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing initialKey := key dbInst, connectedConfig, err := a.connectDatabaseWithStartupRetry(resolvedConfig) if err != nil { - failedKey := getCacheKey(connectedConfig) - a.recordConnectFailureByKey(failedKey, err) - return nil, err + retryInst, retryConfig, retryErr := a.retryConnectAfterMySQLMaxUserConnections(resolvedConfig, connectedConfig, err) + if retryErr != nil { + failedKey := getCacheKey(retryConfig) + a.recordConnectFailureByKey(failedKey, retryErr) + return nil, retryErr + } + dbInst = retryInst + connectedConfig = retryConfig } a.clearConnectFailureByKey(initialKey) effectiveConfig = connectedConfig @@ -800,6 +879,28 @@ func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing return dbInst, nil } +func (a *App) retryConnectAfterMySQLMaxUserConnections(rawConfig connection.ConnectionConfig, failedConfig connection.ConnectionConfig, err error) (db.Database, connection.ConnectionConfig, error) { + if !isMySQLMaxUserConnectionsError(err) { + return nil, failedConfig, err + } + + released := a.releaseCachedDatabaseConnectionsForConfig(failedConfig) + logger.Warnf("检测到 MySQL 用户连接数超限,已释放同实例缓存连接:%s 数量=%d", formatConnSummary(failedConfig), released) + if released <= 0 { + return nil, failedConfig, withMySQLMaxUserConnectionsHint(err, released) + } + + dbInst, connectedConfig, retryErr := a.connectDatabaseWithStartupRetry(rawConfig) + if retryErr != nil { + if isMySQLMaxUserConnectionsError(retryErr) { + return nil, connectedConfig, withMySQLMaxUserConnectionsHint(retryErr, released) + } + return nil, connectedConfig, retryErr + } + logger.Infof("MySQL 用户连接数超限释放缓存后重连成功:%s 释放数量=%d", formatConnSummary(connectedConfig), released) + return dbInst, connectedConfig, nil +} + func (a *App) getCachedConnectFailureByKey(key string) (cachedConnectFailure, time.Duration, bool) { if a == nil || strings.TrimSpace(key) == "" { return cachedConnectFailure{}, 0, false diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index 8745199..8de9e8d 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -81,27 +81,7 @@ func (a *App) DBReleaseConnection(config connection.ConnectionConfig) connection logger.Error(wrapped, "DBReleaseConnection 解析连接密文失败:%s", formatConnSummary(config)) return connection.QueryResult{Success: false, Message: wrapped.Error()} } - targetKey := getConnectionReleaseMatchKey(applyGlobalProxyToConnection(resolvedConfig)) - closed := 0 - - a.mu.Lock() - for key, entry := range a.dbCache { - entryConfig := entry.config - if strings.TrimSpace(entryConfig.Type) == "" { - continue - } - if getConnectionReleaseMatchKey(entryConfig) != targetKey { - continue - } - if entry.inst != nil { - if closeErr := entry.inst.Close(); closeErr != nil { - logger.Error(closeErr, "DBReleaseConnection 关闭缓存连接失败:缓存Key=%s", shortCacheKey(key)) - } - } - delete(a.dbCache, key) - closed++ - } - a.mu.Unlock() + closed := a.releaseCachedDatabaseConnectionsForConfig(applyGlobalProxyToConnection(resolvedConfig)) logger.Infof("DBReleaseConnection 已释放数据库连接:%s 数量=%d", formatConnSummary(resolvedConfig), closed) return connection.QueryResult{Success: true, Message: "连接已释放", Data: map[string]int{"closed": closed}} @@ -115,16 +95,50 @@ func (a *App) TestConnection(config connection.ConnectionConfig) connection.Quer logger.Warnf("TestConnection 参数校验失败:耗时=%s %s 原因=%s", time.Since(started).Round(time.Millisecond), formatConnSummary(testConfig), err.Error()) return connection.QueryResult{Success: false, Message: err.Error()} } - _, err := a.getDatabaseForcePing(testConfig) + dbInst, err := a.openDatabaseIsolated(testConfig) + if err != nil { + dbInst, err = a.retryIsolatedTestConnectionAfterMySQLMaxUserConnections(testConfig, err) + } if err != nil { logger.Error(err, "TestConnection 连接测试失败:耗时=%s %s", time.Since(started).Round(time.Millisecond), formatConnSummary(testConfig)) return connection.QueryResult{Success: false, Message: err.Error()} } + if dbInst != nil { + if closeErr := dbInst.Close(); closeErr != nil { + logger.Error(closeErr, "TestConnection 释放临时连接失败:耗时=%s %s", time.Since(started).Round(time.Millisecond), formatConnSummary(testConfig)) + return connection.QueryResult{Success: false, Message: fmt.Sprintf("连接成功但释放测试连接失败:%v", closeErr)} + } + } logger.Infof("TestConnection 连接测试成功:耗时=%s %s", time.Since(started).Round(time.Millisecond), formatConnSummary(testConfig)) return connection.QueryResult{Success: true, Message: "连接成功"} } +func (a *App) retryIsolatedTestConnectionAfterMySQLMaxUserConnections(config connection.ConnectionConfig, err error) (db.Database, error) { + if !isMySQLMaxUserConnectionsError(err) { + return nil, err + } + + effectiveConfig, resolveErr := a.resolveEffectiveConnectionConfig(config) + if resolveErr != nil { + return nil, err + } + released := a.releaseCachedDatabaseConnectionsForConfig(effectiveConfig) + logger.Warnf("测试连接检测到 MySQL 用户连接数超限,已释放同实例缓存连接:%s 数量=%d", formatConnSummary(effectiveConfig), released) + if released <= 0 { + return nil, withMySQLMaxUserConnectionsHint(err, released) + } + + dbInst, retryErr := a.openDatabaseIsolated(config) + if retryErr != nil { + if isMySQLMaxUserConnectionsError(retryErr) { + return nil, withMySQLMaxUserConnectionsHint(retryErr, released) + } + return nil, retryErr + } + return dbInst, nil +} + func (a *App) MongoDiscoverMembers(config connection.ConnectionConfig) connection.QueryResult { config.Type = "mongodb" diff --git a/internal/app/methods_db_conn_test.go b/internal/app/methods_db_conn_test.go index 39c365b..daa51f1 100644 --- a/internal/app/methods_db_conn_test.go +++ b/internal/app/methods_db_conn_test.go @@ -1,17 +1,25 @@ package app import ( + "errors" "strings" "testing" "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/db" ) type releaseRecordingDB struct { - closed int + closed int + connect func(config connection.ConnectionConfig) error } -func (f *releaseRecordingDB) Connect(config connection.ConnectionConfig) error { return nil } +func (f *releaseRecordingDB) Connect(config connection.ConnectionConfig) error { + if f.connect != nil { + return f.connect(config) + } + return nil +} func (f *releaseRecordingDB) Close() error { f.closed++ return nil @@ -214,3 +222,114 @@ func TestDBReleaseConnectionClosesAllDatabaseCacheEntriesForSameInstance(t *test t.Fatalf("expected only unrelated cache entry to remain, got %d", len(app.dbCache)) } } + +func TestTestConnectionUsesIsolatedConnectionAndClosesIt(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc + proxySnapshot := currentGlobalProxyConfig() + defer func() { + newDatabaseFunc = originalNewDatabaseFunc + resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc + if _, err := setGlobalProxyConfig(proxySnapshot.Enabled, proxySnapshot.Proxy); err != nil { + t.Fatalf("restore global proxy failed: %v", err) + } + }() + if _, err := setGlobalProxyConfig(false, proxySnapshot.Proxy); err != nil { + t.Fatalf("disable global proxy failed: %v", err) + } + + testDB := &releaseRecordingDB{} + newDatabaseFunc = func(dbType string) (db.Database, error) { + return testDB, nil + } + resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) { + return raw, nil + } + + app := NewApp() + result := app.TestConnection(connection.ConnectionConfig{ + Type: "mysql", + Host: "127.0.0.1", + Port: 3306, + User: "root", + Database: "app", + }) + + if !result.Success { + t.Fatalf("expected test connection success, got %s", result.Message) + } + if testDB.closed != 1 { + t.Fatalf("expected isolated test connection to be closed once, got %d", testDB.closed) + } + if len(app.dbCache) != 0 { + t.Fatalf("test connection must not write global db cache, got %d entries", len(app.dbCache)) + } +} + +func TestGetDatabaseReleasesSameInstanceCacheAndRetriesOnMaxUserConnections(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc + proxySnapshot := currentGlobalProxyConfig() + defer func() { + newDatabaseFunc = originalNewDatabaseFunc + resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc + if _, err := setGlobalProxyConfig(proxySnapshot.Enabled, proxySnapshot.Proxy); err != nil { + t.Fatalf("restore global proxy failed: %v", err) + } + }() + if _, err := setGlobalProxyConfig(false, proxySnapshot.Proxy); err != nil { + t.Fatalf("disable global proxy failed: %v", err) + } + + connectCalls := 0 + newDatabaseFunc = func(dbType string) (db.Database, error) { + return &releaseRecordingDB{ + connect: func(config connection.ConnectionConfig) error { + connectCalls++ + if connectCalls == 1 { + return errors.New("Error 1226 (42000): User 'yangguofeng' has exceeded the 'max_user_connections' resource (current value: 5)") + } + return nil + }, + }, nil + } + resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) { + return raw, nil + } + + app := NewApp() + mainConfig := connection.ConnectionConfig{Type: "mysql", Host: "db.example.com", Port: 3306, User: "yangguofeng", Database: "main"} + analyticsConfig := mainConfig + analyticsConfig.Database = "analytics" + analyticsConfig.ConnectionParams = "charset=utf8mb4" + otherConfig := mainConfig + otherConfig.User = "other" + + mainDB := &releaseRecordingDB{} + analyticsDB := &releaseRecordingDB{} + otherDB := &releaseRecordingDB{} + app.dbCache[getCacheKey(mainConfig)] = cachedDatabase{inst: mainDB, config: normalizeCacheKeyConfig(mainConfig)} + app.dbCache[getCacheKey(analyticsConfig)] = cachedDatabase{inst: analyticsDB, config: normalizeCacheKeyConfig(analyticsConfig)} + app.dbCache[getCacheKey(otherConfig)] = cachedDatabase{inst: otherDB, config: normalizeCacheKeyConfig(otherConfig)} + + targetConfig := mainConfig + targetConfig.Database = "target" + targetConfig.ConnectionParams = "timeout=10" + + inst, err := app.getDatabase(targetConfig) + if err != nil { + t.Fatalf("expected retry after releasing cached same-instance connections, got %v", err) + } + if inst == nil { + t.Fatal("expected database instance") + } + if connectCalls != 2 { + t.Fatalf("expected one failed connect and one retry, got %d calls", connectCalls) + } + if mainDB.closed != 1 || analyticsDB.closed != 1 { + t.Fatalf("expected same-instance cached connections closed, got main=%d analytics=%d", mainDB.closed, analyticsDB.closed) + } + if otherDB.closed != 0 { + t.Fatalf("expected other user cache to remain open, got closed=%d", otherDB.closed) + } +} diff --git a/internal/app/methods_redis.go b/internal/app/methods_redis.go index f15e03b..569a762 100644 --- a/internal/app/methods_redis.go +++ b/internal/app/methods_redis.go @@ -78,6 +78,31 @@ func (a *App) getRedisClient(config connection.ConnectionConfig) (redis.RedisCli return client, nil } +func (a *App) openRedisClientIsolated(config connection.ConnectionConfig) (redis.RedisClient, error) { + resolvedConfig, err := a.resolveConnectionSecrets(config) + if err != nil { + wrapped := wrapConnectError(config, err) + logger.Error(wrapped, "Redis 密文解析失败:%s", formatRedisConnSummary(config)) + return nil, wrapped + } + + effectiveConfig := applyGlobalProxyToConnection(resolvedConfig) + connectConfig, proxyErr := resolveDialConfigWithProxyFunc(effectiveConfig) + if proxyErr != nil { + wrapped := wrapConnectError(effectiveConfig, proxyErr) + logger.Error(wrapped, "Redis 代理准备失败:%s", formatRedisConnSummary(effectiveConfig)) + return nil, wrapped + } + + client, connectedConfig, connectErr := connectRedisClientWithLegacyRootFallback(connectConfig) + if connectErr != nil { + wrapped := wrapConnectError(connectedConfig, connectErr) + logger.Error(wrapped, "Redis 临时连接失败:%s", formatRedisConnSummary(connectedConfig)) + return nil, wrapped + } + return client, nil +} + func connectRedisClientWithLegacyRootFallback(config connection.ConnectionConfig) (redis.RedisClient, connection.ConnectionConfig, error) { client := newRedisClientFunc() if err := client.Connect(config); err == nil { @@ -237,7 +262,20 @@ func (a *App) RedisConnect(config connection.ConnectionConfig) connection.QueryR // RedisTestConnection tests a Redis connection (alias for RedisConnect) func (a *App) RedisTestConnection(config connection.ConnectionConfig) connection.QueryResult { - return a.RedisConnect(config) + config.Type = "redis" + client, err := a.openRedisClientIsolated(config) + if err != nil { + logger.Error(err, "RedisTestConnection 连接失败:%s", formatRedisConnSummary(config)) + return connection.QueryResult{Success: false, Message: err.Error()} + } + if client != nil { + if closeErr := client.Close(); closeErr != nil { + logger.Error(closeErr, "RedisTestConnection 释放临时连接失败:%s", formatRedisConnSummary(config)) + return connection.QueryResult{Success: false, Message: fmt.Sprintf("连接成功但释放测试连接失败:%v", closeErr)} + } + } + logger.Infof("RedisTestConnection 连接成功:%s", formatRedisConnSummary(config)) + return connection.QueryResult{Success: true, Message: "连接成功"} } // RedisScanKeys scans keys matching a pattern diff --git a/internal/app/methods_redis_test.go b/internal/app/methods_redis_test.go index d17c4f0..b9de677 100644 --- a/internal/app/methods_redis_test.go +++ b/internal/app/methods_redis_test.go @@ -12,6 +12,7 @@ type capturingRedisClient struct { connectConfig connection.ConnectionConfig deletedHashKey string deletedHashFields []string + closed int } func (c *capturingRedisClient) Connect(config connection.ConnectionConfig) error { @@ -19,7 +20,10 @@ func (c *capturingRedisClient) Connect(config connection.ConnectionConfig) error return nil } -func (c *capturingRedisClient) Close() error { return nil } +func (c *capturingRedisClient) Close() error { + c.closed++ + return nil +} func (c *capturingRedisClient) Ping() error { return nil } @@ -119,6 +123,49 @@ func (c *scriptedRedisClient) Connect(config connection.ConnectionConfig) error return c.connectErr } +func TestRedisTestConnectionUsesIsolatedClientAndClosesIt(t *testing.T) { + originalNewRedisClientFunc := newRedisClientFunc + originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc + proxySnapshot := currentGlobalProxyConfig() + defer func() { + newRedisClientFunc = originalNewRedisClientFunc + resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc + if _, err := setGlobalProxyConfig(proxySnapshot.Enabled, proxySnapshot.Proxy); err != nil { + t.Fatalf("restore global proxy failed: %v", err) + } + CloseAllRedisClients() + }() + CloseAllRedisClients() + if _, err := setGlobalProxyConfig(false, proxySnapshot.Proxy); err != nil { + t.Fatalf("disable global proxy failed: %v", err) + } + + client := &capturingRedisClient{} + newRedisClientFunc = func() redislib.RedisClient { + return client + } + resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) { + return raw, nil + } + + app := NewApp() + result := app.RedisTestConnection(connection.ConnectionConfig{ + Type: "redis", + Host: "127.0.0.1", + Port: 6379, + }) + + if !result.Success { + t.Fatalf("expected redis test connection success, got %s", result.Message) + } + if client.closed != 1 { + t.Fatalf("expected isolated redis test client to be closed once, got %d", client.closed) + } + if len(redisCache) != 0 { + t.Fatalf("redis test connection must not write global redis cache, got %d entries", len(redisCache)) + } +} + func TestRedisConnectResolvesSavedSecretsByConnectionID(t *testing.T) { testCases := []struct { name string diff --git a/internal/db/clickhouse_impl.go b/internal/db/clickhouse_impl.go index 0dac365..d0952b8 100644 --- a/internal/db/clickhouse_impl.go +++ b/internal/db/clickhouse_impl.go @@ -643,6 +643,7 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error { break } c.conn = clickhouse.OpenDB(opts) + configureSQLConnectionPool(c.conn, "clickhouse") if err := c.Ping(); err != nil { lastProtocolErr = err failureMessage := clickHouseAttemptFailureMessage(protocol, err) diff --git a/internal/db/custom_impl.go b/internal/db/custom_impl.go index 0f2d027..657fa92 100644 --- a/internal/db/custom_impl.go +++ b/internal/db/custom_impl.go @@ -34,10 +34,13 @@ func (c *CustomDB) Connect(config connection.ConnectionConfig) error { if err != nil { return formatCustomDriverOpenError(driver, err) } + configureSQLConnectionPool(db, driver) c.conn = db c.driver = driver c.pingTimeout = getConnectTimeout(config) if err := c.Ping(); err != nil { + _ = db.Close() + c.conn = nil return fmt.Errorf("连接建立后验证失败:%w", err) } return nil diff --git a/internal/db/dameng_impl.go b/internal/db/dameng_impl.go index 7b05601..6b57df4 100644 --- a/internal/db/dameng_impl.go +++ b/internal/db/dameng_impl.go @@ -110,6 +110,7 @@ func (d *DamengDB) Connect(config connection.ConnectionConfig) error { failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err)) continue } + configureSQLConnectionPool(db, "dameng") d.conn = db d.pingTimeout = getConnectTimeout(attempt) if err := d.Ping(); err != nil { diff --git a/internal/db/diros_impl.go b/internal/db/diros_impl.go index 4fa0c27..7619528 100644 --- a/internal/db/diros_impl.go +++ b/internal/db/diros_impl.go @@ -187,6 +187,7 @@ func (d *DirosDB) Connect(config connection.ConnectionConfig) error { errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err)) continue } + configureSQLConnectionPool(db, "diros") timeout := getConnectTimeout(candidateConfig) ctx, cancel := utils.ContextWithTimeout(timeout) diff --git a/internal/db/gaussdb_impl.go b/internal/db/gaussdb_impl.go index 88ba92b..8bb79ce 100644 --- a/internal/db/gaussdb_impl.go +++ b/internal/db/gaussdb_impl.go @@ -179,6 +179,7 @@ func (g *GaussDB) Connect(config connection.ConnectionConfig) error { failures = append(failures, fmt.Sprintf("%s 数据库=%s 打开连接失败: %v", sslLabel, dbName, err)) continue } + configureSQLConnectionPool(dbConn, "gaussdb") g.conn = dbConn if err := g.Ping(); err != nil { @@ -233,6 +234,7 @@ func (g *GaussDB) ensureSearchPath(baseDSN string) { newDB, err := sql.Open("gaussdb", newDSN) if err == nil { + configureSQLConnectionPool(newDB, "gaussdb") newDB.SetConnMaxLifetime(5 * time.Minute) oldConn := g.conn g.conn = newDB diff --git a/internal/db/highgo_impl.go b/internal/db/highgo_impl.go index 19fb06c..1edc350 100644 --- a/internal/db/highgo_impl.go +++ b/internal/db/highgo_impl.go @@ -103,6 +103,7 @@ func (h *HighGoDB) Connect(config connection.ConnectionConfig) error { failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err)) continue } + configureSQLConnectionPool(db, "highgo") h.conn = db h.pingTimeout = getConnectTimeout(attempt) if err := h.Ping(); err != nil { diff --git a/internal/db/iris_impl.go b/internal/db/iris_impl.go index 44ed8bb..1750565 100644 --- a/internal/db/iris_impl.go +++ b/internal/db/iris_impl.go @@ -141,9 +141,12 @@ func (i *IrisDB) Connect(config connection.ConnectionConfig) error { if err != nil { return fmt.Errorf("打开数据库连接失败:%w", err) } + configureSQLConnectionPool(db, "iris") i.conn = db i.pingTimeout = getConnectTimeout(runConfig) if err := i.Ping(); err != nil { + _ = db.Close() + i.conn = nil return fmt.Errorf("连接建立后验证失败:%w", err) } cleanupOnFailure = false diff --git a/internal/db/kingbase_impl.go b/internal/db/kingbase_impl.go index 1daa776..44e78d9 100644 --- a/internal/db/kingbase_impl.go +++ b/internal/db/kingbase_impl.go @@ -157,6 +157,7 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error { failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err)) continue } + configureSQLConnectionPool(db, "kingbase") k.conn = db k.pingTimeout = getConnectTimeout(attempt) if err := k.Ping(); err != nil { @@ -175,8 +176,9 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error { // 将 search_path 参数拼入 DSN finalDSN := dsn + " search_path=" + quoteConnValue(searchPathStr) if finalDB, err := sql.Open("kingbase", finalDSN); err == nil { - k.pingTimeout = getConnectTimeout(attempt) + configureSQLConnectionPool(finalDB, "kingbase") finalDB.SetConnMaxLifetime(5 * time.Minute) + k.pingTimeout = getConnectTimeout(attempt) // 临时将 k.conn 指向 finalDB 来做 ping 测试 oldConn := k.conn diff --git a/internal/db/mariadb_impl.go b/internal/db/mariadb_impl.go index 508fddb..82987af 100644 --- a/internal/db/mariadb_impl.go +++ b/internal/db/mariadb_impl.go @@ -49,10 +49,13 @@ func (m *MariaDB) Connect(config connection.ConnectionConfig) error { if err != nil { return fmt.Errorf("打开数据库连接失败:%w", err) } + configureSQLConnectionPool(db, "mariadb") m.conn = db m.pingTimeout = getConnectTimeout(config) if err := m.Ping(); err != nil { + _ = db.Close() + m.conn = nil return fmt.Errorf("连接建立后验证失败:%w", err) } return nil diff --git a/internal/db/mysql_connection_params_test.go b/internal/db/mysql_connection_params_test.go index bf7135a..4e36600 100644 --- a/internal/db/mysql_connection_params_test.go +++ b/internal/db/mysql_connection_params_test.go @@ -45,6 +45,21 @@ func parseMySQLDriverCharsetsForTest(t *testing.T, dsn string) []string { return charsets } +func TestConfigureSQLConnectionPoolCapsOpenConnections(t *testing.T) { + dbConn, err := sql.Open("mysql", "root@tcp(127.0.0.1:1)/test") + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + defer dbConn.Close() + + configureSQLConnectionPool(dbConn, "mysql") + + stats := dbConn.Stats() + if stats.MaxOpenConnections != defaultSQLMaxOpenConns { + t.Fatalf("expected max open connections %d, got %d", defaultSQLMaxOpenConns, stats.MaxOpenConnections) + } +} + func TestMySQLDSN_MergesConnectionParamsWithDefaults(t *testing.T) { t.Parallel() diff --git a/internal/db/mysql_impl.go b/internal/db/mysql_impl.go index 2b4adf6..3a81a43 100644 --- a/internal/db/mysql_impl.go +++ b/internal/db/mysql_impl.go @@ -847,6 +847,7 @@ func (m *MySQLDB) Connect(config connection.ConnectionConfig) error { } continue } + configureSQLConnectionPool(db, candidateConfig.Type) timeout := getConnectTimeout(candidateConfig) ctx, cancel := utils.ContextWithTimeout(timeout) diff --git a/internal/db/oceanbase_impl.go b/internal/db/oceanbase_impl.go index 742f795..c1a647f 100644 --- a/internal/db/oceanbase_impl.go +++ b/internal/db/oceanbase_impl.go @@ -621,6 +621,7 @@ func (o *OceanBaseDB) connectOracleViaOBClient(config connection.ConnectionConfi errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败:%v", address, err)) continue } + configureSQLConnectionPool(db, "oceanbase") timeout := getConnectTimeout(candidateConfig) ctx, cancel := utils.ContextWithTimeout(timeout) @@ -741,6 +742,7 @@ func (o *OceanBaseDB) Connect(config connection.ConnectionConfig) error { errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败:%v", address, err)) continue } + configureSQLConnectionPool(db, "oceanbase") timeout := getConnectTimeout(candidateConfig) ctx, cancel := utils.ContextWithTimeout(timeout) diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go index 5373bdb..a988d25 100644 --- a/internal/db/oracle_impl.go +++ b/internal/db/oracle_impl.go @@ -161,6 +161,7 @@ func (o *OracleDB) Connect(config connection.ConnectionConfig) error { failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err)) continue } + configureSQLConnectionPool(db, "oracle") o.conn = db o.pingTimeout = getConnectTimeout(attempt) if err := o.Ping(); err != nil { diff --git a/internal/db/postgres_impl.go b/internal/db/postgres_impl.go index 30f5fac..9942ef7 100644 --- a/internal/db/postgres_impl.go +++ b/internal/db/postgres_impl.go @@ -159,6 +159,7 @@ func (p *PostgresDB) Connect(config connection.ConnectionConfig) error { failures = append(failures, fmt.Sprintf("%s 数据库=%s 打开连接失败: %v", sslLabel, dbName, err)) continue } + configureSQLConnectionPool(dbConn, "postgres") p.conn = dbConn // Force verification @@ -604,6 +605,7 @@ func (p *PostgresDB) ensureSearchPath(baseDSN string) { newDB, err := sql.Open("postgres", newDSN) if err == nil { + configureSQLConnectionPool(newDB, "postgres") newDB.SetConnMaxLifetime(5 * time.Minute) oldConn := p.conn p.conn = newDB diff --git a/internal/db/sql_pool.go b/internal/db/sql_pool.go new file mode 100644 index 0000000..146ec2a --- /dev/null +++ b/internal/db/sql_pool.go @@ -0,0 +1,27 @@ +package db + +import ( + "database/sql" + "strings" + "time" +) + +const ( + defaultSQLMaxOpenConns = 4 + defaultSQLConnMaxLifetime = 30 * time.Minute + defaultSQLConnMaxIdleTime = 30 * time.Second +) + +func configureSQLConnectionPool(db *sql.DB, dbType string) { + if db == nil { + return + } + switch strings.ToLower(strings.TrimSpace(dbType)) { + case "sqlite", "duckdb": + return + } + db.SetMaxOpenConns(defaultSQLMaxOpenConns) + db.SetMaxIdleConns(0) + db.SetConnMaxIdleTime(defaultSQLConnMaxIdleTime) + db.SetConnMaxLifetime(defaultSQLConnMaxLifetime) +} diff --git a/internal/db/sqlserver_impl.go b/internal/db/sqlserver_impl.go index 429a488..07b1bd3 100644 --- a/internal/db/sqlserver_impl.go +++ b/internal/db/sqlserver_impl.go @@ -176,10 +176,13 @@ func (s *SqlServerDB) Connect(config connection.ConnectionConfig) error { if err != nil { return fmt.Errorf("打开数据库连接失败:%w", err) } + configureSQLConnectionPool(db, "sqlserver") s.conn = db s.pingTimeout = getConnectTimeout(config) if err := s.Ping(); err != nil { + _ = db.Close() + s.conn = nil return fmt.Errorf("连接建立后验证失败:%w", err) } return nil diff --git a/internal/db/starrocks_impl.go b/internal/db/starrocks_impl.go index af2df6d..e65a131 100644 --- a/internal/db/starrocks_impl.go +++ b/internal/db/starrocks_impl.go @@ -270,6 +270,7 @@ func (s *StarRocksDB) Connect(config connection.ConnectionConfig) error { errorDetails = append(errorDetails, fmt.Sprintf("%s 打开失败: %v", address, err)) continue } + configureSQLConnectionPool(db, "starrocks") timeout := getConnectTimeout(candidateConfig) ctx, cancel := utils.ContextWithTimeout(timeout) diff --git a/internal/db/tdengine_impl.go b/internal/db/tdengine_impl.go index 02270cb..b9f8d74 100644 --- a/internal/db/tdengine_impl.go +++ b/internal/db/tdengine_impl.go @@ -96,6 +96,7 @@ func (t *TDengineDB) Connect(config connection.ConnectionConfig) error { failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err)) continue } + configureSQLConnectionPool(db, "tdengine") t.conn = db t.pingTimeout = getConnectTimeout(attempt) diff --git a/internal/db/vastbase_impl.go b/internal/db/vastbase_impl.go index a1ece0a..71cf70f 100644 --- a/internal/db/vastbase_impl.go +++ b/internal/db/vastbase_impl.go @@ -94,6 +94,7 @@ func (v *VastbaseDB) Connect(config connection.ConnectionConfig) error { failures = append(failures, fmt.Sprintf("第%d次连接打开失败: %v", idx+1, err)) continue } + configureSQLConnectionPool(db, "vastbase") v.conn = db v.pingTimeout = getConnectTimeout(attempt) if err := v.Ping(); err != nil {