From eddb9f38c98ee0498a94f670bdf22a49f1f61719 Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Thu, 2 Apr 2026 19:28:48 +0800 Subject: [PATCH 01/14] =?UTF-8?q?=F0=9F=90=9B=20fix(data-viewer):=20?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=A4=9A=E5=88=97=E6=8E=92=E5=BA=8F=E7=8A=B6?= =?UTF-8?q?=E6=80=81=E6=AE=8B=E7=95=99=E5=AF=BC=E8=87=B4=E6=8E=92=E5=BA=8F?= =?UTF-8?q?=E5=A4=B1=E6=95=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将表格排序状态改为按当前 sorter 结果重建\n- 避免取消或切换多列排序后保留失效字段\n- 抽取排序状态归一化工具供数据表复用 --- frontend/src/components/DataGrid.tsx | 34 ++-------------------- frontend/src/utils/dataGridSort.ts | 43 ++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 31 deletions(-) create mode 100644 frontend/src/utils/dataGridSort.ts diff --git a/frontend/src/components/DataGrid.tsx b/frontend/src/components/DataGrid.tsx index 79644df..3ead509 100644 --- a/frontend/src/components/DataGrid.tsx +++ b/frontend/src/components/DataGrid.tsx @@ -33,6 +33,7 @@ import { buildOrderBySQL, buildPaginatedSelectSQL, buildWhereSQL, escapeLiteral, import { isMacLikePlatform, normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance'; import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities'; import { resolvePaginationPageText, resolvePaginationSummaryText, resolvePaginationTotalForControl } from '../utils/dataGridPagination'; +import { resolveGridSortInfoFromTableSorter } from '../utils/dataGridSort'; import { calculateTableBodyBottomPadding, calculateVirtualTableScrollX } from './dataGridLayout'; import { buildCopyInsertSQL, normalizeTemporalLiteralText } from './dataGridCopyInsert'; @@ -2762,39 +2763,10 @@ const DataGrid: React.FC = ({ const handleTableChange = useCallback((_pag: any, _filtersArg: any, sorter: any) => { if (isResizingRef.current) return; // Block sort if resizing - // Ant Design 多列排序模式下 sorter 可能是数组 - const sorters = Array.isArray(sorter) ? sorter : (sorter?.field ? [sorter] : []); - if (sorters.length === 0) { - setSortInfo([]); - if (onSort) onSort(JSON.stringify([]), ''); - return; - } - // 在现有排序数组基础上增量更新 - const next = [...sortInfo]; - for (const s of sorters) { - const field = String(s.field || ''); - if (!field) continue; - const order = s.order as string; - const normalizedOrder = order === 'ascend' || order === 'descend' ? order : ''; - const existIdx = next.findIndex(item => item.columnKey === field); - if (!normalizedOrder) { - // Ant Design 第三次点击想取消排序: - // 如果该字段已在排序数组中,回转为升序而非移除 - if (existIdx >= 0) { - next[existIdx] = { ...next[existIdx], order: 'ascend', enabled: true }; - } - // 不在数组中则忽略 - } else if (existIdx >= 0) { - // 已存在:更新排序方向 - next[existIdx] = { ...next[existIdx], order: normalizedOrder, enabled: true }; - } else { - // 不存在:追加到末尾 - next.push({ columnKey: field, order: normalizedOrder, enabled: true }); - } - } + const next = resolveGridSortInfoFromTableSorter({ sorter }); setSortInfo(next); if (onSort) onSort(JSON.stringify(next), ''); - }, [onSort, sortInfo]); + }, [onSort]); // Native Drag State const draggingRef = useRef<{ diff --git a/frontend/src/utils/dataGridSort.ts b/frontend/src/utils/dataGridSort.ts new file mode 100644 index 0000000..80749bb --- /dev/null +++ b/frontend/src/utils/dataGridSort.ts @@ -0,0 +1,43 @@ +export type GridSortInfoItem = { + columnKey: string; + order: string; + enabled?: boolean; +}; + +type TableSorterLike = { + field?: unknown; + columnKey?: unknown; + order?: unknown; +}; + +export const resolveGridSortInfoFromTableSorter = ({ + sorter, +}: { + sorter: TableSorterLike | TableSorterLike[] | null | undefined; +}): GridSortInfoItem[] => { + const sorters = Array.isArray(sorter) + ? sorter + : ((sorter?.field || sorter?.columnKey) ? [sorter] : []); + + if (sorters.length === 0) { + return []; + } + + const next: GridSortInfoItem[] = []; + const seen = new Set(); + + for (const item of sorters) { + const field = String(item?.field || item?.columnKey || '').trim(); + if (!field) continue; + + const order = item?.order as string; + const normalizedOrder = order === 'ascend' || order === 'descend' ? order : ''; + if (!normalizedOrder) continue; + const dedupeKey = field.toLowerCase(); + if (seen.has(dedupeKey)) continue; + seen.add(dedupeKey); + next.push({ columnKey: field, order: normalizedOrder, enabled: true }); + } + + return next; +}; From acee1a06e8c01aaae72bfb9f76ba4c6ca7d08801 Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Thu, 2 Apr 2026 20:15:49 +0800 Subject: [PATCH 02/14] =?UTF-8?q?fix(driver):=20=E6=94=B6=E7=B4=A7=20Mongo?= =?UTF-8?q?DB=20=E9=A9=B1=E5=8A=A8=E6=94=AF=E6=8C=81=E5=8C=BA=E9=97=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/components/DriverManagerModal.tsx | 50 +-- internal/app/methods_driver.go | 287 +++++++++++++++--- internal/app/methods_driver_version_test.go | 222 ++++++++++++++ 3 files changed, 501 insertions(+), 58 deletions(-) create mode 100644 internal/app/methods_driver_version_test.go diff --git a/frontend/src/components/DriverManagerModal.tsx b/frontend/src/components/DriverManagerModal.tsx index df5b0cd..6dc3df9 100644 --- a/frontend/src/components/DriverManagerModal.tsx +++ b/frontend/src/components/DriverManagerModal.tsx @@ -1067,29 +1067,35 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG const options = versionMap[row.type] || []; const selectedKey = selectedVersionMap[row.type]; const selectOptions = buildVersionSelectOptions(options); + const mongoHint = row.type === 'mongodb' + ? '当前仅支持 MongoDB 1.17.x 和 2.x;更老 1.x 暂不提供安装。' + : ''; return ( - 0 ? '选择驱动版本' : '点击展开加载版本'} + value={selectedKey} + options={selectOptions as any} + onOpenChange={(open) => { + if (open && options.length === 0 && !versionLoadingMap[row.type]) { + void loadVersionOptions(row, true); + return; + } + if (open && selectedKey) { + void loadVersionPackageSize(row, selectedKey); + } + }} + onChange={(value) => { + setSelectedVersionMap((prev) => ({ ...prev, [row.type]: value })); + void loadVersionPackageSize(row, value); + }} + /> + {mongoHint ? {mongoHint} : null} + ); }, }, diff --git a/internal/app/methods_driver.go b/internal/app/methods_driver.go index 1024014..e8873f7 100644 --- a/internal/app/methods_driver.go +++ b/internal/app/methods_driver.go @@ -543,7 +543,10 @@ func (a *App) GetDriverVersionPackageSize(driverType string, version string) con if normalizedVersion == "" { return connection.QueryResult{Success: false, Message: "版本号为空"} } - assetName := optionalDriverReleaseAssetName(normalizedType) + if err := validateDriverSelectedVersion(definition, normalizedVersion); err != nil { + return connection.QueryResult{Success: false, Message: err.Error()} + } + assetName := optionalDriverReleaseAssetNameForVersion(normalizedType, normalizedVersion) if strings.TrimSpace(assetName) == "" { return connection.QueryResult{Success: false, Message: "驱动资产名称为空"} } @@ -554,14 +557,15 @@ func (a *App) GetDriverVersionPackageSize(driverType string, version string) con if sizeByAsset, err := loadReleaseAssetSizesCached("tag:"+tag, func() (*githubRelease, error) { return fetchReleaseByTag(tag) }); err == nil { - sizeBytes = resolveOptionalDriverAssetSize(sizeByAsset, normalizedType) + sizeBytes = resolveOptionalDriverAssetSizeForVersion(sizeByAsset, normalizedType, normalizedVersion) if sizeBytes > 0 { sizeSource = "tag" } } - if sizeBytes <= 0 { + allowLatestFallback := sameDriverVersion(normalizedVersion, definition.PinnedVersion) || sameDriverVersion(normalizedVersion, latestDriverVersionMap[normalizedType]) + if sizeBytes <= 0 && allowLatestFallback { if sizeByAsset, err := loadReleaseAssetSizesCached("latest", fetchLatestReleaseForDriverAssets); err == nil { - sizeBytes = resolveOptionalDriverAssetSize(sizeByAsset, normalizedType) + sizeBytes = resolveOptionalDriverAssetSizeForVersion(sizeByAsset, normalizedType, normalizedVersion) if sizeBytes > 0 { sizeSource = "latest" } @@ -816,6 +820,9 @@ func (a *App) DownloadDriverPackage(driverType string, version string, downloadU urlText = fmt.Sprintf("builtin://activate/%s", optionalDriverPublicTypeName(definition.Type)) } selectedVersion := resolveDriverInstallVersion(version, urlText, definition) + if err := validateDriverSelectedVersion(definition, selectedVersion); err != nil { + return connection.QueryResult{Success: false, Message: err.Error()} + } resolvedDir, err := resolveDriverDownloadDirectory(downloadDir) if err != nil { @@ -1424,6 +1431,11 @@ func resolveDriverVersionOptions(definition driverDefinition, repositoryURL stri if versionText == "" && urlText == "" { return } + if versionText != "" { + if err := validateDriverSelectedVersion(definition, versionText); err != nil { + return + } + } versionKey := normalizeVersion(versionText) key := "" if versionKey != "" { @@ -1550,6 +1562,16 @@ func resolveVersionedDriverOption(definition driverDefinition, version string, s if versionText == "" { return "", "", false } + if err := validateDriverSelectedVersion(definition, versionText); err != nil { + return "", "", false + } + + if publishedURL, ok := resolvePublishedDriverDownloadURL(definition, versionText); ok { + return versionText, publishedURL, true + } + if !optionalDriverSourceBuildAvailable(definition, versionText) { + return "", "", false + } urlText := strings.TrimSpace(definition.DefaultDownloadURL) if urlText == "" && effectiveDriverEngine(definition) == driverEngineGo { @@ -1580,6 +1602,97 @@ func sameDriverVersion(left, right string) bool { return a != "" && a == b } +func validateDriverSelectedVersion(definition driverDefinition, version string) error { + driverType := normalizeDriverType(definition.Type) + versionText := normalizeVersion(strings.TrimSpace(version)) + if driverType == "" || versionText == "" { + return nil + } + + switch driverType { + case "mongodb": + if strings.HasPrefix(versionText, "2.") { + return nil + } + if strings.HasPrefix(versionText, "1.17.") { + return nil + } + return fmt.Errorf("MongoDB 版本 %s 当前不受支持;仅支持 1.17.x 和 2.x", versionText) + default: + return nil + } +} + +func shouldRestrictToExplicitVersionArtifact(definition driverDefinition, selectedVersion string) bool { + versionText := normalizeVersion(strings.TrimSpace(selectedVersion)) + if versionText == "" { + return false + } + return !sameDriverVersion(versionText, definition.PinnedVersion) +} + +func optionalDriverSourceBuildAvailable(definition driverDefinition, selectedVersion string) bool { + driverType := normalizeDriverType(definition.Type) + if driverType == "" || !db.IsOptionalGoDriver(driverType) { + return false + } + if _, err := optionalDriverBuildTag(driverType, selectedVersion); err != nil { + return false + } + if _, err := exec.LookPath("go"); err != nil { + return false + } + if _, err := locateProjectRootForAgentBuild(); err != nil { + return false + } + return true +} + +func resolvePublishedDriverDownloadURL(definition driverDefinition, version string) (string, bool) { + driverType := normalizeDriverType(definition.Type) + versionText := normalizeVersion(strings.TrimSpace(version)) + if driverType == "" || versionText == "" { + return "", false + } + + tag := "v" + versionText + assetName, ok := resolvePublishedDriverReleaseAssetName(driverType, versionText, tag) + if !ok { + return "", false + } + return fmt.Sprintf("https://github.com/%s/releases/download/%s/%s", updateRepo, tag, assetName), true +} + +func resolvePublishedDriverReleaseAssetName(driverType string, version string, tag string) (string, bool) { + assetNames := optionalDriverReleaseAssetNamesForVersion(driverType, version) + if len(assetNames) == 0 { + return "", false + } + + cacheKey := "tag:" + strings.TrimSpace(tag) + if sizeByAsset, ok := readReleaseAssetSizesFromCache(cacheKey); ok { + for _, assetName := range assetNames { + if sizeByAsset[assetName] > 0 { + return assetName, true + } + } + return "", false + } + + sizeByAsset, err := loadReleaseAssetSizesCached(cacheKey, func() (*githubRelease, error) { + return fetchReleaseByTag(tag) + }) + if err != nil { + return "", false + } + for _, assetName := range assetNames { + if sizeByAsset[assetName] > 0 { + return assetName, true + } + } + return "", false +} + func resolveDriverVersionPackageSizeBytes(definition driverDefinition, option driverVersionOptionItem) int64 { driverType := normalizeDriverType(definition.Type) if driverType == "" || definition.BuiltIn { @@ -1593,20 +1706,20 @@ func resolveDriverVersionPackageSizeBytes(definition driverDefinition, option dr if version == "" { return 0 } - assetName := optionalDriverReleaseAssetName(driverType) - if strings.TrimSpace(assetName) == "" { + assetNames := optionalDriverReleaseAssetNamesForVersion(driverType, version) + if len(assetNames) == 0 { return 0 } tag := "v" + version if sizeByAsset, ok := readReleaseAssetSizesFromCache("tag:" + tag); ok { - return resolveOptionalDriverAssetSize(sizeByAsset, driverType) + return resolveOptionalDriverAssetSizeForVersion(sizeByAsset, driverType, version) } // 下拉版本列表要求快速返回:仅复用已有缓存,不在这里触发网络请求。 if strings.EqualFold(strings.TrimSpace(option.Source), "latest") { if sizeByAsset, ok := readReleaseAssetSizesFromCache("latest"); ok { - return resolveOptionalDriverAssetSize(sizeByAsset, driverType) + return resolveOptionalDriverAssetSizeForVersion(sizeByAsset, driverType, version) } } return 0 @@ -1906,19 +2019,23 @@ func resolveDriverVersionOptionsFromReleases(definition driverDefinition) []driv return nil } - assetName := optionalDriverReleaseAssetName(driverType) - assetNames := optionalDriverReleaseAssetNames(driverType) result := make([]driverVersionOptionItem, 0, len(releases)) for _, release := range releases { if release.Prerelease { continue } tag := strings.TrimSpace(release.TagName) - if tag == "" || !releaseContainsAnyAsset(release, assetNames) { + version := normalizeVersion(tag) + if tag == "" || version == "" { + continue + } + assetName := optionalDriverReleaseAssetNameForVersion(driverType, version) + assetNames := optionalDriverReleaseAssetNamesForVersion(driverType, version) + if !releaseContainsAnyAsset(release, assetNames) { continue } result = append(result, driverVersionOptionItem{ - Version: normalizeVersion(tag), + Version: version, DownloadURL: fmt.Sprintf("https://github.com/%s/releases/download/%s/%s", updateRepo, tag, assetName), Source: "release", }) @@ -2791,9 +2908,10 @@ func installOptionalDriverAgentFromLocalZip(zipPath string, definition driverDef func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, executablePath string, downloadURL string, selectedVersion string) (string, string, error) { driverType := normalizeDriverType(definition.Type) displayName := resolveDriverDisplayName(definition) - forceSourceBuild := shouldForceSourceBuildForVersion(driverType, selectedVersion) + forceSourceBuild := shouldForceSourceBuildForResolvedDownload(driverType, selectedVersion, downloadURL) preferSourceBuildBeforeDownload := shouldPreferSourceBuildBeforeDownload(driverType, selectedVersion) skipReuseCandidate := shouldSkipReusableAgentCandidate(driverType, selectedVersion) + restrictToExplicitArtifact := shouldRestrictToExplicitVersionArtifact(definition, selectedVersion) info, err := os.Stat(executablePath) if err == nil && !info.IsDir() { @@ -2851,7 +2969,7 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut } if !forceSourceBuild { - downloadURLs := resolveOptionalDriverAgentDownloadURLs(definition, downloadURL) + downloadURLs := resolveOptionalDriverAgentDownloadURLs(definition, downloadURL, selectedVersion) if len(downloadURLs) > 0 { for _, candidateURL := range downloadURLs { if a != nil { @@ -2865,7 +2983,7 @@ func ensureOptionalDriverAgentBinary(a *App, definition driverDefinition, execut } } bundleURLs := resolveOptionalDriverBundleDownloadURLs() - if len(bundleURLs) > 0 { + if !restrictToExplicitArtifact && len(bundleURLs) > 0 { for _, bundleURL := range bundleURLs { if a != nil { a.emitDriverDownloadProgress(driverType, "downloading", 20, 100, fmt.Sprintf("从驱动总包提取 %s 代理", displayName)) @@ -3108,6 +3226,23 @@ func shouldForceSourceBuildForVersion(driverType string, selectedVersion string) return resolveMongoDriverMajorFromVersion(selectedVersion) == 1 } +func shouldForceSourceBuildForResolvedDownload(driverType string, selectedVersion string, downloadURL string) bool { + if !shouldForceSourceBuildForVersion(driverType, selectedVersion) { + return false + } + + parsed, err := url.Parse(strings.TrimSpace(downloadURL)) + if err != nil || parsed == nil { + return true + } + switch strings.ToLower(strings.TrimSpace(parsed.Scheme)) { + case "http", "https": + return false + default: + return true + } +} + func shouldPreferSourceBuildBeforeDownload(driverType string, selectedVersion string) bool { _ = selectedVersion switch normalizeDriverType(driverType) { @@ -3224,11 +3359,80 @@ func optionalDriverReleaseAssetNameForType(typeName string, goos string, goarch return name } -func optionalDriverExecutableBaseNames(driverType string) []string { +func optionalDriverNameStemCandidates(driverType string, selectedVersion string) []string { + candidates := make([]string, 0, 3) + seen := make(map[string]struct{}, 3) + appendStem := func(stem string) { + trimmed := strings.TrimSpace(stem) + if trimmed == "" { + return + } + if _, ok := seen[trimmed]; ok { + return + } + seen[trimmed] = struct{}{} + candidates = append(candidates, trimmed) + } + + base := fmt.Sprintf("%s-driver-agent", optionalDriverPublicTypeName(driverType)) + if normalizeDriverType(driverType) == "mongodb" { + switch resolveMongoDriverMajorFromVersion(selectedVersion) { + case 1: + appendStem(base + "-v1") + appendStem(base) + case 2: + appendStem(base) + appendStem(base + "-v2") + default: + appendStem(base) + } + return candidates + } + + appendStem(base) + return candidates +} + +func optionalDriverExecutableBaseNamesForVersion(driverType string, selectedVersion string) []string { names := make([]string, 0, 2) seen := make(map[string]struct{}, 2) - appendName := func(typeName string) { - name := optionalDriverExecutableBaseNameForType(typeName) + appendName := func(stem string) { + name := strings.TrimSpace(stem) + if strings.TrimSpace(name) == "" { + return + } + if stdRuntime.GOOS == "windows" { + name += ".exe" + } + if _, ok := seen[name]; ok { + return + } + seen[name] = struct{}{} + names = append(names, name) + } + + for _, stem := range optionalDriverNameStemCandidates(driverType, selectedVersion) { + appendName(stem) + } + return names +} + +func optionalDriverExecutableBaseNames(driverType string) []string { + return optionalDriverExecutableBaseNamesForVersion(driverType, "") +} + +func optionalDriverReleaseAssetNamesForVersion(driverType string, selectedVersion string) []string { + names := make([]string, 0, 2) + seen := make(map[string]struct{}, 2) + appendName := func(stem string) { + trimmedStem := strings.TrimSpace(stem) + if trimmedStem == "" { + return + } + name := fmt.Sprintf("%s-%s-%s", trimmedStem, stdRuntime.GOOS, stdRuntime.GOARCH) + if strings.EqualFold(stdRuntime.GOOS, "windows") { + name += ".exe" + } if strings.TrimSpace(name) == "" { return } @@ -3239,27 +3443,14 @@ func optionalDriverExecutableBaseNames(driverType string) []string { names = append(names, name) } - appendName(optionalDriverPublicTypeName(driverType)) + for _, stem := range optionalDriverNameStemCandidates(driverType, selectedVersion) { + appendName(stem) + } return names } func optionalDriverReleaseAssetNames(driverType string) []string { - names := make([]string, 0, 2) - seen := make(map[string]struct{}, 2) - appendName := func(typeName string) { - name := optionalDriverReleaseAssetNameForType(typeName, stdRuntime.GOOS, stdRuntime.GOARCH) - if strings.TrimSpace(name) == "" { - return - } - if _, ok := seen[name]; ok { - return - } - seen[name] = struct{}{} - names = append(names, name) - } - - appendName(optionalDriverPublicTypeName(driverType)) - return names + return optionalDriverReleaseAssetNamesForVersion(driverType, "") } func optionalDriverExecutableBaseName(driverType string) string { @@ -3278,6 +3469,14 @@ func optionalDriverReleaseAssetName(driverType string) string { return names[0] } +func optionalDriverReleaseAssetNameForVersion(driverType string, selectedVersion string) string { + names := optionalDriverReleaseAssetNamesForVersion(driverType, selectedVersion) + if len(names) == 0 { + return optionalDriverReleaseAssetNameForType("", stdRuntime.GOOS, stdRuntime.GOARCH) + } + return names[0] +} + func optionalDriverBundlePlatformDir(goos string) string { switch strings.ToLower(strings.TrimSpace(goos)) { case "windows": @@ -3328,6 +3527,19 @@ func resolveOptionalDriverAssetSize(sizeByAsset map[string]int64, driverType str return 0 } +func resolveOptionalDriverAssetSizeForVersion(sizeByAsset map[string]int64, driverType string, version string) int64 { + if len(sizeByAsset) == 0 { + return 0 + } + for _, assetName := range optionalDriverReleaseAssetNamesForVersion(driverType, version) { + sizeBytes := sizeByAsset[assetName] + if sizeBytes > 0 { + return sizeBytes + } + } + return 0 +} + func resolveOptionalDriverBundleDownloadURLs() []string { candidates := make([]string, 0, 2) seen := make(map[string]struct{}, 2) @@ -3351,7 +3563,7 @@ func resolveOptionalDriverBundleDownloadURLs() []string { return candidates } -func resolveOptionalDriverAgentDownloadURLs(definition driverDefinition, rawURL string) []string { +func resolveOptionalDriverAgentDownloadURLs(definition driverDefinition, rawURL string, selectedVersion string) []string { driverType := normalizeDriverType(definition.Type) candidates := make([]string, 0, 3) seen := make(map[string]struct{}, 3) @@ -3373,6 +3585,9 @@ func resolveOptionalDriverAgentDownloadURLs(definition driverDefinition, rawURL appendURL(parsed.String()) } } + if shouldRestrictToExplicitVersionArtifact(definition, selectedVersion) { + return candidates + } assetNames := optionalDriverReleaseAssetNames(driverType) currentVersion := normalizeVersion(getCurrentVersion()) diff --git a/internal/app/methods_driver_version_test.go b/internal/app/methods_driver_version_test.go new file mode 100644 index 0000000..5fcfe34 --- /dev/null +++ b/internal/app/methods_driver_version_test.go @@ -0,0 +1,222 @@ +package app + +import ( + "fmt" + "os" + "runtime" + "strings" + "testing" + "time" +) + +func TestResolveVersionedDriverOptionUsesPublishedMongoV1Release(t *testing.T) { + definition, ok := resolveDriverDefinition("mongodb") + if !ok { + t.Fatal("expected mongodb driver definition") + } + + version := "1.17.4" + assetName := mongoVersionedReleaseAssetName(1) + seedReleaseAssetSizeCache(t, "tag:v"+version, map[string]int64{ + assetName: 24 << 20, + }) + chdirTemp(t) + + gotVersion, gotURL, ok := resolveVersionedDriverOption(definition, version, "history") + if !ok { + t.Fatal("expected published mongodb v1 option to remain available") + } + if gotVersion != version { + t.Fatalf("expected version %q, got %q", version, gotVersion) + } + + wantURL := fmt.Sprintf("https://github.com/%s/releases/download/v%s/%s", updateRepo, version, assetName) + if gotURL != wantURL { + t.Fatalf("expected published release URL %q, got %q", wantURL, gotURL) + } +} + +func TestDriverVersionSupportRangeForMongoDB(t *testing.T) { + definition, ok := resolveDriverDefinition("mongodb") + if !ok { + t.Fatal("expected mongodb driver definition") + } + + if err := validateDriverSelectedVersion(definition, "1.17.4"); err != nil { + t.Fatalf("expected 1.17.4 to stay supported, got %v", err) + } + if err := validateDriverSelectedVersion(definition, "2.5.0"); err != nil { + t.Fatalf("expected 2.5.0 to stay supported, got %v", err) + } + if err := validateDriverSelectedVersion(definition, "1.16.1"); err == nil { + t.Fatal("expected 1.16.1 to be rejected by MongoDB support range") + } +} + +func TestResolveVersionedDriverOptionSkipsMongoV1WithoutPublishedReleaseOrSourceBuild(t *testing.T) { + definition, ok := resolveDriverDefinition("mongodb") + if !ok { + t.Fatal("expected mongodb driver definition") + } + + version := "1.17.4" + seedReleaseAssetSizeCache(t, "tag:v"+version, map[string]int64{}) + chdirTemp(t) + + _, _, ok = resolveVersionedDriverOption(definition, version, "history") + if ok { + t.Fatal("expected unpublished mongodb v1 option to be filtered out when source build is unavailable") + } +} + +func TestResolveVersionedDriverOptionRejectsUnsupportedMongoV1Range(t *testing.T) { + definition, ok := resolveDriverDefinition("mongodb") + if !ok { + t.Fatal("expected mongodb driver definition") + } + + seedReleaseAssetSizeCache(t, "tag:v1.16.1", map[string]int64{ + mongoVersionedReleaseAssetName(1): 24 << 20, + }) + + _, _, ok = resolveVersionedDriverOption(definition, "1.16.1", "history") + if ok { + t.Fatal("expected MongoDB 1.16.1 to be hidden from the selectable version list") + } +} + +func TestResolveDriverVersionPackageSizeBytesReadsMongoV1VersionedAsset(t *testing.T) { + definition, ok := resolveDriverDefinition("mongodb") + if !ok { + t.Fatal("expected mongodb driver definition") + } + + version := "1.17.4" + assetName := mongoVersionedReleaseAssetName(1) + const wantSize int64 = 31 << 20 + seedReleaseAssetSizeCache(t, "tag:v"+version, map[string]int64{ + assetName: wantSize, + }) + + got := resolveDriverVersionPackageSizeBytes(definition, driverVersionOptionItem{ + Version: version, + Source: "history", + }) + if got != wantSize { + t.Fatalf("expected size %d, got %d", wantSize, got) + } +} + +func TestResolveOptionalDriverAgentDownloadURLsDoesNotFallbackForHistoricalVersion(t *testing.T) { + definition, ok := resolveDriverDefinition("mongodb") + if !ok { + t.Fatal("expected mongodb driver definition") + } + + explicitURL := fmt.Sprintf("https://github.com/Syngnat/GoNavi/releases/download/v1.17.4/%s", mongoVersionedReleaseAssetName(1)) + urls := resolveOptionalDriverAgentDownloadURLs( + definition, + explicitURL, + "1.17.4", + ) + if len(urls) != 1 { + t.Fatalf("expected only explicit historical URL, got %d candidates: %v", len(urls), urls) + } + if urls[0] != explicitURL { + t.Fatalf("unexpected historical URL candidate: %v", urls) + } +} + +func TestDownloadDriverPackageRejectsUnsupportedMongoVersion(t *testing.T) { + app := &App{} + + result := app.DownloadDriverPackage("mongodb", "1.16.1", "builtin://activate/mongodb?channel=history&version=1.16.1", t.TempDir()) + if result.Success { + t.Fatal("expected unsupported MongoDB 1.16.1 install to be rejected") + } + if !strings.Contains(result.Message, "仅支持 1.17.x 和 2.x") { + t.Fatalf("expected support-range error, got %q", result.Message) + } +} + +func TestShouldForceSourceBuildForResolvedDownload(t *testing.T) { + if !shouldForceSourceBuildForResolvedDownload("mongodb", "1.17.4", "builtin://activate/mongodb?channel=history&version=1.17.4") { + t.Fatal("expected mongodb v1 builtin install to keep source build mode") + } + + explicitURL := fmt.Sprintf("https://github.com/%s/releases/download/v1.17.4/%s", updateRepo, mongoVersionedReleaseAssetName(1)) + if shouldForceSourceBuildForResolvedDownload("mongodb", "1.17.4", explicitURL) { + t.Fatal("expected mongodb v1 published asset install to skip forced source build") + } + + if shouldForceSourceBuildForResolvedDownload("mongodb", "2.5.0", "builtin://activate/mongodb?channel=latest&version=2.5.0") { + t.Fatal("expected mongodb v2 install not to force source build") + } +} + +func seedReleaseAssetSizeCache(t *testing.T, cacheKey string, sizeByKey map[string]int64) { + t.Helper() + + driverReleaseSizeMu.Lock() + original := cloneReleaseAssetSizeCache(driverReleaseSizeMap) + driverReleaseSizeMap[cacheKey] = driverReleaseAssetSizeCacheEntry{ + LoadedAt: time.Now(), + SizeByKey: cloneInt64Map(sizeByKey), + } + driverReleaseSizeMu.Unlock() + + t.Cleanup(func() { + driverReleaseSizeMu.Lock() + driverReleaseSizeMap = original + driverReleaseSizeMu.Unlock() + }) +} + +func cloneReleaseAssetSizeCache(src map[string]driverReleaseAssetSizeCacheEntry) map[string]driverReleaseAssetSizeCacheEntry { + cloned := make(map[string]driverReleaseAssetSizeCacheEntry, len(src)) + for key, value := range src { + cloned[key] = driverReleaseAssetSizeCacheEntry{ + LoadedAt: value.LoadedAt, + SizeByKey: cloneInt64Map(value.SizeByKey), + Err: value.Err, + } + } + return cloned +} + +func cloneInt64Map(src map[string]int64) map[string]int64 { + if len(src) == 0 { + return map[string]int64{} + } + cloned := make(map[string]int64, len(src)) + for key, value := range src { + cloned[key] = value + } + return cloned +} + +func chdirTemp(t *testing.T) { + t.Helper() + + wd, err := os.Getwd() + if err != nil { + t.Fatalf("getwd failed: %v", err) + } + tempDir := t.TempDir() + if err := os.Chdir(tempDir); err != nil { + t.Fatalf("chdir temp failed: %v", err) + } + t.Cleanup(func() { + if err := os.Chdir(wd); err != nil { + t.Fatalf("restore cwd failed: %v", err) + } + }) +} + +func mongoVersionedReleaseAssetName(major int) string { + name := fmt.Sprintf("mongodb-driver-agent-v%d-%s-%s", major, runtime.GOOS, runtime.GOARCH) + if runtime.GOOS == "windows" { + return name + ".exe" + } + return name +} From c1266c225a11ea9cea5ef7cab1936152807d3232 Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Thu, 2 Apr 2026 21:17:52 +0800 Subject: [PATCH 03/14] =?UTF-8?q?=F0=9F=90=9B=20fix(ai/provider):=20?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20Claude=20CLI=20=E5=9C=A8=20Windows=20?= =?UTF-8?q?=E4=B8=8A=E7=9A=84=E6=B5=8B=E8=AF=95=E7=A8=B3=E5=AE=9A=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/ai/provider/claude_cli_test.go | 31 +++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/internal/ai/provider/claude_cli_test.go b/internal/ai/provider/claude_cli_test.go index 773c8a2..8929841 100644 --- a/internal/ai/provider/claude_cli_test.go +++ b/internal/ai/provider/claude_cli_test.go @@ -3,8 +3,11 @@ package provider import ( "context" "errors" + "fmt" "os" + "os/exec" "path/filepath" + "runtime" "strings" "testing" "time" @@ -324,6 +327,26 @@ func TestClaudeCLIProvider_ChatStreamReportsApiRetryAuthenticationFailure(t *tes func writeFakeClaudeScript(t *testing.T, content string) string { t.Helper() dir := t.TempDir() + + if runtime.GOOS == "windows" { + bashPath, err := resolveClaudeCodeGitBashPath(os.Environ(), runtime.GOOS, exec.LookPath, fileExists) + if err != nil { + t.Fatalf("failed to resolve git bash for fake claude command: %v", err) + } + + scriptPath := filepath.Join(dir, "claude.sh") + if err := os.WriteFile(scriptPath, []byte(content), 0o755); err != nil { + t.Fatalf("failed to write fake claude shell script: %v", err) + } + + wrapperPath := filepath.Join(dir, "claude.cmd") + wrapper := fmt.Sprintf("@echo off\r\n\"%s\" \"%s\" %%*\r\n", bashPath, scriptPath) + if err := os.WriteFile(wrapperPath, []byte(wrapper), 0o755); err != nil { + t.Fatalf("failed to write fake claude wrapper: %v", err) + } + return wrapperPath + } + path := filepath.Join(dir, "claude") if err := os.WriteFile(path, []byte(content), 0o755); err != nil { t.Fatalf("failed to write fake claude script: %v", err) @@ -335,12 +358,19 @@ func overrideClaudeCLIForTest(t *testing.T, fakeClaudePath string) func() { t.Helper() originalLookPath := claudeLookPath + originalCommandContext := claudeCommandContext claudeLookPath = func(name string) (string, error) { if name == "claude" { return fakeClaudePath, nil } return originalLookPath(name) } + claudeCommandContext = func(ctx context.Context, name string, args ...string) *exec.Cmd { + if name == "claude" { + return exec.CommandContext(ctx, fakeClaudePath, args...) + } + return originalCommandContext(ctx, name, args...) + } originalPath := os.Getenv("PATH") if err := os.Setenv("PATH", filepath.Dir(fakeClaudePath)+string(os.PathListSeparator)+originalPath); err != nil { @@ -349,6 +379,7 @@ func overrideClaudeCLIForTest(t *testing.T, fakeClaudePath string) func() { return func() { claudeLookPath = originalLookPath + claudeCommandContext = originalCommandContext _ = os.Setenv("PATH", originalPath) } } From ef64a24e013af6b094ac8405f03ee2a1edf9e74a Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Thu, 2 Apr 2026 23:22:16 +0800 Subject: [PATCH 04/14] =?UTF-8?q?=E2=9C=A8=20feat(security):=20=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=E5=AF=86=E9=92=A5=E5=AD=98=E5=82=A8=E5=9F=BA=E7=A1=80?= =?UTF-8?q?=E6=9E=B6=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 8 ++ go.sum | 17 ++++ internal/secretstore/keyring_store.go | 104 +++++++++++++++++++++ internal/secretstore/keyring_store_test.go | 102 ++++++++++++++++++++ internal/secretstore/store.go | 69 ++++++++++++++ 5 files changed, 300 insertions(+) create mode 100644 internal/secretstore/keyring_store.go create mode 100644 internal/secretstore/keyring_store_test.go create mode 100644 internal/secretstore/store.go diff --git a/go.mod b/go.mod index c215315..d2cd1fe 100644 --- a/go.mod +++ b/go.mod @@ -28,11 +28,14 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect + github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect + github.com/99designs/keyring v1.2.2 github.com/ClickHouse/ch-go v0.71.0 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/apache/arrow-go/v18 v18.5.1 // indirect github.com/bep/debounce v1.2.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/danieljoos/wincred v1.1.2 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/duckdb/duckdb-go-bindings v0.3.3 // indirect github.com/duckdb/duckdb-go-bindings/lib/darwin-amd64 v0.3.3 // indirect @@ -41,17 +44,20 @@ require ( github.com/duckdb/duckdb-go-bindings/lib/linux-arm64 v0.3.3 // indirect github.com/duckdb/duckdb-go-bindings/lib/windows-amd64 v0.3.3 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/dvsekhvalnov/jose2go v1.5.0 // indirect github.com/go-faster/city v1.0.1 // indirect github.com/go-faster/errors v0.7.1 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-viper/mapstructure/v2 v2.5.0 // indirect github.com/goccy/go-json v0.10.5 // indirect + github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang/snappy v1.0.0 // indirect github.com/google/flatbuffers v25.12.19+incompatible // indirect github.com/gorilla/websocket v1.5.3 // indirect + github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect github.com/hashicorp/go-version v1.8.0 // indirect github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e // indirect github.com/json-iterator/go v1.1.12 // indirect @@ -68,6 +74,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/montanaflynn/stats v0.7.1 // indirect + github.com/mtibben/percent v0.2.1 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect github.com/paulmach/orb v0.12.0 // indirect github.com/pierrec/lz4/v4 v4.1.25 // indirect @@ -100,6 +107,7 @@ require ( golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/telemetry v0.0.0-20260116145544-c6413dc483f5 // indirect + golang.org/x/term v0.39.0 // indirect golang.org/x/tools v0.41.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect modernc.org/libc v1.67.6 // indirect diff --git a/go.sum b/go.sum index 6579c99..0bd72f3 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,10 @@ gitea.com/kingbase/gokb v0.0.0-20201021123113-29bd62a876c3 h1:QjslQNaH5Nuap5i4ni gitea.com/kingbase/gokb v0.0.0-20201021123113-29bd62a876c3/go.mod h1:7lH5A1jzCXD9Nl16DzaBUOfDAT8NPrDmZwKu1p5wf94= gitee.com/chunanyong/dm v1.8.22 h1:H7fsrnUIvEA0jlDWew7vwELry1ff+tLMIu2Fk2cIBSg= gitee.com/chunanyong/dm v1.8.22/go.mod h1:EPRJnuPFgbyOFgJ0TRYCTGzhq+ZT4wdyaj/GW/LLcNg= +github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 h1:/vQbFIOMbk2FiG/kXiLl8BRyzTWDw7gX/Hz7Dd5eDMs= +github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4/go.mod h1:hN7oaIRCjzsZ2dE+yG5k+rsdt3qcwykqK6HVGcKwsw4= +github.com/99designs/keyring v1.2.2 h1:pZd3neh/EmUzWONb35LxQfvuY7kiSXAq3HQd97+XBn0= +github.com/99designs/keyring v1.2.2/go.mod h1:wes/FrByc8j7lFOAGLGSNEg8f/PaI3cgTBqhFkHUrPk= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4= @@ -34,6 +38,8 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/danieljoos/wincred v1.1.2 h1:QLdCxFs1/Yl4zduvBdcHB8goaYk9RARS2SgLLRuAyr0= +github.com/danieljoos/wincred v1.1.2/go.mod h1:GijpziifJoIBfYh+S7BbkdUTU4LfM+QnGqR5Vl2tAx0= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -56,6 +62,8 @@ github.com/duckdb/duckdb-go/v2 v2.5.5 h1:TlK8ipnzoKW2aNrjGqRkFWLCDpJDxR/VwH8ezEc github.com/duckdb/duckdb-go/v2 v2.5.5/go.mod h1:6uIbC3gz36NCEygECzboygOo/Z9TeVwox/puG+ohWV0= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/dvsekhvalnov/jose2go v1.5.0 h1:3j8ya4Z4kMCwT5nXIKFSV84YS+HdqSSO0VsTQxaLAeM= +github.com/dvsekhvalnov/jose2go v1.5.0/go.mod h1:QsHjhyTlD/lAVqn/NSbVZmSCGeDehTB/mPZadG+mhXU= github.com/go-faster/city v1.0.1 h1:4WAxSZ3V2Ws4QRDrscLEDcibJY8uf41H6AhXDrNDcGw= github.com/go-faster/city v1.0.1/go.mod h1:jKcUJId49qdW3L1qKHH/3wPeUstCVpVSXTM6vO3VcTw= github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AYg= @@ -68,6 +76,8 @@ github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPE github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 h1:ZpnhV/YsD2/4cESfV5+Hoeu/iUR3ruzNvZ+yQfO03a0= +github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= @@ -95,6 +105,8 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c h1:6rhixN/i8ZofjG1Y75iExal34USq5p+wiN1tpie8IrU= +github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c/go.mod h1:NMPJylDgVpX0MLRlPy15sqSwOFv/U1GZ2m21JhFfek0= github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/go-version v1.8.0 h1:KAkNb1HAiZd1ukkxDFGmokVZe1Xy9HG6NUp+bPle2i4= github.com/hashicorp/go-version v1.8.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= @@ -158,6 +170,8 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE= github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= +github.com/mtibben/percent v0.2.1 h1:5gssi8Nqo8QU/r2pynCm+hBQHpkB/uNK7BJCFogWdzs= +github.com/mtibben/percent v0.2.1/go.mod h1:KG9uO+SZkUp+VkRHsCdYQV3XSZrrSpR3O9ibNBTZrns= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= @@ -201,6 +215,7 @@ github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= @@ -300,6 +315,7 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210819135213-f52c844e1c1c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -342,6 +358,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/secretstore/keyring_store.go b/internal/secretstore/keyring_store.go new file mode 100644 index 0000000..93fe0bc --- /dev/null +++ b/internal/secretstore/keyring_store.go @@ -0,0 +1,104 @@ +package secretstore + +import ( + "errors" + "fmt" + "runtime" + + "github.com/99designs/keyring" +) + +type keyringClient interface { + Get(key string) (keyring.Item, error) + Set(item keyring.Item) error + Remove(key string) error +} + +type keyringStore struct { + ring keyringClient +} + +type keyringOpener func(cfg keyring.Config) (keyring.Keyring, error) + +func NewKeyringStore() SecretStore { + return newKeyringStoreWithOpener(runtime.GOOS, keyring.Open) +} + +func newKeyringStoreWithOpener(goos string, open keyringOpener) SecretStore { + cfg, err := keyringConfigFor(goos) + if err != nil { + return NewUnavailableStore(err.Error()) + } + + ring, err := open(cfg) + if err != nil { + return NewUnavailableStore(err.Error()) + } + + return &keyringStore{ring: ring} +} + +func (s *keyringStore) Put(ref string, payload []byte) error { + return wrapKeyringError(s.ring.Set(keyring.Item{Key: ref, Data: payload})) +} + +func (s *keyringStore) Get(ref string) ([]byte, error) { + item, err := s.ring.Get(ref) + if err != nil { + return nil, wrapKeyringError(err) + } + return item.Data, nil +} + +func (s *keyringStore) Delete(ref string) error { + return wrapKeyringError(s.ring.Remove(ref)) +} + +func (s *keyringStore) HealthCheck() error { + _, err := s.ring.Get(healthCheckRef) + if err == nil || errors.Is(err, keyring.ErrKeyNotFound) { + return nil + } + return wrapKeyringError(err) +} + +func wrapKeyringError(err error) error { + if err == nil || errors.Is(err, keyring.ErrKeyNotFound) || IsUnavailable(err) { + return err + } + return &UnavailableError{Reason: err.Error()} +} + +func keyringConfigFor(goos string) (keyring.Config, error) { + backends := allowedBackendsFor(goos) + if len(backends) == 0 { + return keyring.Config{}, fmt.Errorf("unsupported keyring platform: %s", goos) + } + + return keyring.Config{ + ServiceName: serviceName, + AllowedBackends: backends, + KeychainTrustApplication: true, + KeychainAccessibleWhenUnlocked: true, + LibSecretCollectionName: "default", + KeyCtlScope: "user", + WinCredPrefix: serviceName, + }, nil +} + +func allowedBackendsFor(goos string) []keyring.BackendType { + switch goos { + case "windows": + return []keyring.BackendType{keyring.WinCredBackend} + case "darwin": + return []keyring.BackendType{keyring.KeychainBackend} + case "linux": + return []keyring.BackendType{ + keyring.SecretServiceBackend, + keyring.KWalletBackend, + keyring.KeyCtlBackend, + } + default: + return nil + } +} diff --git a/internal/secretstore/keyring_store_test.go b/internal/secretstore/keyring_store_test.go new file mode 100644 index 0000000..4d387f0 --- /dev/null +++ b/internal/secretstore/keyring_store_test.go @@ -0,0 +1,102 @@ +package secretstore + +import ( + "errors" + "testing" + + "github.com/99designs/keyring" +) + +func TestBuildRefRejectsEmptyKind(t *testing.T) { + t.Parallel() + + if _, err := BuildRef("", "secret-id"); err == nil { + t.Fatal("BuildRef should reject an empty kind") + } +} + +func TestBuildRefRejectsEmptyID(t *testing.T) { + t.Parallel() + + if _, err := BuildRef("database", ""); err == nil { + t.Fatal("BuildRef should reject an empty id") + } +} + +func TestUnavailableStoreHealthCheckReturnsUnavailableError(t *testing.T) { + t.Parallel() + + store := NewUnavailableStore("keyring backend disabled") + + err := store.HealthCheck() + if err == nil { + t.Fatal("HealthCheck should return an unavailable error") + } + + if !IsUnavailable(err) { + t.Fatalf("HealthCheck error should be detected by IsUnavailable, got %T", err) + } +} + +func TestKeyringStoreHealthCheckTreatsMissingProbeItemAsHealthy(t *testing.T) { + t.Parallel() + + store := &keyringStore{ring: fakeKeyringClient{getErr: keyring.ErrKeyNotFound}} + if err := store.HealthCheck(); err != nil { + t.Fatalf("HealthCheck should accept ErrKeyNotFound, got %v", err) + } +} + +func TestKeyringStoreHealthCheckReturnsUnavailableErrorOnBackendFailure(t *testing.T) { + t.Parallel() + + store := &keyringStore{ring: fakeKeyringClient{getErr: errors.New("backend offline")}} + if err := store.HealthCheck(); err == nil || !IsUnavailable(err) { + t.Fatalf("HealthCheck should wrap backend failures as unavailable, got %v", err) + } +} + +func TestNewKeyringStoreReturnsUnavailableStoreWhenOpenFails(t *testing.T) { + t.Parallel() + + store := newKeyringStoreWithOpener("windows", func(cfg keyring.Config) (keyring.Keyring, error) { + if len(cfg.AllowedBackends) != 1 || cfg.AllowedBackends[0] != keyring.WinCredBackend { + t.Fatalf("unexpected backend config: %#v", cfg.AllowedBackends) + } + return nil, errors.New("no backend") + }) + + if err := store.HealthCheck(); err == nil || !IsUnavailable(err) { + t.Fatalf("expected unavailable store when open fails, got %v", err) + } +} + +type fakeKeyringClient struct { + getErr error + item keyring.Item + removeErr error +} + +func (f fakeKeyringClient) Get(string) (keyring.Item, error) { + if f.getErr != nil { + return keyring.Item{}, f.getErr + } + return f.item, nil +} + +func (f fakeKeyringClient) Set(item keyring.Item) error { + _ = item + return nil +} + +func (f fakeKeyringClient) Remove(string) error { + return f.removeErr +} + +func (f fakeKeyringClient) GetMetadata(string) (keyring.Metadata, error) { + return keyring.Metadata{}, nil +} + +func (f fakeKeyringClient) Keys() ([]string, error) { + return nil, nil +} diff --git a/internal/secretstore/store.go b/internal/secretstore/store.go new file mode 100644 index 0000000..7716c58 --- /dev/null +++ b/internal/secretstore/store.go @@ -0,0 +1,69 @@ +package secretstore + +import ( + "errors" + "fmt" + "strings" +) + +const ( + serviceName = "gonavi" + healthCheckRef = "oskeyring://gonavi/healthcheck/ping" +) + +type SecretStore interface { + Put(ref string, payload []byte) error + Get(ref string) ([]byte, error) + Delete(ref string) error + HealthCheck() error +} + +type UnavailableError struct { + Reason string +} + +func (e *UnavailableError) Error() string { + reason := strings.TrimSpace(e.Reason) + if reason == "" { + return "secret store unavailable" + } + return fmt.Sprintf("secret store unavailable: %s", reason) +} + +func IsUnavailable(err error) bool { + var target *UnavailableError + return errors.As(err, &target) +} + +type unavailableStore struct { + err error +} + +func NewUnavailableStore(reason string) SecretStore { + return unavailableStore{err: &UnavailableError{Reason: strings.TrimSpace(reason)}} +} + +func (s unavailableStore) Put(string, []byte) error { + return s.err +} + +func (s unavailableStore) Get(string) ([]byte, error) { + return nil, s.err +} + +func (s unavailableStore) Delete(string) error { + return s.err +} + +func (s unavailableStore) HealthCheck() error { + return s.err +} + +func BuildRef(kind, id string) (string, error) { + kind = strings.TrimSpace(kind) + id = strings.TrimSpace(id) + if kind == "" || id == "" { + return "", fmt.Errorf("invalid secret ref") + } + return fmt.Sprintf("oskeyring://%s/%s/%s", serviceName, kind, id), nil +} From f74270d58502a689e2f34e777c2903190a916a5d Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Thu, 2 Apr 2026 23:33:18 +0800 Subject: [PATCH 05/14] =?UTF-8?q?=F0=9F=90=9B=20fix(security):=20=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=E5=AF=86=E9=92=A5=E5=AD=98=E5=82=A8=E7=8A=B6=E6=80=81?= =?UTF-8?q?=E6=9E=9A=E4=B8=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/secretstore/keyring_store_test.go | 11 +++++++++++ internal/secretstore/store.go | 7 +++++++ 2 files changed, 18 insertions(+) diff --git a/internal/secretstore/keyring_store_test.go b/internal/secretstore/keyring_store_test.go index 4d387f0..03fc49f 100644 --- a/internal/secretstore/keyring_store_test.go +++ b/internal/secretstore/keyring_store_test.go @@ -7,6 +7,17 @@ import ( "github.com/99designs/keyring" ) +func TestStoreStatusValuesRemainStable(t *testing.T) { + t.Parallel() + + if StatusAvailable != "available" { + t.Fatalf("expected StatusAvailable to remain stable, got %q", StatusAvailable) + } + if StatusUnavailable != "unavailable" { + t.Fatalf("expected StatusUnavailable to remain stable, got %q", StatusUnavailable) + } +} + func TestBuildRefRejectsEmptyKind(t *testing.T) { t.Parallel() diff --git a/internal/secretstore/store.go b/internal/secretstore/store.go index 7716c58..2d6b3e2 100644 --- a/internal/secretstore/store.go +++ b/internal/secretstore/store.go @@ -18,6 +18,13 @@ type SecretStore interface { HealthCheck() error } +type StoreStatus string + +const ( + StatusAvailable StoreStatus = "available" + StatusUnavailable StoreStatus = "unavailable" +) + type UnavailableError struct { Reason string } From b62d22395b5031e49dd40ac45487e758119ca2b3 Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Fri, 3 Apr 2026 00:18:06 +0800 Subject: [PATCH 06/14] =?UTF-8?q?=E2=9C=A8=20feat(security):=20=E6=8B=86?= =?UTF-8?q?=E5=88=86=20AI=20=E4=BE=9B=E5=BA=94=E5=95=86=E5=85=83=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E4=B8=8E=E5=AF=86=E9=92=A5=E5=AD=98=E5=82=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/ai/service/provider_secret.go | 231 +++++++++++++ internal/ai/service/provider_secret_test.go | 348 ++++++++++++++++++++ internal/ai/service/service.go | 174 ++++++++-- internal/ai/types.go | 2 + 4 files changed, 728 insertions(+), 27 deletions(-) create mode 100644 internal/ai/service/provider_secret.go create mode 100644 internal/ai/service/provider_secret_test.go diff --git a/internal/ai/service/provider_secret.go b/internal/ai/service/provider_secret.go new file mode 100644 index 0000000..6fe22bc --- /dev/null +++ b/internal/ai/service/provider_secret.go @@ -0,0 +1,231 @@ +package aiservice + +import ( + "encoding/json" + "fmt" + "strings" + "unicode" + + "GoNavi-Wails/internal/ai" + "GoNavi-Wails/internal/secretstore" +) + +const providerSecretKind = "ai-provider" + +type providerSecretBundle struct { + APIKey string `json:"apiKey,omitempty"` + SensitiveHeaders map[string]string `json:"sensitiveHeaders,omitempty"` +} + +func (b providerSecretBundle) hasAny() bool { + return strings.TrimSpace(b.APIKey) != "" || len(b.SensitiveHeaders) > 0 +} + +func mergeProviderSecretBundles(base, overlay providerSecretBundle) providerSecretBundle { + merged := providerSecretBundle{ + APIKey: base.APIKey, + SensitiveHeaders: cloneStringMap(base.SensitiveHeaders), + } + if strings.TrimSpace(overlay.APIKey) != "" { + merged.APIKey = overlay.APIKey + } + for key, value := range overlay.SensitiveHeaders { + if merged.SensitiveHeaders == nil { + merged.SensitiveHeaders = make(map[string]string, len(overlay.SensitiveHeaders)) + } + merged.SensitiveHeaders[key] = value + } + if len(merged.SensitiveHeaders) == 0 { + merged.SensitiveHeaders = nil + } + return merged +} + +func splitProviderSecrets(cfg ai.ProviderConfig) (ai.ProviderConfig, providerSecretBundle) { + meta := cfg + meta.APIKey = "" + + bundle := providerSecretBundle{} + if apiKey := strings.TrimSpace(cfg.APIKey); apiKey != "" { + bundle.APIKey = apiKey + } + + if len(cfg.Headers) > 0 { + safeHeaders := make(map[string]string, len(cfg.Headers)) + sensitiveHeaders := make(map[string]string) + for key, value := range cfg.Headers { + if isSensitiveProviderHeader(key) { + if strings.TrimSpace(value) != "" { + sensitiveHeaders[key] = value + } + continue + } + safeHeaders[key] = value + } + if len(safeHeaders) > 0 { + meta.Headers = safeHeaders + } else { + meta.Headers = nil + } + if len(sensitiveHeaders) > 0 { + bundle.SensitiveHeaders = sensitiveHeaders + } + } else { + meta.Headers = nil + } + + meta.HasSecret = cfg.HasSecret || bundle.hasAny() + meta.SecretRef = strings.TrimSpace(cfg.SecretRef) + if meta.HasSecret && meta.SecretRef == "" && strings.TrimSpace(cfg.ID) != "" { + if ref, err := secretstore.BuildRef(providerSecretKind, cfg.ID); err == nil { + meta.SecretRef = ref + } + } + if !meta.HasSecret { + meta.SecretRef = "" + } + + return meta, bundle +} + +func mergeProviderSecrets(cfg ai.ProviderConfig, bundle providerSecretBundle) ai.ProviderConfig { + merged := cfg + merged.APIKey = bundle.APIKey + + headers := cloneStringMap(cfg.Headers) + if len(bundle.SensitiveHeaders) > 0 { + if headers == nil { + headers = make(map[string]string, len(bundle.SensitiveHeaders)) + } + for key, value := range bundle.SensitiveHeaders { + headers[key] = value + } + } + if len(headers) > 0 { + merged.Headers = headers + } else { + merged.Headers = nil + } + + merged.HasSecret = cfg.HasSecret || bundle.hasAny() + if merged.HasSecret && strings.TrimSpace(merged.SecretRef) == "" && strings.TrimSpace(merged.ID) != "" { + if ref, err := secretstore.BuildRef(providerSecretKind, merged.ID); err == nil { + merged.SecretRef = ref + } + } + if !merged.HasSecret { + merged.SecretRef = "" + } + + return merged +} + +func (s *Service) persistProviderSecretBundle(meta ai.ProviderConfig, bundle providerSecretBundle) (ai.ProviderConfig, error) { + meta, _ = splitProviderSecrets(meta) + if !bundle.hasAny() { + meta.HasSecret = false + meta.SecretRef = "" + return meta, nil + } + if s.secretStore == nil { + return meta, fmt.Errorf("secret store unavailable") + } + if err := s.secretStore.HealthCheck(); err != nil { + return meta, err + } + + ref := strings.TrimSpace(meta.SecretRef) + if ref == "" { + var err error + ref, err = secretstore.BuildRef(providerSecretKind, meta.ID) + if err != nil { + return meta, err + } + } + + payload, err := json.Marshal(bundle) + if err != nil { + return meta, fmt.Errorf("序列化 provider secret bundle 失败: %w", err) + } + if err := s.secretStore.Put(ref, payload); err != nil { + return meta, err + } + + meta.SecretRef = ref + meta.HasSecret = true + return meta, nil +} + +func (s *Service) resolveProviderConfigSecrets(cfg ai.ProviderConfig) (ai.ProviderConfig, error) { + cfg = normalizeProviderConfig(cfg) + meta, bundle := splitProviderSecrets(cfg) + if bundle.hasAny() { + return mergeProviderSecrets(meta, bundle), nil + } + if !meta.HasSecret { + return meta, nil + } + if s.secretStore == nil { + return meta, fmt.Errorf("secret store unavailable") + } + + ref := strings.TrimSpace(meta.SecretRef) + if ref == "" { + var err error + ref, err = secretstore.BuildRef(providerSecretKind, meta.ID) + if err != nil { + return meta, err + } + meta.SecretRef = ref + } + + payload, err := s.secretStore.Get(ref) + if err != nil { + return meta, err + } + + var stored providerSecretBundle + if err := json.Unmarshal(payload, &stored); err != nil { + return meta, fmt.Errorf("解析 provider secret bundle 失败: %w", err) + } + return mergeProviderSecrets(meta, stored), nil +} + +func providerMetadataView(cfg ai.ProviderConfig) ai.ProviderConfig { + meta, _ := splitProviderSecrets(normalizeProviderConfig(cfg)) + return meta +} + +func isSensitiveProviderHeader(name string) bool { + normalized := strings.TrimSpace(strings.ToLower(name)) + switch normalized { + case "authorization", "proxy-authorization", "x-api-key", "api-key": + return true + } + + for _, token := range providerHeaderTokens(normalized) { + switch token { + case "auth", "authorization", "token", "secret", "key", "apikey": + return true + } + } + + return false +} + +func providerHeaderTokens(name string) []string { + return strings.FieldsFunc(name, func(r rune) bool { + return !unicode.IsLetter(r) && !unicode.IsDigit(r) + }) +} + +func cloneStringMap(input map[string]string) map[string]string { + if len(input) == 0 { + return nil + } + cloned := make(map[string]string, len(input)) + for key, value := range input { + cloned[key] = value + } + return cloned +} diff --git a/internal/ai/service/provider_secret_test.go b/internal/ai/service/provider_secret_test.go new file mode 100644 index 0000000..033b24f --- /dev/null +++ b/internal/ai/service/provider_secret_test.go @@ -0,0 +1,348 @@ +package aiservice + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "GoNavi-Wails/internal/ai" + "GoNavi-Wails/internal/secretstore" +) + +func TestSplitProviderSecretsStripsAPIKeyAndSensitiveHeaders(t *testing.T) { + input := ai.ProviderConfig{ + ID: "openai-main", + APIKey: "sk-test", + BaseURL: "https://api.openai.com/v1", + Headers: map[string]string{ + "Authorization": "Bearer test", + "X-Team": "db", + }, + } + + meta, bundle := splitProviderSecrets(input) + if meta.APIKey != "" { + t.Fatal("apiKey should not stay in metadata") + } + if meta.Headers["Authorization"] != "" { + t.Fatal("sensitive header should not stay in metadata") + } + if meta.Headers["X-Team"] != "db" { + t.Fatal("non-sensitive header should stay in metadata") + } + if bundle.APIKey != "sk-test" { + t.Fatal("bundle should keep apiKey") + } + if bundle.SensitiveHeaders["Authorization"] != "Bearer test" { + t.Fatal("bundle should keep sensitive header") + } +} + +func TestResolveProviderConfigSecretsRestoresStoredSecretBundle(t *testing.T) { + store := newFakeProviderSecretStore() + service := NewServiceWithSecretStore(store) + ref, err := secretstore.BuildRef("ai-provider", "openai-main") + if err != nil { + t.Fatalf("BuildRef returned error: %v", err) + } + payload, err := json.Marshal(providerSecretBundle{ + APIKey: "sk-test", + SensitiveHeaders: map[string]string{ + "Authorization": "Bearer test", + }, + }) + if err != nil { + t.Fatalf("Marshal returned error: %v", err) + } + if err := store.Put(ref, payload); err != nil { + t.Fatalf("Put returned error: %v", err) + } + + resolved, err := service.resolveProviderConfigSecrets(ai.ProviderConfig{ + ID: "openai-main", + SecretRef: ref, + HasSecret: true, + Headers: map[string]string{ + "X-Team": "db", + }, + }) + if err != nil { + t.Fatalf("resolveProviderConfigSecrets returned error: %v", err) + } + if resolved.APIKey != "sk-test" { + t.Fatalf("expected restored apiKey, got %q", resolved.APIKey) + } + if resolved.Headers["Authorization"] != "Bearer test" { + t.Fatalf("expected restored sensitive header, got %#v", resolved.Headers) + } + if resolved.Headers["X-Team"] != "db" { + t.Fatalf("expected non-sensitive header to survive, got %#v", resolved.Headers) + } +} + +func TestLoadConfigMigratesPlaintextProviderSecrets(t *testing.T) { + store := newFakeProviderSecretStore() + service := NewServiceWithSecretStore(store) + service.configDir = t.TempDir() + + legacy := aiConfig{ + Providers: []ai.ProviderConfig{ + { + ID: "openai-main", + Type: "openai", + Name: "OpenAI", + APIKey: "sk-test", + BaseURL: "https://api.openai.com/v1", + Headers: map[string]string{ + "Authorization": "Bearer test", + "X-Team": "db", + }, + }, + }, + } + data, err := json.MarshalIndent(legacy, "", " ") + if err != nil { + t.Fatalf("MarshalIndent returned error: %v", err) + } + configPath := filepath.Join(service.configDir, "ai_config.json") + if err := os.WriteFile(configPath, data, 0o644); err != nil { + t.Fatalf("WriteFile returned error: %v", err) + } + + service.loadConfig() + + providers := service.AIGetProviders() + if len(providers) != 1 { + t.Fatalf("expected 1 provider, got %d", len(providers)) + } + if providers[0].APIKey != "" { + t.Fatalf("expected migrated provider to be secretless, got %q", providers[0].APIKey) + } + if !providers[0].HasSecret { + t.Fatal("expected migrated provider to report HasSecret=true") + } + stored, err := store.Get(providers[0].SecretRef) + if err != nil { + t.Fatalf("expected secret bundle in store, got error: %v", err) + } + var bundle providerSecretBundle + if err := json.Unmarshal(stored, &bundle); err != nil { + t.Fatalf("Unmarshal returned error: %v", err) + } + if bundle.APIKey != "sk-test" { + t.Fatalf("expected migrated apiKey in store, got %q", bundle.APIKey) + } + if bundle.SensitiveHeaders["Authorization"] != "Bearer test" { + t.Fatalf("expected migrated sensitive header in store, got %#v", bundle.SensitiveHeaders) + } + + rewritten, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("ReadFile returned error: %v", err) + } + text := string(rewritten) + if strings.Contains(text, "sk-test") { + t.Fatalf("expected rewritten config to remove api key, got %s", text) + } + if strings.Contains(text, "Bearer test") { + t.Fatalf("expected rewritten config to remove sensitive header, got %s", text) + } +} + +func TestAISaveProviderPersistsSecretlessConfigAndReturnsSecretlessView(t *testing.T) { + store := newFakeProviderSecretStore() + service := NewServiceWithSecretStore(store) + service.configDir = t.TempDir() + + err := service.AISaveProvider(ai.ProviderConfig{ + ID: "openai-main", + Type: "openai", + Name: "OpenAI", + APIKey: "sk-test", + BaseURL: "https://api.openai.com/v1", + Headers: map[string]string{ + "Authorization": "Bearer test", + "X-Team": "db", + }, + }) + if err != nil { + t.Fatalf("AISaveProvider returned error: %v", err) + } + + providers := service.AIGetProviders() + if len(providers) != 1 { + t.Fatalf("expected 1 provider, got %d", len(providers)) + } + if providers[0].APIKey != "" { + t.Fatalf("expected secretless provider view, got %q", providers[0].APIKey) + } + if !providers[0].HasSecret { + t.Fatal("expected saved provider view to report HasSecret=true") + } + if providers[0].Headers["Authorization"] != "" { + t.Fatalf("expected secretless provider headers, got %#v", providers[0].Headers) + } + if service.providers[0].APIKey != "sk-test" { + t.Fatalf("expected runtime provider to keep apiKey, got %q", service.providers[0].APIKey) + } + if service.providers[0].Headers["Authorization"] != "Bearer test" { + t.Fatalf("expected runtime provider to keep sensitive header, got %#v", service.providers[0].Headers) + } + + configPath := filepath.Join(service.configDir, "ai_config.json") + data, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("ReadFile returned error: %v", err) + } + text := string(data) + if strings.Contains(text, "sk-test") { + t.Fatalf("expected config file to be secretless, got %s", text) + } + if strings.Contains(text, "Bearer test") { + t.Fatalf("expected config file to remove sensitive headers, got %s", text) + } +} + +func TestAISaveProviderKeepsExistingSecretWhenInputOmitsAPIKey(t *testing.T) { + store := newFakeProviderSecretStore() + service := NewServiceWithSecretStore(store) + service.configDir = t.TempDir() + + if err := service.AISaveProvider(ai.ProviderConfig{ + ID: "openai-main", + Type: "openai", + Name: "OpenAI", + APIKey: "sk-original", + BaseURL: "https://api.openai.com/v1", + Headers: map[string]string{ + "Authorization": "Bearer original", + "X-Team": "db", + }, + }); err != nil { + t.Fatalf("initial AISaveProvider returned error: %v", err) + } + + if err := service.AISaveProvider(ai.ProviderConfig{ + ID: "openai-main", + Type: "openai", + Name: "OpenAI Updated", + BaseURL: "https://gateway.openai.com/v1", + HasSecret: true, + Headers: map[string]string{ + "X-Team": "platform", + }, + }); err != nil { + t.Fatalf("update AISaveProvider returned error: %v", err) + } + + if service.providers[0].APIKey != "sk-original" { + t.Fatalf("expected runtime provider to keep original apiKey, got %q", service.providers[0].APIKey) + } + if service.providers[0].Headers["Authorization"] != "Bearer original" { + t.Fatalf("expected runtime provider to keep original sensitive header, got %#v", service.providers[0].Headers) + } + if service.providers[0].Headers["X-Team"] != "platform" { + t.Fatalf("expected runtime provider to update non-sensitive headers, got %#v", service.providers[0].Headers) + } + if service.providers[0].BaseURL != "https://gateway.openai.com/v1" { + t.Fatalf("expected runtime provider to update metadata, got %q", service.providers[0].BaseURL) + } + + providers := service.AIGetProviders() + if len(providers) != 1 || !providers[0].HasSecret { + t.Fatalf("expected provider view to keep HasSecret=true, got %#v", providers) + } + if providers[0].APIKey != "" { + t.Fatalf("expected provider view to stay secretless, got %q", providers[0].APIKey) + } +} + +func TestAISaveProviderMergesStoredSensitiveHeadersWhenUpdatingOnlyAPIKey(t *testing.T) { + store := newFakeProviderSecretStore() + service := NewServiceWithSecretStore(store) + service.configDir = t.TempDir() + + if err := service.AISaveProvider(ai.ProviderConfig{ + ID: "openai-main", + Type: "openai", + Name: "OpenAI", + APIKey: "sk-original", + BaseURL: "https://api.openai.com/v1", + Headers: map[string]string{ + "Authorization": "Bearer original", + "X-Team": "db", + }, + }); err != nil { + t.Fatalf("initial AISaveProvider returned error: %v", err) + } + + if err := service.AISaveProvider(ai.ProviderConfig{ + ID: "openai-main", + Type: "openai", + Name: "OpenAI", + APIKey: "sk-updated", + HasSecret: true, + BaseURL: "https://api.openai.com/v1", + Headers: map[string]string{ + "X-Team": "db", + }, + }); err != nil { + t.Fatalf("update AISaveProvider returned error: %v", err) + } + + if service.providers[0].APIKey != "sk-updated" { + t.Fatalf("expected updated apiKey, got %q", service.providers[0].APIKey) + } + if service.providers[0].Headers["Authorization"] != "Bearer original" { + t.Fatalf("expected existing sensitive header to be kept, got %#v", service.providers[0].Headers) + } + + stored, err := store.Get(service.providers[0].SecretRef) + if err != nil { + t.Fatalf("expected merged secret bundle in store, got %v", err) + } + var bundle providerSecretBundle + if err := json.Unmarshal(stored, &bundle); err != nil { + t.Fatalf("Unmarshal returned error: %v", err) + } + if bundle.APIKey != "sk-updated" { + t.Fatalf("expected store to keep updated apiKey, got %q", bundle.APIKey) + } + if bundle.SensitiveHeaders["Authorization"] != "Bearer original" { + t.Fatalf("expected store to keep existing sensitive header, got %#v", bundle.SensitiveHeaders) + } +} + +type fakeProviderSecretStore struct { + items map[string][]byte +} + +func newFakeProviderSecretStore() *fakeProviderSecretStore { + return &fakeProviderSecretStore{items: make(map[string][]byte)} +} + +func (s *fakeProviderSecretStore) Put(ref string, payload []byte) error { + s.items[ref] = append([]byte(nil), payload...) + return nil +} + +func (s *fakeProviderSecretStore) Get(ref string) ([]byte, error) { + payload, ok := s.items[ref] + if !ok { + return nil, os.ErrNotExist + } + return append([]byte(nil), payload...), nil +} + +func (s *fakeProviderSecretStore) Delete(ref string) error { + delete(s.items, ref) + return nil +} + +func (s *fakeProviderSecretStore) HealthCheck() error { + return nil +} + +var _ secretstore.SecretStore = (*fakeProviderSecretStore)(nil) diff --git a/internal/ai/service/service.go b/internal/ai/service/service.go index 6897820..a5bd44e 100644 --- a/internal/ai/service/service.go +++ b/internal/ai/service/service.go @@ -18,6 +18,7 @@ import ( "GoNavi-Wails/internal/ai/provider" "GoNavi-Wails/internal/ai/safety" "GoNavi-Wails/internal/logger" + "GoNavi-Wails/internal/secretstore" "github.com/google/uuid" wailsRuntime "github.com/wailsapp/wails/v2/pkg/runtime" @@ -32,7 +33,8 @@ type Service struct { safetyLevel ai.SQLPermissionLevel contextLevel ai.ContextLevel guard *safety.Guard - configDir string // 配置存储目录 + configDir string // 配置存储目录 + secretStore secretstore.SecretStore cancelFuncs map[string]context.CancelFunc // 记录每个 session 的 context 取消函数 } @@ -97,11 +99,19 @@ var claudeCLIHealthCheckFunc = func(config ai.ProviderConfig) error { // NewService 创建 AI Service 实例 func NewService() *Service { + return NewServiceWithSecretStore(secretstore.NewKeyringStore()) +} + +func NewServiceWithSecretStore(store secretstore.SecretStore) *Service { + if store == nil { + store = secretstore.NewUnavailableStore("secret store unavailable") + } return &Service{ providers: make([]ai.ProviderConfig, 0), safetyLevel: ai.PermissionReadOnly, contextLevel: ai.ContextSchemaOnly, guard: safety.NewGuard(ai.PermissionReadOnly), + secretStore: store, cancelFuncs: make(map[string]context.CancelFunc), } } @@ -127,35 +137,80 @@ func (s *Service) AIGetProviders() []ai.ProviderConfig { defer s.mu.RUnlock() result := make([]ai.ProviderConfig, len(s.providers)) - copy(result, s.providers) - for i := range result { - result[i] = normalizeProviderConfig(result[i]) + for i := range s.providers { + result[i] = providerMetadataView(s.providers[i]) } return result } // AISaveProvider 保存/更新 Provider 配置 func (s *Service) AISaveProvider(config ai.ProviderConfig) error { - fmt.Printf("[AISaveProvider DEBUG] ID: %s, Model: %s\n", config.ID, config.Model) s.mu.Lock() defer s.mu.Unlock() config = normalizeProviderConfig(config) - if strings.TrimSpace(config.ID) == "" { config.ID = "provider-" + uuid.New().String()[:8] } + var existing ai.ProviderConfig found := false - for i, p := range s.providers { - if p.ID == config.ID { - s.providers[i] = config + for _, providerConfig := range s.providers { + if providerConfig.ID == config.ID { + existing = providerConfig found = true break } } - if !found { - s.providers = append(s.providers, config) + + meta, bundle := splitProviderSecrets(config) + var runtimeConfig ai.ProviderConfig + switch { + case bundle.hasAny(): + mergedBundle := bundle + if found && existing.HasSecret { + _, existingBundle := splitProviderSecrets(existing) + mergedBundle = mergeProviderSecretBundles(existingBundle, bundle) + } + if found && strings.TrimSpace(meta.SecretRef) == "" { + meta.SecretRef = existing.SecretRef + } + storedMeta, err := s.persistProviderSecretBundle(meta, mergedBundle) + if err != nil { + return fmt.Errorf("保存 Provider secret 失败: %w", err) + } + runtimeConfig = mergeProviderSecrets(storedMeta, mergedBundle) + case found && (config.HasSecret || existing.HasSecret): + meta.SecretRef = existing.SecretRef + meta.HasSecret = config.HasSecret || existing.HasSecret + resolved, err := s.resolveProviderConfigSecrets(meta) + if err != nil { + return fmt.Errorf("读取已保存 Provider secret 失败: %w", err) + } + runtimeConfig = resolved + default: + runtimeConfig = meta + } + + if !runtimeConfig.HasSecret && found && strings.TrimSpace(existing.SecretRef) != "" { + if err := s.secretStore.Delete(existing.SecretRef); err != nil { + return fmt.Errorf("删除 Provider secret 失败: %w", err) + } + } + if !runtimeConfig.HasSecret { + runtimeConfig.SecretRef = "" + } + + runtimeConfig = normalizeProviderConfig(runtimeConfig) + if found { + for i := range s.providers { + if s.providers[i].ID == runtimeConfig.ID { + s.providers[i] = runtimeConfig + break + } + } + } else { + s.providers = append(s.providers, runtimeConfig) } return s.saveConfig() @@ -167,9 +222,19 @@ func (s *Service) AIDeleteProvider(id string) error { defer s.mu.Unlock() newProviders := make([]ai.ProviderConfig, 0, len(s.providers)) - for _, p := range s.providers { - if p.ID != id { - newProviders = append(newProviders, p) + var removed ai.ProviderConfig + removedFound := false + for _, providerConfig := range s.providers { + if providerConfig.ID == id { + removed = providerConfig + removedFound = true + continue + } + newProviders = append(newProviders, providerConfig) + } + if removedFound && strings.TrimSpace(removed.SecretRef) != "" { + if err := s.secretStore.Delete(removed.SecretRef); err != nil { + return fmt.Errorf("删除 Provider secret 失败: %w", err) } } s.providers = newProviders @@ -186,17 +251,29 @@ func (s *Service) AIDeleteProvider(id string) error { // AITestProvider 测试 Provider 配置是否可用,仅测试端点连通性与密钥,不实际调用对话 func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{} { - // 如果传入脱敏的 key,使用已保存的 key - s.mu.RLock() if isMaskedAPIKey(config.APIKey) { - for _, p := range s.providers { - if p.ID == config.ID { - config.APIKey = p.APIKey - break + config.APIKey = "" + config.HasSecret = true + } + if strings.TrimSpace(config.APIKey) == "" && (config.HasSecret || strings.TrimSpace(config.SecretRef) != "") { + s.mu.RLock() + if strings.TrimSpace(config.SecretRef) == "" { + for _, providerConfig := range s.providers { + if providerConfig.ID == config.ID { + config.SecretRef = providerConfig.SecretRef + config.HasSecret = config.HasSecret || providerConfig.HasSecret + break + } } } + s.mu.RUnlock() + + resolved, err := s.resolveProviderConfigSecrets(config) + if err != nil { + return map[string]interface{}{"success": false, "message": fmt.Sprintf("连接测试失败: %s", err.Error())} + } + config = resolved } - s.mu.RUnlock() config = normalizeProviderConfig(config) baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/") @@ -842,13 +919,35 @@ func (s *Service) getActiveProvider() (provider.Provider, error) { // --- 配置持久化 --- +const aiConfigSchemaVersion = 2 + type aiConfig struct { + SchemaVersion int `json:"schemaVersion,omitempty"` Providers []ai.ProviderConfig `json:"providers"` ActiveProvider string `json:"activeProvider"` SafetyLevel string `json:"safetyLevel"` ContextLevel string `json:"contextLevel"` } +func (s *Service) loadRuntimeProviderConfig(config ai.ProviderConfig) (ai.ProviderConfig, bool, error) { + meta, bundle := splitProviderSecrets(config) + if bundle.hasAny() { + storedMeta, err := s.persistProviderSecretBundle(meta, bundle) + if err != nil { + meta.HasSecret = false + meta.SecretRef = "" + return meta, true, err + } + return mergeProviderSecrets(storedMeta, bundle), true, nil + } + + resolved, err := s.resolveProviderConfigSecrets(meta) + if err != nil { + return meta, false, err + } + return resolved, false, nil +} + func (s *Service) loadConfig() { path := filepath.Join(s.configDir, "ai_config.json") data, err := os.ReadFile(path) @@ -862,13 +961,22 @@ func (s *Service) loadConfig() { return } - s.providers = cfg.Providers - if s.providers == nil { - s.providers = make([]ai.ProviderConfig, 0) + providers := make([]ai.ProviderConfig, 0, len(cfg.Providers)) + shouldRewrite := cfg.SchemaVersion != aiConfigSchemaVersion + for _, providerConfig := range cfg.Providers { + runtimeConfig, rewritten, err := s.loadRuntimeProviderConfig(normalizeProviderConfig(providerConfig)) + if err != nil { + logger.Error(err, "加载 AI Provider secret 失败,provider=%s", providerConfig.ID) + } + if rewritten { + shouldRewrite = true + } + providers = append(providers, runtimeConfig) } - for i := range s.providers { - s.providers[i] = normalizeProviderConfig(s.providers[i]) + if providers == nil { + providers = make([]ai.ProviderConfig, 0) } + s.providers = providers s.activeProvider = cfg.ActiveProvider switch ai.SQLPermissionLevel(cfg.SafetyLevel) { @@ -885,11 +993,23 @@ func (s *Service) loadConfig() { default: s.contextLevel = ai.ContextSchemaOnly } + + if shouldRewrite { + if err := s.saveConfig(); err != nil { + logger.Error(err, "重写 AI 配置失败") + } + } } func (s *Service) saveConfig() error { + providers := make([]ai.ProviderConfig, len(s.providers)) + for i := range s.providers { + providers[i] = providerMetadataView(s.providers[i]) + } + cfg := aiConfig{ - Providers: s.providers, + SchemaVersion: aiConfigSchemaVersion, + Providers: providers, ActiveProvider: s.activeProvider, SafetyLevel: string(s.safetyLevel), ContextLevel: string(s.contextLevel), diff --git a/internal/ai/types.go b/internal/ai/types.go index 5f4ddae..790b023 100644 --- a/internal/ai/types.go +++ b/internal/ai/types.go @@ -69,6 +69,8 @@ type ProviderConfig struct { Type string `json:"type"` // openai | anthropic | gemini | custom Name string `json:"name"` APIKey string `json:"apiKey"` + SecretRef string `json:"secretRef,omitempty"` + HasSecret bool `json:"hasSecret,omitempty"` BaseURL string `json:"baseUrl"` Model string `json:"model"` Models []string `json:"models,omitempty"` From b5e8f5c022ef43eda00ec5932139828810904e2d Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Fri, 3 Apr 2026 01:04:15 +0800 Subject: [PATCH 07/14] =?UTF-8?q?=E2=9C=A8=20feat(security):=20=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=E8=BF=9E=E6=8E=A5=E9=85=8D=E7=BD=AE=E4=B8=8E=E4=BB=A3?= =?UTF-8?q?=E7=90=86=E7=9A=84=E5=AF=86=E9=92=A5=E4=BB=93=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/app/app.go | 40 +- internal/app/app_cache_key_test.go | 19 + internal/app/connection_secret_resolution.go | 71 ++++ .../app/connection_secret_resolution_test.go | 42 ++ internal/app/global_proxy_persistence.go | 208 +++++++++ internal/app/global_proxy_secret_test.go | 66 +++ internal/app/saved_connections.go | 395 ++++++++++++++++++ internal/app/saved_connections_test.go | 72 ++++ internal/connection/saved_types.go | 46 ++ internal/connection/types.go | 1 + 10 files changed, 957 insertions(+), 3 deletions(-) create mode 100644 internal/app/connection_secret_resolution.go create mode 100644 internal/app/connection_secret_resolution_test.go create mode 100644 internal/app/global_proxy_persistence.go create mode 100644 internal/app/global_proxy_secret_test.go create mode 100644 internal/app/saved_connections.go create mode 100644 internal/app/saved_connections_test.go create mode 100644 internal/connection/saved_types.go diff --git a/internal/app/app.go b/internal/app/app.go index 861d175..3f17b86 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -18,6 +18,7 @@ import ( "GoNavi-Wails/internal/db" "GoNavi-Wails/internal/logger" proxytunnel "GoNavi-Wails/internal/proxy" + "GoNavi-Wails/internal/secretstore" "github.com/google/uuid" ) @@ -53,14 +54,25 @@ type App struct { updateMu sync.Mutex updateState updateState queryMu sync.RWMutex + configDir string + secretStore secretstore.SecretStore runningQueries map[string]queryContext // queryID -> cancelFunc and start time } // NewApp creates a new App application struct func NewApp() *App { + return NewAppWithSecretStore(secretstore.NewKeyringStore()) +} + +func NewAppWithSecretStore(store secretstore.SecretStore) *App { + if store == nil { + store = secretstore.NewUnavailableStore("secret store unavailable") + } return &App{ dbCache: make(map[string]cachedDatabase), runningQueries: make(map[string]queryContext), + configDir: resolveAppConfigDir(), + secretStore: store, } } @@ -74,7 +86,11 @@ func InitializeLifecycle(a *App, ctx context.Context) { func (a *App) startup(ctx context.Context) { a.ctx = ctx a.startedAt = time.Now() + if strings.TrimSpace(a.configDir) == "" { + a.configDir = resolveAppConfigDir() + } logger.Init() + a.loadPersistedGlobalProxy() applyMacWindowTranslucencyFix() logger.Infof("应用启动完成(首次连接保护窗口=%s,最多重试=%d 次)", startupConnectRetryWindow, startupConnectRetryAttempts) } @@ -111,6 +127,7 @@ func (a *App) Shutdown(ctx context.Context) { func normalizeCacheKeyConfig(config connection.ConnectionConfig) connection.ConnectionConfig { normalized := config + normalized.ID = "" normalized.Type = strings.ToLower(strings.TrimSpace(normalized.Type)) // timeout 仅用于 Query/Ping 控制,不应作为物理连接复用键的一部分。 normalized.Timeout = 0 @@ -216,6 +233,9 @@ func shouldRefreshCachedConnection(err error) bool { } func (a *App) invalidateCachedDatabase(config connection.ConnectionConfig, reason error) bool { + if resolvedConfig, err := a.resolveConnectionSecrets(config); err == nil { + config = resolvedConfig + } effectiveConfig := applyGlobalProxyToConnection(config) key := getCacheKey(effectiveConfig) shortKey := shortCacheKey(key) @@ -439,7 +459,11 @@ func (a *App) getDatabase(config connection.ConnectionConfig) (db.Database, erro } func (a *App) openDatabaseIsolated(config connection.ConnectionConfig) (db.Database, error) { - effectiveConfig := applyGlobalProxyToConnection(config) + resolvedConfig, err := a.resolveConnectionSecrets(config) + if err != nil { + return nil, wrapConnectError(config, err) + } + effectiveConfig := applyGlobalProxyToConnection(resolvedConfig) if supported, reason := db.DriverRuntimeSupportStatus(effectiveConfig.Type); !supported { if strings.TrimSpace(reason) == "" { reason = fmt.Sprintf("%s 驱动未启用,请先在驱动管理中安装启用", strings.TrimSpace(effectiveConfig.Type)) @@ -465,7 +489,11 @@ func (a *App) openDatabaseIsolated(config connection.ConnectionConfig) (db.Datab } func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing bool) (db.Database, error) { - effectiveConfig := applyGlobalProxyToConnection(config) + resolvedConfig, err := a.resolveConnectionSecrets(config) + if err != nil { + return nil, wrapConnectError(config, err) + } + effectiveConfig := applyGlobalProxyToConnection(resolvedConfig) isFileDB := isFileDatabaseType(effectiveConfig.Type) key := getCacheKey(effectiveConfig) @@ -546,7 +574,7 @@ func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing logger.Infof("未命中文件库连接缓存,开始创建连接:类型=%s 缓存Key=%s", strings.TrimSpace(effectiveConfig.Type), shortKey) } - dbInst, connectedConfig, err := a.connectDatabaseWithStartupRetry(config) + dbInst, connectedConfig, err := a.connectDatabaseWithStartupRetry(resolvedConfig) if err != nil { return nil, err } @@ -581,6 +609,12 @@ func shortenCacheKey(key string) string { } func (a *App) connectDatabaseWithStartupRetry(rawConfig connection.ConnectionConfig) (db.Database, connection.ConnectionConfig, error) { + resolvedConfig, err := a.resolveConnectionSecrets(rawConfig) + if err != nil { + return nil, rawConfig, wrapConnectError(rawConfig, err) + } + rawConfig = resolvedConfig + var lastErr error var lastEffectiveConfig connection.ConnectionConfig diff --git a/internal/app/app_cache_key_test.go b/internal/app/app_cache_key_test.go index ef7714f..26bd175 100644 --- a/internal/app/app_cache_key_test.go +++ b/internal/app/app_cache_key_test.go @@ -24,6 +24,25 @@ func TestGetCacheKey_IgnoreTimeout(t *testing.T) { } } +func TestGetCacheKey_IgnoreConnectionID(t *testing.T) { + base := connection.ConnectionConfig{ + ID: "conn-1", + Type: "mysql", + Host: "127.0.0.1", + Port: 3306, + User: "root", + Password: "root", + } + modified := base + modified.ID = "conn-2" + + left := getCacheKey(base) + right := getCacheKey(modified) + if left != right { + t.Fatalf("expected same cache key when only connection id differs, got %s vs %s", left, right) + } +} + func TestGetCacheKey_DuckDBHostAndDatabaseEquivalent(t *testing.T) { withHost := connection.ConnectionConfig{ Type: "duckdb", diff --git a/internal/app/connection_secret_resolution.go b/internal/app/connection_secret_resolution.go new file mode 100644 index 0000000..14842c1 --- /dev/null +++ b/internal/app/connection_secret_resolution.go @@ -0,0 +1,71 @@ +package app + +import ( + "strings" + + "GoNavi-Wails/internal/connection" +) + +func (a *App) resolveConnectionSecrets(config connection.ConnectionConfig) (connection.ConnectionConfig, error) { + if strings.TrimSpace(config.ID) == "" { + return config, nil + } + + repo := newSavedConnectionRepository(a.configDir, a.secretStore) + view, err := repo.Find(config.ID) + if err != nil { + return config, err + } + + base := config + if connectionMetadataLooksEmpty(base) { + base = view.Config + } + bundle, err := repo.loadSecretBundle(view) + if err != nil { + return base, err + } + resolved := mergeConnectionSecretBundleIntoConfig(base, bundle) + resolved.ID = view.ID + return resolved, nil +} + +func connectionMetadataLooksEmpty(config connection.ConnectionConfig) bool { + return strings.TrimSpace(config.Type) == "" && + strings.TrimSpace(config.Host) == "" && + config.Port == 0 && + strings.TrimSpace(config.User) == "" && + strings.TrimSpace(config.Database) == "" && + strings.TrimSpace(config.DSN) == "" && + strings.TrimSpace(config.URI) == "" && + len(config.Hosts) == 0 +} + +func mergeConnectionSecretBundleIntoConfig(config connection.ConnectionConfig, bundle connectionSecretBundle) connection.ConnectionConfig { + merged := config + if strings.TrimSpace(merged.Password) == "" { + merged.Password = bundle.Password + } + if strings.TrimSpace(merged.SSH.Password) == "" { + merged.SSH.Password = bundle.SSHPassword + } + if strings.TrimSpace(merged.Proxy.Password) == "" { + merged.Proxy.Password = bundle.ProxyPassword + } + if strings.TrimSpace(merged.HTTPTunnel.Password) == "" { + merged.HTTPTunnel.Password = bundle.HTTPTunnelPassword + } + if strings.TrimSpace(merged.MySQLReplicaPassword) == "" { + merged.MySQLReplicaPassword = bundle.MySQLReplicaPassword + } + if strings.TrimSpace(merged.MongoReplicaPassword) == "" { + merged.MongoReplicaPassword = bundle.MongoReplicaPassword + } + if strings.TrimSpace(merged.URI) == "" { + merged.URI = bundle.OpaqueURI + } + if strings.TrimSpace(merged.DSN) == "" { + merged.DSN = bundle.OpaqueDSN + } + return merged +} diff --git a/internal/app/connection_secret_resolution_test.go b/internal/app/connection_secret_resolution_test.go new file mode 100644 index 0000000..a6336ca --- /dev/null +++ b/internal/app/connection_secret_resolution_test.go @@ -0,0 +1,42 @@ +package app + +import ( + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestResolveConnectionConfigByIDLoadsSecretsFromStore(t *testing.T) { + store := newFakeAppSecretStore() + app := NewAppWithSecretStore(store) + app.configDir = t.TempDir() + + repo := newSavedConnectionRepository(app.configDir, store) + view, err := repo.Save(connection.SavedConnectionInput{ + ID: "conn-1", + Name: "Primary", + Config: connection.ConnectionConfig{ + ID: "conn-1", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + Password: "postgres-secret", + DSN: "postgres://user:pass@db.local/app", + }, + }) + if err != nil { + t.Fatalf("Save returned error: %v", err) + } + + resolved, err := app.resolveConnectionSecrets(view.Config) + if err != nil { + t.Fatalf("resolveConnectionSecrets returned error: %v", err) + } + if resolved.Password != "postgres-secret" { + t.Fatalf("expected restored password, got %q", resolved.Password) + } + if resolved.DSN != "postgres://user:pass@db.local/app" { + t.Fatalf("expected restored DSN, got %q", resolved.DSN) + } +} diff --git a/internal/app/global_proxy_persistence.go b/internal/app/global_proxy_persistence.go new file mode 100644 index 0000000..a10a3b3 --- /dev/null +++ b/internal/app/global_proxy_persistence.go @@ -0,0 +1,208 @@ +package app + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/logger" + "GoNavi-Wails/internal/secretstore" +) + +const ( + globalProxyFileName = "global_proxy.json" + globalProxySecretKind = "global-proxy" + globalProxySecretID = "default" +) + +type globalProxySecretBundle struct { + Password string `json:"password,omitempty"` +} + +func globalProxyMetadataPath(configDir string) string { + return filepath.Join(configDir, globalProxyFileName) +} + +func (a *App) saveGlobalProxy(input connection.SaveGlobalProxyInput) (connection.GlobalProxyView, error) { + if strings.TrimSpace(a.configDir) == "" { + a.configDir = resolveAppConfigDir() + } + + existing, err := a.loadStoredGlobalProxyView() + if err != nil && !os.IsNotExist(err) { + return connection.GlobalProxyView{}, err + } + + view := connection.GlobalProxyView{ + Enabled: input.Enabled, + Type: strings.TrimSpace(input.Type), + Host: strings.TrimSpace(input.Host), + Port: input.Port, + User: strings.TrimSpace(input.User), + } + + bundle := globalProxySecretBundle{} + if strings.TrimSpace(input.Password) != "" { + bundle.Password = input.Password + } else if existing.HasPassword { + existingBundle, loadErr := a.loadGlobalProxySecretBundle(existing) + if loadErr != nil { + return connection.GlobalProxyView{}, loadErr + } + bundle = existingBundle + view.SecretRef = existing.SecretRef + } + + if !view.Enabled { + if strings.TrimSpace(existing.SecretRef) != "" && a.secretStore != nil { + if deleteErr := a.secretStore.Delete(existing.SecretRef); deleteErr != nil { + return connection.GlobalProxyView{}, deleteErr + } + } + view = connection.GlobalProxyView{Enabled: false} + if err := a.persistGlobalProxyView(view); err != nil { + return connection.GlobalProxyView{}, err + } + if _, err := setGlobalProxyConfig(false, connection.ProxyConfig{}); err != nil { + return connection.GlobalProxyView{}, err + } + return view, nil + } + + if strings.TrimSpace(bundle.Password) != "" { + ref, storeErr := a.storeGlobalProxySecret(view.SecretRef, bundle) + if storeErr != nil { + return connection.GlobalProxyView{}, storeErr + } + view.SecretRef = ref + view.HasPassword = true + } else { + if strings.TrimSpace(existing.SecretRef) != "" && a.secretStore != nil { + if deleteErr := a.secretStore.Delete(existing.SecretRef); deleteErr != nil { + return connection.GlobalProxyView{}, deleteErr + } + } + view.SecretRef = "" + view.HasPassword = false + } + + if err := a.persistGlobalProxyView(view); err != nil { + return connection.GlobalProxyView{}, err + } + if _, err := setGlobalProxyConfig(true, connection.ProxyConfig{ + Type: view.Type, + Host: view.Host, + Port: view.Port, + User: view.User, + Password: bundle.Password, + }); err != nil { + return connection.GlobalProxyView{}, err + } + view.Password = "" + return view, nil +} + +func (a *App) persistGlobalProxyView(view connection.GlobalProxyView) error { + if err := os.MkdirAll(a.configDir, 0o755); err != nil { + return err + } + payload, err := json.MarshalIndent(view, "", " ") + if err != nil { + return err + } + return os.WriteFile(globalProxyMetadataPath(a.configDir), payload, 0o644) +} + +func (a *App) loadStoredGlobalProxyView() (connection.GlobalProxyView, error) { + data, err := os.ReadFile(globalProxyMetadataPath(a.configDir)) + if err != nil { + return connection.GlobalProxyView{}, err + } + var view connection.GlobalProxyView + if err := json.Unmarshal(data, &view); err != nil { + return connection.GlobalProxyView{}, err + } + return view, nil +} + +func (a *App) loadGlobalProxySecretBundle(view connection.GlobalProxyView) (globalProxySecretBundle, error) { + if !view.HasPassword { + return globalProxySecretBundle{}, nil + } + if a.secretStore == nil { + return globalProxySecretBundle{}, fmt.Errorf("secret store unavailable") + } + ref := strings.TrimSpace(view.SecretRef) + if ref == "" { + var err error + ref, err = secretstore.BuildRef(globalProxySecretKind, globalProxySecretID) + if err != nil { + return globalProxySecretBundle{}, err + } + } + payload, err := a.secretStore.Get(ref) + if err != nil { + return globalProxySecretBundle{}, err + } + var bundle globalProxySecretBundle + if err := json.Unmarshal(payload, &bundle); err != nil { + return globalProxySecretBundle{}, err + } + return bundle, nil +} + +func (a *App) storeGlobalProxySecret(existingRef string, bundle globalProxySecretBundle) (string, error) { + if a.secretStore == nil { + return "", fmt.Errorf("secret store unavailable") + } + if err := a.secretStore.HealthCheck(); err != nil { + return "", err + } + ref := strings.TrimSpace(existingRef) + if ref == "" { + var err error + ref, err = secretstore.BuildRef(globalProxySecretKind, globalProxySecretID) + if err != nil { + return "", err + } + } + payload, err := json.Marshal(bundle) + if err != nil { + return "", err + } + if err := a.secretStore.Put(ref, payload); err != nil { + return "", err + } + return ref, nil +} + +func (a *App) loadPersistedGlobalProxy() { + view, err := a.loadStoredGlobalProxyView() + if err != nil { + if !os.IsNotExist(err) { + logger.Error(err, "加载全局代理元数据失败") + } + return + } + + proxyConfig := connection.ProxyConfig{ + Type: view.Type, + Host: view.Host, + Port: view.Port, + User: view.User, + } + if view.HasPassword { + bundle, loadErr := a.loadGlobalProxySecretBundle(view) + if loadErr != nil { + logger.Error(loadErr, "加载全局代理密码失败") + return + } + proxyConfig.Password = bundle.Password + } + if _, err := setGlobalProxyConfig(view.Enabled, proxyConfig); err != nil { + logger.Error(err, "恢复全局代理配置失败") + } +} diff --git a/internal/app/global_proxy_secret_test.go b/internal/app/global_proxy_secret_test.go new file mode 100644 index 0000000..177e949 --- /dev/null +++ b/internal/app/global_proxy_secret_test.go @@ -0,0 +1,66 @@ +package app + +import ( + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestSaveGlobalProxyStripsPasswordFromView(t *testing.T) { + store := newFakeAppSecretStore() + app := NewAppWithSecretStore(store) + app.configDir = t.TempDir() + + view, err := app.saveGlobalProxy(connection.SaveGlobalProxyInput{ + Enabled: true, + Type: "http", + Host: "127.0.0.1", + Port: 8080, + User: "ops", + Password: "proxy-secret", + }) + if err != nil { + t.Fatalf("saveGlobalProxy returned error: %v", err) + } + if view.Password != "" { + t.Fatal("global proxy view must not expose plaintext password") + } + if !view.HasPassword { + t.Fatal("expected hasPassword=true") + } + + snapshot := currentGlobalProxyConfig() + if snapshot.Proxy.Password != "proxy-secret" { + t.Fatalf("expected runtime proxy password to be preserved, got %q", snapshot.Proxy.Password) + } +} + +func TestGetGlobalProxyConfigReturnsSecretlessView(t *testing.T) { + store := newFakeAppSecretStore() + app := NewAppWithSecretStore(store) + app.configDir = t.TempDir() + + if _, err := app.saveGlobalProxy(connection.SaveGlobalProxyInput{ + Enabled: true, + Type: "http", + Host: "127.0.0.1", + Port: 8080, + User: "ops", + Password: "proxy-secret", + }); err != nil { + t.Fatalf("saveGlobalProxy returned error: %v", err) + } + + result := app.GetGlobalProxyConfig() + view, ok := result.Data.(connection.GlobalProxyView) + if !ok { + t.Fatalf("expected GlobalProxyView, got %T", result.Data) + } + if view.Password != "" { + t.Fatal("GetGlobalProxyConfig must not expose plaintext password") + } + if !view.HasPassword { + t.Fatal("expected hasPassword=true") + } +} + diff --git a/internal/app/saved_connections.go b/internal/app/saved_connections.go new file mode 100644 index 0000000..19eb76d --- /dev/null +++ b/internal/app/saved_connections.go @@ -0,0 +1,395 @@ +package app + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/secretstore" + "github.com/google/uuid" +) + +const ( + savedConnectionsFileName = "connections.json" + savedConnectionSecretKind = "connection" +) + +type connectionSecretBundle struct { + Password string `json:"password,omitempty"` + SSHPassword string `json:"sshPassword,omitempty"` + ProxyPassword string `json:"proxyPassword,omitempty"` + HTTPTunnelPassword string `json:"httpTunnelPassword,omitempty"` + MySQLReplicaPassword string `json:"mysqlReplicaPassword,omitempty"` + MongoReplicaPassword string `json:"mongoReplicaPassword,omitempty"` + OpaqueURI string `json:"opaqueURI,omitempty"` + OpaqueDSN string `json:"opaqueDSN,omitempty"` +} + +type savedConnectionsFile struct { + Connections []connection.SavedConnectionView `json:"connections"` +} + +type savedConnectionRepository struct { + configDir string + secretStore secretstore.SecretStore +} + +func resolveAppConfigDir() string { + homeDir, err := os.UserHomeDir() + if err != nil || strings.TrimSpace(homeDir) == "" { + return "." + } + return filepath.Join(homeDir, ".gonavi") +} + +func newSavedConnectionRepository(configDir string, store secretstore.SecretStore) *savedConnectionRepository { + if strings.TrimSpace(configDir) == "" { + configDir = resolveAppConfigDir() + } + if store == nil { + store = secretstore.NewUnavailableStore("secret store unavailable") + } + return &savedConnectionRepository{configDir: configDir, secretStore: store} +} + +func (b connectionSecretBundle) hasAny() bool { + return strings.TrimSpace(b.Password) != "" || + strings.TrimSpace(b.SSHPassword) != "" || + strings.TrimSpace(b.ProxyPassword) != "" || + strings.TrimSpace(b.HTTPTunnelPassword) != "" || + strings.TrimSpace(b.MySQLReplicaPassword) != "" || + strings.TrimSpace(b.MongoReplicaPassword) != "" || + strings.TrimSpace(b.OpaqueURI) != "" || + strings.TrimSpace(b.OpaqueDSN) != "" +} + +func mergeConnectionSecretBundles(base, overlay connectionSecretBundle) connectionSecretBundle { + merged := base + if strings.TrimSpace(overlay.Password) != "" { + merged.Password = overlay.Password + } + if strings.TrimSpace(overlay.SSHPassword) != "" { + merged.SSHPassword = overlay.SSHPassword + } + if strings.TrimSpace(overlay.ProxyPassword) != "" { + merged.ProxyPassword = overlay.ProxyPassword + } + if strings.TrimSpace(overlay.HTTPTunnelPassword) != "" { + merged.HTTPTunnelPassword = overlay.HTTPTunnelPassword + } + if strings.TrimSpace(overlay.MySQLReplicaPassword) != "" { + merged.MySQLReplicaPassword = overlay.MySQLReplicaPassword + } + if strings.TrimSpace(overlay.MongoReplicaPassword) != "" { + merged.MongoReplicaPassword = overlay.MongoReplicaPassword + } + if strings.TrimSpace(overlay.OpaqueURI) != "" { + merged.OpaqueURI = overlay.OpaqueURI + } + if strings.TrimSpace(overlay.OpaqueDSN) != "" { + merged.OpaqueDSN = overlay.OpaqueDSN + } + return merged +} + +func splitConnectionSecrets(input connection.SavedConnectionInput) (connection.SavedConnectionView, connectionSecretBundle) { + id := strings.TrimSpace(input.ID) + if id == "" { + id = strings.TrimSpace(input.Config.ID) + } + + meta := input.Config + meta.ID = id + meta.SavePassword = false + + bundle := connectionSecretBundle{} + if strings.TrimSpace(meta.Password) != "" { + bundle.Password = meta.Password + meta.Password = "" + } + if strings.TrimSpace(meta.SSH.Password) != "" { + bundle.SSHPassword = meta.SSH.Password + meta.SSH.Password = "" + } + if strings.TrimSpace(meta.Proxy.Password) != "" { + bundle.ProxyPassword = meta.Proxy.Password + meta.Proxy.Password = "" + } + if strings.TrimSpace(meta.HTTPTunnel.Password) != "" { + bundle.HTTPTunnelPassword = meta.HTTPTunnel.Password + meta.HTTPTunnel.Password = "" + } + if strings.TrimSpace(meta.MySQLReplicaPassword) != "" { + bundle.MySQLReplicaPassword = meta.MySQLReplicaPassword + meta.MySQLReplicaPassword = "" + } + if strings.TrimSpace(meta.MongoReplicaPassword) != "" { + bundle.MongoReplicaPassword = meta.MongoReplicaPassword + meta.MongoReplicaPassword = "" + } + if strings.TrimSpace(meta.URI) != "" { + bundle.OpaqueURI = meta.URI + meta.URI = "" + } + if strings.TrimSpace(meta.DSN) != "" { + bundle.OpaqueDSN = meta.DSN + meta.DSN = "" + } + + view := connection.SavedConnectionView{ + ID: id, + Name: strings.TrimSpace(input.Name), + Config: meta, + HasPrimaryPassword: strings.TrimSpace(bundle.Password) != "", + HasSSHPassword: strings.TrimSpace(bundle.SSHPassword) != "", + HasProxyPassword: strings.TrimSpace(bundle.ProxyPassword) != "", + HasHTTPTunnelPassword: strings.TrimSpace(bundle.HTTPTunnelPassword) != "", + HasMySQLReplicaPassword: strings.TrimSpace(bundle.MySQLReplicaPassword) != "", + HasMongoReplicaPassword: strings.TrimSpace(bundle.MongoReplicaPassword) != "", + HasOpaqueURI: strings.TrimSpace(bundle.OpaqueURI) != "", + HasOpaqueDSN: strings.TrimSpace(bundle.OpaqueDSN) != "", + } + return view, bundle +} + +func (r *savedConnectionRepository) connectionsPath() string { + return filepath.Join(r.configDir, savedConnectionsFileName) +} + +func (r *savedConnectionRepository) load() ([]connection.SavedConnectionView, error) { + data, err := os.ReadFile(r.connectionsPath()) + if err != nil { + if os.IsNotExist(err) { + return []connection.SavedConnectionView{}, nil + } + return nil, err + } + + var file savedConnectionsFile + if err := json.Unmarshal(data, &file); err != nil { + return nil, err + } + if file.Connections == nil { + return []connection.SavedConnectionView{}, nil + } + return file.Connections, nil +} + +func (r *savedConnectionRepository) saveAll(connections []connection.SavedConnectionView) error { + if err := os.MkdirAll(r.configDir, 0o755); err != nil { + return err + } + payload, err := json.MarshalIndent(savedConnectionsFile{Connections: connections}, "", " ") + if err != nil { + return err + } + return os.WriteFile(r.connectionsPath(), payload, 0o644) +} + +func (r *savedConnectionRepository) Save(input connection.SavedConnectionInput) (connection.SavedConnectionView, error) { + if strings.TrimSpace(input.ID) == "" && strings.TrimSpace(input.Config.ID) == "" { + input.ID = "conn-" + uuid.New().String()[:8] + } + if strings.TrimSpace(input.ID) == "" { + input.ID = strings.TrimSpace(input.Config.ID) + } + input.Config.ID = input.ID + + connections, err := r.load() + if err != nil { + return connection.SavedConnectionView{}, err + } + + view, bundle := splitConnectionSecrets(input) + index := -1 + var existing connection.SavedConnectionView + for i, item := range connections { + if item.ID == view.ID { + index = i + existing = item + break + } + } + + mergedBundle := bundle + if index >= 0 && savedConnectionViewHasSecrets(existing) { + existingBundle, bundleErr := r.loadSecretBundle(existing) + if bundleErr != nil { + return connection.SavedConnectionView{}, bundleErr + } + mergedBundle = mergeConnectionSecretBundles(existingBundle, bundle) + view.SecretRef = existing.SecretRef + } + + if mergedBundle.hasAny() { + ref, storeErr := r.storeSecretBundle(view.ID, view.SecretRef, mergedBundle) + if storeErr != nil { + return connection.SavedConnectionView{}, storeErr + } + view.SecretRef = ref + applyConnectionBundleFlags(&view, mergedBundle) + } else { + if index >= 0 && strings.TrimSpace(existing.SecretRef) != "" { + if deleteErr := r.secretStore.Delete(existing.SecretRef); deleteErr != nil { + return connection.SavedConnectionView{}, deleteErr + } + } + view.SecretRef = "" + applyConnectionBundleFlags(&view, connectionSecretBundle{}) + } + + if index >= 0 { + connections[index] = view + } else { + connections = append(connections, view) + } + if err := r.saveAll(connections); err != nil { + return connection.SavedConnectionView{}, err + } + return view, nil +} + +func (r *savedConnectionRepository) Find(id string) (connection.SavedConnectionView, error) { + connections, err := r.load() + if err != nil { + return connection.SavedConnectionView{}, err + } + for _, item := range connections { + if item.ID == strings.TrimSpace(id) { + return item, nil + } + } + return connection.SavedConnectionView{}, fmt.Errorf("saved connection not found: %s", id) +} + +func (r *savedConnectionRepository) storeSecretBundle(id string, existingRef string, bundle connectionSecretBundle) (string, error) { + if r.secretStore == nil { + return "", fmt.Errorf("secret store unavailable") + } + if err := r.secretStore.HealthCheck(); err != nil { + return "", err + } + ref := strings.TrimSpace(existingRef) + if ref == "" { + var err error + ref, err = secretstore.BuildRef(savedConnectionSecretKind, id) + if err != nil { + return "", err + } + } + payload, err := json.Marshal(bundle) + if err != nil { + return "", err + } + if err := r.secretStore.Put(ref, payload); err != nil { + return "", err + } + return ref, nil +} + +func (r *savedConnectionRepository) loadSecretBundle(view connection.SavedConnectionView) (connectionSecretBundle, error) { + if !savedConnectionViewHasSecrets(view) { + return connectionSecretBundle{}, nil + } + if r.secretStore == nil { + return connectionSecretBundle{}, fmt.Errorf("secret store unavailable") + } + ref := strings.TrimSpace(view.SecretRef) + if ref == "" { + var err error + ref, err = secretstore.BuildRef(savedConnectionSecretKind, view.ID) + if err != nil { + return connectionSecretBundle{}, err + } + } + payload, err := r.secretStore.Get(ref) + if err != nil { + return connectionSecretBundle{}, err + } + var bundle connectionSecretBundle + if err := json.Unmarshal(payload, &bundle); err != nil { + return connectionSecretBundle{}, err + } + return bundle, nil +} + +func savedConnectionViewHasSecrets(view connection.SavedConnectionView) bool { + return view.HasPrimaryPassword || view.HasSSHPassword || view.HasProxyPassword || view.HasHTTPTunnelPassword || + view.HasMySQLReplicaPassword || view.HasMongoReplicaPassword || view.HasOpaqueURI || view.HasOpaqueDSN +} + +func applyConnectionBundleFlags(view *connection.SavedConnectionView, bundle connectionSecretBundle) { + view.HasPrimaryPassword = strings.TrimSpace(bundle.Password) != "" + view.HasSSHPassword = strings.TrimSpace(bundle.SSHPassword) != "" + view.HasProxyPassword = strings.TrimSpace(bundle.ProxyPassword) != "" + view.HasHTTPTunnelPassword = strings.TrimSpace(bundle.HTTPTunnelPassword) != "" + view.HasMySQLReplicaPassword = strings.TrimSpace(bundle.MySQLReplicaPassword) != "" + view.HasMongoReplicaPassword = strings.TrimSpace(bundle.MongoReplicaPassword) != "" + view.HasOpaqueURI = strings.TrimSpace(bundle.OpaqueURI) != "" + view.HasOpaqueDSN = strings.TrimSpace(bundle.OpaqueDSN) != "" +} + +func (r *savedConnectionRepository) List() ([]connection.SavedConnectionView, error) { + return r.load() +} + +func (r *savedConnectionRepository) Delete(id string) error { + connections, err := r.load() + if err != nil { + return err + } + filtered := make([]connection.SavedConnectionView, 0, len(connections)) + for _, item := range connections { + if item.ID == strings.TrimSpace(id) { + if strings.TrimSpace(item.SecretRef) != "" && r.secretStore != nil { + if deleteErr := r.secretStore.Delete(item.SecretRef); deleteErr != nil { + return deleteErr + } + } + continue + } + filtered = append(filtered, item) + } + return r.saveAll(filtered) +} + +func (r *savedConnectionRepository) Duplicate(id string) (connection.SavedConnectionView, error) { + original, err := r.Find(id) + if err != nil { + return connection.SavedConnectionView{}, err + } + + duplicate := original + duplicate.ID = "conn-" + uuid.New().String()[:8] + duplicate.Config.ID = duplicate.ID + duplicate.Name = original.Name + " Copy" + + bundle, err := r.loadSecretBundle(original) + if err != nil { + return connection.SavedConnectionView{}, err + } + if bundle.hasAny() { + ref, storeErr := r.storeSecretBundle(duplicate.ID, "", bundle) + if storeErr != nil { + return connection.SavedConnectionView{}, storeErr + } + duplicate.SecretRef = ref + applyConnectionBundleFlags(&duplicate, bundle) + } else { + duplicate.SecretRef = "" + applyConnectionBundleFlags(&duplicate, connectionSecretBundle{}) + } + + connections, err := r.load() + if err != nil { + return connection.SavedConnectionView{}, err + } + connections = append(connections, duplicate) + if err := r.saveAll(connections); err != nil { + return connection.SavedConnectionView{}, err + } + return duplicate, nil +} diff --git a/internal/app/saved_connections_test.go b/internal/app/saved_connections_test.go new file mode 100644 index 0000000..3a81e76 --- /dev/null +++ b/internal/app/saved_connections_test.go @@ -0,0 +1,72 @@ +package app + +import ( + "os" + "testing" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/secretstore" +) + +func TestSplitConnectionSecretsStripsPasswordsAndOpaqueDSN(t *testing.T) { + input := connection.SavedConnectionInput{ + ID: "conn-1", + Name: "Primary", + Config: connection.ConnectionConfig{ + ID: "conn-1", + Type: "postgres", + Host: "db.local", + Password: "postgres-secret", + DSN: "postgres://user:pass@db.local/app", + }, + } + + view, bundle := splitConnectionSecrets(input) + if view.Config.Password != "" { + t.Fatal("metadata must not keep password") + } + if bundle.Password != "postgres-secret" { + t.Fatal("bundle should keep primary password") + } + if bundle.OpaqueDSN == "" { + t.Fatal("opaque DSN should be stored as secret") + } + if !view.HasPrimaryPassword { + t.Fatal("expected view to report primary password") + } + if !view.HasOpaqueDSN { + t.Fatal("expected view to report opaque DSN") + } +} + +type fakeAppSecretStore struct { + items map[string][]byte +} + +func newFakeAppSecretStore() *fakeAppSecretStore { + return &fakeAppSecretStore{items: make(map[string][]byte)} +} + +func (s *fakeAppSecretStore) Put(ref string, payload []byte) error { + s.items[ref] = append([]byte(nil), payload...) + return nil +} + +func (s *fakeAppSecretStore) Get(ref string) ([]byte, error) { + payload, ok := s.items[ref] + if !ok { + return nil, os.ErrNotExist + } + return append([]byte(nil), payload...), nil +} + +func (s *fakeAppSecretStore) Delete(ref string) error { + delete(s.items, ref) + return nil +} + +func (s *fakeAppSecretStore) HealthCheck() error { + return nil +} + +var _ secretstore.SecretStore = (*fakeAppSecretStore)(nil) diff --git a/internal/connection/saved_types.go b/internal/connection/saved_types.go new file mode 100644 index 0000000..a364a50 --- /dev/null +++ b/internal/connection/saved_types.go @@ -0,0 +1,46 @@ +package connection + +type SavedConnectionInput struct { + ID string `json:"id,omitempty"` + Name string `json:"name"` + Config ConnectionConfig `json:"config"` +} + +type SavedConnectionView struct { + ID string `json:"id"` + Name string `json:"name"` + Config ConnectionConfig `json:"config"` + SecretRef string `json:"secretRef,omitempty"` + HasPrimaryPassword bool `json:"hasPrimaryPassword,omitempty"` + HasSSHPassword bool `json:"hasSSHPassword,omitempty"` + HasProxyPassword bool `json:"hasProxyPassword,omitempty"` + HasHTTPTunnelPassword bool `json:"hasHttpTunnelPassword,omitempty"` + HasMySQLReplicaPassword bool `json:"hasMySQLReplicaPassword,omitempty"` + HasMongoReplicaPassword bool `json:"hasMongoReplicaPassword,omitempty"` + HasOpaqueURI bool `json:"hasOpaqueURI,omitempty"` + HasOpaqueDSN bool `json:"hasOpaqueDSN,omitempty"` +} + +type LegacySavedConnection = SavedConnectionInput + +type SaveGlobalProxyInput struct { + Enabled bool `json:"enabled"` + Type string `json:"type"` + Host string `json:"host"` + Port int `json:"port"` + User string `json:"user,omitempty"` + Password string `json:"password,omitempty"` +} + +type GlobalProxyView struct { + Enabled bool `json:"enabled"` + Type string `json:"type"` + Host string `json:"host"` + Port int `json:"port"` + User string `json:"user,omitempty"` + Password string `json:"password,omitempty"` + HasPassword bool `json:"hasPassword,omitempty"` + SecretRef string `json:"secretRef,omitempty"` +} + +type LegacyGlobalProxyInput = SaveGlobalProxyInput diff --git a/internal/connection/types.go b/internal/connection/types.go index 2c1a22b..e6e770e 100644 --- a/internal/connection/types.go +++ b/internal/connection/types.go @@ -28,6 +28,7 @@ type HTTPTunnelConfig struct { // ConnectionConfig 存储数据库连接的完整配置,包括 SSH、代理、SSL 等网络层设置。 type ConnectionConfig struct { + ID string `json:"id,omitempty"` Type string `json:"type"` Host string `json:"host"` Port int `json:"port"` From 263db6bf30c4c0d0bc94c677b54d5c93937b3e9e Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Fri, 3 Apr 2026 01:04:42 +0800 Subject: [PATCH 08/14] =?UTF-8?q?=E2=9C=A8=20feat(security):=20=E6=9A=B4?= =?UTF-8?q?=E9=9C=B2=E8=BF=9E=E6=8E=A5=E9=85=8D=E7=BD=AE=E4=B8=8E=E4=BB=A3?= =?UTF-8?q?=E7=90=86=E7=9A=84=E5=AF=86=E9=92=A5=E5=AD=98=E5=82=A8=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/app/global_proxy.go | 18 +++- internal/app/methods_saved_connections.go | 44 +++++++++ .../app/methods_saved_connections_test.go | 94 +++++++++++++++++++ 3 files changed, 155 insertions(+), 1 deletion(-) create mode 100644 internal/app/methods_saved_connections.go create mode 100644 internal/app/methods_saved_connections_test.go diff --git a/internal/app/global_proxy.go b/internal/app/global_proxy.go index 016fb26..ffb5b35 100644 --- a/internal/app/global_proxy.go +++ b/internal/app/global_proxy.go @@ -123,11 +123,26 @@ func proxyConfigEqual(a, b connection.ProxyConfig) bool { a.Password == b.Password } +func currentGlobalProxyView() connection.GlobalProxyView { + snapshot := currentGlobalProxyConfig() + if !snapshot.Enabled { + return connection.GlobalProxyView{Enabled: false} + } + return connection.GlobalProxyView{ + Enabled: true, + Type: snapshot.Proxy.Type, + Host: snapshot.Proxy.Host, + Port: snapshot.Proxy.Port, + User: snapshot.Proxy.User, + HasPassword: strings.TrimSpace(snapshot.Proxy.Password) != "", + } +} + func (a *App) GetGlobalProxyConfig() connection.QueryResult { return connection.QueryResult{ Success: true, Message: "OK", - Data: currentGlobalProxyConfig(), + Data: currentGlobalProxyView(), } } @@ -312,3 +327,4 @@ func buildProxyURLFromConfig(proxyConfig connection.ProxyConfig) (*url.URL, erro } return proxyURL, nil } + diff --git a/internal/app/methods_saved_connections.go b/internal/app/methods_saved_connections.go new file mode 100644 index 0000000..d8d916d --- /dev/null +++ b/internal/app/methods_saved_connections.go @@ -0,0 +1,44 @@ +package app + +import "GoNavi-Wails/internal/connection" + +func (a *App) savedConnectionRepository() *savedConnectionRepository { + return newSavedConnectionRepository(a.configDir, a.secretStore) +} + +func (a *App) GetSavedConnections() ([]connection.SavedConnectionView, error) { + return a.savedConnectionRepository().List() +} + +func (a *App) SaveConnection(input connection.SavedConnectionInput) (connection.SavedConnectionView, error) { + return a.savedConnectionRepository().Save(input) +} + +func (a *App) DeleteConnection(id string) error { + return a.savedConnectionRepository().Delete(id) +} + +func (a *App) DuplicateConnection(id string) (connection.SavedConnectionView, error) { + return a.savedConnectionRepository().Duplicate(id) +} + +func (a *App) ImportLegacyConnections(items []connection.LegacySavedConnection) ([]connection.SavedConnectionView, error) { + result := make([]connection.SavedConnectionView, 0, len(items)) + repo := a.savedConnectionRepository() + for _, item := range items { + view, err := repo.Save(connection.SavedConnectionInput(item)) + if err != nil { + return nil, err + } + result = append(result, view) + } + return result, nil +} + +func (a *App) SaveGlobalProxy(input connection.SaveGlobalProxyInput) (connection.GlobalProxyView, error) { + return a.saveGlobalProxy(input) +} + +func (a *App) ImportLegacyGlobalProxy(input connection.LegacyGlobalProxyInput) (connection.GlobalProxyView, error) { + return a.saveGlobalProxy(connection.SaveGlobalProxyInput(input)) +} diff --git a/internal/app/methods_saved_connections_test.go b/internal/app/methods_saved_connections_test.go new file mode 100644 index 0000000..4b117cc --- /dev/null +++ b/internal/app/methods_saved_connections_test.go @@ -0,0 +1,94 @@ +package app + +import ( + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestSaveConnectionMethodReturnsSecretlessView(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + result, err := app.SaveConnection(connection.SavedConnectionInput{ + ID: "conn-1", + Name: "Primary", + Config: connection.ConnectionConfig{ + ID: "conn-1", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + Password: "postgres-secret", + }, + }) + if err != nil { + t.Fatal(err) + } + if result.Config.Password != "" { + t.Fatal("SaveConnection must not return plaintext password") + } + if !result.HasPrimaryPassword { + t.Fatal("expected HasPrimaryPassword=true") + } +} + +func TestDuplicateConnectionClonesSecretBundle(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + _, err := app.SaveConnection(connection.SavedConnectionInput{ + ID: "conn-1", + Name: "Primary", + Config: connection.ConnectionConfig{ + ID: "conn-1", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + Password: "postgres-secret", + }, + }) + if err != nil { + t.Fatal(err) + } + + duplicate, err := app.DuplicateConnection("conn-1") + if err != nil { + t.Fatal(err) + } + if duplicate.ID == "conn-1" { + t.Fatal("duplicate should have a new id") + } + + resolved, err := app.resolveConnectionSecrets(duplicate.Config) + if err != nil { + t.Fatal(err) + } + if resolved.Password != "postgres-secret" { + t.Fatalf("expected duplicated secret bundle, got %q", resolved.Password) + } +} + +func TestSaveGlobalProxyReturnsSecretlessView(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + view, err := app.SaveGlobalProxy(connection.SaveGlobalProxyInput{ + Enabled: true, + Type: "http", + Host: "127.0.0.1", + Port: 8080, + User: "ops", + Password: "proxy-secret", + }) + if err != nil { + t.Fatal(err) + } + if view.Password != "" { + t.Fatal("global proxy view must not expose plaintext password") + } + if !view.HasPassword { + t.Fatal("expected hasPassword=true") + } +} From c842201bf47573145580703c5eef9e1af226f6dd Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Fri, 3 Apr 2026 08:06:43 +0800 Subject: [PATCH 09/14] =?UTF-8?q?=E2=9C=A8=20feat(security):=20=E5=89=8D?= =?UTF-8?q?=E7=AB=AF=E7=8A=B6=E6=80=81=E8=BF=81=E7=A7=BB=E8=87=B3=E6=97=A0?= =?UTF-8?q?=E6=98=8E=E6=96=87=E5=AF=86=E9=92=A5=E5=AD=98=E5=82=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- frontend/src/App.tsx | 167 +++++++++++++----- frontend/src/store.ts | 47 +++-- frontend/src/types.ts | 18 ++ frontend/src/utils/globalProxyDraft.test.ts | 35 ++++ frontend/src/utils/globalProxyDraft.ts | 62 +++++++ .../src/utils/legacyConnectionStorage.test.ts | 75 ++++++++ frontend/src/utils/legacyConnectionStorage.ts | 110 ++++++++++++ frontend/src/utils/startupReadiness.test.ts | 5 +- frontend/src/utils/startupReadiness.ts | 3 +- 9 files changed, 460 insertions(+), 62 deletions(-) create mode 100644 frontend/src/utils/globalProxyDraft.test.ts create mode 100644 frontend/src/utils/globalProxyDraft.ts create mode 100644 frontend/src/utils/legacyConnectionStorage.test.ts create mode 100644 frontend/src/utils/legacyConnectionStorage.ts diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 00cbaaa..c285de3 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -17,6 +17,8 @@ import { blurToFilter, normalizeBlurForPlatform, normalizeOpacityForPlatform, is import { getMacNativeTitlebarPaddingLeft, getMacNativeTitlebarPaddingRight, shouldHandleMacNativeFullscreenShortcut, shouldSuppressMacNativeEscapeExit } from './utils/macWindow'; import { buildOverlayWorkbenchTheme } from './utils/overlayWorkbenchTheme'; import { getConnectionWorkbenchState } from './utils/startupReadiness'; +import { createGlobalProxyDraft, toSaveGlobalProxyInput } from './utils/globalProxyDraft'; +import { LEGACY_PERSIST_KEY, readLegacyPersistedSecrets, stripLegacyPersistedSecrets } from './utils/legacyConnectionStorage'; import { SHORTCUT_ACTION_META, SHORTCUT_ACTION_ORDER, @@ -35,7 +37,7 @@ import { resolveAIEdgeHandleDockStyle, resolveAIEdgeHandleStyle, } from './utils/aiEntryLayout'; -import { ConfigureGlobalProxy, SetMacNativeWindowControls, SetWindowTranslucency } from '../wailsjs/go/app/App'; +import { SetMacNativeWindowControls, SetWindowTranslucency } from '../wailsjs/go/app/App'; import './App.css'; const { Sider, Content } = Layout; @@ -76,6 +78,8 @@ function App() { const setStartupFullscreen = useStore(state => state.setStartupFullscreen); const globalProxy = useStore(state => state.globalProxy); const setGlobalProxy = useStore(state => state.setGlobalProxy); + const replaceConnections = useStore(state => state.replaceConnections); + const replaceGlobalProxy = useStore(state => state.replaceGlobalProxy); const shortcutOptions = useStore(state => state.shortcutOptions); const updateShortcut = useStore(state => state.updateShortcut); const resetShortcutOptions = useStore(state => state.resetShortcutOptions); @@ -100,14 +104,14 @@ function App() { const [runtimePlatform, setRuntimePlatform] = useState(''); const [isLinuxRuntime, setIsLinuxRuntime] = useState(false); const [isStoreHydrated, setIsStoreHydrated] = useState(() => useStore.persist.hasHydrated()); - const [hasAppliedInitialGlobalProxy, setHasAppliedInitialGlobalProxy] = useState(false); + const [hasLoadedSecureConfig, setHasLoadedSecureConfig] = useState(false); const sidebarWidth = useStore(state => state.sidebarWidth); const setSidebarWidth = useStore(state => state.setSidebarWidth); const aiPanelVisible = useStore(state => state.aiPanelVisible); const toggleAIPanel = useStore(state => state.toggleAIPanel); const setAIPanelVisible = useStore(state => state.setAIPanelVisible); const globalProxyInvalidHintShownRef = React.useRef(false); - const connectionWorkbenchState = getConnectionWorkbenchState(isStoreHydrated, hasAppliedInitialGlobalProxy); + const connectionWorkbenchState = getConnectionWorkbenchState(isStoreHydrated, hasLoadedSecureConfig); // 同步 macOS 窗口透明度:opacity=1.0 且 blur=0 时关闭 NSVisualEffectView, // 避免 GPU 持续计算窗口背后的模糊合成 @@ -167,6 +171,90 @@ function App() { return; } + let cancelled = false; + const loadSecureConfig = async () => { + const backendApp = (window as any).go?.app?.App; + const persistedPayload = typeof window !== 'undefined' + ? window.localStorage.getItem(LEGACY_PERSIST_KEY) + : null; + const legacy = readLegacyPersistedSecrets(persistedPayload); + + let importedLegacyConnections = false; + let importedLegacyGlobalProxy = false; + + if (legacy.connections.length > 0) { + if (typeof backendApp?.ImportLegacyConnections === 'function') { + try { + await backendApp.ImportLegacyConnections( + legacy.connections.map(({ id, name, config }) => ({ id, name, config })) + ); + importedLegacyConnections = true; + } catch (err) { + console.warn('Failed to import legacy saved connections', err); + } + } else { + replaceConnections(legacy.connections); + } + } + + if (legacy.globalProxy) { + if (typeof backendApp?.ImportLegacyGlobalProxy === 'function') { + try { + await backendApp.ImportLegacyGlobalProxy(toSaveGlobalProxyInput(legacy.globalProxy)); + importedLegacyGlobalProxy = true; + } catch (err) { + console.warn('Failed to import legacy global proxy', err); + } + } else { + replaceGlobalProxy(createGlobalProxyDraft(legacy.globalProxy)); + } + } + + if ((importedLegacyConnections || importedLegacyGlobalProxy) && persistedPayload && typeof window !== 'undefined') { + const sanitizedPayload = stripLegacyPersistedSecrets(persistedPayload); + if (sanitizedPayload && sanitizedPayload !== persistedPayload) { + window.localStorage.setItem(LEGACY_PERSIST_KEY, sanitizedPayload); + } + } + + if (typeof backendApp?.GetSavedConnections === 'function') { + try { + const savedConnections = await backendApp.GetSavedConnections(); + if (!cancelled && Array.isArray(savedConnections)) { + replaceConnections(savedConnections); + } + } catch (err) { + console.warn('Failed to load saved connections from backend', err); + } + } + + if (typeof backendApp?.GetGlobalProxyConfig === 'function') { + try { + const proxyResult = await backendApp.GetGlobalProxyConfig(); + if (!cancelled && proxyResult?.success && proxyResult.data) { + replaceGlobalProxy(createGlobalProxyDraft(proxyResult.data)); + } + } catch (err) { + console.warn('Failed to load global proxy from backend', err); + } + } + + if (!cancelled) { + setHasLoadedSecureConfig(true); + } + }; + + void loadSecureConfig(); + return () => { + cancelled = true; + }; + }, [isStoreHydrated, replaceConnections, replaceGlobalProxy]); + + useEffect(() => { + if (!isStoreHydrated || !hasLoadedSecureConfig) { + return; + } + const host = String(globalProxy.host || '').trim(); const port = Number(globalProxy.port); const portValid = Number.isFinite(port) && port > 0 && port <= 65535; @@ -180,57 +268,44 @@ function App() { }); globalProxyInvalidHintShownRef.current = true; } - } else { - globalProxyInvalidHintShownRef.current = false; - void message.destroy('global-proxy-invalid'); + return; } - const enabledForBackend = globalProxy.enabled && !invalidWhenEnabled; - let cancelled = false; - try { - ConfigureGlobalProxy(enabledForBackend, { - type: globalProxy.type, - host, - port: portValid ? port : (globalProxy.type === 'http' ? 8080 : 1080), - user: String(globalProxy.user || '').trim(), - password: globalProxy.password || '', - }) - .then((res) => { - if (cancelled || res?.success) { - return; - } - void message.error({ - content: '全局代理配置失败: ' + (res?.message || '未知错误'), - key: 'global-proxy-sync-error', - }); - }) - .catch((err) => { - if (cancelled) { - return; - } - const errMsg = err instanceof Error ? err.message : String(err || '未知错误'); - void message.error({ - content: '全局代理配置失败: ' + errMsg, - key: 'global-proxy-sync-error', - }); - }) - .finally(() => { - if (!cancelled) { - setHasAppliedInitialGlobalProxy(true); - } - }); - } catch (e) { - if (!cancelled) { - setHasAppliedInitialGlobalProxy(true); - } - console.warn("Wails API: ConfigureGlobalProxy unavailable", e); + globalProxyInvalidHintShownRef.current = false; + void message.destroy('global-proxy-invalid'); + + const backendApp = (window as any).go?.app?.App; + if (typeof backendApp?.SaveGlobalProxy !== 'function') { + return; } + let cancelled = false; + Promise.resolve( + backendApp.SaveGlobalProxy( + toSaveGlobalProxyInput({ + ...globalProxy, + host, + port: portValid ? port : (globalProxy.type === 'http' ? 8080 : 1080), + }) + ) + ) + .catch((err) => { + if (cancelled) { + return; + } + const errMsg = err instanceof Error ? err.message : String(err || '未知错误'); + void message.error({ + content: '全局代理配置失败: ' + errMsg, + key: 'global-proxy-sync-error', + }); + }); + return () => { cancelled = true; }; }, [ isStoreHydrated, + hasLoadedSecureConfig, globalProxy.enabled, globalProxy.type, globalProxy.host, @@ -2490,3 +2565,5 @@ function App() { } export default App; + + diff --git a/frontend/src/store.ts b/frontend/src/store.ts index 021c764..32ca88b 100644 --- a/frontend/src/store.ts +++ b/frontend/src/store.ts @@ -1,6 +1,6 @@ import { create } from 'zustand'; import { persist } from 'zustand/middleware'; -import { ConnectionConfig, ProxyConfig, SavedConnection, TabData, SavedQuery, ConnectionTag, AIChatMessage, AIContextItem } from './types'; +import { ConnectionConfig, ProxyConfig, SavedConnection, TabData, SavedQuery, ConnectionTag, AIChatMessage, AIContextItem, GlobalProxyConfig } from './types'; import { ShortcutAction, ShortcutBinding, @@ -9,6 +9,7 @@ import { cloneShortcutOptions, sanitizeShortcutOptions, } from './utils/shortcuts'; +import { toPersistedGlobalProxy } from './utils/globalProxyDraft'; const DEFAULT_APPEARANCE = { enabled: true, opacity: 1.0, blur: 0, useNativeMacWindowControls: false }; const DEFAULT_UI_SCALE = 1.0; @@ -34,6 +35,7 @@ const DEFAULT_GLOBAL_PROXY: GlobalProxyConfig = { port: 1080, user: '', password: '', + hasPassword: false, }; const SUPPORTED_CONNECTION_TYPES = new Set([ 'mysql', @@ -246,6 +248,7 @@ const sanitizeConnectionConfig = (value: unknown): ConnectionConfig => { const safeConfig: ConnectionConfig & Record = { ...raw, + id: toTrimmedString(raw.id ?? raw.ID), type, host: toTrimmedString(raw.host, 'localhost') || 'localhost', port: normalizePort(raw.port, defaultPort), @@ -321,7 +324,16 @@ const sanitizeSavedConnection = (value: unknown, index: number): SavedConnection return { id, name, - config, + config: { ...config, id: config.id || id }, + secretRef: toTrimmedString(raw.secretRef) || undefined, + hasPrimaryPassword: raw.hasPrimaryPassword === true, + hasSSHPassword: raw.hasSSHPassword === true, + hasProxyPassword: raw.hasProxyPassword === true, + hasHttpTunnelPassword: raw.hasHttpTunnelPassword === true, + hasMySQLReplicaPassword: raw.hasMySQLReplicaPassword === true, + hasMongoReplicaPassword: raw.hasMongoReplicaPassword === true, + hasOpaqueURI: raw.hasOpaqueURI === true, + hasOpaqueDSN: raw.hasOpaqueDSN === true, includeDatabases: includeDatabases.length > 0 ? includeDatabases : undefined, includeRedisDatabases: includeRedisDatabases.length > 0 ? includeRedisDatabases : undefined, }; @@ -393,10 +405,6 @@ export interface QueryOptions { showColumnType: boolean; } -export interface GlobalProxyConfig extends ProxyConfig { - enabled: boolean; -} - interface AppState { connections: SavedConnection[]; connectionTags: ConnectionTag[]; @@ -440,6 +448,7 @@ interface AppState { addConnection: (conn: SavedConnection) => void; updateConnection: (conn: SavedConnection) => void; removeConnection: (id: string) => void; + replaceConnections: (connections: SavedConnection[]) => void; addConnectionTag: (tag: ConnectionTag) => void; updateConnectionTag: (tag: ConnectionTag) => void; @@ -468,6 +477,7 @@ interface AppState { setFontSize: (size: number) => void; setStartupFullscreen: (enabled: boolean) => void; setGlobalProxy: (proxy: Partial) => void; + replaceGlobalProxy: (proxy: Partial) => void; setSqlFormatOptions: (options: { keywordCase: 'upper' | 'lower' }) => void; setQueryOptions: (options: Partial) => void; updateShortcut: (action: ShortcutAction, binding: Partial) => void; @@ -618,18 +628,24 @@ const sanitizeFontSize = (value: unknown): number => { return normalizeIntegerInRange(value, DEFAULT_FONT_SIZE, MIN_FONT_SIZE, MAX_FONT_SIZE); }; -const sanitizeGlobalProxy = (value: unknown): GlobalProxyConfig => { +const sanitizeGlobalProxy = ( + value: unknown, + options: { allowPassword?: boolean } = {} +): GlobalProxyConfig => { const raw = (value && typeof value === 'object') ? value as Record : {}; const typeRaw = toTrimmedString(raw.type, DEFAULT_GLOBAL_PROXY.type).toLowerCase(); const type: 'socks5' | 'http' = typeRaw === 'http' ? 'http' : 'socks5'; const fallbackPort = type === 'http' ? 8080 : 1080; + const password = toTrimmedString(raw.password); return { enabled: raw.enabled === true, type, host: toTrimmedString(raw.host), port: normalizePort(raw.port, fallbackPort), user: toTrimmedString(raw.user), - password: toTrimmedString(raw.password), + password: options.allowPassword === false ? '' : password, + hasPassword: raw.hasPassword === true || password !== '', + secretRef: toTrimmedString(raw.secretRef) || undefined, }; }; @@ -782,6 +798,7 @@ export const useStore = create()( connectionIds: tag.connectionIds.filter(cid => cid !== id) })) })), + replaceConnections: (connections) => set({ connections: sanitizeConnections(connections) }), addConnectionTag: (tag) => set((state) => ({ connectionTags: [...state.connectionTags, tag] })), updateConnectionTag: (tag) => set((state) => ({ @@ -963,6 +980,7 @@ export const useStore = create()( setFontSize: (size) => set({ fontSize: sanitizeFontSize(size) }), setStartupFullscreen: (enabled) => set({ startupFullscreen: !!enabled }), setGlobalProxy: (proxy) => set((state) => ({ globalProxy: sanitizeGlobalProxy({ ...state.globalProxy, ...proxy }) })), + replaceGlobalProxy: (proxy) => set({ globalProxy: sanitizeGlobalProxy({ ...DEFAULT_GLOBAL_PROXY, ...proxy }) }), setSqlFormatOptions: (options) => set({ sqlFormatOptions: options }), setQueryOptions: (options) => set((state) => ({ queryOptions: { ...state.queryOptions, ...options } })), updateShortcut: (action, binding) => set((state) => ({ @@ -1203,7 +1221,7 @@ export const useStore = create()( migrate: (persistedState: unknown, version: number) => { const state = unwrapPersistedAppState(persistedState) as Partial; const nextState: Partial = { ...state }; - nextState.connections = sanitizeConnections(state.connections); + nextState.connections = []; if (version < 5) { nextState.connectionTags = sanitizeConnectionTags(state.connectionTags); } else { @@ -1215,7 +1233,7 @@ export const useStore = create()( nextState.uiScale = sanitizeUiScale(state.uiScale); nextState.fontSize = sanitizeFontSize(state.fontSize); nextState.startupFullscreen = sanitizeStartupFullscreen(state.startupFullscreen); - nextState.globalProxy = sanitizeGlobalProxy(state.globalProxy); + nextState.globalProxy = sanitizeGlobalProxy(state.globalProxy, { allowPassword: false }); nextState.sqlFormatOptions = sanitizeSqlFormatOptions(state.sqlFormatOptions); nextState.queryOptions = sanitizeQueryOptions(state.queryOptions); nextState.shortcutOptions = sanitizeShortcutOptions(state.shortcutOptions); @@ -1242,7 +1260,7 @@ export const useStore = create()( return { ...currentState, ...state, - connections: sanitizeConnections(state.connections), + connections: currentState.connections, connectionTags: sanitizeConnectionTags(state.connectionTags), savedQueries: sanitizeSavedQueries(state.savedQueries), theme: sanitizeTheme(state.theme), @@ -1250,7 +1268,7 @@ export const useStore = create()( uiScale: sanitizeUiScale(state.uiScale), fontSize: sanitizeFontSize(state.fontSize), startupFullscreen: sanitizeStartupFullscreen(state.startupFullscreen), - globalProxy: sanitizeGlobalProxy(state.globalProxy), + globalProxy: sanitizeGlobalProxy(state.globalProxy, { allowPassword: false }), tableSortPreference: sanitizeTableSortPreference(state.tableSortPreference), tableColumnOrders: sanitizeTableColumnOrders(state.tableColumnOrders), enableColumnOrderMemory: state.enableColumnOrderMemory !== false, @@ -1271,7 +1289,6 @@ export const useStore = create()( }; }, partialize: (state) => ({ - connections: state.connections, connectionTags: state.connectionTags, savedQueries: state.savedQueries, theme: state.theme, @@ -1279,7 +1296,7 @@ export const useStore = create()( uiScale: state.uiScale, fontSize: state.fontSize, startupFullscreen: state.startupFullscreen, - globalProxy: state.globalProxy, + globalProxy: toPersistedGlobalProxy(state.globalProxy), sqlFormatOptions: state.sqlFormatOptions, queryOptions: state.queryOptions, shortcutOptions: state.shortcutOptions, @@ -1298,3 +1315,5 @@ export const useStore = create()( } ) ); + + diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 88c2fb4..34db0ec 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -22,6 +22,7 @@ export interface HTTPTunnelConfig { } export interface ConnectionConfig { + id?: string; type: string; host: string; port: number; @@ -70,12 +71,27 @@ export interface SavedConnection { id: string; name: string; config: ConnectionConfig; + secretRef?: string; + hasPrimaryPassword?: boolean; + hasSSHPassword?: boolean; + hasProxyPassword?: boolean; + hasHttpTunnelPassword?: boolean; + hasMySQLReplicaPassword?: boolean; + hasMongoReplicaPassword?: boolean; + hasOpaqueURI?: boolean; + hasOpaqueDSN?: boolean; includeDatabases?: string[]; includeRedisDatabases?: number[]; // Redis databases to show (0-15) iconType?: string; // 自定义图标类型(如 'mysql','postgres'),不填则取 config.type iconColor?: string; // 自定义图标颜色(十六进制),不填则取类型默认色 } +export interface GlobalProxyConfig extends ProxyConfig { + enabled: boolean; + hasPassword?: boolean; + secretRef?: string; +} + export interface ConnectionTag { id: string; name: string; @@ -243,3 +259,5 @@ export interface AISafetyResult { requiresConfirm: boolean; warningMessage?: string; } + + diff --git a/frontend/src/utils/globalProxyDraft.test.ts b/frontend/src/utils/globalProxyDraft.test.ts new file mode 100644 index 0000000..7354ca7 --- /dev/null +++ b/frontend/src/utils/globalProxyDraft.test.ts @@ -0,0 +1,35 @@ +import { describe, expect, it } from 'vitest'; + +import { createGlobalProxyDraft, toPersistedGlobalProxy } from './globalProxyDraft'; + +describe('global proxy draft', () => { + it('hydrates a secretless draft from backend metadata while keeping password input blank', () => { + const draft = createGlobalProxyDraft({ + enabled: true, + type: 'http', + host: '127.0.0.1', + port: 8080, + user: 'ops', + hasPassword: true, + password: 'should-be-ignored', + }); + + expect(draft.password).toBe(''); + expect(draft.hasPassword).toBe(true); + }); + + it('drops password from persisted metadata but preserves hasPassword', () => { + const persisted = toPersistedGlobalProxy({ + enabled: true, + type: 'http', + host: '127.0.0.1', + port: 8080, + user: 'ops', + password: 'proxy-secret', + hasPassword: true, + }); + + expect('password' in persisted).toBe(false); + expect(persisted.hasPassword).toBe(true); + }); +}); diff --git a/frontend/src/utils/globalProxyDraft.ts b/frontend/src/utils/globalProxyDraft.ts new file mode 100644 index 0000000..408635c --- /dev/null +++ b/frontend/src/utils/globalProxyDraft.ts @@ -0,0 +1,62 @@ +import { GlobalProxyConfig } from '../types'; + +const toTrimmedString = (value: unknown): string => { + if (typeof value === 'string') { + return value.trim(); + } + if (typeof value === 'number' || typeof value === 'boolean') { + return String(value).trim(); + } + return ''; +}; + +const normalizeProxyType = (value: unknown): 'socks5' | 'http' => { + return toTrimmedString(value).toLowerCase() === 'http' ? 'http' : 'socks5'; +}; + +const normalizePort = (value: unknown, fallbackPort: number): number => { + const parsed = Number(value); + if (!Number.isFinite(parsed)) { + return fallbackPort; + } + const port = Math.trunc(parsed); + if (port <= 0 || port > 65535) { + return fallbackPort; + } + return port; +}; + +export function createGlobalProxyDraft(value: Partial = {}): GlobalProxyConfig { + const type = normalizeProxyType(value.type); + return { + enabled: value.enabled === true, + type, + host: toTrimmedString(value.host), + port: normalizePort(value.port, type === 'http' ? 8080 : 1080), + user: toTrimmedString(value.user), + password: '', + hasPassword: value.hasPassword === true, + secretRef: toTrimmedString(value.secretRef) || undefined, + }; +} + +export function toPersistedGlobalProxy(value: Partial = {}): Omit { + const draft = createGlobalProxyDraft(value); + return { + enabled: draft.enabled, + type: draft.type, + host: draft.host, + port: draft.port, + user: draft.user, + hasPassword: draft.hasPassword, + secretRef: draft.secretRef, + }; +} + +export function toSaveGlobalProxyInput(value: Partial = {}): GlobalProxyConfig { + const draft = createGlobalProxyDraft(value); + return { + ...draft, + password: typeof value.password === 'string' ? value.password : '', + }; +} diff --git a/frontend/src/utils/legacyConnectionStorage.test.ts b/frontend/src/utils/legacyConnectionStorage.test.ts new file mode 100644 index 0000000..7f8a46b --- /dev/null +++ b/frontend/src/utils/legacyConnectionStorage.test.ts @@ -0,0 +1,75 @@ +import { describe, expect, it } from 'vitest'; + +import { readLegacyPersistedSecrets, stripLegacyPersistedSecrets } from './legacyConnectionStorage'; + +describe('legacy connection storage', () => { + it('extracts legacy saved connections and global proxy password from lite-db-storage', () => { + const payload = JSON.stringify({ + state: { + connections: [ + { + id: 'conn-1', + name: 'Primary', + config: { + id: 'conn-1', + type: 'postgres', + host: 'db.local', + port: 5432, + user: 'postgres', + password: 'secret', + }, + }, + ], + globalProxy: { + enabled: true, + type: 'http', + host: '127.0.0.1', + port: 8080, + user: 'ops', + password: 'proxy-secret', + }, + }, + }); + + const result = readLegacyPersistedSecrets(payload); + expect(result.connections).toHaveLength(1); + expect(result.connections[0]?.config.password).toBe('secret'); + expect(result.globalProxy?.password).toBe('proxy-secret'); + }); + + it('strips persisted connection secrets but keeps secretless proxy metadata', () => { + const payload = JSON.stringify({ + state: { + connections: [ + { + id: 'conn-1', + name: 'Primary', + config: { + id: 'conn-1', + type: 'postgres', + host: 'db.local', + port: 5432, + user: 'postgres', + password: 'secret', + }, + }, + ], + globalProxy: { + enabled: true, + type: 'http', + host: '127.0.0.1', + port: 8080, + user: 'ops', + password: 'proxy-secret', + }, + }, + }); + + const sanitized = stripLegacyPersistedSecrets(payload); + const parsed = JSON.parse(sanitized); + + expect(parsed.state.connections).toEqual([]); + expect(parsed.state.globalProxy.password).toBeUndefined(); + expect(parsed.state.globalProxy.hasPassword).toBe(true); + }); +}); diff --git a/frontend/src/utils/legacyConnectionStorage.ts b/frontend/src/utils/legacyConnectionStorage.ts new file mode 100644 index 0000000..cdbdd6c --- /dev/null +++ b/frontend/src/utils/legacyConnectionStorage.ts @@ -0,0 +1,110 @@ +import { GlobalProxyConfig, SavedConnection } from '../types'; + +export const LEGACY_PERSIST_KEY = 'lite-db-storage'; + +const toTrimmedString = (value: unknown): string => { + if (typeof value === 'string') { + return value.trim(); + } + if (typeof value === 'number' || typeof value === 'boolean') { + return String(value).trim(); + } + return ''; +}; + +const normalizeProxyType = (value: unknown): 'socks5' | 'http' => { + return toTrimmedString(value).toLowerCase() === 'http' ? 'http' : 'socks5'; +}; + +const normalizePort = (value: unknown, fallbackPort: number): number => { + const parsed = Number(value); + if (!Number.isFinite(parsed)) { + return fallbackPort; + } + const port = Math.trunc(parsed); + if (port <= 0 || port > 65535) { + return fallbackPort; + } + return port; +}; + +const parsePersistedEnvelope = (payload: string | null | undefined): Record => { + if (!payload || typeof payload !== 'string') { + return {}; + } + try { + const parsed = JSON.parse(payload) as Record; + if (parsed.state && typeof parsed.state === 'object') { + return parsed.state as Record; + } + return parsed; + } catch { + return {}; + } +}; + +export function readLegacyPersistedSecrets(payload: string | null | undefined): { + connections: SavedConnection[]; + globalProxy: GlobalProxyConfig | null; +} { + const state = parsePersistedEnvelope(payload); + const connections = Array.isArray(state.connections) + ? state.connections.filter((item): item is SavedConnection => !!item && typeof item === 'object') + : []; + + const proxyRaw = state.globalProxy && typeof state.globalProxy === 'object' + ? state.globalProxy as Record + : null; + if (!proxyRaw) { + return { connections, globalProxy: null }; + } + + const type = normalizeProxyType(proxyRaw.type); + const password = toTrimmedString(proxyRaw.password); + const globalProxy: GlobalProxyConfig = { + enabled: proxyRaw.enabled === true, + type, + host: toTrimmedString(proxyRaw.host), + port: normalizePort(proxyRaw.port, type === 'http' ? 8080 : 1080), + user: toTrimmedString(proxyRaw.user), + password, + hasPassword: proxyRaw.hasPassword === true || password !== '', + secretRef: toTrimmedString(proxyRaw.secretRef) || undefined, + }; + + const hasMeaningfulProxyState = globalProxy.enabled || globalProxy.host !== '' || globalProxy.user !== '' || globalProxy.password !== '' || globalProxy.hasPassword === true; + return { + connections, + globalProxy: hasMeaningfulProxyState ? globalProxy : null, + }; +} + +export function stripLegacyPersistedSecrets(payload: string | null | undefined): string { + if (!payload || typeof payload !== 'string') { + return ''; + } + + let parsed: Record; + try { + parsed = JSON.parse(payload) as Record; + } catch { + return payload; + } + + const state = parsed.state && typeof parsed.state === 'object' + ? parsed.state as Record + : parsed; + state.connections = []; + + if (state.globalProxy && typeof state.globalProxy === 'object') { + const proxy = { ...(state.globalProxy as Record) }; + const password = toTrimmedString(proxy.password); + delete proxy.password; + if (password !== '') { + proxy.hasPassword = true; + } + state.globalProxy = proxy; + } + + return JSON.stringify(parsed); +} diff --git a/frontend/src/utils/startupReadiness.test.ts b/frontend/src/utils/startupReadiness.test.ts index 92c72bd..3b34e7e 100644 --- a/frontend/src/utils/startupReadiness.test.ts +++ b/frontend/src/utils/startupReadiness.test.ts @@ -10,10 +10,10 @@ describe('startup readiness helpers', () => { }); }); - it('keeps sidebar blocked until initial global proxy sync finishes', () => { + it('keeps sidebar blocked until secure config bootstrap finishes', () => { expect(getConnectionWorkbenchState(true, false)).toEqual({ ready: false, - message: '正在同步全局代理配置...', + message: '正在加载安全配置...', }); }); @@ -24,3 +24,4 @@ describe('startup readiness helpers', () => { }); }); }); + diff --git a/frontend/src/utils/startupReadiness.ts b/frontend/src/utils/startupReadiness.ts index 3395627..c7aab1d 100644 --- a/frontend/src/utils/startupReadiness.ts +++ b/frontend/src/utils/startupReadiness.ts @@ -16,7 +16,7 @@ export function getConnectionWorkbenchState( if (!hasAppliedInitialGlobalProxy) { return { ready: false, - message: '正在同步全局代理配置...', + message: '正在加载安全配置...', }; } return { @@ -24,3 +24,4 @@ export function getConnectionWorkbenchState( message: '', }; } + From 91b5b8590420a3e74a759954adc9c5919ea8d7f9 Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Fri, 3 Apr 2026 14:18:16 +0800 Subject: [PATCH 10/14] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(security):?= =?UTF-8?q?=20=E9=80=9A=E8=BF=87=E8=BF=9E=E6=8E=A5=E9=85=8D=E7=BD=AE=20ID?= =?UTF-8?q?=20=E8=B7=AF=E7=94=B1=20RPC=20=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + frontend/src/components/AIChatPanel.tsx | 16 ++- frontend/src/components/DataGrid.tsx | 11 +- frontend/src/components/DataSyncModal.tsx | 32 ++--- frontend/src/components/DataViewer.tsx | 15 ++- frontend/src/components/DefinitionViewer.tsx | 5 +- .../src/components/FindInDatabaseModal.tsx | 7 +- .../src/components/ImportPreviewModal.tsx | 3 +- frontend/src/components/QueryEditor.tsx | 19 +-- .../src/components/RedisCommandEditor.tsx | 3 +- frontend/src/components/RedisMonitor.tsx | 3 +- frontend/src/components/RedisViewer.tsx | 37 +++--- frontend/src/components/Sidebar.tsx | 81 +++++------- frontend/src/components/TableDesigner.tsx | 25 ++-- frontend/src/components/TableOverview.tsx | 13 +- frontend/src/components/TriggerViewer.tsx | 5 +- frontend/src/components/ai/AIChatInput.tsx | 9 +- .../src/utils/connectionRpcConfig.test.ts | 104 +++++++++++++++ frontend/src/utils/connectionRpcConfig.ts | 122 ++++++++++++++++++ 19 files changed, 357 insertions(+), 154 deletions(-) create mode 100644 frontend/src/utils/connectionRpcConfig.test.ts create mode 100644 frontend/src/utils/connectionRpcConfig.ts diff --git a/.gitignore b/.gitignore index f70d8a2..285e157 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,4 @@ docs/需求追踪/ CLAUDE.md **/CLAUDE.md +.worktrees \ No newline at end of file diff --git a/frontend/src/components/AIChatPanel.tsx b/frontend/src/components/AIChatPanel.tsx index f3ff243..695bf09 100644 --- a/frontend/src/components/AIChatPanel.tsx +++ b/frontend/src/components/AIChatPanel.tsx @@ -14,6 +14,7 @@ import { AIMessageBubble } from './ai/AIMessageBubble'; import { AIChatInput } from './ai/AIChatInput'; import { AIHistoryDrawer } from './ai/AIHistoryDrawer'; import type { AIComposerNotice } from '../utils/aiComposerNotice'; +import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; import { buildMissingModelNotice, buildMissingProviderNotice, @@ -260,7 +261,7 @@ export const AIChatPanel: React.FC = ({ const conn = useStore.getState().connections.find(c => c.id === connectionId); if (conn) { import('../../wailsjs/go/app/App').then(({ DBShowCreateTable }) => { - DBShowCreateTable(conn.config as any, dbName, tableName).then(res => { + DBShowCreateTable(buildRpcConnectionConfig(conn.config) as any, dbName, tableName).then(res => { if (res.success && res.data) { let createSql = ''; if (typeof res.data === 'string') createSql = res.data; @@ -834,7 +835,7 @@ SELECT * FROM users WHERE status = 1; const conn = useStore.getState().connections.find(c => c.id === args.connectionId); if (conn) { try { - const dbRes = await DBGetDatabases(conn.config as any); + const dbRes = await DBGetDatabases(buildRpcConnectionConfig(conn.config) as any); if (dbRes?.success && Array.isArray(dbRes.data)) { let dNames = dbRes.data.map((r: any) => r.Database || r.database || Object.values(r)[0]); if (dNames.length > 50) dNames = [...dNames.slice(0, 50), '...(截断)']; @@ -855,7 +856,7 @@ SELECT * FROM users WHERE status = 1; try { const rawDbName = args.dbName || args.database; const safeDbName = rawDbName ? String(rawDbName).trim() : ''; - const tbRes = await DBGetTables(conn.config as any, safeDbName); + const tbRes = await DBGetTables(buildRpcConnectionConfig(conn.config) as any, safeDbName); if (tbRes?.success && Array.isArray(tbRes.data)) { let tNames = tbRes.data.map((r: any) => r.Table || r.table || Object.values(r)[0] as string); if (tNames.length > 150) tNames = [...tNames.slice(0, 150), '...(截断)']; @@ -881,7 +882,7 @@ SELECT * FROM users WHERE status = 1; const safeDbName = args.dbName ? String(args.dbName).trim() : ''; const safeTable = args.tableName ? String(args.tableName).trim() : ''; const { DBGetColumns } = await import('../../wailsjs/go/app/App'); - const colRes = await DBGetColumns(conn.config as any, safeDbName, safeTable); + const colRes = await DBGetColumns(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable); if (colRes?.success && Array.isArray(colRes.data)) { // 只保留关键字段信息,减少 token 占用 const cols = colRes.data.map((c: any) => { @@ -912,7 +913,7 @@ SELECT * FROM users WHERE status = 1; const safeDbName = args.dbName ? String(args.dbName).trim() : ''; const safeTable = args.tableName ? String(args.tableName).trim() : ''; const { DBShowCreateTable } = await import('../../wailsjs/go/app/App'); - const ddlRes = await DBShowCreateTable(conn.config as any, safeDbName, safeTable); + const ddlRes = await DBShowCreateTable(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable); if (ddlRes?.success) { resStr = typeof ddlRes.data === 'string' ? ddlRes.data : JSON.stringify(ddlRes.data); success = true; @@ -946,7 +947,7 @@ SELECT * FROM users WHERE status = 1; const finalSql = (isReadQuery && !sqlTrimmed.toLowerCase().includes('limit')) ? sqlTrimmed + ' LIMIT 50' : sqlTrimmed; - const qRes = await DBQuery(conn.config as any, safeDbName, finalSql); + const qRes = await DBQuery(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeSql + (safeSql.toLowerCase().includes('limit') ? '' : ' LIMIT 50')); if (qRes?.success) { const rows = Array.isArray(qRes.data) ? qRes.data : []; const limitedRows = rows.slice(0, 50); @@ -1306,7 +1307,8 @@ SELECT * FROM users WHERE status = 1; const handleDeleteMessage = useCallback((id: string) => deleteAIChatMessage(sid, id), [sid, deleteAIChatMessage]); const activeConnectionConfig = useMemo(() => { if (!inferredConnectionId) return undefined; - return connections.find(c => c.id === inferredConnectionId)?.config; + const connection = connections.find(c => c.id === inferredConnectionId); + return connection ? buildRpcConnectionConfig(connection.config) : undefined; }, [inferredConnectionId, connections]); const contextUsageChars = useMemo(() => messages.reduce((sum, m) => sum + (m.content?.length || 0) + JSON.stringify(m.tool_calls || []).length, 0), diff --git a/frontend/src/components/DataGrid.tsx b/frontend/src/components/DataGrid.tsx index 3ead509..d6a31fe 100644 --- a/frontend/src/components/DataGrid.tsx +++ b/frontend/src/components/DataGrid.tsx @@ -32,6 +32,7 @@ import 'react-resizable/css/styles.css'; import { buildOrderBySQL, buildPaginatedSelectSQL, buildWhereSQL, escapeLiteral, hasExplicitSort, quoteIdentPart, quoteQualifiedIdent, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql'; import { isMacLikePlatform, normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance'; import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities'; +import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; import { resolvePaginationPageText, resolvePaginationSummaryText, resolvePaginationTotalForControl } from '../utils/dataGridPagination'; import { resolveGridSortInfoFromTableSorter } from '../utils/dataGridSort'; import { calculateTableBodyBottomPadding, calculateVirtualTableScrollX } from './dataGridLayout'; @@ -1357,7 +1358,7 @@ const DataGrid: React.FC = ({ }; const seq = ++columnMetaSeqRef.current; - DBGetColumns(config as any, normalizedDbName, normalizedTableName) + DBGetColumns(buildRpcConnectionConfig(config) as any, normalizedDbName, normalizedTableName) .then((res) => { if (seq !== columnMetaSeqRef.current) return; if (!res.success || !Array.isArray(res.data)) { @@ -3500,7 +3501,7 @@ const DataGrid: React.FC = ({ }; const startTime = Date.now(); - const res = await ApplyChanges(config as any, dbName || '', tableName, { inserts, updates, deletes } as any); + const res = await ApplyChanges(buildRpcConnectionConfig(config) as any, dbName || '', tableName, { inserts, updates, deletes } as any); const duration = Date.now() - startTime; // Construct a pseudo-SQL representation for the log @@ -3618,7 +3619,7 @@ const DataGrid: React.FC = ({ if (!config) return; const hide = message.loading(`正在导出...`, 0); try { - const res = await ExportQuery(config as any, dbName || '', sql, defaultName || 'export', format); + const res = await ExportQuery(buildRpcConnectionConfig(config) as any, dbName || '', sql, defaultName || 'export', format); if (res.success) { void message.success("导出成功"); } else if (res.message !== "已取消") { @@ -3736,7 +3737,7 @@ const DataGrid: React.FC = ({ if (!config) return; const hide = message.loading(`正在导出全部数据...`, 0); try { - const res = await ExportTable(config as any, dbName || '', tableName, format); + const res = await ExportTable(buildRpcConnectionConfig(config) as any, dbName || '', tableName, format); if (res.success) { void message.success("导出成功"); } else if (res.message !== "已取消") { @@ -3811,7 +3812,7 @@ const DataGrid: React.FC = ({ const config = buildConnConfig(); if (!config) return; - const res = await ImportData(config as any, dbName || '', tableName); + const res = await ImportData(buildRpcConnectionConfig(config) as any, dbName || '', tableName); if (res.success && res.data && res.data.filePath) { setImportFilePath(res.data.filePath); setImportPreviewVisible(true); diff --git a/frontend/src/components/DataSyncModal.tsx b/frontend/src/components/DataSyncModal.tsx index 7775c08..720eb3e 100644 --- a/frontend/src/components/DataSyncModal.tsx +++ b/frontend/src/components/DataSyncModal.tsx @@ -6,6 +6,7 @@ import { DBGetDatabases, DBGetTables, DataSync, DataSyncAnalyze, DataSyncPreview import { SavedConnection } from '../types'; import { EventsOn } from '../../wailsjs/runtime/runtime'; import { normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance'; +import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; import { formatLocalDateTimeLiteral, normalizeTemporalLiteralText } from './dataGridCopyInsert'; const { Title, Text } = Typography; @@ -236,14 +237,11 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open, const logBoxRef = useRef(null); const autoScrollRef = useRef(true); - const normalizeConnConfig = (conn: SavedConnection, database?: string) => ({ - ...conn.config, - port: Number((conn.config as any).port), - password: conn.config.password || "", - useSSH: conn.config.useSSH || false, - ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }, - database: typeof database === 'string' ? database : (conn.config.database || ""), - }); + const normalizeConnConfig = (conn: SavedConnection, database?: string) => ( + buildRpcConnectionConfig(conn.config, { + database: typeof database === 'string' ? database : (conn.config.database || ''), + }) + ); useEffect(() => { if (!open) return; @@ -542,22 +540,8 @@ const DataSyncModal: React.FC<{ open: boolean; onClose: () => void }> = ({ open, }); const config = { - sourceConfig: { - ...sConn.config, - port: Number((sConn.config as any).port), - password: sConn.config.password || "", - useSSH: sConn.config.useSSH || false, - ssh: sConn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }, - database: sourceDb, - }, - targetConfig: { - ...tConn.config, - port: Number((tConn.config as any).port), - password: tConn.config.password || "", - useSSH: tConn.config.useSSH || false, - ssh: tConn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" }, - database: targetDb, - }, + sourceConfig: normalizeConnConfig(sConn, sourceDb), + targetConfig: normalizeConnConfig(tConn, targetDb), tables: selectedTables, content: syncContent, mode: syncMode, diff --git a/frontend/src/components/DataViewer.tsx b/frontend/src/components/DataViewer.tsx index d10a2df..e841482 100644 --- a/frontend/src/components/DataViewer.tsx +++ b/frontend/src/components/DataViewer.tsx @@ -9,6 +9,7 @@ import { buildMongoCountCommand, buildMongoFilter, buildMongoFindCommand, buildM import { buildOracleApproximateTotalSql, parseApproximateTableCountRow, resolveApproximateTableCountStrategy } from '../utils/approximateTableCount'; import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities'; import { resolveDataViewerAutoFetchAction } from '../utils/dataViewerAutoFetch'; +import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; type ViewerPaginationState = { current: number; @@ -319,7 +320,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct const countSeq = ++manualCountSeqRef.current; const countStart = Date.now(); setPagination(prev => ({ ...prev, totalCountLoading: true, totalCountCancelled: false })); - const countConfig: any = { ...(config as any), timeout: 120 }; + const countConfig = buildRpcConnectionConfig(config, { timeout: 120 }); try { const resCount = await DBQuery(countConfig as any, dbName, countSql); @@ -478,7 +479,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct const executeDataQuery = async (querySql: string, attemptLabel: string) => { const startTime = Date.now(); try { - const result = await DBQuery(config as any, dbName, querySql); + const result = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, querySql); addSqlLog({ id: `log-${Date.now()}-data`, timestamp: Date.now(), @@ -514,7 +515,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct let safeSelect = duckdbSafeSelectCacheRef.current[cacheKey] || ''; if (!safeSelect) { try { - const resCols = await DBGetColumns(config as any, dbName, tableName); + const resCols = await DBGetColumns(buildRpcConnectionConfig(config) as any, dbName, tableName); if (resCols?.success && Array.isArray(resCols.data)) { const columnDefs = resCols.data as ColumnDefinition[]; const selectParts = columnDefs.map((col) => { @@ -567,7 +568,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct if (pkKeyRef.current !== pkKey) { pkKeyRef.current = pkKey; const pkSeq = ++pkSeqRef.current; - DBGetColumns(config as any, dbName, tableName) + DBGetColumns(buildRpcConnectionConfig(config) as any, dbName, tableName) .then((resCols: any) => { if (pkSeqRef.current !== pkSeq) return; if (pkKeyRef.current !== pkKey) return; @@ -680,7 +681,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct const countStart = Date.now(); // 大表 COUNT(*) 可能非常慢,且在部分运行时环境下会影响后续操作响应; // DuckDB 大文件场景下该统计会显著拖慢翻页,已禁用后台 COUNT。 - const countConfig: any = { ...(config as any), timeout: 5 }; + const countConfig = buildRpcConnectionConfig(config, { timeout: 5 }); DBQuery(countConfig, dbName, countSql) .then((resCount: any) => { @@ -734,7 +735,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct const { schemaName, pureTableName } = resolveDuckDBSchemaAndTable(dbName, tableName); const escapedSchema = escapeSQLLiteral(schemaName); const escapedTable = escapeSQLLiteral(pureTableName); - const approxConfig: any = { ...(config as any), timeout: 3 }; + const approxConfig = buildRpcConnectionConfig(config, { timeout: 3 }); const approxSqlCandidates = [ `SELECT estimated_size AS approx_total FROM duckdb_tables() WHERE schema_name='${escapedSchema}' AND table_name='${escapedTable}' LIMIT 1`, `SELECT estimated_size AS approx_total FROM duckdb_tables() WHERE table_name='${escapedTable}' ORDER BY CASE WHEN schema_name='${escapedSchema}' THEN 0 ELSE 1 END LIMIT 1`, @@ -775,7 +776,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct if (approximateCountStrategy === 'oracle-num-rows' && oracleApproxKeyRef.current !== countKey) { oracleApproxKeyRef.current = countKey; const approxSeq = ++oracleApproxSeqRef.current; - const approxConfig: any = { ...(config as any), timeout: 3 }; + const approxConfig = buildRpcConnectionConfig(config, { timeout: 3 }); const approxSql = buildOracleApproximateTotalSql({ dbName, tableName }); DBQuery(approxConfig as any, dbName, approxSql) diff --git a/frontend/src/components/DefinitionViewer.tsx b/frontend/src/components/DefinitionViewer.tsx index 9072258..d17b6bd 100644 --- a/frontend/src/components/DefinitionViewer.tsx +++ b/frontend/src/components/DefinitionViewer.tsx @@ -4,6 +4,7 @@ import { Spin, Alert } from 'antd'; import { TabData } from '../types'; import { useStore } from '../store'; import { DBQuery } from '../../wailsjs/go/app/App'; +import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; interface DefinitionViewerProps { tab: TabData; @@ -201,7 +202,7 @@ const DefinitionViewer: React.FC = ({ tab }) => { const sql = String(query || '').trim(); if (!sql) continue; try { - const result = await DBQuery(config as any, dbName, sql); + const result = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, sql); if (!result.success || !Array.isArray(result.data)) { lastMessage = result.message || lastMessage; continue; @@ -227,7 +228,7 @@ const DefinitionViewer: React.FC = ({ tab }) => { ]; for (const query of candidates) { try { - const result = await DBQuery(config as any, dbName, query); + const result = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, query); if (!result.success || !Array.isArray(result.data) || result.data.length === 0) { continue; } diff --git a/frontend/src/components/FindInDatabaseModal.tsx b/frontend/src/components/FindInDatabaseModal.tsx index cbe3da2..2a29484 100644 --- a/frontend/src/components/FindInDatabaseModal.tsx +++ b/frontend/src/components/FindInDatabaseModal.tsx @@ -5,6 +5,7 @@ import { DBQuery, DBGetTables, DBGetAllColumns } from '../../wailsjs/go/app/App' import { quoteIdentPart, escapeLiteral } from '../utils/sql'; import { useStore } from '../store'; import { buildOverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme'; +import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; interface FindInDatabaseModalProps { open: boolean; @@ -106,7 +107,7 @@ const FindInDatabaseModal: React.FC = ({ open, onClose try { // 1. 获取所有表 - const tablesRes = await DBGetTables(config as any, dbName); + const tablesRes = await DBGetTables(buildRpcConnectionConfig(config) as any, dbName); if (!tablesRes.success) { message.error('获取表列表失败: ' + tablesRes.message); setSearching(false); @@ -124,7 +125,7 @@ const FindInDatabaseModal: React.FC = ({ open, onClose setProgress({ current: 0, total: tableNames.length, tableName: '' }); // 2. 获取所有列信息(返回 any[],含 tableName/name/type 字段) - const allColsRes = await DBGetAllColumns(config as any, dbName); + const allColsRes = await DBGetAllColumns(buildRpcConnectionConfig(config) as any, dbName); const allColumns: any[] = (allColsRes?.success && Array.isArray(allColsRes.data)) ? allColsRes.data : []; // 按表名分组 @@ -166,7 +167,7 @@ const FindInDatabaseModal: React.FC = ({ open, onClose const sql = buildLimitedSelectSQL(dbType, baseSql, MAX_MATCH_ROWS_PER_TABLE); try { - const res = await DBQuery(config as any, dbName, sql); + const res = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, sql); if (res.success && Array.isArray(res.data) && res.data.length > 0) { // 检查哪些列实际匹配了 const matchedCols = new Set(); diff --git a/frontend/src/components/ImportPreviewModal.tsx b/frontend/src/components/ImportPreviewModal.tsx index 160eb7d..e77aa6d 100644 --- a/frontend/src/components/ImportPreviewModal.tsx +++ b/frontend/src/components/ImportPreviewModal.tsx @@ -4,6 +4,7 @@ import { CheckCircleOutlined, CloseCircleOutlined } from '@ant-design/icons'; import { PreviewImportFile, ImportDataWithProgress } from '../../wailsjs/go/app/App'; import { EventsOn, EventsOff } from '../../wailsjs/runtime/runtime'; import { useStore } from '../store'; +import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; interface ImportPreviewModalProps { visible: boolean; @@ -107,7 +108,7 @@ const ImportPreviewModal: React.FC = ({ ssh: conn.config.ssh || { host: '', port: 22, user: '', password: '', keyPath: '' } }; - const res = await ImportDataWithProgress(config as any, dbName, tableName, filePath); + const res = await ImportDataWithProgress(buildRpcConnectionConfig(config) as any, dbName, tableName, filePath); if (res.success && res.data) { setImportResult(res.data); diff --git a/frontend/src/components/QueryEditor.tsx b/frontend/src/components/QueryEditor.tsx index 1904e44..9b2fd44 100644 --- a/frontend/src/components/QueryEditor.tsx +++ b/frontend/src/components/QueryEditor.tsx @@ -11,6 +11,7 @@ import DataGrid, { GONAVI_ROW_KEY } from './DataGrid'; import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities'; import { convertMongoShellToJsonCommand } from '../utils/mongodb'; import { getShortcutDisplay, isEditableElement, isShortcutMatch } from '../utils/shortcuts'; +import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; const SQL_KEYWORDS = [ 'SELECT', 'FROM', 'WHERE', 'LIMIT', 'INSERT', 'UPDATE', 'DELETE', 'JOIN', 'LEFT', 'RIGHT', @@ -336,7 +337,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } }; - const res = await DBGetDatabases(config as any); + const res = await DBGetDatabases(buildRpcConnectionConfig(config) as any); if (res.success && Array.isArray(res.data)) { let dbs = res.data.map((row: any) => row.Database || row.database); @@ -392,7 +393,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc for (const dbName of visibleDbs) { // 获取表 - const resTables = await DBGetTables(config as any, dbName); + const resTables = await DBGetTables(buildRpcConnectionConfig(config) as any, dbName); if (resTables.success && Array.isArray(resTables.data)) { const tableNames = resTables.data.map((row: any) => Object.values(row)[0] as string); tableNames.forEach((tableName: string) => { @@ -401,7 +402,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc } // 获取列 (所有数据库类型都支持 DBGetAllColumns) - const resCols = await DBGetAllColumns(config as any, dbName); + const resCols = await DBGetAllColumns(buildRpcConnectionConfig(config) as any, dbName); if (resCols.success && Array.isArray(resCols.data)) { resCols.data.forEach((col: any) => { allColumns.push({ @@ -577,7 +578,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc const config = buildConnConfig(); if (!config) return [] as ColumnDefinition[]; - const res = await DBGetColumns(config as any, dbName, tableIdent); + const res = await DBGetColumns(buildRpcConnectionConfig(config) as any, dbName, tableIdent); if (res?.success && Array.isArray(res.data)) { const cols = res.data as ColumnDefinition[]; sharedColumnsCacheData[key] = cols; @@ -1555,7 +1556,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc } catch { queryId = 'reload-' + Date.now(); } - const res = await DBQueryMulti(config as any, currentDb, sql, queryId); + const res = await DBQueryMulti(buildRpcConnectionConfig(config) as any, currentDb, sql, queryId); if (!res?.success) { message.error('刷新失败: ' + (res?.message || '未知错误')); return; @@ -1643,7 +1644,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc try { const rawSQL = getSelectedSQL() || currentQuery; - const dbType = String((config as any).type || 'mysql'); + const dbType = String((buildRpcConnectionConfig(config) as any).type || 'mysql'); const normalizedDbType = dbType.trim().toLowerCase(); const normalizedRawSQL = String(rawSQL || '').replace(/;/g, ';'); @@ -1694,7 +1695,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc } setQueryId(queryId); - const res = await DBQueryWithCancel(config as any, currentDb, executedSql, queryId); + const res = await DBQueryWithCancel(buildRpcConnectionConfig(config) as any, currentDb, executedSql, queryId); const duration = Date.now() - startTime; addSqlLog({ id: `log-${Date.now()}-query-${idx + 1}`, @@ -1795,7 +1796,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc } setQueryId(queryId); - const res = await DBQueryMulti(config as any, currentDb, fullSQL, queryId); + const res = await DBQueryMulti(buildRpcConnectionConfig(config) as any, currentDb, fullSQL, queryId); const duration = Date.now() - startTime; addSqlLog({ @@ -1921,7 +1922,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc setActiveResultKey(nextResultSets[0]?.key || ''); pendingPk.forEach(({ resultKey, tableName }) => { - DBGetColumns(config as any, currentDb, tableName) + DBGetColumns(buildRpcConnectionConfig(config) as any, currentDb, tableName) .then((resCols: any) => { if (runSeqRef.current !== runSeq) return; if (!resCols?.success) { diff --git a/frontend/src/components/RedisCommandEditor.tsx b/frontend/src/components/RedisCommandEditor.tsx index 7cc7b28..8aabd65 100644 --- a/frontend/src/components/RedisCommandEditor.tsx +++ b/frontend/src/components/RedisCommandEditor.tsx @@ -2,6 +2,7 @@ import React, { useState, useCallback, useRef, useEffect } from 'react'; import { Button, Space, message } from 'antd'; import { PlayCircleOutlined, ClearOutlined } from '@ant-design/icons'; import { useStore } from '../store'; +import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; import Editor, { OnMount } from '@monaco-editor/react'; interface RedisCommandEditorProps { @@ -201,7 +202,7 @@ const RedisCommandEditor: React.FC = ({ connectionId, r for (const cmd of commands) { const start = Date.now(); try { - const res = await (window as any).go.app.App.RedisExecuteCommand(config, cmd); + const res = await (window as any).go.app.App.RedisExecuteCommand(buildRpcConnectionConfig(config), cmd); newResults.push({ command: cmd, result: res.success ? res.data : null, diff --git a/frontend/src/components/RedisMonitor.tsx b/frontend/src/components/RedisMonitor.tsx index 0e5cc25..cb86550 100644 --- a/frontend/src/components/RedisMonitor.tsx +++ b/frontend/src/components/RedisMonitor.tsx @@ -12,6 +12,7 @@ import { } from '@ant-design/icons'; import { useStore } from '../store'; import { SavedConnection } from '../types'; +import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; import { RedisGetServerInfo } from '../../wailsjs/go/app/App'; const { Title, Text } = Typography; @@ -61,7 +62,7 @@ const RedisMonitor: React.FC = ({ connectionId, redisDB }) => if (!connection) return; try { - const config = { ...connection.config, redisDB } as any; + const config = buildRpcConnectionConfig(connection.config, { redisDB }); const res = await RedisGetServerInfo(config); if (!mountedRef.current) return; diff --git a/frontend/src/components/RedisViewer.tsx b/frontend/src/components/RedisViewer.tsx index 9329c33..6a1a4ff 100644 --- a/frontend/src/components/RedisViewer.tsx +++ b/frontend/src/components/RedisViewer.tsx @@ -7,6 +7,7 @@ import { RedisKeyInfo, RedisValue, StreamEntry } from '../types'; import Editor from '@monaco-editor/react'; import type { DataNode } from 'antd/es/tree'; import { blurToFilter, normalizeBlurForPlatform, normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance'; +import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; import { applyRenamedRedisKeyState, applyTreeNodeCheck, @@ -429,7 +430,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { setLoading(true); try { - const res = await (window as any).go.app.App.RedisScanKeys(config, normalizedPattern, fromCursor, effectiveTargetCount); + const res = await (window as any).go.app.App.RedisScanKeys(buildRpcConnectionConfig(config), normalizedPattern, fromCursor, effectiveTargetCount); if (requestId !== latestLoadRequestIdRef.current) { return; } @@ -508,7 +509,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { setValueLoading(true); try { - const res = await (window as any).go.app.App.RedisGetValue(config, key); + const res = await (window as any).go.app.App.RedisGetValue(buildRpcConnectionConfig(config), key); if (res.success) { setKeyValue(res.data); setSelectedKey(key); @@ -539,7 +540,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { if (!config) return; try { - const res = await (window as any).go.app.App.RedisDeleteKeys(config, keysToDelete); + const res = await (window as any).go.app.App.RedisDeleteKeys(buildRpcConnectionConfig(config), keysToDelete); if (res.success) { message.success(`已删除 ${res.data.deleted} 个 Key`); setKeys(prev => prev.filter(k => !keysToDelete.includes(k.key))); @@ -567,7 +568,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { try { const values = await ttlForm.validateFields(); - const res = await (window as any).go.app.App.RedisSetTTL(config, selectedKey, values.ttl); + const res = await (window as any).go.app.App.RedisSetTTL(buildRpcConnectionConfig(config), selectedKey, values.ttl); if (res.success) { message.success('TTL 设置成功'); setTtlModalOpen(false); @@ -586,7 +587,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { if (!config || !selectedKey) return; try { - const res = await (window as any).go.app.App.RedisSetString(config, selectedKey, editValue, keyValue?.ttl || -1); + const res = await (window as any).go.app.App.RedisSetString(buildRpcConnectionConfig(config), selectedKey, editValue, keyValue?.ttl || -1); if (res.success) { message.success('保存成功'); setEditModalOpen(false); @@ -605,7 +606,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { try { const values = await newKeyForm.validateFields(); - const res = await (window as any).go.app.App.RedisSetString(config, values.key, values.value, values.ttl || -1); + const res = await (window as any).go.app.App.RedisSetString(buildRpcConnectionConfig(config), values.key, values.value, values.ttl || -1); if (res.success) { message.success('创建成功'); setNewKeyModalOpen(false); @@ -642,7 +643,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { return; } - const existsRes = await (window as any).go.app.App.RedisKeyExists(config, nextKey); + const existsRes = await (window as any).go.app.App.RedisKeyExists(buildRpcConnectionConfig(config), nextKey); if (!existsRes?.success) { message.error('校验目标 Key 失败: ' + (existsRes?.message || '未知错误')); return; @@ -652,7 +653,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { return; } - const res = await (window as any).go.app.App.RedisRenameKey(config, renameTargetKey, nextKey); + const res = await (window as any).go.app.App.RedisRenameKey(buildRpcConnectionConfig(config), renameTargetKey, nextKey); if (res.success) { const nextState = applyRenamedRedisKeyState( { @@ -1177,7 +1178,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { const config = getConfig(); if (!config) return; try { - const res = await (window as any).go.app.App.RedisSetHashField(config, selectedKey, field, newValue); + const res = await (window as any).go.app.App.RedisSetHashField(buildRpcConnectionConfig(config), selectedKey, field, newValue); if (res.success) { message.success('修改成功'); loadKeyValue(selectedKey); @@ -1193,7 +1194,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { const config = getConfig(); if (!config) return; try { - const res = await (window as any).go.app.App.RedisDeleteHashField(config, selectedKey, field); + const res = await (window as any).go.app.App.RedisDeleteHashField(buildRpcConnectionConfig(config), selectedKey, field); if (res.success) { message.success('删除成功'); loadKeyValue(selectedKey); @@ -1338,7 +1339,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { const config = getConfig(); if (!config) return; try { - const res = await (window as any).go.app.App.RedisListSet(config, selectedKey, index, newValue); + const res = await (window as any).go.app.App.RedisListSet(buildRpcConnectionConfig(config), selectedKey, index, newValue); if (res.success) { message.success('修改成功'); loadKeyValue(selectedKey); @@ -1354,7 +1355,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { const config = getConfig(); if (!config) return; try { - const res = await (window as any).go.app.App.RedisListPush(config, selectedKey, { values: [value], position }); + const res = await (window as any).go.app.App.RedisListPush(buildRpcConnectionConfig(config), selectedKey, { values: [value], position }); if (res.success) { message.success('添加成功'); loadKeyValue(selectedKey); @@ -1508,7 +1509,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { const config = getConfig(); if (!config) return; try { - const res = await (window as any).go.app.App.RedisSetAdd(config, selectedKey, [member]); + const res = await (window as any).go.app.App.RedisSetAdd(buildRpcConnectionConfig(config), selectedKey, [member]); if (res.success) { message.success('添加成功'); loadKeyValue(selectedKey); @@ -1524,7 +1525,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { const config = getConfig(); if (!config) return; try { - const res = await (window as any).go.app.App.RedisSetRemove(config, selectedKey, [member]); + const res = await (window as any).go.app.App.RedisSetRemove(buildRpcConnectionConfig(config), selectedKey, [member]); if (res.success) { message.success('删除成功'); loadKeyValue(selectedKey); @@ -1645,7 +1646,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { const config = getConfig(); if (!config) return; try { - const res = await (window as any).go.app.App.RedisZSetAdd(config, selectedKey, [{ member, score }]); + const res = await (window as any).go.app.App.RedisZSetAdd(buildRpcConnectionConfig(config), selectedKey, [{ member, score }]); if (res.success) { message.success('添加成功'); loadKeyValue(selectedKey); @@ -1661,7 +1662,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { const config = getConfig(); if (!config) return; try { - const res = await (window as any).go.app.App.RedisZSetRemove(config, selectedKey, [member]); + const res = await (window as any).go.app.App.RedisZSetRemove(buildRpcConnectionConfig(config), selectedKey, [member]); if (res.success) { message.success('删除成功'); loadKeyValue(selectedKey); @@ -1841,7 +1842,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { } try { - const res = await (window as any).go.app.App.RedisStreamAdd(config, selectedKey, fieldMap, id || '*'); + const res = await (window as any).go.app.App.RedisStreamAdd(buildRpcConnectionConfig(config), selectedKey, fieldMap, id || '*'); if (res.success) { const newID = res.data?.id ? ` (${res.data.id})` : ''; message.success(`添加成功${newID}`); @@ -1859,7 +1860,7 @@ const RedisViewer: React.FC = ({ connectionId, redisDB }) => { if (!config) return; try { - const res = await (window as any).go.app.App.RedisStreamDelete(config, selectedKey, [id]); + const res = await (window as any).go.app.App.RedisStreamDelete(buildRpcConnectionConfig(config), selectedKey, [id]); if (res.success) { const deleted = Number(res.data?.deleted ?? 0); if (deleted > 0) { diff --git a/frontend/src/components/Sidebar.tsx b/frontend/src/components/Sidebar.tsx index 08d27aa..91b61c3 100644 --- a/frontend/src/components/Sidebar.tsx +++ b/frontend/src/components/Sidebar.tsx @@ -42,6 +42,7 @@ import { getDbIcon } from './DatabaseIcons'; import { EventsOn } from '../../wailsjs/runtime/runtime'; import { normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance'; import FindInDatabaseModal from './FindInDatabaseModal'; +import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; const { Search } = Input; @@ -527,7 +528,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> if (SIDEBAR_SCHEMA_DB_TYPES.has(dbType)) return true; if (dbType !== 'custom') return false; - const customDriver = String((conn?.config as any)?.driver || '').trim().toLowerCase(); + const customDriver = String(conn?.config?.driver || '').trim().toLowerCase(); return SIDEBAR_SCHEMA_CUSTOM_DRIVERS.has(customDriver); }; @@ -543,7 +544,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> const getMetadataDialect = (conn: SavedConnection | undefined): string => { const type = String(conn?.config?.type || '').trim().toLowerCase(); if (type === 'custom') { - const driver = String((conn?.config as any)?.driver || '').trim().toLowerCase(); + const driver = String(conn?.config?.driver || '').trim().toLowerCase(); if (driver === 'diros' || driver === 'doris') return 'mysql'; return driver; } @@ -569,7 +570,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> const type = String(conn?.config?.type || '').trim().toLowerCase(); if (type === 'sphinx') return true; if (type !== 'custom') return false; - const driver = String((conn?.config as any)?.driver || '').trim().toLowerCase(); + const driver = String(conn?.config?.driver || '').trim().toLowerCase(); return driver === 'sphinx' || driver === 'sphinxql'; }; @@ -857,7 +858,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> for (const spec of normalizedSpecs) { try { - const result = await DBQuery(config as any, dbName, spec.sql); + const result = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, spec.sql); if (!result.success || !Array.isArray(result.data)) { continue; } @@ -988,7 +989,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> // Handle Redis connections differently if (conn.config.type === 'redis') { try { - const res = await (window as any).go.app.App.RedisGetDatabases(config); + const res = await (window as any).go.app.App.RedisGetDatabases(buildRpcConnectionConfig(config)); if (res.success) { setConnectionStates(prev => ({ ...prev, [conn.id]: 'success' })); const redisRows: any[] = Array.isArray(res.data) ? res.data : []; @@ -1020,7 +1021,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> } try { - const res = await DBGetDatabases(config as any); + const res = await DBGetDatabases(buildRpcConnectionConfig(config) as any); if (res.success) { setConnectionStates(prev => ({ ...prev, [conn.id]: 'success' })); const dbRows: any[] = Array.isArray(res.data) ? res.data : []; @@ -1094,7 +1095,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } }; try { - const res = await DBGetTables(config as any, conn.dbName); + const res = await DBGetTables(buildRpcConnectionConfig(config) as any, conn.dbName); if (res.success) { setConnectionStates(prev => ({ ...prev, [key as string]: 'success' })); @@ -1578,14 +1579,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> const handleCopyStructure = async (node: any) => { const { config, dbName, tableName } = node.dataRef; - const res = await DBShowCreateTable({ - ...config, - port: Number(config.port), - password: config.password || "", - database: config.database || "", - useSSH: config.useSSH || false, - ssh: config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } - } as any, dbName, tableName); + const res = await DBShowCreateTable(buildRpcConnectionConfig(config) as any, dbName, tableName); if (res.success) { navigator.clipboard.writeText(res.data as string); message.success('表结构已复制到剪贴板'); @@ -1597,14 +1591,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> const handleExport = async (node: any, format: string) => { const { config, dbName, tableName } = node.dataRef; const hide = message.loading(`正在导出 ${tableName} 为 ${format.toUpperCase()}...`, 0); - const res = await ExportTable({ - ...config, - port: Number(config.port), - password: config.password || "", - database: config.database || "", - useSSH: config.useSSH || false, - ssh: config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } - } as any, dbName, tableName, format); + const res = await ExportTable(buildRpcConnectionConfig(config) as any, dbName, tableName, format); hide(); if (res.success) { message.success('导出成功'); @@ -1613,14 +1600,9 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> } }; - const normalizeConnConfig = (raw: any) => ({ - ...raw, - port: Number(raw.port), - password: raw.password || "", - database: raw.database || "", - useSSH: raw.useSSH || false, - ssh: raw.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } - }); + const normalizeConnConfig = (raw: any) => ( + buildRpcConnectionConfig(raw) + ); const handleExportDatabaseSQL = async (node: any, includeData: boolean) => { const conn = node.dataRef; @@ -1715,7 +1697,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } }; - const res = await DBGetDatabases(config as any); + const res = await DBGetDatabases(buildRpcConnectionConfig(config) as any); if (res.success) { const dbRows: any[] = Array.isArray(res.data) ? res.data : []; let dbs = dbRows.map((row: any) => { @@ -1750,7 +1732,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> }; const [res, viewResult] = await Promise.all([ - DBGetTables(config as any, dbName), + DBGetTables(buildRpcConnectionConfig(config) as any, dbName), loadViews(conn, dbName).catch(() => ({ views: [], supported: false })), ]); @@ -2026,7 +2008,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } }; - const res = await DBGetDatabases(config as any); + const res = await DBGetDatabases(buildRpcConnectionConfig(config) as any); if (res.success) { const dbRows: any[] = Array.isArray(res.data) ? res.data : []; let dbs = dbRows.map((row: any) => { @@ -2238,7 +2220,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } }; - const res = await CreateDatabase(config as any, values.name); + const res = await CreateDatabase(buildRpcConnectionConfig(config) as any, values.name); if (res.success) { message.success("数据库创建成功"); setIsCreateDbModalOpen(false); @@ -2254,14 +2236,9 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> }; const buildRuntimeConfig = (conn: any, overrideDatabase?: string, clearDatabase: boolean = false) => { - return { - ...conn.config, - port: Number(conn.config.port), - password: conn.config.password || "", - database: clearDatabase ? "" : ((overrideDatabase ?? conn.config.database) || ""), - useSSH: conn.config.useSSH || false, - ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } - }; + return buildRpcConnectionConfig(conn.config, { + database: clearDatabase ? '' : ((overrideDatabase ?? conn.config.database) || ''), + }); }; const getConnectionNodeRef = (connRef: any) => { @@ -2303,7 +2280,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> } const config = buildRuntimeConfig(conn, conn.dbName); - const res = await RenameDatabase(config as any, oldDbName, newDbName); + const res = await RenameDatabase(buildRpcConnectionConfig(config) as any, oldDbName, newDbName); if (res.success) { message.success("数据库重命名成功"); setExpandedKeys(prev => prev.filter(k => !k.toString().startsWith(`${conn.id}-${oldDbName}`))); @@ -2330,7 +2307,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> okButtonProps: { danger: true }, onOk: async () => { const config = buildRuntimeConfig(conn, conn.dbName); - const res = await DropDatabase(config as any, dbName); + const res = await DropDatabase(buildRpcConnectionConfig(config) as any, dbName); if (res.success) { message.success("数据库删除成功"); closeTabsByDatabase(conn.id, dbName); @@ -2360,7 +2337,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> return; } const config = buildRuntimeConfig(conn, conn.dbName); - const res = await RenameTable(config as any, conn.dbName, oldTableName, newTableName); + const res = await RenameTable(buildRpcConnectionConfig(config) as any, conn.dbName, oldTableName, newTableName); if (res.success) { message.success("表重命名成功"); await loadTables(getDatabaseNodeRef(conn, conn.dbName)); @@ -2385,7 +2362,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> okButtonProps: { danger: true }, onOk: async () => { const config = buildRuntimeConfig(conn, conn.dbName); - const res = await DropTable(config as any, conn.dbName, tableName); + const res = await DropTable(buildRpcConnectionConfig(config) as any, conn.dbName, tableName); if (res.success) { message.success("表删除成功"); await loadTables(getDatabaseNodeRef(conn, conn.dbName)); @@ -2445,7 +2422,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> } } if (query) { - const result = await DBQuery(config as any, dbName, query); + const result = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, query); if (result.success && Array.isArray(result.data) && result.data.length > 0) { const row = result.data[0] as Record; const def = row.view_definition || row.VIEW_DEFINITION || Object.values(row).find(v => typeof v === 'string' && String(v).length > 10) || ''; @@ -2511,7 +2488,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> okButtonProps: { danger: true }, onOk: async () => { const config = buildRuntimeConfig(conn, conn.dbName); - const res = await DropView(config as any, conn.dbName, viewName); + const res = await DropView(buildRpcConnectionConfig(config) as any, conn.dbName, viewName); if (res.success) { message.success("视图删除成功"); await loadTables(getDatabaseNodeRef(conn, conn.dbName)); @@ -2538,7 +2515,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> return; } const config = buildRuntimeConfig(conn, conn.dbName); - const res = await RenameView(config as any, conn.dbName, oldViewName, newViewName); + const res = await RenameView(buildRpcConnectionConfig(config) as any, conn.dbName, oldViewName, newViewName); if (res.success) { message.success("视图重命名成功"); await loadTables(getDatabaseNodeRef(conn, conn.dbName)); @@ -2610,7 +2587,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> } } if (query) { - const result = await DBQuery(config as any, dbName, query); + const result = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, query); if (result.success && Array.isArray(result.data) && result.data.length > 0) { if (dialect === 'oracle' || dialect === 'dm') { const lines = result.data.map((row: any) => row.text || row.TEXT || Object.values(row)[0] || '').join(''); @@ -2704,7 +2681,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> okButtonProps: { danger: true }, onOk: async () => { const config = buildRuntimeConfig(conn, conn.dbName); - const res = await DropFunction(config as any, conn.dbName, routineName, routineType); + const res = await DropFunction(buildRpcConnectionConfig(config) as any, conn.dbName, routineName, routineType); if (res.success) { message.success(`${typeLabel}删除成功`); await loadTables(getDatabaseNodeRef(conn, conn.dbName)); diff --git a/frontend/src/components/TableDesigner.tsx b/frontend/src/components/TableDesigner.tsx index fa17af0..48c0192 100644 --- a/frontend/src/components/TableDesigner.tsx +++ b/frontend/src/components/TableDesigner.tsx @@ -9,6 +9,7 @@ import { TabData, ColumnDefinition, IndexDefinition, ForeignKeyDefinition, Trigg import { useStore } from '../store'; import { DBGetColumns, DBGetIndexes, DBQuery, DBGetForeignKeys, DBGetTriggers, DBShowCreateTable } from '../../wailsjs/go/app/App'; import { hasIndexFormChanged, normalizeIndexFormFromRow, shouldRestoreOriginalIndex, toggleIndexSelection as getNextIndexSelection, type IndexDisplaySnapshot } from './tableDesignerIndexUtils'; +import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; interface EditableColumn extends ColumnDefinition { _key: string; @@ -751,14 +752,14 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => { }; const promises: Promise[] = [ - DBGetColumns(config as any, tab.dbName || '', tab.tableName || ''), - DBGetIndexes(config as any, tab.dbName || '', tab.tableName || ''), - DBGetForeignKeys(config as any, tab.dbName || '', tab.tableName || ''), - DBGetTriggers(config as any, tab.dbName || '', tab.tableName || '') + DBGetColumns(buildRpcConnectionConfig(config) as any, tab.dbName || '', tab.tableName || ''), + DBGetIndexes(buildRpcConnectionConfig(config) as any, tab.dbName || '', tab.tableName || ''), + DBGetForeignKeys(buildRpcConnectionConfig(config) as any, tab.dbName || '', tab.tableName || ''), + DBGetTriggers(buildRpcConnectionConfig(config) as any, tab.dbName || '', tab.tableName || '') ]; if (!isNewTable) { - promises.push(DBShowCreateTable(config as any, tab.dbName || '', tab.tableName || '')); + promises.push(DBShowCreateTable(buildRpcConnectionConfig(config) as any, tab.dbName || '', tab.tableName || '')); } const results = await Promise.all(promises); @@ -848,7 +849,7 @@ const TableDesigner: React.FC<{ tab: TabData }> = ({ tab }) => { if (!type) return ''; if (type === 'custom') { - return inferDialectFromCustomDriver(String((conn?.config as any)?.driver || '')); + return inferDialectFromCustomDriver(String(conn?.config?.driver || '')); } if (type === 'mariadb' || type === 'diros' || type === 'sphinx') return 'mysql'; @@ -993,7 +994,7 @@ ${selectedTrigger.statement}`; const dropSql = buildDropTriggerSql(selectedTrigger.name); try { - const res = await DBQuery(config as any, tab.dbName || '', dropSql); + const res = await DBQuery(buildRpcConnectionConfig(config) as any, tab.dbName || '', dropSql); if (res.success) { message.success('触发器删除成功'); setSelectedTrigger(null); @@ -1030,7 +1031,7 @@ ${selectedTrigger.statement}`; // 如果是编辑模式,先删除旧触发器 if (triggerEditMode === 'edit' && selectedTrigger) { const dropSql = buildDropTriggerSql(selectedTrigger.name); - const dropRes = await DBQuery(config as any, tab.dbName || '', dropSql); + const dropRes = await DBQuery(buildRpcConnectionConfig(config) as any, tab.dbName || '', dropSql); if (!dropRes.success) { message.error('删除旧触发器失败: ' + dropRes.message); setTriggerExecuting(false); @@ -1039,7 +1040,7 @@ ${selectedTrigger.statement}`; } // 执行创建语句 - const res = await DBQuery(config as any, tab.dbName || '', triggerEditSql); + const res = await DBQuery(buildRpcConnectionConfig(config) as any, tab.dbName || '', triggerEditSql); if (res.success) { message.success(triggerEditMode === 'create' ? '触发器创建成功' : '触发器修改成功'); setIsTriggerEditModalOpen(false); @@ -1522,7 +1523,7 @@ ${selectedTrigger.statement}`; const sql = buildCreateTableSql(copyTableName.trim(), selectedColumns, copyCharset, copyCollation); setCopyExecuting(true); try { - const res = await DBQuery(config as any, tab.dbName || '', sql); + const res = await DBQuery(buildRpcConnectionConfig(config) as any, tab.dbName || '', sql); if (res.success) { message.success(`已将 ${selectedColumns.length} 个字段复制到新表 ${copyTableName.trim()}`); setIsCopyColumnsModalOpen(false); @@ -1551,7 +1552,7 @@ ${selectedTrigger.statement}`; for (let i = 0; i < statements.length; i++) { let stmt = statements[i]; if (!stmt.endsWith(';')) stmt += ';'; - const res = await DBQuery(config as any, tab.dbName || '', stmt); + const res = await DBQuery(buildRpcConnectionConfig(config) as any, tab.dbName || '', stmt); if (!res.success) { const prefix = statements.length > 1 ? `第 ${i + 1}/${statements.length} 条语句执行失败: ` : '执行失败: '; return { @@ -2202,7 +2203,7 @@ END;`; const conn = connections.find(c => c.id === tab.connectionId); if (!conn) return; const config = { ...conn.config, port: Number(conn.config.port), password: conn.config.password || "", database: conn.config.database || "", useSSH: conn.config.useSSH || false, ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } }; - const res = await DBQuery(config as any, tab.dbName || '', previewSql); + const res = await DBQuery(buildRpcConnectionConfig(config) as any, tab.dbName || '', previewSql); if (res.success) { message.success(isNewTable ? "表创建成功!" : "表结构修改成功!"); setIsPreviewOpen(false); diff --git a/frontend/src/components/TableOverview.tsx b/frontend/src/components/TableOverview.tsx index b93a66f..bf687a1 100644 --- a/frontend/src/components/TableOverview.tsx +++ b/frontend/src/components/TableOverview.tsx @@ -4,6 +4,7 @@ import { TableOutlined, SearchOutlined, ReloadOutlined, SortAscendingOutlined, D import { useStore } from '../store'; import { DBQuery, DBShowCreateTable, ExportTable, DropTable, RenameTable } from '../../wailsjs/go/app/App'; import type { TabData } from '../types'; +import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; interface TableOverviewProps { tab: TabData; @@ -163,9 +164,9 @@ const TableOverview: React.FC = ({ tab }) => { useSSH: connection.config.useSSH || false, ssh: connection.config.ssh || { host: '', port: 22, user: '', password: '', keyPath: '' }, }; - const dialect = getMetadataDialect(connection.config.type, (connection.config as any)?.driver); + const dialect = getMetadataDialect(connection.config.type, connection.config.driver); const sql = buildTableStatusSQL(dialect, tab.dbName || '', (tab as any).schemaName); - const res = await DBQuery(config as any, tab.dbName || '', sql); + const res = await DBQuery(buildRpcConnectionConfig(config) as any, tab.dbName || '', sql); if (res.success && Array.isArray(res.data)) { setTables(parseTableStats(dialect, res.data)); } else { @@ -239,7 +240,7 @@ const TableOverview: React.FC = ({ tab }) => { const handleCopyStructure = useCallback(async (tableName: string) => { const config = buildConfig(); if (!config) return; - const res = await DBShowCreateTable(config as any, tab.dbName || '', tableName); + const res = await DBShowCreateTable(buildRpcConnectionConfig(config) as any, tab.dbName || '', tableName); if (res.success) { navigator.clipboard.writeText(res.data as string); message.success('表结构已复制到剪贴板'); @@ -252,7 +253,7 @@ const TableOverview: React.FC = ({ tab }) => { const config = buildConfig(); if (!config) return; const hide = message.loading(`正在导出 ${tableName} 为 ${format.toUpperCase()}...`, 0); - const res = await ExportTable(config as any, tab.dbName || '', tableName, format); + const res = await ExportTable(buildRpcConnectionConfig(config) as any, tab.dbName || '', tableName, format); hide(); if (res.success) { message.success('导出成功'); @@ -269,7 +270,7 @@ const TableOverview: React.FC = ({ tab }) => { content: `确定删除表 "${tableName}" 吗?该操作不可恢复。`, okButtonProps: { danger: true }, onOk: async () => { - const res = await DropTable(config as any, tab.dbName || '', tableName); + const res = await DropTable(buildRpcConnectionConfig(config) as any, tab.dbName || '', tableName); if (res.success) { message.success('表删除成功'); loadData(); @@ -299,7 +300,7 @@ const TableOverview: React.FC = ({ tab }) => { const trimmed = newName.trim(); if (!trimmed) { message.error('表名不能为空'); return Promise.reject(); } if (trimmed === tableName) { message.warning('新旧表名相同'); return; } - const res = await RenameTable(config as any, tab.dbName || '', tableName, trimmed); + const res = await RenameTable(buildRpcConnectionConfig(config) as any, tab.dbName || '', tableName, trimmed); if (res.success) { message.success('表重命名成功'); loadData(); diff --git a/frontend/src/components/TriggerViewer.tsx b/frontend/src/components/TriggerViewer.tsx index 849a7ca..b380f62 100644 --- a/frontend/src/components/TriggerViewer.tsx +++ b/frontend/src/components/TriggerViewer.tsx @@ -4,6 +4,7 @@ import { Spin, Alert } from 'antd'; import { TabData } from '../types'; import { useStore } from '../store'; import { DBQuery } from '../../wailsjs/go/app/App'; +import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; interface TriggerViewerProps { tab: TabData; @@ -100,7 +101,7 @@ LIMIT 1`]; const sql = String(query || '').trim(); if (!sql) continue; try { - const result = await DBQuery(config as any, dbName, sql); + const result = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, sql); if (!result.success || !Array.isArray(result.data)) { lastMessage = result.message || lastMessage; continue; @@ -126,7 +127,7 @@ LIMIT 1`]; ]; for (const query of candidates) { try { - const result = await DBQuery(config as any, dbName, query); + const result = await DBQuery(buildRpcConnectionConfig(config) as any, dbName, query); if (!result.success || !Array.isArray(result.data) || result.data.length === 0) { continue; } diff --git a/frontend/src/components/ai/AIChatInput.tsx b/frontend/src/components/ai/AIChatInput.tsx index 37ab15e..0640cfc 100644 --- a/frontend/src/components/ai/AIChatInput.tsx +++ b/frontend/src/components/ai/AIChatInput.tsx @@ -5,6 +5,7 @@ import { useStore } from '../../store'; import { DBGetTables, DBShowCreateTable, DBGetDatabases } from '../../../wailsjs/go/app/App'; import type { OverlayWorkbenchTheme } from '../../utils/overlayWorkbenchTheme'; import type { AIComposerNotice } from '../../utils/aiComposerNotice'; +import { buildRpcConnectionConfig } from '../../utils/connectionRpcConfig'; interface AIChatInputProps { input: string; @@ -124,7 +125,7 @@ export const AIChatInput: React.FC = ({ setContextLoading(true); setSelectedDbName(dbName); try { - const res = await DBGetTables(connConfig, dbName); + const res = await DBGetTables(buildRpcConnectionConfig(connConfig), dbName); if (res.success && Array.isArray(res.data)) { setContextTables(res.data.map(r => ({ name: Object.values(r)[0] as string }))); } else { @@ -155,7 +156,7 @@ export const AIChatInput: React.FC = ({ try { // Fetch databases - const dbRes = await DBGetDatabases(conn.config as any); + const dbRes = await DBGetDatabases(buildRpcConnectionConfig(conn.config) as any); if (dbRes.success && Array.isArray(dbRes.data)) { const databases = dbRes.data.map((r: any) => Object.values(r)[0] as string); setDbList(databases); @@ -164,7 +165,7 @@ export const AIChatInput: React.FC = ({ // Fetch tables for the active contextual database const initDbName = activeContext.dbName || ''; setSelectedDbName(initDbName); - const tablesRes = await DBGetTables(conn.config as any, initDbName); + const tablesRes = await DBGetTables(buildRpcConnectionConfig(conn.config) as any, initDbName); if (tablesRes.success && Array.isArray(tablesRes.data)) { setContextTables(tablesRes.data.map((r: any) => ({ name: Object.values(r)[0] as string }))); } else { @@ -201,7 +202,7 @@ export const AIChatInput: React.FC = ({ if (activeContextItems.find(c => c.dbName === dbName && c.tableName === tableName)) { continue; } - const res = await DBShowCreateTable(conn.config as any, dbName, tableName); + const res = await DBShowCreateTable(buildRpcConnectionConfig(conn.config) as any, dbName, tableName); let createSql = ''; if (res.success && res.data) { if (typeof res.data === 'string') { diff --git a/frontend/src/utils/connectionRpcConfig.test.ts b/frontend/src/utils/connectionRpcConfig.test.ts new file mode 100644 index 0000000..197c20a --- /dev/null +++ b/frontend/src/utils/connectionRpcConfig.test.ts @@ -0,0 +1,104 @@ +import { describe, expect, it } from 'vitest'; + +import { connection } from '../../wailsjs/go/models'; +import { buildRpcConnectionConfig } from './connectionRpcConfig'; + +describe('buildRpcConnectionConfig', () => { + it('preserves the saved connection id while normalizing numeric fields', () => { + const result = buildRpcConnectionConfig({ + id: 'conn-1', + type: 'postgres', + host: 'db.local', + port: '5432' as unknown as number, + user: 'postgres', + useSSH: true, + ssh: { + host: 'bastion.local', + port: '2222' as unknown as number, + user: 'ops', + }, + useProxy: true, + proxy: { + type: 'http', + host: '127.0.0.1', + port: '8080' as unknown as number, + }, + } as any, { + id: 'conn-2', + timeout: '120' as unknown as number, + redisDB: '6' as unknown as number, + database: 'app', + }); + + expect(result.id).toBe('conn-1'); + expect(result.port).toBe(5432); + expect(result.ssh?.port).toBe(2222); + expect(result.proxy?.port).toBe(8080); + expect(result.timeout).toBe(120); + expect(result.redisDB).toBe(6); + expect(result.database).toBe('app'); + }); + + it('fills default nested config blocks needed by RPC calls', () => { + const result = buildRpcConnectionConfig({ + id: 'conn-redis', + type: 'redis', + host: '127.0.0.1', + port: 6379, + user: '', + } as any, { + useSSH: true, + useHttpTunnel: true, + redisDB: '4' as unknown as number, + }); + + expect(result.id).toBe('conn-redis'); + expect(result.redisDB).toBe(4); + expect(result.ssh).toEqual({ + host: '', + port: 22, + user: '', + password: '', + keyPath: '', + }); + expect(result.httpTunnel).toEqual({ + host: '', + port: 8080, + user: '', + password: '', + }); + }); + + it('returns a Wails connection model instance for RPC compatibility', () => { + const result = buildRpcConnectionConfig({ + id: 'conn-model', + type: 'mysql', + host: '127.0.0.1', + port: '3306' as unknown as number, + user: 'root', + useSSH: true, + ssh: { + host: 'jump.local', + port: '2222' as unknown as number, + user: 'ops', + }, + useProxy: true, + proxy: { + type: 'http', + host: '127.0.0.1', + port: '8080' as unknown as number, + }, + useHttpTunnel: true, + httpTunnel: { + host: '127.0.0.1', + port: '9000' as unknown as number, + }, + } as any); + + expect(result).toBeInstanceOf(connection.ConnectionConfig); + expect(result.ssh).toBeInstanceOf(connection.SSHConfig); + expect(result.proxy).toBeInstanceOf(connection.ProxyConfig); + expect(result.httpTunnel).toBeInstanceOf(connection.HTTPTunnelConfig); + expect(typeof (result as any).convertValues).toBe('function'); + }); +}); diff --git a/frontend/src/utils/connectionRpcConfig.ts b/frontend/src/utils/connectionRpcConfig.ts new file mode 100644 index 0000000..81d9294 --- /dev/null +++ b/frontend/src/utils/connectionRpcConfig.ts @@ -0,0 +1,122 @@ +import { connection } from '../../wailsjs/go/models'; + +export type RpcConnectionConfig = connection.ConnectionConfig & { id?: string }; +type ConnectionConfigInput = { + id?: string; + ssh?: Record; + proxy?: Record; + httpTunnel?: Record; + [key: string]: any; +}; +type SSHConfigInput = Record; +type ProxyConfigInput = Record; +type HttpTunnelConfigInput = Record; + +const toStringValue = (value: unknown, fallback = ''): string => { + if (typeof value === 'string') { + return value; + } + if (typeof value === 'number' || typeof value === 'boolean') { + return String(value); + } + return fallback; +}; + +const toOptionalInteger = (value: unknown, fallback?: number): number | undefined => { + if (value === undefined || value === null || value === '') { + return fallback; + } + const parsed = Number(value); + if (!Number.isFinite(parsed)) { + return fallback; + } + return Math.trunc(parsed); +}; + +const normalizeProxyType = (value: unknown): 'socks5' | 'http' => { + return toStringValue(value).toLowerCase() === 'http' ? 'http' : 'socks5'; +}; + +const normalizeSSHConfig = (value: unknown): connection.SSHConfig => { + const raw = (value ?? {}) as SSHConfigInput; + return new connection.SSHConfig({ + host: toStringValue(raw.host), + port: toOptionalInteger(raw.port, 22) ?? 22, + user: toStringValue(raw.user), + password: toStringValue(raw.password), + keyPath: toStringValue(raw.keyPath), + }); +}; + +const normalizeProxyConfig = (value: unknown): connection.ProxyConfig => { + const raw = (value ?? {}) as ProxyConfigInput; + const type = normalizeProxyType(raw.type); + return new connection.ProxyConfig({ + type, + host: toStringValue(raw.host), + port: toOptionalInteger(raw.port, type === 'http' ? 8080 : 1080) ?? (type === 'http' ? 8080 : 1080), + user: toStringValue(raw.user), + password: toStringValue(raw.password), + }); +}; + +const normalizeHttpTunnelConfig = (value: unknown): connection.HTTPTunnelConfig => { + const raw = (value ?? {}) as HttpTunnelConfigInput; + return new connection.HTTPTunnelConfig({ + host: toStringValue(raw.host), + port: toOptionalInteger(raw.port, 8080) ?? 8080, + user: toStringValue(raw.user), + password: toStringValue(raw.password), + }); +}; + +export function buildRpcConnectionConfig( + config: ConnectionConfigInput, + overrides: ConnectionConfigInput = {}, +): RpcConnectionConfig { + const mergedSSH = { + ...(config.ssh ?? {}), + ...(overrides.ssh ?? {}), + }; + const mergedProxy = { + ...(config.proxy ?? {}), + ...(overrides.proxy ?? {}), + }; + const mergedHttpTunnel = { + ...(config.httpTunnel ?? {}), + ...(overrides.httpTunnel ?? {}), + }; + const merged: ConnectionConfigInput = { + ...config, + ...overrides, + ssh: mergedSSH, + proxy: mergedProxy, + httpTunnel: mergedHttpTunnel, + }; + + const baseId = toStringValue(config.id).trim() || toStringValue(overrides.id).trim() || undefined; + const timeout = toOptionalInteger(merged.timeout, toOptionalInteger(config.timeout)); + const redisDB = toOptionalInteger(merged.redisDB, toOptionalInteger(config.redisDB)); + + const rpcConfig = new connection.ConnectionConfig({ + ...merged, + type: toStringValue(merged.type), + host: toStringValue(merged.host), + port: toOptionalInteger(merged.port, toOptionalInteger(config.port, 0)) ?? 0, + user: toStringValue(merged.user), + password: toStringValue(merged.password), + database: toStringValue(merged.database), + useSSH: merged.useSSH === true, + ssh: normalizeSSHConfig(merged.ssh), + useProxy: merged.useProxy === true, + proxy: normalizeProxyConfig(merged.proxy), + useHttpTunnel: merged.useHttpTunnel === true, + httpTunnel: normalizeHttpTunnelConfig(merged.httpTunnel), + timeout, + redisDB, + }) as RpcConnectionConfig; + + rpcConfig.id = baseId; + return rpcConfig; +} + From 47187552089deb806c04944498313dfd78db1955 Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Fri, 3 Apr 2026 20:11:53 +0800 Subject: [PATCH 11/14] =?UTF-8?q?=E2=9C=A8=20feat(security):=20=E5=AE=8C?= =?UTF-8?q?=E6=88=90=E9=85=8D=E7=BD=AE=E5=AF=86=E6=96=87=E5=AD=98=E5=82=A8?= =?UTF-8?q?=E5=89=8D=E5=90=8E=E7=AB=AF=E9=97=AD=E7=8E=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 补齐连接与代理密文字段的保留替换清空语义 - 接通保存复制删除导入接口并返回 secretless 视图 - 刷新 Wails 绑定并补充实现留痕文档 --- frontend/src/App.tsx | 58 +++- frontend/src/components/AIChatPanel.tsx | 7 +- frontend/src/components/AISettingsModal.tsx | 51 ++- frontend/src/components/ConnectionModal.tsx | 310 +++++++++++++++++- frontend/src/components/Sidebar.tsx | 176 +++------- frontend/src/main.tsx | 122 ++++++- frontend/src/types.ts | 2 + .../src/utils/connectionSecretDraft.test.ts | 86 +++++ frontend/src/utils/connectionSecretDraft.ts | 63 ++++ .../src/utils/providerSecretDraft.test.ts | 41 +++ frontend/src/utils/providerSecretDraft.ts | 47 +++ frontend/wailsjs/go/app/App.d.ts | 14 + frontend/wailsjs/go/app/App.js | 28 ++ frontend/wailsjs/go/models.ts | 172 ++++++++++ .../app/methods_saved_connections_test.go | 101 +++++- internal/app/saved_connections.go | 93 +++++- internal/connection/saved_types.go | 22 +- 17 files changed, 1207 insertions(+), 186 deletions(-) create mode 100644 frontend/src/utils/connectionSecretDraft.test.ts create mode 100644 frontend/src/utils/connectionSecretDraft.ts create mode 100644 frontend/src/utils/providerSecretDraft.test.ts create mode 100644 frontend/src/utils/providerSecretDraft.ts diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index c285de3..f6c1554 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect, useMemo, useCallback } from 'react'; +import React, { useState, useEffect, useMemo, useCallback } from 'react'; import { Layout, Button, ConfigProvider, theme, message, Modal, Spin, Slider, Progress, Switch, Input, InputNumber, Select, Tooltip } from 'antd'; import zhCN from 'antd/locale/zh_CN'; import { PlusOutlined, ConsoleSqlOutlined, UploadOutlined, DownloadOutlined, CloudDownloadOutlined, BugOutlined, ToolOutlined, GlobalOutlined, InfoCircleOutlined, GithubOutlined, SkinOutlined, CheckOutlined, MinusOutlined, BorderOutlined, CloseOutlined, SettingOutlined, LinkOutlined, BgColorsOutlined, AppstoreOutlined, RobotOutlined } from '@ant-design/icons'; @@ -61,6 +61,24 @@ const detectNavigatorPlatform = (): string => { return navigator.userAgent || ''; }; + +const toLegacySavedConnectionInput = (item: any) => ({ + id: typeof item?.id === 'string' ? item.id : '', + name: typeof item?.name === 'string' ? item.name : '', + config: (item?.config && typeof item.config === 'object') ? item.config : {}, + includeDatabases: Array.isArray(item?.includeDatabases) ? item.includeDatabases : undefined, + includeRedisDatabases: Array.isArray(item?.includeRedisDatabases) ? item.includeRedisDatabases : undefined, + iconType: typeof item?.iconType === 'string' ? item.iconType : '', + iconColor: typeof item?.iconColor === 'string' ? item.iconColor : '', +}); + +const mergeSavedConnections = (current: SavedConnection[], imported: SavedConnection[]): SavedConnection[] => { + const merged = new Map(); + current.forEach((conn) => merged.set(conn.id, conn)); + imported.forEach((conn) => merged.set(conn.id, conn)); + return Array.from(merged.values()); +}; + function App() { const [isModalOpen, setIsModalOpen] = useState(false); const [isSyncModalOpen, setIsSyncModalOpen] = useState(false); @@ -186,7 +204,7 @@ function App() { if (typeof backendApp?.ImportLegacyConnections === 'function') { try { await backendApp.ImportLegacyConnections( - legacy.connections.map(({ id, name, config }) => ({ id, name, config })) + legacy.connections.map(toLegacySavedConnectionInput) ); importedLegacyConnections = true; } catch (err) { @@ -751,7 +769,6 @@ function App() { const addTab = useStore(state => state.addTab); const activeContext = useStore(state => state.activeContext); const connections = useStore(state => state.connections); - const addConnection = useStore(state => state.addConnection); const tabs = useStore(state => state.tabs); const activeTabId = useStore(state => state.activeTabId); const updateCheckInFlightRef = React.useRef(false); @@ -1166,20 +1183,29 @@ function App() { if (res.success) { try { const imported = JSON.parse(res.data); - if (Array.isArray(imported)) { - let count = 0; - imported.forEach((conn: any) => { - if (!connections.some(c => c.id === conn.id)) { - addConnection(conn); - count++; - } - }); - void message.success(`成功导入 ${count} 个连接`); - } else { + if (!Array.isArray(imported)) { void message.error("文件格式错误:需要 JSON 数组"); + return; } - } catch (e) { - void message.error("解析 JSON 失败"); + + const normalizedItems = imported.map(toLegacySavedConnectionInput); + const backendApp = (window as any).go?.app?.App; + + if (typeof backendApp?.ImportLegacyConnections === 'function') { + const importedViews = await backendApp.ImportLegacyConnections(normalizedItems); + if (!Array.isArray(importedViews)) { + throw new Error('导入失败:后端未返回连接列表'); + } + replaceConnections(mergeSavedConnections(connections, importedViews)); + void message.success(`成功导入 ${importedViews.length} 个连接`); + return; + } + + const fallbackItems = normalizedItems as SavedConnection[]; + replaceConnections(mergeSavedConnections(connections, fallbackItems)); + void message.success(`成功导入 ${fallbackItems.length} 个连接`); + } catch (e: any) { + void message.error(e?.message || "解析 JSON 失败"); } } else if (res.message !== "已取消") { void message.error("导入失败: " + res.message); @@ -1191,7 +1217,7 @@ function App() { void message.warning("没有连接可导出"); return; } - const res = await (window as any).go.app.App.ExportData(connections, ['id','name','config','includeDatabases','includeRedisDatabases'], "connections", "json"); + const res = await (window as any).go.app.App.ExportData(connections, ['id','name','config','includeDatabases','includeRedisDatabases','iconType','iconColor'], "connections", "json"); if (res.success) { void message.success("导出成功"); } else if (res.message !== "已取消") { diff --git a/frontend/src/components/AIChatPanel.tsx b/frontend/src/components/AIChatPanel.tsx index 695bf09..754aafe 100644 --- a/frontend/src/components/AIChatPanel.tsx +++ b/frontend/src/components/AIChatPanel.tsx @@ -353,7 +353,12 @@ export const AIChatPanel: React.FC = ({ if (!activeProvider) return; try { const Service = (window as any).go?.aiservice?.Service; - const payload = { ...activeProvider, model: val }; + const payload = { + ...activeProvider, + model: val, + apiKey: activeProvider.apiKey || '', + hasSecret: activeProvider.hasSecret ?? Boolean(activeProvider.secretRef), + }; await Service?.AISaveProvider?.(payload); setActiveProvider(payload); setComposerNotice(null); diff --git a/frontend/src/components/AISettingsModal.tsx b/frontend/src/components/AISettingsModal.tsx index 403f352..7d586a4 100644 --- a/frontend/src/components/AISettingsModal.tsx +++ b/frontend/src/components/AISettingsModal.tsx @@ -1,5 +1,5 @@ -import React, { useState, useEffect, useCallback, useRef } from 'react'; -import { Modal, Button, Input, Select, Form, message as antdMessage, Tooltip, Tabs, Space, Popconfirm, Slider } from 'antd'; +import React, { useState, useEffect, useCallback, useRef } from 'react'; +import { Modal, Button, Input, Select, Form, Checkbox, message as antdMessage, Tooltip, Tabs, Space, Popconfirm, Slider } from 'antd'; import { PlusOutlined, DeleteOutlined, EditOutlined, CheckOutlined, ApiOutlined, SafetyCertificateOutlined, RobotOutlined, ThunderboltOutlined, CloudOutlined, ExperimentOutlined, KeyOutlined, LinkOutlined, AppstoreOutlined, ToolOutlined } from '@ant-design/icons'; import type { AIProviderConfig, AIProviderType, AISafetyLevel, AIContextLevel } from '../types'; import { @@ -18,6 +18,7 @@ import { PROVIDER_PRESET_GRID_STYLE, PROVIDER_PRESET_CARD_TITLE_STYLE, } from '../utils/aiSettingsPresetLayout'; +import { resolveProviderSecretDraft } from '../utils/providerSecretDraft'; import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme'; @@ -88,6 +89,7 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo const [testStatus, setTestStatus] = useState<'idle' | 'success' | 'error'>('idle'); const [builtinPrompts, setBuiltinPrompts] = useState>({}); const [activeSection, setActiveSection] = useState<'providers' | 'safety' | 'context' | 'prompts' | 'tools'>('providers'); + const [clearProviderSecret, setClearProviderSecret] = useState(false); const [form] = Form.useForm(); const modalBodyRef = useRef(null); @@ -105,6 +107,7 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo const watchedType = Form.useWatch('type', form); const watchedPresetKey = Form.useWatch('presetKey', form); const watchedApiFormat = Form.useWatch('apiFormat', form) || 'openai'; + const watchedApiKeyInput = Form.useWatch('apiKey', form); const loadConfig = useCallback(async () => { try { @@ -217,12 +220,18 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo presetFixedApiFormat: preset.fixedApiFormat, valuesApiFormat: values.apiFormat, }); - + const secretDraft = resolveProviderSecretDraft({ + hasSecret: editingProvider?.hasSecret, + apiKeyInput: values.apiKey, + clearSecret: clearProviderSecret, + }); const payload = { ...editingProvider, ...values, ...resolvedTransport, name: finalName, + apiKey: secretDraft.apiKey, + hasSecret: secretDraft.hasSecret, model: finalModel, models: resolvedModels, baseUrl: finalBaseUrl, @@ -230,7 +239,7 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo }; // 后端 AISaveProvider 统一处理新增和更新,返回 void,失败抛异常 await Service?.AISaveProvider?.(payload); - void messageApi.success('已保存'); setIsEditing(false); setEditingProvider(null); void loadConfig(); + void messageApi.success('已保存'); setIsEditing(false); setEditingProvider(null); setClearProviderSecret(false); void loadConfig(); window.dispatchEvent(new CustomEvent('gonavi:ai:provider-changed')); } catch (e: any) { if (e?.errorFields) { /* antd form validation error, ignore */ } @@ -287,10 +296,20 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo presetFixedApiFormat: preset.fixedApiFormat, valuesApiFormat: values.apiFormat, }); + const secretDraft = resolveProviderSecretDraft({ + hasSecret: editingProvider?.hasSecret, + apiKeyInput: values.apiKey, + clearSecret: clearProviderSecret, + }); + if (secretDraft.mode === 'clear') { + throw new Error('测试连接前请填写新的 API Key,或取消清除已保存密钥'); + } const res = await Service?.AITestProvider?.({ ...editingProvider, ...values, ...resolvedTransport, + apiKey: secretDraft.apiKey, + hasSecret: secretDraft.hasSecret, baseUrl: finalBaseUrl, model: finalModel, models: resolvedModels, @@ -401,7 +420,7 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo
{/* 顶部返回 */}
- {editingProvider?.id ? '编辑模型供应商' : '添加模型供应商'} @@ -492,11 +511,25 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo
认证 & 连接
- API Key
} name="apiKey" rules={[{ required: true, message: '请输入 API Key' }]} style={{ marginBottom: 16 }}> - API Key} name="apiKey" rules={[{ validator: (_, value) => { const apiKey = String(value || '').trim(); if (apiKey || clearProviderSecret || editingProvider?.hasSecret) { return Promise.resolve(); } return Promise.reject(new Error('请输入 API Key')); } }]} style={{ marginBottom: editingProvider?.hasSecret ? 8 : 16 }}> + + {editingProvider?.hasSecret && ( +
+
+ 当前已保存 API Key。留空表示继续沿用,输入新值表示替换。 +
+ setClearProviderSecret(event.target.checked)} + > + 清除已保存 API Key + +
+ )} {(presetKeyFromForm === 'custom' || presetKeyFromForm === 'ollama') && ( API Endpoint (URL)} name="baseUrl" rules={[{ required: true, message: '请输入有效的接口地址' }]} style={{ marginBottom: 0 }}> @@ -765,3 +798,7 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo }; export default AISettingsModal; + + + + diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx index c1f17a8..b0f0763 100644 --- a/frontend/src/components/ConnectionModal.tsx +++ b/frontend/src/components/ConnectionModal.tsx @@ -1,10 +1,11 @@ -import React, { useState, useEffect, useRef, useMemo } from 'react'; +import React, { useState, useEffect, useRef, useMemo } from 'react'; import { Modal, Form, Input, InputNumber, Button, message, Checkbox, Divider, Select, Alert, Card, Row, Col, Typography, Collapse, Space, Table, Tag } from 'antd'; import { DatabaseOutlined, ConsoleSqlOutlined, FileTextOutlined, CloudServerOutlined, AppstoreAddOutlined, CloudOutlined, CheckCircleFilled, CloseCircleFilled, LinkOutlined, EditOutlined, AppstoreOutlined, BgColorsOutlined } from '@ant-design/icons'; import { getDbIcon, getDbDefaultColor, getDbIconLabel, DB_ICON_TYPES, PRESET_ICON_COLORS } from './DatabaseIcons'; import { useStore } from '../store'; import { buildOverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme'; import { normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance'; +import { resolveConnectionSecretDraft } from '../utils/connectionSecretDraft'; import { DBGetDatabases, GetDriverStatusList, MongoDiscoverMembers, TestConnection, RedisConnect, SelectDatabaseFile, SelectSSHKeyFile } from '../../wailsjs/go/app/App'; import { ConnectionConfig, MongoMemberInfo, SavedConnection } from '../types'; @@ -18,6 +19,29 @@ const CONNECTION_MODAL_BODY_HEIGHT = 620; const STEP1_SIDEBAR_DIVIDER_DARK = 'rgba(255, 255, 255, 0.16)'; const STEP1_SIDEBAR_DIVIDER_LIGHT = 'rgba(0, 0, 0, 0.08)'; +type ConnectionSecretKey = + | 'primaryPassword' + | 'sshPassword' + | 'proxyPassword' + | 'httpTunnelPassword' + | 'mysqlReplicaPassword' + | 'mongoReplicaPassword' + | 'opaqueURI' + | 'opaqueDSN'; + +type ConnectionSecretClearState = Record; + +const createEmptyConnectionSecretClearState = (): ConnectionSecretClearState => ({ + primaryPassword: false, + sshPassword: false, + proxyPassword: false, + httpTunnelPassword: false, + mysqlReplicaPassword: false, + mongoReplicaPassword: false, + opaqueURI: false, + opaqueDSN: false, +}); + const getDefaultPortByType = (type: string) => { switch (type) { case 'mysql': return 3306; @@ -122,6 +146,7 @@ const ConnectionModal: React.FC<{ const [driverStatusLoaded, setDriverStatusLoaded] = useState(false); const [selectingDbFile, setSelectingDbFile] = useState(false); const [selectingSSHKey, setSelectingSSHKey] = useState(false); + const [clearSecrets, setClearSecrets] = useState(createEmptyConnectionSecretClearState); const testInFlightRef = useRef(false); const testTimerRef = useRef(null); const addConnection = useStore((state) => state.addConnection); @@ -192,6 +217,51 @@ const ConnectionModal: React.FC<{ lineHeight: 1.6, }), [overlayTheme]); + const renderStoredSecretControls = ({ + fieldName, + clearKey, + hasStoredSecret, + clearLabel, + description, + }: { + fieldName: string; + clearKey: ConnectionSecretKey; + hasStoredSecret?: boolean; + clearLabel: string; + description: string; + }) => { + if (!initialValues || !hasStoredSecret) { + return null; + } + return ( + prev[fieldName] !== next[fieldName]}> + {({ getFieldValue }) => { + const draftValue = getFieldValue(fieldName); + const hasDraftValue = String(draftValue ?? '') !== ''; + const cardBorder = darkMode ? '1px solid rgba(255,255,255,0.12)' : '1px solid rgba(16,24,40,0.08)'; + const cardBg = darkMode ? 'rgba(255,255,255,0.03)' : 'rgba(16,24,40,0.03)'; + const effectiveChecked = clearSecrets[clearKey] && !hasDraftValue; + return ( +
+
+ {hasDraftValue ? '已输入新值,保存时会替换当前已保存内容。' : description} +
+ { + const checked = event.target.checked; + setClearSecrets((prev) => ({ ...prev, [clearKey]: checked })); + }} + > + {clearLabel} + +
+ ); + }} +
+ ); + }; const renderConnectionModalTitle = (icon: React.ReactNode, title: string, description: string) => (
@@ -1066,6 +1136,7 @@ const ConnectionModal: React.FC<{ setUriFeedback(null); setCustomIconType(undefined); setCustomIconColor(undefined); + setClearSecrets(createEmptyConnectionSecretClearState()); setTypeSelectWarning(null); setDriverStatusLoaded(false); void refreshDriverStatus(); @@ -1198,6 +1269,107 @@ const ConnectionModal: React.FC<{ }; }, []); + const buildSavedConnectionInput = (config: ConnectionConfig, values: any) => { + const connectionId = initialValues?.id || config.id || Date.now().toString(); + const primaryDraft = resolveConnectionSecretDraft({ + hasSecret: initialValues?.hasPrimaryPassword, + valueInput: config.password, + clearSecret: clearSecrets.primaryPassword, + forceClear: values.type === 'mongodb' && values.savePassword === false, + }); + const sshDraft = resolveConnectionSecretDraft({ + hasSecret: initialValues?.hasSSHPassword, + valueInput: config.ssh?.password, + clearSecret: clearSecrets.sshPassword, + forceClear: !config.useSSH, + }); + const proxyDraft = resolveConnectionSecretDraft({ + hasSecret: initialValues?.hasProxyPassword, + valueInput: config.proxy?.password, + clearSecret: clearSecrets.proxyPassword, + forceClear: !config.useProxy, + }); + const httpTunnelDraft = resolveConnectionSecretDraft({ + hasSecret: initialValues?.hasHttpTunnelPassword, + valueInput: config.httpTunnel?.password, + clearSecret: clearSecrets.httpTunnelPassword, + forceClear: !config.useHttpTunnel, + }); + const mysqlReplicaEnabled = (config.type === 'mysql' || config.type === 'mariadb' || config.type === 'diros' || config.type === 'sphinx') + && config.topology === 'replica'; + const mysqlReplicaDraft = resolveConnectionSecretDraft({ + hasSecret: initialValues?.hasMySQLReplicaPassword, + valueInput: config.mysqlReplicaPassword, + clearSecret: clearSecrets.mysqlReplicaPassword, + forceClear: !mysqlReplicaEnabled, + }); + const mongoReplicaEnabled = config.type === 'mongodb' + && config.topology === 'replica' + && values.savePassword !== false; + const mongoReplicaDraft = resolveConnectionSecretDraft({ + hasSecret: initialValues?.hasMongoReplicaPassword, + valueInput: config.mongoReplicaPassword, + clearSecret: clearSecrets.mongoReplicaPassword, + forceClear: !mongoReplicaEnabled, + }); + const opaqueUriDraft = resolveConnectionSecretDraft({ + hasSecret: initialValues?.hasOpaqueURI, + valueInput: config.uri, + clearSecret: clearSecrets.opaqueURI, + forceClear: values.type === 'custom', + trimInput: true, + }); + const opaqueDsnDraft = resolveConnectionSecretDraft({ + hasSecret: initialValues?.hasOpaqueDSN, + valueInput: config.dsn, + clearSecret: clearSecrets.opaqueDSN, + forceClear: values.type !== 'custom', + trimInput: true, + }); + const isRedisType = values.type === 'redis'; + const displayHost = String((config as any).host || values.host || '').trim(); + const nextName = values.name || (isFileDatabaseType(values.type) + ? (values.type === 'duckdb' ? 'DuckDB DB' : 'SQLite DB') + : (values.type === 'redis' ? `Redis ${displayHost}` : displayHost)); + + return { + id: connectionId, + name: nextName, + config: { + ...config, + id: connectionId, + password: primaryDraft.value, + ssh: { + ...(config.ssh || { host: '', port: 22, user: '', password: '', keyPath: '' }), + password: sshDraft.value, + }, + proxy: { + ...(config.proxy || { type: 'socks5', host: '', port: 1080, user: '', password: '' }), + password: proxyDraft.value, + }, + httpTunnel: { + ...(config.httpTunnel || { host: '', port: 8080, user: '', password: '' }), + password: httpTunnelDraft.value, + }, + uri: opaqueUriDraft.value, + dsn: opaqueDsnDraft.value, + mysqlReplicaPassword: mysqlReplicaDraft.value, + mongoReplicaPassword: mongoReplicaDraft.value, + }, + includeDatabases: values.includeDatabases, + includeRedisDatabases: isRedisType ? values.includeRedisDatabases : undefined, + iconType: customIconType || '', + iconColor: customIconColor || '', + clearPrimaryPassword: primaryDraft.clearStoredSecret, + clearSSHPassword: sshDraft.clearStoredSecret, + clearProxyPassword: proxyDraft.clearStoredSecret, + clearHttpTunnelPassword: httpTunnelDraft.clearStoredSecret, + clearMySQLReplicaPassword: mysqlReplicaDraft.clearStoredSecret, + clearMongoReplicaPassword: mongoReplicaDraft.clearStoredSecret, + clearOpaqueURI: opaqueUriDraft.clearStoredSecret, + clearOpaqueDSN: opaqueDsnDraft.clearStoredSecret, + }; + }; const handleOk = async () => { try { await form.validateFields(); @@ -1211,28 +1383,21 @@ const ConnectionModal: React.FC<{ setLoading(true); const config = await buildConfig(values, true); - const displayHost = String((config as any).host || values.host || '').trim(); - - const isRedisType = values.type === 'redis'; - const newConn = { - id: initialValues ? initialValues.id : Date.now().toString(), - name: values.name || (isFileDatabaseType(values.type) ? (values.type === 'duckdb' ? 'DuckDB DB' : 'SQLite DB') : (values.type === 'redis' ? `Redis ${displayHost}` : displayHost)), - config: config, - includeDatabases: values.includeDatabases, - includeRedisDatabases: isRedisType ? values.includeRedisDatabases : undefined, - iconType: customIconType, - iconColor: customIconColor, - }; + const payload = buildSavedConnectionInput(config, values); + const backendApp = (window as any).go?.app?.App; + const savedConnection = await backendApp?.SaveConnection?.(payload); + if (!savedConnection) { + throw new Error('保存连接失败:后端接口不可用'); + } if (initialValues) { - updateConnection(newConn); + updateConnection(savedConnection); message.success('配置已更新(未连接)'); } else { - addConnection(newConn); + addConnection(savedConnection); message.success('配置已保存(未连接)'); } - setLoading(false); form.resetFields(); setUseSSL(false); setUseSSH(false); @@ -1240,8 +1405,11 @@ const ConnectionModal: React.FC<{ setUseHttpTunnel(false); setDbType('mysql'); setStep(1); + setClearSecrets(createEmptyConnectionSecretClearState()); onClose(); - } catch (e) { + } catch (e: any) { + message.error(e?.message || '保存失败'); + } finally { setLoading(false); } }; @@ -1271,6 +1439,30 @@ const ConnectionModal: React.FC<{ } }; + const getBlockingSecretClearMessage = (values: any): string | null => { + if (clearSecrets.primaryPassword && values.type !== 'custom' && !isFileDatabaseType(values.type) && String(values.password ?? '') === '') { + return '测试连接前请填写新的密码,或取消清除已保存密码'; + } + if (clearSecrets.sshPassword && values.useSSH && String(values.sshPassword ?? '') === '') { + return '测试连接前请填写新的 SSH 密码,或取消清除已保存 SSH 密码'; + } + if (clearSecrets.proxyPassword && values.useProxy && !values.useHttpTunnel && String(values.proxyPassword ?? '') === '') { + return '测试连接前请填写新的代理密码,或取消清除已保存代理密码'; + } + if (clearSecrets.httpTunnelPassword && values.useHttpTunnel && String(values.httpTunnelPassword ?? '') === '') { + return '测试连接前请填写新的隧道密码,或取消清除已保存隧道密码'; + } + if (clearSecrets.mysqlReplicaPassword && (values.type === 'mysql' || values.type === 'mariadb' || values.type === 'diros' || values.type === 'sphinx') && values.mysqlTopology === 'replica' && String(values.mysqlReplicaPassword ?? '') === '') { + return '测试连接前请填写新的从库密码,或取消清除已保存从库密码'; + } + if (clearSecrets.mongoReplicaPassword && values.type === 'mongodb' && values.mongoTopology === 'replica' && String(values.mongoReplicaPassword ?? '') === '') { + return '测试连接前请填写新的副本集密码,或取消清除已保存副本集密码'; + } + if (values.type === 'mongodb' && values.savePassword === false && initialValues?.hasPrimaryPassword && String(values.password ?? '') === '') { + return '测试连接前请填写新的 MongoDB 密码,或重新勾选保存密码'; + } + return null; + }; const buildTestFailureMessage = (reason: unknown, fallback: string) => { const text = String(reason ?? '').trim(); const normalized = text && text !== 'undefined' && text !== 'null' ? text : fallback; @@ -1290,9 +1482,17 @@ const ConnectionModal: React.FC<{ promptInstallDriver(values.type, unavailableReason); return; } + const blockingSecretClearMessage = getBlockingSecretClearMessage(values); + if (blockingSecretClearMessage) { + setTestResult({ type: 'error', message: blockingSecretClearMessage }); + return; + } setLoading(true); setTestResult(null); const config = await buildConfig(values, false); + if (initialValues?.id) { + config.id = initialValues.id; + } const timeoutSecondsRaw = Number(values.timeout); const timeoutSeconds = Number.isFinite(timeoutSecondsRaw) && timeoutSecondsRaw > 0 ? Math.min(timeoutSecondsRaw, MAX_TIMEOUT_SECONDS) @@ -1368,7 +1568,15 @@ const ConnectionModal: React.FC<{ await form.validateFields(); const values = form.getFieldsValue(true); setDiscoveringMembers(true); + const blockingSecretClearMessage = getBlockingSecretClearMessage(values); + if (blockingSecretClearMessage) { + message.error(blockingSecretClearMessage); + return; + } const config = await buildConfig(values, false); + if (initialValues?.id) { + config.id = initialValues.id; + } const result = await MongoDiscoverMembers(config as any); if (!result.success) { message.error(result.message || '成员发现失败'); @@ -1877,6 +2085,13 @@ const ConnectionModal: React.FC<{ style={{ marginBottom: 16 }} /> )} + {renderStoredSecretControls({ + fieldName: 'uri', + clearKey: 'opaqueURI', + hasStoredSecret: initialValues?.hasOpaqueURI, + clearLabel: '清除已保存 URI', + description: '当前已保存连接 URI。留空表示继续沿用,输入新值表示替换。', + })} )} @@ -1888,6 +2103,13 @@ const ConnectionModal: React.FC<{ + {renderStoredSecretControls({ + fieldName: 'dsn', + clearKey: 'opaqueDSN', + hasStoredSecret: initialValues?.hasOpaqueDSN, + clearLabel: '清除已保存 DSN', + description: '当前已保存连接字符串。留空表示继续沿用,输入新值表示替换。', + })} ) : ( <> @@ -1968,6 +2190,13 @@ const ConnectionModal: React.FC<{
+ {renderStoredSecretControls({ + fieldName: 'mysqlReplicaPassword', + clearKey: 'mysqlReplicaPassword', + hasStoredSecret: initialValues?.hasMySQLReplicaPassword, + clearLabel: '清除已保存从库密码', + description: '当前已保存从库密码。留空表示继续沿用,输入新值表示替换。', + })} )} @@ -2010,6 +2239,13 @@ const ConnectionModal: React.FC<{ + {renderStoredSecretControls({ + fieldName: 'mongoReplicaPassword', + clearKey: 'mongoReplicaPassword', + hasStoredSecret: initialValues?.hasMongoReplicaPassword, + clearLabel: '清除已保存副本集密码', + description: '当前已保存副本集密码。留空表示继续沿用,输入新值表示替换。', + })} @@ -2084,6 +2320,13 @@ const ConnectionModal: React.FC<{ + {renderStoredSecretControls({ + fieldName: 'password', + clearKey: 'primaryPassword', + hasStoredSecret: initialValues?.hasPrimaryPassword, + clearLabel: '清除已保存密码', + description: '当前已保存 Redis 密码。留空表示继续沿用,输入新值表示替换。', + })}
)}
+ {renderStoredSecretControls({ + fieldName: 'password', + clearKey: 'primaryPassword', + hasStoredSecret: initialValues?.hasPrimaryPassword, + clearLabel: '清除已保存密码', + description: '当前已保存主连接密码。留空表示继续沿用,输入新值表示替换。', + })} + )} {dbType === 'mongodb' && ( @@ -2233,6 +2485,13 @@ const ConnectionModal: React.FC<{
+ {renderStoredSecretControls({ + fieldName: 'sshPassword', + clearKey: 'sshPassword', + hasStoredSecret: initialValues?.hasSSHPassword, + clearLabel: '清除已保存 SSH 密码', + description: '当前已保存 SSH 密码。留空表示继续沿用,输入新值表示替换。', + })}
)}
@@ -2271,6 +2530,13 @@ const ConnectionModal: React.FC<{
+ {renderStoredSecretControls({ + fieldName: 'proxyPassword', + clearKey: 'proxyPassword', + hasStoredSecret: initialValues?.hasProxyPassword, + clearLabel: '清除已保存代理密码', + description: '当前已保存代理密码。留空表示继续沿用,输入新值表示替换。', + })} )} @@ -2302,6 +2568,13 @@ const ConnectionModal: React.FC<{ + {renderStoredSecretControls({ + fieldName: 'httpTunnelPassword', + clearKey: 'httpTunnelPassword', + hasStoredSecret: initialValues?.hasHttpTunnelPassword, + clearLabel: '清除已保存隧道密码', + description: '当前已保存隧道密码。留空表示继续沿用,输入新值表示替换。', + })} 与“使用代理”互斥,启用后将通过 HTTP CONNECT 建立独立隧道。 )} @@ -2832,3 +3105,6 @@ const ConnectionModal: React.FC<{ }; export default ConnectionModal; + + + diff --git a/frontend/src/components/Sidebar.tsx b/frontend/src/components/Sidebar.tsx index 91b61c3..ed3743b 100644 --- a/frontend/src/components/Sidebar.tsx +++ b/frontend/src/components/Sidebar.tsx @@ -1,4 +1,4 @@ -import React, { useEffect, useState, useMemo, useRef } from 'react'; +import React, { useEffect, useState, useMemo, useRef } from 'react'; import { Tree, message, Dropdown, MenuProps, Input, Button, Modal, Form, Badge, Checkbox, Space, Select, Popover, Tooltip, Progress } from 'antd'; import { DatabaseOutlined, @@ -367,129 +367,25 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> }); }, [connections, connectionTags]); - const buildDuplicateConnectionName = (rawName: string): string => { - const baseName = String(rawName || '').trim() || '连接'; - const suffix = ' - 副本'; - const usedNames = new Set(connections.map(conn => String(conn.name || '').trim())); - let candidate = `${baseName}${suffix}`; - let counter = 2; - while (usedNames.has(candidate)) { - candidate = `${baseName}${suffix} ${counter}`; - counter += 1; - } - return candidate; - }; + const handleDuplicateConnection = async (conn: SavedConnection) => { + if (!conn?.id) return; + + const backendApp = (window as any).go?.app?.App; + if (typeof backendApp?.DuplicateConnection !== 'function') { + message.error('复制连接失败:后端接口不可用'); + return; + } - const cloneConnectionConfig = (config: SavedConnection['config']): SavedConnection['config'] => { - const raw: any = config || {}; - let cloned: any = {}; try { - cloned = typeof structuredClone === 'function' - ? structuredClone(raw) - : JSON.parse(JSON.stringify(raw)); - } catch { - cloned = { ...raw }; + const duplicatedConnection = await backendApp.DuplicateConnection(conn.id); + if (!duplicatedConnection) { + throw new Error('复制连接失败:后端未返回结果'); + } + addConnection(duplicatedConnection); + message.success(`已复制连接: ${duplicatedConnection.name}`); + } catch (error: any) { + message.error(error?.message || '复制连接失败'); } - - const readString = (...values: unknown[]): string => { - for (const value of values) { - if (typeof value === 'string') { - return value; - } - } - return ''; - }; - - const readBool = (fallback: boolean, ...values: unknown[]): boolean => { - for (const value of values) { - if (typeof value === 'boolean') { - return value; - } - } - return fallback; - }; - - const readNumber = (fallback: number, ...values: unknown[]): number => { - for (const value of values) { - const num = Number(value); - if (Number.isFinite(num)) { - return num; - } - } - return fallback; - }; - - const rawSSH = (cloned.ssh ?? cloned.SSH ?? {}) as Record; - const normalizedSSH = { - host: readString(rawSSH.host, rawSSH.Host, cloned.sshHost, cloned.SSHHost), - port: readNumber(22, rawSSH.port, rawSSH.Port, cloned.sshPort, cloned.SSHPort), - user: readString(rawSSH.user, rawSSH.User, cloned.sshUser, cloned.SSHUser), - password: readString(rawSSH.password, rawSSH.Password, cloned.sshPassword, cloned.SSHPassword), - keyPath: readString(rawSSH.keyPath, rawSSH.KeyPath, cloned.sshKeyPath, cloned.SSHKeyPath), - }; - const hasSSHDetail = Boolean( - normalizedSSH.host - || normalizedSSH.user - || normalizedSSH.password - || normalizedSSH.keyPath - ); - - const rawProxy = (cloned.proxy ?? cloned.Proxy ?? {}) as Record; - const proxyTypeRaw = readString(rawProxy.type, rawProxy.Type, cloned.proxyType, cloned.ProxyType).toLowerCase(); - const proxyType: 'socks5' | 'http' = proxyTypeRaw === 'http' ? 'http' : 'socks5'; - const normalizedProxy = { - type: proxyType, - host: readString(rawProxy.host, rawProxy.Host, cloned.proxyHost, cloned.ProxyHost), - port: readNumber(proxyType === 'http' ? 8080 : 1080, rawProxy.port, rawProxy.Port, cloned.proxyPort, cloned.ProxyPort), - user: readString(rawProxy.user, rawProxy.User, cloned.proxyUser, cloned.ProxyUser), - password: readString(rawProxy.password, rawProxy.Password, cloned.proxyPassword, cloned.ProxyPassword), - }; - const hasProxyDetail = Boolean(normalizedProxy.host || normalizedProxy.user || normalizedProxy.password); - const rawHttpTunnel = (cloned.httpTunnel ?? cloned.HTTPTunnel ?? {}) as Record; - const normalizedHttpTunnel = { - host: readString(rawHttpTunnel.host, rawHttpTunnel.Host, cloned.httpTunnelHost, cloned.HttpTunnelHost), - port: readNumber(8080, rawHttpTunnel.port, rawHttpTunnel.Port, cloned.httpTunnelPort, cloned.HttpTunnelPort), - user: readString(rawHttpTunnel.user, rawHttpTunnel.User, cloned.httpTunnelUser, cloned.HttpTunnelUser), - password: readString(rawHttpTunnel.password, rawHttpTunnel.Password, cloned.httpTunnelPassword, cloned.HttpTunnelPassword), - }; - const hasHttpTunnelDetail = Boolean(normalizedHttpTunnel.host || normalizedHttpTunnel.user || normalizedHttpTunnel.password); - const normalizedUseHttpTunnel = readBool(hasHttpTunnelDetail, cloned.useHttpTunnel, cloned.UseHTTPTunnel); - const normalizedUseProxy = !normalizedUseHttpTunnel && readBool(hasProxyDetail, cloned.useProxy, cloned.UseProxy); - - const rawHosts = Array.isArray(cloned.hosts) - ? cloned.hosts - : (Array.isArray(cloned.Hosts) ? cloned.Hosts : []); - const normalizedHosts = rawHosts - .map((entry: unknown) => String(entry || '').trim()) - .filter((entry: string) => !!entry); - - return { - ...(cloned as SavedConnection['config']), - useSSH: readBool(hasSSHDetail, cloned.useSSH, cloned.UseSSH), - ssh: normalizedSSH, - useProxy: normalizedUseProxy, - proxy: normalizedProxy, - useHttpTunnel: normalizedUseHttpTunnel, - httpTunnel: normalizedHttpTunnel, - hosts: normalizedHosts, - timeout: readNumber(30, cloned.timeout, cloned.Timeout), - }; - }; - - const handleDuplicateConnection = (conn: SavedConnection) => { - if (!conn) return; - - const duplicatedConnection: SavedConnection = { - ...conn, - id: `${Date.now()}-${Math.random().toString(36).slice(2, 8)}`, - name: buildDuplicateConnectionName(conn.name), - config: cloneConnectionConfig(conn.config), - includeDatabases: conn.includeDatabases ? [...conn.includeDatabases] : undefined, - includeRedisDatabases: conn.includeRedisDatabases ? [...conn.includeRedisDatabases] : undefined, - }; - - addConnection(duplicatedConnection); - message.success(`已复制连接: ${duplicatedConnection.name}`); }; const updateTreeData = (list: TreeNode[], key: React.Key, children: TreeNode[] | undefined): TreeNode[] => { return list.map(node => { @@ -3163,9 +3059,22 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> Modal.confirm({ title: '确认删除', content: `确定要删除连接 "${node.title}" 吗?`, - onOk: () => { - closeTabsByConnection(String(node.key)); - removeConnection(node.key); + onOk: async () => { + const connId = String(node.key); + const backendApp = (window as any).go?.app?.App; + if (typeof backendApp?.DeleteConnection !== 'function') { + message.error('删除连接失败:后端接口不可用'); + throw new Error('DeleteConnection unavailable'); + } + try { + await backendApp.DeleteConnection(connId); + closeTabsByConnection(connId); + removeConnection(connId); + message.success('已删除连接'); + } catch (error: any) { + message.error(error?.message || '删除连接失败'); + throw error; + } } }); } @@ -3300,9 +3209,22 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> Modal.confirm({ title: '确认删除', content: `确定要删除连接 "${node.title}" 吗?`, - onOk: () => { - closeTabsByConnection(String(node.key)); - removeConnection(node.key); + onOk: async () => { + const connId = String(node.key); + const backendApp = (window as any).go?.app?.App; + if (typeof backendApp?.DeleteConnection !== 'function') { + message.error('删除连接失败:后端接口不可用'); + throw new Error('DeleteConnection unavailable'); + } + try { + await backendApp.DeleteConnection(connId); + closeTabsByConnection(connId); + removeConnection(connId); + message.success('已删除连接'); + } catch (error: any) { + message.error(error?.message || '删除连接失败'); + throw error; + } } }); } diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx index 4f8fa6e..becce4a 100644 --- a/frontend/src/main.tsx +++ b/frontend/src/main.tsx @@ -1,4 +1,4 @@ -import React from 'react' +import React from 'react' import ReactDOM from 'react-dom/client' import App from './App' // import './index.css' // Optional global styles @@ -17,15 +17,125 @@ import { loader } from '@monaco-editor/react' import * as monaco from 'monaco-editor' loader.config({ monaco }) +const cloneBrowserMockValue = (value: any) => { + try { + return JSON.parse(JSON.stringify(value)); + } catch { + return value; + } +}; + +const resolveBrowserMockSecretFlag = (nextValue: unknown, clearFlag: boolean, existingFlag?: boolean) => { + if (String(nextValue ?? '') !== '') return true; + if (clearFlag) return false; + return !!existingFlag; +}; + +const buildBrowserMockDuplicateName = (rawName: string, items: any[]): string => { + const baseName = String(rawName || '').trim() || '连接'; + const suffix = ' - 副本'; + const usedNames = new Set(items.map((item) => String(item?.name || '').trim())); + let candidate = `${baseName}${suffix}`; + let counter = 2; + while (usedNames.has(candidate)) { + candidate = `${baseName}${suffix} ${counter}`; + counter += 1; + } + return candidate; +}; + if (typeof window !== 'undefined' && !(window as any).go) { + const mockConnections: any[] = []; + let mockGlobalProxy: any = { enabled: false, type: 'socks5', host: '', port: 1080, user: '', password: '', hasPassword: false }; + + const upsertMockConnection = (view: any) => { + const index = mockConnections.findIndex((item) => item.id === view.id); + if (index >= 0) { + mockConnections[index] = view; + return; + } + mockConnections.push(view); + }; + + const saveMockConnection = (input: any) => { + const existing = mockConnections.find((item) => item.id === input?.id); + const config = (input?.config && typeof input.config === 'object') ? input.config : {}; + const ssh = (config.ssh && typeof config.ssh === 'object') ? config.ssh : {}; + const proxy = (config.proxy && typeof config.proxy === 'object') ? config.proxy : {}; + const httpTunnel = (config.httpTunnel && typeof config.httpTunnel === 'object') ? config.httpTunnel : {}; + const nextId = String(input?.id || existing?.id || `mock-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`); + const view = { + id: nextId, + name: String(input?.name || existing?.name || '未命名连接'), + config: { + ...config, + id: nextId, + password: '', + ssh: { ...ssh, password: '' }, + proxy: { ...proxy, password: '' }, + httpTunnel: { ...httpTunnel, password: '' }, + uri: '', + dsn: '', + mysqlReplicaPassword: '', + mongoReplicaPassword: '', + }, + includeDatabases: Array.isArray(input?.includeDatabases) ? [...input.includeDatabases] : existing?.includeDatabases, + includeRedisDatabases: Array.isArray(input?.includeRedisDatabases) ? [...input.includeRedisDatabases] : existing?.includeRedisDatabases, + iconType: typeof input?.iconType === 'string' ? input.iconType : (existing?.iconType || ''), + iconColor: typeof input?.iconColor === 'string' ? input.iconColor : (existing?.iconColor || ''), + hasPrimaryPassword: resolveBrowserMockSecretFlag(config.password, !!input?.clearPrimaryPassword, existing?.hasPrimaryPassword), + hasSSHPassword: resolveBrowserMockSecretFlag(ssh.password, !!input?.clearSSHPassword, existing?.hasSSHPassword), + hasProxyPassword: resolveBrowserMockSecretFlag(proxy.password, !!input?.clearProxyPassword, existing?.hasProxyPassword), + hasHttpTunnelPassword: resolveBrowserMockSecretFlag(httpTunnel.password, !!input?.clearHttpTunnelPassword, existing?.hasHttpTunnelPassword), + hasMySQLReplicaPassword: resolveBrowserMockSecretFlag(config.mysqlReplicaPassword, !!input?.clearMySQLReplicaPassword, existing?.hasMySQLReplicaPassword), + hasMongoReplicaPassword: resolveBrowserMockSecretFlag(config.mongoReplicaPassword, !!input?.clearMongoReplicaPassword, existing?.hasMongoReplicaPassword), + hasOpaqueURI: resolveBrowserMockSecretFlag(config.uri, !!input?.clearOpaqueURI, existing?.hasOpaqueURI), + hasOpaqueDSN: resolveBrowserMockSecretFlag(config.dsn, !!input?.clearOpaqueDSN, existing?.hasOpaqueDSN), + }; + upsertMockConnection(view); + return cloneBrowserMockValue(view); + }; + + const saveMockGlobalProxy = (input: any) => { + const nextPassword = String(input?.password ?? ''); + mockGlobalProxy = { + ...mockGlobalProxy, + ...input, + password: '', + hasPassword: nextPassword !== '' ? true : !!mockGlobalProxy.hasPassword, + }; + return cloneBrowserMockValue(mockGlobalProxy); + }; + (window as any).go = { app: { App: { CheckUpdate: async () => ({ success: false }), DownloadUpdate: async () => ({ success: false }), - GetSavedConnections: async () => [], - SaveConnection: async () => null, - DeleteConnection: async () => null, + GetSavedConnections: async () => cloneBrowserMockValue(mockConnections), + SaveConnection: async (input: any) => saveMockConnection(input), + DeleteConnection: async (id: string) => { + const index = mockConnections.findIndex((item) => item.id === id); + if (index >= 0) { + mockConnections.splice(index, 1); + } + return null; + }, + DuplicateConnection: async (id: string) => { + const existing = mockConnections.find((item) => item.id === id); + if (!existing) return null; + const duplicated = cloneBrowserMockValue({ + ...existing, + id: `mock-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`, + name: buildBrowserMockDuplicateName(existing.name, mockConnections), + config: cloneBrowserMockValue(existing.config), + includeDatabases: Array.isArray(existing.includeDatabases) ? [...existing.includeDatabases] : undefined, + includeRedisDatabases: Array.isArray(existing.includeRedisDatabases) ? [...existing.includeRedisDatabases] : undefined, + }); + mockConnections.push(duplicated); + return cloneBrowserMockValue(duplicated); + }, + ImportLegacyConnections: async (items: any[]) => items.map((item) => saveMockConnection(item)), OpenConnection: async () => null, CloseConnection: async () => null, GetDatabases: async () => [], @@ -42,11 +152,13 @@ if (typeof window !== 'undefined' && !(window as any).go) { InstallUpdateAndRestart: async () => ({ success: false }), ImportConfigFile: async () => ({ success: false }), ExportData: async () => ({ success: false }), + GetGlobalProxyConfig: async () => ({ success: true, data: cloneBrowserMockValue(mockGlobalProxy) }), + SaveGlobalProxy: async (input: any) => saveMockGlobalProxy(input), + ImportLegacyGlobalProxy: async (input: any) => saveMockGlobalProxy(input), } } }; } - // 全局注册透明主题,避免每个 Editor 组件 beforeMount 中重复定义 monaco.editor.defineTheme('transparent-dark', { base: 'vs-dark', inherit: true, rules: [], diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 34db0ec..40e1e9e 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -217,6 +217,8 @@ export interface AIProviderConfig { type: AIProviderType; name: string; apiKey: string; + secretRef?: string; + hasSecret?: boolean; baseUrl: string; model: string; models?: string[]; diff --git a/frontend/src/utils/connectionSecretDraft.test.ts b/frontend/src/utils/connectionSecretDraft.test.ts new file mode 100644 index 0000000..1577c0b --- /dev/null +++ b/frontend/src/utils/connectionSecretDraft.test.ts @@ -0,0 +1,86 @@ +import { describe, expect, it } from 'vitest'; + +import { resolveConnectionSecretDraft } from './connectionSecretDraft'; + +describe('resolveConnectionSecretDraft', () => { + it('keeps an existing stored secret when edit form leaves the field blank', () => { + const result = resolveConnectionSecretDraft({ + hasSecret: true, + valueInput: '', + clearSecret: false, + }); + + expect(result.value).toBe(''); + expect(result.clearStoredSecret).toBe(false); + expect(result.keepsStoredSecret).toBe(true); + expect(result.hasSecretAfterSave).toBe(true); + }); + + it('replaces the stored secret when a new value is entered', () => { + const result = resolveConnectionSecretDraft({ + hasSecret: true, + valueInput: ' mongodb://demo ', + clearSecret: false, + trimInput: true, + }); + + expect(result.value).toBe('mongodb://demo'); + expect(result.clearStoredSecret).toBe(false); + expect(result.keepsStoredSecret).toBe(false); + expect(result.hasSecretAfterSave).toBe(true); + }); + + it('clears the stored secret when explicitly requested', () => { + const result = resolveConnectionSecretDraft({ + hasSecret: true, + valueInput: '', + clearSecret: true, + }); + + expect(result.value).toBe(''); + expect(result.clearStoredSecret).toBe(true); + expect(result.keepsStoredSecret).toBe(false); + expect(result.hasSecretAfterSave).toBe(false); + }); + + it('prefers a newly entered value over a stale clear toggle', () => { + const result = resolveConnectionSecretDraft({ + hasSecret: true, + valueInput: 'new-password', + clearSecret: true, + }); + + expect(result.value).toBe('new-password'); + expect(result.clearStoredSecret).toBe(false); + expect(result.keepsStoredSecret).toBe(false); + expect(result.hasSecretAfterSave).toBe(true); + }); + + it('does not emit a clear flag for a brand new blank field', () => { + const result = resolveConnectionSecretDraft({ + hasSecret: false, + valueInput: '', + clearSecret: false, + }); + + expect(result.value).toBe(''); + expect(result.clearStoredSecret).toBe(false); + expect(result.keepsStoredSecret).toBe(false); + expect(result.hasSecretAfterSave).toBe(false); + }); + + it('supports force clearing stored secrets', () => { + const result = resolveConnectionSecretDraft({ + hasSecret: true, + valueInput: 'temporary', + clearSecret: false, + forceClear: true, + }); + + expect(result.value).toBe(''); + expect(result.clearStoredSecret).toBe(true); + expect(result.keepsStoredSecret).toBe(false); + expect(result.hasSecretAfterSave).toBe(false); + }); +}); + diff --git a/frontend/src/utils/connectionSecretDraft.ts b/frontend/src/utils/connectionSecretDraft.ts new file mode 100644 index 0000000..368aeb5 --- /dev/null +++ b/frontend/src/utils/connectionSecretDraft.ts @@ -0,0 +1,63 @@ +export interface ConnectionSecretDraftInput { + valueInput?: string; + hasSecret?: boolean; + clearSecret?: boolean; + forceClear?: boolean; + trimInput?: boolean; +} + +export interface ConnectionSecretDraftResult { + value: string; + clearStoredSecret: boolean; + keepsStoredSecret: boolean; + hasSecretAfterSave: boolean; +} + +export function resolveConnectionSecretDraft(input: ConnectionSecretDraftInput): ConnectionSecretDraftResult { + const rawValue = input.valueInput ?? ''; + const value = input.trimInput ? String(rawValue).trim() : String(rawValue); + + if (input.forceClear) { + return { + value: '', + clearStoredSecret: true, + keepsStoredSecret: false, + hasSecretAfterSave: false, + }; + } + + if (value !== '') { + return { + value, + clearStoredSecret: false, + keepsStoredSecret: false, + hasSecretAfterSave: true, + }; + } + + if (input.clearSecret) { + return { + value: '', + clearStoredSecret: true, + keepsStoredSecret: false, + hasSecretAfterSave: false, + }; + } + + if (input.hasSecret) { + return { + value: '', + clearStoredSecret: false, + keepsStoredSecret: true, + hasSecretAfterSave: true, + }; + } + + return { + value: '', + clearStoredSecret: false, + keepsStoredSecret: false, + hasSecretAfterSave: false, + }; +} + diff --git a/frontend/src/utils/providerSecretDraft.test.ts b/frontend/src/utils/providerSecretDraft.test.ts new file mode 100644 index 0000000..09d652a --- /dev/null +++ b/frontend/src/utils/providerSecretDraft.test.ts @@ -0,0 +1,41 @@ +import { describe, expect, it } from 'vitest'; + +import { resolveProviderSecretDraft } from './providerSecretDraft'; + +describe('resolveProviderSecretDraft', () => { + it('keeps existing provider secret when edit form leaves apiKey blank', () => { + const result = resolveProviderSecretDraft({ + hasSecret: true, + apiKeyInput: '', + clearSecret: false, + }); + + expect(result.mode).toBe('keep'); + expect(result.apiKey).toBe(''); + expect(result.hasSecret).toBe(true); + }); + + it('replaces the provider secret when a new apiKey is entered', () => { + const result = resolveProviderSecretDraft({ + hasSecret: true, + apiKeyInput: ' sk-new ', + clearSecret: false, + }); + + expect(result.mode).toBe('replace'); + expect(result.apiKey).toBe('sk-new'); + expect(result.hasSecret).toBe(true); + }); + + it('clears the stored provider secret when requested', () => { + const result = resolveProviderSecretDraft({ + hasSecret: true, + apiKeyInput: '', + clearSecret: true, + }); + + expect(result.mode).toBe('clear'); + expect(result.apiKey).toBe(''); + expect(result.hasSecret).toBe(false); + }); +}); diff --git a/frontend/src/utils/providerSecretDraft.ts b/frontend/src/utils/providerSecretDraft.ts new file mode 100644 index 0000000..be0ce45 --- /dev/null +++ b/frontend/src/utils/providerSecretDraft.ts @@ -0,0 +1,47 @@ +export type ProviderSecretDraftMode = 'keep' | 'replace' | 'clear'; + +export interface ProviderSecretDraftInput { + hasSecret?: boolean; + apiKeyInput?: string; + clearSecret?: boolean; +} + +export interface ProviderSecretDraftResult { + mode: ProviderSecretDraftMode; + apiKey: string; + hasSecret: boolean; +} + +export function resolveProviderSecretDraft(input: ProviderSecretDraftInput): ProviderSecretDraftResult { + const apiKey = String(input.apiKeyInput || '').trim(); + + if (input.clearSecret) { + return { + mode: 'clear', + apiKey: '', + hasSecret: false, + }; + } + + if (apiKey) { + return { + mode: 'replace', + apiKey, + hasSecret: true, + }; + } + + if (input.hasSecret) { + return { + mode: 'keep', + apiKey: '', + hasSecret: true, + }; + } + + return { + mode: 'clear', + apiKey: '', + hasSecret: false, + }; +} diff --git a/frontend/wailsjs/go/app/App.d.ts b/frontend/wailsjs/go/app/App.d.ts index d8203ed..08c1dd8 100755 --- a/frontend/wailsjs/go/app/App.d.ts +++ b/frontend/wailsjs/go/app/App.d.ts @@ -52,6 +52,8 @@ export function DataSyncAnalyze(arg1:sync.SyncConfig):Promise; +export function DeleteConnection(arg1:string):Promise; + export function DownloadDriverPackage(arg1:string,arg2:string,arg3:string,arg4:string):Promise; export function DownloadUpdate():Promise; @@ -64,6 +66,8 @@ export function DropTable(arg1:connection.ConnectionConfig,arg2:string,arg3:stri export function DropView(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise; +export function DuplicateConnection(arg1:string):Promise; + export function ExecuteSQLFile(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise; export function ExportData(arg1:Array>,arg2:Array,arg3:string,arg4:string):Promise; @@ -90,12 +94,18 @@ export function GetDriverVersionPackageSize(arg1:string,arg2:string):Promise; +export function GetSavedConnections():Promise>; + export function ImportConfigFile():Promise; export function ImportData(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise; export function ImportDataWithProgress(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise; +export function ImportLegacyConnections(arg1:Array):Promise>; + +export function ImportLegacyGlobalProxy(arg1:connection.SaveGlobalProxyInput):Promise; + export function InstallLocalDriverPackage(arg1:string,arg2:string,arg3:string):Promise; export function InstallUpdateAndRestart():Promise; @@ -180,6 +190,10 @@ export function ResolveDriverPackageDownloadURL(arg1:string,arg2:string):Promise export function ResolveDriverRepositoryURL(arg1:string):Promise; +export function SaveConnection(arg1:connection.SavedConnectionInput):Promise; + +export function SaveGlobalProxy(arg1:connection.SaveGlobalProxyInput):Promise; + export function SelectDatabaseFile(arg1:string,arg2:string):Promise; export function SelectDriverDownloadDirectory(arg1:string):Promise; diff --git a/frontend/wailsjs/go/app/App.js b/frontend/wailsjs/go/app/App.js index 8862f24..9564131 100755 --- a/frontend/wailsjs/go/app/App.js +++ b/frontend/wailsjs/go/app/App.js @@ -98,6 +98,10 @@ export function DataSyncPreview(arg1, arg2, arg3) { return window['go']['app']['App']['DataSyncPreview'](arg1, arg2, arg3); } +export function DeleteConnection(arg1) { + return window['go']['app']['App']['DeleteConnection'](arg1); +} + export function DownloadDriverPackage(arg1, arg2, arg3, arg4) { return window['go']['app']['App']['DownloadDriverPackage'](arg1, arg2, arg3, arg4); } @@ -122,6 +126,10 @@ export function DropView(arg1, arg2, arg3) { return window['go']['app']['App']['DropView'](arg1, arg2, arg3); } +export function DuplicateConnection(arg1) { + return window['go']['app']['App']['DuplicateConnection'](arg1); +} + export function ExecuteSQLFile(arg1, arg2, arg3, arg4) { return window['go']['app']['App']['ExecuteSQLFile'](arg1, arg2, arg3, arg4); } @@ -174,6 +182,10 @@ export function GetGlobalProxyConfig() { return window['go']['app']['App']['GetGlobalProxyConfig'](); } +export function GetSavedConnections() { + return window['go']['app']['App']['GetSavedConnections'](); +} + export function ImportConfigFile() { return window['go']['app']['App']['ImportConfigFile'](); } @@ -186,6 +198,14 @@ export function ImportDataWithProgress(arg1, arg2, arg3, arg4) { return window['go']['app']['App']['ImportDataWithProgress'](arg1, arg2, arg3, arg4); } +export function ImportLegacyConnections(arg1) { + return window['go']['app']['App']['ImportLegacyConnections'](arg1); +} + +export function ImportLegacyGlobalProxy(arg1) { + return window['go']['app']['App']['ImportLegacyGlobalProxy'](arg1); +} + export function InstallLocalDriverPackage(arg1, arg2, arg3) { return window['go']['app']['App']['InstallLocalDriverPackage'](arg1, arg2, arg3); } @@ -354,6 +374,14 @@ export function ResolveDriverRepositoryURL(arg1) { return window['go']['app']['App']['ResolveDriverRepositoryURL'](arg1); } +export function SaveConnection(arg1) { + return window['go']['app']['App']['SaveConnection'](arg1); +} + +export function SaveGlobalProxy(arg1) { + return window['go']['app']['App']['SaveGlobalProxy'](arg1); +} + export function SelectDatabaseFile(arg1, arg2) { return window['go']['app']['App']['SelectDatabaseFile'](arg1, arg2); } diff --git a/frontend/wailsjs/go/models.ts b/frontend/wailsjs/go/models.ts index e9558a8..433b7bc 100755 --- a/frontend/wailsjs/go/models.ts +++ b/frontend/wailsjs/go/models.ts @@ -78,6 +78,8 @@ export namespace ai { type: string; name: string; apiKey: string; + secretRef?: string; + hasSecret?: boolean; baseUrl: string; model: string; models?: string[]; @@ -96,6 +98,8 @@ export namespace ai { this.type = source["type"]; this.name = source["name"]; this.apiKey = source["apiKey"]; + this.secretRef = source["secretRef"]; + this.hasSecret = source["hasSecret"]; this.baseUrl = source["baseUrl"]; this.model = source["model"]; this.models = source["models"]; @@ -284,6 +288,7 @@ export namespace connection { } } export class ConnectionConfig { + id?: string; type: string; host: string; port: number; @@ -324,6 +329,7 @@ export namespace connection { constructor(source: any = {}) { if ('string' === typeof source) source = JSON.parse(source); + this.id = source["id"]; this.type = source["type"]; this.host = source["host"]; this.port = source["port"]; @@ -377,6 +383,32 @@ export namespace connection { return a; } } + export class GlobalProxyView { + enabled: boolean; + type: string; + host: string; + port: number; + user?: string; + password?: string; + hasPassword?: boolean; + secretRef?: string; + + static createFrom(source: any = {}) { + return new GlobalProxyView(source); + } + + constructor(source: any = {}) { + if ('string' === typeof source) source = JSON.parse(source); + this.enabled = source["enabled"]; + this.type = source["type"]; + this.host = source["host"]; + this.port = source["port"]; + this.user = source["user"]; + this.password = source["password"]; + this.hasPassword = source["hasPassword"]; + this.secretRef = source["secretRef"]; + } + } export class QueryResult { @@ -400,6 +432,146 @@ export namespace connection { } } + export class SaveGlobalProxyInput { + enabled: boolean; + type: string; + host: string; + port: number; + user?: string; + password?: string; + + static createFrom(source: any = {}) { + return new SaveGlobalProxyInput(source); + } + + constructor(source: any = {}) { + if ('string' === typeof source) source = JSON.parse(source); + this.enabled = source["enabled"]; + this.type = source["type"]; + this.host = source["host"]; + this.port = source["port"]; + this.user = source["user"]; + this.password = source["password"]; + } + } + export class SavedConnectionInput { + id?: string; + name: string; + config: ConnectionConfig; + includeDatabases?: string[]; + includeRedisDatabases?: number[]; + iconType?: string; + iconColor?: string; + clearPrimaryPassword?: boolean; + clearSSHPassword?: boolean; + clearProxyPassword?: boolean; + clearHttpTunnelPassword?: boolean; + clearMySQLReplicaPassword?: boolean; + clearMongoReplicaPassword?: boolean; + clearOpaqueURI?: boolean; + clearOpaqueDSN?: boolean; + + static createFrom(source: any = {}) { + return new SavedConnectionInput(source); + } + + constructor(source: any = {}) { + if ('string' === typeof source) source = JSON.parse(source); + this.id = source["id"]; + this.name = source["name"]; + this.config = this.convertValues(source["config"], ConnectionConfig); + this.includeDatabases = source["includeDatabases"]; + this.includeRedisDatabases = source["includeRedisDatabases"]; + this.iconType = source["iconType"]; + this.iconColor = source["iconColor"]; + this.clearPrimaryPassword = source["clearPrimaryPassword"]; + this.clearSSHPassword = source["clearSSHPassword"]; + this.clearProxyPassword = source["clearProxyPassword"]; + this.clearHttpTunnelPassword = source["clearHttpTunnelPassword"]; + this.clearMySQLReplicaPassword = source["clearMySQLReplicaPassword"]; + this.clearMongoReplicaPassword = source["clearMongoReplicaPassword"]; + this.clearOpaqueURI = source["clearOpaqueURI"]; + this.clearOpaqueDSN = source["clearOpaqueDSN"]; + } + + convertValues(a: any, classs: any, asMap: boolean = false): any { + if (!a) { + return a; + } + if (a.slice && a.map) { + return (a as any[]).map(elem => this.convertValues(elem, classs)); + } else if ("object" === typeof a) { + if (asMap) { + for (const key of Object.keys(a)) { + a[key] = new classs(a[key]); + } + return a; + } + return new classs(a); + } + return a; + } + } + export class SavedConnectionView { + id: string; + name: string; + config: ConnectionConfig; + includeDatabases?: string[]; + includeRedisDatabases?: number[]; + iconType?: string; + iconColor?: string; + secretRef?: string; + hasPrimaryPassword?: boolean; + hasSSHPassword?: boolean; + hasProxyPassword?: boolean; + hasHttpTunnelPassword?: boolean; + hasMySQLReplicaPassword?: boolean; + hasMongoReplicaPassword?: boolean; + hasOpaqueURI?: boolean; + hasOpaqueDSN?: boolean; + + static createFrom(source: any = {}) { + return new SavedConnectionView(source); + } + + constructor(source: any = {}) { + if ('string' === typeof source) source = JSON.parse(source); + this.id = source["id"]; + this.name = source["name"]; + this.config = this.convertValues(source["config"], ConnectionConfig); + this.includeDatabases = source["includeDatabases"]; + this.includeRedisDatabases = source["includeRedisDatabases"]; + this.iconType = source["iconType"]; + this.iconColor = source["iconColor"]; + this.secretRef = source["secretRef"]; + this.hasPrimaryPassword = source["hasPrimaryPassword"]; + this.hasSSHPassword = source["hasSSHPassword"]; + this.hasProxyPassword = source["hasProxyPassword"]; + this.hasHttpTunnelPassword = source["hasHttpTunnelPassword"]; + this.hasMySQLReplicaPassword = source["hasMySQLReplicaPassword"]; + this.hasMongoReplicaPassword = source["hasMongoReplicaPassword"]; + this.hasOpaqueURI = source["hasOpaqueURI"]; + this.hasOpaqueDSN = source["hasOpaqueDSN"]; + } + + convertValues(a: any, classs: any, asMap: boolean = false): any { + if (!a) { + return a; + } + if (a.slice && a.map) { + return (a as any[]).map(elem => this.convertValues(elem, classs)); + } else if ("object" === typeof a) { + if (asMap) { + for (const key of Object.keys(a)) { + a[key] = new classs(a[key]); + } + return a; + } + return new classs(a); + } + return a; + } + } } diff --git a/internal/app/methods_saved_connections_test.go b/internal/app/methods_saved_connections_test.go index 4b117cc..d17785f 100644 --- a/internal/app/methods_saved_connections_test.go +++ b/internal/app/methods_saved_connections_test.go @@ -1,6 +1,7 @@ package app import ( + "reflect" "testing" "GoNavi-Wails/internal/connection" @@ -11,8 +12,11 @@ func TestSaveConnectionMethodReturnsSecretlessView(t *testing.T) { app.configDir = t.TempDir() result, err := app.SaveConnection(connection.SavedConnectionInput{ - ID: "conn-1", - Name: "Primary", + ID: "conn-1", + Name: "Primary", + IncludeDatabases: []string{"appdb"}, + IconType: "postgres", + IconColor: "#1677ff", Config: connection.ConnectionConfig{ ID: "conn-1", Type: "postgres", @@ -31,6 +35,79 @@ func TestSaveConnectionMethodReturnsSecretlessView(t *testing.T) { if !result.HasPrimaryPassword { t.Fatal("expected HasPrimaryPassword=true") } + if !reflect.DeepEqual(result.IncludeDatabases, []string{"appdb"}) { + t.Fatalf("expected include databases to be preserved, got %#v", result.IncludeDatabases) + } + if result.IconType != "postgres" || result.IconColor != "#1677ff" { + t.Fatalf("expected icon metadata to be preserved, got type=%q color=%q", result.IconType, result.IconColor) + } +} + +func TestSaveConnectionClearsRequestedSecretFields(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + _, err := app.SaveConnection(connection.SavedConnectionInput{ + ID: "conn-1", + Name: "Primary", + Config: connection.ConnectionConfig{ + ID: "conn-1", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + Password: "postgres-secret", + UseSSH: true, + SSH: connection.SSHConfig{ + Host: "jump.local", + Port: 22, + User: "ops", + Password: "ssh-secret", + }, + }, + }) + if err != nil { + t.Fatal(err) + } + + view, err := app.SaveConnection(connection.SavedConnectionInput{ + ID: "conn-1", + Name: "Primary", + Config: connection.ConnectionConfig{ + ID: "conn-1", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + UseSSH: true, + SSH: connection.SSHConfig{ + Host: "jump.local", + Port: 22, + User: "ops", + }, + }, + ClearPrimaryPassword: true, + }) + if err != nil { + t.Fatal(err) + } + if view.HasPrimaryPassword { + t.Fatal("expected HasPrimaryPassword=false after clearing") + } + if !view.HasSSHPassword { + t.Fatal("expected SSH password to stay stored") + } + + resolved, err := app.resolveConnectionSecrets(view.Config) + if err != nil { + t.Fatal(err) + } + if resolved.Password != "" { + t.Fatalf("expected cleared primary password, got %q", resolved.Password) + } + if resolved.SSH.Password != "ssh-secret" { + t.Fatalf("expected SSH password to stay stored, got %q", resolved.SSH.Password) + } } func TestDuplicateConnectionClonesSecretBundle(t *testing.T) { @@ -38,8 +115,12 @@ func TestDuplicateConnectionClonesSecretBundle(t *testing.T) { app.configDir = t.TempDir() _, err := app.SaveConnection(connection.SavedConnectionInput{ - ID: "conn-1", - Name: "Primary", + ID: "conn-1", + Name: "Primary", + IncludeDatabases: []string{"appdb"}, + IncludeRedisDatabases: []int{0, 1}, + IconType: "postgres", + IconColor: "#1677ff", Config: connection.ConnectionConfig{ ID: "conn-1", Type: "postgres", @@ -60,6 +141,18 @@ func TestDuplicateConnectionClonesSecretBundle(t *testing.T) { if duplicate.ID == "conn-1" { t.Fatal("duplicate should have a new id") } + if duplicate.Name != "Primary - 副本" { + t.Fatalf("expected duplicate name to keep existing UX, got %q", duplicate.Name) + } + if !reflect.DeepEqual(duplicate.IncludeDatabases, []string{"appdb"}) { + t.Fatalf("expected include databases to be cloned, got %#v", duplicate.IncludeDatabases) + } + if !reflect.DeepEqual(duplicate.IncludeRedisDatabases, []int{0, 1}) { + t.Fatalf("expected redis include databases to be cloned, got %#v", duplicate.IncludeRedisDatabases) + } + if duplicate.IconType != "postgres" || duplicate.IconColor != "#1677ff" { + t.Fatalf("expected icon metadata to be cloned, got type=%q color=%q", duplicate.IconType, duplicate.IconColor) + } resolved, err := app.resolveConnectionSecrets(duplicate.Config) if err != nil { diff --git a/internal/app/saved_connections.go b/internal/app/saved_connections.go index 19eb76d..4a5dbb7 100644 --- a/internal/app/saved_connections.go +++ b/internal/app/saved_connections.go @@ -95,6 +95,53 @@ func mergeConnectionSecretBundles(base, overlay connectionSecretBundle) connecti return merged } +func applyConnectionSecretClears(bundle connectionSecretBundle, input connection.SavedConnectionInput) connectionSecretBundle { + cleared := bundle + if input.ClearPrimaryPassword { + cleared.Password = "" + } + if input.ClearSSHPassword { + cleared.SSHPassword = "" + } + if input.ClearProxyPassword { + cleared.ProxyPassword = "" + } + if input.ClearHTTPTunnelPassword { + cleared.HTTPTunnelPassword = "" + } + if input.ClearMySQLReplicaPassword { + cleared.MySQLReplicaPassword = "" + } + if input.ClearMongoReplicaPassword { + cleared.MongoReplicaPassword = "" + } + if input.ClearOpaqueURI { + cleared.OpaqueURI = "" + } + if input.ClearOpaqueDSN { + cleared.OpaqueDSN = "" + } + return cleared +} + +func cloneStringSlice(input []string) []string { + if len(input) == 0 { + return nil + } + cloned := make([]string, len(input)) + copy(cloned, input) + return cloned +} + +func cloneIntSlice(input []int) []int { + if len(input) == 0 { + return nil + } + cloned := make([]int, len(input)) + copy(cloned, input) + return cloned +} + func splitConnectionSecrets(input connection.SavedConnectionInput) (connection.SavedConnectionView, connectionSecretBundle) { id := strings.TrimSpace(input.ID) if id == "" { @@ -143,6 +190,10 @@ func splitConnectionSecrets(input connection.SavedConnectionInput) (connection.S ID: id, Name: strings.TrimSpace(input.Name), Config: meta, + IncludeDatabases: cloneStringSlice(input.IncludeDatabases), + IncludeRedisDatabases: cloneIntSlice(input.IncludeRedisDatabases), + IconType: strings.TrimSpace(input.IconType), + IconColor: strings.TrimSpace(input.IconColor), HasPrimaryPassword: strings.TrimSpace(bundle.Password) != "", HasSSHPassword: strings.TrimSpace(bundle.SSHPassword) != "", HasProxyPassword: strings.TrimSpace(bundle.ProxyPassword) != "", @@ -223,6 +274,7 @@ func (r *savedConnectionRepository) Save(input connection.SavedConnectionInput) mergedBundle = mergeConnectionSecretBundles(existingBundle, bundle) view.SecretRef = existing.SecretRef } + mergedBundle = applyConnectionSecretClears(mergedBundle, input) if mergedBundle.hasAny() { ref, storeErr := r.storeSecretBundle(view.ID, view.SecretRef, mergedBundle) @@ -332,6 +384,27 @@ func applyConnectionBundleFlags(view *connection.SavedConnectionView, bundle con view.HasOpaqueDSN = strings.TrimSpace(bundle.OpaqueDSN) != "" } +func buildDuplicateConnectionName(baseName string, existing []connection.SavedConnectionView) string { + trimmedBaseName := strings.TrimSpace(baseName) + if trimmedBaseName == "" { + trimmedBaseName = "连接" + } + suffix := " - 副本" + usedNames := make(map[string]struct{}, len(existing)) + for _, item := range existing { + usedNames[strings.TrimSpace(item.Name)] = struct{}{} + } + candidate := trimmedBaseName + suffix + counter := 2 + for { + if _, exists := usedNames[candidate]; !exists { + return candidate + } + candidate = fmt.Sprintf("%s%s %d", trimmedBaseName, suffix, counter) + counter++ + } +} + func (r *savedConnectionRepository) List() ([]connection.SavedConnectionView, error) { return r.load() } @@ -357,15 +430,27 @@ func (r *savedConnectionRepository) Delete(id string) error { } func (r *savedConnectionRepository) Duplicate(id string) (connection.SavedConnectionView, error) { - original, err := r.Find(id) + connections, err := r.load() if err != nil { return connection.SavedConnectionView{}, err } + index := -1 + for i, item := range connections { + if item.ID == strings.TrimSpace(id) { + index = i + break + } + } + if index < 0 { + return connection.SavedConnectionView{}, fmt.Errorf("saved connection not found: %s", id) + } + + original := connections[index] duplicate := original duplicate.ID = "conn-" + uuid.New().String()[:8] duplicate.Config.ID = duplicate.ID - duplicate.Name = original.Name + " Copy" + duplicate.Name = buildDuplicateConnectionName(original.Name, connections) bundle, err := r.loadSecretBundle(original) if err != nil { @@ -383,10 +468,6 @@ func (r *savedConnectionRepository) Duplicate(id string) (connection.SavedConnec applyConnectionBundleFlags(&duplicate, connectionSecretBundle{}) } - connections, err := r.load() - if err != nil { - return connection.SavedConnectionView{}, err - } connections = append(connections, duplicate) if err := r.saveAll(connections); err != nil { return connection.SavedConnectionView{}, err diff --git a/internal/connection/saved_types.go b/internal/connection/saved_types.go index a364a50..c99bf1a 100644 --- a/internal/connection/saved_types.go +++ b/internal/connection/saved_types.go @@ -1,15 +1,31 @@ package connection type SavedConnectionInput struct { - ID string `json:"id,omitempty"` - Name string `json:"name"` - Config ConnectionConfig `json:"config"` + ID string `json:"id,omitempty"` + Name string `json:"name"` + Config ConnectionConfig `json:"config"` + IncludeDatabases []string `json:"includeDatabases,omitempty"` + IncludeRedisDatabases []int `json:"includeRedisDatabases,omitempty"` + IconType string `json:"iconType,omitempty"` + IconColor string `json:"iconColor,omitempty"` + ClearPrimaryPassword bool `json:"clearPrimaryPassword,omitempty"` + ClearSSHPassword bool `json:"clearSSHPassword,omitempty"` + ClearProxyPassword bool `json:"clearProxyPassword,omitempty"` + ClearHTTPTunnelPassword bool `json:"clearHttpTunnelPassword,omitempty"` + ClearMySQLReplicaPassword bool `json:"clearMySQLReplicaPassword,omitempty"` + ClearMongoReplicaPassword bool `json:"clearMongoReplicaPassword,omitempty"` + ClearOpaqueURI bool `json:"clearOpaqueURI,omitempty"` + ClearOpaqueDSN bool `json:"clearOpaqueDSN,omitempty"` } type SavedConnectionView struct { ID string `json:"id"` Name string `json:"name"` Config ConnectionConfig `json:"config"` + IncludeDatabases []string `json:"includeDatabases,omitempty"` + IncludeRedisDatabases []int `json:"includeRedisDatabases,omitempty"` + IconType string `json:"iconType,omitempty"` + IconColor string `json:"iconColor,omitempty"` SecretRef string `json:"secretRef,omitempty"` HasPrimaryPassword bool `json:"hasPrimaryPassword,omitempty"` HasSSHPassword bool `json:"hasSSHPassword,omitempty"` From 255cc14bf64150286b9e0f1503c7b2895dbe7c14 Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Sat, 4 Apr 2026 10:51:32 +0800 Subject: [PATCH 12/14] =?UTF-8?q?=F0=9F=90=9B=20fix(config-secret-storage)?= =?UTF-8?q?:=20=E4=BF=AE=E5=A4=8D=E5=AF=86=E6=96=87=E7=BC=96=E8=BE=91?= =?UTF-8?q?=E4=B8=8E=E7=8A=B6=E6=80=81=E6=AE=8B=E7=95=99=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复自定义连接编辑时已保存 DSN 无法留空沿用的问题 - 重置 AI 供应商编辑态与清空密钥开关,避免关闭后状态残留 - 对齐浏览器 mock 复制连接的 config.id 语义并补充回归测试 --- frontend/src/components/AISettingsModal.tsx | 75 ++++++++++----- frontend/src/components/ConnectionModal.tsx | 20 +++- frontend/src/main.tsx | 44 ++------- .../src/utils/aiProviderEditorState.test.ts | 49 ++++++++++ frontend/src/utils/aiProviderEditorState.ts | 92 +++++++++++++++++++ .../src/utils/browserMockConnections.test.ts | 26 ++++++ frontend/src/utils/browserMockConnections.ts | 47 ++++++++++ .../src/utils/customConnectionDsn.test.ts | 37 ++++++++ frontend/src/utils/customConnectionDsn.ts | 27 ++++++ 9 files changed, 355 insertions(+), 62 deletions(-) create mode 100644 frontend/src/utils/aiProviderEditorState.test.ts create mode 100644 frontend/src/utils/aiProviderEditorState.ts create mode 100644 frontend/src/utils/browserMockConnections.test.ts create mode 100644 frontend/src/utils/browserMockConnections.ts create mode 100644 frontend/src/utils/customConnectionDsn.test.ts create mode 100644 frontend/src/utils/customConnectionDsn.ts diff --git a/frontend/src/components/AISettingsModal.tsx b/frontend/src/components/AISettingsModal.tsx index 7d586a4..ecd4b62 100644 --- a/frontend/src/components/AISettingsModal.tsx +++ b/frontend/src/components/AISettingsModal.tsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect, useCallback, useRef } from 'react'; +import React, { useState, useEffect, useCallback, useRef } from 'react'; import { Modal, Button, Input, Select, Form, Checkbox, message as antdMessage, Tooltip, Tabs, Space, Popconfirm, Slider } from 'antd'; import { PlusOutlined, DeleteOutlined, EditOutlined, CheckOutlined, ApiOutlined, SafetyCertificateOutlined, RobotOutlined, ThunderboltOutlined, CloudOutlined, ExperimentOutlined, KeyOutlined, LinkOutlined, AppstoreOutlined, ToolOutlined } from '@ant-design/icons'; import type { AIProviderConfig, AIProviderType, AISafetyLevel, AIContextLevel } from '../types'; @@ -19,6 +19,7 @@ import { PROVIDER_PRESET_CARD_TITLE_STYLE, } from '../utils/aiSettingsPresetLayout'; import { resolveProviderSecretDraft } from '../utils/providerSecretDraft'; +import { buildAddProviderEditorSession, buildClosedProviderEditorSession, buildEditProviderEditorSession, type ProviderEditorSession } from '../utils/aiProviderEditorState'; import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme'; @@ -134,18 +135,41 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo useEffect(() => { if (open) void loadConfig(); }, [open, loadConfig]); + const applyProviderEditorSession = useCallback((session: ProviderEditorSession) => { + setEditingProvider(session.editingProvider as AIProviderConfig | null); + setIsEditing(session.isEditing); + setTestStatus(session.testStatus); + setClearProviderSecret(session.clearProviderSecret); + form.resetFields(); + if (session.formValues) { + form.setFieldsValue(session.formValues); + } + }, [form]); + + const resetProviderEditorSession = useCallback(() => { + applyProviderEditorSession(buildClosedProviderEditorSession()); + }, [applyProviderEditorSession]); + + const handleModalClose = useCallback(() => { + resetProviderEditorSession(); + onClose(); + }, [onClose, resetProviderEditorSession]); + + useEffect(() => { + if (!open) { + resetProviderEditorSession(); + } + }, [open, resetProviderEditorSession]); const handleAddProvider = () => { const preset = findPreset('openai'); - const newProvider: AIProviderConfig = { - id: '', type: preset.backendType, name: '', apiKey: '', - baseUrl: preset.defaultBaseUrl, model: preset.defaultModel, - models: [], maxTokens: 4096, temperature: 0.7, - }; - setEditingProvider({ ...newProvider, presetKey: 'openai' } as any); - setIsEditing(true); - setTestStatus('idle'); - form.resetFields(); - form.setFieldsValue({ ...newProvider, presetKey: 'openai', apiFormat: 'openai' }); + applyProviderEditorSession(buildAddProviderEditorSession({ + presetKey: 'openai', + presetBackendType: preset.backendType, + presetBaseUrl: preset.defaultBaseUrl, + presetModel: preset.defaultModel, + presetModels: preset.models, + apiFormat: 'openai', + })); }; const handleEditProvider = (p: AIProviderConfig) => { @@ -156,17 +180,16 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo presetFixedApiFormat: matchedPreset.fixedApiFormat, valuesApiFormat: p.apiFormat, }); - setEditingProvider(p); - setIsEditing(true); - setTestStatus('idle'); - form.resetFields(); - form.setFieldsValue({ - ...p, - type: resolvedTransport.type, - models: p.models || [], - presetKey: matchedPreset.key, - apiFormat: resolvedTransport.apiFormat || p.apiFormat || 'openai', - }); + applyProviderEditorSession(buildEditProviderEditorSession({ + provider: { ...p, presetKey: matchedPreset.key } as any, + formValues: { + ...p, + type: resolvedTransport.type, + models: p.models || [], + presetKey: matchedPreset.key, + apiFormat: resolvedTransport.apiFormat || p.apiFormat || 'openai', + }, + })); }; const handleDeleteProvider = async (id: string) => { @@ -239,7 +262,7 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo }; // 后端 AISaveProvider 统一处理新增和更新,返回 void,失败抛异常 await Service?.AISaveProvider?.(payload); - void messageApi.success('已保存'); setIsEditing(false); setEditingProvider(null); setClearProviderSecret(false); void loadConfig(); + void messageApi.success('已保存'); resetProviderEditorSession(); void loadConfig(); window.dispatchEvent(new CustomEvent('gonavi:ai:provider-changed')); } catch (e: any) { if (e?.errorFields) { /* antd form validation error, ignore */ } @@ -420,7 +443,7 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo
{/* 顶部返回 */}
- {editingProvider?.id ? '编辑模型供应商' : '添加模型供应商'} @@ -732,7 +755,7 @@ const AISettingsModal: React.FC = ({ open, onClose, darkMo
} open={open} - onCancel={onClose} + onCancel={handleModalClose} footer={null} width={820} styles={{ @@ -802,3 +825,5 @@ export default AISettingsModal; + + diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx index b0f0763..0f0dec8 100644 --- a/frontend/src/components/ConnectionModal.tsx +++ b/frontend/src/components/ConnectionModal.tsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect, useRef, useMemo } from 'react'; +import React, { useState, useEffect, useRef, useMemo } from 'react'; import { Modal, Form, Input, InputNumber, Button, message, Checkbox, Divider, Select, Alert, Card, Row, Col, Typography, Collapse, Space, Table, Tag } from 'antd'; import { DatabaseOutlined, ConsoleSqlOutlined, FileTextOutlined, CloudServerOutlined, AppstoreAddOutlined, CloudOutlined, CheckCircleFilled, CloseCircleFilled, LinkOutlined, EditOutlined, AppstoreOutlined, BgColorsOutlined } from '@ant-design/icons'; import { getDbIcon, getDbDefaultColor, getDbIconLabel, DB_ICON_TYPES, PRESET_ICON_COLORS } from './DatabaseIcons'; @@ -6,6 +6,7 @@ import { useStore } from '../store'; import { buildOverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme'; import { normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance'; import { resolveConnectionSecretDraft } from '../utils/connectionSecretDraft'; +import { getCustomConnectionDsnValidationMessage } from '../utils/customConnectionDsn'; import { DBGetDatabases, GetDriverStatusList, MongoDiscoverMembers, TestConnection, RedisConnect, SelectDatabaseFile, SelectSSHKeyFile } from '../../wailsjs/go/app/App'; import { ConnectionConfig, MongoMemberInfo, SavedConnection } from '../types'; @@ -819,6 +820,19 @@ const ConnectionModal: React.FC<{ } }); + const createCustomDsnRule = () => ({ + validator(_: unknown, value: unknown) { + const validationMessage = getCustomConnectionDsnValidationMessage({ + dsnInput: value, + hasStoredSecret: initialValues?.hasOpaqueDSN, + clearStoredSecret: clearSecrets.opaqueDSN, + }); + return validationMessage + ? Promise.reject(new Error(validationMessage)) + : Promise.resolve(); + } + }); + const getUriPlaceholder = () => { if (dbType === 'mysql' || dbType === 'mariadb' || dbType === 'diros' || dbType === 'sphinx') { const defaultPort = getDefaultPortByType(dbType); @@ -2100,7 +2114,7 @@ const ConnectionModal: React.FC<{ - + {renderStoredSecretControls({ @@ -3108,3 +3122,5 @@ export default ConnectionModal; + + diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx index becce4a..f1c7eab 100644 --- a/frontend/src/main.tsx +++ b/frontend/src/main.tsx @@ -1,4 +1,4 @@ -import React from 'react' +import React from 'react' import ReactDOM from 'react-dom/client' import App from './App' // import './index.css' // Optional global styles @@ -15,35 +15,9 @@ dayjs.locale('zh-cn') import 'monaco-editor/esm/nls.messages.zh-cn' import { loader } from '@monaco-editor/react' import * as monaco from 'monaco-editor' +import { cloneBrowserMockValue, duplicateBrowserMockConnection, resolveBrowserMockSecretFlag } from './utils/browserMockConnections' loader.config({ monaco }) -const cloneBrowserMockValue = (value: any) => { - try { - return JSON.parse(JSON.stringify(value)); - } catch { - return value; - } -}; - -const resolveBrowserMockSecretFlag = (nextValue: unknown, clearFlag: boolean, existingFlag?: boolean) => { - if (String(nextValue ?? '') !== '') return true; - if (clearFlag) return false; - return !!existingFlag; -}; - -const buildBrowserMockDuplicateName = (rawName: string, items: any[]): string => { - const baseName = String(rawName || '').trim() || '连接'; - const suffix = ' - 副本'; - const usedNames = new Set(items.map((item) => String(item?.name || '').trim())); - let candidate = `${baseName}${suffix}`; - let counter = 2; - while (usedNames.has(candidate)) { - candidate = `${baseName}${suffix} ${counter}`; - counter += 1; - } - return candidate; -}; - if (typeof window !== 'undefined' && !(window as any).go) { const mockConnections: any[] = []; let mockGlobalProxy: any = { enabled: false, type: 'socks5', host: '', port: 1080, user: '', password: '', hasPassword: false }; @@ -124,13 +98,10 @@ if (typeof window !== 'undefined' && !(window as any).go) { DuplicateConnection: async (id: string) => { const existing = mockConnections.find((item) => item.id === id); if (!existing) return null; - const duplicated = cloneBrowserMockValue({ - ...existing, - id: `mock-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`, - name: buildBrowserMockDuplicateName(existing.name, mockConnections), - config: cloneBrowserMockValue(existing.config), - includeDatabases: Array.isArray(existing.includeDatabases) ? [...existing.includeDatabases] : undefined, - includeRedisDatabases: Array.isArray(existing.includeRedisDatabases) ? [...existing.includeRedisDatabases] : undefined, + const duplicated = duplicateBrowserMockConnection({ + existing, + items: mockConnections, + nextId: `mock-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`, }); mockConnections.push(duplicated); return cloneBrowserMockValue(duplicated); @@ -174,3 +145,6 @@ ReactDOM.createRoot(document.getElementById('root')!).render( , ) + + + diff --git a/frontend/src/utils/aiProviderEditorState.test.ts b/frontend/src/utils/aiProviderEditorState.test.ts new file mode 100644 index 0000000..f869d4b --- /dev/null +++ b/frontend/src/utils/aiProviderEditorState.test.ts @@ -0,0 +1,49 @@ +import { describe, expect, it } from 'vitest'; + +import { + buildAddProviderEditorSession, + buildClosedProviderEditorSession, + buildEditProviderEditorSession, +} from './aiProviderEditorState'; + +describe('aiProviderEditorState', () => { + it('resets clearProviderSecret when starting add flow', () => { + const session = buildAddProviderEditorSession({ + previousClearProviderSecret: true, + presetBackendType: 'openai', + presetBaseUrl: 'https://api.openai.com/v1', + presetModel: 'gpt-4.1', + }); + + expect(session.clearProviderSecret).toBe(false); + expect(session.isEditing).toBe(true); + expect(session.testStatus).toBe('idle'); + }); + + it('resets clearProviderSecret when starting edit flow', () => { + const session = buildEditProviderEditorSession({ + previousClearProviderSecret: true, + provider: { + id: 'provider-1', + type: 'openai', + name: 'OpenAI', + apiKey: '', + hasSecret: true, + }, + }); + + expect(session.clearProviderSecret).toBe(false); + expect(session.isEditing).toBe(true); + expect(session.editingProvider?.id).toBe('provider-1'); + }); + + it('resets clearProviderSecret when the modal closes', () => { + const session = buildClosedProviderEditorSession({ + previousClearProviderSecret: true, + }); + + expect(session.clearProviderSecret).toBe(false); + expect(session.isEditing).toBe(false); + expect(session.editingProvider).toBeNull(); + }); +}); diff --git a/frontend/src/utils/aiProviderEditorState.ts b/frontend/src/utils/aiProviderEditorState.ts new file mode 100644 index 0000000..6ce5e0f --- /dev/null +++ b/frontend/src/utils/aiProviderEditorState.ts @@ -0,0 +1,92 @@ +import type { AIProviderConfig, AIProviderType } from '../types'; + +type ProviderEditorStatus = 'idle' | 'success' | 'error'; + +type ProviderEditorConfig = Partial & Pick & { presetKey?: string }; + +export interface ProviderEditorSession { + editingProvider: ProviderEditorConfig | null; + formValues: Record | null; + isEditing: boolean; + clearProviderSecret: boolean; + testStatus: ProviderEditorStatus; +} + +interface BuildAddProviderEditorSessionInput { + previousClearProviderSecret?: boolean; + presetKey?: string; + presetBackendType: AIProviderType; + presetBaseUrl: string; + presetModel: string; + presetModels?: string[]; + apiFormat?: string; +} + +interface BuildEditProviderEditorSessionInput { + previousClearProviderSecret?: boolean; + provider: ProviderEditorConfig; + formValues?: Record; +} + +interface BuildClosedProviderEditorSessionInput { + previousClearProviderSecret?: boolean; +} + +export const buildAddProviderEditorSession = ({ + presetKey = 'openai', + presetBackendType, + presetBaseUrl, + presetModel, + presetModels = [], + apiFormat = 'openai', +}: BuildAddProviderEditorSessionInput): ProviderEditorSession => { + const editingProvider: ProviderEditorConfig = { + id: '', + type: presetBackendType, + name: '', + apiKey: '', + baseUrl: presetBaseUrl, + model: presetModel, + models: [...presetModels], + maxTokens: 4096, + temperature: 0.7, + presetKey, + }; + + return { + editingProvider, + formValues: { + ...editingProvider, + presetKey, + apiFormat, + }, + isEditing: true, + clearProviderSecret: false, + testStatus: 'idle', + }; +}; + +export const buildEditProviderEditorSession = ({ + provider, + formValues, +}: BuildEditProviderEditorSessionInput): ProviderEditorSession => ({ + editingProvider: provider, + formValues: formValues || { + ...provider, + models: provider.models || [], + presetKey: provider.presetKey, + apiFormat: provider.apiFormat || 'openai', + }, + isEditing: true, + clearProviderSecret: false, + testStatus: 'idle', +}); + +export const buildClosedProviderEditorSession = (_input?: BuildClosedProviderEditorSessionInput): ProviderEditorSession => ({ + editingProvider: null, + formValues: null, + isEditing: false, + clearProviderSecret: false, + testStatus: 'idle', +}); + diff --git a/frontend/src/utils/browserMockConnections.test.ts b/frontend/src/utils/browserMockConnections.test.ts new file mode 100644 index 0000000..10299c6 --- /dev/null +++ b/frontend/src/utils/browserMockConnections.test.ts @@ -0,0 +1,26 @@ +import { describe, expect, it } from 'vitest'; + +import { duplicateBrowserMockConnection } from './browserMockConnections'; + +describe('duplicateBrowserMockConnection', () => { + it('rewrites config.id to match the duplicated top-level id', () => { + const duplicated = duplicateBrowserMockConnection({ + existing: { + id: 'conn-1', + name: 'Primary', + config: { + id: 'conn-1', + type: 'postgres', + }, + includeDatabases: ['appdb'], + }, + items: [], + nextId: 'conn-2', + }); + + expect(duplicated.id).toBe('conn-2'); + expect(duplicated.config.id).toBe('conn-2'); + expect(duplicated.name).toBe('Primary - 副本'); + expect(duplicated.includeDatabases).toEqual(['appdb']); + }); +}); diff --git a/frontend/src/utils/browserMockConnections.ts b/frontend/src/utils/browserMockConnections.ts new file mode 100644 index 0000000..402cd6f --- /dev/null +++ b/frontend/src/utils/browserMockConnections.ts @@ -0,0 +1,47 @@ +export const cloneBrowserMockValue = (value: T): T => { + try { + return JSON.parse(JSON.stringify(value)); + } catch { + return value; + } +}; + +export const resolveBrowserMockSecretFlag = (nextValue: unknown, clearFlag: boolean, existingFlag?: boolean) => { + if (String(nextValue ?? '') !== '') return true; + if (clearFlag) return false; + return !!existingFlag; +}; + +export const buildBrowserMockDuplicateName = (rawName: string, items: any[]): string => { + const baseName = String(rawName || '').trim() || '连接'; + const suffix = ' - 副本'; + const usedNames = new Set(items.map((item) => String(item?.name || '').trim())); + let candidate = `${baseName}${suffix}`; + let counter = 2; + while (usedNames.has(candidate)) { + candidate = `${baseName}${suffix} ${counter}`; + counter += 1; + } + return candidate; +}; + +interface DuplicateBrowserMockConnectionInput { + existing: any; + items: any[]; + nextId: string; +} + +export const duplicateBrowserMockConnection = ({ existing, items, nextId }: DuplicateBrowserMockConnectionInput) => { + const duplicated = cloneBrowserMockValue({ + ...existing, + id: nextId, + name: buildBrowserMockDuplicateName(existing?.name, items), + config: { + ...cloneBrowserMockValue(existing?.config), + id: nextId, + }, + includeDatabases: Array.isArray(existing?.includeDatabases) ? [...existing.includeDatabases] : undefined, + includeRedisDatabases: Array.isArray(existing?.includeRedisDatabases) ? [...existing.includeRedisDatabases] : undefined, + }); + return duplicated; +}; diff --git a/frontend/src/utils/customConnectionDsn.test.ts b/frontend/src/utils/customConnectionDsn.test.ts new file mode 100644 index 0000000..8c35fb5 --- /dev/null +++ b/frontend/src/utils/customConnectionDsn.test.ts @@ -0,0 +1,37 @@ +import { describe, expect, it } from 'vitest'; + +import { shouldAllowBlankCustomDsn } from './customConnectionDsn'; + +describe('shouldAllowBlankCustomDsn', () => { + it('allows a blank DSN when editing a connection that already has a stored opaque DSN', () => { + expect(shouldAllowBlankCustomDsn({ + dsnInput: '', + hasStoredSecret: true, + clearStoredSecret: false, + })).toBe(true); + }); + + it('requires a new DSN when the user chooses to clear the stored opaque DSN', () => { + expect(shouldAllowBlankCustomDsn({ + dsnInput: '', + hasStoredSecret: true, + clearStoredSecret: true, + })).toBe(false); + }); + + it('requires a DSN for brand new custom connections', () => { + expect(shouldAllowBlankCustomDsn({ + dsnInput: '', + hasStoredSecret: false, + clearStoredSecret: false, + })).toBe(false); + }); + + it('accepts a newly entered DSN even when a stored secret already exists', () => { + expect(shouldAllowBlankCustomDsn({ + dsnInput: 'driver://demo', + hasStoredSecret: true, + clearStoredSecret: true, + })).toBe(true); + }); +}); diff --git a/frontend/src/utils/customConnectionDsn.ts b/frontend/src/utils/customConnectionDsn.ts new file mode 100644 index 0000000..58f92ed --- /dev/null +++ b/frontend/src/utils/customConnectionDsn.ts @@ -0,0 +1,27 @@ +export interface CustomConnectionDsnState { + dsnInput: unknown; + hasStoredSecret?: boolean; + clearStoredSecret?: boolean; +} + +export const getCustomConnectionDsnValidationMessage = ({ + dsnInput, + hasStoredSecret, + clearStoredSecret, +}: CustomConnectionDsnState): string | null => { + const dsnText = String(dsnInput ?? '').trim(); + if (dsnText !== '') { + return null; + } + if (hasStoredSecret && !clearStoredSecret) { + return null; + } + if (hasStoredSecret && clearStoredSecret) { + return '请输入新的连接字符串,或取消清除已保存 DSN'; + } + return '请输入连接字符串'; +}; + +export const shouldAllowBlankCustomDsn = (state: CustomConnectionDsnState): boolean => ( + getCustomConnectionDsnValidationMessage(state) === null +); From 37b3c78049f23f4c38146a70b64ff18cec168174 Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Thu, 2 Apr 2026 23:41:34 +0800 Subject: [PATCH 13/14] =?UTF-8?q?=E2=9C=A8=20feat(datagrid):=20=E5=A2=9E?= =?UTF-8?q?=E5=BC=BA=E6=95=B0=E6=8D=AE=E8=A1=A8=E6=98=BE=E7=A4=BA=E4=B8=8E?= =?UTF-8?q?=E8=A1=8C=E7=BA=A7SQL=E5=A4=8D=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 DataGrid 竖向分隔线与列宽模式配置并持久化\n- 支持复制 INSERT/UPDATE/DELETE 并按主键或唯一键生成条件\n- 补充外观配置与 SQL 复制相关测试 --- .gitignore | 3 +- frontend/src/App.tsx | 36 +- frontend/src/components/DataGrid.tsx | 283 +++++++++++++--- .../src/components/dataGridCopyInsert.test.ts | 103 +++++- frontend/src/components/dataGridCopyInsert.ts | 308 +++++++++++++++++- frontend/src/store.test.ts | 94 ++++++ frontend/src/store.ts | 35 +- frontend/src/utils/dataGridDisplay.test.ts | 32 ++ frontend/src/utils/dataGridDisplay.ts | 72 ++++ 9 files changed, 903 insertions(+), 63 deletions(-) create mode 100644 frontend/src/store.test.ts create mode 100644 frontend/src/utils/dataGridDisplay.test.ts create mode 100644 frontend/src/utils/dataGridDisplay.ts diff --git a/.gitignore b/.gitignore index 285e157..0d35fb4 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,5 @@ docs/需求追踪/ CLAUDE.md **/CLAUDE.md -.worktrees \ No newline at end of file +.worktrees +docs \ No newline at end of file diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index f6c1554..3c6825d 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,5 +1,5 @@ import React, { useState, useEffect, useMemo, useCallback } from 'react'; -import { Layout, Button, ConfigProvider, theme, message, Modal, Spin, Slider, Progress, Switch, Input, InputNumber, Select, Tooltip } from 'antd'; +import { Layout, Button, ConfigProvider, theme, message, Modal, Spin, Slider, Progress, Switch, Input, InputNumber, Select, Segmented, Tooltip } from 'antd'; import zhCN from 'antd/locale/zh_CN'; import { PlusOutlined, ConsoleSqlOutlined, UploadOutlined, DownloadOutlined, CloudDownloadOutlined, BugOutlined, ToolOutlined, GlobalOutlined, InfoCircleOutlined, GithubOutlined, SkinOutlined, CheckOutlined, MinusOutlined, BorderOutlined, CloseOutlined, SettingOutlined, LinkOutlined, BgColorsOutlined, AppstoreOutlined, RobotOutlined } from '@ant-design/icons'; import { BrowserOpenURL, Environment, EventsOn, Quit, WindowFullscreen, WindowGetPosition, WindowGetSize, WindowIsFullscreen, WindowIsMaximised, WindowMaximise, WindowMinimise, WindowSetPosition, WindowSetSize, WindowToggleMaximise, WindowUnfullscreen } from '../wailsjs/runtime'; @@ -11,9 +11,10 @@ import DriverManagerModal from './components/DriverManagerModal'; import LogPanel from './components/LogPanel'; import AIChatPanel from './components/AIChatPanel'; import AISettingsModal from './components/AISettingsModal'; -import { useStore } from './store'; +import { DEFAULT_APPEARANCE, useStore } from './store'; import { SavedConnection } from './types'; import { blurToFilter, normalizeBlurForPlatform, normalizeOpacityForPlatform, isWindowsPlatform, resolveAppearanceValues } from './utils/appearance'; +import { DATA_GRID_COLUMN_WIDTH_MODE_OPTIONS, sanitizeDataTableColumnWidthMode } from './utils/dataGridDisplay'; import { getMacNativeTitlebarPaddingLeft, getMacNativeTitlebarPaddingRight, shouldHandleMacNativeFullscreenShortcut, shouldSuppressMacNativeEscapeExit } from './utils/macWindow'; import { buildOverlayWorkbenchTheme } from './utils/overlayWorkbenchTheme'; import { getConnectionWorkbenchState } from './utils/startupReadiness'; @@ -2295,6 +2296,33 @@ function App() {
+
+
数据表显示
+
+
+
+
显示数据表竖向分隔线
+
仅作用于数据表页面 DataGrid,不影响其他表格组件。
+
+ setAppearance({ showDataTableVerticalBorders: checked })} + /> +
+
+
数据表列宽模式
+ setAppearance({ dataTableColumnWidthMode: sanitizeDataTableColumnWidthMode(value) })} + /> +
+ 标准模式默认列宽 200px;紧凑模式默认列宽 140px。已手动拖拽调整的列宽优先保留。 +
+
+
+
{isMacRuntime ? (
macOS 窗口控制
@@ -2328,7 +2356,7 @@ function App() { onClick={() => { setUiScale(DEFAULT_UI_SCALE); setFontSize(DEFAULT_FONT_SIZE); - setAppearance({ enabled: true, opacity: 1.0, blur: 0, useNativeMacWindowControls: false }); + setAppearance({ ...DEFAULT_APPEARANCE }); }} > 恢复默认 @@ -2591,5 +2619,3 @@ function App() { } export default App; - - diff --git a/frontend/src/components/DataGrid.tsx b/frontend/src/components/DataGrid.tsx index d6a31fe..63c8528 100644 --- a/frontend/src/components/DataGrid.tsx +++ b/frontend/src/components/DataGrid.tsx @@ -23,20 +23,31 @@ import { arrayMove } from '@dnd-kit/sortable'; import { CSS } from '@dnd-kit/utilities'; -import { ImportData, ExportTable, ExportData, ExportQuery, ApplyChanges, DBGetColumns } from '../../wailsjs/go/app/App'; +import { ImportData, ExportTable, ExportData, ExportQuery, ApplyChanges, DBGetColumns, DBGetIndexes } from '../../wailsjs/go/app/App'; import ImportPreviewModal from './ImportPreviewModal'; import { useStore } from '../store'; -import type { ColumnDefinition } from '../types'; +import type { ColumnDefinition, IndexDefinition } from '../types'; import { v4 as generateUuid } from 'uuid'; import 'react-resizable/css/styles.css'; import { buildOrderBySQL, buildPaginatedSelectSQL, buildWhereSQL, escapeLiteral, hasExplicitSort, quoteIdentPart, quoteQualifiedIdent, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql'; import { isMacLikePlatform, normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance'; import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities'; import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; +import { + resolveDataTableColumnWidth, + resolveDataTableDefaultColumnWidth, + resolveDataTableVerticalBorderColor, +} from '../utils/dataGridDisplay'; import { resolvePaginationPageText, resolvePaginationSummaryText, resolvePaginationTotalForControl } from '../utils/dataGridPagination'; import { resolveGridSortInfoFromTableSorter } from '../utils/dataGridSort'; import { calculateTableBodyBottomPadding, calculateVirtualTableScrollX } from './dataGridLayout'; -import { buildCopyInsertSQL, normalizeTemporalLiteralText } from './dataGridCopyInsert'; +import { + buildCopyDeleteSQL, + buildCopyInsertSQL, + buildCopyUpdateSQL, + normalizeTemporalLiteralText, + resolveUniqueKeyGroupsFromIndexes, +} from './dataGridCopyInsert'; // --- Error Boundary --- interface DataGridErrorBoundaryState { @@ -533,6 +544,8 @@ const DataContext = React.createContext<{ selectedRowKeysRef: React.MutableRefObject; displayDataRef: React.MutableRefObject; handleCopyInsert: (r: any) => void; + handleCopyUpdate: (r: any) => void; + handleCopyDelete: (r: any) => void; handleCopyJson: (r: any) => void; handleCopyCsv: (r: any) => void; handleExportSelected: (format: string, r: any) => Promise; @@ -785,7 +798,19 @@ const ContextMenuRow = React.memo(({ children, record, ...props }: any) => { if (!record || !context) return {children}; - const { selectedRowKeysRef, displayDataRef, handleCopyInsert, handleCopyJson, handleCopyCsv, handleExportSelected, copyToClipboard, enableRowContextMenu, supportsCopyInsert } = context; + const { + selectedRowKeysRef, + displayDataRef, + handleCopyInsert, + handleCopyUpdate, + handleCopyDelete, + handleCopyJson, + handleCopyCsv, + handleExportSelected, + copyToClipboard, + enableRowContextMenu, + supportsCopyInsert, + } = context; if (!enableRowContextMenu) { return {children}; @@ -806,6 +831,16 @@ const ContextMenuRow = React.memo(({ children, record, ...props }: any) => { label: '复制为 INSERT', icon: , onClick: () => handleCopyInsert(record), + }, { + key: 'update', + label: '复制为 UPDATE', + icon: , + onClick: () => handleCopyUpdate(record), + }, { + key: 'delete', + label: '复制为 DELETE', + icon: , + onClick: () => handleCopyDelete(record), }] : []), { key: 'json', label: '复制为 JSON', icon: , onClick: () => handleCopyJson(record) }, { key: 'csv', label: '复制为 CSV', icon: , onClick: () => handleCopyCsv(record) }, @@ -931,6 +966,13 @@ const DataGrid: React.FC = ({ const darkMode = theme === 'dark'; const resolvedAppearance = resolveAppearanceValues(appearance); const opacity = normalizeOpacityForPlatform(resolvedAppearance.opacity); + const showDataTableVerticalBorders = appearance.showDataTableVerticalBorders === true; + const dataTableColumnWidthMode = appearance.dataTableColumnWidthMode; + const defaultColumnWidth = resolveDataTableDefaultColumnWidth(dataTableColumnWidthMode); + const dataTableVerticalBorderColor = resolveDataTableVerticalBorderColor({ + darkMode, + visible: showDataTableVerticalBorders, + }); const canModifyData = !readOnly && !!tableName; const showColumnComment = queryOptions?.showColumnComment ?? true; const showColumnType = queryOptions?.showColumnType ?? true; @@ -1312,8 +1354,11 @@ const DataGrid: React.FC = ({ const [sortInfo, setSortInfo] = useState>([]); const [columnWidths, setColumnWidths] = useState>({}); const [columnMetaMap, setColumnMetaMap] = useState>({}); + const [uniqueKeyGroups, setUniqueKeyGroups] = useState([]); const columnMetaCacheRef = useRef>>({}); const columnMetaSeqRef = useRef(0); + const uniqueKeyGroupsCacheRef = useRef>({}); + const uniqueKeyGroupsSeqRef = useRef(0); useEffect(() => { const ext = sortInfoExternal || []; @@ -1328,10 +1373,12 @@ const DataGrid: React.FC = ({ const normalizedDbName = String(dbName || '').trim(); if (!connectionId || !normalizedTableName) { setColumnMetaMap({}); + setUniqueKeyGroups([]); return; } const cacheKey = `${connectionId}|${normalizedDbName}|${normalizedTableName}`; setColumnMetaMap(columnMetaCacheRef.current[cacheKey] || {}); + setUniqueKeyGroups(uniqueKeyGroupsCacheRef.current[cacheKey] || []); }, [connectionId, dbName, tableName]); useEffect(() => { @@ -1382,6 +1429,47 @@ const DataGrid: React.FC = ({ }); }, [connections, connectionId, dbName, tableName]); + useEffect(() => { + const normalizedTableName = String(tableName || '').trim(); + const normalizedDbName = String(dbName || '').trim(); + if (!connectionId || !normalizedTableName) return; + + const cacheKey = `${connectionId}|${normalizedDbName}|${normalizedTableName}`; + if (uniqueKeyGroupsCacheRef.current[cacheKey]) return; + + const conn = connections.find(c => c.id === connectionId); + if (!conn) { + setUniqueKeyGroups([]); + return; + } + + const config = { + ...conn.config, + port: Number(conn.config.port), + password: conn.config.password || "", + database: conn.config.database || "", + useSSH: conn.config.useSSH || false, + ssh: conn.config.ssh || { host: "", port: 22, user: "", password: "", keyPath: "" } + }; + + const seq = ++uniqueKeyGroupsSeqRef.current; + DBGetIndexes(config as any, normalizedDbName, normalizedTableName) + .then((res) => { + if (seq !== uniqueKeyGroupsSeqRef.current) return; + if (!res.success || !Array.isArray(res.data)) { + setUniqueKeyGroups([]); + return; + } + const nextGroups = resolveUniqueKeyGroupsFromIndexes(res.data as IndexDefinition[]); + uniqueKeyGroupsCacheRef.current[cacheKey] = nextGroups; + setUniqueKeyGroups(nextGroups); + }) + .catch(() => { + if (seq !== uniqueKeyGroupsSeqRef.current) return; + setUniqueKeyGroups([]); + }); + }, [connections, connectionId, dbName, tableName]); + const columnMetaMapByLowerName = useMemo(() => { const next: Record = {}; Object.entries(columnMetaMap).forEach(([name, meta]) => { @@ -1402,6 +1490,17 @@ const DataGrid: React.FC = ({ return next; }, [columnMetaMapByLowerName]); + const allTableColumnNames = useMemo(() => { + const metaColumns = Object.keys(columnMetaMap); + if (metaColumns.length > 0) { + return metaColumns; + } + if (exportScope === 'table') { + return columnNames.filter((columnName) => columnName !== GONAVI_ROW_KEY); + } + return []; + }, [columnMetaMap, exportScope, columnNames]); + const normalizeCommitCellValue = useCallback( (columnName: string, value: any, mode: 'insert' | 'update') => { if (value === undefined) return undefined; @@ -1572,8 +1671,15 @@ const DataGrid: React.FC = ({ overflow: hidden !important; } .${gridId} .ant-table-tbody > tr > td, - .${gridId} .ant-table-tbody .ant-table-row > .ant-table-cell { background: transparent !important; border-bottom: 1px solid ${darkMode ? 'rgba(255,255,255,0.05)' : 'rgba(0,0,0,0.05)'} !important; border-inline-end: 1px solid transparent !important; } - .${gridId} .ant-table-thead > tr > th { background: transparent !important; border-bottom: 1px solid ${darkMode ? 'rgba(255,255,255,0.05)' : 'rgba(0,0,0,0.05)'} !important; border-inline-end: 1px solid transparent !important; } + .${gridId} .ant-table-tbody .ant-table-row > .ant-table-cell, + .${gridId} .ant-table-tbody-virtual-holder .ant-table-row > .ant-table-cell { background: transparent !important; border-bottom: 1px solid ${darkMode ? 'rgba(255,255,255,0.05)' : 'rgba(0,0,0,0.05)'} !important; border-inline-end: 1px solid ${dataTableVerticalBorderColor} !important; } + .${gridId} .ant-table-thead > tr > th { background: transparent !important; border-bottom: 1px solid ${darkMode ? 'rgba(255,255,255,0.05)' : 'rgba(0,0,0,0.05)'} !important; border-inline-end: 1px solid ${dataTableVerticalBorderColor} !important; } + .${gridId} .ant-table-tbody > tr > td:last-child, + .${gridId} .ant-table-tbody .ant-table-row > .ant-table-cell:last-child, + .${gridId} .ant-table-tbody-virtual-holder .ant-table-row > .ant-table-cell:last-child, + .${gridId} .ant-table-thead > tr > th:last-child { + border-inline-end-color: transparent !important; + } /* 选择列对齐:header TH 无 class(Ant Design 虚拟模式),需用 :first-child 匹配 */ .${gridId} .ant-table-header th:first-child, .${gridId} .ant-table-thead > tr > th:first-child { @@ -2010,7 +2116,7 @@ const DataGrid: React.FC = ({ justify-content: center; line-height: 1; } - `, [themeStyles, gridId, tableBodyBottomPadding, darkMode, opacity]); + `, [themeStyles, gridId, tableBodyBottomPadding, darkMode, opacity, dataTableVerticalBorderColor]); const recalculateTableMetrics = useCallback((targetElement?: HTMLElement | null) => { const target = targetElement || containerRef.current; @@ -2805,7 +2911,10 @@ const DataGrid: React.FC = ({ const startX = e.clientX; - const currentWidth = columnWidths[key] || 200; + const currentWidth = resolveDataTableColumnWidth({ + manualWidth: columnWidths[key], + widthMode: dataTableColumnWidthMode, + }); const containerLeft = containerRef.current?.getBoundingClientRect().left ?? 0; @@ -2836,7 +2945,7 @@ const DataGrid: React.FC = ({ document.body.style.userSelect = 'none'; - }, [columnWidths]); + }, [columnWidths, dataTableColumnWidthMode]); // 2. Drag Move (Global) const handleResizeMove = useCallback((e: MouseEvent) => { @@ -3280,7 +3389,10 @@ const DataGrid: React.FC = ({ dataIndex: key, key: key, // 不使用 ellipsis,避免 Ant Design 的 Tooltip 展开行为 - width: columnWidths[key] || 200, + width: resolveDataTableColumnWidth({ + manualWidth: columnWidths[key], + widthMode: dataTableColumnWidthMode, + }), sorter: onSort ? { multiple: displayColumnNames.indexOf(key) + 1 } : false, sortOrder: (sortInfo.find(s => s.columnKey === key && s.enabled !== false)?.order || null) as SortOrder | undefined, editable: canModifyData, // Only editable if table name known and not readonly @@ -3321,7 +3433,7 @@ const DataGrid: React.FC = ({ }, }), })); - }, [displayColumnNames, columnWidths, sortInfo, handleResizeStart, canModifyData, onSort, renderColumnTitle]); + }, [displayColumnNames, columnWidths, sortInfo, handleResizeStart, canModifyData, onSort, renderColumnTitle, dataTableColumnWidthMode]); const mergedColumns = useMemo(() => columns.map((col): ColumnType => { const dataIndex = String(col.dataIndex); @@ -3554,24 +3666,87 @@ const DataGrid: React.FC = ({ return [clickedRecord]; }, []); - const handleCopyInsert = useCallback((record: any) => { + const buildCopySqlBatchText = useCallback((mode: 'insert' | 'update' | 'delete', record: any): string | null => { if (!supportsCopyInsert) { - void message.warning("当前数据源不支持复制为 INSERT,请使用 JSON/CSV/Markdown 复制。"); - return; + void message.warning("当前数据源不支持复制 SQL,请使用 JSON/CSV/Markdown 复制。"); + return null; } const records = getTargets(record); - // 使用 columnNames 保持表定义的字段顺序,而非 Object.keys() 的不确定顺序 const orderedCols = columnNames.filter(c => c !== GONAVI_ROW_KEY); - const sqlList = records.map((r: any) => { - return buildCopyInsertSQL({ + if (mode === 'insert') { + return records.map((row: any) => buildCopyInsertSQL({ dbType, tableName, orderedCols, - record: r, + record: row, columnTypesByLowerName: columnTypeMapByLowerName, - }); + })).join('\n\n'); + } + + const sqlResults = records.map((row: any) => ( + mode === 'update' + ? buildCopyUpdateSQL({ + dbType, + tableName, + orderedCols, + record: row, + pkColumns, + uniqueKeyGroups, + allTableColumns: allTableColumnNames, + columnTypesByLowerName: columnTypeMapByLowerName, + }) + : buildCopyDeleteSQL({ + dbType, + tableName, + orderedCols, + record: row, + pkColumns, + uniqueKeyGroups, + allTableColumns: allTableColumnNames, + columnTypesByLowerName: columnTypeMapByLowerName, + }) + )); + const failedResult = sqlResults.find((result) => result.ok === false); + if (failedResult && failedResult.ok === false) { + void message.warning(failedResult.error); + return null; + } + const sqlTexts: string[] = []; + sqlResults.forEach((result) => { + if (result.ok) { + sqlTexts.push(result.sql); + } }); - copyToClipboard(sqlList.join('\n')); }, [supportsCopyInsert, columnNames, getTargets, copyToClipboard, dbType, tableName, columnTypeMapByLowerName]); + return sqlTexts.join('\n\n'); + }, [ + supportsCopyInsert, + getTargets, + columnNames, + dbType, + tableName, + columnTypeMapByLowerName, + pkColumns, + uniqueKeyGroups, + allTableColumnNames, + ]); + + const handleCopyInsert = useCallback((record: any) => { + const batchText = buildCopySqlBatchText('insert', record); + if (!batchText) return; + copyToClipboard(batchText); + }, [buildCopySqlBatchText, copyToClipboard]); + + const handleCopyUpdate = useCallback((record: any) => { + const batchText = buildCopySqlBatchText('update', record); + if (!batchText) return; + copyToClipboard(batchText); + }, [buildCopySqlBatchText, copyToClipboard]); + + const handleCopyDelete = useCallback((record: any) => { + const batchText = buildCopySqlBatchText('delete', record); + if (!batchText) return; + copyToClipboard(batchText); + }, [buildCopySqlBatchText, copyToClipboard]); const handleCopyJson = useCallback((record: any) => { const records = getTargets(record); @@ -4022,6 +4197,8 @@ const DataGrid: React.FC = ({ selectedRowKeysRef, displayDataRef, handleCopyInsert, + handleCopyUpdate, + handleCopyDelete, handleCopyJson, handleCopyCsv, handleExportSelected, @@ -4029,7 +4206,7 @@ const DataGrid: React.FC = ({ tableName, enableRowContextMenu: false, supportsCopyInsert, - }), [handleCopyCsv, handleCopyInsert, handleCopyJson, handleExportSelected, copyToClipboard, tableName, canModifyData, supportsCopyInsert]); + }), [handleCopyCsv, handleCopyDelete, handleCopyInsert, handleCopyJson, handleCopyUpdate, handleExportSelected, copyToClipboard, tableName, supportsCopyInsert]); const cellContextMenuValue = useMemo(() => ({ showMenu: showCellContextMenu, @@ -4044,7 +4221,7 @@ const DataGrid: React.FC = ({ const rowPropsFactory = useCallback((record: any) => ({ record } as any), []); - const totalWidth = columns.reduce((sum: number, col: any) => sum + (Number(col.width) || 200), 0) + selectionColumnWidth; + const totalWidth = columns.reduce((sum: number, col: any) => sum + (Number(col.width) || defaultColumnWidth), 0) + selectionColumnWidth; const useContextMenuRow = false; const tableScrollX = useMemo(() => { // rc-table 在 scroll.x 小于容器宽度时会把实际列宽按视口补齐。 @@ -5446,21 +5623,53 @@ const DataGrid: React.FC = ({ )} {supportsCopyInsert && ( -
e.currentTarget.style.background = darkMode ? '#303030' : '#f5f5f5'} - onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'} - onClick={() => { - if (cellContextMenu.record) handleCopyInsert(cellContextMenu.record); - setCellContextMenu(prev => ({ ...prev, visible: false })); - }} - > - 复制为 INSERT -
+ <> +
e.currentTarget.style.background = darkMode ? '#303030' : '#f5f5f5'} + onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'} + onClick={() => { + if (cellContextMenu.record) handleCopyInsert(cellContextMenu.record); + setCellContextMenu(prev => ({ ...prev, visible: false })); + }} + > + 复制为 INSERT +
+
e.currentTarget.style.background = darkMode ? '#303030' : '#f5f5f5'} + onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'} + onClick={() => { + if (cellContextMenu.record) handleCopyUpdate(cellContextMenu.record); + setCellContextMenu(prev => ({ ...prev, visible: false })); + }} + > + 复制为 UPDATE +
+
e.currentTarget.style.background = darkMode ? '#303030' : '#f5f5f5'} + onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'} + onClick={() => { + if (cellContextMenu.record) handleCopyDelete(cellContextMenu.record); + setCellContextMenu(prev => ({ ...prev, visible: false })); + }} + > + 复制为 DELETE +
+ )}
{ it('normalizes PostgreSQL timestamp values for copy-as-insert and uses PostgreSQL identifier quoting', () => { @@ -58,4 +63,100 @@ describe('buildCopyInsertSQL', () => { `INSERT INTO public.audit_log (payload) VALUES ('2026-01-21T18:32:26+08:00');`, ); }); + + it('groups composite unique indexes by name and sequence order', () => { + expect(resolveUniqueKeyGroupsFromIndexes([ + { name: 'PRIMARY', columnName: 'id', nonUnique: 0, seqInIndex: 1, indexType: 'BTREE' }, + { name: 'uk_order_code', columnName: 'code', nonUnique: 0, seqInIndex: 2, indexType: 'BTREE' }, + { name: 'uk_order_code', columnName: 'tenant_id', nonUnique: 0, seqInIndex: 1, indexType: 'BTREE' }, + { name: 'idx_note', columnName: 'note', nonUnique: 1, seqInIndex: 1, indexType: 'BTREE' }, + ])).toEqual([ + ['id'], + ['tenant_id', 'code'], + ]); + }); + + it('builds UPDATE SQL with a primary-key WHERE clause and keeps literal formatting aligned with INSERT', () => { + const result = buildCopyUpdateSQL({ + dbType: 'mysql', + tableName: 'orders', + orderedCols: ['id', 'note', 'deleted_at'], + record: { + id: 7, + note: "O'Brien", + deleted_at: null, + }, + pkColumns: ['id'], + columnTypesByLowerName: { + deleted_at: 'datetime', + }, + allTableColumns: ['id', 'note', 'deleted_at'], + }); + + expect(result).toEqual({ + ok: true, + whereStrategy: 'primary-key', + sql: `UPDATE \`orders\` SET \`id\` = '7', \`note\` = 'O''Brien', \`deleted_at\` = NULL WHERE (\`id\` = '7');`, + }); + }); + + it('builds DELETE SQL with a composite unique-key WHERE clause when no primary key is available', () => { + const result = buildCopyDeleteSQL({ + dbType: 'postgres', + tableName: 'public.audit_log', + orderedCols: ['tenant_id', 'code', 'payload'], + record: { + tenant_id: 'acme', + code: 'evt-7', + payload: '{"ok":true}', + }, + uniqueKeyGroups: [['tenant_id', 'code']], + allTableColumns: ['tenant_id', 'code', 'payload'], + }); + + expect(result).toEqual({ + ok: true, + whereStrategy: 'unique-key', + sql: `DELETE FROM public.audit_log WHERE (tenant_id = 'acme' AND code = 'evt-7');`, + }); + }); + + it('falls back to all-column matching and uses IS NULL for null values', () => { + const result = buildCopyDeleteSQL({ + dbType: 'sqlserver', + tableName: 'dbo.OrderLog', + orderedCols: ['id', 'deleted_at', 'flag'], + allTableColumns: ['id', 'deleted_at', 'flag'], + record: { + id: 5, + deleted_at: null, + flag: true, + }, + }); + + expect(result).toEqual({ + ok: true, + whereStrategy: 'all-columns', + sql: `DELETE FROM [dbo].[OrderLog] WHERE ([id] = '5' AND [deleted_at] IS NULL AND [flag] = 'true');`, + }); + }); + + it('refuses to build UPDATE/DELETE SQL when the result set lacks keys and does not cover all table columns', () => { + const result = buildCopyDeleteSQL({ + dbType: 'mysql', + tableName: 'orders', + orderedCols: ['note'], + allTableColumns: ['id', 'note', 'created_at'], + record: { + note: 'partial row', + }, + }); + + expect(result.ok).toBe(false); + if (result.ok) { + throw new Error('expected buildCopyDeleteSQL to fail'); + } + expect(result.error).toContain('主键'); + expect(result.error).toContain('全部字段'); + }); }); diff --git a/frontend/src/components/dataGridCopyInsert.ts b/frontend/src/components/dataGridCopyInsert.ts index 3034584..8b8c039 100644 --- a/frontend/src/components/dataGridCopyInsert.ts +++ b/frontend/src/components/dataGridCopyInsert.ts @@ -1,3 +1,4 @@ +import type { IndexDefinition } from '../types'; import { escapeLiteral, quoteIdentPart, quoteQualifiedIdent } from '../utils/sql'; type BuildCopyInsertSQLParams = { @@ -8,6 +9,22 @@ type BuildCopyInsertSQLParams = { columnTypesByLowerName?: Record; }; +type BuildCopyMutationSQLParams = BuildCopyInsertSQLParams & { + pkColumns?: string[]; + uniqueKeyGroups?: string[][]; + allTableColumns?: string[]; +}; + +type CopySqlWhereStrategy = 'primary-key' | 'unique-key' | 'all-columns'; + +export type CopyMutationSQLResult = + | { ok: true; sql: string; whereStrategy: CopySqlWhereStrategy } + | { ok: false; error: string }; + +type CopyMutationWhereClauseResult = + | { ok: true; clause: string; whereStrategy: CopySqlWhereStrategy } + | { ok: false; error: string }; + const looksLikeDateTimeText = (val: string): boolean => { if (!val) return false; const len = val.length; @@ -104,6 +121,157 @@ export const formatLocalDateTimeLiteral = (value: Date): string => { return `${year}-${month}-${day} ${hour}:${minute}:${second}`; }; +const getColumnType = (columnTypesByLowerName: Record, columnName: string): string | undefined => ( + columnTypesByLowerName[String(columnName || '').toLowerCase()] +); + +const getRecordValue = ( + record: Record, + columnName: string, +): { exists: boolean; value: any } => { + if (Object.prototype.hasOwnProperty.call(record || {}, columnName)) { + return { exists: true, value: record?.[columnName] }; + } + const loweredColumnName = String(columnName || '').toLowerCase(); + const matchedKey = Object.keys(record || {}).find((key) => key.toLowerCase() === loweredColumnName); + if (!matchedKey) { + return { exists: false, value: undefined }; + } + return { exists: true, value: record?.[matchedKey] }; +}; + +const normalizeColumnList = (columns: string[] | undefined): string[] => { + const seen = new Set(); + const result: string[] = []; + (columns || []).forEach((column) => { + const normalized = String(column || '').trim(); + if (!normalized) return; + const lowered = normalized.toLowerCase(); + if (seen.has(lowered)) return; + seen.add(lowered); + result.push(normalized); + }); + return result; +}; + +const toNormalizedLiteralText = (value: any, columnType?: string): string => { + if (typeof value === 'string') { + return normalizeTemporalLiteralText(value, columnType, true); + } + if (value instanceof Date) { + return formatLocalDateTimeLiteral(value); + } + return String(value); +}; + +const formatCopySqlLiteral = (value: any, columnType?: string): string => { + if (value === null || value === undefined) { + return 'NULL'; + } + return `'${escapeLiteral(toNormalizedLiteralText(value, columnType))}'`; +}; + +const doesResultCoverAllTableColumns = (orderedCols: string[], allTableColumns: string[]): boolean => { + const normalizedOrderedCols = normalizeColumnList(orderedCols); + const normalizedAllTableColumns = normalizeColumnList(allTableColumns); + if (normalizedOrderedCols.length === 0 || normalizedOrderedCols.length !== normalizedAllTableColumns.length) { + return false; + } + const orderedSet = new Set(normalizedOrderedCols.map((column) => column.toLowerCase())); + return normalizedAllTableColumns.every((column) => orderedSet.has(column.toLowerCase())); +}; + +const buildWhereClauseForColumns = ({ + dbType, + columns, + record, + columnTypesByLowerName, + requireNonNullValues, +}: { + dbType: string; + columns: string[]; + record: Record; + columnTypesByLowerName: Record; + requireNonNullValues: boolean; +}): string | null => { + const predicates: string[] = []; + for (const columnName of columns) { + const { exists, value } = getRecordValue(record, columnName); + if (!exists) { + return null; + } + const quotedColumn = quoteIdentPart(dbType, columnName); + if (value === null || value === undefined) { + if (requireNonNullValues) { + return null; + } + predicates.push(`${quotedColumn} IS NULL`); + continue; + } + predicates.push(`${quotedColumn} = ${formatCopySqlLiteral(value, getColumnType(columnTypesByLowerName, columnName))}`); + } + if (predicates.length === 0) { + return null; + } + return `(${predicates.join(' AND ')})`; +}; + +const resolveMutationWhereClause = ({ + dbType, + orderedCols, + record, + pkColumns = [], + uniqueKeyGroups = [], + allTableColumns = [], + columnTypesByLowerName = {}, +}: BuildCopyMutationSQLParams): CopyMutationWhereClauseResult => { + const normalizedPkColumns = normalizeColumnList(pkColumns); + const pkWhereClause = buildWhereClauseForColumns({ + dbType, + columns: normalizedPkColumns, + record, + columnTypesByLowerName, + requireNonNullValues: true, + }); + if (pkWhereClause) { + return { ok: true, clause: pkWhereClause, whereStrategy: 'primary-key' }; + } + + const normalizedUniqueKeyGroups = (uniqueKeyGroups || []) + .map((group) => normalizeColumnList(group)) + .filter((group) => group.length > 0); + for (const group of normalizedUniqueKeyGroups) { + const uniqueWhereClause = buildWhereClauseForColumns({ + dbType, + columns: group, + record, + columnTypesByLowerName, + requireNonNullValues: true, + }); + if (uniqueWhereClause) { + return { ok: true, clause: uniqueWhereClause, whereStrategy: 'unique-key' }; + } + } + + if (doesResultCoverAllTableColumns(orderedCols, allTableColumns)) { + const fullRowWhereClause = buildWhereClauseForColumns({ + dbType, + columns: orderedCols, + record, + columnTypesByLowerName, + requireNonNullValues: false, + }); + if (fullRowWhereClause) { + return { ok: true, clause: fullRowWhereClause, whereStrategy: 'all-columns' }; + } + } + + return { + ok: false, + error: '当前结果集缺少可安全定位行数据的主键/唯一键,且未覆盖表的全部字段,无法生成 WHERE 条件。', + }; +}; + export const buildCopyInsertSQL = ({ dbType, tableName, @@ -114,18 +282,136 @@ export const buildCopyInsertSQL = ({ const targetTable = quoteQualifiedIdent(dbType, tableName || 'table'); const quotedCols = orderedCols.map((col) => quoteIdentPart(dbType, col)); const values = orderedCols.map((col) => { - const value = record?.[col]; - if (value === null || value === undefined) return 'NULL'; - - const columnType = columnTypesByLowerName[String(col || '').toLowerCase()]; - const raw = - typeof value === 'string' - ? normalizeTemporalLiteralText(value, columnType, true) - : value instanceof Date - ? formatLocalDateTimeLiteral(value) - : String(value); - return `'${escapeLiteral(raw)}'`; + const { value } = getRecordValue(record, col); + return formatCopySqlLiteral(value, getColumnType(columnTypesByLowerName, col)); }); return `INSERT INTO ${targetTable} (${quotedCols.join(', ')}) VALUES (${values.join(', ')});`; }; + +const buildCopyMutationSQL = ( + mode: 'update' | 'delete', + { + dbType, + tableName, + orderedCols, + record, + pkColumns = [], + uniqueKeyGroups = [], + allTableColumns = [], + columnTypesByLowerName = {}, + }: BuildCopyMutationSQLParams, +): CopyMutationSQLResult => { + const normalizedTableName = String(tableName || '').trim(); + const normalizedOrderedCols = normalizeColumnList(orderedCols); + if (!normalizedTableName) { + return { + ok: false, + error: `当前结果集未关联明确表名,无法生成 ${mode.toUpperCase()} SQL。`, + }; + } + if (normalizedOrderedCols.length === 0) { + return { + ok: false, + error: '当前结果集没有可复制的字段,无法生成 SQL。', + }; + } + + const whereClause = resolveMutationWhereClause({ + dbType, + orderedCols: normalizedOrderedCols, + record, + pkColumns, + uniqueKeyGroups, + allTableColumns, + columnTypesByLowerName, + }); + if (whereClause.ok === false) { + return { ok: false, error: whereClause.error }; + } + + const targetTable = quoteQualifiedIdent(dbType, normalizedTableName); + if (mode === 'delete') { + return { + ok: true, + sql: `DELETE FROM ${targetTable} WHERE ${whereClause.clause};`, + whereStrategy: whereClause.whereStrategy, + }; + } + + const assignments = normalizedOrderedCols.map((columnName) => { + const { value } = getRecordValue(record, columnName); + return `${quoteIdentPart(dbType, columnName)} = ${formatCopySqlLiteral(value, getColumnType(columnTypesByLowerName, columnName))}`; + }); + + return { + ok: true, + sql: `UPDATE ${targetTable} SET ${assignments.join(', ')} WHERE ${whereClause.clause};`, + whereStrategy: whereClause.whereStrategy, + }; +}; + +export const buildCopyUpdateSQL = (params: BuildCopyMutationSQLParams): CopyMutationSQLResult => ( + buildCopyMutationSQL('update', params) +); + +export const buildCopyDeleteSQL = (params: BuildCopyMutationSQLParams): CopyMutationSQLResult => ( + buildCopyMutationSQL('delete', params) +); + +export const resolveUniqueKeyGroupsFromIndexes = (indexes: IndexDefinition[] | undefined): string[][] => { + type IndexBucket = { + order: number; + columns: Array<{ columnName: string; seqInIndex: number; order: number }>; + }; + + const buckets = new Map(); + (indexes || []).forEach((index, order) => { + if (index?.nonUnique !== 0) { + return; + } + const name = String(index?.name || '').trim(); + const columnName = String(index?.columnName || '').trim(); + if (!name || !columnName) { + return; + } + if (!buckets.has(name)) { + buckets.set(name, { order, columns: [] }); + } + const bucket = buckets.get(name); + if (!bucket) { + return; + } + bucket.columns.push({ + columnName, + seqInIndex: Number.isFinite(Number(index?.seqInIndex)) ? Number(index.seqInIndex) : 0, + order, + }); + }); + + return Array.from(buckets.values()) + .sort((left, right) => left.order - right.order) + .map((bucket) => { + const seen = new Set(); + return bucket.columns + .slice() + .sort((left, right) => { + const leftSeq = left.seqInIndex > 0 ? left.seqInIndex : Number.MAX_SAFE_INTEGER; + const rightSeq = right.seqInIndex > 0 ? right.seqInIndex : Number.MAX_SAFE_INTEGER; + if (leftSeq !== rightSeq) { + return leftSeq - rightSeq; + } + return left.order - right.order; + }) + .map((item) => item.columnName) + .filter((columnName) => { + const lowered = columnName.toLowerCase(); + if (seen.has(lowered)) { + return false; + } + seen.add(lowered); + return true; + }); + }) + .filter((group) => group.length > 0); +}; diff --git a/frontend/src/store.test.ts b/frontend/src/store.test.ts new file mode 100644 index 0000000..633130a --- /dev/null +++ b/frontend/src/store.test.ts @@ -0,0 +1,94 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +class MemoryStorage implements Storage { + private data = new Map(); + + get length(): number { + return this.data.size; + } + + clear(): void { + this.data.clear(); + } + + getItem(key: string): string | null { + return this.data.has(key) ? this.data.get(key)! : null; + } + + key(index: number): string | null { + return Array.from(this.data.keys())[index] ?? null; + } + + removeItem(key: string): void { + this.data.delete(key); + } + + setItem(key: string, value: string): void { + this.data.set(key, String(value)); + } +} + +const importStore = async () => { + const store = await import('./store'); + await store.useStore.persist.rehydrate(); + return store; +}; + +describe('store appearance persistence', () => { + let storage: MemoryStorage; + + beforeEach(() => { + storage = new MemoryStorage(); + vi.stubGlobal('localStorage', storage); + vi.resetModules(); + }); + + afterEach(() => { + vi.unstubAllGlobals(); + vi.resetModules(); + }); + + it('fills missing DataGrid appearance settings with defaults during hydration', async () => { + storage.setItem('lite-db-storage', JSON.stringify({ + state: { + appearance: { + enabled: false, + opacity: 0.75, + blur: 6, + useNativeMacWindowControls: true, + }, + }, + version: 7, + })); + + const { useStore } = await importStore(); + const appearance = useStore.getState().appearance; + + expect(appearance.enabled).toBe(false); + expect(appearance.opacity).toBe(0.75); + expect(appearance.blur).toBe(6); + expect(appearance.useNativeMacWindowControls).toBe(true); + expect(appearance.showDataTableVerticalBorders).toBe(false); + expect(appearance.dataTableColumnWidthMode).toBe('standard'); + }); + + it('persists DataGrid appearance settings and restores them after reload', async () => { + const { useStore } = await importStore(); + + useStore.getState().setAppearance({ + showDataTableVerticalBorders: true, + dataTableColumnWidthMode: 'compact', + }); + + const persisted = JSON.parse(storage.getItem('lite-db-storage') || '{}'); + expect(persisted.state.appearance.showDataTableVerticalBorders).toBe(true); + expect(persisted.state.appearance.dataTableColumnWidthMode).toBe('compact'); + + vi.resetModules(); + const reloaded = await importStore(); + const appearance = reloaded.useStore.getState().appearance; + + expect(appearance.showDataTableVerticalBorders).toBe(true); + expect(appearance.dataTableColumnWidthMode).toBe('compact'); + }); +}); diff --git a/frontend/src/store.ts b/frontend/src/store.ts index 32ca88b..23eff8f 100644 --- a/frontend/src/store.ts +++ b/frontend/src/store.ts @@ -10,8 +10,26 @@ import { sanitizeShortcutOptions, } from './utils/shortcuts'; import { toPersistedGlobalProxy } from './utils/globalProxyDraft'; +import { + DEFAULT_DATA_GRID_DISPLAY_SETTINGS, + sanitizeDataGridDisplaySettings, + type DataGridDisplaySettings, +} from './utils/dataGridDisplay'; -const DEFAULT_APPEARANCE = { enabled: true, opacity: 1.0, blur: 0, useNativeMacWindowControls: false }; +export interface AppearanceSettings extends DataGridDisplaySettings { + enabled: boolean; + opacity: number; + blur: number; + useNativeMacWindowControls: boolean; +} + +export const DEFAULT_APPEARANCE: AppearanceSettings = { + enabled: true, + opacity: 1.0, + blur: 0, + useNativeMacWindowControls: false, + ...DEFAULT_DATA_GRID_DISPLAY_SETTINGS, +}; const DEFAULT_UI_SCALE = 1.0; const MIN_UI_SCALE = 0.8; const MAX_UI_SCALE = 1.25; @@ -26,7 +44,7 @@ const MAX_HOST_ENTRY_LENGTH = 512; const MAX_HOST_ENTRIES = 64; const DEFAULT_TIMEOUT_SECONDS = 30; const MAX_TIMEOUT_SECONDS = 3600; -const PERSIST_VERSION = 7; +const PERSIST_VERSION = 8; const DEFAULT_CONNECTION_TYPE = 'mysql'; const DEFAULT_GLOBAL_PROXY: GlobalProxyConfig = { enabled: false, @@ -413,7 +431,7 @@ interface AppState { activeContext: { connectionId: string; dbName: string } | null; savedQueries: SavedQuery[]; theme: 'light' | 'dark'; - appearance: { enabled: boolean; opacity: number; blur: number; useNativeMacWindowControls: boolean }; + appearance: AppearanceSettings; uiScale: number; fontSize: number; startupFullscreen: boolean; @@ -472,7 +490,7 @@ interface AppState { deleteQuery: (id: string) => void; setTheme: (theme: 'light' | 'dark') => void; - setAppearance: (appearance: Partial<{ enabled: boolean; opacity: number; blur: number; useNativeMacWindowControls: boolean }>) => void; + setAppearance: (appearance: Partial) => void; setUiScale: (scale: number) => void; setFontSize: (size: number) => void; setStartupFullscreen: (enabled: boolean) => void; @@ -596,12 +614,13 @@ const sanitizeTableHiddenColumns = (value: unknown): Record => }; const sanitizeAppearance = ( - appearance: Partial<{ enabled: boolean; opacity: number; blur: number; useNativeMacWindowControls: boolean }> | undefined, + appearance: Partial | undefined, version: number -): { enabled: boolean; opacity: number; blur: number; useNativeMacWindowControls: boolean } => { +): AppearanceSettings => { if (!appearance || typeof appearance !== 'object') { return { ...DEFAULT_APPEARANCE }; } + const dataGridDisplaySettings = sanitizeDataGridDisplaySettings(appearance); const nextAppearance = { enabled: typeof appearance.enabled === 'boolean' ? appearance.enabled : DEFAULT_APPEARANCE.enabled, opacity: typeof appearance.opacity === 'number' ? appearance.opacity : DEFAULT_APPEARANCE.opacity, @@ -609,6 +628,8 @@ const sanitizeAppearance = ( useNativeMacWindowControls: typeof appearance.useNativeMacWindowControls === 'boolean' ? appearance.useNativeMacWindowControls : DEFAULT_APPEARANCE.useNativeMacWindowControls, + showDataTableVerticalBorders: dataGridDisplaySettings.showDataTableVerticalBorders, + dataTableColumnWidthMode: dataGridDisplaySettings.dataTableColumnWidthMode, }; if (version < 2 && isLegacyDefaultAppearance(appearance)) { return { ...DEFAULT_APPEARANCE }; @@ -1315,5 +1336,3 @@ export const useStore = create()( } ) ); - - diff --git a/frontend/src/utils/dataGridDisplay.test.ts b/frontend/src/utils/dataGridDisplay.test.ts new file mode 100644 index 0000000..0f7e47b --- /dev/null +++ b/frontend/src/utils/dataGridDisplay.test.ts @@ -0,0 +1,32 @@ +import { describe, expect, it } from 'vitest'; + +import { + DEFAULT_DATA_GRID_DISPLAY_SETTINGS, + resolveDataTableColumnWidth, + resolveDataTableDefaultColumnWidth, + resolveDataTableVerticalBorderColor, + sanitizeDataGridDisplaySettings, +} from './dataGridDisplay'; + +describe('dataGridDisplay helpers', () => { + it('sanitizes missing display settings to safe defaults', () => { + expect(sanitizeDataGridDisplaySettings(undefined)).toEqual(DEFAULT_DATA_GRID_DISPLAY_SETTINGS); + expect(sanitizeDataGridDisplaySettings({ dataTableColumnWidthMode: 'invalid' as never })).toEqual(DEFAULT_DATA_GRID_DISPLAY_SETTINGS); + }); + + it('resolves standard and compact default column widths', () => { + expect(resolveDataTableDefaultColumnWidth('standard')).toBe(200); + expect(resolveDataTableDefaultColumnWidth('compact')).toBe(140); + }); + + it('keeps manual column widths ahead of mode defaults', () => { + expect(resolveDataTableColumnWidth({ manualWidth: 320, widthMode: 'compact' })).toBe(320); + expect(resolveDataTableColumnWidth({ manualWidth: undefined, widthMode: 'compact' })).toBe(140); + }); + + it('uses subtle themed vertical border colors and transparent when disabled', () => { + expect(resolveDataTableVerticalBorderColor({ darkMode: true, visible: true })).toBe('rgba(255, 255, 255, 0.08)'); + expect(resolveDataTableVerticalBorderColor({ darkMode: false, visible: true })).toBe('rgba(15, 23, 42, 0.08)'); + expect(resolveDataTableVerticalBorderColor({ darkMode: false, visible: false })).toBe('transparent'); + }); +}); diff --git a/frontend/src/utils/dataGridDisplay.ts b/frontend/src/utils/dataGridDisplay.ts new file mode 100644 index 0000000..32ed056 --- /dev/null +++ b/frontend/src/utils/dataGridDisplay.ts @@ -0,0 +1,72 @@ +export type DataTableColumnWidthMode = 'standard' | 'compact'; + +export interface DataGridDisplaySettings { + showDataTableVerticalBorders: boolean; + dataTableColumnWidthMode: DataTableColumnWidthMode; +} + +export const DEFAULT_DATA_GRID_DISPLAY_SETTINGS: DataGridDisplaySettings = { + showDataTableVerticalBorders: false, + dataTableColumnWidthMode: 'standard', +}; + +export const DATA_GRID_COLUMN_WIDTH_MODE_OPTIONS = [ + { label: '标准 200px', value: 'standard' as const }, + { label: '紧凑 140px', value: 'compact' as const }, +]; + +const STANDARD_DATA_TABLE_COLUMN_WIDTH = 200; +const COMPACT_DATA_TABLE_COLUMN_WIDTH = 140; + +export const sanitizeDataTableColumnWidthMode = (value: unknown): DataTableColumnWidthMode => { + return value === 'compact' ? 'compact' : 'standard'; +}; + +export const sanitizeDataGridDisplaySettings = ( + value: Partial | undefined +): DataGridDisplaySettings => { + if (!value || typeof value !== 'object') { + return { ...DEFAULT_DATA_GRID_DISPLAY_SETTINGS }; + } + + return { + showDataTableVerticalBorders: value.showDataTableVerticalBorders === true, + dataTableColumnWidthMode: sanitizeDataTableColumnWidthMode(value.dataTableColumnWidthMode), + }; +}; + +export const resolveDataTableDefaultColumnWidth = ( + widthMode: DataTableColumnWidthMode | null | undefined +): number => { + return sanitizeDataTableColumnWidthMode(widthMode) === 'compact' + ? COMPACT_DATA_TABLE_COLUMN_WIDTH + : STANDARD_DATA_TABLE_COLUMN_WIDTH; +}; + +export const resolveDataTableColumnWidth = ({ + manualWidth, + widthMode, +}: { + manualWidth: number | null | undefined; + widthMode: DataTableColumnWidthMode | null | undefined; +}): number => { + if (typeof manualWidth === 'number' && Number.isFinite(manualWidth) && manualWidth > 0) { + return manualWidth; + } + + return resolveDataTableDefaultColumnWidth(widthMode); +}; + +export const resolveDataTableVerticalBorderColor = ({ + darkMode, + visible, +}: { + darkMode: boolean; + visible: boolean; +}): string => { + if (!visible) { + return 'transparent'; + } + + return darkMode ? 'rgba(255, 255, 255, 0.08)' : 'rgba(15, 23, 42, 0.08)'; +}; From ac0b6c05e8f64933013f02959aa16301dd277aaa Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Fri, 3 Apr 2026 01:23:38 +0800 Subject: [PATCH 14/14] =?UTF-8?q?=F0=9F=90=9B=20fix(database):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E6=9C=AC=E5=9C=B0=E9=A9=B1=E5=8A=A8=E5=AF=BC=E5=85=A5?= =?UTF-8?q?=E7=89=88=E6=9C=AC=E8=AF=86=E5=88=AB=E4=B8=8E=E8=BF=9E=E6=8E=A5?= =?UTF-8?q?=E6=A0=A1=E9=AA=8C=E9=81=97=E6=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - MongoDB 本地导入按所选版本解析目录与压缩包 - ClickHouse 连接测试补充 query path 校验 - 补充驱动版本与查询路径回归测试 --- .../src/components/DriverManagerModal.tsx | 22 +++- frontend/wailsjs/go/app/App.d.ts | 2 +- frontend/wailsjs/go/app/App.js | 4 +- internal/app/methods_driver.go | 45 ++++--- internal/app/methods_driver_version_test.go | 109 ++++++++++++++++ internal/db/clickhouse_impl.go | 39 +++++- internal/db/clickhouse_impl_test.go | 119 ++++++++++++++++++ 7 files changed, 314 insertions(+), 26 deletions(-) create mode 100644 internal/db/clickhouse_impl_test.go diff --git a/frontend/src/components/DriverManagerModal.tsx b/frontend/src/components/DriverManagerModal.tsx index 6dc3df9..9516e31 100644 --- a/frontend/src/components/DriverManagerModal.tsx +++ b/frontend/src/components/DriverManagerModal.tsx @@ -757,6 +757,16 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG }; }, [appendOperationLog, open]); + const resolveLocalImportVersion = useCallback((row: DriverStatusRow) => { + const options = versionMap[row.type] || []; + const selectedKey = selectedVersionMap[row.type]; + const selectedOption = + options.find((item) => buildVersionOptionKey(item) === selectedKey) || + options.find((item) => item.recommended) || + options[0]; + return selectedOption?.version || row.pinnedVersion || ''; + }, [selectedVersionMap, versionMap]); + const installDriver = useCallback(async (row: DriverStatusRow) => { setActionState({ driverType: row.type, kind: 'install' }); setProgressMap((prev) => ({ @@ -820,9 +830,11 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG percent: 0, }, })); - appendOperationLog(row.type, `[START] 开始本地导入(${sourceLabel}):${pathText}`); + const selectedVersion = resolveLocalImportVersion(row); + const versionTip = selectedVersion ? `(${selectedVersion})` : ''; + appendOperationLog(row.type, `[START] 开始本地导入${versionTip}(${sourceLabel}):${pathText}`); try { - const result = await InstallLocalDriverPackage(row.type, pathText, downloadDir); + const result = await InstallLocalDriverPackage(row.type, pathText, downloadDir, selectedVersion); if (!result?.success) { const errText = result?.message || `导入 ${row.name} 本地驱动包失败`; appendOperationLog(row.type, `[ERROR] ${errText}`); @@ -831,9 +843,9 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG } return false; } - appendOperationLog(row.type, '[DONE] 本地导入安装完成'); + appendOperationLog(row.type, `[DONE] 本地导入安装完成 ${versionTip}`.trim()); if (!options?.silentToast) { - message.success(`${row.name} 本地驱动包已安装启用`); + message.success(`${row.name}${versionTip} 本地驱动包已安装启用`); } if (!options?.skipRefresh) { await refreshStatus(false); @@ -842,7 +854,7 @@ const DriverManagerModal: React.FC<{ open: boolean; onClose: () => void; onOpenG } finally { setActionState({ driverType: '', kind: '' }); } - }, [appendOperationLog, downloadDir, refreshStatus]); + }, [appendOperationLog, downloadDir, refreshStatus, resolveLocalImportVersion]); const installDriverFromLocalFile = useCallback(async (row: DriverStatusRow) => { const fileRes = await SelectDriverPackageFile(downloadDir); diff --git a/frontend/wailsjs/go/app/App.d.ts b/frontend/wailsjs/go/app/App.d.ts index 08c1dd8..f94ace7 100755 --- a/frontend/wailsjs/go/app/App.d.ts +++ b/frontend/wailsjs/go/app/App.d.ts @@ -106,7 +106,7 @@ export function ImportLegacyConnections(arg1:Array; -export function InstallLocalDriverPackage(arg1:string,arg2:string,arg3:string):Promise; +export function InstallLocalDriverPackage(arg1:string,arg2:string,arg3:string,arg4:string):Promise; export function InstallUpdateAndRestart():Promise; diff --git a/frontend/wailsjs/go/app/App.js b/frontend/wailsjs/go/app/App.js index 9564131..d2e2c50 100755 --- a/frontend/wailsjs/go/app/App.js +++ b/frontend/wailsjs/go/app/App.js @@ -206,8 +206,8 @@ export function ImportLegacyGlobalProxy(arg1) { return window['go']['app']['App']['ImportLegacyGlobalProxy'](arg1); } -export function InstallLocalDriverPackage(arg1, arg2, arg3) { - return window['go']['app']['App']['InstallLocalDriverPackage'](arg1, arg2, arg3); +export function InstallLocalDriverPackage(arg1, arg2, arg3, arg4) { + return window['go']['app']['App']['InstallLocalDriverPackage'](arg1, arg2, arg3, arg4); } export function InstallUpdateAndRestart() { diff --git a/internal/app/methods_driver.go b/internal/app/methods_driver.go index e8873f7..f8e90dc 100644 --- a/internal/app/methods_driver.go +++ b/internal/app/methods_driver.go @@ -745,7 +745,7 @@ func (a *App) CheckDriverNetworkStatus() connection.QueryResult { } } -func (a *App) InstallLocalDriverPackage(driverType string, filePath string, downloadDir string) connection.QueryResult { +func (a *App) InstallLocalDriverPackage(driverType string, filePath string, downloadDir string, version string) connection.QueryResult { definition, ok := resolveDriverDefinition(driverType) if !ok { return connection.QueryResult{Success: false, Message: "不支持的驱动类型"} @@ -768,7 +768,10 @@ func (a *App) InstallLocalDriverPackage(driverType string, filePath string, down db.SetExternalDriverDownloadDirectory(resolvedDir) a.emitDriverDownloadProgress(definition.Type, "start", 0, 100, "开始安装本地驱动包") - selectedVersion := resolveDriverInstallVersion(definition.PinnedVersion, "local://manual", definition) + selectedVersion := resolveDriverInstallVersion(version, "local://manual", definition) + if err := validateDriverSelectedVersion(definition, selectedVersion); err != nil { + return connection.QueryResult{Success: false, Message: err.Error()} + } meta, installErr := installOptionalDriverAgentFromLocalPath(definition, filePath, resolvedDir, selectedVersion) if installErr != nil { errText := normalizeErrorMessage(installErr) @@ -2628,7 +2631,7 @@ func installOptionalDriverAgentFromLocalPath(definition driverDefinition, filePa sourceName := filepath.Base(pathText) downloadSource := fmt.Sprintf("local://manual/%s", filepath.Base(pathText)) if info.IsDir() { - matchedPath, matchedEntry, resolveErr := resolveLocalDriverAgentFromDirectory(pathText, driverType) + matchedPath, matchedEntry, resolveErr := resolveLocalDriverAgentFromLocalDirectory(pathText, driverType, selectedVersion) if resolveErr != nil { return installedDriverPackage{}, resolveErr } @@ -2641,7 +2644,7 @@ func installOptionalDriverAgentFromLocalPath(definition driverDefinition, filePa } if !info.IsDir() && strings.EqualFold(filepath.Ext(pathText), ".zip") { - entryName, extractErr := installOptionalDriverAgentFromLocalZip(pathText, definition, executablePath) + entryName, extractErr := installOptionalDriverAgentFromLocalZip(pathText, definition, executablePath, selectedVersion) if extractErr != nil { return installedDriverPackage{}, extractErr } @@ -2680,7 +2683,7 @@ type localDriverCandidate struct { inPlatformDir bool } -func resolveLocalDriverAgentFromDirectory(directoryPath string, driverType string) (string, string, error) { +func resolveLocalDriverAgentFromLocalDirectory(directoryPath string, driverType string, selectedVersion string) (string, string, error) { root := strings.TrimSpace(directoryPath) if root == "" { return "", "", fmt.Errorf("本地驱动目录路径为空") @@ -2703,9 +2706,9 @@ func resolveLocalDriverAgentFromDirectory(directoryPath string, driverType strin } displayName := resolveDriverDisplayName(displayDefinition) platformDir := optionalDriverBundlePlatformDir(stdRuntime.GOOS) - assetNameCandidates := optionalDriverReleaseAssetNames(normalizedType) - baseNameCandidates := optionalDriverExecutableBaseNames(normalizedType) - assetName := optionalDriverReleaseAssetName(normalizedType) + assetNameCandidates := optionalDriverReleaseAssetNamesForVersion(normalizedType, selectedVersion) + baseNameCandidates := optionalDriverExecutableBaseNamesForVersion(normalizedType, selectedVersion) + assetName := optionalDriverReleaseAssetNameForVersion(normalizedType, selectedVersion) exactRelativePath := filepath.ToSlash(filepath.Join(platformDir, assetName)) for _, candidateName := range assetNameCandidates { @@ -2820,7 +2823,7 @@ func resolveLocalDriverAgentFromDirectory(directoryPath string, driverType strin ) } -func installOptionalDriverAgentFromLocalZip(zipPath string, definition driverDefinition, executablePath string) (string, error) { +func installOptionalDriverAgentFromLocalZip(zipPath string, definition driverDefinition, executablePath string, selectedVersion string) (string, error) { driverType := normalizeDriverType(definition.Type) displayName := resolveDriverDisplayName(definition) reader, err := zip.OpenReader(zipPath) @@ -2829,9 +2832,9 @@ func installOptionalDriverAgentFromLocalZip(zipPath string, definition driverDef } defer reader.Close() - entryPath := optionalDriverBundleEntryPath(driverType) - entryPaths := optionalDriverBundleEntryPaths(driverType) - expectedBaseNames := optionalDriverReleaseAssetNames(driverType) + entryPath := optionalDriverBundleEntryPathForVersion(driverType, selectedVersion) + entryPaths := optionalDriverBundleEntryPathsForVersion(driverType, selectedVersion) + expectedBaseNames := optionalDriverReleaseAssetNamesForVersion(driverType, selectedVersion) findEntry := func() *zip.File { for _, file := range reader.File { name := filepath.ToSlash(strings.TrimPrefix(strings.TrimSpace(file.Name), "./")) @@ -3490,9 +3493,9 @@ func optionalDriverBundlePlatformDir(goos string) string { } } -func optionalDriverBundleEntryPaths(driverType string) []string { +func optionalDriverBundleEntryPathsForVersion(driverType string, selectedVersion string) []string { platformDir := optionalDriverBundlePlatformDir(stdRuntime.GOOS) - assetNames := optionalDriverReleaseAssetNames(driverType) + assetNames := optionalDriverReleaseAssetNamesForVersion(driverType, selectedVersion) result := make([]string, 0, len(assetNames)) seen := make(map[string]struct{}, len(assetNames)) for _, assetName := range assetNames { @@ -3506,14 +3509,22 @@ func optionalDriverBundleEntryPaths(driverType string) []string { return result } -func optionalDriverBundleEntryPath(driverType string) string { - paths := optionalDriverBundleEntryPaths(driverType) +func optionalDriverBundleEntryPaths(driverType string) []string { + return optionalDriverBundleEntryPathsForVersion(driverType, "") +} + +func optionalDriverBundleEntryPathForVersion(driverType string, selectedVersion string) string { + paths := optionalDriverBundleEntryPathsForVersion(driverType, selectedVersion) if len(paths) == 0 { - return filepath.ToSlash(filepath.Join(optionalDriverBundlePlatformDir(stdRuntime.GOOS), optionalDriverReleaseAssetName(driverType))) + return filepath.ToSlash(filepath.Join(optionalDriverBundlePlatformDir(stdRuntime.GOOS), optionalDriverReleaseAssetNameForVersion(driverType, selectedVersion))) } return paths[0] } +func optionalDriverBundleEntryPath(driverType string) string { + return optionalDriverBundleEntryPathForVersion(driverType, "") +} + func resolveOptionalDriverAssetSize(sizeByAsset map[string]int64, driverType string) int64 { if len(sizeByAsset) == 0 { return 0 diff --git a/internal/app/methods_driver_version_test.go b/internal/app/methods_driver_version_test.go index 5fcfe34..ff12876 100644 --- a/internal/app/methods_driver_version_test.go +++ b/internal/app/methods_driver_version_test.go @@ -1,8 +1,10 @@ package app import ( + "archive/zip" "fmt" "os" + "path/filepath" "runtime" "strings" "testing" @@ -154,6 +156,66 @@ func TestShouldForceSourceBuildForResolvedDownload(t *testing.T) { } } +func TestInstallOptionalDriverAgentFromLocalPathSupportsMongoV1DirectoryImport(t *testing.T) { + definition, ok := resolveDriverDefinition("mongodb") + if !ok { + t.Fatal("expected mongodb driver definition") + } + + packageRoot := t.TempDir() + platformDir := filepath.Join(packageRoot, optionalDriverBundlePlatformDir(runtime.GOOS)) + if err := os.MkdirAll(platformDir, 0o755); err != nil { + t.Fatalf("mkdir package dir failed: %v", err) + } + + assetName := mongoVersionedReleaseAssetName(1) + writeSelfExecutable(t, filepath.Join(platformDir, assetName)) + + installRoot := filepath.Join(t.TempDir(), "drivers") + meta, err := installOptionalDriverAgentFromLocalPath(definition, packageRoot, installRoot, "1.17.4") + if err != nil { + t.Fatalf("expected mongodb v1 directory import to succeed, got %v", err) + } + if meta.Version != "1.17.4" { + t.Fatalf("expected imported version to stay 1.17.4, got %q", meta.Version) + } + if filepath.Base(meta.FilePath) != assetName { + t.Fatalf("expected source file %q, got %q", assetName, meta.FilePath) + } + if !strings.Contains(meta.DownloadURL, assetName) { + t.Fatalf("expected download source to reference %q, got %q", assetName, meta.DownloadURL) + } + if _, err := os.Stat(meta.ExecutablePath); err != nil { + t.Fatalf("expected imported executable to exist, got %v", err) + } +} + +func TestInstallOptionalDriverAgentFromLocalPathSupportsMongoV1ZipImport(t *testing.T) { + definition, ok := resolveDriverDefinition("mongodb") + if !ok { + t.Fatal("expected mongodb driver definition") + } + + assetName := mongoVersionedReleaseAssetName(1) + zipPath := filepath.Join(t.TempDir(), "mongodb-v1.zip") + writeZipWithSelfExecutable(t, zipPath, filepath.ToSlash(filepath.Join(optionalDriverBundlePlatformDir(runtime.GOOS), assetName))) + + installRoot := filepath.Join(t.TempDir(), "drivers") + meta, err := installOptionalDriverAgentFromLocalPath(definition, zipPath, installRoot, "1.17.4") + if err != nil { + t.Fatalf("expected mongodb v1 zip import to succeed, got %v", err) + } + if meta.Version != "1.17.4" { + t.Fatalf("expected imported version to stay 1.17.4, got %q", meta.Version) + } + if !strings.Contains(meta.DownloadURL, assetName) { + t.Fatalf("expected zip download source to reference %q, got %q", assetName, meta.DownloadURL) + } + if _, err := os.Stat(meta.ExecutablePath); err != nil { + t.Fatalf("expected imported executable to exist, got %v", err) + } +} + func seedReleaseAssetSizeCache(t *testing.T, cacheKey string, sizeByKey map[string]int64) { t.Helper() @@ -220,3 +282,50 @@ func mongoVersionedReleaseAssetName(major int) string { } return name } + +func writeSelfExecutable(t *testing.T, targetPath string) { + t.Helper() + + selfPath, err := os.Executable() + if err != nil { + t.Fatalf("executable path failed: %v", err) + } + content, err := os.ReadFile(selfPath) + if err != nil { + t.Fatalf("read self executable failed: %v", err) + } + if err := os.WriteFile(targetPath, content, 0o755); err != nil { + t.Fatalf("write executable failed: %v", err) + } +} + +func writeZipWithSelfExecutable(t *testing.T, zipPath string, entryName string) { + t.Helper() + + selfPath, err := os.Executable() + if err != nil { + t.Fatalf("executable path failed: %v", err) + } + content, err := os.ReadFile(selfPath) + if err != nil { + t.Fatalf("read self executable failed: %v", err) + } + + file, err := os.Create(zipPath) + if err != nil { + t.Fatalf("create zip failed: %v", err) + } + defer file.Close() + + writer := zip.NewWriter(file) + entry, err := writer.Create(entryName) + if err != nil { + t.Fatalf("create zip entry failed: %v", err) + } + if _, err := entry.Write(content); err != nil { + t.Fatalf("write zip entry failed: %v", err) + } + if err := writer.Close(); err != nil { + t.Fatalf("close zip writer failed: %v", err) + } +} diff --git a/internal/db/clickhouse_impl.go b/internal/db/clickhouse_impl.go index d98ca3d..c6be5d0 100644 --- a/internal/db/clickhouse_impl.go +++ b/internal/db/clickhouse_impl.go @@ -279,7 +279,44 @@ func (c *ClickHouseDB) Ping() error { } ctx, cancel := utils.ContextWithTimeout(timeout) defer cancel() - return c.conn.PingContext(ctx) + if err := c.conn.PingContext(ctx); err != nil { + return err + } + return c.validateQueryPath() +} + +func (c *ClickHouseDB) validateQueryPath() error { + if c.conn == nil { + return fmt.Errorf("连接未打开") + } + timeout := c.pingTimeout + if timeout <= 0 { + timeout = 5 * time.Second + } + ctx, cancel := utils.ContextWithTimeout(timeout) + defer cancel() + + rows, err := c.conn.QueryContext(ctx, "SELECT currentDatabase()") + if err != nil { + return err + } + defer rows.Close() + + if !rows.Next() { + if err := rows.Err(); err != nil { + return err + } + return fmt.Errorf("连接查询验证未返回结果") + } + + var current sql.NullString + if err := rows.Scan(¤t); err != nil { + return err + } + if err := rows.Err(); err != nil { + return err + } + return nil } func (c *ClickHouseDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { diff --git a/internal/db/clickhouse_impl_test.go b/internal/db/clickhouse_impl_test.go new file mode 100644 index 0000000..a2fd40a --- /dev/null +++ b/internal/db/clickhouse_impl_test.go @@ -0,0 +1,119 @@ +//go:build gonavi_full_drivers || gonavi_clickhouse_driver + +package db + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "io" + "strings" + "sync" + "testing" + "time" +) + +const fakeClickHouseDriverName = "gonavi-fake-clickhouse" + +var ( + registerFakeClickHouseDriverOnce sync.Once + fakeClickHouseStateMu sync.Mutex + fakeClickHouseState = struct { + pingErr error + queryErr error + lastQuery string + }{ + lastQuery: "", + } +) + +func TestClickHousePingValidatesQueryPath(t *testing.T) { + registerFakeClickHouseDriverOnce.Do(func() { + sql.Register(fakeClickHouseDriverName, fakeClickHouseDriver{}) + }) + + db, err := sql.Open(fakeClickHouseDriverName, "") + if err != nil { + t.Fatalf("open fake clickhouse db failed: %v", err) + } + defer db.Close() + + fakeClickHouseStateMu.Lock() + fakeClickHouseState.pingErr = nil + fakeClickHouseState.queryErr = errors.New("query path failed") + fakeClickHouseState.lastQuery = "" + fakeClickHouseStateMu.Unlock() + + client := &ClickHouseDB{ + conn: db, + pingTimeout: time.Second, + } + err = client.Ping() + if err == nil { + t.Fatal("expected Ping to fail when query validation fails") + } + if !strings.Contains(err.Error(), "query path failed") { + t.Fatalf("expected query validation error, got %v", err) + } + + fakeClickHouseStateMu.Lock() + lastQuery := fakeClickHouseState.lastQuery + fakeClickHouseStateMu.Unlock() + if lastQuery != "SELECT currentDatabase()" { + t.Fatalf("expected query validation SQL to run, got %q", lastQuery) + } +} + +type fakeClickHouseDriver struct{} + +func (fakeClickHouseDriver) Open(name string) (driver.Conn, error) { + return fakeClickHouseConn{}, nil +} + +type fakeClickHouseConn struct{} + +func (fakeClickHouseConn) Prepare(query string) (driver.Stmt, error) { + return nil, errors.New("prepare not implemented") +} + +func (fakeClickHouseConn) Close() error { + return nil +} + +func (fakeClickHouseConn) Begin() (driver.Tx, error) { + return nil, errors.New("transactions not implemented") +} + +func (fakeClickHouseConn) Ping(ctx context.Context) error { + fakeClickHouseStateMu.Lock() + defer fakeClickHouseStateMu.Unlock() + return fakeClickHouseState.pingErr +} + +func (fakeClickHouseConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + fakeClickHouseStateMu.Lock() + defer fakeClickHouseStateMu.Unlock() + fakeClickHouseState.lastQuery = query + if fakeClickHouseState.queryErr != nil { + return nil, fakeClickHouseState.queryErr + } + return &fakeClickHouseRows{}, nil +} + +type fakeClickHouseRows struct{} + +func (r *fakeClickHouseRows) Columns() []string { + return []string{"currentDatabase"} +} + +func (r *fakeClickHouseRows) Close() error { + return nil +} + +func (r *fakeClickHouseRows) Next(dest []driver.Value) error { + if len(dest) > 0 { + dest[0] = "default" + } + return io.EOF +}