void }> = ({
) : null}
{activeDriverLogLines.length > 0 ? (
-
+
{activeDriverLogLines.join('\n')}
) : (
diff --git a/frontend/src/components/RedisViewer.tsx b/frontend/src/components/RedisViewer.tsx
index 0aa3750..23855aa 100644
--- a/frontend/src/components/RedisViewer.tsx
+++ b/frontend/src/components/RedisViewer.tsx
@@ -218,18 +218,17 @@ const ResizableDivider: React.FC<{
(e.currentTarget.style.background = '#d9d9d9')}
- onMouseLeave={(e) => (e.currentTarget.style.background = '#f0f0f0')}
+ title="拖动调整宽度"
>
-
);
};
@@ -281,6 +280,23 @@ const getRedisScanLoadCount = (pattern: string, append: boolean): number => {
return append ? REDIS_KEY_SEARCH_LOAD_MORE_COUNT : REDIS_KEY_SEARCH_INITIAL_LOAD_COUNT;
};
+const normalizeRedisCursor = (value: unknown): string => {
+ if (typeof value === 'string') {
+ const trimmed = value.trim();
+ return trimmed === '' ? '0' : trimmed;
+ }
+ if (typeof value === 'number') {
+ if (!Number.isFinite(value)) {
+ return '0';
+ }
+ return Math.trunc(value).toString();
+ }
+ if (typeof value === 'bigint') {
+ return value.toString();
+ }
+ return '0';
+};
+
const normalizeKeySegment = (segment: string): string => {
return segment === '' ? EMPTY_SEGMENT_LABEL : segment;
};
@@ -384,7 +400,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => {
const [keys, setKeys] = useState([]);
const [loading, setLoading] = useState(false);
const [searchPattern, setSearchPattern] = useState('*');
- const [cursor, setCursor] = useState(0);
+ const [cursor, setCursor] = useState('0');
const [hasMore, setHasMore] = useState(false);
const [selectedKey, setSelectedKey] = useState(null);
const [keyValue, setKeyValue] = useState(null);
@@ -433,7 +449,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => {
const loadKeys = useCallback(async (
pattern: string = '*',
- fromCursor: number = 0,
+ fromCursor: string = '0',
append: boolean = false,
targetCount?: number
) => {
@@ -454,7 +470,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => {
if (res.success) {
const result = res.data;
const scannedKeys = Array.isArray(result?.keys) ? result.keys : [];
- const nextCursor = Number(result?.cursor || 0);
+ const nextCursor = normalizeRedisCursor(result?.cursor);
if (append) {
setKeys(prev => {
const keyMap = new Map();
@@ -466,7 +482,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => {
setKeys(scannedKeys);
}
setCursor(nextCursor);
- setHasMore(nextCursor !== 0);
+ setHasMore(nextCursor !== '0');
} else {
message.error('加载 Key 失败: ' + res.message);
}
@@ -483,14 +499,14 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => {
}, [getConfig]);
useEffect(() => {
- loadKeys(searchPattern, 0, false, getRedisScanLoadCount(searchPattern, false));
+ loadKeys(searchPattern, '0', false, getRedisScanLoadCount(searchPattern, false));
}, [redisDB]);
const handleSearch = (value: string) => {
const pattern = value.trim() || '*';
setSearchPattern(pattern);
- setCursor(0);
- loadKeys(pattern, 0, false, getRedisScanLoadCount(pattern, false));
+ setCursor('0');
+ loadKeys(pattern, '0', false, getRedisScanLoadCount(pattern, false));
};
const handleLoadMore = () => {
@@ -501,8 +517,8 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => {
};
const handleRefresh = () => {
- setCursor(0);
- loadKeys(searchPattern, 0, false, getRedisScanLoadCount(searchPattern, false));
+ setCursor('0');
+ loadKeys(searchPattern, '0', false, getRedisScanLoadCount(searchPattern, false));
};
const loadKeyValue = async (key: string) => {
diff --git a/frontend/src/store.ts b/frontend/src/store.ts
index 76854c3..88e2d4a 100644
--- a/frontend/src/store.ts
+++ b/frontend/src/store.ts
@@ -24,6 +24,7 @@ const DEFAULT_GLOBAL_PROXY: GlobalProxyConfig = {
const SUPPORTED_CONNECTION_TYPES = new Set([
'mysql',
'mariadb',
+ 'doris',
'diros',
'sphinx',
'clickhouse',
@@ -47,6 +48,7 @@ const getDefaultPortByType = (type: string): number => {
case 'mysql':
case 'mariadb':
return 3306;
+ case 'doris':
case 'diros':
return 9030;
case 'duckdb':
@@ -150,6 +152,9 @@ const sanitizeAddressList = (value: unknown): string[] => {
const normalizeConnectionType = (value: unknown): string => {
const type = toTrimmedString(value).toLowerCase();
+ if (type === 'doris') {
+ return 'diros';
+ }
return SUPPORTED_CONNECTION_TYPES.has(type) ? type : DEFAULT_CONNECTION_TYPE;
};
@@ -241,7 +246,8 @@ const sanitizeSavedConnection = (value: unknown, index: number): SavedConnection
const raw = value as Record;
const config = sanitizeConnectionConfig(resolveConnectionConfigPayload(raw));
const id = toTrimmedString(raw.id, `conn-${index + 1}`) || `conn-${index + 1}`;
- const fallbackName = config.host ? `${config.type}-${config.host}` : `连接-${index + 1}`;
+ const displayType = config.type === 'diros' ? 'doris' : config.type;
+ const fallbackName = config.host ? `${displayType}-${config.host}` : `连接-${index + 1}`;
const name = toTrimmedString(raw.name, fallbackName) || fallbackName;
const includeDatabases = sanitizeStringArray(raw.includeDatabases, 256);
const includeRedisDatabases = sanitizeNumberArray(raw.includeRedisDatabases, 0, 15);
diff --git a/frontend/src/types.ts b/frontend/src/types.ts
index f700677..e8a6cb4 100644
--- a/frontend/src/types.ts
+++ b/frontend/src/types.ts
@@ -137,7 +137,7 @@ export interface RedisKeyInfo {
export interface RedisScanResult {
keys: RedisKeyInfo[];
- cursor: number;
+ cursor: string;
}
export interface RedisValue {
diff --git a/frontend/wailsjs/go/app/App.d.ts b/frontend/wailsjs/go/app/App.d.ts
index b2edb2f..72ad6a1 100755
--- a/frontend/wailsjs/go/app/App.d.ts
+++ b/frontend/wailsjs/go/app/App.d.ts
@@ -34,6 +34,8 @@ export function DBGetTriggers(arg1:connection.ConnectionConfig,arg2:string,arg3:
export function DBQuery(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise;
+export function DBQueryIsolated(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise;
+
export function DBShowCreateTable(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise;
export function DataSync(arg1:sync.SyncConfig):Promise;
@@ -124,7 +126,7 @@ export function RedisListSet(arg1:connection.ConnectionConfig,arg2:string,arg3:n
export function RedisRenameKey(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise;
-export function RedisScanKeys(arg1:connection.ConnectionConfig,arg2:string,arg3:number,arg4:number):Promise;
+export function RedisScanKeys(arg1:connection.ConnectionConfig,arg2:string,arg3:any,arg4:number):Promise;
export function RedisSelectDB(arg1:connection.ConnectionConfig,arg2:number):Promise;
@@ -164,6 +166,8 @@ export function ResolveDriverRepositoryURL(arg1:string):Promise;
+export function SelectDriverPackageDirectory(arg1:string):Promise;
+
export function SelectDriverPackageFile(arg1:string):Promise;
export function SelectSSHKeyFile(arg1:string):Promise;
diff --git a/frontend/wailsjs/go/app/App.js b/frontend/wailsjs/go/app/App.js
index 6dba529..86f801f 100755
--- a/frontend/wailsjs/go/app/App.js
+++ b/frontend/wailsjs/go/app/App.js
@@ -62,6 +62,10 @@ export function DBQuery(arg1, arg2, arg3) {
return window['go']['app']['App']['DBQuery'](arg1, arg2, arg3);
}
+export function DBQueryIsolated(arg1, arg2, arg3) {
+ return window['go']['app']['App']['DBQueryIsolated'](arg1, arg2, arg3);
+}
+
export function DBShowCreateTable(arg1, arg2, arg3) {
return window['go']['app']['App']['DBShowCreateTable'](arg1, arg2, arg3);
}
@@ -322,6 +326,10 @@ export function SelectDriverDownloadDirectory(arg1) {
return window['go']['app']['App']['SelectDriverDownloadDirectory'](arg1);
}
+export function SelectDriverPackageDirectory(arg1) {
+ return window['go']['app']['App']['SelectDriverPackageDirectory'](arg1);
+}
+
export function SelectDriverPackageFile(arg1) {
return window['go']['app']['App']['SelectDriverPackageFile'](arg1);
}
diff --git a/internal/app/app.go b/internal/app/app.go
index a46726e..b8dd6a7 100644
--- a/internal/app/app.go
+++ b/internal/app/app.go
@@ -207,6 +207,32 @@ func (a *App) getDatabase(config connection.ConnectionConfig) (db.Database, erro
return a.getDatabaseWithPing(config, false)
}
+func (a *App) openDatabaseIsolated(config connection.ConnectionConfig) (db.Database, error) {
+ effectiveConfig := applyGlobalProxyToConnection(config)
+ if supported, reason := db.DriverRuntimeSupportStatus(effectiveConfig.Type); !supported {
+ if strings.TrimSpace(reason) == "" {
+ reason = fmt.Sprintf("%s 驱动未启用,请先在驱动管理中安装启用", strings.TrimSpace(effectiveConfig.Type))
+ }
+ return nil, withLogHint{err: fmt.Errorf("%s", reason), logPath: logger.Path()}
+ }
+
+ dbInst, err := db.NewDatabase(effectiveConfig.Type)
+ if err != nil {
+ return nil, err
+ }
+
+ connectConfig, proxyErr := resolveDialConfigWithProxy(effectiveConfig)
+ if proxyErr != nil {
+ _ = dbInst.Close()
+ return nil, wrapConnectError(effectiveConfig, proxyErr)
+ }
+ if err := dbInst.Connect(connectConfig); err != nil {
+ _ = dbInst.Close()
+ return nil, wrapConnectError(effectiveConfig, err)
+ }
+ return dbInst, nil
+}
+
func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing bool) (db.Database, error) {
effectiveConfig := applyGlobalProxyToConnection(config)
diff --git a/internal/app/global_proxy.go b/internal/app/global_proxy.go
index 57db384..4dc8686 100644
--- a/internal/app/global_proxy.go
+++ b/internal/app/global_proxy.go
@@ -1,6 +1,9 @@
package app
import (
+ "crypto/tls"
+ "crypto/x509"
+ "errors"
"fmt"
"net"
"net/http"
@@ -26,6 +29,12 @@ var globalProxyRuntime = struct {
proxy connection.ProxyConfig
}{}
+type localProxyTLSFallbackTransport struct {
+ primary *http.Transport
+ fallback *http.Transport
+ proxyEndpoint string
+}
+
func currentGlobalProxyConfig() globalProxySnapshot {
globalProxyRuntime.mu.RLock()
defer globalProxyRuntime.mu.RUnlock()
@@ -139,7 +148,7 @@ func newHTTPClientWithGlobalProxy(timeout time.Duration) *http.Client {
return client
}
-func buildHTTPTransportWithGlobalProxy() *http.Transport {
+func buildHTTPTransportWithGlobalProxy() http.RoundTripper {
baseTransport, ok := http.DefaultTransport.(*http.Transport)
if !ok || baseTransport == nil {
return nil
@@ -160,7 +169,98 @@ func buildHTTPTransportWithGlobalProxy() *http.Transport {
}
transport.Proxy = http.ProxyURL(proxyURL)
- return transport
+ if !isLoopbackProxyHost(snapshot.Proxy.Host) {
+ return transport
+ }
+
+ fallbackTransport := transport.Clone()
+ fallbackTransport.TLSClientConfig = cloneTLSConfigWithInsecureSkipVerify(fallbackTransport.TLSClientConfig)
+ return &localProxyTLSFallbackTransport{
+ primary: transport,
+ fallback: fallbackTransport,
+ proxyEndpoint: proxyURL.Redacted(),
+ }
+}
+
+func (t *localProxyTLSFallbackTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ resp, err := t.primary.RoundTrip(req)
+ if err == nil {
+ return resp, nil
+ }
+ if !isTLSFallbackCandidate(req.Method, err) {
+ return nil, err
+ }
+
+ retryReq, cloneErr := cloneRequestForRetry(req)
+ if cloneErr != nil {
+ return nil, err
+ }
+ logger.Warnf("检测到本地代理 TLS 证书不受信任,启用兼容回退:代理=%s 目标=%s 错误=%v", t.proxyEndpoint, req.URL.String(), err)
+ return t.fallback.RoundTrip(retryReq)
+}
+
+func isTLSFallbackCandidate(method string, err error) bool {
+ if !isIdempotentRequestMethod(method) {
+ return false
+ }
+ return isUnknownAuthorityError(err)
+}
+
+func isIdempotentRequestMethod(method string) bool {
+ switch strings.ToUpper(strings.TrimSpace(method)) {
+ case http.MethodGet, http.MethodHead:
+ return true
+ default:
+ return false
+ }
+}
+
+func cloneRequestForRetry(req *http.Request) (*http.Request, error) {
+ cloned := req.Clone(req.Context())
+ if req.Body == nil || req.Body == http.NoBody {
+ return cloned, nil
+ }
+ if req.GetBody == nil {
+ return nil, fmt.Errorf("request body not replayable")
+ }
+ body, err := req.GetBody()
+ if err != nil {
+ return nil, err
+ }
+ cloned.Body = body
+ return cloned, nil
+}
+
+func isUnknownAuthorityError(err error) bool {
+ var unknownErr x509.UnknownAuthorityError
+ if errors.As(err, &unknownErr) {
+ return true
+ }
+ return strings.Contains(strings.ToLower(err.Error()), "x509: certificate signed by unknown authority")
+}
+
+func cloneTLSConfigWithInsecureSkipVerify(base *tls.Config) *tls.Config {
+ if base == nil {
+ return &tls.Config{InsecureSkipVerify: true}
+ }
+ cloned := base.Clone()
+ cloned.InsecureSkipVerify = true
+ return cloned
+}
+
+func isLoopbackProxyHost(host string) bool {
+ trimmed := strings.TrimSpace(host)
+ if trimmed == "" {
+ return false
+ }
+ if strings.EqualFold(trimmed, "localhost") {
+ return true
+ }
+ ip := net.ParseIP(trimmed)
+ if ip == nil {
+ return false
+ }
+ return ip.IsLoopback()
}
func buildProxyURLFromConfig(proxyConfig connection.ProxyConfig) (*url.URL, error) {
diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go
index 4c4086b..6d20bdc 100644
--- a/internal/app/methods_db.go
+++ b/internal/app/methods_db.go
@@ -7,6 +7,7 @@ import (
"time"
"GoNavi-Wails/internal/connection"
+ "GoNavi-Wails/internal/db"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/utils"
)
@@ -112,16 +113,39 @@ func resolveDDLDBType(config connection.ConnectionConfig) string {
driver := strings.ToLower(strings.TrimSpace(config.Driver))
switch driver {
- case "postgresql":
+ case "postgresql", "postgres", "pg", "pq", "pgx":
return "postgres"
- case "dm":
+ case "dm", "dameng", "dm8":
return "dameng"
- case "sqlite3":
+ case "sqlite3", "sqlite":
return "sqlite"
case "sphinxql":
return "sphinx"
case "diros", "doris":
return "diros"
+ case "kingbase", "kingbase8", "kingbasees", "kingbasev8":
+ return "kingbase"
+ case "highgo":
+ return "highgo"
+ case "vastbase":
+ return "vastbase"
+ }
+
+ switch {
+ case strings.Contains(driver, "postgres"):
+ return "postgres"
+ case strings.Contains(driver, "kingbase"):
+ return "kingbase"
+ case strings.Contains(driver, "highgo"):
+ return "highgo"
+ case strings.Contains(driver, "vastbase"):
+ return "vastbase"
+ case strings.Contains(driver, "sqlite"):
+ return "sqlite"
+ case strings.Contains(driver, "sphinx"):
+ return "sphinx"
+ case strings.Contains(driver, "diros"), strings.Contains(driver, "doris"):
+ return "diros"
default:
return driver
}
@@ -186,7 +210,7 @@ func (a *App) RenameDatabase(config connection.ConnectionConfig, oldName string,
dbType := resolveDDLDBType(config)
switch dbType {
case "mysql", "mariadb", "diros", "sphinx":
- return connection.QueryResult{Success: false, Message: "MySQL/MariaDB/Diros/Sphinx 不支持直接重命名数据库,请新建库后迁移数据"}
+ return connection.QueryResult{Success: false, Message: "MySQL/MariaDB/Doris/Sphinx 不支持直接重命名数据库,请新建库后迁移数据"}
case "postgres", "kingbase", "highgo", "vastbase":
if strings.EqualFold(strings.TrimSpace(config.Database), oldName) {
return connection.QueryResult{Success: false, Message: "当前连接正在使用目标数据库,请先连接到其他数据库后再重命名"}
@@ -406,6 +430,66 @@ func (a *App) DBQuery(config connection.ConnectionConfig, dbName string, query s
}
}
+func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string, query string) connection.QueryResult {
+ runConfig := normalizeRunConfig(config, dbName)
+
+ dbInst, err := a.openDatabaseIsolated(runConfig)
+ if err != nil {
+ logger.Error(err, "DBQueryIsolated 获取连接失败:%s", formatConnSummary(runConfig))
+ return connection.QueryResult{Success: false, Message: err.Error()}
+ }
+ defer func() {
+ if closeErr := dbInst.Close(); closeErr != nil {
+ logger.Error(closeErr, "DBQueryIsolated 关闭临时连接失败:%s", formatConnSummary(runConfig))
+ }
+ }()
+
+ query = sanitizeSQLForPgLike(runConfig.Type, query)
+ timeoutSeconds := runConfig.Timeout
+ if timeoutSeconds <= 0 {
+ timeoutSeconds = 30
+ }
+ ctx, cancel := utils.ContextWithTimeout(time.Duration(timeoutSeconds) * time.Second)
+ defer cancel()
+
+ lowerQuery := strings.TrimSpace(strings.ToLower(query))
+ isReadQuery := strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "show") || strings.HasPrefix(lowerQuery, "describe") || strings.HasPrefix(lowerQuery, "explain")
+ if !isReadQuery && strings.ToLower(strings.TrimSpace(runConfig.Type)) == "mongodb" && strings.HasPrefix(strings.TrimSpace(query), "{") {
+ isReadQuery = true
+ }
+
+ if isReadQuery {
+ var data []map[string]interface{}
+ var columns []string
+ if q, ok := dbInst.(interface {
+ QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
+ }); ok {
+ data, columns, err = q.QueryContext(ctx, query)
+ } else {
+ data, columns, err = dbInst.Query(query)
+ }
+ if err != nil {
+ logger.Error(err, "DBQueryIsolated 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
+ return connection.QueryResult{Success: false, Message: err.Error()}
+ }
+ return connection.QueryResult{Success: true, Data: data, Fields: columns}
+ }
+
+ var affected int64
+ if e, ok := dbInst.(interface {
+ ExecContext(context.Context, string) (int64, error)
+ }); ok {
+ affected, err = e.ExecContext(ctx, query)
+ } else {
+ affected, err = dbInst.Exec(query)
+ }
+ if err != nil {
+ logger.Error(err, "DBQueryIsolated 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
+ return connection.QueryResult{Success: false, Message: err.Error()}
+ }
+ return connection.QueryResult{Success: true, Data: map[string]int64{"affectedRows": affected}}
+}
+
func sqlSnippet(query string) string {
q := strings.TrimSpace(query)
const max = 200
@@ -460,8 +544,8 @@ func (a *App) DBGetTables(config connection.ConnectionConfig, dbName string) con
}
func (a *App) DBShowCreateTable(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult {
- runConfig := normalizeRunConfig(config, dbName)
dbType := resolveDDLDBType(config)
+ runConfig := buildRunConfigForDDL(config, dbType, dbName)
dbInst, err := a.getDatabase(runConfig)
if err != nil {
@@ -469,35 +553,65 @@ func (a *App) DBShowCreateTable(config connection.ConnectionConfig, dbName strin
return connection.QueryResult{Success: false, Message: err.Error()}
}
- schemaName, pureTableName := normalizeSchemaAndTable(config, dbName, tableName)
- sqlStr, err := dbInst.GetCreateStatement(schemaName, pureTableName)
+ sqlStr, err := resolveCreateStatementWithFallback(dbInst, config, dbName, tableName)
if err != nil {
logger.Error(err, "DBShowCreateTable 获取建表语句失败:%s 表=%s", formatConnSummary(runConfig), tableName)
return connection.QueryResult{Success: false, Message: err.Error()}
}
- if shouldFallbackCreateStatement(dbType, sqlStr) {
- columns, colErr := dbInst.GetColumns(schemaName, pureTableName)
- if colErr != nil {
- logger.Error(colErr, "DBShowCreateTable 兜底加载字段失败:%s 表=%s", formatConnSummary(runConfig), tableName)
- return connection.QueryResult{Success: false, Message: colErr.Error()}
- }
- fallbackDDL, buildErr := buildFallbackCreateStatement(dbType, schemaName, pureTableName, columns)
- if buildErr != nil {
- logger.Error(buildErr, "DBShowCreateTable 兜底生成 DDL 失败:%s 表=%s", formatConnSummary(runConfig), tableName)
- return connection.QueryResult{Success: false, Message: buildErr.Error()}
- }
- sqlStr = fallbackDDL
- }
return connection.QueryResult{Success: true, Data: sqlStr}
}
-func shouldFallbackCreateStatement(dbType string, ddl string) bool {
+func resolveCreateStatementWithFallback(dbInst db.Database, config connection.ConnectionConfig, dbName string, tableName string) (string, error) {
+ dbType := resolveDDLDBType(config)
+ schemaName, pureTableName := normalizeSchemaAndTableByType(dbType, dbName, tableName)
+ if pureTableName == "" {
+ return "", fmt.Errorf("表名不能为空")
+ }
+
+ sqlStr, sourceErr := dbInst.GetCreateStatement(schemaName, pureTableName)
+ if sourceErr == nil && !shouldFallbackCreateStatement(dbType, sqlStr) {
+ return sqlStr, nil
+ }
+
+ if !supportsCreateStatementFallback(dbType) {
+ if sourceErr != nil {
+ return "", sourceErr
+ }
+ return sqlStr, nil
+ }
+
+ columns, colErr := dbInst.GetColumns(schemaName, pureTableName)
+ if colErr != nil {
+ if sourceErr != nil {
+ return "", sourceErr
+ }
+ return "", colErr
+ }
+
+ fallbackDDL, buildErr := buildFallbackCreateStatement(dbType, schemaName, pureTableName, columns)
+ if buildErr != nil {
+ if sourceErr != nil {
+ return "", sourceErr
+ }
+ return "", buildErr
+ }
+ return fallbackDDL, nil
+}
+
+func supportsCreateStatementFallback(dbType string) bool {
switch dbType {
case "postgres", "kingbase", "highgo", "vastbase":
+ return true
default:
return false
}
+}
+
+func shouldFallbackCreateStatement(dbType string, ddl string) bool {
+ if !supportsCreateStatementFallback(dbType) {
+ return false
+ }
trimmed := strings.TrimSpace(ddl)
if trimmed == "" {
diff --git a/internal/app/methods_db_create_statement_test.go b/internal/app/methods_db_create_statement_test.go
new file mode 100644
index 0000000..dfbf1fd
--- /dev/null
+++ b/internal/app/methods_db_create_statement_test.go
@@ -0,0 +1,174 @@
+package app
+
+import (
+ "errors"
+ "strings"
+ "testing"
+
+ "GoNavi-Wails/internal/connection"
+)
+
+type fakeCreateStatementDB struct {
+ createSQL string
+ createErr error
+ columns []connection.ColumnDefinition
+ columnsErr error
+
+ createSchema string
+ createTable string
+ colsSchema string
+ colsTable string
+}
+
+func (f *fakeCreateStatementDB) Connect(config connection.ConnectionConfig) error { return nil }
+func (f *fakeCreateStatementDB) Close() error { return nil }
+func (f *fakeCreateStatementDB) Ping() error { return nil }
+func (f *fakeCreateStatementDB) Query(query string) ([]map[string]interface{}, []string, error) {
+ return nil, nil, nil
+}
+func (f *fakeCreateStatementDB) Exec(query string) (int64, error) { return 0, nil }
+func (f *fakeCreateStatementDB) GetDatabases() ([]string, error) { return nil, nil }
+func (f *fakeCreateStatementDB) GetTables(dbName string) ([]string, error) { return nil, nil }
+func (f *fakeCreateStatementDB) GetCreateStatement(dbName, tableName string) (string, error) {
+ f.createSchema = dbName
+ f.createTable = tableName
+ return f.createSQL, f.createErr
+}
+func (f *fakeCreateStatementDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
+ f.colsSchema = dbName
+ f.colsTable = tableName
+ return f.columns, f.columnsErr
+}
+func (f *fakeCreateStatementDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
+ return nil, nil
+}
+func (f *fakeCreateStatementDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
+ return nil, nil
+}
+func (f *fakeCreateStatementDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
+ return nil, nil
+}
+func (f *fakeCreateStatementDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
+ return nil, nil
+}
+
+func TestResolveDDLDBType_CustomDriverAlias(t *testing.T) {
+ t.Parallel()
+
+ testCases := []struct {
+ name string
+ driver string
+ want string
+ }{
+ {name: "postgresql alias", driver: "postgresql", want: "postgres"},
+ {name: "pgx alias", driver: "pgx", want: "postgres"},
+ {name: "kingbase8 alias", driver: "kingbase8", want: "kingbase"},
+ {name: "kingbase contains alias", driver: "kingbasees", want: "kingbase"},
+ {name: "dm alias", driver: "dm8", want: "dameng"},
+ {name: "sqlite alias", driver: "sqlite3", want: "sqlite"},
+ }
+
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ cfg := connection.ConnectionConfig{Type: "custom", Driver: tc.driver}
+ if got := resolveDDLDBType(cfg); got != tc.want {
+ t.Fatalf("resolveDDLDBType() mismatch, want=%q got=%q", tc.want, got)
+ }
+ })
+ }
+}
+
+func TestResolveCreateStatementWithFallback_CustomKingbaseUsesPublicSchema(t *testing.T) {
+ t.Parallel()
+
+ dbInst := &fakeCreateStatementDB{
+ createSQL: "SHOW CREATE TABLE not directly supported in Kingbase/Postgres via SQL",
+ columns: []connection.ColumnDefinition{
+ {Name: "id", Type: "bigint", Nullable: "NO", Key: "PRI"},
+ },
+ }
+
+ ddl, err := resolveCreateStatementWithFallback(dbInst, connection.ConnectionConfig{
+ Type: "custom",
+ Driver: "kingbase8",
+ }, "demo_db", "orders")
+ if err != nil {
+ t.Fatalf("resolveCreateStatementWithFallback() unexpected error: %v", err)
+ }
+ if dbInst.createSchema != "public" || dbInst.colsSchema != "public" {
+ t.Fatalf("expected fallback schema public, got create=%q columns=%q", dbInst.createSchema, dbInst.colsSchema)
+ }
+ if !strings.Contains(ddl, `CREATE TABLE "public"."orders"`) {
+ t.Fatalf("expected fallback DDL with public schema, got: %s", ddl)
+ }
+}
+
+func TestResolveCreateStatementWithFallback_KeepQualifiedSchema(t *testing.T) {
+ t.Parallel()
+
+ dbInst := &fakeCreateStatementDB{
+ createSQL: "-- SHOW CREATE TABLE not fully supported for PostgreSQL in this MVP.",
+ columns: []connection.ColumnDefinition{
+ {Name: "id", Type: "integer", Nullable: "NO", Key: "PRI"},
+ },
+ }
+
+ ddl, err := resolveCreateStatementWithFallback(dbInst, connection.ConnectionConfig{
+ Type: "custom",
+ Driver: "postgresql",
+ }, "demo_db", "sales.orders")
+ if err != nil {
+ t.Fatalf("resolveCreateStatementWithFallback() unexpected error: %v", err)
+ }
+ if dbInst.createSchema != "sales" || dbInst.colsSchema != "sales" {
+ t.Fatalf("expected schema sales, got create=%q columns=%q", dbInst.createSchema, dbInst.colsSchema)
+ }
+ if !strings.Contains(ddl, `CREATE TABLE "sales"."orders"`) {
+ t.Fatalf("expected fallback DDL with sales schema, got: %s", ddl)
+ }
+}
+
+func TestResolveCreateStatementWithFallback_NoFallbackForMySQL(t *testing.T) {
+ t.Parallel()
+
+ dbInst := &fakeCreateStatementDB{
+ createSQL: "SHOW CREATE TABLE not directly supported in Kingbase/Postgres via SQL",
+ columnsErr: errors.New("should not be called"),
+ }
+
+ ddl, err := resolveCreateStatementWithFallback(dbInst, connection.ConnectionConfig{
+ Type: "mysql",
+ }, "demo_db", "orders")
+ if err != nil {
+ t.Fatalf("resolveCreateStatementWithFallback() unexpected error: %v", err)
+ }
+ if ddl != dbInst.createSQL {
+ t.Fatalf("expected original ddl for mysql, got: %s", ddl)
+ }
+ if dbInst.colsTable != "" {
+ t.Fatalf("mysql path should not call GetColumns, got table=%q", dbInst.colsTable)
+ }
+}
+
+func TestResolveCreateStatementWithFallback_FallbackWhenCreateStatementError(t *testing.T) {
+ t.Parallel()
+
+ dbInst := &fakeCreateStatementDB{
+ createErr: errors.New("statement unsupported"),
+ columns: []connection.ColumnDefinition{
+ {Name: "id", Type: "bigint", Nullable: "NO", Key: "PRI"},
+ },
+ }
+
+ ddl, err := resolveCreateStatementWithFallback(dbInst, connection.ConnectionConfig{
+ Type: "postgres",
+ }, "demo_db", "orders")
+ if err != nil {
+ t.Fatalf("resolveCreateStatementWithFallback() unexpected error: %v", err)
+ }
+ if !strings.Contains(ddl, `CREATE TABLE "public"."orders"`) {
+ t.Fatalf("expected fallback DDL for postgres error path, got: %s", ddl)
+ }
+}
diff --git a/internal/app/methods_driver.go b/internal/app/methods_driver.go
index 0c3a3e9..cef721a 100644
--- a/internal/app/methods_driver.go
+++ b/internal/app/methods_driver.go
@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"io"
+ "io/fs"
"net"
"net/http"
"net/url"
@@ -200,6 +201,7 @@ const (
driverBundleIndexMaxSize = 1 << 20
driverManifestMaxSize = 2 << 20
driverNetworkProbeTimeout = 4 * time.Second
+ localDriverDirectoryScanMaxEntries = 20000
driverChecksumPolicyStrict = "strict"
driverChecksumPolicyWarn = "warn"
driverChecksumPolicyOff = "off"
@@ -212,7 +214,7 @@ const builtinDriverManifestJSON = `{
"drivers": {
"mysql": { "engine": "go", "version": "1.9.3", "checksumPolicy": "off" },
"mariadb": { "engine": "go", "version": "1.9.3", "checksumPolicy": "off", "downloadUrl": "builtin://activate/mariadb" },
- "diros": { "engine": "go", "version": "1.9.3", "checksumPolicy": "off", "downloadUrl": "builtin://activate/diros" },
+ "doris": { "engine": "go", "version": "1.9.3", "checksumPolicy": "off", "downloadUrl": "builtin://activate/doris" },
"sphinx": { "engine": "go", "version": "1.9.3", "checksumPolicy": "off", "downloadUrl": "builtin://activate/sphinx" },
"sqlserver": { "engine": "go", "version": "1.9.6", "checksumPolicy": "off", "downloadUrl": "builtin://activate/sqlserver" },
"sqlite": { "engine": "go", "version": "1.44.3", "checksumPolicy": "off", "downloadUrl": "builtin://activate/sqlite" },
@@ -228,18 +230,19 @@ const builtinDriverManifestJSON = `{
}`
var (
- driverManifestCacheMu sync.RWMutex
- driverManifestCache = make(map[string]driverManifestCacheEntry)
- driverReleaseSizeMu sync.RWMutex
- driverReleaseSizeMap = make(map[string]driverReleaseAssetSizeCacheEntry)
- driverReleaseListMu sync.RWMutex
- driverReleaseList = driverManifestReleaseListCache{}
- driverModuleLatestMu sync.RWMutex
- driverModuleLatestMap = make(map[string]goModuleLatestVersionCacheEntry)
- driverModuleVersionMu sync.RWMutex
- driverModuleVersionMap = make(map[string]goModuleVersionListCacheEntry)
- driverVersionWarmupMu sync.Mutex
- driverVersionWarmup = driverVersionWarmupState{}
+ driverManifestCacheMu sync.RWMutex
+ driverManifestCache = make(map[string]driverManifestCacheEntry)
+ driverReleaseSizeMu sync.RWMutex
+ driverReleaseSizeMap = make(map[string]driverReleaseAssetSizeCacheEntry)
+ driverReleaseListMu sync.RWMutex
+ driverReleaseList = driverManifestReleaseListCache{}
+ driverModuleLatestMu sync.RWMutex
+ driverModuleLatestMap = make(map[string]goModuleLatestVersionCacheEntry)
+ driverModuleVersionMu sync.RWMutex
+ driverModuleVersionMap = make(map[string]goModuleVersionListCacheEntry)
+ driverVersionWarmupMu sync.Mutex
+ driverVersionWarmup = driverVersionWarmupState{}
+ errLocalDriverDirScanLimit = errors.New("local_driver_directory_scan_limit_exceeded")
)
type driverVersionWarmupState struct {
@@ -360,9 +363,6 @@ func (a *App) SelectDriverPackageFile(currentPath string) connection.QueryResult
selection, err := runtime.OpenFileDialog(a.ctx, runtime.OpenDialogOptions{
Title: "选择驱动包文件",
DefaultDirectory: defaultDir,
- Filters: []runtime.FileFilter{
- {DisplayName: "所有文件", Pattern: "*"},
- },
})
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
@@ -377,6 +377,36 @@ func (a *App) SelectDriverPackageFile(currentPath string) connection.QueryResult
return connection.QueryResult{Success: true, Data: map[string]interface{}{"path": selection}}
}
+func (a *App) SelectDriverPackageDirectory(currentPath string) connection.QueryResult {
+ defaultDir := strings.TrimSpace(currentPath)
+ if defaultDir == "" {
+ defaultDir = defaultDriverDownloadDirectory()
+ }
+ if filepath.Ext(defaultDir) != "" {
+ defaultDir = filepath.Dir(defaultDir)
+ }
+ if !filepath.IsAbs(defaultDir) {
+ if abs, err := filepath.Abs(defaultDir); err == nil {
+ defaultDir = abs
+ }
+ }
+
+ selection, err := runtime.OpenDirectoryDialog(a.ctx, runtime.OpenDialogOptions{
+ Title: "选择驱动包目录",
+ DefaultDirectory: defaultDir,
+ })
+ if err != nil {
+ return connection.QueryResult{Success: false, Message: err.Error()}
+ }
+ if strings.TrimSpace(selection) == "" {
+ return connection.QueryResult{Success: false, Message: "Cancelled"}
+ }
+ if abs, err := filepath.Abs(selection); err == nil {
+ selection = abs
+ }
+ return connection.QueryResult{Success: true, Data: map[string]interface{}{"path": selection}}
+}
+
func (a *App) ResolveDriverDownloadDirectory(directory string) connection.QueryResult {
resolved, err := resolveDriverDownloadDirectory(directory)
if err != nil {
@@ -426,7 +456,7 @@ func (a *App) ResolveDriverPackageDownloadURL(driverType string, repositoryURL s
if engine == driverEngineGo && !definition.BuiltIn {
urlText := strings.TrimSpace(definition.DefaultDownloadURL)
if urlText == "" {
- urlText = fmt.Sprintf("builtin://activate/%s", definition.Type)
+ urlText = fmt.Sprintf("builtin://activate/%s", optionalDriverPublicTypeName(definition.Type))
}
data := map[string]interface{}{
"url": urlText,
@@ -500,14 +530,14 @@ func (a *App) GetDriverVersionPackageSize(driverType string, version string) con
if sizeByAsset, err := loadReleaseAssetSizesCached("tag:"+tag, func() (*githubRelease, error) {
return fetchReleaseByTag(tag)
}); err == nil {
- sizeBytes = sizeByAsset[assetName]
+ sizeBytes = resolveOptionalDriverAssetSize(sizeByAsset, normalizedType)
if sizeBytes > 0 {
sizeSource = "tag"
}
}
if sizeBytes <= 0 {
if sizeByAsset, err := loadReleaseAssetSizesCached("latest", fetchLatestReleaseForDriverAssets); err == nil {
- sizeBytes = sizeByAsset[assetName]
+ sizeBytes = resolveOptionalDriverAssetSize(sizeByAsset, normalizedType)
if sizeBytes > 0 {
sizeSource = "latest"
}
@@ -684,7 +714,7 @@ func (a *App) InstallLocalDriverPackage(driverType string, filePath string, down
a.emitDriverDownloadProgress(definition.Type, "start", 0, 100, "开始安装本地驱动包")
selectedVersion := resolveDriverInstallVersion(definition.PinnedVersion, "local://manual", definition)
- meta, installErr := installOptionalDriverAgentFromLocalFile(definition, filePath, resolvedDir, selectedVersion)
+ meta, installErr := installOptionalDriverAgentFromLocalPath(definition, filePath, resolvedDir, selectedVersion)
if installErr != nil {
errText := normalizeErrorMessage(installErr)
a.emitDriverDownloadProgress(definition.Type, "error", 0, 0, errText)
@@ -732,7 +762,7 @@ func (a *App) DownloadDriverPackage(driverType string, version string, downloadU
urlText = strings.TrimSpace(definition.DefaultDownloadURL)
}
if urlText == "" {
- urlText = fmt.Sprintf("builtin://activate/%s", definition.Type)
+ urlText = fmt.Sprintf("builtin://activate/%s", optionalDriverPublicTypeName(definition.Type))
}
selectedVersion := resolveDriverInstallVersion(version, urlText, definition)
@@ -1038,7 +1068,7 @@ func allDriverDefinitionsWithPackages(packages map[string]pinnedDriverPackage) [
// 其他数据源需要先在驱动管理中“安装启用”。
buildOptionalGoDriverDefinition("mariadb", "MariaDB", packages),
- buildOptionalGoDriverDefinition("diros", "Diros", packages),
+ buildOptionalGoDriverDefinition("diros", "Doris", packages),
buildOptionalGoDriverDefinition("sphinx", "Sphinx", packages),
buildOptionalGoDriverDefinition("sqlserver", "SQL Server", packages),
buildOptionalGoDriverDefinition("sqlite", "SQLite", packages),
@@ -1216,7 +1246,7 @@ func resolveDriverVersionOptions(definition driverDefinition, repositoryURL stri
urlText = strings.TrimSpace(definition.DefaultDownloadURL)
}
if urlText == "" && effectiveDriverEngine(definition) == driverEngineGo {
- urlText = fmt.Sprintf("builtin://activate/%s", driverType)
+ urlText = fmt.Sprintf("builtin://activate/%s", optionalDriverPublicTypeName(driverType))
}
if versionText == "" {
versionText = resolveDriverInstallVersion("", urlText, definition)
@@ -1353,7 +1383,7 @@ func resolveVersionedDriverOption(definition driverDefinition, version string, s
urlText := strings.TrimSpace(definition.DefaultDownloadURL)
if urlText == "" && effectiveDriverEngine(definition) == driverEngineGo {
- urlText = fmt.Sprintf("builtin://activate/%s", driverType)
+ urlText = fmt.Sprintf("builtin://activate/%s", optionalDriverPublicTypeName(driverType))
}
if urlText == "" {
return "", "", false
@@ -1400,13 +1430,13 @@ func resolveDriverVersionPackageSizeBytes(definition driverDefinition, option dr
tag := "v" + version
if sizeByAsset, ok := readReleaseAssetSizesFromCache("tag:" + tag); ok {
- return sizeByAsset[assetName]
+ return resolveOptionalDriverAssetSize(sizeByAsset, driverType)
}
// 下拉版本列表要求快速返回:仅复用已有缓存,不在这里触发网络请求。
if strings.EqualFold(strings.TrimSpace(option.Source), "latest") {
if sizeByAsset, ok := readReleaseAssetSizesFromCache("latest"); ok {
- return sizeByAsset[assetName]
+ return resolveOptionalDriverAssetSize(sizeByAsset, driverType)
}
}
return 0
@@ -1635,13 +1665,14 @@ func resolveDriverVersionOptionsFromReleases(definition driverDefinition) []driv
}
assetName := optionalDriverReleaseAssetName(driverType)
+ assetNames := optionalDriverReleaseAssetNames(driverType)
result := make([]driverVersionOptionItem, 0, len(releases))
for _, release := range releases {
if release.Prerelease {
continue
}
tag := strings.TrimSpace(release.TagName)
- if tag == "" || !releaseContainsAsset(release, assetName) {
+ if tag == "" || !releaseContainsAnyAsset(release, assetNames) {
continue
}
result = append(result, driverVersionOptionItem{
@@ -1718,14 +1749,24 @@ func fetchDriverReleaseList() ([]githubRelease, error) {
return releases, nil
}
-func releaseContainsAsset(release githubRelease, assetName string) bool {
- name := strings.TrimSpace(assetName)
- if name == "" {
+func releaseContainsAnyAsset(release githubRelease, assetNames []string) bool {
+ normalizedNames := make([]string, 0, len(assetNames))
+ for _, assetName := range assetNames {
+ name := strings.TrimSpace(assetName)
+ if name == "" {
+ continue
+ }
+ normalizedNames = append(normalizedNames, name)
+ }
+ if len(normalizedNames) == 0 {
return false
}
for _, asset := range release.Assets {
- if strings.EqualFold(strings.TrimSpace(asset.Name), name) {
- return true
+ assetName := strings.TrimSpace(asset.Name)
+ for _, expected := range normalizedNames {
+ if strings.EqualFold(assetName, expected) {
+ return true
+ }
}
}
return false
@@ -2194,7 +2235,7 @@ func installOptionalDriverAgentPackage(a *App, definition driverDefinition, sele
}, nil
}
-func installOptionalDriverAgentFromLocalFile(definition driverDefinition, filePath string, resolvedDir string, selectedVersion string) (installedDriverPackage, error) {
+func installOptionalDriverAgentFromLocalPath(definition driverDefinition, filePath string, resolvedDir string, selectedVersion string) (installedDriverPackage, error) {
driverType := normalizeDriverType(definition.Type)
displayName := resolveDriverDisplayName(definition)
pathText := strings.TrimSpace(filePath)
@@ -2208,9 +2249,6 @@ func installOptionalDriverAgentFromLocalFile(definition driverDefinition, filePa
if statErr != nil {
return installedDriverPackage{}, fmt.Errorf("读取本地驱动包失败:%w", statErr)
}
- if info.IsDir() {
- return installedDriverPackage{}, fmt.Errorf("本地驱动包路径为目录:%s", pathText)
- }
executablePath, err := db.ResolveOptionalDriverAgentExecutablePath(resolvedDir, driverType)
if err != nil {
@@ -2220,8 +2258,23 @@ func installOptionalDriverAgentFromLocalFile(definition driverDefinition, filePa
return installedDriverPackage{}, fmt.Errorf("创建 %s 驱动目录失败:%w", displayName, mkErr)
}
+ sourcePath := pathText
+ sourceName := filepath.Base(pathText)
downloadSource := fmt.Sprintf("local://manual/%s", filepath.Base(pathText))
- if strings.EqualFold(filepath.Ext(pathText), ".zip") {
+ if info.IsDir() {
+ matchedPath, matchedEntry, resolveErr := resolveLocalDriverAgentFromDirectory(pathText, driverType)
+ if resolveErr != nil {
+ return installedDriverPackage{}, resolveErr
+ }
+ sourcePath = matchedPath
+ sourceName = filepath.Base(matchedPath)
+ downloadSource = fmt.Sprintf("local://manual-dir/%s", filepath.Base(pathText))
+ if strings.TrimSpace(matchedEntry) != "" {
+ downloadSource = downloadSource + "#" + matchedEntry
+ }
+ }
+
+ if !info.IsDir() && strings.EqualFold(filepath.Ext(pathText), ".zip") {
entryName, extractErr := installOptionalDriverAgentFromLocalZip(pathText, definition, executablePath)
if extractErr != nil {
return installedDriverPackage{}, extractErr
@@ -2230,7 +2283,7 @@ func installOptionalDriverAgentFromLocalFile(definition driverDefinition, filePa
downloadSource = downloadSource + "#" + entryName
}
} else {
- if copyErr := copyAgentBinary(pathText, executablePath); copyErr != nil {
+ if copyErr := copyAgentBinary(sourcePath, executablePath); copyErr != nil {
return installedDriverPackage{}, fmt.Errorf("导入本地驱动代理失败:%w", copyErr)
}
}
@@ -2242,8 +2295,8 @@ func installOptionalDriverAgentFromLocalFile(definition driverDefinition, filePa
return installedDriverPackage{
DriverType: driverType,
Version: strings.TrimSpace(selectedVersion),
- FilePath: pathText,
- FileName: filepath.Base(pathText),
+ FilePath: sourcePath,
+ FileName: sourceName,
ExecutablePath: executablePath,
DownloadURL: downloadSource,
SHA256: hash,
@@ -2251,6 +2304,153 @@ func installOptionalDriverAgentFromLocalFile(definition driverDefinition, filePa
}, nil
}
+type localDriverCandidate struct {
+ absPath string
+ relativePath string
+ depth int
+ inPlatformDir bool
+}
+
+func resolveLocalDriverAgentFromDirectory(directoryPath string, driverType string) (string, string, error) {
+ root := strings.TrimSpace(directoryPath)
+ if root == "" {
+ return "", "", fmt.Errorf("本地驱动目录路径为空")
+ }
+ if absPath, absErr := filepath.Abs(root); absErr == nil {
+ root = absPath
+ }
+ info, statErr := os.Stat(root)
+ if statErr != nil {
+ return "", "", fmt.Errorf("读取本地驱动目录失败:%w", statErr)
+ }
+ if !info.IsDir() {
+ return "", "", fmt.Errorf("本地驱动目录路径不是目录:%s", root)
+ }
+
+ normalizedType := normalizeDriverType(driverType)
+ displayDefinition, found := resolveDriverDefinition(normalizedType)
+ if !found {
+ displayDefinition = driverDefinition{Type: normalizedType, Name: normalizedType}
+ }
+ displayName := resolveDriverDisplayName(displayDefinition)
+ platformDir := optionalDriverBundlePlatformDir(stdRuntime.GOOS)
+ assetNameCandidates := optionalDriverReleaseAssetNames(normalizedType)
+ baseNameCandidates := optionalDriverExecutableBaseNames(normalizedType)
+ assetName := optionalDriverReleaseAssetName(normalizedType)
+
+ exactRelativePath := filepath.ToSlash(filepath.Join(platformDir, assetName))
+ for _, candidateName := range assetNameCandidates {
+ exactPath := filepath.Join(root, platformDir, candidateName)
+ if exactInfo, err := os.Stat(exactPath); err == nil && !exactInfo.IsDir() {
+ return exactPath, filepath.ToSlash(filepath.Join(platformDir, candidateName)), nil
+ }
+ }
+
+ for _, candidateName := range assetNameCandidates {
+ rootAssetPath := filepath.Join(root, candidateName)
+ if rootAssetInfo, err := os.Stat(rootAssetPath); err == nil && !rootAssetInfo.IsDir() {
+ return rootAssetPath, filepath.ToSlash(candidateName), nil
+ }
+ }
+
+ assetCandidates := make([]localDriverCandidate, 0, 8)
+ baseCandidates := make([]localDriverCandidate, 0, 8)
+ visited := 0
+ walkErr := filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error {
+ if err != nil {
+ return nil
+ }
+ visited++
+ if visited > localDriverDirectoryScanMaxEntries {
+ return errLocalDriverDirScanLimit
+ }
+ if d.IsDir() {
+ return nil
+ }
+ name := strings.TrimSpace(d.Name())
+ if name == "" {
+ return nil
+ }
+
+ relative, relErr := filepath.Rel(root, path)
+ if relErr != nil {
+ relative = name
+ }
+ normalizedRelative := filepath.ToSlash(strings.TrimPrefix(strings.TrimSpace(relative), "./"))
+ if normalizedRelative == "" {
+ normalizedRelative = name
+ }
+ normalizedLower := strings.ToLower(normalizedRelative)
+ platformPrefix := strings.ToLower(platformDir) + "/"
+ inPlatformDir := normalizedLower == strings.ToLower(platformDir) || strings.HasPrefix(normalizedLower, platformPrefix)
+ depth := strings.Count(normalizedRelative, "/")
+ candidate := localDriverCandidate{
+ absPath: path,
+ relativePath: normalizedRelative,
+ depth: depth,
+ inPlatformDir: inPlatformDir,
+ }
+
+ for _, candidateName := range assetNameCandidates {
+ if strings.EqualFold(name, candidateName) {
+ assetCandidates = append(assetCandidates, candidate)
+ return nil
+ }
+ }
+ for _, candidateName := range baseNameCandidates {
+ if strings.EqualFold(name, candidateName) {
+ baseCandidates = append(baseCandidates, candidate)
+ return nil
+ }
+ }
+ return nil
+ })
+ if errors.Is(walkErr, errLocalDriverDirScanLimit) {
+ return "", "", fmt.Errorf("本地驱动目录条目过多(超过 %d),请缩小目录范围或直接选择 zip/单文件", localDriverDirectoryScanMaxEntries)
+ }
+ if walkErr != nil {
+ return "", "", fmt.Errorf("扫描本地驱动目录失败:%w", walkErr)
+ }
+
+ selectBest := func(candidates []localDriverCandidate) (localDriverCandidate, bool) {
+ if len(candidates) == 0 {
+ return localDriverCandidate{}, false
+ }
+ sort.Slice(candidates, func(i, j int) bool {
+ left := candidates[i]
+ right := candidates[j]
+ if left.inPlatformDir != right.inPlatformDir {
+ return left.inPlatformDir
+ }
+ if left.depth != right.depth {
+ return left.depth < right.depth
+ }
+ leftRelative := strings.ToLower(left.relativePath)
+ rightRelative := strings.ToLower(right.relativePath)
+ if leftRelative != rightRelative {
+ return leftRelative < rightRelative
+ }
+ return strings.ToLower(left.absPath) < strings.ToLower(right.absPath)
+ })
+ return candidates[0], true
+ }
+
+ if candidate, ok := selectBest(assetCandidates); ok {
+ return candidate.absPath, candidate.relativePath, nil
+ }
+ if candidate, ok := selectBest(baseCandidates); ok {
+ return candidate.absPath, candidate.relativePath, nil
+ }
+
+ return "", "", fmt.Errorf(
+ "目录中未找到 %s 代理文件(优先路径 %s,候选文件名 %s / %s)",
+ displayName,
+ exactRelativePath,
+ strings.Join(assetNameCandidates, " | "),
+ strings.Join(baseNameCandidates, " | "),
+ )
+}
+
func installOptionalDriverAgentFromLocalZip(zipPath string, definition driverDefinition, executablePath string) (string, error) {
driverType := normalizeDriverType(definition.Type)
displayName := resolveDriverDisplayName(definition)
@@ -2261,24 +2461,31 @@ func installOptionalDriverAgentFromLocalZip(zipPath string, definition driverDef
defer reader.Close()
entryPath := optionalDriverBundleEntryPath(driverType)
- expectedBaseName := optionalDriverReleaseAssetName(driverType)
+ entryPaths := optionalDriverBundleEntryPaths(driverType)
+ expectedBaseNames := optionalDriverReleaseAssetNames(driverType)
findEntry := func() *zip.File {
for _, file := range reader.File {
name := filepath.ToSlash(strings.TrimPrefix(strings.TrimSpace(file.Name), "./"))
- if name == entryPath {
- return file
+ for _, expectedPath := range entryPaths {
+ if name == expectedPath {
+ return file
+ }
}
}
for _, file := range reader.File {
name := filepath.ToSlash(strings.TrimPrefix(strings.TrimSpace(file.Name), "./"))
- if strings.EqualFold(name, entryPath) {
- return file
+ for _, expectedPath := range entryPaths {
+ if strings.EqualFold(name, expectedPath) {
+ return file
+ }
}
}
for _, file := range reader.File {
name := filepath.ToSlash(strings.TrimPrefix(strings.TrimSpace(file.Name), "./"))
- if strings.EqualFold(filepath.Base(name), expectedBaseName) {
- return file
+ for _, expectedName := range expectedBaseNames {
+ if strings.EqualFold(filepath.Base(name), expectedName) {
+ return file
+ }
}
}
return nil
@@ -2472,24 +2679,31 @@ func downloadOptionalDriverAgentFromBundle(a *App, definition driverDefinition,
defer reader.Close()
entryPath := optionalDriverBundleEntryPath(driverType)
- expectedBaseName := optionalDriverReleaseAssetName(driverType)
+ entryPaths := optionalDriverBundleEntryPaths(driverType)
+ expectedBaseNames := optionalDriverReleaseAssetNames(driverType)
findEntry := func() *zip.File {
for _, file := range reader.File {
name := filepath.ToSlash(strings.TrimPrefix(strings.TrimSpace(file.Name), "./"))
- if name == entryPath {
- return file
+ for _, expectedPath := range entryPaths {
+ if name == expectedPath {
+ return file
+ }
}
}
for _, file := range reader.File {
name := filepath.ToSlash(strings.TrimPrefix(strings.TrimSpace(file.Name), "./"))
- if strings.EqualFold(name, entryPath) {
- return file
+ for _, expectedPath := range entryPaths {
+ if strings.EqualFold(name, expectedPath) {
+ return file
+ }
}
}
for _, file := range reader.File {
name := filepath.ToSlash(strings.TrimPrefix(strings.TrimSpace(file.Name), "./"))
- if strings.EqualFold(filepath.Base(name), expectedBaseName) {
- return file
+ for _, expectedName := range expectedBaseNames {
+ if strings.EqualFold(filepath.Base(name), expectedName) {
+ return file
+ }
}
}
return nil
@@ -2640,22 +2854,93 @@ func fileExists(path string) bool {
return err == nil && !info.IsDir()
}
-func optionalDriverExecutableBaseName(driverType string) string {
- name := fmt.Sprintf("%s-driver-agent", normalizeDriverType(driverType))
+func optionalDriverPublicTypeName(driverType string) string {
+ switch normalizeDriverType(driverType) {
+ case "diros":
+ return "doris"
+ default:
+ return normalizeDriverType(driverType)
+ }
+}
+
+func optionalDriverExecutableBaseNameForType(typeName string) string {
+ base := strings.TrimSpace(typeName)
+ if base == "" {
+ base = "unknown"
+ }
+ name := fmt.Sprintf("%s-driver-agent", base)
if stdRuntime.GOOS == "windows" {
return name + ".exe"
}
return name
}
-func optionalDriverReleaseAssetName(driverType string) string {
- name := fmt.Sprintf("%s-driver-agent-%s-%s", normalizeDriverType(driverType), stdRuntime.GOOS, stdRuntime.GOARCH)
- if stdRuntime.GOOS == "windows" {
+func optionalDriverReleaseAssetNameForType(typeName string, goos string, goarch string) string {
+ base := strings.TrimSpace(typeName)
+ if base == "" {
+ base = "unknown"
+ }
+ name := fmt.Sprintf("%s-driver-agent-%s-%s", base, goos, goarch)
+ if strings.EqualFold(goos, "windows") {
return name + ".exe"
}
return name
}
+func optionalDriverExecutableBaseNames(driverType string) []string {
+ names := make([]string, 0, 2)
+ seen := make(map[string]struct{}, 2)
+ appendName := func(typeName string) {
+ name := optionalDriverExecutableBaseNameForType(typeName)
+ if strings.TrimSpace(name) == "" {
+ return
+ }
+ if _, ok := seen[name]; ok {
+ return
+ }
+ seen[name] = struct{}{}
+ names = append(names, name)
+ }
+
+ appendName(optionalDriverPublicTypeName(driverType))
+ return names
+}
+
+func optionalDriverReleaseAssetNames(driverType string) []string {
+ names := make([]string, 0, 2)
+ seen := make(map[string]struct{}, 2)
+ appendName := func(typeName string) {
+ name := optionalDriverReleaseAssetNameForType(typeName, stdRuntime.GOOS, stdRuntime.GOARCH)
+ if strings.TrimSpace(name) == "" {
+ return
+ }
+ if _, ok := seen[name]; ok {
+ return
+ }
+ seen[name] = struct{}{}
+ names = append(names, name)
+ }
+
+ appendName(optionalDriverPublicTypeName(driverType))
+ return names
+}
+
+func optionalDriverExecutableBaseName(driverType string) string {
+ names := optionalDriverExecutableBaseNames(driverType)
+ if len(names) == 0 {
+ return optionalDriverExecutableBaseNameForType("")
+ }
+ return names[0]
+}
+
+func optionalDriverReleaseAssetName(driverType string) string {
+ names := optionalDriverReleaseAssetNames(driverType)
+ if len(names) == 0 {
+ return optionalDriverReleaseAssetNameForType("", stdRuntime.GOOS, stdRuntime.GOARCH)
+ }
+ return names[0]
+}
+
func optionalDriverBundlePlatformDir(goos string) string {
switch strings.ToLower(strings.TrimSpace(goos)) {
case "windows":
@@ -2669,8 +2954,41 @@ func optionalDriverBundlePlatformDir(goos string) string {
}
}
+func optionalDriverBundleEntryPaths(driverType string) []string {
+ platformDir := optionalDriverBundlePlatformDir(stdRuntime.GOOS)
+ assetNames := optionalDriverReleaseAssetNames(driverType)
+ result := make([]string, 0, len(assetNames))
+ seen := make(map[string]struct{}, len(assetNames))
+ for _, assetName := range assetNames {
+ entry := filepath.ToSlash(filepath.Join(platformDir, assetName))
+ if _, ok := seen[entry]; ok {
+ continue
+ }
+ seen[entry] = struct{}{}
+ result = append(result, entry)
+ }
+ return result
+}
+
func optionalDriverBundleEntryPath(driverType string) string {
- return filepath.ToSlash(filepath.Join(optionalDriverBundlePlatformDir(stdRuntime.GOOS), optionalDriverReleaseAssetName(driverType)))
+ paths := optionalDriverBundleEntryPaths(driverType)
+ if len(paths) == 0 {
+ return filepath.ToSlash(filepath.Join(optionalDriverBundlePlatformDir(stdRuntime.GOOS), optionalDriverReleaseAssetName(driverType)))
+ }
+ return paths[0]
+}
+
+func resolveOptionalDriverAssetSize(sizeByAsset map[string]int64, driverType string) int64 {
+ if len(sizeByAsset) == 0 {
+ return 0
+ }
+ for _, assetName := range optionalDriverReleaseAssetNames(driverType) {
+ sizeBytes := sizeByAsset[assetName]
+ if sizeBytes > 0 {
+ return sizeBytes
+ }
+ }
+ return 0
}
func resolveOptionalDriverBundleDownloadURLs() []string {
@@ -2719,12 +3037,16 @@ func resolveOptionalDriverAgentDownloadURLs(definition driverDefinition, rawURL
}
}
- assetName := optionalDriverReleaseAssetName(driverType)
+ assetNames := optionalDriverReleaseAssetNames(driverType)
currentVersion := normalizeVersion(getCurrentVersion())
if currentVersion != "" && currentVersion != "0.0.0" {
- appendURL(fmt.Sprintf("https://github.com/Syngnat/GoNavi/releases/download/v%s/%s", currentVersion, assetName))
+ for _, assetName := range assetNames {
+ appendURL(fmt.Sprintf("https://github.com/Syngnat/GoNavi/releases/download/v%s/%s", currentVersion, assetName))
+ }
+ }
+ for _, assetName := range assetNames {
+ appendURL(fmt.Sprintf("https://github.com/Syngnat/GoNavi/releases/latest/download/%s", assetName))
}
- appendURL(fmt.Sprintf("https://github.com/Syngnat/GoNavi/releases/latest/download/%s", assetName))
return candidates
}
@@ -2753,8 +3075,23 @@ func findExistingOptionalDriverAgentCandidate(definition driverDefinition, targe
func resolveOptionalDriverAgentCandidatePaths(definition driverDefinition) []string {
driverType := normalizeDriverType(definition.Type)
- name := optionalDriverExecutableBaseName(driverType)
- assetName := optionalDriverReleaseAssetName(driverType)
+ names := optionalDriverExecutableBaseNames(driverType)
+ assetNames := optionalDriverReleaseAssetNames(driverType)
+ pathTypeNames := make([]string, 0, 2)
+ seenPathType := make(map[string]struct{}, 2)
+ appendPathType := func(typeName string) {
+ trimmed := strings.TrimSpace(typeName)
+ if trimmed == "" {
+ return
+ }
+ if _, ok := seenPathType[trimmed]; ok {
+ return
+ }
+ seenPathType[trimmed] = struct{}{}
+ pathTypeNames = append(pathTypeNames, trimmed)
+ }
+ appendPathType(optionalDriverPublicTypeName(driverType))
+
candidates := make([]string, 0, 12)
appendPath := func(pathText string) {
trimmed := strings.TrimSpace(pathText)
@@ -2769,18 +3106,36 @@ func resolveOptionalDriverAgentCandidatePaths(definition driverDefinition) []str
resolved = evalPath
}
exeDir := filepath.Dir(resolved)
- appendPath(filepath.Join(exeDir, name))
- appendPath(filepath.Join(exeDir, assetName))
- appendPath(filepath.Join(exeDir, "drivers", driverType, name))
- appendPath(filepath.Join(exeDir, "drivers", driverType, assetName))
+ for _, name := range names {
+ appendPath(filepath.Join(exeDir, name))
+ }
+ for _, assetName := range assetNames {
+ appendPath(filepath.Join(exeDir, assetName))
+ }
+ for _, typeName := range pathTypeNames {
+ for _, name := range names {
+ appendPath(filepath.Join(exeDir, "drivers", typeName, name))
+ }
+ for _, assetName := range assetNames {
+ appendPath(filepath.Join(exeDir, "drivers", typeName, assetName))
+ }
+ }
resourcesDir := filepath.Clean(filepath.Join(exeDir, "..", "Resources"))
- appendPath(filepath.Join(resourcesDir, "drivers", driverType, name))
- appendPath(filepath.Join(resourcesDir, "drivers", driverType, assetName))
+ for _, typeName := range pathTypeNames {
+ for _, name := range names {
+ appendPath(filepath.Join(resourcesDir, "drivers", typeName, name))
+ }
+ for _, assetName := range assetNames {
+ appendPath(filepath.Join(resourcesDir, "drivers", typeName, assetName))
+ }
+ }
}
if wd, err := os.Getwd(); err == nil && strings.TrimSpace(wd) != "" {
- appendPath(filepath.Join(wd, "dist", assetName))
- appendPath(filepath.Join(wd, assetName))
+ for _, assetName := range assetNames {
+ appendPath(filepath.Join(wd, "dist", assetName))
+ appendPath(filepath.Join(wd, assetName))
+ }
}
unique := make([]string, 0, len(candidates))
@@ -2896,8 +3251,7 @@ func preloadOptionalDriverPackageSizes(definitions []driverDefinition) map[strin
fillFromSizes := func(sizeByAsset map[string]int64, driverTypes []string) []string {
missing := make([]string, 0, len(driverTypes))
for _, driverType := range driverTypes {
- assetName := optionalDriverReleaseAssetName(driverType)
- sizeBytes := sizeByAsset[assetName]
+ sizeBytes := resolveOptionalDriverAssetSize(sizeByAsset, driverType)
if sizeBytes > 0 {
result[driverType] = sizeBytes
continue
diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go
index d80c251..561ef9b 100644
--- a/internal/app/methods_file.go
+++ b/internal/app/methods_file.go
@@ -1291,7 +1291,7 @@ func dumpTableSQL(
createSQL = ddl
}
} else {
- ddl, err := dbInst.GetCreateStatement(schemaName, pureTableName)
+ ddl, err := resolveCreateStatementWithFallback(dbInst, config, dbName, tableName)
if err != nil {
if viewDDL, ok := tryGetViewCreateStatement(dbInst, config, dbName, schemaName, pureTableName); ok {
createSQL = viewDDL
diff --git a/internal/app/methods_redis.go b/internal/app/methods_redis.go
index f356277..e88d79d 100644
--- a/internal/app/methods_redis.go
+++ b/internal/app/methods_redis.go
@@ -4,6 +4,9 @@ import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
+ "fmt"
+ "math"
+ "strconv"
"strings"
"sync"
@@ -107,14 +110,20 @@ func (a *App) RedisTestConnection(config connection.ConnectionConfig) connection
}
// RedisScanKeys scans keys matching a pattern
-func (a *App) RedisScanKeys(config connection.ConnectionConfig, pattern string, cursor uint64, count int64) connection.QueryResult {
+func (a *App) RedisScanKeys(config connection.ConnectionConfig, pattern string, cursor any, count int64) connection.QueryResult {
config.Type = "redis"
client, err := a.getRedisClient(config)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
- result, err := client.ScanKeys(pattern, cursor, count)
+ parsedCursor, err := parseRedisScanCursor(cursor)
+ if err != nil {
+ logger.Warnf("RedisScanKeys 游标解析失败,已回退到起始游标:cursor=%v err=%v", cursor, err)
+ parsedCursor = 0
+ }
+
+ result, err := client.ScanKeys(pattern, parsedCursor, count)
if err != nil {
logger.Error(err, "RedisScanKeys 扫描失败:pattern=%s", pattern)
return connection.QueryResult{Success: false, Message: err.Error()}
@@ -123,6 +132,82 @@ func (a *App) RedisScanKeys(config connection.ConnectionConfig, pattern string,
return connection.QueryResult{Success: true, Data: result}
}
+func parseRedisScanCursor(cursor any) (uint64, error) {
+ switch v := cursor.(type) {
+ case nil:
+ return 0, nil
+ case uint64:
+ return v, nil
+ case uint32:
+ return uint64(v), nil
+ case uint16:
+ return uint64(v), nil
+ case uint8:
+ return uint64(v), nil
+ case uint:
+ return uint64(v), nil
+ case int64:
+ if v < 0 {
+ return 0, fmt.Errorf("游标不能为负数: %d", v)
+ }
+ return uint64(v), nil
+ case int32:
+ if v < 0 {
+ return 0, fmt.Errorf("游标不能为负数: %d", v)
+ }
+ return uint64(v), nil
+ case int16:
+ if v < 0 {
+ return 0, fmt.Errorf("游标不能为负数: %d", v)
+ }
+ return uint64(v), nil
+ case int8:
+ if v < 0 {
+ return 0, fmt.Errorf("游标不能为负数: %d", v)
+ }
+ return uint64(v), nil
+ case int:
+ if v < 0 {
+ return 0, fmt.Errorf("游标不能为负数: %d", v)
+ }
+ return uint64(v), nil
+ case float64:
+ return parseRedisScanCursorFromFloat(v)
+ case float32:
+ return parseRedisScanCursorFromFloat(float64(v))
+ case json.Number:
+ return parseRedisScanCursor(strings.TrimSpace(v.String()))
+ case string:
+ trimmed := strings.TrimSpace(v)
+ if trimmed == "" {
+ return 0, nil
+ }
+ parsed, err := strconv.ParseUint(trimmed, 10, 64)
+ if err != nil {
+ return 0, fmt.Errorf("无效游标: %q", v)
+ }
+ return parsed, nil
+ default:
+ return 0, fmt.Errorf("不支持的游标类型: %T", cursor)
+ }
+}
+
+func parseRedisScanCursorFromFloat(value float64) (uint64, error) {
+ if math.IsNaN(value) || math.IsInf(value, 0) {
+ return 0, fmt.Errorf("无效浮点游标: %v", value)
+ }
+ if value < 0 {
+ return 0, fmt.Errorf("游标不能为负数: %v", value)
+ }
+ if math.Trunc(value) != value {
+ return 0, fmt.Errorf("游标必须为整数: %v", value)
+ }
+ if value > float64(math.MaxUint64) {
+ return 0, fmt.Errorf("游标超出范围: %v", value)
+ }
+ return uint64(value), nil
+}
+
// RedisGetValue gets the value of a key
func (a *App) RedisGetValue(config connection.ConnectionConfig, key string) connection.QueryResult {
config.Type = "redis"
diff --git a/internal/app/methods_redis_cursor_test.go b/internal/app/methods_redis_cursor_test.go
new file mode 100644
index 0000000..e121d8f
--- /dev/null
+++ b/internal/app/methods_redis_cursor_test.go
@@ -0,0 +1,50 @@
+package app
+
+import (
+ "encoding/json"
+ "testing"
+)
+
+func TestParseRedisScanCursor(t *testing.T) {
+ t.Parallel()
+
+ testCases := []struct {
+ name string
+ input any
+ want uint64
+ wantErr bool
+ }{
+ {name: "nil defaults to zero", input: nil, want: 0},
+ {name: "empty string defaults to zero", input: " ", want: 0},
+ {name: "string cursor", input: "123", want: 123},
+ {name: "uint64 cursor", input: uint64(456), want: 456},
+ {name: "int cursor", input: int(789), want: 789},
+ {name: "float cursor", input: float64(42), want: 42},
+ {name: "json number cursor", input: json.Number("88"), want: 88},
+ {name: "negative int rejected", input: -1, wantErr: true},
+ {name: "fraction float rejected", input: float64(1.5), wantErr: true},
+ {name: "invalid string rejected", input: "abc", wantErr: true},
+ {name: "unsupported type rejected", input: true, wantErr: true},
+ }
+
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ got, err := parseRedisScanCursor(tc.input)
+ if tc.wantErr {
+ if err == nil {
+ t.Fatalf("expected error, got nil (value=%d)", got)
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got != tc.want {
+ t.Fatalf("parseRedisScanCursor() mismatch, want=%d got=%d", tc.want, got)
+ }
+ })
+ }
+}
diff --git a/internal/db/clickhouse_impl.go b/internal/db/clickhouse_impl.go
index 4ba1c85..b20a359 100644
--- a/internal/db/clickhouse_impl.go
+++ b/internal/db/clickhouse_impl.go
@@ -17,7 +17,7 @@ import (
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
- _ "github.com/ClickHouse/clickhouse-go/v2"
+ clickhouse "github.com/ClickHouse/clickhouse-go/v2"
)
const (
@@ -100,25 +100,20 @@ func applyClickHouseURI(config connection.ConnectionConfig) connection.Connectio
return config
}
-func (c *ClickHouseDB) getDSN(config connection.ConnectionConfig) string {
- u := &url.URL{
- Scheme: "clickhouse",
- Host: net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
- Path: "/" + strings.TrimPrefix(strings.TrimSpace(config.Database), "/"),
+func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig) *clickhouse.Options {
+ timeout := getConnectTimeout(config)
+ return &clickhouse.Options{
+ Addr: []string{
+ net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
+ },
+ Auth: clickhouse.Auth{
+ Database: strings.TrimSpace(config.Database),
+ Username: strings.TrimSpace(config.User),
+ Password: config.Password,
+ },
+ DialTimeout: timeout,
+ ReadTimeout: timeout,
}
- if strings.TrimSpace(config.Password) != "" {
- u.User = url.UserPassword(strings.TrimSpace(config.User), config.Password)
- } else {
- u.User = url.User(strings.TrimSpace(config.User))
- }
-
- timeoutSeconds := getConnectTimeoutSeconds(config)
- query := u.Query()
- query.Set("dial_timeout", fmt.Sprintf("%ds", timeoutSeconds))
- query.Set("read_timeout", fmt.Sprintf("%ds", timeoutSeconds))
- query.Set("write_timeout", fmt.Sprintf("%ds", timeoutSeconds))
- u.RawQuery = query.Encode()
- return u.String()
}
func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
@@ -165,11 +160,7 @@ func (c *ClickHouseDB) Connect(config connection.ConnectionConfig) error {
logger.Infof("ClickHouse 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
}
- dbConn, err := sql.Open("clickhouse", c.getDSN(runConfig))
- if err != nil {
- return fmt.Errorf("打开数据库连接失败:%w", err)
- }
- c.conn = dbConn
+ c.conn = clickhouse.OpenDB(c.buildClickHouseOptions(runConfig))
if err := c.Ping(); err != nil {
_ = c.Close()
diff --git a/internal/db/diros_impl.go b/internal/db/diros_impl.go
index 30eb116..38ac270 100644
--- a/internal/db/diros_impl.go
+++ b/internal/db/diros_impl.go
@@ -21,7 +21,7 @@ const (
defaultDirosPort = 9030
)
-// DirosDB 使用独立 driver 名称(diros)接入,底层协议兼容 MySQL。
+// DirosDB 使用独立 driver 名称(diros)接入,底层协议兼容 MySQL(对外显示为 Doris)。
type DirosDB struct {
MySQLDB
}
@@ -146,7 +146,7 @@ func (d *DirosDB) getDSN(config connection.ConnectionConfig) string {
protocol = netName
address = normalizeMySQLAddress(config.Host, config.Port)
} else {
- logger.Warnf("注册 Diros SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s,原因:%v", config.Host, config.Port, config.User, err)
+ logger.Warnf("注册 Doris SSH 网络失败,将尝试直连:地址=%s:%d 用户=%s,原因:%v", config.Host, config.Port, config.User, err)
}
}
@@ -177,7 +177,7 @@ func (d *DirosDB) Connect(config connection.ConnectionConfig) error {
runConfig := applyDirosURI(config)
addresses := collectDirosAddresses(runConfig)
if len(addresses) == 0 {
- return fmt.Errorf("连接建立后验证失败:未找到可用的 Diros 地址")
+ return fmt.Errorf("连接建立后验证失败:未找到可用的 Doris 地址")
}
var errorDetails []string
@@ -214,7 +214,7 @@ func (d *DirosDB) Connect(config connection.ConnectionConfig) error {
}
if len(errorDetails) == 0 {
- return fmt.Errorf("连接建立后验证失败:未找到可用的 Diros 地址")
+ return fmt.Errorf("连接建立后验证失败:未找到可用的 Doris 地址")
}
return fmt.Errorf("连接建立后验证失败:%s", strings.Join(errorDetails, ";"))
}
diff --git a/internal/db/driver_support.go b/internal/db/driver_support.go
index 4ffe820..517a81a 100644
--- a/internal/db/driver_support.go
+++ b/internal/db/driver_support.go
@@ -61,7 +61,7 @@ func driverDisplayName(driverType string) string {
case "mariadb":
return "MariaDB"
case "diros":
- return "Diros"
+ return "Doris"
case "sphinx":
return "Sphinx"
case "postgres":
diff --git a/internal/db/dsn_test.go b/internal/db/dsn_test.go
index f3d9392..87ec9f6 100644
--- a/internal/db/dsn_test.go
+++ b/internal/db/dsn_test.go
@@ -5,6 +5,7 @@ package db
import (
"strings"
"testing"
+ "time"
"GoNavi-Wails/internal/connection"
)
@@ -115,7 +116,7 @@ func TestTDengineDSN_UsesWebSocketFormat(t *testing.T) {
}
}
-func TestClickHouseDSN_EscapesPasswordAndSetsTimeout(t *testing.T) {
+func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) {
c := &ClickHouseDB{}
cfg := normalizeClickHouseConfig(connection.ConnectionConfig{
Type: "clickhouse",
@@ -127,17 +128,35 @@ func TestClickHouseDSN_EscapesPasswordAndSetsTimeout(t *testing.T) {
Timeout: 15,
})
- dsn := c.getDSN(cfg)
- if strings.Contains(dsn, cfg.Password) {
- t.Fatalf("dsn 包含原始密码:%s", dsn)
+ opts := c.buildClickHouseOptions(cfg)
+ if opts == nil {
+ t.Fatal("options 为空")
}
- if !strings.Contains(dsn, "p%40ss%3Awo%2Frd") {
- t.Fatalf("dsn 未正确转义密码:%s", dsn)
+ if len(opts.Addr) != 1 || opts.Addr[0] != "127.0.0.1:9000" {
+ t.Fatalf("addr 不符合预期:%v", opts.Addr)
}
- if !strings.Contains(dsn, "dial_timeout=15s") {
- t.Fatalf("dsn 缺少 dial_timeout 参数:%s", dsn)
+ if opts.Auth.Username != "default" {
+ t.Fatalf("username 不符合预期:%s", opts.Auth.Username)
}
- if !strings.Contains(dsn, "/analytics") {
- t.Fatalf("dsn 缺少数据库路径:%s", dsn)
+ if opts.Auth.Password != cfg.Password {
+ t.Fatalf("password 不符合预期:%s", opts.Auth.Password)
+ }
+ if opts.Auth.Database != "analytics" {
+ t.Fatalf("database 不符合预期:%s", opts.Auth.Database)
+ }
+ if opts.DialTimeout != 15*time.Second {
+ t.Fatalf("dial timeout 不符合预期:%s", opts.DialTimeout)
+ }
+ if opts.ReadTimeout != 15*time.Second {
+ t.Fatalf("read timeout 不符合预期:%s", opts.ReadTimeout)
+ }
+ if _, ok := opts.Settings["write_timeout"]; ok {
+ t.Fatalf("options 不应包含 write_timeout 设置:%v", opts.Settings)
+ }
+ if _, ok := opts.Settings["read_timeout"]; ok {
+ t.Fatalf("options 不应通过 settings 传递 read_timeout:%v", opts.Settings)
+ }
+ if _, ok := opts.Settings["dial_timeout"]; ok {
+ t.Fatalf("options 不应通过 settings 传递 dial_timeout:%v", opts.Settings)
}
}
diff --git a/internal/redis/redis.go b/internal/redis/redis.go
index 7e0416d..80e58f6 100644
--- a/internal/redis/redis.go
+++ b/internal/redis/redis.go
@@ -26,7 +26,7 @@ type RedisKeyInfo struct {
// RedisScanResult represents the result of a SCAN operation
type RedisScanResult struct {
Keys []RedisKeyInfo `json:"keys"`
- Cursor uint64 `json:"cursor"`
+ Cursor string `json:"cursor"`
}
// RedisClient defines the interface for Redis operations
diff --git a/internal/redis/redis_impl.go b/internal/redis/redis_impl.go
index 50df382..044f16d 100644
--- a/internal/redis/redis_impl.go
+++ b/internal/redis/redis_impl.go
@@ -175,7 +175,7 @@ func (r *RedisClientImpl) ScanKeys(pattern string, cursor uint64, count int64) (
return &RedisScanResult{
Keys: r.loadRedisKeyInfos(ctx, keys),
- Cursor: currentCursor,
+ Cursor: strconv.FormatUint(currentCursor, 10),
}, nil
}