diff --git a/build-driver-agents.sh b/build-driver-agents.sh new file mode 100755 index 0000000..e3734d2 --- /dev/null +++ b/build-driver-agents.sh @@ -0,0 +1,228 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +DEFAULT_DRIVERS=(mariadb doris sphinx sqlserver sqlite duckdb dameng kingbase highgo vastbase mongodb tdengine clickhouse) + +usage() { + cat <<'EOF' +用法: + ./build-driver-agents.sh [选项] + +选项: + --drivers <列表> 指定驱动列表(逗号分隔),例如:kingbase,mongodb + --platform + 目标平台,默认使用当前 Go 环境(go env GOOS/GOARCH) + --out-dir <目录> 输出目录根路径,默认:dist/driver-agents + --bundle-name <文件名> 驱动总包 zip 名称,默认:GoNavi-DriverAgents.zip + --strict 任一驱动构建失败即中断(默认失败后继续,最后汇总) + -h, --help 显示帮助 + +示例: + ./build-driver-agents.sh + ./build-driver-agents.sh --drivers kingbase + ./build-driver-agents.sh --platform windows/amd64 --drivers kingbase,mongodb +EOF +} + +normalize_driver() { + local name + name="$(echo "${1:-}" | tr '[:upper:]' '[:lower:]' | xargs)" + case "$name" in + doris|diros) echo "doris" ;; + mariadb|sphinx|sqlserver|sqlite|duckdb|dameng|kingbase|highgo|vastbase|mongodb|tdengine|clickhouse) + echo "$name" + ;; + *) + return 1 + ;; + esac +} + +build_driver_name() { + case "$1" in + doris) echo "diros" ;; + *) echo "$1" ;; + esac +} + +platform_dir_name() { + case "$1" in + windows) echo "Windows" ;; + darwin) echo "MacOS" ;; + linux) echo "Linux" ;; + *) echo "Unknown" ;; + esac +} + +driver_csv="" +target_platform="" +out_root="dist/driver-agents" +bundle_name="GoNavi-DriverAgents.zip" +strict_mode="false" + +while [[ $# -gt 0 ]]; do + case "$1" in + --drivers) + driver_csv="${2:-}" + shift 2 + ;; + --platform) + target_platform="${2:-}" + shift 2 + ;; + --out-dir) + out_root="${2:-}" + shift 2 + ;; + --bundle-name) + bundle_name="${2:-}" + shift 2 + ;; + --strict) + strict_mode="true" + shift + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "❌ 未知参数:$1" + usage + exit 1 + ;; + esac +done + +if ! command -v go >/dev/null 2>&1; then + echo "❌ 未找到 Go,请先安装 Go 并确保 go 在 PATH 中。" + exit 1 +fi + +if [[ -z "$target_platform" ]]; then + target_platform="$(go env GOOS)/$(go env GOARCH)" +fi + +if [[ "$target_platform" != */* ]]; then + echo "❌ --platform 参数格式错误,应为 GOOS/GOARCH,例如 darwin/arm64" + exit 1 +fi + +goos="${target_platform%%/*}" +goarch="${target_platform##*/}" +platform_key="${goos}-${goarch}" +platform_dir="$(platform_dir_name "$goos")" + +declare -a drivers=() +if [[ -n "$driver_csv" ]]; then + IFS=',' read -r -a raw_drivers <<<"$driver_csv" + for item in "${raw_drivers[@]}"; do + normalized="$(normalize_driver "$item")" || { + echo "❌ 不支持的驱动:$item" + exit 1 + } + drivers+=("$normalized") + done +else + drivers=("${DEFAULT_DRIVERS[@]}") +fi + +output_dir="${out_root%/}/${platform_key}" +bundle_stage_dir="$(mktemp -d "${TMPDIR:-/tmp}/gonavi-driver-bundle.XXXXXX")" +bundle_platform_dir="$bundle_stage_dir/$platform_dir" + +cleanup() { + rm -rf "$bundle_stage_dir" +} +trap cleanup EXIT + +mkdir -p "$output_dir" "$bundle_platform_dir" +output_dir_abs="$(cd "$output_dir" && pwd)" +bundle_zip_path="$output_dir_abs/$bundle_name" + +declare -a built_assets=() +declare -a failed_drivers=() +declare -a skipped_drivers=() + +echo "🚀 开始构建 optional-driver-agent" +echo " 平台:$goos/$goarch" +echo " 输出目录:$output_dir_abs" +echo " 驱动列表:${drivers[*]}" + +for driver in "${drivers[@]}"; do + if [[ "$driver" == "duckdb" && "$goos" == "windows" && "$goarch" != "amd64" ]]; then + echo "⚠️ 跳过 duckdb(仅支持 windows/amd64)" + skipped_drivers+=("$driver") + continue + fi + + build_driver="$(build_driver_name "$driver")" + tag="gonavi_${build_driver}_driver" + asset_name="${driver}-driver-agent-${goos}-${goarch}" + if [[ "$goos" == "windows" ]]; then + asset_name="${asset_name}.exe" + fi + output_path="$output_dir_abs/$asset_name" + + cgo_enabled=0 + if [[ "$driver" == "duckdb" ]]; then + cgo_enabled=1 + fi + + echo "🔧 构建 $driver -> $asset_name (tag=$tag, CGO_ENABLED=$cgo_enabled)" + set +e + CGO_ENABLED="$cgo_enabled" GOOS="$goos" GOARCH="$goarch" GOTOOLCHAIN=auto \ + go build -tags "$tag" -trimpath -ldflags "-s -w" -o "$output_path" ./cmd/optional-driver-agent + build_exit=$? + set -e + + if [[ $build_exit -ne 0 ]]; then + echo "❌ 构建失败:$driver" + failed_drivers+=("$driver") + if [[ "$strict_mode" == "true" ]]; then + exit $build_exit + fi + continue + fi + + cp "$output_path" "$bundle_platform_dir/$asset_name" + built_assets+=("$asset_name") +done + +if [[ ${#built_assets[@]} -eq 0 ]]; then + echo "❌ 未成功构建任何驱动代理。" + exit 1 +fi + +rm -f "$bundle_zip_path" +if command -v zip >/dev/null 2>&1; then + ( + cd "$bundle_stage_dir" + zip -qry "$bundle_zip_path" "$platform_dir" + ) +elif command -v ditto >/dev/null 2>&1; then + ( + cd "$bundle_stage_dir" + ditto -c -k --sequesterRsrc --keepParent "$platform_dir" "$bundle_zip_path" + ) +else + echo "❌ 未找到 zip/ditto,无法生成驱动总包 zip。" + exit 1 +fi + +echo "" +echo "✅ 构建完成" +echo " 单文件输出目录:$output_dir_abs" +echo " 驱动总包:$bundle_zip_path" +echo " 已构建:${built_assets[*]}" +if [[ ${#skipped_drivers[@]} -gt 0 ]]; then + echo " 已跳过:${skipped_drivers[*]}" +fi +if [[ ${#failed_drivers[@]} -gt 0 ]]; then + echo "⚠️ 构建失败驱动:${failed_drivers[*]}" + exit 2 +fi diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx index ce874a9..1f9d9b5 100644 --- a/frontend/src/components/ConnectionModal.tsx +++ b/frontend/src/components/ConnectionModal.tsx @@ -1049,6 +1049,12 @@ const ConnectionModal: React.FC<{ useEffect(() => { if (open) { + setLoading(false); + testInFlightRef.current = false; + if (testTimerRef.current !== null) { + window.clearTimeout(testTimerRef.current); + testTimerRef.current = null; + } setTestResult(null); // Reset test result setTestErrorLogOpen(false); setDbList([]); @@ -1240,6 +1246,22 @@ const ConnectionModal: React.FC<{ }, 0); }; + const withClientTimeout = async (promise: Promise, timeoutMs: number, timeoutMessage: string): Promise => { + let timer: number | null = null; + try { + return await Promise.race([ + promise, + new Promise((_, reject) => { + timer = window.setTimeout(() => reject(new Error(timeoutMessage)), timeoutMs); + }), + ]); + } finally { + if (timer !== null) { + window.clearTimeout(timer); + } + } + }; + const buildTestFailureMessage = (reason: unknown, fallback: string) => { const text = String(reason ?? '').trim(); const normalized = text && text !== 'undefined' && text !== 'null' ? text : fallback; @@ -1262,12 +1284,21 @@ const ConnectionModal: React.FC<{ setLoading(true); setTestResult(null); const config = await buildConfig(values, false); + const timeoutSecondsRaw = Number(values.timeout); + const timeoutSeconds = Number.isFinite(timeoutSecondsRaw) && timeoutSecondsRaw > 0 + ? Math.min(timeoutSecondsRaw, MAX_TIMEOUT_SECONDS) + : 30; + const rpcTimeoutMs = (timeoutSeconds + 5) * 1000; // Use different API for Redis const isRedisType = values.type === 'redis'; - const res = isRedisType - ? await RedisConnect(config as any) - : await TestConnection(config as any); + const res = await withClientTimeout( + isRedisType + ? RedisConnect(config as any) + : TestConnection(config as any), + rpcTimeoutMs, + `连接测试超时(>${timeoutSeconds} 秒),请检查网络/代理/SSH配置后重试` + ); if (res.success) { setTestResult({ type: 'success', message: res.message }); @@ -1275,7 +1306,11 @@ const ConnectionModal: React.FC<{ setRedisDbList(Array.from({ length: 16 }, (_, i) => i)); } else { // Other databases: fetch database list - const dbRes = await DBGetDatabases(config as any); + const dbRes = await withClientTimeout( + DBGetDatabases(config as any), + rpcTimeoutMs, + `连接成功但拉取数据库列表超时(>${timeoutSeconds} 秒)` + ); if (dbRes.success) { const dbRows = Array.isArray(dbRes.data) ? dbRes.data : []; const dbs = dbRows @@ -1572,12 +1607,13 @@ const ConnectionModal: React.FC<{ }; }; - const handleTypeSelect = async (type: string) => { - const unavailableReason = await resolveDriverUnavailableReason(type); - if (unavailableReason) { - const normalized = normalizeDriverType(type); - const driverName = driverStatusMap[normalized]?.name || type; - setTypeSelectWarning({ driverName, reason: unavailableReason }); + const handleTypeSelect = (type: string) => { + const normalized = normalizeDriverType(type); + const snapshot = driverStatusMap[normalized]; + if (snapshot && !snapshot.connectable) { + const driverName = snapshot.name || type; + const reason = snapshot.message || `${driverName} 驱动未安装启用,请先在驱动管理中安装`; + setTypeSelectWarning({ driverName, reason }); return; } setTypeSelectWarning(null); @@ -1679,6 +1715,10 @@ const ConnectionModal: React.FC<{ setMongoMembers([]); setStep(2); + + if (!driverStatusLoaded || !snapshot) { + void refreshDriverStatus(); + } }; const isFileDb = isFileDatabaseType(dbType); @@ -1851,7 +1891,6 @@ const ConnectionModal: React.FC<{ > {isFileDb ? ( diff --git a/internal/app/app.go b/internal/app/app.go index 0709a27..4a0aff9 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -8,6 +8,8 @@ import ( "errors" "fmt" "net" + "net/url" + "os" "strings" "sync" "time" @@ -218,6 +220,7 @@ func wrapConnectError(config connection.ConnectionConfig, err error) error { if err == nil { return nil } + err = sanitizeMongoConnectErrorLabel(config, err) var netErr net.Error if errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &netErr) && netErr.Timeout()) { @@ -231,6 +234,73 @@ func wrapConnectError(config connection.ConnectionConfig, err error) error { return withLogHint{err: err, logPath: logger.Path()} } +type errorMessageOverride struct { + message string + cause error +} + +func (e errorMessageOverride) Error() string { + return e.message +} + +func (e errorMessageOverride) Unwrap() error { + return e.cause +} + +func sanitizeMongoConnectErrorLabel(config connection.ConnectionConfig, err error) error { + if err == nil { + return nil + } + if strings.ToLower(strings.TrimSpace(config.Type)) != "mongodb" { + return err + } + if mongoConnectUsesTLS(config) { + return err + } + original := err.Error() + rewritten := strings.ReplaceAll(original, "SSL 主库凭据", "主库凭据") + rewritten = strings.ReplaceAll(rewritten, "SSL 从库凭据", "从库凭据") + if rewritten == original { + return err + } + return errorMessageOverride{ + message: rewritten, + cause: err, + } +} + +func mongoConnectUsesTLS(config connection.ConnectionConfig) bool { + if config.UseSSL { + return true + } + uriText := strings.TrimSpace(config.URI) + if uriText == "" { + return false + } + parsed, err := url.Parse(uriText) + if err != nil { + return false + } + for _, key := range []string{"tls", "ssl"} { + if enabled, known := parseMongoBool(parsed.Query().Get(key)); known { + return enabled + } + } + return strings.EqualFold(strings.TrimSpace(parsed.Scheme), "mongodb+srv") +} + +func parseMongoBool(raw string) (enabled bool, known bool) { + value := strings.ToLower(strings.TrimSpace(raw)) + switch value { + case "1", "true", "t", "yes", "y", "on", "required": + return true, true + case "0", "false", "f", "no", "n", "off", "disable", "disabled": + return false, true + default: + return false, false + } +} + type withLogHint struct { err error logPath string @@ -238,10 +308,15 @@ type withLogHint struct { func (e withLogHint) Error() string { message := normalizeErrorMessage(e.err) - if strings.TrimSpace(e.logPath) == "" { + path := strings.TrimSpace(e.logPath) + if path == "" { return message } - return fmt.Sprintf("%s(详细日志:%s)", message, e.logPath) + info, statErr := os.Stat(path) + if statErr != nil || info.IsDir() || info.Size() <= 0 { + return message + } + return fmt.Sprintf("%s(详细日志:%s)", message, path) } func (e withLogHint) Unwrap() error { diff --git a/internal/app/app_connect_error_test.go b/internal/app/app_connect_error_test.go new file mode 100644 index 0000000..36bb99e --- /dev/null +++ b/internal/app/app_connect_error_test.go @@ -0,0 +1,84 @@ +package app + +import ( + "errors" + "os" + "path/filepath" + "strings" + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestWrapConnectError_MongoNoSSL_RemovesMisleadingSSLLabel(t *testing.T) { + config := connection.ConnectionConfig{ + Type: "mongodb", + UseSSL: false, + } + sourceErr := errors.New("MongoDB 连接失败:SSL 主库凭据验证失败: mock error") + + wrapped := wrapConnectError(config, sourceErr) + text := wrapped.Error() + if strings.Contains(text, "SSL 主库凭据") { + t.Fatalf("expected ssl label to be removed when TLS disabled, got: %s", text) + } + if !strings.Contains(text, "主库凭据验证失败") { + t.Fatalf("expected auth label to remain, got: %s", text) + } +} + +func TestWrapConnectError_MongoURIForcesTLS_KeepsSSLLabel(t *testing.T) { + config := connection.ConnectionConfig{ + Type: "mongodb", + UseSSL: false, + URI: "mongodb://user:pass@127.0.0.1:27017/admin?tls=true", + } + sourceErr := errors.New("MongoDB 连接失败:SSL 主库凭据验证失败: mock error") + + wrapped := wrapConnectError(config, sourceErr) + text := wrapped.Error() + if !strings.Contains(text, "SSL 主库凭据") { + t.Fatalf("expected ssl label to remain when URI enables TLS, got: %s", text) + } +} + +func TestWrapConnectError_MongoSRVDefaultTLS_KeepsSSLLabel(t *testing.T) { + config := connection.ConnectionConfig{ + Type: "mongodb", + UseSSL: false, + URI: "mongodb+srv://user:pass@cluster0.example.com/admin", + } + sourceErr := errors.New("MongoDB 连接失败:SSL 主库凭据验证失败: mock error") + + wrapped := wrapConnectError(config, sourceErr) + text := wrapped.Error() + if !strings.Contains(text, "SSL 主库凭据") { + t.Fatalf("expected ssl label to remain for mongodb+srv default TLS, got: %s", text) + } +} + +func TestWithLogHintError_OmitEmptyLogPath(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "gonavi.log") + if err := os.WriteFile(logPath, nil, 0o644); err != nil { + t.Fatalf("write empty log failed: %v", err) + } + err := withLogHint{err: errors.New("连接失败"), logPath: logPath} + text := err.Error() + if strings.Contains(text, "详细日志:") { + t.Fatalf("expected no log hint for empty file, got: %s", text) + } +} + +func TestWithLogHintError_IncludeNonEmptyLogPath(t *testing.T) { + dir := t.TempDir() + logPath := filepath.Join(dir, "gonavi.log") + if err := os.WriteFile(logPath, []byte("log entry\n"), 0o644); err != nil { + t.Fatalf("write log failed: %v", err) + } + err := withLogHint{err: errors.New("连接失败"), logPath: logPath} + text := err.Error() + if !strings.Contains(text, "详细日志:"+logPath) { + t.Fatalf("expected log hint with path, got: %s", text) + } +} diff --git a/internal/app/db_proxy.go b/internal/app/db_proxy.go index e3228b6..14af069 100644 --- a/internal/app/db_proxy.go +++ b/internal/app/db_proxy.go @@ -73,8 +73,8 @@ func resolveDialConfigWithProxy(raw connection.ConnectionConfig) (connection.Con // 文件型/自定义 DSN 类型不走标准 host:port,不在此层改写。 return config, nil } - if normalizedType == "mongodb" && config.MongoSRV { - // Mongo SRV 由驱动侧 Dialer 处理代理,避免破坏 DNS SRV 拓扑发现。 + if normalizedType == "mongodb" { + // MongoDB 统一由驱动侧 Dialer 处理代理,保留原始目标地址,避免将连接目标改写为本地转发地址。 return config, nil } diff --git a/internal/app/db_proxy_test.go b/internal/app/db_proxy_test.go new file mode 100644 index 0000000..5d44170 --- /dev/null +++ b/internal/app/db_proxy_test.go @@ -0,0 +1,64 @@ +package app + +import ( + "reflect" + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestResolveDialConfigWithProxy_MongoKeepsTargetAddress(t *testing.T) { + hosts := []string{"10.20.30.40:27017", "10.20.30.41:27017"} + raw := connection.ConnectionConfig{ + Type: "mongodb", + Host: "10.20.30.40", + Port: 27017, + UseProxy: true, + Proxy: connection.ProxyConfig{ + Type: "socks5", + Host: "127.0.0.1", + Port: 1080, + }, + Hosts: hosts, + } + + got, err := resolveDialConfigWithProxy(raw) + if err != nil { + t.Fatalf("resolveDialConfigWithProxy returned error: %v", err) + } + if got.Host != raw.Host || got.Port != raw.Port { + t.Fatalf("mongo target address should be kept, got=%s:%d want=%s:%d", got.Host, got.Port, raw.Host, raw.Port) + } + if !got.UseProxy { + t.Fatalf("mongo should keep UseProxy=true for driver-level dialer") + } + if !reflect.DeepEqual(got.Hosts, hosts) { + t.Fatalf("mongo hosts should be kept, got=%v want=%v", got.Hosts, hosts) + } +} + +func TestResolveDialConfigWithProxy_MongoSRVKeepsTargetAddress(t *testing.T) { + raw := connection.ConnectionConfig{ + Type: "mongodb", + Host: "cluster0.example.com", + Port: 27017, + MongoSRV: true, + UseProxy: true, + Proxy: connection.ProxyConfig{ + Type: "http", + Host: "127.0.0.1", + Port: 7890, + }, + } + + got, err := resolveDialConfigWithProxy(raw) + if err != nil { + t.Fatalf("resolveDialConfigWithProxy returned error: %v", err) + } + if got.Host != raw.Host || got.Port != raw.Port { + t.Fatalf("mongo SRV target address should be kept, got=%s:%d want=%s:%d", got.Host, got.Port, raw.Host, raw.Port) + } + if !got.UseProxy { + t.Fatalf("mongo SRV should keep UseProxy=true for driver-level dialer") + } +} diff --git a/internal/app/global_proxy.go b/internal/app/global_proxy.go index 4361782..016fb26 100644 --- a/internal/app/global_proxy.go +++ b/internal/app/global_proxy.go @@ -72,25 +72,30 @@ func setGlobalProxyConfig(enabled bool, proxyConfig connection.ProxyConfig) (glo } func (a *App) ConfigureGlobalProxy(enabled bool, proxyConfig connection.ProxyConfig) connection.QueryResult { + before := currentGlobalProxyConfig() snapshot, err := setGlobalProxyConfig(enabled, proxyConfig) if err != nil { return connection.QueryResult{Success: false, Message: err.Error()} } - if snapshot.Enabled { - authState := "" - if strings.TrimSpace(snapshot.Proxy.User) != "" { - authState = "(认证:已配置)" + // 前端可能在同一配置下重复触发同步(例如严格模式或状态回放), + // 这里做幂等日志,避免重复刷屏。 + if !globalProxySnapshotEqual(before, snapshot) { + if snapshot.Enabled { + authState := "" + if strings.TrimSpace(snapshot.Proxy.User) != "" { + authState = "(认证:已配置)" + } + logger.Infof( + "全局代理已启用:%s://%s:%d%s", + strings.ToLower(strings.TrimSpace(snapshot.Proxy.Type)), + strings.TrimSpace(snapshot.Proxy.Host), + snapshot.Proxy.Port, + authState, + ) + } else { + logger.Infof("全局代理已关闭") } - logger.Infof( - "全局代理已启用:%s://%s:%d%s", - strings.ToLower(strings.TrimSpace(snapshot.Proxy.Type)), - strings.TrimSpace(snapshot.Proxy.Host), - snapshot.Proxy.Port, - authState, - ) - } else { - logger.Infof("全局代理已关闭") } return connection.QueryResult{ @@ -100,6 +105,24 @@ func (a *App) ConfigureGlobalProxy(enabled bool, proxyConfig connection.ProxyCon } } +func globalProxySnapshotEqual(a, b globalProxySnapshot) bool { + if a.Enabled != b.Enabled { + return false + } + if !a.Enabled { + return true + } + return proxyConfigEqual(a.Proxy, b.Proxy) +} + +func proxyConfigEqual(a, b connection.ProxyConfig) bool { + return strings.EqualFold(strings.TrimSpace(a.Type), strings.TrimSpace(b.Type)) && + strings.TrimSpace(a.Host) == strings.TrimSpace(b.Host) && + a.Port == b.Port && + strings.TrimSpace(a.User) == strings.TrimSpace(b.User) && + a.Password == b.Password +} + func (a *App) GetGlobalProxyConfig() connection.QueryResult { return connection.QueryResult{ Success: true, diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index b28109f..f411653 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -13,6 +13,16 @@ import ( "GoNavi-Wails/internal/utils" ) +const testConnectionTimeoutUpperBoundSeconds = 12 + +func normalizeTestConnectionConfig(config connection.ConnectionConfig) connection.ConnectionConfig { + normalized := config + if normalized.Timeout <= 0 || normalized.Timeout > testConnectionTimeoutUpperBoundSeconds { + normalized.Timeout = testConnectionTimeoutUpperBoundSeconds + } + return normalized +} + // Generic DB Methods func (a *App) DBConnect(config connection.ConnectionConfig) connection.QueryResult { @@ -28,13 +38,16 @@ func (a *App) DBConnect(config connection.ConnectionConfig) connection.QueryResu } func (a *App) TestConnection(config connection.ConnectionConfig) connection.QueryResult { - _, err := a.getDatabaseForcePing(config) + testConfig := normalizeTestConnectionConfig(config) + started := time.Now() + logger.Infof("TestConnection 开始:%s", formatConnSummary(testConfig)) + _, err := a.getDatabaseForcePing(testConfig) if err != nil { - logger.Error(err, "TestConnection 连接测试失败:%s", formatConnSummary(config)) + logger.Error(err, "TestConnection 连接测试失败:耗时=%s %s", time.Since(started).Round(time.Millisecond), formatConnSummary(testConfig)) return connection.QueryResult{Success: false, Message: err.Error()} } - logger.Infof("TestConnection 连接测试成功:%s", formatConnSummary(config)) + logger.Infof("TestConnection 连接测试成功:耗时=%s %s", time.Since(started).Round(time.Millisecond), formatConnSummary(testConfig)) return connection.QueryResult{Success: true, Message: "连接成功"} } diff --git a/internal/app/methods_db_timeout_test.go b/internal/app/methods_db_timeout_test.go new file mode 100644 index 0000000..d6cf867 --- /dev/null +++ b/internal/app/methods_db_timeout_test.go @@ -0,0 +1,31 @@ +package app + +import ( + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestNormalizeTestConnectionConfig_DefaultToUpperBound(t *testing.T) { + config := connection.ConnectionConfig{Type: "mongodb", Timeout: 0} + got := normalizeTestConnectionConfig(config) + if got.Timeout != testConnectionTimeoutUpperBoundSeconds { + t.Fatalf("expected timeout=%d, got=%d", testConnectionTimeoutUpperBoundSeconds, got.Timeout) + } +} + +func TestNormalizeTestConnectionConfig_KeepSmallerTimeout(t *testing.T) { + config := connection.ConnectionConfig{Type: "mongodb", Timeout: 6} + got := normalizeTestConnectionConfig(config) + if got.Timeout != 6 { + t.Fatalf("expected timeout=6, got=%d", got.Timeout) + } +} + +func TestNormalizeTestConnectionConfig_ClampLargeTimeout(t *testing.T) { + config := connection.ConnectionConfig{Type: "mongodb", Timeout: 60} + got := normalizeTestConnectionConfig(config) + if got.Timeout != testConnectionTimeoutUpperBoundSeconds { + t.Fatalf("expected timeout=%d, got=%d", testConnectionTimeoutUpperBoundSeconds, got.Timeout) + } +} diff --git a/internal/app/methods_driver.go b/internal/app/methods_driver.go index 07a13cc..ca7ce8c 100644 --- a/internal/app/methods_driver.go +++ b/internal/app/methods_driver.go @@ -2792,6 +2792,7 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut driverType := normalizeDriverType(definition.Type) displayName := resolveDriverDisplayName(definition) forceSourceBuild := shouldForceSourceBuildForVersion(driverType, selectedVersion) + preferSourceBuildBeforeDownload := shouldPreferSourceBuildBeforeDownload(driverType, selectedVersion) skipReuseCandidate := shouldSkipReusableAgentCandidate(driverType, selectedVersion) info, err := os.Stat(executablePath) @@ -2799,11 +2800,10 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut if validateErr := db.ValidateOptionalDriverAgentExecutable(driverType, executablePath); validateErr != nil { _ = os.Remove(executablePath) } else { - hash, hashErr := hashFileSHA256(executablePath) - if hashErr != nil { - return "", "", fmt.Errorf("读取已安装 %s 驱动代理摘要失败:%w", displayName, hashErr) + // 用户点击“安装/重装”时应强制刷新驱动代理,避免沿用旧二进制导致修复不生效。 + if removeErr := os.Remove(executablePath); removeErr != nil { + return "", "", fmt.Errorf("清理已安装 %s 驱动代理失败:%w", displayName, removeErr) } - return fmt.Sprintf("local://existing/%s-driver-agent", driverType), hash, nil } } if err == nil && info.IsDir() { @@ -2834,6 +2834,22 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut } var downloadErrs []string + var sourceBuildAttempted bool + var sourceBuildErr error + + if !forceSourceBuild && preferSourceBuildBeforeDownload { + sourceBuildAttempted = true + if a != nil { + a.emitDriverDownloadProgress(driverType, "downloading", 16, 100, fmt.Sprintf("优先使用本地源码构建 %s 驱动代理", displayName)) + } + hash, buildErr := buildOptionalDriverAgentFromSource(definition, executablePath, selectedVersion) + if buildErr == nil { + return fmt.Sprintf("local://go-build/%s-driver-agent", driverType), hash, nil + } + sourceBuildErr = buildErr + logger.Warnf("预先本地构建 %s 驱动代理失败,将继续尝试下载预编译包:%v", displayName, buildErr) + } + if !forceSourceBuild { downloadURLs := resolveOptionalDriverAgentDownloadURLs(definition, downloadURL) if len(downloadURLs) > 0 { @@ -2866,9 +2882,15 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut a.emitDriverDownloadProgress(driverType, "downloading", 92, 100, "未命中预编译包,尝试开发态本地构建") } - hash, buildErr := buildOptionalDriverAgentFromSource(definition, executablePath, selectedVersion) - if buildErr == nil { - return fmt.Sprintf("local://go-build/%s-driver-agent", driverType), hash, nil + var buildErr error + if sourceBuildAttempted { + buildErr = sourceBuildErr + } else { + hash, runErr := buildOptionalDriverAgentFromSource(definition, executablePath, selectedVersion) + buildErr = runErr + if buildErr == nil { + return fmt.Sprintf("local://go-build/%s-driver-agent", driverType), hash, nil + } } var parts []string @@ -3086,12 +3108,25 @@ func shouldForceSourceBuildForVersion(driverType string, selectedVersion string) return resolveMongoDriverMajorFromVersion(selectedVersion) == 1 } -func shouldSkipReusableAgentCandidate(driverType string, selectedVersion string) bool { - if normalizeDriverType(driverType) != "mongodb" { +func shouldPreferSourceBuildBeforeDownload(driverType string, selectedVersion string) bool { + _ = selectedVersion + switch normalizeDriverType(driverType) { + case "kingbase": + // 金仓迭代期优先本地源码构建,避免下载到旧版本预编译代理导致修复不生效。 + return true + default: return false } +} + +func shouldSkipReusableAgentCandidate(driverType string, selectedVersion string) bool { _ = selectedVersion - return true + switch normalizeDriverType(driverType) { + case "mongodb", "kingbase": + return true + default: + return false + } } func optionalDriverBuildTag(driverType string, selectedVersion string) (string, error) { diff --git a/internal/db/kingbase_identifier_utils.go b/internal/db/kingbase_identifier_utils.go new file mode 100644 index 0000000..f3412ac --- /dev/null +++ b/internal/db/kingbase_identifier_utils.go @@ -0,0 +1,164 @@ +package db + +import "strings" + +func normalizeKingbaseIdentCommon(raw string) string { + value := strings.TrimSpace(raw) + if value == "" { + return "" + } + + // 兼容被多次 JSON 序列化后的转义引号: + // \\\"schema\\\" -> \"schema\" -> "schema" + for i := 0; i < 8; i++ { + next := strings.TrimSpace(value) + next = strings.ReplaceAll(next, `\\\"`, `\"`) + next = strings.ReplaceAll(next, `\"`, `"`) + if next == value { + break + } + value = next + } + value = strings.TrimSpace(value) + + stripWrapperOnce := func(text string) string { + t := strings.TrimSpace(text) + if strings.HasPrefix(t, `\`) && len(t) > 1 { + t = strings.TrimSpace(strings.TrimPrefix(t, `\`)) + } + if strings.HasSuffix(t, `\`) && len(t) > 1 { + t = strings.TrimSpace(strings.TrimSuffix(t, `\`)) + } + if len(t) >= 4 && strings.HasPrefix(t, `\"`) && strings.HasSuffix(t, `\"`) { + return strings.TrimSpace(t[2 : len(t)-2]) + } + if len(t) >= 2 && strings.HasPrefix(t, `"`) && strings.HasSuffix(t, `"`) { + return strings.TrimSpace(t[1 : len(t)-1]) + } + if len(t) >= 2 && strings.HasPrefix(t, "`") && strings.HasSuffix(t, "`") { + return strings.TrimSpace(t[1 : len(t)-1]) + } + if len(t) >= 2 && strings.HasPrefix(t, "[") && strings.HasSuffix(t, "]") { + return strings.TrimSpace(t[1 : len(t)-1]) + } + return t + } + + for i := 0; i < 8; i++ { + next := stripWrapperOnce(value) + if next == value { + break + } + value = next + } + value = strings.TrimSpace(value) + + // 兼容错误的二次引用与残留反斜杠。 + value = strings.ReplaceAll(value, `\"`, `"`) + value = strings.ReplaceAll(value, `""`, "") + value = strings.TrimSpace(value) + + for i := 0; i < 8; i++ { + next := strings.TrimSpace(value) + changed := false + if strings.HasPrefix(next, `\`) && len(next) > 1 { + next = strings.TrimSpace(strings.TrimPrefix(next, `\`)) + changed = true + } + if strings.HasSuffix(next, `\`) && len(next) > 1 { + next = strings.TrimSpace(strings.TrimSuffix(next, `\`)) + changed = true + } + if !changed || next == value { + break + } + value = next + } + + return strings.TrimSpace(value) +} + +func splitKingbaseQualifiedNameCommon(raw string) (schema string, table string) { + text := strings.TrimSpace(raw) + if text == "" { + return "", "" + } + + sep := findKingbaseQualifiedSeparator(text) + if sep < 0 { + return "", normalizeKingbaseIdentCommon(text) + } + + schemaPart := normalizeKingbaseIdentCommon(text[:sep]) + tablePart := normalizeKingbaseIdentCommon(text[sep+1:]) + + if tablePart == "" { + if schemaPart == "" { + return "", normalizeKingbaseIdentCommon(text) + } + return "", schemaPart + } + if schemaPart == "" { + return "", tablePart + } + return schemaPart, tablePart +} + +func findKingbaseQualifiedSeparator(raw string) int { + inDouble := false + inBacktick := false + inBracket := false + escaped := false + + for i := 0; i < len(raw); i++ { + ch := raw[i] + if escaped { + escaped = false + continue + } + + if ch == '\\' { + escaped = true + continue + } + + if inDouble { + if ch == '"' { + // SQL 双引号转义:"" 代表字面量 " + if i+1 < len(raw) && raw[i+1] == '"' { + i++ + continue + } + inDouble = false + } + continue + } + + if inBacktick { + if ch == '`' { + inBacktick = false + } + continue + } + + if inBracket { + if ch == ']' { + inBracket = false + } + continue + } + + switch ch { + case '"': + inDouble = true + case '`': + inBacktick = true + case '[': + inBracket = true + case '.': + return i + } + } + + return -1 +} diff --git a/internal/db/kingbase_identifier_utils_test.go b/internal/db/kingbase_identifier_utils_test.go new file mode 100644 index 0000000..69e2b2e --- /dev/null +++ b/internal/db/kingbase_identifier_utils_test.go @@ -0,0 +1,52 @@ +package db + +import "testing" + +func TestNormalizeKingbaseIdentCommon(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "plain", in: "ldf_server", want: "ldf_server"}, + {name: "quoted", in: `"ldf_server"`, want: "ldf_server"}, + {name: "escaped quoted", in: `\"ldf_server\"`, want: "ldf_server"}, + {name: "double escaped quoted", in: `\\\"ldf_server\\\"`, want: "ldf_server"}, + {name: "double quoted", in: `""ldf_server""`, want: "ldf_server"}, + {name: "backtick quoted", in: "`ldf_server`", want: "ldf_server"}, + {name: "bracket quoted", in: "[ldf_server]", want: "ldf_server"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := normalizeKingbaseIdentCommon(tt.in); got != tt.want { + t.Fatalf("normalizeKingbaseIdentCommon(%q)=%q,want=%q", tt.in, got, tt.want) + } + }) + } +} + +func TestSplitKingbaseQualifiedNameCommon(t *testing.T) { + tests := []struct { + name string + in string + wantSchema string + wantTable string + }{ + {name: "plain", in: "ldf_server.andon_events", wantSchema: "ldf_server", wantTable: "andon_events"}, + {name: "quoted", in: `"ldf_server"."andon_events"`, wantSchema: "ldf_server", wantTable: "andon_events"}, + {name: "escaped quoted", in: `\"ldf_server\".\"andon_events\"`, wantSchema: "ldf_server", wantTable: "andon_events"}, + {name: "double escaped quoted", in: `\\\"ldf_server\\\".\\\"andon_events\\\"`, wantSchema: "ldf_server", wantTable: "andon_events"}, + {name: "space around dot", in: ` "ldf_server" . "andon_events" `, wantSchema: "ldf_server", wantTable: "andon_events"}, + {name: "table only", in: "andon_events", wantSchema: "", wantTable: "andon_events"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotSchema, gotTable := splitKingbaseQualifiedNameCommon(tt.in) + if gotSchema != tt.wantSchema || gotTable != tt.wantTable { + t.Fatalf("splitKingbaseQualifiedNameCommon(%q)=(%q,%q),want=(%q,%q)", tt.in, gotSchema, gotTable, tt.wantSchema, tt.wantTable) + } + }) + } +} diff --git a/internal/db/kingbase_impl.go b/internal/db/kingbase_impl.go index c227506..d4eda20 100644 --- a/internal/db/kingbase_impl.go +++ b/internal/db/kingbase_impl.go @@ -21,10 +21,9 @@ import ( ) type KingbaseDB struct { - conn *sql.DB - pingTimeout time.Duration - defaultSearchPath string - forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder + conn *sql.DB + pingTimeout time.Duration + forwarder *ssh.LocalForwarder // Store SSH tunnel forwarder } func quoteConnValue(v string) string { @@ -76,9 +75,6 @@ func (k *KingbaseDB) getDSN(config connection.ConnectionConfig) string { quoteConnValue(resolvePostgresSSLMode(config)), getConnectTimeoutSeconds(config), ) - if strings.TrimSpace(k.defaultSearchPath) != "" { - dsn += fmt.Sprintf(" search_path=%s", quoteConnValue(k.defaultSearchPath)) - } return dsn } @@ -124,9 +120,6 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error { var failures []string for idx, attempt := range attempts { - // 避免跨连接缓存 defaultSearchPath 造成的污染:每次 Connect 都重新探测一次。 - k.defaultSearchPath = "" - dsn := k.getDSN(attempt) db, err := sql.Open("kingbase", dsn) if err != nil { @@ -145,163 +138,85 @@ func (k *KingbaseDB) Connect(config connection.ConnectionConfig) error { logger.Warnf("人大金仓 SSL 优先连接失败,已回退至明文连接") } - k.reconnectWithPreferredSearchPathIfNeeded(attempt) + // 获取 schema 列表以重构带有 search_path 的连接池 + searchPathStr := k.getSearchPathStr() + if searchPathStr != "" { + // 将 search_path 参数拼入 DSN + finalDSN := dsn + " search_path=" + quoteConnValue(searchPathStr) + if finalDB, err := sql.Open("kingbase", finalDSN); err == nil { + k.pingTimeout = getConnectTimeout(attempt) + finalDB.SetConnMaxLifetime(5 * time.Minute) + + // 临时将 k.conn 指向 finalDB 来做 ping 测试 + oldConn := k.conn + k.conn = finalDB + if err := k.Ping(); err == nil { + // 成功使用带 search_path 的连接池 + _ = oldConn.Close() + logger.Infof("人大金仓已配置连接级 search_path:%s", searchPathStr) + } else { + _ = finalDB.Close() + k.conn = oldConn + } + } + } + if searchPathStr != "" { + timeout := k.pingTimeout + if timeout <= 0 { + timeout = 5 * time.Second + } + ctx, cancel := utils.ContextWithTimeout(timeout) + defer cancel() + if _, err := k.conn.ExecContext(ctx, fmt.Sprintf("SET search_path TO %s", searchPathStr)); err != nil { + logger.Warnf("人大金仓显式设置 search_path 失败:%v", err) + } else { + logger.Infof("人大金仓已设置默认 search_path:%s", searchPathStr) + } + } + return nil } return fmt.Errorf("连接建立后验证失败:%s", strings.Join(failures, ";")) } -func (k *KingbaseDB) reconnectWithPreferredSearchPathIfNeeded(config connection.ConnectionConfig) { +// getSearchPathStr 查询当前数据库中所有用户 schema,配置 DSN 的 search_path。 +// KingBase 默认 search_path 为 "$user", public,对于自定义 schema 下的表不可见。 +func (k *KingbaseDB) getSearchPathStr() string { if k.conn == nil { - return + return "" } - timeout := k.pingTimeout - if timeout <= 0 { - timeout = 5 * time.Second - } - ctx, cancel := utils.ContextWithTimeout(timeout) - defer cancel() + query := `SELECT nspname FROM pg_namespace + WHERE nspname NOT IN ('pg_catalog', 'information_schema') + AND nspname NOT LIKE 'pg_%' + ORDER BY nspname` - var currentSchema string - if err := k.conn.QueryRowContext(ctx, "SELECT current_schema()").Scan(¤tSchema); err != nil { - logger.Warnf("人大金仓读取当前 schema 失败:%v", err) - return - } - - if schema := strings.TrimSpace(currentSchema); schema != "" && !strings.EqualFold(schema, "public") { - return - } - - searchPath, chosenSchema := k.detectPreferredSearchPath(ctx, config) - if strings.TrimSpace(searchPath) == "" { - return - } - - oldConn := k.conn - prevSearchPath := k.defaultSearchPath - k.defaultSearchPath = searchPath - - dsn := k.getDSN(config) - newConn, err := sql.Open("kingbase", dsn) + rows, err := k.conn.Query(query) if err != nil { - k.defaultSearchPath = prevSearchPath - logger.Warnf("人大金仓重连以设置 search_path 失败:%v", err) - return - } - if err := newConn.PingContext(ctx); err != nil { - _ = newConn.Close() - k.defaultSearchPath = prevSearchPath - logger.Warnf("人大金仓重连后验证失败:%v", err) - return - } - - k.conn = newConn - _ = oldConn.Close() - logger.Infof("人大金仓已设置默认 schema:%s", chosenSchema) -} - -func (k *KingbaseDB) kingbaseSchemaExists(ctx context.Context, schema string) (bool, error) { - if schema = strings.TrimSpace(schema); schema == "" { - return false, nil - } - - var one int - err := k.conn.QueryRowContext(ctx, "SELECT 1 FROM pg_namespace WHERE nspname = $1", schema).Scan(&one) - if err == sql.ErrNoRows { - return false, nil - } - if err != nil { - return false, err - } - return true, nil -} - -func (k *KingbaseDB) detectPreferredSearchPath(ctx context.Context, config connection.ConnectionConfig) (searchPath string, chosenSchema string) { - // 1) 优先使用与数据库名/用户名同名的 schema(需要存在) - candidates := []string{ - normalizeKingbaseIdentifier(config.Database), - normalizeKingbaseIdentifier(config.User), - } - - seen := make(map[string]struct{}, len(candidates)) - for _, candidate := range candidates { - if candidate == "" || strings.EqualFold(candidate, "public") { - continue - } - key := strings.ToLower(candidate) - if _, ok := seen[key]; ok { - continue - } - seen[key] = struct{}{} - - exists, err := k.kingbaseSchemaExists(ctx, candidate) - if err != nil { - logger.Warnf("人大金仓检查 schema 是否存在失败:schema=%s err=%v", candidate, err) - continue - } - if !exists { - continue - } - - return fmt.Sprintf("%s,public", quoteKingbaseIdent(candidate)), candidate - } - - // 2) 如果只有一个“用户 schema”含有表,则将其作为默认 schema(更符合 DB GUI 的直觉) - schema, err := k.detectSingleUserSchemaWithTables(ctx) - if err != nil { - logger.Warnf("人大金仓探测默认 schema 失败:%v", err) - return "", "" - } - if schema == "" || strings.EqualFold(schema, "public") { - return "", "" - } - return fmt.Sprintf("%s,public", quoteKingbaseIdent(schema)), schema -} - -func (k *KingbaseDB) detectSingleUserSchemaWithTables(ctx context.Context) (string, error) { - if k.conn == nil { - return "", nil - } - - // 仅在“唯一用户 schema”场景做兜底,避免多 schema 下误选导致对象解析歧义。 - // 注:information_schema.tables 的视图在 PG/金仓语义稳定且权限要求相对低。 - query := ` -SELECT table_schema, COUNT(*) AS table_count -FROM information_schema.tables -WHERE table_type = 'BASE TABLE' - AND table_schema NOT IN ('pg_catalog', 'information_schema', 'public') - AND table_schema NOT LIKE 'pg_%' -GROUP BY table_schema -ORDER BY table_count DESC, table_schema -LIMIT 2` - - rows, err := k.conn.QueryContext(ctx, query) - if err != nil { - return "", err + logger.Warnf("人大金仓查询用户 schema 失败,跳过 search_path 设置:%v", err) + return "" } defer rows.Close() - type row struct { - schema string - count int64 - } - var results []row + var schemas []string for rows.Next() { - var r row - if scanErr := rows.Scan(&r.schema, &r.count); scanErr != nil { - return "", scanErr + var name string + if err := rows.Scan(&name); err != nil { + continue + } + name = strings.TrimSpace(name) + if name != "" { + // 使用 SQL 标准的双引号包裹标识符 + escaped := strings.ReplaceAll(name, `"`, `""`) + schemas = append(schemas, `"`+escaped+`"`) } - results = append(results, r) - } - if err := rows.Err(); err != nil { - return "", err } - if len(results) != 1 { - return "", nil + if len(schemas) == 0 { + return "" } - return normalizeKingbaseIdentifier(results[0].schema), nil + + return strings.Join(schemas, ", ") } func (k *KingbaseDB) Close() error { @@ -938,34 +853,7 @@ func (k *KingbaseDB) ApplyChanges(tableName string, changes connection.ChangeSet } func normalizeKingbaseIdentifier(raw string) string { - value := strings.TrimSpace(raw) - if value == "" { - return "" - } - - // 兼容 JSON/字符串转义后传入的标识符:\"schema\" -> "schema" - value = strings.ReplaceAll(value, `\"`, `"`) - value = strings.TrimSpace(value) - - // 兼容异常多重包裹引号(例如 ""schema""、""""schema"""")。 - // strings.Trim 会移除两端连续引号,迭代后可收敛到纯标识符。 - for i := 0; i < 4; i++ { - next := strings.TrimSpace(strings.Trim(value, `"`)) - if next == value { - break - } - value = next - } - - // 兼容其他方言可能残留的引用形式 - if len(value) >= 2 && strings.HasPrefix(value, "`") && strings.HasSuffix(value, "`") { - value = strings.TrimSpace(strings.Trim(value, "`")) - } - if len(value) >= 2 && strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") { - value = strings.TrimSpace(value[1 : len(value)-1]) - } - - return value + return normalizeKingbaseIdentCommon(raw) } // kingbaseIdentNeedsQuote 判断标识符是否需要双引号包裹。 @@ -1002,7 +890,7 @@ func isKingbaseReservedWord(ident string) bool { "begin", "commit", "rollback", "schema", "database", "view", "function", "procedure", "sequence", "type", "domain", "role", "session", "current", "authorization", "cross", "full", "natural", "some", "cast", "fetch", - "for", "to", "do", "if", "return", "returns", "declare", "cursor": + "for", "to", "do", "if", "return", "returns", "declare", "cursor", "server", "owner": return true } return false @@ -1013,7 +901,6 @@ func quoteKingbaseIdent(name string) string { if n == "" { return "\"\"" } - // 仅在需要时才加双引号,避免 KingbaseES 兼容性问题 if !kingbaseIdentNeedsQuote(n) { return n } @@ -1022,24 +909,7 @@ func quoteKingbaseIdent(name string) string { } func splitKingbaseQualifiedTable(tableName string) (schema string, table string) { - raw := strings.TrimSpace(tableName) - if raw == "" { - return "", "" - } - - if parts := strings.SplitN(raw, ".", 2); len(parts) == 2 { - schema = normalizeKingbaseIdentifier(parts[0]) - table = normalizeKingbaseIdentifier(parts[1]) - if table == "" { - return "", normalizeKingbaseIdentifier(raw) - } - if schema == "" { - return "", table - } - return schema, table - } - - return "", normalizeKingbaseIdentifier(raw) + return splitKingbaseQualifiedNameCommon(tableName) } func (k *KingbaseDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { diff --git a/internal/db/kingbase_impl_test.go b/internal/db/kingbase_impl_test.go index afad520..8b0d6f5 100644 --- a/internal/db/kingbase_impl_test.go +++ b/internal/db/kingbase_impl_test.go @@ -15,8 +15,10 @@ func TestNormalizeKingbaseIdentifier(t *testing.T) { {name: "double quoted", in: `""ldf_server""`, want: "ldf_server"}, {name: "quad quoted", in: `""""ldf_server""""`, want: "ldf_server"}, {name: "escaped quoted", in: `\"ldf_server\"`, want: "ldf_server"}, + {name: "double escaped quoted", in: `\\\"ldf_server\\\"`, want: "ldf_server"}, {name: "backtick quoted", in: "`ldf_server`", want: "ldf_server"}, {name: "bracket quoted", in: "[ldf_server]", want: "ldf_server"}, + {name: "embedded double quotes", in: `ldf""server`, want: "ldfserver"}, } for _, tt := range tests { @@ -99,6 +101,7 @@ func TestSplitKingbaseQualifiedTable(t *testing.T) { {name: "plain qualified", in: "ldf_server.t_user", wantSchema: "ldf_server", wantTable: "t_user"}, {name: "double quoted qualified", in: `""ldf_server"".""t_user""`, wantSchema: "ldf_server", wantTable: "t_user"}, {name: "escaped qualified", in: `\"ldf_server\".\"t_user\"`, wantSchema: "ldf_server", wantTable: "t_user"}, + {name: "double escaped qualified", in: `\\\"ldf_server\\\".\\\"t_user\\\"`, wantSchema: "ldf_server", wantTable: "t_user"}, {name: "bracket qualified", in: "[ldf_server].[t_user]", wantSchema: "ldf_server", wantTable: "t_user"}, {name: "table only", in: `""t_user""`, wantSchema: "", wantTable: "t_user"}, } diff --git a/internal/db/mongodb_impl.go b/internal/db/mongodb_impl.go index 27ac0c7..dff4644 100644 --- a/internal/db/mongodb_impl.go +++ b/internal/db/mongodb_impl.go @@ -151,10 +151,14 @@ func applyMongoURI(config connection.ConnectionConfig) connection.ConnectionConf } } - if len(config.Hosts) == 0 && len(hostsFromURI) > 0 { + explicitHost := strings.TrimSpace(config.Host) != "" + explicitHosts := len(config.Hosts) > 0 + + // 显式填写的 host/hosts 优先级高于 URI,避免表单 host 被 URI 中的 localhost 覆盖。 + if !explicitHost && !explicitHosts && len(hostsFromURI) > 0 { config.Hosts = hostsFromURI } - if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 { + if !explicitHost && !explicitHosts && len(hostsFromURI) > 0 { host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort) if ok { config.Host = host @@ -281,9 +285,44 @@ func buildMongoAuthAttempts(config connection.ConnectionConfig) []connection.Con return attempts } +func mongoURIForcesTLS(uriText string) bool { + trimmed := strings.TrimSpace(uriText) + if trimmed == "" { + return false + } + parsed, err := url.Parse(trimmed) + if err != nil { + return false + } + query := parsed.Query() + for _, key := range []string{"tls", "ssl"} { + value := strings.ToLower(strings.TrimSpace(query.Get(key))) + switch value { + case "1", "true", "t", "yes", "y", "required": + return true + } + } + return false +} + +func mongoAttemptSSLLabel(config connection.ConnectionConfig, fallbackToPlain bool) string { + if fallbackToPlain { + return "明文回退" + } + if mongoURIForcesTLS(config.URI) { + return "SSL" + } + enabled, _ := resolveMongoTLSSettings(config) + if enabled { + return "SSL" + } + return "明文" +} + func (m *MongoDB) Connect(config connection.ConnectionConfig) error { runConfig := applyMongoURI(config) connectConfig := runConfig + sshRouteHint := "" if runConfig.UseSSH && runConfig.MongoSRV { return fmt.Errorf("MongoDB SRV 记录模式暂不支持 SSH 隧道") @@ -324,6 +363,7 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error { localConfig.URI = "" localConfig.Hosts = []string{normalizeMongoAddress(host, port)} connectConfig = localConfig + sshRouteHint = fmt.Sprintf("SSH隧道 %s -> %s:%d", forwarder.LocalAddr, targetHost, targetPort) logger.Infof("MongoDB 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, targetHost, targetPort) } @@ -337,20 +377,32 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error { if shouldTrySSLPreferredFallback(connectConfig) { sslAttempts = append(sslAttempts, withSSLDisabled(connectConfig)) } + totalAttempts := 0 + for _, attemptConfig := range sslAttempts { + totalAttempts += len(buildMongoAuthAttempts(attemptConfig)) + } + attemptNo := 0 var errorDetails []string for sslIndex, sslConfig := range sslAttempts { - sslLabel := "SSL" - if sslIndex > 0 { - sslLabel = "明文回退" - } + sslLabel := mongoAttemptSSLLabel(sslConfig, sslIndex > 0) attemptConfigs := buildMongoAuthAttempts(sslConfig) for index, attemptConfig := range attemptConfigs { + attemptNo++ authLabel := "主库凭据" if index > 0 { authLabel = "从库凭据" } + targets := collectMongoSeeds(attemptConfig) + if len(targets) == 0 { + targets = append(targets, normalizeMongoAddress(attemptConfig.Host, attemptConfig.Port)) + } + attemptStarted := time.Now() + logger.Infof( + "MongoDB 连接尝试:%d/%d 模式=%s 凭据=%s 目标=%s 代理=%t", + attemptNo, totalAttempts, sslLabel, authLabel, strings.Join(targets, ","), attemptConfig.UseProxy, + ) if sslIndex > 0 { attemptConfig.URI = "" @@ -369,7 +421,13 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error { } client, err := mongo.Connect(clientOpts) if err != nil { - errorDetails = append(errorDetails, fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err)) + logger.Warnf("MongoDB 连接尝试失败:%d/%d 模式=%s 凭据=%s 耗时=%s 错误=%v", + attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond), err) + detail := fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err) + if sshRouteHint != "" { + detail = fmt.Sprintf("%s(%s)", detail, sshRouteHint) + } + errorDetails = append(errorDetails, detail) continue } @@ -379,9 +437,17 @@ func (m *MongoDB) Connect(config connection.ConnectionConfig) error { _ = client.Disconnect(ctx) cancel() m.client = nil - errorDetails = append(errorDetails, fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err)) + logger.Warnf("MongoDB 连接尝试验证失败:%d/%d 模式=%s 凭据=%s 耗时=%s 错误=%v", + attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond), err) + detail := fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err) + if sshRouteHint != "" { + detail = fmt.Sprintf("%s(%s)", detail, sshRouteHint) + } + errorDetails = append(errorDetails, detail) continue } + logger.Infof("MongoDB 连接尝试成功:%d/%d 模式=%s 凭据=%s 耗时=%s", + attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond)) if sslIndex > 0 { logger.Warnf("MongoDB SSL 优先连接失败,已回退至明文连接") } diff --git a/internal/db/mongodb_impl_uri_test.go b/internal/db/mongodb_impl_uri_test.go new file mode 100644 index 0000000..020b293 --- /dev/null +++ b/internal/db/mongodb_impl_uri_test.go @@ -0,0 +1,39 @@ +//go:build gonavi_full_drivers || gonavi_mongodb_driver + +package db + +import ( + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestApplyMongoURI_ExplicitHostDoesNotAdoptURIHosts(t *testing.T) { + config := connection.ConnectionConfig{ + Host: "10.10.10.10", + Port: 27017, + URI: "mongodb://localhost:27017/admin", + } + + got := applyMongoURI(config) + if got.Host != "10.10.10.10" { + t.Fatalf("expected host to remain explicit, got %q", got.Host) + } + if len(got.Hosts) != 0 { + t.Fatalf("expected hosts to remain empty when explicit host exists, got %v", got.Hosts) + } +} + +func TestApplyMongoURI_ExplicitHostsDoesNotAdoptURIHosts(t *testing.T) { + config := connection.ConnectionConfig{ + Host: "10.10.10.10", + Port: 27017, + Hosts: []string{"10.10.10.10:27017", "10.10.10.11:27017"}, + URI: "mongodb://localhost:27017,localhost:27018/admin?replicaSet=rs0", + } + + got := applyMongoURI(config) + if len(got.Hosts) != 2 || got.Hosts[0] != "10.10.10.10:27017" { + t.Fatalf("expected explicit hosts to stay untouched, got %v", got.Hosts) + } +} diff --git a/internal/db/mongodb_impl_v1.go b/internal/db/mongodb_impl_v1.go index e3aa5b4..60d4fb2 100644 --- a/internal/db/mongodb_impl_v1.go +++ b/internal/db/mongodb_impl_v1.go @@ -152,10 +152,14 @@ func applyMongoURI(config connection.ConnectionConfig) connection.ConnectionConf } } - if len(config.Hosts) == 0 && len(hostsFromURI) > 0 { + explicitHost := strings.TrimSpace(config.Host) != "" + explicitHosts := len(config.Hosts) > 0 + + // 显式填写的 host/hosts 优先级高于 URI,避免表单 host 被 URI 中的 localhost 覆盖。 + if !explicitHost && !explicitHosts && len(hostsFromURI) > 0 { config.Hosts = hostsFromURI } - if strings.TrimSpace(config.Host) == "" && len(hostsFromURI) > 0 { + if !explicitHost && !explicitHosts && len(hostsFromURI) > 0 { host, port, ok := parseHostPortWithDefault(hostsFromURI[0], defaultPort) if ok { config.Host = host @@ -282,9 +286,44 @@ func buildMongoAuthAttempts(config connection.ConnectionConfig) []connection.Con return attempts } +func mongoURIForcesTLS(uriText string) bool { + trimmed := strings.TrimSpace(uriText) + if trimmed == "" { + return false + } + parsed, err := url.Parse(trimmed) + if err != nil { + return false + } + query := parsed.Query() + for _, key := range []string{"tls", "ssl"} { + value := strings.ToLower(strings.TrimSpace(query.Get(key))) + switch value { + case "1", "true", "t", "yes", "y", "required": + return true + } + } + return false +} + +func mongoAttemptSSLLabel(config connection.ConnectionConfig, fallbackToPlain bool) string { + if fallbackToPlain { + return "明文回退" + } + if mongoURIForcesTLS(config.URI) { + return "SSL" + } + enabled, _ := resolveMongoTLSSettings(config) + if enabled { + return "SSL" + } + return "明文" +} + func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error { runConfig := applyMongoURI(config) connectConfig := runConfig + sshRouteHint := "" if runConfig.UseSSH && runConfig.MongoSRV { return fmt.Errorf("MongoDB SRV 记录模式暂不支持 SSH 隧道") @@ -325,6 +364,7 @@ func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error { localConfig.URI = "" localConfig.Hosts = []string{normalizeMongoAddress(host, port)} connectConfig = localConfig + sshRouteHint = fmt.Sprintf("SSH隧道 %s -> %s:%d", forwarder.LocalAddr, targetHost, targetPort) logger.Infof("MongoDB 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, targetHost, targetPort) } @@ -338,20 +378,32 @@ func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error { if shouldTrySSLPreferredFallback(connectConfig) { sslAttempts = append(sslAttempts, withSSLDisabled(connectConfig)) } + totalAttempts := 0 + for _, attemptConfig := range sslAttempts { + totalAttempts += len(buildMongoAuthAttempts(attemptConfig)) + } + attemptNo := 0 var errorDetails []string for sslIndex, sslConfig := range sslAttempts { - sslLabel := "SSL" - if sslIndex > 0 { - sslLabel = "明文回退" - } + sslLabel := mongoAttemptSSLLabel(sslConfig, sslIndex > 0) attemptConfigs := buildMongoAuthAttempts(sslConfig) for index, attemptConfig := range attemptConfigs { + attemptNo++ authLabel := "主库凭据" if index > 0 { authLabel = "从库凭据" } + targets := collectMongoSeeds(attemptConfig) + if len(targets) == 0 { + targets = append(targets, normalizeMongoAddress(attemptConfig.Host, attemptConfig.Port)) + } + attemptStarted := time.Now() + logger.Infof( + "MongoDB(v1) 连接尝试:%d/%d 模式=%s 凭据=%s 目标=%s 代理=%t", + attemptNo, totalAttempts, sslLabel, authLabel, strings.Join(targets, ","), attemptConfig.UseProxy, + ) if sslIndex > 0 { attemptConfig.URI = "" @@ -372,7 +424,13 @@ func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error { client, err := mongo.Connect(connectCtx, clientOpts) connectCancel() if err != nil { - errorDetails = append(errorDetails, fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err)) + logger.Warnf("MongoDB(v1) 连接尝试失败:%d/%d 模式=%s 凭据=%s 耗时=%s 错误=%v", + attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond), err) + detail := fmt.Sprintf("%s %s连接失败: %v", sslLabel, authLabel, err) + if sshRouteHint != "" { + detail = fmt.Sprintf("%s(%s)", detail, sshRouteHint) + } + errorDetails = append(errorDetails, detail) continue } @@ -382,9 +440,17 @@ func (m *MongoDBV1) Connect(config connection.ConnectionConfig) error { _ = client.Disconnect(ctx) cancel() m.client = nil - errorDetails = append(errorDetails, fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err)) + logger.Warnf("MongoDB(v1) 连接尝试验证失败:%d/%d 模式=%s 凭据=%s 耗时=%s 错误=%v", + attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond), err) + detail := fmt.Sprintf("%s %s验证失败: %v", sslLabel, authLabel, err) + if sshRouteHint != "" { + detail = fmt.Sprintf("%s(%s)", detail, sshRouteHint) + } + errorDetails = append(errorDetails, detail) continue } + logger.Infof("MongoDB(v1) 连接尝试成功:%d/%d 模式=%s 凭据=%s 耗时=%s", + attemptNo, totalAttempts, sslLabel, authLabel, time.Since(attemptStarted).Round(time.Millisecond)) if sslIndex > 0 { logger.Warnf("MongoDB(v1) SSL 优先连接失败,已回退至明文连接") } diff --git a/internal/db/mongodb_impl_v1_uri_test.go b/internal/db/mongodb_impl_v1_uri_test.go new file mode 100644 index 0000000..8860db2 --- /dev/null +++ b/internal/db/mongodb_impl_v1_uri_test.go @@ -0,0 +1,25 @@ +//go:build gonavi_mongodb_driver_v1 + +package db + +import ( + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestApplyMongoURIV1_ExplicitHostDoesNotAdoptURIHosts(t *testing.T) { + config := connection.ConnectionConfig{ + Host: "10.10.10.10", + Port: 27017, + URI: "mongodb://localhost:27017/admin", + } + + got := applyMongoURI(config) + if got.Host != "10.10.10.10" { + t.Fatalf("expected host to remain explicit, got %q", got.Host) + } + if len(got.Hosts) != 0 { + t.Fatalf("expected hosts to remain empty when explicit host exists, got %v", got.Hosts) + } +} diff --git a/internal/db/optional_driver_agent_impl.go b/internal/db/optional_driver_agent_impl.go index 2579b7c..07fd7d3 100644 --- a/internal/db/optional_driver_agent_impl.go +++ b/internal/db/optional_driver_agent_impl.go @@ -9,6 +9,7 @@ import ( "io" "os" "os/exec" + "reflect" "runtime" "strings" "sync" @@ -145,6 +146,7 @@ func (c *optionalDriverAgentClient) captureStderr(stderr io.Reader) { if line == "" { continue } + logger.Warnf("%s 驱动代理 stderr: %s", driverDisplayName(c.driver), line) c.stderrMu.Lock() if c.stderr.Len() > 0 { c.stderr.WriteString(" | ") @@ -268,6 +270,7 @@ func (d *OptionalDriverAgentDB) Connect(config connection.ConnectionConfig) erro return err } d.client = client + d.ensureKingbaseSearchPath(config) return nil } @@ -488,6 +491,16 @@ func (d *OptionalDriverAgentDB) ApplyChanges(tableName string, changes connectio if err != nil { return err } + if strings.EqualFold(d.driverType, "kingbase") { + if normalized := normalizeKingbaseAgentTableName(tableName); normalized != "" { + tableName = normalized + } + if normalized, normErr := d.normalizeKingbaseAgentChangeSet(tableName, changes); normErr == nil { + changes = normalized + } else { + logger.Warnf("Kingbase ApplyChanges 字段名规范化失败:%v", normErr) + } + } return client.call(optionalAgentRequest{ Method: optionalAgentMethodApplyChanges, TableName: tableName, @@ -502,6 +515,269 @@ func (d *OptionalDriverAgentDB) requireClient() (*optionalDriverAgentClient, err return d.client, nil } +func (d *OptionalDriverAgentDB) ensureKingbaseSearchPath(config connection.ConnectionConfig) { + if !strings.EqualFold(d.driverType, "kingbase") { + return + } + client, err := d.requireClient() + if err != nil || client == nil { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + schemas, err := d.listKingbaseSchemas(ctx) + if err != nil || len(schemas) == 0 { + if err != nil { + logger.Warnf("人大金仓驱动代理探测 schema 失败:%v", err) + } + return + } + + searchPath := buildKingbaseSearchPathFromSchemas(schemas) + if strings.TrimSpace(searchPath) == "" { + return + } + + if _, err := d.ExecContext(ctx, fmt.Sprintf("SET search_path TO %s", searchPath)); err != nil { + logger.Warnf("人大金仓驱动代理设置 search_path 失败:%v", err) + return + } + logger.Infof("人大金仓驱动代理已设置默认 search_path:%s", searchPath) +} + +func (d *OptionalDriverAgentDB) listKingbaseSchemas(ctx context.Context) ([]string, error) { + query := `SELECT nspname FROM pg_namespace + WHERE nspname NOT IN ('pg_catalog', 'information_schema') + AND nspname NOT LIKE 'pg_%' + ORDER BY nspname` + rows, _, err := d.QueryContext(ctx, query) + if err != nil { + return nil, err + } + + schemas := make([]string, 0, len(rows)) + for _, row := range rows { + for key, val := range row { + if strings.EqualFold(key, "nspname") || strings.EqualFold(key, "schema") { + name := strings.TrimSpace(fmt.Sprintf("%v", val)) + if name != "" { + schemas = append(schemas, name) + } + break + } + } + if len(row) == 1 { + for _, val := range row { + name := strings.TrimSpace(fmt.Sprintf("%v", val)) + if name != "" { + schemas = append(schemas, name) + } + break + } + } + } + return schemas, nil +} + +func buildKingbaseSearchPathFromSchemas(schemas []string) string { + if len(schemas) == 0 { + return "" + } + seen := make(map[string]struct{}, len(schemas)+1) + parts := make([]string, 0, len(schemas)+1) + for _, name := range schemas { + trimmed := normalizeKingbaseAgentIdent(name) + if trimmed == "" { + continue + } + key := strings.ToLower(trimmed) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + parts = append(parts, quoteKingbaseAgentIdent(trimmed)) + } + if _, ok := seen["public"]; !ok { + parts = append(parts, "public") + } + return strings.Join(parts, ", ") +} + +func quoteKingbaseAgentIdent(name string) string { + n := normalizeKingbaseAgentIdent(name) + if n == "" { + return "\"\"" + } + n = strings.ReplaceAll(n, `"`, `""`) + return `"` + n + `"` +} + +func normalizeKingbaseAgentTableName(raw string) string { + schema, table := splitKingbaseQualifiedNameCommon(raw) + if table == "" { + return "" + } + if schema == "" { + return table + } + return schema + "." + table +} + +func normalizeKingbaseAgentIdent(raw string) string { + return normalizeKingbaseIdentCommon(raw) +} + +type kingbaseAgentColumnIndex struct { + exact map[string]string + compact map[string]string +} + +func buildKingbaseAgentColumnIndex(columns []string) kingbaseAgentColumnIndex { + exact := make(map[string]string, len(columns)) + compact := make(map[string]string, len(columns)) + compactSeen := make(map[string]string, len(columns)) + compactDup := make(map[string]struct{}, len(columns)) + + for _, col := range columns { + name := normalizeKingbaseAgentIdent(col) + if name == "" { + continue + } + lower := strings.ToLower(name) + if _, ok := exact[lower]; !ok { + exact[lower] = name + } + key := normalizeKingbaseAgentCompactKey(name) + if key == "" { + continue + } + if prev, ok := compactSeen[key]; ok && !strings.EqualFold(prev, name) { + compactDup[key] = struct{}{} + continue + } + compactSeen[key] = name + } + + if len(compactDup) > 0 { + for key := range compactDup { + delete(compactSeen, key) + } + } + for key, value := range compactSeen { + compact[key] = value + } + return kingbaseAgentColumnIndex{exact: exact, compact: compact} +} + +func normalizeKingbaseAgentCompactKey(raw string) string { + name := normalizeKingbaseAgentIdent(raw) + if name == "" { + return "" + } + name = strings.ToLower(strings.TrimSpace(name)) + name = strings.Join(strings.Fields(name), "") + name = strings.ReplaceAll(name, "_", "") + return name +} + +func resolveKingbaseAgentColumnName(name string, index kingbaseAgentColumnIndex) string { + cleaned := normalizeKingbaseAgentIdent(name) + if cleaned == "" { + return name + } + lower := strings.ToLower(cleaned) + if actual, ok := index.exact[lower]; ok { + return actual + } + compact := normalizeKingbaseAgentCompactKey(cleaned) + if actual, ok := index.compact[compact]; ok { + return actual + } + return cleaned +} + +func normalizeKingbaseAgentChangeSetByColumns(changes connection.ChangeSet, columns []string) (connection.ChangeSet, error) { + index := buildKingbaseAgentColumnIndex(columns) + if len(index.exact) == 0 && len(index.compact) == 0 { + return changes, nil + } + + mapRow := func(row map[string]interface{}) (map[string]interface{}, error) { + if row == nil { + return row, nil + } + out := make(map[string]interface{}, len(row)) + for key, value := range row { + nextKey := resolveKingbaseAgentColumnName(key, index) + if existing, ok := out[nextKey]; ok && !reflect.DeepEqual(existing, value) { + return nil, fmt.Errorf("duplicate mapped column %q", nextKey) + } + out[nextKey] = value + } + return out, nil + } + + next := connection.ChangeSet{ + Inserts: make([]map[string]interface{}, 0, len(changes.Inserts)), + Updates: make([]connection.UpdateRow, 0, len(changes.Updates)), + Deletes: make([]map[string]interface{}, 0, len(changes.Deletes)), + } + + for _, row := range changes.Inserts { + mapped, err := mapRow(row) + if err != nil { + return changes, err + } + next.Inserts = append(next.Inserts, mapped) + } + + for _, upd := range changes.Updates { + keys, err := mapRow(upd.Keys) + if err != nil { + return changes, err + } + values, err := mapRow(upd.Values) + if err != nil { + return changes, err + } + next.Updates = append(next.Updates, connection.UpdateRow{ + Keys: keys, + Values: values, + }) + } + + for _, row := range changes.Deletes { + mapped, err := mapRow(row) + if err != nil { + return changes, err + } + next.Deletes = append(next.Deletes, mapped) + } + + return next, nil +} + +func (d *OptionalDriverAgentDB) normalizeKingbaseAgentChangeSet(tableName string, changes connection.ChangeSet) (connection.ChangeSet, error) { + columns, err := d.GetColumns("", tableName) + if err != nil { + return changes, err + } + if len(columns) == 0 { + return changes, nil + } + names := make([]string, 0, len(columns)) + for _, col := range columns { + name := strings.TrimSpace(col.Name) + if name == "" { + continue + } + names = append(names, name) + } + return normalizeKingbaseAgentChangeSetByColumns(changes, names) +} + func timeoutMsFromContext(ctx context.Context) int64 { deadline, ok := ctx.Deadline() if !ok { diff --git a/internal/db/optional_driver_agent_impl_test.go b/internal/db/optional_driver_agent_impl_test.go index 2273a06..a79b03d 100644 --- a/internal/db/optional_driver_agent_impl_test.go +++ b/internal/db/optional_driver_agent_impl_test.go @@ -1,32 +1,67 @@ package db import ( - "context" "testing" - "time" + + "GoNavi-Wails/internal/connection" ) -func TestTimeoutMsFromContext_NoDeadline(t *testing.T) { - if got := timeoutMsFromContext(context.Background()); got != 0 { - t.Fatalf("无 deadline 时应返回 0,got=%d", got) +func TestNormalizeKingbaseAgentTableName(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "plain", in: "ldf_server.andon_events", want: "ldf_server.andon_events"}, + {name: "quoted", in: `"ldf_server"."andon_events"`, want: "ldf_server.andon_events"}, + {name: "double quoted", in: `""ldf_server"".""andon_events""`, want: "ldf_server.andon_events"}, + {name: "escaped", in: `\"ldf_server\".\"andon_events\"`, want: "ldf_server.andon_events"}, + {name: "double escaped", in: `\\\"ldf_server\\\".\\\"andon_events\\\"`, want: "ldf_server.andon_events"}, + {name: "space around dot", in: ` "ldf_server" . "andon_events" `, want: "ldf_server.andon_events"}, + {name: "table only", in: `bcs_barcode`, want: "bcs_barcode"}, + {name: "table only quoted", in: `"bcs_barcode"`, want: "bcs_barcode"}, + {name: "table only double quoted", in: `""bcs_barcode""`, want: "bcs_barcode"}, + {name: "table only double escaped", in: `\\\"bcs_barcode\\\"`, want: "bcs_barcode"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := normalizeKingbaseAgentTableName(tt.in); got != tt.want { + t.Fatalf("normalizeKingbaseAgentTableName(%q) = %q, want %q", tt.in, got, tt.want) + } + }) } } -func TestTimeoutMsFromContext_WithDeadline(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() +func TestNormalizeKingbaseAgentChangeSetByColumns(t *testing.T) { + columns := []string{"andon_events_id", "event_name", "event_code"} + input := connection.ChangeSet{ + Inserts: []map[string]interface{}{ + {"event name": "物料1", "event_code": "EV-0001", "andon_events_id": 1}, + }, + Updates: []connection.UpdateRow{ + {Keys: map[string]interface{}{"andon_events_id": 1}, Values: map[string]interface{}{"event name": "物料2"}}, + }, + Deletes: []map[string]interface{}{ + {"andon_events_id": 1}, + }, + } - got := timeoutMsFromContext(ctx) - if got <= 0 { - t.Fatalf("有 deadline 时应返回正值,got=%d", got) - } -} - -func TestTimeoutMsFromContext_ExpiredDeadline(t *testing.T) { - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) - defer cancel() - - if got := timeoutMsFromContext(ctx); got != 1 { - t.Fatalf("过期 deadline 应返回 1,got=%d", got) + out, err := normalizeKingbaseAgentChangeSetByColumns(input, columns) + if err != nil { + t.Fatalf("normalizeKingbaseAgentChangeSetByColumns error: %v", err) + } + + if _, ok := out.Inserts[0]["event_name"]; !ok { + t.Fatalf("expected insert to map \"event name\" -> \"event_name\"") + } + if _, ok := out.Inserts[0]["event name"]; ok { + t.Fatalf("unexpected insert key \"event name\" after normalization") + } + if _, ok := out.Updates[0].Values["event_name"]; !ok { + t.Fatalf("expected update values to map \"event name\" -> \"event_name\"") + } + if _, ok := out.Updates[0].Values["event name"]; ok { + t.Fatalf("unexpected update value key \"event name\" after normalization") } } diff --git a/internal/logger/logger.go b/internal/logger/logger.go index e224608..56cf583 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -14,8 +14,9 @@ import ( ) const ( - envLogDir = "GONAVI_LOG_DIR" - appDirName = "GoNavi" + envLogDir = "GONAVI_LOG_DIR" + appHiddenDir = ".GoNavi" + appLogDirName = "Logs" logFileName = "gonavi.log" logRotateMaxBytes = 10 * 1024 * 1024 // 10MB @@ -37,7 +38,7 @@ func Init() { defer logMu.Unlock() logPath = path logInst = log.New(out, "", log.Ldate|log.Ltime|log.Lmicroseconds) - logInst.Printf("[信息] 日志初始化完成,日志文件:%s", logPath) + logInst.Printf("[INFO] 日志初始化完成,日志文件:%s", logPath) }) } @@ -62,15 +63,15 @@ func Close() { } func Infof(format string, args ...any) { - printf("信息", format, args...) + printf("INFO", format, args...) } func Warnf(format string, args ...any) { - printf("警告", format, args...) + printf("WARN", format, args...) } func Errorf(format string, args ...any) { - printf("错误", format, args...) + printf("ERROR", format, args...) } func Error(err error, format string, args ...any) { @@ -115,37 +116,58 @@ func ErrorChain(err error) string { func printf(level string, format string, args ...any) { Init() logMu.Lock() + defer logMu.Unlock() inst := logInst - logMu.Unlock() if inst == nil { return } inst.Printf("[%s] %s", level, fmt.Sprintf(format, args...)) + if logFile != nil { + _ = logFile.Sync() + } } func initOutput() (string, io.Writer) { dir := strings.TrimSpace(os.Getenv(envLogDir)) if dir == "" { - base, err := os.UserConfigDir() - if err != nil || strings.TrimSpace(base) == "" { - base = os.TempDir() - } - dir = filepath.Join(base, appDirName, "logs") + dir = defaultLogDir() } + if path, writer, ok := openLogFile(dir); ok { + return path, writer + } + + fallbackDir := filepath.Join(os.TempDir(), appHiddenDir, appLogDirName) + if path, writer, ok := openLogFile(fallbackDir); ok { + return path, writer + } + + return "", os.Stderr +} + +func defaultLogDir() string { + home, err := os.UserHomeDir() + if err != nil || strings.TrimSpace(home) == "" { + return filepath.Join(os.TempDir(), appHiddenDir, appLogDirName) + } + return filepath.Join(home, appHiddenDir, appLogDirName) +} + +func openLogFile(dir string) (string, io.Writer, bool) { + if strings.TrimSpace(dir) == "" { + return "", nil, false + } if err := os.MkdirAll(dir, 0o755); err != nil { - return filepath.Join(dir, logFileName), os.Stderr + return "", nil, false } - path := filepath.Join(dir, logFileName) rotateIfNeeded(path, dir) - f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) if err != nil { - return path, os.Stderr + return "", nil, false } logFile = f - return path, f + return path, f, true } func rotateIfNeeded(path, dir string) {