From 1a042321d2a6e8828ac575ea9adba98f8b49f98a Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Sun, 5 Apr 2026 15:27:08 +0800 Subject: [PATCH 1/7] =?UTF-8?q?=F0=9F=90=9B=20fix(connection):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E5=A4=B1=E8=B4=A5=E8=BF=9E=E6=8E=A5=E9=AB=98=E9=A2=91?= =?UTF-8?q?=E9=87=8D=E8=AF=95=E5=B9=B6=E6=9A=82=E5=81=9C=E5=90=8E=E5=8F=B0?= =?UTF-8?q?=E8=87=AA=E5=8A=A8=E5=85=83=E6=95=B0=E6=8D=AE=E6=8B=89=E5=8F=96?= =?UTF-8?q?=20#331?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 后端为失败数据库连接增加冷却窗口,避免短时间内重复真实建连 - 补充失败冷却回归测试,覆盖重复失败、冷却后重试和成功后清理场景 - 前端在后台态暂停查询页、侧边栏和表概览的自动元数据拉取 - 保持手动刷新、手动展开等显式操作行为不变 --- frontend/src/components/QueryEditor.tsx | 14 +- frontend/src/components/Sidebar.tsx | 8 +- frontend/src/components/TableOverview.tsx | 9 +- .../src/utils/autoFetchVisibility.test.ts | 22 +++ frontend/src/utils/autoFetchVisibility.ts | 54 ++++++ internal/app/app.go | 78 +++++++++ .../app/app_startup_connect_retry_test.go | 164 ++++++++++++++++++ 7 files changed, 345 insertions(+), 4 deletions(-) create mode 100644 frontend/src/utils/autoFetchVisibility.test.ts create mode 100644 frontend/src/utils/autoFetchVisibility.ts diff --git a/frontend/src/components/QueryEditor.tsx b/frontend/src/components/QueryEditor.tsx index 9b2fd44..8b952bc 100644 --- a/frontend/src/components/QueryEditor.tsx +++ b/frontend/src/components/QueryEditor.tsx @@ -11,6 +11,7 @@ import DataGrid, { GONAVI_ROW_KEY } from './DataGrid'; import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities'; import { convertMongoShellToJsonCommand } from '../utils/mongodb'; import { getShortcutDisplay, isEditableElement, isShortcutMatch } from '../utils/shortcuts'; +import { useAutoFetchVisibility } from '../utils/autoFetchVisibility'; import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; const SQL_KEYWORDS = [ @@ -249,6 +250,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc const setQueryOptions = useStore(state => state.setQueryOptions); const shortcutOptions = useStore(state => state.shortcutOptions); const activeTabId = useStore(state => state.activeTabId); + const autoFetchVisible = useAutoFetchVisibility(); const currentSavedQuery = useMemo(() => { const savedId = String(tab.savedQueryId || '').trim(); @@ -324,6 +326,10 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc // Fetch Database List useEffect(() => { + if (!autoFetchVisible) { + return; + } + const fetchDbs = async () => { const conn = connections.find(c => c.id === currentConnectionId); if (!conn) return; @@ -367,10 +373,14 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc } }; void fetchDbs(); - }, [currentConnectionId, connections]); + }, [autoFetchVisible, currentConnectionId, connections]); // Fetch Metadata for Autocomplete (Cross-database) useEffect(() => { + if (!autoFetchVisible) { + return; + } + const fetchMetadata = async () => { const conn = connections.find(c => c.id === currentConnectionId); if (!conn) return; @@ -424,7 +434,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc } }; void fetchMetadata(); - }, [currentConnectionId, connections, dbList]); // dbList 变化时触发重新加载 + }, [autoFetchVisible, currentConnectionId, connections, dbList]); // dbList 变化时触发重新加载 // Query ID management helpers const setQueryId = (id: string) => { diff --git a/frontend/src/components/Sidebar.tsx b/frontend/src/components/Sidebar.tsx index ed3743b..80688b0 100644 --- a/frontend/src/components/Sidebar.tsx +++ b/frontend/src/components/Sidebar.tsx @@ -41,6 +41,7 @@ import { getDbIcon } from './DatabaseIcons'; import { DBGetDatabases, DBGetTables, DBQuery, DBShowCreateTable, ExportTable, OpenSQLFile, ExecuteSQLFile, CancelSQLFileExecution, CreateDatabase, RenameDatabase, DropDatabase, RenameTable, DropTable, DropView, DropFunction, RenameView } from '../../wailsjs/go/app/App'; import { EventsOn } from '../../wailsjs/runtime/runtime'; import { normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance'; +import { useAutoFetchVisibility } from '../utils/autoFetchVisibility'; import FindInDatabaseModal from './FindInDatabaseModal'; import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; @@ -118,6 +119,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> const darkMode = theme === 'dark'; const resolvedAppearance = resolveAppearanceValues(appearance); const opacity = normalizeOpacityForPlatform(resolvedAppearance.opacity); + const autoFetchVisible = useAutoFetchVisibility(); const [treeData, setTreeData] = useState([]); // Background Helper (Duplicate logic for now, ideally shared) @@ -292,6 +294,10 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> const [findInDbContext, setFindInDbContext] = useState<{ open: boolean; connectionId: string; dbName: string }>({ open: false, connectionId: '', dbName: '' }); useEffect(() => { + if (!autoFetchVisible) { + return; + } + // Refresh queries for expanded databases const findNode = (nodes: TreeNode[], k: React.Key): TreeNode | null => { for (const node of nodes) { @@ -310,7 +316,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }> loadTables(node); } }); - }, [savedQueries]); + }, [autoFetchVisible, savedQueries]); useEffect(() => { setTreeData((prev) => { diff --git a/frontend/src/components/TableOverview.tsx b/frontend/src/components/TableOverview.tsx index bf687a1..f18c272 100644 --- a/frontend/src/components/TableOverview.tsx +++ b/frontend/src/components/TableOverview.tsx @@ -4,6 +4,7 @@ import { TableOutlined, SearchOutlined, ReloadOutlined, SortAscendingOutlined, D import { useStore } from '../store'; import { DBQuery, DBShowCreateTable, ExportTable, DropTable, RenameTable } from '../../wailsjs/go/app/App'; import type { TabData } from '../types'; +import { useAutoFetchVisibility } from '../utils/autoFetchVisibility'; import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; interface TableOverviewProps { @@ -151,6 +152,7 @@ const TableOverview: React.FC = ({ tab }) => { const [viewMode, setViewMode] = useState('card'); const connection = useMemo(() => connections.find(c => c.id === tab.connectionId), [connections, tab.connectionId]); + const autoFetchVisible = useAutoFetchVisibility(); const loadData = useCallback(async () => { if (!connection) return; @@ -179,7 +181,12 @@ const TableOverview: React.FC = ({ tab }) => { } }, [connection, tab.dbName]); - useEffect(() => { loadData(); }, [loadData]); + useEffect(() => { + if (!autoFetchVisible) { + return; + } + void loadData(); + }, [autoFetchVisible, loadData]); const sortedFiltered = useMemo(() => { let list = [...tables]; diff --git a/frontend/src/utils/autoFetchVisibility.test.ts b/frontend/src/utils/autoFetchVisibility.test.ts new file mode 100644 index 0000000..4e0794f --- /dev/null +++ b/frontend/src/utils/autoFetchVisibility.test.ts @@ -0,0 +1,22 @@ +import { describe, expect, it } from 'vitest'; + +import { isAutoFetchVisible } from './autoFetchVisibility'; + +describe('isAutoFetchVisible', () => { + it('allows auto fetch only when the document is visible and not hidden', () => { + expect(isAutoFetchVisible({ hidden: false, visibilityState: 'visible' })).toBe(true); + }); + + it('blocks auto fetch when the page is hidden even if visibilityState looks visible', () => { + expect(isAutoFetchVisible({ hidden: true, visibilityState: 'visible' })).toBe(false); + }); + + it('blocks auto fetch when visibilityState is not visible', () => { + expect(isAutoFetchVisible({ hidden: false, visibilityState: 'hidden' })).toBe(false); + }); + + it('defaults to allowing auto fetch when document visibility APIs are unavailable', () => { + expect(isAutoFetchVisible(undefined)).toBe(true); + expect(isAutoFetchVisible({})).toBe(true); + }); +}); diff --git a/frontend/src/utils/autoFetchVisibility.ts b/frontend/src/utils/autoFetchVisibility.ts new file mode 100644 index 0000000..00836ab --- /dev/null +++ b/frontend/src/utils/autoFetchVisibility.ts @@ -0,0 +1,54 @@ +import { useEffect, useState } from 'react'; + +type AutoFetchVisibilitySource = Partial> | undefined; + +export const isAutoFetchVisible = (source?: AutoFetchVisibilitySource): boolean => { + if (!source) { + return true; + } + + if (source.hidden === true) { + return false; + } + + if (source.visibilityState && source.visibilityState !== 'visible') { + return false; + } + + return true; +}; + +const getDocumentAutoFetchVisibility = (): boolean => { + if (typeof document === 'undefined') { + return true; + } + + return isAutoFetchVisible(document); +}; + +export const useAutoFetchVisibility = (): boolean => { + const [isVisible, setIsVisible] = useState(() => getDocumentAutoFetchVisibility()); + + useEffect(() => { + if (typeof document === 'undefined') { + return undefined; + } + + const syncVisibility = () => { + setIsVisible(getDocumentAutoFetchVisibility()); + }; + + syncVisibility(); + document.addEventListener('visibilitychange', syncVisibility); + window.addEventListener('focus', syncVisibility); + window.addEventListener('pageshow', syncVisibility); + + return () => { + document.removeEventListener('visibilitychange', syncVisibility); + window.removeEventListener('focus', syncVisibility); + window.removeEventListener('pageshow', syncVisibility); + }; + }, []); + + return isVisible; +}; diff --git a/internal/app/app.go b/internal/app/app.go index 3f17b86..2f96f64 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -23,6 +23,7 @@ import ( ) const dbCachePingInterval = 30 * time.Second +const dbConnectFailureCooldown = 30 * time.Second const ( startupConnectRetryWindow = 20 * time.Second @@ -40,6 +41,11 @@ type cachedDatabase struct { lastPing time.Time } +type cachedConnectFailure struct { + occurredAt time.Time + err error +} + type queryContext struct { cancel context.CancelFunc started time.Time @@ -50,6 +56,7 @@ type App struct { ctx context.Context startedAt time.Time dbCache map[string]cachedDatabase // Cache for DB connections + connectFailures map[string]cachedConnectFailure mu sync.RWMutex // Mutex for cache access updateMu sync.Mutex updateState updateState @@ -70,6 +77,7 @@ func NewAppWithSecretStore(store secretstore.SecretStore) *App { } return &App{ dbCache: make(map[string]cachedDatabase), + connectFailures: make(map[string]cachedConnectFailure), runningQueries: make(map[string]queryContext), configDir: resolveAppConfigDir(), secretStore: store, @@ -573,14 +581,28 @@ func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing if isFileDB { logger.Infof("未命中文件库连接缓存,开始创建连接:类型=%s 缓存Key=%s", strings.TrimSpace(effectiveConfig.Type), shortKey) } + if failure, remaining, ok := a.getCachedConnectFailureByKey(key); ok { + message := fmt.Sprintf("连接最近失败,正在冷却中,请 %s 后重试;上次错误:%s", + formatConnectFailureCooldown(remaining), + normalizeErrorMessage(failure.err), + ) + logger.Warnf("命中数据库连接失败冷却:%s 缓存Key=%s 剩余=%s 原因=%s", + formatConnSummary(effectiveConfig), shortKey, formatConnectFailureCooldown(remaining), normalizeErrorMessage(failure.err)) + return nil, withLogHint{err: fmt.Errorf("%s", message), logPath: logger.Path()} + } + initialKey := key dbInst, connectedConfig, err := a.connectDatabaseWithStartupRetry(resolvedConfig) if err != nil { + failedKey := getCacheKey(connectedConfig) + a.recordConnectFailureByKey(failedKey, err) return nil, err } + a.clearConnectFailureByKey(initialKey) effectiveConfig = connectedConfig key = getCacheKey(effectiveConfig) shortKey = shortenCacheKey(key) + a.clearConnectFailureByKey(key) now := time.Now() @@ -601,6 +623,62 @@ func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing return dbInst, nil } +func (a *App) getCachedConnectFailureByKey(key string) (cachedConnectFailure, time.Duration, bool) { + if a == nil || strings.TrimSpace(key) == "" { + return cachedConnectFailure{}, 0, false + } + + a.mu.RLock() + entry, exists := a.connectFailures[key] + a.mu.RUnlock() + if !exists || entry.err == nil || entry.occurredAt.IsZero() { + return cachedConnectFailure{}, 0, false + } + + remaining := dbConnectFailureCooldown - time.Since(entry.occurredAt) + if remaining <= 0 { + a.clearConnectFailureByKey(key) + return cachedConnectFailure{}, 0, false + } + + return entry, remaining, true +} + +func (a *App) recordConnectFailureByKey(key string, err error) { + if a == nil || strings.TrimSpace(key) == "" || err == nil { + return + } + + a.mu.Lock() + if a.connectFailures == nil { + a.connectFailures = make(map[string]cachedConnectFailure) + } + a.connectFailures[key] = cachedConnectFailure{ + occurredAt: time.Now(), + err: err, + } + a.mu.Unlock() +} + +func (a *App) clearConnectFailureByKey(key string) { + if a == nil || strings.TrimSpace(key) == "" { + return + } + + a.mu.Lock() + if a.connectFailures != nil { + delete(a.connectFailures, key) + } + a.mu.Unlock() +} + +func formatConnectFailureCooldown(remaining time.Duration) time.Duration { + if remaining <= time.Second { + return time.Second + } + return remaining.Truncate(time.Second) +} + func shortenCacheKey(key string) string { if len(key) > 12 { return key[:12] diff --git a/internal/app/app_startup_connect_retry_test.go b/internal/app/app_startup_connect_retry_test.go index b8fb027..8bd0ec1 100644 --- a/internal/app/app_startup_connect_retry_test.go +++ b/internal/app/app_startup_connect_retry_test.go @@ -303,3 +303,167 @@ func TestIsTransientStartupConnectError(t *testing.T) { t.Fatal("expected authentication failure to not be treated as transient startup connect error") } } + +func TestGetDatabaseWithPing_CoolsDownRepeatedFailures(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc + defer func() { + newDatabaseFunc = originalNewDatabaseFunc + resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc + }() + + connectCalls := 0 + newDatabaseFunc = func(dbType string) (db.Database, error) { + return &fakeStartupRetryDB{ + connect: func(config connection.ConnectionConfig) error { + connectCalls++ + return errors.New("dial tcp 10.1.131.86:5432: connect: connection refused") + }, + }, nil + } + resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) { + return raw, nil + } + + a := &App{ + startedAt: time.Now().Add(-startupConnectRetryWindow - time.Second), + dbCache: make(map[string]cachedDatabase), + connectFailures: make(map[string]cachedConnectFailure), + runningQueries: make(map[string]queryContext), + } + config := connection.ConnectionConfig{Type: "postgres", Host: "10.1.131.86", Port: 5432, User: "postgres"} + + _, firstErr := a.getDatabaseWithPing(config, false) + if firstErr == nil { + t.Fatal("expected first connection attempt to fail") + } + if connectCalls != 2 { + t.Fatalf("expected first request to use 2 connect attempts, got %d", connectCalls) + } + + _, secondErr := a.getDatabaseWithPing(config, false) + if secondErr == nil { + t.Fatal("expected second connection attempt to fail during cooldown") + } + if connectCalls != 2 { + t.Fatalf("expected repeated request during cooldown to avoid reconnecting, got %d connect attempts", connectCalls) + } +} + +func TestGetDatabaseWithPing_AllowsRetryAfterFailureCooldown(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc + defer func() { + newDatabaseFunc = originalNewDatabaseFunc + resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc + }() + + connectCalls := 0 + newDatabaseFunc = func(dbType string) (db.Database, error) { + return &fakeStartupRetryDB{ + connect: func(config connection.ConnectionConfig) error { + connectCalls++ + if connectCalls <= 2 { + return errors.New("dial tcp 10.1.131.86:5432: connect: connection refused") + } + return nil + }, + }, nil + } + resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) { + return raw, nil + } + + a := &App{ + startedAt: time.Now().Add(-startupConnectRetryWindow - time.Second), + dbCache: make(map[string]cachedDatabase), + connectFailures: make(map[string]cachedConnectFailure), + runningQueries: make(map[string]queryContext), + } + config := connection.ConnectionConfig{Type: "postgres", Host: "10.1.131.86", Port: 5432, User: "postgres"} + + _, firstErr := a.getDatabaseWithPing(config, false) + if firstErr == nil { + t.Fatal("expected first connection attempt to fail") + } + if connectCalls != 2 { + t.Fatalf("expected first request to use 2 connect attempts, got %d", connectCalls) + } + + key := getCacheKey(config) + a.mu.Lock() + a.connectFailures[key] = cachedConnectFailure{ + occurredAt: time.Now().Add(-dbConnectFailureCooldown - time.Second), + err: errors.New("dial tcp 10.1.131.86:5432: connect: connection refused"), + } + a.mu.Unlock() + + inst, secondErr := a.getDatabaseWithPing(config, false) + if secondErr != nil { + t.Fatalf("expected retry after cooldown to be allowed, got error: %v", secondErr) + } + if inst == nil { + t.Fatal("expected database instance after cooldown retry") + } + if connectCalls != 3 { + t.Fatalf("expected reconnect after cooldown expiration, got %d connect attempts", connectCalls) + } +} + +func TestGetDatabaseWithPing_ClearsFailureCooldownAfterSuccess(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc + defer func() { + newDatabaseFunc = originalNewDatabaseFunc + resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc + }() + + connectCalls := 0 + newDatabaseFunc = func(dbType string) (db.Database, error) { + return &fakeStartupRetryDB{ + connect: func(config connection.ConnectionConfig) error { + connectCalls++ + if connectCalls <= 2 { + return errors.New("dial tcp 10.1.131.86:5432: connect: connection refused") + } + return nil + }, + }, nil + } + resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) { + return raw, nil + } + + a := &App{ + startedAt: time.Now().Add(-startupConnectRetryWindow - time.Second), + dbCache: make(map[string]cachedDatabase), + connectFailures: make(map[string]cachedConnectFailure), + runningQueries: make(map[string]queryContext), + } + config := connection.ConnectionConfig{Type: "postgres", Host: "10.1.131.86", Port: 5432, User: "postgres"} + + _, firstErr := a.getDatabaseWithPing(config, false) + if firstErr == nil { + t.Fatal("expected first connection attempt to fail") + } + + key := getCacheKey(config) + a.mu.Lock() + a.connectFailures[key] = cachedConnectFailure{ + occurredAt: time.Now().Add(-dbConnectFailureCooldown - time.Second), + err: errors.New("dial tcp 10.1.131.86:5432: connect: connection refused"), + } + a.mu.Unlock() + + _, secondErr := a.getDatabaseWithPing(config, false) + if secondErr != nil { + t.Fatalf("expected retry after cooldown to succeed, got error: %v", secondErr) + } + + a.mu.RLock() + _, exists := a.connectFailures[key] + a.mu.RUnlock() + if exists { + t.Fatal("expected successful connection to clear cached failure cooldown") + } +} From 070ff72ad8e9f0e058e300faec20faac5699d97c Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Fri, 10 Apr 2026 21:29:45 +0800 Subject: [PATCH 2/7] =?UTF-8?q?=E2=9C=A8=20feat(security):=20=E5=AE=8C?= =?UTF-8?q?=E6=88=90=E5=AF=86=E6=96=87=E5=8D=87=E7=BA=A7=E4=B8=8E=E8=BF=9E?= =?UTF-8?q?=E6=8E=A5=E6=81=A2=E5=A4=8D=E5=8C=85=E5=AF=BC=E5=85=A5=E5=AF=BC?= =?UTF-8?q?=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 5 +- cmd/manualtestseed/main.go | 339 +++++++ frontend/package.json.md5 | 2 +- frontend/src/App.tsx | 727 +++++++++++--- frontend/src/components/AISettingsModal.tsx | 14 +- frontend/src/components/ConnectionModal.tsx | 136 ++- .../ConnectionPackagePasswordModal.tsx | 56 ++ .../src/components/SecurityUpdateBanner.tsx | 135 +++ .../components/SecurityUpdateIntroModal.tsx | 107 ++ .../SecurityUpdateProgressModal.tsx | 67 ++ .../SecurityUpdateSettingsModal.tsx | 247 +++++ frontend/src/main.browserMock.test.ts | 99 ++ frontend/src/main.tsx | 14 +- frontend/src/store.test.ts | 48 + frontend/src/store.ts | 89 +- frontend/src/types.ts | 66 ++ frontend/src/utils/connectionExport.test.ts | 60 ++ frontend/src/utils/connectionExport.ts | 78 ++ .../utils/connectionModalPresentation.test.ts | 57 ++ .../src/utils/connectionModalPresentation.ts | 78 ++ .../src/utils/legacyConnectionStorage.test.ts | 116 ++- frontend/src/utils/legacyConnectionStorage.ts | 48 +- .../src/utils/secureConfigBootstrap.test.ts | 466 +++++++++ frontend/src/utils/secureConfigBootstrap.ts | 351 +++++++ .../utils/securityUpdatePresentation.test.ts | 96 ++ .../src/utils/securityUpdatePresentation.ts | 210 ++++ .../utils/securityUpdateRepairFlow.test.ts | 79 ++ .../src/utils/securityUpdateRepairFlow.ts | 82 ++ frontend/wailsjs/go/app/App.d.ts | 15 + frontend/wailsjs/go/app/App.js | 30 +- frontend/wailsjs/go/models.ts | 213 ++++ go.mod | 6 + go.sum | 13 +- internal/ai/service/config_store.go | 262 +++++ internal/ai/service/config_store_test.go | 206 ++++ internal/ai/service/provider_secret.go | 22 +- internal/ai/service/provider_secret_test.go | 153 ++- internal/ai/service/service.go | 165 ++- internal/app/connection_package_crypto.go | 228 +++++ .../app/connection_package_crypto_test.go | 224 +++++ internal/app/connection_package_transfer.go | 229 +++++ .../app/connection_package_transfer_test.go | 529 ++++++++++ internal/app/connection_package_types.go | 77 ++ internal/app/connection_secret_resolution.go | 24 +- .../app/connection_secret_resolution_test.go | 21 + internal/app/methods_file.go | 51 + .../app/methods_saved_connections_test.go | 86 ++ internal/app/security_update_engine.go | 561 +++++++++++ internal/app/security_update_engine_test.go | 942 ++++++++++++++++++ internal/app/security_update_rollback.go | 314 ++++++ .../app/security_update_source_current_app.go | 85 ++ internal/app/security_update_state.go | 293 ++++++ internal/app/security_update_state_test.go | 226 +++++ internal/app/security_update_types.go | 129 +++ 54 files changed, 8644 insertions(+), 332 deletions(-) create mode 100644 cmd/manualtestseed/main.go create mode 100644 frontend/src/components/ConnectionPackagePasswordModal.tsx create mode 100644 frontend/src/components/SecurityUpdateBanner.tsx create mode 100644 frontend/src/components/SecurityUpdateIntroModal.tsx create mode 100644 frontend/src/components/SecurityUpdateProgressModal.tsx create mode 100644 frontend/src/components/SecurityUpdateSettingsModal.tsx create mode 100644 frontend/src/main.browserMock.test.ts create mode 100644 frontend/src/utils/connectionExport.test.ts create mode 100644 frontend/src/utils/connectionExport.ts create mode 100644 frontend/src/utils/connectionModalPresentation.test.ts create mode 100644 frontend/src/utils/connectionModalPresentation.ts create mode 100644 frontend/src/utils/secureConfigBootstrap.test.ts create mode 100644 frontend/src/utils/secureConfigBootstrap.ts create mode 100644 frontend/src/utils/securityUpdatePresentation.test.ts create mode 100644 frontend/src/utils/securityUpdatePresentation.ts create mode 100644 frontend/src/utils/securityUpdateRepairFlow.test.ts create mode 100644 frontend/src/utils/securityUpdateRepairFlow.ts create mode 100644 internal/ai/service/config_store.go create mode 100644 internal/ai/service/config_store_test.go create mode 100644 internal/app/connection_package_crypto.go create mode 100644 internal/app/connection_package_crypto_test.go create mode 100644 internal/app/connection_package_transfer.go create mode 100644 internal/app/connection_package_transfer_test.go create mode 100644 internal/app/connection_package_types.go create mode 100644 internal/app/security_update_engine.go create mode 100644 internal/app/security_update_engine_test.go create mode 100644 internal/app/security_update_rollback.go create mode 100644 internal/app/security_update_source_current_app.go create mode 100644 internal/app/security_update_state.go create mode 100644 internal/app/security_update_state_test.go create mode 100644 internal/app/security_update_types.go diff --git a/.gitignore b/.gitignore index 0d35fb4..8aca869 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ # IDE .idea/ *.iml - +.gitignore # build / release artifacts frontend/release/ **/release/ @@ -27,4 +27,5 @@ docs/需求追踪/ CLAUDE.md **/CLAUDE.md .worktrees -docs \ No newline at end of file +docs +.tmp_superpowers_edit \ No newline at end of file diff --git a/cmd/manualtestseed/main.go b/cmd/manualtestseed/main.go new file mode 100644 index 0000000..925ec98 --- /dev/null +++ b/cmd/manualtestseed/main.go @@ -0,0 +1,339 @@ +package main + +import ( + "encoding/json" + "flag" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "GoNavi-Wails/internal/ai" + aiservice "GoNavi-Wails/internal/ai/service" + "GoNavi-Wails/internal/app" + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/secretstore" +) + +const ( + modeSeedSecureStorage = "seed-secure-storage" + modeSeedAIUpdate = "seed-ai-update" +) + +const ( + testConnectionID = "manualtest-postgres" + testSecureProviderID = "manualtest-secure-provider" + testPendingProviderID = "manualtest-pending-provider" + testBackupDirName = "manual-test-backups" + connectionsFileName = "connections.json" + globalProxyFileName = "global_proxy.json" + aiConfigFileName = "ai_config.json" + securityUpdateFileName = "config-security-update.json" +) + +type backupManifest struct { + CreatedAt string `json:"createdAt"` + ConfigDir string `json:"configDir"` + Files []backupManifestFile `json:"files"` +} + +type backupManifestFile struct { + RelativePath string `json:"relativePath"` + Existed bool `json:"existed"` +} + +type storedAIConfig struct { + SchemaVersion int `json:"schemaVersion,omitempty"` + Providers []ai.ProviderConfig `json:"providers"` + ActiveProvider string `json:"activeProvider"` + SafetyLevel string `json:"safetyLevel"` + ContextLevel string `json:"contextLevel"` +} + +func main() { + mode := flag.String("mode", modeSeedSecureStorage, "seed mode: seed-secure-storage | seed-ai-update") + flag.Parse() + + configDir, err := resolveConfigDir() + if err != nil { + fatalf("resolve config dir failed: %v", err) + } + + store := secretstore.NewKeyringStore() + if err := store.HealthCheck(); err != nil { + fatalf("secret store unavailable: %v", err) + } + + backupDir, err := backupConfigFiles(configDir) + if err != nil { + fatalf("backup config files failed: %v", err) + } + + switch strings.TrimSpace(*mode) { + case modeSeedSecureStorage: + if err := seedSecureStorage(configDir, store); err != nil { + fatalf("seed secure storage failed: %v", err) + } + fmt.Printf("mode=%s\nbackup=%s\nconnectionId=%s\nproviderId=%s\n", modeSeedSecureStorage, backupDir, testConnectionID, testSecureProviderID) + case modeSeedAIUpdate: + if err := seedAIUpdate(configDir, store); err != nil { + fatalf("seed ai update failed: %v", err) + } + fmt.Printf("mode=%s\nbackup=%s\npendingProviderId=%s\n", modeSeedAIUpdate, backupDir, testPendingProviderID) + default: + fatalf("unsupported mode: %s", *mode) + } +} + +func fatalf(format string, args ...any) { + fmt.Fprintf(os.Stderr, format+"\n", args...) + os.Exit(1) +} + +func resolveConfigDir() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", err + } + return filepath.Join(homeDir, ".gonavi"), nil +} + +func backupConfigFiles(configDir string) (string, error) { + backupDir := filepath.Join(configDir, testBackupDirName, time.Now().Format("20060102-150405")) + files := []string{ + connectionsFileName, + globalProxyFileName, + aiConfigFileName, + filepath.Join("migrations", securityUpdateFileName), + } + + manifest := backupManifest{ + CreatedAt: time.Now().Format(time.RFC3339), + ConfigDir: configDir, + Files: make([]backupManifestFile, 0, len(files)), + } + + for _, relativePath := range files { + srcPath := filepath.Join(configDir, relativePath) + info, err := os.Stat(srcPath) + if err != nil { + if os.IsNotExist(err) { + manifest.Files = append(manifest.Files, backupManifestFile{ + RelativePath: relativePath, + Existed: false, + }) + continue + } + return "", err + } + if info.IsDir() { + continue + } + + dstPath := filepath.Join(backupDir, relativePath) + if err := os.MkdirAll(filepath.Dir(dstPath), 0o755); err != nil { + return "", err + } + data, err := os.ReadFile(srcPath) + if err != nil { + return "", err + } + if err := os.WriteFile(dstPath, data, 0o644); err != nil { + return "", err + } + manifest.Files = append(manifest.Files, backupManifestFile{ + RelativePath: relativePath, + Existed: true, + }) + } + + if err := os.MkdirAll(backupDir, 0o755); err != nil { + return "", err + } + manifestData, err := json.MarshalIndent(manifest, "", " ") + if err != nil { + return "", err + } + if err := os.WriteFile(filepath.Join(backupDir, "manifest.json"), manifestData, 0o644); err != nil { + return "", err + } + return backupDir, nil +} + +func seedSecureStorage(configDir string, store secretstore.SecretStore) error { + if err := cleanupKnownTestSecrets(store); err != nil { + return err + } + + appService := app.NewAppWithSecretStore(store) + _ = appService.DeleteConnection(testConnectionID) + + if _, err := appService.SaveConnection(connection.SavedConnectionInput{ + ID: testConnectionID, + Name: "手工测试 PostgreSQL", + Config: connection.ConnectionConfig{ + ID: testConnectionID, + Type: "postgres", + Host: "127.0.0.1", + Port: 5432, + User: "postgres", + Password: "manualtest-pg-secret", + Database: "postgres", + }, + }); err != nil { + return err + } + + if _, err := appService.SaveGlobalProxy(connection.SaveGlobalProxyInput{ + Enabled: true, + Type: "http", + Host: "127.0.0.1", + Port: 7890, + User: "manual-test", + Password: "manualtest-proxy-secret", + }); err != nil { + return err + } + + storeConfig := aiservice.NewProviderConfigStore(configDir, store) + snapshot, err := storeConfig.LoadRuntime() + if err != nil { + return err + } + snapshot.Providers = filterProviders(snapshot.Providers, testSecureProviderID, testPendingProviderID) + snapshot.Providers = append(snapshot.Providers, ai.ProviderConfig{ + ID: testSecureProviderID, + Type: "custom", + Name: "手工测试 Secure Provider", + APIKey: "manualtest-ai-secret", + BaseURL: "https://api.openai.com/v1", + Model: "gpt-4o-mini", + APIFormat: "openai", + Headers: map[string]string{ + "Authorization": "Bearer manualtest-header-secret", + "X-Trace-Id": "manualtest-visible", + }, + MaxTokens: 2048, + Temperature: 0.2, + }) + if snapshot.SafetyLevel == "" { + snapshot.SafetyLevel = ai.PermissionReadOnly + } + if snapshot.ContextLevel == "" { + snapshot.ContextLevel = ai.ContextSchemaOnly + } + return storeConfig.Save(snapshot) +} + +func seedAIUpdate(configDir string, store secretstore.SecretStore) error { + if err := cleanupKnownTestSecrets(store); err != nil { + return err + } + + configPath := filepath.Join(configDir, aiConfigFileName) + cfg, err := readStoredAIConfig(configPath) + if err != nil { + return err + } + + cfg.Providers = filterProviders(cfg.Providers, testSecureProviderID, testPendingProviderID) + cfg.Providers = append(cfg.Providers, ai.ProviderConfig{ + ID: testPendingProviderID, + Type: "custom", + Name: "手工测试 待迁移 AI", + APIKey: "manualtest-ai-update-secret", + BaseURL: "https://api.openai.com/v1", + Model: "gpt-4o-mini", + APIFormat: "openai", + MaxTokens: 1024, + }) + if cfg.SchemaVersion == 0 { + cfg.SchemaVersion = 2 + } + if cfg.Providers == nil { + cfg.Providers = []ai.ProviderConfig{} + } + if err := os.MkdirAll(configDir, 0o755); err != nil { + return err + } + data, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return err + } + return os.WriteFile(configPath, data, 0o644) +} + +func readStoredAIConfig(configPath string) (storedAIConfig, error) { + cfg := storedAIConfig{ + Providers: []ai.ProviderConfig{}, + SafetyLevel: string(ai.PermissionReadOnly), + ContextLevel: string(ai.ContextSchemaOnly), + SchemaVersion: 2, + ActiveProvider: "", + } + + data, err := os.ReadFile(configPath) + if err != nil { + if os.IsNotExist(err) { + return cfg, nil + } + return storedAIConfig{}, err + } + if err := json.Unmarshal(data, &cfg); err != nil { + return storedAIConfig{}, err + } + if cfg.Providers == nil { + cfg.Providers = []ai.ProviderConfig{} + } + return cfg, nil +} + +func filterProviders(providers []ai.ProviderConfig, excludedIDs ...string) []ai.ProviderConfig { + excluded := make(map[string]struct{}, len(excludedIDs)) + for _, id := range excludedIDs { + excluded[strings.TrimSpace(id)] = struct{}{} + } + filtered := make([]ai.ProviderConfig, 0, len(providers)) + for _, provider := range providers { + if _, skip := excluded[strings.TrimSpace(provider.ID)]; skip { + continue + } + filtered = append(filtered, provider) + } + return filtered +} + +func cleanupKnownTestSecrets(store secretstore.SecretStore) error { + type secretRef struct { + kind string + id string + } + refs := []secretRef{ + {kind: "connection", id: testConnectionID}, + {kind: "global-proxy", id: "default"}, + {kind: "ai-provider", id: testSecureProviderID}, + {kind: "ai-provider", id: testPendingProviderID}, + } + + for _, item := range refs { + ref, err := secretstore.BuildRef(item.kind, item.id) + if err != nil { + return err + } + if err := store.Delete(ref); err != nil && !isIgnorableDeleteError(err) { + return err + } + } + return nil +} + +func isIgnorableDeleteError(err error) bool { + if err == nil || os.IsNotExist(err) { + return true + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + return strings.Contains(message, "could not be found") || + strings.Contains(message, "not be found in the keyring") || + strings.Contains(message, "element not found") +} diff --git a/frontend/package.json.md5 b/frontend/package.json.md5 index b8be944..ad6ce0c 100755 --- a/frontend/package.json.md5 +++ b/frontend/package.json.md5 @@ -1 +1 @@ -f697e821b4acd5cf614d63d46453e8a4 \ No newline at end of file +20168ff7047e0ecea00acb73f413f7db \ No newline at end of file diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 3c6825d..76d869b 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,25 +1,51 @@ import React, { useState, useEffect, useMemo, useCallback } from 'react'; import { Layout, Button, ConfigProvider, theme, message, Modal, Spin, Slider, Progress, Switch, Input, InputNumber, Select, Segmented, Tooltip } from 'antd'; import zhCN from 'antd/locale/zh_CN'; -import { PlusOutlined, ConsoleSqlOutlined, UploadOutlined, DownloadOutlined, CloudDownloadOutlined, BugOutlined, ToolOutlined, GlobalOutlined, InfoCircleOutlined, GithubOutlined, SkinOutlined, CheckOutlined, MinusOutlined, BorderOutlined, CloseOutlined, SettingOutlined, LinkOutlined, BgColorsOutlined, AppstoreOutlined, RobotOutlined } from '@ant-design/icons'; +import { PlusOutlined, ConsoleSqlOutlined, UploadOutlined, DownloadOutlined, CloudDownloadOutlined, BugOutlined, ToolOutlined, GlobalOutlined, InfoCircleOutlined, GithubOutlined, SkinOutlined, CheckOutlined, MinusOutlined, BorderOutlined, CloseOutlined, SettingOutlined, LinkOutlined, BgColorsOutlined, AppstoreOutlined, RobotOutlined, SafetyCertificateOutlined } from '@ant-design/icons'; import { BrowserOpenURL, Environment, EventsOn, Quit, WindowFullscreen, WindowGetPosition, WindowGetSize, WindowIsFullscreen, WindowIsMaximised, WindowMaximise, WindowMinimise, WindowSetPosition, WindowSetSize, WindowToggleMaximise, WindowUnfullscreen } from '../wailsjs/runtime'; import Sidebar from './components/Sidebar'; import TabManager from './components/TabManager'; import ConnectionModal from './components/ConnectionModal'; +import ConnectionPackagePasswordModal from './components/ConnectionPackagePasswordModal'; import DataSyncModal from './components/DataSyncModal'; import DriverManagerModal from './components/DriverManagerModal'; import LogPanel from './components/LogPanel'; import AIChatPanel from './components/AIChatPanel'; import AISettingsModal from './components/AISettingsModal'; +import SecurityUpdateBanner from './components/SecurityUpdateBanner'; +import SecurityUpdateIntroModal from './components/SecurityUpdateIntroModal'; +import SecurityUpdateProgressModal from './components/SecurityUpdateProgressModal'; +import SecurityUpdateSettingsModal from './components/SecurityUpdateSettingsModal'; import { DEFAULT_APPEARANCE, useStore } from './store'; -import { SavedConnection } from './types'; +import { SavedConnection, SecurityUpdateIssue, SecurityUpdateStatus } from './types'; import { blurToFilter, normalizeBlurForPlatform, normalizeOpacityForPlatform, isWindowsPlatform, resolveAppearanceValues } from './utils/appearance'; import { DATA_GRID_COLUMN_WIDTH_MODE_OPTIONS, sanitizeDataTableColumnWidthMode } from './utils/dataGridDisplay'; import { getMacNativeTitlebarPaddingLeft, getMacNativeTitlebarPaddingRight, shouldHandleMacNativeFullscreenShortcut, shouldSuppressMacNativeEscapeExit } from './utils/macWindow'; import { buildOverlayWorkbenchTheme } from './utils/overlayWorkbenchTheme'; import { getConnectionWorkbenchState } from './utils/startupReadiness'; -import { createGlobalProxyDraft, toSaveGlobalProxyInput } from './utils/globalProxyDraft'; -import { LEGACY_PERSIST_KEY, readLegacyPersistedSecrets, stripLegacyPersistedSecrets } from './utils/legacyConnectionStorage'; +import { toSaveGlobalProxyInput } from './utils/globalProxyDraft'; +import { detectConnectionImportKind, normalizeConnectionPackagePassword } from './utils/connectionExport'; +import { + bootstrapSecureConfig, + finalizeSecurityUpdateStatus, + mergeSecurityUpdateStatusWithLegacySource, + startSecurityUpdateFromBootstrap, +} from './utils/secureConfigBootstrap'; +import { + LEGACY_PERSIST_KEY, + hasLegacyMigratableSensitiveItems, + stripLegacyPersistedConnectionById, +} from './utils/legacyConnectionStorage'; +import { + getSecurityUpdateStatusMeta, + resolveSecurityUpdateEntryVisibility, +} from './utils/securityUpdatePresentation'; +import { + resolveSecurityUpdateRepairEntry, + shouldReopenSecurityUpdateDetails, + shouldRetrySecurityUpdateAfterRepairSave, + type SecurityUpdateRepairSource, +} from './utils/securityUpdateRepairFlow'; import { SHORTCUT_ACTION_META, SHORTCUT_ACTION_ORDER, @@ -38,7 +64,7 @@ import { resolveAIEdgeHandleDockStyle, resolveAIEdgeHandleStyle, } from './utils/aiEntryLayout'; -import { SetMacNativeWindowControls, SetWindowTranslucency } from '../wailsjs/go/app/App'; +import { GetSavedConnections, SetMacNativeWindowControls, SetWindowTranslucency } from '../wailsjs/go/app/App'; import './App.css'; const { Sider, Content } = Layout; @@ -48,6 +74,17 @@ const MIN_FONT_SIZE = 12; const MAX_FONT_SIZE = 20; const DEFAULT_UI_SCALE = 1.0; const DEFAULT_FONT_SIZE = 14; +const createEmptySecurityUpdateStatus = (): SecurityUpdateStatus => ({ + overallStatus: 'not_detected', + summary: { + total: 0, + updated: 0, + pending: 0, + skipped: 0, + failed: 0, + }, + issues: [], +}); const detectNavigatorPlatform = (): string => { if (typeof navigator === 'undefined') { @@ -63,16 +100,6 @@ const detectNavigatorPlatform = (): string => { }; -const toLegacySavedConnectionInput = (item: any) => ({ - id: typeof item?.id === 'string' ? item.id : '', - name: typeof item?.name === 'string' ? item.name : '', - config: (item?.config && typeof item.config === 'object') ? item.config : {}, - includeDatabases: Array.isArray(item?.includeDatabases) ? item.includeDatabases : undefined, - includeRedisDatabases: Array.isArray(item?.includeRedisDatabases) ? item.includeRedisDatabases : undefined, - iconType: typeof item?.iconType === 'string' ? item.iconType : '', - iconColor: typeof item?.iconColor === 'string' ? item.iconColor : '', -}); - const mergeSavedConnections = (current: SavedConnection[], imported: SavedConnection[]): SavedConnection[] => { const merged = new Map(); current.forEach((conn) => merged.set(conn.id, conn)); @@ -80,6 +107,24 @@ const mergeSavedConnections = (current: SavedConnection[], imported: SavedConnec return Array.from(merged.values()); }; +type ConnectionPackageDialogMode = 'import' | 'export'; + +type ConnectionPackageDialogState = { + open: boolean; + mode: ConnectionPackageDialogMode; + password: string; + error: string; + confirmLoading: boolean; +}; + +const createClosedConnectionPackageDialogState = (): ConnectionPackageDialogState => ({ + open: false, + mode: 'export', + password: '', + error: '', + confirmLoading: false, +}); + function App() { const [isModalOpen, setIsModalOpen] = useState(false); const [isSyncModalOpen, setIsSyncModalOpen] = useState(false); @@ -124,6 +169,18 @@ function App() { const [isLinuxRuntime, setIsLinuxRuntime] = useState(false); const [isStoreHydrated, setIsStoreHydrated] = useState(() => useStore.persist.hasHydrated()); const [hasLoadedSecureConfig, setHasLoadedSecureConfig] = useState(false); + const [securityUpdateStatus, setSecurityUpdateStatus] = useState(() => createEmptySecurityUpdateStatus()); + const [securityUpdateRawPayload, setSecurityUpdateRawPayload] = useState(null); + const [securityUpdateHasLegacySensitiveItems, setSecurityUpdateHasLegacySensitiveItems] = useState(false); + const [isSecurityUpdateIntroOpen, setIsSecurityUpdateIntroOpen] = useState(false); + const [isSecurityUpdateBannerDismissed, setIsSecurityUpdateBannerDismissed] = useState(false); + const [isSecurityUpdateSettingsOpen, setIsSecurityUpdateSettingsOpen] = useState(false); + const [isSecurityUpdateProgressOpen, setIsSecurityUpdateProgressOpen] = useState(false); + const [securityUpdateProgressStage, setSecurityUpdateProgressStage] = useState('正在检查已保存配置'); + const [securityUpdateRepairSource, setSecurityUpdateRepairSource] = useState(null); + const [focusedAIProviderId, setFocusedAIProviderId] = useState(undefined); + const [connectionPackageDialog, setConnectionPackageDialog] = useState(() => createClosedConnectionPackageDialogState()); + const [pendingConnectionImportPayload, setPendingConnectionImportPayload] = useState(null); const sidebarWidth = useStore(state => state.sidebarWidth); const setSidebarWidth = useStore(state => state.setSidebarWidth); const aiPanelVisible = useStore(state => state.aiPanelVisible); @@ -131,6 +188,14 @@ function App() { const setAIPanelVisible = useStore(state => state.setAIPanelVisible); const globalProxyInvalidHintShownRef = React.useRef(false); const connectionWorkbenchState = getConnectionWorkbenchState(isStoreHydrated, hasLoadedSecureConfig); + const securityUpdateStatusMeta = useMemo( + () => getSecurityUpdateStatusMeta(securityUpdateStatus), + [securityUpdateStatus], + ); + const securityUpdateEntryVisibility = useMemo( + () => resolveSecurityUpdateEntryVisibility(securityUpdateStatus), + [securityUpdateStatus], + ); // 同步 macOS 窗口透明度:opacity=1.0 且 blur=0 时关闭 NSVisualEffectView, // 避免 GPU 持续计算窗口背后的模糊合成 @@ -185,6 +250,39 @@ function App() { }; }, [isStoreHydrated]); + const normalizeSecurityUpdateStatus = useCallback((status?: Partial | null): SecurityUpdateStatus => { + const fallback = createEmptySecurityUpdateStatus(); + return { + ...fallback, + ...(status ?? {}), + summary: { + ...fallback.summary, + ...(status?.summary ?? {}), + }, + issues: Array.isArray(status?.issues) ? status.issues : [], + }; + }, []); + + const applySecurityUpdateStatus = useCallback(( + status?: Partial | null, + options?: { + openSettings?: boolean; + resetBannerDismissed?: boolean; + }, + ) => { + const nextStatus = normalizeSecurityUpdateStatus(status); + const visibility = resolveSecurityUpdateEntryVisibility(nextStatus); + setSecurityUpdateStatus(nextStatus); + setIsSecurityUpdateIntroOpen(visibility.showIntro); + if (options?.resetBannerDismissed !== false) { + setIsSecurityUpdateBannerDismissed(false); + } + if (options?.openSettings) { + setIsSecurityUpdateSettingsOpen(true); + } + return nextStatus; + }, [normalizeSecurityUpdateStatus]); + useEffect(() => { if (!isStoreHydrated) { return; @@ -192,82 +290,32 @@ function App() { let cancelled = false; const loadSecureConfig = async () => { - const backendApp = (window as any).go?.app?.App; - const persistedPayload = typeof window !== 'undefined' - ? window.localStorage.getItem(LEGACY_PERSIST_KEY) - : null; - const legacy = readLegacyPersistedSecrets(persistedPayload); - - let importedLegacyConnections = false; - let importedLegacyGlobalProxy = false; - - if (legacy.connections.length > 0) { - if (typeof backendApp?.ImportLegacyConnections === 'function') { - try { - await backendApp.ImportLegacyConnections( - legacy.connections.map(toLegacySavedConnectionInput) - ); - importedLegacyConnections = true; - } catch (err) { - console.warn('Failed to import legacy saved connections', err); - } - } else { - replaceConnections(legacy.connections); + try { + const result = await bootstrapSecureConfig({ + backend: (window as any).go?.app?.App, + replaceConnections, + replaceGlobalProxy, + }); + if (cancelled) { + return; } - } - - if (legacy.globalProxy) { - if (typeof backendApp?.ImportLegacyGlobalProxy === 'function') { - try { - await backendApp.ImportLegacyGlobalProxy(toSaveGlobalProxyInput(legacy.globalProxy)); - importedLegacyGlobalProxy = true; - } catch (err) { - console.warn('Failed to import legacy global proxy', err); - } - } else { - replaceGlobalProxy(createGlobalProxyDraft(legacy.globalProxy)); + setSecurityUpdateRawPayload(result.rawPayload); + setSecurityUpdateHasLegacySensitiveItems(result.hasLegacySensitiveItems); + applySecurityUpdateStatus(result.status); + } catch (err) { + console.warn('Failed to bootstrap secure config', err); + } finally { + if (!cancelled) { + setHasLoadedSecureConfig(true); } } - - if ((importedLegacyConnections || importedLegacyGlobalProxy) && persistedPayload && typeof window !== 'undefined') { - const sanitizedPayload = stripLegacyPersistedSecrets(persistedPayload); - if (sanitizedPayload && sanitizedPayload !== persistedPayload) { - window.localStorage.setItem(LEGACY_PERSIST_KEY, sanitizedPayload); - } - } - - if (typeof backendApp?.GetSavedConnections === 'function') { - try { - const savedConnections = await backendApp.GetSavedConnections(); - if (!cancelled && Array.isArray(savedConnections)) { - replaceConnections(savedConnections); - } - } catch (err) { - console.warn('Failed to load saved connections from backend', err); - } - } - - if (typeof backendApp?.GetGlobalProxyConfig === 'function') { - try { - const proxyResult = await backendApp.GetGlobalProxyConfig(); - if (!cancelled && proxyResult?.success && proxyResult.data) { - replaceGlobalProxy(createGlobalProxyDraft(proxyResult.data)); - } - } catch (err) { - console.warn('Failed to load global proxy from backend', err); - } - } - - if (!cancelled) { - setHasLoadedSecureConfig(true); - } }; void loadSecureConfig(); return () => { cancelled = true; }; - }, [isStoreHydrated, replaceConnections, replaceGlobalProxy]); + }, [applySecurityUpdateStatus, isStoreHydrated, replaceConnections, replaceGlobalProxy]); useEffect(() => { if (!isStoreHydrated || !hasLoadedSecureConfig) { @@ -772,6 +820,161 @@ function App() { const connections = useStore(state => state.connections); const tabs = useStore(state => state.tabs); const activeTabId = useStore(state => state.activeTabId); + const handleOpenSecurityUpdateSettings = useCallback(() => { + setIsSecurityUpdateIntroOpen(false); + setIsSecurityUpdateSettingsOpen(true); + }, []); + const runSecurityUpdateRound = useCallback(async (mode: 'start' | 'retry' | 'restart') => { + const backendApp = (window as any).go?.app?.App; + const stageText = mode === 'retry' + ? '正在校验更新结果' + : '正在更新安全存储'; + setSecurityUpdateProgressStage(stageText); + setIsSecurityUpdateProgressOpen(true); + setIsSecurityUpdateIntroOpen(false); + + try { + let nextStatus: SecurityUpdateStatus | null = null; + if (mode === 'start') { + const result = await startSecurityUpdateFromBootstrap({ + backend: backendApp, + replaceConnections, + replaceGlobalProxy, + }); + if (result.error) { + throw result.error; + } + nextStatus = normalizeSecurityUpdateStatus(result.status); + } else if (mode === 'retry') { + if (typeof backendApp?.RetrySecurityUpdateCurrentRound !== 'function') { + throw new Error('安全更新能力不可用'); + } + nextStatus = normalizeSecurityUpdateStatus(await backendApp.RetrySecurityUpdateCurrentRound({ + migrationId: securityUpdateStatus.migrationId, + })); + } else { + if (typeof backendApp?.RestartSecurityUpdate !== 'function') { + throw new Error('安全更新能力不可用'); + } + nextStatus = normalizeSecurityUpdateStatus(await backendApp.RestartSecurityUpdate({ + migrationId: securityUpdateStatus.migrationId, + sourceType: 'current_app_saved_config', + rawPayload: securityUpdateRawPayload ?? '', + options: { + allowPartial: true, + writeBackup: true, + }, + })); + } + + if (mode !== 'start') { + nextStatus = await finalizeSecurityUpdateStatus({ + backend: backendApp, + replaceConnections, + replaceGlobalProxy, + }, nextStatus); + } + + const shouldOpenSettings = nextStatus.overallStatus === 'needs_attention' || nextStatus.overallStatus === 'rolled_back'; + applySecurityUpdateStatus(nextStatus, { + openSettings: shouldOpenSettings, + }); + + if (nextStatus.overallStatus === 'completed') { + setSecurityUpdateHasLegacySensitiveItems(false); + setSecurityUpdateRawPayload(null); + setIsSecurityUpdateSettingsOpen(false); + void message.success('已保存配置已完成安全更新'); + } else if (nextStatus.overallStatus === 'needs_attention') { + void message.warning('更新尚未完成,有少量配置需要你处理'); + } else if (nextStatus.overallStatus === 'rolled_back') { + void message.warning('本次更新未完成,系统已保留当前可用配置'); + } + } catch (err: any) { + console.warn('Failed to execute security update round', err); + void message.error(err?.message || '安全更新未完成,请稍后重试'); + } finally { + setIsSecurityUpdateProgressOpen(false); + } + }, [ + applySecurityUpdateStatus, + normalizeSecurityUpdateStatus, + replaceConnections, + replaceGlobalProxy, + securityUpdateRawPayload, + securityUpdateStatus.migrationId, + ]); + const handleStartSecurityUpdate = useCallback(() => { + void runSecurityUpdateRound('start'); + }, [runSecurityUpdateRound]); + const handleRetrySecurityUpdate = useCallback(() => { + void runSecurityUpdateRound('retry'); + }, [runSecurityUpdateRound]); + const handleRestartSecurityUpdate = useCallback(() => { + void runSecurityUpdateRound('restart'); + }, [runSecurityUpdateRound]); + const handlePostponeSecurityUpdate = useCallback(async () => { + const backendApp = (window as any).go?.app?.App; + setIsSecurityUpdateIntroOpen(false); + try { + if (typeof backendApp?.DismissSecurityUpdateReminder === 'function') { + const nextStatus = mergeSecurityUpdateStatusWithLegacySource( + await backendApp.DismissSecurityUpdateReminder(), + securityUpdateRawPayload, + ); + applySecurityUpdateStatus(nextStatus); + return; + } + applySecurityUpdateStatus({ + overallStatus: 'postponed', + canStart: true, + canPostpone: true, + summary: securityUpdateStatus.summary, + issues: securityUpdateStatus.issues, + }); + } catch (err: any) { + console.warn('Failed to dismiss security update reminder', err); + void message.error(err?.message || '暂时无法延后本次安全更新'); + } + }, [ + applySecurityUpdateStatus, + securityUpdateRawPayload, + securityUpdateStatus.issues, + securityUpdateStatus.summary, + ]); + const handleSecurityUpdateIssueAction = useCallback((issue: SecurityUpdateIssue) => { + const repairEntry = resolveSecurityUpdateRepairEntry(issue, connections); + if (repairEntry.type === 'warning') { + void message.warning(repairEntry.message); + return; + } + if (repairEntry.type === 'connection') { + setIsSecurityUpdateSettingsOpen(false); + setSecurityUpdateRepairSource(repairEntry.repairSource); + setEditingConnection(repairEntry.connection); + setIsModalOpen(true); + return; + } + if (repairEntry.type === 'proxy') { + setIsSecurityUpdateSettingsOpen(false); + setSecurityUpdateRepairSource(repairEntry.repairSource); + setIsProxyModalOpen(true); + return; + } + if (repairEntry.type === 'ai') { + setIsSecurityUpdateSettingsOpen(false); + setSecurityUpdateRepairSource(repairEntry.repairSource); + setFocusedAIProviderId(repairEntry.providerId); + setIsAISettingsOpen(true); + return; + } + if (repairEntry.type === 'retry') { + void runSecurityUpdateRound('retry'); + return; + } + setSecurityUpdateRepairSource(null); + setIsSecurityUpdateSettingsOpen(true); + }, [connections, runSecurityUpdateRound]); const updateCheckInFlightRef = React.useRef(false); const updateDownloadInFlightRef = React.useRef(false); const updateUserDismissedRef = React.useRef(false); @@ -1179,37 +1382,74 @@ function App() { }); }, [activeTabId, tabs, connections, activeContext, addTab]); + const closeConnectionPackageDialog = useCallback(() => { + setConnectionPackageDialog(createClosedConnectionPackageDialogState()); + setPendingConnectionImportPayload(null); + }, []); + + const refreshConnectionsAfterImport = useCallback(async (importedViews: SavedConnection[]) => { + const backendApp = (window as any).go?.app?.App; + if (typeof backendApp?.GetSavedConnections === 'function') { + const latestConnections = await GetSavedConnections(); + if (!Array.isArray(latestConnections)) { + throw new Error('导入成功,但刷新连接列表失败:后端未返回连接列表'); + } + replaceConnections(latestConnections as SavedConnection[]); + return; + } + + const latestConnections = useStore.getState().connections; + replaceConnections(mergeSavedConnections(latestConnections, importedViews)); + }, [replaceConnections]); + + const importConnectionsPayload = useCallback(async (raw: string, password: string) => { + const backendApp = (window as any).go?.app?.App; + if (typeof backendApp?.ImportConnectionsPayload !== 'function') { + throw new Error('导入失败:当前后端未提供新版导入能力'); + } + + const importedViews = await backendApp.ImportConnectionsPayload(raw, password); + if (!Array.isArray(importedViews)) { + throw new Error('导入失败:后端未返回连接列表'); + } + await refreshConnectionsAfterImport(importedViews as SavedConnection[]); + return importedViews as SavedConnection[]; + }, [refreshConnectionsAfterImport]); + const handleImportConnections = async () => { const res = await (window as any).go.app.App.ImportConfigFile(); - if (res.success) { - try { - const imported = JSON.parse(res.data); - if (!Array.isArray(imported)) { - void message.error("文件格式错误:需要 JSON 数组"); - return; - } - - const normalizedItems = imported.map(toLegacySavedConnectionInput); - const backendApp = (window as any).go?.app?.App; - - if (typeof backendApp?.ImportLegacyConnections === 'function') { - const importedViews = await backendApp.ImportLegacyConnections(normalizedItems); - if (!Array.isArray(importedViews)) { - throw new Error('导入失败:后端未返回连接列表'); - } - replaceConnections(mergeSavedConnections(connections, importedViews)); - void message.success(`成功导入 ${importedViews.length} 个连接`); - return; - } - - const fallbackItems = normalizedItems as SavedConnection[]; - replaceConnections(mergeSavedConnections(connections, fallbackItems)); - void message.success(`成功导入 ${fallbackItems.length} 个连接`); - } catch (e: any) { - void message.error(e?.message || "解析 JSON 失败"); + if (!res.success) { + if (res.message !== "已取消") { + void message.error("导入失败: " + res.message); } - } else if (res.message !== "已取消") { - void message.error("导入失败: " + res.message); + return; + } + + const raw = typeof res.data === 'string' ? res.data : String(res.data ?? ''); + const importKind = detectConnectionImportKind(raw); + + if (importKind === 'invalid') { + void message.error('文件格式错误:仅支持 GoNavi 恢复包或历史 JSON 连接数组'); + return; + } + + if (importKind === 'encrypted-package') { + setPendingConnectionImportPayload(raw); + setConnectionPackageDialog({ + open: true, + mode: 'import', + password: '', + error: '', + confirmLoading: false, + }); + return; + } + + try { + const importedViews = await importConnectionsPayload(raw, ''); + void message.success(`成功导入 ${importedViews.length} 个连接`); + } catch (e: any) { + void message.error(e?.message || '导入失败'); } }; @@ -1218,11 +1458,64 @@ function App() { void message.warning("没有连接可导出"); return; } - const res = await (window as any).go.app.App.ExportData(connections, ['id','name','config','includeDatabases','includeRedisDatabases','iconType','iconColor'], "connections", "json"); - if (res.success) { - void message.success("导出成功"); - } else if (res.message !== "已取消") { - void message.error("导出失败: " + res.message); + + setConnectionPackageDialog({ + open: true, + mode: 'export', + password: '', + error: '', + confirmLoading: false, + }); + }; + + const handleConfirmConnectionPackageDialog = async () => { + const backendApp = (window as any).go?.app?.App; + const password = normalizeConnectionPackagePassword(connectionPackageDialog.password); + + if (!password) { + setConnectionPackageDialog((current) => ({ + ...current, + error: '恢复包密码不能为空', + })); + return; + } + + setConnectionPackageDialog((current) => ({ + ...current, + password, + error: '', + confirmLoading: true, + })); + + try { + if (connectionPackageDialog.mode === 'export') { + if (typeof backendApp?.ExportConnectionsPackage !== 'function') { + throw new Error('导出失败:当前后端未提供新版导出能力'); + } + + const res = await backendApp.ExportConnectionsPackage(password); + if (!res?.success) { + throw new Error(res?.message || '导出失败'); + } + + closeConnectionPackageDialog(); + void message.success('导出成功'); + return; + } + + if (!pendingConnectionImportPayload) { + throw new Error('导入失败:未找到待导入的恢复包内容'); + } + + const importedViews = await importConnectionsPayload(pendingConnectionImportPayload, password); + closeConnectionPackageDialog(); + void message.success(`成功导入 ${importedViews.length} 个连接`); + } catch (e: any) { + setConnectionPackageDialog((current) => ({ + ...current, + confirmLoading: false, + error: e?.message || (current.mode === 'export' ? '导出失败' : '导入失败'), + })); } }; @@ -1259,7 +1552,10 @@ function App() { key: 'proxy', title: '代理', icon: , - onClick: () => setIsProxyModalOpen(true), + onClick: () => { + setSecurityUpdateRepairSource(null); + setIsProxyModalOpen(true); + }, }, theme: { key: 'theme', @@ -1342,14 +1638,90 @@ function App() { document.removeEventListener('mouseup', handleLogResizeUp); }; + const handleCreateConnection = () => { + setSecurityUpdateRepairSource(null); + setEditingConnection(null); + setIsModalOpen(true); + }; + const handleEditConnection = (conn: SavedConnection) => { + setSecurityUpdateRepairSource(null); setEditingConnection(conn); setIsModalOpen(true); }; + const handleConnectionSaved = useCallback(async (savedConnection: SavedConnection) => { + if (!shouldRetrySecurityUpdateAfterRepairSave(securityUpdateRepairSource)) { + return; + } + + const backendApp = (window as any).go?.app?.App; + if (securityUpdateStatus.migrationId) { + if (typeof backendApp?.RetrySecurityUpdateCurrentRound !== 'function') { + return; + } + + const rawStatus = await backendApp.RetrySecurityUpdateCurrentRound({ + migrationId: securityUpdateStatus.migrationId, + }); + const nextStatus = await finalizeSecurityUpdateStatus({ + backend: backendApp, + replaceConnections, + replaceGlobalProxy, + }, normalizeSecurityUpdateStatus(rawStatus)); + + applySecurityUpdateStatus(nextStatus, { + openSettings: false, + }); + + if (nextStatus.overallStatus === 'completed') { + setSecurityUpdateHasLegacySensitiveItems(false); + setSecurityUpdateRawPayload(null); + } + return; + } + + if (!securityUpdateRawPayload || !savedConnection?.id) { + return; + } + + const nextRawPayload = stripLegacyPersistedConnectionById(securityUpdateRawPayload, savedConnection.id); + if (!nextRawPayload || nextRawPayload === securityUpdateRawPayload) { + return; + } + + window.localStorage.setItem(LEGACY_PERSIST_KEY, nextRawPayload); + + const rawStatus = typeof backendApp?.GetSecurityUpdateStatus === 'function' + ? await backendApp.GetSecurityUpdateStatus() + : securityUpdateStatus; + const nextStatus = mergeSecurityUpdateStatusWithLegacySource(rawStatus, nextRawPayload); + const nextHasLegacySensitiveItems = hasLegacyMigratableSensitiveItems(nextRawPayload); + + setSecurityUpdateRawPayload(nextRawPayload); + setSecurityUpdateHasLegacySensitiveItems(nextHasLegacySensitiveItems); + applySecurityUpdateStatus(nextStatus, { + openSettings: false, + }); + }, [ + applySecurityUpdateStatus, + normalizeSecurityUpdateStatus, + replaceConnections, + replaceGlobalProxy, + securityUpdateRawPayload, + securityUpdateRepairSource, + securityUpdateStatus, + securityUpdateStatus.migrationId, + ]); + const handleCloseModal = () => { + const reopenSecurityUpdateDetails = shouldReopenSecurityUpdateDetails(securityUpdateRepairSource); setIsModalOpen(false); setEditingConnection(null); + setSecurityUpdateRepairSource(null); + if (reopenSecurityUpdateDetails) { + setIsSecurityUpdateSettingsOpen(true); + } }; const handleOpenDriverManagerFromConnection = () => { @@ -1358,6 +1730,45 @@ function App() { setIsDriverModalOpen(true); }; + const handleCloseDriverManager = useCallback(() => { + const reopenSecurityUpdateDetails = shouldReopenSecurityUpdateDetails(securityUpdateRepairSource); + setIsDriverModalOpen(false); + setSecurityUpdateRepairSource(null); + if (reopenSecurityUpdateDetails) { + setIsSecurityUpdateSettingsOpen(true); + } + }, [securityUpdateRepairSource]); + + const handleOpenGlobalProxySettings = useCallback(() => { + setSecurityUpdateRepairSource(null); + setIsProxyModalOpen(true); + }, []); + + const handleCloseGlobalProxySettings = useCallback(() => { + const reopenSecurityUpdateDetails = shouldReopenSecurityUpdateDetails(securityUpdateRepairSource); + setIsProxyModalOpen(false); + setSecurityUpdateRepairSource(null); + if (reopenSecurityUpdateDetails) { + setIsSecurityUpdateSettingsOpen(true); + } + }, [securityUpdateRepairSource]); + + const handleOpenAISettings = useCallback((providerId?: string) => { + setSecurityUpdateRepairSource(null); + setFocusedAIProviderId(providerId); + setIsAISettingsOpen(true); + }, []); + + const handleCloseAISettings = useCallback(() => { + const reopenSecurityUpdateDetails = shouldReopenSecurityUpdateDetails(securityUpdateRepairSource); + setIsAISettingsOpen(false); + setFocusedAIProviderId(undefined); + setSecurityUpdateRepairSource(null); + if (reopenSecurityUpdateDetails) { + setIsSecurityUpdateSettingsOpen(true); + } + }, [securityUpdateRepairSource]); + const handleTitleBarWindowToggle = async () => { try { if (await WindowIsFullscreen()) { @@ -1811,7 +2222,7 @@ function App() { - @@ -1912,6 +2323,18 @@ function App() { /> + {securityUpdateEntryVisibility.showBanner && !isSecurityUpdateBannerDismissed && ( + setIsSecurityUpdateBannerDismissed(true)} + /> + )}
@@ -1928,7 +2351,9 @@ function App() { {renderAIEdgeHandle()}
)} - setAIPanelVisible(false)} onOpenSettings={() => setIsAISettingsOpen(true)} overlayTheme={overlayTheme} /> + setAIPanelVisible(false)} onOpenSettings={() => { + handleOpenAISettings(); + }} overlayTheme={overlayTheme} />
)} @@ -1946,6 +2371,7 @@ function App() { onClose={handleCloseModal} initialValues={editingConnection} onOpenDriverManager={handleOpenDriverManagerFromConnection} + onSaved={handleConnectionSaved} /> , '工具中心', '集中处理连接配置、同步、驱动和快捷键相关操作。')} @@ -2007,6 +2433,18 @@ function App() { setIsShortcutModalOpen(true); }, }, + { + key: 'security-update', + icon: , + title: '安全更新', + description: securityUpdateEntryVisibility.showDetailEntry || securityUpdateHasLegacySensitiveItems + ? `当前状态:${securityUpdateStatusMeta.label}` + : '查看已保存配置的安全更新状态。', + onClick: () => { + setIsToolsModalOpen(false); + setIsSecurityUpdateSettingsOpen(true); + }, + }, ].map((item) => ( + ) : null} + + , + , + , + ]} + > +
+ 为了让已保存的连接、代理和相关服务配置使用新的安全存储方式,本次更新需要进行一次本地配置更新。 + 更新前会自动创建本地备份;如果本次未完成,系统会保留当前可用配置,你也可以稍后继续。 +
+
+ ); +}; + +export type { SecurityUpdateIntroModalProps }; +export default SecurityUpdateIntroModal; diff --git a/frontend/src/components/SecurityUpdateProgressModal.tsx b/frontend/src/components/SecurityUpdateProgressModal.tsx new file mode 100644 index 0000000..dec305e --- /dev/null +++ b/frontend/src/components/SecurityUpdateProgressModal.tsx @@ -0,0 +1,67 @@ +import { Modal, Spin } from 'antd'; +import { SafetyCertificateOutlined } from '@ant-design/icons'; + +import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme'; + +interface SecurityUpdateProgressModalProps { + open: boolean; + stageText: string; + detailText?: string; + overlayTheme: OverlayWorkbenchTheme; +} + +const SecurityUpdateProgressModal = ({ + open, + stageText, + detailText, + overlayTheme, +}: SecurityUpdateProgressModalProps) => { + return ( + +
+
+ +
+
+ {stageText} +
+
+ {detailText ?? '更新过程中会保留当前可用配置,请稍候。'} +
+ +
+
+ ); +}; + +export type { SecurityUpdateProgressModalProps }; +export default SecurityUpdateProgressModal; diff --git a/frontend/src/components/SecurityUpdateSettingsModal.tsx b/frontend/src/components/SecurityUpdateSettingsModal.tsx new file mode 100644 index 0000000..fb03385 --- /dev/null +++ b/frontend/src/components/SecurityUpdateSettingsModal.tsx @@ -0,0 +1,247 @@ +import { Button, Empty, Modal, Tag } from 'antd'; +import { SafetyCertificateOutlined } from '@ant-design/icons'; + +import type { SecurityUpdateIssue, SecurityUpdateStatus } from '../types'; +import { + getSecurityUpdateIssueActionMeta, + getSecurityUpdateIssueSeverityMeta, + getSecurityUpdateItemStatusMeta, + getSecurityUpdateStatusMeta, + sortSecurityUpdateIssues, +} from '../utils/securityUpdatePresentation'; +import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme'; + +interface SecurityUpdateSettingsModalProps { + open: boolean; + darkMode: boolean; + overlayTheme: OverlayWorkbenchTheme; + status: SecurityUpdateStatus; + onClose: () => void; + onStart: () => void; + onRetry: () => void; + onRestart: () => void; + onIssueAction: (issue: SecurityUpdateIssue) => void; +} + +const sectionStyle = (overlayTheme: OverlayWorkbenchTheme) => ({ + borderRadius: 14, + border: overlayTheme.sectionBorder, + background: overlayTheme.sectionBg, + padding: 16, +}); + +const SecurityUpdateSettingsModal = ({ + open, + darkMode, + overlayTheme, + status, + onClose, + onStart, + onRetry, + onRestart, + onIssueAction, +}: SecurityUpdateSettingsModalProps) => { + const statusMeta = getSecurityUpdateStatusMeta(status); + const sortedIssues = sortSecurityUpdateIssues(status.issues); + const showStart = status.overallStatus === 'pending' || status.overallStatus === 'postponed'; + const showRetry = status.overallStatus === 'needs_attention'; + const showRestart = status.overallStatus === 'needs_attention' || status.overallStatus === 'rolled_back'; + + return ( + +
+ +
+
+
+ 安全更新 +
+
+ 管理已保存配置的安全更新状态与待处理项。 +
+
+ + )} + open={open} + onCancel={onClose} + footer={[ + showRetry ? ( + + ) : null, + showRestart ? ( + + ) : null, + showStart ? ( + + ) : null, + , + ]} + width={760} + styles={{ + content: { + background: overlayTheme.shellBg, + border: overlayTheme.shellBorder, + boxShadow: overlayTheme.shellShadow, + backdropFilter: overlayTheme.shellBackdropFilter, + }, + header: { background: 'transparent', borderBottom: 'none', paddingBottom: 8 }, + body: { paddingTop: 8, maxHeight: 640, overflowY: 'auto' }, + footer: { background: 'transparent', borderTop: 'none', paddingTop: 10 }, + }} + > +
+
+
+
+
+ 当前状态:{statusMeta.label} +
+
+ {statusMeta.description} +
+
+ + {statusMeta.label} + +
+
+ +
+
+ 影响范围 +
+
+ {[ + { label: '总计', value: status.summary.total }, + { label: '已更新', value: status.summary.updated }, + { label: '待处理', value: status.summary.pending }, + { label: '已跳过', value: status.summary.skipped }, + { label: '失败', value: status.summary.failed }, + ].map((item) => ( +
+
{item.label}
+
{item.value}
+
+ ))} +
+
+ +
+
+ 待处理清单 +
+ {sortedIssues.length === 0 ? ( + + ) : ( +
+ {sortedIssues.map((issue) => { + const actionMeta = getSecurityUpdateIssueActionMeta(issue); + const itemStatusMeta = getSecurityUpdateItemStatusMeta(issue.status); + const issueSeverityMeta = getSecurityUpdateIssueSeverityMeta(issue.severity); + return ( +
+
+
+
+ {issue.title || issue.message || issue.id} +
+ + 状态:{itemStatusMeta.label} + + + 级别:{issueSeverityMeta.label} + +
+
+ {issue.message || '当前项需要进一步处理后才能完成安全更新。'} +
+
+ +
+ ); + })} +
+ )} +
+ + {status.backupPath ? ( +
+
+ 最近一次结果 +
+
+ 备份位置:{status.backupPath} +
+ {status.lastError ? ( +
+ 最近错误:{status.lastError} +
+ ) : null} +
+ ) : null} +
+
+ ); +}; + +export type { SecurityUpdateSettingsModalProps }; +export default SecurityUpdateSettingsModal; diff --git a/frontend/src/main.browserMock.test.ts b/frontend/src/main.browserMock.test.ts new file mode 100644 index 0000000..9d18204 --- /dev/null +++ b/frontend/src/main.browserMock.test.ts @@ -0,0 +1,99 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +vi.mock('./App', () => ({ + default: () => null, +})); + +const createRootMock = vi.fn(() => ({ + render: vi.fn(), +})); + +vi.mock('react-dom/client', () => ({ + default: { + createRoot: createRootMock, + }, + createRoot: createRootMock, +})); + +const dayjsLocaleMock = vi.fn(); + +vi.mock('dayjs', () => ({ + default: Object.assign(() => null, { + locale: dayjsLocaleMock, + }), +})); + +vi.mock('dayjs/locale/zh-cn', () => ({})); + +const loaderConfigMock = vi.fn(); + +vi.mock('@monaco-editor/react', () => ({ + loader: { + config: loaderConfigMock, + }, +})); + +const defineThemeMock = vi.fn(); + +vi.mock('monaco-editor', () => ({ + editor: { + defineTheme: defineThemeMock, + }, +})); + +vi.mock('monaco-editor/esm/nls.messages.zh-cn', () => ({})); + +const importMain = async () => { + await import('./main'); + return (globalThis as typeof globalThis & { + window: { + go?: { + app?: { + App?: { + ImportConfigFile: () => Promise<{ success: boolean; message?: string }>; + ImportConnectionsPayload: (raw: string) => Promise; + ExportConnectionsPackage: () => Promise<{ success: boolean; message?: string }>; + }; + }; + }; + }; + }).window.go?.app?.App; +}; + +describe('main browser mock', () => { + beforeEach(() => { + vi.resetModules(); + vi.stubGlobal('window', {}); + vi.stubGlobal('document', { + getElementById: vi.fn(() => ({})), + }); + }); + + afterEach(() => { + vi.unstubAllGlobals(); + vi.clearAllMocks(); + vi.resetModules(); + }); + + it('returns explicit browser-mode messages for import picker and package export', async () => { + const app = await importMain(); + + expect(app).toBeDefined(); + await expect(app!.ImportConfigFile()).resolves.toEqual({ + success: false, + message: '已取消', + }); + await expect(app!.ExportConnectionsPackage()).resolves.toEqual({ + success: false, + message: '浏览器 mock 不支持恢复包导出', + }); + }); + + it('rejects non-array payloads instead of treating them as successful imports', async () => { + const app = await importMain(); + + await expect(app!.ImportConnectionsPayload('{"version":1}')).rejects.toThrow( + '浏览器 mock 不支持恢复包导入,仅支持历史 JSON 连接数组', + ); + }); +}); diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx index f1c7eab..bb2e2ac 100644 --- a/frontend/src/main.tsx +++ b/frontend/src/main.tsx @@ -121,7 +121,19 @@ if (typeof window !== 'undefined' && !(window as any).go) { CheckForUpdates: async () => ({ success: false }), OpenDownloadedUpdateDirectory: async () => ({ success: false }), InstallUpdateAndRestart: async () => ({ success: false }), - ImportConfigFile: async () => ({ success: false }), + ImportConfigFile: async () => ({ success: false, message: '已取消' }), + ImportConnectionsPayload: async (raw: string) => { + try { + const parsed = JSON.parse(raw); + if (Array.isArray(parsed)) { + return parsed.map((item) => saveMockConnection(item)); + } + } catch { + throw new Error('浏览器 mock 不支持恢复包导入,仅支持历史 JSON 连接数组'); + } + throw new Error('浏览器 mock 不支持恢复包导入,仅支持历史 JSON 连接数组'); + }, + ExportConnectionsPackage: async () => ({ success: false, message: '浏览器 mock 不支持恢复包导出' }), ExportData: async () => ({ success: false }), GetGlobalProxyConfig: async () => ({ success: true, data: cloneBrowserMockValue(mockGlobalProxy) }), SaveGlobalProxy: async (input: any) => saveMockGlobalProxy(input), diff --git a/frontend/src/store.test.ts b/frontend/src/store.test.ts index 633130a..677a287 100644 --- a/frontend/src/store.test.ts +++ b/frontend/src/store.test.ts @@ -91,4 +91,52 @@ describe('store appearance persistence', () => { expect(appearance.showDataTableVerticalBorders).toBe(true); expect(appearance.dataTableColumnWidthMode).toBe('compact'); }); + + it('does not clear persisted legacy connections during hydration migration', async () => { + storage.setItem('lite-db-storage', JSON.stringify({ + state: { + connections: [ + { + id: 'legacy-1', + name: 'Legacy', + config: { + id: 'legacy-1', + type: 'postgres', + host: 'db.local', + port: 5432, + user: 'postgres', + password: 'secret', + }, + }, + ], + }, + version: 7, + })); + + const { useStore } = await importStore(); + + expect(useStore.getState().connections).toHaveLength(1); + expect(useStore.getState().connections[0]?.config.password).toBe('secret'); + }); + + it('keeps legacy global proxy password during hydration until explicit cleanup', async () => { + storage.setItem('lite-db-storage', JSON.stringify({ + state: { + globalProxy: { + enabled: true, + type: 'http', + host: '127.0.0.1', + port: 8080, + user: 'ops', + password: 'proxy-secret', + }, + }, + version: 7, + })); + + const { useStore } = await importStore(); + + expect(useStore.getState().globalProxy.password).toBe('proxy-secret'); + expect(useStore.getState().globalProxy.hasPassword).toBe(true); + }); }); diff --git a/frontend/src/store.ts b/frontend/src/store.ts index 23eff8f..c71863b 100644 --- a/frontend/src/store.ts +++ b/frontend/src/store.ts @@ -553,6 +553,34 @@ const sanitizeSavedQueries = (value: unknown): SavedQuery[] => { return result; }; +const hasLegacyConnectionSecrets = (connections: SavedConnection[]): boolean => { + return connections.some((connection) => { + const config = connection?.config && typeof connection.config === 'object' + ? connection.config as unknown as Record + : {}; + const ssh = config.ssh && typeof config.ssh === 'object' + ? config.ssh as Record + : {}; + const proxy = config.proxy && typeof config.proxy === 'object' + ? config.proxy as Record + : {}; + const httpTunnel = config.httpTunnel && typeof config.httpTunnel === 'object' + ? config.httpTunnel as Record + : {}; + + return ( + toTrimmedString(config.password) !== '' + || toTrimmedString(ssh.password) !== '' + || toTrimmedString(proxy.password) !== '' + || toTrimmedString(httpTunnel.password) !== '' + || toTrimmedString(config.mysqlReplicaPassword) !== '' + || toTrimmedString(config.mongoReplicaPassword) !== '' + || toTrimmedString(config.uri) !== '' + || toTrimmedString(config.dsn) !== '' + ); + }); +}; + const sanitizeTheme = (value: unknown): 'light' | 'dark' => (value === 'dark' ? 'dark' : 'light'); const sanitizeSqlFormatOptions = (value: unknown): { keywordCase: 'upper' | 'lower' } => { @@ -1242,7 +1270,7 @@ export const useStore = create()( migrate: (persistedState: unknown, version: number) => { const state = unwrapPersistedAppState(persistedState) as Partial; const nextState: Partial = { ...state }; - nextState.connections = []; + nextState.connections = sanitizeConnections(state.connections); if (version < 5) { nextState.connectionTags = sanitizeConnectionTags(state.connectionTags); } else { @@ -1254,7 +1282,7 @@ export const useStore = create()( nextState.uiScale = sanitizeUiScale(state.uiScale); nextState.fontSize = sanitizeFontSize(state.fontSize); nextState.startupFullscreen = sanitizeStartupFullscreen(state.startupFullscreen); - nextState.globalProxy = sanitizeGlobalProxy(state.globalProxy, { allowPassword: false }); + nextState.globalProxy = sanitizeGlobalProxy(state.globalProxy); nextState.sqlFormatOptions = sanitizeSqlFormatOptions(state.sqlFormatOptions); nextState.queryOptions = sanitizeQueryOptions(state.queryOptions); nextState.shortcutOptions = sanitizeShortcutOptions(state.shortcutOptions); @@ -1281,7 +1309,7 @@ export const useStore = create()( return { ...currentState, ...state, - connections: currentState.connections, + connections: sanitizeConnections(state.connections), connectionTags: sanitizeConnectionTags(state.connectionTags), savedQueries: sanitizeSavedQueries(state.savedQueries), theme: sanitizeTheme(state.theme), @@ -1289,7 +1317,7 @@ export const useStore = create()( uiScale: sanitizeUiScale(state.uiScale), fontSize: sanitizeFontSize(state.fontSize), startupFullscreen: sanitizeStartupFullscreen(state.startupFullscreen), - globalProxy: sanitizeGlobalProxy(state.globalProxy, { allowPassword: false }), + globalProxy: sanitizeGlobalProxy(state.globalProxy), tableSortPreference: sanitizeTableSortPreference(state.tableSortPreference), tableColumnOrders: sanitizeTableColumnOrders(state.tableColumnOrders), enableColumnOrderMemory: state.enableColumnOrderMemory !== false, @@ -1309,30 +1337,39 @@ export const useStore = create()( aiChatSessions: [], }; }, - partialize: (state) => ({ - connectionTags: state.connectionTags, - savedQueries: state.savedQueries, - theme: state.theme, - appearance: state.appearance, - uiScale: state.uiScale, - fontSize: state.fontSize, - startupFullscreen: state.startupFullscreen, - globalProxy: toPersistedGlobalProxy(state.globalProxy), - sqlFormatOptions: state.sqlFormatOptions, - queryOptions: state.queryOptions, - shortcutOptions: state.shortcutOptions, - tableAccessCount: state.tableAccessCount, - tableSortPreference: state.tableSortPreference, - tableColumnOrders: state.tableColumnOrders, - enableColumnOrderMemory: state.enableColumnOrderMemory, - tableHiddenColumns: state.tableHiddenColumns, - enableHiddenColumnMemory: state.enableHiddenColumnMemory, - windowBounds: state.windowBounds, - windowState: state.windowState, - sidebarWidth: state.sidebarWidth, + partialize: (state) => { + const partialState: Partial = { + connectionTags: state.connectionTags, + savedQueries: state.savedQueries, + theme: state.theme, + appearance: state.appearance, + uiScale: state.uiScale, + fontSize: state.fontSize, + startupFullscreen: state.startupFullscreen, + globalProxy: toTrimmedString(state.globalProxy.password) !== '' + ? { ...state.globalProxy } + : toPersistedGlobalProxy(state.globalProxy), + sqlFormatOptions: state.sqlFormatOptions, + queryOptions: state.queryOptions, + shortcutOptions: state.shortcutOptions, + tableAccessCount: state.tableAccessCount, + tableSortPreference: state.tableSortPreference, + tableColumnOrders: state.tableColumnOrders, + enableColumnOrderMemory: state.enableColumnOrderMemory, + tableHiddenColumns: state.tableHiddenColumns, + enableHiddenColumnMemory: state.enableHiddenColumnMemory, + windowBounds: state.windowBounds, + windowState: state.windowState, + sidebarWidth: state.sidebarWidth, + }; + + if (hasLegacyConnectionSecrets(state.connections)) { + partialState.connections = state.connections; + } // AI 会话数据已迁移到后端文件持久化(~/.gonavi/sessions/),不再写入 localStorage - }), // Don't persist logs + return partialState as AppState; + }, // Don't persist logs } ) ); diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 40e1e9e..d753408 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -262,4 +262,70 @@ export interface AISafetyResult { warningMessage?: string; } +export type SecurityUpdateOverallStatus = + | 'not_detected' + | 'pending' + | 'postponed' + | 'in_progress' + | 'needs_attention' + | 'completed' + | 'rolled_back'; + +export type SecurityUpdateIssueScope = 'connection' | 'global_proxy' | 'ai_provider' | 'system'; +export type SecurityUpdateIssueSeverity = 'high' | 'medium' | 'low'; +export type SecurityUpdateItemStatus = 'pending' | 'updated' | 'needs_attention' | 'skipped' | 'failed'; +export type SecurityUpdateIssueReasonCode = + | 'migration_required' + | 'secret_missing' + | 'field_invalid' + | 'write_conflict' + | 'validation_failed' + | 'environment_blocked'; +export type SecurityUpdateIssueAction = + | 'open_connection' + | 'open_proxy_settings' + | 'open_ai_settings' + | 'retry_update' + | 'view_details'; + +export interface SecurityUpdateSummary { + total: number; + updated: number; + pending: number; + skipped: number; + failed: number; +} + +export interface SecurityUpdateIssue { + id: string; + scope?: SecurityUpdateIssueScope; + refId?: string; + title?: string; + severity?: SecurityUpdateIssueSeverity; + status?: SecurityUpdateItemStatus; + reasonCode?: SecurityUpdateIssueReasonCode; + action?: SecurityUpdateIssueAction; + message?: string; +} + +export interface SecurityUpdateStatus { + schemaVersion?: number; + migrationId?: string; + overallStatus: SecurityUpdateOverallStatus; + sourceType?: 'current_app_saved_config'; + reminderVisible?: boolean; + canStart?: boolean; + canPostpone?: boolean; + canRetry?: boolean; + backupAvailable?: boolean; + backupPath?: string; + startedAt?: string; + updatedAt?: string; + completedAt?: string; + postponedAt?: string; + summary: SecurityUpdateSummary; + issues: SecurityUpdateIssue[]; + lastError?: string; +} + diff --git a/frontend/src/utils/connectionExport.test.ts b/frontend/src/utils/connectionExport.test.ts new file mode 100644 index 0000000..5b1c53e --- /dev/null +++ b/frontend/src/utils/connectionExport.test.ts @@ -0,0 +1,60 @@ +import { describe, expect, it } from 'vitest'; + +import { + detectConnectionImportKind, + normalizeConnectionPackagePassword, +} from './connectionExport'; + +describe('connectionExport', () => { + it('detects encrypted packages by gonavi envelope kind', () => { + expect(detectConnectionImportKind(JSON.stringify({ + schemaVersion: 1, + kind: 'gonavi_connection_package', + cipher: 'AES-256-GCM', + kdf: { + name: 'Argon2id', + memoryKiB: 65536, + timeCost: 3, + parallelism: 4, + salt: 'c2FsdA==', + }, + nonce: 'bm9uY2Utbm9uY2U=', + payload: 'encrypted-data', + }))).toBe('encrypted-package'); + }); + + it('detects legacy imports from historical json arrays', () => { + expect(detectConnectionImportKind(JSON.stringify([ + { + id: 'conn-1', + name: 'Primary', + config: { + type: 'postgres', + }, + }, + ]))).toBe('legacy-json'); + }); + + it('returns invalid for malformed or unsupported content', () => { + expect(detectConnectionImportKind('{not-json}')).toBe('invalid'); + expect(detectConnectionImportKind(JSON.stringify({ + kind: 'gonavi_connection_package', + payload: 'encrypted-data', + }))).toBe('invalid'); + expect(detectConnectionImportKind(JSON.stringify([ + { + foo: 'bar', + }, + ]))).toBe('invalid'); + expect(detectConnectionImportKind(JSON.stringify({ + kind: 'other_package', + payload: 'encrypted-data', + }))).toBe('invalid'); + expect(detectConnectionImportKind('null')).toBe('invalid'); + }); + + it('trims package passwords before use', () => { + expect(normalizeConnectionPackagePassword(' secret-pass ')).toBe('secret-pass'); + expect(normalizeConnectionPackagePassword('\n\t \t')).toBe(''); + }); +}); diff --git a/frontend/src/utils/connectionExport.ts b/frontend/src/utils/connectionExport.ts new file mode 100644 index 0000000..9cec933 --- /dev/null +++ b/frontend/src/utils/connectionExport.ts @@ -0,0 +1,78 @@ +import type { ConnectionConfig, SavedConnection } from '../types'; + +export type ConnectionImportKind = 'encrypted-package' | 'legacy-json' | 'invalid'; + +type JsonObject = Record; + +const CONNECTION_PACKAGE_KIND = 'gonavi_connection_package'; + +const isJsonObject = (value: unknown): value is JsonObject => ( + typeof value === 'object' && value !== null && !Array.isArray(value) +); + +const isConnectionPackageKDF = (value: unknown): value is JsonObject => ( + isJsonObject(value) + && typeof value.name === 'string' + && typeof value.memoryKiB === 'number' + && typeof value.timeCost === 'number' + && typeof value.parallelism === 'number' + && typeof value.salt === 'string' +); + +const isConnectionPackageEnvelope = (value: unknown): value is JsonObject => ( + isJsonObject(value) + && typeof value.schemaVersion === 'number' + && value.kind === CONNECTION_PACKAGE_KIND + && typeof value.cipher === 'string' + && isConnectionPackageKDF(value.kdf) + && typeof value.nonce === 'string' + && typeof value.payload === 'string' +); + +const isLegacyConnectionConfig = (value: unknown): value is JsonObject => ( + isJsonObject(value) + && typeof value.type === 'string' +); + +const isLegacyConnectionItem = (value: unknown): value is JsonObject => ( + isJsonObject(value) + && typeof value.id === 'string' + && typeof value.name === 'string' + && isLegacyConnectionConfig(value.config) +); + +const parseConnectionImportRaw = (raw: unknown): unknown => { + if (typeof raw !== 'string') { + return raw; + } + + try { + return JSON.parse(raw); + } catch { + return undefined; + } +}; + +export const detectConnectionImportKind = (raw: unknown): ConnectionImportKind => { + const parsed = parseConnectionImportRaw(raw); + + if (Array.isArray(parsed) && parsed.every((item) => isLegacyConnectionItem(item))) { + return 'legacy-json'; + } + + if (isConnectionPackageEnvelope(parsed)) { + return 'encrypted-package'; + } + + return 'invalid'; +}; + +export const normalizeConnectionPackagePassword = (value: string): string => value.trim(); + +const legacyExportRemovedError = (): never => { + throw new Error('Legacy connection JSON export has been removed. Use the recovery package flow instead.'); +}; + +export const sanitizeConnectionConfigForExport = (_config: ConnectionConfig): never => legacyExportRemovedError(); + +export const buildExportableConnections = (_connections: SavedConnection[]): never => legacyExportRemovedError(); diff --git a/frontend/src/utils/connectionModalPresentation.test.ts b/frontend/src/utils/connectionModalPresentation.test.ts new file mode 100644 index 0000000..cca5dd9 --- /dev/null +++ b/frontend/src/utils/connectionModalPresentation.test.ts @@ -0,0 +1,57 @@ +import { describe, expect, it } from 'vitest'; + +import { + getStoredSecretPlaceholder, + normalizeConnectionSecretErrorMessage, + resolveConnectionTestFailureFeedback, +} from './connectionModalPresentation'; + +describe('connectionModalPresentation', () => { + it('shows an explicit stored-secret placeholder instead of an empty-looking password field', () => { + expect(getStoredSecretPlaceholder({ + hasStoredSecret: true, + emptyPlaceholder: '密码', + retainedLabel: '已保存密码', + })).toBe('••••••(留空表示继续沿用已保存密码)'); + }); + + it('keeps the original placeholder when no stored secret exists', () => { + expect(getStoredSecretPlaceholder({ + hasStoredSecret: false, + emptyPlaceholder: '密码', + retainedLabel: '已保存密码', + })).toBe('密码'); + }); + + it('maps missing saved-connection errors to a secret-specific hint', () => { + expect(normalizeConnectionSecretErrorMessage('saved connection not found: conn-1')).toBe( + '未找到当前连接对应的已保存密文,请重新填写密码并保存后再试', + ); + }); + + it('preserves existing user-facing messages', () => { + expect(normalizeConnectionSecretErrorMessage('连接测试超时')).toBe('连接测试超时'); + }); + + it('shows a toast-worthy failure message for saved-secret lookup errors during connection tests', () => { + expect(resolveConnectionTestFailureFeedback({ + kind: 'runtime', + reason: 'saved connection not found: conn-1', + fallback: '连接失败', + })).toEqual({ + message: '测试失败: 未找到当前连接对应的已保存密文,请重新填写密码并保存后再试', + shouldToast: true, + }); + }); + + it('keeps required-field validation failures inline without an extra toast', () => { + expect(resolveConnectionTestFailureFeedback({ + kind: 'validation', + reason: '', + fallback: '连接失败', + })).toEqual({ + message: '测试失败: 请先完善必填项后再测试连接', + shouldToast: false, + }); + }); +}); diff --git a/frontend/src/utils/connectionModalPresentation.ts b/frontend/src/utils/connectionModalPresentation.ts new file mode 100644 index 0000000..fbd841a --- /dev/null +++ b/frontend/src/utils/connectionModalPresentation.ts @@ -0,0 +1,78 @@ +type StoredSecretPlaceholderOptions = { + hasStoredSecret?: boolean; + emptyPlaceholder: string; + retainedLabel: string; +}; + +type ConnectionTestFailureKind = + | 'validation' + | 'runtime' + | 'driver_unavailable' + | 'secret_blocked'; + +type ConnectionTestFailureFeedback = { + message: string; + shouldToast: boolean; +}; + +const normalizeText = (value: unknown, fallback = ''): string => { + const text = String(value ?? '').trim(); + if (!text || text === 'undefined' || text === 'null') { + return fallback; + } + return text; +}; + +export const getStoredSecretPlaceholder = ({ + hasStoredSecret, + emptyPlaceholder, + retainedLabel, +}: StoredSecretPlaceholderOptions): string => ( + hasStoredSecret + ? `••••••(留空表示继续沿用${retainedLabel})` + : emptyPlaceholder +); + +export const normalizeConnectionSecretErrorMessage = ( + value: unknown, + fallback = '', +): string => { + const text = normalizeText(value, fallback); + const lower = text.toLowerCase(); + + if (lower.includes('saved connection not found:')) { + return '未找到当前连接对应的已保存密文,请重新填写密码并保存后再试'; + } + if (lower.includes('secret store unavailable')) { + return '系统密文存储当前不可用,请检查系统钥匙串或凭据管理器后再试'; + } + + return text; +}; + +export const resolveConnectionTestFailureFeedback = ({ + kind, + reason, + fallback, +}: { + kind: ConnectionTestFailureKind; + reason: unknown; + fallback: string; +}): ConnectionTestFailureFeedback => { + if (kind === 'validation') { + return { + message: '测试失败: 请先完善必填项后再测试连接', + shouldToast: false, + }; + } + + return { + message: `测试失败: ${normalizeConnectionSecretErrorMessage(reason, fallback)}`, + shouldToast: true, + }; +}; + +export type { + ConnectionTestFailureFeedback, + ConnectionTestFailureKind, +}; diff --git a/frontend/src/utils/legacyConnectionStorage.test.ts b/frontend/src/utils/legacyConnectionStorage.test.ts index 7f8a46b..d1d8bb8 100644 --- a/frontend/src/utils/legacyConnectionStorage.test.ts +++ b/frontend/src/utils/legacyConnectionStorage.test.ts @@ -1,6 +1,11 @@ import { describe, expect, it } from 'vitest'; -import { readLegacyPersistedSecrets, stripLegacyPersistedSecrets } from './legacyConnectionStorage'; +import { + hasLegacyMigratableSensitiveItems, + readLegacyPersistedSecrets, + stripLegacyPersistedConnectionById, + stripLegacyPersistedSecrets, +} from './legacyConnectionStorage'; describe('legacy connection storage', () => { it('extracts legacy saved connections and global proxy password from lite-db-storage', () => { @@ -37,7 +42,7 @@ describe('legacy connection storage', () => { expect(result.globalProxy?.password).toBe('proxy-secret'); }); - it('strips persisted connection secrets but keeps secretless proxy metadata', () => { + it('clears legacy connection and proxy source data after cleanup', () => { const payload = JSON.stringify({ state: { connections: [ @@ -69,7 +74,110 @@ describe('legacy connection storage', () => { const parsed = JSON.parse(sanitized); expect(parsed.state.connections).toEqual([]); - expect(parsed.state.globalProxy.password).toBeUndefined(); - expect(parsed.state.globalProxy.hasPassword).toBe(true); + expect(parsed.state.globalProxy).toBeUndefined(); + }); + + it('treats a meaningful legacy global proxy as migratable even when it has no password', () => { + const payload = JSON.stringify({ + state: { + globalProxy: { + enabled: true, + type: 'http', + host: '127.0.0.1', + port: 8080, + user: 'ops', + }, + }, + }); + + expect(hasLegacyMigratableSensitiveItems(payload)).toBe(true); + }); + + it('detects migratable sensitive items before cleanup and clears the signal after cleanup', () => { + const payload = JSON.stringify({ + state: { + connections: [ + { + id: 'conn-1', + name: 'Primary', + config: { + id: 'conn-1', + type: 'postgres', + host: 'db.local', + port: 5432, + user: 'postgres', + password: 'secret', + }, + }, + ], + globalProxy: { + enabled: true, + type: 'http', + host: '127.0.0.1', + port: 8080, + user: 'ops', + password: 'proxy-secret', + }, + }, + }); + + expect(hasLegacyMigratableSensitiveItems(payload)).toBe(true); + expect(hasLegacyMigratableSensitiveItems(stripLegacyPersistedSecrets(payload))).toBe(false); + }); + + it('removes only the repaired legacy connection while preserving other source data', () => { + const payload = JSON.stringify({ + state: { + connections: [ + { + id: 'conn-1', + name: 'Primary', + config: { + id: 'conn-1', + type: 'postgres', + host: 'db.local', + port: 5432, + user: 'postgres', + password: 'secret', + }, + }, + { + id: 'conn-2', + name: 'Replica', + config: { + id: 'conn-2', + type: 'mysql', + host: 'replica.local', + port: 3306, + user: 'root', + password: 'replica-secret', + }, + }, + ], + globalProxy: { + enabled: true, + type: 'http', + host: '127.0.0.1', + port: 8080, + user: 'ops', + password: 'proxy-secret', + }, + }, + }); + + const sanitized = stripLegacyPersistedConnectionById(payload, 'conn-1'); + const parsed = JSON.parse(sanitized); + + expect(parsed.state.connections).toEqual([ + expect.objectContaining({ + id: 'conn-2', + config: expect.objectContaining({ + password: 'replica-secret', + }), + }), + ]); + expect(parsed.state.globalProxy).toEqual(expect.objectContaining({ + password: 'proxy-secret', + })); }); }); diff --git a/frontend/src/utils/legacyConnectionStorage.ts b/frontend/src/utils/legacyConnectionStorage.ts index cdbdd6c..159d0fe 100644 --- a/frontend/src/utils/legacyConnectionStorage.ts +++ b/frontend/src/utils/legacyConnectionStorage.ts @@ -79,6 +79,11 @@ export function readLegacyPersistedSecrets(payload: string | null | undefined): }; } +export function hasLegacyMigratableSensitiveItems(payload: string | null | undefined): boolean { + const legacy = readLegacyPersistedSecrets(payload); + return legacy.connections.length > 0 || legacy.globalProxy !== null; +} + export function stripLegacyPersistedSecrets(payload: string | null | undefined): string { if (!payload || typeof payload !== 'string') { return ''; @@ -96,15 +101,42 @@ export function stripLegacyPersistedSecrets(payload: string | null | undefined): : parsed; state.connections = []; - if (state.globalProxy && typeof state.globalProxy === 'object') { - const proxy = { ...(state.globalProxy as Record) }; - const password = toTrimmedString(proxy.password); - delete proxy.password; - if (password !== '') { - proxy.hasPassword = true; - } - state.globalProxy = proxy; + if (state.globalProxy !== undefined) { + delete state.globalProxy; } return JSON.stringify(parsed); } + +export function stripLegacyPersistedConnectionById( + payload: string | null | undefined, + connectionId: string, +): string { + if (!payload || typeof payload !== 'string') { + return ''; + } + + let parsed: Record; + try { + parsed = JSON.parse(payload) as Record; + } catch { + return payload; + } + + const state = parsed.state && typeof parsed.state === 'object' + ? parsed.state as Record + : parsed; + const targetId = toTrimmedString(connectionId); + if (!targetId || !Array.isArray(state.connections)) { + return payload; + } + + state.connections = state.connections.filter((item) => { + if (!item || typeof item !== 'object') { + return true; + } + return toTrimmedString((item as { id?: unknown }).id) !== targetId; + }); + + return JSON.stringify(parsed); +} diff --git a/frontend/src/utils/secureConfigBootstrap.test.ts b/frontend/src/utils/secureConfigBootstrap.test.ts new file mode 100644 index 0000000..07e13e6 --- /dev/null +++ b/frontend/src/utils/secureConfigBootstrap.test.ts @@ -0,0 +1,466 @@ +import { describe, expect, it, vi } from 'vitest'; + +import { LEGACY_PERSIST_KEY } from './legacyConnectionStorage'; +import { + bootstrapSecureConfig, + finalizeSecurityUpdateStatus, + mergeSecurityUpdateStatusWithLegacySource, + startSecurityUpdateFromBootstrap, +} from './secureConfigBootstrap'; +import { stripLegacyPersistedConnectionById } from './legacyConnectionStorage'; + +const legacyPayload = JSON.stringify({ + state: { + connections: [ + { + id: 'legacy-1', + name: 'Legacy', + config: { + id: 'legacy-1', + type: 'postgres', + host: 'db.local', + port: 5432, + user: 'postgres', + password: 'secret', + }, + }, + ], + globalProxy: { + enabled: true, + type: 'http', + host: '127.0.0.1', + port: 8080, + user: 'ops', + password: 'proxy-secret', + }, + }, +}); + +const createMemoryStorage = () => { + const data = new Map(); + return { + getItem: (key: string) => data.get(key) ?? null, + setItem: (key: string, value: string) => { + data.set(key, value); + }, + removeItem: (key: string) => { + data.delete(key); + }, + }; +}; + +const createBaseArgs = (storage = createMemoryStorage()) => { + const replaceConnections = vi.fn(); + const replaceGlobalProxy = vi.fn(); + + storage.setItem(LEGACY_PERSIST_KEY, legacyPayload); + + return { + storage, + replaceConnections, + replaceGlobalProxy, + }; +}; + +describe('secureConfigBootstrap', () => { + it('builds legacy pending summary and issue list before the first round starts', async () => { + const args = createBaseArgs(); + + const result = await bootstrapSecureConfig({ + ...args, + backend: { + GetSecurityUpdateStatus: vi.fn().mockResolvedValue({ + overallStatus: 'not_detected', + summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 }, + issues: [], + }), + }, + }); + + expect(result.status.overallStatus).toBe('pending'); + expect(result.status.summary).toEqual({ + total: 2, + updated: 0, + pending: 2, + skipped: 0, + failed: 0, + }); + expect(result.status.issues).toEqual(expect.arrayContaining([ + expect.objectContaining({ + scope: 'connection', + refId: 'legacy-1', + action: 'open_connection', + }), + expect.objectContaining({ + scope: 'global_proxy', + action: 'open_proxy_settings', + }), + ])); + }); + + it('shows intro when legacy sensitive items exist and backend status is pending', async () => { + const args = createBaseArgs(); + + const result = await bootstrapSecureConfig({ + ...args, + backend: { + GetSecurityUpdateStatus: vi.fn().mockResolvedValue({ + overallStatus: 'pending', + summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 }, + issues: [], + }), + }, + }); + + expect(result.status.overallStatus).toBe('pending'); + expect(result.shouldShowIntro).toBe(true); + expect(result.shouldShowBanner).toBe(false); + expect(args.replaceConnections).toHaveBeenCalledWith( + expect.arrayContaining([expect.objectContaining({ id: 'legacy-1' })]), + ); + }); + + it('keeps banner flow without intro when backend status is postponed', async () => { + const args = createBaseArgs(); + + const result = await bootstrapSecureConfig({ + ...args, + backend: { + GetSecurityUpdateStatus: vi.fn().mockResolvedValue({ + overallStatus: 'postponed', + summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 }, + issues: [], + }), + }, + }); + + expect(result.shouldShowIntro).toBe(false); + expect(result.shouldShowBanner).toBe(true); + }); + + it('keeps legacy pending summary and issues when a pre-start round is postponed', async () => { + const args = createBaseArgs(); + + const result = await bootstrapSecureConfig({ + ...args, + backend: { + GetSecurityUpdateStatus: vi.fn().mockResolvedValue({ + overallStatus: 'postponed', + summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 }, + issues: [], + }), + }, + }); + + expect(result.status.overallStatus).toBe('postponed'); + expect(result.status.summary.total).toBe(2); + expect(result.status.summary.pending).toBe(2); + expect(result.status.issues).toEqual(expect.arrayContaining([ + expect.objectContaining({ scope: 'connection', refId: 'legacy-1' }), + expect.objectContaining({ scope: 'global_proxy' }), + ])); + }); + + it('merges backend pending issues with legacy source items before the first round starts', async () => { + const args = createBaseArgs(); + + const result = await bootstrapSecureConfig({ + ...args, + backend: { + GetSecurityUpdateStatus: vi.fn().mockResolvedValue({ + overallStatus: 'pending', + summary: { total: 1, updated: 0, pending: 1, skipped: 0, failed: 0 }, + issues: [ + { + id: 'ai-provider-openai-main', + scope: 'ai_provider', + refId: 'openai-main', + title: 'OpenAI', + severity: 'medium', + status: 'pending', + reasonCode: 'secret_missing', + action: 'open_ai_settings', + message: 'AI 提供商配置仍需完成安全更新', + }, + ], + }), + }, + }); + + expect(result.status.overallStatus).toBe('pending'); + expect(result.status.summary).toEqual({ + total: 3, + updated: 0, + pending: 3, + skipped: 0, + failed: 0, + }); + expect(result.status.issues).toEqual(expect.arrayContaining([ + expect.objectContaining({ scope: 'ai_provider', refId: 'openai-main' }), + expect.objectContaining({ scope: 'connection', refId: 'legacy-1' }), + expect.objectContaining({ scope: 'global_proxy' }), + ])); + }); + + it('keeps banner flow without intro when backend status is rolled_back', async () => { + const args = createBaseArgs(); + + const result = await bootstrapSecureConfig({ + ...args, + backend: { + GetSecurityUpdateStatus: vi.fn().mockResolvedValue({ + overallStatus: 'rolled_back', + summary: { total: 1, updated: 0, pending: 0, skipped: 0, failed: 1 }, + issues: [], + }), + }, + }); + + expect(result.shouldShowIntro).toBe(false); + expect(result.shouldShowBanner).toBe(true); + }); + + it('loads backend secure config directly when no legacy source exists', async () => { + const storage = createMemoryStorage(); + const replaceConnections = vi.fn(); + const replaceGlobalProxy = vi.fn(); + + const result = await bootstrapSecureConfig({ + storage, + replaceConnections, + replaceGlobalProxy, + backend: { + GetSecurityUpdateStatus: vi.fn().mockResolvedValue({ + overallStatus: 'not_detected', + summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 }, + issues: [], + }), + GetSavedConnections: vi.fn().mockResolvedValue([ + { + id: 'secure-1', + name: 'Secure', + config: { + id: 'secure-1', + type: 'postgres', + host: 'db.local', + port: 5432, + user: 'postgres', + }, + }, + ]), + }, + }); + + expect(result.status.overallStatus).toBe('not_detected'); + expect(replaceConnections).toHaveBeenCalledWith( + expect.arrayContaining([expect.objectContaining({ id: 'secure-1' })]), + ); + }); + + it('shows intro when backend status is pending even without legacy local source', async () => { + const storage = createMemoryStorage(); + const replaceConnections = vi.fn(); + const replaceGlobalProxy = vi.fn(); + + const result = await bootstrapSecureConfig({ + storage, + replaceConnections, + replaceGlobalProxy, + backend: { + GetSecurityUpdateStatus: vi.fn().mockResolvedValue({ + overallStatus: 'pending', + summary: { total: 1, updated: 0, pending: 1, skipped: 0, failed: 0 }, + issues: [], + }), + }, + }); + + expect(result.status.overallStatus).toBe('pending'); + expect(result.shouldShowIntro).toBe(true); + expect(result.shouldShowBanner).toBe(false); + }); + + it('falls back to legacy visible config when StartSecurityUpdate throws', async () => { + const args = createBaseArgs(); + + const result = await startSecurityUpdateFromBootstrap({ + ...args, + backend: { + StartSecurityUpdate: vi.fn().mockRejectedValue(new Error('boom')), + }, + }); + + expect(result.status).toBeNull(); + expect(result.error?.message).toContain('boom'); + expect(args.replaceConnections).toHaveBeenCalledWith( + expect.arrayContaining([expect.objectContaining({ id: 'legacy-1' })]), + ); + expect(args.storage.getItem(LEGACY_PERSIST_KEY)).toContain('"password":"secret"'); + }); + + it('starts security update even when rawPayload is empty but backend supports AI-only update', async () => { + const storage = createMemoryStorage(); + const replaceConnections = vi.fn(); + const replaceGlobalProxy = vi.fn(); + const StartSecurityUpdate = vi.fn().mockResolvedValue({ + overallStatus: 'completed', + summary: { total: 1, updated: 1, pending: 0, skipped: 0, failed: 0 }, + issues: [], + }); + + const result = await startSecurityUpdateFromBootstrap({ + storage, + replaceConnections, + replaceGlobalProxy, + backend: { + StartSecurityUpdate, + }, + }); + + expect(result.error).toBeNull(); + expect(result.status?.overallStatus).toBe('completed'); + expect(StartSecurityUpdate).toHaveBeenCalledWith({ + sourceType: 'current_app_saved_config', + rawPayload: '', + options: { + allowPartial: true, + writeBackup: true, + }, + }); + }); + + it('keeps source-side secrets when update ends in needs_attention', async () => { + const args = createBaseArgs(); + + const result = await startSecurityUpdateFromBootstrap({ + ...args, + backend: { + StartSecurityUpdate: vi.fn().mockResolvedValue({ + overallStatus: 'needs_attention', + summary: { total: 3, updated: 2, pending: 1, skipped: 0, failed: 0 }, + issues: [{ id: 'ai-1' }], + }), + GetSavedConnections: vi.fn().mockResolvedValue([]), + }, + }); + + expect(result.status?.overallStatus).toBe('needs_attention'); + expect(args.storage.getItem(LEGACY_PERSIST_KEY)).toContain('"password":"secret"'); + }); + + it('cleans source-side secrets only after completed update and backend refresh', async () => { + const args = createBaseArgs(); + + const result = await startSecurityUpdateFromBootstrap({ + ...args, + backend: { + StartSecurityUpdate: vi.fn().mockResolvedValue({ + overallStatus: 'completed', + summary: { total: 3, updated: 3, pending: 0, skipped: 0, failed: 0 }, + issues: [], + }), + GetSavedConnections: vi.fn().mockResolvedValue([ + { + id: 'secure-1', + name: 'Secure', + config: { + id: 'secure-1', + type: 'postgres', + host: 'db.local', + port: 5432, + user: 'postgres', + }, + hasPrimaryPassword: true, + }, + ]), + GetGlobalProxyConfig: vi.fn().mockResolvedValue({ + success: true, + data: { + enabled: true, + type: 'http', + host: '127.0.0.1', + port: 8080, + user: 'ops', + hasPassword: true, + }, + }), + }, + }); + + expect(result.status?.overallStatus).toBe('completed'); + expect(args.storage.getItem(LEGACY_PERSIST_KEY)).not.toContain('"password":"secret"'); + expect(args.replaceConnections).toHaveBeenLastCalledWith( + expect.arrayContaining([expect.objectContaining({ id: 'secure-1' })]), + ); + }); + + it('refreshes backend config and strips source-side secrets when a later round finishes as completed', async () => { + const args = createBaseArgs(); + + const status = await finalizeSecurityUpdateStatus({ + ...args, + backend: { + GetSavedConnections: vi.fn().mockResolvedValue([ + { + id: 'secure-1', + name: 'Secure', + config: { + id: 'secure-1', + type: 'postgres', + host: 'db.local', + port: 5432, + user: 'postgres', + }, + hasPrimaryPassword: true, + }, + ]), + GetGlobalProxyConfig: vi.fn().mockResolvedValue({ + success: true, + data: { + enabled: true, + type: 'http', + host: '127.0.0.1', + port: 8080, + user: 'ops', + hasPassword: true, + }, + }), + }, + }, { + overallStatus: 'completed', + summary: { total: 3, updated: 3, pending: 0, skipped: 0, failed: 0 }, + issues: [], + }); + + expect(status.overallStatus).toBe('completed'); + expect(args.storage.getItem(LEGACY_PERSIST_KEY)).not.toContain('"password":"secret"'); + expect(args.replaceConnections).toHaveBeenLastCalledWith( + expect.arrayContaining([expect.objectContaining({ id: 'secure-1' })]), + ); + }); + + it('reduces legacy pending issues after a single connection is repaired before the first round starts', () => { + const nextPayload = stripLegacyPersistedConnectionById(legacyPayload, 'legacy-1'); + + const status = mergeSecurityUpdateStatusWithLegacySource({ + overallStatus: 'not_detected', + summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 }, + issues: [], + }, nextPayload); + + expect(status.overallStatus).toBe('pending'); + expect(status.summary).toEqual({ + total: 1, + updated: 0, + pending: 1, + skipped: 0, + failed: 0, + }); + expect(status.issues).toEqual([ + expect.objectContaining({ + scope: 'global_proxy', + action: 'open_proxy_settings', + }), + ]); + }); +}); diff --git a/frontend/src/utils/secureConfigBootstrap.ts b/frontend/src/utils/secureConfigBootstrap.ts new file mode 100644 index 0000000..666178e --- /dev/null +++ b/frontend/src/utils/secureConfigBootstrap.ts @@ -0,0 +1,351 @@ +import { + GlobalProxyConfig, + SavedConnection, + SecurityUpdateIssue, + SecurityUpdateStatus, + SecurityUpdateSummary, +} from '../types'; +import { createGlobalProxyDraft } from './globalProxyDraft'; +import { + LEGACY_PERSIST_KEY, + hasLegacyMigratableSensitiveItems, + readLegacyPersistedSecrets, + stripLegacyPersistedSecrets, +} from './legacyConnectionStorage'; + +type StorageLike = Pick; + +type BackendGlobalProxyResult = { + success?: boolean; + data?: Partial; +}; + +type SecurityUpdateBackend = { + GetSecurityUpdateStatus?: () => Promise | undefined>; + StartSecurityUpdate?: (request: { + sourceType: 'current_app_saved_config'; + rawPayload: string; + options?: { + allowPartial?: boolean; + writeBackup?: boolean; + }; + }) => Promise | undefined>; + GetSavedConnections?: () => Promise; + GetGlobalProxyConfig?: () => Promise; +}; + +type SecureConfigBootstrapArgs = { + backend?: SecurityUpdateBackend; + storage?: StorageLike; + replaceConnections: (connections: SavedConnection[]) => void; + replaceGlobalProxy: (proxy: GlobalProxyConfig) => void; +}; + +type SecureConfigBootstrapResult = { + status: SecurityUpdateStatus; + rawPayload: string | null; + hasLegacySensitiveItems: boolean; + shouldShowIntro: boolean; + shouldShowBanner: boolean; +}; + +type StartSecurityUpdateResult = { + status: SecurityUpdateStatus | null; + error: Error | null; +}; + +const defaultSummary = () => ({ + total: 0, + updated: 0, + pending: 0, + skipped: 0, + failed: 0, +}); + +const hasMeaningfulSummary = (summary: SecurityUpdateSummary): boolean => ( + summary.total > 0 + || summary.updated > 0 + || summary.pending > 0 + || summary.skipped > 0 + || summary.failed > 0 +); + +const buildLegacyPendingDetails = (rawPayload: string | null): { + hasLegacyItems: boolean; + summary: SecurityUpdateSummary; + issues: SecurityUpdateIssue[]; +} => { + const legacy = readLegacyPersistedSecrets(rawPayload); + const issues: SecurityUpdateIssue[] = legacy.connections.map((connection) => ({ + id: `legacy-connection-${connection.id}`, + scope: 'connection', + refId: connection.id, + title: connection.name || connection.id, + severity: 'medium', + status: 'pending', + reasonCode: 'migration_required', + action: 'open_connection', + message: '该连接仍保存在当前应用的本地配置中,完成安全更新后会迁入新的安全存储。', + })); + + if (legacy.globalProxy) { + issues.push({ + id: 'legacy-global-proxy-default', + scope: 'global_proxy', + title: '全局代理', + severity: 'medium', + status: 'pending', + reasonCode: 'migration_required', + action: 'open_proxy_settings', + message: '全局代理仍保存在当前应用的本地配置中,完成安全更新后会迁入新的安全存储。', + }); + } + + return { + hasLegacyItems: issues.length > 0, + summary: { + total: issues.length, + updated: 0, + pending: issues.length, + skipped: 0, + failed: 0, + }, + issues, + }; +}; + +const mergeSecurityUpdateIssues = ( + baseIssues: SecurityUpdateIssue[], + legacyIssues: SecurityUpdateIssue[], +): { + issues: SecurityUpdateIssue[]; + addedCount: number; +} => { + const issueIds = new Set(baseIssues.map((issue) => issue.id)); + const additions = legacyIssues.filter((issue) => !issueIds.has(issue.id)); + return { + issues: [...baseIssues, ...additions], + addedCount: additions.length, + }; +}; + +export const mergeSecurityUpdateStatusWithLegacySource = ( + status: Partial | undefined, + rawPayload: string | null, +): SecurityUpdateStatus => { + const base: SecurityUpdateStatus = { + ...defaultStatus(), + ...status, + summary: { + ...defaultSummary(), + ...(status?.summary ?? {}), + }, + issues: Array.isArray(status?.issues) ? status.issues : [], + }; + + const legacy = buildLegacyPendingDetails(rawPayload); + if (!legacy.hasLegacyItems) { + return base; + } + + if (base.overallStatus === 'not_detected') { + return { + ...base, + overallStatus: 'pending', + reminderVisible: true, + canStart: true, + canPostpone: true, + summary: legacy.summary, + issues: legacy.issues, + }; + } + + if (base.overallStatus === 'pending' || base.overallStatus === 'postponed') { + const mergedIssues = mergeSecurityUpdateIssues(base.issues, legacy.issues); + const summary = hasMeaningfulSummary(base.summary) + ? { + total: base.summary.total + mergedIssues.addedCount, + updated: base.summary.updated, + pending: base.summary.pending + mergedIssues.addedCount, + skipped: base.summary.skipped, + failed: base.summary.failed, + } + : legacy.summary; + + return { + ...base, + summary, + issues: mergedIssues.issues, + canStart: true, + canPostpone: true, + reminderVisible: base.overallStatus === 'pending' ? true : base.reminderVisible, + }; + } + + return base; +}; + +const defaultStatus = (): SecurityUpdateStatus => ({ + overallStatus: 'not_detected', + summary: defaultSummary(), + issues: [], +}); + +const resolveStorage = (storage?: StorageLike): StorageLike | undefined => { + if (storage) { + return storage; + } + if (typeof window === 'undefined') { + return undefined; + } + return window.localStorage; +}; + +const applyLegacyVisibleConfig = ( + rawPayload: string | null, + replaceConnections: (connections: SavedConnection[]) => void, + replaceGlobalProxy: (proxy: GlobalProxyConfig) => void, +) => { + const legacy = readLegacyPersistedSecrets(rawPayload); + if (legacy.connections.length > 0) { + replaceConnections(legacy.connections); + } + if (legacy.globalProxy) { + replaceGlobalProxy(createGlobalProxyDraft(legacy.globalProxy)); + } +}; + +const refreshVisibleConfigFromBackend = async ( + backend: SecurityUpdateBackend | undefined, + replaceConnections: (connections: SavedConnection[]) => void, + replaceGlobalProxy: (proxy: GlobalProxyConfig) => void, + allowEmptyConnections: boolean, +) => { + if (typeof backend?.GetSavedConnections === 'function') { + try { + const connections = await backend.GetSavedConnections(); + if (Array.isArray(connections) && (allowEmptyConnections || connections.length > 0)) { + replaceConnections(connections); + } + } catch { + // Keep current visible state as fallback. + } + } + + if (typeof backend?.GetGlobalProxyConfig === 'function') { + try { + const proxyResult = await backend.GetGlobalProxyConfig(); + if (proxyResult?.success && proxyResult.data) { + replaceGlobalProxy(createGlobalProxyDraft(proxyResult.data)); + } + } catch { + // Keep current visible state as fallback. + } + } +}; + +const cleanupLegacySourceIfCompleted = ( + storage: StorageLike | undefined, + rawPayload: string | null, + status: SecurityUpdateStatus, +) => { + if (!storage || !rawPayload || status.overallStatus !== 'completed') { + return; + } + const sanitizedPayload = stripLegacyPersistedSecrets(rawPayload); + if (sanitizedPayload && sanitizedPayload !== rawPayload) { + storage.setItem(LEGACY_PERSIST_KEY, sanitizedPayload); + } +}; + +export async function finalizeSecurityUpdateStatus( + args: SecureConfigBootstrapArgs, + rawStatus: Partial | undefined, +): Promise { + const storage = resolveStorage(args.storage); + const rawPayload = storage?.getItem(LEGACY_PERSIST_KEY) ?? null; + const status = mergeSecurityUpdateStatusWithLegacySource(rawStatus, rawPayload); + + if (status.overallStatus === 'completed') { + await refreshVisibleConfigFromBackend(args.backend, args.replaceConnections, args.replaceGlobalProxy, true); + cleanupLegacySourceIfCompleted(storage, rawPayload, status); + } + + return status; +} + +export async function bootstrapSecureConfig(args: SecureConfigBootstrapArgs): Promise { + const storage = resolveStorage(args.storage); + const rawPayload = storage?.getItem(LEGACY_PERSIST_KEY) ?? null; + const hasLegacySensitiveItems = hasLegacyMigratableSensitiveItems(rawPayload); + + applyLegacyVisibleConfig(rawPayload, args.replaceConnections, args.replaceGlobalProxy); + + const backendStatus = typeof args.backend?.GetSecurityUpdateStatus === 'function' + ? await args.backend.GetSecurityUpdateStatus() + : undefined; + const status = mergeSecurityUpdateStatusWithLegacySource(backendStatus, rawPayload); + + if (!hasLegacySensitiveItems) { + await refreshVisibleConfigFromBackend(args.backend, args.replaceConnections, args.replaceGlobalProxy, true); + } else if (status.overallStatus === 'completed') { + await refreshVisibleConfigFromBackend(args.backend, args.replaceConnections, args.replaceGlobalProxy, true); + cleanupLegacySourceIfCompleted(storage, rawPayload, status); + } + + return { + status, + rawPayload, + hasLegacySensitiveItems, + shouldShowIntro: status.overallStatus === 'pending', + shouldShowBanner: ['postponed', 'rolled_back', 'needs_attention'].includes(status.overallStatus), + }; +} + +export async function startSecurityUpdateFromBootstrap(args: SecureConfigBootstrapArgs): Promise { + const storage = resolveStorage(args.storage); + const rawPayload = storage?.getItem(LEGACY_PERSIST_KEY) ?? null; + const startPayload = rawPayload ?? ''; + + applyLegacyVisibleConfig(rawPayload, args.replaceConnections, args.replaceGlobalProxy); + + if (typeof args.backend?.StartSecurityUpdate !== 'function') { + return { + status: null, + error: new Error('安全更新能力不可用'), + }; + } + + try { + const rawStatus = await args.backend.StartSecurityUpdate({ + sourceType: 'current_app_saved_config', + rawPayload: startPayload, + options: { + allowPartial: true, + writeBackup: true, + }, + }); + const status = mergeSecurityUpdateStatusWithLegacySource(rawStatus, rawPayload); + + if (status.overallStatus === 'completed') { + await refreshVisibleConfigFromBackend(args.backend, args.replaceConnections, args.replaceGlobalProxy, true); + cleanupLegacySourceIfCompleted(storage, rawPayload, status); + } + + return { status, error: null }; + } catch (error) { + applyLegacyVisibleConfig(rawPayload, args.replaceConnections, args.replaceGlobalProxy); + return { + status: null, + error: error instanceof Error ? error : new Error(String(error)), + }; + } +} + +export type { + BackendGlobalProxyResult, + SecurityUpdateBackend, + SecureConfigBootstrapArgs, + SecureConfigBootstrapResult, + StartSecurityUpdateResult, +}; diff --git a/frontend/src/utils/securityUpdatePresentation.test.ts b/frontend/src/utils/securityUpdatePresentation.test.ts new file mode 100644 index 0000000..4effeed --- /dev/null +++ b/frontend/src/utils/securityUpdatePresentation.test.ts @@ -0,0 +1,96 @@ +import { describe, expect, it } from 'vitest'; + +import type { SecurityUpdateIssue, SecurityUpdateStatus } from '../types'; +import { + getSecurityUpdateIssueSeverityMeta, + getSecurityUpdateItemStatusMeta, + getSecurityUpdateIssueActionMeta, + getSecurityUpdateStatusMeta, + resolveSecurityUpdateEntryVisibility, + sortSecurityUpdateIssues, +} from './securityUpdatePresentation'; + +const createStatus = (overallStatus: SecurityUpdateStatus['overallStatus']): SecurityUpdateStatus => ({ + overallStatus, + summary: { + total: 0, + updated: 0, + pending: 0, + skipped: 0, + failed: 0, + }, + issues: [], +}); + +describe('securityUpdatePresentation', () => { + it('sorts issues by severity from high to low', () => { + const issues: SecurityUpdateIssue[] = [ + { id: 'medium-1', severity: 'medium' }, + { id: 'low-1', severity: 'low' }, + { id: 'high-1', severity: 'high' }, + { id: 'medium-2', severity: 'medium' }, + ]; + + expect(sortSecurityUpdateIssues(issues).map((issue) => issue.id)).toEqual([ + 'high-1', + 'medium-1', + 'medium-2', + 'low-1', + ]); + }); + + it('maps needs_attention, rolled_back and completed to stable display labels', () => { + expect(getSecurityUpdateStatusMeta(createStatus('needs_attention')).label).toBe('待处理'); + expect(getSecurityUpdateStatusMeta(createStatus('rolled_back')).label).toBe('已回退'); + expect(getSecurityUpdateStatusMeta(createStatus('completed')).label).toBe('已完成'); + }); + + it('resolves intro, banner and detail entry visibility for key overall states', () => { + expect(resolveSecurityUpdateEntryVisibility(createStatus('pending'))).toEqual({ + showIntro: true, + showBanner: false, + showDetailEntry: true, + }); + + expect(resolveSecurityUpdateEntryVisibility(createStatus('postponed'))).toEqual({ + showIntro: false, + showBanner: true, + showDetailEntry: true, + }); + + expect(resolveSecurityUpdateEntryVisibility(createStatus('rolled_back'))).toEqual({ + showIntro: false, + showBanner: true, + showDetailEntry: true, + }); + }); + + it('maps issue scope actions to existing repair entry labels', () => { + expect(getSecurityUpdateIssueActionMeta({ id: 'conn', scope: 'connection', action: 'open_connection' }).label).toBe('打开连接'); + expect(getSecurityUpdateIssueActionMeta({ id: 'proxy', scope: 'global_proxy', action: 'open_proxy_settings' }).label).toBe('代理设置'); + expect(getSecurityUpdateIssueActionMeta({ id: 'ai', scope: 'ai_provider', action: 'open_ai_settings' }).label).toBe('AI 设置'); + expect(getSecurityUpdateIssueActionMeta({ id: 'system', scope: 'system', action: 'view_details' }).label).toBe('查看详情'); + }); + + it('maps item status to explicit Chinese labels instead of reusing severity wording', () => { + expect(getSecurityUpdateItemStatusMeta('needs_attention')).toEqual({ + label: '待处理', + color: 'warning', + }); + expect(getSecurityUpdateItemStatusMeta('updated')).toEqual({ + label: '已更新', + color: 'success', + }); + }); + + it('maps issue severity to dedicated risk labels', () => { + expect(getSecurityUpdateIssueSeverityMeta('medium')).toEqual({ + label: '中风险', + color: 'warning', + }); + expect(getSecurityUpdateIssueSeverityMeta('high')).toEqual({ + label: '高风险', + color: 'error', + }); + }); +}); diff --git a/frontend/src/utils/securityUpdatePresentation.ts b/frontend/src/utils/securityUpdatePresentation.ts new file mode 100644 index 0000000..19a16c5 --- /dev/null +++ b/frontend/src/utils/securityUpdatePresentation.ts @@ -0,0 +1,210 @@ +import type { + SecurityUpdateIssue, + SecurityUpdateIssueAction, + SecurityUpdateIssueSeverity, + SecurityUpdateItemStatus, + SecurityUpdateStatus, +} from '../types'; + +type SecurityUpdateTone = 'default' | 'warning' | 'processing' | 'success' | 'error'; + +type SecurityUpdateStatusMeta = { + label: string; + description: string; + tone: SecurityUpdateTone; +}; + +type SecurityUpdateEntryVisibility = { + showIntro: boolean; + showBanner: boolean; + showDetailEntry: boolean; +}; + +type SecurityUpdateIssueActionMeta = { + label: string; + emphasis: 'primary' | 'default'; +}; + +type SecurityUpdateBadgeMeta = { + label: string; + color: SecurityUpdateTone; +}; + +const severityWeight: Record = { + high: 0, + medium: 1, + low: 2, +}; + +const actionMetaMap: Record = { + open_connection: { + label: '打开连接', + emphasis: 'primary', + }, + open_proxy_settings: { + label: '代理设置', + emphasis: 'primary', + }, + open_ai_settings: { + label: 'AI 设置', + emphasis: 'primary', + }, + retry_update: { + label: '重新检查', + emphasis: 'primary', + }, + view_details: { + label: '查看详情', + emphasis: 'default', + }, +}; + +const itemStatusMetaMap: Record = { + pending: { + label: '待更新', + color: 'processing', + }, + updated: { + label: '已更新', + color: 'success', + }, + needs_attention: { + label: '待处理', + color: 'warning', + }, + skipped: { + label: '已跳过', + color: 'default', + }, + failed: { + label: '失败', + color: 'error', + }, +}; + +const issueSeverityMetaMap: Record = { + high: { + label: '高风险', + color: 'error', + }, + medium: { + label: '中风险', + color: 'warning', + }, + low: { + label: '低风险', + color: 'default', + }, +}; + +export function sortSecurityUpdateIssues(issues: SecurityUpdateIssue[]): SecurityUpdateIssue[] { + return [...issues].sort((left, right) => { + const leftWeight = severityWeight[left.severity ?? 'low']; + const rightWeight = severityWeight[right.severity ?? 'low']; + if (leftWeight !== rightWeight) { + return leftWeight - rightWeight; + } + return left.id.localeCompare(right.id); + }); +} + +export function getSecurityUpdateStatusMeta(status: SecurityUpdateStatus): SecurityUpdateStatusMeta { + switch (status.overallStatus) { + case 'pending': + return { + label: '待更新', + description: '检测到可进行的安全更新,你可以现在开始或稍后继续。', + tone: 'warning', + }; + case 'postponed': + return { + label: '待更新', + description: '本次安全更新已延后,当前可用配置会继续保留。', + tone: 'warning', + }; + case 'in_progress': + return { + label: '更新中', + description: '正在检查并更新已保存配置的安全存储。', + tone: 'processing', + }; + case 'needs_attention': + return { + label: '待处理', + description: '更新尚未完成,有少量配置需要你处理。', + tone: 'warning', + }; + case 'completed': + return { + label: '已完成', + description: '已保存配置已完成安全更新。', + tone: 'success', + }; + case 'rolled_back': + return { + label: '已回退', + description: '本次更新未完成,系统已保留当前可用配置。', + tone: 'error', + }; + case 'not_detected': + default: + return { + label: '未检测到', + description: '当前没有需要处理的安全更新。', + tone: 'default', + }; + } +} + +export function resolveSecurityUpdateEntryVisibility(status: SecurityUpdateStatus): SecurityUpdateEntryVisibility { + switch (status.overallStatus) { + case 'pending': + return { + showIntro: true, + showBanner: false, + showDetailEntry: true, + }; + case 'postponed': + case 'needs_attention': + case 'rolled_back': + return { + showIntro: false, + showBanner: true, + showDetailEntry: true, + }; + case 'completed': + case 'in_progress': + return { + showIntro: false, + showBanner: false, + showDetailEntry: true, + }; + case 'not_detected': + default: + return { + showIntro: false, + showBanner: false, + showDetailEntry: false, + }; + } +} + +export function getSecurityUpdateIssueActionMeta(issue: Partial): SecurityUpdateIssueActionMeta { + return actionMetaMap[issue.action ?? 'view_details'] ?? actionMetaMap.view_details; +} + +export function getSecurityUpdateItemStatusMeta(status?: SecurityUpdateItemStatus): SecurityUpdateBadgeMeta { + return itemStatusMetaMap[status ?? 'pending'] ?? itemStatusMetaMap.pending; +} + +export function getSecurityUpdateIssueSeverityMeta(severity?: SecurityUpdateIssueSeverity): SecurityUpdateBadgeMeta { + return issueSeverityMetaMap[severity ?? 'low'] ?? issueSeverityMetaMap.low; +} + +export type { + SecurityUpdateBadgeMeta, + SecurityUpdateEntryVisibility, + SecurityUpdateIssueActionMeta, + SecurityUpdateStatusMeta, + SecurityUpdateTone, +}; diff --git a/frontend/src/utils/securityUpdateRepairFlow.test.ts b/frontend/src/utils/securityUpdateRepairFlow.test.ts new file mode 100644 index 0000000..3e514bb --- /dev/null +++ b/frontend/src/utils/securityUpdateRepairFlow.test.ts @@ -0,0 +1,79 @@ +import { describe, expect, it } from 'vitest'; + +import type { SavedConnection, SecurityUpdateIssue } from '../types'; +import { + resolveSecurityUpdateRepairEntry, + shouldReopenSecurityUpdateDetails, + shouldRetrySecurityUpdateAfterRepairSave, +} from './securityUpdateRepairFlow'; + +const createConnection = (id: string): SavedConnection => ({ + id, + name: `连接-${id}`, + config: { + id, + type: 'postgres', + host: 'db.local', + port: 5432, + user: 'postgres', + }, +}); + +describe('securityUpdateRepairFlow', () => { + it('opens the matching connection and preserves the return source for security update repairs', () => { + const target = createConnection('conn-1'); + const issue: SecurityUpdateIssue = { + id: 'issue-1', + action: 'open_connection', + refId: 'conn-1', + }; + + expect(resolveSecurityUpdateRepairEntry(issue, [target])).toEqual({ + type: 'connection', + connection: target, + repairSource: 'connection', + }); + }); + + it('returns a user-facing warning when the target connection no longer exists', () => { + const issue: SecurityUpdateIssue = { + id: 'issue-1', + action: 'open_connection', + refId: 'missing-conn', + }; + + expect(resolveSecurityUpdateRepairEntry(issue, [createConnection('conn-1')])).toEqual({ + type: 'warning', + message: '未找到对应连接,请先重新检查最新状态', + }); + }); + + it('maps proxy, ai and retry actions to the expected repair entry', () => { + expect(resolveSecurityUpdateRepairEntry({ id: 'proxy', action: 'open_proxy_settings' }, [])).toEqual({ + type: 'proxy', + repairSource: 'proxy', + }); + expect(resolveSecurityUpdateRepairEntry({ id: 'ai', action: 'open_ai_settings', refId: 'provider-1' }, [])).toEqual({ + type: 'ai', + providerId: 'provider-1', + repairSource: 'ai', + }); + expect(resolveSecurityUpdateRepairEntry({ id: 'retry', action: 'retry_update' }, [])).toEqual({ + type: 'retry', + }); + }); + + it('reopens security update details after closing a repair entry opened from that page', () => { + expect(shouldReopenSecurityUpdateDetails('connection')).toBe(true); + expect(shouldReopenSecurityUpdateDetails('proxy')).toBe(true); + expect(shouldReopenSecurityUpdateDetails('ai')).toBe(true); + expect(shouldReopenSecurityUpdateDetails(null)).toBe(false); + }); + + it('retries the current round automatically after saving a connection from the repair flow', () => { + expect(shouldRetrySecurityUpdateAfterRepairSave('connection')).toBe(true); + expect(shouldRetrySecurityUpdateAfterRepairSave('proxy')).toBe(false); + expect(shouldRetrySecurityUpdateAfterRepairSave('ai')).toBe(false); + expect(shouldRetrySecurityUpdateAfterRepairSave(null)).toBe(false); + }); +}); diff --git a/frontend/src/utils/securityUpdateRepairFlow.ts b/frontend/src/utils/securityUpdateRepairFlow.ts new file mode 100644 index 0000000..5df098a --- /dev/null +++ b/frontend/src/utils/securityUpdateRepairFlow.ts @@ -0,0 +1,82 @@ +import type { SavedConnection, SecurityUpdateIssue } from '../types'; + +export type SecurityUpdateRepairSource = 'connection' | 'proxy' | 'ai'; + +export type SecurityUpdateRepairEntry = + | { + type: 'connection'; + connection: SavedConnection; + repairSource: 'connection'; + } + | { + type: 'proxy'; + repairSource: 'proxy'; + } + | { + type: 'ai'; + providerId?: string; + repairSource: 'ai'; + } + | { + type: 'retry'; + } + | { + type: 'details'; + } + | { + type: 'warning'; + message: string; + }; + +export const resolveSecurityUpdateRepairEntry = ( + issue: SecurityUpdateIssue, + connections: SavedConnection[], +): SecurityUpdateRepairEntry => { + if (issue.action === 'open_connection') { + const target = connections.find((connection) => connection.id === issue.refId); + if (!target) { + return { + type: 'warning', + message: '未找到对应连接,请先重新检查最新状态', + }; + } + return { + type: 'connection', + connection: target, + repairSource: 'connection', + }; + } + + if (issue.action === 'open_proxy_settings') { + return { + type: 'proxy', + repairSource: 'proxy', + }; + } + + if (issue.action === 'open_ai_settings') { + return { + type: 'ai', + providerId: issue.refId || undefined, + repairSource: 'ai', + }; + } + + if (issue.action === 'retry_update') { + return { + type: 'retry', + }; + } + + return { + type: 'details', + }; +}; + +export const shouldReopenSecurityUpdateDetails = ( + repairSource: SecurityUpdateRepairSource | null | undefined, +): boolean => repairSource === 'connection' || repairSource === 'proxy' || repairSource === 'ai'; + +export const shouldRetrySecurityUpdateAfterRepairSave = ( + repairSource: SecurityUpdateRepairSource | null | undefined, +): boolean => repairSource === 'connection'; diff --git a/frontend/wailsjs/go/app/App.d.ts b/frontend/wailsjs/go/app/App.d.ts index f94ace7..e18d5a7 100755 --- a/frontend/wailsjs/go/app/App.d.ts +++ b/frontend/wailsjs/go/app/App.d.ts @@ -2,6 +2,7 @@ // This file is automatically generated. DO NOT EDIT import {connection} from '../models'; import {sync} from '../models'; +import {app} from '../models'; import {redis} from '../models'; export function ApplyChanges(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:connection.ChangeSet):Promise; @@ -54,6 +55,8 @@ export function DataSyncPreview(arg1:sync.SyncConfig,arg2:string,arg3:number):Pr export function DeleteConnection(arg1:string):Promise; +export function DismissSecurityUpdateReminder():Promise; + export function DownloadDriverPackage(arg1:string,arg2:string,arg3:string,arg4:string):Promise; export function DownloadUpdate():Promise; @@ -70,6 +73,8 @@ export function DuplicateConnection(arg1:string):Promise; +export function ExportConnectionsPackage(arg1:string):Promise; + export function ExportData(arg1:Array>,arg2:Array,arg3:string,arg4:string):Promise; export function ExportDatabaseSQL(arg1:connection.ConnectionConfig,arg2:string,arg3:boolean):Promise; @@ -96,8 +101,12 @@ export function GetGlobalProxyConfig():Promise; export function GetSavedConnections():Promise>; +export function GetSecurityUpdateStatus():Promise; + export function ImportConfigFile():Promise; +export function ImportConnectionsPayload(arg1:string,arg2:string):Promise>; + export function ImportData(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise; export function ImportDataWithProgress(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise; @@ -190,6 +199,10 @@ export function ResolveDriverPackageDownloadURL(arg1:string,arg2:string):Promise export function ResolveDriverRepositoryURL(arg1:string):Promise; +export function RestartSecurityUpdate(arg1:app.RestartSecurityUpdateRequest):Promise; + +export function RetrySecurityUpdateCurrentRound(arg1:app.RetrySecurityUpdateRequest):Promise; + export function SaveConnection(arg1:connection.SavedConnectionInput):Promise; export function SaveGlobalProxy(arg1:connection.SaveGlobalProxyInput):Promise; @@ -208,6 +221,8 @@ export function SetMacNativeWindowControls(arg1:boolean):Promise; export function SetWindowTranslucency(arg1:number,arg2:number):Promise; +export function StartSecurityUpdate(arg1:app.StartSecurityUpdateRequest):Promise; + export function TestConnection(arg1:connection.ConnectionConfig):Promise; export function TruncateTables(arg1:connection.ConnectionConfig,arg2:string,arg3:Array):Promise; diff --git a/frontend/wailsjs/go/app/App.js b/frontend/wailsjs/go/app/App.js index d2e2c50..5f65811 100755 --- a/frontend/wailsjs/go/app/App.js +++ b/frontend/wailsjs/go/app/App.js @@ -102,6 +102,10 @@ export function DeleteConnection(arg1) { return window['go']['app']['App']['DeleteConnection'](arg1); } +export function DismissSecurityUpdateReminder() { + return window['go']['app']['App']['DismissSecurityUpdateReminder'](); +} + export function DownloadDriverPackage(arg1, arg2, arg3, arg4) { return window['go']['app']['App']['DownloadDriverPackage'](arg1, arg2, arg3, arg4); } @@ -134,6 +138,10 @@ export function ExecuteSQLFile(arg1, arg2, arg3, arg4) { return window['go']['app']['App']['ExecuteSQLFile'](arg1, arg2, arg3, arg4); } +export function ExportConnectionsPackage(arg1) { + return window['go']['app']['App']['ExportConnectionsPackage'](arg1); +} + export function ExportData(arg1, arg2, arg3, arg4) { return window['go']['app']['App']['ExportData'](arg1, arg2, arg3, arg4); } @@ -186,10 +194,18 @@ export function GetSavedConnections() { return window['go']['app']['App']['GetSavedConnections'](); } +export function GetSecurityUpdateStatus() { + return window['go']['app']['App']['GetSecurityUpdateStatus'](); +} + export function ImportConfigFile() { return window['go']['app']['App']['ImportConfigFile'](); } +export function ImportConnectionsPayload(arg1, arg2) { + return window['go']['app']['App']['ImportConnectionsPayload'](arg1, arg2); +} + export function ImportData(arg1, arg2, arg3) { return window['go']['app']['App']['ImportData'](arg1, arg2, arg3); } @@ -207,7 +223,7 @@ export function ImportLegacyGlobalProxy(arg1) { } export function InstallLocalDriverPackage(arg1, arg2, arg3, arg4) { - return window['go']['app']['App']['InstallLocalDriverPackage'](arg1, arg2, arg3, arg4); + return window['go']['app']['App']['InstallLocalDriverPackage'](arg1, arg2, arg3, arg4); } export function InstallUpdateAndRestart() { @@ -374,6 +390,14 @@ export function ResolveDriverRepositoryURL(arg1) { return window['go']['app']['App']['ResolveDriverRepositoryURL'](arg1); } +export function RestartSecurityUpdate(arg1) { + return window['go']['app']['App']['RestartSecurityUpdate'](arg1); +} + +export function RetrySecurityUpdateCurrentRound(arg1) { + return window['go']['app']['App']['RetrySecurityUpdateCurrentRound'](arg1); +} + export function SaveConnection(arg1) { return window['go']['app']['App']['SaveConnection'](arg1); } @@ -410,6 +434,10 @@ export function SetWindowTranslucency(arg1, arg2) { return window['go']['app']['App']['SetWindowTranslucency'](arg1, arg2); } +export function StartSecurityUpdate(arg1) { + return window['go']['app']['App']['StartSecurityUpdate'](arg1); +} + export function TestConnection(arg1) { return window['go']['app']['App']['TestConnection'](arg1); } diff --git a/frontend/wailsjs/go/models.ts b/frontend/wailsjs/go/models.ts index 433b7bc..ca54253 100755 --- a/frontend/wailsjs/go/models.ts +++ b/frontend/wailsjs/go/models.ts @@ -179,6 +179,219 @@ export namespace ai { } +export namespace app { + + export class SecurityUpdateOptions { + allowPartial?: boolean; + writeBackup?: boolean; + + static createFrom(source: any = {}) { + return new SecurityUpdateOptions(source); + } + + constructor(source: any = {}) { + if ('string' === typeof source) source = JSON.parse(source); + this.allowPartial = source["allowPartial"]; + this.writeBackup = source["writeBackup"]; + } + } + export class RestartSecurityUpdateRequest { + migrationId?: string; + sourceType: string; + rawPayload?: string; + options?: SecurityUpdateOptions; + + static createFrom(source: any = {}) { + return new RestartSecurityUpdateRequest(source); + } + + constructor(source: any = {}) { + if ('string' === typeof source) source = JSON.parse(source); + this.migrationId = source["migrationId"]; + this.sourceType = source["sourceType"]; + this.rawPayload = source["rawPayload"]; + this.options = this.convertValues(source["options"], SecurityUpdateOptions); + } + + convertValues(a: any, classs: any, asMap: boolean = false): any { + if (!a) { + return a; + } + if (a.slice && a.map) { + return (a as any[]).map(elem => this.convertValues(elem, classs)); + } else if ("object" === typeof a) { + if (asMap) { + for (const key of Object.keys(a)) { + a[key] = new classs(a[key]); + } + return a; + } + return new classs(a); + } + return a; + } + } + export class RetrySecurityUpdateRequest { + migrationId?: string; + + static createFrom(source: any = {}) { + return new RetrySecurityUpdateRequest(source); + } + + constructor(source: any = {}) { + if ('string' === typeof source) source = JSON.parse(source); + this.migrationId = source["migrationId"]; + } + } + export class SecurityUpdateIssue { + id: string; + scope: string; + refId?: string; + title: string; + severity: string; + status: string; + reasonCode: string; + action: string; + message: string; + + static createFrom(source: any = {}) { + return new SecurityUpdateIssue(source); + } + + constructor(source: any = {}) { + if ('string' === typeof source) source = JSON.parse(source); + this.id = source["id"]; + this.scope = source["scope"]; + this.refId = source["refId"]; + this.title = source["title"]; + this.severity = source["severity"]; + this.status = source["status"]; + this.reasonCode = source["reasonCode"]; + this.action = source["action"]; + this.message = source["message"]; + } + } + + export class SecurityUpdateSummary { + total: number; + updated: number; + pending: number; + skipped: number; + failed: number; + + static createFrom(source: any = {}) { + return new SecurityUpdateSummary(source); + } + + constructor(source: any = {}) { + if ('string' === typeof source) source = JSON.parse(source); + this.total = source["total"]; + this.updated = source["updated"]; + this.pending = source["pending"]; + this.skipped = source["skipped"]; + this.failed = source["failed"]; + } + } + export class SecurityUpdateStatus { + schemaVersion?: number; + migrationId?: string; + overallStatus: string; + sourceType?: string; + reminderVisible: boolean; + canStart: boolean; + canPostpone: boolean; + canRetry: boolean; + backupAvailable: boolean; + backupPath?: string; + startedAt?: string; + updatedAt?: string; + completedAt?: string; + postponedAt?: string; + summary: SecurityUpdateSummary; + issues: SecurityUpdateIssue[]; + lastError?: string; + + static createFrom(source: any = {}) { + return new SecurityUpdateStatus(source); + } + + constructor(source: any = {}) { + if ('string' === typeof source) source = JSON.parse(source); + this.schemaVersion = source["schemaVersion"]; + this.migrationId = source["migrationId"]; + this.overallStatus = source["overallStatus"]; + this.sourceType = source["sourceType"]; + this.reminderVisible = source["reminderVisible"]; + this.canStart = source["canStart"]; + this.canPostpone = source["canPostpone"]; + this.canRetry = source["canRetry"]; + this.backupAvailable = source["backupAvailable"]; + this.backupPath = source["backupPath"]; + this.startedAt = source["startedAt"]; + this.updatedAt = source["updatedAt"]; + this.completedAt = source["completedAt"]; + this.postponedAt = source["postponedAt"]; + this.summary = this.convertValues(source["summary"], SecurityUpdateSummary); + this.issues = this.convertValues(source["issues"], SecurityUpdateIssue); + this.lastError = source["lastError"]; + } + + convertValues(a: any, classs: any, asMap: boolean = false): any { + if (!a) { + return a; + } + if (a.slice && a.map) { + return (a as any[]).map(elem => this.convertValues(elem, classs)); + } else if ("object" === typeof a) { + if (asMap) { + for (const key of Object.keys(a)) { + a[key] = new classs(a[key]); + } + return a; + } + return new classs(a); + } + return a; + } + } + + export class StartSecurityUpdateRequest { + sourceType: string; + rawPayload?: string; + options?: SecurityUpdateOptions; + + static createFrom(source: any = {}) { + return new StartSecurityUpdateRequest(source); + } + + constructor(source: any = {}) { + if ('string' === typeof source) source = JSON.parse(source); + this.sourceType = source["sourceType"]; + this.rawPayload = source["rawPayload"]; + this.options = this.convertValues(source["options"], SecurityUpdateOptions); + } + + convertValues(a: any, classs: any, asMap: boolean = false): any { + if (!a) { + return a; + } + if (a.slice && a.map) { + return (a as any[]).map(elem => this.convertValues(elem, classs)); + } else if ("object" === typeof a) { + if (asMap) { + for (const key of Object.keys(a)) { + a[key] = new classs(a[key]); + } + return a; + } + return new classs(a); + } + return a; + } + } + +} + export namespace connection { export class UpdateRow { diff --git a/go.mod b/go.mod index d2cd1fe..29edc5a 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,12 @@ require ( modernc.org/sqlite v1.44.3 ) +require ( + github.com/kr/pretty v0.3.1 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect +) + require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect diff --git a/go.sum b/go.sum index 0bd72f3..a93b158 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,7 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/danieljoos/wincred v1.1.2 h1:QLdCxFs1/Yl4zduvBdcHB8goaYk9RARS2SgLLRuAyr0= github.com/danieljoos/wincred v1.1.2/go.mod h1:GijpziifJoIBfYh+S7BbkdUTU4LfM+QnGqR5Vl2tAx0= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -126,6 +127,9 @@ github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxh github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -174,7 +178,6 @@ github.com/mtibben/percent v0.2.1 h1:5gssi8Nqo8QU/r2pynCm+hBQHpkB/uNK7BJCFogWdzs github.com/mtibben/percent v0.2.1/go.mod h1:KG9uO+SZkUp+VkRHsCdYQV3XSZrrSpR3O9ibNBTZrns= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/paulmach/orb v0.12.0 h1:z+zOwjmG3MyEEqzv92UN49Lg1JFYx0L9GpGKNVDKk1s= github.com/paulmach/orb v0.12.0/go.mod h1:5mULz1xQfs3bmQm63QEJA6lNGujuRafwA5S/EnuLaLU= @@ -183,6 +186,7 @@ github.com/pierrec/lz4/v4 v4.1.25 h1:kocOqRffaIbU5djlIBr7Wh+cx82C0vtFb0fOurZHqD0 github.com/pierrec/lz4/v4 v4.1.25/go.mod h1:EoQMVJgeeEOMsCqCzqFm2O0cJvljX2nGZjcRIPL34O4= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -200,6 +204,9 @@ github.com/richardlehane/msoleps v1.0.4/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTK github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/samber/lo v1.49.1 h1:4BIFyVfuQSEpluc7Fua+j1NolZHiEHEpaSEKdsH0tew= github.com/samber/lo v1.49.1/go.mod h1:dO6KHFzUKXgP8LDhU0oI8d2hekjXnGOu0DB8Jecxd6o= github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0= @@ -356,9 +363,9 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/ai/service/config_store.go b/internal/ai/service/config_store.go new file mode 100644 index 0000000..f1653f6 --- /dev/null +++ b/internal/ai/service/config_store.go @@ -0,0 +1,262 @@ +package aiservice + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "GoNavi-Wails/internal/ai" + "GoNavi-Wails/internal/logger" + "GoNavi-Wails/internal/secretstore" +) + +const ( + aiConfigSchemaVersion = 2 + aiConfigFileName = "ai_config.json" +) + +type aiConfig struct { + SchemaVersion int `json:"schemaVersion,omitempty"` + Providers []ai.ProviderConfig `json:"providers"` + ActiveProvider string `json:"activeProvider"` + SafetyLevel string `json:"safetyLevel"` + ContextLevel string `json:"contextLevel"` +} + +type ProviderConfigStoreSnapshot struct { + Providers []ai.ProviderConfig + ActiveProvider string + SafetyLevel ai.SQLPermissionLevel + ContextLevel ai.ContextLevel +} + +type ProviderConfigStoreInspection struct { + Snapshot ProviderConfigStoreSnapshot + ProvidersNeedingMigration []string +} + +type ProviderConfigStore struct { + configDir string + secretStore secretstore.SecretStore +} + +func NewProviderConfigStore(configDir string, store secretstore.SecretStore) *ProviderConfigStore { + if strings.TrimSpace(configDir) == "" { + configDir = resolveConfigDir() + } + if store == nil { + store = secretstore.NewUnavailableStore("secret store unavailable") + } + return &ProviderConfigStore{ + configDir: configDir, + secretStore: store, + } +} + +func newProviderConfigStore(configDir string, store secretstore.SecretStore) *ProviderConfigStore { + return NewProviderConfigStore(configDir, store) +} + +func (s *ProviderConfigStore) configPath() string { + return filepath.Join(s.configDir, aiConfigFileName) +} + +func (s *ProviderConfigStore) Load() (ProviderConfigStoreSnapshot, error) { + cfg, snapshot, err := s.readStoredSnapshot() + if err != nil { + return snapshot, err + } + + shouldRewrite := cfg.SchemaVersion != aiConfigSchemaVersion + providers := make([]ai.ProviderConfig, 0, len(snapshot.Providers)) + for _, providerConfig := range snapshot.Providers { + runtimeConfig, rewritten, loadErr := s.loadStoredProviderConfig(providerConfig) + if loadErr != nil { + return snapshot, fmt.Errorf("加载 AI Provider secret 失败(provider=%s): %w", providerConfig.ID, loadErr) + } + if rewritten { + shouldRewrite = true + } + providers = append(providers, runtimeConfig) + } + if providers == nil { + providers = []ai.ProviderConfig{} + } + snapshot.Providers = providers + + if shouldRewrite { + if err := s.Save(snapshot); err != nil { + return snapshot, fmt.Errorf("重写 AI 配置失败: %w", err) + } + } + + return snapshot, nil +} + +func (s *ProviderConfigStore) LoadRuntime() (ProviderConfigStoreSnapshot, error) { + _, snapshot, err := s.readStoredSnapshot() + if err != nil { + return snapshot, err + } + + providers := make([]ai.ProviderConfig, 0, len(snapshot.Providers)) + for _, providerConfig := range snapshot.Providers { + runtimeConfig, loadErr := s.loadRuntimeProviderConfig(providerConfig) + if loadErr != nil { + logger.Error(loadErr, "加载 AI Provider secret 失败,provider=%s", providerConfig.ID) + } + providers = append(providers, runtimeConfig) + } + if providers == nil { + providers = []ai.ProviderConfig{} + } + snapshot.Providers = providers + return snapshot, nil +} + +func (s *ProviderConfigStore) Inspect() (ProviderConfigStoreInspection, error) { + _, snapshot, err := s.readStoredSnapshot() + inspection := ProviderConfigStoreInspection{ + Snapshot: snapshot, + ProvidersNeedingMigration: []string{}, + } + if err != nil { + return inspection, err + } + + for _, providerConfig := range snapshot.Providers { + if providerNeedsMigration(providerConfig) { + inspection.ProvidersNeedingMigration = append(inspection.ProvidersNeedingMigration, providerConfig.ID) + } + } + + return inspection, nil +} + +func (s *ProviderConfigStore) Save(snapshot ProviderConfigStoreSnapshot) error { + providers := make([]ai.ProviderConfig, 0, len(snapshot.Providers)) + for _, providerConfig := range snapshot.Providers { + runtimeConfig := normalizeProviderConfig(providerConfig) + meta, bundle := splitProviderSecrets(runtimeConfig) + if bundle.hasAny() { + storedMeta, err := persistProviderSecretBundle(s.secretStore, meta, bundle) + if err != nil { + return fmt.Errorf("保存 Provider secret 失败: %w", err) + } + meta = storedMeta + } + providers = append(providers, providerMetadataView(meta)) + } + if providers == nil { + providers = []ai.ProviderConfig{} + } + + cfg := aiConfig{ + SchemaVersion: aiConfigSchemaVersion, + Providers: providers, + ActiveProvider: snapshot.ActiveProvider, + SafetyLevel: string(snapshot.SafetyLevel), + ContextLevel: string(snapshot.ContextLevel), + } + + data, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return fmt.Errorf("序列化 AI 配置失败: %w", err) + } + if err := os.MkdirAll(s.configDir, 0o755); err != nil { + return fmt.Errorf("创建配置目录失败: %w", err) + } + if err := os.WriteFile(s.configPath(), data, 0o644); err != nil { + return fmt.Errorf("写入 AI 配置失败: %w", err) + } + return nil +} + +func (s *ProviderConfigStore) readStoredSnapshot() (aiConfig, ProviderConfigStoreSnapshot, error) { + snapshot := ProviderConfigStoreSnapshot{ + Providers: []ai.ProviderConfig{}, + SafetyLevel: ai.PermissionReadOnly, + ContextLevel: ai.ContextSchemaOnly, + } + + data, err := os.ReadFile(s.configPath()) + if err != nil { + if os.IsNotExist(err) { + return aiConfig{}, snapshot, nil + } + return aiConfig{}, snapshot, fmt.Errorf("读取 AI 配置失败: %w", err) + } + + var cfg aiConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return aiConfig{}, snapshot, fmt.Errorf("加载 AI 配置失败: %w", err) + } + + snapshot.ActiveProvider = cfg.ActiveProvider + switch ai.SQLPermissionLevel(cfg.SafetyLevel) { + case ai.PermissionReadOnly, ai.PermissionReadWrite, ai.PermissionFull: + snapshot.SafetyLevel = ai.SQLPermissionLevel(cfg.SafetyLevel) + } + switch ai.ContextLevel(cfg.ContextLevel) { + case ai.ContextSchemaOnly, ai.ContextWithSamples, ai.ContextWithResults: + snapshot.ContextLevel = ai.ContextLevel(cfg.ContextLevel) + } + + providers := make([]ai.ProviderConfig, 0, len(cfg.Providers)) + for _, providerConfig := range cfg.Providers { + providers = append(providers, normalizeProviderConfig(providerConfig)) + } + if providers == nil { + providers = []ai.ProviderConfig{} + } + snapshot.Providers = providers + + return cfg, snapshot, nil +} + +func (s *ProviderConfigStore) loadStoredProviderConfig(config ai.ProviderConfig) (ai.ProviderConfig, bool, error) { + meta, bundle := splitProviderSecrets(config) + if bundle.hasAny() { + storedMeta, err := persistProviderSecretBundle(s.secretStore, meta, bundle) + if err != nil { + return meta, false, err + } + return mergeProviderSecrets(storedMeta, bundle), true, nil + } + + if !meta.HasSecret { + return meta, false, nil + } + + resolved, err := resolveProviderConfigSecrets(s.secretStore, meta) + if err != nil { + if os.IsNotExist(err) { + return meta, false, nil + } + return meta, false, err + } + return resolved, false, nil +} + +func (s *ProviderConfigStore) loadRuntimeProviderConfig(config ai.ProviderConfig) (ai.ProviderConfig, error) { + meta, bundle := splitProviderSecrets(config) + if bundle.hasAny() { + return mergeProviderSecrets(meta, bundle), nil + } + if !meta.HasSecret { + return meta, nil + } + + resolved, err := resolveProviderConfigSecrets(s.secretStore, meta) + if err != nil { + return meta, err + } + return resolved, nil +} + +func providerNeedsMigration(config ai.ProviderConfig) bool { + _, bundle := splitProviderSecrets(normalizeProviderConfig(config)) + return bundle.hasAny() +} diff --git a/internal/ai/service/config_store_test.go b/internal/ai/service/config_store_test.go new file mode 100644 index 0000000..2ca85a5 --- /dev/null +++ b/internal/ai/service/config_store_test.go @@ -0,0 +1,206 @@ +package aiservice + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "GoNavi-Wails/internal/ai" + "GoNavi-Wails/internal/secretstore" +) + +func TestProviderConfigStoreLoadMigratesPlaintextProviderSecrets(t *testing.T) { + store := newFakeProviderSecretStore() + configStore := newProviderConfigStore(t.TempDir(), store) + + legacy := aiConfig{ + Providers: []ai.ProviderConfig{ + { + ID: "openai-main", + Type: "openai", + Name: "OpenAI", + APIKey: "sk-test", + BaseURL: "https://api.openai.com/v1", + Headers: map[string]string{ + "Authorization": "Bearer test", + "X-Team": "platform", + }, + }, + }, + } + data, err := json.MarshalIndent(legacy, "", " ") + if err != nil { + t.Fatalf("MarshalIndent returned error: %v", err) + } + if err := os.WriteFile(filepath.Join(configStore.configDir, aiConfigFileName), data, 0o644); err != nil { + t.Fatalf("WriteFile returned error: %v", err) + } + + snapshot, err := configStore.Load() + if err != nil { + t.Fatalf("Load returned error: %v", err) + } + if len(snapshot.Providers) != 1 { + t.Fatalf("expected 1 provider, got %d", len(snapshot.Providers)) + } + if snapshot.Providers[0].APIKey != "sk-test" { + t.Fatalf("expected runtime provider to restore apiKey, got %q", snapshot.Providers[0].APIKey) + } + if snapshot.Providers[0].Headers["Authorization"] != "Bearer test" { + t.Fatalf("expected runtime provider to restore sensitive header, got %#v", snapshot.Providers[0].Headers) + } + + stored, err := store.Get(snapshot.Providers[0].SecretRef) + if err != nil { + t.Fatalf("expected migrated provider secret bundle, got %v", err) + } + var bundle providerSecretBundle + if err := json.Unmarshal(stored, &bundle); err != nil { + t.Fatalf("Unmarshal returned error: %v", err) + } + if bundle.APIKey != "sk-test" { + t.Fatalf("expected migrated apiKey in store, got %q", bundle.APIKey) + } + + rewritten, err := os.ReadFile(filepath.Join(configStore.configDir, aiConfigFileName)) + if err != nil { + t.Fatalf("ReadFile returned error: %v", err) + } + text := string(rewritten) + if strings.Contains(text, "sk-test") { + t.Fatalf("expected rewritten config to be secretless, got %s", text) + } + if strings.Contains(text, "Bearer test") { + t.Fatalf("expected rewritten config to remove sensitive headers, got %s", text) + } +} + +func TestProviderConfigStoreSavePersistsSecretlessMetadata(t *testing.T) { + store := newFakeProviderSecretStore() + configStore := newProviderConfigStore(t.TempDir(), store) + + err := configStore.Save(ProviderConfigStoreSnapshot{ + Providers: []ai.ProviderConfig{ + { + ID: "openai-main", + Type: "openai", + Name: "OpenAI", + APIKey: "sk-test", + BaseURL: "https://api.openai.com/v1", + Headers: map[string]string{ + "Authorization": "Bearer test", + "X-Team": "platform", + }, + }, + }, + ActiveProvider: "openai-main", + SafetyLevel: ai.PermissionReadOnly, + ContextLevel: ai.ContextSchemaOnly, + }) + if err != nil { + t.Fatalf("Save returned error: %v", err) + } + + configData, err := os.ReadFile(filepath.Join(configStore.configDir, aiConfigFileName)) + if err != nil { + t.Fatalf("ReadFile returned error: %v", err) + } + text := string(configData) + if strings.Contains(text, "sk-test") { + t.Fatalf("expected config file to be secretless, got %s", text) + } + if strings.Contains(text, "Bearer test") { + t.Fatalf("expected config file to remove sensitive headers, got %s", text) + } + + ref, err := secretstore.BuildRef(providerSecretKind, "openai-main") + if err != nil { + t.Fatalf("BuildRef returned error: %v", err) + } + stored, err := store.Get(ref) + if err != nil { + t.Fatalf("expected provider secret bundle in store, got %v", err) + } + var bundle providerSecretBundle + if err := json.Unmarshal(stored, &bundle); err != nil { + t.Fatalf("Unmarshal returned error: %v", err) + } + if bundle.APIKey != "sk-test" { + t.Fatalf("expected stored apiKey, got %q", bundle.APIKey) + } + if bundle.SensitiveHeaders["Authorization"] != "Bearer test" { + t.Fatalf("expected stored sensitive header, got %#v", bundle.SensitiveHeaders) + } +} + +func TestProviderConfigStoreSaveKeepsExistingSecretRef(t *testing.T) { + store := newFakeProviderSecretStore() + configStore := newProviderConfigStore(t.TempDir(), store) + + ref, err := secretstore.BuildRef(providerSecretKind, "openai-main") + if err != nil { + t.Fatalf("BuildRef returned error: %v", err) + } + payload, err := json.Marshal(providerSecretBundle{ + APIKey: "sk-existing", + SensitiveHeaders: map[string]string{ + "Authorization": "Bearer existing", + }, + }) + if err != nil { + t.Fatalf("Marshal returned error: %v", err) + } + if err := store.Put(ref, payload); err != nil { + t.Fatalf("Put returned error: %v", err) + } + + err = configStore.Save(ProviderConfigStoreSnapshot{ + Providers: []ai.ProviderConfig{ + { + ID: "openai-main", + Type: "openai", + Name: "OpenAI", + HasSecret: true, + SecretRef: ref, + BaseURL: "https://gateway.openai.com/v1", + Headers: map[string]string{ + "X-Team": "platform", + }, + }, + }, + ActiveProvider: "openai-main", + SafetyLevel: ai.PermissionReadOnly, + ContextLevel: ai.ContextSchemaOnly, + }) + if err != nil { + t.Fatalf("Save returned error: %v", err) + } + + stored, err := store.Get(ref) + if err != nil { + t.Fatalf("expected existing provider secret bundle to remain available, got %v", err) + } + var bundle providerSecretBundle + if err := json.Unmarshal(stored, &bundle); err != nil { + t.Fatalf("Unmarshal returned error: %v", err) + } + if bundle.APIKey != "sk-existing" { + t.Fatalf("expected existing apiKey to be kept, got %q", bundle.APIKey) + } + + snapshot, err := configStore.Load() + if err != nil { + t.Fatalf("Load returned error: %v", err) + } + if len(snapshot.Providers) != 1 { + t.Fatalf("expected 1 provider after reload, got %d", len(snapshot.Providers)) + } + if snapshot.Providers[0].APIKey != "sk-existing" { + t.Fatalf("expected reload to restore existing apiKey, got %q", snapshot.Providers[0].APIKey) + } + if snapshot.Providers[0].Headers["Authorization"] != "Bearer existing" { + t.Fatalf("expected reload to restore existing sensitive header, got %#v", snapshot.Providers[0].Headers) + } +} diff --git a/internal/ai/service/provider_secret.go b/internal/ai/service/provider_secret.go index 6fe22bc..5d116b4 100644 --- a/internal/ai/service/provider_secret.go +++ b/internal/ai/service/provider_secret.go @@ -120,17 +120,17 @@ func mergeProviderSecrets(cfg ai.ProviderConfig, bundle providerSecretBundle) ai return merged } -func (s *Service) persistProviderSecretBundle(meta ai.ProviderConfig, bundle providerSecretBundle) (ai.ProviderConfig, error) { +func persistProviderSecretBundle(store secretstore.SecretStore, meta ai.ProviderConfig, bundle providerSecretBundle) (ai.ProviderConfig, error) { meta, _ = splitProviderSecrets(meta) if !bundle.hasAny() { meta.HasSecret = false meta.SecretRef = "" return meta, nil } - if s.secretStore == nil { + if store == nil { return meta, fmt.Errorf("secret store unavailable") } - if err := s.secretStore.HealthCheck(); err != nil { + if err := store.HealthCheck(); err != nil { return meta, err } @@ -147,7 +147,7 @@ func (s *Service) persistProviderSecretBundle(meta ai.ProviderConfig, bundle pro if err != nil { return meta, fmt.Errorf("序列化 provider secret bundle 失败: %w", err) } - if err := s.secretStore.Put(ref, payload); err != nil { + if err := store.Put(ref, payload); err != nil { return meta, err } @@ -156,7 +156,7 @@ func (s *Service) persistProviderSecretBundle(meta ai.ProviderConfig, bundle pro return meta, nil } -func (s *Service) resolveProviderConfigSecrets(cfg ai.ProviderConfig) (ai.ProviderConfig, error) { +func resolveProviderConfigSecrets(store secretstore.SecretStore, cfg ai.ProviderConfig) (ai.ProviderConfig, error) { cfg = normalizeProviderConfig(cfg) meta, bundle := splitProviderSecrets(cfg) if bundle.hasAny() { @@ -165,7 +165,7 @@ func (s *Service) resolveProviderConfigSecrets(cfg ai.ProviderConfig) (ai.Provid if !meta.HasSecret { return meta, nil } - if s.secretStore == nil { + if store == nil { return meta, fmt.Errorf("secret store unavailable") } @@ -179,7 +179,7 @@ func (s *Service) resolveProviderConfigSecrets(cfg ai.ProviderConfig) (ai.Provid meta.SecretRef = ref } - payload, err := s.secretStore.Get(ref) + payload, err := store.Get(ref) if err != nil { return meta, err } @@ -191,6 +191,14 @@ func (s *Service) resolveProviderConfigSecrets(cfg ai.ProviderConfig) (ai.Provid return mergeProviderSecrets(meta, stored), nil } +func (s *Service) persistProviderSecretBundle(meta ai.ProviderConfig, bundle providerSecretBundle) (ai.ProviderConfig, error) { + return persistProviderSecretBundle(s.secretStore, meta, bundle) +} + +func (s *Service) resolveProviderConfigSecrets(cfg ai.ProviderConfig) (ai.ProviderConfig, error) { + return resolveProviderConfigSecrets(s.secretStore, cfg) +} + func providerMetadataView(cfg ai.ProviderConfig) ai.ProviderConfig { meta, _ := splitProviderSecrets(normalizeProviderConfig(cfg)) return meta diff --git a/internal/ai/service/provider_secret_test.go b/internal/ai/service/provider_secret_test.go index 033b24f..b80bdc9 100644 --- a/internal/ai/service/provider_secret_test.go +++ b/internal/ai/service/provider_secret_test.go @@ -82,7 +82,7 @@ func TestResolveProviderConfigSecretsRestoresStoredSecretBundle(t *testing.T) { } } -func TestLoadConfigMigratesPlaintextProviderSecrets(t *testing.T) { +func TestLoadConfigUsesPlaintextProviderSecretsWithoutSilentMigration(t *testing.T) { store := newFakeProviderSecretStore() service := NewServiceWithSecretStore(store) service.configDir = t.TempDir() @@ -118,24 +118,28 @@ func TestLoadConfigMigratesPlaintextProviderSecrets(t *testing.T) { t.Fatalf("expected 1 provider, got %d", len(providers)) } if providers[0].APIKey != "" { - t.Fatalf("expected migrated provider to be secretless, got %q", providers[0].APIKey) + t.Fatalf("expected provider view to stay secretless, got %q", providers[0].APIKey) } if !providers[0].HasSecret { - t.Fatal("expected migrated provider to report HasSecret=true") + t.Fatal("expected provider view to report HasSecret=true") } - stored, err := store.Get(providers[0].SecretRef) + + if len(service.providers) != 1 { + t.Fatalf("expected runtime providers to be loaded, got %d", len(service.providers)) + } + if service.providers[0].APIKey != "sk-test" { + t.Fatalf("expected runtime provider to keep plaintext apiKey, got %q", service.providers[0].APIKey) + } + if service.providers[0].Headers["Authorization"] != "Bearer test" { + t.Fatalf("expected runtime provider to keep sensitive header, got %#v", service.providers[0].Headers) + } + + ref, err := secretstore.BuildRef("ai-provider", "openai-main") if err != nil { - t.Fatalf("expected secret bundle in store, got error: %v", err) + t.Fatalf("BuildRef returned error: %v", err) } - var bundle providerSecretBundle - if err := json.Unmarshal(stored, &bundle); err != nil { - t.Fatalf("Unmarshal returned error: %v", err) - } - if bundle.APIKey != "sk-test" { - t.Fatalf("expected migrated apiKey in store, got %q", bundle.APIKey) - } - if bundle.SensitiveHeaders["Authorization"] != "Bearer test" { - t.Fatalf("expected migrated sensitive header in store, got %#v", bundle.SensitiveHeaders) + if _, err := store.Get(ref); !os.IsNotExist(err) { + t.Fatalf("expected startup load to avoid secret-store migration, got %v", err) } rewritten, err := os.ReadFile(configPath) @@ -143,11 +147,124 @@ func TestLoadConfigMigratesPlaintextProviderSecrets(t *testing.T) { t.Fatalf("ReadFile returned error: %v", err) } text := string(rewritten) - if strings.Contains(text, "sk-test") { - t.Fatalf("expected rewritten config to remove api key, got %s", text) + if !strings.Contains(text, "sk-test") { + t.Fatalf("expected config file to remain unchanged, got %s", text) } - if strings.Contains(text, "Bearer test") { - t.Fatalf("expected rewritten config to remove sensitive header, got %s", text) + if !strings.Contains(text, "Bearer test") { + t.Fatalf("expected config file to keep sensitive header, got %s", text) + } +} + +func TestAISaveProviderKeepsLegacyPlaintextSecretAfterStartupLoad(t *testing.T) { + store := newFakeProviderSecretStore() + service := NewServiceWithSecretStore(store) + service.configDir = t.TempDir() + + legacy := aiConfig{ + Providers: []ai.ProviderConfig{ + { + ID: "openai-main", + Type: "custom", + Name: "OpenAI", + APIKey: "sk-test", + BaseURL: "", + Headers: map[string]string{ + "Authorization": "Bearer test", + "X-Team": "db", + }, + }, + }, + } + data, err := json.MarshalIndent(legacy, "", " ") + if err != nil { + t.Fatalf("MarshalIndent returned error: %v", err) + } + if err := os.WriteFile(filepath.Join(service.configDir, aiConfigFileName), data, 0o644); err != nil { + t.Fatalf("WriteFile returned error: %v", err) + } + + service.loadConfig() + + if err := service.AISaveProvider(ai.ProviderConfig{ + ID: "openai-main", + Type: "custom", + Name: "OpenAI Updated", + BaseURL: "", + HasSecret: true, + Headers: map[string]string{ + "X-Team": "platform", + }, + }); err != nil { + t.Fatalf("AISaveProvider returned error: %v", err) + } + + if service.providers[0].APIKey != "sk-test" { + t.Fatalf("expected runtime provider to keep legacy plaintext apiKey, got %q", service.providers[0].APIKey) + } + if service.providers[0].Headers["Authorization"] != "Bearer test" { + t.Fatalf("expected runtime provider to keep legacy sensitive header, got %#v", service.providers[0].Headers) + } + + ref, err := secretstore.BuildRef("ai-provider", "openai-main") + if err != nil { + t.Fatalf("BuildRef returned error: %v", err) + } + stored, err := store.Get(ref) + if err != nil { + t.Fatalf("expected save to persist provider secret bundle, got %v", err) + } + var bundle providerSecretBundle + if err := json.Unmarshal(stored, &bundle); err != nil { + t.Fatalf("Unmarshal returned error: %v", err) + } + if bundle.APIKey != "sk-test" { + t.Fatalf("expected persisted apiKey, got %q", bundle.APIKey) + } +} + +func TestAITestProviderUsesLegacyPlaintextSecretAfterStartupLoad(t *testing.T) { + store := newFakeProviderSecretStore() + service := NewServiceWithSecretStore(store) + service.configDir = t.TempDir() + + legacy := aiConfig{ + Providers: []ai.ProviderConfig{ + { + ID: "openai-main", + Type: "custom", + Name: "OpenAI", + APIKey: "sk-test", + BaseURL: "", + Headers: map[string]string{ + "Authorization": "Bearer test", + "X-Team": "db", + }, + }, + }, + } + data, err := json.MarshalIndent(legacy, "", " ") + if err != nil { + t.Fatalf("MarshalIndent returned error: %v", err) + } + if err := os.WriteFile(filepath.Join(service.configDir, aiConfigFileName), data, 0o644); err != nil { + t.Fatalf("WriteFile returned error: %v", err) + } + + service.loadConfig() + + result := service.AITestProvider(ai.ProviderConfig{ + ID: "openai-main", + Type: "custom", + Name: "OpenAI", + BaseURL: "", + HasSecret: true, + Headers: map[string]string{ + "X-Team": "db", + }, + }) + + if success, _ := result["success"].(bool); !success { + t.Fatalf("expected test provider to use in-memory legacy secret, got %#v", result) } } diff --git a/internal/ai/service/service.go b/internal/ai/service/service.go index ca067dc..0bb6564 100644 --- a/internal/ai/service/service.go +++ b/internal/ai/service/service.go @@ -183,11 +183,16 @@ func (s *Service) AISaveProvider(config ai.ProviderConfig) error { case found && (config.HasSecret || existing.HasSecret): meta.SecretRef = existing.SecretRef meta.HasSecret = config.HasSecret || existing.HasSecret - resolved, err := s.resolveProviderConfigSecrets(meta) - if err != nil { - return fmt.Errorf("读取已保存 Provider secret 失败: %w", err) + meta, existingBundle := applyExistingRuntimeProviderSecrets(meta, existing) + if existingBundle.hasAny() { + runtimeConfig = mergeProviderSecrets(meta, existingBundle) + } else { + resolved, err := s.resolveProviderConfigSecrets(meta) + if err != nil { + return fmt.Errorf("读取已保存 Provider secret 失败: %w", err) + } + runtimeConfig = resolved } - runtimeConfig = resolved default: runtimeConfig = meta } @@ -257,22 +262,47 @@ func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{ } if strings.TrimSpace(config.APIKey) == "" && (config.HasSecret || strings.TrimSpace(config.SecretRef) != "") { s.mu.RLock() + var existing ai.ProviderConfig + found := false if strings.TrimSpace(config.SecretRef) == "" { for _, providerConfig := range s.providers { if providerConfig.ID == config.ID { + existing = providerConfig + found = true config.SecretRef = providerConfig.SecretRef config.HasSecret = config.HasSecret || providerConfig.HasSecret break } } + } else { + for _, providerConfig := range s.providers { + if providerConfig.ID == config.ID { + existing = providerConfig + found = true + break + } + } } s.mu.RUnlock() - resolved, err := s.resolveProviderConfigSecrets(config) - if err != nil { - return map[string]interface{}{"success": false, "message": fmt.Sprintf("连接测试失败: %s", err.Error())} + if found { + config, existingBundle := applyExistingRuntimeProviderSecrets(config, existing) + if existingBundle.hasAny() { + config = mergeProviderSecrets(config, existingBundle) + } else { + resolved, err := s.resolveProviderConfigSecrets(config) + if err != nil { + return map[string]interface{}{"success": false, "message": fmt.Sprintf("连接测试失败: %s", err.Error())} + } + config = resolved + } + } else { + resolved, err := s.resolveProviderConfigSecrets(config) + if err != nil { + return map[string]interface{}{"success": false, "message": fmt.Sprintf("连接测试失败: %s", err.Error())} + } + config = resolved } - config = resolved } config = normalizeProviderConfig(config) @@ -462,6 +492,15 @@ func normalizeProviderConfig(config ai.ProviderConfig) ai.ProviderConfig { return config } +func applyExistingRuntimeProviderSecrets(meta ai.ProviderConfig, existing ai.ProviderConfig) (ai.ProviderConfig, providerSecretBundle) { + existingMeta, existingBundle := splitProviderSecrets(normalizeProviderConfig(existing)) + if strings.TrimSpace(meta.SecretRef) == "" { + meta.SecretRef = strings.TrimSpace(existingMeta.SecretRef) + } + meta.HasSecret = meta.HasSecret || existingMeta.HasSecret || existingBundle.hasAny() + return meta, existingBundle +} + func resolveModelsURL(config ai.ProviderConfig) string { config = normalizeProviderConfig(config) providerType := normalizedProviderType(config) @@ -919,117 +958,27 @@ func (s *Service) getActiveProvider() (provider.Provider, error) { // --- 配置持久化 --- -const aiConfigSchemaVersion = 2 - -type aiConfig struct { - SchemaVersion int `json:"schemaVersion,omitempty"` - Providers []ai.ProviderConfig `json:"providers"` - ActiveProvider string `json:"activeProvider"` - SafetyLevel string `json:"safetyLevel"` - ContextLevel string `json:"contextLevel"` -} - -func (s *Service) loadRuntimeProviderConfig(config ai.ProviderConfig) (ai.ProviderConfig, bool, error) { - meta, bundle := splitProviderSecrets(config) - if bundle.hasAny() { - storedMeta, err := s.persistProviderSecretBundle(meta, bundle) - if err != nil { - meta.HasSecret = false - meta.SecretRef = "" - return meta, true, err - } - return mergeProviderSecrets(storedMeta, bundle), true, nil - } - - resolved, err := s.resolveProviderConfigSecrets(meta) - if err != nil { - return meta, false, err - } - return resolved, false, nil -} - func (s *Service) loadConfig() { - path := filepath.Join(s.configDir, "ai_config.json") - data, err := os.ReadFile(path) + snapshot, err := NewProviderConfigStore(s.configDir, s.secretStore).LoadRuntime() if err != nil { - return // 首次启动,无配置文件 - } - - var cfg aiConfig - if err := json.Unmarshal(data, &cfg); err != nil { logger.Error(err, "加载 AI 配置失败") return } - providers := make([]ai.ProviderConfig, 0, len(cfg.Providers)) - shouldRewrite := cfg.SchemaVersion != aiConfigSchemaVersion - for _, providerConfig := range cfg.Providers { - runtimeConfig, rewritten, err := s.loadRuntimeProviderConfig(normalizeProviderConfig(providerConfig)) - if err != nil { - logger.Error(err, "加载 AI Provider secret 失败,provider=%s", providerConfig.ID) - } - if rewritten { - shouldRewrite = true - } - providers = append(providers, runtimeConfig) - } - if providers == nil { - providers = make([]ai.ProviderConfig, 0) - } - s.providers = providers - s.activeProvider = cfg.ActiveProvider - - switch ai.SQLPermissionLevel(cfg.SafetyLevel) { - case ai.PermissionReadOnly, ai.PermissionReadWrite, ai.PermissionFull: - s.safetyLevel = ai.SQLPermissionLevel(cfg.SafetyLevel) - default: - s.safetyLevel = ai.PermissionReadOnly - } + s.providers = snapshot.Providers + s.activeProvider = snapshot.ActiveProvider + s.safetyLevel = snapshot.SafetyLevel s.guard.SetPermissionLevel(s.safetyLevel) - - switch ai.ContextLevel(cfg.ContextLevel) { - case ai.ContextSchemaOnly, ai.ContextWithSamples, ai.ContextWithResults: - s.contextLevel = ai.ContextLevel(cfg.ContextLevel) - default: - s.contextLevel = ai.ContextSchemaOnly - } - - if shouldRewrite { - if err := s.saveConfig(); err != nil { - logger.Error(err, "重写 AI 配置失败") - } - } + s.contextLevel = snapshot.ContextLevel } func (s *Service) saveConfig() error { - providers := make([]ai.ProviderConfig, len(s.providers)) - for i := range s.providers { - providers[i] = providerMetadataView(s.providers[i]) - } - - cfg := aiConfig{ - SchemaVersion: aiConfigSchemaVersion, - Providers: providers, + return NewProviderConfigStore(s.configDir, s.secretStore).Save(ProviderConfigStoreSnapshot{ + Providers: s.providers, ActiveProvider: s.activeProvider, - SafetyLevel: string(s.safetyLevel), - ContextLevel: string(s.contextLevel), - } - - data, err := json.MarshalIndent(cfg, "", " ") - if err != nil { - return fmt.Errorf("序列化 AI 配置失败: %w", err) - } - - if err := os.MkdirAll(s.configDir, 0o755); err != nil { - return fmt.Errorf("创建配置目录失败: %w", err) - } - - path := filepath.Join(s.configDir, "ai_config.json") - if err := os.WriteFile(path, data, 0o644); err != nil { - return fmt.Errorf("写入 AI 配置失败: %w", err) - } - - return nil + SafetyLevel: s.safetyLevel, + ContextLevel: s.contextLevel, + }) } // --- 会话文件持久化 --- diff --git a/internal/app/connection_package_crypto.go b/internal/app/connection_package_crypto.go new file mode 100644 index 0000000..e844144 --- /dev/null +++ b/internal/app/connection_package_crypto.go @@ -0,0 +1,228 @@ +package app + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "strings" + + "golang.org/x/crypto/argon2" +) + +const ( + connectionPackageAES256KeyBytes = 32 + connectionPackageSaltBytes = 16 + connectionPackageNonceBytes = 12 +) + +type connectionPackageAAD struct { + SchemaVersion int `json:"schemaVersion"` + Kind string `json:"kind"` + Cipher string `json:"cipher"` + KDF connectionPackageKDFSpec `json:"kdf"` + Nonce string `json:"nonce"` +} + +func encryptConnectionPackage(payload connectionPackagePayload, password string) (connectionPackageFile, error) { + normalizedPassword := normalizeConnectionPackagePassword(password) + if normalizedPassword == "" { + return connectionPackageFile{}, errConnectionPackagePasswordRequired + } + + plain, err := json.Marshal(payload) + if err != nil { + return connectionPackageFile{}, err + } + + salt := make([]byte, connectionPackageSaltBytes) + if _, err := rand.Read(salt); err != nil { + return connectionPackageFile{}, err + } + nonce := make([]byte, connectionPackageNonceBytes) + if _, err := rand.Read(nonce); err != nil { + return connectionPackageFile{}, err + } + + file := connectionPackageFile{ + SchemaVersion: connectionPackageSchemaVersion, + Kind: connectionPackageKind, + Cipher: connectionPackageCipher, + KDF: defaultConnectionPackageKDFSpec(), + Nonce: base64.StdEncoding.EncodeToString(nonce), + } + file.KDF.Salt = base64.StdEncoding.EncodeToString(salt) + + key, err := deriveConnectionPackageKey(normalizedPassword, file.KDF) + if err != nil { + return connectionPackageFile{}, err + } + aad, err := marshalConnectionPackageAAD(file) + if err != nil { + return connectionPackageFile{}, err + } + aead, err := newConnectionPackageAEAD(key) + if err != nil { + return connectionPackageFile{}, err + } + + ciphertext := aead.Seal(nil, nonce, plain, aad) + file.Payload = base64.StdEncoding.EncodeToString(ciphertext) + return file, nil +} + +func decryptConnectionPackage(file connectionPackageFile, password string) (connectionPackagePayload, error) { + normalizedPassword := normalizeConnectionPackagePassword(password) + if normalizedPassword == "" { + return connectionPackagePayload{}, errConnectionPackagePasswordRequired + } + if err := validateConnectionPackageFileHeader(file); err != nil { + return connectionPackagePayload{}, err + } + + plain, err := decryptConnectionPackagePlaintext(file, normalizedPassword) + if err != nil { + return connectionPackagePayload{}, errConnectionPackageDecryptFailed + } + + var payload connectionPackagePayload + if err := json.Unmarshal(plain, &payload); err != nil { + return connectionPackagePayload{}, errConnectionPackageDecryptFailed + } + return payload, nil +} + +func isConnectionPackageEnvelope(raw string) bool { + file, err := decodeConnectionPackageEnvelope(raw) + if err != nil { + return false + } + return file.Kind == connectionPackageKind +} + +func encodeConnectionPackageEnvelope(file connectionPackageFile) (string, error) { + raw, err := json.Marshal(file) + if err != nil { + return "", err + } + return string(raw), nil +} + +func decodeConnectionPackageEnvelope(raw string) (connectionPackageFile, error) { + var file connectionPackageFile + if err := json.Unmarshal([]byte(raw), &file); err != nil { + return connectionPackageFile{}, err + } + return file, nil +} + +func decryptConnectionPackagePlaintext(file connectionPackageFile, password string) ([]byte, error) { + if err := validateConnectionPackageFileHeader(file); err != nil { + return nil, err + } + + nonce, err := base64.StdEncoding.DecodeString(file.Nonce) + if err != nil || len(nonce) != connectionPackageNonceBytes { + return nil, errors.New("invalid nonce") + } + ciphertext, err := base64.StdEncoding.DecodeString(file.Payload) + if err != nil || len(ciphertext) == 0 { + return nil, errors.New("invalid payload") + } + + key, err := deriveConnectionPackageKey(password, file.KDF) + if err != nil { + return nil, err + } + aad, err := marshalConnectionPackageAAD(file) + if err != nil { + return nil, err + } + aead, err := newConnectionPackageAEAD(key) + if err != nil { + return nil, err + } + + plain, err := aead.Open(nil, nonce, ciphertext, aad) + if err != nil { + return nil, err + } + return plain, nil +} + +func deriveConnectionPackageKey(password string, spec connectionPackageKDFSpec) ([]byte, error) { + if password == "" { + return nil, errConnectionPackagePasswordRequired + } + if err := validateConnectionPackageKDFSpec(spec); err != nil { + return nil, err + } + + salt, err := base64.StdEncoding.DecodeString(spec.Salt) + if err != nil || len(salt) == 0 { + return nil, errors.New("invalid salt") + } + + key := argon2.IDKey( + []byte(password), + salt, + spec.TimeCost, + spec.MemoryKiB, + spec.Parallelism, + connectionPackageAES256KeyBytes, + ) + return key, nil +} + +func marshalConnectionPackageAAD(file connectionPackageFile) ([]byte, error) { + aad := connectionPackageAAD{ + SchemaVersion: file.SchemaVersion, + Kind: file.Kind, + Cipher: file.Cipher, + KDF: file.KDF, + Nonce: file.Nonce, + } + return json.Marshal(aad) +} + +func newConnectionPackageAEAD(key []byte) (cipher.AEAD, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return cipher.NewGCM(block) +} + +func validateConnectionPackageFileHeader(file connectionPackageFile) error { + switch { + case file.SchemaVersion != connectionPackageSchemaVersion: + return errConnectionPackageUnsupported + case strings.TrimSpace(file.Kind) != connectionPackageKind: + return errConnectionPackageUnsupported + case strings.TrimSpace(file.Cipher) != connectionPackageCipher: + return errConnectionPackageUnsupported + case validateConnectionPackageKDFSpec(file.KDF) != nil: + return errConnectionPackageUnsupported + default: + return nil + } +} + +func validateConnectionPackageKDFSpec(spec connectionPackageKDFSpec) error { + switch { + case strings.TrimSpace(spec.Name) != connectionPackageKDFName: + return errConnectionPackageUnsupported + case spec.MemoryKiB == 0 || spec.TimeCost == 0 || spec.Parallelism == 0: + return errConnectionPackageUnsupported + case spec.MemoryKiB > connectionPackageKDFMaxMemoryKiB: + return errConnectionPackageUnsupported + case spec.TimeCost > connectionPackageKDFMaxTimeCost: + return errConnectionPackageUnsupported + case spec.Parallelism > connectionPackageKDFMaxParallelism: + return errConnectionPackageUnsupported + default: + return nil + } +} diff --git a/internal/app/connection_package_crypto_test.go b/internal/app/connection_package_crypto_test.go new file mode 100644 index 0000000..b1368e0 --- /dev/null +++ b/internal/app/connection_package_crypto_test.go @@ -0,0 +1,224 @@ +package app + +import ( + "encoding/json" + "errors" + "reflect" + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestConnectionPackageCryptoRoundTrip(t *testing.T) { + payload := connectionPackagePayload{ + ExportedAt: "2026-04-10T12:00:00+08:00", + Connections: []connectionPackageItem{ + { + ID: "conn-1", + Name: "local-mysql", + IncludeDatabases: []string{"app"}, + IconType: "database", + IconColor: "#2f855a", + Config: connection.ConnectionConfig{ + Type: "mysql", + Host: "127.0.0.1", + Port: 3306, + User: "root", + Database: "app", + }, + }, + }, + } + + file, err := encryptConnectionPackage(payload, "strong-password") + if err != nil { + t.Fatalf("encryptConnectionPackage returned error: %v", err) + } + + raw, err := json.Marshal(file) + if err != nil { + t.Fatalf("json.Marshal envelope returned error: %v", err) + } + if !isConnectionPackageEnvelope(string(raw)) { + t.Fatalf("isConnectionPackageEnvelope should return true for valid envelope") + } + + var decoded connectionPackageFile + if err := json.Unmarshal(raw, &decoded); err != nil { + t.Fatalf("json.Unmarshal envelope returned error: %v", err) + } + + got, err := decryptConnectionPackage(decoded, "strong-password") + if err != nil { + t.Fatalf("decryptConnectionPackage returned error: %v", err) + } + if !reflect.DeepEqual(got, payload) { + t.Fatalf("round-trip mismatch: got=%+v want=%+v", got, payload) + } +} + +func TestConnectionPackageDecryptWrongPasswordReturnsUnifiedError(t *testing.T) { + payload := connectionPackagePayload{ + Connections: []connectionPackageItem{ + { + ID: "conn-1", + Name: "test", + Config: connection.ConnectionConfig{ + Type: "mysql", + }, + }, + }, + } + + file, err := encryptConnectionPackage(payload, "correct-password") + if err != nil { + t.Fatalf("encryptConnectionPackage returned error: %v", err) + } + + _, err = decryptConnectionPackage(file, "wrong-password") + if !errors.Is(err, errConnectionPackageDecryptFailed) { + t.Fatalf("wrong password should return unified error, got: %v", err) + } +} + +func TestConnectionPackageDecryptTamperedHeaderFailsAADValidation(t *testing.T) { + payload := connectionPackagePayload{ + Connections: []connectionPackageItem{ + { + ID: "conn-1", + Name: "test", + Config: connection.ConnectionConfig{ + Type: "mysql", + }, + }, + }, + } + + file, err := encryptConnectionPackage(payload, "correct-password") + if err != nil { + t.Fatalf("encryptConnectionPackage returned error: %v", err) + } + + t.Run("cipher", func(t *testing.T) { + tampered := file + tampered.Nonce = "AAAAAAAAAAAAAAAA" + _, err := decryptConnectionPackage(tampered, "correct-password") + if !errors.Is(err, errConnectionPackageDecryptFailed) { + t.Fatalf("tampered nonce should fail with unified error, got: %v", err) + } + }) + + t.Run("kdf-salt", func(t *testing.T) { + tampered := file + tampered.KDF.Salt = "AAAAAAAAAAAAAAAAAAAAAA==" + _, err := decryptConnectionPackage(tampered, "correct-password") + if !errors.Is(err, errConnectionPackageDecryptFailed) { + t.Fatalf("tampered kdf salt should fail with unified error, got: %v", err) + } + }) +} + +func TestConnectionPackagePasswordRequired(t *testing.T) { + payload := connectionPackagePayload{ + Connections: []connectionPackageItem{ + { + ID: "conn-1", + Name: "test", + Config: connection.ConnectionConfig{ + Type: "mysql", + }, + }, + }, + } + + _, err := encryptConnectionPackage(payload, " ") + if !errors.Is(err, errConnectionPackagePasswordRequired) { + t.Fatalf("encryptConnectionPackage should return password required error, got: %v", err) + } + + _, err = decryptConnectionPackage(connectionPackageFile{}, " ") + if !errors.Is(err, errConnectionPackagePasswordRequired) { + t.Fatalf("decryptConnectionPackage should return password required error, got: %v", err) + } +} + +func TestConnectionPackageDecryptUnsupportedHeaderReturnsUnsupportedError(t *testing.T) { + payload := connectionPackagePayload{ + Connections: []connectionPackageItem{ + { + ID: "conn-1", + Name: "test", + Config: connection.ConnectionConfig{ + Type: "mysql", + }, + }, + }, + } + + file, err := encryptConnectionPackage(payload, "correct-password") + if err != nil { + t.Fatalf("encryptConnectionPackage returned error: %v", err) + } + + t.Run("schemaVersion", func(t *testing.T) { + tampered := file + tampered.SchemaVersion = tampered.SchemaVersion + 1 + _, err := decryptConnectionPackage(tampered, "correct-password") + if !errors.Is(err, errConnectionPackageUnsupported) { + t.Fatalf("unsupported schemaVersion should return unsupported error, got: %v", err) + } + }) + + t.Run("kind", func(t *testing.T) { + tampered := file + tampered.Kind = "other_connection_package" + _, err := decryptConnectionPackage(tampered, "correct-password") + if !errors.Is(err, errConnectionPackageUnsupported) { + t.Fatalf("unsupported kind should return unsupported error, got: %v", err) + } + }) + + t.Run("cipher", func(t *testing.T) { + tampered := file + tampered.Cipher = "AES-128-GCM" + _, err := decryptConnectionPackage(tampered, "correct-password") + if !errors.Is(err, errConnectionPackageUnsupported) { + t.Fatalf("unsupported cipher should return unsupported error, got: %v", err) + } + }) + + t.Run("kdf-name", func(t *testing.T) { + tampered := file + tampered.KDF.Name = "PBKDF2" + _, err := decryptConnectionPackage(tampered, "correct-password") + if !errors.Is(err, errConnectionPackageUnsupported) { + t.Fatalf("unsupported kdf name should return unsupported error, got: %v", err) + } + }) +} + +func TestValidateConnectionPackageKDFSpecRejectsOversizedParams(t *testing.T) { + t.Run("memory", func(t *testing.T) { + spec := defaultConnectionPackageKDFSpec() + spec.MemoryKiB = connectionPackageKDFMaxMemoryKiB + 1 + if err := validateConnectionPackageKDFSpec(spec); !errors.Is(err, errConnectionPackageUnsupported) { + t.Fatalf("oversized memory should return unsupported error, got: %v", err) + } + }) + + t.Run("timeCost", func(t *testing.T) { + spec := defaultConnectionPackageKDFSpec() + spec.TimeCost = connectionPackageKDFMaxTimeCost + 1 + if err := validateConnectionPackageKDFSpec(spec); !errors.Is(err, errConnectionPackageUnsupported) { + t.Fatalf("oversized timeCost should return unsupported error, got: %v", err) + } + }) + + t.Run("parallelism", func(t *testing.T) { + spec := defaultConnectionPackageKDFSpec() + spec.Parallelism = connectionPackageKDFMaxParallelism + 1 + if err := validateConnectionPackageKDFSpec(spec); !errors.Is(err, errConnectionPackageUnsupported) { + t.Fatalf("oversized parallelism should return unsupported error, got: %v", err) + } + }) +} diff --git a/internal/app/connection_package_transfer.go b/internal/app/connection_package_transfer.go new file mode 100644 index 0000000..4cd47e3 --- /dev/null +++ b/internal/app/connection_package_transfer.go @@ -0,0 +1,229 @@ +package app + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/secretstore" +) + +func newConnectionPackageItem(view connection.SavedConnectionView, bundle connectionSecretBundle) connectionPackageItem { + return connectionPackageItem{ + ID: view.ID, + Name: view.Name, + IncludeDatabases: cloneStringSlice(view.IncludeDatabases), + IncludeRedisDatabases: cloneIntSlice(view.IncludeRedisDatabases), + IconType: view.IconType, + IconColor: view.IconColor, + Config: view.Config, + Secrets: bundle, + } +} + +func (a *App) buildConnectionPackagePayload() (connectionPackagePayload, error) { + repo := a.savedConnectionRepository() + items, err := repo.List() + if err != nil { + return connectionPackagePayload{}, err + } + + connections := make([]connectionPackageItem, 0, len(items)) + for _, item := range items { + bundle, bundleErr := repo.loadSecretBundle(item) + if bundleErr != nil { + return connectionPackagePayload{}, bundleErr + } + connections = append(connections, newConnectionPackageItem(item, bundle)) + } + + return connectionPackagePayload{ + ExportedAt: time.Now().UTC().Format(time.RFC3339), + Connections: connections, + }, nil +} + +func newSavedConnectionInputFromPackageItem(item connectionPackageItem) connection.SavedConnectionInput { + id := strings.TrimSpace(item.ID) + if id == "" { + id = strings.TrimSpace(item.Config.ID) + } + + config := item.Config + config.ID = id + config.SavePassword = false + + secrets := item.Secrets + config.Password = secrets.Password + config.SSH.Password = secrets.SSHPassword + config.Proxy.Password = secrets.ProxyPassword + config.HTTPTunnel.Password = secrets.HTTPTunnelPassword + config.MySQLReplicaPassword = secrets.MySQLReplicaPassword + config.MongoReplicaPassword = secrets.MongoReplicaPassword + config.URI = secrets.OpaqueURI + config.DSN = secrets.OpaqueDSN + + return connection.SavedConnectionInput{ + ID: id, + Name: item.Name, + Config: config, + IncludeDatabases: cloneStringSlice(item.IncludeDatabases), + IncludeRedisDatabases: cloneIntSlice(item.IncludeRedisDatabases), + IconType: item.IconType, + IconColor: item.IconColor, + // 连接恢复包以最新导入文件为准;载荷中缺失的密文字段需要显式清空旧值。 + ClearPrimaryPassword: strings.TrimSpace(secrets.Password) == "", + ClearSSHPassword: strings.TrimSpace(secrets.SSHPassword) == "", + ClearProxyPassword: strings.TrimSpace(secrets.ProxyPassword) == "", + ClearHTTPTunnelPassword: strings.TrimSpace(secrets.HTTPTunnelPassword) == "", + ClearMySQLReplicaPassword: strings.TrimSpace(secrets.MySQLReplicaPassword) == "", + ClearMongoReplicaPassword: strings.TrimSpace(secrets.MongoReplicaPassword) == "", + ClearOpaqueURI: strings.TrimSpace(secrets.OpaqueURI) == "", + ClearOpaqueDSN: strings.TrimSpace(secrets.OpaqueDSN) == "", + } +} + +func (a *App) importConnectionPackagePayload(payload connectionPackagePayload) ([]connection.SavedConnectionView, error) { + repo := a.savedConnectionRepository() + rollbackSnapshot, err := captureConnectionPackageImportRollbackSnapshot(a, payload) + if err != nil { + return nil, err + } + + result := make([]connection.SavedConnectionView, 0, len(payload.Connections)) + for _, item := range payload.Connections { + view, err := repo.Save(newSavedConnectionInputFromPackageItem(item)) + if err != nil { + if rollbackErr := rollbackSnapshot.restore(a); rollbackErr != nil { + return nil, errors.Join(err, fmt.Errorf("restore connection package rollback: %w", rollbackErr)) + } + return nil, err + } + result = append(result, view) + } + return result, nil +} + +func (a *App) ImportConnectionsPayload(raw string, password string) ([]connection.SavedConnectionView, error) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return nil, errConnectionPackageUnsupported + } + + if isConnectionPackageEnvelope(trimmed) { + var file connectionPackageFile + if err := json.Unmarshal([]byte(trimmed), &file); err != nil { + return nil, errConnectionPackageUnsupported + } + payload, err := decryptConnectionPackage(file, password) + if err != nil { + return nil, err + } + return a.importConnectionPackagePayload(payload) + } + + var legacy []connection.LegacySavedConnection + if err := json.Unmarshal([]byte(trimmed), &legacy); err != nil { + return nil, errConnectionPackageUnsupported + } + return a.ImportLegacyConnections(legacy) +} + +type connectionPackageImportRollbackSnapshot struct { + connectionsFileExists bool + connectionsFileData []byte + connectionSecrets map[string]securityUpdateSecretSnapshot + connectionCleanupRefs []string +} + +func captureConnectionPackageImportRollbackSnapshot(a *App, payload connectionPackagePayload) (connectionPackageImportRollbackSnapshot, error) { + snapshot := connectionPackageImportRollbackSnapshot{ + connectionSecrets: make(map[string]securityUpdateSecretSnapshot), + } + + repo := a.savedConnectionRepository() + connectionFileData, connectionFileExists, err := readOptionalFile(repo.connectionsPath()) + if err != nil { + return snapshot, err + } + snapshot.connectionsFileExists = connectionFileExists + snapshot.connectionsFileData = connectionFileData + + existingConnections, err := repo.load() + if err != nil { + return snapshot, err + } + existingConnectionsByID := make(map[string]connection.SavedConnectionView, len(existingConnections)) + for _, item := range existingConnections { + existingConnectionsByID[item.ID] = item + } + + cleanupSet := make(map[string]struct{}) + seenIDs := make(map[string]struct{}) + for _, item := range payload.Connections { + input := newSavedConnectionInputFromPackageItem(item) + connectionID := strings.TrimSpace(input.ID) + if connectionID == "" { + continue + } + if _, alreadySeen := seenIDs[connectionID]; alreadySeen { + continue + } + seenIDs[connectionID] = struct{}{} + + defaultRef, refErr := secretstore.BuildRef(savedConnectionSecretKind, connectionID) + if refErr == nil { + cleanupSet[defaultRef] = struct{}{} + } + + existing, ok := existingConnectionsByID[connectionID] + if !ok || !savedConnectionViewHasSecrets(existing) { + continue + } + + ref := strings.TrimSpace(existing.SecretRef) + if ref == "" { + ref = defaultRef + } + if ref == "" { + continue + } + + secretSnapshot, captureErr := captureSecurityUpdateSecretSnapshot(a.secretStore, ref) + if captureErr != nil { + return snapshot, captureErr + } + snapshot.connectionSecrets[ref] = secretSnapshot + cleanupSet[ref] = struct{}{} + } + + snapshot.connectionCleanupRefs = make([]string, 0, len(cleanupSet)) + for ref := range cleanupSet { + snapshot.connectionCleanupRefs = append(snapshot.connectionCleanupRefs, ref) + } + return snapshot, nil +} + +func (s connectionPackageImportRollbackSnapshot) restore(a *App) error { + repo := a.savedConnectionRepository() + if err := restoreOptionalFile(repo.connectionsPath(), s.connectionsFileExists, s.connectionsFileData); err != nil { + return err + } + for ref, secretSnapshot := range s.connectionSecrets { + if err := restoreSecurityUpdateSecretSnapshot(a.secretStore, ref, secretSnapshot); err != nil { + return err + } + } + for _, ref := range s.connectionCleanupRefs { + if _, alreadyRestored := s.connectionSecrets[ref]; alreadyRestored { + continue + } + if err := deleteSecurityUpdateSecretRef(a.secretStore, ref); err != nil { + return err + } + } + return nil +} diff --git a/internal/app/connection_package_transfer_test.go b/internal/app/connection_package_transfer_test.go new file mode 100644 index 0000000..a988266 --- /dev/null +++ b/internal/app/connection_package_transfer_test.go @@ -0,0 +1,529 @@ +package app + +import ( + "encoding/json" + "errors" + "os" + "strings" + "testing" + "time" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/secretstore" +) + +func TestBuildConnectionPackagePayloadIncludesSecretBundles(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + _, err := app.SaveConnection(connection.SavedConnectionInput{ + ID: "conn-1", + Name: "Primary", + Config: connection.ConnectionConfig{ + ID: "conn-1", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + Password: "db-secret", + UseSSH: true, + SSH: connection.SSHConfig{ + Host: "jump.local", + Port: 22, + User: "ops", + Password: "ssh-secret", + }, + URI: "postgres://postgres:db-secret@db.local/app", + }, + }) + if err != nil { + t.Fatalf("SaveConnection returned error: %v", err) + } + + payload, err := app.buildConnectionPackagePayload() + if err != nil { + t.Fatalf("buildConnectionPackagePayload returned error: %v", err) + } + if _, parseErr := time.Parse(time.RFC3339, payload.ExportedAt); parseErr != nil { + t.Fatalf("expected RFC3339 exportedAt, got %q", payload.ExportedAt) + } + if len(payload.Connections) != 1 { + t.Fatalf("expected 1 connection in payload, got %d", len(payload.Connections)) + } + + item := payload.Connections[0] + if item.ID != "conn-1" { + t.Fatalf("expected ID=conn-1, got %q", item.ID) + } + if item.Config.Password != "" { + t.Fatalf("payload metadata must stay secretless, got password=%q", item.Config.Password) + } + if item.Config.SSH.Password != "" { + t.Fatalf("payload metadata must stay secretless for SSH, got %q", item.Config.SSH.Password) + } + if item.Config.URI != "" { + t.Fatalf("payload metadata must stay secretless for URI, got %q", item.Config.URI) + } + if item.Secrets.Password != "db-secret" { + t.Fatalf("expected bundled primary password, got %q", item.Secrets.Password) + } + if item.Secrets.SSHPassword != "ssh-secret" { + t.Fatalf("expected bundled SSH password, got %q", item.Secrets.SSHPassword) + } + if item.Secrets.OpaqueURI != "postgres://postgres:db-secret@db.local/app" { + t.Fatalf("expected bundled URI secret, got %q", item.Secrets.OpaqueURI) + } + + raw, err := json.Marshal(payload) + if err != nil { + t.Fatalf("json.Marshal returned error: %v", err) + } + if strings.Contains(string(raw), "secretRef") { + t.Fatalf("payload must not contain secretRef, got %s", string(raw)) + } +} + +func TestImportConnectionPackagePayloadOverwritesExistingSecrets(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + _, err := app.SaveConnection(connection.SavedConnectionInput{ + ID: "conn-1", + Name: "Primary", + Config: connection.ConnectionConfig{ + ID: "conn-1", + Type: "postgres", + Host: "db.old.local", + Port: 5432, + User: "postgres", + Password: "old-primary", + UseSSH: true, + SSH: connection.SSHConfig{ + Host: "jump.old.local", + Port: 22, + User: "ops", + Password: "old-ssh", + }, + URI: "postgres://old", + }, + }) + if err != nil { + t.Fatalf("SaveConnection returned error: %v", err) + } + + imported, err := app.importConnectionPackagePayload(connectionPackagePayload{ + Connections: []connectionPackageItem{ + { + ID: "conn-1", + Name: "Imported", + Config: connection.ConnectionConfig{ + ID: "conn-1", + Type: "postgres", + Host: "db.new.local", + Port: 5432, + User: "postgres", + UseSSH: true, + SSH: connection.SSHConfig{ + Host: "jump.new.local", + Port: 22, + User: "ops", + }, + }, + Secrets: connectionSecretBundle{ + Password: "new-primary", + }, + }, + }, + }) + if err != nil { + t.Fatalf("importConnectionPackagePayload returned error: %v", err) + } + if len(imported) != 1 { + t.Fatalf("expected 1 imported item, got %d", len(imported)) + } + if imported[0].Name != "Imported" { + t.Fatalf("expected imported name, got %q", imported[0].Name) + } + if !imported[0].HasPrimaryPassword { + t.Fatal("expected primary password to be present after overwrite") + } + if imported[0].HasSSHPassword { + t.Fatal("expected SSH password to be cleared by package overwrite") + } + if imported[0].HasOpaqueURI { + t.Fatal("expected URI secret to be cleared by package overwrite") + } + + resolved, err := app.resolveConnectionSecrets(imported[0].Config) + if err != nil { + t.Fatalf("resolveConnectionSecrets returned error: %v", err) + } + if resolved.Password != "new-primary" { + t.Fatalf("expected primary password to be overwritten, got %q", resolved.Password) + } + if resolved.SSH.Password != "" { + t.Fatalf("expected SSH password to be cleared, got %q", resolved.SSH.Password) + } + if resolved.URI != "" { + t.Fatalf("expected URI secret to be cleared, got %q", resolved.URI) + } +} + +func TestImportConnectionPackagePayloadLatestEntryWinsForSameID(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + _, err := app.importConnectionPackagePayload(connectionPackagePayload{ + Connections: []connectionPackageItem{ + { + ID: "conn-dup", + Name: "First", + Config: connection.ConnectionConfig{ + ID: "conn-dup", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + }, + Secrets: connectionSecretBundle{Password: "first-secret"}, + }, + { + ID: "conn-dup", + Name: "Second", + Config: connection.ConnectionConfig{ + ID: "conn-dup", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + }, + Secrets: connectionSecretBundle{Password: "second-secret"}, + }, + }, + }) + if err != nil { + t.Fatalf("importConnectionPackagePayload returned error: %v", err) + } + + saved, err := app.GetSavedConnections() + if err != nil { + t.Fatalf("GetSavedConnections returned error: %v", err) + } + if len(saved) != 1 { + t.Fatalf("expected 1 saved item after duplicate id overwrite, got %d", len(saved)) + } + if saved[0].Name != "Second" { + t.Fatalf("expected latest item to win, got %q", saved[0].Name) + } + + resolved, err := app.resolveConnectionSecrets(saved[0].Config) + if err != nil { + t.Fatalf("resolveConnectionSecrets returned error: %v", err) + } + if resolved.Password != "second-secret" { + t.Fatalf("expected latest secret to win, got %q", resolved.Password) + } +} + +func TestImportConnectionPackagePayloadRollsBackOnSaveFailure(t *testing.T) { + failRef, err := secretstore.BuildRef(savedConnectionSecretKind, "conn-2") + if err != nil { + t.Fatalf("BuildRef returned error: %v", err) + } + + store := newFailOnPutSecretStore(failRef) + app := NewAppWithSecretStore(store) + app.configDir = t.TempDir() + + _, err = app.SaveConnection(connection.SavedConnectionInput{ + ID: "conn-1", + Name: "Existing", + Config: connection.ConnectionConfig{ + ID: "conn-1", + Type: "postgres", + Host: "db.old.local", + Port: 5432, + User: "postgres", + Password: "old-primary", + }, + }) + if err != nil { + t.Fatalf("SaveConnection returned error: %v", err) + } + + imported, err := app.importConnectionPackagePayload(connectionPackagePayload{ + Connections: []connectionPackageItem{ + { + ID: "conn-1", + Name: "Imported Existing", + Config: connection.ConnectionConfig{ + ID: "conn-1", + Type: "postgres", + Host: "db.new.local", + Port: 5432, + User: "postgres", + }, + Secrets: connectionSecretBundle{Password: "new-primary"}, + }, + { + ID: "conn-2", + Name: "Imported New", + Config: connection.ConnectionConfig{ + ID: "conn-2", + Type: "mysql", + Host: "db.second.local", + Port: 3306, + User: "root", + }, + Secrets: connectionSecretBundle{Password: "second-primary"}, + }, + }, + }) + if err == nil { + t.Fatal("expected importConnectionPackagePayload to return error") + } + if imported != nil { + t.Fatalf("expected no imported results after rollback, got %#v", imported) + } + + saved, err := app.GetSavedConnections() + if err != nil { + t.Fatalf("GetSavedConnections returned error: %v", err) + } + if len(saved) != 1 { + t.Fatalf("expected rollback to restore exactly 1 connection, got %d", len(saved)) + } + if saved[0].ID != "conn-1" || saved[0].Name != "Existing" { + t.Fatalf("expected rollback to restore original connection metadata, got %#v", saved[0]) + } + if saved[0].Config.Host != "db.old.local" { + t.Fatalf("expected rollback to restore original host, got %q", saved[0].Config.Host) + } + + resolved, err := app.resolveConnectionSecrets(saved[0].Config) + if err != nil { + t.Fatalf("resolveConnectionSecrets returned error: %v", err) + } + if resolved.Password != "old-primary" { + t.Fatalf("expected rollback to restore original primary password, got %q", resolved.Password) + } + + if _, err := store.Get(failRef); !os.IsNotExist(err) { + t.Fatalf("expected rollback to remove partially imported secret ref, got err=%v", err) + } +} + +func TestImportConnectionsPayloadLegacyJSONKeepsExistingSecretWhenMissing(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + _, err := app.SaveConnection(connection.SavedConnectionInput{ + ID: "legacy-1", + Name: "Legacy", + Config: connection.ConnectionConfig{ + ID: "legacy-1", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + Password: "legacy-secret", + }, + }) + if err != nil { + t.Fatalf("SaveConnection returned error: %v", err) + } + + raw, err := json.Marshal([]connection.LegacySavedConnection{ + { + ID: "legacy-1", + Name: "Legacy Updated", + Config: connection.ConnectionConfig{ + ID: "legacy-1", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + }, + }, + }) + if err != nil { + t.Fatalf("json.Marshal returned error: %v", err) + } + + imported, err := app.ImportConnectionsPayload(string(raw), "ignored") + if err != nil { + t.Fatalf("ImportConnectionsPayload returned error: %v", err) + } + if len(imported) != 1 { + t.Fatalf("expected 1 imported item, got %d", len(imported)) + } + if imported[0].Name != "Legacy Updated" { + t.Fatalf("expected legacy metadata to be overwritten, got %q", imported[0].Name) + } + + resolved, err := app.resolveConnectionSecrets(imported[0].Config) + if err != nil { + t.Fatalf("resolveConnectionSecrets returned error: %v", err) + } + if resolved.Password != "legacy-secret" { + t.Fatalf("expected legacy import to preserve existing secret, got %q", resolved.Password) + } +} + +func TestImportConnectionsPayloadEnvelopeRequiresPassword(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + raw := `{ + "schemaVersion": 1, + "kind": "gonavi_connection_package", + "cipher": "AES-256-GCM", + "kdf": { + "name": "Argon2id", + "memoryKiB": 65536, + "timeCost": 3, + "parallelism": 4, + "salt": "salt" + }, + "nonce": "nonce", + "payload": "payload" +}` + + _, err := app.ImportConnectionsPayload(raw, "") + if !errors.Is(err, errConnectionPackagePasswordRequired) { + t.Fatalf("expected errConnectionPackagePasswordRequired, got %v", err) + } +} + +func TestImportConnectionsPayloadEnvelopeImportsAndOverwritesSecrets(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + _, err := app.SaveConnection(connection.SavedConnectionInput{ + ID: "conn-1", + Name: "Existing", + Config: connection.ConnectionConfig{ + ID: "conn-1", + Type: "postgres", + Host: "db.old.local", + Port: 5432, + User: "postgres", + Password: "old-primary", + UseSSH: true, + SSH: connection.SSHConfig{ + Host: "jump.old.local", + Port: 22, + User: "ops", + Password: "old-ssh", + }, + URI: "postgres://old", + }, + }) + if err != nil { + t.Fatalf("SaveConnection returned error: %v", err) + } + + file, err := encryptConnectionPackage(connectionPackagePayload{ + Connections: []connectionPackageItem{ + { + ID: "conn-1", + Name: "Imported", + Config: connection.ConnectionConfig{ + ID: "conn-1", + Type: "postgres", + Host: "db.new.local", + Port: 5432, + User: "postgres", + }, + Secrets: connectionSecretBundle{ + Password: "new-primary", + }, + }, + }, + }, "package-password") + if err != nil { + t.Fatalf("encryptConnectionPackage returned error: %v", err) + } + + raw, err := json.Marshal(file) + if err != nil { + t.Fatalf("json.Marshal returned error: %v", err) + } + + imported, err := app.ImportConnectionsPayload(string(raw), "package-password") + if err != nil { + t.Fatalf("ImportConnectionsPayload returned error: %v", err) + } + if len(imported) != 1 { + t.Fatalf("expected 1 imported item, got %d", len(imported)) + } + if imported[0].Name != "Imported" { + t.Fatalf("expected imported name, got %q", imported[0].Name) + } + if !imported[0].HasPrimaryPassword { + t.Fatal("expected primary password after envelope import") + } + if imported[0].HasSSHPassword { + t.Fatal("expected missing SSH password in package to clear old secret") + } + if imported[0].HasOpaqueURI { + t.Fatal("expected missing URI in package to clear old secret") + } + + resolved, err := app.resolveConnectionSecrets(imported[0].Config) + if err != nil { + t.Fatalf("resolveConnectionSecrets returned error: %v", err) + } + if resolved.Password != "new-primary" { + t.Fatalf("expected primary password to be overwritten, got %q", resolved.Password) + } + if resolved.SSH.Password != "" { + t.Fatalf("expected SSH password to be cleared, got %q", resolved.SSH.Password) + } + if resolved.URI != "" { + t.Fatalf("expected URI secret to be cleared, got %q", resolved.URI) + } +} + +func TestNormalizeConnectionPackageExportFilenameAddsExtension(t *testing.T) { + filename := normalizeConnectionPackageExportFilename(`C:\tmp\connections`) + if !strings.HasSuffix(filename, connectionPackageExtension) { + t.Fatalf("expected filename to end with %q, got %q", connectionPackageExtension, filename) + } + + alreadyExtended := normalizeConnectionPackageExportFilename(`C:\tmp\connections` + connectionPackageExtension) + if alreadyExtended != `C:\tmp\connections`+connectionPackageExtension { + t.Fatalf("expected existing extension to be preserved, got %q", alreadyExtended) + } +} + +type failOnPutSecretStore struct { + base *fakeAppSecretStore + failRef string +} + +func newFailOnPutSecretStore(failRef string) *failOnPutSecretStore { + return &failOnPutSecretStore{ + base: newFakeAppSecretStore(), + failRef: failRef, + } +} + +func (s *failOnPutSecretStore) Put(ref string, payload []byte) error { + if ref == s.failRef { + return errors.New("injected put failure") + } + return s.base.Put(ref, payload) +} + +func (s *failOnPutSecretStore) Get(ref string) ([]byte, error) { + return s.base.Get(ref) +} + +func (s *failOnPutSecretStore) Delete(ref string) error { + return s.base.Delete(ref) +} + +func (s *failOnPutSecretStore) HealthCheck() error { + return s.base.HealthCheck() +} diff --git a/internal/app/connection_package_types.go b/internal/app/connection_package_types.go new file mode 100644 index 0000000..df959c9 --- /dev/null +++ b/internal/app/connection_package_types.go @@ -0,0 +1,77 @@ +package app + +import ( + "errors" + "strings" + + "GoNavi-Wails/internal/connection" +) + +const ( + connectionPackageSchemaVersion = 1 + connectionPackageKind = "gonavi_connection_package" + connectionPackageCipher = "AES-256-GCM" + connectionPackageKDFName = "Argon2id" + connectionPackageExtension = ".gonavi-conn" + + connectionPackageKDFDefaultMemoryKiB = 65536 + connectionPackageKDFDefaultTimeCost = 3 + connectionPackageKDFDefaultParallelism = 4 + + connectionPackageKDFMaxMemoryKiB = 262144 + connectionPackageKDFMaxTimeCost = 10 + connectionPackageKDFMaxParallelism = 16 +) + +var ( + errConnectionPackagePasswordRequired = errors.New("恢复包密码不能为空") + errConnectionPackageDecryptFailed = errors.New("文件密码错误或文件已损坏") + errConnectionPackageUnsupported = errors.New("不支持的连接恢复包格式") + errConnectionPackageNotImplemented = errors.New("connection package not implemented") +) + +type connectionPackageFile struct { + SchemaVersion int `json:"schemaVersion"` + Kind string `json:"kind"` + Cipher string `json:"cipher"` + KDF connectionPackageKDFSpec `json:"kdf"` + Nonce string `json:"nonce"` + Payload string `json:"payload"` +} + +type connectionPackageKDFSpec struct { + Name string `json:"name"` + MemoryKiB uint32 `json:"memoryKiB"` + TimeCost uint32 `json:"timeCost"` + Parallelism uint8 `json:"parallelism"` + Salt string `json:"salt"` +} + +type connectionPackagePayload struct { + ExportedAt string `json:"exportedAt,omitempty"` + Connections []connectionPackageItem `json:"connections"` +} + +type connectionPackageItem struct { + ID string `json:"id"` + Name string `json:"name"` + IncludeDatabases []string `json:"includeDatabases,omitempty"` + IncludeRedisDatabases []int `json:"includeRedisDatabases,omitempty"` + IconType string `json:"iconType,omitempty"` + IconColor string `json:"iconColor,omitempty"` + Config connection.ConnectionConfig `json:"config"` + Secrets connectionSecretBundle `json:"secrets,omitempty"` +} + +func defaultConnectionPackageKDFSpec() connectionPackageKDFSpec { + return connectionPackageKDFSpec{ + Name: connectionPackageKDFName, + MemoryKiB: connectionPackageKDFDefaultMemoryKiB, + TimeCost: connectionPackageKDFDefaultTimeCost, + Parallelism: connectionPackageKDFDefaultParallelism, + } +} + +func normalizeConnectionPackagePassword(password string) string { + return strings.TrimSpace(password) +} diff --git a/internal/app/connection_secret_resolution.go b/internal/app/connection_secret_resolution.go index 14842c1..5e7eb6f 100644 --- a/internal/app/connection_secret_resolution.go +++ b/internal/app/connection_secret_resolution.go @@ -1,6 +1,7 @@ package app import ( + "fmt" "strings" "GoNavi-Wails/internal/connection" @@ -14,7 +15,7 @@ func (a *App) resolveConnectionSecrets(config connection.ConnectionConfig) (conn repo := newSavedConnectionRepository(a.configDir, a.secretStore) view, err := repo.Find(config.ID) if err != nil { - return config, err + return config, normalizeConnectionSecretResolutionError(config, err) } base := config @@ -23,13 +24,32 @@ func (a *App) resolveConnectionSecrets(config connection.ConnectionConfig) (conn } bundle, err := repo.loadSecretBundle(view) if err != nil { - return base, err + return base, normalizeConnectionSecretResolutionError(base, err) } resolved := mergeConnectionSecretBundleIntoConfig(base, bundle) resolved.ID = view.ID return resolved, nil } +func normalizeConnectionSecretResolutionError(config connection.ConnectionConfig, err error) error { + if err == nil { + return nil + } + + lower := strings.ToLower(strings.TrimSpace(err.Error())) + switch { + case strings.Contains(lower, "saved connection not found:"): + if connectionMetadataLooksEmpty(config) { + return fmt.Errorf("未找到已保存连接,可能已被删除,请刷新后重试") + } + return fmt.Errorf("未找到当前连接对应的已保存密文,请重新填写密码并保存后再试") + case strings.Contains(lower, "secret store unavailable"): + return fmt.Errorf("系统密文存储当前不可用,请检查系统钥匙串或凭据管理器后再试") + default: + return err + } +} + func connectionMetadataLooksEmpty(config connection.ConnectionConfig) bool { return strings.TrimSpace(config.Type) == "" && strings.TrimSpace(config.Host) == "" && diff --git a/internal/app/connection_secret_resolution_test.go b/internal/app/connection_secret_resolution_test.go index a6336ca..e09e24a 100644 --- a/internal/app/connection_secret_resolution_test.go +++ b/internal/app/connection_secret_resolution_test.go @@ -1,6 +1,7 @@ package app import ( + "strings" "testing" "GoNavi-Wails/internal/connection" @@ -40,3 +41,23 @@ func TestResolveConnectionConfigByIDLoadsSecretsFromStore(t *testing.T) { t.Fatalf("expected restored DSN, got %q", resolved.DSN) } } + +func TestResolveConnectionSecretsReturnsFriendlyMessageWhenSavedSecretSourceIsMissing(t *testing.T) { + store := newFakeAppSecretStore() + app := NewAppWithSecretStore(store) + app.configDir = t.TempDir() + + _, err := app.resolveConnectionSecrets(connection.ConnectionConfig{ + ID: "conn-missing", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + }) + if err == nil { + t.Fatal("expected resolveConnectionSecrets to fail for a missing saved connection") + } + if !strings.Contains(err.Error(), "已保存密文") { + t.Fatalf("expected a secret-specific error message, got %q", err.Error()) + } +} diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go index 81e95ce..64e7aef 100644 --- a/internal/app/methods_file.go +++ b/internal/app/methods_file.go @@ -263,6 +263,10 @@ func (a *App) ImportConfigFile() connection.QueryResult { selection, err := runtime.OpenFileDialog(a.ctx, runtime.OpenDialogOptions{ Title: "Select Config File", Filters: []runtime.FileFilter{ + { + DisplayName: "GoNavi Connection Package (*.gonavi-conn)", + Pattern: "*.gonavi-conn", + }, { DisplayName: "JSON Files (*.json)", Pattern: "*.json", @@ -286,6 +290,53 @@ func (a *App) ImportConfigFile() connection.QueryResult { return connection.QueryResult{Success: true, Data: string(content)} } +func (a *App) ExportConnectionsPackage(password string) connection.QueryResult { + payload, err := a.buildConnectionPackagePayload() + if err != nil { + return connection.QueryResult{Success: false, Message: err.Error()} + } + + filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{ + Title: "Export Connections", + DefaultFilename: "connections" + connectionPackageExtension, + Filters: []runtime.FileFilter{ + { + DisplayName: "GoNavi Connection Package (*.gonavi-conn)", + Pattern: "*.gonavi-conn", + }, + }, + }) + if err != nil || strings.TrimSpace(filename) == "" { + return connection.QueryResult{Success: false, Message: "已取消"} + } + filename = normalizeConnectionPackageExportFilename(filename) + + pkg, err := encryptConnectionPackage(payload, password) + if err != nil { + return connection.QueryResult{Success: false, Message: err.Error()} + } + + content, err := json.MarshalIndent(pkg, "", " ") + if err != nil { + return connection.QueryResult{Success: false, Message: err.Error()} + } + if err := os.WriteFile(filename, content, 0o644); err != nil { + return connection.QueryResult{Success: false, Message: err.Error()} + } + return connection.QueryResult{Success: true, Message: "导出完成"} +} + +func normalizeConnectionPackageExportFilename(filename string) string { + trimmed := strings.TrimSpace(filename) + if trimmed == "" { + return "" + } + if strings.EqualFold(filepath.Ext(trimmed), connectionPackageExtension) { + return trimmed + } + return trimmed + connectionPackageExtension +} + func (a *App) SelectSSHKeyFile(currentPath string) connection.QueryResult { defaultDir := strings.TrimSpace(currentPath) if defaultDir == "" { diff --git a/internal/app/methods_saved_connections_test.go b/internal/app/methods_saved_connections_test.go index d17785f..a813c53 100644 --- a/internal/app/methods_saved_connections_test.go +++ b/internal/app/methods_saved_connections_test.go @@ -185,3 +185,89 @@ func TestSaveGlobalProxyReturnsSecretlessView(t *testing.T) { t.Fatal("expected hasPassword=true") } } + +func TestImportLegacyConnectionsIsIdempotentForSameID(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + legacy := connection.LegacySavedConnection{ + ID: "legacy-1", + Name: "Legacy", + Config: connection.ConnectionConfig{ + ID: "legacy-1", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + Password: "secret-1", + }, + } + + if _, err := app.ImportLegacyConnections([]connection.LegacySavedConnection{legacy}); err != nil { + t.Fatalf("first ImportLegacyConnections returned error: %v", err) + } + if _, err := app.ImportLegacyConnections([]connection.LegacySavedConnection{legacy}); err != nil { + t.Fatalf("second ImportLegacyConnections returned error: %v", err) + } + + saved, err := app.GetSavedConnections() + if err != nil { + t.Fatalf("GetSavedConnections returned error: %v", err) + } + if len(saved) != 1 { + t.Fatalf("expected a single saved connection after repeated import, got %d", len(saved)) + } +} + +func TestImportLegacyConnectionsKeepsExistingSecretWhenReimportOmitsPassword(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + if _, err := app.ImportLegacyConnections([]connection.LegacySavedConnection{ + { + ID: "legacy-1", + Name: "Legacy", + Config: connection.ConnectionConfig{ + ID: "legacy-1", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + Password: "secret-1", + }, + }, + }); err != nil { + t.Fatalf("initial ImportLegacyConnections returned error: %v", err) + } + + if _, err := app.ImportLegacyConnections([]connection.LegacySavedConnection{ + { + ID: "legacy-1", + Name: "Legacy Updated", + Config: connection.ConnectionConfig{ + ID: "legacy-1", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + }, + }, + }); err != nil { + t.Fatalf("update ImportLegacyConnections returned error: %v", err) + } + + saved, err := app.GetSavedConnections() + if err != nil { + t.Fatalf("GetSavedConnections returned error: %v", err) + } + if len(saved) != 1 { + t.Fatalf("expected 1 saved connection, got %d", len(saved)) + } + resolved, err := app.resolveConnectionSecrets(saved[0].Config) + if err != nil { + t.Fatalf("resolveConnectionSecrets returned error: %v", err) + } + if resolved.Password != "secret-1" { + t.Fatalf("expected original password to be preserved, got %q", resolved.Password) + } +} diff --git a/internal/app/security_update_engine.go b/internal/app/security_update_engine.go new file mode 100644 index 0000000..ac69a10 --- /dev/null +++ b/internal/app/security_update_engine.go @@ -0,0 +1,561 @@ +package app + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "GoNavi-Wails/internal/ai" + aiservice "GoNavi-Wails/internal/ai/service" + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/secretstore" +) + +type securityUpdateNormalizedPreview struct { + SourceType SecurityUpdateSourceType `json:"sourceType"` + ConnectionIDs []string `json:"connectionIds"` + HasGlobalProxy bool `json:"hasGlobalProxy"` + AIProviderIDs []string `json:"aiProviderIds"` + AIProvidersNeedingAttention []string `json:"aiProvidersNeedingAttention,omitempty"` +} + +func (a *App) GetSecurityUpdateStatus() (SecurityUpdateStatus, error) { + a.updateMu.Lock() + defer a.updateMu.Unlock() + + repo := newSecurityUpdateStateRepository(a.configDir) + status, err := repo.LoadMarker() + if err != nil { + if os.IsNotExist(err) { + inspection, inspectErr := aiservice.NewProviderConfigStore(a.configDir, a.secretStore).Inspect() + if inspectErr != nil { + return SecurityUpdateStatus{}, inspectErr + } + if len(inspection.ProvidersNeedingMigration) > 0 { + return buildSecurityUpdatePendingStatusFromInspection(inspection, SecurityUpdateOverallStatusPending), nil + } + return SecurityUpdateStatus{ + SchemaVersion: securityUpdateSchemaVersion, + OverallStatus: SecurityUpdateOverallStatusNotDetected, + Summary: SecurityUpdateSummary{}, + Issues: []SecurityUpdateIssue{}, + }, nil + } + return SecurityUpdateStatus{}, err + } + return status, nil +} + +func (a *App) StartSecurityUpdate(request StartSecurityUpdateRequest) (SecurityUpdateStatus, error) { + a.updateMu.Lock() + defer a.updateMu.Unlock() + + repo := newSecurityUpdateStateRepository(a.configDir) + status, err := repo.StartRound(request) + if err != nil { + return SecurityUpdateStatus{}, err + } + return a.executeSecurityUpdateRound(repo, status, request.SourceType, request.RawPayload) +} + +func (a *App) RetrySecurityUpdateCurrentRound(request RetrySecurityUpdateRequest) (SecurityUpdateStatus, error) { + a.updateMu.Lock() + defer a.updateMu.Unlock() + + repo := newSecurityUpdateStateRepository(a.configDir) + status, err := repo.RetryRound(request) + if err != nil { + return SecurityUpdateStatus{}, err + } + + previewData, err := os.ReadFile(filepath.Join(status.BackupPath, securityUpdateNormalizedPreviewFileName)) + if err != nil { + failed := newSecurityUpdateSystemFailureStatus(status, SecurityUpdateIssueReasonCodeEnvironmentBlocked, err) + _ = repo.WriteResult(failed) + return failed, nil + } + + var preview securityUpdateNormalizedPreview + if err := json.Unmarshal(previewData, &preview); err != nil { + failed := newSecurityUpdateSystemFailureStatus(status, SecurityUpdateIssueReasonCodeValidationFailed, err) + _ = repo.WriteResult(failed) + return failed, nil + } + + finalStatus, execErr := a.validateSecurityUpdateCurrentAppRound(status, preview) + if execErr != nil { + _ = repo.WriteResult(finalStatus) + return finalStatus, nil + } + if err := repo.WriteResult(finalStatus); err != nil { + return SecurityUpdateStatus{}, err + } + return finalStatus, nil +} + +func (a *App) RestartSecurityUpdate(request RestartSecurityUpdateRequest) (SecurityUpdateStatus, error) { + a.updateMu.Lock() + defer a.updateMu.Unlock() + + repo := newSecurityUpdateStateRepository(a.configDir) + status, err := repo.RestartRound(request) + if err != nil { + return SecurityUpdateStatus{}, err + } + return a.executeSecurityUpdateRound(repo, status, request.SourceType, request.RawPayload) +} + +func (a *App) DismissSecurityUpdateReminder() (SecurityUpdateStatus, error) { + a.updateMu.Lock() + defer a.updateMu.Unlock() + + now := nowRFC3339() + repo := newSecurityUpdateStateRepository(a.configDir) + status, err := repo.LoadMarker() + if err != nil { + if !os.IsNotExist(err) { + return SecurityUpdateStatus{}, err + } + inspection, inspectErr := aiservice.NewProviderConfigStore(a.configDir, a.secretStore).Inspect() + if inspectErr != nil { + return SecurityUpdateStatus{}, inspectErr + } + if len(inspection.ProvidersNeedingMigration) > 0 { + status = buildSecurityUpdatePendingStatusFromInspection(inspection, SecurityUpdateOverallStatusPostponed) + } else { + status = SecurityUpdateStatus{ + SchemaVersion: securityUpdateSchemaVersion, + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + Summary: SecurityUpdateSummary{}, + Issues: []SecurityUpdateIssue{}, + } + } + } + status.SchemaVersion = securityUpdateSchemaVersion + if strings.TrimSpace(string(status.SourceType)) == "" { + status.SourceType = SecurityUpdateSourceTypeCurrentAppSavedConfig + } + if status.Issues == nil { + status.Issues = []SecurityUpdateIssue{} + } + if status.OverallStatus == SecurityUpdateOverallStatusCompleted || status.OverallStatus == SecurityUpdateOverallStatusRolledBack { + return status, nil + } + status.OverallStatus = SecurityUpdateOverallStatusPostponed + status.PostponedAt = now + status.UpdatedAt = now + + if err := repo.WriteResult(status); err != nil { + return SecurityUpdateStatus{}, err + } + return repo.LoadMarker() +} + +func (a *App) executeSecurityUpdateRound(repo *securityUpdateStateRepository, round SecurityUpdateStatus, sourceType SecurityUpdateSourceType, rawPayload string) (SecurityUpdateStatus, error) { + if strings.TrimSpace(string(sourceType)) == "" { + sourceType = SecurityUpdateSourceTypeCurrentAppSavedConfig + } + if sourceType != SecurityUpdateSourceTypeCurrentAppSavedConfig { + failed := newSecurityUpdateSystemFailureStatus(round, SecurityUpdateIssueReasonCodeValidationFailed, fmt.Errorf("unsupported source type: %s", sourceType)) + _ = repo.WriteResult(failed) + return failed, nil + } + + source, rawParsed, err := parseSecurityUpdateCurrentAppSource(rawPayload) + if err != nil { + failed := newSecurityUpdateSystemFailureStatus(round, SecurityUpdateIssueReasonCodeValidationFailed, err) + _ = repo.WriteResult(failed) + return failed, nil + } + + rollbackSnapshot, err := captureSecurityUpdateCurrentAppRollbackSnapshot(a, source) + if err != nil { + failed := newSecurityUpdateSystemFailureStatus(round, securityUpdateFailureReasonForError(err), err) + _ = repo.WriteResult(failed) + return failed, nil + } + + if err := securityUpdateWriteJSONFile(filepath.Join(round.BackupPath, securityUpdateSourceCurrentAppFileName), rawParsed); err != nil { + return SecurityUpdateStatus{}, err + } + + finalStatus, preview, execErr := a.runSecurityUpdateCurrentAppRound(round, source) + if previewErr := securityUpdateWriteJSONFile(filepath.Join(round.BackupPath, securityUpdateNormalizedPreviewFileName), preview); previewErr != nil { + return a.rollbackSecurityUpdatePersistenceFailure(repo, rollbackSnapshot, finalStatus, previewErr) + } + + if execErr != nil { + if rollbackErr := rollbackSnapshot.restore(a); rollbackErr != nil { + failed := newSecurityUpdateSystemFailureStatus(finalStatus, securityUpdateFailureReasonForError(rollbackErr), rollbackErr) + _ = repo.WriteResult(failed) + return failed, nil + } + _ = repo.WriteResult(finalStatus) + return finalStatus, nil + } + if err := repo.WriteResult(finalStatus); err != nil { + return a.rollbackSecurityUpdatePersistenceFailure(repo, rollbackSnapshot, finalStatus, err) + } + return finalStatus, nil +} + +func (a *App) rollbackSecurityUpdatePersistenceFailure( + repo *securityUpdateStateRepository, + rollbackSnapshot securityUpdateCurrentAppRollbackSnapshot, + base SecurityUpdateStatus, + cause error, +) (SecurityUpdateStatus, error) { + if rollbackErr := rollbackSnapshot.restore(a); rollbackErr != nil { + failed := newSecurityUpdateSystemFailureStatus(base, securityUpdateFailureReasonForError(rollbackErr), rollbackErr) + _ = repo.WriteResult(failed) + return failed, nil + } + + failed := newSecurityUpdateSystemFailureStatus(base, SecurityUpdateIssueReasonCodeEnvironmentBlocked, cause) + _ = repo.WriteResult(failed) + return failed, nil +} + +func (a *App) runSecurityUpdateCurrentAppRound(round SecurityUpdateStatus, source securityUpdateCurrentAppSource) (SecurityUpdateStatus, securityUpdateNormalizedPreview, error) { + finalStatus := newSecurityUpdateRoundBaseStatus(round, SecurityUpdateSourceTypeCurrentAppSavedConfig) + + preview := securityUpdateNormalizedPreview{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + ConnectionIDs: make([]string, 0, len(source.Connections)), + HasGlobalProxy: source.GlobalProxy != nil, + AIProviderIDs: []string{}, + } + + connectionRepo := a.savedConnectionRepository() + for _, item := range source.Connections { + finalStatus.Summary.Total++ + preview.ConnectionIDs = append(preview.ConnectionIDs, item.ID) + if _, err := connectionRepo.Save(connection.SavedConnectionInput(item)); err != nil { + failed := newSecurityUpdateSystemFailureStatus(finalStatus, SecurityUpdateIssueReasonCodeEnvironmentBlocked, err) + return failed, preview, err + } + finalStatus.Summary.Updated++ + } + + if source.GlobalProxy != nil { + finalStatus.Summary.Total++ + if _, err := a.saveGlobalProxy(connection.SaveGlobalProxyInput(*source.GlobalProxy)); err != nil { + failed := newSecurityUpdateSystemFailureStatus(finalStatus, SecurityUpdateIssueReasonCodeEnvironmentBlocked, err) + return failed, preview, err + } + finalStatus.Summary.Updated++ + } + + providerSnapshot, err := aiservice.NewProviderConfigStore(a.configDir, a.secretStore).Load() + if err != nil { + failed := newSecurityUpdateSystemFailureStatus(finalStatus, securityUpdateFailureReasonForError(err), err) + return failed, preview, err + } + + for _, provider := range providerSnapshot.Providers { + if !providerParticipatesInSecurityUpdate(provider) { + continue + } + + preview.AIProviderIDs = append(preview.AIProviderIDs, provider.ID) + finalStatus.Summary.Total++ + if provider.HasSecret && strings.TrimSpace(provider.APIKey) == "" { + finalStatus.OverallStatus = SecurityUpdateOverallStatusNeedsAttention + finalStatus.Summary.Pending++ + finalStatus.Issues = append(finalStatus.Issues, SecurityUpdateIssue{ + ID: "ai-provider-" + provider.ID, + Scope: SecurityUpdateIssueScopeAIProvider, + RefID: provider.ID, + Title: provider.Name, + Severity: SecurityUpdateIssueSeverityMedium, + Status: SecurityUpdateItemStatusNeedsAttention, + ReasonCode: SecurityUpdateIssueReasonCodeSecretMissing, + Action: SecurityUpdateIssueActionOpenAISettings, + Message: "AI 提供商配置需要补充后才能完成安全更新", + }) + preview.AIProvidersNeedingAttention = append(preview.AIProvidersNeedingAttention, provider.ID) + continue + } + finalStatus.Summary.Updated++ + } + + if finalStatus.OverallStatus == SecurityUpdateOverallStatusCompleted { + finalStatus.CompletedAt = finalStatus.UpdatedAt + } + + return finalStatus, preview, nil +} + +func (a *App) validateSecurityUpdateCurrentAppRound(round SecurityUpdateStatus, preview securityUpdateNormalizedPreview) (SecurityUpdateStatus, error) { + if strings.TrimSpace(string(preview.SourceType)) == "" { + preview.SourceType = SecurityUpdateSourceTypeCurrentAppSavedConfig + } + + finalStatus := newSecurityUpdateRoundBaseStatus(round, preview.SourceType) + connectionRepo := a.savedConnectionRepository() + for _, id := range preview.ConnectionIDs { + finalStatus.Summary.Total++ + savedConnection, err := connectionRepo.Find(id) + if err != nil { + markSecurityUpdateNeedsAttention( + &finalStatus, + SecurityUpdateIssue{ + ID: "connection-" + id, + Scope: SecurityUpdateIssueScopeConnection, + RefID: id, + Title: id, + Severity: SecurityUpdateIssueSeverityMedium, + Status: SecurityUpdateItemStatusNeedsAttention, + ReasonCode: SecurityUpdateIssueReasonCodeValidationFailed, + Action: SecurityUpdateIssueActionOpenConnection, + Message: "连接配置已不存在或仍需重新保存后才能完成安全更新", + }, + ) + continue + } + if _, err := a.resolveConnectionSecrets(savedConnection.Config); err != nil { + if secretstore.IsUnavailable(err) { + failed := newSecurityUpdateSystemFailureStatus(finalStatus, SecurityUpdateIssueReasonCodeEnvironmentBlocked, err) + return failed, err + } + reason := SecurityUpdateIssueReasonCodeValidationFailed + message := "连接配置仍需补充后才能完成安全更新" + if os.IsNotExist(err) { + reason = SecurityUpdateIssueReasonCodeSecretMissing + message = "连接密码已丢失,请重新保存后再继续" + } + markSecurityUpdateNeedsAttention( + &finalStatus, + SecurityUpdateIssue{ + ID: "connection-" + id, + Scope: SecurityUpdateIssueScopeConnection, + RefID: id, + Title: savedConnection.Name, + Severity: SecurityUpdateIssueSeverityMedium, + Status: SecurityUpdateItemStatusNeedsAttention, + ReasonCode: reason, + Action: SecurityUpdateIssueActionOpenConnection, + Message: message, + }, + ) + continue + } + finalStatus.Summary.Updated++ + } + + if preview.HasGlobalProxy { + finalStatus.Summary.Total++ + proxyView, err := a.loadStoredGlobalProxyView() + if err != nil { + if !os.IsNotExist(err) { + failed := newSecurityUpdateSystemFailureStatus(finalStatus, securityUpdateFailureReasonForError(err), err) + return failed, err + } + markSecurityUpdateNeedsAttention( + &finalStatus, + SecurityUpdateIssue{ + ID: "global-proxy-default", + Scope: SecurityUpdateIssueScopeGlobalProxy, + Title: "全局代理", + Severity: SecurityUpdateIssueSeverityMedium, + Status: SecurityUpdateItemStatusNeedsAttention, + ReasonCode: SecurityUpdateIssueReasonCodeValidationFailed, + Action: SecurityUpdateIssueActionOpenProxySettings, + Message: "全局代理配置已不存在或仍需重新保存后才能完成安全更新", + }, + ) + } else { + if proxyView.HasPassword { + if _, err := a.loadGlobalProxySecretBundle(proxyView); err != nil { + if secretstore.IsUnavailable(err) { + failed := newSecurityUpdateSystemFailureStatus(finalStatus, SecurityUpdateIssueReasonCodeEnvironmentBlocked, err) + return failed, err + } + reason := SecurityUpdateIssueReasonCodeValidationFailed + message := "全局代理密码仍需补充后才能完成安全更新" + if os.IsNotExist(err) { + reason = SecurityUpdateIssueReasonCodeSecretMissing + message = "全局代理密码已丢失,请重新保存后再继续" + } + markSecurityUpdateNeedsAttention( + &finalStatus, + SecurityUpdateIssue{ + ID: "global-proxy-default", + Scope: SecurityUpdateIssueScopeGlobalProxy, + Title: "全局代理", + Severity: SecurityUpdateIssueSeverityMedium, + Status: SecurityUpdateItemStatusNeedsAttention, + ReasonCode: reason, + Action: SecurityUpdateIssueActionOpenProxySettings, + Message: message, + }, + ) + goto validateProviders + } + } + finalStatus.Summary.Updated++ + } + } + +validateProviders: + providerSnapshot, err := aiservice.NewProviderConfigStore(a.configDir, a.secretStore).Load() + if err != nil { + failed := newSecurityUpdateSystemFailureStatus(finalStatus, securityUpdateFailureReasonForError(err), err) + return failed, err + } + + providersByID := make(map[string]ai.ProviderConfig, len(providerSnapshot.Providers)) + for _, provider := range providerSnapshot.Providers { + providersByID[provider.ID] = provider + } + + for _, providerID := range preview.AIProviderIDs { + finalStatus.Summary.Total++ + provider, ok := providersByID[providerID] + if !ok { + markSecurityUpdateNeedsAttention( + &finalStatus, + SecurityUpdateIssue{ + ID: "ai-provider-" + providerID, + Scope: SecurityUpdateIssueScopeAIProvider, + RefID: providerID, + Title: providerID, + Severity: SecurityUpdateIssueSeverityMedium, + Status: SecurityUpdateItemStatusNeedsAttention, + ReasonCode: SecurityUpdateIssueReasonCodeValidationFailed, + Action: SecurityUpdateIssueActionOpenAISettings, + Message: "AI 提供商配置已不存在或仍需重新保存后才能完成安全更新", + }, + ) + continue + } + if provider.HasSecret && strings.TrimSpace(provider.APIKey) == "" { + markSecurityUpdateNeedsAttention( + &finalStatus, + SecurityUpdateIssue{ + ID: "ai-provider-" + provider.ID, + Scope: SecurityUpdateIssueScopeAIProvider, + RefID: provider.ID, + Title: provider.Name, + Severity: SecurityUpdateIssueSeverityMedium, + Status: SecurityUpdateItemStatusNeedsAttention, + ReasonCode: SecurityUpdateIssueReasonCodeSecretMissing, + Action: SecurityUpdateIssueActionOpenAISettings, + Message: "AI 提供商配置需要补充后才能完成安全更新", + }, + ) + continue + } + finalStatus.Summary.Updated++ + } + + if finalStatus.OverallStatus == SecurityUpdateOverallStatusCompleted { + finalStatus.CompletedAt = finalStatus.UpdatedAt + } + return finalStatus, nil +} + +func providerParticipatesInSecurityUpdate(provider ai.ProviderConfig) bool { + return provider.HasSecret || strings.TrimSpace(provider.APIKey) != "" +} + +func buildSecurityUpdatePendingStatusFromInspection( + inspection aiservice.ProviderConfigStoreInspection, + overallStatus SecurityUpdateOverallStatus, +) SecurityUpdateStatus { + providersByID := make(map[string]ai.ProviderConfig, len(inspection.Snapshot.Providers)) + for _, provider := range inspection.Snapshot.Providers { + providersByID[provider.ID] = provider + } + + issues := make([]SecurityUpdateIssue, 0, len(inspection.ProvidersNeedingMigration)) + for _, providerID := range inspection.ProvidersNeedingMigration { + provider := providersByID[providerID] + title := strings.TrimSpace(provider.Name) + if title == "" { + title = providerID + } + issues = append(issues, SecurityUpdateIssue{ + ID: "ai-provider-" + providerID, + Scope: SecurityUpdateIssueScopeAIProvider, + RefID: providerID, + Title: title, + Severity: SecurityUpdateIssueSeverityMedium, + Status: SecurityUpdateItemStatusPending, + ReasonCode: SecurityUpdateIssueReasonCodeMigrationRequired, + Action: SecurityUpdateIssueActionOpenAISettings, + Message: "AI 提供商配置仍保存在当前应用配置中,完成安全更新后会迁入新的安全存储。", + }) + } + + return SecurityUpdateStatus{ + SchemaVersion: securityUpdateSchemaVersion, + OverallStatus: overallStatus, + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + ReminderVisible: overallStatus == SecurityUpdateOverallStatusPending, + CanStart: overallStatus == SecurityUpdateOverallStatusPending || overallStatus == SecurityUpdateOverallStatusPostponed, + CanPostpone: overallStatus == SecurityUpdateOverallStatusPending || overallStatus == SecurityUpdateOverallStatusPostponed, + Summary: SecurityUpdateSummary{ + Total: len(issues), + Pending: len(issues), + }, + Issues: issues, + } +} + +func newSecurityUpdateRoundBaseStatus(round SecurityUpdateStatus, sourceType SecurityUpdateSourceType) SecurityUpdateStatus { + if strings.TrimSpace(string(sourceType)) == "" { + sourceType = SecurityUpdateSourceTypeCurrentAppSavedConfig + } + return SecurityUpdateStatus{ + SchemaVersion: securityUpdateSchemaVersion, + MigrationID: round.MigrationID, + OverallStatus: SecurityUpdateOverallStatusCompleted, + SourceType: sourceType, + BackupAvailable: round.BackupAvailable || strings.TrimSpace(round.BackupPath) != "", + BackupPath: round.BackupPath, + StartedAt: round.StartedAt, + UpdatedAt: nowRFC3339(), + Summary: SecurityUpdateSummary{}, + Issues: []SecurityUpdateIssue{}, + } +} + +func markSecurityUpdateNeedsAttention(status *SecurityUpdateStatus, issue SecurityUpdateIssue) { + status.OverallStatus = SecurityUpdateOverallStatusNeedsAttention + status.Summary.Pending++ + status.Issues = append(status.Issues, issue) +} + +func securityUpdateFailureReasonForError(err error) SecurityUpdateIssueReasonCode { + if secretstore.IsUnavailable(err) { + return SecurityUpdateIssueReasonCodeEnvironmentBlocked + } + return SecurityUpdateIssueReasonCodeValidationFailed +} + +func newSecurityUpdateSystemFailureStatus(base SecurityUpdateStatus, reasonCode SecurityUpdateIssueReasonCode, err error) SecurityUpdateStatus { + status := base + status.SchemaVersion = securityUpdateSchemaVersion + status.OverallStatus = SecurityUpdateOverallStatusRolledBack + status.BackupAvailable = status.BackupAvailable || strings.TrimSpace(status.BackupPath) != "" + status.UpdatedAt = nowRFC3339() + status.CompletedAt = "" + status.LastError = err.Error() + status.Summary.Failed++ + status.Issues = []SecurityUpdateIssue{ + { + ID: "system-blocked", + Scope: SecurityUpdateIssueScopeSystem, + Title: "安全更新未完成", + Severity: SecurityUpdateIssueSeverityHigh, + Status: SecurityUpdateItemStatusFailed, + ReasonCode: reasonCode, + Action: SecurityUpdateIssueActionViewDetails, + Message: "当前环境无法完成本次安全更新,请稍后重试", + }, + } + return status +} diff --git a/internal/app/security_update_engine_test.go b/internal/app/security_update_engine_test.go new file mode 100644 index 0000000..fb54774 --- /dev/null +++ b/internal/app/security_update_engine_test.go @@ -0,0 +1,942 @@ +package app + +import ( + "errors" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + aiservice "GoNavi-Wails/internal/ai/service" + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/secretstore" +) + +func TestStartSecurityUpdateCreatesBackupAndImportsSavedConfig(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + writeLegacyAIProviderConfig(t, app.configDir, map[string]any{ + "providers": []map[string]any{ + { + "id": "openai-main", + "type": "openai", + "name": "OpenAI", + "apiKey": "sk-ai-test", + "baseUrl": "https://api.openai.com/v1", + "headers": map[string]any{ + "Authorization": "Bearer ai-test", + "X-Team": "platform", + }, + }, + }, + }) + + status, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + RawPayload: buildLegacySecurityUpdatePayload(), + }) + if err != nil { + t.Fatalf("StartSecurityUpdate returned error: %v", err) + } + if status.OverallStatus != SecurityUpdateOverallStatusCompleted { + t.Fatalf("expected completed status, got %q", status.OverallStatus) + } + if status.MigrationID == "" { + t.Fatal("expected migration ID to be created") + } + if status.Summary.Total != 3 || status.Summary.Updated != 3 { + t.Fatalf("expected summary total=3 updated=3, got %#v", status.Summary) + } + + savedConnections, err := app.GetSavedConnections() + if err != nil { + t.Fatalf("GetSavedConnections returned error: %v", err) + } + if len(savedConnections) != 1 { + t.Fatalf("expected 1 saved connection, got %d", len(savedConnections)) + } + resolvedConnection, err := app.resolveConnectionSecrets(savedConnections[0].Config) + if err != nil { + t.Fatalf("resolveConnectionSecrets returned error: %v", err) + } + if resolvedConnection.Password != "postgres-secret" { + t.Fatalf("expected imported connection password, got %q", resolvedConnection.Password) + } + + globalProxyView, err := app.loadStoredGlobalProxyView() + if err != nil { + t.Fatalf("loadStoredGlobalProxyView returned error: %v", err) + } + globalProxyBundle, err := app.loadGlobalProxySecretBundle(globalProxyView) + if err != nil { + t.Fatalf("loadGlobalProxySecretBundle returned error: %v", err) + } + if globalProxyBundle.Password != "proxy-secret" { + t.Fatalf("expected imported proxy password, got %q", globalProxyBundle.Password) + } + + providerStore := aiservice.NewProviderConfigStore(app.configDir, app.secretStore) + providerSnapshot, err := providerStore.Load() + if err != nil { + t.Fatalf("provider store Load returned error: %v", err) + } + if len(providerSnapshot.Providers) != 1 { + t.Fatalf("expected 1 AI provider, got %d", len(providerSnapshot.Providers)) + } + if providerSnapshot.Providers[0].APIKey != "sk-ai-test" { + t.Fatalf("expected migrated AI provider apiKey, got %q", providerSnapshot.Providers[0].APIKey) + } + + for _, name := range []string{ + securityUpdateManifestFileName, + securityUpdateSourceCurrentAppFileName, + securityUpdateNormalizedPreviewFileName, + securityUpdateResultFileName, + } { + if _, err := os.Stat(filepath.Join(status.BackupPath, name)); err != nil { + t.Fatalf("expected backup artifact %q: %v", name, err) + } + } +} + +func TestGetSecurityUpdateStatusReturnsPendingWhenOnlyAIProviderNeedsSecurityUpdate(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + writeLegacyAIProviderConfig(t, app.configDir, map[string]any{ + "providers": []map[string]any{ + { + "id": "openai-main", + "type": "openai", + "name": "OpenAI", + "apiKey": "sk-ai-test", + "baseUrl": "https://api.openai.com/v1", + }, + }, + }) + + status, err := app.GetSecurityUpdateStatus() + if err != nil { + t.Fatalf("GetSecurityUpdateStatus returned error: %v", err) + } + if status.OverallStatus != SecurityUpdateOverallStatusPending { + t.Fatalf("expected pending status, got %q", status.OverallStatus) + } + if !status.CanStart || !status.ReminderVisible { + t.Fatalf("expected pending status to expose start/reminder flags, got %#v", status) + } +} + +func TestGetSecurityUpdateStatusIncludesPendingAIProviderIssuesBeforeStart(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + writeLegacyAIProviderConfig(t, app.configDir, map[string]any{ + "providers": []map[string]any{ + { + "id": "openai-main", + "type": "openai", + "name": "OpenAI", + "apiKey": "sk-ai-test", + "baseUrl": "https://api.openai.com/v1", + }, + }, + }) + + status, err := app.GetSecurityUpdateStatus() + if err != nil { + t.Fatalf("GetSecurityUpdateStatus returned error: %v", err) + } + if len(status.Issues) != 1 { + t.Fatalf("expected 1 pending issue, got %#v", status.Issues) + } + if status.Summary.Total != 1 || status.Summary.Pending != 1 { + t.Fatalf("expected summary total=1 pending=1, got %#v", status.Summary) + } + issue := status.Issues[0] + if issue.Scope != SecurityUpdateIssueScopeAIProvider { + t.Fatalf("expected AI provider issue scope, got %#v", issue) + } + if issue.RefID != "openai-main" || issue.Title != "OpenAI" { + t.Fatalf("expected provider issue to point at openai-main/OpenAI, got %#v", issue) + } + if issue.Status != SecurityUpdateItemStatusPending || issue.Action != SecurityUpdateIssueActionOpenAISettings { + t.Fatalf("expected pending AI settings issue, got %#v", issue) + } +} + +func TestRetrySecurityUpdateCurrentRoundReusesMigrationIDAfterPendingIssueIsFixed(t *testing.T) { + store := newFakeAppSecretStore() + app := NewAppWithSecretStore(store) + app.configDir = t.TempDir() + + ref, err := secretstore.BuildRef("ai-provider", "openai-main") + if err != nil { + t.Fatalf("BuildRef returned error: %v", err) + } + writeLegacyAIProviderConfig(t, app.configDir, map[string]any{ + "providers": []map[string]any{ + { + "id": "openai-main", + "type": "openai", + "name": "OpenAI", + "hasSecret": true, + "secretRef": ref, + "baseUrl": "https://api.openai.com/v1", + }, + }, + }) + + initial, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + RawPayload: buildLegacySecurityUpdatePayload(), + }) + if err != nil { + t.Fatalf("StartSecurityUpdate returned error: %v", err) + } + if initial.OverallStatus != SecurityUpdateOverallStatusNeedsAttention { + t.Fatalf("expected needs_attention status, got %q", initial.OverallStatus) + } + if len(initial.Issues) != 1 || initial.Issues[0].Scope != SecurityUpdateIssueScopeAIProvider { + t.Fatalf("expected AI provider issue, got %#v", initial.Issues) + } + + if err := store.Put(ref, []byte(`{"apiKey":"sk-fixed","sensitiveHeaders":{"Authorization":"Bearer fixed"}}`)); err != nil { + t.Fatalf("Put returned error: %v", err) + } + + retried, err := app.RetrySecurityUpdateCurrentRound(RetrySecurityUpdateRequest{ + MigrationID: initial.MigrationID, + }) + if err != nil { + t.Fatalf("RetrySecurityUpdateCurrentRound returned error: %v", err) + } + if retried.MigrationID != initial.MigrationID { + t.Fatalf("expected retry to reuse migration ID %q, got %q", initial.MigrationID, retried.MigrationID) + } + if retried.OverallStatus != SecurityUpdateOverallStatusCompleted { + t.Fatalf("expected completed status after retry, got %q", retried.OverallStatus) + } +} + +func TestRetrySecurityUpdateCurrentRoundDoesNotReimportBrokenLegacySourceAfterUserFix(t *testing.T) { + store := newFakeAppSecretStore() + app := NewAppWithSecretStore(store) + app.configDir = t.TempDir() + + ref, err := secretstore.BuildRef("ai-provider", "openai-main") + if err != nil { + t.Fatalf("BuildRef returned error: %v", err) + } + writeLegacyAIProviderConfig(t, app.configDir, map[string]any{ + "providers": []map[string]any{ + { + "id": "openai-main", + "type": "openai", + "name": "OpenAI", + "hasSecret": true, + "secretRef": ref, + "baseUrl": "https://api.openai.com/v1", + }, + }, + }) + + initial, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + RawPayload: buildLegacySecurityUpdatePayload(), + }) + if err != nil { + t.Fatalf("StartSecurityUpdate returned error: %v", err) + } + if initial.OverallStatus != SecurityUpdateOverallStatusNeedsAttention { + t.Fatalf("expected needs_attention status, got %q", initial.OverallStatus) + } + + if _, err := app.SaveConnection(connection.SavedConnectionInput{ + ID: "legacy-1", + Name: "Legacy Fixed", + Config: connection.ConnectionConfig{ + ID: "legacy-1", + Type: "postgres", + Host: "db-fixed.local", + Port: 5432, + User: "postgres", + Password: "postgres-fixed", + }, + }); err != nil { + t.Fatalf("SaveConnection returned error: %v", err) + } + + if err := store.Put(ref, []byte(`{"apiKey":"sk-fixed"}`)); err != nil { + t.Fatalf("Put returned error: %v", err) + } + + retried, err := app.RetrySecurityUpdateCurrentRound(RetrySecurityUpdateRequest{ + MigrationID: initial.MigrationID, + }) + if err != nil { + t.Fatalf("RetrySecurityUpdateCurrentRound returned error: %v", err) + } + if retried.OverallStatus != SecurityUpdateOverallStatusCompleted { + t.Fatalf("expected completed status after retry, got %q", retried.OverallStatus) + } + + savedConnections, err := app.GetSavedConnections() + if err != nil { + t.Fatalf("GetSavedConnections returned error: %v", err) + } + if len(savedConnections) != 1 { + t.Fatalf("expected 1 saved connection, got %d", len(savedConnections)) + } + + resolvedConnection, err := app.resolveConnectionSecrets(savedConnections[0].Config) + if err != nil { + t.Fatalf("resolveConnectionSecrets returned error: %v", err) + } + if resolvedConnection.Host != "db-fixed.local" { + t.Fatalf("expected retry to keep user-fixed host, got %q", resolvedConnection.Host) + } + if resolvedConnection.Password != "postgres-fixed" { + t.Fatalf("expected retry to keep user-fixed password, got %q", resolvedConnection.Password) + } +} + +func TestRestartSecurityUpdateCreatesNewMigrationID(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + initial, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + RawPayload: buildLegacySecurityUpdatePayload(), + }) + if err != nil { + t.Fatalf("StartSecurityUpdate returned error: %v", err) + } + + restarted, err := app.RestartSecurityUpdate(RestartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + RawPayload: buildLegacySecurityUpdatePayload(), + }) + if err != nil { + t.Fatalf("RestartSecurityUpdate returned error: %v", err) + } + if restarted.MigrationID == initial.MigrationID { + t.Fatal("expected restart to create a new migration ID") + } +} + +func TestDismissSecurityUpdateReminderMarksStatusPostponed(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + status, err := app.DismissSecurityUpdateReminder() + if err != nil { + t.Fatalf("DismissSecurityUpdateReminder returned error: %v", err) + } + if status.OverallStatus != SecurityUpdateOverallStatusPostponed { + t.Fatalf("expected postponed status, got %q", status.OverallStatus) + } + if status.PostponedAt == "" { + t.Fatal("expected postponedAt to be recorded") + } +} + +func TestDismissSecurityUpdateReminderKeepsCurrentRoundContext(t *testing.T) { + store := newFakeAppSecretStore() + app := NewAppWithSecretStore(store) + app.configDir = t.TempDir() + + ref, err := secretstore.BuildRef("ai-provider", "openai-main") + if err != nil { + t.Fatalf("BuildRef returned error: %v", err) + } + writeLegacyAIProviderConfig(t, app.configDir, map[string]any{ + "providers": []map[string]any{ + { + "id": "openai-main", + "type": "openai", + "name": "OpenAI", + "hasSecret": true, + "secretRef": ref, + "baseUrl": "https://api.openai.com/v1", + }, + }, + }) + + initial, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + RawPayload: buildLegacySecurityUpdatePayload(), + }) + if err != nil { + t.Fatalf("StartSecurityUpdate returned error: %v", err) + } + if initial.OverallStatus != SecurityUpdateOverallStatusNeedsAttention { + t.Fatalf("expected needs_attention status, got %q", initial.OverallStatus) + } + + postponed, err := app.DismissSecurityUpdateReminder() + if err != nil { + t.Fatalf("DismissSecurityUpdateReminder returned error: %v", err) + } + if postponed.OverallStatus != SecurityUpdateOverallStatusPostponed { + t.Fatalf("expected postponed status, got %q", postponed.OverallStatus) + } + if postponed.MigrationID != initial.MigrationID { + t.Fatalf("expected migration ID %q to be preserved, got %q", initial.MigrationID, postponed.MigrationID) + } + if postponed.BackupPath != initial.BackupPath { + t.Fatalf("expected backupPath %q to be preserved, got %q", initial.BackupPath, postponed.BackupPath) + } + if postponed.Summary != initial.Summary { + t.Fatalf("expected summary %#v to be preserved, got %#v", initial.Summary, postponed.Summary) + } + if len(postponed.Issues) != len(initial.Issues) { + t.Fatalf("expected %d issues to be preserved, got %#v", len(initial.Issues), postponed.Issues) + } + if postponed.PostponedAt == "" { + t.Fatal("expected postponedAt to be recorded") + } +} + +func TestDismissSecurityUpdateReminderKeepsPendingAIProviderDetailsWithoutCurrentRound(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + writeLegacyAIProviderConfig(t, app.configDir, map[string]any{ + "providers": []map[string]any{ + { + "id": "openai-main", + "type": "openai", + "name": "OpenAI", + "apiKey": "sk-ai-test", + "baseUrl": "https://api.openai.com/v1", + }, + }, + }) + + status, err := app.DismissSecurityUpdateReminder() + if err != nil { + t.Fatalf("DismissSecurityUpdateReminder returned error: %v", err) + } + if status.OverallStatus != SecurityUpdateOverallStatusPostponed { + t.Fatalf("expected postponed status, got %q", status.OverallStatus) + } + if status.Summary.Total != 1 || status.Summary.Pending != 1 { + t.Fatalf("expected summary total=1 pending=1, got %#v", status.Summary) + } + if len(status.Issues) != 1 { + t.Fatalf("expected 1 pending issue, got %#v", status.Issues) + } + if status.Issues[0].RefID != "openai-main" || status.Issues[0].Action != SecurityUpdateIssueActionOpenAISettings { + t.Fatalf("expected postponed issue to keep AI provider repair entry, got %#v", status.Issues[0]) + } +} + +func TestDismissSecurityUpdateReminderDoesNotOverrideCompletedRound(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + repo := newSecurityUpdateStateRepository(app.configDir) + completed := SecurityUpdateStatus{ + SchemaVersion: securityUpdateSchemaVersion, + MigrationID: "migration-1", + OverallStatus: SecurityUpdateOverallStatusCompleted, + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + BackupPath: filepath.Join(app.configDir, securityUpdateBackupRootDirName, "migration-1"), + StartedAt: "2026-04-09T00:00:00Z", + UpdatedAt: "2026-04-09T00:05:00Z", + CompletedAt: "2026-04-09T00:05:00Z", + Summary: SecurityUpdateSummary{ + Total: 1, + Updated: 1, + }, + Issues: []SecurityUpdateIssue{}, + } + if err := repo.WriteResult(completed); err != nil { + t.Fatalf("WriteResult returned error: %v", err) + } + + status, err := app.DismissSecurityUpdateReminder() + if err != nil { + t.Fatalf("DismissSecurityUpdateReminder returned error: %v", err) + } + if status.OverallStatus != SecurityUpdateOverallStatusCompleted { + t.Fatalf("expected completed status to be preserved, got %q", status.OverallStatus) + } + if status.MigrationID != completed.MigrationID { + t.Fatalf("expected migration ID %q to be preserved, got %q", completed.MigrationID, status.MigrationID) + } + if status.PostponedAt != "" { + t.Fatalf("expected completed round to keep empty postponedAt, got %q", status.PostponedAt) + } +} + +func TestDismissSecurityUpdateReminderDoesNotOverrideRolledBackRound(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + repo := newSecurityUpdateStateRepository(app.configDir) + rolledBack := SecurityUpdateStatus{ + SchemaVersion: securityUpdateSchemaVersion, + MigrationID: "migration-1", + OverallStatus: SecurityUpdateOverallStatusRolledBack, + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + BackupPath: filepath.Join(app.configDir, securityUpdateBackupRootDirName, "migration-1"), + StartedAt: "2026-04-09T00:00:00Z", + UpdatedAt: "2026-04-09T00:05:00Z", + Summary: SecurityUpdateSummary{ + Total: 1, + Failed: 1, + }, + Issues: []SecurityUpdateIssue{ + { + ID: "system-blocked", + Scope: SecurityUpdateIssueScopeSystem, + Title: "安全更新未完成", + Severity: SecurityUpdateIssueSeverityHigh, + Status: SecurityUpdateItemStatusFailed, + ReasonCode: SecurityUpdateIssueReasonCodeEnvironmentBlocked, + Action: SecurityUpdateIssueActionViewDetails, + Message: "当前环境无法完成本次安全更新,请稍后重试", + }, + }, + } + if err := repo.WriteResult(rolledBack); err != nil { + t.Fatalf("WriteResult returned error: %v", err) + } + + status, err := app.DismissSecurityUpdateReminder() + if err != nil { + t.Fatalf("DismissSecurityUpdateReminder returned error: %v", err) + } + if status.OverallStatus != SecurityUpdateOverallStatusRolledBack { + t.Fatalf("expected rolled_back status to be preserved, got %q", status.OverallStatus) + } + if status.MigrationID != rolledBack.MigrationID { + t.Fatalf("expected migration ID %q to be preserved, got %q", rolledBack.MigrationID, status.MigrationID) + } + if status.PostponedAt != "" { + t.Fatalf("expected rolled_back round to keep empty postponedAt, got %q", status.PostponedAt) + } + if len(status.Issues) != 1 || status.Issues[0].Scope != SecurityUpdateIssueScopeSystem { + t.Fatalf("expected rolled_back issue details to be preserved, got %#v", status.Issues) + } +} + +func TestStartSecurityUpdateRollsBackWhenSecretStoreUnavailable(t *testing.T) { + app := NewAppWithSecretStore(nil) + app.configDir = t.TempDir() + + status, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + RawPayload: buildLegacySecurityUpdatePayload(), + }) + if err != nil { + t.Fatalf("StartSecurityUpdate returned error: %v", err) + } + if status.OverallStatus != SecurityUpdateOverallStatusRolledBack { + t.Fatalf("expected rolled_back status, got %q", status.OverallStatus) + } + if len(status.Issues) != 1 || status.Issues[0].Scope != SecurityUpdateIssueScopeSystem { + t.Fatalf("expected single system issue, got %#v", status.Issues) + } +} + +func TestStartSecurityUpdateRollsBackWhenAIProviderSecretStoreUnavailable(t *testing.T) { + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("blocked")) + app.configDir = t.TempDir() + + writeLegacyAIProviderConfig(t, app.configDir, map[string]any{ + "providers": []map[string]any{ + { + "id": "openai-main", + "type": "openai", + "name": "OpenAI", + "apiKey": "sk-ai-test", + "baseUrl": "https://api.openai.com/v1", + }, + }, + }) + + status, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + RawPayload: "", + }) + if err != nil { + t.Fatalf("StartSecurityUpdate returned error: %v", err) + } + if status.OverallStatus != SecurityUpdateOverallStatusRolledBack { + t.Fatalf("expected rolled_back status, got %q", status.OverallStatus) + } + if len(status.Issues) != 1 || status.Issues[0].Scope != SecurityUpdateIssueScopeSystem { + t.Fatalf("expected single system issue, got %#v", status.Issues) + } +} + +func TestStartSecurityUpdateRollsBackPartialConnectionImportWhenLaterProviderStepFails(t *testing.T) { + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("blocked")) + app.configDir = t.TempDir() + + writeLegacyAIProviderConfig(t, app.configDir, map[string]any{ + "providers": []map[string]any{ + { + "id": "openai-main", + "type": "openai", + "name": "OpenAI", + "apiKey": "sk-ai-test", + "baseUrl": "https://api.openai.com/v1", + }, + }, + }) + + payload, err := json.Marshal(map[string]any{ + "state": map[string]any{ + "connections": []map[string]any{ + { + "id": "legacy-1", + "name": "Legacy", + "config": map[string]any{ + "id": "legacy-1", + "type": "postgres", + "host": "db.local", + "port": 5432, + "user": "postgres", + }, + }, + }, + }, + }) + if err != nil { + t.Fatalf("Marshal returned error: %v", err) + } + + status, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + RawPayload: string(payload), + }) + if err != nil { + t.Fatalf("StartSecurityUpdate returned error: %v", err) + } + if status.OverallStatus != SecurityUpdateOverallStatusRolledBack { + t.Fatalf("expected rolled_back status, got %q", status.OverallStatus) + } + + savedConnections, err := app.GetSavedConnections() + if err != nil { + t.Fatalf("GetSavedConnections returned error: %v", err) + } + if len(savedConnections) != 0 { + t.Fatalf("expected rollback to leave no imported connections, got %#v", savedConnections) + } +} + +func TestStartSecurityUpdateRollsBackExistingConnectionMetadataAndSecretWhenLaterProviderStepFails(t *testing.T) { + store := newFakeAppSecretStore() + app := NewAppWithSecretStore(store) + app.configDir = t.TempDir() + + if _, err := app.SaveConnection(connection.SavedConnectionInput{ + ID: "legacy-1", + Name: "Existing", + Config: connection.ConnectionConfig{ + ID: "legacy-1", + Type: "postgres", + Host: "db-old.local", + Port: 5432, + User: "postgres", + Password: "old-secret", + }, + }); err != nil { + t.Fatalf("SaveConnection returned error: %v", err) + } + + if err := os.WriteFile(filepath.Join(app.configDir, "ai_config.json"), []byte("{"), 0o644); err != nil { + t.Fatalf("WriteFile returned error: %v", err) + } + + payload, err := json.Marshal(map[string]any{ + "state": map[string]any{ + "connections": []map[string]any{ + { + "id": "legacy-1", + "name": "Migrated", + "config": map[string]any{ + "id": "legacy-1", + "type": "postgres", + "host": "db-new.local", + "port": 5432, + "user": "postgres", + "password": "new-secret", + }, + }, + }, + }, + }) + if err != nil { + t.Fatalf("Marshal returned error: %v", err) + } + + status, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + RawPayload: string(payload), + }) + if err != nil { + t.Fatalf("StartSecurityUpdate returned error: %v", err) + } + if status.OverallStatus != SecurityUpdateOverallStatusRolledBack { + t.Fatalf("expected rolled_back status, got %q", status.OverallStatus) + } + + savedConnections, err := app.GetSavedConnections() + if err != nil { + t.Fatalf("GetSavedConnections returned error: %v", err) + } + if len(savedConnections) != 1 { + t.Fatalf("expected existing connection to remain, got %#v", savedConnections) + } + if savedConnections[0].Name != "Existing" || savedConnections[0].Config.Host != "db-old.local" { + t.Fatalf("expected existing connection metadata to be restored, got %#v", savedConnections[0]) + } + resolved, err := app.resolveConnectionSecrets(savedConnections[0].Config) + if err != nil { + t.Fatalf("resolveConnectionSecrets returned error: %v", err) + } + if resolved.Password != "old-secret" { + t.Fatalf("expected existing connection secret to be restored, got %q", resolved.Password) + } +} + +func TestStartSecurityUpdateRollsBackExistingGlobalProxyWhenLaterProviderStepFails(t *testing.T) { + store := newFakeAppSecretStore() + app := NewAppWithSecretStore(store) + app.configDir = t.TempDir() + + if _, err := app.saveGlobalProxy(connection.SaveGlobalProxyInput{ + Enabled: true, + Type: "http", + Host: "proxy-old.local", + Port: 8080, + User: "ops", + Password: "old-proxy-secret", + }); err != nil { + t.Fatalf("saveGlobalProxy returned error: %v", err) + } + + if err := os.WriteFile(filepath.Join(app.configDir, "ai_config.json"), []byte("{"), 0o644); err != nil { + t.Fatalf("WriteFile returned error: %v", err) + } + + payload, err := json.Marshal(map[string]any{ + "state": map[string]any{ + "globalProxy": map[string]any{ + "enabled": true, + "type": "http", + "host": "proxy-new.local", + "port": 8081, + "user": "ops-new", + "password": "new-proxy-secret", + }, + }, + }) + if err != nil { + t.Fatalf("Marshal returned error: %v", err) + } + + status, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + RawPayload: string(payload), + }) + if err != nil { + t.Fatalf("StartSecurityUpdate returned error: %v", err) + } + if status.OverallStatus != SecurityUpdateOverallStatusRolledBack { + t.Fatalf("expected rolled_back status, got %q", status.OverallStatus) + } + + view, err := app.loadStoredGlobalProxyView() + if err != nil { + t.Fatalf("loadStoredGlobalProxyView returned error: %v", err) + } + if view.Host != "proxy-old.local" || view.Port != 8080 || view.User != "ops" { + t.Fatalf("expected existing global proxy metadata to be restored, got %#v", view) + } + bundle, err := app.loadGlobalProxySecretBundle(view) + if err != nil { + t.Fatalf("loadGlobalProxySecretBundle returned error: %v", err) + } + if bundle.Password != "old-proxy-secret" { + t.Fatalf("expected existing global proxy secret to be restored, got %q", bundle.Password) + } +} + +func TestStartSecurityUpdateRollsBackAllChangesWhenPreviewArtifactWriteFails(t *testing.T) { + store := newFakeAppSecretStore() + app := NewAppWithSecretStore(store) + app.configDir = t.TempDir() + + writeLegacyAIProviderConfig(t, app.configDir, map[string]any{ + "providers": []map[string]any{ + { + "id": "openai-main", + "type": "openai", + "name": "OpenAI", + "apiKey": "sk-ai-test", + "baseUrl": "https://api.openai.com/v1", + "headers": map[string]any{ + "Authorization": "Bearer ai-test", + }, + }, + }, + }) + + restoreWriteJSONFile := swapSecurityUpdateWriteJSONFile(func(path string, payload any) error { + if strings.HasSuffix(filepath.ToSlash(path), "/"+securityUpdateNormalizedPreviewFileName) { + return errors.New("forced preview write failure") + } + return writeJSONFile(path, payload) + }) + defer restoreWriteJSONFile() + + status, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + RawPayload: buildLegacySecurityUpdatePayload(), + }) + if err != nil { + t.Fatalf("StartSecurityUpdate returned error: %v", err) + } + if status.OverallStatus != SecurityUpdateOverallStatusRolledBack { + t.Fatalf("expected rolled_back status, got %q", status.OverallStatus) + } + + assertSecurityUpdateRollbackRestoredCurrentAppState(t, app, store) +} + +func TestStartSecurityUpdateRollsBackAllChangesWhenFinalResultWriteFails(t *testing.T) { + store := newFakeAppSecretStore() + app := NewAppWithSecretStore(store) + app.configDir = t.TempDir() + + writeLegacyAIProviderConfig(t, app.configDir, map[string]any{ + "providers": []map[string]any{ + { + "id": "openai-main", + "type": "openai", + "name": "OpenAI", + "apiKey": "sk-ai-test", + "baseUrl": "https://api.openai.com/v1", + "headers": map[string]any{ + "Authorization": "Bearer ai-test", + }, + }, + }, + }) + + resultWrites := 0 + restoreWriteJSONFile := swapSecurityUpdateWriteJSONFile(func(path string, payload any) error { + if strings.HasSuffix(filepath.ToSlash(path), "/"+securityUpdateResultFileName) { + resultWrites++ + if resultWrites == 2 { + return errors.New("forced result write failure") + } + } + return writeJSONFile(path, payload) + }) + defer restoreWriteJSONFile() + + status, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + RawPayload: buildLegacySecurityUpdatePayload(), + }) + if err != nil { + t.Fatalf("StartSecurityUpdate returned error: %v", err) + } + if status.OverallStatus != SecurityUpdateOverallStatusRolledBack { + t.Fatalf("expected rolled_back status, got %q", status.OverallStatus) + } + + assertSecurityUpdateRollbackRestoredCurrentAppState(t, app, store) +} + +func buildLegacySecurityUpdatePayload() string { + payload, _ := json.Marshal(map[string]any{ + "state": map[string]any{ + "connections": []map[string]any{ + { + "id": "legacy-1", + "name": "Legacy", + "config": map[string]any{ + "id": "legacy-1", + "type": "postgres", + "host": "db.local", + "port": 5432, + "user": "postgres", + "password": "postgres-secret", + }, + }, + }, + "globalProxy": map[string]any{ + "enabled": true, + "type": "http", + "host": "127.0.0.1", + "port": 8080, + "user": "ops", + "password": "proxy-secret", + }, + }, + }) + return string(payload) +} + +func writeLegacyAIProviderConfig(t *testing.T, configDir string, payload map[string]any) { + t.Helper() + + data, err := json.MarshalIndent(payload, "", " ") + if err != nil { + t.Fatalf("MarshalIndent returned error: %v", err) + } + if err := os.WriteFile(filepath.Join(configDir, "ai_config.json"), data, 0o644); err != nil { + t.Fatalf("WriteFile returned error: %v", err) + } +} + +func swapSecurityUpdateWriteJSONFile(next func(path string, payload any) error) func() { + original := securityUpdateWriteJSONFile + securityUpdateWriteJSONFile = next + return func() { + securityUpdateWriteJSONFile = original + } +} + +func assertSecurityUpdateRollbackRestoredCurrentAppState(t *testing.T, app *App, store *fakeAppSecretStore) { + t.Helper() + + savedConnections, err := app.GetSavedConnections() + if err != nil { + t.Fatalf("GetSavedConnections returned error: %v", err) + } + if len(savedConnections) != 0 { + t.Fatalf("expected rollback to leave no imported connections, got %#v", savedConnections) + } + + if _, err := app.loadStoredGlobalProxyView(); !os.IsNotExist(err) { + t.Fatalf("expected rollback to remove imported global proxy, got err=%v", err) + } + + inspection, err := aiservice.NewProviderConfigStore(app.configDir, app.secretStore).Inspect() + if err != nil { + t.Fatalf("Inspect returned error: %v", err) + } + if len(inspection.ProvidersNeedingMigration) != 1 || inspection.ProvidersNeedingMigration[0] != "openai-main" { + t.Fatalf("expected AI provider migration requirement to be restored, got %#v", inspection.ProvidersNeedingMigration) + } + + ref, err := secretstore.BuildRef("ai-provider", "openai-main") + if err != nil { + t.Fatalf("BuildRef returned error: %v", err) + } + if _, err := store.Get(ref); !os.IsNotExist(err) { + t.Fatalf("expected rollback to remove migrated AI provider secret, got err=%v", err) + } +} diff --git a/internal/app/security_update_rollback.go b/internal/app/security_update_rollback.go new file mode 100644 index 0000000..3ba633b --- /dev/null +++ b/internal/app/security_update_rollback.go @@ -0,0 +1,314 @@ +package app + +import ( + "os" + "path/filepath" + "strings" + + aiservice "GoNavi-Wails/internal/ai/service" + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/secretstore" +) + +const ( + securityUpdateAIConfigFileName = "ai_config.json" + securityUpdateAIProviderSecretKind = "ai-provider" +) + +type securityUpdateSecretSnapshot struct { + Exists bool + Payload []byte +} + +type securityUpdateCurrentAppRollbackSnapshot struct { + connectionsFileExists bool + connectionsFileData []byte + connectionSecrets map[string]securityUpdateSecretSnapshot + connectionCleanupRefs []string + + globalProxyFileExists bool + globalProxyFileData []byte + globalProxySecretRef string + globalProxySecret securityUpdateSecretSnapshot + globalProxyCleanupRef string + + aiConfigFileExists bool + aiConfigFileData []byte + aiProviderSecrets map[string]securityUpdateSecretSnapshot + aiProviderCleanupRefs []string +} + +func captureSecurityUpdateCurrentAppRollbackSnapshot(a *App, source securityUpdateCurrentAppSource) (securityUpdateCurrentAppRollbackSnapshot, error) { + snapshot := securityUpdateCurrentAppRollbackSnapshot{ + connectionSecrets: make(map[string]securityUpdateSecretSnapshot), + aiProviderSecrets: make(map[string]securityUpdateSecretSnapshot), + } + configDir := strings.TrimSpace(a.configDir) + if configDir == "" { + configDir = resolveAppConfigDir() + } + + connectionRepo := a.savedConnectionRepository() + connectionFileData, connectionFileExists, err := readOptionalFile(connectionRepo.connectionsPath()) + if err != nil { + return snapshot, err + } + snapshot.connectionsFileExists = connectionFileExists + snapshot.connectionsFileData = connectionFileData + + existingConnections, err := connectionRepo.load() + if err != nil { + return snapshot, err + } + existingConnectionsByID := make(map[string]connection.SavedConnectionView, len(existingConnections)) + for _, item := range existingConnections { + existingConnectionsByID[item.ID] = item + } + + connectionCleanupSet := make(map[string]struct{}) + for _, item := range source.Connections { + connectionID := strings.TrimSpace(item.ID) + if connectionID == "" { + connectionID = strings.TrimSpace(item.Config.ID) + } + if connectionID == "" { + continue + } + + defaultRef, refErr := secretstore.BuildRef(savedConnectionSecretKind, connectionID) + if refErr == nil { + connectionCleanupSet[defaultRef] = struct{}{} + } + + existing, ok := existingConnectionsByID[connectionID] + if !ok || !savedConnectionViewHasSecrets(existing) { + continue + } + + ref := strings.TrimSpace(existing.SecretRef) + if ref == "" { + ref = defaultRef + } + if ref == "" { + continue + } + + secretSnapshot, captureErr := captureSecurityUpdateSecretSnapshot(a.secretStore, ref) + if captureErr != nil { + return snapshot, captureErr + } + snapshot.connectionSecrets[ref] = secretSnapshot + connectionCleanupSet[ref] = struct{}{} + } + + snapshot.connectionCleanupRefs = make([]string, 0, len(connectionCleanupSet)) + for ref := range connectionCleanupSet { + snapshot.connectionCleanupRefs = append(snapshot.connectionCleanupRefs, ref) + } + + if source.GlobalProxy != nil { + globalProxyFileData, globalProxyFileExists, err := readOptionalFile(globalProxyMetadataPath(configDir)) + if err != nil { + return snapshot, err + } + snapshot.globalProxyFileExists = globalProxyFileExists + snapshot.globalProxyFileData = globalProxyFileData + + defaultProxyRef, refErr := secretstore.BuildRef(globalProxySecretKind, globalProxySecretID) + if refErr == nil { + snapshot.globalProxyCleanupRef = defaultProxyRef + } + + existingProxy, err := a.loadStoredGlobalProxyView() + if err != nil { + if !os.IsNotExist(err) { + return snapshot, err + } + } else if existingProxy.HasPassword { + ref := strings.TrimSpace(existingProxy.SecretRef) + if ref == "" { + ref = snapshot.globalProxyCleanupRef + } + if ref != "" { + secretSnapshot, captureErr := captureSecurityUpdateSecretSnapshot(a.secretStore, ref) + if captureErr != nil { + return snapshot, captureErr + } + snapshot.globalProxySecretRef = ref + snapshot.globalProxySecret = secretSnapshot + } + } + } + + aiConfigPath := filepath.Join(configDir, securityUpdateAIConfigFileName) + aiConfigFileData, aiConfigFileExists, err := readOptionalFile(aiConfigPath) + if err != nil { + return snapshot, err + } + snapshot.aiConfigFileExists = aiConfigFileExists + snapshot.aiConfigFileData = aiConfigFileData + + inspection, err := aiservice.NewProviderConfigStore(configDir, a.secretStore).Inspect() + if err != nil { + return snapshot, err + } + aiProviderCleanupSet := make(map[string]struct{}) + for _, provider := range inspection.Snapshot.Providers { + providerID := strings.TrimSpace(provider.ID) + if providerID == "" { + continue + } + + ref := strings.TrimSpace(provider.SecretRef) + if ref == "" && (provider.HasSecret || strings.TrimSpace(provider.APIKey) != "" || len(provider.Headers) > 0) { + builtRef, refErr := secretstore.BuildRef(securityUpdateAIProviderSecretKind, providerID) + if refErr == nil { + ref = builtRef + } + } + if ref == "" { + continue + } + + secretSnapshot, captureErr := captureSecurityUpdateSecretSnapshot(a.secretStore, ref) + if captureErr != nil { + return snapshot, captureErr + } + snapshot.aiProviderSecrets[ref] = secretSnapshot + aiProviderCleanupSet[ref] = struct{}{} + } + snapshot.aiProviderCleanupRefs = make([]string, 0, len(aiProviderCleanupSet)) + for ref := range aiProviderCleanupSet { + snapshot.aiProviderCleanupRefs = append(snapshot.aiProviderCleanupRefs, ref) + } + return snapshot, nil +} + +func (s securityUpdateCurrentAppRollbackSnapshot) restore(a *App) error { + configDir := strings.TrimSpace(a.configDir) + if configDir == "" { + configDir = resolveAppConfigDir() + } + connectionRepo := a.savedConnectionRepository() + if err := restoreOptionalFile(connectionRepo.connectionsPath(), s.connectionsFileExists, s.connectionsFileData); err != nil { + return err + } + for ref, secretSnapshot := range s.connectionSecrets { + if err := restoreSecurityUpdateSecretSnapshot(a.secretStore, ref, secretSnapshot); err != nil { + return err + } + } + for _, ref := range s.connectionCleanupRefs { + if _, alreadyRestored := s.connectionSecrets[ref]; alreadyRestored { + continue + } + if err := deleteSecurityUpdateSecretRef(a.secretStore, ref); err != nil { + return err + } + } + + if err := restoreOptionalFile(globalProxyMetadataPath(configDir), s.globalProxyFileExists, s.globalProxyFileData); err != nil { + return err + } + if s.globalProxySecretRef != "" { + if err := restoreSecurityUpdateSecretSnapshot(a.secretStore, s.globalProxySecretRef, s.globalProxySecret); err != nil { + return err + } + } + if s.globalProxyCleanupRef != "" && s.globalProxyCleanupRef != s.globalProxySecretRef { + if err := deleteSecurityUpdateSecretRef(a.secretStore, s.globalProxyCleanupRef); err != nil { + return err + } + } + + if err := restoreOptionalFile(filepath.Join(configDir, securityUpdateAIConfigFileName), s.aiConfigFileExists, s.aiConfigFileData); err != nil { + return err + } + for ref, secretSnapshot := range s.aiProviderSecrets { + if err := restoreSecurityUpdateSecretSnapshot(a.secretStore, ref, secretSnapshot); err != nil { + return err + } + } + for _, ref := range s.aiProviderCleanupRefs { + if _, alreadyRestored := s.aiProviderSecrets[ref]; alreadyRestored { + continue + } + if err := deleteSecurityUpdateSecretRef(a.secretStore, ref); err != nil { + return err + } + } + + if s.globalProxyFileExists { + a.loadPersistedGlobalProxy() + return nil + } + _, err := setGlobalProxyConfig(false, connection.ProxyConfig{}) + return err +} + +func readOptionalFile(path string) ([]byte, bool, error) { + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, false, nil + } + return nil, false, err + } + return append([]byte(nil), data...), true, nil +} + +func restoreOptionalFile(path string, exists bool, data []byte) error { + if !exists { + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return err + } + return nil + } + return os.WriteFile(path, data, 0o644) +} + +func captureSecurityUpdateSecretSnapshot(store secretstore.SecretStore, ref string) (securityUpdateSecretSnapshot, error) { + if store == nil || strings.TrimSpace(ref) == "" { + return securityUpdateSecretSnapshot{}, nil + } + payload, err := store.Get(ref) + if err != nil { + if os.IsNotExist(err) || secretstore.IsUnavailable(err) { + return securityUpdateSecretSnapshot{}, nil + } + return securityUpdateSecretSnapshot{}, err + } + return securityUpdateSecretSnapshot{ + Exists: true, + Payload: append([]byte(nil), payload...), + }, nil +} + +func restoreSecurityUpdateSecretSnapshot(store secretstore.SecretStore, ref string, snapshot securityUpdateSecretSnapshot) error { + if store == nil || strings.TrimSpace(ref) == "" { + return nil + } + if snapshot.Exists { + if err := store.Put(ref, snapshot.Payload); err != nil { + if secretstore.IsUnavailable(err) { + return nil + } + return err + } + return nil + } + return deleteSecurityUpdateSecretRef(store, ref) +} + +func deleteSecurityUpdateSecretRef(store secretstore.SecretStore, ref string) error { + if store == nil || strings.TrimSpace(ref) == "" { + return nil + } + if err := store.Delete(ref); err != nil { + if os.IsNotExist(err) || secretstore.IsUnavailable(err) { + return nil + } + return err + } + return nil +} diff --git a/internal/app/security_update_source_current_app.go b/internal/app/security_update_source_current_app.go new file mode 100644 index 0000000..1239d68 --- /dev/null +++ b/internal/app/security_update_source_current_app.go @@ -0,0 +1,85 @@ +package app + +import ( + "encoding/json" + "strings" + + "GoNavi-Wails/internal/connection" +) + +const ( + securityUpdateSourceCurrentAppFileName = "source-current-app.json" + securityUpdateNormalizedPreviewFileName = "normalized-preview.json" +) + +type securityUpdateCurrentAppEnvelope struct { + State securityUpdateCurrentAppPayload `json:"state"` + Connections []connection.LegacySavedConnection `json:"connections"` + GlobalProxy *connection.LegacyGlobalProxyInput `json:"globalProxy"` +} + +type securityUpdateCurrentAppPayload struct { + Connections []connection.LegacySavedConnection `json:"connections"` + GlobalProxy *connection.LegacyGlobalProxyInput `json:"globalProxy"` +} + +type securityUpdateCurrentAppSource struct { + Connections []connection.LegacySavedConnection `json:"connections"` + GlobalProxy *connection.LegacyGlobalProxyInput `json:"globalProxy,omitempty"` +} + +func parseSecurityUpdateCurrentAppSource(rawPayload string) (securityUpdateCurrentAppSource, any, error) { + trimmed := strings.TrimSpace(rawPayload) + if trimmed == "" { + return securityUpdateCurrentAppSource{Connections: []connection.LegacySavedConnection{}}, map[string]any{}, nil + } + + var raw any + if err := json.Unmarshal([]byte(trimmed), &raw); err != nil { + return securityUpdateCurrentAppSource{}, nil, err + } + + var envelope securityUpdateCurrentAppEnvelope + if err := json.Unmarshal([]byte(trimmed), &envelope); err != nil { + return securityUpdateCurrentAppSource{}, nil, err + } + + connections := envelope.Connections + globalProxy := envelope.GlobalProxy + if len(envelope.State.Connections) > 0 || envelope.State.GlobalProxy != nil { + connections = envelope.State.Connections + globalProxy = envelope.State.GlobalProxy + } + + normalizedConnections := make([]connection.LegacySavedConnection, 0, len(connections)) + for _, item := range connections { + if strings.TrimSpace(item.ID) == "" && strings.TrimSpace(item.Config.ID) == "" { + continue + } + if strings.TrimSpace(item.ID) == "" { + item.ID = strings.TrimSpace(item.Config.ID) + } + item.Config.ID = item.ID + normalizedConnections = append(normalizedConnections, item) + } + + if globalProxy != nil { + normalizedType := strings.ToLower(strings.TrimSpace(globalProxy.Type)) + if normalizedType != "http" { + normalizedType = "socks5" + } + globalProxy.Type = normalizedType + if globalProxy.Port <= 0 || globalProxy.Port > 65535 { + if normalizedType == "http" { + globalProxy.Port = 8080 + } else { + globalProxy.Port = 1080 + } + } + } + + return securityUpdateCurrentAppSource{ + Connections: normalizedConnections, + GlobalProxy: globalProxy, + }, raw, nil +} diff --git a/internal/app/security_update_state.go b/internal/app/security_update_state.go new file mode 100644 index 0000000..de0d063 --- /dev/null +++ b/internal/app/security_update_state.go @@ -0,0 +1,293 @@ +package app + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/uuid" +) + +const ( + securityUpdateSchemaVersion = 1 + securityUpdateMarkerDirName = "migrations" + securityUpdateMarkerFileName = "config-security-update.json" + securityUpdateBackupRootDirName = "migration-backups" + securityUpdateManifestFileName = "manifest.json" + securityUpdateResultFileName = "result.json" +) + +var securityUpdateWriteJSONFile = writeJSONFile + +type securityUpdateStateRepository struct { + configDir string +} + +type securityUpdateMarker struct { + SchemaVersion int `json:"schemaVersion"` + MigrationID string `json:"migrationId"` + SourceType SecurityUpdateSourceType `json:"sourceType"` + Status SecurityUpdateOverallStatus `json:"status"` + StartedAt string `json:"startedAt,omitempty"` + UpdatedAt string `json:"updatedAt,omitempty"` + CompletedAt string `json:"completedAt,omitempty"` + PostponedAt string `json:"postponedAt,omitempty"` + BackupPath string `json:"backupPath,omitempty"` + BackupSHA256 string `json:"backupSha256,omitempty"` + Summary SecurityUpdateSummary `json:"summary"` + Issues []SecurityUpdateIssue `json:"issues"` + LastError string `json:"lastError,omitempty"` +} + +type securityUpdateBackupManifest struct { + SchemaVersion int `json:"schemaVersion"` + MigrationID string `json:"migrationId"` + SourceType SecurityUpdateSourceType `json:"sourceType"` + CreatedAt string `json:"createdAt"` + StartedAt string `json:"startedAt,omitempty"` + BackupPath string `json:"backupPath"` +} + +func newSecurityUpdateStateRepository(configDir string) *securityUpdateStateRepository { + if strings.TrimSpace(configDir) == "" { + configDir = resolveAppConfigDir() + } + return &securityUpdateStateRepository{configDir: configDir} +} + +func (r *securityUpdateStateRepository) markerPath() string { + return filepath.Join(r.configDir, securityUpdateMarkerDirName, securityUpdateMarkerFileName) +} + +func (r *securityUpdateStateRepository) backupRootPath() string { + return filepath.Join(r.configDir, securityUpdateBackupRootDirName) +} + +func (r *securityUpdateStateRepository) backupPath(migrationID string) string { + return filepath.Join(r.backupRootPath(), migrationID) +} + +func (r *securityUpdateStateRepository) manifestPath(migrationID string) string { + return filepath.Join(r.backupPath(migrationID), securityUpdateManifestFileName) +} + +func (r *securityUpdateStateRepository) resultPath(migrationID string) string { + return filepath.Join(r.backupPath(migrationID), securityUpdateResultFileName) +} + +func (r *securityUpdateStateRepository) LoadMarker() (SecurityUpdateStatus, error) { + marker, err := r.readMarker() + if err != nil { + return SecurityUpdateStatus{}, err + } + return buildSecurityUpdateStatus(marker), nil +} + +func (r *securityUpdateStateRepository) StartRound(request StartSecurityUpdateRequest) (SecurityUpdateStatus, error) { + marker := r.newRoundMarker(request.SourceType) + if err := r.initializeRoundArtifacts(marker); err != nil { + return SecurityUpdateStatus{}, err + } + status := buildSecurityUpdateStatus(marker) + if err := r.WriteResult(status); err != nil { + return SecurityUpdateStatus{}, err + } + return status, nil +} + +func (r *securityUpdateStateRepository) RetryRound(request RetrySecurityUpdateRequest) (SecurityUpdateStatus, error) { + marker, err := r.readMarker() + if err != nil { + return SecurityUpdateStatus{}, err + } + if requestedID := strings.TrimSpace(request.MigrationID); requestedID != "" && requestedID != marker.MigrationID { + return SecurityUpdateStatus{}, fmt.Errorf("migration ID mismatch: current=%s requested=%s", marker.MigrationID, requestedID) + } + if marker.Status != SecurityUpdateOverallStatusNeedsAttention { + return SecurityUpdateStatus{}, fmt.Errorf( + "retry current round requires status %s: current=%s", + SecurityUpdateOverallStatusNeedsAttention, + marker.Status, + ) + } + marker.Status = SecurityUpdateOverallStatusInProgress + marker.UpdatedAt = nowRFC3339() + if marker.BackupPath == "" { + marker.BackupPath = r.backupPath(marker.MigrationID) + } + if err := os.MkdirAll(marker.BackupPath, 0o755); err != nil { + return SecurityUpdateStatus{}, err + } + status := buildSecurityUpdateStatus(marker) + if err := r.WriteResult(status); err != nil { + return SecurityUpdateStatus{}, err + } + return status, nil +} + +func (r *securityUpdateStateRepository) RestartRound(request RestartSecurityUpdateRequest) (SecurityUpdateStatus, error) { + marker := r.newRoundMarker(request.SourceType) + if err := r.initializeRoundArtifacts(marker); err != nil { + return SecurityUpdateStatus{}, err + } + status := buildSecurityUpdateStatus(marker) + if err := r.WriteResult(status); err != nil { + return SecurityUpdateStatus{}, err + } + return status, nil +} + +func (r *securityUpdateStateRepository) WriteResult(status SecurityUpdateStatus) error { + marker := markerFromStatus(status) + if err := r.writeMarker(marker); err != nil { + return err + } + if strings.TrimSpace(marker.BackupPath) == "" { + return nil + } + if err := os.MkdirAll(marker.BackupPath, 0o755); err != nil { + return err + } + return securityUpdateWriteJSONFile(r.resultPath(marker.MigrationID), buildSecurityUpdateStatus(marker)) +} + +func (r *securityUpdateStateRepository) newRoundMarker(sourceType SecurityUpdateSourceType) securityUpdateMarker { + now := nowRFC3339() + if strings.TrimSpace(string(sourceType)) == "" { + sourceType = SecurityUpdateSourceTypeCurrentAppSavedConfig + } + migrationID := uuid.NewString() + return securityUpdateMarker{ + SchemaVersion: securityUpdateSchemaVersion, + MigrationID: migrationID, + SourceType: sourceType, + Status: SecurityUpdateOverallStatusInProgress, + StartedAt: now, + UpdatedAt: now, + BackupPath: r.backupPath(migrationID), + Summary: SecurityUpdateSummary{}, + Issues: []SecurityUpdateIssue{}, + } +} + +func (r *securityUpdateStateRepository) initializeRoundArtifacts(marker securityUpdateMarker) error { + if err := os.MkdirAll(marker.BackupPath, 0o755); err != nil { + return err + } + manifest := securityUpdateBackupManifest{ + SchemaVersion: securityUpdateSchemaVersion, + MigrationID: marker.MigrationID, + SourceType: marker.SourceType, + CreatedAt: marker.UpdatedAt, + StartedAt: marker.StartedAt, + BackupPath: marker.BackupPath, + } + if err := securityUpdateWriteJSONFile(r.manifestPath(marker.MigrationID), manifest); err != nil { + return err + } + return r.writeMarker(marker) +} + +func (r *securityUpdateStateRepository) readMarker() (securityUpdateMarker, error) { + data, err := os.ReadFile(r.markerPath()) + if err != nil { + return securityUpdateMarker{}, err + } + var marker securityUpdateMarker + if err := json.Unmarshal(data, &marker); err != nil { + return securityUpdateMarker{}, err + } + if marker.Issues == nil { + marker.Issues = []SecurityUpdateIssue{} + } + return marker, nil +} + +func (r *securityUpdateStateRepository) writeMarker(marker securityUpdateMarker) error { + if err := os.MkdirAll(filepath.Dir(r.markerPath()), 0o755); err != nil { + return err + } + return securityUpdateWriteJSONFile(r.markerPath(), marker) +} + +func buildSecurityUpdateStatus(marker securityUpdateMarker) SecurityUpdateStatus { + status := SecurityUpdateStatus{ + SchemaVersion: marker.SchemaVersion, + MigrationID: marker.MigrationID, + OverallStatus: marker.Status, + SourceType: marker.SourceType, + BackupAvailable: strings.TrimSpace(marker.BackupPath) != "", + BackupPath: marker.BackupPath, + StartedAt: marker.StartedAt, + UpdatedAt: marker.UpdatedAt, + CompletedAt: marker.CompletedAt, + PostponedAt: marker.PostponedAt, + Summary: marker.Summary, + Issues: marker.Issues, + LastError: marker.LastError, + } + if status.Issues == nil { + status.Issues = []SecurityUpdateIssue{} + } + switch status.OverallStatus { + case SecurityUpdateOverallStatusPending: + status.ReminderVisible = true + status.CanStart = true + status.CanPostpone = true + case SecurityUpdateOverallStatusPostponed: + status.CanStart = true + case SecurityUpdateOverallStatusNeedsAttention: + status.CanRetry = true + status.CanStart = true + case SecurityUpdateOverallStatusRolledBack: + status.CanStart = true + case SecurityUpdateOverallStatusCompleted: + status.BackupAvailable = strings.TrimSpace(status.BackupPath) != "" + } + return status +} + +func markerFromStatus(status SecurityUpdateStatus) securityUpdateMarker { + marker := securityUpdateMarker{ + SchemaVersion: securityUpdateSchemaVersion, + MigrationID: strings.TrimSpace(status.MigrationID), + SourceType: status.SourceType, + Status: status.OverallStatus, + StartedAt: status.StartedAt, + UpdatedAt: status.UpdatedAt, + CompletedAt: status.CompletedAt, + PostponedAt: status.PostponedAt, + BackupPath: status.BackupPath, + Summary: status.Summary, + Issues: status.Issues, + LastError: status.LastError, + } + if marker.SchemaVersion == 0 { + marker.SchemaVersion = securityUpdateSchemaVersion + } + if marker.Issues == nil { + marker.Issues = []SecurityUpdateIssue{} + } + if marker.BackupPath == "" && marker.MigrationID != "" { + marker.BackupPath = filepath.Join(resolveAppConfigDir(), securityUpdateBackupRootDirName, marker.MigrationID) + } + if marker.UpdatedAt == "" { + marker.UpdatedAt = nowRFC3339() + } + return marker +} + +func writeJSONFile(path string, payload any) error { + data, err := json.MarshalIndent(payload, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0o644) +} + +func nowRFC3339() string { + return time.Now().UTC().Format(time.RFC3339) +} diff --git a/internal/app/security_update_state_test.go b/internal/app/security_update_state_test.go new file mode 100644 index 0000000..09a82f5 --- /dev/null +++ b/internal/app/security_update_state_test.go @@ -0,0 +1,226 @@ +package app + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestSecurityUpdateStateStartRoundCreatesMarkerAndManifest(t *testing.T) { + repo := newSecurityUpdateStateRepository(t.TempDir()) + + status, err := repo.StartRound(StartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + }) + if err != nil { + t.Fatalf("StartRound returned error: %v", err) + } + + if status.MigrationID == "" { + t.Fatal("expected migration ID to be created") + } + if status.SourceType != SecurityUpdateSourceTypeCurrentAppSavedConfig { + t.Fatalf("expected source type %q, got %q", SecurityUpdateSourceTypeCurrentAppSavedConfig, status.SourceType) + } + if status.OverallStatus != SecurityUpdateOverallStatusInProgress { + t.Fatalf("expected overall status %q, got %q", SecurityUpdateOverallStatusInProgress, status.OverallStatus) + } + if !status.BackupAvailable { + t.Fatal("expected backupAvailable=true") + } + + markerPath := filepath.Join(repo.configDir, securityUpdateMarkerDirName, securityUpdateMarkerFileName) + if _, err := os.Stat(markerPath); err != nil { + t.Fatalf("expected marker file at %q: %v", markerPath, err) + } + + data, err := os.ReadFile(markerPath) + if err != nil { + t.Fatalf("ReadFile marker failed: %v", err) + } + + var marker securityUpdateMarker + if err := json.Unmarshal(data, &marker); err != nil { + t.Fatalf("Unmarshal marker failed: %v", err) + } + if marker.MigrationID != status.MigrationID { + t.Fatalf("expected marker migration ID %q, got %q", status.MigrationID, marker.MigrationID) + } + + manifestPath := filepath.Join(repo.configDir, securityUpdateBackupRootDirName, status.MigrationID, securityUpdateManifestFileName) + if _, err := os.Stat(manifestPath); err != nil { + t.Fatalf("expected manifest file at %q: %v", manifestPath, err) + } +} + +func TestSecurityUpdateStateRetryRoundReusesCurrentMigrationID(t *testing.T) { + repo := newSecurityUpdateStateRepository(t.TempDir()) + + initial, err := repo.StartRound(StartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + }) + if err != nil { + t.Fatalf("StartRound returned error: %v", err) + } + + initial.OverallStatus = SecurityUpdateOverallStatusNeedsAttention + initial.UpdatedAt = nowRFC3339() + initial.Summary = SecurityUpdateSummary{ + Total: 1, + Pending: 1, + } + initial.Issues = []SecurityUpdateIssue{ + { + ID: "connection-legacy-1", + Scope: SecurityUpdateIssueScopeConnection, + RefID: "legacy-1", + Title: "Legacy", + Severity: SecurityUpdateIssueSeverityMedium, + Status: SecurityUpdateItemStatusNeedsAttention, + ReasonCode: SecurityUpdateIssueReasonCodeSecretMissing, + Action: SecurityUpdateIssueActionOpenConnection, + Message: "连接密码已丢失,请重新保存后再继续", + }, + } + if err := repo.WriteResult(initial); err != nil { + t.Fatalf("WriteResult returned error: %v", err) + } + + retried, err := repo.RetryRound(RetrySecurityUpdateRequest{ + MigrationID: initial.MigrationID, + }) + if err != nil { + t.Fatalf("RetryRound returned error: %v", err) + } + + if retried.MigrationID != initial.MigrationID { + t.Fatalf("expected retry to reuse migration ID %q, got %q", initial.MigrationID, retried.MigrationID) + } + + entries, err := os.ReadDir(filepath.Join(repo.configDir, securityUpdateBackupRootDirName)) + if err != nil { + t.Fatalf("ReadDir backup root failed: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected retry to keep a single backup directory, got %d", len(entries)) + } +} + +func TestSecurityUpdateStateRetryRoundRejectsRolledBackRound(t *testing.T) { + repo := newSecurityUpdateStateRepository(t.TempDir()) + + marker := securityUpdateMarker{ + SchemaVersion: securityUpdateSchemaVersion, + MigrationID: "migration-1", + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + Status: SecurityUpdateOverallStatusRolledBack, + StartedAt: "2026-04-09T00:00:00Z", + UpdatedAt: "2026-04-09T00:05:00Z", + BackupPath: repo.backupPath("migration-1"), + Summary: SecurityUpdateSummary{ + Total: 1, + Failed: 1, + }, + Issues: []SecurityUpdateIssue{ + { + ID: "system-blocked", + Scope: SecurityUpdateIssueScopeSystem, + Title: "安全更新未完成", + Severity: SecurityUpdateIssueSeverityHigh, + Status: SecurityUpdateItemStatusFailed, + ReasonCode: SecurityUpdateIssueReasonCodeEnvironmentBlocked, + Action: SecurityUpdateIssueActionViewDetails, + Message: "当前环境无法完成本次安全更新,请稍后重试", + }, + }, + } + if err := repo.writeMarker(marker); err != nil { + t.Fatalf("writeMarker returned error: %v", err) + } + + if _, err := repo.RetryRound(RetrySecurityUpdateRequest{MigrationID: marker.MigrationID}); err == nil { + t.Fatal("expected RetryRound to reject rolled_back round") + } + + current, err := repo.LoadMarker() + if err != nil { + t.Fatalf("LoadMarker returned error: %v", err) + } + if current.OverallStatus != SecurityUpdateOverallStatusRolledBack { + t.Fatalf("expected marker to remain rolled_back, got %q", current.OverallStatus) + } +} + +func TestBuildSecurityUpdateStatusDoesNotAllowRetryAfterRollback(t *testing.T) { + status := buildSecurityUpdateStatus(securityUpdateMarker{ + SchemaVersion: securityUpdateSchemaVersion, + MigrationID: "migration-1", + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + Status: SecurityUpdateOverallStatusRolledBack, + StartedAt: "2026-04-09T00:00:00Z", + UpdatedAt: "2026-04-09T00:05:00Z", + BackupPath: filepath.Join("backup", "migration-1"), + Summary: SecurityUpdateSummary{ + Total: 1, + Failed: 1, + }, + Issues: []SecurityUpdateIssue{ + { + ID: "system-blocked", + Scope: SecurityUpdateIssueScopeSystem, + Title: "安全更新未完成", + Severity: SecurityUpdateIssueSeverityHigh, + Status: SecurityUpdateItemStatusFailed, + ReasonCode: SecurityUpdateIssueReasonCodeEnvironmentBlocked, + Action: SecurityUpdateIssueActionViewDetails, + Message: "当前环境无法完成本次安全更新,请稍后重试", + }, + }, + }) + + if status.CanRetry { + t.Fatal("expected rolled_back status to require restart instead of retry") + } + if !status.CanStart { + t.Fatal("expected rolled_back status to allow starting a new round") + } +} + +func TestSecurityUpdateStateRestartRoundCreatesNewMigrationID(t *testing.T) { + repo := newSecurityUpdateStateRepository(t.TempDir()) + + initial, err := repo.StartRound(StartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + }) + if err != nil { + t.Fatalf("StartRound returned error: %v", err) + } + + restarted, err := repo.RestartRound(RestartSecurityUpdateRequest{ + SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig, + }) + if err != nil { + t.Fatalf("RestartRound returned error: %v", err) + } + + if restarted.MigrationID == initial.MigrationID { + t.Fatal("expected restart to create a new migration ID") + } + + entries, err := os.ReadDir(filepath.Join(repo.configDir, securityUpdateBackupRootDirName)) + if err != nil { + t.Fatalf("ReadDir backup root failed: %v", err) + } + if len(entries) != 2 { + t.Fatalf("expected restart to create a second backup directory, got %d", len(entries)) + } + + current, err := repo.LoadMarker() + if err != nil { + t.Fatalf("LoadMarker returned error: %v", err) + } + if current.MigrationID != restarted.MigrationID { + t.Fatalf("expected marker to point to latest migration ID %q, got %q", restarted.MigrationID, current.MigrationID) + } +} diff --git a/internal/app/security_update_types.go b/internal/app/security_update_types.go new file mode 100644 index 0000000..9f57b6d --- /dev/null +++ b/internal/app/security_update_types.go @@ -0,0 +1,129 @@ +package app + +type SecurityUpdateSourceType string + +const ( + SecurityUpdateSourceTypeCurrentAppSavedConfig SecurityUpdateSourceType = "current_app_saved_config" +) + +type SecurityUpdateOverallStatus string + +const ( + SecurityUpdateOverallStatusNotDetected SecurityUpdateOverallStatus = "not_detected" + SecurityUpdateOverallStatusPending SecurityUpdateOverallStatus = "pending" + SecurityUpdateOverallStatusPostponed SecurityUpdateOverallStatus = "postponed" + SecurityUpdateOverallStatusInProgress SecurityUpdateOverallStatus = "in_progress" + SecurityUpdateOverallStatusNeedsAttention SecurityUpdateOverallStatus = "needs_attention" + SecurityUpdateOverallStatusCompleted SecurityUpdateOverallStatus = "completed" + SecurityUpdateOverallStatusRolledBack SecurityUpdateOverallStatus = "rolled_back" +) + +type SecurityUpdateIssueScope string + +const ( + SecurityUpdateIssueScopeConnection SecurityUpdateIssueScope = "connection" + SecurityUpdateIssueScopeGlobalProxy SecurityUpdateIssueScope = "global_proxy" + SecurityUpdateIssueScopeAIProvider SecurityUpdateIssueScope = "ai_provider" + SecurityUpdateIssueScopeSystem SecurityUpdateIssueScope = "system" +) + +type SecurityUpdateIssueSeverity string + +const ( + SecurityUpdateIssueSeverityHigh SecurityUpdateIssueSeverity = "high" + SecurityUpdateIssueSeverityMedium SecurityUpdateIssueSeverity = "medium" + SecurityUpdateIssueSeverityLow SecurityUpdateIssueSeverity = "low" +) + +type SecurityUpdateItemStatus string + +const ( + SecurityUpdateItemStatusPending SecurityUpdateItemStatus = "pending" + SecurityUpdateItemStatusUpdated SecurityUpdateItemStatus = "updated" + SecurityUpdateItemStatusNeedsAttention SecurityUpdateItemStatus = "needs_attention" + SecurityUpdateItemStatusSkipped SecurityUpdateItemStatus = "skipped" + SecurityUpdateItemStatusFailed SecurityUpdateItemStatus = "failed" +) + +type SecurityUpdateIssueReasonCode string + +const ( + SecurityUpdateIssueReasonCodeMigrationRequired SecurityUpdateIssueReasonCode = "migration_required" + SecurityUpdateIssueReasonCodeSecretMissing SecurityUpdateIssueReasonCode = "secret_missing" + SecurityUpdateIssueReasonCodeFieldInvalid SecurityUpdateIssueReasonCode = "field_invalid" + SecurityUpdateIssueReasonCodeWriteConflict SecurityUpdateIssueReasonCode = "write_conflict" + SecurityUpdateIssueReasonCodeValidationFailed SecurityUpdateIssueReasonCode = "validation_failed" + SecurityUpdateIssueReasonCodeEnvironmentBlocked SecurityUpdateIssueReasonCode = "environment_blocked" +) + +type SecurityUpdateIssueAction string + +const ( + SecurityUpdateIssueActionOpenConnection SecurityUpdateIssueAction = "open_connection" + SecurityUpdateIssueActionOpenProxySettings SecurityUpdateIssueAction = "open_proxy_settings" + SecurityUpdateIssueActionOpenAISettings SecurityUpdateIssueAction = "open_ai_settings" + SecurityUpdateIssueActionRetryUpdate SecurityUpdateIssueAction = "retry_update" + SecurityUpdateIssueActionViewDetails SecurityUpdateIssueAction = "view_details" +) + +type SecurityUpdateSummary struct { + Total int `json:"total"` + Updated int `json:"updated"` + Pending int `json:"pending"` + Skipped int `json:"skipped"` + Failed int `json:"failed"` +} + +type SecurityUpdateIssue struct { + ID string `json:"id"` + Scope SecurityUpdateIssueScope `json:"scope"` + RefID string `json:"refId,omitempty"` + Title string `json:"title"` + Severity SecurityUpdateIssueSeverity `json:"severity"` + Status SecurityUpdateItemStatus `json:"status"` + ReasonCode SecurityUpdateIssueReasonCode `json:"reasonCode"` + Action SecurityUpdateIssueAction `json:"action"` + Message string `json:"message"` +} + +type SecurityUpdateStatus struct { + SchemaVersion int `json:"schemaVersion,omitempty"` + MigrationID string `json:"migrationId,omitempty"` + OverallStatus SecurityUpdateOverallStatus `json:"overallStatus"` + SourceType SecurityUpdateSourceType `json:"sourceType,omitempty"` + ReminderVisible bool `json:"reminderVisible"` + CanStart bool `json:"canStart"` + CanPostpone bool `json:"canPostpone"` + CanRetry bool `json:"canRetry"` + BackupAvailable bool `json:"backupAvailable"` + BackupPath string `json:"backupPath,omitempty"` + StartedAt string `json:"startedAt,omitempty"` + UpdatedAt string `json:"updatedAt,omitempty"` + CompletedAt string `json:"completedAt,omitempty"` + PostponedAt string `json:"postponedAt,omitempty"` + Summary SecurityUpdateSummary `json:"summary"` + Issues []SecurityUpdateIssue `json:"issues"` + LastError string `json:"lastError,omitempty"` +} + +type SecurityUpdateOptions struct { + AllowPartial bool `json:"allowPartial,omitempty"` + WriteBackup bool `json:"writeBackup,omitempty"` +} + +type StartSecurityUpdateRequest struct { + SourceType SecurityUpdateSourceType `json:"sourceType"` + RawPayload string `json:"rawPayload,omitempty"` + Options *SecurityUpdateOptions `json:"options,omitempty"` +} + +type RetrySecurityUpdateRequest struct { + MigrationID string `json:"migrationId,omitempty"` +} + +type RestartSecurityUpdateRequest struct { + MigrationID string `json:"migrationId,omitempty"` + SourceType SecurityUpdateSourceType `json:"sourceType"` + RawPayload string `json:"rawPayload,omitempty"` + Options *SecurityUpdateOptions `json:"options,omitempty"` +} From 82e06bd94d882c7acf18b3ac22b6ffe546a0f2d4 Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Sat, 11 Apr 2026 16:53:03 +0800 Subject: [PATCH 3/7] =?UTF-8?q?=F0=9F=90=9B=20fix(security):=20=E5=AE=8C?= =?UTF-8?q?=E5=96=84=E5=AF=86=E6=96=87=E5=8D=87=E7=BA=A7=E5=AF=BC=E5=85=A5?= =?UTF-8?q?=E8=A6=86=E7=9B=96=E4=B8=8E=E5=AE=89=E5=85=A8=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E9=93=BE=E8=B7=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 完善连接恢复包与 legacy 导入覆盖语义及密文兼容处理 - 修复安全更新详情高亮反馈与相关前后端链路 - 补强 keyring 误判边界与安全更新回归测试 --- frontend/package-lock.json | 4 +- frontend/package.json | 2 +- frontend/package.json.md5 | 2 +- frontend/src/App.css | 44 +++ frontend/src/App.tsx | 55 ++- .../src/components/SecurityUpdateBanner.tsx | 31 +- .../components/SecurityUpdateIntroModal.tsx | 38 ++- .../SecurityUpdateProgressModal.tsx | 2 + .../SecurityUpdateSettingsModal.tsx | 129 +++++-- frontend/src/main.tsx | 1 + frontend/src/utils/connectionExport.test.ts | 52 +++ frontend/src/utils/connectionExport.ts | 49 +++ .../src/utils/secureConfigBootstrap.test.ts | 173 +++++++++- frontend/src/utils/secureConfigBootstrap.ts | 89 ++++- .../utils/securityUpdateRepairFlow.test.ts | 62 +++- .../src/utils/securityUpdateRepairFlow.ts | 38 ++- .../src/utils/securityUpdateVisuals.test.ts | 88 +++++ frontend/src/utils/securityUpdateVisuals.ts | 65 ++++ frontend/wailsjs/go/app/App.d.ts | 2 + frontend/wailsjs/go/app/App.js | 4 + internal/app/connection_package_crypto.go | 15 + .../app/connection_package_crypto_test.go | 73 ++++ internal/app/connection_package_transfer.go | 101 +++++- .../app/connection_package_transfer_test.go | 315 +++++++++++++++++- internal/app/connection_package_types.go | 6 + internal/app/methods_file.go | 23 +- internal/app/methods_file_import_test.go | 33 ++ internal/app/methods_redis.go | 14 +- internal/app/methods_redis_test.go | 258 ++++++++++++++ internal/app/methods_saved_connections.go | 26 +- .../app/methods_saved_connections_test.go | 6 +- internal/app/methods_update.go | 60 +++- internal/app/methods_update_test.go | 160 +++++++++ internal/secretstore/keyring_store.go | 20 +- internal/secretstore/keyring_store_test.go | 91 +++++ 35 files changed, 2021 insertions(+), 110 deletions(-) create mode 100644 frontend/src/utils/securityUpdateVisuals.test.ts create mode 100644 frontend/src/utils/securityUpdateVisuals.ts create mode 100644 internal/app/methods_file_import_test.go create mode 100644 internal/app/methods_redis_test.go create mode 100644 internal/app/methods_update_test.go diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 98891ca..d6c584f 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -1,12 +1,12 @@ { "name": "gonavi-client", - "version": "0.0.1", + "version": "0.6.5", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "gonavi-client", - "version": "0.0.1", + "version": "0.6.5", "dependencies": { "@ant-design/icons": "^5.2.6", "@dnd-kit/core": "^6.3.1", diff --git a/frontend/package.json b/frontend/package.json index 904b7a2..daddbfb 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -1,7 +1,7 @@ { "name": "gonavi-client", "private": true, - "version": "0.0.1", + "version": "0.6.5", "type": "module", "scripts": { "dev": "vite", diff --git a/frontend/package.json.md5 b/frontend/package.json.md5 index ad6ce0c..5774671 100755 --- a/frontend/package.json.md5 +++ b/frontend/package.json.md5 @@ -1 +1 @@ -20168ff7047e0ecea00acb73f413f7db \ No newline at end of file +8cc5d6401a6ce7dd0f500c66ce8bb4a9 \ No newline at end of file diff --git a/frontend/src/App.css b/frontend/src/App.css index 24b8e5b..8927590 100644 --- a/frontend/src/App.css +++ b/frontend/src/App.css @@ -340,3 +340,47 @@ body[data-theme='light'] .redis-viewer-workbench .ant-radio-button-wrapper-check .driver-manager-hscroll-inner { height: 1px; } + +.security-update-action-btn.ant-btn, +.security-update-action-btn.ant-btn-default, +.security-update-action-btn.ant-btn-primary, +.security-update-action-btn.ant-btn-text { + box-shadow: none !important; +} + +.security-update-action-btn.ant-btn:focus, +.security-update-action-btn.ant-btn:focus-visible, +.security-update-action-btn.ant-btn-default:focus, +.security-update-action-btn.ant-btn-default:focus-visible, +.security-update-action-btn.ant-btn-primary:focus, +.security-update-action-btn.ant-btn-primary:focus-visible, +.security-update-action-btn.ant-btn-text:focus, +.security-update-action-btn.ant-btn-text:focus-visible { + outline: none !important; + box-shadow: none !important; +} + +.security-update-banner { + position: relative; + isolation: isolate; +} + +.security-update-result-card { + transition: background 0.22s ease, box-shadow 0.22s ease, transform 0.22s ease; +} + +.security-update-result-card-active { + animation: security-update-result-pulse 1.8s ease; +} + +@keyframes security-update-result-pulse { + 0% { + transform: translateY(0); + } + 30% { + transform: translateY(-2px); + } + 100% { + transform: translateY(0); + } +} diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 76d869b..59a200e 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -24,7 +24,11 @@ import { getMacNativeTitlebarPaddingLeft, getMacNativeTitlebarPaddingRight, shou import { buildOverlayWorkbenchTheme } from './utils/overlayWorkbenchTheme'; import { getConnectionWorkbenchState } from './utils/startupReadiness'; import { toSaveGlobalProxyInput } from './utils/globalProxyDraft'; -import { detectConnectionImportKind, normalizeConnectionPackagePassword } from './utils/connectionExport'; +import { + detectConnectionImportKind, + resolveConnectionPackageExportResult, + normalizeConnectionPackagePassword, +} from './utils/connectionExport'; import { bootstrapSecureConfig, finalizeSecurityUpdateStatus, @@ -41,10 +45,13 @@ import { resolveSecurityUpdateEntryVisibility, } from './utils/securityUpdatePresentation'; import { + hasSecurityUpdateRecentResult, resolveSecurityUpdateRepairEntry, + resolveSecurityUpdateSettingsFocusTarget, shouldReopenSecurityUpdateDetails, shouldRetrySecurityUpdateAfterRepairSave, type SecurityUpdateRepairSource, + type SecurityUpdateSettingsFocusTarget, } from './utils/securityUpdateRepairFlow'; import { SHORTCUT_ACTION_META, @@ -175,6 +182,8 @@ function App() { const [isSecurityUpdateIntroOpen, setIsSecurityUpdateIntroOpen] = useState(false); const [isSecurityUpdateBannerDismissed, setIsSecurityUpdateBannerDismissed] = useState(false); const [isSecurityUpdateSettingsOpen, setIsSecurityUpdateSettingsOpen] = useState(false); + const [securityUpdateSettingsFocusTarget, setSecurityUpdateSettingsFocusTarget] = useState(null); + const [securityUpdateSettingsFocusRequest, setSecurityUpdateSettingsFocusRequest] = useState(0); const [isSecurityUpdateProgressOpen, setIsSecurityUpdateProgressOpen] = useState(false); const [securityUpdateProgressStage, setSecurityUpdateProgressStage] = useState('正在检查已保存配置'); const [securityUpdateRepairSource, setSecurityUpdateRepairSource] = useState(null); @@ -278,6 +287,8 @@ function App() { setIsSecurityUpdateBannerDismissed(false); } if (options?.openSettings) { + setSecurityUpdateSettingsFocusTarget(resolveSecurityUpdateSettingsFocusTarget(nextStatus)); + setSecurityUpdateSettingsFocusRequest((current) => current + 1); setIsSecurityUpdateSettingsOpen(true); } return nextStatus; @@ -820,10 +831,15 @@ function App() { const connections = useStore(state => state.connections); const tabs = useStore(state => state.tabs); const activeTabId = useStore(state => state.activeTabId); - const handleOpenSecurityUpdateSettings = useCallback(() => { + const openSecurityUpdateSettings = useCallback((focusTarget: SecurityUpdateSettingsFocusTarget | null = null) => { setIsSecurityUpdateIntroOpen(false); + setSecurityUpdateSettingsFocusTarget(focusTarget); + setSecurityUpdateSettingsFocusRequest((current) => current + 1); setIsSecurityUpdateSettingsOpen(true); }, []); + const handleOpenSecurityUpdateSettings = useCallback((focusTarget: SecurityUpdateSettingsFocusTarget | null = null) => { + openSecurityUpdateSettings(focusTarget); + }, [openSecurityUpdateSettings]); const runSecurityUpdateRound = useCallback(async (mode: 'start' | 'retry' | 'restart') => { const backendApp = (window as any).go?.app?.App; const stageText = mode === 'retry' @@ -943,7 +959,7 @@ function App() { securityUpdateStatus.summary, ]); const handleSecurityUpdateIssueAction = useCallback((issue: SecurityUpdateIssue) => { - const repairEntry = resolveSecurityUpdateRepairEntry(issue, connections); + const repairEntry = resolveSecurityUpdateRepairEntry(issue, connections, securityUpdateStatus); if (repairEntry.type === 'warning') { void message.warning(repairEntry.message); return; @@ -973,8 +989,8 @@ function App() { return; } setSecurityUpdateRepairSource(null); - setIsSecurityUpdateSettingsOpen(true); - }, [connections, runSecurityUpdateRound]); + openSecurityUpdateSettings(repairEntry.focusTarget); + }, [connections, openSecurityUpdateSettings, runSecurityUpdateRound, securityUpdateStatus]); const updateCheckInFlightRef = React.useRef(false); const updateDownloadInFlightRef = React.useRef(false); const updateUserDismissedRef = React.useRef(false); @@ -1216,7 +1232,11 @@ function App() { if (!silent) { setAboutUpdateStatus('正在检查更新...'); } - const res = await (window as any).go.app.App.CheckForUpdates(); + const updateAPI = (window as any).go.app.App; + const checkFn = silent && typeof updateAPI.CheckForUpdatesSilently === 'function' + ? updateAPI.CheckForUpdatesSilently + : updateAPI.CheckForUpdates; + const res = await checkFn(); updateCheckInFlightRef.current = false; if (!res?.success) { if (!silent) { @@ -1494,8 +1514,13 @@ function App() { } const res = await backendApp.ExportConnectionsPackage(password); - if (!res?.success) { - throw new Error(res?.message || '导出失败'); + const exportResult = resolveConnectionPackageExportResult(connectionPackageDialog, res); + if (exportResult.kind === 'canceled') { + setConnectionPackageDialog(exportResult.nextDialog); + return; + } + if (exportResult.kind === 'failed') { + throw new Error(exportResult.error); } closeConnectionPackageDialog(); @@ -1695,7 +1720,9 @@ function App() { const rawStatus = typeof backendApp?.GetSecurityUpdateStatus === 'function' ? await backendApp.GetSecurityUpdateStatus() : securityUpdateStatus; - const nextStatus = mergeSecurityUpdateStatusWithLegacySource(rawStatus, nextRawPayload); + const nextStatus = mergeSecurityUpdateStatusWithLegacySource(rawStatus, nextRawPayload, { + previousStatus: securityUpdateStatus, + }); const nextHasLegacySensitiveItems = hasLegacyMigratableSensitiveItems(nextRawPayload); setSecurityUpdateRawPayload(nextRawPayload); @@ -2322,7 +2349,7 @@ function App() { title="拖动调整宽度" /> - + {securityUpdateEntryVisibility.showBanner && !isSecurityUpdateBannerDismissed && ( handleOpenSecurityUpdateSettings( + hasSecurityUpdateRecentResult(securityUpdateStatus) ? 'recent_result' : null, + )} onDismiss={() => setIsSecurityUpdateBannerDismissed(true)} /> )} @@ -2474,13 +2503,15 @@ function App() { overlayTheme={overlayTheme} onStart={handleStartSecurityUpdate} onPostpone={handlePostponeSecurityUpdate} - onViewDetails={handleOpenSecurityUpdateSettings} + onViewDetails={() => handleOpenSecurityUpdateSettings()} /> setIsSecurityUpdateSettingsOpen(false)} onStart={handleStartSecurityUpdate} onRetry={handleRetrySecurityUpdate} diff --git a/frontend/src/components/SecurityUpdateBanner.tsx b/frontend/src/components/SecurityUpdateBanner.tsx index b83fc31..ac410b5 100644 --- a/frontend/src/components/SecurityUpdateBanner.tsx +++ b/frontend/src/components/SecurityUpdateBanner.tsx @@ -4,6 +4,12 @@ import { CloseOutlined, SafetyCertificateOutlined } from '@ant-design/icons'; import type { SecurityUpdateStatus } from '../types'; import { getSecurityUpdateStatusMeta } from '../utils/securityUpdatePresentation'; import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme'; +import { + SECURITY_UPDATE_ACTION_BUTTON_CLASS, + SECURITY_UPDATE_BANNER_CLASS, + getSecurityUpdateActionButtonStyle, + getSecurityUpdateBannerSurfaceStyle, +} from '../utils/securityUpdateVisuals'; interface SecurityUpdateBannerProps { status: SecurityUpdateStatus; @@ -77,20 +83,20 @@ const SecurityUpdateBanner = ({ const statusMeta = getSecurityUpdateStatusMeta(status); const primaryAction = resolvePrimaryAction(status, { onStart, onRetry, onRestart, onOpenDetails }); const secondaryAction = resolveSecondaryAction(status, { onRetry, onOpenDetails }); + const actionButtonStyle = getSecurityUpdateActionButtonStyle(); return (
{secondaryAction ? ( - ) : null} - -
); diff --git a/frontend/src/components/SecurityUpdateIntroModal.tsx b/frontend/src/components/SecurityUpdateIntroModal.tsx index 7123f0c..e02c099 100644 --- a/frontend/src/components/SecurityUpdateIntroModal.tsx +++ b/frontend/src/components/SecurityUpdateIntroModal.tsx @@ -3,6 +3,11 @@ import { SafetyCertificateOutlined } from '@ant-design/icons'; import type { CSSProperties } from 'react'; import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme'; +import { + SECURITY_UPDATE_ACTION_BUTTON_CLASS, + SECURITY_UPDATE_MODAL_CLASS, + getSecurityUpdateActionButtonStyle, +} from '../utils/securityUpdateVisuals'; interface SecurityUpdateIntroModalProps { open: boolean; @@ -15,10 +20,9 @@ interface SecurityUpdateIntroModalProps { } const actionButtonStyle: CSSProperties = { + ...getSecurityUpdateActionButtonStyle(), height: 38, - borderRadius: 12, paddingInline: 18, - fontWeight: 600, }; const SecurityUpdateIntroModal = ({ @@ -32,6 +36,7 @@ const SecurityUpdateIntroModal = ({ }: SecurityUpdateIntroModalProps) => { return (
+ , - , - , ]} diff --git a/frontend/src/components/SecurityUpdateProgressModal.tsx b/frontend/src/components/SecurityUpdateProgressModal.tsx index dec305e..5e3888b 100644 --- a/frontend/src/components/SecurityUpdateProgressModal.tsx +++ b/frontend/src/components/SecurityUpdateProgressModal.tsx @@ -2,6 +2,7 @@ import { Modal, Spin } from 'antd'; import { SafetyCertificateOutlined } from '@ant-design/icons'; import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme'; +import { SECURITY_UPDATE_MODAL_CLASS } from '../utils/securityUpdateVisuals'; interface SecurityUpdateProgressModalProps { open: boolean; @@ -18,6 +19,7 @@ const SecurityUpdateProgressModal = ({ }: SecurityUpdateProgressModalProps) => { return ( void; onStart: () => void; onRetry: () => void; @@ -23,18 +41,27 @@ interface SecurityUpdateSettingsModalProps { onIssueAction: (issue: SecurityUpdateIssue) => void; } -const sectionStyle = (overlayTheme: OverlayWorkbenchTheme) => ({ +const sectionStyle = ( + overlayTheme: OverlayWorkbenchTheme, + options?: { emphasized?: boolean }, +) => ({ borderRadius: 14, - border: overlayTheme.sectionBorder, - background: overlayTheme.sectionBg, padding: 16, + ...getSecurityUpdateSectionSurfaceStyle(overlayTheme, options), }); +const EMPTY_FOCUS_STATE: SecurityUpdateFocusState = { + target: null, + pulseKey: null, +}; + const SecurityUpdateSettingsModal = ({ open, darkMode, overlayTheme, status, + focusTarget = null, + focusRequest = 0, onClose, onStart, onRetry, @@ -43,12 +70,53 @@ const SecurityUpdateSettingsModal = ({ }: SecurityUpdateSettingsModalProps) => { const statusMeta = getSecurityUpdateStatusMeta(status); const sortedIssues = sortSecurityUpdateIssues(status.issues); + const showRecentResult = hasSecurityUpdateRecentResult(status); const showStart = status.overallStatus === 'pending' || status.overallStatus === 'postponed'; const showRetry = status.overallStatus === 'needs_attention'; const showRestart = status.overallStatus === 'needs_attention' || status.overallStatus === 'rolled_back'; + const actionButtonStyle = getSecurityUpdateActionButtonStyle(); + const [activeFocus, setActiveFocus] = useState(EMPTY_FOCUS_STATE); + const statusSectionRef = useRef(null); + const recentResultRef = useRef(null); + + useEffect(() => { + const nextFocus = resolveSecurityUpdateFocusState(open, focusTarget, focusRequest); + if (!nextFocus.target || !nextFocus.pulseKey) { + setActiveFocus(EMPTY_FOCUS_STATE); + return undefined; + } + + const targetNode = nextFocus.target === 'recent_result' + ? recentResultRef.current + : statusSectionRef.current; + if (!targetNode) { + return undefined; + } + + setActiveFocus(EMPTY_FOCUS_STATE); + const animationFrame = window.requestAnimationFrame(() => { + targetNode.scrollIntoView({ + block: 'nearest', + behavior: 'smooth', + }); + targetNode.focus({ preventScroll: true }); + setActiveFocus(nextFocus); + }); + const highlightTimer = window.setTimeout(() => { + setActiveFocus((current) => ( + current.pulseKey === nextFocus.pulseKey ? EMPTY_FOCUS_STATE : current + )); + }, 1800); + + return () => { + window.cancelAnimationFrame(animationFrame); + window.clearTimeout(highlightTimer); + }; + }, [focusRequest, focusTarget, open]); return (
+ ) : null, showRestart ? ( - ) : null, showStart ? ( - ) : null, - , ]} width={760} styles={{ - content: { - background: overlayTheme.shellBg, - border: overlayTheme.shellBorder, - boxShadow: overlayTheme.shellShadow, - backdropFilter: overlayTheme.shellBackdropFilter, - }, + content: getSecurityUpdateShellSurfaceStyle(overlayTheme), header: { background: 'transparent', borderBottom: 'none', paddingBottom: 8 }, body: { paddingTop: 8, maxHeight: 640, overflowY: 'auto' }, footer: { background: 'transparent', borderTop: 'none', paddingTop: 10 }, }} >
-
+
@@ -153,8 +226,9 @@ const SecurityUpdateSettingsModal = ({
@@ -184,9 +258,8 @@ const SecurityUpdateSettingsModal = ({
- {status.backupPath ? ( -
+ {showRecentResult ? ( +
最近一次结果
-
- 备份位置:{status.backupPath} -
+ {status.backupPath ? ( +
+ 备份位置:{status.backupPath} +
+ ) : null} {status.lastError ? (
最近错误:{status.lastError} diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx index bb2e2ac..fea428b 100644 --- a/frontend/src/main.tsx +++ b/frontend/src/main.tsx @@ -119,6 +119,7 @@ if (typeof window !== 'undefined' && !(window as any).go) { DeleteQuery: async () => null, GetAppInfo: async () => ({}), CheckForUpdates: async () => ({ success: false }), + CheckForUpdatesSilently: async () => ({ success: false }), OpenDownloadedUpdateDirectory: async () => ({ success: false }), InstallUpdateAndRestart: async () => ({ success: false }), ImportConfigFile: async () => ({ success: false, message: '已取消' }), diff --git a/frontend/src/utils/connectionExport.test.ts b/frontend/src/utils/connectionExport.test.ts index 5b1c53e..d4f9720 100644 --- a/frontend/src/utils/connectionExport.test.ts +++ b/frontend/src/utils/connectionExport.test.ts @@ -2,6 +2,8 @@ import { describe, expect, it } from 'vitest'; import { detectConnectionImportKind, + isConnectionPackageExportCanceled, + resolveConnectionPackageExportResult, normalizeConnectionPackagePassword, } from './connectionExport'; @@ -57,4 +59,54 @@ describe('connectionExport', () => { expect(normalizeConnectionPackagePassword(' secret-pass ')).toBe('secret-pass'); expect(normalizeConnectionPackagePassword('\n\t \t')).toBe(''); }); + + it('treats export cancel as a non-error backend result', () => { + expect(isConnectionPackageExportCanceled({ success: false, message: '已取消' })).toBe(true); + expect(isConnectionPackageExportCanceled({ success: false, message: '导出失败' })).toBe(false); + expect(isConnectionPackageExportCanceled({ success: true, message: '已取消' })).toBe(false); + expect(isConnectionPackageExportCanceled(undefined)).toBe(false); + }); + + it('maps export results to dialog state transitions', () => { + const staleDialog = { + open: true, + mode: 'export' as const, + password: ' secret-pass ', + error: '上一次失败', + confirmLoading: false, + }; + + const canceledResult = resolveConnectionPackageExportResult(staleDialog, { success: false, message: '已取消' }); + expect(canceledResult.kind).toBe('canceled'); + if (canceledResult.kind === 'canceled') { + expect(typeof canceledResult.nextDialog).toBe('function'); + expect((canceledResult.nextDialog as (current: typeof staleDialog) => typeof staleDialog)({ + open: false, + mode: 'export', + password: 'secret-pass', + error: '更新后的错误', + confirmLoading: true, + })).toEqual({ + open: false, + mode: 'export', + password: 'secret-pass', + error: '', + confirmLoading: false, + }); + } + + expect(resolveConnectionPackageExportResult(staleDialog, { success: true, message: '导出完成' })).toEqual({ + kind: 'succeeded', + }); + + expect(resolveConnectionPackageExportResult(staleDialog, { success: false, message: '磁盘已满' })).toEqual({ + kind: 'failed', + error: '磁盘已满', + }); + + expect(resolveConnectionPackageExportResult(staleDialog, undefined)).toEqual({ + kind: 'failed', + error: '导出失败', + }); + }); }); diff --git a/frontend/src/utils/connectionExport.ts b/frontend/src/utils/connectionExport.ts index 9cec933..13ff987 100644 --- a/frontend/src/utils/connectionExport.ts +++ b/frontend/src/utils/connectionExport.ts @@ -1,10 +1,26 @@ import type { ConnectionConfig, SavedConnection } from '../types'; export type ConnectionImportKind = 'encrypted-package' | 'legacy-json' | 'invalid'; +export type ConnectionPackageDialogSnapshot = { + open: boolean; + mode: 'export' | 'import'; + password: string; + error: string; + confirmLoading: boolean; +}; +export type ConnectionPackageDialogUpdater = ( + current: ConnectionPackageDialogSnapshot, +) => ConnectionPackageDialogSnapshot; + +export type ConnectionPackageExportResult = + | { kind: 'canceled'; nextDialog: ConnectionPackageDialogUpdater } + | { kind: 'succeeded' } + | { kind: 'failed'; error: string }; type JsonObject = Record; const CONNECTION_PACKAGE_KIND = 'gonavi_connection_package'; +const CANCELED_MESSAGE = '已取消'; const isJsonObject = (value: unknown): value is JsonObject => ( typeof value === 'object' && value !== null && !Array.isArray(value) @@ -69,6 +85,39 @@ export const detectConnectionImportKind = (raw: unknown): ConnectionImportKind = export const normalizeConnectionPackagePassword = (value: string): string => value.trim(); +export const isConnectionPackageExportCanceled = (result: unknown): boolean => ( + isJsonObject(result) + && result.success === false + && result.message === CANCELED_MESSAGE +); + +export const resolveConnectionPackageExportResult = ( + _currentDialog: ConnectionPackageDialogSnapshot, + result: unknown, +): ConnectionPackageExportResult => { + if (isConnectionPackageExportCanceled(result)) { + return { + kind: 'canceled', + nextDialog: (current) => ({ + ...current, + confirmLoading: false, + error: '', + }), + }; + } + + if (isJsonObject(result) && result.success === true) { + return { kind: 'succeeded' }; + } + + return { + kind: 'failed', + error: isJsonObject(result) && typeof result.message === 'string' && result.message.trim() + ? result.message + : '导出失败', + }; +}; + const legacyExportRemovedError = (): never => { throw new Error('Legacy connection JSON export has been removed. Use the recovery package flow instead.'); }; diff --git a/frontend/src/utils/secureConfigBootstrap.test.ts b/frontend/src/utils/secureConfigBootstrap.test.ts index 07e13e6..32c9cd5 100644 --- a/frontend/src/utils/secureConfigBootstrap.test.ts +++ b/frontend/src/utils/secureConfigBootstrap.test.ts @@ -220,6 +220,83 @@ describe('secureConfigBootstrap', () => { expect(result.shouldShowBanner).toBe(true); }); + it('merges legacy pending items into rolled_back status without overwriting backend system issues', () => { + const status = mergeSecurityUpdateStatusWithLegacySource({ + overallStatus: 'rolled_back', + summary: { total: 1, updated: 0, pending: 0, skipped: 0, failed: 1 }, + issues: [ + { + id: 'system-blocked', + scope: 'system', + title: '系统回滚', + severity: 'high', + status: 'failed', + reasonCode: 'environment_blocked', + action: 'view_details', + message: '后端已回滚本轮更新,需要处理后重试。', + }, + ], + }, legacyPayload); + + expect(status.overallStatus).toBe('rolled_back'); + expect(status.summary).toEqual({ + total: 3, + updated: 0, + pending: 2, + skipped: 0, + failed: 1, + }); + expect(status.issues).toEqual(expect.arrayContaining([ + expect.objectContaining({ id: 'system-blocked', scope: 'system' }), + expect.objectContaining({ id: 'legacy-connection-legacy-1', scope: 'connection', refId: 'legacy-1' }), + expect.objectContaining({ id: 'legacy-global-proxy-default', scope: 'global_proxy' }), + ])); + }); + + it('merges legacy pending items into needs_attention status without overwriting backend system issues', () => { + const status = mergeSecurityUpdateStatusWithLegacySource({ + overallStatus: 'needs_attention', + summary: { total: 2, updated: 1, pending: 0, skipped: 0, failed: 1 }, + issues: [ + { + id: 'system-partial-failure', + scope: 'system', + title: '部分失败', + severity: 'high', + status: 'failed', + reasonCode: 'environment_blocked', + action: 'view_details', + message: '部分项目迁移失败,需要人工处理。', + }, + { + id: 'ai-provider-openai-main', + scope: 'ai_provider', + refId: 'openai-main', + title: 'OpenAI', + severity: 'medium', + status: 'updated', + action: 'open_ai_settings', + message: 'AI 提供商配置已完成安全更新。', + }, + ], + }, legacyPayload); + + expect(status.overallStatus).toBe('needs_attention'); + expect(status.summary).toEqual({ + total: 4, + updated: 1, + pending: 2, + skipped: 0, + failed: 1, + }); + expect(status.issues).toEqual(expect.arrayContaining([ + expect.objectContaining({ id: 'system-partial-failure', scope: 'system' }), + expect.objectContaining({ id: 'ai-provider-openai-main', scope: 'ai_provider', refId: 'openai-main' }), + expect.objectContaining({ id: 'legacy-connection-legacy-1', scope: 'connection', refId: 'legacy-1' }), + expect.objectContaining({ id: 'legacy-global-proxy-default', scope: 'global_proxy' }), + ])); + }); + it('loads backend secure config directly when no legacy source exists', async () => { const storage = createMemoryStorage(); const replaceConnections = vi.fn(); @@ -440,18 +517,25 @@ describe('secureConfigBootstrap', () => { }); it('reduces legacy pending issues after a single connection is repaired before the first round starts', () => { + const initialStatus = mergeSecurityUpdateStatusWithLegacySource({ + overallStatus: 'not_detected', + summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 }, + issues: [], + }, legacyPayload); const nextPayload = stripLegacyPersistedConnectionById(legacyPayload, 'legacy-1'); const status = mergeSecurityUpdateStatusWithLegacySource({ overallStatus: 'not_detected', summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 }, issues: [], - }, nextPayload); + }, nextPayload, { + previousStatus: initialStatus, + }); expect(status.overallStatus).toBe('pending'); expect(status.summary).toEqual({ - total: 1, - updated: 0, + total: 2, + updated: 1, pending: 1, skipped: 0, failed: 0, @@ -463,4 +547,87 @@ describe('secureConfigBootstrap', () => { }), ]); }); + + it('accumulates pre-start repaired progress across multiple connection saves in the same round-free session', () => { + const multiConnectionPayload = JSON.stringify({ + state: { + connections: [ + { + id: 'legacy-1', + name: 'Legacy 1', + config: { + id: 'legacy-1', + type: 'postgres', + host: 'db-1.local', + port: 5432, + user: 'postgres', + password: 'secret-1', + }, + }, + { + id: 'legacy-2', + name: 'Legacy 2', + config: { + id: 'legacy-2', + type: 'postgres', + host: 'db-2.local', + port: 5432, + user: 'postgres', + password: 'secret-2', + }, + }, + { + id: 'legacy-3', + name: 'Legacy 3', + config: { + id: 'legacy-3', + type: 'postgres', + host: 'db-3.local', + port: 5432, + user: 'postgres', + password: 'secret-3', + }, + }, + ], + }, + }); + + const backendStatus = { + overallStatus: 'not_detected' as const, + summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 }, + issues: [], + }; + const initialStatus = mergeSecurityUpdateStatusWithLegacySource(backendStatus, multiConnectionPayload); + const afterFirstRepairPayload = stripLegacyPersistedConnectionById(multiConnectionPayload, 'legacy-1'); + const afterFirstRepairStatus = mergeSecurityUpdateStatusWithLegacySource(backendStatus, afterFirstRepairPayload, { + previousStatus: initialStatus, + }); + const afterSecondRepairPayload = stripLegacyPersistedConnectionById(afterFirstRepairPayload, 'legacy-2'); + + const afterSecondRepairStatus = mergeSecurityUpdateStatusWithLegacySource(backendStatus, afterSecondRepairPayload, { + previousStatus: afterFirstRepairStatus, + }); + + expect(afterFirstRepairStatus.summary).toEqual({ + total: 3, + updated: 1, + pending: 2, + skipped: 0, + failed: 0, + }); + expect(afterSecondRepairStatus.summary).toEqual({ + total: 3, + updated: 2, + pending: 1, + skipped: 0, + failed: 0, + }); + expect(afterSecondRepairStatus.issues).toEqual([ + expect.objectContaining({ + id: 'legacy-connection-legacy-3', + scope: 'connection', + refId: 'legacy-3', + }), + ]); + }); }); diff --git a/frontend/src/utils/secureConfigBootstrap.ts b/frontend/src/utils/secureConfigBootstrap.ts index 666178e..f457024 100644 --- a/frontend/src/utils/secureConfigBootstrap.ts +++ b/frontend/src/utils/secureConfigBootstrap.ts @@ -54,6 +54,10 @@ type StartSecurityUpdateResult = { error: Error | null; }; +type MergeSecurityUpdateStatusOptions = { + previousStatus?: Partial | null; +}; + const defaultSummary = () => ({ total: 0, updated: 0, @@ -129,9 +133,56 @@ const mergeSecurityUpdateIssues = ( }; }; +const isLocalLegacyIssue = (issue: Partial | null | undefined): boolean => { + const issueId = String(issue?.id || '').trim(); + return issueId.startsWith('legacy-connection-') || issueId === 'legacy-global-proxy-default'; +}; + +const countLocalLegacyIssues = (issues: SecurityUpdateIssue[]): number => ( + issues.filter((issue) => isLocalLegacyIssue(issue)).length +); + +const deriveLegacySummary = ( + base: SecurityUpdateStatus, + currentLegacyCount: number, + previousStatus?: Partial | null, +): { + summary: SecurityUpdateSummary; + hasContribution: boolean; +} => { + const previousSummary = previousStatus?.summary ?? defaultSummary(); + const previousIssues = Array.isArray(previousStatus?.issues) ? previousStatus.issues : []; + const previousLegacyCount = countLocalLegacyIssues(previousIssues); + const previousLegacyTotal = Math.max( + 0, + previousSummary.total - base.summary.total, + previousSummary.updated - base.summary.updated + previousLegacyCount, + previousLegacyCount, + ); + const previousLegacyUpdated = Math.max( + 0, + Math.min(previousLegacyTotal, previousSummary.updated - base.summary.updated), + ); + const repairedSincePrevious = Math.max(0, previousLegacyCount - currentLegacyCount); + const nextLegacyUpdated = Math.min(previousLegacyTotal, previousLegacyUpdated + repairedSincePrevious); + const nextLegacyTotal = Math.max(previousLegacyTotal, nextLegacyUpdated + currentLegacyCount); + + return { + summary: { + total: base.summary.total + nextLegacyTotal, + updated: base.summary.updated + nextLegacyUpdated, + pending: base.summary.pending + currentLegacyCount, + skipped: base.summary.skipped, + failed: base.summary.failed, + }, + hasContribution: nextLegacyTotal > 0, + }; +}; + export const mergeSecurityUpdateStatusWithLegacySource = ( status: Partial | undefined, rawPayload: string | null, + options?: MergeSecurityUpdateStatusOptions, ): SecurityUpdateStatus => { const base: SecurityUpdateStatus = { ...defaultStatus(), @@ -142,46 +193,51 @@ export const mergeSecurityUpdateStatusWithLegacySource = ( }, issues: Array.isArray(status?.issues) ? status.issues : [], }; + const baseNonLegacyIssues = base.issues.filter((issue) => !isLocalLegacyIssue(issue)); const legacy = buildLegacyPendingDetails(rawPayload); - if (!legacy.hasLegacyItems) { + const legacySummary = deriveLegacySummary(base, legacy.issues.length, options?.previousStatus); + + if (!legacySummary.hasContribution) { return base; } + const mergedIssues = mergeSecurityUpdateIssues(baseNonLegacyIssues, legacy.issues).issues; + if (base.overallStatus === 'not_detected') { + if (!legacy.hasLegacyItems) { + return base; + } return { ...base, overallStatus: 'pending', reminderVisible: true, canStart: true, canPostpone: true, - summary: legacy.summary, - issues: legacy.issues, + summary: legacySummary.summary, + issues: mergedIssues, }; } if (base.overallStatus === 'pending' || base.overallStatus === 'postponed') { - const mergedIssues = mergeSecurityUpdateIssues(base.issues, legacy.issues); - const summary = hasMeaningfulSummary(base.summary) - ? { - total: base.summary.total + mergedIssues.addedCount, - updated: base.summary.updated, - pending: base.summary.pending + mergedIssues.addedCount, - skipped: base.summary.skipped, - failed: base.summary.failed, - } - : legacy.summary; - return { ...base, - summary, - issues: mergedIssues.issues, + summary: hasMeaningfulSummary(base.summary) || legacy.hasLegacyItems ? legacySummary.summary : legacy.summary, + issues: mergedIssues, canStart: true, canPostpone: true, reminderVisible: base.overallStatus === 'pending' ? true : base.reminderVisible, }; } + if (base.overallStatus === 'rolled_back' || base.overallStatus === 'needs_attention') { + return { + ...base, + summary: hasMeaningfulSummary(base.summary) || legacy.hasLegacyItems ? legacySummary.summary : legacy.summary, + issues: mergedIssues, + }; + } + return base; }; @@ -344,6 +400,7 @@ export async function startSecurityUpdateFromBootstrap(args: SecureConfigBootstr export type { BackendGlobalProxyResult, + MergeSecurityUpdateStatusOptions, SecurityUpdateBackend, SecureConfigBootstrapArgs, SecureConfigBootstrapResult, diff --git a/frontend/src/utils/securityUpdateRepairFlow.test.ts b/frontend/src/utils/securityUpdateRepairFlow.test.ts index 3e514bb..0cb57f7 100644 --- a/frontend/src/utils/securityUpdateRepairFlow.test.ts +++ b/frontend/src/utils/securityUpdateRepairFlow.test.ts @@ -1,8 +1,11 @@ import { describe, expect, it } from 'vitest'; -import type { SavedConnection, SecurityUpdateIssue } from '../types'; +import type { SavedConnection, SecurityUpdateIssue, SecurityUpdateStatus } from '../types'; import { + hasSecurityUpdateRecentResult, + resolveSecurityUpdateFocusState, resolveSecurityUpdateRepairEntry, + resolveSecurityUpdateSettingsFocusTarget, shouldReopenSecurityUpdateDetails, shouldRetrySecurityUpdateAfterRepairSave, } from './securityUpdateRepairFlow'; @@ -19,6 +22,19 @@ const createConnection = (id: string): SavedConnection => ({ }, }); +const createStatus = (overrides: Partial = {}): SecurityUpdateStatus => ({ + overallStatus: 'needs_attention', + summary: { + total: 1, + updated: 0, + pending: 1, + skipped: 0, + failed: 0, + }, + issues: [], + ...overrides, +}); + describe('securityUpdateRepairFlow', () => { it('opens the matching connection and preserves the return source for security update repairs', () => { const target = createConnection('conn-1'); @@ -63,6 +79,50 @@ describe('securityUpdateRepairFlow', () => { }); }); + it('routes view_details actions to the latest result section when a recent result exists', () => { + const status = createStatus({ + backupPath: '/tmp/gonavi-backup.json', + lastError: '写入新密钥失败', + }); + + expect(hasSecurityUpdateRecentResult(status)).toBe(true); + expect(resolveSecurityUpdateSettingsFocusTarget(status)).toBe('recent_result'); + expect(resolveSecurityUpdateRepairEntry({ id: 'details', action: 'view_details' }, [], status)).toEqual({ + type: 'details', + focusTarget: 'recent_result', + }); + }); + + it('falls back to the status section when no recent result is available yet', () => { + const status = createStatus(); + + expect(hasSecurityUpdateRecentResult(status)).toBe(false); + expect(resolveSecurityUpdateSettingsFocusTarget(status)).toBe('status'); + expect(resolveSecurityUpdateRepairEntry({ id: 'details', action: 'view_details' }, [], status)).toEqual({ + type: 'details', + focusTarget: 'status', + }); + }); + + it('builds a fresh focus pulse for repeated details clicks and clears it when the modal closes', () => { + expect(resolveSecurityUpdateFocusState(true, 'status', 1)).toEqual({ + target: 'status', + pulseKey: 'status:1', + }); + expect(resolveSecurityUpdateFocusState(true, 'status', 2)).toEqual({ + target: 'status', + pulseKey: 'status:2', + }); + expect(resolveSecurityUpdateFocusState(false, 'status', 2)).toEqual({ + target: null, + pulseKey: null, + }); + expect(resolveSecurityUpdateFocusState(true, null, 3)).toEqual({ + target: null, + pulseKey: null, + }); + }); + it('reopens security update details after closing a repair entry opened from that page', () => { expect(shouldReopenSecurityUpdateDetails('connection')).toBe(true); expect(shouldReopenSecurityUpdateDetails('proxy')).toBe(true); diff --git a/frontend/src/utils/securityUpdateRepairFlow.ts b/frontend/src/utils/securityUpdateRepairFlow.ts index 5df098a..9a6be1e 100644 --- a/frontend/src/utils/securityUpdateRepairFlow.ts +++ b/frontend/src/utils/securityUpdateRepairFlow.ts @@ -1,6 +1,11 @@ -import type { SavedConnection, SecurityUpdateIssue } from '../types'; +import type { SavedConnection, SecurityUpdateIssue, SecurityUpdateStatus } from '../types'; export type SecurityUpdateRepairSource = 'connection' | 'proxy' | 'ai'; +export type SecurityUpdateSettingsFocusTarget = 'recent_result' | 'status'; +export type SecurityUpdateFocusState = { + target: SecurityUpdateSettingsFocusTarget | null; + pulseKey: string | null; +}; export type SecurityUpdateRepairEntry = | { @@ -22,15 +27,45 @@ export type SecurityUpdateRepairEntry = } | { type: 'details'; + focusTarget: SecurityUpdateSettingsFocusTarget; } | { type: 'warning'; message: string; }; +export const hasSecurityUpdateRecentResult = ( + status?: Pick | null, +): boolean => Boolean(status?.backupPath || status?.lastError); + +export const resolveSecurityUpdateSettingsFocusTarget = ( + status?: Pick | null, +): SecurityUpdateSettingsFocusTarget => ( + hasSecurityUpdateRecentResult(status) ? 'recent_result' : 'status' +); + +export const resolveSecurityUpdateFocusState = ( + open: boolean, + focusTarget: SecurityUpdateSettingsFocusTarget | null | undefined, + focusRequest: number, +): SecurityUpdateFocusState => { + if (!open || !focusTarget) { + return { + target: null, + pulseKey: null, + }; + } + + return { + target: focusTarget, + pulseKey: `${focusTarget}:${focusRequest}`, + }; +}; + export const resolveSecurityUpdateRepairEntry = ( issue: SecurityUpdateIssue, connections: SavedConnection[], + status?: Pick | null, ): SecurityUpdateRepairEntry => { if (issue.action === 'open_connection') { const target = connections.find((connection) => connection.id === issue.refId); @@ -70,6 +105,7 @@ export const resolveSecurityUpdateRepairEntry = ( return { type: 'details', + focusTarget: resolveSecurityUpdateSettingsFocusTarget(status), }; }; diff --git a/frontend/src/utils/securityUpdateVisuals.test.ts b/frontend/src/utils/securityUpdateVisuals.test.ts new file mode 100644 index 0000000..7d0d8e7 --- /dev/null +++ b/frontend/src/utils/securityUpdateVisuals.test.ts @@ -0,0 +1,88 @@ +import { describe, expect, it } from 'vitest'; + +import { buildOverlayWorkbenchTheme } from './overlayWorkbenchTheme'; +import { + SECURITY_UPDATE_ACTION_BUTTON_CLASS, + SECURITY_UPDATE_BANNER_CLASS, + SECURITY_UPDATE_RESULT_CARD_ACTIVE_CLASS, + SECURITY_UPDATE_RESULT_CARD_CLASS, + getSecurityUpdateActionButtonStyle, + getSecurityUpdateBannerSurfaceStyle, + getSecurityUpdateSectionSurfaceStyle, + getSecurityUpdateShellSurfaceStyle, +} from './securityUpdateVisuals'; + +describe('securityUpdateVisuals', () => { + it('builds action buttons without default ant focus glow shadow', () => { + expect(SECURITY_UPDATE_ACTION_BUTTON_CLASS).toBe('security-update-action-btn'); + expect(SECURITY_UPDATE_BANNER_CLASS).toBe('security-update-banner'); + expect(SECURITY_UPDATE_RESULT_CARD_CLASS).toBe('security-update-result-card'); + expect(SECURITY_UPDATE_RESULT_CARD_ACTIVE_CLASS).toBe('security-update-result-card-active'); + expect(getSecurityUpdateActionButtonStyle()).toMatchObject({ + height: 36, + borderRadius: 12, + boxShadow: 'none', + fontWeight: 600, + }); + }); + + it('keeps the shell surface aligned with overlay shell tokens in light and dark mode', () => { + const lightTheme = buildOverlayWorkbenchTheme(false); + const darkTheme = buildOverlayWorkbenchTheme(true); + + expect(getSecurityUpdateShellSurfaceStyle(lightTheme)).toMatchObject({ + border: lightTheme.shellBorder, + background: lightTheme.shellBg, + boxShadow: lightTheme.shellShadow, + backdropFilter: lightTheme.shellBackdropFilter, + }); + expect(getSecurityUpdateShellSurfaceStyle(darkTheme)).toMatchObject({ + border: darkTheme.shellBorder, + background: darkTheme.shellBg, + boxShadow: darkTheme.shellShadow, + backdropFilter: darkTheme.shellBackdropFilter, + }); + }); + + it('keeps the banner surface aligned with overlay shell tokens instead of translucent section tokens', () => { + const lightTheme = buildOverlayWorkbenchTheme(false); + const darkTheme = buildOverlayWorkbenchTheme(true); + + expect(getSecurityUpdateBannerSurfaceStyle(lightTheme)).toMatchObject({ + border: lightTheme.shellBorder, + background: lightTheme.shellBg, + boxShadow: 'none', + backdropFilter: lightTheme.shellBackdropFilter, + }); + expect(getSecurityUpdateBannerSurfaceStyle(darkTheme)).toMatchObject({ + border: darkTheme.shellBorder, + background: darkTheme.shellBg, + boxShadow: 'none', + backdropFilter: darkTheme.shellBackdropFilter, + }); + }); + + it('can emphasize a section surface for transient focus and recent-result highlighting', () => { + const lightTheme = buildOverlayWorkbenchTheme(false); + const darkTheme = buildOverlayWorkbenchTheme(true); + + expect(getSecurityUpdateSectionSurfaceStyle(lightTheme)).toMatchObject({ + border: lightTheme.sectionBorder, + background: lightTheme.sectionBg, + boxShadow: 'none', + }); + expect(getSecurityUpdateSectionSurfaceStyle(darkTheme)).toMatchObject({ + border: darkTheme.sectionBorder, + background: darkTheme.sectionBg, + boxShadow: 'none', + }); + + const emphasizedLight = getSecurityUpdateSectionSurfaceStyle(lightTheme, { emphasized: true }); + const emphasizedDark = getSecurityUpdateSectionSurfaceStyle(darkTheme, { emphasized: true }); + + expect(emphasizedLight.background).not.toBe(lightTheme.sectionBg); + expect(emphasizedLight.boxShadow).not.toBe('none'); + expect(emphasizedDark.background).not.toBe(darkTheme.sectionBg); + expect(emphasizedDark.boxShadow).not.toBe('none'); + }); +}); diff --git a/frontend/src/utils/securityUpdateVisuals.ts b/frontend/src/utils/securityUpdateVisuals.ts new file mode 100644 index 0000000..735ca2d --- /dev/null +++ b/frontend/src/utils/securityUpdateVisuals.ts @@ -0,0 +1,65 @@ +import type { CSSProperties } from 'react'; + +import type { OverlayWorkbenchTheme } from './overlayWorkbenchTheme'; + +export const SECURITY_UPDATE_ACTION_BUTTON_CLASS = 'security-update-action-btn'; +export const SECURITY_UPDATE_BANNER_CLASS = 'security-update-banner'; +export const SECURITY_UPDATE_MODAL_CLASS = 'security-update-modal'; +export const SECURITY_UPDATE_RESULT_CARD_CLASS = 'security-update-result-card'; +export const SECURITY_UPDATE_RESULT_CARD_ACTIVE_CLASS = 'security-update-result-card-active'; + +type SecurityUpdateSectionSurfaceOptions = { + emphasized?: boolean; +}; + +const getSecurityUpdateHighlightBorder = (overlayTheme: OverlayWorkbenchTheme): string => ( + overlayTheme.isDark + ? '1px solid rgba(255,214,102,0.26)' + : '1px solid rgba(22,119,255,0.22)' +); + +const getSecurityUpdateHighlightBackground = (overlayTheme: OverlayWorkbenchTheme): string => ( + overlayTheme.isDark + ? 'linear-gradient(180deg, rgba(255,214,102,0.14) 0%, rgba(255,255,255,0.05) 100%)' + : 'linear-gradient(180deg, rgba(22,119,255,0.12) 0%, rgba(255,255,255,0.96) 100%)' +); + +const getSecurityUpdateHighlightShadow = (overlayTheme: OverlayWorkbenchTheme): string => ( + overlayTheme.isDark + ? '0 0 0 1px rgba(255,214,102,0.12), 0 12px 24px rgba(0,0,0,0.16)' + : '0 0 0 1px rgba(22,119,255,0.08), 0 10px 22px rgba(15,23,42,0.08)' +); + +export const getSecurityUpdateActionButtonStyle = (): CSSProperties => ({ + height: 36, + borderRadius: 12, + paddingInline: 16, + boxShadow: 'none', + fontWeight: 600, +}); + +export const getSecurityUpdateShellSurfaceStyle = ( + overlayTheme: OverlayWorkbenchTheme, +): CSSProperties => ({ + border: overlayTheme.shellBorder, + background: overlayTheme.shellBg, + boxShadow: overlayTheme.shellShadow, + backdropFilter: overlayTheme.shellBackdropFilter, +}); + +export const getSecurityUpdateBannerSurfaceStyle = ( + overlayTheme: OverlayWorkbenchTheme, +): CSSProperties => ({ + ...getSecurityUpdateShellSurfaceStyle(overlayTheme), + boxShadow: 'none', +}); + +export const getSecurityUpdateSectionSurfaceStyle = ( + overlayTheme: OverlayWorkbenchTheme, + options: SecurityUpdateSectionSurfaceOptions = {}, +): CSSProperties => ({ + border: options.emphasized ? getSecurityUpdateHighlightBorder(overlayTheme) : overlayTheme.sectionBorder, + background: options.emphasized ? getSecurityUpdateHighlightBackground(overlayTheme) : overlayTheme.sectionBg, + boxShadow: options.emphasized ? getSecurityUpdateHighlightShadow(overlayTheme) : 'none', + transition: 'background 180ms ease, border-color 180ms ease, box-shadow 180ms ease', +}); diff --git a/frontend/wailsjs/go/app/App.d.ts b/frontend/wailsjs/go/app/App.d.ts index e18d5a7..ab7f7d3 100755 --- a/frontend/wailsjs/go/app/App.d.ts +++ b/frontend/wailsjs/go/app/App.d.ts @@ -15,6 +15,8 @@ export function CheckDriverNetworkStatus():Promise; export function CheckForUpdates():Promise; +export function CheckForUpdatesSilently():Promise; + export function ConfigureDriverRuntimeDirectory(arg1:string):Promise; export function ConfigureGlobalProxy(arg1:boolean,arg2:connection.ProxyConfig):Promise; diff --git a/frontend/wailsjs/go/app/App.js b/frontend/wailsjs/go/app/App.js index 5f65811..269c6af 100755 --- a/frontend/wailsjs/go/app/App.js +++ b/frontend/wailsjs/go/app/App.js @@ -22,6 +22,10 @@ export function CheckForUpdates() { return window['go']['app']['App']['CheckForUpdates'](); } +export function CheckForUpdatesSilently() { + return window['go']['app']['App']['CheckForUpdatesSilently'](); +} + export function ConfigureDriverRuntimeDirectory(arg1) { return window['go']['app']['App']['ConfigureDriverRuntimeDirectory'](arg1); } diff --git a/internal/app/connection_package_crypto.go b/internal/app/connection_package_crypto.go index e844144..8337e2e 100644 --- a/internal/app/connection_package_crypto.go +++ b/internal/app/connection_package_crypto.go @@ -69,7 +69,13 @@ func encryptConnectionPackage(payload connectionPackagePayload, password string) } ciphertext := aead.Seal(nil, nonce, plain, aad) + if len(ciphertext) > connectionPackageMaxCiphertextBytes { + return connectionPackageFile{}, errConnectionPackagePayloadTooLarge + } file.Payload = base64.StdEncoding.EncodeToString(ciphertext) + if len(file.Payload) > connectionPackageMaxPayloadBase64Bytes { + return connectionPackageFile{}, errConnectionPackagePayloadTooLarge + } return file, nil } @@ -84,6 +90,9 @@ func decryptConnectionPackage(file connectionPackageFile, password string) (conn plain, err := decryptConnectionPackagePlaintext(file, normalizedPassword) if err != nil { + if errors.Is(err, errConnectionPackagePayloadTooLarge) { + return connectionPackagePayload{}, err + } return connectionPackagePayload{}, errConnectionPackageDecryptFailed } @@ -127,10 +136,16 @@ func decryptConnectionPackagePlaintext(file connectionPackageFile, password stri if err != nil || len(nonce) != connectionPackageNonceBytes { return nil, errors.New("invalid nonce") } + if len(file.Payload) > connectionPackageMaxPayloadBase64Bytes { + return nil, errConnectionPackagePayloadTooLarge + } ciphertext, err := base64.StdEncoding.DecodeString(file.Payload) if err != nil || len(ciphertext) == 0 { return nil, errors.New("invalid payload") } + if len(ciphertext) > connectionPackageMaxCiphertextBytes { + return nil, errConnectionPackagePayloadTooLarge + } key, err := deriveConnectionPackageKey(password, file.KDF) if err != nil { diff --git a/internal/app/connection_package_crypto_test.go b/internal/app/connection_package_crypto_test.go index b1368e0..22ba2f1 100644 --- a/internal/app/connection_package_crypto_test.go +++ b/internal/app/connection_package_crypto_test.go @@ -1,9 +1,11 @@ package app import ( + "encoding/base64" "encoding/json" "errors" "reflect" + "strings" "testing" "GoNavi-Wails/internal/connection" @@ -222,3 +224,74 @@ func TestValidateConnectionPackageKDFSpecRejectsOversizedParams(t *testing.T) { } }) } + +func TestDecryptConnectionPackagePlaintextRejectsOversizedPayload(t *testing.T) { + nonce := base64.StdEncoding.EncodeToString(make([]byte, connectionPackageNonceBytes)) + salt := base64.StdEncoding.EncodeToString(make([]byte, connectionPackageSaltBytes)) + payload := base64.StdEncoding.EncodeToString(make([]byte, connectionPackageMaxCiphertextBytes+1)) + + file := connectionPackageFile{ + SchemaVersion: connectionPackageSchemaVersion, + Kind: connectionPackageKind, + Cipher: connectionPackageCipher, + KDF: connectionPackageKDFSpec{ + Name: connectionPackageKDFName, + MemoryKiB: connectionPackageKDFDefaultMemoryKiB, + TimeCost: connectionPackageKDFDefaultTimeCost, + Parallelism: connectionPackageKDFDefaultParallelism, + Salt: salt, + }, + Nonce: nonce, + Payload: payload, + } + + _, err := decryptConnectionPackagePlaintext(file, "correct-password") + if !errors.Is(err, errConnectionPackagePayloadTooLarge) { + t.Fatalf("oversized payload should return errConnectionPackagePayloadTooLarge, got: %v", err) + } +} + +func TestDecryptConnectionPackagePlaintextRejectsOversizedBase64PayloadBeforeDecode(t *testing.T) { + nonce := base64.StdEncoding.EncodeToString(make([]byte, connectionPackageNonceBytes)) + + file := connectionPackageFile{ + SchemaVersion: connectionPackageSchemaVersion, + Kind: connectionPackageKind, + Cipher: connectionPackageCipher, + KDF: connectionPackageKDFSpec{ + Name: connectionPackageKDFName, + MemoryKiB: connectionPackageKDFDefaultMemoryKiB, + TimeCost: connectionPackageKDFDefaultTimeCost, + Parallelism: connectionPackageKDFDefaultParallelism, + Salt: base64.StdEncoding.EncodeToString(make([]byte, connectionPackageSaltBytes)), + }, + Nonce: nonce, + Payload: strings.Repeat("A", connectionPackageMaxPayloadBase64Bytes+4), + } + + _, err := decryptConnectionPackagePlaintext(file, "correct-password") + if !errors.Is(err, errConnectionPackagePayloadTooLarge) { + t.Fatalf("oversized base64 payload should return errConnectionPackagePayloadTooLarge, got: %v", err) + } +} + +func TestEncryptConnectionPackageRejectsOversizedPayload(t *testing.T) { + _, err := encryptConnectionPackage(connectionPackagePayload{ + Connections: []connectionPackageItem{ + { + ID: "conn-large", + Name: strings.Repeat("x", connectionPackageMaxCiphertextBytes), + Config: connection.ConnectionConfig{ + ID: "conn-large", + Type: "postgres", + Host: "db.large.local", + Port: 5432, + User: "postgres", + }, + }, + }, + }, "correct-password") + if !errors.Is(err, errConnectionPackagePayloadTooLarge) { + t.Fatalf("oversized export payload should return errConnectionPackagePayloadTooLarge, got: %v", err) + } +} diff --git a/internal/app/connection_package_transfer.go b/internal/app/connection_package_transfer.go index 4cd47e3..3fc8e31 100644 --- a/internal/app/connection_package_transfer.go +++ b/internal/app/connection_package_transfer.go @@ -9,6 +9,8 @@ import ( "GoNavi-Wails/internal/connection" "GoNavi-Wails/internal/secretstore" + + "github.com/google/uuid" ) func newConnectionPackageItem(view connection.SavedConnectionView, bundle connectionSecretBundle) connectionPackageItem { @@ -86,25 +88,99 @@ func newSavedConnectionInputFromPackageItem(item connectionPackageItem) connecti } } -func (a *App) importConnectionPackagePayload(payload connectionPackagePayload) ([]connection.SavedConnectionView, error) { +func dedupeImportedSavedConnectionViews(views []connection.SavedConnectionView) []connection.SavedConnectionView { + if len(views) < 2 { + return views + } + + lastIndexByID := make(map[string]int, len(views)) + for index, view := range views { + id := strings.TrimSpace(view.ID) + if id == "" { + continue + } + lastIndexByID[id] = index + } + + result := make([]connection.SavedConnectionView, 0, len(views)) + for index, view := range views { + id := strings.TrimSpace(view.ID) + if id != "" && lastIndexByID[id] != index { + continue + } + result = append(result, view) + } + return result +} + +func dedupeImportedSavedConnectionInputs(inputs []connection.SavedConnectionInput) []connection.SavedConnectionInput { + if len(inputs) < 2 { + return inputs + } + + lastIndexByID := make(map[string]int, len(inputs)) + for index, input := range inputs { + id := strings.TrimSpace(input.ID) + if id == "" { + continue + } + lastIndexByID[id] = index + } + + result := make([]connection.SavedConnectionInput, 0, len(inputs)) + for index, input := range inputs { + id := strings.TrimSpace(input.ID) + if id != "" && lastIndexByID[id] != index { + continue + } + result = append(result, input) + } + return result +} + +func normalizeImportedSavedConnectionInput(input connection.SavedConnectionInput) connection.SavedConnectionInput { + if strings.TrimSpace(input.ID) == "" && strings.TrimSpace(input.Config.ID) == "" { + input.ID = "conn-" + uuid.New().String()[:8] + } + if strings.TrimSpace(input.ID) == "" { + input.ID = strings.TrimSpace(input.Config.ID) + } + input.Config.ID = input.ID + return input +} + +func (a *App) importSavedConnectionsAtomically(inputs []connection.SavedConnectionInput) ([]connection.SavedConnectionView, error) { repo := a.savedConnectionRepository() - rollbackSnapshot, err := captureConnectionPackageImportRollbackSnapshot(a, payload) + normalizedInputs := make([]connection.SavedConnectionInput, 0, len(inputs)) + for _, input := range inputs { + normalizedInputs = append(normalizedInputs, normalizeImportedSavedConnectionInput(input)) + } + finalInputs := dedupeImportedSavedConnectionInputs(normalizedInputs) + rollbackSnapshot, err := captureConnectionImportRollbackSnapshot(a, finalInputs) if err != nil { return nil, err } - result := make([]connection.SavedConnectionView, 0, len(payload.Connections)) - for _, item := range payload.Connections { - view, err := repo.Save(newSavedConnectionInputFromPackageItem(item)) + result := make([]connection.SavedConnectionView, 0, len(finalInputs)) + for _, input := range finalInputs { + view, err := repo.Save(input) if err != nil { if rollbackErr := rollbackSnapshot.restore(a); rollbackErr != nil { - return nil, errors.Join(err, fmt.Errorf("restore connection package rollback: %w", rollbackErr)) + return nil, errors.Join(err, fmt.Errorf("restore connection import rollback: %w", rollbackErr)) } return nil, err } result = append(result, view) } - return result, nil + return dedupeImportedSavedConnectionViews(result), nil +} + +func (a *App) importConnectionPackagePayload(payload connectionPackagePayload) ([]connection.SavedConnectionView, error) { + inputs := make([]connection.SavedConnectionInput, 0, len(payload.Connections)) + for _, item := range payload.Connections { + inputs = append(inputs, newSavedConnectionInputFromPackageItem(item)) + } + return a.importSavedConnectionsAtomically(inputs) } func (a *App) ImportConnectionsPayload(raw string, password string) ([]connection.SavedConnectionView, error) { @@ -112,6 +188,9 @@ func (a *App) ImportConnectionsPayload(raw string, password string) ([]connectio if trimmed == "" { return nil, errConnectionPackageUnsupported } + if len(trimmed) > connectionImportMaxFileBytes { + return nil, errConnectionImportFileTooLarge + } if isConnectionPackageEnvelope(trimmed) { var file connectionPackageFile @@ -139,7 +218,7 @@ type connectionPackageImportRollbackSnapshot struct { connectionCleanupRefs []string } -func captureConnectionPackageImportRollbackSnapshot(a *App, payload connectionPackagePayload) (connectionPackageImportRollbackSnapshot, error) { +func captureConnectionImportRollbackSnapshot(a *App, inputs []connection.SavedConnectionInput) (connectionPackageImportRollbackSnapshot, error) { snapshot := connectionPackageImportRollbackSnapshot{ connectionSecrets: make(map[string]securityUpdateSecretSnapshot), } @@ -163,9 +242,11 @@ func captureConnectionPackageImportRollbackSnapshot(a *App, payload connectionPa cleanupSet := make(map[string]struct{}) seenIDs := make(map[string]struct{}) - for _, item := range payload.Connections { - input := newSavedConnectionInputFromPackageItem(item) + for _, input := range inputs { connectionID := strings.TrimSpace(input.ID) + if connectionID == "" { + connectionID = strings.TrimSpace(input.Config.ID) + } if connectionID == "" { continue } diff --git a/internal/app/connection_package_transfer_test.go b/internal/app/connection_package_transfer_test.go index a988266..81d40ea 100644 --- a/internal/app/connection_package_transfer_test.go +++ b/internal/app/connection_package_transfer_test.go @@ -173,7 +173,7 @@ func TestImportConnectionPackagePayloadLatestEntryWinsForSameID(t *testing.T) { app := NewAppWithSecretStore(newFakeAppSecretStore()) app.configDir = t.TempDir() - _, err := app.importConnectionPackagePayload(connectionPackagePayload{ + imported, err := app.importConnectionPackagePayload(connectionPackagePayload{ Connections: []connectionPackageItem{ { ID: "conn-dup", @@ -204,6 +204,12 @@ func TestImportConnectionPackagePayloadLatestEntryWinsForSameID(t *testing.T) { if err != nil { t.Fatalf("importConnectionPackagePayload returned error: %v", err) } + if len(imported) != 1 { + t.Fatalf("expected duplicate ids to return 1 final imported item, got %d", len(imported)) + } + if imported[0].Name != "Second" { + t.Fatalf("expected returned import result to keep latest entry, got %q", imported[0].Name) + } saved, err := app.GetSavedConnections() if err != nil { @@ -225,6 +231,153 @@ func TestImportConnectionPackagePayloadLatestEntryWinsForSameID(t *testing.T) { } } +func TestImportConnectionsPayloadLegacyJSONRollsBackOnSaveFailure(t *testing.T) { + failRef, err := secretstore.BuildRef(savedConnectionSecretKind, "legacy-2") + if err != nil { + t.Fatalf("BuildRef returned error: %v", err) + } + + store := newFailOnPutSecretStore(failRef) + app := NewAppWithSecretStore(store) + app.configDir = t.TempDir() + + _, err = app.SaveConnection(connection.SavedConnectionInput{ + ID: "legacy-1", + Name: "Existing Legacy", + Config: connection.ConnectionConfig{ + ID: "legacy-1", + Type: "postgres", + Host: "db.old.local", + Port: 5432, + User: "postgres", + Password: "old-primary", + }, + }) + if err != nil { + t.Fatalf("SaveConnection returned error: %v", err) + } + + raw, err := json.Marshal([]connection.LegacySavedConnection{ + { + ID: "legacy-1", + Name: "Imported Existing Legacy", + Config: connection.ConnectionConfig{ + ID: "legacy-1", + Type: "postgres", + Host: "db.new.local", + Port: 5432, + User: "postgres", + }, + }, + { + ID: "legacy-2", + Name: "Imported New Legacy", + Config: connection.ConnectionConfig{ + ID: "legacy-2", + Type: "mysql", + Host: "db.second.local", + Port: 3306, + User: "root", + Password: "second-primary", + }, + }, + }) + if err != nil { + t.Fatalf("json.Marshal returned error: %v", err) + } + + imported, err := app.ImportConnectionsPayload(string(raw), "ignored") + if err == nil { + t.Fatal("expected ImportConnectionsPayload to return error") + } + if imported != nil { + t.Fatalf("expected no imported results after rollback, got %#v", imported) + } + + saved, err := app.GetSavedConnections() + if err != nil { + t.Fatalf("GetSavedConnections returned error: %v", err) + } + if len(saved) != 1 { + t.Fatalf("expected rollback to restore exactly 1 legacy connection, got %d", len(saved)) + } + if saved[0].ID != "legacy-1" || saved[0].Name != "Existing Legacy" { + t.Fatalf("expected rollback to restore original legacy metadata, got %#v", saved[0]) + } + if saved[0].Config.Host != "db.old.local" { + t.Fatalf("expected rollback to restore original legacy host, got %q", saved[0].Config.Host) + } + + resolved, err := app.resolveConnectionSecrets(saved[0].Config) + if err != nil { + t.Fatalf("resolveConnectionSecrets returned error: %v", err) + } + if resolved.Password != "old-primary" { + t.Fatalf("expected rollback to restore original legacy password, got %q", resolved.Password) + } + + if _, err := store.Get(failRef); !os.IsNotExist(err) { + t.Fatalf("expected rollback to remove partially imported legacy secret ref, got err=%v", err) + } +} + +func TestImportLegacyConnectionsRollbackRemovesGeneratedSecretRefs(t *testing.T) { + failRef, err := secretstore.BuildRef(savedConnectionSecretKind, "legacy-2") + if err != nil { + t.Fatalf("BuildRef returned error: %v", err) + } + + store := newFailOnPutSecretStore(failRef) + app := NewAppWithSecretStore(store) + app.configDir = t.TempDir() + + imported, err := app.ImportLegacyConnections([]connection.LegacySavedConnection{ + { + Name: "Generated ID Legacy", + Config: connection.ConnectionConfig{ + Type: "postgres", + Host: "db.generated.local", + Port: 5432, + User: "postgres", + Password: "generated-secret", + }, + }, + { + ID: "legacy-2", + Name: "Will Fail", + Config: connection.ConnectionConfig{ + ID: "legacy-2", + Type: "mysql", + Host: "db.fail.local", + Port: 3306, + User: "root", + Password: "fail-secret", + }, + }, + }) + if err == nil { + t.Fatal("expected ImportLegacyConnections to return error") + } + if imported != nil { + t.Fatalf("expected no imported results after rollback, got %#v", imported) + } + + saved, err := app.GetSavedConnections() + if err != nil { + t.Fatalf("GetSavedConnections returned error: %v", err) + } + if len(saved) != 0 { + t.Fatalf("expected rollback to remove generated-id connection, got %d saved connections", len(saved)) + } + + if got := len(store.base.items); got != 0 { + t.Fatalf("expected rollback to remove generated secret refs, got %d remaining items", got) + } + if _, err := store.Get(failRef); !os.IsNotExist(err) { + t.Fatalf("expected rollback to remove failed explicit secret ref, got err=%v", err) + } +} + func TestImportConnectionPackagePayloadRollsBackOnSaveFailure(t *testing.T) { failRef, err := secretstore.BuildRef(savedConnectionSecretKind, "conn-2") if err != nil { @@ -313,7 +466,7 @@ func TestImportConnectionPackagePayloadRollsBackOnSaveFailure(t *testing.T) { } } -func TestImportConnectionsPayloadLegacyJSONKeepsExistingSecretWhenMissing(t *testing.T) { +func TestImportConnectionsPayloadLegacyJSONClearsExistingSecretWhenMissing(t *testing.T) { app := NewAppWithSecretStore(newFakeAppSecretStore()) app.configDir = t.TempDir() @@ -365,8 +518,162 @@ func TestImportConnectionsPayloadLegacyJSONKeepsExistingSecretWhenMissing(t *tes if err != nil { t.Fatalf("resolveConnectionSecrets returned error: %v", err) } - if resolved.Password != "legacy-secret" { - t.Fatalf("expected legacy import to preserve existing secret, got %q", resolved.Password) + if resolved.Password != "" { + t.Fatalf("expected legacy import to clear existing secret when the imported file omits it, got %q", resolved.Password) + } +} + +func TestImportConnectionsPayloadLegacyJSONLatestEntryWinsForSameID(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + raw, err := json.Marshal([]connection.LegacySavedConnection{ + { + ID: "legacy-dup", + Name: "First", + Config: connection.ConnectionConfig{ + ID: "legacy-dup", + Type: "postgres", + Host: "db.first.local", + Port: 5432, + User: "postgres", + Password: "first-secret", + }, + }, + { + ID: "legacy-dup", + Name: "Second", + Config: connection.ConnectionConfig{ + ID: "legacy-dup", + Type: "postgres", + Host: "db.second.local", + Port: 5432, + User: "postgres", + Password: "second-secret", + }, + }, + }) + if err != nil { + t.Fatalf("json.Marshal returned error: %v", err) + } + + imported, err := app.ImportConnectionsPayload(string(raw), "ignored") + if err != nil { + t.Fatalf("ImportConnectionsPayload returned error: %v", err) + } + if len(imported) != 1 { + t.Fatalf("expected duplicate legacy ids to return 1 final imported item, got %d", len(imported)) + } + if imported[0].Name != "Second" { + t.Fatalf("expected returned import result to keep latest legacy entry, got %q", imported[0].Name) + } + + saved, err := app.GetSavedConnections() + if err != nil { + t.Fatalf("GetSavedConnections returned error: %v", err) + } + if len(saved) != 1 { + t.Fatalf("expected 1 saved legacy item after duplicate id overwrite, got %d", len(saved)) + } + if saved[0].Name != "Second" { + t.Fatalf("expected latest legacy item to win, got %q", saved[0].Name) + } + + resolved, err := app.resolveConnectionSecrets(saved[0].Config) + if err != nil { + t.Fatalf("resolveConnectionSecrets returned error: %v", err) + } + if resolved.Password != "second-secret" { + t.Fatalf("expected latest legacy secret to win, got %q", resolved.Password) + } +} + +func TestImportConnectionsPayloadLegacyJSONLatestEntryWithoutPasswordDoesNotKeepEarlierDuplicateSecret(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + raw, err := json.Marshal([]connection.LegacySavedConnection{ + { + ID: "legacy-dup", + Name: "First", + Config: connection.ConnectionConfig{ + ID: "legacy-dup", + Type: "postgres", + Host: "db.first.local", + Port: 5432, + User: "postgres", + Password: "first-secret", + }, + }, + { + ID: "legacy-dup", + Name: "Second", + Config: connection.ConnectionConfig{ + ID: "legacy-dup", + Type: "postgres", + Host: "db.second.local", + Port: 5432, + User: "postgres", + }, + }, + }) + if err != nil { + t.Fatalf("json.Marshal returned error: %v", err) + } + + imported, err := app.ImportConnectionsPayload(string(raw), "ignored") + if err != nil { + t.Fatalf("ImportConnectionsPayload returned error: %v", err) + } + if len(imported) != 1 { + t.Fatalf("expected duplicate legacy ids to return 1 final imported item, got %d", len(imported)) + } + + saved, err := app.GetSavedConnections() + if err != nil { + t.Fatalf("GetSavedConnections returned error: %v", err) + } + if len(saved) != 1 { + t.Fatalf("expected 1 saved legacy item after duplicate id overwrite, got %d", len(saved)) + } + if saved[0].HasPrimaryPassword { + t.Fatalf("expected latest legacy item without password to clear earlier duplicate secret, got view=%#v", saved[0]) + } + + resolved, err := app.resolveConnectionSecrets(saved[0].Config) + if err != nil { + t.Fatalf("resolveConnectionSecrets returned error: %v", err) + } + if resolved.Password != "" { + t.Fatalf("expected latest legacy item without password to keep empty secret, got %q", resolved.Password) + } +} + +func TestImportConnectionsPayloadEnvelopeRejectsOversizedPayloadWithDedicatedError(t *testing.T) { + raw, err := json.Marshal(connectionPackageFile{ + SchemaVersion: connectionPackageSchemaVersion, + Kind: connectionPackageKind, + Cipher: connectionPackageCipher, + KDF: connectionPackageKDFSpec{ + Name: connectionPackageKDFName, + MemoryKiB: connectionPackageKDFDefaultMemoryKiB, + TimeCost: connectionPackageKDFDefaultTimeCost, + Parallelism: connectionPackageKDFDefaultParallelism, + Salt: "AAAAAAAAAAAAAAAAAAAAAA==", + }, + Nonce: "AAAAAAAAAAAAAAAA", + Payload: strings.Repeat("A", connectionPackageMaxPayloadBase64Bytes+4), + }) + if err != nil { + t.Fatalf("json.Marshal returned error: %v", err) + } + + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + _, err = app.ImportConnectionsPayload(string(raw), "package-password") + if !errors.Is(err, errConnectionPackagePayloadTooLarge) { + t.Fatalf("expected errConnectionPackagePayloadTooLarge, got %v", err) } } diff --git a/internal/app/connection_package_types.go b/internal/app/connection_package_types.go index df959c9..18b28b6 100644 --- a/internal/app/connection_package_types.go +++ b/internal/app/connection_package_types.go @@ -21,12 +21,18 @@ const ( connectionPackageKDFMaxMemoryKiB = 262144 connectionPackageKDFMaxTimeCost = 10 connectionPackageKDFMaxParallelism = 16 + + connectionPackageMaxCiphertextBytes = 16 * 1024 * 1024 + connectionPackageMaxPayloadBase64Bytes = ((connectionPackageMaxCiphertextBytes + 2) / 3) * 4 + connectionImportMaxFileBytes = connectionPackageMaxPayloadBase64Bytes + (1 * 1024 * 1024) ) var ( errConnectionPackagePasswordRequired = errors.New("恢复包密码不能为空") errConnectionPackageDecryptFailed = errors.New("文件密码错误或文件已损坏") errConnectionPackageUnsupported = errors.New("不支持的连接恢复包格式") + errConnectionImportFileTooLarge = errors.New("连接导入文件过大") + errConnectionPackagePayloadTooLarge = errors.New("连接恢复包过大") errConnectionPackageNotImplemented = errors.New("connection package not implemented") ) diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go index 64e7aef..924730a 100644 --- a/internal/app/methods_file.go +++ b/internal/app/methods_file.go @@ -259,6 +259,22 @@ func (cr *countingReader) Read(p []byte) (int, error) { return n, err } +func readImportedConnectionConfigFile(path string) (string, error) { + info, err := os.Stat(path) + if err != nil { + return "", err + } + if info.Size() > connectionImportMaxFileBytes { + return "", errConnectionImportFileTooLarge + } + + content, err := os.ReadFile(path) + if err != nil { + return "", err + } + return string(content), nil +} + func (a *App) ImportConfigFile() connection.QueryResult { selection, err := runtime.OpenFileDialog(a.ctx, runtime.OpenDialogOptions{ Title: "Select Config File", @@ -282,12 +298,12 @@ func (a *App) ImportConfigFile() connection.QueryResult { return connection.QueryResult{Success: false, Message: "已取消"} } - content, err := os.ReadFile(selection) + content, err := readImportedConnectionConfigFile(selection) if err != nil { return connection.QueryResult{Success: false, Message: err.Error()} } - return connection.QueryResult{Success: true, Data: string(content)} + return connection.QueryResult{Success: true, Data: content} } func (a *App) ExportConnectionsPackage(password string) connection.QueryResult { @@ -320,6 +336,9 @@ func (a *App) ExportConnectionsPackage(password string) connection.QueryResult { if err != nil { return connection.QueryResult{Success: false, Message: err.Error()} } + if len(content) > connectionImportMaxFileBytes { + return connection.QueryResult{Success: false, Message: errConnectionImportFileTooLarge.Error()} + } if err := os.WriteFile(filename, content, 0o644); err != nil { return connection.QueryResult{Success: false, Message: err.Error()} } diff --git a/internal/app/methods_file_import_test.go b/internal/app/methods_file_import_test.go new file mode 100644 index 0000000..d2b13bc --- /dev/null +++ b/internal/app/methods_file_import_test.go @@ -0,0 +1,33 @@ +package app + +import ( + "errors" + "os" + "path/filepath" + "testing" +) + +func TestReadImportedConnectionConfigFileRejectsOversizedFiles(t *testing.T) { + for _, ext := range []string{connectionPackageExtension, ".json"} { + t.Run(ext, func(t *testing.T) { + path := filepath.Join(t.TempDir(), "connections"+ext) + + file, err := os.Create(path) + if err != nil { + t.Fatalf("Create returned error: %v", err) + } + if err := file.Truncate(connectionImportMaxFileBytes + 1); err != nil { + file.Close() + t.Fatalf("Truncate returned error: %v", err) + } + if err := file.Close(); err != nil { + t.Fatalf("Close returned error: %v", err) + } + + _, err = readImportedConnectionConfigFile(path) + if !errors.Is(err, errConnectionImportFileTooLarge) { + t.Fatalf("oversized import file should return errConnectionImportFileTooLarge, got: %v", err) + } + }) + } +} diff --git a/internal/app/methods_redis.go b/internal/app/methods_redis.go index 8b4a0b0..08f6c22 100644 --- a/internal/app/methods_redis.go +++ b/internal/app/methods_redis.go @@ -19,12 +19,20 @@ import ( var ( redisCache = make(map[string]redis.RedisClient) redisCacheMu sync.Mutex + newRedisClientFunc = redis.NewRedisClient ) // getRedisClient gets or creates a Redis client from cache func (a *App) getRedisClient(config connection.ConnectionConfig) (redis.RedisClient, error) { - effectiveConfig := applyGlobalProxyToConnection(config) - connectConfig, proxyErr := resolveDialConfigWithProxy(effectiveConfig) + resolvedConfig, err := a.resolveConnectionSecrets(config) + if err != nil { + wrapped := wrapConnectError(config, err) + logger.Error(wrapped, "Redis 密文解析失败:%s", formatRedisConnSummary(config)) + return nil, wrapped + } + + effectiveConfig := applyGlobalProxyToConnection(resolvedConfig) + connectConfig, proxyErr := resolveDialConfigWithProxyFunc(effectiveConfig) if proxyErr != nil { wrapped := wrapConnectError(effectiveConfig, proxyErr) logger.Error(wrapped, "Redis 代理准备失败:%s", formatRedisConnSummary(effectiveConfig)) @@ -54,7 +62,7 @@ func (a *App) getRedisClient(config connection.ConnectionConfig) (redis.RedisCli } logger.Infof("创建 Redis 客户端实例:缓存Key=%s", shortKey) - client := redis.NewRedisClient() + client := newRedisClientFunc() if err := client.Connect(connectConfig); err != nil { wrapped := wrapConnectError(effectiveConfig, err) logger.Error(wrapped, "Redis 连接失败:%s 缓存Key=%s", formatRedisConnSummary(effectiveConfig), shortKey) diff --git a/internal/app/methods_redis_test.go b/internal/app/methods_redis_test.go new file mode 100644 index 0000000..f713cad --- /dev/null +++ b/internal/app/methods_redis_test.go @@ -0,0 +1,258 @@ +package app + +import ( + "testing" + + "GoNavi-Wails/internal/connection" + redislib "GoNavi-Wails/internal/redis" +) + +type capturingRedisClient struct { + connectConfig connection.ConnectionConfig +} + +func (c *capturingRedisClient) Connect(config connection.ConnectionConfig) error { + c.connectConfig = config + return nil +} + +func (c *capturingRedisClient) Close() error { return nil } + +func (c *capturingRedisClient) Ping() error { return nil } + +func (c *capturingRedisClient) ScanKeys(pattern string, cursor uint64, count int64) (*redislib.RedisScanResult, error) { + return &redislib.RedisScanResult{}, nil +} + +func (c *capturingRedisClient) GetKeyType(key string) (string, error) { return "", nil } + +func (c *capturingRedisClient) GetTTL(key string) (int64, error) { return 0, nil } + +func (c *capturingRedisClient) SetTTL(key string, ttl int64) error { return nil } + +func (c *capturingRedisClient) DeleteKeys(keys []string) (int64, error) { return 0, nil } + +func (c *capturingRedisClient) RenameKey(oldKey, newKey string) error { return nil } + +func (c *capturingRedisClient) KeyExists(key string) (bool, error) { return false, nil } + +func (c *capturingRedisClient) GetValue(key string) (*redislib.RedisValue, error) { + return &redislib.RedisValue{}, nil +} + +func (c *capturingRedisClient) GetString(key string) (string, error) { return "", nil } + +func (c *capturingRedisClient) SetString(key, value string, ttl int64) error { return nil } + +func (c *capturingRedisClient) GetHash(key string) (map[string]string, error) { return map[string]string{}, nil } + +func (c *capturingRedisClient) SetHashField(key, field, value string) error { return nil } + +func (c *capturingRedisClient) DeleteHashField(key string, fields ...string) error { return nil } + +func (c *capturingRedisClient) GetList(key string, start, stop int64) ([]string, error) { return nil, nil } + +func (c *capturingRedisClient) ListPush(key string, values ...string) error { return nil } + +func (c *capturingRedisClient) ListSet(key string, index int64, value string) error { return nil } + +func (c *capturingRedisClient) GetSet(key string) ([]string, error) { return nil, nil } + +func (c *capturingRedisClient) SetAdd(key string, members ...string) error { return nil } + +func (c *capturingRedisClient) SetRemove(key string, members ...string) error { return nil } + +func (c *capturingRedisClient) GetZSet(key string, start, stop int64) ([]redislib.ZSetMember, error) { + return nil, nil +} + +func (c *capturingRedisClient) ZSetAdd(key string, members ...redislib.ZSetMember) error { return nil } + +func (c *capturingRedisClient) ZSetRemove(key string, members ...string) error { return nil } + +func (c *capturingRedisClient) GetStream(key, start, stop string, count int64) ([]redislib.StreamEntry, error) { + return nil, nil +} + +func (c *capturingRedisClient) StreamAdd(key string, fields map[string]string, id string) (string, error) { + return "", nil +} + +func (c *capturingRedisClient) StreamDelete(key string, ids ...string) (int64, error) { return 0, nil } + +func (c *capturingRedisClient) ExecuteCommand(args []string) (interface{}, error) { return nil, nil } + +func (c *capturingRedisClient) GetServerInfo() (map[string]string, error) { return map[string]string{}, nil } + +func (c *capturingRedisClient) GetDatabases() ([]redislib.RedisDBInfo, error) { return nil, nil } + +func (c *capturingRedisClient) SelectDB(index int) error { return nil } + +func (c *capturingRedisClient) GetCurrentDB() int { return 0 } + +func (c *capturingRedisClient) FlushDB() error { return nil } + +func TestRedisConnectResolvesSavedSecretsByConnectionID(t *testing.T) { + testCases := []struct { + name string + savedConfig connection.ConnectionConfig + runtimeConfig connection.ConnectionConfig + assertResolved func(t *testing.T, got connection.ConnectionConfig) + }{ + { + name: "redis and ssh secrets", + savedConfig: connection.ConnectionConfig{ + ID: "redis-1", + Type: "redis", + Host: "redis.local", + Port: 6379, + Password: "redis-secret", + UseSSH: true, + SSH: connection.SSHConfig{ + Host: "ssh.local", + Port: 22, + User: "ops", + Password: "ssh-secret", + }, + }, + runtimeConfig: connection.ConnectionConfig{ + ID: "redis-1", + Type: "redis", + Host: "redis.local", + Port: 6379, + UseSSH: true, + SSH: connection.SSHConfig{ + Host: "ssh.local", + Port: 22, + User: "ops", + }, + }, + assertResolved: func(t *testing.T, got connection.ConnectionConfig) { + t.Helper() + if got.Password != "redis-secret" { + t.Fatalf("expected RedisConnect to resolve saved Redis password, got %q", got.Password) + } + if got.SSH.Password != "ssh-secret" { + t.Fatalf("expected RedisConnect to resolve saved SSH password, got %q", got.SSH.Password) + } + }, + }, + { + name: "proxy secret", + savedConfig: connection.ConnectionConfig{ + ID: "redis-1", + Type: "redis", + Host: "redis.local", + Port: 6379, + Password: "redis-secret", + UseProxy: true, + Proxy: connection.ProxyConfig{ + Type: "http", + Host: "proxy.local", + Port: 8080, + User: "proxy-user", + Password: "proxy-secret", + }, + }, + runtimeConfig: connection.ConnectionConfig{ + ID: "redis-1", + Type: "redis", + Host: "redis.local", + Port: 6379, + UseProxy: true, + Proxy: connection.ProxyConfig{ + Type: "http", + Host: "proxy.local", + Port: 8080, + User: "proxy-user", + }, + }, + assertResolved: func(t *testing.T, got connection.ConnectionConfig) { + t.Helper() + if got.Password != "redis-secret" { + t.Fatalf("expected RedisConnect to resolve saved Redis password, got %q", got.Password) + } + if got.Proxy.Password != "proxy-secret" { + t.Fatalf("expected RedisConnect to resolve saved proxy password, got %q", got.Proxy.Password) + } + }, + }, + { + name: "http tunnel secret", + savedConfig: connection.ConnectionConfig{ + ID: "redis-1", + Type: "redis", + Host: "redis.local", + Port: 6379, + Password: "redis-secret", + UseHTTPTunnel: true, + HTTPTunnel: connection.HTTPTunnelConfig{ + Host: "tunnel.local", + Port: 8443, + User: "tunnel-user", + Password: "tunnel-secret", + }, + }, + runtimeConfig: connection.ConnectionConfig{ + ID: "redis-1", + Type: "redis", + Host: "redis.local", + Port: 6379, + UseHTTPTunnel: true, + HTTPTunnel: connection.HTTPTunnelConfig{ + Host: "tunnel.local", + Port: 8443, + User: "tunnel-user", + }, + }, + assertResolved: func(t *testing.T, got connection.ConnectionConfig) { + t.Helper() + if got.Password != "redis-secret" { + t.Fatalf("expected RedisConnect to resolve saved Redis password, got %q", got.Password) + } + if got.HTTPTunnel.Password != "tunnel-secret" { + t.Fatalf("expected RedisConnect to resolve saved HTTP tunnel password, got %q", got.HTTPTunnel.Password) + } + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + _, err := app.SaveConnection(connection.SavedConnectionInput{ + ID: "redis-1", + Name: "Redis Saved", + Config: testCase.savedConfig, + }) + if err != nil { + t.Fatalf("SaveConnection returned error: %v", err) + } + + CloseAllRedisClients() + client := &capturingRedisClient{} + originalNewRedisClientFunc := newRedisClientFunc + originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc + defer func() { + newRedisClientFunc = originalNewRedisClientFunc + resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc + CloseAllRedisClients() + }() + newRedisClientFunc = func() redislib.RedisClient { + return client + } + resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) { + return raw, nil + } + + result := app.RedisConnect(testCase.runtimeConfig) + if !result.Success { + t.Fatalf("RedisConnect returned failure: %+v", result) + } + + testCase.assertResolved(t, client.connectConfig) + }) + } +} diff --git a/internal/app/methods_saved_connections.go b/internal/app/methods_saved_connections.go index d8d916d..fc2351e 100644 --- a/internal/app/methods_saved_connections.go +++ b/internal/app/methods_saved_connections.go @@ -1,6 +1,10 @@ package app -import "GoNavi-Wails/internal/connection" +import ( + "strings" + + "GoNavi-Wails/internal/connection" +) func (a *App) savedConnectionRepository() *savedConnectionRepository { return newSavedConnectionRepository(a.configDir, a.secretStore) @@ -23,16 +27,20 @@ func (a *App) DuplicateConnection(id string) (connection.SavedConnectionView, er } func (a *App) ImportLegacyConnections(items []connection.LegacySavedConnection) ([]connection.SavedConnectionView, error) { - result := make([]connection.SavedConnectionView, 0, len(items)) - repo := a.savedConnectionRepository() + inputs := make([]connection.SavedConnectionInput, 0, len(items)) for _, item := range items { - view, err := repo.Save(connection.SavedConnectionInput(item)) - if err != nil { - return nil, err - } - result = append(result, view) + input := connection.SavedConnectionInput(item) + input.ClearPrimaryPassword = strings.TrimSpace(item.Config.Password) == "" + input.ClearSSHPassword = strings.TrimSpace(item.Config.SSH.Password) == "" + input.ClearProxyPassword = strings.TrimSpace(item.Config.Proxy.Password) == "" + input.ClearHTTPTunnelPassword = strings.TrimSpace(item.Config.HTTPTunnel.Password) == "" + input.ClearMySQLReplicaPassword = strings.TrimSpace(item.Config.MySQLReplicaPassword) == "" + input.ClearMongoReplicaPassword = strings.TrimSpace(item.Config.MongoReplicaPassword) == "" + input.ClearOpaqueURI = strings.TrimSpace(item.Config.URI) == "" + input.ClearOpaqueDSN = strings.TrimSpace(item.Config.DSN) == "" + inputs = append(inputs, input) } - return result, nil + return a.importSavedConnectionsAtomically(inputs) } func (a *App) SaveGlobalProxy(input connection.SaveGlobalProxyInput) (connection.GlobalProxyView, error) { diff --git a/internal/app/methods_saved_connections_test.go b/internal/app/methods_saved_connections_test.go index a813c53..dc8b63c 100644 --- a/internal/app/methods_saved_connections_test.go +++ b/internal/app/methods_saved_connections_test.go @@ -219,7 +219,7 @@ func TestImportLegacyConnectionsIsIdempotentForSameID(t *testing.T) { } } -func TestImportLegacyConnectionsKeepsExistingSecretWhenReimportOmitsPassword(t *testing.T) { +func TestImportLegacyConnectionsClearsExistingSecretWhenReimportOmitsPassword(t *testing.T) { app := NewAppWithSecretStore(newFakeAppSecretStore()) app.configDir = t.TempDir() @@ -267,7 +267,7 @@ func TestImportLegacyConnectionsKeepsExistingSecretWhenReimportOmitsPassword(t * if err != nil { t.Fatalf("resolveConnectionSecrets returned error: %v", err) } - if resolved.Password != "secret-1" { - t.Fatalf("expected original password to be preserved, got %q", resolved.Password) + if resolved.Password != "" { + t.Fatalf("expected missing import password to clear existing secret, got %q", resolved.Password) } } diff --git a/internal/app/methods_update.go b/internal/app/methods_update.go index 240f446..43f2d90 100644 --- a/internal/app/methods_update.go +++ b/internal/app/methods_update.go @@ -30,6 +30,12 @@ const ( updateDownloadProgressEvent = "update:download-progress" ) +var ( + updateFetchLatestRelease = fetchLatestRelease + updateFetchReleaseSHA256 = fetchReleaseSHA256 + updateLogCheckError = func(err error) { logger.Error(err, "检查更新失败") } +) + type updateState struct { lastCheck *UpdateInfo downloading bool @@ -100,9 +106,19 @@ type githubAsset struct { } func (a *App) CheckForUpdates() connection.QueryResult { + return a.checkForUpdates(true) +} + +func (a *App) CheckForUpdatesSilently() connection.QueryResult { + return a.checkForUpdates(false) +} + +func (a *App) checkForUpdates(logFailure bool) connection.QueryResult { info, err := fetchLatestUpdateInfo() if err != nil { - logger.Error(err, "检查更新失败") + if logFailure { + updateLogCheckError(err) + } return connection.QueryResult{Success: false, Message: err.Error()} } @@ -359,7 +375,7 @@ func (a *App) downloadAndStageUpdate(info UpdateInfo) connection.QueryResult { } func fetchLatestUpdateInfo() (UpdateInfo, error) { - release, err := fetchLatestRelease() + release, err := updateFetchLatestRelease() if err != nil { return UpdateInfo{}, err } @@ -370,6 +386,17 @@ func fetchLatestUpdateInfo() (UpdateInfo, error) { return UpdateInfo{}, errors.New("无法解析最新版本号") } + hasUpdate := compareVersion(currentVersion, latestVersion) < 0 + if !hasUpdate { + return UpdateInfo{ + HasUpdate: false, + CurrentVersion: currentVersion, + LatestVersion: latestVersion, + ReleaseName: release.Name, + ReleaseNotesURL: release.HTMLURL, + }, nil + } + assetVersion := strings.TrimSpace(release.TagName) if assetVersion == "" { assetVersion = latestVersion @@ -383,7 +410,7 @@ func fetchLatestUpdateInfo() (UpdateInfo, error) { return UpdateInfo{}, err } - hashMap, err := fetchReleaseSHA256(release.Assets) + hashMap, err := updateFetchReleaseSHA256(release.Assets) if err != nil { return UpdateInfo{}, err } @@ -391,9 +418,6 @@ func fetchLatestUpdateInfo() (UpdateInfo, error) { if sha256Value == "" { return UpdateInfo{}, errors.New("SHA256SUMS 未包含当前平台更新包") } - - hasUpdate := compareVersion(currentVersion, latestVersion) < 0 - return UpdateInfo{ HasUpdate: hasUpdate, CurrentVersion: currentVersion, @@ -407,6 +431,30 @@ func fetchLatestUpdateInfo() (UpdateInfo, error) { }, nil } +func swapUpdateFetchLatestRelease(next func() (*githubRelease, error)) func() { + original := updateFetchLatestRelease + updateFetchLatestRelease = next + return func() { + updateFetchLatestRelease = original + } +} + +func swapUpdateFetchReleaseSHA256(next func([]githubAsset) (map[string]string, error)) func() { + original := updateFetchReleaseSHA256 + updateFetchReleaseSHA256 = next + return func() { + updateFetchReleaseSHA256 = original + } +} + +func swapUpdateCheckErrorLogger(next func(error)) func() { + original := updateLogCheckError + updateLogCheckError = next + return func() { + updateLogCheckError = original + } +} + func getCurrentAuthor() string { if env := strings.TrimSpace(os.Getenv("GONAVI_AUTHOR")); env != "" { return env diff --git a/internal/app/methods_update_test.go b/internal/app/methods_update_test.go new file mode 100644 index 0000000..ed1beaf --- /dev/null +++ b/internal/app/methods_update_test.go @@ -0,0 +1,160 @@ +package app + +import ( + "errors" + stdRuntime "runtime" + "testing" +) + +func TestFetchLatestUpdateInfoSkipsChecksumWhenCurrentVersionIsAlreadyLatest(t *testing.T) { + assetName, err := expectedAssetName(stdRuntime.GOOS, stdRuntime.GOARCH, "v0.6.5") + if err != nil { + t.Fatalf("expectedAssetName returned error: %v", err) + } + + originalVersion := AppVersion + AppVersion = "0.6.5" + defer func() { + AppVersion = originalVersion + }() + + releaseCalled := false + restoreRelease := swapUpdateFetchLatestRelease(func() (*githubRelease, error) { + releaseCalled = true + return &githubRelease{ + TagName: "v0.6.5", + Name: "v0.6.5", + HTMLURL: "https://github.com/Syngnat/GoNavi/releases/tag/v0.6.5", + Assets: []githubAsset{ + { + Name: assetName, + BrowserDownloadURL: "https://example.com/" + assetName, + Size: 1024, + }, + }, + }, nil + }) + defer restoreRelease() + + checksumCalled := false + restoreChecksum := swapUpdateFetchReleaseSHA256(func([]githubAsset) (map[string]string, error) { + checksumCalled = true + return nil, errors.New("checksum should not be fetched when no update is needed") + }) + defer restoreChecksum() + + info, err := fetchLatestUpdateInfo() + if err != nil { + t.Fatalf("fetchLatestUpdateInfo returned error: %v", err) + } + if !releaseCalled { + t.Fatal("expected latest release metadata to be fetched") + } + if checksumCalled { + t.Fatal("expected SHA256SUMS fetch to be skipped when current version is already latest") + } + if info.HasUpdate { + t.Fatalf("expected HasUpdate=false, got %#v", info) + } + if info.LatestVersion != "0.6.5" || info.CurrentVersion != "0.6.5" { + t.Fatalf("unexpected version info: %#v", info) + } +} + +func TestFetchLatestUpdateInfoFetchesChecksumWhenUpdateIsAvailable(t *testing.T) { + assetName, err := expectedAssetName(stdRuntime.GOOS, stdRuntime.GOARCH, "v0.6.5") + if err != nil { + t.Fatalf("expectedAssetName returned error: %v", err) + } + + originalVersion := AppVersion + AppVersion = "0.6.4" + defer func() { + AppVersion = originalVersion + }() + + restoreRelease := swapUpdateFetchLatestRelease(func() (*githubRelease, error) { + return &githubRelease{ + TagName: "v0.6.5", + Name: "v0.6.5", + HTMLURL: "https://github.com/Syngnat/GoNavi/releases/tag/v0.6.5", + Assets: []githubAsset{ + { + Name: assetName, + BrowserDownloadURL: "https://example.com/" + assetName, + Size: 4096, + }, + }, + }, nil + }) + defer restoreRelease() + + checksumCalled := false + restoreChecksum := swapUpdateFetchReleaseSHA256(func([]githubAsset) (map[string]string, error) { + checksumCalled = true + return map[string]string{ + assetName: "abc123", + }, nil + }) + defer restoreChecksum() + + info, err := fetchLatestUpdateInfo() + if err != nil { + t.Fatalf("fetchLatestUpdateInfo returned error: %v", err) + } + if !checksumCalled { + t.Fatal("expected SHA256SUMS fetch when update is available") + } + if !info.HasUpdate { + t.Fatalf("expected HasUpdate=true, got %#v", info) + } + if info.SHA256 != "abc123" || info.AssetName != assetName { + t.Fatalf("unexpected update info: %#v", info) + } +} + +func TestCheckForUpdatesLogsFailuresForManualChecks(t *testing.T) { + app := &App{} + + restoreRelease := swapUpdateFetchLatestRelease(func() (*githubRelease, error) { + return nil, errors.New("request timed out") + }) + defer restoreRelease() + + logged := 0 + restoreLogger := swapUpdateCheckErrorLogger(func(error) { + logged++ + }) + defer restoreLogger() + + result := app.CheckForUpdates() + if result.Success { + t.Fatalf("expected failure result, got %#v", result) + } + if logged != 1 { + t.Fatalf("expected manual check to log once, got %d", logged) + } +} + +func TestCheckForUpdatesSilentlySkipsFailureLogs(t *testing.T) { + app := &App{} + + restoreRelease := swapUpdateFetchLatestRelease(func() (*githubRelease, error) { + return nil, errors.New("request timed out") + }) + defer restoreRelease() + + logged := 0 + restoreLogger := swapUpdateCheckErrorLogger(func(error) { + logged++ + }) + defer restoreLogger() + + result := app.CheckForUpdatesSilently() + if result.Success { + t.Fatalf("expected failure result, got %#v", result) + } + if logged != 0 { + t.Fatalf("expected silent check to skip error logging, got %d", logged) + } +} diff --git a/internal/secretstore/keyring_store.go b/internal/secretstore/keyring_store.go index 93fe0bc..e1d4e1e 100644 --- a/internal/secretstore/keyring_store.go +++ b/internal/secretstore/keyring_store.go @@ -3,7 +3,10 @@ package secretstore import ( "errors" "fmt" + "os" "runtime" + "strings" + "syscall" "github.com/99designs/keyring" ) @@ -56,19 +59,32 @@ func (s *keyringStore) Delete(ref string) error { func (s *keyringStore) HealthCheck() error { _, err := s.ring.Get(healthCheckRef) - if err == nil || errors.Is(err, keyring.ErrKeyNotFound) { + if err == nil || isKeyringSecretNotFound(err) { return nil } return wrapKeyringError(err) } func wrapKeyringError(err error) error { - if err == nil || errors.Is(err, keyring.ErrKeyNotFound) || IsUnavailable(err) { + if err == nil || IsUnavailable(err) { return err } + if isKeyringSecretNotFound(err) { + return os.ErrNotExist + } return &UnavailableError{Reason: err.Error()} } +func isKeyringSecretNotFound(err error) bool { + if err == nil { + return false + } + if errors.Is(err, keyring.ErrKeyNotFound) || errors.Is(err, syscall.Errno(1168)) { + return true + } + return strings.EqualFold(strings.TrimSpace(err.Error()), keyring.ErrKeyNotFound.Error()) +} + func keyringConfigFor(goos string) (keyring.Config, error) { backends := allowedBackendsFor(goos) if len(backends) == 0 { diff --git a/internal/secretstore/keyring_store_test.go b/internal/secretstore/keyring_store_test.go index 03fc49f..440b5d6 100644 --- a/internal/secretstore/keyring_store_test.go +++ b/internal/secretstore/keyring_store_test.go @@ -2,6 +2,9 @@ package secretstore import ( "errors" + "fmt" + "os" + "syscall" "testing" "github.com/99designs/keyring" @@ -58,6 +61,33 @@ func TestKeyringStoreHealthCheckTreatsMissingProbeItemAsHealthy(t *testing.T) { } } +func TestKeyringStoreHealthCheckTreatsWinCredNotFoundMessageAsHealthy(t *testing.T) { + t.Parallel() + + store := &keyringStore{ring: fakeKeyringClient{getErr: errors.New("The specified item could not be found in the keyring")}} + if err := store.HealthCheck(); err != nil { + t.Fatalf("HealthCheck should accept WinCred not-found errors, got %v", err) + } +} + +func TestKeyringStoreHealthCheckDoesNotTreatWrappedOsErrNotExistAsHealthy(t *testing.T) { + t.Parallel() + + store := &keyringStore{ring: fakeKeyringClient{getErr: fmt.Errorf("backend unavailable: %w", os.ErrNotExist)}} + if err := store.HealthCheck(); err == nil { + t.Fatal("HealthCheck should not accept unrelated wrapped os.ErrNotExist errors as healthy") + } +} + +func TestKeyringStoreHealthCheckDoesNotTreatPlainOsErrNotExistAsHealthy(t *testing.T) { + t.Parallel() + + store := &keyringStore{ring: fakeKeyringClient{getErr: os.ErrNotExist}} + if err := store.HealthCheck(); err == nil { + t.Fatal("HealthCheck should not accept plain os.ErrNotExist errors as healthy") + } +} + func TestKeyringStoreHealthCheckReturnsUnavailableErrorOnBackendFailure(t *testing.T) { t.Parallel() @@ -82,6 +112,67 @@ func TestNewKeyringStoreReturnsUnavailableStoreWhenOpenFails(t *testing.T) { } } +func TestWrapKeyringErrorNormalizesWinCredNotFoundMessage(t *testing.T) { + t.Parallel() + + err := wrapKeyringError(errors.New("The specified item could not be found in the keyring")) + if err == nil { + t.Fatal("wrapKeyringError should preserve missing-secret semantics") + } + if !os.IsNotExist(err) { + t.Fatalf("wrapKeyringError should map WinCred not-found errors to os.ErrNotExist, got %v", err) + } + if IsUnavailable(err) { + t.Fatalf("wrapKeyringError should not treat WinCred not-found errors as unavailable, got %v", err) + } +} + +func TestWrapKeyringErrorNormalizesWrappedKeyringErrKeyNotFound(t *testing.T) { + t.Parallel() + + err := wrapKeyringError(fmt.Errorf("wrapped: %w", keyring.ErrKeyNotFound)) + if err == nil { + t.Fatal("wrapKeyringError should preserve wrapped missing-secret semantics") + } + if !os.IsNotExist(err) { + t.Fatalf("wrapKeyringError should map wrapped ErrKeyNotFound to os.ErrNotExist, got %v", err) + } + if IsUnavailable(err) { + t.Fatalf("wrapKeyringError should not treat wrapped ErrKeyNotFound as unavailable, got %v", err) + } +} + +func TestWrapKeyringErrorNormalizesWinCredErrno1168(t *testing.T) { + t.Parallel() + + err := wrapKeyringError(syscall.Errno(1168)) + if err == nil { + t.Fatal("wrapKeyringError should preserve WinCred errno missing-secret semantics") + } + if !os.IsNotExist(err) { + t.Fatalf("wrapKeyringError should map WinCred errno to os.ErrNotExist, got %v", err) + } + if IsUnavailable(err) { + t.Fatalf("wrapKeyringError should not treat WinCred errno as unavailable, got %v", err) + } +} + +func TestWrapKeyringErrorDoesNotSwallowUnrelatedElementNotFoundMessages(t *testing.T) { + t.Parallel() + + backendErr := errors.New("database element not found while enumerating providers") + err := wrapKeyringError(backendErr) + if err == nil { + t.Fatal("wrapKeyringError should preserve backend failures") + } + if os.IsNotExist(err) { + t.Fatalf("wrapKeyringError should not map unrelated element-not-found errors to os.ErrNotExist, got %v", err) + } + if !IsUnavailable(err) { + t.Fatalf("wrapKeyringError should keep unrelated backend failures unavailable, got %v", err) + } +} + type fakeKeyringClient struct { getErr error item keyring.Item From 1751e14d209e413d868b2829cef4b6bf8d868e9f Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Sat, 11 Apr 2026 20:12:23 +0800 Subject: [PATCH 4/7] =?UTF-8?q?=F0=9F=90=9B=20fix(security):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E5=AE=89=E5=85=A8=E6=9B=B4=E6=96=B0=E9=87=8D=E6=A3=80?= =?UTF-8?q?=E5=8D=A1=E6=AD=BB=E4=B8=8E=20Redis=20=E5=AF=86=E6=96=87?= =?UTF-8?q?=E5=85=BC=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- frontend/src/App.tsx | 65 +++++-- .../src/components/SecurityUpdateBanner.tsx | 4 +- .../components/SecurityUpdateIntroModal.tsx | 10 +- .../SecurityUpdateProgressModal.tsx | 14 +- .../SecurityUpdateSettingsModal.tsx | 23 ++- .../src/utils/secureConfigBootstrap.test.ts | 65 +++++++ frontend/src/utils/secureConfigBootstrap.ts | 4 + .../utils/securityUpdateRepairFlow.test.ts | 16 ++ .../src/utils/securityUpdateRepairFlow.ts | 8 + .../src/utils/securityUpdateVisuals.test.ts | 11 ++ frontend/src/utils/securityUpdateVisuals.ts | 43 ++++- internal/app/connection_secret_resolution.go | 60 +++++++ .../app/connection_secret_resolution_test.go | 75 ++++++++ internal/app/methods_redis.go | 71 +++++++- internal/app/methods_redis_test.go | 170 ++++++++++++++++++ 15 files changed, 585 insertions(+), 54 deletions(-) diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 59a200e..4ae0f0b 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -48,6 +48,7 @@ import { hasSecurityUpdateRecentResult, resolveSecurityUpdateRepairEntry, resolveSecurityUpdateSettingsFocusTarget, + shouldRefreshSecurityUpdateDetailsFocus, shouldReopenSecurityUpdateDetails, shouldRetrySecurityUpdateAfterRepairSave, type SecurityUpdateRepairSource, @@ -276,6 +277,7 @@ function App() { status?: Partial | null, options?: { openSettings?: boolean; + refreshFocus?: boolean; resetBannerDismissed?: boolean; }, ) => { @@ -287,8 +289,10 @@ function App() { setIsSecurityUpdateBannerDismissed(false); } if (options?.openSettings) { - setSecurityUpdateSettingsFocusTarget(resolveSecurityUpdateSettingsFocusTarget(nextStatus)); - setSecurityUpdateSettingsFocusRequest((current) => current + 1); + if (options.refreshFocus !== false) { + setSecurityUpdateSettingsFocusTarget(resolveSecurityUpdateSettingsFocusTarget(nextStatus)); + setSecurityUpdateSettingsFocusRequest((current) => current + 1); + } setIsSecurityUpdateSettingsOpen(true); } return nextStatus; @@ -845,12 +849,16 @@ function App() { const stageText = mode === 'retry' ? '正在校验更新结果' : '正在更新安全存储'; + const detailsWereOpen = isSecurityUpdateSettingsOpen; setSecurityUpdateProgressStage(stageText); setIsSecurityUpdateProgressOpen(true); setIsSecurityUpdateIntroOpen(false); + setIsSecurityUpdateSettingsOpen(false); + let nextStatus: SecurityUpdateStatus | null = null; + let shouldOpenSettings = false; + let refreshSettingsFocus = false; try { - let nextStatus: SecurityUpdateStatus | null = null; if (mode === 'start') { const result = await startSecurityUpdateFromBootstrap({ backend: backendApp, @@ -891,29 +899,44 @@ function App() { }, nextStatus); } - const shouldOpenSettings = nextStatus.overallStatus === 'needs_attention' || nextStatus.overallStatus === 'rolled_back'; - applySecurityUpdateStatus(nextStatus, { - openSettings: shouldOpenSettings, + shouldOpenSettings = nextStatus.overallStatus === 'needs_attention' || nextStatus.overallStatus === 'rolled_back'; + refreshSettingsFocus = shouldRefreshSecurityUpdateDetailsFocus({ + requestedOpen: shouldOpenSettings, + wasOpen: detailsWereOpen, }); - - if (nextStatus.overallStatus === 'completed') { - setSecurityUpdateHasLegacySensitiveItems(false); - setSecurityUpdateRawPayload(null); - setIsSecurityUpdateSettingsOpen(false); - void message.success('已保存配置已完成安全更新'); - } else if (nextStatus.overallStatus === 'needs_attention') { - void message.warning('更新尚未完成,有少量配置需要你处理'); - } else if (nextStatus.overallStatus === 'rolled_back') { - void message.warning('本次更新未完成,系统已保留当前可用配置'); - } } catch (err: any) { console.warn('Failed to execute security update round', err); - void message.error(err?.message || '安全更新未完成,请稍后重试'); - } finally { setIsSecurityUpdateProgressOpen(false); + if (detailsWereOpen) { + setIsSecurityUpdateSettingsOpen(true); + } + void message.error(err?.message || '安全更新未完成,请稍后重试'); + return; + } + + if (!nextStatus) { + setIsSecurityUpdateProgressOpen(false); + return; + } + setIsSecurityUpdateProgressOpen(false); + applySecurityUpdateStatus(nextStatus, { + openSettings: shouldOpenSettings, + refreshFocus: refreshSettingsFocus, + }); + + if (nextStatus.overallStatus === 'completed') { + setSecurityUpdateHasLegacySensitiveItems(false); + setSecurityUpdateRawPayload(null); + setIsSecurityUpdateSettingsOpen(false); + void message.success('已保存配置已完成安全更新'); + } else if (nextStatus.overallStatus === 'needs_attention') { + void message.warning('更新尚未完成,有少量配置需要你处理'); + } else if (nextStatus.overallStatus === 'rolled_back') { + void message.warning('本次更新未完成,系统已保留当前可用配置'); } }, [ applySecurityUpdateStatus, + isSecurityUpdateSettingsOpen, normalizeSecurityUpdateStatus, replaceConnections, replaceGlobalProxy, @@ -2355,6 +2378,7 @@ function App() { status={securityUpdateStatus} darkMode={darkMode} overlayTheme={overlayTheme} + surfaceOpacity={effectiveOpacity} onStart={handleStartSecurityUpdate} onRetry={handleRetrySecurityUpdate} onRestart={handleRestartSecurityUpdate} @@ -2501,6 +2525,7 @@ function App() { loading={isSecurityUpdateProgressOpen} darkMode={darkMode} overlayTheme={overlayTheme} + surfaceOpacity={effectiveOpacity} onStart={handleStartSecurityUpdate} onPostpone={handlePostponeSecurityUpdate} onViewDetails={() => handleOpenSecurityUpdateSettings()} @@ -2509,6 +2534,7 @@ function App() { open={isSecurityUpdateSettingsOpen} darkMode={darkMode} overlayTheme={overlayTheme} + surfaceOpacity={effectiveOpacity} status={securityUpdateStatus} focusTarget={securityUpdateSettingsFocusTarget} focusRequest={securityUpdateSettingsFocusRequest} @@ -2522,6 +2548,7 @@ function App() { open={isSecurityUpdateProgressOpen} stageText={securityUpdateProgressStage} overlayTheme={overlayTheme} + surfaceOpacity={effectiveOpacity} /> void; onRetry: () => void; onRestart: () => void; @@ -74,6 +75,7 @@ const SecurityUpdateBanner = ({ status, darkMode, overlayTheme, + surfaceOpacity = 1, onStart, onRetry, onRestart, @@ -92,7 +94,7 @@ const SecurityUpdateBanner = ({ margin: '12px 12px 0', padding: '14px 16px', borderRadius: 16, - ...getSecurityUpdateBannerSurfaceStyle(overlayTheme), + ...getSecurityUpdateBannerSurfaceStyle(overlayTheme, surfaceOpacity), display: 'flex', alignItems: 'center', gap: 16, diff --git a/frontend/src/components/SecurityUpdateIntroModal.tsx b/frontend/src/components/SecurityUpdateIntroModal.tsx index e02c099..e8d0db5 100644 --- a/frontend/src/components/SecurityUpdateIntroModal.tsx +++ b/frontend/src/components/SecurityUpdateIntroModal.tsx @@ -7,6 +7,7 @@ import { SECURITY_UPDATE_ACTION_BUTTON_CLASS, SECURITY_UPDATE_MODAL_CLASS, getSecurityUpdateActionButtonStyle, + getSecurityUpdateShellSurfaceStyle, } from '../utils/securityUpdateVisuals'; interface SecurityUpdateIntroModalProps { @@ -14,6 +15,7 @@ interface SecurityUpdateIntroModalProps { loading?: boolean; darkMode: boolean; overlayTheme: OverlayWorkbenchTheme; + surfaceOpacity?: number; onStart: () => void; onPostpone: () => void; onViewDetails: () => void; @@ -30,6 +32,7 @@ const SecurityUpdateIntroModal = ({ loading = false, darkMode, overlayTheme, + surfaceOpacity = 1, onStart, onPostpone, onViewDetails, @@ -71,12 +74,7 @@ const SecurityUpdateIntroModal = ({ onCancel={onPostpone} width={560} styles={{ - content: { - background: overlayTheme.shellBg, - border: overlayTheme.shellBorder, - boxShadow: overlayTheme.shellShadow, - backdropFilter: overlayTheme.shellBackdropFilter, - }, + content: getSecurityUpdateShellSurfaceStyle(overlayTheme, surfaceOpacity), header: { background: 'transparent', borderBottom: 'none', paddingBottom: 8 }, body: { paddingTop: 8 }, footer: { background: 'transparent', borderTop: 'none', paddingTop: 10 }, diff --git a/frontend/src/components/SecurityUpdateProgressModal.tsx b/frontend/src/components/SecurityUpdateProgressModal.tsx index 5e3888b..35f130c 100644 --- a/frontend/src/components/SecurityUpdateProgressModal.tsx +++ b/frontend/src/components/SecurityUpdateProgressModal.tsx @@ -2,13 +2,17 @@ import { Modal, Spin } from 'antd'; import { SafetyCertificateOutlined } from '@ant-design/icons'; import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme'; -import { SECURITY_UPDATE_MODAL_CLASS } from '../utils/securityUpdateVisuals'; +import { + SECURITY_UPDATE_MODAL_CLASS, + getSecurityUpdateShellSurfaceStyle, +} from '../utils/securityUpdateVisuals'; interface SecurityUpdateProgressModalProps { open: boolean; stageText: string; detailText?: string; overlayTheme: OverlayWorkbenchTheme; + surfaceOpacity?: number; } const SecurityUpdateProgressModal = ({ @@ -16,6 +20,7 @@ const SecurityUpdateProgressModal = ({ stageText, detailText, overlayTheme, + surfaceOpacity = 1, }: SecurityUpdateProgressModalProps) => { return ( ({ borderRadius: 14, padding: 16, - ...getSecurityUpdateSectionSurfaceStyle(overlayTheme, options), + ...getSecurityUpdateSectionSurfaceStyle(overlayTheme, { + ...options, + surfaceOpacity, + }), }); const EMPTY_FOCUS_STATE: SecurityUpdateFocusState = { @@ -59,6 +64,7 @@ const SecurityUpdateSettingsModal = ({ open, darkMode, overlayTheme, + surfaceOpacity = 1, status, focusTarget = null, focusRequest = 0, @@ -174,7 +180,7 @@ const SecurityUpdateSettingsModal = ({ ]} width={760} styles={{ - content: getSecurityUpdateShellSurfaceStyle(overlayTheme), + content: getSecurityUpdateShellSurfaceStyle(overlayTheme, surfaceOpacity), header: { background: 'transparent', borderBottom: 'none', paddingBottom: 8 }, body: { paddingTop: 8, maxHeight: 640, overflowY: 'auto' }, footer: { background: 'transparent', borderTop: 'none', paddingTop: 10 }, @@ -184,7 +190,7 @@ const SecurityUpdateSettingsModal = ({
@@ -211,7 +217,7 @@ const SecurityUpdateSettingsModal = ({
-
+
影响范围
@@ -226,9 +232,8 @@ const SecurityUpdateSettingsModal = ({
@@ -239,7 +244,7 @@ const SecurityUpdateSettingsModal = ({
-
+
待处理清单
@@ -258,7 +263,7 @@ const SecurityUpdateSettingsModal = ({
最近一次结果 diff --git a/frontend/src/utils/secureConfigBootstrap.test.ts b/frontend/src/utils/secureConfigBootstrap.test.ts index 32c9cd5..7f7b0cd 100644 --- a/frontend/src/utils/secureConfigBootstrap.test.ts +++ b/frontend/src/utils/secureConfigBootstrap.test.ts @@ -297,6 +297,71 @@ describe('secureConfigBootstrap', () => { ])); }); + it('does not merge local legacy pending items back into an active migration round that already reports needs_attention', () => { + const status = mergeSecurityUpdateStatusWithLegacySource({ + migrationId: 'migration-active-1', + overallStatus: 'needs_attention', + summary: { total: 3, updated: 2, pending: 1, skipped: 0, failed: 0 }, + issues: [ + { + id: 'ai-provider-openai-main', + scope: 'ai_provider', + refId: 'openai-main', + title: 'OpenAI', + severity: 'medium', + status: 'needs_attention', + reasonCode: 'secret_missing', + action: 'open_ai_settings', + message: 'AI 提供商配置需要补充后才能完成安全更新。', + }, + ], + }, legacyPayload); + + expect(status.overallStatus).toBe('needs_attention'); + expect(status.summary).toEqual({ + total: 3, + updated: 2, + pending: 1, + skipped: 0, + failed: 0, + }); + expect(status.issues).toEqual([ + expect.objectContaining({ id: 'ai-provider-openai-main', scope: 'ai_provider', refId: 'openai-main' }), + ]); + }); + + it('does not merge local legacy pending items back into a rolled_back migration round', () => { + const status = mergeSecurityUpdateStatusWithLegacySource({ + migrationId: 'migration-active-2', + overallStatus: 'rolled_back', + summary: { total: 3, updated: 1, pending: 0, skipped: 0, failed: 2 }, + issues: [ + { + id: 'system-blocked', + scope: 'system', + title: '系统回滚', + severity: 'high', + status: 'failed', + reasonCode: 'environment_blocked', + action: 'view_details', + message: '后端已回滚本轮更新,需要处理后重试。', + }, + ], + }, legacyPayload); + + expect(status.overallStatus).toBe('rolled_back'); + expect(status.summary).toEqual({ + total: 3, + updated: 1, + pending: 0, + skipped: 0, + failed: 2, + }); + expect(status.issues).toEqual([ + expect.objectContaining({ id: 'system-blocked', scope: 'system' }), + ]); + }); + it('loads backend secure config directly when no legacy source exists', async () => { const storage = createMemoryStorage(); const replaceConnections = vi.fn(); diff --git a/frontend/src/utils/secureConfigBootstrap.ts b/frontend/src/utils/secureConfigBootstrap.ts index f457024..cd26ef2 100644 --- a/frontend/src/utils/secureConfigBootstrap.ts +++ b/frontend/src/utils/secureConfigBootstrap.ts @@ -193,6 +193,7 @@ export const mergeSecurityUpdateStatusWithLegacySource = ( }, issues: Array.isArray(status?.issues) ? status.issues : [], }; + const hasActiveMigrationRound = String(base.migrationId || '').trim() !== ''; const baseNonLegacyIssues = base.issues.filter((issue) => !isLocalLegacyIssue(issue)); const legacy = buildLegacyPendingDetails(rawPayload); @@ -231,6 +232,9 @@ export const mergeSecurityUpdateStatusWithLegacySource = ( } if (base.overallStatus === 'rolled_back' || base.overallStatus === 'needs_attention') { + if (hasActiveMigrationRound) { + return base; + } return { ...base, summary: hasMeaningfulSummary(base.summary) || legacy.hasLegacyItems ? legacySummary.summary : legacy.summary, diff --git a/frontend/src/utils/securityUpdateRepairFlow.test.ts b/frontend/src/utils/securityUpdateRepairFlow.test.ts index 0cb57f7..2642dcc 100644 --- a/frontend/src/utils/securityUpdateRepairFlow.test.ts +++ b/frontend/src/utils/securityUpdateRepairFlow.test.ts @@ -6,6 +6,7 @@ import { resolveSecurityUpdateFocusState, resolveSecurityUpdateRepairEntry, resolveSecurityUpdateSettingsFocusTarget, + shouldRefreshSecurityUpdateDetailsFocus, shouldReopenSecurityUpdateDetails, shouldRetrySecurityUpdateAfterRepairSave, } from './securityUpdateRepairFlow'; @@ -136,4 +137,19 @@ describe('securityUpdateRepairFlow', () => { expect(shouldRetrySecurityUpdateAfterRepairSave('ai')).toBe(false); expect(shouldRetrySecurityUpdateAfterRepairSave(null)).toBe(false); }); + + it('does not force a new focus pulse when the details modal is already open and only the round result is refreshing', () => { + expect(shouldRefreshSecurityUpdateDetailsFocus({ + requestedOpen: true, + wasOpen: true, + })).toBe(false); + expect(shouldRefreshSecurityUpdateDetailsFocus({ + requestedOpen: true, + wasOpen: false, + })).toBe(true); + expect(shouldRefreshSecurityUpdateDetailsFocus({ + requestedOpen: false, + wasOpen: true, + })).toBe(false); + }); }); diff --git a/frontend/src/utils/securityUpdateRepairFlow.ts b/frontend/src/utils/securityUpdateRepairFlow.ts index 9a6be1e..bac59c4 100644 --- a/frontend/src/utils/securityUpdateRepairFlow.ts +++ b/frontend/src/utils/securityUpdateRepairFlow.ts @@ -113,6 +113,14 @@ export const shouldReopenSecurityUpdateDetails = ( repairSource: SecurityUpdateRepairSource | null | undefined, ): boolean => repairSource === 'connection' || repairSource === 'proxy' || repairSource === 'ai'; +export const shouldRefreshSecurityUpdateDetailsFocus = ({ + requestedOpen, + wasOpen, +}: { + requestedOpen: boolean; + wasOpen: boolean; +}): boolean => requestedOpen && !wasOpen; + export const shouldRetrySecurityUpdateAfterRepairSave = ( repairSource: SecurityUpdateRepairSource | null | undefined, ): boolean => repairSource === 'connection'; diff --git a/frontend/src/utils/securityUpdateVisuals.test.ts b/frontend/src/utils/securityUpdateVisuals.test.ts index 7d0d8e7..781e47c 100644 --- a/frontend/src/utils/securityUpdateVisuals.test.ts +++ b/frontend/src/utils/securityUpdateVisuals.test.ts @@ -62,6 +62,17 @@ describe('securityUpdateVisuals', () => { }); }); + it('can scale shell surface alpha with the current appearance opacity so reminder layers stay visually consistent', () => { + const lightTheme = buildOverlayWorkbenchTheme(false); + const fadedShell = getSecurityUpdateShellSurfaceStyle(lightTheme, 0.5); + const fadedBanner = getSecurityUpdateBannerSurfaceStyle(lightTheme, 0.5); + + expect(fadedShell.background).not.toBe(lightTheme.shellBg); + expect(fadedShell.border).not.toBe(lightTheme.shellBorder); + expect(fadedShell.background).toContain('0.49'); + expect(fadedBanner.background).toContain('0.49'); + }); + it('can emphasize a section surface for transient focus and recent-result highlighting', () => { const lightTheme = buildOverlayWorkbenchTheme(false); const darkTheme = buildOverlayWorkbenchTheme(true); diff --git a/frontend/src/utils/securityUpdateVisuals.ts b/frontend/src/utils/securityUpdateVisuals.ts index 735ca2d..a93fce6 100644 --- a/frontend/src/utils/securityUpdateVisuals.ts +++ b/frontend/src/utils/securityUpdateVisuals.ts @@ -10,6 +10,25 @@ export const SECURITY_UPDATE_RESULT_CARD_ACTIVE_CLASS = 'security-update-result- type SecurityUpdateSectionSurfaceOptions = { emphasized?: boolean; + surfaceOpacity?: number; +}; + +const clampOpacity = (value: number): number => Math.min(1, Math.max(0.1, value)); + +const formatAlpha = (value: number): string => ( + Number(value.toFixed(3)).toString() +); + +const applySurfaceOpacity = (token: string, surfaceOpacity = 1): string => { + const normalizedOpacity = clampOpacity(surfaceOpacity); + if (normalizedOpacity >= 0.999) { + return token; + } + + return token.replace( + /rgba\(\s*([^)]+?)\s*,\s*([0-9]*\.?[0-9]+)\s*\)/g, + (_, channels: string, alpha: string) => `rgba(${channels}, ${formatAlpha(Number(alpha) * normalizedOpacity)})`, + ); }; const getSecurityUpdateHighlightBorder = (overlayTheme: OverlayWorkbenchTheme): string => ( @@ -40,17 +59,19 @@ export const getSecurityUpdateActionButtonStyle = (): CSSProperties => ({ export const getSecurityUpdateShellSurfaceStyle = ( overlayTheme: OverlayWorkbenchTheme, + surfaceOpacity = 1, ): CSSProperties => ({ - border: overlayTheme.shellBorder, - background: overlayTheme.shellBg, - boxShadow: overlayTheme.shellShadow, + border: applySurfaceOpacity(overlayTheme.shellBorder, surfaceOpacity), + background: applySurfaceOpacity(overlayTheme.shellBg, surfaceOpacity), + boxShadow: applySurfaceOpacity(overlayTheme.shellShadow, surfaceOpacity), backdropFilter: overlayTheme.shellBackdropFilter, }); export const getSecurityUpdateBannerSurfaceStyle = ( overlayTheme: OverlayWorkbenchTheme, + surfaceOpacity = 1, ): CSSProperties => ({ - ...getSecurityUpdateShellSurfaceStyle(overlayTheme), + ...getSecurityUpdateShellSurfaceStyle(overlayTheme, surfaceOpacity), boxShadow: 'none', }); @@ -58,8 +79,16 @@ export const getSecurityUpdateSectionSurfaceStyle = ( overlayTheme: OverlayWorkbenchTheme, options: SecurityUpdateSectionSurfaceOptions = {}, ): CSSProperties => ({ - border: options.emphasized ? getSecurityUpdateHighlightBorder(overlayTheme) : overlayTheme.sectionBorder, - background: options.emphasized ? getSecurityUpdateHighlightBackground(overlayTheme) : overlayTheme.sectionBg, - boxShadow: options.emphasized ? getSecurityUpdateHighlightShadow(overlayTheme) : 'none', + border: applySurfaceOpacity( + options.emphasized ? getSecurityUpdateHighlightBorder(overlayTheme) : overlayTheme.sectionBorder, + options.surfaceOpacity, + ), + background: applySurfaceOpacity( + options.emphasized ? getSecurityUpdateHighlightBackground(overlayTheme) : overlayTheme.sectionBg, + options.surfaceOpacity, + ), + boxShadow: options.emphasized + ? applySurfaceOpacity(getSecurityUpdateHighlightShadow(overlayTheme), options.surfaceOpacity) + : 'none', transition: 'background 180ms ease, border-color 180ms ease, box-shadow 180ms ease', }); diff --git a/internal/app/connection_secret_resolution.go b/internal/app/connection_secret_resolution.go index 5e7eb6f..ac21714 100644 --- a/internal/app/connection_secret_resolution.go +++ b/internal/app/connection_secret_resolution.go @@ -1,10 +1,13 @@ package app import ( + "errors" "fmt" + "os" "strings" "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/secretstore" ) func (a *App) resolveConnectionSecrets(config connection.ConnectionConfig) (connection.ConnectionConfig, error) { @@ -15,6 +18,9 @@ func (a *App) resolveConnectionSecrets(config connection.ConnectionConfig) (conn repo := newSavedConnectionRepository(a.configDir, a.secretStore) view, err := repo.Find(config.ID) if err != nil { + if shouldFallbackToInlineConnectionSecrets(config, err) { + return config, nil + } return config, normalizeConnectionSecretResolutionError(config, err) } @@ -24,6 +30,9 @@ func (a *App) resolveConnectionSecrets(config connection.ConnectionConfig) (conn } bundle, err := repo.loadSecretBundle(view) if err != nil { + if shouldFallbackToInlineConnectionSecrets(config, err) { + return mergeInlineConnectionSecrets(base, config), nil + } return base, normalizeConnectionSecretResolutionError(base, err) } resolved := mergeConnectionSecretBundleIntoConfig(base, bundle) @@ -31,6 +40,57 @@ func (a *App) resolveConnectionSecrets(config connection.ConnectionConfig) (conn return resolved, nil } +func shouldFallbackToInlineConnectionSecrets(config connection.ConnectionConfig, err error) bool { + if err == nil || !connectionConfigCarriesInlineSecrets(config) || secretstore.IsUnavailable(err) { + return false + } + if errors.Is(err, os.ErrNotExist) { + return true + } + lower := strings.ToLower(strings.TrimSpace(err.Error())) + return strings.Contains(lower, "saved connection not found:") +} + +func connectionConfigCarriesInlineSecrets(config connection.ConnectionConfig) bool { + return strings.TrimSpace(config.Password) != "" || + strings.TrimSpace(config.SSH.Password) != "" || + strings.TrimSpace(config.Proxy.Password) != "" || + strings.TrimSpace(config.HTTPTunnel.Password) != "" || + strings.TrimSpace(config.MySQLReplicaPassword) != "" || + strings.TrimSpace(config.MongoReplicaPassword) != "" || + strings.TrimSpace(config.URI) != "" || + strings.TrimSpace(config.DSN) != "" +} + +func mergeInlineConnectionSecrets(base connection.ConnectionConfig, inline connection.ConnectionConfig) connection.ConnectionConfig { + merged := base + if strings.TrimSpace(inline.Password) != "" { + merged.Password = inline.Password + } + if strings.TrimSpace(inline.SSH.Password) != "" { + merged.SSH.Password = inline.SSH.Password + } + if strings.TrimSpace(inline.Proxy.Password) != "" { + merged.Proxy.Password = inline.Proxy.Password + } + if strings.TrimSpace(inline.HTTPTunnel.Password) != "" { + merged.HTTPTunnel.Password = inline.HTTPTunnel.Password + } + if strings.TrimSpace(inline.MySQLReplicaPassword) != "" { + merged.MySQLReplicaPassword = inline.MySQLReplicaPassword + } + if strings.TrimSpace(inline.MongoReplicaPassword) != "" { + merged.MongoReplicaPassword = inline.MongoReplicaPassword + } + if strings.TrimSpace(inline.URI) != "" { + merged.URI = inline.URI + } + if strings.TrimSpace(inline.DSN) != "" { + merged.DSN = inline.DSN + } + return merged +} + func normalizeConnectionSecretResolutionError(config connection.ConnectionConfig, err error) error { if err == nil { return nil diff --git a/internal/app/connection_secret_resolution_test.go b/internal/app/connection_secret_resolution_test.go index e09e24a..8ecf590 100644 --- a/internal/app/connection_secret_resolution_test.go +++ b/internal/app/connection_secret_resolution_test.go @@ -61,3 +61,78 @@ func TestResolveConnectionSecretsReturnsFriendlyMessageWhenSavedSecretSourceIsMi t.Fatalf("expected a secret-specific error message, got %q", err.Error()) } } + +func TestResolveConnectionSecretsFallsBackToInlineSecretsWhenSavedConnectionIsMissing(t *testing.T) { + store := newFakeAppSecretStore() + app := NewAppWithSecretStore(store) + app.configDir = t.TempDir() + + input := connection.ConnectionConfig{ + ID: "legacy-inline", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + Password: "inline-secret", + DSN: "postgres://postgres:inline-secret@db.local/app", + } + + resolved, err := app.resolveConnectionSecrets(input) + if err != nil { + t.Fatalf("expected inline secrets to be used as fallback, got error: %v", err) + } + if resolved.Password != "inline-secret" { + t.Fatalf("expected inline password to be preserved, got %q", resolved.Password) + } + if resolved.DSN != "postgres://postgres:inline-secret@db.local/app" { + t.Fatalf("expected inline DSN to be preserved, got %q", resolved.DSN) + } +} + +func TestResolveConnectionSecretsFallsBackToInlineSecretsWhenSavedSecretBundleIsMissing(t *testing.T) { + store := newFakeAppSecretStore() + app := NewAppWithSecretStore(store) + app.configDir = t.TempDir() + + view, err := app.SaveConnection(connection.SavedConnectionInput{ + ID: "conn-inline-fallback", + Name: "Primary", + Config: connection.ConnectionConfig{ + ID: "conn-inline-fallback", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + Password: "stored-secret", + DSN: "postgres://postgres:stored-secret@db.local/app", + }, + }) + if err != nil { + t.Fatalf("SaveConnection returned error: %v", err) + } + if view.SecretRef == "" { + t.Fatal("expected saved connection to allocate a secret ref") + } + if err := store.Delete(view.SecretRef); err != nil { + t.Fatalf("Delete returned error: %v", err) + } + + resolved, err := app.resolveConnectionSecrets(connection.ConnectionConfig{ + ID: "conn-inline-fallback", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + Password: "inline-secret", + DSN: "postgres://postgres:inline-secret@db.local/app", + }) + if err != nil { + t.Fatalf("expected inline secrets to be used when secret bundle is missing, got error: %v", err) + } + if resolved.Password != "inline-secret" { + t.Fatalf("expected inline password to be preserved, got %q", resolved.Password) + } + if resolved.DSN != "postgres://postgres:inline-secret@db.local/app" { + t.Fatalf("expected inline DSN to be preserved, got %q", resolved.DSN) + } +} diff --git a/internal/app/methods_redis.go b/internal/app/methods_redis.go index 08f6c22..71b70a7 100644 --- a/internal/app/methods_redis.go +++ b/internal/app/methods_redis.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "math" + "net/url" "strconv" "strings" "sync" @@ -62,18 +63,78 @@ func (a *App) getRedisClient(config connection.ConnectionConfig) (redis.RedisCli } logger.Infof("创建 Redis 客户端实例:缓存Key=%s", shortKey) - client := newRedisClientFunc() - if err := client.Connect(connectConfig); err != nil { - wrapped := wrapConnectError(effectiveConfig, err) - logger.Error(wrapped, "Redis 连接失败:%s 缓存Key=%s", formatRedisConnSummary(effectiveConfig), shortKey) + client, connectedConfig, connectErr := connectRedisClientWithLegacyRootFallback(connectConfig) + if connectErr != nil { + wrapped := wrapConnectError(connectedConfig, connectErr) + logger.Error(wrapped, "Redis 连接失败:%s 缓存Key=%s", formatRedisConnSummary(connectedConfig), shortKey) return nil, wrapped } redisCache[key] = client - logger.Infof("Redis 连接成功并写入缓存:%s 缓存Key=%s", formatRedisConnSummary(effectiveConfig), shortKey) + logger.Infof("Redis 连接成功并写入缓存:%s 缓存Key=%s", formatRedisConnSummary(connectedConfig), shortKey) return client, nil } +func connectRedisClientWithLegacyRootFallback(config connection.ConnectionConfig) (redis.RedisClient, connection.ConnectionConfig, error) { + client := newRedisClientFunc() + if err := client.Connect(config); err == nil { + return client, config, nil + } else { + client.Close() + if !shouldRetryRedisWithClearedLegacyRoot(config, err) { + return nil, config, err + } + + fallbackConfig := config + fallbackConfig.User = "" + logger.Warnf("Redis 使用用户名 root 认证失败,已按历史默认值回退为空用户名重试:%s", formatRedisConnSummary(config)) + + fallbackClient := newRedisClientFunc() + if retryErr := fallbackClient.Connect(fallbackConfig); retryErr != nil { + fallbackClient.Close() + return nil, fallbackConfig, retryErr + } + return fallbackClient, fallbackConfig, nil + } +} + +func shouldRetryRedisWithClearedLegacyRoot(config connection.ConnectionConfig, err error) bool { + if err == nil || strings.ToLower(strings.TrimSpace(config.Type)) != "redis" { + return false + } + if strings.TrimSpace(config.User) != "root" { + return false + } + if _, ok := extractExplicitRedisUsername(config.URI); ok { + return false + } + + lower := strings.ToLower(strings.TrimSpace(err.Error())) + return strings.Contains(lower, "wrongpass") || + strings.Contains(lower, "invalid username-password pair") || + strings.Contains(lower, "auth failed") || + strings.Contains(lower, "wrong number of arguments for 'auth' command") || + strings.Contains(lower, "authentication failed") +} + +func extractExplicitRedisUsername(rawURI string) (string, bool) { + trimmed := strings.TrimSpace(rawURI) + if trimmed == "" { + return "", false + } + + parsed, err := url.Parse(trimmed) + if err != nil || parsed.User == nil { + return "", false + } + + username := strings.TrimSpace(parsed.User.Username()) + if username == "" { + return "", false + } + return username, true +} + func getRedisClientCacheKey(config connection.ConnectionConfig) string { normalized := normalizeCacheKeyConfig(config) b, _ := json.Marshal(normalized) diff --git a/internal/app/methods_redis_test.go b/internal/app/methods_redis_test.go index f713cad..5f801bf 100644 --- a/internal/app/methods_redis_test.go +++ b/internal/app/methods_redis_test.go @@ -1,6 +1,7 @@ package app import ( + "errors" "testing" "GoNavi-Wails/internal/connection" @@ -92,6 +93,20 @@ func (c *capturingRedisClient) GetCurrentDB() int { return 0 } func (c *capturingRedisClient) FlushDB() error { return nil } +type scriptedRedisClient struct { + capturingRedisClient + connectErr error + connectCalls *[]connection.ConnectionConfig +} + +func (c *scriptedRedisClient) Connect(config connection.ConnectionConfig) error { + c.connectConfig = config + if c.connectCalls != nil { + *c.connectCalls = append(*c.connectCalls, config) + } + return c.connectErr +} + func TestRedisConnectResolvesSavedSecretsByConnectionID(t *testing.T) { testCases := []struct { name string @@ -215,6 +230,34 @@ func TestRedisConnectResolvesSavedSecretsByConnectionID(t *testing.T) { } }, }, + { + name: "explicit redis username from uri is preserved even when it is root", + savedConfig: connection.ConnectionConfig{ + ID: "redis-1", + Type: "redis", + Host: "redis.local", + Port: 6379, + User: "root", + Password: "redis-secret", + URI: "redis://root:redis-secret@redis.local:6379/0", + }, + runtimeConfig: connection.ConnectionConfig{ + ID: "redis-1", + Type: "redis", + Host: "redis.local", + Port: 6379, + User: "root", + }, + assertResolved: func(t *testing.T, got connection.ConnectionConfig) { + t.Helper() + if got.User != "root" { + t.Fatalf("expected RedisConnect to preserve explicit uri user root, got %q", got.User) + } + if got.URI != "redis://root:redis-secret@redis.local:6379/0" { + t.Fatalf("expected RedisConnect to restore saved redis uri, got %q", got.URI) + } + }, + }, } for _, testCase := range testCases { @@ -256,3 +299,130 @@ func TestRedisConnectResolvesSavedSecretsByConnectionID(t *testing.T) { }) } } + +func TestRedisConnectPreservesExplicitRootUserWithoutURIWhenConnectSucceeds(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + _, err := app.SaveConnection(connection.SavedConnectionInput{ + ID: "redis-1", + Name: "Redis Saved", + Config: connection.ConnectionConfig{ + ID: "redis-1", + Type: "redis", + Host: "redis.local", + Port: 6379, + User: "root", + Password: "redis-secret", + }, + }) + if err != nil { + t.Fatalf("SaveConnection returned error: %v", err) + } + + CloseAllRedisClients() + connectCalls := make([]connection.ConnectionConfig, 0, 1) + client := &scriptedRedisClient{connectCalls: &connectCalls} + originalNewRedisClientFunc := newRedisClientFunc + originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc + defer func() { + newRedisClientFunc = originalNewRedisClientFunc + resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc + CloseAllRedisClients() + }() + newRedisClientFunc = func() redislib.RedisClient { + return client + } + resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) { + return raw, nil + } + + result := app.RedisConnect(connection.ConnectionConfig{ + ID: "redis-1", + Type: "redis", + Host: "redis.local", + Port: 6379, + User: "root", + }) + if !result.Success { + t.Fatalf("RedisConnect returned failure: %+v", result) + } + if len(connectCalls) != 1 { + t.Fatalf("expected exactly one Redis connect attempt, got %d", len(connectCalls)) + } + if connectCalls[0].User != "root" { + t.Fatalf("expected RedisConnect to preserve explicit root user when connect succeeds, got %q", connectCalls[0].User) + } +} + +func TestRedisConnectRetriesLegacyDefaultRootUserWithoutUsernameAfterAuthFailure(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + _, err := app.SaveConnection(connection.SavedConnectionInput{ + ID: "redis-1", + Name: "Redis Saved", + Config: connection.ConnectionConfig{ + ID: "redis-1", + Type: "redis", + Host: "redis.local", + Port: 6379, + User: "root", + Password: "redis-secret", + }, + }) + if err != nil { + t.Fatalf("SaveConnection returned error: %v", err) + } + + CloseAllRedisClients() + connectCalls := make([]connection.ConnectionConfig, 0, 2) + clients := []redislib.RedisClient{ + &scriptedRedisClient{ + connectErr: errors.New("WRONGPASS invalid username-password pair"), + connectCalls: &connectCalls, + }, + &scriptedRedisClient{ + connectCalls: &connectCalls, + }, + } + clientIndex := 0 + originalNewRedisClientFunc := newRedisClientFunc + originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc + defer func() { + newRedisClientFunc = originalNewRedisClientFunc + resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc + CloseAllRedisClients() + }() + newRedisClientFunc = func() redislib.RedisClient { + if clientIndex >= len(clients) { + t.Fatalf("unexpected Redis client allocation #%d", clientIndex+1) + } + client := clients[clientIndex] + clientIndex++ + return client + } + resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) { + return raw, nil + } + + result := app.RedisConnect(connection.ConnectionConfig{ + ID: "redis-1", + Type: "redis", + Host: "redis.local", + Port: 6379, + User: "root", + }) + if !result.Success { + t.Fatalf("RedisConnect returned failure after fallback: %+v", result) + } + if len(connectCalls) != 2 { + t.Fatalf("expected RedisConnect to retry exactly once after auth failure, got %d attempts", len(connectCalls)) + } + if connectCalls[0].User != "root" { + t.Fatalf("expected first Redis connect attempt to keep root user, got %q", connectCalls[0].User) + } + if connectCalls[1].User != "" { + t.Fatalf("expected fallback Redis connect attempt to clear legacy root user, got %q", connectCalls[1].User) + } +} From 52d2ee7592116d25ba35ee30ad5c3fec275b2fd6 Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Sat, 11 Apr 2026 23:51:43 +0800 Subject: [PATCH 5/7] =?UTF-8?q?=E2=9C=A8=20feat(connection-package):=20?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E8=BF=9E=E6=8E=A5=E6=81=A2=E5=A4=8D=E5=8C=85?= =?UTF-8?q?=E5=8F=8C=E6=A8=A1=E5=BC=8F=E5=8A=A0=E5=AF=86=E5=AF=BC=E5=85=A5?= =?UTF-8?q?=E5=AF=BC=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 v2 连接恢复包 appKey 与文件密码双模式加密链路 - 扩展前后端导入导出流程并兼容 v1 与 legacy 格式 - 修复无文件密码恢复包导入误弹密码框导致的流程阻塞 --- frontend/src/App.tsx | 83 ++++- .../ConnectionPackagePasswordModal.tsx | 60 ++- frontend/src/main.browserMock.test.ts | 6 +- frontend/src/main.tsx | 4 +- frontend/src/utils/connectionExport.test.ts | 76 +++- frontend/src/utils/connectionExport.ts | 64 +++- frontend/wailsjs/go/app/App.d.ts | 2 +- frontend/wailsjs/go/models.ts | 14 + internal/app/connection_package_appkey.go | 196 ++++++++++ .../app/connection_package_appkey_test.go | 141 ++++++++ internal/app/connection_package_crypto.go | 341 +++++++++++++++++- .../app/connection_package_crypto_test.go | 180 +++++++++ internal/app/connection_package_transfer.go | 52 +++ .../app/connection_package_transfer_test.go | 180 +++++++++ internal/app/connection_package_types.go | 83 ++++- internal/app/methods_file.go | 14 +- 16 files changed, 1447 insertions(+), 49 deletions(-) create mode 100644 internal/app/connection_package_appkey.go create mode 100644 internal/app/connection_package_appkey_test.go diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 4ae0f0b..ea479d4 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -26,6 +26,7 @@ import { getConnectionWorkbenchState } from './utils/startupReadiness'; import { toSaveGlobalProxyInput } from './utils/globalProxyDraft'; import { detectConnectionImportKind, + isConnectionPackagePasswordRequiredError, resolveConnectionPackageExportResult, normalizeConnectionPackagePassword, } from './utils/connectionExport'; @@ -120,6 +121,8 @@ type ConnectionPackageDialogMode = 'import' | 'export'; type ConnectionPackageDialogState = { open: boolean; mode: ConnectionPackageDialogMode; + includeSecrets: boolean; + useFilePassword: boolean; password: string; error: string; confirmLoading: boolean; @@ -128,6 +131,8 @@ type ConnectionPackageDialogState = { const createClosedConnectionPackageDialogState = (): ConnectionPackageDialogState => ({ open: false, mode: 'export', + includeSecrets: true, + useFilePassword: false, password: '', error: '', confirmLoading: false, @@ -1476,22 +1481,24 @@ function App() { return; } - if (importKind === 'encrypted-package') { - setPendingConnectionImportPayload(raw); - setConnectionPackageDialog({ - open: true, - mode: 'import', - password: '', - error: '', - confirmLoading: false, - }); - return; - } - try { + setPendingConnectionImportPayload(null); const importedViews = await importConnectionsPayload(raw, ''); void message.success(`成功导入 ${importedViews.length} 个连接`); } catch (e: any) { + if (isConnectionPackagePasswordRequiredError(e)) { + setPendingConnectionImportPayload(raw); + setConnectionPackageDialog({ + open: true, + mode: 'import', + includeSecrets: true, + useFilePassword: false, + password: '', + error: '', + confirmLoading: false, + }); + return; + } void message.error(e?.message || '导入失败'); } }; @@ -1505,6 +1512,8 @@ function App() { setConnectionPackageDialog({ open: true, mode: 'export', + includeSecrets: true, + useFilePassword: false, password: '', error: '', confirmLoading: false, @@ -1515,7 +1524,7 @@ function App() { const backendApp = (window as any).go?.app?.App; const password = normalizeConnectionPackagePassword(connectionPackageDialog.password); - if (!password) { + if (connectionPackageDialog.mode === 'import' && !password) { setConnectionPackageDialog((current) => ({ ...current, error: '恢复包密码不能为空', @@ -1523,9 +1532,25 @@ function App() { return; } + if ( + connectionPackageDialog.mode === 'export' + && connectionPackageDialog.includeSecrets + && connectionPackageDialog.useFilePassword + && !password + ) { + setConnectionPackageDialog((current) => ({ + ...current, + error: '文件保护密码不能为空', + })); + return; + } + setConnectionPackageDialog((current) => ({ ...current, - password, + password: ( + current.mode === 'export' + && (!current.includeSecrets || !current.useFilePassword) + ) ? '' : password, error: '', confirmLoading: true, })); @@ -1536,7 +1561,13 @@ function App() { throw new Error('导出失败:当前后端未提供新版导出能力'); } - const res = await backendApp.ExportConnectionsPackage(password); + const res = await backendApp.ExportConnectionsPackage({ + includeSecrets: connectionPackageDialog.includeSecrets, + filePassword: ( + connectionPackageDialog.includeSecrets + && connectionPackageDialog.useFilePassword + ) ? password : '', + }); const exportResult = resolveConnectionPackageExportResult(connectionPackageDialog, res); if (exportResult.kind === 'canceled') { setConnectionPackageDialog(exportResult.nextDialog); @@ -2559,11 +2590,31 @@ function App() { /> { + setConnectionPackageDialog((current) => ({ + ...current, + includeSecrets: value, + useFilePassword: value ? current.useFilePassword : false, + password: value ? current.password : '', + error: '', + })); + }} + onUseFilePasswordChange={(value) => { + setConnectionPackageDialog((current) => ({ + ...current, + useFilePassword: value, + password: value ? current.password : '', + error: '', + })); + }} onPasswordChange={(value) => { setConnectionPackageDialog((current) => ({ ...current, diff --git a/frontend/src/components/ConnectionPackagePasswordModal.tsx b/frontend/src/components/ConnectionPackagePasswordModal.tsx index 2f415e1..0f42d29 100644 --- a/frontend/src/components/ConnectionPackagePasswordModal.tsx +++ b/frontend/src/components/ConnectionPackagePasswordModal.tsx @@ -1,16 +1,23 @@ import React from 'react'; -import { Input, Modal, Typography } from 'antd'; +import { Checkbox, Input, Modal, Typography } from 'antd'; const { Text } = Typography; +type ConnectionPackagePasswordModalMode = 'import' | 'export'; + export interface ConnectionPackagePasswordModalProps { open: boolean; title: string; + mode?: ConnectionPackagePasswordModalMode; + includeSecrets?: boolean; + useFilePassword?: boolean; password: string; error?: string; confirmLoading?: boolean; confirmText?: string; cancelText?: string; + onIncludeSecretsChange?: (value: boolean) => void; + onUseFilePasswordChange?: (value: boolean) => void; onPasswordChange: (value: string) => void; onConfirm: () => void; onCancel: () => void; @@ -19,15 +26,29 @@ export interface ConnectionPackagePasswordModalProps { export default function ConnectionPackagePasswordModal({ open, title, + mode = 'import', + includeSecrets = true, + useFilePassword = false, password, error, confirmLoading, confirmText = '确认', cancelText = '取消', + onIncludeSecretsChange, + onUseFilePasswordChange, onPasswordChange, onConfirm, onCancel, }: ConnectionPackagePasswordModalProps) { + const isExportMode = mode === 'export'; + const showFilePasswordInput = isExportMode ? useFilePassword : true; + const placeholder = isExportMode ? '请输入文件保护密码(可选)' : '请输入恢复包密码'; + const helperText = !includeSecrets + ? '将仅导出连接配置,不包含密码。' + : (useFilePassword + ? '请通过单独渠道将密码告知接收方,不要和文件一起发送。' + : '密码已加密保护。如需通过公网传输,建议设置文件保护密码。'); + return ( - onPasswordChange(event.target.value)} - /> + {isExportMode ? ( +
+ onIncludeSecretsChange?.(event.target.checked)} + > + 导出连接密码 + + onUseFilePasswordChange?.(event.target.checked)} + > + 设置文件保护密码 + +
+ ) : null} + {showFilePasswordInput ? ( + onPasswordChange(event.target.value)} + /> + ) : null} + {isExportMode ? ( + + {helperText} + + ) : null} {error ? ( {error} diff --git a/frontend/src/main.browserMock.test.ts b/frontend/src/main.browserMock.test.ts index 9d18204..7802732 100644 --- a/frontend/src/main.browserMock.test.ts +++ b/frontend/src/main.browserMock.test.ts @@ -51,8 +51,8 @@ const importMain = async () => { app?: { App?: { ImportConfigFile: () => Promise<{ success: boolean; message?: string }>; - ImportConnectionsPayload: (raw: string) => Promise; - ExportConnectionsPackage: () => Promise<{ success: boolean; message?: string }>; + ImportConnectionsPayload: (raw: string, password?: string) => Promise; + ExportConnectionsPackage: (options?: { includeSecrets?: boolean; filePassword?: string }) => Promise<{ success: boolean; message?: string }>; }; }; }; @@ -83,7 +83,7 @@ describe('main browser mock', () => { success: false, message: '已取消', }); - await expect(app!.ExportConnectionsPackage()).resolves.toEqual({ + await expect(app!.ExportConnectionsPackage({ includeSecrets: true, filePassword: '' })).resolves.toEqual({ success: false, message: '浏览器 mock 不支持恢复包导出', }); diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx index fea428b..392cd43 100644 --- a/frontend/src/main.tsx +++ b/frontend/src/main.tsx @@ -123,7 +123,7 @@ if (typeof window !== 'undefined' && !(window as any).go) { OpenDownloadedUpdateDirectory: async () => ({ success: false }), InstallUpdateAndRestart: async () => ({ success: false }), ImportConfigFile: async () => ({ success: false, message: '已取消' }), - ImportConnectionsPayload: async (raw: string) => { + ImportConnectionsPayload: async (raw: string, _password?: string) => { try { const parsed = JSON.parse(raw); if (Array.isArray(parsed)) { @@ -134,7 +134,7 @@ if (typeof window !== 'undefined' && !(window as any).go) { } throw new Error('浏览器 mock 不支持恢复包导入,仅支持历史 JSON 连接数组'); }, - ExportConnectionsPackage: async () => ({ success: false, message: '浏览器 mock 不支持恢复包导出' }), + ExportConnectionsPackage: async (_options?: { includeSecrets?: boolean; filePassword?: string }) => ({ success: false, message: '浏览器 mock 不支持恢复包导出' }), ExportData: async () => ({ success: false }), GetGlobalProxyConfig: async () => ({ success: true, data: cloneBrowserMockValue(mockGlobalProxy) }), SaveGlobalProxy: async (input: any) => saveMockGlobalProxy(input), diff --git a/frontend/src/utils/connectionExport.test.ts b/frontend/src/utils/connectionExport.test.ts index d4f9720..d18541a 100644 --- a/frontend/src/utils/connectionExport.test.ts +++ b/frontend/src/utils/connectionExport.test.ts @@ -2,13 +2,64 @@ import { describe, expect, it } from 'vitest'; import { detectConnectionImportKind, + isConnectionPackagePasswordRequiredError, isConnectionPackageExportCanceled, resolveConnectionPackageExportResult, normalizeConnectionPackagePassword, } from './connectionExport'; describe('connectionExport', () => { - it('detects encrypted packages by gonavi envelope kind', () => { + it('detects v2 app-managed packages', () => { + expect(detectConnectionImportKind(JSON.stringify({ + v: 2, + kind: 'gonavi_connection_package', + p: 1, + exportedAt: '2026-04-11T21:00:00Z', + connections: [], + }))).toBe('app-managed-package'); + }); + + it('detects v2 encrypted packages', () => { + expect(detectConnectionImportKind(JSON.stringify({ + v: 2, + kind: 'gonavi_connection_package', + p: 2, + kdf: { + n: 'a2id', + m: 65536, + t: 3, + l: 4, + s: 'c2FsdA==', + }, + nc: 'bm9uY2Utbm9uY2U=', + d: 'encrypted-data', + }))).toBe('encrypted-package'); + }); + + it('rejects malformed v2 app-managed packages without connections array', () => { + expect(detectConnectionImportKind(JSON.stringify({ + v: 2, + kind: 'gonavi_connection_package', + p: 1, + exportedAt: '2026-04-11T21:00:00Z', + }))).toBe('invalid'); + }); + + it('rejects malformed v2 encrypted packages without protected payload fields', () => { + expect(detectConnectionImportKind(JSON.stringify({ + v: 2, + kind: 'gonavi_connection_package', + p: 2, + kdf: { + n: 'a2id', + m: 65536, + t: 3, + l: 4, + }, + }))).toBe('invalid'); + }); + + it('detects v1 encrypted packages by gonavi envelope kind', () => { expect(detectConnectionImportKind(JSON.stringify({ schemaVersion: 1, kind: 'gonavi_connection_package', @@ -39,6 +90,15 @@ describe('connectionExport', () => { it('returns invalid for malformed or unsupported content', () => { expect(detectConnectionImportKind('{not-json}')).toBe('invalid'); + expect(detectConnectionImportKind(JSON.stringify({ + v: 2, + kind: 'gonavi_connection_package', + p: 0, + }))).toBe('invalid'); + expect(detectConnectionImportKind(JSON.stringify({ + v: 2, + kind: 'gonavi_connection_package', + }))).toBe('invalid'); expect(detectConnectionImportKind(JSON.stringify({ kind: 'gonavi_connection_package', payload: 'encrypted-data', @@ -60,6 +120,14 @@ describe('connectionExport', () => { expect(normalizeConnectionPackagePassword('\n\t \t')).toBe(''); }); + it('recognizes backend password-required errors for protected packages', () => { + expect(isConnectionPackagePasswordRequiredError(new Error('恢复包密码不能为空'))).toBe(true); + expect(isConnectionPackagePasswordRequiredError({ message: '恢复包密码不能为空' })).toBe(true); + expect(isConnectionPackagePasswordRequiredError('恢复包密码不能为空')).toBe(true); + expect(isConnectionPackagePasswordRequiredError(new Error('文件密码错误或文件已损坏'))).toBe(false); + expect(isConnectionPackagePasswordRequiredError(undefined)).toBe(false); + }); + it('treats export cancel as a non-error backend result', () => { expect(isConnectionPackageExportCanceled({ success: false, message: '已取消' })).toBe(true); expect(isConnectionPackageExportCanceled({ success: false, message: '导出失败' })).toBe(false); @@ -71,6 +139,8 @@ describe('connectionExport', () => { const staleDialog = { open: true, mode: 'export' as const, + includeSecrets: true, + useFilePassword: false, password: ' secret-pass ', error: '上一次失败', confirmLoading: false, @@ -83,12 +153,16 @@ describe('connectionExport', () => { expect((canceledResult.nextDialog as (current: typeof staleDialog) => typeof staleDialog)({ open: false, mode: 'export', + includeSecrets: true, + useFilePassword: false, password: 'secret-pass', error: '更新后的错误', confirmLoading: true, })).toEqual({ open: false, mode: 'export', + includeSecrets: true, + useFilePassword: false, password: 'secret-pass', error: '', confirmLoading: false, diff --git a/frontend/src/utils/connectionExport.ts b/frontend/src/utils/connectionExport.ts index 13ff987..22de9eb 100644 --- a/frontend/src/utils/connectionExport.ts +++ b/frontend/src/utils/connectionExport.ts @@ -1,9 +1,11 @@ import type { ConnectionConfig, SavedConnection } from '../types'; -export type ConnectionImportKind = 'encrypted-package' | 'legacy-json' | 'invalid'; +export type ConnectionImportKind = 'app-managed-package' | 'encrypted-package' | 'legacy-json' | 'invalid'; export type ConnectionPackageDialogSnapshot = { open: boolean; mode: 'export' | 'import'; + includeSecrets: boolean; + useFilePassword: boolean; password: string; error: string; confirmLoading: boolean; @@ -20,7 +22,11 @@ export type ConnectionPackageExportResult = type JsonObject = Record; const CONNECTION_PACKAGE_KIND = 'gonavi_connection_package'; +const CONNECTION_PACKAGE_SCHEMA_VERSION_V2 = 2; +const CONNECTION_PACKAGE_PROTECTION_APP_MANAGED = 1; +const CONNECTION_PACKAGE_PROTECTION_FILE_PASSWORD = 2; const CANCELED_MESSAGE = '已取消'; +const CONNECTION_PACKAGE_PASSWORD_REQUIRED_MESSAGE = '恢复包密码不能为空'; const isJsonObject = (value: unknown): value is JsonObject => ( typeof value === 'object' && value !== null && !Array.isArray(value) @@ -45,6 +51,36 @@ const isConnectionPackageEnvelope = (value: unknown): value is JsonObject => ( && typeof value.payload === 'string' ); +const isConnectionPackageV2Envelope = (value: unknown): value is JsonObject => ( + isJsonObject(value) + && value.kind === CONNECTION_PACKAGE_KIND + && value.v === CONNECTION_PACKAGE_SCHEMA_VERSION_V2 + && typeof value.p === 'number' +); + +const isConnectionPackageKDFV2 = (value: unknown): value is JsonObject => ( + isJsonObject(value) + && typeof value.n === 'string' + && typeof value.m === 'number' + && typeof value.t === 'number' + && typeof value.l === 'number' + && typeof value.s === 'string' +); + +const isConnectionPackageV2AppManagedEnvelope = (value: unknown): value is JsonObject => ( + isConnectionPackageV2Envelope(value) + && value.p === CONNECTION_PACKAGE_PROTECTION_APP_MANAGED + && Array.isArray(value.connections) +); + +const isConnectionPackageV2ProtectedEnvelope = (value: unknown): value is JsonObject => ( + isConnectionPackageV2Envelope(value) + && value.p === CONNECTION_PACKAGE_PROTECTION_FILE_PASSWORD + && isConnectionPackageKDFV2(value.kdf) + && typeof value.nc === 'string' + && typeof value.d === 'string' +); + const isLegacyConnectionConfig = (value: unknown): value is JsonObject => ( isJsonObject(value) && typeof value.type === 'string' @@ -72,6 +108,18 @@ const parseConnectionImportRaw = (raw: unknown): unknown => { export const detectConnectionImportKind = (raw: unknown): ConnectionImportKind => { const parsed = parseConnectionImportRaw(raw); + if (isConnectionPackageV2AppManagedEnvelope(parsed)) { + return 'app-managed-package'; + } + + if (isConnectionPackageV2ProtectedEnvelope(parsed)) { + return 'encrypted-package'; + } + + if (isConnectionPackageV2Envelope(parsed)) { + return 'invalid'; + } + if (Array.isArray(parsed) && parsed.every((item) => isLegacyConnectionItem(item))) { return 'legacy-json'; } @@ -85,6 +133,20 @@ export const detectConnectionImportKind = (raw: unknown): ConnectionImportKind = export const normalizeConnectionPackagePassword = (value: string): string => value.trim(); +export const isConnectionPackagePasswordRequiredError = (value: unknown): boolean => { + if (typeof value === 'string') { + return value.trim() === CONNECTION_PACKAGE_PASSWORD_REQUIRED_MESSAGE; + } + + if (value instanceof Error) { + return value.message.trim() === CONNECTION_PACKAGE_PASSWORD_REQUIRED_MESSAGE; + } + + return isJsonObject(value) + && typeof value.message === 'string' + && value.message.trim() === CONNECTION_PACKAGE_PASSWORD_REQUIRED_MESSAGE; +}; + export const isConnectionPackageExportCanceled = (result: unknown): boolean => ( isJsonObject(result) && result.success === false diff --git a/frontend/wailsjs/go/app/App.d.ts b/frontend/wailsjs/go/app/App.d.ts index ab7f7d3..4f09f28 100755 --- a/frontend/wailsjs/go/app/App.d.ts +++ b/frontend/wailsjs/go/app/App.d.ts @@ -75,7 +75,7 @@ export function DuplicateConnection(arg1:string):Promise; -export function ExportConnectionsPackage(arg1:string):Promise; +export function ExportConnectionsPackage(arg1:app.ConnectionExportOptions):Promise; export function ExportData(arg1:Array>,arg2:Array,arg3:string,arg4:string):Promise; diff --git a/frontend/wailsjs/go/models.ts b/frontend/wailsjs/go/models.ts index ca54253..9c899d1 100755 --- a/frontend/wailsjs/go/models.ts +++ b/frontend/wailsjs/go/models.ts @@ -181,6 +181,20 @@ export namespace ai { export namespace app { + export class ConnectionExportOptions { + includeSecrets: boolean; + filePassword?: string; + + static createFrom(source: any = {}) { + return new ConnectionExportOptions(source); + } + + constructor(source: any = {}) { + if ('string' === typeof source) source = JSON.parse(source); + this.includeSecrets = source["includeSecrets"]; + this.filePassword = source["filePassword"]; + } + } export class SecurityUpdateOptions { allowPartial?: boolean; writeBackup?: boolean; diff --git a/internal/app/connection_package_appkey.go b/internal/app/connection_package_appkey.go new file mode 100644 index 0000000..eb9e45f --- /dev/null +++ b/internal/app/connection_package_appkey.go @@ -0,0 +1,196 @@ +package app + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "strings" + "sync" + + "golang.org/x/crypto/argon2" +) + +const ( + connectionPackageAppKeyPurpose = "gonavi-export-key-v2" + connectionPackageAppKeyFallbackSeed = "gonavi-connection-package-v2-seed" + connectionPackageAppKeyFallbackSalt = "gonavi-connection-package-v2-salt" +) + +var ( + connectionPackageAppKeySeed string + connectionPackageAppKeySalt string + + connectionPackageAppKeyMu sync.Mutex + connectionPackageAppKeyCached []byte +) + +func deriveConnectionPackageAppKey() ([]byte, error) { + connectionPackageAppKeyMu.Lock() + defer connectionPackageAppKeyMu.Unlock() + + if len(connectionPackageAppKeyCached) == connectionPackageAES256KeyBytes { + return append([]byte(nil), connectionPackageAppKeyCached...), nil + } + + seed := strings.TrimSpace(connectionPackageAppKeySeed) + if seed == "" { + seed = connectionPackageAppKeyFallbackSeed + } + saltValue := strings.TrimSpace(connectionPackageAppKeySalt) + if saltValue == "" { + saltValue = connectionPackageAppKeyFallbackSalt + } + + mac := hmac.New(sha256.New, []byte(seed)) + if _, err := mac.Write([]byte(connectionPackageAppKeyPurpose)); err != nil { + return nil, err + } + intermediate := mac.Sum(nil) + + saltHash := sha256.Sum256([]byte(saltValue)) + key := argon2.IDKey( + intermediate, + saltHash[:connectionPackageSaltBytes], + connectionPackageKDFDefaultTimeCost, + connectionPackageKDFDefaultMemoryKiB, + connectionPackageKDFDefaultParallelism, + connectionPackageAES256KeyBytes, + ) + connectionPackageAppKeyCached = append([]byte(nil), key...) + return append([]byte(nil), key...), nil +} + +func resetConnectionPackageAppKeyCache() { + connectionPackageAppKeyMu.Lock() + defer connectionPackageAppKeyMu.Unlock() + connectionPackageAppKeyCached = nil +} + +func encryptSecretField(appKey []byte, plaintext string, aad string) (string, error) { + if plaintext == "" { + return "", nil + } + + aead, err := newConnectionPackageAEAD(appKey) + if err != nil { + return "", err + } + + nonce := make([]byte, connectionPackageNonceBytes) + if _, err := rand.Read(nonce); err != nil { + return "", err + } + + ciphertext := aead.Seal(nil, nonce, []byte(plaintext), []byte(aad)) + encoded := make([]byte, 0, len(nonce)+len(ciphertext)) + encoded = append(encoded, nonce...) + encoded = append(encoded, ciphertext...) + return base64.StdEncoding.EncodeToString(encoded), nil +} + +func decryptSecretField(appKey []byte, encrypted string, aad string) (string, error) { + if encrypted == "" { + return "", nil + } + + raw, err := base64.StdEncoding.DecodeString(encrypted) + if err != nil { + return "", err + } + if len(raw) <= connectionPackageNonceBytes { + return "", errors.New("invalid encrypted secret") + } + + aead, err := newConnectionPackageAEAD(appKey) + if err != nil { + return "", err + } + + plain, err := aead.Open(nil, raw[:connectionPackageNonceBytes], raw[connectionPackageNonceBytes:], []byte(aad)) + if err != nil { + return "", err + } + return string(plain), nil +} + +func encryptSecretBundle(appKey []byte, bundle connectionSecretBundle, connectionID string) (connectionSecretBundle, error) { + var encrypted connectionSecretBundle + var err error + + encrypted.Password, err = encryptSecretField(appKey, bundle.Password, connectionID) + if err != nil { + return connectionSecretBundle{}, err + } + encrypted.SSHPassword, err = encryptSecretField(appKey, bundle.SSHPassword, connectionID) + if err != nil { + return connectionSecretBundle{}, err + } + encrypted.ProxyPassword, err = encryptSecretField(appKey, bundle.ProxyPassword, connectionID) + if err != nil { + return connectionSecretBundle{}, err + } + encrypted.HTTPTunnelPassword, err = encryptSecretField(appKey, bundle.HTTPTunnelPassword, connectionID) + if err != nil { + return connectionSecretBundle{}, err + } + encrypted.MySQLReplicaPassword, err = encryptSecretField(appKey, bundle.MySQLReplicaPassword, connectionID) + if err != nil { + return connectionSecretBundle{}, err + } + encrypted.MongoReplicaPassword, err = encryptSecretField(appKey, bundle.MongoReplicaPassword, connectionID) + if err != nil { + return connectionSecretBundle{}, err + } + encrypted.OpaqueURI, err = encryptSecretField(appKey, bundle.OpaqueURI, connectionID) + if err != nil { + return connectionSecretBundle{}, err + } + encrypted.OpaqueDSN, err = encryptSecretField(appKey, bundle.OpaqueDSN, connectionID) + if err != nil { + return connectionSecretBundle{}, err + } + + return encrypted, nil +} + +func decryptSecretBundle(appKey []byte, bundle connectionSecretBundle, connectionID string) (connectionSecretBundle, error) { + var decrypted connectionSecretBundle + var err error + + decrypted.Password, err = decryptSecretField(appKey, bundle.Password, connectionID) + if err != nil { + return connectionSecretBundle{}, err + } + decrypted.SSHPassword, err = decryptSecretField(appKey, bundle.SSHPassword, connectionID) + if err != nil { + return connectionSecretBundle{}, err + } + decrypted.ProxyPassword, err = decryptSecretField(appKey, bundle.ProxyPassword, connectionID) + if err != nil { + return connectionSecretBundle{}, err + } + decrypted.HTTPTunnelPassword, err = decryptSecretField(appKey, bundle.HTTPTunnelPassword, connectionID) + if err != nil { + return connectionSecretBundle{}, err + } + decrypted.MySQLReplicaPassword, err = decryptSecretField(appKey, bundle.MySQLReplicaPassword, connectionID) + if err != nil { + return connectionSecretBundle{}, err + } + decrypted.MongoReplicaPassword, err = decryptSecretField(appKey, bundle.MongoReplicaPassword, connectionID) + if err != nil { + return connectionSecretBundle{}, err + } + decrypted.OpaqueURI, err = decryptSecretField(appKey, bundle.OpaqueURI, connectionID) + if err != nil { + return connectionSecretBundle{}, err + } + decrypted.OpaqueDSN, err = decryptSecretField(appKey, bundle.OpaqueDSN, connectionID) + if err != nil { + return connectionSecretBundle{}, err + } + + return decrypted, nil +} diff --git a/internal/app/connection_package_appkey_test.go b/internal/app/connection_package_appkey_test.go new file mode 100644 index 0000000..619bb84 --- /dev/null +++ b/internal/app/connection_package_appkey_test.go @@ -0,0 +1,141 @@ +package app + +import ( + "encoding/base64" + "reflect" + "strings" + "testing" +) + +func TestDeriveConnectionPackageAppKeyIsStable(t *testing.T) { + originalSeed := connectionPackageAppKeySeed + originalSalt := connectionPackageAppKeySalt + t.Cleanup(func() { + connectionPackageAppKeySeed = originalSeed + connectionPackageAppKeySalt = originalSalt + resetConnectionPackageAppKeyCache() + }) + + connectionPackageAppKeySeed = "unit-test-seed" + connectionPackageAppKeySalt = "unit-test-salt" + resetConnectionPackageAppKeyCache() + + first, err := deriveConnectionPackageAppKey() + if err != nil { + t.Fatalf("deriveConnectionPackageAppKey returned error: %v", err) + } + second, err := deriveConnectionPackageAppKey() + if err != nil { + t.Fatalf("deriveConnectionPackageAppKey returned error on second call: %v", err) + } + if len(first) != connectionPackageAES256KeyBytes { + t.Fatalf("expected %d-byte app key, got %d", connectionPackageAES256KeyBytes, len(first)) + } + if !reflect.DeepEqual(first, second) { + t.Fatal("expected deriveConnectionPackageAppKey to be stable across repeated calls") + } + + connectionPackageAppKeySeed = "unit-test-seed-rotated" + resetConnectionPackageAppKeyCache() + rotated, err := deriveConnectionPackageAppKey() + if err != nil { + t.Fatalf("deriveConnectionPackageAppKey returned error after seed rotation: %v", err) + } + if reflect.DeepEqual(first, rotated) { + t.Fatal("expected different injected seed to produce a different app key") + } +} + +func TestEncryptSecretFieldRoundTrip(t *testing.T) { + appKey := []byte("0123456789abcdef0123456789abcdef") + + encrypted, err := encryptSecretField(appKey, "super-secret", "conn-1") + if err != nil { + t.Fatalf("encryptSecretField returned error: %v", err) + } + if strings.HasPrefix(encrypted, "ENC:") { + t.Fatalf("encrypted field must not carry ENC prefix, got %q", encrypted) + } + raw, err := base64.StdEncoding.DecodeString(encrypted) + if err != nil { + t.Fatalf("encrypted field must be base64, got error: %v", err) + } + if len(raw) <= connectionPackageNonceBytes { + t.Fatalf("expected nonce+ciphertext output, got %d bytes", len(raw)) + } + + decrypted, err := decryptSecretField(appKey, encrypted, "conn-1") + if err != nil { + t.Fatalf("decryptSecretField returned error: %v", err) + } + if decrypted != "super-secret" { + t.Fatalf("round-trip mismatch: got %q", decrypted) + } +} + +func TestDecryptSecretFieldRejectsAADMismatch(t *testing.T) { + appKey := []byte("0123456789abcdef0123456789abcdef") + + encrypted, err := encryptSecretField(appKey, "super-secret", "conn-1") + if err != nil { + t.Fatalf("encryptSecretField returned error: %v", err) + } + + if _, err := decryptSecretField(appKey, encrypted, "conn-2"); err == nil { + t.Fatal("expected decryptSecretField to reject mismatched AAD") + } +} + +func TestEncryptSecretBundleRoundTripAndAADBinding(t *testing.T) { + appKey := []byte("0123456789abcdef0123456789abcdef") + plain := connectionSecretBundle{ + Password: "primary-secret", + SSHPassword: "ssh-secret", + ProxyPassword: "proxy-secret", + HTTPTunnelPassword: "http-secret", + MySQLReplicaPassword: "mysql-secret", + MongoReplicaPassword: "mongo-secret", + OpaqueURI: "postgres://user:pass@db.local/app", + OpaqueDSN: "server=db.local;password=secret", + } + + encrypted, err := encryptSecretBundle(appKey, plain, "conn-1") + if err != nil { + t.Fatalf("encryptSecretBundle returned error: %v", err) + } + + for name, value := range map[string]string{ + "password": encrypted.Password, + "sshPassword": encrypted.SSHPassword, + "proxyPassword": encrypted.ProxyPassword, + "httpTunnelPassword": encrypted.HTTPTunnelPassword, + "mysqlReplicaPassword": encrypted.MySQLReplicaPassword, + "mongoReplicaPassword": encrypted.MongoReplicaPassword, + "opaqueURI": encrypted.OpaqueURI, + "opaqueDSN": encrypted.OpaqueDSN, + } { + if value == "" { + t.Fatalf("expected encrypted %s field to be populated", name) + } + if strings.HasPrefix(value, "ENC:") { + t.Fatalf("encrypted %s field must not carry ENC prefix", name) + } + if value == plain.Password || value == plain.SSHPassword || value == plain.ProxyPassword || + value == plain.HTTPTunnelPassword || value == plain.MySQLReplicaPassword || value == plain.MongoReplicaPassword || + value == plain.OpaqueURI || value == plain.OpaqueDSN { + t.Fatalf("expected encrypted %s field to differ from plaintext", name) + } + } + + decrypted, err := decryptSecretBundle(appKey, encrypted, "conn-1") + if err != nil { + t.Fatalf("decryptSecretBundle returned error: %v", err) + } + if !reflect.DeepEqual(decrypted, plain) { + t.Fatalf("bundle round-trip mismatch: got=%+v want=%+v", decrypted, plain) + } + + if _, err := decryptSecretBundle(appKey, encrypted, "conn-2"); err == nil { + t.Fatal("expected decryptSecretBundle to reject mismatched connection AAD") + } +} diff --git a/internal/app/connection_package_crypto.go b/internal/app/connection_package_crypto.go index 8337e2e..1b7b748 100644 --- a/internal/app/connection_package_crypto.go +++ b/internal/app/connection_package_crypto.go @@ -26,6 +26,14 @@ type connectionPackageAAD struct { Nonce string `json:"nonce"` } +type connectionPackageAADV2Protected struct { + V int `json:"v"` + Kind string `json:"kind"` + P int `json:"p"` + KDF connectionPackageKDFSpecV2 `json:"kdf"` + NC string `json:"nc"` +} + func encryptConnectionPackage(payload connectionPackagePayload, password string) (connectionPackageFile, error) { normalizedPassword := normalizeConnectionPackagePassword(password) if normalizedPassword == "" { @@ -108,7 +116,162 @@ func isConnectionPackageEnvelope(raw string) bool { if err != nil { return false } - return file.Kind == connectionPackageKind + return validateConnectionPackageFileHeader(file) == nil +} + +func encryptConnectionPackageV2AppManaged(payload connectionPackagePayload) (connectionPackageFileV2, error) { + appKey, err := deriveConnectionPackageAppKey() + if err != nil { + return connectionPackageFileV2{}, err + } + + encryptedPayload, err := encryptConnectionPackagePayloadSecrets(payload, appKey) + if err != nil { + return connectionPackageFileV2{}, err + } + + return connectionPackageFileV2{ + V: connectionPackageSchemaVersionV2, + Kind: connectionPackageKind, + P: connectionPackageProtectionAppManaged, + ExportedAt: encryptedPayload.ExportedAt, + Connections: encryptedPayload.Connections, + }, nil +} + +func encryptConnectionPackageV2Protected(payload connectionPackagePayload, password string) (connectionPackageFileV2Protected, error) { + normalizedPassword := normalizeConnectionPackagePassword(password) + if normalizedPassword == "" { + return connectionPackageFileV2Protected{}, errConnectionPackagePasswordRequired + } + + appKey, err := deriveConnectionPackageAppKey() + if err != nil { + return connectionPackageFileV2Protected{}, err + } + encryptedPayload, err := encryptConnectionPackagePayloadSecrets(payload, appKey) + if err != nil { + return connectionPackageFileV2Protected{}, err + } + + plain, err := json.Marshal(encryptedPayload) + if err != nil { + return connectionPackageFileV2Protected{}, err + } + + salt := make([]byte, connectionPackageSaltBytes) + if _, err := rand.Read(salt); err != nil { + return connectionPackageFileV2Protected{}, err + } + nonce := make([]byte, connectionPackageNonceBytes) + if _, err := rand.Read(nonce); err != nil { + return connectionPackageFileV2Protected{}, err + } + + file := connectionPackageFileV2Protected{ + V: connectionPackageSchemaVersionV2, + Kind: connectionPackageKind, + P: connectionPackageProtectionPasswordProtected, + KDF: defaultConnectionPackageKDFSpecV2(), + NC: base64.StdEncoding.EncodeToString(nonce), + } + file.KDF.S = base64.StdEncoding.EncodeToString(salt) + + key, err := deriveConnectionPackageKeyV2(normalizedPassword, file.KDF) + if err != nil { + return connectionPackageFileV2Protected{}, err + } + aad, err := marshalConnectionPackageAADV2Protected(file) + if err != nil { + return connectionPackageFileV2Protected{}, err + } + aead, err := newConnectionPackageAEAD(key) + if err != nil { + return connectionPackageFileV2Protected{}, err + } + + ciphertext := aead.Seal(nil, nonce, plain, aad) + if len(ciphertext) > connectionPackageMaxCiphertextBytes { + return connectionPackageFileV2Protected{}, errConnectionPackagePayloadTooLarge + } + file.D = base64.StdEncoding.EncodeToString(ciphertext) + if len(file.D) > connectionPackageMaxPayloadBase64Bytes { + return connectionPackageFileV2Protected{}, errConnectionPackagePayloadTooLarge + } + return file, nil +} + +func decryptConnectionPackageV2AppManaged(file connectionPackageFileV2) (connectionPackagePayload, error) { + if err := validateConnectionPackageFileHeaderV2AppManaged(file); err != nil { + return connectionPackagePayload{}, err + } + + appKey, err := deriveConnectionPackageAppKey() + if err != nil { + return connectionPackagePayload{}, err + } + + payload, err := decryptConnectionPackagePayloadSecrets(connectionPackagePayload{ + ExportedAt: file.ExportedAt, + Connections: file.Connections, + }, appKey) + if err != nil { + return connectionPackagePayload{}, errConnectionPackageDecryptFailed + } + return payload, nil +} + +func decryptConnectionPackageV2Protected(file connectionPackageFileV2Protected, password string) (connectionPackagePayload, error) { + normalizedPassword := normalizeConnectionPackagePassword(password) + if normalizedPassword == "" { + return connectionPackagePayload{}, errConnectionPackagePasswordRequired + } + if err := validateConnectionPackageFileHeaderV2Protected(file); err != nil { + return connectionPackagePayload{}, err + } + + plain, err := decryptConnectionPackageV2ProtectedPlaintext(file, normalizedPassword) + if err != nil { + if errors.Is(err, errConnectionPackagePayloadTooLarge) { + return connectionPackagePayload{}, err + } + return connectionPackagePayload{}, errConnectionPackageDecryptFailed + } + + var encryptedPayload connectionPackagePayload + if err := json.Unmarshal(plain, &encryptedPayload); err != nil { + return connectionPackagePayload{}, errConnectionPackageDecryptFailed + } + + appKey, err := deriveConnectionPackageAppKey() + if err != nil { + return connectionPackagePayload{}, err + } + payload, err := decryptConnectionPackagePayloadSecrets(encryptedPayload, appKey) + if err != nil { + return connectionPackagePayload{}, errConnectionPackageDecryptFailed + } + return payload, nil +} + +func isConnectionPackageV2AppManaged(raw string) bool { + header, err := decodeConnectionPackageV2Header(raw) + if err != nil { + return false + } + return header.Kind == connectionPackageKind && + header.V == connectionPackageSchemaVersionV2 && + header.P == connectionPackageProtectionAppManaged +} + +func isConnectionPackageV2Protected(raw string) bool { + header, err := decodeConnectionPackageV2Header(raw) + if err != nil { + return false + } + return header.Kind == connectionPackageKind && + header.V == connectionPackageSchemaVersionV2 && + header.P == connectionPackageProtectionPasswordProtected } func encodeConnectionPackageEnvelope(file connectionPackageFile) (string, error) { @@ -127,6 +290,22 @@ func decodeConnectionPackageEnvelope(raw string) (connectionPackageFile, error) return file, nil } +func decodeConnectionPackageV2Header(raw string) (struct { + V int `json:"v"` + Kind string `json:"kind"` + P int `json:"p"` +}, error) { + var header struct { + V int `json:"v"` + Kind string `json:"kind"` + P int `json:"p"` + } + if err := json.Unmarshal([]byte(raw), &header); err != nil { + return header, err + } + return header, nil +} + func decryptConnectionPackagePlaintext(file connectionPackageFile, password string) ([]byte, error) { if err := validateConnectionPackageFileHeader(file); err != nil { return nil, err @@ -191,6 +370,30 @@ func deriveConnectionPackageKey(password string, spec connectionPackageKDFSpec) return key, nil } +func deriveConnectionPackageKeyV2(password string, spec connectionPackageKDFSpecV2) ([]byte, error) { + if password == "" { + return nil, errConnectionPackagePasswordRequired + } + if err := validateConnectionPackageKDFSpecV2(spec); err != nil { + return nil, err + } + + salt, err := base64.StdEncoding.DecodeString(spec.S) + if err != nil || len(salt) == 0 { + return nil, errors.New("invalid salt") + } + + key := argon2.IDKey( + []byte(password), + salt, + spec.T, + spec.M, + spec.L, + connectionPackageAES256KeyBytes, + ) + return key, nil +} + func marshalConnectionPackageAAD(file connectionPackageFile) ([]byte, error) { aad := connectionPackageAAD{ SchemaVersion: file.SchemaVersion, @@ -202,6 +405,16 @@ func marshalConnectionPackageAAD(file connectionPackageFile) ([]byte, error) { return json.Marshal(aad) } +func marshalConnectionPackageAADV2Protected(file connectionPackageFileV2Protected) ([]byte, error) { + return json.Marshal(connectionPackageAADV2Protected{ + V: file.V, + Kind: file.Kind, + P: file.P, + KDF: file.KDF, + NC: file.NC, + }) +} + func newConnectionPackageAEAD(key []byte) (cipher.AEAD, error) { block, err := aes.NewCipher(key) if err != nil { @@ -225,6 +438,34 @@ func validateConnectionPackageFileHeader(file connectionPackageFile) error { } } +func validateConnectionPackageFileHeaderV2AppManaged(file connectionPackageFileV2) error { + switch { + case file.V != connectionPackageSchemaVersionV2: + return errConnectionPackageUnsupported + case strings.TrimSpace(file.Kind) != connectionPackageKind: + return errConnectionPackageUnsupported + case file.P != connectionPackageProtectionAppManaged: + return errConnectionPackageUnsupported + default: + return nil + } +} + +func validateConnectionPackageFileHeaderV2Protected(file connectionPackageFileV2Protected) error { + switch { + case file.V != connectionPackageSchemaVersionV2: + return errConnectionPackageUnsupported + case strings.TrimSpace(file.Kind) != connectionPackageKind: + return errConnectionPackageUnsupported + case file.P != connectionPackageProtectionPasswordProtected: + return errConnectionPackageUnsupported + case validateConnectionPackageKDFSpecV2(file.KDF) != nil: + return errConnectionPackageUnsupported + default: + return nil + } +} + func validateConnectionPackageKDFSpec(spec connectionPackageKDFSpec) error { switch { case strings.TrimSpace(spec.Name) != connectionPackageKDFName: @@ -241,3 +482,101 @@ func validateConnectionPackageKDFSpec(spec connectionPackageKDFSpec) error { return nil } } + +func validateConnectionPackageKDFSpecV2(spec connectionPackageKDFSpecV2) error { + switch { + case strings.TrimSpace(spec.N) != connectionPackageKDFNameV2: + return errConnectionPackageUnsupported + case spec.M == 0 || spec.T == 0 || spec.L == 0: + return errConnectionPackageUnsupported + case spec.M > connectionPackageKDFMaxMemoryKiB: + return errConnectionPackageUnsupported + case spec.T > connectionPackageKDFMaxTimeCost: + return errConnectionPackageUnsupported + case spec.L > connectionPackageKDFMaxParallelism: + return errConnectionPackageUnsupported + default: + return nil + } +} + +func decryptConnectionPackageV2ProtectedPlaintext(file connectionPackageFileV2Protected, password string) ([]byte, error) { + if err := validateConnectionPackageFileHeaderV2Protected(file); err != nil { + return nil, err + } + + nonce, err := base64.StdEncoding.DecodeString(file.NC) + if err != nil || len(nonce) != connectionPackageNonceBytes { + return nil, errors.New("invalid nonce") + } + if len(file.D) > connectionPackageMaxPayloadBase64Bytes { + return nil, errConnectionPackagePayloadTooLarge + } + ciphertext, err := base64.StdEncoding.DecodeString(file.D) + if err != nil || len(ciphertext) == 0 { + return nil, errors.New("invalid payload") + } + if len(ciphertext) > connectionPackageMaxCiphertextBytes { + return nil, errConnectionPackagePayloadTooLarge + } + + key, err := deriveConnectionPackageKeyV2(password, file.KDF) + if err != nil { + return nil, err + } + aad, err := marshalConnectionPackageAADV2Protected(file) + if err != nil { + return nil, err + } + aead, err := newConnectionPackageAEAD(key) + if err != nil { + return nil, err + } + + return aead.Open(nil, nonce, ciphertext, aad) +} + +func encryptConnectionPackagePayloadSecrets(payload connectionPackagePayload, appKey []byte) (connectionPackagePayload, error) { + encrypted := connectionPackagePayload{ + ExportedAt: payload.ExportedAt, + Connections: make([]connectionPackageItem, len(payload.Connections)), + } + + for index, item := range payload.Connections { + encryptedItem := item + bundle, err := encryptSecretBundle(appKey, item.Secrets, connectionPackageItemAAD(item)) + if err != nil { + return connectionPackagePayload{}, err + } + encryptedItem.Secrets = bundle + encrypted.Connections[index] = encryptedItem + } + + return encrypted, nil +} + +func decryptConnectionPackagePayloadSecrets(payload connectionPackagePayload, appKey []byte) (connectionPackagePayload, error) { + decrypted := connectionPackagePayload{ + ExportedAt: payload.ExportedAt, + Connections: make([]connectionPackageItem, len(payload.Connections)), + } + + for index, item := range payload.Connections { + decryptedItem := item + bundle, err := decryptSecretBundle(appKey, item.Secrets, connectionPackageItemAAD(item)) + if err != nil { + return connectionPackagePayload{}, err + } + decryptedItem.Secrets = bundle + decrypted.Connections[index] = decryptedItem + } + + return decrypted, nil +} + +func connectionPackageItemAAD(item connectionPackageItem) string { + if strings.TrimSpace(item.ID) != "" { + return item.ID + } + return item.Config.ID +} diff --git a/internal/app/connection_package_crypto_test.go b/internal/app/connection_package_crypto_test.go index 22ba2f1..57d3748 100644 --- a/internal/app/connection_package_crypto_test.go +++ b/internal/app/connection_package_crypto_test.go @@ -59,6 +59,186 @@ func TestConnectionPackageCryptoRoundTrip(t *testing.T) { } } +func TestConnectionPackageV2AppManagedRoundTrip(t *testing.T) { + payload := connectionPackagePayload{ + ExportedAt: "2026-04-11T12:00:00Z", + Connections: []connectionPackageItem{ + { + ID: "conn-v2-1", + Name: "app-managed", + Config: connection.ConnectionConfig{ + ID: "conn-v2-1", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + Database: "app", + }, + Secrets: connectionSecretBundle{ + Password: "primary-secret", + SSHPassword: "ssh-secret", + OpaqueURI: "postgres://postgres:primary-secret@db.local/app", + }, + }, + }, + } + + file, err := encryptConnectionPackageV2AppManaged(payload) + if err != nil { + t.Fatalf("encryptConnectionPackageV2AppManaged returned error: %v", err) + } + if file.V != connectionPackageSchemaVersionV2 { + t.Fatalf("expected v2 schema, got %d", file.V) + } + if file.P != connectionPackageProtectionAppManaged { + t.Fatalf("expected p=1, got %d", file.P) + } + if len(file.Connections) != 1 { + t.Fatalf("expected 1 connection, got %d", len(file.Connections)) + } + if file.Connections[0].Secrets.Password == payload.Connections[0].Secrets.Password { + t.Fatal("expected p=1 secrets to stay encrypted in file") + } + + raw, err := json.Marshal(file) + if err != nil { + t.Fatalf("json.Marshal returned error: %v", err) + } + if !isConnectionPackageV2AppManaged(string(raw)) { + t.Fatal("expected raw v2 p=1 payload to be detected") + } + if isConnectionPackageEnvelope(string(raw)) { + t.Fatal("v2 p=1 payload must not be misclassified as v1 envelope") + } + rawString := string(raw) + for _, forbidden := range []string{ + "schemaVersion", + "cipher", + "protectionLevel", + "ENC:", + "primary-secret", + "ssh-secret", + "postgres://postgres:primary-secret@db.local/app", + } { + if strings.Contains(rawString, forbidden) { + t.Fatalf("v2 p=1 payload must not contain %q: %s", forbidden, rawString) + } + } + + got, err := decryptConnectionPackageV2AppManaged(file) + if err != nil { + t.Fatalf("decryptConnectionPackageV2AppManaged returned error: %v", err) + } + if !reflect.DeepEqual(got, payload) { + t.Fatalf("round-trip mismatch: got=%+v want=%+v", got, payload) + } +} + +func TestConnectionPackageV2ProtectedRoundTrip(t *testing.T) { + payload := connectionPackagePayload{ + ExportedAt: "2026-04-11T12:00:00Z", + Connections: []connectionPackageItem{ + { + ID: "conn-v2-2", + Name: "password-protected", + Config: connection.ConnectionConfig{ + ID: "conn-v2-2", + Type: "mysql", + Host: "db.local", + Port: 3306, + User: "root", + Database: "app", + }, + Secrets: connectionSecretBundle{ + Password: "primary-secret", + SSHPassword: "ssh-secret", + ProxyPassword: "proxy-secret", + HTTPTunnelPassword: "http-secret", + MySQLReplicaPassword: "mysql-secret", + MongoReplicaPassword: "mongo-secret", + OpaqueURI: "mysql://root:primary-secret@tcp(db.local:3306)/app", + OpaqueDSN: "root:primary-secret@tcp(db.local:3306)/app", + }, + }, + }, + } + + file, err := encryptConnectionPackageV2Protected(payload, "package-password") + if err != nil { + t.Fatalf("encryptConnectionPackageV2Protected returned error: %v", err) + } + if file.V != connectionPackageSchemaVersionV2 { + t.Fatalf("expected v2 schema, got %d", file.V) + } + if file.P != connectionPackageProtectionPasswordProtected { + t.Fatalf("expected p=2, got %d", file.P) + } + if file.D == "" || file.NC == "" { + t.Fatal("expected p=2 file to carry outer encrypted payload") + } + if strings.HasPrefix(file.D, "ENC:") { + t.Fatalf("outer payload must not carry ENC prefix, got %q", file.D) + } + + raw, err := json.Marshal(file) + if err != nil { + t.Fatalf("json.Marshal returned error: %v", err) + } + if !isConnectionPackageV2Protected(string(raw)) { + t.Fatal("expected raw v2 p=2 payload to be detected") + } + if isConnectionPackageEnvelope(string(raw)) { + t.Fatal("v2 p=2 payload must not be misclassified as v1 envelope") + } + rawString := string(raw) + for _, forbidden := range []string{ + "schemaVersion", + "cipher", + "protectionLevel", + "ENC:", + "primary-secret", + "ssh-secret", + } { + if strings.Contains(rawString, forbidden) { + t.Fatalf("v2 p=2 payload must not contain %q: %s", forbidden, rawString) + } + } + + got, err := decryptConnectionPackageV2Protected(file, "package-password") + if err != nil { + t.Fatalf("decryptConnectionPackageV2Protected returned error: %v", err) + } + if !reflect.DeepEqual(got, payload) { + t.Fatalf("round-trip mismatch: got=%+v want=%+v", got, payload) + } +} + +func TestConnectionPackageV2ProtectedWrongPasswordReturnsUnifiedError(t *testing.T) { + file, err := encryptConnectionPackageV2Protected(connectionPackagePayload{ + Connections: []connectionPackageItem{ + { + ID: "conn-v2-3", + Name: "wrong-password", + Config: connection.ConnectionConfig{ + ID: "conn-v2-3", + Type: "postgres", + }, + Secrets: connectionSecretBundle{ + Password: "primary-secret", + }, + }, + }, + }, "correct-password") + if err != nil { + t.Fatalf("encryptConnectionPackageV2Protected returned error: %v", err) + } + + _, err = decryptConnectionPackageV2Protected(file, "wrong-password") + if !errors.Is(err, errConnectionPackageDecryptFailed) { + t.Fatalf("wrong p=2 password should return unified error, got: %v", err) + } +} + func TestConnectionPackageDecryptWrongPasswordReturnsUnifiedError(t *testing.T) { payload := connectionPackagePayload{ Connections: []connectionPackageItem{ diff --git a/internal/app/connection_package_transfer.go b/internal/app/connection_package_transfer.go index 3fc8e31..efba78a 100644 --- a/internal/app/connection_package_transfer.go +++ b/internal/app/connection_package_transfer.go @@ -48,6 +48,34 @@ func (a *App) buildConnectionPackagePayload() (connectionPackagePayload, error) }, nil } +func (a *App) buildExportedConnectionPackage(options ConnectionExportOptions) ([]byte, error) { + payload, err := a.buildConnectionPackagePayload() + if err != nil { + return nil, err + } + + if !options.IncludeSecrets { + for index := range payload.Connections { + payload.Connections[index].Secrets = connectionSecretBundle{} + } + } + + normalizedPassword := normalizeConnectionPackagePassword(options.FilePassword) + if !options.IncludeSecrets || normalizedPassword == "" { + file, err := encryptConnectionPackageV2AppManaged(payload) + if err != nil { + return nil, err + } + return json.MarshalIndent(file, "", " ") + } + + file, err := encryptConnectionPackageV2Protected(payload, normalizedPassword) + if err != nil { + return nil, err + } + return json.MarshalIndent(file, "", " ") +} + func newSavedConnectionInputFromPackageItem(item connectionPackageItem) connection.SavedConnectionInput { id := strings.TrimSpace(item.ID) if id == "" { @@ -192,6 +220,30 @@ func (a *App) ImportConnectionsPayload(raw string, password string) ([]connectio return nil, errConnectionImportFileTooLarge } + if isConnectionPackageV2AppManaged(trimmed) { + var file connectionPackageFileV2 + if err := json.Unmarshal([]byte(trimmed), &file); err != nil { + return nil, errConnectionPackageUnsupported + } + payload, err := decryptConnectionPackageV2AppManaged(file) + if err != nil { + return nil, err + } + return a.importConnectionPackagePayload(payload) + } + + if isConnectionPackageV2Protected(trimmed) { + var file connectionPackageFileV2Protected + if err := json.Unmarshal([]byte(trimmed), &file); err != nil { + return nil, errConnectionPackageUnsupported + } + payload, err := decryptConnectionPackageV2Protected(file, password) + if err != nil { + return nil, err + } + return a.importConnectionPackagePayload(payload) + } + if isConnectionPackageEnvelope(trimmed) { var file connectionPackageFile if err := json.Unmarshal([]byte(trimmed), &file); err != nil { diff --git a/internal/app/connection_package_transfer_test.go b/internal/app/connection_package_transfer_test.go index 81d40ea..0c11d3e 100644 --- a/internal/app/connection_package_transfer_test.go +++ b/internal/app/connection_package_transfer_test.go @@ -83,6 +83,71 @@ func TestBuildConnectionPackagePayloadIncludesSecretBundles(t *testing.T) { } } +func TestBuildExportedConnectionPackageWithoutSecretsUsesV2AppManagedAndImportsWithoutPasswords(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + + _, err := app.SaveConnection(connection.SavedConnectionInput{ + ID: "conn-v2-no-secrets", + Name: "Primary", + Config: connection.ConnectionConfig{ + ID: "conn-v2-no-secrets", + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + Password: "db-secret", + }, + }) + if err != nil { + t.Fatalf("SaveConnection returned error: %v", err) + } + + raw, err := app.buildExportedConnectionPackage(ConnectionExportOptions{ + IncludeSecrets: false, + FilePassword: "ignored-password", + }) + if err != nil { + t.Fatalf("buildExportedConnectionPackage returned error: %v", err) + } + + var file connectionPackageFileV2 + if err := json.Unmarshal(raw, &file); err != nil { + t.Fatalf("json.Unmarshal returned error: %v", err) + } + if file.V != connectionPackageSchemaVersionV2 { + t.Fatalf("expected v2 package, got v=%d", file.V) + } + if file.P != connectionPackageProtectionAppManaged { + t.Fatalf("expected app-managed protection, got p=%d", file.P) + } + if strings.Contains(string(raw), `"secrets"`) { + t.Fatalf("expected exported JSON to omit secrets when IncludeSecrets=false, got %s", string(raw)) + } + + importApp := NewAppWithSecretStore(newFakeAppSecretStore()) + importApp.configDir = t.TempDir() + + imported, err := importApp.ImportConnectionsPayload(string(raw), "") + if err != nil { + t.Fatalf("ImportConnectionsPayload returned error: %v", err) + } + if len(imported) != 1 { + t.Fatalf("expected 1 imported connection, got %d", len(imported)) + } + if imported[0].HasPrimaryPassword { + t.Fatal("expected imported connection to keep empty password when secrets are excluded") + } + + resolved, err := importApp.resolveConnectionSecrets(imported[0].Config) + if err != nil { + t.Fatalf("resolveConnectionSecrets returned error: %v", err) + } + if resolved.Password != "" { + t.Fatalf("expected imported password to be empty, got %q", resolved.Password) + } +} + func TestImportConnectionPackagePayloadOverwritesExistingSecrets(t *testing.T) { app := NewAppWithSecretStore(newFakeAppSecretStore()) app.configDir = t.TempDir() @@ -792,6 +857,93 @@ func TestImportConnectionsPayloadEnvelopeImportsAndOverwritesSecrets(t *testing. } } +func TestBuildExportedConnectionPackageWithSecretsUsesV2AppManagedEncryption(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + saveConnectionForPackageExport(t, app, "conn-v2-app", "app-secret") + + raw, err := app.buildExportedConnectionPackage(ConnectionExportOptions{ + IncludeSecrets: true, + }) + if err != nil { + t.Fatalf("buildExportedConnectionPackage returned error: %v", err) + } + + rawString := string(raw) + if !isConnectionPackageV2AppManaged(rawString) { + t.Fatalf("expected app-managed export, got %s", rawString) + } + for _, forbidden := range []string{ + "app-secret", + "schemaVersion", + "cipher", + "ENC:", + } { + if strings.Contains(rawString, forbidden) { + t.Fatalf("v2 p=1 export must not contain %q: %s", forbidden, rawString) + } + } + + imported, err := app.ImportConnectionsPayload(rawString, "") + if err != nil { + t.Fatalf("ImportConnectionsPayload returned error: %v", err) + } + if len(imported) != 1 { + t.Fatalf("expected 1 imported item, got %d", len(imported)) + } + + resolved, err := app.resolveConnectionSecrets(imported[0].Config) + if err != nil { + t.Fatalf("resolveConnectionSecrets returned error: %v", err) + } + if resolved.Password != "app-secret" { + t.Fatalf("expected v2 p=1 import to restore password, got %q", resolved.Password) + } +} + +func TestBuildExportedConnectionPackageWithFilePasswordUsesV2ProtectedEnvelope(t *testing.T) { + app := NewAppWithSecretStore(newFakeAppSecretStore()) + app.configDir = t.TempDir() + saveConnectionForPackageExport(t, app, "conn-v2-protected", "protected-secret") + + raw, err := app.buildExportedConnectionPackage(ConnectionExportOptions{ + IncludeSecrets: true, + FilePassword: "package-password", + }) + if err != nil { + t.Fatalf("buildExportedConnectionPackage returned error: %v", err) + } + + rawString := string(raw) + if !isConnectionPackageV2Protected(rawString) { + t.Fatalf("expected password-protected export, got %s", rawString) + } + if strings.Contains(rawString, "protected-secret") { + t.Fatalf("v2 p=2 export must not contain plaintext secret: %s", rawString) + } + + _, err = app.ImportConnectionsPayload(rawString, "wrong-password") + if !errors.Is(err, errConnectionPackageDecryptFailed) { + t.Fatalf("wrong v2 p=2 password should return unified error, got %v", err) + } + + imported, err := app.ImportConnectionsPayload(rawString, "package-password") + if err != nil { + t.Fatalf("ImportConnectionsPayload returned error: %v", err) + } + if len(imported) != 1 { + t.Fatalf("expected 1 imported item, got %d", len(imported)) + } + + resolved, err := app.resolveConnectionSecrets(imported[0].Config) + if err != nil { + t.Fatalf("resolveConnectionSecrets returned error: %v", err) + } + if resolved.Password != "protected-secret" { + t.Fatalf("expected v2 p=2 import to restore password, got %q", resolved.Password) + } +} + func TestNormalizeConnectionPackageExportFilenameAddsExtension(t *testing.T) { filename := normalizeConnectionPackageExportFilename(`C:\tmp\connections`) if !strings.HasSuffix(filename, connectionPackageExtension) { @@ -816,6 +968,34 @@ func newFailOnPutSecretStore(failRef string) *failOnPutSecretStore { } } +func saveConnectionForPackageExport(t *testing.T, app *App, id string, primaryPassword string) { + t.Helper() + + _, err := app.SaveConnection(connection.SavedConnectionInput{ + ID: id, + Name: "Exported " + id, + Config: connection.ConnectionConfig{ + ID: id, + Type: "postgres", + Host: "db.local", + Port: 5432, + User: "postgres", + Password: primaryPassword, + UseSSH: true, + SSH: connection.SSHConfig{ + Host: "jump.local", + Port: 22, + User: "ops", + Password: "ssh-" + primaryPassword, + }, + URI: "postgres://postgres:" + primaryPassword + "@db.local/app", + }, + }) + if err != nil { + t.Fatalf("SaveConnection returned error: %v", err) + } +} + func (s *failOnPutSecretStore) Put(ref string, payload []byte) error { if ref == s.failRef { return errors.New("injected put failure") diff --git a/internal/app/connection_package_types.go b/internal/app/connection_package_types.go index 18b28b6..3794656 100644 --- a/internal/app/connection_package_types.go +++ b/internal/app/connection_package_types.go @@ -1,6 +1,7 @@ package app import ( + "encoding/json" "errors" "strings" @@ -8,11 +9,16 @@ import ( ) const ( - connectionPackageSchemaVersion = 1 - connectionPackageKind = "gonavi_connection_package" - connectionPackageCipher = "AES-256-GCM" - connectionPackageKDFName = "Argon2id" - connectionPackageExtension = ".gonavi-conn" + connectionPackageSchemaVersion = 1 + connectionPackageSchemaVersionV2 = 2 + connectionPackageKind = "gonavi_connection_package" + connectionPackageCipher = "AES-256-GCM" + connectionPackageKDFName = "Argon2id" + connectionPackageKDFNameV2 = "a2id" + connectionPackageExtension = ".gonavi-conn" + + connectionPackageProtectionAppManaged = 1 + connectionPackageProtectionPasswordProtected = 2 connectionPackageKDFDefaultMemoryKiB = 65536 connectionPackageKDFDefaultTimeCost = 3 @@ -53,6 +59,31 @@ type connectionPackageKDFSpec struct { Salt string `json:"salt"` } +type connectionPackageFileV2 struct { + V int `json:"v"` + Kind string `json:"kind"` + P int `json:"p"` + ExportedAt string `json:"exportedAt,omitempty"` + Connections []connectionPackageItem `json:"connections"` +} + +type connectionPackageFileV2Protected struct { + V int `json:"v"` + Kind string `json:"kind"` + P int `json:"p"` + KDF connectionPackageKDFSpecV2 `json:"kdf"` + NC string `json:"nc"` + D string `json:"d"` +} + +type connectionPackageKDFSpecV2 struct { + N string `json:"n"` + M uint32 `json:"m"` + T uint32 `json:"t"` + L uint8 `json:"l"` + S string `json:"s"` +} + type connectionPackagePayload struct { ExportedAt string `json:"exportedAt,omitempty"` Connections []connectionPackageItem `json:"connections"` @@ -69,6 +100,39 @@ type connectionPackageItem struct { Secrets connectionSecretBundle `json:"secrets,omitempty"` } +func (i connectionPackageItem) MarshalJSON() ([]byte, error) { + type connectionPackageItemJSON struct { + ID string `json:"id"` + Name string `json:"name"` + IncludeDatabases []string `json:"includeDatabases,omitempty"` + IncludeRedisDatabases []int `json:"includeRedisDatabases,omitempty"` + IconType string `json:"iconType,omitempty"` + IconColor string `json:"iconColor,omitempty"` + Config connection.ConnectionConfig `json:"config"` + Secrets *connectionSecretBundle `json:"secrets,omitempty"` + } + + item := connectionPackageItemJSON{ + ID: i.ID, + Name: i.Name, + IncludeDatabases: i.IncludeDatabases, + IncludeRedisDatabases: i.IncludeRedisDatabases, + IconType: i.IconType, + IconColor: i.IconColor, + Config: i.Config, + } + if i.Secrets.hasAny() { + secrets := i.Secrets + item.Secrets = &secrets + } + return json.Marshal(item) +} + +type ConnectionExportOptions struct { + IncludeSecrets bool `json:"includeSecrets"` + FilePassword string `json:"filePassword,omitempty"` +} + func defaultConnectionPackageKDFSpec() connectionPackageKDFSpec { return connectionPackageKDFSpec{ Name: connectionPackageKDFName, @@ -78,6 +142,15 @@ func defaultConnectionPackageKDFSpec() connectionPackageKDFSpec { } } +func defaultConnectionPackageKDFSpecV2() connectionPackageKDFSpecV2 { + return connectionPackageKDFSpecV2{ + N: connectionPackageKDFNameV2, + M: connectionPackageKDFDefaultMemoryKiB, + T: connectionPackageKDFDefaultTimeCost, + L: connectionPackageKDFDefaultParallelism, + } +} + func normalizeConnectionPackagePassword(password string) string { return strings.TrimSpace(password) } diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go index 924730a..0d1c312 100644 --- a/internal/app/methods_file.go +++ b/internal/app/methods_file.go @@ -306,12 +306,7 @@ func (a *App) ImportConfigFile() connection.QueryResult { return connection.QueryResult{Success: true, Data: content} } -func (a *App) ExportConnectionsPackage(password string) connection.QueryResult { - payload, err := a.buildConnectionPackagePayload() - if err != nil { - return connection.QueryResult{Success: false, Message: err.Error()} - } - +func (a *App) ExportConnectionsPackage(options ConnectionExportOptions) connection.QueryResult { filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{ Title: "Export Connections", DefaultFilename: "connections" + connectionPackageExtension, @@ -327,12 +322,7 @@ func (a *App) ExportConnectionsPackage(password string) connection.QueryResult { } filename = normalizeConnectionPackageExportFilename(filename) - pkg, err := encryptConnectionPackage(payload, password) - if err != nil { - return connection.QueryResult{Success: false, Message: err.Error()} - } - - content, err := json.MarshalIndent(pkg, "", " ") + content, err := a.buildExportedConnectionPackage(options) if err != nil { return connection.QueryResult{Success: false, Message: err.Error()} } From 8e0d1b0a8087038aadc7e4fd3cacd96f4a867d14 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Sun, 12 Apr 2026 12:34:50 +0800 Subject: [PATCH 6/7] =?UTF-8?q?=F0=9F=93=9D=20docs(contributing):=20?= =?UTF-8?q?=E4=BF=AE=E6=AD=A3=20dev=20=E5=88=86=E6=94=AF=E8=B4=A1=E7=8C=AE?= =?UTF-8?q?=E6=B5=81=E7=A8=8B=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修正文档中的默认分支与集成分支描述 - 调整贡献分支创建基线为 dev - 调整外部 Pull Request 目标分支为 dev - 同步 README 中英文贡献说明 - 更新 release 后 main 回流 dev 的维护说明 Refs: #352 --- CONTRIBUTING.md | 48 ++++++++++++++++--------------------------- CONTRIBUTING.zh-CN.md | 48 ++++++++++++++++--------------------------- README.md | 2 +- README.zh-CN.md | 2 +- 4 files changed, 38 insertions(+), 62 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 162357f..41eb383 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,14 +2,14 @@ Thank you for contributing to this project. -This repository follows a release-first workflow: `main` is the default public branch, while releases are prepared through `release/*` branches. +This repository uses `dev` as the default integration branch, while stable releases are published from `main` through `release/*` branches. --- ## Branch Model -- `main`: stable release branch and default branch -- `dev`: day-to-day integration branch for maintainers +- `dev`: default branch and day-to-day integration branch +- `main`: stable release branch - `release/*`: release preparation branches for maintainers - Recommended branch names for external contributors: - `fix/*`: bug fixes @@ -25,21 +25,21 @@ feature/* / fix/* -> dev -> release/* -> main -> tag(vX.Y.Z) ## How External Contributors Should Open Pull Requests -Whether your branch is `fix/*` or `feature/*`, external contributors should **open pull requests directly against `main`**. +Whether your branch is `fix/*` or `feature/*`, external contributors should **open pull requests directly against `dev`**. Reasons: -- `main` is the default branch, so the PR entry point is clearer -- merged contributions are immediately visible on the default branch -- maintainers can handle downstream sync and release preparation in one place +- `dev` is the active integration branch, so changes can be reviewed in the same lane as ongoing work +- contributors align with the branch that triggers day-to-day validation and dev builds +- maintainers can cut `release/*` branches from `dev` without re-syncing external changes first Recommended flow: 1. Fork this repository -2. Create a branch in your fork (`fix/*` or `feature/*` is recommended) +2. Sync your fork with `dev` and create a branch from `dev` (`fix/*` or `feature/*` is recommended) 3. Make your changes and perform basic self-checks 4. Push the branch to your fork -5. Open a pull request against the `main` branch of this repository +5. Open a pull request against the `dev` branch of this repository --- @@ -63,33 +63,21 @@ Recommended expectations: ## Merge Strategy for Maintainers -Pull requests merged into `main` should generally use **Squash and merge**. +Pull requests merged into `dev` should generally use **Squash and merge**. Reasons: -- keeps `main` history clean and linear -- maps each PR to a single commit on `main` -- reduces release, audit, and rollback complexity +- keeps `dev` history readable and easier to audit during active iteration +- maps each PR to a single integration commit on `dev` +- reduces cherry-pick and conflict cost before creating `release/*` --- ## Maintainer Sync Rules -Because external pull requests are merged directly into `main`, maintainers must sync `main` back to development and release branches to avoid branch drift. +Because external pull requests are merged directly into `dev`, maintainers should treat `dev` as the source branch for daily collaboration and release preparation. -### 1. Sync `main` -> `dev` (required) - -The automatic GitHub Actions sync workflow has been removed. -Maintainers should sync `main` back to `dev` manually when needed: - -```bash -git checkout dev -git pull -git merge main -git push -``` - -### 2. Create `release/*` from `dev` +### 1. Create `release/*` from `dev` Before a release, create a release branch from `dev`, for example: @@ -100,7 +88,7 @@ git checkout -b release/v0.6.0 git push -u origin release/v0.6.0 ``` -### 3. Release from `release/*` back to `main` +### 2. Release from `release/*` back to `main` When release preparation is complete, merge the release branch back into `main` and create a tag: @@ -113,9 +101,9 @@ git tag v0.6.0 git push origin v0.6.0 ``` -### 4. Sync `main` back to `dev` after release +### 3. Sync `main` back to `dev` after release -After the release, the same automation still applies. If needed, you can run the workflow manually (`workflow_dispatch`) or execute the fallback commands: +After the release, sync `main` back into `dev` so the next iteration starts from the released code line: ```bash git checkout dev diff --git a/CONTRIBUTING.zh-CN.md b/CONTRIBUTING.zh-CN.md index a2d3983..0aa2d51 100644 --- a/CONTRIBUTING.zh-CN.md +++ b/CONTRIBUTING.zh-CN.md @@ -2,14 +2,14 @@ 感谢你对本项目的贡献。 -本项目采用“发布优先(`main` 为默认分支)+ `release/*` 分支发版”的协作模型。为减少分支漂移与 PR 处理成本,请在提交贡献前先阅读本指南。 +本项目当前采用“`dev` 作为默认集成分支,`main` 作为稳定发布分支,`release/*` 负责发版准备”的协作模型。为减少分支漂移与 PR 处理成本,请在提交贡献前先阅读本指南。 --- ## 分支模型 -- `main`:稳定发布分支,也是仓库默认分支 -- `dev`:日常开发集成分支,主要供维护者使用 +- `dev`:默认分支,也是日常开发集成分支 +- `main`:稳定发布分支 - `release/*`:发布准备分支,主要供维护者使用 - 外部贡献者建议使用以下分支命名: - `fix/*`:问题修复 @@ -25,21 +25,21 @@ feature/* / fix/* -> dev -> release/* -> main -> tag(vX.Y.Z) ## 外部贡献者如何提 Pull Request -无论是 `fix/*` 还是 `feature/*`,**外部贡献者统一直接向 `main` 发起 Pull Request**。 +无论是 `fix/*` 还是 `feature/*`,**外部贡献者统一直接向 `dev` 发起 Pull Request**。 这样做的原因: -- `main` 是默认分支,PR 入口更直观 -- 合并后贡献会直接体现在默认分支 -- 便于维护者统一做后续同步与发版整理 +- `dev` 是当前日常集成分支,评审与合入路径和维护者开发流程一致 +- 外部贡献会直接进入触发日常校验和 dev 构建的分支 +- 维护者可以直接从 `dev` 切 `release/*`,减少额外同步步骤 建议流程: 1. Fork 本仓库 -2. 从你自己的仓库创建分支(建议命名为 `fix/*` 或 `feature/*`) +2. 先同步你 fork 中的 `dev`,再从 `dev` 创建分支(建议命名为 `fix/*` 或 `feature/*`) 3. 完成代码修改,并进行必要自检 4. 推送到你的远程分支 -5. 向本仓库的 `main` 分支发起 Pull Request +5. 向本仓库的 `dev` 分支发起 Pull Request --- @@ -63,33 +63,21 @@ feature/* / fix/* -> dev -> release/* -> main -> tag(vX.Y.Z) ## PR 合并策略(维护者) -`main` 分支上的 PR 建议使用 **Squash and merge**。 +`dev` 分支上的 PR 建议使用 **Squash and merge**。 原因: -- 保持 `main` 历史干净、线性 -- 每个 PR 在 `main` 上对应一个清晰提交 -- 降低发布排查与回滚成本 +- 保持 `dev` 集成历史清晰、便于审查 +- 每个 PR 在 `dev` 上对应一个明确的集成提交 +- 降低发版前整理与冲突处理成本 --- ## 维护者同步规则 -由于外部 PR 会直接合入 `main`,维护者必须及时将 `main` 的变更同步到开发与发布分支,避免分支漂移。 +由于外部 PR 会直接合入 `dev`,维护者应将 `dev` 作为日常协作与发版准备的主线分支。 -### 1. main → dev 同步(必做) - -仓库已移除 GitHub Actions 自动回灌 workflow。 -当前统一采用手动方式将 `main` 同步回 `dev`: - -```bash -git checkout dev -git pull -git merge main -git push -``` - -### 2. 发版前从 dev 切 release/* +### 1. 发版前从 dev 切 release/* 发布前由维护者基于 `dev` 创建发布分支,例如: @@ -100,7 +88,7 @@ git checkout -b release/v0.6.0 git push -u origin release/v0.6.0 ``` -### 3. release/* → main 发版 +### 2. release/* → main 发版 发布准备完成后,将 `release/*` 合并回 `main`,并打标签发布: @@ -113,9 +101,9 @@ git tag v0.6.0 git push origin v0.6.0 ``` -### 4. main 回流到 dev(发版后必做) +### 3. main 回流到 dev(发版后必做) -发布完成后,仍沿用同一套自动化流程;如有需要,也可以手动触发 `workflow_dispatch`,或执行以下兜底命令,确保开发线与发布线一致: +发布完成后,需要将 `main` 回流到 `dev`,确保下一轮开发从已发布代码线继续推进: ```bash git checkout dev diff --git a/README.md b/README.md index 9b42c4a..06fcc36 100644 --- a/README.md +++ b/README.md @@ -212,7 +212,7 @@ For the full workflow, branch model, and maintainer sync rules, see: - [CONTRIBUTING.md](CONTRIBUTING.md) -External contributors should open pull requests directly against `main`. +External contributors should branch from `dev` and open pull requests against `dev`. ## Star History diff --git a/README.zh-CN.md b/README.zh-CN.md index b392a06..46ea99c 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -195,7 +195,7 @@ sudo apt-get install -y libgtk-3-0 libwebkit2gtk-4.0-37 libjavascriptcoregtk-4.0 - [CONTRIBUTING.zh-CN.md](CONTRIBUTING.zh-CN.md) -外部贡献者统一直接向 `main` 发起 Pull Request。 +外部贡献者应从 `dev` 拉出分支,并统一向 `dev` 发起 Pull Request。 ## Star History (Star 增长趋势) From bb6271246b781789d702a5c4033cd3d618a752f3 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Sun, 12 Apr 2026 12:46:15 +0800 Subject: [PATCH 7/7] =?UTF-8?q?=F0=9F=90=9B=20fix(mac):=20=E7=A6=81?= =?UTF-8?q?=E7=94=A8=E6=AD=A3=E5=BC=8F=E5=8C=85=E9=BB=98=E8=AE=A4=E7=AA=97?= =?UTF-8?q?=E5=8F=A3=E8=AF=8A=E6=96=AD=E4=BB=A5=E8=A7=84=E9=81=BF=E5=90=AF?= =?UTF-8?q?=E5=8A=A8=E6=97=A0=E7=AA=97=E4=BD=93=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 macOS 原生窗口诊断改为默认关闭 - 仅在显式设置 GONAVI_ENABLE_MAC_WINDOW_DIAGNOSTICS 时启用后端诊断 - 仅在前端开发环境启用窗口诊断采集 - 避免正式构建在启动阶段附加额外窗口状态探测与日志观察 - 为诊断开关补充前后端最小回归测试 Refs: #360 --- frontend/src/App.tsx | 10 +++-- .../src/utils/macWindowDiagnostics.test.ts | 17 +++++++++ frontend/src/utils/macWindowDiagnostics.ts | 6 +++ internal/app/app.go | 4 +- internal/app/env.go | 5 +++ internal/app/window_diagnostics.go | 14 +++++++ internal/app/window_diagnostics_test.go | 37 +++++++++++++++++++ 7 files changed, 88 insertions(+), 5 deletions(-) create mode 100644 frontend/src/utils/macWindowDiagnostics.test.ts create mode 100644 frontend/src/utils/macWindowDiagnostics.ts create mode 100644 internal/app/env.go create mode 100644 internal/app/window_diagnostics.go create mode 100644 internal/app/window_diagnostics_test.go diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 502beca..9deb4e7 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -16,6 +16,7 @@ import { SavedConnection } from './types'; import { blurToFilter, normalizeBlurForPlatform, normalizeOpacityForPlatform, isWindowsPlatform, resolveAppearanceValues } from './utils/appearance'; import { DATA_GRID_COLUMN_WIDTH_MODE_OPTIONS, sanitizeDataTableColumnWidthMode } from './utils/dataGridDisplay'; import { getMacNativeTitlebarPaddingLeft, getMacNativeTitlebarPaddingRight, shouldHandleMacNativeFullscreenShortcut, shouldSuppressMacNativeEscapeExit } from './utils/macWindow'; +import { shouldEnableMacWindowDiagnostics } from './utils/macWindowDiagnostics'; import { buildOverlayWorkbenchTheme } from './utils/overlayWorkbenchTheme'; import { getConnectionWorkbenchState } from './utils/startupReadiness'; import { createGlobalProxyDraft, toSaveGlobalProxyInput } from './utils/globalProxyDraft'; @@ -897,9 +898,10 @@ function App() { const isWindowsRuntime = runtimePlatform === 'windows' || (runtimePlatform === '' && isWindowsPlatform()); const useNativeMacWindowControls = isMacRuntime && appearance.useNativeMacWindowControls === true; + const macWindowDiagnosticsEnabled = shouldEnableMacWindowDiagnostics(isMacRuntime, import.meta.env.DEV); const emitWindowDiagnostic = useCallback(async (stage: string, extra: Record = {}) => { - if (!isMacRuntime) { + if (!macWindowDiagnosticsEnabled) { return; } const backendApp = (window as any).go?.app?.App; @@ -953,7 +955,7 @@ function App() { } catch (error) { console.warn('Failed to emit window diagnostic', error); } - }, [isMacRuntime, useNativeMacWindowControls]); + }, [macWindowDiagnosticsEnabled, useNativeMacWindowControls]); useEffect(() => { if (!isStoreHydrated || !isMacRuntime) { @@ -968,7 +970,7 @@ function App() { }, [isMacRuntime, isStoreHydrated, useNativeMacWindowControls]); useEffect(() => { - if (!isMacRuntime) { + if (!macWindowDiagnosticsEnabled) { return; } @@ -1063,7 +1065,7 @@ function App() { window.removeEventListener('compositionend', handleCompositionEnd, true); document.removeEventListener('visibilitychange', handleVisibilityChange); }; - }, [emitWindowDiagnostic, isMacRuntime]); + }, [emitWindowDiagnostic, macWindowDiagnosticsEnabled]); const formatBytes = (bytes?: number) => { if (!bytes || bytes <= 0) return '0 B'; diff --git a/frontend/src/utils/macWindowDiagnostics.test.ts b/frontend/src/utils/macWindowDiagnostics.test.ts new file mode 100644 index 0000000..8c17c75 --- /dev/null +++ b/frontend/src/utils/macWindowDiagnostics.test.ts @@ -0,0 +1,17 @@ +import { describe, expect, it } from 'vitest'; + +import { shouldEnableMacWindowDiagnostics } from './macWindowDiagnostics'; + +describe('macWindowDiagnostics', () => { + it('stays disabled outside macOS runtime', () => { + expect(shouldEnableMacWindowDiagnostics(false, true)).toBe(false); + }); + + it('stays disabled for production builds on macOS', () => { + expect(shouldEnableMacWindowDiagnostics(true, false)).toBe(false); + }); + + it('enables diagnostics only for macOS development builds', () => { + expect(shouldEnableMacWindowDiagnostics(true, true)).toBe(true); + }); +}); diff --git a/frontend/src/utils/macWindowDiagnostics.ts b/frontend/src/utils/macWindowDiagnostics.ts new file mode 100644 index 0000000..90be3ab --- /dev/null +++ b/frontend/src/utils/macWindowDiagnostics.ts @@ -0,0 +1,6 @@ +export const shouldEnableMacWindowDiagnostics = ( + isMacRuntime: boolean, + isDevBuild: boolean, +): boolean => { + return isMacRuntime && isDevBuild; +}; diff --git a/internal/app/app.go b/internal/app/app.go index b91400b..491d103 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -94,7 +94,9 @@ func (a *App) startup(ctx context.Context) { db.SetExternalDriverDownloadDirectory(appdata.DriverRoot(a.configDir)) logger.Init() a.loadPersistedGlobalProxy() - installMacNativeWindowDiagnostics(logger.Path()) + if shouldInstallMacNativeWindowDiagnostics() { + installMacNativeWindowDiagnostics(logger.Path()) + } applyMacWindowTranslucencyFix() logger.Infof("应用启动完成(首次连接保护窗口=%s,最多重试=%d 次)", startupConnectRetryWindow, startupConnectRetryAttempts) } diff --git a/internal/app/env.go b/internal/app/env.go new file mode 100644 index 0000000..b21abbd --- /dev/null +++ b/internal/app/env.go @@ -0,0 +1,5 @@ +package app + +import "os" + +var getenv = os.Getenv diff --git a/internal/app/window_diagnostics.go b/internal/app/window_diagnostics.go new file mode 100644 index 0000000..107ca74 --- /dev/null +++ b/internal/app/window_diagnostics.go @@ -0,0 +1,14 @@ +package app + +import "strings" + +const macWindowDiagnosticsEnv = "GONAVI_ENABLE_MAC_WINDOW_DIAGNOSTICS" + +func shouldInstallMacNativeWindowDiagnostics() bool { + switch strings.ToLower(strings.TrimSpace(getenv(macWindowDiagnosticsEnv))) { + case "1", "true", "yes", "on": + return true + default: + return false + } +} diff --git a/internal/app/window_diagnostics_test.go b/internal/app/window_diagnostics_test.go new file mode 100644 index 0000000..23b3d83 --- /dev/null +++ b/internal/app/window_diagnostics_test.go @@ -0,0 +1,37 @@ +package app + +import "testing" + +func TestShouldInstallMacNativeWindowDiagnosticsDefaultsDisabled(t *testing.T) { + t.Setenv("GONAVI_ENABLE_MAC_WINDOW_DIAGNOSTICS", "") + + if shouldInstallMacNativeWindowDiagnostics() { + t.Fatal("expected mac native window diagnostics to stay disabled by default") + } +} + +func TestShouldInstallMacNativeWindowDiagnosticsHonorsEnvOptIn(t *testing.T) { + t.Setenv("GONAVI_ENABLE_MAC_WINDOW_DIAGNOSTICS", "1") + + if !shouldInstallMacNativeWindowDiagnostics() { + t.Fatal("expected mac native window diagnostics to enable when explicitly opted in") + } + + t.Setenv("GONAVI_ENABLE_MAC_WINDOW_DIAGNOSTICS", "true") + if !shouldInstallMacNativeWindowDiagnostics() { + t.Fatal("expected mac native window diagnostics to accept true as opt-in value") + } + + t.Setenv("GONAVI_ENABLE_MAC_WINDOW_DIAGNOSTICS", "0") + if shouldInstallMacNativeWindowDiagnostics() { + t.Fatal("expected mac native window diagnostics to stay disabled for non-opt-in values") + } +} + +func TestShouldInstallMacNativeWindowDiagnosticsIgnoresCaseAndWhitespace(t *testing.T) { + t.Setenv("GONAVI_ENABLE_MAC_WINDOW_DIAGNOSTICS", " TRUE ") + + if !shouldInstallMacNativeWindowDiagnostics() { + t.Fatal("expected helper to trim and lowercase opt-in values") + } +}