diff --git a/frontend/src/components/DriverManagerModal.tsx b/frontend/src/components/DriverManagerModal.tsx index 6dc3df9..9516e31 100644 --- a/frontend/src/components/DriverManagerModal.tsx +++ b/frontend/src/components/DriverManagerModal.tsx @@ -757,6 +757,16 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG }; }, [appendOperationLog, open]); + const resolveLocalImportVersion = useCallback((row: DriverStatusRow) => { + const options = versionMap[row.type] || []; + const selectedKey = selectedVersionMap[row.type]; + const selectedOption = + options.find((item) => buildVersionOptionKey(item) === selectedKey) || + options.find((item) => item.recommended) || + options[0]; + return selectedOption?.version || row.pinnedVersion || ''; + }, [selectedVersionMap, versionMap]); + const installDriver = useCallback(async (row: DriverStatusRow) => { setActionState({ driverType: row.type, kind: 'install' }); setProgressMap((prev) => ({ @@ -820,9 +830,11 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG percent: 0, }, })); - appendOperationLog(row.type, `[START] 开始本地导入(${sourceLabel}):${pathText}`); + const selectedVersion = resolveLocalImportVersion(row); + const versionTip = selectedVersion ? `(${selectedVersion})` : ''; + appendOperationLog(row.type, `[START] 开始本地导入${versionTip}(${sourceLabel}):${pathText}`); try { - const result = await InstallLocalDriverPackage(row.type, pathText, downloadDir); + const result = await InstallLocalDriverPackage(row.type, pathText, downloadDir, selectedVersion); if (!result?.success) { const errText = result?.message || `导入 ${row.name} 本地驱动包失败`; appendOperationLog(row.type, `[ERROR] ${errText}`); @@ -831,9 +843,9 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG } return false; } - appendOperationLog(row.type, '[DONE] 本地导入安装完成'); + appendOperationLog(row.type, `[DONE] 本地导入安装完成 ${versionTip}`.trim()); if (!options?.silentToast) { - message.success(`${row.name} 本地驱动包已安装启用`); + message.success(`${row.name}${versionTip} 本地驱动包已安装启用`); } if (!options?.skipRefresh) { await refreshStatus(false); @@ -842,7 +854,7 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG } finally { setActionState({ driverType: '', kind: '' }); } - }, [appendOperationLog, downloadDir, refreshStatus]); + }, [appendOperationLog, downloadDir, refreshStatus, resolveLocalImportVersion]); const installDriverFromLocalFile = useCallback(async (row: DriverStatusRow) => { const fileRes = await SelectDriverPackageFile(downloadDir); diff --git a/frontend/wailsjs/go/app/App.d.ts b/frontend/wailsjs/go/app/App.d.ts index 08c1dd8..f94ace7 100755 --- a/frontend/wailsjs/go/app/App.d.ts +++ b/frontend/wailsjs/go/app/App.d.ts @@ -106,7 +106,7 @@ export function ImportLegacyConnections(arg1:Array; -export function InstallLocalDriverPackage(arg1:string,arg2:string,arg3:string):Promise; +export function InstallLocalDriverPackage(arg1:string,arg2:string,arg3:string,arg4:string):Promise; export function InstallUpdateAndRestart():Promise; diff --git a/frontend/wailsjs/go/app/App.js b/frontend/wailsjs/go/app/App.js index 9564131..d2e2c50 100755 --- a/frontend/wailsjs/go/app/App.js +++ b/frontend/wailsjs/go/app/App.js @@ -206,8 +206,8 @@ export function ImportLegacyGlobalProxy(arg1) { return window['go']['app']['App']['ImportLegacyGlobalProxy'](arg1); } -export function InstallLocalDriverPackage(arg1, arg2, arg3) { - return window['go']['app']['App']['InstallLocalDriverPackage'](arg1, arg2, arg3); +export function InstallLocalDriverPackage(arg1, arg2, arg3, arg4) { + return window['go']['app']['App']['InstallLocalDriverPackage'](arg1, arg2, arg3, arg4); } export function InstallUpdateAndRestart() { diff --git a/internal/app/methods_driver.go b/internal/app/methods_driver.go index e8873f7..f8e90dc 100644 --- a/internal/app/methods_driver.go +++ b/internal/app/methods_driver.go @@ -745,7 +745,7 @@ func (a *App) CheckDriverNetworkStatus() connection.QueryResult { } } -func (a *App) InstallLocalDriverPackage(driverType string, filePath string, downloadDir string) connection.QueryResult { +func (a *App) InstallLocalDriverPackage(driverType string, filePath string, downloadDir string, version string) connection.QueryResult { definition, ok := resolveDriverDefinition(driverType) if !ok { return connection.QueryResult{Success: false, Message: "不支持的驱动类型"} @@ -768,7 +768,10 @@ func (a *App) InstallLocalDriverPackage(driverType string, filePath string, down db.SetExternalDriverDownloadDirectory(resolvedDir) a.emitDriverDownloadProgress(definition.Type, "start", 0, 100, "开始安装本地驱动包") - selectedVersion := resolveDriverInstallVersion(definition.PinnedVersion, "local://manual", definition) + selectedVersion := resolveDriverInstallVersion(version, "local://manual", definition) + if err := validateDriverSelectedVersion(definition, selectedVersion); err != nil { + return connection.QueryResult{Success: false, Message: err.Error()} + } meta, installErr := installOptionalDriverAgentFromLocalPath(definition, filePath, resolvedDir, selectedVersion) if installErr != nil { errText := normalizeErrorMessage(installErr) @@ -2628,7 +2631,7 @@ func installOptionalDriverAgentFromLocalPath(definition driverDefinition, filePa sourceName := filepath.Base(pathText) downloadSource := fmt.Sprintf("local://manual/%s", filepath.Base(pathText)) if info.IsDir() { - matchedPath, matchedEntry, resolveErr := resolveLocalDriverAgentFromDirectory(pathText, driverType) + matchedPath, matchedEntry, resolveErr := resolveLocalDriverAgentFromLocalDirectory(pathText, driverType, selectedVersion) if resolveErr != nil { return installedDriverPackage{}, resolveErr } @@ -2641,7 +2644,7 @@ func installOptionalDriverAgentFromLocalPath(definition driverDefinition, filePa } if !info.IsDir() && strings.EqualFold(filepath.Ext(pathText), ".zip") { - entryName, extractErr := installOptionalDriverAgentFromLocalZip(pathText, definition, executablePath) + entryName, extractErr := installOptionalDriverAgentFromLocalZip(pathText, definition, executablePath, selectedVersion) if extractErr != nil { return installedDriverPackage{}, extractErr } @@ -2680,7 +2683,7 @@ type localDriverCandidate struct { inPlatformDir bool } -func resolveLocalDriverAgentFromDirectory(directoryPath string, driverType string) (string, string, error) { +func resolveLocalDriverAgentFromLocalDirectory(directoryPath string, driverType string, selectedVersion string) (string, string, error) { root := strings.TrimSpace(directoryPath) if root == "" { return "", "", fmt.Errorf("本地驱动目录路径为空") @@ -2703,9 +2706,9 @@ func resolveLocalDriverAgentFromDirectory(directoryPath string, driverType strin } displayName := resolveDriverDisplayName(displayDefinition) platformDir := optionalDriverBundlePlatformDir(stdRuntime.GOOS) - assetNameCandidates := optionalDriverReleaseAssetNames(normalizedType) - baseNameCandidates := optionalDriverExecutableBaseNames(normalizedType) - assetName := optionalDriverReleaseAssetName(normalizedType) + assetNameCandidates := optionalDriverReleaseAssetNamesForVersion(normalizedType, selectedVersion) + baseNameCandidates := optionalDriverExecutableBaseNamesForVersion(normalizedType, selectedVersion) + assetName := optionalDriverReleaseAssetNameForVersion(normalizedType, selectedVersion) exactRelativePath := filepath.ToSlash(filepath.Join(platformDir, assetName)) for _, candidateName := range assetNameCandidates { @@ -2820,7 +2823,7 @@ func resolveLocalDriverAgentFromDirectory(directoryPath string, driverType strin ) } -func installOptionalDriverAgentFromLocalZip(zipPath string, definition driverDefinition, executablePath string) (string, error) { +func installOptionalDriverAgentFromLocalZip(zipPath string, definition driverDefinition, executablePath string, selectedVersion string) (string, error) { driverType := normalizeDriverType(definition.Type) displayName := resolveDriverDisplayName(definition) reader, err := zip.OpenReader(zipPath) @@ -2829,9 +2832,9 @@ func installOptionalDriverAgentFromLocalZip(zipPath string, definition driverDef } defer reader.Close() - entryPath := optionalDriverBundleEntryPath(driverType) - entryPaths := optionalDriverBundleEntryPaths(driverType) - expectedBaseNames := optionalDriverReleaseAssetNames(driverType) + entryPath := optionalDriverBundleEntryPathForVersion(driverType, selectedVersion) + entryPaths := optionalDriverBundleEntryPathsForVersion(driverType, selectedVersion) + expectedBaseNames := optionalDriverReleaseAssetNamesForVersion(driverType, selectedVersion) findEntry := func() *zip.File { for _, file := range reader.File { name := filepath.ToSlash(strings.TrimPrefix(strings.TrimSpace(file.Name), "./")) @@ -3490,9 +3493,9 @@ func optionalDriverBundlePlatformDir(goos string) string { } } -func optionalDriverBundleEntryPaths(driverType string) []string { +func optionalDriverBundleEntryPathsForVersion(driverType string, selectedVersion string) []string { platformDir := optionalDriverBundlePlatformDir(stdRuntime.GOOS) - assetNames := optionalDriverReleaseAssetNames(driverType) + assetNames := optionalDriverReleaseAssetNamesForVersion(driverType, selectedVersion) result := make([]string, 0, len(assetNames)) seen := make(map[string]struct{}, len(assetNames)) for _, assetName := range assetNames { @@ -3506,14 +3509,22 @@ func optionalDriverBundleEntryPaths(driverType string) []string { return result } -func optionalDriverBundleEntryPath(driverType string) string { - paths := optionalDriverBundleEntryPaths(driverType) +func optionalDriverBundleEntryPaths(driverType string) []string { + return optionalDriverBundleEntryPathsForVersion(driverType, "") +} + +func optionalDriverBundleEntryPathForVersion(driverType string, selectedVersion string) string { + paths := optionalDriverBundleEntryPathsForVersion(driverType, selectedVersion) if len(paths) == 0 { - return filepath.ToSlash(filepath.Join(optionalDriverBundlePlatformDir(stdRuntime.GOOS), optionalDriverReleaseAssetName(driverType))) + return filepath.ToSlash(filepath.Join(optionalDriverBundlePlatformDir(stdRuntime.GOOS), optionalDriverReleaseAssetNameForVersion(driverType, selectedVersion))) } return paths[0] } +func optionalDriverBundleEntryPath(driverType string) string { + return optionalDriverBundleEntryPathForVersion(driverType, "") +} + func resolveOptionalDriverAssetSize(sizeByAsset map[string]int64, driverType string) int64 { if len(sizeByAsset) == 0 { return 0 diff --git a/internal/app/methods_driver_version_test.go b/internal/app/methods_driver_version_test.go index 5fcfe34..ff12876 100644 --- a/internal/app/methods_driver_version_test.go +++ b/internal/app/methods_driver_version_test.go @@ -1,8 +1,10 @@ package app import ( + "archive/zip" "fmt" "os" + "path/filepath" "runtime" "strings" "testing" @@ -154,6 +156,66 @@ func TestShouldForceSourceBuildForResolvedDownload(t *testing.T) { } } +func TestInstallOptionalDriverAgentFromLocalPathSupportsMongoV1DirectoryImport(t *testing.T) { + definition, ok := resolveDriverDefinition("mongodb") + if !ok { + t.Fatal("expected mongodb driver definition") + } + + packageRoot := t.TempDir() + platformDir := filepath.Join(packageRoot, optionalDriverBundlePlatformDir(runtime.GOOS)) + if err := os.MkdirAll(platformDir, 0o755); err != nil { + t.Fatalf("mkdir package dir failed: %v", err) + } + + assetName := mongoVersionedReleaseAssetName(1) + writeSelfExecutable(t, filepath.Join(platformDir, assetName)) + + installRoot := filepath.Join(t.TempDir(), "drivers") + meta, err := installOptionalDriverAgentFromLocalPath(definition, packageRoot, installRoot, "1.17.4") + if err != nil { + t.Fatalf("expected mongodb v1 directory import to succeed, got %v", err) + } + if meta.Version != "1.17.4" { + t.Fatalf("expected imported version to stay 1.17.4, got %q", meta.Version) + } + if filepath.Base(meta.FilePath) != assetName { + t.Fatalf("expected source file %q, got %q", assetName, meta.FilePath) + } + if !strings.Contains(meta.DownloadURL, assetName) { + t.Fatalf("expected download source to reference %q, got %q", assetName, meta.DownloadURL) + } + if _, err := os.Stat(meta.ExecutablePath); err != nil { + t.Fatalf("expected imported executable to exist, got %v", err) + } +} + +func TestInstallOptionalDriverAgentFromLocalPathSupportsMongoV1ZipImport(t *testing.T) { + definition, ok := resolveDriverDefinition("mongodb") + if !ok { + t.Fatal("expected mongodb driver definition") + } + + assetName := mongoVersionedReleaseAssetName(1) + zipPath := filepath.Join(t.TempDir(), "mongodb-v1.zip") + writeZipWithSelfExecutable(t, zipPath, filepath.ToSlash(filepath.Join(optionalDriverBundlePlatformDir(runtime.GOOS), assetName))) + + installRoot := filepath.Join(t.TempDir(), "drivers") + meta, err := installOptionalDriverAgentFromLocalPath(definition, zipPath, installRoot, "1.17.4") + if err != nil { + t.Fatalf("expected mongodb v1 zip import to succeed, got %v", err) + } + if meta.Version != "1.17.4" { + t.Fatalf("expected imported version to stay 1.17.4, got %q", meta.Version) + } + if !strings.Contains(meta.DownloadURL, assetName) { + t.Fatalf("expected zip download source to reference %q, got %q", assetName, meta.DownloadURL) + } + if _, err := os.Stat(meta.ExecutablePath); err != nil { + t.Fatalf("expected imported executable to exist, got %v", err) + } +} + func seedReleaseAssetSizeCache(t *testing.T, cacheKey string, sizeByKey map[string]int64) { t.Helper() @@ -220,3 +282,50 @@ func mongoVersionedReleaseAssetName(major int) string { } return name } + +func writeSelfExecutable(t *testing.T, targetPath string) { + t.Helper() + + selfPath, err := os.Executable() + if err != nil { + t.Fatalf("executable path failed: %v", err) + } + content, err := os.ReadFile(selfPath) + if err != nil { + t.Fatalf("read self executable failed: %v", err) + } + if err := os.WriteFile(targetPath, content, 0o755); err != nil { + t.Fatalf("write executable failed: %v", err) + } +} + +func writeZipWithSelfExecutable(t *testing.T, zipPath string, entryName string) { + t.Helper() + + selfPath, err := os.Executable() + if err != nil { + t.Fatalf("executable path failed: %v", err) + } + content, err := os.ReadFile(selfPath) + if err != nil { + t.Fatalf("read self executable failed: %v", err) + } + + file, err := os.Create(zipPath) + if err != nil { + t.Fatalf("create zip failed: %v", err) + } + defer file.Close() + + writer := zip.NewWriter(file) + entry, err := writer.Create(entryName) + if err != nil { + t.Fatalf("create zip entry failed: %v", err) + } + if _, err := entry.Write(content); err != nil { + t.Fatalf("write zip entry failed: %v", err) + } + if err := writer.Close(); err != nil { + t.Fatalf("close zip writer failed: %v", err) + } +} diff --git a/internal/db/clickhouse_impl.go b/internal/db/clickhouse_impl.go index d98ca3d..c6be5d0 100644 --- a/internal/db/clickhouse_impl.go +++ b/internal/db/clickhouse_impl.go @@ -279,7 +279,44 @@ func (c *ClickHouseDB) Ping() error { } ctx, cancel := utils.ContextWithTimeout(timeout) defer cancel() - return c.conn.PingContext(ctx) + if err := c.conn.PingContext(ctx); err != nil { + return err + } + return c.validateQueryPath() +} + +func (c *ClickHouseDB) validateQueryPath() error { + if c.conn == nil { + return fmt.Errorf("连接未打开") + } + timeout := c.pingTimeout + if timeout <= 0 { + timeout = 5 * time.Second + } + ctx, cancel := utils.ContextWithTimeout(timeout) + defer cancel() + + rows, err := c.conn.QueryContext(ctx, "SELECT currentDatabase()") + if err != nil { + return err + } + defer rows.Close() + + if !rows.Next() { + if err := rows.Err(); err != nil { + return err + } + return fmt.Errorf("连接查询验证未返回结果") + } + + var current sql.NullString + if err := rows.Scan(¤t); err != nil { + return err + } + if err := rows.Err(); err != nil { + return err + } + return nil } func (c *ClickHouseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { diff --git a/internal/db/clickhouse_impl_test.go b/internal/db/clickhouse_impl_test.go new file mode 100644 index 0000000..a2fd40a --- /dev/null +++ b/internal/db/clickhouse_impl_test.go @@ -0,0 +1,119 @@ +//go:build gonavi_full_drivers || gonavi_clickhouse_driver + +package db + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "io" + "strings" + "sync" + "testing" + "time" +) + +const fakeClickHouseDriverName = "gonavi-fake-clickhouse" + +var ( + registerFakeClickHouseDriverOnce sync.Once + fakeClickHouseStateMu sync.Mutex + fakeClickHouseState = struct { + pingErr error + queryErr error + lastQuery string + }{ + lastQuery: "", + } +) + +func TestClickHousePingValidatesQueryPath(t *testing.T) { + registerFakeClickHouseDriverOnce.Do(func() { + sql.Register(fakeClickHouseDriverName, fakeClickHouseDriver{}) + }) + + db, err := sql.Open(fakeClickHouseDriverName, "") + if err != nil { + t.Fatalf("open fake clickhouse db failed: %v", err) + } + defer db.Close() + + fakeClickHouseStateMu.Lock() + fakeClickHouseState.pingErr = nil + fakeClickHouseState.queryErr = errors.New("query path failed") + fakeClickHouseState.lastQuery = "" + fakeClickHouseStateMu.Unlock() + + client := &ClickHouseDB{ + conn: db, + pingTimeout: time.Second, + } + err = client.Ping() + if err == nil { + t.Fatal("expected Ping to fail when query validation fails") + } + if !strings.Contains(err.Error(), "query path failed") { + t.Fatalf("expected query validation error, got %v", err) + } + + fakeClickHouseStateMu.Lock() + lastQuery := fakeClickHouseState.lastQuery + fakeClickHouseStateMu.Unlock() + if lastQuery != "SELECT currentDatabase()" { + t.Fatalf("expected query validation SQL to run, got %q", lastQuery) + } +} + +type fakeClickHouseDriver struct{} + +func (fakeClickHouseDriver) Open(name string) (driver.Conn, error) { + return fakeClickHouseConn{}, nil +} + +type fakeClickHouseConn struct{} + +func (fakeClickHouseConn) Prepare(query string) (driver.Stmt, error) { + return nil, errors.New("prepare not implemented") +} + +func (fakeClickHouseConn) Close() error { + return nil +} + +func (fakeClickHouseConn) Begin() (driver.Tx, error) { + return nil, errors.New("transactions not implemented") +} + +func (fakeClickHouseConn) Ping(ctx context.Context) error { + fakeClickHouseStateMu.Lock() + defer fakeClickHouseStateMu.Unlock() + return fakeClickHouseState.pingErr +} + +func (fakeClickHouseConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + fakeClickHouseStateMu.Lock() + defer fakeClickHouseStateMu.Unlock() + fakeClickHouseState.lastQuery = query + if fakeClickHouseState.queryErr != nil { + return nil, fakeClickHouseState.queryErr + } + return &fakeClickHouseRows{}, nil +} + +type fakeClickHouseRows struct{} + +func (r *fakeClickHouseRows) Columns() []string { + return []string{"currentDatabase"} +} + +func (r *fakeClickHouseRows) Close() error { + return nil +} + +func (r *fakeClickHouseRows) Next(dest []driver.Value) error { + if len(dest) > 0 { + dest[0] = "default" + } + return io.EOF +}