合并拉取请求 #362

release/0.6.7
This commit is contained in:
Syngnat
2026-04-12 12:51:23 +08:00
committed by GitHub
85 changed files with 12985 additions and 433 deletions

5
.gitignore vendored
View File

@@ -1,7 +1,7 @@
# IDE
.idea/
*.iml
.gitignore
# build / release artifacts
frontend/release/
**/release/
@@ -27,4 +27,5 @@ docs/需求追踪/
CLAUDE.md
**/CLAUDE.md
.worktrees
docs
docs
.tmp_superpowers_edit

View File

@@ -2,14 +2,14 @@
Thank you for contributing to this project.
This repository follows a release-first workflow: `main` is the default public branch, while releases are prepared through `release/*` branches.
This repository uses `dev` as the default integration branch, while stable releases are published from `main` through `release/*` branches.
---
## Branch Model
- `main`: stable release branch and default branch
- `dev`: day-to-day integration branch for maintainers
- `dev`: default branch and day-to-day integration branch
- `main`: stable release branch
- `release/*`: release preparation branches for maintainers
- Recommended branch names for external contributors:
- `fix/*`: bug fixes
@@ -25,21 +25,21 @@ feature/* / fix/* -> dev -> release/* -> main -> tag(vX.Y.Z)
## How External Contributors Should Open Pull Requests
Whether your branch is `fix/*` or `feature/*`, external contributors should **open pull requests directly against `main`**.
Whether your branch is `fix/*` or `feature/*`, external contributors should **open pull requests directly against `dev`**.
Reasons:
- `main` is the default branch, so the PR entry point is clearer
- merged contributions are immediately visible on the default branch
- maintainers can handle downstream sync and release preparation in one place
- `dev` is the active integration branch, so changes can be reviewed in the same lane as ongoing work
- contributors align with the branch that triggers day-to-day validation and dev builds
- maintainers can cut `release/*` branches from `dev` without re-syncing external changes first
Recommended flow:
1. Fork this repository
2. Create a branch in your fork (`fix/*` or `feature/*` is recommended)
2. Sync your fork with `dev` and create a branch from `dev` (`fix/*` or `feature/*` is recommended)
3. Make your changes and perform basic self-checks
4. Push the branch to your fork
5. Open a pull request against the `main` branch of this repository
5. Open a pull request against the `dev` branch of this repository
---
@@ -63,33 +63,21 @@ Recommended expectations:
## Merge Strategy for Maintainers
Pull requests merged into `main` should generally use **Squash and merge**.
Pull requests merged into `dev` should generally use **Squash and merge**.
Reasons:
- keeps `main` history clean and linear
- maps each PR to a single commit on `main`
- reduces release, audit, and rollback complexity
- keeps `dev` history readable and easier to audit during active iteration
- maps each PR to a single integration commit on `dev`
- reduces cherry-pick and conflict cost before creating `release/*`
---
## Maintainer Sync Rules
Because external pull requests are merged directly into `main`, maintainers must sync `main` back to development and release branches to avoid branch drift.
Because external pull requests are merged directly into `dev`, maintainers should treat `dev` as the source branch for daily collaboration and release preparation.
### 1. Sync `main` -> `dev` (required)
The automatic GitHub Actions sync workflow has been removed.
Maintainers should sync `main` back to `dev` manually when needed:
```bash
git checkout dev
git pull
git merge main
git push
```
### 2. Create `release/*` from `dev`
### 1. Create `release/*` from `dev`
Before a release, create a release branch from `dev`, for example:
@@ -100,7 +88,7 @@ git checkout -b release/v0.6.0
git push -u origin release/v0.6.0
```
### 3. Release from `release/*` back to `main`
### 2. Release from `release/*` back to `main`
When release preparation is complete, merge the release branch back into `main` and create a tag:
@@ -113,9 +101,9 @@ git tag v0.6.0
git push origin v0.6.0
```
### 4. Sync `main` back to `dev` after release
### 3. Sync `main` back to `dev` after release
After the release, the same automation still applies. If needed, you can run the workflow manually (`workflow_dispatch`) or execute the fallback commands:
After the release, sync `main` back into `dev` so the next iteration starts from the released code line:
```bash
git checkout dev

View File

@@ -2,14 +2,14 @@
感谢你对本项目的贡献。
本项目采用“发布优先(`main` 为默认分支)+ `release/*` 分支发版”的协作模型。为减少分支漂移与 PR 处理成本,请在提交贡献前先阅读本指南。
本项目当前采用“`dev` 作为默认集成分支,`main` 作为稳定发布分支,`release/*` 负责发版准备”的协作模型。为减少分支漂移与 PR 处理成本,请在提交贡献前先阅读本指南。
---
## 分支模型
- `main`:稳定发布分支,也是仓库默认分支
- `dev`:日常开发集成分支,主要供维护者使用
- `dev`:默认分支,也是日常开发集成分支
- `main`:稳定发布分支
- `release/*`:发布准备分支,主要供维护者使用
- 外部贡献者建议使用以下分支命名:
- `fix/*`:问题修复
@@ -25,21 +25,21 @@ feature/* / fix/* -> dev -> release/* -> main -> tag(vX.Y.Z)
## 外部贡献者如何提 Pull Request
无论是 `fix/*` 还是 `feature/*`**外部贡献者统一直接向 `main` 发起 Pull Request**。
无论是 `fix/*` 还是 `feature/*`**外部贡献者统一直接向 `dev` 发起 Pull Request**。
这样做的原因:
- `main` 是默认分支PR 入口更直观
- 合并后贡献会直接体现在默认分支
- 便于维护者统一做后续同步与发版整理
- `dev` 是当前日常集成分支,评审与合入路径和维护者开发流程一致
- 外部贡献会直接进入触发日常校验和 dev 构建的分支
- 维护者可以直接从 `dev``release/*`,减少额外同步步骤
建议流程:
1. Fork 本仓库
2. 从你自己的仓库创建分支(建议命名为 `fix/*``feature/*`
2. 先同步你 fork 中的 `dev`,再从 `dev` 创建分支(建议命名为 `fix/*``feature/*`
3. 完成代码修改,并进行必要自检
4. 推送到你的远程分支
5. 向本仓库的 `main` 分支发起 Pull Request
5. 向本仓库的 `dev` 分支发起 Pull Request
---
@@ -63,33 +63,21 @@ feature/* / fix/* -> dev -> release/* -> main -> tag(vX.Y.Z)
## PR 合并策略(维护者)
`main` 分支上的 PR 建议使用 **Squash and merge**
`dev` 分支上的 PR 建议使用 **Squash and merge**
原因:
- 保持 `main` 历史干净、线性
- 每个 PR 在 `main` 上对应一个清晰提交
- 降低发布排查与回滚成本
- 保持 `dev` 集成历史清晰、便于审查
- 每个 PR 在 `dev` 上对应一个明确的集成提交
- 降低发版前整理与冲突处理成本
---
## 维护者同步规则
由于外部 PR 会直接合入 `main`,维护者必须及时将 `main` 的变更同步到开发与发布分支,避免分支漂移
由于外部 PR 会直接合入 `dev`,维护者应将 `dev` 作为日常协作与发版准备的主线分支
### 1. main → dev 同步(必做)
仓库已移除 GitHub Actions 自动回灌 workflow。
当前统一采用手动方式将 `main` 同步回 `dev`
```bash
git checkout dev
git pull
git merge main
git push
```
### 2. 发版前从 dev 切 release/*
### 1. 发版前从 dev 切 release/*
发布前由维护者基于 `dev` 创建发布分支,例如:
@@ -100,7 +88,7 @@ git checkout -b release/v0.6.0
git push -u origin release/v0.6.0
```
### 3. release/* → main 发版
### 2. release/* → main 发版
发布准备完成后,将 `release/*` 合并回 `main`,并打标签发布:
@@ -113,9 +101,9 @@ git tag v0.6.0
git push origin v0.6.0
```
### 4. main 回流到 dev发版后必做
### 3. main 回流到 dev发版后必做
发布完成后,仍沿用同一套自动化流程;如有需要,也可以手动触发 `workflow_dispatch`,或执行以下兜底命令,确保开发线与发布线一致
发布完成后,需要将 `main` 回流到 `dev`,确保下一轮开发从已发布代码线继续推进
```bash
git checkout dev

View File

@@ -212,7 +212,7 @@ For the full workflow, branch model, and maintainer sync rules, see:
- [CONTRIBUTING.md](CONTRIBUTING.md)
External contributors should open pull requests directly against `main`.
External contributors should branch from `dev` and open pull requests against `dev`.
## Star History
<a href="https://www.star-history.com/?repos=Syngnat%2FGoNavi&type=date&legend=top-left">

View File

@@ -195,7 +195,7 @@ sudo apt-get install -y libgtk-3-0 libwebkit2gtk-4.0-37 libjavascriptcoregtk-4.0
- [CONTRIBUTING.zh-CN.md](CONTRIBUTING.zh-CN.md)
外部贡献者统一直接向 `main` 发起 Pull Request。
外部贡献者应从 `dev` 拉出分支,并统一向 `dev` 发起 Pull Request。
## Star History (Star 增长趋势)

339
cmd/manualtestseed/main.go Normal file
View File

@@ -0,0 +1,339 @@
package main
import (
"encoding/json"
"flag"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"GoNavi-Wails/internal/ai"
aiservice "GoNavi-Wails/internal/ai/service"
"GoNavi-Wails/internal/app"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/secretstore"
)
const (
modeSeedSecureStorage = "seed-secure-storage"
modeSeedAIUpdate = "seed-ai-update"
)
const (
testConnectionID = "manualtest-postgres"
testSecureProviderID = "manualtest-secure-provider"
testPendingProviderID = "manualtest-pending-provider"
testBackupDirName = "manual-test-backups"
connectionsFileName = "connections.json"
globalProxyFileName = "global_proxy.json"
aiConfigFileName = "ai_config.json"
securityUpdateFileName = "config-security-update.json"
)
type backupManifest struct {
CreatedAt string `json:"createdAt"`
ConfigDir string `json:"configDir"`
Files []backupManifestFile `json:"files"`
}
type backupManifestFile struct {
RelativePath string `json:"relativePath"`
Existed bool `json:"existed"`
}
type storedAIConfig struct {
SchemaVersion int `json:"schemaVersion,omitempty"`
Providers []ai.ProviderConfig `json:"providers"`
ActiveProvider string `json:"activeProvider"`
SafetyLevel string `json:"safetyLevel"`
ContextLevel string `json:"contextLevel"`
}
func main() {
mode := flag.String("mode", modeSeedSecureStorage, "seed mode: seed-secure-storage | seed-ai-update")
flag.Parse()
configDir, err := resolveConfigDir()
if err != nil {
fatalf("resolve config dir failed: %v", err)
}
store := secretstore.NewKeyringStore()
if err := store.HealthCheck(); err != nil {
fatalf("secret store unavailable: %v", err)
}
backupDir, err := backupConfigFiles(configDir)
if err != nil {
fatalf("backup config files failed: %v", err)
}
switch strings.TrimSpace(*mode) {
case modeSeedSecureStorage:
if err := seedSecureStorage(configDir, store); err != nil {
fatalf("seed secure storage failed: %v", err)
}
fmt.Printf("mode=%s\nbackup=%s\nconnectionId=%s\nproviderId=%s\n", modeSeedSecureStorage, backupDir, testConnectionID, testSecureProviderID)
case modeSeedAIUpdate:
if err := seedAIUpdate(configDir, store); err != nil {
fatalf("seed ai update failed: %v", err)
}
fmt.Printf("mode=%s\nbackup=%s\npendingProviderId=%s\n", modeSeedAIUpdate, backupDir, testPendingProviderID)
default:
fatalf("unsupported mode: %s", *mode)
}
}
func fatalf(format string, args ...any) {
fmt.Fprintf(os.Stderr, format+"\n", args...)
os.Exit(1)
}
func resolveConfigDir() (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(homeDir, ".gonavi"), nil
}
func backupConfigFiles(configDir string) (string, error) {
backupDir := filepath.Join(configDir, testBackupDirName, time.Now().Format("20060102-150405"))
files := []string{
connectionsFileName,
globalProxyFileName,
aiConfigFileName,
filepath.Join("migrations", securityUpdateFileName),
}
manifest := backupManifest{
CreatedAt: time.Now().Format(time.RFC3339),
ConfigDir: configDir,
Files: make([]backupManifestFile, 0, len(files)),
}
for _, relativePath := range files {
srcPath := filepath.Join(configDir, relativePath)
info, err := os.Stat(srcPath)
if err != nil {
if os.IsNotExist(err) {
manifest.Files = append(manifest.Files, backupManifestFile{
RelativePath: relativePath,
Existed: false,
})
continue
}
return "", err
}
if info.IsDir() {
continue
}
dstPath := filepath.Join(backupDir, relativePath)
if err := os.MkdirAll(filepath.Dir(dstPath), 0o755); err != nil {
return "", err
}
data, err := os.ReadFile(srcPath)
if err != nil {
return "", err
}
if err := os.WriteFile(dstPath, data, 0o644); err != nil {
return "", err
}
manifest.Files = append(manifest.Files, backupManifestFile{
RelativePath: relativePath,
Existed: true,
})
}
if err := os.MkdirAll(backupDir, 0o755); err != nil {
return "", err
}
manifestData, err := json.MarshalIndent(manifest, "", " ")
if err != nil {
return "", err
}
if err := os.WriteFile(filepath.Join(backupDir, "manifest.json"), manifestData, 0o644); err != nil {
return "", err
}
return backupDir, nil
}
func seedSecureStorage(configDir string, store secretstore.SecretStore) error {
if err := cleanupKnownTestSecrets(store); err != nil {
return err
}
appService := app.NewAppWithSecretStore(store)
_ = appService.DeleteConnection(testConnectionID)
if _, err := appService.SaveConnection(connection.SavedConnectionInput{
ID: testConnectionID,
Name: "手工测试 PostgreSQL",
Config: connection.ConnectionConfig{
ID: testConnectionID,
Type: "postgres",
Host: "127.0.0.1",
Port: 5432,
User: "postgres",
Password: "manualtest-pg-secret",
Database: "postgres",
},
}); err != nil {
return err
}
if _, err := appService.SaveGlobalProxy(connection.SaveGlobalProxyInput{
Enabled: true,
Type: "http",
Host: "127.0.0.1",
Port: 7890,
User: "manual-test",
Password: "manualtest-proxy-secret",
}); err != nil {
return err
}
storeConfig := aiservice.NewProviderConfigStore(configDir, store)
snapshot, err := storeConfig.LoadRuntime()
if err != nil {
return err
}
snapshot.Providers = filterProviders(snapshot.Providers, testSecureProviderID, testPendingProviderID)
snapshot.Providers = append(snapshot.Providers, ai.ProviderConfig{
ID: testSecureProviderID,
Type: "custom",
Name: "手工测试 Secure Provider",
APIKey: "manualtest-ai-secret",
BaseURL: "https://api.openai.com/v1",
Model: "gpt-4o-mini",
APIFormat: "openai",
Headers: map[string]string{
"Authorization": "Bearer manualtest-header-secret",
"X-Trace-Id": "manualtest-visible",
},
MaxTokens: 2048,
Temperature: 0.2,
})
if snapshot.SafetyLevel == "" {
snapshot.SafetyLevel = ai.PermissionReadOnly
}
if snapshot.ContextLevel == "" {
snapshot.ContextLevel = ai.ContextSchemaOnly
}
return storeConfig.Save(snapshot)
}
func seedAIUpdate(configDir string, store secretstore.SecretStore) error {
if err := cleanupKnownTestSecrets(store); err != nil {
return err
}
configPath := filepath.Join(configDir, aiConfigFileName)
cfg, err := readStoredAIConfig(configPath)
if err != nil {
return err
}
cfg.Providers = filterProviders(cfg.Providers, testSecureProviderID, testPendingProviderID)
cfg.Providers = append(cfg.Providers, ai.ProviderConfig{
ID: testPendingProviderID,
Type: "custom",
Name: "手工测试 待迁移 AI",
APIKey: "manualtest-ai-update-secret",
BaseURL: "https://api.openai.com/v1",
Model: "gpt-4o-mini",
APIFormat: "openai",
MaxTokens: 1024,
})
if cfg.SchemaVersion == 0 {
cfg.SchemaVersion = 2
}
if cfg.Providers == nil {
cfg.Providers = []ai.ProviderConfig{}
}
if err := os.MkdirAll(configDir, 0o755); err != nil {
return err
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
return os.WriteFile(configPath, data, 0o644)
}
func readStoredAIConfig(configPath string) (storedAIConfig, error) {
cfg := storedAIConfig{
Providers: []ai.ProviderConfig{},
SafetyLevel: string(ai.PermissionReadOnly),
ContextLevel: string(ai.ContextSchemaOnly),
SchemaVersion: 2,
ActiveProvider: "",
}
data, err := os.ReadFile(configPath)
if err != nil {
if os.IsNotExist(err) {
return cfg, nil
}
return storedAIConfig{}, err
}
if err := json.Unmarshal(data, &cfg); err != nil {
return storedAIConfig{}, err
}
if cfg.Providers == nil {
cfg.Providers = []ai.ProviderConfig{}
}
return cfg, nil
}
func filterProviders(providers []ai.ProviderConfig, excludedIDs ...string) []ai.ProviderConfig {
excluded := make(map[string]struct{}, len(excludedIDs))
for _, id := range excludedIDs {
excluded[strings.TrimSpace(id)] = struct{}{}
}
filtered := make([]ai.ProviderConfig, 0, len(providers))
for _, provider := range providers {
if _, skip := excluded[strings.TrimSpace(provider.ID)]; skip {
continue
}
filtered = append(filtered, provider)
}
return filtered
}
func cleanupKnownTestSecrets(store secretstore.SecretStore) error {
type secretRef struct {
kind string
id string
}
refs := []secretRef{
{kind: "connection", id: testConnectionID},
{kind: "global-proxy", id: "default"},
{kind: "ai-provider", id: testSecureProviderID},
{kind: "ai-provider", id: testPendingProviderID},
}
for _, item := range refs {
ref, err := secretstore.BuildRef(item.kind, item.id)
if err != nil {
return err
}
if err := store.Delete(ref); err != nil && !isIgnorableDeleteError(err) {
return err
}
}
return nil
}
func isIgnorableDeleteError(err error) bool {
if err == nil || os.IsNotExist(err) {
return true
}
message := strings.ToLower(strings.TrimSpace(err.Error()))
return strings.Contains(message, "could not be found") ||
strings.Contains(message, "not be found in the keyring") ||
strings.Contains(message, "element not found")
}

View File

@@ -1,12 +1,12 @@
{
"name": "gonavi-client",
"version": "0.0.1",
"version": "0.6.5",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "gonavi-client",
"version": "0.0.1",
"version": "0.6.5",
"dependencies": {
"@ant-design/icons": "^5.2.6",
"@dnd-kit/core": "^6.3.1",

View File

@@ -1,7 +1,7 @@
{
"name": "gonavi-client",
"private": true,
"version": "0.0.1",
"version": "0.6.5",
"type": "module",
"scripts": {
"dev": "vite",

View File

@@ -1 +1 @@
f697e821b4acd5cf614d63d46453e8a4
8cc5d6401a6ce7dd0f500c66ce8bb4a9

View File

@@ -375,3 +375,47 @@ body[data-theme='light'] .redis-viewer-workbench .ant-radio-button-wrapper-check
.driver-manager-hscroll-inner {
height: 1px;
}
.security-update-action-btn.ant-btn,
.security-update-action-btn.ant-btn-default,
.security-update-action-btn.ant-btn-primary,
.security-update-action-btn.ant-btn-text {
box-shadow: none !important;
}
.security-update-action-btn.ant-btn:focus,
.security-update-action-btn.ant-btn:focus-visible,
.security-update-action-btn.ant-btn-default:focus,
.security-update-action-btn.ant-btn-default:focus-visible,
.security-update-action-btn.ant-btn-primary:focus,
.security-update-action-btn.ant-btn-primary:focus-visible,
.security-update-action-btn.ant-btn-text:focus,
.security-update-action-btn.ant-btn-text:focus-visible {
outline: none !important;
box-shadow: none !important;
}
.security-update-banner {
position: relative;
isolation: isolate;
}
.security-update-result-card {
transition: background 0.22s ease, box-shadow 0.22s ease, transform 0.22s ease;
}
.security-update-result-card-active {
animation: security-update-result-pulse 1.8s ease;
}
@keyframes security-update-result-pulse {
0% {
transform: translateY(0);
}
30% {
transform: translateY(-2px);
}
100% {
transform: translateY(0);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -28,6 +28,7 @@ interface AISettingsModalProps {
onClose: () => void;
darkMode: boolean;
overlayTheme: OverlayWorkbenchTheme;
focusProviderId?: string;
}
// 预设配置:每个预设映射到后端 typeopenai/anthropic/gemini/custom并附带默认 URL 和 Model
@@ -79,7 +80,7 @@ const CONTEXT_OPTIONS: { label: string; value: AIContextLevel; desc: string; ico
{ label: '含查询结果', value: 'with_results', desc: '传递最近的查询结果作为上下文', icon: '📑' },
];
const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMode, overlayTheme }) => {
const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMode, overlayTheme, focusProviderId }) => {
const [providers, setProviders] = useState<AIProviderConfig[]>([]);
const [activeProviderId, setActiveProviderId] = useState<string>('');
const [safetyLevel, setSafetyLevel] = useState<AISafetyLevel>('readonly');
@@ -135,6 +136,17 @@ const AISettingsModal: React.FC<AISettingsModalProps> = ({ open, onClose, darkMo
useEffect(() => { if (open) void loadConfig(); }, [open, loadConfig]);
useEffect(() => {
if (!open || !focusProviderId) {
return;
}
if (!providers.some((provider) => provider.id === focusProviderId)) {
return;
}
setActiveSection('providers');
setActiveProviderId(focusProviderId);
}, [focusProviderId, open, providers]);
const applyProviderEditorSession = useCallback((session: ProviderEditorSession) => {
setEditingProvider(session.editingProvider as AIProviderConfig | null);
setIsEditing(session.isEditing);

View File

@@ -5,6 +5,11 @@ import { getDbIcon, getDbDefaultColor, getDbIconLabel, DB_ICON_TYPES, PRESET_ICO
import { useStore } from '../store';
import { buildOverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme';
import { normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance';
import {
getStoredSecretPlaceholder,
normalizeConnectionSecretErrorMessage,
resolveConnectionTestFailureFeedback,
} from '../utils/connectionModalPresentation';
import { resolveConnectionSecretDraft } from '../utils/connectionSecretDraft';
import { getCustomConnectionDsnValidationMessage } from '../utils/customConnectionDsn';
import { DBGetDatabases, GetDriverStatusList, MongoDiscoverMembers, TestConnection, RedisConnect, SelectDatabaseFile, SelectSSHKeyFile } from '../../wailsjs/go/app/App';
@@ -135,7 +140,8 @@ const ConnectionModal: React.FC<{
onClose: () => void;
initialValues?: SavedConnection | null;
onOpenDriverManager?: () => void;
}> = ({ open, onClose, initialValues, onOpenDriverManager }) => {
onSaved?: (savedConnection: SavedConnection) => void | Promise<void>;
}> = ({ open, onClose, initialValues, onOpenDriverManager, onSaved }) => {
const [form] = Form.useForm();
const [loading, setLoading] = useState(false);
const [useSSL, setUseSSL] = useState(false);
@@ -1443,6 +1449,13 @@ const ConnectionModal: React.FC<{
message.success('配置已保存(未连接)');
}
if (onSaved) {
void Promise.resolve(onSaved(savedConnection)).catch((error: unknown) => {
console.warn('Failed to refresh post-save state', error);
void message.warning('配置已保存,但安全更新状态暂未刷新,请稍后重新检查');
});
}
form.resetFields();
setUseSSL(false);
setUseSSH(false);
@@ -1453,7 +1466,7 @@ const ConnectionModal: React.FC<{
setClearSecrets(createEmptyConnectionSecretClearState());
onClose();
} catch (e: any) {
message.error(e?.message || '保存失败');
message.error(normalizeConnectionSecretErrorMessage(e?.message || e, '保存失败'));
} finally {
setLoading(false);
}
@@ -1508,10 +1521,14 @@ const ConnectionModal: React.FC<{
}
return null;
};
const buildTestFailureMessage = (reason: unknown, fallback: string) => {
const text = String(reason ?? '').trim();
const normalized = text && text !== 'undefined' && text !== 'null' ? text : fallback;
return `测试失败: ${normalized}`;
const applyTestFailureFeedback = (feedback: { message: string; shouldToast: boolean }) => {
setTestResult({ type: 'error', message: feedback.message });
if (feedback.shouldToast) {
void message.error({
content: feedback.message,
key: 'connection-test-failure',
});
}
};
const handleTest = async () => {
@@ -1522,14 +1539,21 @@ const ConnectionModal: React.FC<{
const values = form.getFieldsValue(true);
const unavailableReason = await resolveDriverUnavailableReason(values.type);
if (unavailableReason) {
const failMessage = buildTestFailureMessage(unavailableReason, '驱动未安装启用');
setTestResult({ type: 'error', message: failMessage });
applyTestFailureFeedback(resolveConnectionTestFailureFeedback({
kind: 'driver_unavailable',
reason: unavailableReason,
fallback: '驱动未安装启用',
}));
promptInstallDriver(values.type, unavailableReason);
return;
}
const blockingSecretClearMessage = getBlockingSecretClearMessage(values);
if (blockingSecretClearMessage) {
setTestResult({ type: 'error', message: blockingSecretClearMessage });
applyTestFailureFeedback(resolveConnectionTestFailureFeedback({
kind: 'secret_blocked',
reason: blockingSecretClearMessage,
fallback: '连接参数不完整',
}));
return;
}
setLoading(true);
@@ -1555,6 +1579,7 @@ const ConnectionModal: React.FC<{
);
if (res.success) {
void message.destroy('connection-test-failure');
setTestResult({ type: 'success', message: res.message });
if (isRedisType) {
setRedisDbList(Array.from({ length: 16 }, (_, i) => i));
@@ -1578,27 +1603,33 @@ const ConnectionModal: React.FC<{
}
} else {
setDbList([]);
message.warning(`连接成功,但获取数据库列表失败:${dbRes.message || '未知错误'}`);
message.warning(`连接成功,但获取数据库列表失败:${normalizeConnectionSecretErrorMessage(dbRes.message, '未知错误')}`);
}
}
} else {
const failMessage = buildTestFailureMessage(
res?.message,
'连接被拒绝或参数无效,请检查后重试'
);
setTestResult({ type: 'error', message: failMessage });
applyTestFailureFeedback(resolveConnectionTestFailureFeedback({
kind: 'runtime',
reason: res?.message,
fallback: '连接被拒绝或参数无效,请检查后重试',
}));
}
} catch (e: unknown) {
if (e && typeof e === 'object' && 'errorFields' in e) {
const failMessage = '测试失败: 请先完善必填项后再测试连接';
setTestResult({ type: 'error', message: failMessage });
applyTestFailureFeedback(resolveConnectionTestFailureFeedback({
kind: 'validation',
reason: '',
fallback: '请先完善必填项后再测试连接',
}));
return;
}
const reason = e instanceof Error
? e.message
: (typeof e === 'string' ? e : '未知异常');
const failMessage = buildTestFailureMessage(reason, '未知异常');
setTestResult({ type: 'error', message: failMessage });
applyTestFailureFeedback(resolveConnectionTestFailureFeedback({
kind: 'runtime',
reason,
fallback: '未知异常',
}));
} finally {
testInFlightRef.current = false;
setLoading(false);
@@ -1624,7 +1655,7 @@ const ConnectionModal: React.FC<{
}
const result = await MongoDiscoverMembers(config as any);
if (!result.success) {
message.error(result.message || '成员发现失败');
message.error(normalizeConnectionSecretErrorMessage(result.message, '成员发现失败'));
return;
}
const data = (result.data as Record<string, any>) || {};
@@ -1645,7 +1676,7 @@ const ConnectionModal: React.FC<{
}
message.success(result.message || `发现 ${members.length} 个成员`);
} catch (error: any) {
message.error(error?.message || '成员发现失败');
message.error(normalizeConnectionSecretErrorMessage(error?.message || error, '成员发现失败'));
} finally {
setDiscoveringMembers(false);
}
@@ -2233,7 +2264,14 @@ const ConnectionModal: React.FC<{
<Input {...noAutoCapInputProps} placeholder="留空沿用主库用户名" />
</Form.Item>
<Form.Item name="mysqlReplicaPassword" label="从库密码(可选)" style={{ marginBottom: 0 }}>
<Input.Password {...noAutoCapInputProps} placeholder="留空沿用主库密码" />
<Input.Password
{...noAutoCapInputProps}
placeholder={getStoredSecretPlaceholder({
hasStoredSecret: initialValues?.hasMySQLReplicaPassword,
emptyPlaceholder: '留空沿用主库密码',
retainedLabel: '已保存从库密码',
})}
/>
</Form.Item>
</div>
{renderStoredSecretControls({
@@ -2283,7 +2321,14 @@ const ConnectionModal: React.FC<{
</Form.Item>
</div>
<Form.Item name="mongoReplicaPassword" label="副本集密码(可选)" style={{ marginBottom: 0 }}>
<Input.Password {...noAutoCapInputProps} placeholder="留空沿用主密码" />
<Input.Password
{...noAutoCapInputProps}
placeholder={getStoredSecretPlaceholder({
hasStoredSecret: initialValues?.hasMongoReplicaPassword,
emptyPlaceholder: '留空沿用主密码',
retainedLabel: '已保存副本集密码',
})}
/>
</Form.Item>
{renderStoredSecretControls({
fieldName: 'mongoReplicaPassword',
@@ -2364,7 +2409,14 @@ const ConnectionModal: React.FC<{
</Form.Item>
)}
<Form.Item name="password" label="密码 (可选)">
<Input.Password {...noAutoCapInputProps} placeholder="Redis 密码(如果设置了 requirepass" />
<Input.Password
{...noAutoCapInputProps}
placeholder={getStoredSecretPlaceholder({
hasStoredSecret: initialValues?.hasPrimaryPassword,
emptyPlaceholder: 'Redis 密码(如果设置了 requirepass',
retainedLabel: '已保存 Redis 密码',
})}
/>
</Form.Item>
{renderStoredSecretControls({
fieldName: 'password',
@@ -2397,7 +2449,14 @@ const ConnectionModal: React.FC<{
<Input {...noAutoCapInputProps} />
</Form.Item>
<Form.Item name="password" label="密码" style={{ marginBottom: 0 }}>
<Input.Password {...noAutoCapInputProps} />
<Input.Password
{...noAutoCapInputProps}
placeholder={getStoredSecretPlaceholder({
hasStoredSecret: initialValues?.hasPrimaryPassword,
emptyPlaceholder: '密码',
retainedLabel: '已保存密码',
})}
/>
</Form.Item>
{dbType === 'mongodb' && (
<Form.Item name="mongoAuthMechanism" label="验证方式" style={{ marginBottom: 0 }}>
@@ -2518,7 +2577,14 @@ const ConnectionModal: React.FC<{
<Input {...noAutoCapInputProps} placeholder="root" />
</Form.Item>
<Form.Item name="sshPassword" label="SSH 密码" style={{ flex: 1 }}>
<Input.Password {...noAutoCapInputProps} placeholder="密码" />
<Input.Password
{...noAutoCapInputProps}
placeholder={getStoredSecretPlaceholder({
hasStoredSecret: initialValues?.hasSSHPassword,
emptyPlaceholder: '密码',
retainedLabel: '已保存 SSH 密码',
})}
/>
</Form.Item>
</div>
<Form.Item label="私钥路径 (可选)" help="例如: /Users/name/.ssh/id_rsa">
@@ -2573,7 +2639,14 @@ const ConnectionModal: React.FC<{
<Input {...noAutoCapInputProps} placeholder="留空表示无认证" />
</Form.Item>
<Form.Item name="proxyPassword" label="代理密码(可选)" style={{ flex: 1 }}>
<Input.Password {...noAutoCapInputProps} placeholder="留空表示无认证" />
<Input.Password
{...noAutoCapInputProps}
placeholder={getStoredSecretPlaceholder({
hasStoredSecret: initialValues?.hasProxyPassword,
emptyPlaceholder: '留空表示无认证',
retainedLabel: '已保存代理密码',
})}
/>
</Form.Item>
</div>
{renderStoredSecretControls({
@@ -2611,7 +2684,14 @@ const ConnectionModal: React.FC<{
<Input {...noAutoCapInputProps} placeholder="留空表示无认证" />
</Form.Item>
<Form.Item name="httpTunnelPassword" label="隧道密码(可选)" style={{ flex: 1 }}>
<Input.Password {...noAutoCapInputProps} placeholder="留空表示无认证" />
<Input.Password
{...noAutoCapInputProps}
placeholder={getStoredSecretPlaceholder({
hasStoredSecret: initialValues?.hasHttpTunnelPassword,
emptyPlaceholder: '留空表示无认证',
retainedLabel: '已保存隧道密码',
})}
/>
</Form.Item>
</div>
{renderStoredSecretControls({

View File

@@ -0,0 +1,102 @@
import React from 'react';
import { Checkbox, Input, Modal, Typography } from 'antd';
const { Text } = Typography;
type ConnectionPackagePasswordModalMode = 'import' | 'export';
export interface ConnectionPackagePasswordModalProps {
open: boolean;
title: string;
mode?: ConnectionPackagePasswordModalMode;
includeSecrets?: boolean;
useFilePassword?: boolean;
password: string;
error?: string;
confirmLoading?: boolean;
confirmText?: string;
cancelText?: string;
onIncludeSecretsChange?: (value: boolean) => void;
onUseFilePasswordChange?: (value: boolean) => void;
onPasswordChange: (value: string) => void;
onConfirm: () => void;
onCancel: () => void;
}
export default function ConnectionPackagePasswordModal({
open,
title,
mode = 'import',
includeSecrets = true,
useFilePassword = false,
password,
error,
confirmLoading,
confirmText = '确认',
cancelText = '取消',
onIncludeSecretsChange,
onUseFilePasswordChange,
onPasswordChange,
onConfirm,
onCancel,
}: ConnectionPackagePasswordModalProps) {
const isExportMode = mode === 'export';
const showFilePasswordInput = isExportMode ? useFilePassword : true;
const placeholder = isExportMode ? '请输入文件保护密码(可选)' : '请输入恢复包密码';
const helperText = !includeSecrets
? '将仅导出连接配置,不包含密码。'
: (useFilePassword
? '请通过单独渠道将密码告知接收方,不要和文件一起发送。'
: '密码已加密保护。如需通过公网传输,建议设置文件保护密码。');
return (
<Modal
open={open}
title={title}
okText={confirmText}
cancelText={cancelText}
confirmLoading={confirmLoading}
onOk={onConfirm}
onCancel={onCancel}
destroyOnClose={false}
maskClosable={false}
>
{isExportMode ? (
<div style={{ display: 'flex', flexDirection: 'column', gap: 12 }}>
<Checkbox
checked={includeSecrets}
onChange={(event) => onIncludeSecretsChange?.(event.target.checked)}
>
</Checkbox>
<Checkbox
checked={useFilePassword}
disabled={!includeSecrets}
onChange={(event) => onUseFilePasswordChange?.(event.target.checked)}
>
</Checkbox>
</div>
) : null}
{showFilePasswordInput ? (
<Input.Password
autoFocus
value={password}
placeholder={placeholder}
disabled={isExportMode && !useFilePassword}
onChange={(event) => onPasswordChange(event.target.value)}
/>
) : null}
{isExportMode ? (
<Text type={useFilePassword ? 'warning' : 'secondary'} style={{ display: 'block', marginTop: 8 }}>
{helperText}
</Text>
) : null}
{error ? (
<Text type="danger" style={{ display: 'block', marginTop: 8 }}>
{error}
</Text>
) : null}
</Modal>
);
}

View File

@@ -11,6 +11,7 @@ import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
import { convertMongoShellToJsonCommand } from '../utils/mongodb';
import { getShortcutDisplay, isEditableElement, isShortcutMatch } from '../utils/shortcuts';
import { useAutoFetchVisibility } from '../utils/autoFetchVisibility';
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
const SQL_KEYWORDS = [
@@ -249,6 +250,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
const setQueryOptions = useStore(state => state.setQueryOptions);
const shortcutOptions = useStore(state => state.shortcutOptions);
const activeTabId = useStore(state => state.activeTabId);
const autoFetchVisible = useAutoFetchVisibility();
const currentSavedQuery = useMemo(() => {
const savedId = String(tab.savedQueryId || '').trim();
@@ -324,6 +326,10 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
// Fetch Database List
useEffect(() => {
if (!autoFetchVisible) {
return;
}
const fetchDbs = async () => {
const conn = connections.find(c => c.id === currentConnectionId);
if (!conn) return;
@@ -367,10 +373,14 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
}
};
void fetchDbs();
}, [currentConnectionId, connections]);
}, [autoFetchVisible, currentConnectionId, connections]);
// Fetch Metadata for Autocomplete (Cross-database)
useEffect(() => {
if (!autoFetchVisible) {
return;
}
const fetchMetadata = async () => {
const conn = connections.find(c => c.id === currentConnectionId);
if (!conn) return;
@@ -424,7 +434,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
}
};
void fetchMetadata();
}, [currentConnectionId, connections, dbList]); // dbList 变化时触发重新加载
}, [autoFetchVisible, currentConnectionId, connections, dbList]); // dbList 变化时触发重新加载
// Query ID management helpers
const setQueryId = (id: string) => {

View File

@@ -0,0 +1,154 @@
import { Button } from 'antd';
import { CloseOutlined, SafetyCertificateOutlined } from '@ant-design/icons';
import type { SecurityUpdateStatus } from '../types';
import { getSecurityUpdateStatusMeta } from '../utils/securityUpdatePresentation';
import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme';
import {
SECURITY_UPDATE_ACTION_BUTTON_CLASS,
SECURITY_UPDATE_BANNER_CLASS,
getSecurityUpdateActionButtonStyle,
getSecurityUpdateBannerSurfaceStyle,
} from '../utils/securityUpdateVisuals';
interface SecurityUpdateBannerProps {
status: SecurityUpdateStatus;
darkMode: boolean;
overlayTheme: OverlayWorkbenchTheme;
surfaceOpacity?: number;
onStart: () => void;
onRetry: () => void;
onRestart: () => void;
onOpenDetails: () => void;
onDismiss: () => void;
}
const resolvePrimaryAction = (
status: SecurityUpdateStatus,
actions: Pick<SecurityUpdateBannerProps, 'onStart' | 'onRetry' | 'onRestart' | 'onOpenDetails'>,
) => {
switch (status.overallStatus) {
case 'postponed':
return {
label: '立即更新',
onClick: actions.onStart,
};
case 'needs_attention':
return {
label: '查看详情',
onClick: actions.onOpenDetails,
};
case 'rolled_back':
return {
label: '重新开始更新',
onClick: actions.onRestart,
};
default:
return {
label: '查看详情',
onClick: actions.onOpenDetails,
};
}
};
const resolveSecondaryAction = (
status: SecurityUpdateStatus,
actions: Pick<SecurityUpdateBannerProps, 'onRetry' | 'onOpenDetails'>,
) => {
switch (status.overallStatus) {
case 'needs_attention':
return {
label: '重新检查',
onClick: actions.onRetry,
};
case 'rolled_back':
return {
label: '查看详情',
onClick: actions.onOpenDetails,
};
default:
return null;
}
};
const SecurityUpdateBanner = ({
status,
darkMode,
overlayTheme,
surfaceOpacity = 1,
onStart,
onRetry,
onRestart,
onOpenDetails,
onDismiss,
}: SecurityUpdateBannerProps) => {
const statusMeta = getSecurityUpdateStatusMeta(status);
const primaryAction = resolvePrimaryAction(status, { onStart, onRetry, onRestart, onOpenDetails });
const secondaryAction = resolveSecondaryAction(status, { onRetry, onOpenDetails });
const actionButtonStyle = getSecurityUpdateActionButtonStyle();
return (
<div
className={SECURITY_UPDATE_BANNER_CLASS}
style={{
margin: '12px 12px 0',
padding: '14px 16px',
borderRadius: 16,
...getSecurityUpdateBannerSurfaceStyle(overlayTheme, surfaceOpacity),
display: 'flex',
alignItems: 'center',
gap: 16,
overflow: 'hidden',
}}
>
<div
style={{
width: 40,
height: 40,
borderRadius: 14,
display: 'grid',
placeItems: 'center',
background: overlayTheme.iconBg,
color: overlayTheme.iconColor,
flexShrink: 0,
fontSize: 18,
}}
>
<SafetyCertificateOutlined />
</div>
<div style={{ minWidth: 0, flex: 1 }}>
<div style={{ fontSize: 15, fontWeight: 700, color: overlayTheme.titleText }}>
</div>
<div style={{ marginTop: 4, fontSize: 13, color: overlayTheme.mutedText, lineHeight: 1.7 }}>
{statusMeta.description}
</div>
</div>
<div style={{ display: 'flex', alignItems: 'center', gap: 8, flexShrink: 0 }}>
{secondaryAction ? (
<Button className={SECURITY_UPDATE_ACTION_BUTTON_CLASS} style={actionButtonStyle} onClick={secondaryAction.onClick}>
{secondaryAction.label}
</Button>
) : null}
<Button
className={SECURITY_UPDATE_ACTION_BUTTON_CLASS}
style={actionButtonStyle}
type="primary"
onClick={primaryAction.onClick}
>
{primaryAction.label}
</Button>
<Button
className={SECURITY_UPDATE_ACTION_BUTTON_CLASS}
style={{ ...actionButtonStyle, width: 36, minWidth: 36, paddingInline: 0 }}
type="text"
icon={<CloseOutlined />}
onClick={onDismiss}
/>
</div>
</div>
);
};
export type { SecurityUpdateBannerProps };
export default SecurityUpdateBanner;

View File

@@ -0,0 +1,133 @@
import { Button, Modal } from 'antd';
import { SafetyCertificateOutlined } from '@ant-design/icons';
import type { CSSProperties } from 'react';
import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme';
import {
SECURITY_UPDATE_ACTION_BUTTON_CLASS,
SECURITY_UPDATE_MODAL_CLASS,
getSecurityUpdateActionButtonStyle,
getSecurityUpdateShellSurfaceStyle,
} from '../utils/securityUpdateVisuals';
interface SecurityUpdateIntroModalProps {
open: boolean;
loading?: boolean;
darkMode: boolean;
overlayTheme: OverlayWorkbenchTheme;
surfaceOpacity?: number;
onStart: () => void;
onPostpone: () => void;
onViewDetails: () => void;
}
const actionButtonStyle: CSSProperties = {
...getSecurityUpdateActionButtonStyle(),
height: 38,
paddingInline: 18,
};
const SecurityUpdateIntroModal = ({
open,
loading = false,
darkMode,
overlayTheme,
surfaceOpacity = 1,
onStart,
onPostpone,
onViewDetails,
}: SecurityUpdateIntroModalProps) => {
return (
<Modal
rootClassName={SECURITY_UPDATE_MODAL_CLASS}
title={(
<div style={{ display: 'flex', alignItems: 'flex-start', gap: 12 }}>
<div
style={{
width: 38,
height: 38,
borderRadius: 12,
display: 'grid',
placeItems: 'center',
background: overlayTheme.iconBg,
color: overlayTheme.iconColor,
fontSize: 18,
flexShrink: 0,
}}
>
<SafetyCertificateOutlined />
</div>
<div>
<div style={{ fontSize: 16, fontWeight: 800, color: overlayTheme.titleText }}>
</div>
<div style={{ marginTop: 3, color: overlayTheme.mutedText, fontSize: 12 }}>
使
</div>
</div>
</div>
)}
open={open}
closable={!loading}
maskClosable={!loading}
keyboard={!loading}
onCancel={onPostpone}
width={560}
styles={{
content: getSecurityUpdateShellSurfaceStyle(overlayTheme, surfaceOpacity),
header: { background: 'transparent', borderBottom: 'none', paddingBottom: 8 },
body: { paddingTop: 8 },
footer: { background: 'transparent', borderTop: 'none', paddingTop: 10 },
}}
footer={[
<Button
key="details"
className={SECURITY_UPDATE_ACTION_BUTTON_CLASS}
type="primary"
ghost
style={actionButtonStyle}
onClick={onViewDetails}
disabled={loading}
>
</Button>,
<Button
key="later"
className={SECURITY_UPDATE_ACTION_BUTTON_CLASS}
type="primary"
ghost
style={actionButtonStyle}
onClick={onPostpone}
disabled={loading}
>
</Button>,
<Button
key="start"
className={SECURITY_UPDATE_ACTION_BUTTON_CLASS}
type="primary"
style={actionButtonStyle}
loading={loading}
onClick={onStart}
>
</Button>,
]}
>
<div
style={{
padding: '12px 0 6px',
color: darkMode ? 'rgba(255,255,255,0.82)' : '#2f3b52',
lineHeight: 1.8,
fontSize: 14,
}}
>
使
</div>
</Modal>
);
};
export type { SecurityUpdateIntroModalProps };
export default SecurityUpdateIntroModal;

View File

@@ -0,0 +1,69 @@
import { Modal, Spin } from 'antd';
import { SafetyCertificateOutlined } from '@ant-design/icons';
import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme';
import {
SECURITY_UPDATE_MODAL_CLASS,
getSecurityUpdateShellSurfaceStyle,
} from '../utils/securityUpdateVisuals';
interface SecurityUpdateProgressModalProps {
open: boolean;
stageText: string;
detailText?: string;
overlayTheme: OverlayWorkbenchTheme;
surfaceOpacity?: number;
}
const SecurityUpdateProgressModal = ({
open,
stageText,
detailText,
overlayTheme,
surfaceOpacity = 1,
}: SecurityUpdateProgressModalProps) => {
return (
<Modal
rootClassName={SECURITY_UPDATE_MODAL_CLASS}
open={open}
closable={false}
maskClosable={false}
keyboard={false}
footer={null}
width={420}
centered
styles={{
content: getSecurityUpdateShellSurfaceStyle(overlayTheme, surfaceOpacity),
header: { display: 'none' },
body: { padding: 28 },
}}
>
<div style={{ display: 'flex', flexDirection: 'column', alignItems: 'center', textAlign: 'center', gap: 16 }}>
<div
style={{
width: 52,
height: 52,
borderRadius: 18,
display: 'grid',
placeItems: 'center',
background: overlayTheme.iconBg,
color: overlayTheme.iconColor,
fontSize: 22,
}}
>
<SafetyCertificateOutlined />
</div>
<div style={{ fontSize: 16, fontWeight: 700, color: overlayTheme.titleText }}>
{stageText}
</div>
<div style={{ fontSize: 13, color: overlayTheme.mutedText, lineHeight: 1.7 }}>
{detailText ?? '更新过程中会保留当前可用配置,请稍候。'}
</div>
<Spin size="large" />
</div>
</Modal>
);
};
export type { SecurityUpdateProgressModalProps };
export default SecurityUpdateProgressModal;

View File

@@ -0,0 +1,337 @@
import { useEffect, useRef, useState } from 'react';
import { Button, Empty, Modal, Tag } from 'antd';
import { SafetyCertificateOutlined } from '@ant-design/icons';
import type { SecurityUpdateIssue, SecurityUpdateStatus } from '../types';
import {
getSecurityUpdateIssueActionMeta,
getSecurityUpdateIssueSeverityMeta,
getSecurityUpdateItemStatusMeta,
getSecurityUpdateStatusMeta,
sortSecurityUpdateIssues,
} from '../utils/securityUpdatePresentation';
import {
hasSecurityUpdateRecentResult,
resolveSecurityUpdateFocusState,
type SecurityUpdateFocusState,
type SecurityUpdateSettingsFocusTarget,
} from '../utils/securityUpdateRepairFlow';
import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme';
import {
SECURITY_UPDATE_ACTION_BUTTON_CLASS,
SECURITY_UPDATE_MODAL_CLASS,
SECURITY_UPDATE_RESULT_CARD_ACTIVE_CLASS,
SECURITY_UPDATE_RESULT_CARD_CLASS,
getSecurityUpdateActionButtonStyle,
getSecurityUpdateSectionSurfaceStyle,
getSecurityUpdateShellSurfaceStyle,
} from '../utils/securityUpdateVisuals';
interface SecurityUpdateSettingsModalProps {
open: boolean;
darkMode: boolean;
overlayTheme: OverlayWorkbenchTheme;
surfaceOpacity?: number;
status: SecurityUpdateStatus;
focusTarget?: SecurityUpdateSettingsFocusTarget | null;
focusRequest?: number;
onClose: () => void;
onStart: () => void;
onRetry: () => void;
onRestart: () => void;
onIssueAction: (issue: SecurityUpdateIssue) => void;
}
const sectionStyle = (
overlayTheme: OverlayWorkbenchTheme,
surfaceOpacity: number,
options?: { emphasized?: boolean },
) => ({
borderRadius: 14,
padding: 16,
...getSecurityUpdateSectionSurfaceStyle(overlayTheme, {
...options,
surfaceOpacity,
}),
});
const EMPTY_FOCUS_STATE: SecurityUpdateFocusState = {
target: null,
pulseKey: null,
};
const SecurityUpdateSettingsModal = ({
open,
darkMode,
overlayTheme,
surfaceOpacity = 1,
status,
focusTarget = null,
focusRequest = 0,
onClose,
onStart,
onRetry,
onRestart,
onIssueAction,
}: SecurityUpdateSettingsModalProps) => {
const statusMeta = getSecurityUpdateStatusMeta(status);
const sortedIssues = sortSecurityUpdateIssues(status.issues);
const showRecentResult = hasSecurityUpdateRecentResult(status);
const showStart = status.overallStatus === 'pending' || status.overallStatus === 'postponed';
const showRetry = status.overallStatus === 'needs_attention';
const showRestart = status.overallStatus === 'needs_attention' || status.overallStatus === 'rolled_back';
const actionButtonStyle = getSecurityUpdateActionButtonStyle();
const [activeFocus, setActiveFocus] = useState<SecurityUpdateFocusState>(EMPTY_FOCUS_STATE);
const statusSectionRef = useRef<HTMLDivElement | null>(null);
const recentResultRef = useRef<HTMLDivElement | null>(null);
useEffect(() => {
const nextFocus = resolveSecurityUpdateFocusState(open, focusTarget, focusRequest);
if (!nextFocus.target || !nextFocus.pulseKey) {
setActiveFocus(EMPTY_FOCUS_STATE);
return undefined;
}
const targetNode = nextFocus.target === 'recent_result'
? recentResultRef.current
: statusSectionRef.current;
if (!targetNode) {
return undefined;
}
setActiveFocus(EMPTY_FOCUS_STATE);
const animationFrame = window.requestAnimationFrame(() => {
targetNode.scrollIntoView({
block: 'nearest',
behavior: 'smooth',
});
targetNode.focus({ preventScroll: true });
setActiveFocus(nextFocus);
});
const highlightTimer = window.setTimeout(() => {
setActiveFocus((current) => (
current.pulseKey === nextFocus.pulseKey ? EMPTY_FOCUS_STATE : current
));
}, 1800);
return () => {
window.cancelAnimationFrame(animationFrame);
window.clearTimeout(highlightTimer);
};
}, [focusRequest, focusTarget, open]);
return (
<Modal
rootClassName={SECURITY_UPDATE_MODAL_CLASS}
title={(
<div style={{ display: 'flex', alignItems: 'flex-start', gap: 12 }}>
<div
style={{
width: 38,
height: 38,
borderRadius: 12,
display: 'grid',
placeItems: 'center',
background: overlayTheme.iconBg,
color: overlayTheme.iconColor,
fontSize: 18,
flexShrink: 0,
}}
>
<SafetyCertificateOutlined />
</div>
<div>
<div style={{ fontSize: 16, fontWeight: 800, color: overlayTheme.titleText }}>
</div>
<div style={{ marginTop: 3, color: overlayTheme.mutedText, fontSize: 12 }}>
</div>
</div>
</div>
)}
open={open}
onCancel={onClose}
footer={[
showRetry ? (
<Button key="retry" className={SECURITY_UPDATE_ACTION_BUTTON_CLASS} style={actionButtonStyle} onClick={onRetry}>
</Button>
) : null,
showRestart ? (
<Button key="restart" className={SECURITY_UPDATE_ACTION_BUTTON_CLASS} style={actionButtonStyle} onClick={onRestart}>
</Button>
) : null,
showStart ? (
<Button
key="start"
className={SECURITY_UPDATE_ACTION_BUTTON_CLASS}
style={actionButtonStyle}
type="primary"
onClick={onStart}
>
</Button>
) : null,
<Button key="close" className={SECURITY_UPDATE_ACTION_BUTTON_CLASS} style={actionButtonStyle} onClick={onClose}>
</Button>,
]}
width={760}
styles={{
content: getSecurityUpdateShellSurfaceStyle(overlayTheme, surfaceOpacity),
header: { background: 'transparent', borderBottom: 'none', paddingBottom: 8 },
body: { paddingTop: 8, maxHeight: 640, overflowY: 'auto' },
footer: { background: 'transparent', borderTop: 'none', paddingTop: 10 },
}}
>
<div style={{ display: 'grid', gap: 14, padding: '12px 0' }}>
<div
ref={statusSectionRef}
tabIndex={-1}
style={sectionStyle(overlayTheme, surfaceOpacity, { emphasized: activeFocus.target === 'status' })}
>
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between', gap: 12, flexWrap: 'wrap' }}>
<div>
<div style={{ fontSize: 15, fontWeight: 700, color: overlayTheme.titleText }}>
{statusMeta.label}
</div>
<div style={{ marginTop: 6, fontSize: 13, color: overlayTheme.mutedText, lineHeight: 1.7 }}>
{statusMeta.description}
</div>
</div>
<Tag color={
statusMeta.tone === 'success'
? 'success'
: statusMeta.tone === 'error'
? 'error'
: statusMeta.tone === 'processing'
? 'processing'
: statusMeta.tone === 'warning'
? 'warning'
: 'default'
}>
{statusMeta.label}
</Tag>
</div>
</div>
<div style={sectionStyle(overlayTheme, surfaceOpacity)}>
<div style={{ fontSize: 14, fontWeight: 700, color: overlayTheme.titleText, marginBottom: 12 }}>
</div>
<div style={{ display: 'grid', gridTemplateColumns: 'repeat(5, minmax(0, 1fr))', gap: 10 }}>
{[
{ label: '总计', value: status.summary.total },
{ label: '已更新', value: status.summary.updated },
{ label: '待处理', value: status.summary.pending },
{ label: '已跳过', value: status.summary.skipped },
{ label: '失败', value: status.summary.failed },
].map((item) => (
<div
key={item.label}
style={{
...getSecurityUpdateSectionSurfaceStyle(overlayTheme, { surfaceOpacity }),
borderRadius: 12,
padding: '12px 10px',
}}
>
<div style={{ fontSize: 12, color: overlayTheme.mutedText }}>{item.label}</div>
<div style={{ marginTop: 6, fontSize: 20, fontWeight: 700, color: overlayTheme.titleText }}>{item.value}</div>
</div>
))}
</div>
</div>
<div style={sectionStyle(overlayTheme, surfaceOpacity)}>
<div style={{ fontSize: 14, fontWeight: 700, color: overlayTheme.titleText, marginBottom: 12 }}>
</div>
{sortedIssues.length === 0 ? (
<Empty
image={Empty.PRESENTED_IMAGE_SIMPLE}
description="当前没有待处理项"
/>
) : (
<div style={{ display: 'grid', gap: 10 }}>
{sortedIssues.map((issue) => {
const actionMeta = getSecurityUpdateIssueActionMeta(issue);
const itemStatusMeta = getSecurityUpdateItemStatusMeta(issue.status);
const issueSeverityMeta = getSecurityUpdateIssueSeverityMeta(issue.severity);
return (
<div
key={issue.id}
style={{
...getSecurityUpdateSectionSurfaceStyle(overlayTheme, { surfaceOpacity }),
borderRadius: 12,
padding: 14,
display: 'flex',
alignItems: 'flex-start',
justifyContent: 'space-between',
gap: 16,
}}
>
<div style={{ minWidth: 0 }}>
<div style={{ display: 'flex', alignItems: 'center', gap: 8, flexWrap: 'wrap' }}>
<div style={{ fontSize: 14, fontWeight: 700, color: overlayTheme.titleText }}>
{issue.title || issue.message || issue.id}
</div>
<Tag color={itemStatusMeta.color}>
{itemStatusMeta.label}
</Tag>
<Tag color={issueSeverityMeta.color}>
{issueSeverityMeta.label}
</Tag>
</div>
<div style={{ marginTop: 6, fontSize: 13, color: overlayTheme.mutedText, lineHeight: 1.7 }}>
{issue.message || '当前项需要进一步处理后才能完成安全更新。'}
</div>
</div>
<Button
className={SECURITY_UPDATE_ACTION_BUTTON_CLASS}
style={actionButtonStyle}
type={actionMeta.emphasis === 'primary' ? 'primary' : 'default'}
onClick={() => onIssueAction(issue)}
>
{actionMeta.label}
</Button>
</div>
);
})}
</div>
)}
</div>
{showRecentResult ? (
<div
ref={recentResultRef}
tabIndex={-1}
className={[
SECURITY_UPDATE_RESULT_CARD_CLASS,
activeFocus.target === 'recent_result' ? SECURITY_UPDATE_RESULT_CARD_ACTIVE_CLASS : '',
].filter(Boolean).join(' ')}
style={sectionStyle(overlayTheme, surfaceOpacity, { emphasized: activeFocus.target === 'recent_result' })}
>
<div style={{ fontSize: 14, fontWeight: 700, color: overlayTheme.titleText, marginBottom: 8 }}>
</div>
{status.backupPath ? (
<div style={{ fontSize: 13, color: overlayTheme.mutedText, lineHeight: 1.7 }}>
<span style={{ color: overlayTheme.titleText }}>{status.backupPath}</span>
</div>
) : null}
{status.lastError ? (
<div style={{ marginTop: 8, fontSize: 13, color: '#ff7875', lineHeight: 1.7 }}>
{status.lastError}
</div>
) : null}
</div>
) : null}
</div>
</Modal>
);
};
export type { SecurityUpdateSettingsModalProps };
export default SecurityUpdateSettingsModal;

View File

@@ -42,6 +42,7 @@ import { getDbIcon } from './DatabaseIcons';
import { getTableDataDangerActionMeta, supportsTableTruncateAction, type TableDataDangerActionKind } from './tableDataDangerActions';
import { EventsOn } from '../../wailsjs/runtime/runtime';
import { normalizeOpacityForPlatform, resolveAppearanceValues } from '../utils/appearance';
import { useAutoFetchVisibility } from '../utils/autoFetchVisibility';
import FindInDatabaseModal from './FindInDatabaseModal';
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
@@ -119,6 +120,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const darkMode = theme === 'dark';
const resolvedAppearance = resolveAppearanceValues(appearance);
const opacity = normalizeOpacityForPlatform(resolvedAppearance.opacity);
const autoFetchVisible = useAutoFetchVisibility();
const [treeData, setTreeData] = useState<TreeNode[]>([]);
// Background Helper (Duplicate logic for now, ideally shared)
@@ -293,6 +295,10 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const [findInDbContext, setFindInDbContext] = useState<{ open: boolean; connectionId: string; dbName: string }>({ open: false, connectionId: '', dbName: '' });
useEffect(() => {
if (!autoFetchVisible) {
return;
}
// Refresh queries for expanded databases
const findNode = (nodes: TreeNode[], k: React.Key): TreeNode | null => {
for (const node of nodes) {
@@ -311,7 +317,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
loadTables(node);
}
});
}, [savedQueries]);
}, [autoFetchVisible, savedQueries]);
useEffect(() => {
setTreeData((prev) => {

View File

@@ -4,6 +4,7 @@ import { TableOutlined, SearchOutlined, ReloadOutlined, SortAscendingOutlined, D
import { useStore } from '../store';
import { DBQuery, DBShowCreateTable, ExportTable, DropTable, RenameTable } from '../../wailsjs/go/app/App';
import type { TabData } from '../types';
import { useAutoFetchVisibility } from '../utils/autoFetchVisibility';
import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig';
import { getTableDataDangerActionMeta, supportsTableTruncateAction, type TableDataDangerActionKind } from './tableDataDangerActions';
@@ -152,6 +153,7 @@ const TableOverview: React.FC<TableOverviewProps> = ({ tab }) => {
const [viewMode, setViewMode] = useState<ViewMode>('list');
const connection = useMemo(() => connections.find(c => c.id === tab.connectionId), [connections, tab.connectionId]);
const autoFetchVisible = useAutoFetchVisibility();
const loadData = useCallback(async () => {
if (!connection) return;
@@ -180,7 +182,12 @@ const TableOverview: React.FC<TableOverviewProps> = ({ tab }) => {
}
}, [connection, tab.dbName]);
useEffect(() => { loadData(); }, [loadData]);
useEffect(() => {
if (!autoFetchVisible) {
return;
}
void loadData();
}, [autoFetchVisible, loadData]);
const sortedFiltered = useMemo(() => {
let list = [...tables];

View File

@@ -0,0 +1,99 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
vi.mock('./App', () => ({
default: () => null,
}));
const createRootMock = vi.fn(() => ({
render: vi.fn(),
}));
vi.mock('react-dom/client', () => ({
default: {
createRoot: createRootMock,
},
createRoot: createRootMock,
}));
const dayjsLocaleMock = vi.fn();
vi.mock('dayjs', () => ({
default: Object.assign(() => null, {
locale: dayjsLocaleMock,
}),
}));
vi.mock('dayjs/locale/zh-cn', () => ({}));
const loaderConfigMock = vi.fn();
vi.mock('@monaco-editor/react', () => ({
loader: {
config: loaderConfigMock,
},
}));
const defineThemeMock = vi.fn();
vi.mock('monaco-editor', () => ({
editor: {
defineTheme: defineThemeMock,
},
}));
vi.mock('monaco-editor/esm/nls.messages.zh-cn', () => ({}));
const importMain = async () => {
await import('./main');
return (globalThis as typeof globalThis & {
window: {
go?: {
app?: {
App?: {
ImportConfigFile: () => Promise<{ success: boolean; message?: string }>;
ImportConnectionsPayload: (raw: string, password?: string) => Promise<unknown>;
ExportConnectionsPackage: (options?: { includeSecrets?: boolean; filePassword?: string }) => Promise<{ success: boolean; message?: string }>;
};
};
};
};
}).window.go?.app?.App;
};
describe('main browser mock', () => {
beforeEach(() => {
vi.resetModules();
vi.stubGlobal('window', {});
vi.stubGlobal('document', {
getElementById: vi.fn(() => ({})),
});
});
afterEach(() => {
vi.unstubAllGlobals();
vi.clearAllMocks();
vi.resetModules();
});
it('returns explicit browser-mode messages for import picker and package export', async () => {
const app = await importMain();
expect(app).toBeDefined();
await expect(app!.ImportConfigFile()).resolves.toEqual({
success: false,
message: '已取消',
});
await expect(app!.ExportConnectionsPackage({ includeSecrets: true, filePassword: '' })).resolves.toEqual({
success: false,
message: '浏览器 mock 不支持恢复包导出',
});
});
it('rejects non-array payloads instead of treating them as successful imports', async () => {
const app = await importMain();
await expect(app!.ImportConnectionsPayload('{"version":1}')).rejects.toThrow(
'浏览器 mock 不支持恢复包导入,仅支持历史 JSON 连接数组',
);
});
});

View File

@@ -127,11 +127,24 @@ if (typeof window !== 'undefined' && !(window as any).go) {
GetAppInfo: async () => ({}),
GetDataRootDirectoryInfo: async () => ({ success: true, data: cloneBrowserMockValue(mockDataRootInfo) }),
CheckForUpdates: async () => ({ success: false }),
CheckForUpdatesSilently: async () => ({ success: false }),
OpenDownloadedUpdateDirectory: async () => ({ success: false }),
OpenDriverDownloadDirectory: async (path: string) => ({ success: true, data: { path } }),
OpenDataRootDirectory: async () => ({ success: true }),
InstallUpdateAndRestart: async () => ({ success: false }),
ImportConfigFile: async () => ({ success: false }),
ImportConfigFile: async () => ({ success: false, message: '已取消' }),
ImportConnectionsPayload: async (raw: string, _password?: string) => {
try {
const parsed = JSON.parse(raw);
if (Array.isArray(parsed)) {
return parsed.map((item) => saveMockConnection(item));
}
} catch {
throw new Error('浏览器 mock 不支持恢复包导入,仅支持历史 JSON 连接数组');
}
throw new Error('浏览器 mock 不支持恢复包导入,仅支持历史 JSON 连接数组');
},
ExportConnectionsPackage: async (_options?: { includeSecrets?: boolean; filePassword?: string }) => ({ success: false, message: '浏览器 mock 不支持恢复包导出' }),
ExportData: async () => ({ success: false }),
GetGlobalProxyConfig: async () => ({ success: true, data: cloneBrowserMockValue(mockGlobalProxy) }),
SaveGlobalProxy: async (input: any) => saveMockGlobalProxy(input),

View File

@@ -91,4 +91,52 @@ describe('store appearance persistence', () => {
expect(appearance.showDataTableVerticalBorders).toBe(true);
expect(appearance.dataTableColumnWidthMode).toBe('compact');
});
it('does not clear persisted legacy connections during hydration migration', async () => {
storage.setItem('lite-db-storage', JSON.stringify({
state: {
connections: [
{
id: 'legacy-1',
name: 'Legacy',
config: {
id: 'legacy-1',
type: 'postgres',
host: 'db.local',
port: 5432,
user: 'postgres',
password: 'secret',
},
},
],
},
version: 7,
}));
const { useStore } = await importStore();
expect(useStore.getState().connections).toHaveLength(1);
expect(useStore.getState().connections[0]?.config.password).toBe('secret');
});
it('keeps legacy global proxy password during hydration until explicit cleanup', async () => {
storage.setItem('lite-db-storage', JSON.stringify({
state: {
globalProxy: {
enabled: true,
type: 'http',
host: '127.0.0.1',
port: 8080,
user: 'ops',
password: 'proxy-secret',
},
},
version: 7,
}));
const { useStore } = await importStore();
expect(useStore.getState().globalProxy.password).toBe('proxy-secret');
expect(useStore.getState().globalProxy.hasPassword).toBe(true);
});
});

View File

@@ -553,6 +553,34 @@ const sanitizeSavedQueries = (value: unknown): SavedQuery[] => {
return result;
};
const hasLegacyConnectionSecrets = (connections: SavedConnection[]): boolean => {
return connections.some((connection) => {
const config = connection?.config && typeof connection.config === 'object'
? connection.config as unknown as Record<string, unknown>
: {};
const ssh = config.ssh && typeof config.ssh === 'object'
? config.ssh as Record<string, unknown>
: {};
const proxy = config.proxy && typeof config.proxy === 'object'
? config.proxy as Record<string, unknown>
: {};
const httpTunnel = config.httpTunnel && typeof config.httpTunnel === 'object'
? config.httpTunnel as Record<string, unknown>
: {};
return (
toTrimmedString(config.password) !== ''
|| toTrimmedString(ssh.password) !== ''
|| toTrimmedString(proxy.password) !== ''
|| toTrimmedString(httpTunnel.password) !== ''
|| toTrimmedString(config.mysqlReplicaPassword) !== ''
|| toTrimmedString(config.mongoReplicaPassword) !== ''
|| toTrimmedString(config.uri) !== ''
|| toTrimmedString(config.dsn) !== ''
);
});
};
const sanitizeTheme = (value: unknown): 'light' | 'dark' => (value === 'dark' ? 'dark' : 'light');
const sanitizeSqlFormatOptions = (value: unknown): { keywordCase: 'upper' | 'lower' } => {
@@ -1242,7 +1270,7 @@ export const useStore = create<AppState>()(
migrate: (persistedState: unknown, version: number) => {
const state = unwrapPersistedAppState(persistedState) as Partial<AppState>;
const nextState: Partial<AppState> = { ...state };
nextState.connections = [];
nextState.connections = sanitizeConnections(state.connections);
if (version < 5) {
nextState.connectionTags = sanitizeConnectionTags(state.connectionTags);
} else {
@@ -1254,7 +1282,7 @@ export const useStore = create<AppState>()(
nextState.uiScale = sanitizeUiScale(state.uiScale);
nextState.fontSize = sanitizeFontSize(state.fontSize);
nextState.startupFullscreen = sanitizeStartupFullscreen(state.startupFullscreen);
nextState.globalProxy = sanitizeGlobalProxy(state.globalProxy, { allowPassword: false });
nextState.globalProxy = sanitizeGlobalProxy(state.globalProxy);
nextState.sqlFormatOptions = sanitizeSqlFormatOptions(state.sqlFormatOptions);
nextState.queryOptions = sanitizeQueryOptions(state.queryOptions);
nextState.shortcutOptions = sanitizeShortcutOptions(state.shortcutOptions);
@@ -1281,7 +1309,7 @@ export const useStore = create<AppState>()(
return {
...currentState,
...state,
connections: currentState.connections,
connections: sanitizeConnections(state.connections),
connectionTags: sanitizeConnectionTags(state.connectionTags),
savedQueries: sanitizeSavedQueries(state.savedQueries),
theme: sanitizeTheme(state.theme),
@@ -1289,7 +1317,7 @@ export const useStore = create<AppState>()(
uiScale: sanitizeUiScale(state.uiScale),
fontSize: sanitizeFontSize(state.fontSize),
startupFullscreen: sanitizeStartupFullscreen(state.startupFullscreen),
globalProxy: sanitizeGlobalProxy(state.globalProxy, { allowPassword: false }),
globalProxy: sanitizeGlobalProxy(state.globalProxy),
tableSortPreference: sanitizeTableSortPreference(state.tableSortPreference),
tableColumnOrders: sanitizeTableColumnOrders(state.tableColumnOrders),
enableColumnOrderMemory: state.enableColumnOrderMemory !== false,
@@ -1309,30 +1337,39 @@ export const useStore = create<AppState>()(
aiChatSessions: [],
};
},
partialize: (state) => ({
connectionTags: state.connectionTags,
savedQueries: state.savedQueries,
theme: state.theme,
appearance: state.appearance,
uiScale: state.uiScale,
fontSize: state.fontSize,
startupFullscreen: state.startupFullscreen,
globalProxy: toPersistedGlobalProxy(state.globalProxy),
sqlFormatOptions: state.sqlFormatOptions,
queryOptions: state.queryOptions,
shortcutOptions: state.shortcutOptions,
tableAccessCount: state.tableAccessCount,
tableSortPreference: state.tableSortPreference,
tableColumnOrders: state.tableColumnOrders,
enableColumnOrderMemory: state.enableColumnOrderMemory,
tableHiddenColumns: state.tableHiddenColumns,
enableHiddenColumnMemory: state.enableHiddenColumnMemory,
windowBounds: state.windowBounds,
windowState: state.windowState,
sidebarWidth: state.sidebarWidth,
partialize: (state) => {
const partialState: Partial<AppState> = {
connectionTags: state.connectionTags,
savedQueries: state.savedQueries,
theme: state.theme,
appearance: state.appearance,
uiScale: state.uiScale,
fontSize: state.fontSize,
startupFullscreen: state.startupFullscreen,
globalProxy: toTrimmedString(state.globalProxy.password) !== ''
? { ...state.globalProxy }
: toPersistedGlobalProxy(state.globalProxy),
sqlFormatOptions: state.sqlFormatOptions,
queryOptions: state.queryOptions,
shortcutOptions: state.shortcutOptions,
tableAccessCount: state.tableAccessCount,
tableSortPreference: state.tableSortPreference,
tableColumnOrders: state.tableColumnOrders,
enableColumnOrderMemory: state.enableColumnOrderMemory,
tableHiddenColumns: state.tableHiddenColumns,
enableHiddenColumnMemory: state.enableHiddenColumnMemory,
windowBounds: state.windowBounds,
windowState: state.windowState,
sidebarWidth: state.sidebarWidth,
};
if (hasLegacyConnectionSecrets(state.connections)) {
partialState.connections = state.connections;
}
// AI 会话数据已迁移到后端文件持久化(~/.gonavi/sessions/),不再写入 localStorage
}), // Don't persist logs
return partialState as AppState;
}, // Don't persist logs
}
)
);

View File

@@ -262,4 +262,70 @@ export interface AISafetyResult {
warningMessage?: string;
}
export type SecurityUpdateOverallStatus =
| 'not_detected'
| 'pending'
| 'postponed'
| 'in_progress'
| 'needs_attention'
| 'completed'
| 'rolled_back';
export type SecurityUpdateIssueScope = 'connection' | 'global_proxy' | 'ai_provider' | 'system';
export type SecurityUpdateIssueSeverity = 'high' | 'medium' | 'low';
export type SecurityUpdateItemStatus = 'pending' | 'updated' | 'needs_attention' | 'skipped' | 'failed';
export type SecurityUpdateIssueReasonCode =
| 'migration_required'
| 'secret_missing'
| 'field_invalid'
| 'write_conflict'
| 'validation_failed'
| 'environment_blocked';
export type SecurityUpdateIssueAction =
| 'open_connection'
| 'open_proxy_settings'
| 'open_ai_settings'
| 'retry_update'
| 'view_details';
export interface SecurityUpdateSummary {
total: number;
updated: number;
pending: number;
skipped: number;
failed: number;
}
export interface SecurityUpdateIssue {
id: string;
scope?: SecurityUpdateIssueScope;
refId?: string;
title?: string;
severity?: SecurityUpdateIssueSeverity;
status?: SecurityUpdateItemStatus;
reasonCode?: SecurityUpdateIssueReasonCode;
action?: SecurityUpdateIssueAction;
message?: string;
}
export interface SecurityUpdateStatus {
schemaVersion?: number;
migrationId?: string;
overallStatus: SecurityUpdateOverallStatus;
sourceType?: 'current_app_saved_config';
reminderVisible?: boolean;
canStart?: boolean;
canPostpone?: boolean;
canRetry?: boolean;
backupAvailable?: boolean;
backupPath?: string;
startedAt?: string;
updatedAt?: string;
completedAt?: string;
postponedAt?: string;
summary: SecurityUpdateSummary;
issues: SecurityUpdateIssue[];
lastError?: string;
}

View File

@@ -0,0 +1,22 @@
import { describe, expect, it } from 'vitest';
import { isAutoFetchVisible } from './autoFetchVisibility';
describe('isAutoFetchVisible', () => {
it('allows auto fetch only when the document is visible and not hidden', () => {
expect(isAutoFetchVisible({ hidden: false, visibilityState: 'visible' })).toBe(true);
});
it('blocks auto fetch when the page is hidden even if visibilityState looks visible', () => {
expect(isAutoFetchVisible({ hidden: true, visibilityState: 'visible' })).toBe(false);
});
it('blocks auto fetch when visibilityState is not visible', () => {
expect(isAutoFetchVisible({ hidden: false, visibilityState: 'hidden' })).toBe(false);
});
it('defaults to allowing auto fetch when document visibility APIs are unavailable', () => {
expect(isAutoFetchVisible(undefined)).toBe(true);
expect(isAutoFetchVisible({})).toBe(true);
});
});

View File

@@ -0,0 +1,54 @@
import { useEffect, useState } from 'react';
type AutoFetchVisibilitySource = Partial<Pick<Document, 'hidden' | 'visibilityState'>> | undefined;
export const isAutoFetchVisible = (source?: AutoFetchVisibilitySource): boolean => {
if (!source) {
return true;
}
if (source.hidden === true) {
return false;
}
if (source.visibilityState && source.visibilityState !== 'visible') {
return false;
}
return true;
};
const getDocumentAutoFetchVisibility = (): boolean => {
if (typeof document === 'undefined') {
return true;
}
return isAutoFetchVisible(document);
};
export const useAutoFetchVisibility = (): boolean => {
const [isVisible, setIsVisible] = useState<boolean>(() => getDocumentAutoFetchVisibility());
useEffect(() => {
if (typeof document === 'undefined') {
return undefined;
}
const syncVisibility = () => {
setIsVisible(getDocumentAutoFetchVisibility());
};
syncVisibility();
document.addEventListener('visibilitychange', syncVisibility);
window.addEventListener('focus', syncVisibility);
window.addEventListener('pageshow', syncVisibility);
return () => {
document.removeEventListener('visibilitychange', syncVisibility);
window.removeEventListener('focus', syncVisibility);
window.removeEventListener('pageshow', syncVisibility);
};
}, []);
return isVisible;
};

View File

@@ -0,0 +1,186 @@
import { describe, expect, it } from 'vitest';
import {
detectConnectionImportKind,
isConnectionPackagePasswordRequiredError,
isConnectionPackageExportCanceled,
resolveConnectionPackageExportResult,
normalizeConnectionPackagePassword,
} from './connectionExport';
describe('connectionExport', () => {
it('detects v2 app-managed packages', () => {
expect(detectConnectionImportKind(JSON.stringify({
v: 2,
kind: 'gonavi_connection_package',
p: 1,
exportedAt: '2026-04-11T21:00:00Z',
connections: [],
}))).toBe('app-managed-package');
});
it('detects v2 encrypted packages', () => {
expect(detectConnectionImportKind(JSON.stringify({
v: 2,
kind: 'gonavi_connection_package',
p: 2,
kdf: {
n: 'a2id',
m: 65536,
t: 3,
l: 4,
s: 'c2FsdA==',
},
nc: 'bm9uY2Utbm9uY2U=',
d: 'encrypted-data',
}))).toBe('encrypted-package');
});
it('rejects malformed v2 app-managed packages without connections array', () => {
expect(detectConnectionImportKind(JSON.stringify({
v: 2,
kind: 'gonavi_connection_package',
p: 1,
exportedAt: '2026-04-11T21:00:00Z',
}))).toBe('invalid');
});
it('rejects malformed v2 encrypted packages without protected payload fields', () => {
expect(detectConnectionImportKind(JSON.stringify({
v: 2,
kind: 'gonavi_connection_package',
p: 2,
kdf: {
n: 'a2id',
m: 65536,
t: 3,
l: 4,
},
}))).toBe('invalid');
});
it('detects v1 encrypted packages by gonavi envelope kind', () => {
expect(detectConnectionImportKind(JSON.stringify({
schemaVersion: 1,
kind: 'gonavi_connection_package',
cipher: 'AES-256-GCM',
kdf: {
name: 'Argon2id',
memoryKiB: 65536,
timeCost: 3,
parallelism: 4,
salt: 'c2FsdA==',
},
nonce: 'bm9uY2Utbm9uY2U=',
payload: 'encrypted-data',
}))).toBe('encrypted-package');
});
it('detects legacy imports from historical json arrays', () => {
expect(detectConnectionImportKind(JSON.stringify([
{
id: 'conn-1',
name: 'Primary',
config: {
type: 'postgres',
},
},
]))).toBe('legacy-json');
});
it('returns invalid for malformed or unsupported content', () => {
expect(detectConnectionImportKind('{not-json}')).toBe('invalid');
expect(detectConnectionImportKind(JSON.stringify({
v: 2,
kind: 'gonavi_connection_package',
p: 0,
}))).toBe('invalid');
expect(detectConnectionImportKind(JSON.stringify({
v: 2,
kind: 'gonavi_connection_package',
}))).toBe('invalid');
expect(detectConnectionImportKind(JSON.stringify({
kind: 'gonavi_connection_package',
payload: 'encrypted-data',
}))).toBe('invalid');
expect(detectConnectionImportKind(JSON.stringify([
{
foo: 'bar',
},
]))).toBe('invalid');
expect(detectConnectionImportKind(JSON.stringify({
kind: 'other_package',
payload: 'encrypted-data',
}))).toBe('invalid');
expect(detectConnectionImportKind('null')).toBe('invalid');
});
it('trims package passwords before use', () => {
expect(normalizeConnectionPackagePassword(' secret-pass ')).toBe('secret-pass');
expect(normalizeConnectionPackagePassword('\n\t \t')).toBe('');
});
it('recognizes backend password-required errors for protected packages', () => {
expect(isConnectionPackagePasswordRequiredError(new Error('恢复包密码不能为空'))).toBe(true);
expect(isConnectionPackagePasswordRequiredError({ message: '恢复包密码不能为空' })).toBe(true);
expect(isConnectionPackagePasswordRequiredError('恢复包密码不能为空')).toBe(true);
expect(isConnectionPackagePasswordRequiredError(new Error('文件密码错误或文件已损坏'))).toBe(false);
expect(isConnectionPackagePasswordRequiredError(undefined)).toBe(false);
});
it('treats export cancel as a non-error backend result', () => {
expect(isConnectionPackageExportCanceled({ success: false, message: '已取消' })).toBe(true);
expect(isConnectionPackageExportCanceled({ success: false, message: '导出失败' })).toBe(false);
expect(isConnectionPackageExportCanceled({ success: true, message: '已取消' })).toBe(false);
expect(isConnectionPackageExportCanceled(undefined)).toBe(false);
});
it('maps export results to dialog state transitions', () => {
const staleDialog = {
open: true,
mode: 'export' as const,
includeSecrets: true,
useFilePassword: false,
password: ' secret-pass ',
error: '上一次失败',
confirmLoading: false,
};
const canceledResult = resolveConnectionPackageExportResult(staleDialog, { success: false, message: '已取消' });
expect(canceledResult.kind).toBe('canceled');
if (canceledResult.kind === 'canceled') {
expect(typeof canceledResult.nextDialog).toBe('function');
expect((canceledResult.nextDialog as (current: typeof staleDialog) => typeof staleDialog)({
open: false,
mode: 'export',
includeSecrets: true,
useFilePassword: false,
password: 'secret-pass',
error: '更新后的错误',
confirmLoading: true,
})).toEqual({
open: false,
mode: 'export',
includeSecrets: true,
useFilePassword: false,
password: 'secret-pass',
error: '',
confirmLoading: false,
});
}
expect(resolveConnectionPackageExportResult(staleDialog, { success: true, message: '导出完成' })).toEqual({
kind: 'succeeded',
});
expect(resolveConnectionPackageExportResult(staleDialog, { success: false, message: '磁盘已满' })).toEqual({
kind: 'failed',
error: '磁盘已满',
});
expect(resolveConnectionPackageExportResult(staleDialog, undefined)).toEqual({
kind: 'failed',
error: '导出失败',
});
});
});

View File

@@ -0,0 +1,189 @@
import type { ConnectionConfig, SavedConnection } from '../types';
export type ConnectionImportKind = 'app-managed-package' | 'encrypted-package' | 'legacy-json' | 'invalid';
export type ConnectionPackageDialogSnapshot = {
open: boolean;
mode: 'export' | 'import';
includeSecrets: boolean;
useFilePassword: boolean;
password: string;
error: string;
confirmLoading: boolean;
};
export type ConnectionPackageDialogUpdater = (
current: ConnectionPackageDialogSnapshot,
) => ConnectionPackageDialogSnapshot;
export type ConnectionPackageExportResult =
| { kind: 'canceled'; nextDialog: ConnectionPackageDialogUpdater }
| { kind: 'succeeded' }
| { kind: 'failed'; error: string };
type JsonObject = Record<string, unknown>;
const CONNECTION_PACKAGE_KIND = 'gonavi_connection_package';
const CONNECTION_PACKAGE_SCHEMA_VERSION_V2 = 2;
const CONNECTION_PACKAGE_PROTECTION_APP_MANAGED = 1;
const CONNECTION_PACKAGE_PROTECTION_FILE_PASSWORD = 2;
const CANCELED_MESSAGE = '已取消';
const CONNECTION_PACKAGE_PASSWORD_REQUIRED_MESSAGE = '恢复包密码不能为空';
const isJsonObject = (value: unknown): value is JsonObject => (
typeof value === 'object' && value !== null && !Array.isArray(value)
);
const isConnectionPackageKDF = (value: unknown): value is JsonObject => (
isJsonObject(value)
&& typeof value.name === 'string'
&& typeof value.memoryKiB === 'number'
&& typeof value.timeCost === 'number'
&& typeof value.parallelism === 'number'
&& typeof value.salt === 'string'
);
const isConnectionPackageEnvelope = (value: unknown): value is JsonObject => (
isJsonObject(value)
&& typeof value.schemaVersion === 'number'
&& value.kind === CONNECTION_PACKAGE_KIND
&& typeof value.cipher === 'string'
&& isConnectionPackageKDF(value.kdf)
&& typeof value.nonce === 'string'
&& typeof value.payload === 'string'
);
const isConnectionPackageV2Envelope = (value: unknown): value is JsonObject => (
isJsonObject(value)
&& value.kind === CONNECTION_PACKAGE_KIND
&& value.v === CONNECTION_PACKAGE_SCHEMA_VERSION_V2
&& typeof value.p === 'number'
);
const isConnectionPackageKDFV2 = (value: unknown): value is JsonObject => (
isJsonObject(value)
&& typeof value.n === 'string'
&& typeof value.m === 'number'
&& typeof value.t === 'number'
&& typeof value.l === 'number'
&& typeof value.s === 'string'
);
const isConnectionPackageV2AppManagedEnvelope = (value: unknown): value is JsonObject => (
isConnectionPackageV2Envelope(value)
&& value.p === CONNECTION_PACKAGE_PROTECTION_APP_MANAGED
&& Array.isArray(value.connections)
);
const isConnectionPackageV2ProtectedEnvelope = (value: unknown): value is JsonObject => (
isConnectionPackageV2Envelope(value)
&& value.p === CONNECTION_PACKAGE_PROTECTION_FILE_PASSWORD
&& isConnectionPackageKDFV2(value.kdf)
&& typeof value.nc === 'string'
&& typeof value.d === 'string'
);
const isLegacyConnectionConfig = (value: unknown): value is JsonObject => (
isJsonObject(value)
&& typeof value.type === 'string'
);
const isLegacyConnectionItem = (value: unknown): value is JsonObject => (
isJsonObject(value)
&& typeof value.id === 'string'
&& typeof value.name === 'string'
&& isLegacyConnectionConfig(value.config)
);
const parseConnectionImportRaw = (raw: unknown): unknown => {
if (typeof raw !== 'string') {
return raw;
}
try {
return JSON.parse(raw);
} catch {
return undefined;
}
};
export const detectConnectionImportKind = (raw: unknown): ConnectionImportKind => {
const parsed = parseConnectionImportRaw(raw);
if (isConnectionPackageV2AppManagedEnvelope(parsed)) {
return 'app-managed-package';
}
if (isConnectionPackageV2ProtectedEnvelope(parsed)) {
return 'encrypted-package';
}
if (isConnectionPackageV2Envelope(parsed)) {
return 'invalid';
}
if (Array.isArray(parsed) && parsed.every((item) => isLegacyConnectionItem(item))) {
return 'legacy-json';
}
if (isConnectionPackageEnvelope(parsed)) {
return 'encrypted-package';
}
return 'invalid';
};
export const normalizeConnectionPackagePassword = (value: string): string => value.trim();
export const isConnectionPackagePasswordRequiredError = (value: unknown): boolean => {
if (typeof value === 'string') {
return value.trim() === CONNECTION_PACKAGE_PASSWORD_REQUIRED_MESSAGE;
}
if (value instanceof Error) {
return value.message.trim() === CONNECTION_PACKAGE_PASSWORD_REQUIRED_MESSAGE;
}
return isJsonObject(value)
&& typeof value.message === 'string'
&& value.message.trim() === CONNECTION_PACKAGE_PASSWORD_REQUIRED_MESSAGE;
};
export const isConnectionPackageExportCanceled = (result: unknown): boolean => (
isJsonObject(result)
&& result.success === false
&& result.message === CANCELED_MESSAGE
);
export const resolveConnectionPackageExportResult = (
_currentDialog: ConnectionPackageDialogSnapshot,
result: unknown,
): ConnectionPackageExportResult => {
if (isConnectionPackageExportCanceled(result)) {
return {
kind: 'canceled',
nextDialog: (current) => ({
...current,
confirmLoading: false,
error: '',
}),
};
}
if (isJsonObject(result) && result.success === true) {
return { kind: 'succeeded' };
}
return {
kind: 'failed',
error: isJsonObject(result) && typeof result.message === 'string' && result.message.trim()
? result.message
: '导出失败',
};
};
const legacyExportRemovedError = (): never => {
throw new Error('Legacy connection JSON export has been removed. Use the recovery package flow instead.');
};
export const sanitizeConnectionConfigForExport = (_config: ConnectionConfig): never => legacyExportRemovedError();
export const buildExportableConnections = (_connections: SavedConnection[]): never => legacyExportRemovedError();

View File

@@ -0,0 +1,57 @@
import { describe, expect, it } from 'vitest';
import {
getStoredSecretPlaceholder,
normalizeConnectionSecretErrorMessage,
resolveConnectionTestFailureFeedback,
} from './connectionModalPresentation';
describe('connectionModalPresentation', () => {
it('shows an explicit stored-secret placeholder instead of an empty-looking password field', () => {
expect(getStoredSecretPlaceholder({
hasStoredSecret: true,
emptyPlaceholder: '密码',
retainedLabel: '已保存密码',
})).toBe('••••••(留空表示继续沿用已保存密码)');
});
it('keeps the original placeholder when no stored secret exists', () => {
expect(getStoredSecretPlaceholder({
hasStoredSecret: false,
emptyPlaceholder: '密码',
retainedLabel: '已保存密码',
})).toBe('密码');
});
it('maps missing saved-connection errors to a secret-specific hint', () => {
expect(normalizeConnectionSecretErrorMessage('saved connection not found: conn-1')).toBe(
'未找到当前连接对应的已保存密文,请重新填写密码并保存后再试',
);
});
it('preserves existing user-facing messages', () => {
expect(normalizeConnectionSecretErrorMessage('连接测试超时')).toBe('连接测试超时');
});
it('shows a toast-worthy failure message for saved-secret lookup errors during connection tests', () => {
expect(resolveConnectionTestFailureFeedback({
kind: 'runtime',
reason: 'saved connection not found: conn-1',
fallback: '连接失败',
})).toEqual({
message: '测试失败: 未找到当前连接对应的已保存密文,请重新填写密码并保存后再试',
shouldToast: true,
});
});
it('keeps required-field validation failures inline without an extra toast', () => {
expect(resolveConnectionTestFailureFeedback({
kind: 'validation',
reason: '',
fallback: '连接失败',
})).toEqual({
message: '测试失败: 请先完善必填项后再测试连接',
shouldToast: false,
});
});
});

View File

@@ -0,0 +1,78 @@
type StoredSecretPlaceholderOptions = {
hasStoredSecret?: boolean;
emptyPlaceholder: string;
retainedLabel: string;
};
type ConnectionTestFailureKind =
| 'validation'
| 'runtime'
| 'driver_unavailable'
| 'secret_blocked';
type ConnectionTestFailureFeedback = {
message: string;
shouldToast: boolean;
};
const normalizeText = (value: unknown, fallback = ''): string => {
const text = String(value ?? '').trim();
if (!text || text === 'undefined' || text === 'null') {
return fallback;
}
return text;
};
export const getStoredSecretPlaceholder = ({
hasStoredSecret,
emptyPlaceholder,
retainedLabel,
}: StoredSecretPlaceholderOptions): string => (
hasStoredSecret
? `••••••(留空表示继续沿用${retainedLabel}`
: emptyPlaceholder
);
export const normalizeConnectionSecretErrorMessage = (
value: unknown,
fallback = '',
): string => {
const text = normalizeText(value, fallback);
const lower = text.toLowerCase();
if (lower.includes('saved connection not found:')) {
return '未找到当前连接对应的已保存密文,请重新填写密码并保存后再试';
}
if (lower.includes('secret store unavailable')) {
return '系统密文存储当前不可用,请检查系统钥匙串或凭据管理器后再试';
}
return text;
};
export const resolveConnectionTestFailureFeedback = ({
kind,
reason,
fallback,
}: {
kind: ConnectionTestFailureKind;
reason: unknown;
fallback: string;
}): ConnectionTestFailureFeedback => {
if (kind === 'validation') {
return {
message: '测试失败: 请先完善必填项后再测试连接',
shouldToast: false,
};
}
return {
message: `测试失败: ${normalizeConnectionSecretErrorMessage(reason, fallback)}`,
shouldToast: true,
};
};
export type {
ConnectionTestFailureFeedback,
ConnectionTestFailureKind,
};

View File

@@ -1,6 +1,11 @@
import { describe, expect, it } from 'vitest';
import { readLegacyPersistedSecrets, stripLegacyPersistedSecrets } from './legacyConnectionStorage';
import {
hasLegacyMigratableSensitiveItems,
readLegacyPersistedSecrets,
stripLegacyPersistedConnectionById,
stripLegacyPersistedSecrets,
} from './legacyConnectionStorage';
describe('legacy connection storage', () => {
it('extracts legacy saved connections and global proxy password from lite-db-storage', () => {
@@ -37,7 +42,7 @@ describe('legacy connection storage', () => {
expect(result.globalProxy?.password).toBe('proxy-secret');
});
it('strips persisted connection secrets but keeps secretless proxy metadata', () => {
it('clears legacy connection and proxy source data after cleanup', () => {
const payload = JSON.stringify({
state: {
connections: [
@@ -69,7 +74,110 @@ describe('legacy connection storage', () => {
const parsed = JSON.parse(sanitized);
expect(parsed.state.connections).toEqual([]);
expect(parsed.state.globalProxy.password).toBeUndefined();
expect(parsed.state.globalProxy.hasPassword).toBe(true);
expect(parsed.state.globalProxy).toBeUndefined();
});
it('treats a meaningful legacy global proxy as migratable even when it has no password', () => {
const payload = JSON.stringify({
state: {
globalProxy: {
enabled: true,
type: 'http',
host: '127.0.0.1',
port: 8080,
user: 'ops',
},
},
});
expect(hasLegacyMigratableSensitiveItems(payload)).toBe(true);
});
it('detects migratable sensitive items before cleanup and clears the signal after cleanup', () => {
const payload = JSON.stringify({
state: {
connections: [
{
id: 'conn-1',
name: 'Primary',
config: {
id: 'conn-1',
type: 'postgres',
host: 'db.local',
port: 5432,
user: 'postgres',
password: 'secret',
},
},
],
globalProxy: {
enabled: true,
type: 'http',
host: '127.0.0.1',
port: 8080,
user: 'ops',
password: 'proxy-secret',
},
},
});
expect(hasLegacyMigratableSensitiveItems(payload)).toBe(true);
expect(hasLegacyMigratableSensitiveItems(stripLegacyPersistedSecrets(payload))).toBe(false);
});
it('removes only the repaired legacy connection while preserving other source data', () => {
const payload = JSON.stringify({
state: {
connections: [
{
id: 'conn-1',
name: 'Primary',
config: {
id: 'conn-1',
type: 'postgres',
host: 'db.local',
port: 5432,
user: 'postgres',
password: 'secret',
},
},
{
id: 'conn-2',
name: 'Replica',
config: {
id: 'conn-2',
type: 'mysql',
host: 'replica.local',
port: 3306,
user: 'root',
password: 'replica-secret',
},
},
],
globalProxy: {
enabled: true,
type: 'http',
host: '127.0.0.1',
port: 8080,
user: 'ops',
password: 'proxy-secret',
},
},
});
const sanitized = stripLegacyPersistedConnectionById(payload, 'conn-1');
const parsed = JSON.parse(sanitized);
expect(parsed.state.connections).toEqual([
expect.objectContaining({
id: 'conn-2',
config: expect.objectContaining({
password: 'replica-secret',
}),
}),
]);
expect(parsed.state.globalProxy).toEqual(expect.objectContaining({
password: 'proxy-secret',
}));
});
});

View File

@@ -79,6 +79,11 @@ export function readLegacyPersistedSecrets(payload: string | null | undefined):
};
}
export function hasLegacyMigratableSensitiveItems(payload: string | null | undefined): boolean {
const legacy = readLegacyPersistedSecrets(payload);
return legacy.connections.length > 0 || legacy.globalProxy !== null;
}
export function stripLegacyPersistedSecrets(payload: string | null | undefined): string {
if (!payload || typeof payload !== 'string') {
return '';
@@ -96,15 +101,42 @@ export function stripLegacyPersistedSecrets(payload: string | null | undefined):
: parsed;
state.connections = [];
if (state.globalProxy && typeof state.globalProxy === 'object') {
const proxy = { ...(state.globalProxy as Record<string, unknown>) };
const password = toTrimmedString(proxy.password);
delete proxy.password;
if (password !== '') {
proxy.hasPassword = true;
}
state.globalProxy = proxy;
if (state.globalProxy !== undefined) {
delete state.globalProxy;
}
return JSON.stringify(parsed);
}
export function stripLegacyPersistedConnectionById(
payload: string | null | undefined,
connectionId: string,
): string {
if (!payload || typeof payload !== 'string') {
return '';
}
let parsed: Record<string, unknown>;
try {
parsed = JSON.parse(payload) as Record<string, unknown>;
} catch {
return payload;
}
const state = parsed.state && typeof parsed.state === 'object'
? parsed.state as Record<string, unknown>
: parsed;
const targetId = toTrimmedString(connectionId);
if (!targetId || !Array.isArray(state.connections)) {
return payload;
}
state.connections = state.connections.filter((item) => {
if (!item || typeof item !== 'object') {
return true;
}
return toTrimmedString((item as { id?: unknown }).id) !== targetId;
});
return JSON.stringify(parsed);
}

View File

@@ -0,0 +1,17 @@
import { describe, expect, it } from 'vitest';
import { shouldEnableMacWindowDiagnostics } from './macWindowDiagnostics';
describe('macWindowDiagnostics', () => {
it('stays disabled outside macOS runtime', () => {
expect(shouldEnableMacWindowDiagnostics(false, true)).toBe(false);
});
it('stays disabled for production builds on macOS', () => {
expect(shouldEnableMacWindowDiagnostics(true, false)).toBe(false);
});
it('enables diagnostics only for macOS development builds', () => {
expect(shouldEnableMacWindowDiagnostics(true, true)).toBe(true);
});
});

View File

@@ -0,0 +1,6 @@
export const shouldEnableMacWindowDiagnostics = (
isMacRuntime: boolean,
isDevBuild: boolean,
): boolean => {
return isMacRuntime && isDevBuild;
};

View File

@@ -0,0 +1,698 @@
import { describe, expect, it, vi } from 'vitest';
import { LEGACY_PERSIST_KEY } from './legacyConnectionStorage';
import {
bootstrapSecureConfig,
finalizeSecurityUpdateStatus,
mergeSecurityUpdateStatusWithLegacySource,
startSecurityUpdateFromBootstrap,
} from './secureConfigBootstrap';
import { stripLegacyPersistedConnectionById } from './legacyConnectionStorage';
const legacyPayload = JSON.stringify({
state: {
connections: [
{
id: 'legacy-1',
name: 'Legacy',
config: {
id: 'legacy-1',
type: 'postgres',
host: 'db.local',
port: 5432,
user: 'postgres',
password: 'secret',
},
},
],
globalProxy: {
enabled: true,
type: 'http',
host: '127.0.0.1',
port: 8080,
user: 'ops',
password: 'proxy-secret',
},
},
});
const createMemoryStorage = () => {
const data = new Map<string, string>();
return {
getItem: (key: string) => data.get(key) ?? null,
setItem: (key: string, value: string) => {
data.set(key, value);
},
removeItem: (key: string) => {
data.delete(key);
},
};
};
const createBaseArgs = (storage = createMemoryStorage()) => {
const replaceConnections = vi.fn();
const replaceGlobalProxy = vi.fn();
storage.setItem(LEGACY_PERSIST_KEY, legacyPayload);
return {
storage,
replaceConnections,
replaceGlobalProxy,
};
};
describe('secureConfigBootstrap', () => {
it('builds legacy pending summary and issue list before the first round starts', async () => {
const args = createBaseArgs();
const result = await bootstrapSecureConfig({
...args,
backend: {
GetSecurityUpdateStatus: vi.fn().mockResolvedValue({
overallStatus: 'not_detected',
summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 },
issues: [],
}),
},
});
expect(result.status.overallStatus).toBe('pending');
expect(result.status.summary).toEqual({
total: 2,
updated: 0,
pending: 2,
skipped: 0,
failed: 0,
});
expect(result.status.issues).toEqual(expect.arrayContaining([
expect.objectContaining({
scope: 'connection',
refId: 'legacy-1',
action: 'open_connection',
}),
expect.objectContaining({
scope: 'global_proxy',
action: 'open_proxy_settings',
}),
]));
});
it('shows intro when legacy sensitive items exist and backend status is pending', async () => {
const args = createBaseArgs();
const result = await bootstrapSecureConfig({
...args,
backend: {
GetSecurityUpdateStatus: vi.fn().mockResolvedValue({
overallStatus: 'pending',
summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 },
issues: [],
}),
},
});
expect(result.status.overallStatus).toBe('pending');
expect(result.shouldShowIntro).toBe(true);
expect(result.shouldShowBanner).toBe(false);
expect(args.replaceConnections).toHaveBeenCalledWith(
expect.arrayContaining([expect.objectContaining({ id: 'legacy-1' })]),
);
});
it('keeps banner flow without intro when backend status is postponed', async () => {
const args = createBaseArgs();
const result = await bootstrapSecureConfig({
...args,
backend: {
GetSecurityUpdateStatus: vi.fn().mockResolvedValue({
overallStatus: 'postponed',
summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 },
issues: [],
}),
},
});
expect(result.shouldShowIntro).toBe(false);
expect(result.shouldShowBanner).toBe(true);
});
it('keeps legacy pending summary and issues when a pre-start round is postponed', async () => {
const args = createBaseArgs();
const result = await bootstrapSecureConfig({
...args,
backend: {
GetSecurityUpdateStatus: vi.fn().mockResolvedValue({
overallStatus: 'postponed',
summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 },
issues: [],
}),
},
});
expect(result.status.overallStatus).toBe('postponed');
expect(result.status.summary.total).toBe(2);
expect(result.status.summary.pending).toBe(2);
expect(result.status.issues).toEqual(expect.arrayContaining([
expect.objectContaining({ scope: 'connection', refId: 'legacy-1' }),
expect.objectContaining({ scope: 'global_proxy' }),
]));
});
it('merges backend pending issues with legacy source items before the first round starts', async () => {
const args = createBaseArgs();
const result = await bootstrapSecureConfig({
...args,
backend: {
GetSecurityUpdateStatus: vi.fn().mockResolvedValue({
overallStatus: 'pending',
summary: { total: 1, updated: 0, pending: 1, skipped: 0, failed: 0 },
issues: [
{
id: 'ai-provider-openai-main',
scope: 'ai_provider',
refId: 'openai-main',
title: 'OpenAI',
severity: 'medium',
status: 'pending',
reasonCode: 'secret_missing',
action: 'open_ai_settings',
message: 'AI 提供商配置仍需完成安全更新',
},
],
}),
},
});
expect(result.status.overallStatus).toBe('pending');
expect(result.status.summary).toEqual({
total: 3,
updated: 0,
pending: 3,
skipped: 0,
failed: 0,
});
expect(result.status.issues).toEqual(expect.arrayContaining([
expect.objectContaining({ scope: 'ai_provider', refId: 'openai-main' }),
expect.objectContaining({ scope: 'connection', refId: 'legacy-1' }),
expect.objectContaining({ scope: 'global_proxy' }),
]));
});
it('keeps banner flow without intro when backend status is rolled_back', async () => {
const args = createBaseArgs();
const result = await bootstrapSecureConfig({
...args,
backend: {
GetSecurityUpdateStatus: vi.fn().mockResolvedValue({
overallStatus: 'rolled_back',
summary: { total: 1, updated: 0, pending: 0, skipped: 0, failed: 1 },
issues: [],
}),
},
});
expect(result.shouldShowIntro).toBe(false);
expect(result.shouldShowBanner).toBe(true);
});
it('merges legacy pending items into rolled_back status without overwriting backend system issues', () => {
const status = mergeSecurityUpdateStatusWithLegacySource({
overallStatus: 'rolled_back',
summary: { total: 1, updated: 0, pending: 0, skipped: 0, failed: 1 },
issues: [
{
id: 'system-blocked',
scope: 'system',
title: '系统回滚',
severity: 'high',
status: 'failed',
reasonCode: 'environment_blocked',
action: 'view_details',
message: '后端已回滚本轮更新,需要处理后重试。',
},
],
}, legacyPayload);
expect(status.overallStatus).toBe('rolled_back');
expect(status.summary).toEqual({
total: 3,
updated: 0,
pending: 2,
skipped: 0,
failed: 1,
});
expect(status.issues).toEqual(expect.arrayContaining([
expect.objectContaining({ id: 'system-blocked', scope: 'system' }),
expect.objectContaining({ id: 'legacy-connection-legacy-1', scope: 'connection', refId: 'legacy-1' }),
expect.objectContaining({ id: 'legacy-global-proxy-default', scope: 'global_proxy' }),
]));
});
it('merges legacy pending items into needs_attention status without overwriting backend system issues', () => {
const status = mergeSecurityUpdateStatusWithLegacySource({
overallStatus: 'needs_attention',
summary: { total: 2, updated: 1, pending: 0, skipped: 0, failed: 1 },
issues: [
{
id: 'system-partial-failure',
scope: 'system',
title: '部分失败',
severity: 'high',
status: 'failed',
reasonCode: 'environment_blocked',
action: 'view_details',
message: '部分项目迁移失败,需要人工处理。',
},
{
id: 'ai-provider-openai-main',
scope: 'ai_provider',
refId: 'openai-main',
title: 'OpenAI',
severity: 'medium',
status: 'updated',
action: 'open_ai_settings',
message: 'AI 提供商配置已完成安全更新。',
},
],
}, legacyPayload);
expect(status.overallStatus).toBe('needs_attention');
expect(status.summary).toEqual({
total: 4,
updated: 1,
pending: 2,
skipped: 0,
failed: 1,
});
expect(status.issues).toEqual(expect.arrayContaining([
expect.objectContaining({ id: 'system-partial-failure', scope: 'system' }),
expect.objectContaining({ id: 'ai-provider-openai-main', scope: 'ai_provider', refId: 'openai-main' }),
expect.objectContaining({ id: 'legacy-connection-legacy-1', scope: 'connection', refId: 'legacy-1' }),
expect.objectContaining({ id: 'legacy-global-proxy-default', scope: 'global_proxy' }),
]));
});
it('does not merge local legacy pending items back into an active migration round that already reports needs_attention', () => {
const status = mergeSecurityUpdateStatusWithLegacySource({
migrationId: 'migration-active-1',
overallStatus: 'needs_attention',
summary: { total: 3, updated: 2, pending: 1, skipped: 0, failed: 0 },
issues: [
{
id: 'ai-provider-openai-main',
scope: 'ai_provider',
refId: 'openai-main',
title: 'OpenAI',
severity: 'medium',
status: 'needs_attention',
reasonCode: 'secret_missing',
action: 'open_ai_settings',
message: 'AI 提供商配置需要补充后才能完成安全更新。',
},
],
}, legacyPayload);
expect(status.overallStatus).toBe('needs_attention');
expect(status.summary).toEqual({
total: 3,
updated: 2,
pending: 1,
skipped: 0,
failed: 0,
});
expect(status.issues).toEqual([
expect.objectContaining({ id: 'ai-provider-openai-main', scope: 'ai_provider', refId: 'openai-main' }),
]);
});
it('does not merge local legacy pending items back into a rolled_back migration round', () => {
const status = mergeSecurityUpdateStatusWithLegacySource({
migrationId: 'migration-active-2',
overallStatus: 'rolled_back',
summary: { total: 3, updated: 1, pending: 0, skipped: 0, failed: 2 },
issues: [
{
id: 'system-blocked',
scope: 'system',
title: '系统回滚',
severity: 'high',
status: 'failed',
reasonCode: 'environment_blocked',
action: 'view_details',
message: '后端已回滚本轮更新,需要处理后重试。',
},
],
}, legacyPayload);
expect(status.overallStatus).toBe('rolled_back');
expect(status.summary).toEqual({
total: 3,
updated: 1,
pending: 0,
skipped: 0,
failed: 2,
});
expect(status.issues).toEqual([
expect.objectContaining({ id: 'system-blocked', scope: 'system' }),
]);
});
it('loads backend secure config directly when no legacy source exists', async () => {
const storage = createMemoryStorage();
const replaceConnections = vi.fn();
const replaceGlobalProxy = vi.fn();
const result = await bootstrapSecureConfig({
storage,
replaceConnections,
replaceGlobalProxy,
backend: {
GetSecurityUpdateStatus: vi.fn().mockResolvedValue({
overallStatus: 'not_detected',
summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 },
issues: [],
}),
GetSavedConnections: vi.fn().mockResolvedValue([
{
id: 'secure-1',
name: 'Secure',
config: {
id: 'secure-1',
type: 'postgres',
host: 'db.local',
port: 5432,
user: 'postgres',
},
},
]),
},
});
expect(result.status.overallStatus).toBe('not_detected');
expect(replaceConnections).toHaveBeenCalledWith(
expect.arrayContaining([expect.objectContaining({ id: 'secure-1' })]),
);
});
it('shows intro when backend status is pending even without legacy local source', async () => {
const storage = createMemoryStorage();
const replaceConnections = vi.fn();
const replaceGlobalProxy = vi.fn();
const result = await bootstrapSecureConfig({
storage,
replaceConnections,
replaceGlobalProxy,
backend: {
GetSecurityUpdateStatus: vi.fn().mockResolvedValue({
overallStatus: 'pending',
summary: { total: 1, updated: 0, pending: 1, skipped: 0, failed: 0 },
issues: [],
}),
},
});
expect(result.status.overallStatus).toBe('pending');
expect(result.shouldShowIntro).toBe(true);
expect(result.shouldShowBanner).toBe(false);
});
it('falls back to legacy visible config when StartSecurityUpdate throws', async () => {
const args = createBaseArgs();
const result = await startSecurityUpdateFromBootstrap({
...args,
backend: {
StartSecurityUpdate: vi.fn().mockRejectedValue(new Error('boom')),
},
});
expect(result.status).toBeNull();
expect(result.error?.message).toContain('boom');
expect(args.replaceConnections).toHaveBeenCalledWith(
expect.arrayContaining([expect.objectContaining({ id: 'legacy-1' })]),
);
expect(args.storage.getItem(LEGACY_PERSIST_KEY)).toContain('"password":"secret"');
});
it('starts security update even when rawPayload is empty but backend supports AI-only update', async () => {
const storage = createMemoryStorage();
const replaceConnections = vi.fn();
const replaceGlobalProxy = vi.fn();
const StartSecurityUpdate = vi.fn().mockResolvedValue({
overallStatus: 'completed',
summary: { total: 1, updated: 1, pending: 0, skipped: 0, failed: 0 },
issues: [],
});
const result = await startSecurityUpdateFromBootstrap({
storage,
replaceConnections,
replaceGlobalProxy,
backend: {
StartSecurityUpdate,
},
});
expect(result.error).toBeNull();
expect(result.status?.overallStatus).toBe('completed');
expect(StartSecurityUpdate).toHaveBeenCalledWith({
sourceType: 'current_app_saved_config',
rawPayload: '',
options: {
allowPartial: true,
writeBackup: true,
},
});
});
it('keeps source-side secrets when update ends in needs_attention', async () => {
const args = createBaseArgs();
const result = await startSecurityUpdateFromBootstrap({
...args,
backend: {
StartSecurityUpdate: vi.fn().mockResolvedValue({
overallStatus: 'needs_attention',
summary: { total: 3, updated: 2, pending: 1, skipped: 0, failed: 0 },
issues: [{ id: 'ai-1' }],
}),
GetSavedConnections: vi.fn().mockResolvedValue([]),
},
});
expect(result.status?.overallStatus).toBe('needs_attention');
expect(args.storage.getItem(LEGACY_PERSIST_KEY)).toContain('"password":"secret"');
});
it('cleans source-side secrets only after completed update and backend refresh', async () => {
const args = createBaseArgs();
const result = await startSecurityUpdateFromBootstrap({
...args,
backend: {
StartSecurityUpdate: vi.fn().mockResolvedValue({
overallStatus: 'completed',
summary: { total: 3, updated: 3, pending: 0, skipped: 0, failed: 0 },
issues: [],
}),
GetSavedConnections: vi.fn().mockResolvedValue([
{
id: 'secure-1',
name: 'Secure',
config: {
id: 'secure-1',
type: 'postgres',
host: 'db.local',
port: 5432,
user: 'postgres',
},
hasPrimaryPassword: true,
},
]),
GetGlobalProxyConfig: vi.fn().mockResolvedValue({
success: true,
data: {
enabled: true,
type: 'http',
host: '127.0.0.1',
port: 8080,
user: 'ops',
hasPassword: true,
},
}),
},
});
expect(result.status?.overallStatus).toBe('completed');
expect(args.storage.getItem(LEGACY_PERSIST_KEY)).not.toContain('"password":"secret"');
expect(args.replaceConnections).toHaveBeenLastCalledWith(
expect.arrayContaining([expect.objectContaining({ id: 'secure-1' })]),
);
});
it('refreshes backend config and strips source-side secrets when a later round finishes as completed', async () => {
const args = createBaseArgs();
const status = await finalizeSecurityUpdateStatus({
...args,
backend: {
GetSavedConnections: vi.fn().mockResolvedValue([
{
id: 'secure-1',
name: 'Secure',
config: {
id: 'secure-1',
type: 'postgres',
host: 'db.local',
port: 5432,
user: 'postgres',
},
hasPrimaryPassword: true,
},
]),
GetGlobalProxyConfig: vi.fn().mockResolvedValue({
success: true,
data: {
enabled: true,
type: 'http',
host: '127.0.0.1',
port: 8080,
user: 'ops',
hasPassword: true,
},
}),
},
}, {
overallStatus: 'completed',
summary: { total: 3, updated: 3, pending: 0, skipped: 0, failed: 0 },
issues: [],
});
expect(status.overallStatus).toBe('completed');
expect(args.storage.getItem(LEGACY_PERSIST_KEY)).not.toContain('"password":"secret"');
expect(args.replaceConnections).toHaveBeenLastCalledWith(
expect.arrayContaining([expect.objectContaining({ id: 'secure-1' })]),
);
});
it('reduces legacy pending issues after a single connection is repaired before the first round starts', () => {
const initialStatus = mergeSecurityUpdateStatusWithLegacySource({
overallStatus: 'not_detected',
summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 },
issues: [],
}, legacyPayload);
const nextPayload = stripLegacyPersistedConnectionById(legacyPayload, 'legacy-1');
const status = mergeSecurityUpdateStatusWithLegacySource({
overallStatus: 'not_detected',
summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 },
issues: [],
}, nextPayload, {
previousStatus: initialStatus,
});
expect(status.overallStatus).toBe('pending');
expect(status.summary).toEqual({
total: 2,
updated: 1,
pending: 1,
skipped: 0,
failed: 0,
});
expect(status.issues).toEqual([
expect.objectContaining({
scope: 'global_proxy',
action: 'open_proxy_settings',
}),
]);
});
it('accumulates pre-start repaired progress across multiple connection saves in the same round-free session', () => {
const multiConnectionPayload = JSON.stringify({
state: {
connections: [
{
id: 'legacy-1',
name: 'Legacy 1',
config: {
id: 'legacy-1',
type: 'postgres',
host: 'db-1.local',
port: 5432,
user: 'postgres',
password: 'secret-1',
},
},
{
id: 'legacy-2',
name: 'Legacy 2',
config: {
id: 'legacy-2',
type: 'postgres',
host: 'db-2.local',
port: 5432,
user: 'postgres',
password: 'secret-2',
},
},
{
id: 'legacy-3',
name: 'Legacy 3',
config: {
id: 'legacy-3',
type: 'postgres',
host: 'db-3.local',
port: 5432,
user: 'postgres',
password: 'secret-3',
},
},
],
},
});
const backendStatus = {
overallStatus: 'not_detected' as const,
summary: { total: 0, updated: 0, pending: 0, skipped: 0, failed: 0 },
issues: [],
};
const initialStatus = mergeSecurityUpdateStatusWithLegacySource(backendStatus, multiConnectionPayload);
const afterFirstRepairPayload = stripLegacyPersistedConnectionById(multiConnectionPayload, 'legacy-1');
const afterFirstRepairStatus = mergeSecurityUpdateStatusWithLegacySource(backendStatus, afterFirstRepairPayload, {
previousStatus: initialStatus,
});
const afterSecondRepairPayload = stripLegacyPersistedConnectionById(afterFirstRepairPayload, 'legacy-2');
const afterSecondRepairStatus = mergeSecurityUpdateStatusWithLegacySource(backendStatus, afterSecondRepairPayload, {
previousStatus: afterFirstRepairStatus,
});
expect(afterFirstRepairStatus.summary).toEqual({
total: 3,
updated: 1,
pending: 2,
skipped: 0,
failed: 0,
});
expect(afterSecondRepairStatus.summary).toEqual({
total: 3,
updated: 2,
pending: 1,
skipped: 0,
failed: 0,
});
expect(afterSecondRepairStatus.issues).toEqual([
expect.objectContaining({
id: 'legacy-connection-legacy-3',
scope: 'connection',
refId: 'legacy-3',
}),
]);
});
});

View File

@@ -0,0 +1,412 @@
import {
GlobalProxyConfig,
SavedConnection,
SecurityUpdateIssue,
SecurityUpdateStatus,
SecurityUpdateSummary,
} from '../types';
import { createGlobalProxyDraft } from './globalProxyDraft';
import {
LEGACY_PERSIST_KEY,
hasLegacyMigratableSensitiveItems,
readLegacyPersistedSecrets,
stripLegacyPersistedSecrets,
} from './legacyConnectionStorage';
type StorageLike = Pick<Storage, 'getItem' | 'setItem' | 'removeItem'>;
type BackendGlobalProxyResult = {
success?: boolean;
data?: Partial<GlobalProxyConfig>;
};
type SecurityUpdateBackend = {
GetSecurityUpdateStatus?: () => Promise<Partial<SecurityUpdateStatus> | undefined>;
StartSecurityUpdate?: (request: {
sourceType: 'current_app_saved_config';
rawPayload: string;
options?: {
allowPartial?: boolean;
writeBackup?: boolean;
};
}) => Promise<Partial<SecurityUpdateStatus> | undefined>;
GetSavedConnections?: () => Promise<SavedConnection[]>;
GetGlobalProxyConfig?: () => Promise<BackendGlobalProxyResult | undefined>;
};
type SecureConfigBootstrapArgs = {
backend?: SecurityUpdateBackend;
storage?: StorageLike;
replaceConnections: (connections: SavedConnection[]) => void;
replaceGlobalProxy: (proxy: GlobalProxyConfig) => void;
};
type SecureConfigBootstrapResult = {
status: SecurityUpdateStatus;
rawPayload: string | null;
hasLegacySensitiveItems: boolean;
shouldShowIntro: boolean;
shouldShowBanner: boolean;
};
type StartSecurityUpdateResult = {
status: SecurityUpdateStatus | null;
error: Error | null;
};
type MergeSecurityUpdateStatusOptions = {
previousStatus?: Partial<SecurityUpdateStatus> | null;
};
const defaultSummary = () => ({
total: 0,
updated: 0,
pending: 0,
skipped: 0,
failed: 0,
});
const hasMeaningfulSummary = (summary: SecurityUpdateSummary): boolean => (
summary.total > 0
|| summary.updated > 0
|| summary.pending > 0
|| summary.skipped > 0
|| summary.failed > 0
);
const buildLegacyPendingDetails = (rawPayload: string | null): {
hasLegacyItems: boolean;
summary: SecurityUpdateSummary;
issues: SecurityUpdateIssue[];
} => {
const legacy = readLegacyPersistedSecrets(rawPayload);
const issues: SecurityUpdateIssue[] = legacy.connections.map((connection) => ({
id: `legacy-connection-${connection.id}`,
scope: 'connection',
refId: connection.id,
title: connection.name || connection.id,
severity: 'medium',
status: 'pending',
reasonCode: 'migration_required',
action: 'open_connection',
message: '该连接仍保存在当前应用的本地配置中,完成安全更新后会迁入新的安全存储。',
}));
if (legacy.globalProxy) {
issues.push({
id: 'legacy-global-proxy-default',
scope: 'global_proxy',
title: '全局代理',
severity: 'medium',
status: 'pending',
reasonCode: 'migration_required',
action: 'open_proxy_settings',
message: '全局代理仍保存在当前应用的本地配置中,完成安全更新后会迁入新的安全存储。',
});
}
return {
hasLegacyItems: issues.length > 0,
summary: {
total: issues.length,
updated: 0,
pending: issues.length,
skipped: 0,
failed: 0,
},
issues,
};
};
const mergeSecurityUpdateIssues = (
baseIssues: SecurityUpdateIssue[],
legacyIssues: SecurityUpdateIssue[],
): {
issues: SecurityUpdateIssue[];
addedCount: number;
} => {
const issueIds = new Set(baseIssues.map((issue) => issue.id));
const additions = legacyIssues.filter((issue) => !issueIds.has(issue.id));
return {
issues: [...baseIssues, ...additions],
addedCount: additions.length,
};
};
const isLocalLegacyIssue = (issue: Partial<SecurityUpdateIssue> | null | undefined): boolean => {
const issueId = String(issue?.id || '').trim();
return issueId.startsWith('legacy-connection-') || issueId === 'legacy-global-proxy-default';
};
const countLocalLegacyIssues = (issues: SecurityUpdateIssue[]): number => (
issues.filter((issue) => isLocalLegacyIssue(issue)).length
);
const deriveLegacySummary = (
base: SecurityUpdateStatus,
currentLegacyCount: number,
previousStatus?: Partial<SecurityUpdateStatus> | null,
): {
summary: SecurityUpdateSummary;
hasContribution: boolean;
} => {
const previousSummary = previousStatus?.summary ?? defaultSummary();
const previousIssues = Array.isArray(previousStatus?.issues) ? previousStatus.issues : [];
const previousLegacyCount = countLocalLegacyIssues(previousIssues);
const previousLegacyTotal = Math.max(
0,
previousSummary.total - base.summary.total,
previousSummary.updated - base.summary.updated + previousLegacyCount,
previousLegacyCount,
);
const previousLegacyUpdated = Math.max(
0,
Math.min(previousLegacyTotal, previousSummary.updated - base.summary.updated),
);
const repairedSincePrevious = Math.max(0, previousLegacyCount - currentLegacyCount);
const nextLegacyUpdated = Math.min(previousLegacyTotal, previousLegacyUpdated + repairedSincePrevious);
const nextLegacyTotal = Math.max(previousLegacyTotal, nextLegacyUpdated + currentLegacyCount);
return {
summary: {
total: base.summary.total + nextLegacyTotal,
updated: base.summary.updated + nextLegacyUpdated,
pending: base.summary.pending + currentLegacyCount,
skipped: base.summary.skipped,
failed: base.summary.failed,
},
hasContribution: nextLegacyTotal > 0,
};
};
export const mergeSecurityUpdateStatusWithLegacySource = (
status: Partial<SecurityUpdateStatus> | undefined,
rawPayload: string | null,
options?: MergeSecurityUpdateStatusOptions,
): SecurityUpdateStatus => {
const base: SecurityUpdateStatus = {
...defaultStatus(),
...status,
summary: {
...defaultSummary(),
...(status?.summary ?? {}),
},
issues: Array.isArray(status?.issues) ? status.issues : [],
};
const hasActiveMigrationRound = String(base.migrationId || '').trim() !== '';
const baseNonLegacyIssues = base.issues.filter((issue) => !isLocalLegacyIssue(issue));
const legacy = buildLegacyPendingDetails(rawPayload);
const legacySummary = deriveLegacySummary(base, legacy.issues.length, options?.previousStatus);
if (!legacySummary.hasContribution) {
return base;
}
const mergedIssues = mergeSecurityUpdateIssues(baseNonLegacyIssues, legacy.issues).issues;
if (base.overallStatus === 'not_detected') {
if (!legacy.hasLegacyItems) {
return base;
}
return {
...base,
overallStatus: 'pending',
reminderVisible: true,
canStart: true,
canPostpone: true,
summary: legacySummary.summary,
issues: mergedIssues,
};
}
if (base.overallStatus === 'pending' || base.overallStatus === 'postponed') {
return {
...base,
summary: hasMeaningfulSummary(base.summary) || legacy.hasLegacyItems ? legacySummary.summary : legacy.summary,
issues: mergedIssues,
canStart: true,
canPostpone: true,
reminderVisible: base.overallStatus === 'pending' ? true : base.reminderVisible,
};
}
if (base.overallStatus === 'rolled_back' || base.overallStatus === 'needs_attention') {
if (hasActiveMigrationRound) {
return base;
}
return {
...base,
summary: hasMeaningfulSummary(base.summary) || legacy.hasLegacyItems ? legacySummary.summary : legacy.summary,
issues: mergedIssues,
};
}
return base;
};
const defaultStatus = (): SecurityUpdateStatus => ({
overallStatus: 'not_detected',
summary: defaultSummary(),
issues: [],
});
const resolveStorage = (storage?: StorageLike): StorageLike | undefined => {
if (storage) {
return storage;
}
if (typeof window === 'undefined') {
return undefined;
}
return window.localStorage;
};
const applyLegacyVisibleConfig = (
rawPayload: string | null,
replaceConnections: (connections: SavedConnection[]) => void,
replaceGlobalProxy: (proxy: GlobalProxyConfig) => void,
) => {
const legacy = readLegacyPersistedSecrets(rawPayload);
if (legacy.connections.length > 0) {
replaceConnections(legacy.connections);
}
if (legacy.globalProxy) {
replaceGlobalProxy(createGlobalProxyDraft(legacy.globalProxy));
}
};
const refreshVisibleConfigFromBackend = async (
backend: SecurityUpdateBackend | undefined,
replaceConnections: (connections: SavedConnection[]) => void,
replaceGlobalProxy: (proxy: GlobalProxyConfig) => void,
allowEmptyConnections: boolean,
) => {
if (typeof backend?.GetSavedConnections === 'function') {
try {
const connections = await backend.GetSavedConnections();
if (Array.isArray(connections) && (allowEmptyConnections || connections.length > 0)) {
replaceConnections(connections);
}
} catch {
// Keep current visible state as fallback.
}
}
if (typeof backend?.GetGlobalProxyConfig === 'function') {
try {
const proxyResult = await backend.GetGlobalProxyConfig();
if (proxyResult?.success && proxyResult.data) {
replaceGlobalProxy(createGlobalProxyDraft(proxyResult.data));
}
} catch {
// Keep current visible state as fallback.
}
}
};
const cleanupLegacySourceIfCompleted = (
storage: StorageLike | undefined,
rawPayload: string | null,
status: SecurityUpdateStatus,
) => {
if (!storage || !rawPayload || status.overallStatus !== 'completed') {
return;
}
const sanitizedPayload = stripLegacyPersistedSecrets(rawPayload);
if (sanitizedPayload && sanitizedPayload !== rawPayload) {
storage.setItem(LEGACY_PERSIST_KEY, sanitizedPayload);
}
};
export async function finalizeSecurityUpdateStatus(
args: SecureConfigBootstrapArgs,
rawStatus: Partial<SecurityUpdateStatus> | undefined,
): Promise<SecurityUpdateStatus> {
const storage = resolveStorage(args.storage);
const rawPayload = storage?.getItem(LEGACY_PERSIST_KEY) ?? null;
const status = mergeSecurityUpdateStatusWithLegacySource(rawStatus, rawPayload);
if (status.overallStatus === 'completed') {
await refreshVisibleConfigFromBackend(args.backend, args.replaceConnections, args.replaceGlobalProxy, true);
cleanupLegacySourceIfCompleted(storage, rawPayload, status);
}
return status;
}
export async function bootstrapSecureConfig(args: SecureConfigBootstrapArgs): Promise<SecureConfigBootstrapResult> {
const storage = resolveStorage(args.storage);
const rawPayload = storage?.getItem(LEGACY_PERSIST_KEY) ?? null;
const hasLegacySensitiveItems = hasLegacyMigratableSensitiveItems(rawPayload);
applyLegacyVisibleConfig(rawPayload, args.replaceConnections, args.replaceGlobalProxy);
const backendStatus = typeof args.backend?.GetSecurityUpdateStatus === 'function'
? await args.backend.GetSecurityUpdateStatus()
: undefined;
const status = mergeSecurityUpdateStatusWithLegacySource(backendStatus, rawPayload);
if (!hasLegacySensitiveItems) {
await refreshVisibleConfigFromBackend(args.backend, args.replaceConnections, args.replaceGlobalProxy, true);
} else if (status.overallStatus === 'completed') {
await refreshVisibleConfigFromBackend(args.backend, args.replaceConnections, args.replaceGlobalProxy, true);
cleanupLegacySourceIfCompleted(storage, rawPayload, status);
}
return {
status,
rawPayload,
hasLegacySensitiveItems,
shouldShowIntro: status.overallStatus === 'pending',
shouldShowBanner: ['postponed', 'rolled_back', 'needs_attention'].includes(status.overallStatus),
};
}
export async function startSecurityUpdateFromBootstrap(args: SecureConfigBootstrapArgs): Promise<StartSecurityUpdateResult> {
const storage = resolveStorage(args.storage);
const rawPayload = storage?.getItem(LEGACY_PERSIST_KEY) ?? null;
const startPayload = rawPayload ?? '';
applyLegacyVisibleConfig(rawPayload, args.replaceConnections, args.replaceGlobalProxy);
if (typeof args.backend?.StartSecurityUpdate !== 'function') {
return {
status: null,
error: new Error('安全更新能力不可用'),
};
}
try {
const rawStatus = await args.backend.StartSecurityUpdate({
sourceType: 'current_app_saved_config',
rawPayload: startPayload,
options: {
allowPartial: true,
writeBackup: true,
},
});
const status = mergeSecurityUpdateStatusWithLegacySource(rawStatus, rawPayload);
if (status.overallStatus === 'completed') {
await refreshVisibleConfigFromBackend(args.backend, args.replaceConnections, args.replaceGlobalProxy, true);
cleanupLegacySourceIfCompleted(storage, rawPayload, status);
}
return { status, error: null };
} catch (error) {
applyLegacyVisibleConfig(rawPayload, args.replaceConnections, args.replaceGlobalProxy);
return {
status: null,
error: error instanceof Error ? error : new Error(String(error)),
};
}
}
export type {
BackendGlobalProxyResult,
MergeSecurityUpdateStatusOptions,
SecurityUpdateBackend,
SecureConfigBootstrapArgs,
SecureConfigBootstrapResult,
StartSecurityUpdateResult,
};

View File

@@ -0,0 +1,96 @@
import { describe, expect, it } from 'vitest';
import type { SecurityUpdateIssue, SecurityUpdateStatus } from '../types';
import {
getSecurityUpdateIssueSeverityMeta,
getSecurityUpdateItemStatusMeta,
getSecurityUpdateIssueActionMeta,
getSecurityUpdateStatusMeta,
resolveSecurityUpdateEntryVisibility,
sortSecurityUpdateIssues,
} from './securityUpdatePresentation';
const createStatus = (overallStatus: SecurityUpdateStatus['overallStatus']): SecurityUpdateStatus => ({
overallStatus,
summary: {
total: 0,
updated: 0,
pending: 0,
skipped: 0,
failed: 0,
},
issues: [],
});
describe('securityUpdatePresentation', () => {
it('sorts issues by severity from high to low', () => {
const issues: SecurityUpdateIssue[] = [
{ id: 'medium-1', severity: 'medium' },
{ id: 'low-1', severity: 'low' },
{ id: 'high-1', severity: 'high' },
{ id: 'medium-2', severity: 'medium' },
];
expect(sortSecurityUpdateIssues(issues).map((issue) => issue.id)).toEqual([
'high-1',
'medium-1',
'medium-2',
'low-1',
]);
});
it('maps needs_attention, rolled_back and completed to stable display labels', () => {
expect(getSecurityUpdateStatusMeta(createStatus('needs_attention')).label).toBe('待处理');
expect(getSecurityUpdateStatusMeta(createStatus('rolled_back')).label).toBe('已回退');
expect(getSecurityUpdateStatusMeta(createStatus('completed')).label).toBe('已完成');
});
it('resolves intro, banner and detail entry visibility for key overall states', () => {
expect(resolveSecurityUpdateEntryVisibility(createStatus('pending'))).toEqual({
showIntro: true,
showBanner: false,
showDetailEntry: true,
});
expect(resolveSecurityUpdateEntryVisibility(createStatus('postponed'))).toEqual({
showIntro: false,
showBanner: true,
showDetailEntry: true,
});
expect(resolveSecurityUpdateEntryVisibility(createStatus('rolled_back'))).toEqual({
showIntro: false,
showBanner: true,
showDetailEntry: true,
});
});
it('maps issue scope actions to existing repair entry labels', () => {
expect(getSecurityUpdateIssueActionMeta({ id: 'conn', scope: 'connection', action: 'open_connection' }).label).toBe('打开连接');
expect(getSecurityUpdateIssueActionMeta({ id: 'proxy', scope: 'global_proxy', action: 'open_proxy_settings' }).label).toBe('代理设置');
expect(getSecurityUpdateIssueActionMeta({ id: 'ai', scope: 'ai_provider', action: 'open_ai_settings' }).label).toBe('AI 设置');
expect(getSecurityUpdateIssueActionMeta({ id: 'system', scope: 'system', action: 'view_details' }).label).toBe('查看详情');
});
it('maps item status to explicit Chinese labels instead of reusing severity wording', () => {
expect(getSecurityUpdateItemStatusMeta('needs_attention')).toEqual({
label: '待处理',
color: 'warning',
});
expect(getSecurityUpdateItemStatusMeta('updated')).toEqual({
label: '已更新',
color: 'success',
});
});
it('maps issue severity to dedicated risk labels', () => {
expect(getSecurityUpdateIssueSeverityMeta('medium')).toEqual({
label: '中风险',
color: 'warning',
});
expect(getSecurityUpdateIssueSeverityMeta('high')).toEqual({
label: '高风险',
color: 'error',
});
});
});

View File

@@ -0,0 +1,210 @@
import type {
SecurityUpdateIssue,
SecurityUpdateIssueAction,
SecurityUpdateIssueSeverity,
SecurityUpdateItemStatus,
SecurityUpdateStatus,
} from '../types';
type SecurityUpdateTone = 'default' | 'warning' | 'processing' | 'success' | 'error';
type SecurityUpdateStatusMeta = {
label: string;
description: string;
tone: SecurityUpdateTone;
};
type SecurityUpdateEntryVisibility = {
showIntro: boolean;
showBanner: boolean;
showDetailEntry: boolean;
};
type SecurityUpdateIssueActionMeta = {
label: string;
emphasis: 'primary' | 'default';
};
type SecurityUpdateBadgeMeta = {
label: string;
color: SecurityUpdateTone;
};
const severityWeight: Record<SecurityUpdateIssueSeverity, number> = {
high: 0,
medium: 1,
low: 2,
};
const actionMetaMap: Record<SecurityUpdateIssueAction, SecurityUpdateIssueActionMeta> = {
open_connection: {
label: '打开连接',
emphasis: 'primary',
},
open_proxy_settings: {
label: '代理设置',
emphasis: 'primary',
},
open_ai_settings: {
label: 'AI 设置',
emphasis: 'primary',
},
retry_update: {
label: '重新检查',
emphasis: 'primary',
},
view_details: {
label: '查看详情',
emphasis: 'default',
},
};
const itemStatusMetaMap: Record<SecurityUpdateItemStatus, SecurityUpdateBadgeMeta> = {
pending: {
label: '待更新',
color: 'processing',
},
updated: {
label: '已更新',
color: 'success',
},
needs_attention: {
label: '待处理',
color: 'warning',
},
skipped: {
label: '已跳过',
color: 'default',
},
failed: {
label: '失败',
color: 'error',
},
};
const issueSeverityMetaMap: Record<SecurityUpdateIssueSeverity, SecurityUpdateBadgeMeta> = {
high: {
label: '高风险',
color: 'error',
},
medium: {
label: '中风险',
color: 'warning',
},
low: {
label: '低风险',
color: 'default',
},
};
export function sortSecurityUpdateIssues(issues: SecurityUpdateIssue[]): SecurityUpdateIssue[] {
return [...issues].sort((left, right) => {
const leftWeight = severityWeight[left.severity ?? 'low'];
const rightWeight = severityWeight[right.severity ?? 'low'];
if (leftWeight !== rightWeight) {
return leftWeight - rightWeight;
}
return left.id.localeCompare(right.id);
});
}
export function getSecurityUpdateStatusMeta(status: SecurityUpdateStatus): SecurityUpdateStatusMeta {
switch (status.overallStatus) {
case 'pending':
return {
label: '待更新',
description: '检测到可进行的安全更新,你可以现在开始或稍后继续。',
tone: 'warning',
};
case 'postponed':
return {
label: '待更新',
description: '本次安全更新已延后,当前可用配置会继续保留。',
tone: 'warning',
};
case 'in_progress':
return {
label: '更新中',
description: '正在检查并更新已保存配置的安全存储。',
tone: 'processing',
};
case 'needs_attention':
return {
label: '待处理',
description: '更新尚未完成,有少量配置需要你处理。',
tone: 'warning',
};
case 'completed':
return {
label: '已完成',
description: '已保存配置已完成安全更新。',
tone: 'success',
};
case 'rolled_back':
return {
label: '已回退',
description: '本次更新未完成,系统已保留当前可用配置。',
tone: 'error',
};
case 'not_detected':
default:
return {
label: '未检测到',
description: '当前没有需要处理的安全更新。',
tone: 'default',
};
}
}
export function resolveSecurityUpdateEntryVisibility(status: SecurityUpdateStatus): SecurityUpdateEntryVisibility {
switch (status.overallStatus) {
case 'pending':
return {
showIntro: true,
showBanner: false,
showDetailEntry: true,
};
case 'postponed':
case 'needs_attention':
case 'rolled_back':
return {
showIntro: false,
showBanner: true,
showDetailEntry: true,
};
case 'completed':
case 'in_progress':
return {
showIntro: false,
showBanner: false,
showDetailEntry: true,
};
case 'not_detected':
default:
return {
showIntro: false,
showBanner: false,
showDetailEntry: false,
};
}
}
export function getSecurityUpdateIssueActionMeta(issue: Partial<SecurityUpdateIssue>): SecurityUpdateIssueActionMeta {
return actionMetaMap[issue.action ?? 'view_details'] ?? actionMetaMap.view_details;
}
export function getSecurityUpdateItemStatusMeta(status?: SecurityUpdateItemStatus): SecurityUpdateBadgeMeta {
return itemStatusMetaMap[status ?? 'pending'] ?? itemStatusMetaMap.pending;
}
export function getSecurityUpdateIssueSeverityMeta(severity?: SecurityUpdateIssueSeverity): SecurityUpdateBadgeMeta {
return issueSeverityMetaMap[severity ?? 'low'] ?? issueSeverityMetaMap.low;
}
export type {
SecurityUpdateBadgeMeta,
SecurityUpdateEntryVisibility,
SecurityUpdateIssueActionMeta,
SecurityUpdateStatusMeta,
SecurityUpdateTone,
};

View File

@@ -0,0 +1,155 @@
import { describe, expect, it } from 'vitest';
import type { SavedConnection, SecurityUpdateIssue, SecurityUpdateStatus } from '../types';
import {
hasSecurityUpdateRecentResult,
resolveSecurityUpdateFocusState,
resolveSecurityUpdateRepairEntry,
resolveSecurityUpdateSettingsFocusTarget,
shouldRefreshSecurityUpdateDetailsFocus,
shouldReopenSecurityUpdateDetails,
shouldRetrySecurityUpdateAfterRepairSave,
} from './securityUpdateRepairFlow';
const createConnection = (id: string): SavedConnection => ({
id,
name: `连接-${id}`,
config: {
id,
type: 'postgres',
host: 'db.local',
port: 5432,
user: 'postgres',
},
});
const createStatus = (overrides: Partial<SecurityUpdateStatus> = {}): SecurityUpdateStatus => ({
overallStatus: 'needs_attention',
summary: {
total: 1,
updated: 0,
pending: 1,
skipped: 0,
failed: 0,
},
issues: [],
...overrides,
});
describe('securityUpdateRepairFlow', () => {
it('opens the matching connection and preserves the return source for security update repairs', () => {
const target = createConnection('conn-1');
const issue: SecurityUpdateIssue = {
id: 'issue-1',
action: 'open_connection',
refId: 'conn-1',
};
expect(resolveSecurityUpdateRepairEntry(issue, [target])).toEqual({
type: 'connection',
connection: target,
repairSource: 'connection',
});
});
it('returns a user-facing warning when the target connection no longer exists', () => {
const issue: SecurityUpdateIssue = {
id: 'issue-1',
action: 'open_connection',
refId: 'missing-conn',
};
expect(resolveSecurityUpdateRepairEntry(issue, [createConnection('conn-1')])).toEqual({
type: 'warning',
message: '未找到对应连接,请先重新检查最新状态',
});
});
it('maps proxy, ai and retry actions to the expected repair entry', () => {
expect(resolveSecurityUpdateRepairEntry({ id: 'proxy', action: 'open_proxy_settings' }, [])).toEqual({
type: 'proxy',
repairSource: 'proxy',
});
expect(resolveSecurityUpdateRepairEntry({ id: 'ai', action: 'open_ai_settings', refId: 'provider-1' }, [])).toEqual({
type: 'ai',
providerId: 'provider-1',
repairSource: 'ai',
});
expect(resolveSecurityUpdateRepairEntry({ id: 'retry', action: 'retry_update' }, [])).toEqual({
type: 'retry',
});
});
it('routes view_details actions to the latest result section when a recent result exists', () => {
const status = createStatus({
backupPath: '/tmp/gonavi-backup.json',
lastError: '写入新密钥失败',
});
expect(hasSecurityUpdateRecentResult(status)).toBe(true);
expect(resolveSecurityUpdateSettingsFocusTarget(status)).toBe('recent_result');
expect(resolveSecurityUpdateRepairEntry({ id: 'details', action: 'view_details' }, [], status)).toEqual({
type: 'details',
focusTarget: 'recent_result',
});
});
it('falls back to the status section when no recent result is available yet', () => {
const status = createStatus();
expect(hasSecurityUpdateRecentResult(status)).toBe(false);
expect(resolveSecurityUpdateSettingsFocusTarget(status)).toBe('status');
expect(resolveSecurityUpdateRepairEntry({ id: 'details', action: 'view_details' }, [], status)).toEqual({
type: 'details',
focusTarget: 'status',
});
});
it('builds a fresh focus pulse for repeated details clicks and clears it when the modal closes', () => {
expect(resolveSecurityUpdateFocusState(true, 'status', 1)).toEqual({
target: 'status',
pulseKey: 'status:1',
});
expect(resolveSecurityUpdateFocusState(true, 'status', 2)).toEqual({
target: 'status',
pulseKey: 'status:2',
});
expect(resolveSecurityUpdateFocusState(false, 'status', 2)).toEqual({
target: null,
pulseKey: null,
});
expect(resolveSecurityUpdateFocusState(true, null, 3)).toEqual({
target: null,
pulseKey: null,
});
});
it('reopens security update details after closing a repair entry opened from that page', () => {
expect(shouldReopenSecurityUpdateDetails('connection')).toBe(true);
expect(shouldReopenSecurityUpdateDetails('proxy')).toBe(true);
expect(shouldReopenSecurityUpdateDetails('ai')).toBe(true);
expect(shouldReopenSecurityUpdateDetails(null)).toBe(false);
});
it('retries the current round automatically after saving a connection from the repair flow', () => {
expect(shouldRetrySecurityUpdateAfterRepairSave('connection')).toBe(true);
expect(shouldRetrySecurityUpdateAfterRepairSave('proxy')).toBe(false);
expect(shouldRetrySecurityUpdateAfterRepairSave('ai')).toBe(false);
expect(shouldRetrySecurityUpdateAfterRepairSave(null)).toBe(false);
});
it('does not force a new focus pulse when the details modal is already open and only the round result is refreshing', () => {
expect(shouldRefreshSecurityUpdateDetailsFocus({
requestedOpen: true,
wasOpen: true,
})).toBe(false);
expect(shouldRefreshSecurityUpdateDetailsFocus({
requestedOpen: true,
wasOpen: false,
})).toBe(true);
expect(shouldRefreshSecurityUpdateDetailsFocus({
requestedOpen: false,
wasOpen: true,
})).toBe(false);
});
});

View File

@@ -0,0 +1,126 @@
import type { SavedConnection, SecurityUpdateIssue, SecurityUpdateStatus } from '../types';
export type SecurityUpdateRepairSource = 'connection' | 'proxy' | 'ai';
export type SecurityUpdateSettingsFocusTarget = 'recent_result' | 'status';
export type SecurityUpdateFocusState = {
target: SecurityUpdateSettingsFocusTarget | null;
pulseKey: string | null;
};
export type SecurityUpdateRepairEntry =
| {
type: 'connection';
connection: SavedConnection;
repairSource: 'connection';
}
| {
type: 'proxy';
repairSource: 'proxy';
}
| {
type: 'ai';
providerId?: string;
repairSource: 'ai';
}
| {
type: 'retry';
}
| {
type: 'details';
focusTarget: SecurityUpdateSettingsFocusTarget;
}
| {
type: 'warning';
message: string;
};
export const hasSecurityUpdateRecentResult = (
status?: Pick<SecurityUpdateStatus, 'backupPath' | 'lastError'> | null,
): boolean => Boolean(status?.backupPath || status?.lastError);
export const resolveSecurityUpdateSettingsFocusTarget = (
status?: Pick<SecurityUpdateStatus, 'backupPath' | 'lastError'> | null,
): SecurityUpdateSettingsFocusTarget => (
hasSecurityUpdateRecentResult(status) ? 'recent_result' : 'status'
);
export const resolveSecurityUpdateFocusState = (
open: boolean,
focusTarget: SecurityUpdateSettingsFocusTarget | null | undefined,
focusRequest: number,
): SecurityUpdateFocusState => {
if (!open || !focusTarget) {
return {
target: null,
pulseKey: null,
};
}
return {
target: focusTarget,
pulseKey: `${focusTarget}:${focusRequest}`,
};
};
export const resolveSecurityUpdateRepairEntry = (
issue: SecurityUpdateIssue,
connections: SavedConnection[],
status?: Pick<SecurityUpdateStatus, 'backupPath' | 'lastError'> | null,
): SecurityUpdateRepairEntry => {
if (issue.action === 'open_connection') {
const target = connections.find((connection) => connection.id === issue.refId);
if (!target) {
return {
type: 'warning',
message: '未找到对应连接,请先重新检查最新状态',
};
}
return {
type: 'connection',
connection: target,
repairSource: 'connection',
};
}
if (issue.action === 'open_proxy_settings') {
return {
type: 'proxy',
repairSource: 'proxy',
};
}
if (issue.action === 'open_ai_settings') {
return {
type: 'ai',
providerId: issue.refId || undefined,
repairSource: 'ai',
};
}
if (issue.action === 'retry_update') {
return {
type: 'retry',
};
}
return {
type: 'details',
focusTarget: resolveSecurityUpdateSettingsFocusTarget(status),
};
};
export const shouldReopenSecurityUpdateDetails = (
repairSource: SecurityUpdateRepairSource | null | undefined,
): boolean => repairSource === 'connection' || repairSource === 'proxy' || repairSource === 'ai';
export const shouldRefreshSecurityUpdateDetailsFocus = ({
requestedOpen,
wasOpen,
}: {
requestedOpen: boolean;
wasOpen: boolean;
}): boolean => requestedOpen && !wasOpen;
export const shouldRetrySecurityUpdateAfterRepairSave = (
repairSource: SecurityUpdateRepairSource | null | undefined,
): boolean => repairSource === 'connection';

View File

@@ -0,0 +1,99 @@
import { describe, expect, it } from 'vitest';
import { buildOverlayWorkbenchTheme } from './overlayWorkbenchTheme';
import {
SECURITY_UPDATE_ACTION_BUTTON_CLASS,
SECURITY_UPDATE_BANNER_CLASS,
SECURITY_UPDATE_RESULT_CARD_ACTIVE_CLASS,
SECURITY_UPDATE_RESULT_CARD_CLASS,
getSecurityUpdateActionButtonStyle,
getSecurityUpdateBannerSurfaceStyle,
getSecurityUpdateSectionSurfaceStyle,
getSecurityUpdateShellSurfaceStyle,
} from './securityUpdateVisuals';
describe('securityUpdateVisuals', () => {
it('builds action buttons without default ant focus glow shadow', () => {
expect(SECURITY_UPDATE_ACTION_BUTTON_CLASS).toBe('security-update-action-btn');
expect(SECURITY_UPDATE_BANNER_CLASS).toBe('security-update-banner');
expect(SECURITY_UPDATE_RESULT_CARD_CLASS).toBe('security-update-result-card');
expect(SECURITY_UPDATE_RESULT_CARD_ACTIVE_CLASS).toBe('security-update-result-card-active');
expect(getSecurityUpdateActionButtonStyle()).toMatchObject({
height: 36,
borderRadius: 12,
boxShadow: 'none',
fontWeight: 600,
});
});
it('keeps the shell surface aligned with overlay shell tokens in light and dark mode', () => {
const lightTheme = buildOverlayWorkbenchTheme(false);
const darkTheme = buildOverlayWorkbenchTheme(true);
expect(getSecurityUpdateShellSurfaceStyle(lightTheme)).toMatchObject({
border: lightTheme.shellBorder,
background: lightTheme.shellBg,
boxShadow: lightTheme.shellShadow,
backdropFilter: lightTheme.shellBackdropFilter,
});
expect(getSecurityUpdateShellSurfaceStyle(darkTheme)).toMatchObject({
border: darkTheme.shellBorder,
background: darkTheme.shellBg,
boxShadow: darkTheme.shellShadow,
backdropFilter: darkTheme.shellBackdropFilter,
});
});
it('keeps the banner surface aligned with overlay shell tokens instead of translucent section tokens', () => {
const lightTheme = buildOverlayWorkbenchTheme(false);
const darkTheme = buildOverlayWorkbenchTheme(true);
expect(getSecurityUpdateBannerSurfaceStyle(lightTheme)).toMatchObject({
border: lightTheme.shellBorder,
background: lightTheme.shellBg,
boxShadow: 'none',
backdropFilter: lightTheme.shellBackdropFilter,
});
expect(getSecurityUpdateBannerSurfaceStyle(darkTheme)).toMatchObject({
border: darkTheme.shellBorder,
background: darkTheme.shellBg,
boxShadow: 'none',
backdropFilter: darkTheme.shellBackdropFilter,
});
});
it('can scale shell surface alpha with the current appearance opacity so reminder layers stay visually consistent', () => {
const lightTheme = buildOverlayWorkbenchTheme(false);
const fadedShell = getSecurityUpdateShellSurfaceStyle(lightTheme, 0.5);
const fadedBanner = getSecurityUpdateBannerSurfaceStyle(lightTheme, 0.5);
expect(fadedShell.background).not.toBe(lightTheme.shellBg);
expect(fadedShell.border).not.toBe(lightTheme.shellBorder);
expect(fadedShell.background).toContain('0.49');
expect(fadedBanner.background).toContain('0.49');
});
it('can emphasize a section surface for transient focus and recent-result highlighting', () => {
const lightTheme = buildOverlayWorkbenchTheme(false);
const darkTheme = buildOverlayWorkbenchTheme(true);
expect(getSecurityUpdateSectionSurfaceStyle(lightTheme)).toMatchObject({
border: lightTheme.sectionBorder,
background: lightTheme.sectionBg,
boxShadow: 'none',
});
expect(getSecurityUpdateSectionSurfaceStyle(darkTheme)).toMatchObject({
border: darkTheme.sectionBorder,
background: darkTheme.sectionBg,
boxShadow: 'none',
});
const emphasizedLight = getSecurityUpdateSectionSurfaceStyle(lightTheme, { emphasized: true });
const emphasizedDark = getSecurityUpdateSectionSurfaceStyle(darkTheme, { emphasized: true });
expect(emphasizedLight.background).not.toBe(lightTheme.sectionBg);
expect(emphasizedLight.boxShadow).not.toBe('none');
expect(emphasizedDark.background).not.toBe(darkTheme.sectionBg);
expect(emphasizedDark.boxShadow).not.toBe('none');
});
});

View File

@@ -0,0 +1,94 @@
import type { CSSProperties } from 'react';
import type { OverlayWorkbenchTheme } from './overlayWorkbenchTheme';
export const SECURITY_UPDATE_ACTION_BUTTON_CLASS = 'security-update-action-btn';
export const SECURITY_UPDATE_BANNER_CLASS = 'security-update-banner';
export const SECURITY_UPDATE_MODAL_CLASS = 'security-update-modal';
export const SECURITY_UPDATE_RESULT_CARD_CLASS = 'security-update-result-card';
export const SECURITY_UPDATE_RESULT_CARD_ACTIVE_CLASS = 'security-update-result-card-active';
type SecurityUpdateSectionSurfaceOptions = {
emphasized?: boolean;
surfaceOpacity?: number;
};
const clampOpacity = (value: number): number => Math.min(1, Math.max(0.1, value));
const formatAlpha = (value: number): string => (
Number(value.toFixed(3)).toString()
);
const applySurfaceOpacity = (token: string, surfaceOpacity = 1): string => {
const normalizedOpacity = clampOpacity(surfaceOpacity);
if (normalizedOpacity >= 0.999) {
return token;
}
return token.replace(
/rgba\(\s*([^)]+?)\s*,\s*([0-9]*\.?[0-9]+)\s*\)/g,
(_, channels: string, alpha: string) => `rgba(${channels}, ${formatAlpha(Number(alpha) * normalizedOpacity)})`,
);
};
const getSecurityUpdateHighlightBorder = (overlayTheme: OverlayWorkbenchTheme): string => (
overlayTheme.isDark
? '1px solid rgba(255,214,102,0.26)'
: '1px solid rgba(22,119,255,0.22)'
);
const getSecurityUpdateHighlightBackground = (overlayTheme: OverlayWorkbenchTheme): string => (
overlayTheme.isDark
? 'linear-gradient(180deg, rgba(255,214,102,0.14) 0%, rgba(255,255,255,0.05) 100%)'
: 'linear-gradient(180deg, rgba(22,119,255,0.12) 0%, rgba(255,255,255,0.96) 100%)'
);
const getSecurityUpdateHighlightShadow = (overlayTheme: OverlayWorkbenchTheme): string => (
overlayTheme.isDark
? '0 0 0 1px rgba(255,214,102,0.12), 0 12px 24px rgba(0,0,0,0.16)'
: '0 0 0 1px rgba(22,119,255,0.08), 0 10px 22px rgba(15,23,42,0.08)'
);
export const getSecurityUpdateActionButtonStyle = (): CSSProperties => ({
height: 36,
borderRadius: 12,
paddingInline: 16,
boxShadow: 'none',
fontWeight: 600,
});
export const getSecurityUpdateShellSurfaceStyle = (
overlayTheme: OverlayWorkbenchTheme,
surfaceOpacity = 1,
): CSSProperties => ({
border: applySurfaceOpacity(overlayTheme.shellBorder, surfaceOpacity),
background: applySurfaceOpacity(overlayTheme.shellBg, surfaceOpacity),
boxShadow: applySurfaceOpacity(overlayTheme.shellShadow, surfaceOpacity),
backdropFilter: overlayTheme.shellBackdropFilter,
});
export const getSecurityUpdateBannerSurfaceStyle = (
overlayTheme: OverlayWorkbenchTheme,
surfaceOpacity = 1,
): CSSProperties => ({
...getSecurityUpdateShellSurfaceStyle(overlayTheme, surfaceOpacity),
boxShadow: 'none',
});
export const getSecurityUpdateSectionSurfaceStyle = (
overlayTheme: OverlayWorkbenchTheme,
options: SecurityUpdateSectionSurfaceOptions = {},
): CSSProperties => ({
border: applySurfaceOpacity(
options.emphasized ? getSecurityUpdateHighlightBorder(overlayTheme) : overlayTheme.sectionBorder,
options.surfaceOpacity,
),
background: applySurfaceOpacity(
options.emphasized ? getSecurityUpdateHighlightBackground(overlayTheme) : overlayTheme.sectionBg,
options.surfaceOpacity,
),
boxShadow: options.emphasized
? applySurfaceOpacity(getSecurityUpdateHighlightShadow(overlayTheme), options.surfaceOpacity)
: 'none',
transition: 'background 180ms ease, border-color 180ms ease, box-shadow 180ms ease',
});

View File

@@ -2,6 +2,7 @@
// This file is automatically generated. DO NOT EDIT
import {connection} from '../models';
import {sync} from '../models';
import {app} from '../models';
import {redis} from '../models';
export function ApplyChanges(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:connection.ChangeSet):Promise<connection.QueryResult>;
@@ -16,6 +17,8 @@ export function CheckDriverNetworkStatus():Promise<connection.QueryResult>;
export function CheckForUpdates():Promise<connection.QueryResult>;
export function CheckForUpdatesSilently():Promise<connection.QueryResult>;
export function ClearTables(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;
export function ConfigureDriverRuntimeDirectory(arg1:string):Promise<connection.QueryResult>;
@@ -58,6 +61,8 @@ export function DataSyncPreview(arg1:sync.SyncConfig,arg2:string,arg3:number):Pr
export function DeleteConnection(arg1:string):Promise<void>;
export function DismissSecurityUpdateReminder():Promise<app.SecurityUpdateStatus>;
export function DownloadDriverPackage(arg1:string,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
export function DownloadUpdate():Promise<connection.QueryResult>;
@@ -74,6 +79,8 @@ export function DuplicateConnection(arg1:string):Promise<connection.SavedConnect
export function ExecuteSQLFile(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
export function ExportConnectionsPackage(arg1:app.ConnectionExportOptions):Promise<connection.QueryResult>;
export function ExportData(arg1:Array<Record<string, any>>,arg2:Array<string>,arg3:string,arg4:string):Promise<connection.QueryResult>;
export function ExportDatabaseSQL(arg1:connection.ConnectionConfig,arg2:string,arg3:boolean):Promise<connection.QueryResult>;
@@ -102,8 +109,12 @@ export function GetGlobalProxyConfig():Promise<connection.QueryResult>;
export function GetSavedConnections():Promise<Array<connection.SavedConnectionView>>;
export function GetSecurityUpdateStatus():Promise<app.SecurityUpdateStatus>;
export function ImportConfigFile():Promise<connection.QueryResult>;
export function ImportConnectionsPayload(arg1:string,arg2:string):Promise<Array<connection.SavedConnectionView>>;
export function ImportData(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise<connection.QueryResult>;
export function ImportDataWithProgress(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise<connection.QueryResult>;
@@ -202,6 +213,10 @@ export function ResolveDriverPackageDownloadURL(arg1:string,arg2:string):Promise
export function ResolveDriverRepositoryURL(arg1:string):Promise<connection.QueryResult>;
export function RestartSecurityUpdate(arg1:app.RestartSecurityUpdateRequest):Promise<app.SecurityUpdateStatus>;
export function RetrySecurityUpdateCurrentRound(arg1:app.RetrySecurityUpdateRequest):Promise<app.SecurityUpdateStatus>;
export function SaveConnection(arg1:connection.SavedConnectionInput):Promise<connection.SavedConnectionView>;
export function SaveGlobalProxy(arg1:connection.SaveGlobalProxyInput):Promise<connection.GlobalProxyView>;
@@ -222,6 +237,8 @@ export function SetMacNativeWindowControls(arg1:boolean):Promise<void>;
export function SetWindowTranslucency(arg1:number,arg2:number):Promise<void>;
export function StartSecurityUpdate(arg1:app.StartSecurityUpdateRequest):Promise<app.SecurityUpdateStatus>;
export function TestConnection(arg1:connection.ConnectionConfig):Promise<connection.QueryResult>;
export function TruncateTables(arg1:connection.ConnectionConfig,arg2:string,arg3:Array<string>):Promise<connection.QueryResult>;

View File

@@ -26,6 +26,10 @@ export function CheckForUpdates() {
return window['go']['app']['App']['CheckForUpdates']();
}
export function CheckForUpdatesSilently() {
return window['go']['app']['App']['CheckForUpdatesSilently']();
}
export function ClearTables(arg1, arg2, arg3) {
return window['go']['app']['App']['ClearTables'](arg1, arg2, arg3);
}
@@ -110,6 +114,10 @@ export function DeleteConnection(arg1) {
return window['go']['app']['App']['DeleteConnection'](arg1);
}
export function DismissSecurityUpdateReminder() {
return window['go']['app']['App']['DismissSecurityUpdateReminder']();
}
export function DownloadDriverPackage(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['DownloadDriverPackage'](arg1, arg2, arg3, arg4);
}
@@ -142,6 +150,10 @@ export function ExecuteSQLFile(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['ExecuteSQLFile'](arg1, arg2, arg3, arg4);
}
export function ExportConnectionsPackage(arg1) {
return window['go']['app']['App']['ExportConnectionsPackage'](arg1);
}
export function ExportData(arg1, arg2, arg3, arg4) {
return window['go']['app']['App']['ExportData'](arg1, arg2, arg3, arg4);
}
@@ -198,10 +210,18 @@ export function GetSavedConnections() {
return window['go']['app']['App']['GetSavedConnections']();
}
export function GetSecurityUpdateStatus() {
return window['go']['app']['App']['GetSecurityUpdateStatus']();
}
export function ImportConfigFile() {
return window['go']['app']['App']['ImportConfigFile']();
}
export function ImportConnectionsPayload(arg1, arg2) {
return window['go']['app']['App']['ImportConnectionsPayload'](arg1, arg2);
}
export function ImportData(arg1, arg2, arg3) {
return window['go']['app']['App']['ImportData'](arg1, arg2, arg3);
}
@@ -398,6 +418,14 @@ export function ResolveDriverRepositoryURL(arg1) {
return window['go']['app']['App']['ResolveDriverRepositoryURL'](arg1);
}
export function RestartSecurityUpdate(arg1) {
return window['go']['app']['App']['RestartSecurityUpdate'](arg1);
}
export function RetrySecurityUpdateCurrentRound(arg1) {
return window['go']['app']['App']['RetrySecurityUpdateCurrentRound'](arg1);
}
export function SaveConnection(arg1) {
return window['go']['app']['App']['SaveConnection'](arg1);
}
@@ -438,6 +466,10 @@ export function SetWindowTranslucency(arg1, arg2) {
return window['go']['app']['App']['SetWindowTranslucency'](arg1, arg2);
}
export function StartSecurityUpdate(arg1) {
return window['go']['app']['App']['StartSecurityUpdate'](arg1);
}
export function TestConnection(arg1) {
return window['go']['app']['App']['TestConnection'](arg1);
}

View File

@@ -179,6 +179,233 @@ export namespace ai {
}
export namespace app {
export class ConnectionExportOptions {
includeSecrets: boolean;
filePassword?: string;
static createFrom(source: any = {}) {
return new ConnectionExportOptions(source);
}
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.includeSecrets = source["includeSecrets"];
this.filePassword = source["filePassword"];
}
}
export class SecurityUpdateOptions {
allowPartial?: boolean;
writeBackup?: boolean;
static createFrom(source: any = {}) {
return new SecurityUpdateOptions(source);
}
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.allowPartial = source["allowPartial"];
this.writeBackup = source["writeBackup"];
}
}
export class RestartSecurityUpdateRequest {
migrationId?: string;
sourceType: string;
rawPayload?: string;
options?: SecurityUpdateOptions;
static createFrom(source: any = {}) {
return new RestartSecurityUpdateRequest(source);
}
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.migrationId = source["migrationId"];
this.sourceType = source["sourceType"];
this.rawPayload = source["rawPayload"];
this.options = this.convertValues(source["options"], SecurityUpdateOptions);
}
convertValues(a: any, classs: any, asMap: boolean = false): any {
if (!a) {
return a;
}
if (a.slice && a.map) {
return (a as any[]).map(elem => this.convertValues(elem, classs));
} else if ("object" === typeof a) {
if (asMap) {
for (const key of Object.keys(a)) {
a[key] = new classs(a[key]);
}
return a;
}
return new classs(a);
}
return a;
}
}
export class RetrySecurityUpdateRequest {
migrationId?: string;
static createFrom(source: any = {}) {
return new RetrySecurityUpdateRequest(source);
}
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.migrationId = source["migrationId"];
}
}
export class SecurityUpdateIssue {
id: string;
scope: string;
refId?: string;
title: string;
severity: string;
status: string;
reasonCode: string;
action: string;
message: string;
static createFrom(source: any = {}) {
return new SecurityUpdateIssue(source);
}
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.id = source["id"];
this.scope = source["scope"];
this.refId = source["refId"];
this.title = source["title"];
this.severity = source["severity"];
this.status = source["status"];
this.reasonCode = source["reasonCode"];
this.action = source["action"];
this.message = source["message"];
}
}
export class SecurityUpdateSummary {
total: number;
updated: number;
pending: number;
skipped: number;
failed: number;
static createFrom(source: any = {}) {
return new SecurityUpdateSummary(source);
}
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.total = source["total"];
this.updated = source["updated"];
this.pending = source["pending"];
this.skipped = source["skipped"];
this.failed = source["failed"];
}
}
export class SecurityUpdateStatus {
schemaVersion?: number;
migrationId?: string;
overallStatus: string;
sourceType?: string;
reminderVisible: boolean;
canStart: boolean;
canPostpone: boolean;
canRetry: boolean;
backupAvailable: boolean;
backupPath?: string;
startedAt?: string;
updatedAt?: string;
completedAt?: string;
postponedAt?: string;
summary: SecurityUpdateSummary;
issues: SecurityUpdateIssue[];
lastError?: string;
static createFrom(source: any = {}) {
return new SecurityUpdateStatus(source);
}
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.schemaVersion = source["schemaVersion"];
this.migrationId = source["migrationId"];
this.overallStatus = source["overallStatus"];
this.sourceType = source["sourceType"];
this.reminderVisible = source["reminderVisible"];
this.canStart = source["canStart"];
this.canPostpone = source["canPostpone"];
this.canRetry = source["canRetry"];
this.backupAvailable = source["backupAvailable"];
this.backupPath = source["backupPath"];
this.startedAt = source["startedAt"];
this.updatedAt = source["updatedAt"];
this.completedAt = source["completedAt"];
this.postponedAt = source["postponedAt"];
this.summary = this.convertValues(source["summary"], SecurityUpdateSummary);
this.issues = this.convertValues(source["issues"], SecurityUpdateIssue);
this.lastError = source["lastError"];
}
convertValues(a: any, classs: any, asMap: boolean = false): any {
if (!a) {
return a;
}
if (a.slice && a.map) {
return (a as any[]).map(elem => this.convertValues(elem, classs));
} else if ("object" === typeof a) {
if (asMap) {
for (const key of Object.keys(a)) {
a[key] = new classs(a[key]);
}
return a;
}
return new classs(a);
}
return a;
}
}
export class StartSecurityUpdateRequest {
sourceType: string;
rawPayload?: string;
options?: SecurityUpdateOptions;
static createFrom(source: any = {}) {
return new StartSecurityUpdateRequest(source);
}
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.sourceType = source["sourceType"];
this.rawPayload = source["rawPayload"];
this.options = this.convertValues(source["options"], SecurityUpdateOptions);
}
convertValues(a: any, classs: any, asMap: boolean = false): any {
if (!a) {
return a;
}
if (a.slice && a.map) {
return (a as any[]).map(elem => this.convertValues(elem, classs));
} else if ("object" === typeof a) {
if (asMap) {
for (const key of Object.keys(a)) {
a[key] = new classs(a[key]);
}
return a;
}
return new classs(a);
}
return a;
}
}
}
export namespace connection {
export class UpdateRow {

6
go.mod
View File

@@ -26,6 +26,12 @@ require (
modernc.org/sqlite v1.44.3
)
require (
github.com/kr/pretty v0.3.1 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect

13
go.sum
View File

@@ -38,6 +38,7 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/danieljoos/wincred v1.1.2 h1:QLdCxFs1/Yl4zduvBdcHB8goaYk9RARS2SgLLRuAyr0=
github.com/danieljoos/wincred v1.1.2/go.mod h1:GijpziifJoIBfYh+S7BbkdUTU4LfM+QnGqR5Vl2tAx0=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -126,6 +127,9 @@ github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxh
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@@ -174,7 +178,6 @@ github.com/mtibben/percent v0.2.1 h1:5gssi8Nqo8QU/r2pynCm+hBQHpkB/uNK7BJCFogWdzs
github.com/mtibben/percent v0.2.1/go.mod h1:KG9uO+SZkUp+VkRHsCdYQV3XSZrrSpR3O9ibNBTZrns=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/paulmach/orb v0.12.0 h1:z+zOwjmG3MyEEqzv92UN49Lg1JFYx0L9GpGKNVDKk1s=
github.com/paulmach/orb v0.12.0/go.mod h1:5mULz1xQfs3bmQm63QEJA6lNGujuRafwA5S/EnuLaLU=
@@ -183,6 +186,7 @@ github.com/pierrec/lz4/v4 v4.1.25 h1:kocOqRffaIbU5djlIBr7Wh+cx82C0vtFb0fOurZHqD0
github.com/pierrec/lz4/v4 v4.1.25/go.mod h1:EoQMVJgeeEOMsCqCzqFm2O0cJvljX2nGZjcRIPL34O4=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -200,6 +204,9 @@ github.com/richardlehane/msoleps v1.0.4/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTK
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/samber/lo v1.49.1 h1:4BIFyVfuQSEpluc7Fua+j1NolZHiEHEpaSEKdsH0tew=
github.com/samber/lo v1.49.1/go.mod h1:dO6KHFzUKXgP8LDhU0oI8d2hekjXnGOu0DB8Jecxd6o=
github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0=
@@ -356,9 +363,9 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,262 @@
package aiservice
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"GoNavi-Wails/internal/ai"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/secretstore"
)
const (
aiConfigSchemaVersion = 2
aiConfigFileName = "ai_config.json"
)
type aiConfig struct {
SchemaVersion int `json:"schemaVersion,omitempty"`
Providers []ai.ProviderConfig `json:"providers"`
ActiveProvider string `json:"activeProvider"`
SafetyLevel string `json:"safetyLevel"`
ContextLevel string `json:"contextLevel"`
}
type ProviderConfigStoreSnapshot struct {
Providers []ai.ProviderConfig
ActiveProvider string
SafetyLevel ai.SQLPermissionLevel
ContextLevel ai.ContextLevel
}
type ProviderConfigStoreInspection struct {
Snapshot ProviderConfigStoreSnapshot
ProvidersNeedingMigration []string
}
type ProviderConfigStore struct {
configDir string
secretStore secretstore.SecretStore
}
func NewProviderConfigStore(configDir string, store secretstore.SecretStore) *ProviderConfigStore {
if strings.TrimSpace(configDir) == "" {
configDir = resolveConfigDir()
}
if store == nil {
store = secretstore.NewUnavailableStore("secret store unavailable")
}
return &ProviderConfigStore{
configDir: configDir,
secretStore: store,
}
}
func newProviderConfigStore(configDir string, store secretstore.SecretStore) *ProviderConfigStore {
return NewProviderConfigStore(configDir, store)
}
func (s *ProviderConfigStore) configPath() string {
return filepath.Join(s.configDir, aiConfigFileName)
}
func (s *ProviderConfigStore) Load() (ProviderConfigStoreSnapshot, error) {
cfg, snapshot, err := s.readStoredSnapshot()
if err != nil {
return snapshot, err
}
shouldRewrite := cfg.SchemaVersion != aiConfigSchemaVersion
providers := make([]ai.ProviderConfig, 0, len(snapshot.Providers))
for _, providerConfig := range snapshot.Providers {
runtimeConfig, rewritten, loadErr := s.loadStoredProviderConfig(providerConfig)
if loadErr != nil {
return snapshot, fmt.Errorf("加载 AI Provider secret 失败(provider=%s): %w", providerConfig.ID, loadErr)
}
if rewritten {
shouldRewrite = true
}
providers = append(providers, runtimeConfig)
}
if providers == nil {
providers = []ai.ProviderConfig{}
}
snapshot.Providers = providers
if shouldRewrite {
if err := s.Save(snapshot); err != nil {
return snapshot, fmt.Errorf("重写 AI 配置失败: %w", err)
}
}
return snapshot, nil
}
func (s *ProviderConfigStore) LoadRuntime() (ProviderConfigStoreSnapshot, error) {
_, snapshot, err := s.readStoredSnapshot()
if err != nil {
return snapshot, err
}
providers := make([]ai.ProviderConfig, 0, len(snapshot.Providers))
for _, providerConfig := range snapshot.Providers {
runtimeConfig, loadErr := s.loadRuntimeProviderConfig(providerConfig)
if loadErr != nil {
logger.Error(loadErr, "加载 AI Provider secret 失败provider=%s", providerConfig.ID)
}
providers = append(providers, runtimeConfig)
}
if providers == nil {
providers = []ai.ProviderConfig{}
}
snapshot.Providers = providers
return snapshot, nil
}
func (s *ProviderConfigStore) Inspect() (ProviderConfigStoreInspection, error) {
_, snapshot, err := s.readStoredSnapshot()
inspection := ProviderConfigStoreInspection{
Snapshot: snapshot,
ProvidersNeedingMigration: []string{},
}
if err != nil {
return inspection, err
}
for _, providerConfig := range snapshot.Providers {
if providerNeedsMigration(providerConfig) {
inspection.ProvidersNeedingMigration = append(inspection.ProvidersNeedingMigration, providerConfig.ID)
}
}
return inspection, nil
}
func (s *ProviderConfigStore) Save(snapshot ProviderConfigStoreSnapshot) error {
providers := make([]ai.ProviderConfig, 0, len(snapshot.Providers))
for _, providerConfig := range snapshot.Providers {
runtimeConfig := normalizeProviderConfig(providerConfig)
meta, bundle := splitProviderSecrets(runtimeConfig)
if bundle.hasAny() {
storedMeta, err := persistProviderSecretBundle(s.secretStore, meta, bundle)
if err != nil {
return fmt.Errorf("保存 Provider secret 失败: %w", err)
}
meta = storedMeta
}
providers = append(providers, providerMetadataView(meta))
}
if providers == nil {
providers = []ai.ProviderConfig{}
}
cfg := aiConfig{
SchemaVersion: aiConfigSchemaVersion,
Providers: providers,
ActiveProvider: snapshot.ActiveProvider,
SafetyLevel: string(snapshot.SafetyLevel),
ContextLevel: string(snapshot.ContextLevel),
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("序列化 AI 配置失败: %w", err)
}
if err := os.MkdirAll(s.configDir, 0o755); err != nil {
return fmt.Errorf("创建配置目录失败: %w", err)
}
if err := os.WriteFile(s.configPath(), data, 0o644); err != nil {
return fmt.Errorf("写入 AI 配置失败: %w", err)
}
return nil
}
func (s *ProviderConfigStore) readStoredSnapshot() (aiConfig, ProviderConfigStoreSnapshot, error) {
snapshot := ProviderConfigStoreSnapshot{
Providers: []ai.ProviderConfig{},
SafetyLevel: ai.PermissionReadOnly,
ContextLevel: ai.ContextSchemaOnly,
}
data, err := os.ReadFile(s.configPath())
if err != nil {
if os.IsNotExist(err) {
return aiConfig{}, snapshot, nil
}
return aiConfig{}, snapshot, fmt.Errorf("读取 AI 配置失败: %w", err)
}
var cfg aiConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return aiConfig{}, snapshot, fmt.Errorf("加载 AI 配置失败: %w", err)
}
snapshot.ActiveProvider = cfg.ActiveProvider
switch ai.SQLPermissionLevel(cfg.SafetyLevel) {
case ai.PermissionReadOnly, ai.PermissionReadWrite, ai.PermissionFull:
snapshot.SafetyLevel = ai.SQLPermissionLevel(cfg.SafetyLevel)
}
switch ai.ContextLevel(cfg.ContextLevel) {
case ai.ContextSchemaOnly, ai.ContextWithSamples, ai.ContextWithResults:
snapshot.ContextLevel = ai.ContextLevel(cfg.ContextLevel)
}
providers := make([]ai.ProviderConfig, 0, len(cfg.Providers))
for _, providerConfig := range cfg.Providers {
providers = append(providers, normalizeProviderConfig(providerConfig))
}
if providers == nil {
providers = []ai.ProviderConfig{}
}
snapshot.Providers = providers
return cfg, snapshot, nil
}
func (s *ProviderConfigStore) loadStoredProviderConfig(config ai.ProviderConfig) (ai.ProviderConfig, bool, error) {
meta, bundle := splitProviderSecrets(config)
if bundle.hasAny() {
storedMeta, err := persistProviderSecretBundle(s.secretStore, meta, bundle)
if err != nil {
return meta, false, err
}
return mergeProviderSecrets(storedMeta, bundle), true, nil
}
if !meta.HasSecret {
return meta, false, nil
}
resolved, err := resolveProviderConfigSecrets(s.secretStore, meta)
if err != nil {
if os.IsNotExist(err) {
return meta, false, nil
}
return meta, false, err
}
return resolved, false, nil
}
func (s *ProviderConfigStore) loadRuntimeProviderConfig(config ai.ProviderConfig) (ai.ProviderConfig, error) {
meta, bundle := splitProviderSecrets(config)
if bundle.hasAny() {
return mergeProviderSecrets(meta, bundle), nil
}
if !meta.HasSecret {
return meta, nil
}
resolved, err := resolveProviderConfigSecrets(s.secretStore, meta)
if err != nil {
return meta, err
}
return resolved, nil
}
func providerNeedsMigration(config ai.ProviderConfig) bool {
_, bundle := splitProviderSecrets(normalizeProviderConfig(config))
return bundle.hasAny()
}

View File

@@ -0,0 +1,206 @@
package aiservice
import (
"encoding/json"
"os"
"path/filepath"
"strings"
"testing"
"GoNavi-Wails/internal/ai"
"GoNavi-Wails/internal/secretstore"
)
func TestProviderConfigStoreLoadMigratesPlaintextProviderSecrets(t *testing.T) {
store := newFakeProviderSecretStore()
configStore := newProviderConfigStore(t.TempDir(), store)
legacy := aiConfig{
Providers: []ai.ProviderConfig{
{
ID: "openai-main",
Type: "openai",
Name: "OpenAI",
APIKey: "sk-test",
BaseURL: "https://api.openai.com/v1",
Headers: map[string]string{
"Authorization": "Bearer test",
"X-Team": "platform",
},
},
},
}
data, err := json.MarshalIndent(legacy, "", " ")
if err != nil {
t.Fatalf("MarshalIndent returned error: %v", err)
}
if err := os.WriteFile(filepath.Join(configStore.configDir, aiConfigFileName), data, 0o644); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}
snapshot, err := configStore.Load()
if err != nil {
t.Fatalf("Load returned error: %v", err)
}
if len(snapshot.Providers) != 1 {
t.Fatalf("expected 1 provider, got %d", len(snapshot.Providers))
}
if snapshot.Providers[0].APIKey != "sk-test" {
t.Fatalf("expected runtime provider to restore apiKey, got %q", snapshot.Providers[0].APIKey)
}
if snapshot.Providers[0].Headers["Authorization"] != "Bearer test" {
t.Fatalf("expected runtime provider to restore sensitive header, got %#v", snapshot.Providers[0].Headers)
}
stored, err := store.Get(snapshot.Providers[0].SecretRef)
if err != nil {
t.Fatalf("expected migrated provider secret bundle, got %v", err)
}
var bundle providerSecretBundle
if err := json.Unmarshal(stored, &bundle); err != nil {
t.Fatalf("Unmarshal returned error: %v", err)
}
if bundle.APIKey != "sk-test" {
t.Fatalf("expected migrated apiKey in store, got %q", bundle.APIKey)
}
rewritten, err := os.ReadFile(filepath.Join(configStore.configDir, aiConfigFileName))
if err != nil {
t.Fatalf("ReadFile returned error: %v", err)
}
text := string(rewritten)
if strings.Contains(text, "sk-test") {
t.Fatalf("expected rewritten config to be secretless, got %s", text)
}
if strings.Contains(text, "Bearer test") {
t.Fatalf("expected rewritten config to remove sensitive headers, got %s", text)
}
}
func TestProviderConfigStoreSavePersistsSecretlessMetadata(t *testing.T) {
store := newFakeProviderSecretStore()
configStore := newProviderConfigStore(t.TempDir(), store)
err := configStore.Save(ProviderConfigStoreSnapshot{
Providers: []ai.ProviderConfig{
{
ID: "openai-main",
Type: "openai",
Name: "OpenAI",
APIKey: "sk-test",
BaseURL: "https://api.openai.com/v1",
Headers: map[string]string{
"Authorization": "Bearer test",
"X-Team": "platform",
},
},
},
ActiveProvider: "openai-main",
SafetyLevel: ai.PermissionReadOnly,
ContextLevel: ai.ContextSchemaOnly,
})
if err != nil {
t.Fatalf("Save returned error: %v", err)
}
configData, err := os.ReadFile(filepath.Join(configStore.configDir, aiConfigFileName))
if err != nil {
t.Fatalf("ReadFile returned error: %v", err)
}
text := string(configData)
if strings.Contains(text, "sk-test") {
t.Fatalf("expected config file to be secretless, got %s", text)
}
if strings.Contains(text, "Bearer test") {
t.Fatalf("expected config file to remove sensitive headers, got %s", text)
}
ref, err := secretstore.BuildRef(providerSecretKind, "openai-main")
if err != nil {
t.Fatalf("BuildRef returned error: %v", err)
}
stored, err := store.Get(ref)
if err != nil {
t.Fatalf("expected provider secret bundle in store, got %v", err)
}
var bundle providerSecretBundle
if err := json.Unmarshal(stored, &bundle); err != nil {
t.Fatalf("Unmarshal returned error: %v", err)
}
if bundle.APIKey != "sk-test" {
t.Fatalf("expected stored apiKey, got %q", bundle.APIKey)
}
if bundle.SensitiveHeaders["Authorization"] != "Bearer test" {
t.Fatalf("expected stored sensitive header, got %#v", bundle.SensitiveHeaders)
}
}
func TestProviderConfigStoreSaveKeepsExistingSecretRef(t *testing.T) {
store := newFakeProviderSecretStore()
configStore := newProviderConfigStore(t.TempDir(), store)
ref, err := secretstore.BuildRef(providerSecretKind, "openai-main")
if err != nil {
t.Fatalf("BuildRef returned error: %v", err)
}
payload, err := json.Marshal(providerSecretBundle{
APIKey: "sk-existing",
SensitiveHeaders: map[string]string{
"Authorization": "Bearer existing",
},
})
if err != nil {
t.Fatalf("Marshal returned error: %v", err)
}
if err := store.Put(ref, payload); err != nil {
t.Fatalf("Put returned error: %v", err)
}
err = configStore.Save(ProviderConfigStoreSnapshot{
Providers: []ai.ProviderConfig{
{
ID: "openai-main",
Type: "openai",
Name: "OpenAI",
HasSecret: true,
SecretRef: ref,
BaseURL: "https://gateway.openai.com/v1",
Headers: map[string]string{
"X-Team": "platform",
},
},
},
ActiveProvider: "openai-main",
SafetyLevel: ai.PermissionReadOnly,
ContextLevel: ai.ContextSchemaOnly,
})
if err != nil {
t.Fatalf("Save returned error: %v", err)
}
stored, err := store.Get(ref)
if err != nil {
t.Fatalf("expected existing provider secret bundle to remain available, got %v", err)
}
var bundle providerSecretBundle
if err := json.Unmarshal(stored, &bundle); err != nil {
t.Fatalf("Unmarshal returned error: %v", err)
}
if bundle.APIKey != "sk-existing" {
t.Fatalf("expected existing apiKey to be kept, got %q", bundle.APIKey)
}
snapshot, err := configStore.Load()
if err != nil {
t.Fatalf("Load returned error: %v", err)
}
if len(snapshot.Providers) != 1 {
t.Fatalf("expected 1 provider after reload, got %d", len(snapshot.Providers))
}
if snapshot.Providers[0].APIKey != "sk-existing" {
t.Fatalf("expected reload to restore existing apiKey, got %q", snapshot.Providers[0].APIKey)
}
if snapshot.Providers[0].Headers["Authorization"] != "Bearer existing" {
t.Fatalf("expected reload to restore existing sensitive header, got %#v", snapshot.Providers[0].Headers)
}
}

View File

@@ -120,17 +120,17 @@ func mergeProviderSecrets(cfg ai.ProviderConfig, bundle providerSecretBundle) ai
return merged
}
func (s *Service) persistProviderSecretBundle(meta ai.ProviderConfig, bundle providerSecretBundle) (ai.ProviderConfig, error) {
func persistProviderSecretBundle(store secretstore.SecretStore, meta ai.ProviderConfig, bundle providerSecretBundle) (ai.ProviderConfig, error) {
meta, _ = splitProviderSecrets(meta)
if !bundle.hasAny() {
meta.HasSecret = false
meta.SecretRef = ""
return meta, nil
}
if s.secretStore == nil {
if store == nil {
return meta, fmt.Errorf("secret store unavailable")
}
if err := s.secretStore.HealthCheck(); err != nil {
if err := store.HealthCheck(); err != nil {
return meta, err
}
@@ -147,7 +147,7 @@ func (s *Service) persistProviderSecretBundle(meta ai.ProviderConfig, bundle pro
if err != nil {
return meta, fmt.Errorf("序列化 provider secret bundle 失败: %w", err)
}
if err := s.secretStore.Put(ref, payload); err != nil {
if err := store.Put(ref, payload); err != nil {
return meta, err
}
@@ -156,7 +156,7 @@ func (s *Service) persistProviderSecretBundle(meta ai.ProviderConfig, bundle pro
return meta, nil
}
func (s *Service) resolveProviderConfigSecrets(cfg ai.ProviderConfig) (ai.ProviderConfig, error) {
func resolveProviderConfigSecrets(store secretstore.SecretStore, cfg ai.ProviderConfig) (ai.ProviderConfig, error) {
cfg = normalizeProviderConfig(cfg)
meta, bundle := splitProviderSecrets(cfg)
if bundle.hasAny() {
@@ -165,7 +165,7 @@ func (s *Service) resolveProviderConfigSecrets(cfg ai.ProviderConfig) (ai.Provid
if !meta.HasSecret {
return meta, nil
}
if s.secretStore == nil {
if store == nil {
return meta, fmt.Errorf("secret store unavailable")
}
@@ -179,7 +179,7 @@ func (s *Service) resolveProviderConfigSecrets(cfg ai.ProviderConfig) (ai.Provid
meta.SecretRef = ref
}
payload, err := s.secretStore.Get(ref)
payload, err := store.Get(ref)
if err != nil {
return meta, err
}
@@ -191,6 +191,14 @@ func (s *Service) resolveProviderConfigSecrets(cfg ai.ProviderConfig) (ai.Provid
return mergeProviderSecrets(meta, stored), nil
}
func (s *Service) persistProviderSecretBundle(meta ai.ProviderConfig, bundle providerSecretBundle) (ai.ProviderConfig, error) {
return persistProviderSecretBundle(s.secretStore, meta, bundle)
}
func (s *Service) resolveProviderConfigSecrets(cfg ai.ProviderConfig) (ai.ProviderConfig, error) {
return resolveProviderConfigSecrets(s.secretStore, cfg)
}
func providerMetadataView(cfg ai.ProviderConfig) ai.ProviderConfig {
meta, _ := splitProviderSecrets(normalizeProviderConfig(cfg))
return meta

View File

@@ -82,7 +82,7 @@ func TestResolveProviderConfigSecretsRestoresStoredSecretBundle(t *testing.T) {
}
}
func TestLoadConfigMigratesPlaintextProviderSecrets(t *testing.T) {
func TestLoadConfigUsesPlaintextProviderSecretsWithoutSilentMigration(t *testing.T) {
store := newFakeProviderSecretStore()
service := NewServiceWithSecretStore(store)
service.configDir = t.TempDir()
@@ -118,24 +118,28 @@ func TestLoadConfigMigratesPlaintextProviderSecrets(t *testing.T) {
t.Fatalf("expected 1 provider, got %d", len(providers))
}
if providers[0].APIKey != "" {
t.Fatalf("expected migrated provider to be secretless, got %q", providers[0].APIKey)
t.Fatalf("expected provider view to stay secretless, got %q", providers[0].APIKey)
}
if !providers[0].HasSecret {
t.Fatal("expected migrated provider to report HasSecret=true")
t.Fatal("expected provider view to report HasSecret=true")
}
stored, err := store.Get(providers[0].SecretRef)
if len(service.providers) != 1 {
t.Fatalf("expected runtime providers to be loaded, got %d", len(service.providers))
}
if service.providers[0].APIKey != "sk-test" {
t.Fatalf("expected runtime provider to keep plaintext apiKey, got %q", service.providers[0].APIKey)
}
if service.providers[0].Headers["Authorization"] != "Bearer test" {
t.Fatalf("expected runtime provider to keep sensitive header, got %#v", service.providers[0].Headers)
}
ref, err := secretstore.BuildRef("ai-provider", "openai-main")
if err != nil {
t.Fatalf("expected secret bundle in store, got error: %v", err)
t.Fatalf("BuildRef returned error: %v", err)
}
var bundle providerSecretBundle
if err := json.Unmarshal(stored, &bundle); err != nil {
t.Fatalf("Unmarshal returned error: %v", err)
}
if bundle.APIKey != "sk-test" {
t.Fatalf("expected migrated apiKey in store, got %q", bundle.APIKey)
}
if bundle.SensitiveHeaders["Authorization"] != "Bearer test" {
t.Fatalf("expected migrated sensitive header in store, got %#v", bundle.SensitiveHeaders)
if _, err := store.Get(ref); !os.IsNotExist(err) {
t.Fatalf("expected startup load to avoid secret-store migration, got %v", err)
}
rewritten, err := os.ReadFile(configPath)
@@ -143,11 +147,124 @@ func TestLoadConfigMigratesPlaintextProviderSecrets(t *testing.T) {
t.Fatalf("ReadFile returned error: %v", err)
}
text := string(rewritten)
if strings.Contains(text, "sk-test") {
t.Fatalf("expected rewritten config to remove api key, got %s", text)
if !strings.Contains(text, "sk-test") {
t.Fatalf("expected config file to remain unchanged, got %s", text)
}
if strings.Contains(text, "Bearer test") {
t.Fatalf("expected rewritten config to remove sensitive header, got %s", text)
if !strings.Contains(text, "Bearer test") {
t.Fatalf("expected config file to keep sensitive header, got %s", text)
}
}
func TestAISaveProviderKeepsLegacyPlaintextSecretAfterStartupLoad(t *testing.T) {
store := newFakeProviderSecretStore()
service := NewServiceWithSecretStore(store)
service.configDir = t.TempDir()
legacy := aiConfig{
Providers: []ai.ProviderConfig{
{
ID: "openai-main",
Type: "custom",
Name: "OpenAI",
APIKey: "sk-test",
BaseURL: "",
Headers: map[string]string{
"Authorization": "Bearer test",
"X-Team": "db",
},
},
},
}
data, err := json.MarshalIndent(legacy, "", " ")
if err != nil {
t.Fatalf("MarshalIndent returned error: %v", err)
}
if err := os.WriteFile(filepath.Join(service.configDir, aiConfigFileName), data, 0o644); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}
service.loadConfig()
if err := service.AISaveProvider(ai.ProviderConfig{
ID: "openai-main",
Type: "custom",
Name: "OpenAI Updated",
BaseURL: "",
HasSecret: true,
Headers: map[string]string{
"X-Team": "platform",
},
}); err != nil {
t.Fatalf("AISaveProvider returned error: %v", err)
}
if service.providers[0].APIKey != "sk-test" {
t.Fatalf("expected runtime provider to keep legacy plaintext apiKey, got %q", service.providers[0].APIKey)
}
if service.providers[0].Headers["Authorization"] != "Bearer test" {
t.Fatalf("expected runtime provider to keep legacy sensitive header, got %#v", service.providers[0].Headers)
}
ref, err := secretstore.BuildRef("ai-provider", "openai-main")
if err != nil {
t.Fatalf("BuildRef returned error: %v", err)
}
stored, err := store.Get(ref)
if err != nil {
t.Fatalf("expected save to persist provider secret bundle, got %v", err)
}
var bundle providerSecretBundle
if err := json.Unmarshal(stored, &bundle); err != nil {
t.Fatalf("Unmarshal returned error: %v", err)
}
if bundle.APIKey != "sk-test" {
t.Fatalf("expected persisted apiKey, got %q", bundle.APIKey)
}
}
func TestAITestProviderUsesLegacyPlaintextSecretAfterStartupLoad(t *testing.T) {
store := newFakeProviderSecretStore()
service := NewServiceWithSecretStore(store)
service.configDir = t.TempDir()
legacy := aiConfig{
Providers: []ai.ProviderConfig{
{
ID: "openai-main",
Type: "custom",
Name: "OpenAI",
APIKey: "sk-test",
BaseURL: "",
Headers: map[string]string{
"Authorization": "Bearer test",
"X-Team": "db",
},
},
},
}
data, err := json.MarshalIndent(legacy, "", " ")
if err != nil {
t.Fatalf("MarshalIndent returned error: %v", err)
}
if err := os.WriteFile(filepath.Join(service.configDir, aiConfigFileName), data, 0o644); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}
service.loadConfig()
result := service.AITestProvider(ai.ProviderConfig{
ID: "openai-main",
Type: "custom",
Name: "OpenAI",
BaseURL: "",
HasSecret: true,
Headers: map[string]string{
"X-Team": "db",
},
})
if success, _ := result["success"].(bool); !success {
t.Fatalf("expected test provider to use in-memory legacy secret, got %#v", result)
}
}

View File

@@ -184,11 +184,16 @@ func (s *Service) AISaveProvider(config ai.ProviderConfig) error {
case found && (config.HasSecret || existing.HasSecret):
meta.SecretRef = existing.SecretRef
meta.HasSecret = config.HasSecret || existing.HasSecret
resolved, err := s.resolveProviderConfigSecrets(meta)
if err != nil {
return fmt.Errorf("读取已保存 Provider secret 失败: %w", err)
meta, existingBundle := applyExistingRuntimeProviderSecrets(meta, existing)
if existingBundle.hasAny() {
runtimeConfig = mergeProviderSecrets(meta, existingBundle)
} else {
resolved, err := s.resolveProviderConfigSecrets(meta)
if err != nil {
return fmt.Errorf("读取已保存 Provider secret 失败: %w", err)
}
runtimeConfig = resolved
}
runtimeConfig = resolved
default:
runtimeConfig = meta
}
@@ -258,22 +263,47 @@ func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{
}
if strings.TrimSpace(config.APIKey) == "" && (config.HasSecret || strings.TrimSpace(config.SecretRef) != "") {
s.mu.RLock()
var existing ai.ProviderConfig
found := false
if strings.TrimSpace(config.SecretRef) == "" {
for _, providerConfig := range s.providers {
if providerConfig.ID == config.ID {
existing = providerConfig
found = true
config.SecretRef = providerConfig.SecretRef
config.HasSecret = config.HasSecret || providerConfig.HasSecret
break
}
}
} else {
for _, providerConfig := range s.providers {
if providerConfig.ID == config.ID {
existing = providerConfig
found = true
break
}
}
}
s.mu.RUnlock()
resolved, err := s.resolveProviderConfigSecrets(config)
if err != nil {
return map[string]interface{}{"success": false, "message": fmt.Sprintf("连接测试失败: %s", err.Error())}
if found {
config, existingBundle := applyExistingRuntimeProviderSecrets(config, existing)
if existingBundle.hasAny() {
config = mergeProviderSecrets(config, existingBundle)
} else {
resolved, err := s.resolveProviderConfigSecrets(config)
if err != nil {
return map[string]interface{}{"success": false, "message": fmt.Sprintf("连接测试失败: %s", err.Error())}
}
config = resolved
}
} else {
resolved, err := s.resolveProviderConfigSecrets(config)
if err != nil {
return map[string]interface{}{"success": false, "message": fmt.Sprintf("连接测试失败: %s", err.Error())}
}
config = resolved
}
config = resolved
}
config = normalizeProviderConfig(config)
@@ -463,6 +493,15 @@ func normalizeProviderConfig(config ai.ProviderConfig) ai.ProviderConfig {
return config
}
func applyExistingRuntimeProviderSecrets(meta ai.ProviderConfig, existing ai.ProviderConfig) (ai.ProviderConfig, providerSecretBundle) {
existingMeta, existingBundle := splitProviderSecrets(normalizeProviderConfig(existing))
if strings.TrimSpace(meta.SecretRef) == "" {
meta.SecretRef = strings.TrimSpace(existingMeta.SecretRef)
}
meta.HasSecret = meta.HasSecret || existingMeta.HasSecret || existingBundle.hasAny()
return meta, existingBundle
}
func resolveModelsURL(config ai.ProviderConfig) string {
config = normalizeProviderConfig(config)
providerType := normalizedProviderType(config)
@@ -920,117 +959,27 @@ func (s *Service) getActiveProvider() (provider.Provider, error) {
// --- 配置持久化 ---
const aiConfigSchemaVersion = 2
type aiConfig struct {
SchemaVersion int `json:"schemaVersion,omitempty"`
Providers []ai.ProviderConfig `json:"providers"`
ActiveProvider string `json:"activeProvider"`
SafetyLevel string `json:"safetyLevel"`
ContextLevel string `json:"contextLevel"`
}
func (s *Service) loadRuntimeProviderConfig(config ai.ProviderConfig) (ai.ProviderConfig, bool, error) {
meta, bundle := splitProviderSecrets(config)
if bundle.hasAny() {
storedMeta, err := s.persistProviderSecretBundle(meta, bundle)
if err != nil {
meta.HasSecret = false
meta.SecretRef = ""
return meta, true, err
}
return mergeProviderSecrets(storedMeta, bundle), true, nil
}
resolved, err := s.resolveProviderConfigSecrets(meta)
if err != nil {
return meta, false, err
}
return resolved, false, nil
}
func (s *Service) loadConfig() {
path := filepath.Join(s.configDir, "ai_config.json")
data, err := os.ReadFile(path)
snapshot, err := NewProviderConfigStore(s.configDir, s.secretStore).LoadRuntime()
if err != nil {
return // 首次启动,无配置文件
}
var cfg aiConfig
if err := json.Unmarshal(data, &cfg); err != nil {
logger.Error(err, "加载 AI 配置失败")
return
}
providers := make([]ai.ProviderConfig, 0, len(cfg.Providers))
shouldRewrite := cfg.SchemaVersion != aiConfigSchemaVersion
for _, providerConfig := range cfg.Providers {
runtimeConfig, rewritten, err := s.loadRuntimeProviderConfig(normalizeProviderConfig(providerConfig))
if err != nil {
logger.Error(err, "加载 AI Provider secret 失败provider=%s", providerConfig.ID)
}
if rewritten {
shouldRewrite = true
}
providers = append(providers, runtimeConfig)
}
if providers == nil {
providers = make([]ai.ProviderConfig, 0)
}
s.providers = providers
s.activeProvider = cfg.ActiveProvider
switch ai.SQLPermissionLevel(cfg.SafetyLevel) {
case ai.PermissionReadOnly, ai.PermissionReadWrite, ai.PermissionFull:
s.safetyLevel = ai.SQLPermissionLevel(cfg.SafetyLevel)
default:
s.safetyLevel = ai.PermissionReadOnly
}
s.providers = snapshot.Providers
s.activeProvider = snapshot.ActiveProvider
s.safetyLevel = snapshot.SafetyLevel
s.guard.SetPermissionLevel(s.safetyLevel)
switch ai.ContextLevel(cfg.ContextLevel) {
case ai.ContextSchemaOnly, ai.ContextWithSamples, ai.ContextWithResults:
s.contextLevel = ai.ContextLevel(cfg.ContextLevel)
default:
s.contextLevel = ai.ContextSchemaOnly
}
if shouldRewrite {
if err := s.saveConfig(); err != nil {
logger.Error(err, "重写 AI 配置失败")
}
}
s.contextLevel = snapshot.ContextLevel
}
func (s *Service) saveConfig() error {
providers := make([]ai.ProviderConfig, len(s.providers))
for i := range s.providers {
providers[i] = providerMetadataView(s.providers[i])
}
cfg := aiConfig{
SchemaVersion: aiConfigSchemaVersion,
Providers: providers,
return NewProviderConfigStore(s.configDir, s.secretStore).Save(ProviderConfigStoreSnapshot{
Providers: s.providers,
ActiveProvider: s.activeProvider,
SafetyLevel: string(s.safetyLevel),
ContextLevel: string(s.contextLevel),
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("序列化 AI 配置失败: %w", err)
}
if err := os.MkdirAll(s.configDir, 0o755); err != nil {
return fmt.Errorf("创建配置目录失败: %w", err)
}
path := filepath.Join(s.configDir, "ai_config.json")
if err := os.WriteFile(path, data, 0o644); err != nil {
return fmt.Errorf("写入 AI 配置失败: %w", err)
}
return nil
SafetyLevel: s.safetyLevel,
ContextLevel: s.contextLevel,
})
}
// --- 会话文件持久化 ---

View File

@@ -25,6 +25,7 @@ import (
)
const dbCachePingInterval = 30 * time.Second
const dbConnectFailureCooldown = 30 * time.Second
const (
startupConnectRetryWindow = 20 * time.Second
@@ -42,6 +43,11 @@ type cachedDatabase struct {
lastPing time.Time
}
type cachedConnectFailure struct {
occurredAt time.Time
err error
}
type queryContext struct {
cancel context.CancelFunc
started time.Time
@@ -52,6 +58,7 @@ type App struct {
ctx context.Context
startedAt time.Time
dbCache map[string]cachedDatabase // Cache for DB connections
connectFailures map[string]cachedConnectFailure
mu sync.RWMutex // Mutex for cache access
updateMu sync.Mutex
updateState updateState
@@ -72,6 +79,7 @@ func NewAppWithSecretStore(store secretstore.SecretStore) *App {
}
return &App{
dbCache: make(map[string]cachedDatabase),
connectFailures: make(map[string]cachedConnectFailure),
runningQueries: make(map[string]queryContext),
configDir: resolveAppConfigDir(),
secretStore: store,
@@ -94,7 +102,9 @@ func (a *App) startup(ctx context.Context) {
db.SetExternalDriverDownloadDirectory(appdata.DriverRoot(a.configDir))
logger.Init()
a.loadPersistedGlobalProxy()
installMacNativeWindowDiagnostics(logger.Path())
if shouldInstallMacNativeWindowDiagnostics() {
installMacNativeWindowDiagnostics(logger.Path())
}
applyMacWindowTranslucencyFix()
logger.Infof("应用启动完成(首次连接保护窗口=%s最多重试=%d 次)", startupConnectRetryWindow, startupConnectRetryAttempts)
}
@@ -602,14 +612,28 @@ func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing
if isFileDB {
logger.Infof("未命中文件库连接缓存,开始创建连接:类型=%s 缓存Key=%s", strings.TrimSpace(effectiveConfig.Type), shortKey)
}
if failure, remaining, ok := a.getCachedConnectFailureByKey(key); ok {
message := fmt.Sprintf("连接最近失败,正在冷却中,请 %s 后重试;上次错误:%s",
formatConnectFailureCooldown(remaining),
normalizeErrorMessage(failure.err),
)
logger.Warnf("命中数据库连接失败冷却:%s 缓存Key=%s 剩余=%s 原因=%s",
formatConnSummary(effectiveConfig), shortKey, formatConnectFailureCooldown(remaining), normalizeErrorMessage(failure.err))
return nil, withLogHint{err: fmt.Errorf("%s", message), logPath: logger.Path()}
}
initialKey := key
dbInst, connectedConfig, err := a.connectDatabaseWithStartupRetry(resolvedConfig)
if err != nil {
failedKey := getCacheKey(connectedConfig)
a.recordConnectFailureByKey(failedKey, err)
return nil, err
}
a.clearConnectFailureByKey(initialKey)
effectiveConfig = connectedConfig
key = getCacheKey(effectiveConfig)
shortKey = shortenCacheKey(key)
a.clearConnectFailureByKey(key)
now := time.Now()
@@ -630,6 +654,62 @@ func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing
return dbInst, nil
}
func (a *App) getCachedConnectFailureByKey(key string) (cachedConnectFailure, time.Duration, bool) {
if a == nil || strings.TrimSpace(key) == "" {
return cachedConnectFailure{}, 0, false
}
a.mu.RLock()
entry, exists := a.connectFailures[key]
a.mu.RUnlock()
if !exists || entry.err == nil || entry.occurredAt.IsZero() {
return cachedConnectFailure{}, 0, false
}
remaining := dbConnectFailureCooldown - time.Since(entry.occurredAt)
if remaining <= 0 {
a.clearConnectFailureByKey(key)
return cachedConnectFailure{}, 0, false
}
return entry, remaining, true
}
func (a *App) recordConnectFailureByKey(key string, err error) {
if a == nil || strings.TrimSpace(key) == "" || err == nil {
return
}
a.mu.Lock()
if a.connectFailures == nil {
a.connectFailures = make(map[string]cachedConnectFailure)
}
a.connectFailures[key] = cachedConnectFailure{
occurredAt: time.Now(),
err: err,
}
a.mu.Unlock()
}
func (a *App) clearConnectFailureByKey(key string) {
if a == nil || strings.TrimSpace(key) == "" {
return
}
a.mu.Lock()
if a.connectFailures != nil {
delete(a.connectFailures, key)
}
a.mu.Unlock()
}
func formatConnectFailureCooldown(remaining time.Duration) time.Duration {
if remaining <= time.Second {
return time.Second
}
return remaining.Truncate(time.Second)
}
func shortenCacheKey(key string) string {
if len(key) > 12 {
return key[:12]

View File

@@ -277,3 +277,167 @@ func TestIsTransientStartupConnectError(t *testing.T) {
t.Fatal("expected authentication failure to not be treated as transient startup connect error")
}
}
func TestGetDatabaseWithPing_CoolsDownRepeatedFailures(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
defer func() {
newDatabaseFunc = originalNewDatabaseFunc
resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc
}()
connectCalls := 0
newDatabaseFunc = func(dbType string) (db.Database, error) {
return &fakeStartupRetryDB{
connect: func(config connection.ConnectionConfig) error {
connectCalls++
return errors.New("dial tcp 10.1.131.86:5432: connect: connection refused")
},
}, nil
}
resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) {
return raw, nil
}
a := &App{
startedAt: time.Now().Add(-startupConnectRetryWindow - time.Second),
dbCache: make(map[string]cachedDatabase),
connectFailures: make(map[string]cachedConnectFailure),
runningQueries: make(map[string]queryContext),
}
config := connection.ConnectionConfig{Type: "postgres", Host: "10.1.131.86", Port: 5432, User: "postgres"}
_, firstErr := a.getDatabaseWithPing(config, false)
if firstErr == nil {
t.Fatal("expected first connection attempt to fail")
}
if connectCalls != 1 {
t.Fatalf("expected first request to use 1 connect attempt outside startup window, got %d", connectCalls)
}
_, secondErr := a.getDatabaseWithPing(config, false)
if secondErr == nil {
t.Fatal("expected second connection attempt to fail during cooldown")
}
if connectCalls != 1 {
t.Fatalf("expected repeated request during cooldown to avoid reconnecting, got %d connect attempts", connectCalls)
}
}
func TestGetDatabaseWithPing_AllowsRetryAfterFailureCooldown(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
defer func() {
newDatabaseFunc = originalNewDatabaseFunc
resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc
}()
connectCalls := 0
newDatabaseFunc = func(dbType string) (db.Database, error) {
return &fakeStartupRetryDB{
connect: func(config connection.ConnectionConfig) error {
connectCalls++
if connectCalls == 1 {
return errors.New("dial tcp 10.1.131.86:5432: connect: connection refused")
}
return nil
},
}, nil
}
resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) {
return raw, nil
}
a := &App{
startedAt: time.Now().Add(-startupConnectRetryWindow - time.Second),
dbCache: make(map[string]cachedDatabase),
connectFailures: make(map[string]cachedConnectFailure),
runningQueries: make(map[string]queryContext),
}
config := connection.ConnectionConfig{Type: "postgres", Host: "10.1.131.86", Port: 5432, User: "postgres"}
_, firstErr := a.getDatabaseWithPing(config, false)
if firstErr == nil {
t.Fatal("expected first connection attempt to fail")
}
if connectCalls != 1 {
t.Fatalf("expected first request to use 1 connect attempt outside startup window, got %d", connectCalls)
}
key := getCacheKey(config)
a.mu.Lock()
a.connectFailures[key] = cachedConnectFailure{
occurredAt: time.Now().Add(-dbConnectFailureCooldown - time.Second),
err: errors.New("dial tcp 10.1.131.86:5432: connect: connection refused"),
}
a.mu.Unlock()
inst, secondErr := a.getDatabaseWithPing(config, false)
if secondErr != nil {
t.Fatalf("expected retry after cooldown to be allowed, got error: %v", secondErr)
}
if inst == nil {
t.Fatal("expected database instance after cooldown retry")
}
if connectCalls != 2 {
t.Fatalf("expected reconnect after cooldown expiration, got %d connect attempts", connectCalls)
}
}
func TestGetDatabaseWithPing_ClearsFailureCooldownAfterSuccess(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
defer func() {
newDatabaseFunc = originalNewDatabaseFunc
resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc
}()
connectCalls := 0
newDatabaseFunc = func(dbType string) (db.Database, error) {
return &fakeStartupRetryDB{
connect: func(config connection.ConnectionConfig) error {
connectCalls++
if connectCalls == 1 {
return errors.New("dial tcp 10.1.131.86:5432: connect: connection refused")
}
return nil
},
}, nil
}
resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) {
return raw, nil
}
a := &App{
startedAt: time.Now().Add(-startupConnectRetryWindow - time.Second),
dbCache: make(map[string]cachedDatabase),
connectFailures: make(map[string]cachedConnectFailure),
runningQueries: make(map[string]queryContext),
}
config := connection.ConnectionConfig{Type: "postgres", Host: "10.1.131.86", Port: 5432, User: "postgres"}
_, firstErr := a.getDatabaseWithPing(config, false)
if firstErr == nil {
t.Fatal("expected first connection attempt to fail")
}
key := getCacheKey(config)
a.mu.Lock()
a.connectFailures[key] = cachedConnectFailure{
occurredAt: time.Now().Add(-dbConnectFailureCooldown - time.Second),
err: errors.New("dial tcp 10.1.131.86:5432: connect: connection refused"),
}
a.mu.Unlock()
_, secondErr := a.getDatabaseWithPing(config, false)
if secondErr != nil {
t.Fatalf("expected retry after cooldown to succeed, got error: %v", secondErr)
}
a.mu.RLock()
_, exists := a.connectFailures[key]
a.mu.RUnlock()
if exists {
t.Fatal("expected successful connection to clear cached failure cooldown")
}
}

View File

@@ -0,0 +1,196 @@
package app
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"strings"
"sync"
"golang.org/x/crypto/argon2"
)
const (
connectionPackageAppKeyPurpose = "gonavi-export-key-v2"
connectionPackageAppKeyFallbackSeed = "gonavi-connection-package-v2-seed"
connectionPackageAppKeyFallbackSalt = "gonavi-connection-package-v2-salt"
)
var (
connectionPackageAppKeySeed string
connectionPackageAppKeySalt string
connectionPackageAppKeyMu sync.Mutex
connectionPackageAppKeyCached []byte
)
func deriveConnectionPackageAppKey() ([]byte, error) {
connectionPackageAppKeyMu.Lock()
defer connectionPackageAppKeyMu.Unlock()
if len(connectionPackageAppKeyCached) == connectionPackageAES256KeyBytes {
return append([]byte(nil), connectionPackageAppKeyCached...), nil
}
seed := strings.TrimSpace(connectionPackageAppKeySeed)
if seed == "" {
seed = connectionPackageAppKeyFallbackSeed
}
saltValue := strings.TrimSpace(connectionPackageAppKeySalt)
if saltValue == "" {
saltValue = connectionPackageAppKeyFallbackSalt
}
mac := hmac.New(sha256.New, []byte(seed))
if _, err := mac.Write([]byte(connectionPackageAppKeyPurpose)); err != nil {
return nil, err
}
intermediate := mac.Sum(nil)
saltHash := sha256.Sum256([]byte(saltValue))
key := argon2.IDKey(
intermediate,
saltHash[:connectionPackageSaltBytes],
connectionPackageKDFDefaultTimeCost,
connectionPackageKDFDefaultMemoryKiB,
connectionPackageKDFDefaultParallelism,
connectionPackageAES256KeyBytes,
)
connectionPackageAppKeyCached = append([]byte(nil), key...)
return append([]byte(nil), key...), nil
}
func resetConnectionPackageAppKeyCache() {
connectionPackageAppKeyMu.Lock()
defer connectionPackageAppKeyMu.Unlock()
connectionPackageAppKeyCached = nil
}
func encryptSecretField(appKey []byte, plaintext string, aad string) (string, error) {
if plaintext == "" {
return "", nil
}
aead, err := newConnectionPackageAEAD(appKey)
if err != nil {
return "", err
}
nonce := make([]byte, connectionPackageNonceBytes)
if _, err := rand.Read(nonce); err != nil {
return "", err
}
ciphertext := aead.Seal(nil, nonce, []byte(plaintext), []byte(aad))
encoded := make([]byte, 0, len(nonce)+len(ciphertext))
encoded = append(encoded, nonce...)
encoded = append(encoded, ciphertext...)
return base64.StdEncoding.EncodeToString(encoded), nil
}
func decryptSecretField(appKey []byte, encrypted string, aad string) (string, error) {
if encrypted == "" {
return "", nil
}
raw, err := base64.StdEncoding.DecodeString(encrypted)
if err != nil {
return "", err
}
if len(raw) <= connectionPackageNonceBytes {
return "", errors.New("invalid encrypted secret")
}
aead, err := newConnectionPackageAEAD(appKey)
if err != nil {
return "", err
}
plain, err := aead.Open(nil, raw[:connectionPackageNonceBytes], raw[connectionPackageNonceBytes:], []byte(aad))
if err != nil {
return "", err
}
return string(plain), nil
}
func encryptSecretBundle(appKey []byte, bundle connectionSecretBundle, connectionID string) (connectionSecretBundle, error) {
var encrypted connectionSecretBundle
var err error
encrypted.Password, err = encryptSecretField(appKey, bundle.Password, connectionID)
if err != nil {
return connectionSecretBundle{}, err
}
encrypted.SSHPassword, err = encryptSecretField(appKey, bundle.SSHPassword, connectionID)
if err != nil {
return connectionSecretBundle{}, err
}
encrypted.ProxyPassword, err = encryptSecretField(appKey, bundle.ProxyPassword, connectionID)
if err != nil {
return connectionSecretBundle{}, err
}
encrypted.HTTPTunnelPassword, err = encryptSecretField(appKey, bundle.HTTPTunnelPassword, connectionID)
if err != nil {
return connectionSecretBundle{}, err
}
encrypted.MySQLReplicaPassword, err = encryptSecretField(appKey, bundle.MySQLReplicaPassword, connectionID)
if err != nil {
return connectionSecretBundle{}, err
}
encrypted.MongoReplicaPassword, err = encryptSecretField(appKey, bundle.MongoReplicaPassword, connectionID)
if err != nil {
return connectionSecretBundle{}, err
}
encrypted.OpaqueURI, err = encryptSecretField(appKey, bundle.OpaqueURI, connectionID)
if err != nil {
return connectionSecretBundle{}, err
}
encrypted.OpaqueDSN, err = encryptSecretField(appKey, bundle.OpaqueDSN, connectionID)
if err != nil {
return connectionSecretBundle{}, err
}
return encrypted, nil
}
func decryptSecretBundle(appKey []byte, bundle connectionSecretBundle, connectionID string) (connectionSecretBundle, error) {
var decrypted connectionSecretBundle
var err error
decrypted.Password, err = decryptSecretField(appKey, bundle.Password, connectionID)
if err != nil {
return connectionSecretBundle{}, err
}
decrypted.SSHPassword, err = decryptSecretField(appKey, bundle.SSHPassword, connectionID)
if err != nil {
return connectionSecretBundle{}, err
}
decrypted.ProxyPassword, err = decryptSecretField(appKey, bundle.ProxyPassword, connectionID)
if err != nil {
return connectionSecretBundle{}, err
}
decrypted.HTTPTunnelPassword, err = decryptSecretField(appKey, bundle.HTTPTunnelPassword, connectionID)
if err != nil {
return connectionSecretBundle{}, err
}
decrypted.MySQLReplicaPassword, err = decryptSecretField(appKey, bundle.MySQLReplicaPassword, connectionID)
if err != nil {
return connectionSecretBundle{}, err
}
decrypted.MongoReplicaPassword, err = decryptSecretField(appKey, bundle.MongoReplicaPassword, connectionID)
if err != nil {
return connectionSecretBundle{}, err
}
decrypted.OpaqueURI, err = decryptSecretField(appKey, bundle.OpaqueURI, connectionID)
if err != nil {
return connectionSecretBundle{}, err
}
decrypted.OpaqueDSN, err = decryptSecretField(appKey, bundle.OpaqueDSN, connectionID)
if err != nil {
return connectionSecretBundle{}, err
}
return decrypted, nil
}

View File

@@ -0,0 +1,141 @@
package app
import (
"encoding/base64"
"reflect"
"strings"
"testing"
)
func TestDeriveConnectionPackageAppKeyIsStable(t *testing.T) {
originalSeed := connectionPackageAppKeySeed
originalSalt := connectionPackageAppKeySalt
t.Cleanup(func() {
connectionPackageAppKeySeed = originalSeed
connectionPackageAppKeySalt = originalSalt
resetConnectionPackageAppKeyCache()
})
connectionPackageAppKeySeed = "unit-test-seed"
connectionPackageAppKeySalt = "unit-test-salt"
resetConnectionPackageAppKeyCache()
first, err := deriveConnectionPackageAppKey()
if err != nil {
t.Fatalf("deriveConnectionPackageAppKey returned error: %v", err)
}
second, err := deriveConnectionPackageAppKey()
if err != nil {
t.Fatalf("deriveConnectionPackageAppKey returned error on second call: %v", err)
}
if len(first) != connectionPackageAES256KeyBytes {
t.Fatalf("expected %d-byte app key, got %d", connectionPackageAES256KeyBytes, len(first))
}
if !reflect.DeepEqual(first, second) {
t.Fatal("expected deriveConnectionPackageAppKey to be stable across repeated calls")
}
connectionPackageAppKeySeed = "unit-test-seed-rotated"
resetConnectionPackageAppKeyCache()
rotated, err := deriveConnectionPackageAppKey()
if err != nil {
t.Fatalf("deriveConnectionPackageAppKey returned error after seed rotation: %v", err)
}
if reflect.DeepEqual(first, rotated) {
t.Fatal("expected different injected seed to produce a different app key")
}
}
func TestEncryptSecretFieldRoundTrip(t *testing.T) {
appKey := []byte("0123456789abcdef0123456789abcdef")
encrypted, err := encryptSecretField(appKey, "super-secret", "conn-1")
if err != nil {
t.Fatalf("encryptSecretField returned error: %v", err)
}
if strings.HasPrefix(encrypted, "ENC:") {
t.Fatalf("encrypted field must not carry ENC prefix, got %q", encrypted)
}
raw, err := base64.StdEncoding.DecodeString(encrypted)
if err != nil {
t.Fatalf("encrypted field must be base64, got error: %v", err)
}
if len(raw) <= connectionPackageNonceBytes {
t.Fatalf("expected nonce+ciphertext output, got %d bytes", len(raw))
}
decrypted, err := decryptSecretField(appKey, encrypted, "conn-1")
if err != nil {
t.Fatalf("decryptSecretField returned error: %v", err)
}
if decrypted != "super-secret" {
t.Fatalf("round-trip mismatch: got %q", decrypted)
}
}
func TestDecryptSecretFieldRejectsAADMismatch(t *testing.T) {
appKey := []byte("0123456789abcdef0123456789abcdef")
encrypted, err := encryptSecretField(appKey, "super-secret", "conn-1")
if err != nil {
t.Fatalf("encryptSecretField returned error: %v", err)
}
if _, err := decryptSecretField(appKey, encrypted, "conn-2"); err == nil {
t.Fatal("expected decryptSecretField to reject mismatched AAD")
}
}
func TestEncryptSecretBundleRoundTripAndAADBinding(t *testing.T) {
appKey := []byte("0123456789abcdef0123456789abcdef")
plain := connectionSecretBundle{
Password: "primary-secret",
SSHPassword: "ssh-secret",
ProxyPassword: "proxy-secret",
HTTPTunnelPassword: "http-secret",
MySQLReplicaPassword: "mysql-secret",
MongoReplicaPassword: "mongo-secret",
OpaqueURI: "postgres://user:pass@db.local/app",
OpaqueDSN: "server=db.local;password=secret",
}
encrypted, err := encryptSecretBundle(appKey, plain, "conn-1")
if err != nil {
t.Fatalf("encryptSecretBundle returned error: %v", err)
}
for name, value := range map[string]string{
"password": encrypted.Password,
"sshPassword": encrypted.SSHPassword,
"proxyPassword": encrypted.ProxyPassword,
"httpTunnelPassword": encrypted.HTTPTunnelPassword,
"mysqlReplicaPassword": encrypted.MySQLReplicaPassword,
"mongoReplicaPassword": encrypted.MongoReplicaPassword,
"opaqueURI": encrypted.OpaqueURI,
"opaqueDSN": encrypted.OpaqueDSN,
} {
if value == "" {
t.Fatalf("expected encrypted %s field to be populated", name)
}
if strings.HasPrefix(value, "ENC:") {
t.Fatalf("encrypted %s field must not carry ENC prefix", name)
}
if value == plain.Password || value == plain.SSHPassword || value == plain.ProxyPassword ||
value == plain.HTTPTunnelPassword || value == plain.MySQLReplicaPassword || value == plain.MongoReplicaPassword ||
value == plain.OpaqueURI || value == plain.OpaqueDSN {
t.Fatalf("expected encrypted %s field to differ from plaintext", name)
}
}
decrypted, err := decryptSecretBundle(appKey, encrypted, "conn-1")
if err != nil {
t.Fatalf("decryptSecretBundle returned error: %v", err)
}
if !reflect.DeepEqual(decrypted, plain) {
t.Fatalf("bundle round-trip mismatch: got=%+v want=%+v", decrypted, plain)
}
if _, err := decryptSecretBundle(appKey, encrypted, "conn-2"); err == nil {
t.Fatal("expected decryptSecretBundle to reject mismatched connection AAD")
}
}

View File

@@ -0,0 +1,582 @@
package app
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"strings"
"golang.org/x/crypto/argon2"
)
const (
connectionPackageAES256KeyBytes = 32
connectionPackageSaltBytes = 16
connectionPackageNonceBytes = 12
)
type connectionPackageAAD struct {
SchemaVersion int `json:"schemaVersion"`
Kind string `json:"kind"`
Cipher string `json:"cipher"`
KDF connectionPackageKDFSpec `json:"kdf"`
Nonce string `json:"nonce"`
}
type connectionPackageAADV2Protected struct {
V int `json:"v"`
Kind string `json:"kind"`
P int `json:"p"`
KDF connectionPackageKDFSpecV2 `json:"kdf"`
NC string `json:"nc"`
}
func encryptConnectionPackage(payload connectionPackagePayload, password string) (connectionPackageFile, error) {
normalizedPassword := normalizeConnectionPackagePassword(password)
if normalizedPassword == "" {
return connectionPackageFile{}, errConnectionPackagePasswordRequired
}
plain, err := json.Marshal(payload)
if err != nil {
return connectionPackageFile{}, err
}
salt := make([]byte, connectionPackageSaltBytes)
if _, err := rand.Read(salt); err != nil {
return connectionPackageFile{}, err
}
nonce := make([]byte, connectionPackageNonceBytes)
if _, err := rand.Read(nonce); err != nil {
return connectionPackageFile{}, err
}
file := connectionPackageFile{
SchemaVersion: connectionPackageSchemaVersion,
Kind: connectionPackageKind,
Cipher: connectionPackageCipher,
KDF: defaultConnectionPackageKDFSpec(),
Nonce: base64.StdEncoding.EncodeToString(nonce),
}
file.KDF.Salt = base64.StdEncoding.EncodeToString(salt)
key, err := deriveConnectionPackageKey(normalizedPassword, file.KDF)
if err != nil {
return connectionPackageFile{}, err
}
aad, err := marshalConnectionPackageAAD(file)
if err != nil {
return connectionPackageFile{}, err
}
aead, err := newConnectionPackageAEAD(key)
if err != nil {
return connectionPackageFile{}, err
}
ciphertext := aead.Seal(nil, nonce, plain, aad)
if len(ciphertext) > connectionPackageMaxCiphertextBytes {
return connectionPackageFile{}, errConnectionPackagePayloadTooLarge
}
file.Payload = base64.StdEncoding.EncodeToString(ciphertext)
if len(file.Payload) > connectionPackageMaxPayloadBase64Bytes {
return connectionPackageFile{}, errConnectionPackagePayloadTooLarge
}
return file, nil
}
func decryptConnectionPackage(file connectionPackageFile, password string) (connectionPackagePayload, error) {
normalizedPassword := normalizeConnectionPackagePassword(password)
if normalizedPassword == "" {
return connectionPackagePayload{}, errConnectionPackagePasswordRequired
}
if err := validateConnectionPackageFileHeader(file); err != nil {
return connectionPackagePayload{}, err
}
plain, err := decryptConnectionPackagePlaintext(file, normalizedPassword)
if err != nil {
if errors.Is(err, errConnectionPackagePayloadTooLarge) {
return connectionPackagePayload{}, err
}
return connectionPackagePayload{}, errConnectionPackageDecryptFailed
}
var payload connectionPackagePayload
if err := json.Unmarshal(plain, &payload); err != nil {
return connectionPackagePayload{}, errConnectionPackageDecryptFailed
}
return payload, nil
}
func isConnectionPackageEnvelope(raw string) bool {
file, err := decodeConnectionPackageEnvelope(raw)
if err != nil {
return false
}
return validateConnectionPackageFileHeader(file) == nil
}
func encryptConnectionPackageV2AppManaged(payload connectionPackagePayload) (connectionPackageFileV2, error) {
appKey, err := deriveConnectionPackageAppKey()
if err != nil {
return connectionPackageFileV2{}, err
}
encryptedPayload, err := encryptConnectionPackagePayloadSecrets(payload, appKey)
if err != nil {
return connectionPackageFileV2{}, err
}
return connectionPackageFileV2{
V: connectionPackageSchemaVersionV2,
Kind: connectionPackageKind,
P: connectionPackageProtectionAppManaged,
ExportedAt: encryptedPayload.ExportedAt,
Connections: encryptedPayload.Connections,
}, nil
}
func encryptConnectionPackageV2Protected(payload connectionPackagePayload, password string) (connectionPackageFileV2Protected, error) {
normalizedPassword := normalizeConnectionPackagePassword(password)
if normalizedPassword == "" {
return connectionPackageFileV2Protected{}, errConnectionPackagePasswordRequired
}
appKey, err := deriveConnectionPackageAppKey()
if err != nil {
return connectionPackageFileV2Protected{}, err
}
encryptedPayload, err := encryptConnectionPackagePayloadSecrets(payload, appKey)
if err != nil {
return connectionPackageFileV2Protected{}, err
}
plain, err := json.Marshal(encryptedPayload)
if err != nil {
return connectionPackageFileV2Protected{}, err
}
salt := make([]byte, connectionPackageSaltBytes)
if _, err := rand.Read(salt); err != nil {
return connectionPackageFileV2Protected{}, err
}
nonce := make([]byte, connectionPackageNonceBytes)
if _, err := rand.Read(nonce); err != nil {
return connectionPackageFileV2Protected{}, err
}
file := connectionPackageFileV2Protected{
V: connectionPackageSchemaVersionV2,
Kind: connectionPackageKind,
P: connectionPackageProtectionPasswordProtected,
KDF: defaultConnectionPackageKDFSpecV2(),
NC: base64.StdEncoding.EncodeToString(nonce),
}
file.KDF.S = base64.StdEncoding.EncodeToString(salt)
key, err := deriveConnectionPackageKeyV2(normalizedPassword, file.KDF)
if err != nil {
return connectionPackageFileV2Protected{}, err
}
aad, err := marshalConnectionPackageAADV2Protected(file)
if err != nil {
return connectionPackageFileV2Protected{}, err
}
aead, err := newConnectionPackageAEAD(key)
if err != nil {
return connectionPackageFileV2Protected{}, err
}
ciphertext := aead.Seal(nil, nonce, plain, aad)
if len(ciphertext) > connectionPackageMaxCiphertextBytes {
return connectionPackageFileV2Protected{}, errConnectionPackagePayloadTooLarge
}
file.D = base64.StdEncoding.EncodeToString(ciphertext)
if len(file.D) > connectionPackageMaxPayloadBase64Bytes {
return connectionPackageFileV2Protected{}, errConnectionPackagePayloadTooLarge
}
return file, nil
}
func decryptConnectionPackageV2AppManaged(file connectionPackageFileV2) (connectionPackagePayload, error) {
if err := validateConnectionPackageFileHeaderV2AppManaged(file); err != nil {
return connectionPackagePayload{}, err
}
appKey, err := deriveConnectionPackageAppKey()
if err != nil {
return connectionPackagePayload{}, err
}
payload, err := decryptConnectionPackagePayloadSecrets(connectionPackagePayload{
ExportedAt: file.ExportedAt,
Connections: file.Connections,
}, appKey)
if err != nil {
return connectionPackagePayload{}, errConnectionPackageDecryptFailed
}
return payload, nil
}
func decryptConnectionPackageV2Protected(file connectionPackageFileV2Protected, password string) (connectionPackagePayload, error) {
normalizedPassword := normalizeConnectionPackagePassword(password)
if normalizedPassword == "" {
return connectionPackagePayload{}, errConnectionPackagePasswordRequired
}
if err := validateConnectionPackageFileHeaderV2Protected(file); err != nil {
return connectionPackagePayload{}, err
}
plain, err := decryptConnectionPackageV2ProtectedPlaintext(file, normalizedPassword)
if err != nil {
if errors.Is(err, errConnectionPackagePayloadTooLarge) {
return connectionPackagePayload{}, err
}
return connectionPackagePayload{}, errConnectionPackageDecryptFailed
}
var encryptedPayload connectionPackagePayload
if err := json.Unmarshal(plain, &encryptedPayload); err != nil {
return connectionPackagePayload{}, errConnectionPackageDecryptFailed
}
appKey, err := deriveConnectionPackageAppKey()
if err != nil {
return connectionPackagePayload{}, err
}
payload, err := decryptConnectionPackagePayloadSecrets(encryptedPayload, appKey)
if err != nil {
return connectionPackagePayload{}, errConnectionPackageDecryptFailed
}
return payload, nil
}
func isConnectionPackageV2AppManaged(raw string) bool {
header, err := decodeConnectionPackageV2Header(raw)
if err != nil {
return false
}
return header.Kind == connectionPackageKind &&
header.V == connectionPackageSchemaVersionV2 &&
header.P == connectionPackageProtectionAppManaged
}
func isConnectionPackageV2Protected(raw string) bool {
header, err := decodeConnectionPackageV2Header(raw)
if err != nil {
return false
}
return header.Kind == connectionPackageKind &&
header.V == connectionPackageSchemaVersionV2 &&
header.P == connectionPackageProtectionPasswordProtected
}
func encodeConnectionPackageEnvelope(file connectionPackageFile) (string, error) {
raw, err := json.Marshal(file)
if err != nil {
return "", err
}
return string(raw), nil
}
func decodeConnectionPackageEnvelope(raw string) (connectionPackageFile, error) {
var file connectionPackageFile
if err := json.Unmarshal([]byte(raw), &file); err != nil {
return connectionPackageFile{}, err
}
return file, nil
}
func decodeConnectionPackageV2Header(raw string) (struct {
V int `json:"v"`
Kind string `json:"kind"`
P int `json:"p"`
}, error) {
var header struct {
V int `json:"v"`
Kind string `json:"kind"`
P int `json:"p"`
}
if err := json.Unmarshal([]byte(raw), &header); err != nil {
return header, err
}
return header, nil
}
func decryptConnectionPackagePlaintext(file connectionPackageFile, password string) ([]byte, error) {
if err := validateConnectionPackageFileHeader(file); err != nil {
return nil, err
}
nonce, err := base64.StdEncoding.DecodeString(file.Nonce)
if err != nil || len(nonce) != connectionPackageNonceBytes {
return nil, errors.New("invalid nonce")
}
if len(file.Payload) > connectionPackageMaxPayloadBase64Bytes {
return nil, errConnectionPackagePayloadTooLarge
}
ciphertext, err := base64.StdEncoding.DecodeString(file.Payload)
if err != nil || len(ciphertext) == 0 {
return nil, errors.New("invalid payload")
}
if len(ciphertext) > connectionPackageMaxCiphertextBytes {
return nil, errConnectionPackagePayloadTooLarge
}
key, err := deriveConnectionPackageKey(password, file.KDF)
if err != nil {
return nil, err
}
aad, err := marshalConnectionPackageAAD(file)
if err != nil {
return nil, err
}
aead, err := newConnectionPackageAEAD(key)
if err != nil {
return nil, err
}
plain, err := aead.Open(nil, nonce, ciphertext, aad)
if err != nil {
return nil, err
}
return plain, nil
}
func deriveConnectionPackageKey(password string, spec connectionPackageKDFSpec) ([]byte, error) {
if password == "" {
return nil, errConnectionPackagePasswordRequired
}
if err := validateConnectionPackageKDFSpec(spec); err != nil {
return nil, err
}
salt, err := base64.StdEncoding.DecodeString(spec.Salt)
if err != nil || len(salt) == 0 {
return nil, errors.New("invalid salt")
}
key := argon2.IDKey(
[]byte(password),
salt,
spec.TimeCost,
spec.MemoryKiB,
spec.Parallelism,
connectionPackageAES256KeyBytes,
)
return key, nil
}
func deriveConnectionPackageKeyV2(password string, spec connectionPackageKDFSpecV2) ([]byte, error) {
if password == "" {
return nil, errConnectionPackagePasswordRequired
}
if err := validateConnectionPackageKDFSpecV2(spec); err != nil {
return nil, err
}
salt, err := base64.StdEncoding.DecodeString(spec.S)
if err != nil || len(salt) == 0 {
return nil, errors.New("invalid salt")
}
key := argon2.IDKey(
[]byte(password),
salt,
spec.T,
spec.M,
spec.L,
connectionPackageAES256KeyBytes,
)
return key, nil
}
func marshalConnectionPackageAAD(file connectionPackageFile) ([]byte, error) {
aad := connectionPackageAAD{
SchemaVersion: file.SchemaVersion,
Kind: file.Kind,
Cipher: file.Cipher,
KDF: file.KDF,
Nonce: file.Nonce,
}
return json.Marshal(aad)
}
func marshalConnectionPackageAADV2Protected(file connectionPackageFileV2Protected) ([]byte, error) {
return json.Marshal(connectionPackageAADV2Protected{
V: file.V,
Kind: file.Kind,
P: file.P,
KDF: file.KDF,
NC: file.NC,
})
}
func newConnectionPackageAEAD(key []byte) (cipher.AEAD, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return cipher.NewGCM(block)
}
func validateConnectionPackageFileHeader(file connectionPackageFile) error {
switch {
case file.SchemaVersion != connectionPackageSchemaVersion:
return errConnectionPackageUnsupported
case strings.TrimSpace(file.Kind) != connectionPackageKind:
return errConnectionPackageUnsupported
case strings.TrimSpace(file.Cipher) != connectionPackageCipher:
return errConnectionPackageUnsupported
case validateConnectionPackageKDFSpec(file.KDF) != nil:
return errConnectionPackageUnsupported
default:
return nil
}
}
func validateConnectionPackageFileHeaderV2AppManaged(file connectionPackageFileV2) error {
switch {
case file.V != connectionPackageSchemaVersionV2:
return errConnectionPackageUnsupported
case strings.TrimSpace(file.Kind) != connectionPackageKind:
return errConnectionPackageUnsupported
case file.P != connectionPackageProtectionAppManaged:
return errConnectionPackageUnsupported
default:
return nil
}
}
func validateConnectionPackageFileHeaderV2Protected(file connectionPackageFileV2Protected) error {
switch {
case file.V != connectionPackageSchemaVersionV2:
return errConnectionPackageUnsupported
case strings.TrimSpace(file.Kind) != connectionPackageKind:
return errConnectionPackageUnsupported
case file.P != connectionPackageProtectionPasswordProtected:
return errConnectionPackageUnsupported
case validateConnectionPackageKDFSpecV2(file.KDF) != nil:
return errConnectionPackageUnsupported
default:
return nil
}
}
func validateConnectionPackageKDFSpec(spec connectionPackageKDFSpec) error {
switch {
case strings.TrimSpace(spec.Name) != connectionPackageKDFName:
return errConnectionPackageUnsupported
case spec.MemoryKiB == 0 || spec.TimeCost == 0 || spec.Parallelism == 0:
return errConnectionPackageUnsupported
case spec.MemoryKiB > connectionPackageKDFMaxMemoryKiB:
return errConnectionPackageUnsupported
case spec.TimeCost > connectionPackageKDFMaxTimeCost:
return errConnectionPackageUnsupported
case spec.Parallelism > connectionPackageKDFMaxParallelism:
return errConnectionPackageUnsupported
default:
return nil
}
}
func validateConnectionPackageKDFSpecV2(spec connectionPackageKDFSpecV2) error {
switch {
case strings.TrimSpace(spec.N) != connectionPackageKDFNameV2:
return errConnectionPackageUnsupported
case spec.M == 0 || spec.T == 0 || spec.L == 0:
return errConnectionPackageUnsupported
case spec.M > connectionPackageKDFMaxMemoryKiB:
return errConnectionPackageUnsupported
case spec.T > connectionPackageKDFMaxTimeCost:
return errConnectionPackageUnsupported
case spec.L > connectionPackageKDFMaxParallelism:
return errConnectionPackageUnsupported
default:
return nil
}
}
func decryptConnectionPackageV2ProtectedPlaintext(file connectionPackageFileV2Protected, password string) ([]byte, error) {
if err := validateConnectionPackageFileHeaderV2Protected(file); err != nil {
return nil, err
}
nonce, err := base64.StdEncoding.DecodeString(file.NC)
if err != nil || len(nonce) != connectionPackageNonceBytes {
return nil, errors.New("invalid nonce")
}
if len(file.D) > connectionPackageMaxPayloadBase64Bytes {
return nil, errConnectionPackagePayloadTooLarge
}
ciphertext, err := base64.StdEncoding.DecodeString(file.D)
if err != nil || len(ciphertext) == 0 {
return nil, errors.New("invalid payload")
}
if len(ciphertext) > connectionPackageMaxCiphertextBytes {
return nil, errConnectionPackagePayloadTooLarge
}
key, err := deriveConnectionPackageKeyV2(password, file.KDF)
if err != nil {
return nil, err
}
aad, err := marshalConnectionPackageAADV2Protected(file)
if err != nil {
return nil, err
}
aead, err := newConnectionPackageAEAD(key)
if err != nil {
return nil, err
}
return aead.Open(nil, nonce, ciphertext, aad)
}
func encryptConnectionPackagePayloadSecrets(payload connectionPackagePayload, appKey []byte) (connectionPackagePayload, error) {
encrypted := connectionPackagePayload{
ExportedAt: payload.ExportedAt,
Connections: make([]connectionPackageItem, len(payload.Connections)),
}
for index, item := range payload.Connections {
encryptedItem := item
bundle, err := encryptSecretBundle(appKey, item.Secrets, connectionPackageItemAAD(item))
if err != nil {
return connectionPackagePayload{}, err
}
encryptedItem.Secrets = bundle
encrypted.Connections[index] = encryptedItem
}
return encrypted, nil
}
func decryptConnectionPackagePayloadSecrets(payload connectionPackagePayload, appKey []byte) (connectionPackagePayload, error) {
decrypted := connectionPackagePayload{
ExportedAt: payload.ExportedAt,
Connections: make([]connectionPackageItem, len(payload.Connections)),
}
for index, item := range payload.Connections {
decryptedItem := item
bundle, err := decryptSecretBundle(appKey, item.Secrets, connectionPackageItemAAD(item))
if err != nil {
return connectionPackagePayload{}, err
}
decryptedItem.Secrets = bundle
decrypted.Connections[index] = decryptedItem
}
return decrypted, nil
}
func connectionPackageItemAAD(item connectionPackageItem) string {
if strings.TrimSpace(item.ID) != "" {
return item.ID
}
return item.Config.ID
}

View File

@@ -0,0 +1,477 @@
package app
import (
"encoding/base64"
"encoding/json"
"errors"
"reflect"
"strings"
"testing"
"GoNavi-Wails/internal/connection"
)
func TestConnectionPackageCryptoRoundTrip(t *testing.T) {
payload := connectionPackagePayload{
ExportedAt: "2026-04-10T12:00:00+08:00",
Connections: []connectionPackageItem{
{
ID: "conn-1",
Name: "local-mysql",
IncludeDatabases: []string{"app"},
IconType: "database",
IconColor: "#2f855a",
Config: connection.ConnectionConfig{
Type: "mysql",
Host: "127.0.0.1",
Port: 3306,
User: "root",
Database: "app",
},
},
},
}
file, err := encryptConnectionPackage(payload, "strong-password")
if err != nil {
t.Fatalf("encryptConnectionPackage returned error: %v", err)
}
raw, err := json.Marshal(file)
if err != nil {
t.Fatalf("json.Marshal envelope returned error: %v", err)
}
if !isConnectionPackageEnvelope(string(raw)) {
t.Fatalf("isConnectionPackageEnvelope should return true for valid envelope")
}
var decoded connectionPackageFile
if err := json.Unmarshal(raw, &decoded); err != nil {
t.Fatalf("json.Unmarshal envelope returned error: %v", err)
}
got, err := decryptConnectionPackage(decoded, "strong-password")
if err != nil {
t.Fatalf("decryptConnectionPackage returned error: %v", err)
}
if !reflect.DeepEqual(got, payload) {
t.Fatalf("round-trip mismatch: got=%+v want=%+v", got, payload)
}
}
func TestConnectionPackageV2AppManagedRoundTrip(t *testing.T) {
payload := connectionPackagePayload{
ExportedAt: "2026-04-11T12:00:00Z",
Connections: []connectionPackageItem{
{
ID: "conn-v2-1",
Name: "app-managed",
Config: connection.ConnectionConfig{
ID: "conn-v2-1",
Type: "postgres",
Host: "db.local",
Port: 5432,
User: "postgres",
Database: "app",
},
Secrets: connectionSecretBundle{
Password: "primary-secret",
SSHPassword: "ssh-secret",
OpaqueURI: "postgres://postgres:primary-secret@db.local/app",
},
},
},
}
file, err := encryptConnectionPackageV2AppManaged(payload)
if err != nil {
t.Fatalf("encryptConnectionPackageV2AppManaged returned error: %v", err)
}
if file.V != connectionPackageSchemaVersionV2 {
t.Fatalf("expected v2 schema, got %d", file.V)
}
if file.P != connectionPackageProtectionAppManaged {
t.Fatalf("expected p=1, got %d", file.P)
}
if len(file.Connections) != 1 {
t.Fatalf("expected 1 connection, got %d", len(file.Connections))
}
if file.Connections[0].Secrets.Password == payload.Connections[0].Secrets.Password {
t.Fatal("expected p=1 secrets to stay encrypted in file")
}
raw, err := json.Marshal(file)
if err != nil {
t.Fatalf("json.Marshal returned error: %v", err)
}
if !isConnectionPackageV2AppManaged(string(raw)) {
t.Fatal("expected raw v2 p=1 payload to be detected")
}
if isConnectionPackageEnvelope(string(raw)) {
t.Fatal("v2 p=1 payload must not be misclassified as v1 envelope")
}
rawString := string(raw)
for _, forbidden := range []string{
"schemaVersion",
"cipher",
"protectionLevel",
"ENC:",
"primary-secret",
"ssh-secret",
"postgres://postgres:primary-secret@db.local/app",
} {
if strings.Contains(rawString, forbidden) {
t.Fatalf("v2 p=1 payload must not contain %q: %s", forbidden, rawString)
}
}
got, err := decryptConnectionPackageV2AppManaged(file)
if err != nil {
t.Fatalf("decryptConnectionPackageV2AppManaged returned error: %v", err)
}
if !reflect.DeepEqual(got, payload) {
t.Fatalf("round-trip mismatch: got=%+v want=%+v", got, payload)
}
}
func TestConnectionPackageV2ProtectedRoundTrip(t *testing.T) {
payload := connectionPackagePayload{
ExportedAt: "2026-04-11T12:00:00Z",
Connections: []connectionPackageItem{
{
ID: "conn-v2-2",
Name: "password-protected",
Config: connection.ConnectionConfig{
ID: "conn-v2-2",
Type: "mysql",
Host: "db.local",
Port: 3306,
User: "root",
Database: "app",
},
Secrets: connectionSecretBundle{
Password: "primary-secret",
SSHPassword: "ssh-secret",
ProxyPassword: "proxy-secret",
HTTPTunnelPassword: "http-secret",
MySQLReplicaPassword: "mysql-secret",
MongoReplicaPassword: "mongo-secret",
OpaqueURI: "mysql://root:primary-secret@tcp(db.local:3306)/app",
OpaqueDSN: "root:primary-secret@tcp(db.local:3306)/app",
},
},
},
}
file, err := encryptConnectionPackageV2Protected(payload, "package-password")
if err != nil {
t.Fatalf("encryptConnectionPackageV2Protected returned error: %v", err)
}
if file.V != connectionPackageSchemaVersionV2 {
t.Fatalf("expected v2 schema, got %d", file.V)
}
if file.P != connectionPackageProtectionPasswordProtected {
t.Fatalf("expected p=2, got %d", file.P)
}
if file.D == "" || file.NC == "" {
t.Fatal("expected p=2 file to carry outer encrypted payload")
}
if strings.HasPrefix(file.D, "ENC:") {
t.Fatalf("outer payload must not carry ENC prefix, got %q", file.D)
}
raw, err := json.Marshal(file)
if err != nil {
t.Fatalf("json.Marshal returned error: %v", err)
}
if !isConnectionPackageV2Protected(string(raw)) {
t.Fatal("expected raw v2 p=2 payload to be detected")
}
if isConnectionPackageEnvelope(string(raw)) {
t.Fatal("v2 p=2 payload must not be misclassified as v1 envelope")
}
rawString := string(raw)
for _, forbidden := range []string{
"schemaVersion",
"cipher",
"protectionLevel",
"ENC:",
"primary-secret",
"ssh-secret",
} {
if strings.Contains(rawString, forbidden) {
t.Fatalf("v2 p=2 payload must not contain %q: %s", forbidden, rawString)
}
}
got, err := decryptConnectionPackageV2Protected(file, "package-password")
if err != nil {
t.Fatalf("decryptConnectionPackageV2Protected returned error: %v", err)
}
if !reflect.DeepEqual(got, payload) {
t.Fatalf("round-trip mismatch: got=%+v want=%+v", got, payload)
}
}
func TestConnectionPackageV2ProtectedWrongPasswordReturnsUnifiedError(t *testing.T) {
file, err := encryptConnectionPackageV2Protected(connectionPackagePayload{
Connections: []connectionPackageItem{
{
ID: "conn-v2-3",
Name: "wrong-password",
Config: connection.ConnectionConfig{
ID: "conn-v2-3",
Type: "postgres",
},
Secrets: connectionSecretBundle{
Password: "primary-secret",
},
},
},
}, "correct-password")
if err != nil {
t.Fatalf("encryptConnectionPackageV2Protected returned error: %v", err)
}
_, err = decryptConnectionPackageV2Protected(file, "wrong-password")
if !errors.Is(err, errConnectionPackageDecryptFailed) {
t.Fatalf("wrong p=2 password should return unified error, got: %v", err)
}
}
func TestConnectionPackageDecryptWrongPasswordReturnsUnifiedError(t *testing.T) {
payload := connectionPackagePayload{
Connections: []connectionPackageItem{
{
ID: "conn-1",
Name: "test",
Config: connection.ConnectionConfig{
Type: "mysql",
},
},
},
}
file, err := encryptConnectionPackage(payload, "correct-password")
if err != nil {
t.Fatalf("encryptConnectionPackage returned error: %v", err)
}
_, err = decryptConnectionPackage(file, "wrong-password")
if !errors.Is(err, errConnectionPackageDecryptFailed) {
t.Fatalf("wrong password should return unified error, got: %v", err)
}
}
func TestConnectionPackageDecryptTamperedHeaderFailsAADValidation(t *testing.T) {
payload := connectionPackagePayload{
Connections: []connectionPackageItem{
{
ID: "conn-1",
Name: "test",
Config: connection.ConnectionConfig{
Type: "mysql",
},
},
},
}
file, err := encryptConnectionPackage(payload, "correct-password")
if err != nil {
t.Fatalf("encryptConnectionPackage returned error: %v", err)
}
t.Run("cipher", func(t *testing.T) {
tampered := file
tampered.Nonce = "AAAAAAAAAAAAAAAA"
_, err := decryptConnectionPackage(tampered, "correct-password")
if !errors.Is(err, errConnectionPackageDecryptFailed) {
t.Fatalf("tampered nonce should fail with unified error, got: %v", err)
}
})
t.Run("kdf-salt", func(t *testing.T) {
tampered := file
tampered.KDF.Salt = "AAAAAAAAAAAAAAAAAAAAAA=="
_, err := decryptConnectionPackage(tampered, "correct-password")
if !errors.Is(err, errConnectionPackageDecryptFailed) {
t.Fatalf("tampered kdf salt should fail with unified error, got: %v", err)
}
})
}
func TestConnectionPackagePasswordRequired(t *testing.T) {
payload := connectionPackagePayload{
Connections: []connectionPackageItem{
{
ID: "conn-1",
Name: "test",
Config: connection.ConnectionConfig{
Type: "mysql",
},
},
},
}
_, err := encryptConnectionPackage(payload, " ")
if !errors.Is(err, errConnectionPackagePasswordRequired) {
t.Fatalf("encryptConnectionPackage should return password required error, got: %v", err)
}
_, err = decryptConnectionPackage(connectionPackageFile{}, " ")
if !errors.Is(err, errConnectionPackagePasswordRequired) {
t.Fatalf("decryptConnectionPackage should return password required error, got: %v", err)
}
}
func TestConnectionPackageDecryptUnsupportedHeaderReturnsUnsupportedError(t *testing.T) {
payload := connectionPackagePayload{
Connections: []connectionPackageItem{
{
ID: "conn-1",
Name: "test",
Config: connection.ConnectionConfig{
Type: "mysql",
},
},
},
}
file, err := encryptConnectionPackage(payload, "correct-password")
if err != nil {
t.Fatalf("encryptConnectionPackage returned error: %v", err)
}
t.Run("schemaVersion", func(t *testing.T) {
tampered := file
tampered.SchemaVersion = tampered.SchemaVersion + 1
_, err := decryptConnectionPackage(tampered, "correct-password")
if !errors.Is(err, errConnectionPackageUnsupported) {
t.Fatalf("unsupported schemaVersion should return unsupported error, got: %v", err)
}
})
t.Run("kind", func(t *testing.T) {
tampered := file
tampered.Kind = "other_connection_package"
_, err := decryptConnectionPackage(tampered, "correct-password")
if !errors.Is(err, errConnectionPackageUnsupported) {
t.Fatalf("unsupported kind should return unsupported error, got: %v", err)
}
})
t.Run("cipher", func(t *testing.T) {
tampered := file
tampered.Cipher = "AES-128-GCM"
_, err := decryptConnectionPackage(tampered, "correct-password")
if !errors.Is(err, errConnectionPackageUnsupported) {
t.Fatalf("unsupported cipher should return unsupported error, got: %v", err)
}
})
t.Run("kdf-name", func(t *testing.T) {
tampered := file
tampered.KDF.Name = "PBKDF2"
_, err := decryptConnectionPackage(tampered, "correct-password")
if !errors.Is(err, errConnectionPackageUnsupported) {
t.Fatalf("unsupported kdf name should return unsupported error, got: %v", err)
}
})
}
func TestValidateConnectionPackageKDFSpecRejectsOversizedParams(t *testing.T) {
t.Run("memory", func(t *testing.T) {
spec := defaultConnectionPackageKDFSpec()
spec.MemoryKiB = connectionPackageKDFMaxMemoryKiB + 1
if err := validateConnectionPackageKDFSpec(spec); !errors.Is(err, errConnectionPackageUnsupported) {
t.Fatalf("oversized memory should return unsupported error, got: %v", err)
}
})
t.Run("timeCost", func(t *testing.T) {
spec := defaultConnectionPackageKDFSpec()
spec.TimeCost = connectionPackageKDFMaxTimeCost + 1
if err := validateConnectionPackageKDFSpec(spec); !errors.Is(err, errConnectionPackageUnsupported) {
t.Fatalf("oversized timeCost should return unsupported error, got: %v", err)
}
})
t.Run("parallelism", func(t *testing.T) {
spec := defaultConnectionPackageKDFSpec()
spec.Parallelism = connectionPackageKDFMaxParallelism + 1
if err := validateConnectionPackageKDFSpec(spec); !errors.Is(err, errConnectionPackageUnsupported) {
t.Fatalf("oversized parallelism should return unsupported error, got: %v", err)
}
})
}
func TestDecryptConnectionPackagePlaintextRejectsOversizedPayload(t *testing.T) {
nonce := base64.StdEncoding.EncodeToString(make([]byte, connectionPackageNonceBytes))
salt := base64.StdEncoding.EncodeToString(make([]byte, connectionPackageSaltBytes))
payload := base64.StdEncoding.EncodeToString(make([]byte, connectionPackageMaxCiphertextBytes+1))
file := connectionPackageFile{
SchemaVersion: connectionPackageSchemaVersion,
Kind: connectionPackageKind,
Cipher: connectionPackageCipher,
KDF: connectionPackageKDFSpec{
Name: connectionPackageKDFName,
MemoryKiB: connectionPackageKDFDefaultMemoryKiB,
TimeCost: connectionPackageKDFDefaultTimeCost,
Parallelism: connectionPackageKDFDefaultParallelism,
Salt: salt,
},
Nonce: nonce,
Payload: payload,
}
_, err := decryptConnectionPackagePlaintext(file, "correct-password")
if !errors.Is(err, errConnectionPackagePayloadTooLarge) {
t.Fatalf("oversized payload should return errConnectionPackagePayloadTooLarge, got: %v", err)
}
}
func TestDecryptConnectionPackagePlaintextRejectsOversizedBase64PayloadBeforeDecode(t *testing.T) {
nonce := base64.StdEncoding.EncodeToString(make([]byte, connectionPackageNonceBytes))
file := connectionPackageFile{
SchemaVersion: connectionPackageSchemaVersion,
Kind: connectionPackageKind,
Cipher: connectionPackageCipher,
KDF: connectionPackageKDFSpec{
Name: connectionPackageKDFName,
MemoryKiB: connectionPackageKDFDefaultMemoryKiB,
TimeCost: connectionPackageKDFDefaultTimeCost,
Parallelism: connectionPackageKDFDefaultParallelism,
Salt: base64.StdEncoding.EncodeToString(make([]byte, connectionPackageSaltBytes)),
},
Nonce: nonce,
Payload: strings.Repeat("A", connectionPackageMaxPayloadBase64Bytes+4),
}
_, err := decryptConnectionPackagePlaintext(file, "correct-password")
if !errors.Is(err, errConnectionPackagePayloadTooLarge) {
t.Fatalf("oversized base64 payload should return errConnectionPackagePayloadTooLarge, got: %v", err)
}
}
func TestEncryptConnectionPackageRejectsOversizedPayload(t *testing.T) {
_, err := encryptConnectionPackage(connectionPackagePayload{
Connections: []connectionPackageItem{
{
ID: "conn-large",
Name: strings.Repeat("x", connectionPackageMaxCiphertextBytes),
Config: connection.ConnectionConfig{
ID: "conn-large",
Type: "postgres",
Host: "db.large.local",
Port: 5432,
User: "postgres",
},
},
},
}, "correct-password")
if !errors.Is(err, errConnectionPackagePayloadTooLarge) {
t.Fatalf("oversized export payload should return errConnectionPackagePayloadTooLarge, got: %v", err)
}
}

View File

@@ -0,0 +1,362 @@
package app
import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/secretstore"
"github.com/google/uuid"
)
func newConnectionPackageItem(view connection.SavedConnectionView, bundle connectionSecretBundle) connectionPackageItem {
return connectionPackageItem{
ID: view.ID,
Name: view.Name,
IncludeDatabases: cloneStringSlice(view.IncludeDatabases),
IncludeRedisDatabases: cloneIntSlice(view.IncludeRedisDatabases),
IconType: view.IconType,
IconColor: view.IconColor,
Config: view.Config,
Secrets: bundle,
}
}
func (a *App) buildConnectionPackagePayload() (connectionPackagePayload, error) {
repo := a.savedConnectionRepository()
items, err := repo.List()
if err != nil {
return connectionPackagePayload{}, err
}
connections := make([]connectionPackageItem, 0, len(items))
for _, item := range items {
bundle, bundleErr := repo.loadSecretBundle(item)
if bundleErr != nil {
return connectionPackagePayload{}, bundleErr
}
connections = append(connections, newConnectionPackageItem(item, bundle))
}
return connectionPackagePayload{
ExportedAt: time.Now().UTC().Format(time.RFC3339),
Connections: connections,
}, nil
}
func (a *App) buildExportedConnectionPackage(options ConnectionExportOptions) ([]byte, error) {
payload, err := a.buildConnectionPackagePayload()
if err != nil {
return nil, err
}
if !options.IncludeSecrets {
for index := range payload.Connections {
payload.Connections[index].Secrets = connectionSecretBundle{}
}
}
normalizedPassword := normalizeConnectionPackagePassword(options.FilePassword)
if !options.IncludeSecrets || normalizedPassword == "" {
file, err := encryptConnectionPackageV2AppManaged(payload)
if err != nil {
return nil, err
}
return json.MarshalIndent(file, "", " ")
}
file, err := encryptConnectionPackageV2Protected(payload, normalizedPassword)
if err != nil {
return nil, err
}
return json.MarshalIndent(file, "", " ")
}
func newSavedConnectionInputFromPackageItem(item connectionPackageItem) connection.SavedConnectionInput {
id := strings.TrimSpace(item.ID)
if id == "" {
id = strings.TrimSpace(item.Config.ID)
}
config := item.Config
config.ID = id
config.SavePassword = false
secrets := item.Secrets
config.Password = secrets.Password
config.SSH.Password = secrets.SSHPassword
config.Proxy.Password = secrets.ProxyPassword
config.HTTPTunnel.Password = secrets.HTTPTunnelPassword
config.MySQLReplicaPassword = secrets.MySQLReplicaPassword
config.MongoReplicaPassword = secrets.MongoReplicaPassword
config.URI = secrets.OpaqueURI
config.DSN = secrets.OpaqueDSN
return connection.SavedConnectionInput{
ID: id,
Name: item.Name,
Config: config,
IncludeDatabases: cloneStringSlice(item.IncludeDatabases),
IncludeRedisDatabases: cloneIntSlice(item.IncludeRedisDatabases),
IconType: item.IconType,
IconColor: item.IconColor,
// 连接恢复包以最新导入文件为准;载荷中缺失的密文字段需要显式清空旧值。
ClearPrimaryPassword: strings.TrimSpace(secrets.Password) == "",
ClearSSHPassword: strings.TrimSpace(secrets.SSHPassword) == "",
ClearProxyPassword: strings.TrimSpace(secrets.ProxyPassword) == "",
ClearHTTPTunnelPassword: strings.TrimSpace(secrets.HTTPTunnelPassword) == "",
ClearMySQLReplicaPassword: strings.TrimSpace(secrets.MySQLReplicaPassword) == "",
ClearMongoReplicaPassword: strings.TrimSpace(secrets.MongoReplicaPassword) == "",
ClearOpaqueURI: strings.TrimSpace(secrets.OpaqueURI) == "",
ClearOpaqueDSN: strings.TrimSpace(secrets.OpaqueDSN) == "",
}
}
func dedupeImportedSavedConnectionViews(views []connection.SavedConnectionView) []connection.SavedConnectionView {
if len(views) < 2 {
return views
}
lastIndexByID := make(map[string]int, len(views))
for index, view := range views {
id := strings.TrimSpace(view.ID)
if id == "" {
continue
}
lastIndexByID[id] = index
}
result := make([]connection.SavedConnectionView, 0, len(views))
for index, view := range views {
id := strings.TrimSpace(view.ID)
if id != "" && lastIndexByID[id] != index {
continue
}
result = append(result, view)
}
return result
}
func dedupeImportedSavedConnectionInputs(inputs []connection.SavedConnectionInput) []connection.SavedConnectionInput {
if len(inputs) < 2 {
return inputs
}
lastIndexByID := make(map[string]int, len(inputs))
for index, input := range inputs {
id := strings.TrimSpace(input.ID)
if id == "" {
continue
}
lastIndexByID[id] = index
}
result := make([]connection.SavedConnectionInput, 0, len(inputs))
for index, input := range inputs {
id := strings.TrimSpace(input.ID)
if id != "" && lastIndexByID[id] != index {
continue
}
result = append(result, input)
}
return result
}
func normalizeImportedSavedConnectionInput(input connection.SavedConnectionInput) connection.SavedConnectionInput {
if strings.TrimSpace(input.ID) == "" && strings.TrimSpace(input.Config.ID) == "" {
input.ID = "conn-" + uuid.New().String()[:8]
}
if strings.TrimSpace(input.ID) == "" {
input.ID = strings.TrimSpace(input.Config.ID)
}
input.Config.ID = input.ID
return input
}
func (a *App) importSavedConnectionsAtomically(inputs []connection.SavedConnectionInput) ([]connection.SavedConnectionView, error) {
repo := a.savedConnectionRepository()
normalizedInputs := make([]connection.SavedConnectionInput, 0, len(inputs))
for _, input := range inputs {
normalizedInputs = append(normalizedInputs, normalizeImportedSavedConnectionInput(input))
}
finalInputs := dedupeImportedSavedConnectionInputs(normalizedInputs)
rollbackSnapshot, err := captureConnectionImportRollbackSnapshot(a, finalInputs)
if err != nil {
return nil, err
}
result := make([]connection.SavedConnectionView, 0, len(finalInputs))
for _, input := range finalInputs {
view, err := repo.Save(input)
if err != nil {
if rollbackErr := rollbackSnapshot.restore(a); rollbackErr != nil {
return nil, errors.Join(err, fmt.Errorf("restore connection import rollback: %w", rollbackErr))
}
return nil, err
}
result = append(result, view)
}
return dedupeImportedSavedConnectionViews(result), nil
}
func (a *App) importConnectionPackagePayload(payload connectionPackagePayload) ([]connection.SavedConnectionView, error) {
inputs := make([]connection.SavedConnectionInput, 0, len(payload.Connections))
for _, item := range payload.Connections {
inputs = append(inputs, newSavedConnectionInputFromPackageItem(item))
}
return a.importSavedConnectionsAtomically(inputs)
}
func (a *App) ImportConnectionsPayload(raw string, password string) ([]connection.SavedConnectionView, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return nil, errConnectionPackageUnsupported
}
if len(trimmed) > connectionImportMaxFileBytes {
return nil, errConnectionImportFileTooLarge
}
if isConnectionPackageV2AppManaged(trimmed) {
var file connectionPackageFileV2
if err := json.Unmarshal([]byte(trimmed), &file); err != nil {
return nil, errConnectionPackageUnsupported
}
payload, err := decryptConnectionPackageV2AppManaged(file)
if err != nil {
return nil, err
}
return a.importConnectionPackagePayload(payload)
}
if isConnectionPackageV2Protected(trimmed) {
var file connectionPackageFileV2Protected
if err := json.Unmarshal([]byte(trimmed), &file); err != nil {
return nil, errConnectionPackageUnsupported
}
payload, err := decryptConnectionPackageV2Protected(file, password)
if err != nil {
return nil, err
}
return a.importConnectionPackagePayload(payload)
}
if isConnectionPackageEnvelope(trimmed) {
var file connectionPackageFile
if err := json.Unmarshal([]byte(trimmed), &file); err != nil {
return nil, errConnectionPackageUnsupported
}
payload, err := decryptConnectionPackage(file, password)
if err != nil {
return nil, err
}
return a.importConnectionPackagePayload(payload)
}
var legacy []connection.LegacySavedConnection
if err := json.Unmarshal([]byte(trimmed), &legacy); err != nil {
return nil, errConnectionPackageUnsupported
}
return a.ImportLegacyConnections(legacy)
}
type connectionPackageImportRollbackSnapshot struct {
connectionsFileExists bool
connectionsFileData []byte
connectionSecrets map[string]securityUpdateSecretSnapshot
connectionCleanupRefs []string
}
func captureConnectionImportRollbackSnapshot(a *App, inputs []connection.SavedConnectionInput) (connectionPackageImportRollbackSnapshot, error) {
snapshot := connectionPackageImportRollbackSnapshot{
connectionSecrets: make(map[string]securityUpdateSecretSnapshot),
}
repo := a.savedConnectionRepository()
connectionFileData, connectionFileExists, err := readOptionalFile(repo.connectionsPath())
if err != nil {
return snapshot, err
}
snapshot.connectionsFileExists = connectionFileExists
snapshot.connectionsFileData = connectionFileData
existingConnections, err := repo.load()
if err != nil {
return snapshot, err
}
existingConnectionsByID := make(map[string]connection.SavedConnectionView, len(existingConnections))
for _, item := range existingConnections {
existingConnectionsByID[item.ID] = item
}
cleanupSet := make(map[string]struct{})
seenIDs := make(map[string]struct{})
for _, input := range inputs {
connectionID := strings.TrimSpace(input.ID)
if connectionID == "" {
connectionID = strings.TrimSpace(input.Config.ID)
}
if connectionID == "" {
continue
}
if _, alreadySeen := seenIDs[connectionID]; alreadySeen {
continue
}
seenIDs[connectionID] = struct{}{}
defaultRef, refErr := secretstore.BuildRef(savedConnectionSecretKind, connectionID)
if refErr == nil {
cleanupSet[defaultRef] = struct{}{}
}
existing, ok := existingConnectionsByID[connectionID]
if !ok || !savedConnectionViewHasSecrets(existing) {
continue
}
ref := strings.TrimSpace(existing.SecretRef)
if ref == "" {
ref = defaultRef
}
if ref == "" {
continue
}
secretSnapshot, captureErr := captureSecurityUpdateSecretSnapshot(a.secretStore, ref)
if captureErr != nil {
return snapshot, captureErr
}
snapshot.connectionSecrets[ref] = secretSnapshot
cleanupSet[ref] = struct{}{}
}
snapshot.connectionCleanupRefs = make([]string, 0, len(cleanupSet))
for ref := range cleanupSet {
snapshot.connectionCleanupRefs = append(snapshot.connectionCleanupRefs, ref)
}
return snapshot, nil
}
func (s connectionPackageImportRollbackSnapshot) restore(a *App) error {
repo := a.savedConnectionRepository()
if err := restoreOptionalFile(repo.connectionsPath(), s.connectionsFileExists, s.connectionsFileData); err != nil {
return err
}
for ref, secretSnapshot := range s.connectionSecrets {
if err := restoreSecurityUpdateSecretSnapshot(a.secretStore, ref, secretSnapshot); err != nil {
return err
}
}
for _, ref := range s.connectionCleanupRefs {
if _, alreadyRestored := s.connectionSecrets[ref]; alreadyRestored {
continue
}
if err := deleteSecurityUpdateSecretRef(a.secretStore, ref); err != nil {
return err
}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,156 @@
package app
import (
"encoding/json"
"errors"
"strings"
"GoNavi-Wails/internal/connection"
)
const (
connectionPackageSchemaVersion = 1
connectionPackageSchemaVersionV2 = 2
connectionPackageKind = "gonavi_connection_package"
connectionPackageCipher = "AES-256-GCM"
connectionPackageKDFName = "Argon2id"
connectionPackageKDFNameV2 = "a2id"
connectionPackageExtension = ".gonavi-conn"
connectionPackageProtectionAppManaged = 1
connectionPackageProtectionPasswordProtected = 2
connectionPackageKDFDefaultMemoryKiB = 65536
connectionPackageKDFDefaultTimeCost = 3
connectionPackageKDFDefaultParallelism = 4
connectionPackageKDFMaxMemoryKiB = 262144
connectionPackageKDFMaxTimeCost = 10
connectionPackageKDFMaxParallelism = 16
connectionPackageMaxCiphertextBytes = 16 * 1024 * 1024
connectionPackageMaxPayloadBase64Bytes = ((connectionPackageMaxCiphertextBytes + 2) / 3) * 4
connectionImportMaxFileBytes = connectionPackageMaxPayloadBase64Bytes + (1 * 1024 * 1024)
)
var (
errConnectionPackagePasswordRequired = errors.New("恢复包密码不能为空")
errConnectionPackageDecryptFailed = errors.New("文件密码错误或文件已损坏")
errConnectionPackageUnsupported = errors.New("不支持的连接恢复包格式")
errConnectionImportFileTooLarge = errors.New("连接导入文件过大")
errConnectionPackagePayloadTooLarge = errors.New("连接恢复包过大")
errConnectionPackageNotImplemented = errors.New("connection package not implemented")
)
type connectionPackageFile struct {
SchemaVersion int `json:"schemaVersion"`
Kind string `json:"kind"`
Cipher string `json:"cipher"`
KDF connectionPackageKDFSpec `json:"kdf"`
Nonce string `json:"nonce"`
Payload string `json:"payload"`
}
type connectionPackageKDFSpec struct {
Name string `json:"name"`
MemoryKiB uint32 `json:"memoryKiB"`
TimeCost uint32 `json:"timeCost"`
Parallelism uint8 `json:"parallelism"`
Salt string `json:"salt"`
}
type connectionPackageFileV2 struct {
V int `json:"v"`
Kind string `json:"kind"`
P int `json:"p"`
ExportedAt string `json:"exportedAt,omitempty"`
Connections []connectionPackageItem `json:"connections"`
}
type connectionPackageFileV2Protected struct {
V int `json:"v"`
Kind string `json:"kind"`
P int `json:"p"`
KDF connectionPackageKDFSpecV2 `json:"kdf"`
NC string `json:"nc"`
D string `json:"d"`
}
type connectionPackageKDFSpecV2 struct {
N string `json:"n"`
M uint32 `json:"m"`
T uint32 `json:"t"`
L uint8 `json:"l"`
S string `json:"s"`
}
type connectionPackagePayload struct {
ExportedAt string `json:"exportedAt,omitempty"`
Connections []connectionPackageItem `json:"connections"`
}
type connectionPackageItem struct {
ID string `json:"id"`
Name string `json:"name"`
IncludeDatabases []string `json:"includeDatabases,omitempty"`
IncludeRedisDatabases []int `json:"includeRedisDatabases,omitempty"`
IconType string `json:"iconType,omitempty"`
IconColor string `json:"iconColor,omitempty"`
Config connection.ConnectionConfig `json:"config"`
Secrets connectionSecretBundle `json:"secrets,omitempty"`
}
func (i connectionPackageItem) MarshalJSON() ([]byte, error) {
type connectionPackageItemJSON struct {
ID string `json:"id"`
Name string `json:"name"`
IncludeDatabases []string `json:"includeDatabases,omitempty"`
IncludeRedisDatabases []int `json:"includeRedisDatabases,omitempty"`
IconType string `json:"iconType,omitempty"`
IconColor string `json:"iconColor,omitempty"`
Config connection.ConnectionConfig `json:"config"`
Secrets *connectionSecretBundle `json:"secrets,omitempty"`
}
item := connectionPackageItemJSON{
ID: i.ID,
Name: i.Name,
IncludeDatabases: i.IncludeDatabases,
IncludeRedisDatabases: i.IncludeRedisDatabases,
IconType: i.IconType,
IconColor: i.IconColor,
Config: i.Config,
}
if i.Secrets.hasAny() {
secrets := i.Secrets
item.Secrets = &secrets
}
return json.Marshal(item)
}
type ConnectionExportOptions struct {
IncludeSecrets bool `json:"includeSecrets"`
FilePassword string `json:"filePassword,omitempty"`
}
func defaultConnectionPackageKDFSpec() connectionPackageKDFSpec {
return connectionPackageKDFSpec{
Name: connectionPackageKDFName,
MemoryKiB: connectionPackageKDFDefaultMemoryKiB,
TimeCost: connectionPackageKDFDefaultTimeCost,
Parallelism: connectionPackageKDFDefaultParallelism,
}
}
func defaultConnectionPackageKDFSpecV2() connectionPackageKDFSpecV2 {
return connectionPackageKDFSpecV2{
N: connectionPackageKDFNameV2,
M: connectionPackageKDFDefaultMemoryKiB,
T: connectionPackageKDFDefaultTimeCost,
L: connectionPackageKDFDefaultParallelism,
}
}
func normalizeConnectionPackagePassword(password string) string {
return strings.TrimSpace(password)
}

View File

@@ -1,9 +1,13 @@
package app
import (
"errors"
"fmt"
"os"
"strings"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/secretstore"
)
func (a *App) resolveConnectionSecrets(config connection.ConnectionConfig) (connection.ConnectionConfig, error) {
@@ -14,7 +18,10 @@ func (a *App) resolveConnectionSecrets(config connection.ConnectionConfig) (conn
repo := newSavedConnectionRepository(a.configDir, a.secretStore)
view, err := repo.Find(config.ID)
if err != nil {
return config, err
if shouldFallbackToInlineConnectionSecrets(config, err) {
return config, nil
}
return config, normalizeConnectionSecretResolutionError(config, err)
}
base := config
@@ -23,13 +30,86 @@ func (a *App) resolveConnectionSecrets(config connection.ConnectionConfig) (conn
}
bundle, err := repo.loadSecretBundle(view)
if err != nil {
return base, err
if shouldFallbackToInlineConnectionSecrets(config, err) {
return mergeInlineConnectionSecrets(base, config), nil
}
return base, normalizeConnectionSecretResolutionError(base, err)
}
resolved := mergeConnectionSecretBundleIntoConfig(base, bundle)
resolved.ID = view.ID
return resolved, nil
}
func shouldFallbackToInlineConnectionSecrets(config connection.ConnectionConfig, err error) bool {
if err == nil || !connectionConfigCarriesInlineSecrets(config) || secretstore.IsUnavailable(err) {
return false
}
if errors.Is(err, os.ErrNotExist) {
return true
}
lower := strings.ToLower(strings.TrimSpace(err.Error()))
return strings.Contains(lower, "saved connection not found:")
}
func connectionConfigCarriesInlineSecrets(config connection.ConnectionConfig) bool {
return strings.TrimSpace(config.Password) != "" ||
strings.TrimSpace(config.SSH.Password) != "" ||
strings.TrimSpace(config.Proxy.Password) != "" ||
strings.TrimSpace(config.HTTPTunnel.Password) != "" ||
strings.TrimSpace(config.MySQLReplicaPassword) != "" ||
strings.TrimSpace(config.MongoReplicaPassword) != "" ||
strings.TrimSpace(config.URI) != "" ||
strings.TrimSpace(config.DSN) != ""
}
func mergeInlineConnectionSecrets(base connection.ConnectionConfig, inline connection.ConnectionConfig) connection.ConnectionConfig {
merged := base
if strings.TrimSpace(inline.Password) != "" {
merged.Password = inline.Password
}
if strings.TrimSpace(inline.SSH.Password) != "" {
merged.SSH.Password = inline.SSH.Password
}
if strings.TrimSpace(inline.Proxy.Password) != "" {
merged.Proxy.Password = inline.Proxy.Password
}
if strings.TrimSpace(inline.HTTPTunnel.Password) != "" {
merged.HTTPTunnel.Password = inline.HTTPTunnel.Password
}
if strings.TrimSpace(inline.MySQLReplicaPassword) != "" {
merged.MySQLReplicaPassword = inline.MySQLReplicaPassword
}
if strings.TrimSpace(inline.MongoReplicaPassword) != "" {
merged.MongoReplicaPassword = inline.MongoReplicaPassword
}
if strings.TrimSpace(inline.URI) != "" {
merged.URI = inline.URI
}
if strings.TrimSpace(inline.DSN) != "" {
merged.DSN = inline.DSN
}
return merged
}
func normalizeConnectionSecretResolutionError(config connection.ConnectionConfig, err error) error {
if err == nil {
return nil
}
lower := strings.ToLower(strings.TrimSpace(err.Error()))
switch {
case strings.Contains(lower, "saved connection not found:"):
if connectionMetadataLooksEmpty(config) {
return fmt.Errorf("未找到已保存连接,可能已被删除,请刷新后重试")
}
return fmt.Errorf("未找到当前连接对应的已保存密文,请重新填写密码并保存后再试")
case strings.Contains(lower, "secret store unavailable"):
return fmt.Errorf("系统密文存储当前不可用,请检查系统钥匙串或凭据管理器后再试")
default:
return err
}
}
func connectionMetadataLooksEmpty(config connection.ConnectionConfig) bool {
return strings.TrimSpace(config.Type) == "" &&
strings.TrimSpace(config.Host) == "" &&

View File

@@ -1,6 +1,7 @@
package app
import (
"strings"
"testing"
"GoNavi-Wails/internal/connection"
@@ -40,3 +41,98 @@ func TestResolveConnectionConfigByIDLoadsSecretsFromStore(t *testing.T) {
t.Fatalf("expected restored DSN, got %q", resolved.DSN)
}
}
func TestResolveConnectionSecretsReturnsFriendlyMessageWhenSavedSecretSourceIsMissing(t *testing.T) {
store := newFakeAppSecretStore()
app := NewAppWithSecretStore(store)
app.configDir = t.TempDir()
_, err := app.resolveConnectionSecrets(connection.ConnectionConfig{
ID: "conn-missing",
Type: "postgres",
Host: "db.local",
Port: 5432,
User: "postgres",
})
if err == nil {
t.Fatal("expected resolveConnectionSecrets to fail for a missing saved connection")
}
if !strings.Contains(err.Error(), "已保存密文") {
t.Fatalf("expected a secret-specific error message, got %q", err.Error())
}
}
func TestResolveConnectionSecretsFallsBackToInlineSecretsWhenSavedConnectionIsMissing(t *testing.T) {
store := newFakeAppSecretStore()
app := NewAppWithSecretStore(store)
app.configDir = t.TempDir()
input := connection.ConnectionConfig{
ID: "legacy-inline",
Type: "postgres",
Host: "db.local",
Port: 5432,
User: "postgres",
Password: "inline-secret",
DSN: "postgres://postgres:inline-secret@db.local/app",
}
resolved, err := app.resolveConnectionSecrets(input)
if err != nil {
t.Fatalf("expected inline secrets to be used as fallback, got error: %v", err)
}
if resolved.Password != "inline-secret" {
t.Fatalf("expected inline password to be preserved, got %q", resolved.Password)
}
if resolved.DSN != "postgres://postgres:inline-secret@db.local/app" {
t.Fatalf("expected inline DSN to be preserved, got %q", resolved.DSN)
}
}
func TestResolveConnectionSecretsFallsBackToInlineSecretsWhenSavedSecretBundleIsMissing(t *testing.T) {
store := newFakeAppSecretStore()
app := NewAppWithSecretStore(store)
app.configDir = t.TempDir()
view, err := app.SaveConnection(connection.SavedConnectionInput{
ID: "conn-inline-fallback",
Name: "Primary",
Config: connection.ConnectionConfig{
ID: "conn-inline-fallback",
Type: "postgres",
Host: "db.local",
Port: 5432,
User: "postgres",
Password: "stored-secret",
DSN: "postgres://postgres:stored-secret@db.local/app",
},
})
if err != nil {
t.Fatalf("SaveConnection returned error: %v", err)
}
if view.SecretRef == "" {
t.Fatal("expected saved connection to allocate a secret ref")
}
if err := store.Delete(view.SecretRef); err != nil {
t.Fatalf("Delete returned error: %v", err)
}
resolved, err := app.resolveConnectionSecrets(connection.ConnectionConfig{
ID: "conn-inline-fallback",
Type: "postgres",
Host: "db.local",
Port: 5432,
User: "postgres",
Password: "inline-secret",
DSN: "postgres://postgres:inline-secret@db.local/app",
})
if err != nil {
t.Fatalf("expected inline secrets to be used when secret bundle is missing, got error: %v", err)
}
if resolved.Password != "inline-secret" {
t.Fatalf("expected inline password to be preserved, got %q", resolved.Password)
}
if resolved.DSN != "postgres://postgres:inline-secret@db.local/app" {
t.Fatalf("expected inline DSN to be preserved, got %q", resolved.DSN)
}
}

5
internal/app/env.go Normal file
View File

@@ -0,0 +1,5 @@
package app
import "os"
var getenv = os.Getenv

View File

@@ -259,10 +259,30 @@ func (cr *countingReader) Read(p []byte) (int, error) {
return n, err
}
func readImportedConnectionConfigFile(path string) (string, error) {
info, err := os.Stat(path)
if err != nil {
return "", err
}
if info.Size() > connectionImportMaxFileBytes {
return "", errConnectionImportFileTooLarge
}
content, err := os.ReadFile(path)
if err != nil {
return "", err
}
return string(content), nil
}
func (a *App) ImportConfigFile() connection.QueryResult {
selection, err := runtime.OpenFileDialog(a.ctx, runtime.OpenDialogOptions{
Title: "Select Config File",
Filters: []runtime.FileFilter{
{
DisplayName: "GoNavi Connection Package (*.gonavi-conn)",
Pattern: "*.gonavi-conn",
},
{
DisplayName: "JSON Files (*.json)",
Pattern: "*.json",
@@ -278,12 +298,52 @@ func (a *App) ImportConfigFile() connection.QueryResult {
return connection.QueryResult{Success: false, Message: "已取消"}
}
content, err := os.ReadFile(selection)
content, err := readImportedConnectionConfigFile(selection)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Data: string(content)}
return connection.QueryResult{Success: true, Data: content}
}
func (a *App) ExportConnectionsPackage(options ConnectionExportOptions) connection.QueryResult {
filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{
Title: "Export Connections",
DefaultFilename: "connections" + connectionPackageExtension,
Filters: []runtime.FileFilter{
{
DisplayName: "GoNavi Connection Package (*.gonavi-conn)",
Pattern: "*.gonavi-conn",
},
},
})
if err != nil || strings.TrimSpace(filename) == "" {
return connection.QueryResult{Success: false, Message: "已取消"}
}
filename = normalizeConnectionPackageExportFilename(filename)
content, err := a.buildExportedConnectionPackage(options)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
if len(content) > connectionImportMaxFileBytes {
return connection.QueryResult{Success: false, Message: errConnectionImportFileTooLarge.Error()}
}
if err := os.WriteFile(filename, content, 0o644); err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Message: "导出完成"}
}
func normalizeConnectionPackageExportFilename(filename string) string {
trimmed := strings.TrimSpace(filename)
if trimmed == "" {
return ""
}
if strings.EqualFold(filepath.Ext(trimmed), connectionPackageExtension) {
return trimmed
}
return trimmed + connectionPackageExtension
}
func (a *App) SelectSSHKeyFile(currentPath string) connection.QueryResult {

View File

@@ -0,0 +1,33 @@
package app
import (
"errors"
"os"
"path/filepath"
"testing"
)
func TestReadImportedConnectionConfigFileRejectsOversizedFiles(t *testing.T) {
for _, ext := range []string{connectionPackageExtension, ".json"} {
t.Run(ext, func(t *testing.T) {
path := filepath.Join(t.TempDir(), "connections"+ext)
file, err := os.Create(path)
if err != nil {
t.Fatalf("Create returned error: %v", err)
}
if err := file.Truncate(connectionImportMaxFileBytes + 1); err != nil {
file.Close()
t.Fatalf("Truncate returned error: %v", err)
}
if err := file.Close(); err != nil {
t.Fatalf("Close returned error: %v", err)
}
_, err = readImportedConnectionConfigFile(path)
if !errors.Is(err, errConnectionImportFileTooLarge) {
t.Fatalf("oversized import file should return errConnectionImportFileTooLarge, got: %v", err)
}
})
}
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"math"
"net/url"
"strconv"
"strings"
"sync"
@@ -19,12 +20,20 @@ import (
var (
redisCache = make(map[string]redis.RedisClient)
redisCacheMu sync.Mutex
newRedisClientFunc = redis.NewRedisClient
)
// getRedisClient gets or creates a Redis client from cache
func (a *App) getRedisClient(config connection.ConnectionConfig) (redis.RedisClient, error) {
effectiveConfig := applyGlobalProxyToConnection(config)
connectConfig, proxyErr := resolveDialConfigWithProxy(effectiveConfig)
resolvedConfig, err := a.resolveConnectionSecrets(config)
if err != nil {
wrapped := wrapConnectError(config, err)
logger.Error(wrapped, "Redis 密文解析失败:%s", formatRedisConnSummary(config))
return nil, wrapped
}
effectiveConfig := applyGlobalProxyToConnection(resolvedConfig)
connectConfig, proxyErr := resolveDialConfigWithProxyFunc(effectiveConfig)
if proxyErr != nil {
wrapped := wrapConnectError(effectiveConfig, proxyErr)
logger.Error(wrapped, "Redis 代理准备失败:%s", formatRedisConnSummary(effectiveConfig))
@@ -54,18 +63,78 @@ func (a *App) getRedisClient(config connection.ConnectionConfig) (redis.RedisCli
}
logger.Infof("创建 Redis 客户端实例缓存Key=%s", shortKey)
client := redis.NewRedisClient()
if err := client.Connect(connectConfig); err != nil {
wrapped := wrapConnectError(effectiveConfig, err)
logger.Error(wrapped, "Redis 连接失败:%s 缓存Key=%s", formatRedisConnSummary(effectiveConfig), shortKey)
client, connectedConfig, connectErr := connectRedisClientWithLegacyRootFallback(connectConfig)
if connectErr != nil {
wrapped := wrapConnectError(connectedConfig, connectErr)
logger.Error(wrapped, "Redis 连接失败:%s 缓存Key=%s", formatRedisConnSummary(connectedConfig), shortKey)
return nil, wrapped
}
redisCache[key] = client
logger.Infof("Redis 连接成功并写入缓存:%s 缓存Key=%s", formatRedisConnSummary(effectiveConfig), shortKey)
logger.Infof("Redis 连接成功并写入缓存:%s 缓存Key=%s", formatRedisConnSummary(connectedConfig), shortKey)
return client, nil
}
func connectRedisClientWithLegacyRootFallback(config connection.ConnectionConfig) (redis.RedisClient, connection.ConnectionConfig, error) {
client := newRedisClientFunc()
if err := client.Connect(config); err == nil {
return client, config, nil
} else {
client.Close()
if !shouldRetryRedisWithClearedLegacyRoot(config, err) {
return nil, config, err
}
fallbackConfig := config
fallbackConfig.User = ""
logger.Warnf("Redis 使用用户名 root 认证失败,已按历史默认值回退为空用户名重试:%s", formatRedisConnSummary(config))
fallbackClient := newRedisClientFunc()
if retryErr := fallbackClient.Connect(fallbackConfig); retryErr != nil {
fallbackClient.Close()
return nil, fallbackConfig, retryErr
}
return fallbackClient, fallbackConfig, nil
}
}
func shouldRetryRedisWithClearedLegacyRoot(config connection.ConnectionConfig, err error) bool {
if err == nil || strings.ToLower(strings.TrimSpace(config.Type)) != "redis" {
return false
}
if strings.TrimSpace(config.User) != "root" {
return false
}
if _, ok := extractExplicitRedisUsername(config.URI); ok {
return false
}
lower := strings.ToLower(strings.TrimSpace(err.Error()))
return strings.Contains(lower, "wrongpass") ||
strings.Contains(lower, "invalid username-password pair") ||
strings.Contains(lower, "auth failed") ||
strings.Contains(lower, "wrong number of arguments for 'auth' command") ||
strings.Contains(lower, "authentication failed")
}
func extractExplicitRedisUsername(rawURI string) (string, bool) {
trimmed := strings.TrimSpace(rawURI)
if trimmed == "" {
return "", false
}
parsed, err := url.Parse(trimmed)
if err != nil || parsed.User == nil {
return "", false
}
username := strings.TrimSpace(parsed.User.Username())
if username == "" {
return "", false
}
return username, true
}
func getRedisClientCacheKey(config connection.ConnectionConfig) string {
normalized := normalizeCacheKeyConfig(config)
b, _ := json.Marshal(normalized)

View File

@@ -0,0 +1,428 @@
package app
import (
"errors"
"testing"
"GoNavi-Wails/internal/connection"
redislib "GoNavi-Wails/internal/redis"
)
type capturingRedisClient struct {
connectConfig connection.ConnectionConfig
}
func (c *capturingRedisClient) Connect(config connection.ConnectionConfig) error {
c.connectConfig = config
return nil
}
func (c *capturingRedisClient) Close() error { return nil }
func (c *capturingRedisClient) Ping() error { return nil }
func (c *capturingRedisClient) ScanKeys(pattern string, cursor uint64, count int64) (*redislib.RedisScanResult, error) {
return &redislib.RedisScanResult{}, nil
}
func (c *capturingRedisClient) GetKeyType(key string) (string, error) { return "", nil }
func (c *capturingRedisClient) GetTTL(key string) (int64, error) { return 0, nil }
func (c *capturingRedisClient) SetTTL(key string, ttl int64) error { return nil }
func (c *capturingRedisClient) DeleteKeys(keys []string) (int64, error) { return 0, nil }
func (c *capturingRedisClient) RenameKey(oldKey, newKey string) error { return nil }
func (c *capturingRedisClient) KeyExists(key string) (bool, error) { return false, nil }
func (c *capturingRedisClient) GetValue(key string) (*redislib.RedisValue, error) {
return &redislib.RedisValue{}, nil
}
func (c *capturingRedisClient) GetString(key string) (string, error) { return "", nil }
func (c *capturingRedisClient) SetString(key, value string, ttl int64) error { return nil }
func (c *capturingRedisClient) GetHash(key string) (map[string]string, error) { return map[string]string{}, nil }
func (c *capturingRedisClient) SetHashField(key, field, value string) error { return nil }
func (c *capturingRedisClient) DeleteHashField(key string, fields ...string) error { return nil }
func (c *capturingRedisClient) GetList(key string, start, stop int64) ([]string, error) { return nil, nil }
func (c *capturingRedisClient) ListPush(key string, values ...string) error { return nil }
func (c *capturingRedisClient) ListSet(key string, index int64, value string) error { return nil }
func (c *capturingRedisClient) GetSet(key string) ([]string, error) { return nil, nil }
func (c *capturingRedisClient) SetAdd(key string, members ...string) error { return nil }
func (c *capturingRedisClient) SetRemove(key string, members ...string) error { return nil }
func (c *capturingRedisClient) GetZSet(key string, start, stop int64) ([]redislib.ZSetMember, error) {
return nil, nil
}
func (c *capturingRedisClient) ZSetAdd(key string, members ...redislib.ZSetMember) error { return nil }
func (c *capturingRedisClient) ZSetRemove(key string, members ...string) error { return nil }
func (c *capturingRedisClient) GetStream(key, start, stop string, count int64) ([]redislib.StreamEntry, error) {
return nil, nil
}
func (c *capturingRedisClient) StreamAdd(key string, fields map[string]string, id string) (string, error) {
return "", nil
}
func (c *capturingRedisClient) StreamDelete(key string, ids ...string) (int64, error) { return 0, nil }
func (c *capturingRedisClient) ExecuteCommand(args []string) (interface{}, error) { return nil, nil }
func (c *capturingRedisClient) GetServerInfo() (map[string]string, error) { return map[string]string{}, nil }
func (c *capturingRedisClient) GetDatabases() ([]redislib.RedisDBInfo, error) { return nil, nil }
func (c *capturingRedisClient) SelectDB(index int) error { return nil }
func (c *capturingRedisClient) GetCurrentDB() int { return 0 }
func (c *capturingRedisClient) FlushDB() error { return nil }
type scriptedRedisClient struct {
capturingRedisClient
connectErr error
connectCalls *[]connection.ConnectionConfig
}
func (c *scriptedRedisClient) Connect(config connection.ConnectionConfig) error {
c.connectConfig = config
if c.connectCalls != nil {
*c.connectCalls = append(*c.connectCalls, config)
}
return c.connectErr
}
func TestRedisConnectResolvesSavedSecretsByConnectionID(t *testing.T) {
testCases := []struct {
name string
savedConfig connection.ConnectionConfig
runtimeConfig connection.ConnectionConfig
assertResolved func(t *testing.T, got connection.ConnectionConfig)
}{
{
name: "redis and ssh secrets",
savedConfig: connection.ConnectionConfig{
ID: "redis-1",
Type: "redis",
Host: "redis.local",
Port: 6379,
Password: "redis-secret",
UseSSH: true,
SSH: connection.SSHConfig{
Host: "ssh.local",
Port: 22,
User: "ops",
Password: "ssh-secret",
},
},
runtimeConfig: connection.ConnectionConfig{
ID: "redis-1",
Type: "redis",
Host: "redis.local",
Port: 6379,
UseSSH: true,
SSH: connection.SSHConfig{
Host: "ssh.local",
Port: 22,
User: "ops",
},
},
assertResolved: func(t *testing.T, got connection.ConnectionConfig) {
t.Helper()
if got.Password != "redis-secret" {
t.Fatalf("expected RedisConnect to resolve saved Redis password, got %q", got.Password)
}
if got.SSH.Password != "ssh-secret" {
t.Fatalf("expected RedisConnect to resolve saved SSH password, got %q", got.SSH.Password)
}
},
},
{
name: "proxy secret",
savedConfig: connection.ConnectionConfig{
ID: "redis-1",
Type: "redis",
Host: "redis.local",
Port: 6379,
Password: "redis-secret",
UseProxy: true,
Proxy: connection.ProxyConfig{
Type: "http",
Host: "proxy.local",
Port: 8080,
User: "proxy-user",
Password: "proxy-secret",
},
},
runtimeConfig: connection.ConnectionConfig{
ID: "redis-1",
Type: "redis",
Host: "redis.local",
Port: 6379,
UseProxy: true,
Proxy: connection.ProxyConfig{
Type: "http",
Host: "proxy.local",
Port: 8080,
User: "proxy-user",
},
},
assertResolved: func(t *testing.T, got connection.ConnectionConfig) {
t.Helper()
if got.Password != "redis-secret" {
t.Fatalf("expected RedisConnect to resolve saved Redis password, got %q", got.Password)
}
if got.Proxy.Password != "proxy-secret" {
t.Fatalf("expected RedisConnect to resolve saved proxy password, got %q", got.Proxy.Password)
}
},
},
{
name: "http tunnel secret",
savedConfig: connection.ConnectionConfig{
ID: "redis-1",
Type: "redis",
Host: "redis.local",
Port: 6379,
Password: "redis-secret",
UseHTTPTunnel: true,
HTTPTunnel: connection.HTTPTunnelConfig{
Host: "tunnel.local",
Port: 8443,
User: "tunnel-user",
Password: "tunnel-secret",
},
},
runtimeConfig: connection.ConnectionConfig{
ID: "redis-1",
Type: "redis",
Host: "redis.local",
Port: 6379,
UseHTTPTunnel: true,
HTTPTunnel: connection.HTTPTunnelConfig{
Host: "tunnel.local",
Port: 8443,
User: "tunnel-user",
},
},
assertResolved: func(t *testing.T, got connection.ConnectionConfig) {
t.Helper()
if got.Password != "redis-secret" {
t.Fatalf("expected RedisConnect to resolve saved Redis password, got %q", got.Password)
}
if got.HTTPTunnel.Password != "tunnel-secret" {
t.Fatalf("expected RedisConnect to resolve saved HTTP tunnel password, got %q", got.HTTPTunnel.Password)
}
},
},
{
name: "explicit redis username from uri is preserved even when it is root",
savedConfig: connection.ConnectionConfig{
ID: "redis-1",
Type: "redis",
Host: "redis.local",
Port: 6379,
User: "root",
Password: "redis-secret",
URI: "redis://root:redis-secret@redis.local:6379/0",
},
runtimeConfig: connection.ConnectionConfig{
ID: "redis-1",
Type: "redis",
Host: "redis.local",
Port: 6379,
User: "root",
},
assertResolved: func(t *testing.T, got connection.ConnectionConfig) {
t.Helper()
if got.User != "root" {
t.Fatalf("expected RedisConnect to preserve explicit uri user root, got %q", got.User)
}
if got.URI != "redis://root:redis-secret@redis.local:6379/0" {
t.Fatalf("expected RedisConnect to restore saved redis uri, got %q", got.URI)
}
},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
app := NewAppWithSecretStore(newFakeAppSecretStore())
app.configDir = t.TempDir()
_, err := app.SaveConnection(connection.SavedConnectionInput{
ID: "redis-1",
Name: "Redis Saved",
Config: testCase.savedConfig,
})
if err != nil {
t.Fatalf("SaveConnection returned error: %v", err)
}
CloseAllRedisClients()
client := &capturingRedisClient{}
originalNewRedisClientFunc := newRedisClientFunc
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
defer func() {
newRedisClientFunc = originalNewRedisClientFunc
resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc
CloseAllRedisClients()
}()
newRedisClientFunc = func() redislib.RedisClient {
return client
}
resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) {
return raw, nil
}
result := app.RedisConnect(testCase.runtimeConfig)
if !result.Success {
t.Fatalf("RedisConnect returned failure: %+v", result)
}
testCase.assertResolved(t, client.connectConfig)
})
}
}
func TestRedisConnectPreservesExplicitRootUserWithoutURIWhenConnectSucceeds(t *testing.T) {
app := NewAppWithSecretStore(newFakeAppSecretStore())
app.configDir = t.TempDir()
_, err := app.SaveConnection(connection.SavedConnectionInput{
ID: "redis-1",
Name: "Redis Saved",
Config: connection.ConnectionConfig{
ID: "redis-1",
Type: "redis",
Host: "redis.local",
Port: 6379,
User: "root",
Password: "redis-secret",
},
})
if err != nil {
t.Fatalf("SaveConnection returned error: %v", err)
}
CloseAllRedisClients()
connectCalls := make([]connection.ConnectionConfig, 0, 1)
client := &scriptedRedisClient{connectCalls: &connectCalls}
originalNewRedisClientFunc := newRedisClientFunc
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
defer func() {
newRedisClientFunc = originalNewRedisClientFunc
resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc
CloseAllRedisClients()
}()
newRedisClientFunc = func() redislib.RedisClient {
return client
}
resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) {
return raw, nil
}
result := app.RedisConnect(connection.ConnectionConfig{
ID: "redis-1",
Type: "redis",
Host: "redis.local",
Port: 6379,
User: "root",
})
if !result.Success {
t.Fatalf("RedisConnect returned failure: %+v", result)
}
if len(connectCalls) != 1 {
t.Fatalf("expected exactly one Redis connect attempt, got %d", len(connectCalls))
}
if connectCalls[0].User != "root" {
t.Fatalf("expected RedisConnect to preserve explicit root user when connect succeeds, got %q", connectCalls[0].User)
}
}
func TestRedisConnectRetriesLegacyDefaultRootUserWithoutUsernameAfterAuthFailure(t *testing.T) {
app := NewAppWithSecretStore(newFakeAppSecretStore())
app.configDir = t.TempDir()
_, err := app.SaveConnection(connection.SavedConnectionInput{
ID: "redis-1",
Name: "Redis Saved",
Config: connection.ConnectionConfig{
ID: "redis-1",
Type: "redis",
Host: "redis.local",
Port: 6379,
User: "root",
Password: "redis-secret",
},
})
if err != nil {
t.Fatalf("SaveConnection returned error: %v", err)
}
CloseAllRedisClients()
connectCalls := make([]connection.ConnectionConfig, 0, 2)
clients := []redislib.RedisClient{
&scriptedRedisClient{
connectErr: errors.New("WRONGPASS invalid username-password pair"),
connectCalls: &connectCalls,
},
&scriptedRedisClient{
connectCalls: &connectCalls,
},
}
clientIndex := 0
originalNewRedisClientFunc := newRedisClientFunc
originalResolveDialConfigWithProxyFunc := resolveDialConfigWithProxyFunc
defer func() {
newRedisClientFunc = originalNewRedisClientFunc
resolveDialConfigWithProxyFunc = originalResolveDialConfigWithProxyFunc
CloseAllRedisClients()
}()
newRedisClientFunc = func() redislib.RedisClient {
if clientIndex >= len(clients) {
t.Fatalf("unexpected Redis client allocation #%d", clientIndex+1)
}
client := clients[clientIndex]
clientIndex++
return client
}
resolveDialConfigWithProxyFunc = func(raw connection.ConnectionConfig) (connection.ConnectionConfig, error) {
return raw, nil
}
result := app.RedisConnect(connection.ConnectionConfig{
ID: "redis-1",
Type: "redis",
Host: "redis.local",
Port: 6379,
User: "root",
})
if !result.Success {
t.Fatalf("RedisConnect returned failure after fallback: %+v", result)
}
if len(connectCalls) != 2 {
t.Fatalf("expected RedisConnect to retry exactly once after auth failure, got %d attempts", len(connectCalls))
}
if connectCalls[0].User != "root" {
t.Fatalf("expected first Redis connect attempt to keep root user, got %q", connectCalls[0].User)
}
if connectCalls[1].User != "" {
t.Fatalf("expected fallback Redis connect attempt to clear legacy root user, got %q", connectCalls[1].User)
}
}

View File

@@ -1,6 +1,10 @@
package app
import "GoNavi-Wails/internal/connection"
import (
"strings"
"GoNavi-Wails/internal/connection"
)
func (a *App) savedConnectionRepository() *savedConnectionRepository {
return newSavedConnectionRepository(a.configDir, a.secretStore)
@@ -23,16 +27,20 @@ func (a *App) DuplicateConnection(id string) (connection.SavedConnectionView, er
}
func (a *App) ImportLegacyConnections(items []connection.LegacySavedConnection) ([]connection.SavedConnectionView, error) {
result := make([]connection.SavedConnectionView, 0, len(items))
repo := a.savedConnectionRepository()
inputs := make([]connection.SavedConnectionInput, 0, len(items))
for _, item := range items {
view, err := repo.Save(connection.SavedConnectionInput(item))
if err != nil {
return nil, err
}
result = append(result, view)
input := connection.SavedConnectionInput(item)
input.ClearPrimaryPassword = strings.TrimSpace(item.Config.Password) == ""
input.ClearSSHPassword = strings.TrimSpace(item.Config.SSH.Password) == ""
input.ClearProxyPassword = strings.TrimSpace(item.Config.Proxy.Password) == ""
input.ClearHTTPTunnelPassword = strings.TrimSpace(item.Config.HTTPTunnel.Password) == ""
input.ClearMySQLReplicaPassword = strings.TrimSpace(item.Config.MySQLReplicaPassword) == ""
input.ClearMongoReplicaPassword = strings.TrimSpace(item.Config.MongoReplicaPassword) == ""
input.ClearOpaqueURI = strings.TrimSpace(item.Config.URI) == ""
input.ClearOpaqueDSN = strings.TrimSpace(item.Config.DSN) == ""
inputs = append(inputs, input)
}
return result, nil
return a.importSavedConnectionsAtomically(inputs)
}
func (a *App) SaveGlobalProxy(input connection.SaveGlobalProxyInput) (connection.GlobalProxyView, error) {

View File

@@ -185,3 +185,89 @@ func TestSaveGlobalProxyReturnsSecretlessView(t *testing.T) {
t.Fatal("expected hasPassword=true")
}
}
func TestImportLegacyConnectionsIsIdempotentForSameID(t *testing.T) {
app := NewAppWithSecretStore(newFakeAppSecretStore())
app.configDir = t.TempDir()
legacy := connection.LegacySavedConnection{
ID: "legacy-1",
Name: "Legacy",
Config: connection.ConnectionConfig{
ID: "legacy-1",
Type: "postgres",
Host: "db.local",
Port: 5432,
User: "postgres",
Password: "secret-1",
},
}
if _, err := app.ImportLegacyConnections([]connection.LegacySavedConnection{legacy}); err != nil {
t.Fatalf("first ImportLegacyConnections returned error: %v", err)
}
if _, err := app.ImportLegacyConnections([]connection.LegacySavedConnection{legacy}); err != nil {
t.Fatalf("second ImportLegacyConnections returned error: %v", err)
}
saved, err := app.GetSavedConnections()
if err != nil {
t.Fatalf("GetSavedConnections returned error: %v", err)
}
if len(saved) != 1 {
t.Fatalf("expected a single saved connection after repeated import, got %d", len(saved))
}
}
func TestImportLegacyConnectionsClearsExistingSecretWhenReimportOmitsPassword(t *testing.T) {
app := NewAppWithSecretStore(newFakeAppSecretStore())
app.configDir = t.TempDir()
if _, err := app.ImportLegacyConnections([]connection.LegacySavedConnection{
{
ID: "legacy-1",
Name: "Legacy",
Config: connection.ConnectionConfig{
ID: "legacy-1",
Type: "postgres",
Host: "db.local",
Port: 5432,
User: "postgres",
Password: "secret-1",
},
},
}); err != nil {
t.Fatalf("initial ImportLegacyConnections returned error: %v", err)
}
if _, err := app.ImportLegacyConnections([]connection.LegacySavedConnection{
{
ID: "legacy-1",
Name: "Legacy Updated",
Config: connection.ConnectionConfig{
ID: "legacy-1",
Type: "postgres",
Host: "db.local",
Port: 5432,
User: "postgres",
},
},
}); err != nil {
t.Fatalf("update ImportLegacyConnections returned error: %v", err)
}
saved, err := app.GetSavedConnections()
if err != nil {
t.Fatalf("GetSavedConnections returned error: %v", err)
}
if len(saved) != 1 {
t.Fatalf("expected 1 saved connection, got %d", len(saved))
}
resolved, err := app.resolveConnectionSecrets(saved[0].Config)
if err != nil {
t.Fatalf("resolveConnectionSecrets returned error: %v", err)
}
if resolved.Password != "" {
t.Fatalf("expected missing import password to clear existing secret, got %q", resolved.Password)
}
}

View File

@@ -30,6 +30,12 @@ const (
updateDownloadProgressEvent = "update:download-progress"
)
var (
updateFetchLatestRelease = fetchLatestRelease
updateFetchReleaseSHA256 = fetchReleaseSHA256
updateLogCheckError = func(err error) { logger.Error(err, "检查更新失败") }
)
type updateState struct {
lastCheck *UpdateInfo
downloading bool
@@ -100,9 +106,19 @@ type githubAsset struct {
}
func (a *App) CheckForUpdates() connection.QueryResult {
return a.checkForUpdates(true)
}
func (a *App) CheckForUpdatesSilently() connection.QueryResult {
return a.checkForUpdates(false)
}
func (a *App) checkForUpdates(logFailure bool) connection.QueryResult {
info, err := fetchLatestUpdateInfo()
if err != nil {
logger.Error(err, "检查更新失败")
if logFailure {
updateLogCheckError(err)
}
return connection.QueryResult{Success: false, Message: err.Error()}
}
@@ -359,7 +375,7 @@ func (a *App) downloadAndStageUpdate(info UpdateInfo) connection.QueryResult {
}
func fetchLatestUpdateInfo() (UpdateInfo, error) {
release, err := fetchLatestRelease()
release, err := updateFetchLatestRelease()
if err != nil {
return UpdateInfo{}, err
}
@@ -370,6 +386,17 @@ func fetchLatestUpdateInfo() (UpdateInfo, error) {
return UpdateInfo{}, errors.New("无法解析最新版本号")
}
hasUpdate := compareVersion(currentVersion, latestVersion) < 0
if !hasUpdate {
return UpdateInfo{
HasUpdate: false,
CurrentVersion: currentVersion,
LatestVersion: latestVersion,
ReleaseName: release.Name,
ReleaseNotesURL: release.HTMLURL,
}, nil
}
assetVersion := strings.TrimSpace(release.TagName)
if assetVersion == "" {
assetVersion = latestVersion
@@ -383,7 +410,7 @@ func fetchLatestUpdateInfo() (UpdateInfo, error) {
return UpdateInfo{}, err
}
hashMap, err := fetchReleaseSHA256(release.Assets)
hashMap, err := updateFetchReleaseSHA256(release.Assets)
if err != nil {
return UpdateInfo{}, err
}
@@ -391,9 +418,6 @@ func fetchLatestUpdateInfo() (UpdateInfo, error) {
if sha256Value == "" {
return UpdateInfo{}, errors.New("SHA256SUMS 未包含当前平台更新包")
}
hasUpdate := compareVersion(currentVersion, latestVersion) < 0
return UpdateInfo{
HasUpdate: hasUpdate,
CurrentVersion: currentVersion,
@@ -407,6 +431,30 @@ func fetchLatestUpdateInfo() (UpdateInfo, error) {
}, nil
}
func swapUpdateFetchLatestRelease(next func() (*githubRelease, error)) func() {
original := updateFetchLatestRelease
updateFetchLatestRelease = next
return func() {
updateFetchLatestRelease = original
}
}
func swapUpdateFetchReleaseSHA256(next func([]githubAsset) (map[string]string, error)) func() {
original := updateFetchReleaseSHA256
updateFetchReleaseSHA256 = next
return func() {
updateFetchReleaseSHA256 = original
}
}
func swapUpdateCheckErrorLogger(next func(error)) func() {
original := updateLogCheckError
updateLogCheckError = next
return func() {
updateLogCheckError = original
}
}
func getCurrentAuthor() string {
if env := strings.TrimSpace(os.Getenv("GONAVI_AUTHOR")); env != "" {
return env

View File

@@ -0,0 +1,160 @@
package app
import (
"errors"
stdRuntime "runtime"
"testing"
)
func TestFetchLatestUpdateInfoSkipsChecksumWhenCurrentVersionIsAlreadyLatest(t *testing.T) {
assetName, err := expectedAssetName(stdRuntime.GOOS, stdRuntime.GOARCH, "v0.6.5")
if err != nil {
t.Fatalf("expectedAssetName returned error: %v", err)
}
originalVersion := AppVersion
AppVersion = "0.6.5"
defer func() {
AppVersion = originalVersion
}()
releaseCalled := false
restoreRelease := swapUpdateFetchLatestRelease(func() (*githubRelease, error) {
releaseCalled = true
return &githubRelease{
TagName: "v0.6.5",
Name: "v0.6.5",
HTMLURL: "https://github.com/Syngnat/GoNavi/releases/tag/v0.6.5",
Assets: []githubAsset{
{
Name: assetName,
BrowserDownloadURL: "https://example.com/" + assetName,
Size: 1024,
},
},
}, nil
})
defer restoreRelease()
checksumCalled := false
restoreChecksum := swapUpdateFetchReleaseSHA256(func([]githubAsset) (map[string]string, error) {
checksumCalled = true
return nil, errors.New("checksum should not be fetched when no update is needed")
})
defer restoreChecksum()
info, err := fetchLatestUpdateInfo()
if err != nil {
t.Fatalf("fetchLatestUpdateInfo returned error: %v", err)
}
if !releaseCalled {
t.Fatal("expected latest release metadata to be fetched")
}
if checksumCalled {
t.Fatal("expected SHA256SUMS fetch to be skipped when current version is already latest")
}
if info.HasUpdate {
t.Fatalf("expected HasUpdate=false, got %#v", info)
}
if info.LatestVersion != "0.6.5" || info.CurrentVersion != "0.6.5" {
t.Fatalf("unexpected version info: %#v", info)
}
}
func TestFetchLatestUpdateInfoFetchesChecksumWhenUpdateIsAvailable(t *testing.T) {
assetName, err := expectedAssetName(stdRuntime.GOOS, stdRuntime.GOARCH, "v0.6.5")
if err != nil {
t.Fatalf("expectedAssetName returned error: %v", err)
}
originalVersion := AppVersion
AppVersion = "0.6.4"
defer func() {
AppVersion = originalVersion
}()
restoreRelease := swapUpdateFetchLatestRelease(func() (*githubRelease, error) {
return &githubRelease{
TagName: "v0.6.5",
Name: "v0.6.5",
HTMLURL: "https://github.com/Syngnat/GoNavi/releases/tag/v0.6.5",
Assets: []githubAsset{
{
Name: assetName,
BrowserDownloadURL: "https://example.com/" + assetName,
Size: 4096,
},
},
}, nil
})
defer restoreRelease()
checksumCalled := false
restoreChecksum := swapUpdateFetchReleaseSHA256(func([]githubAsset) (map[string]string, error) {
checksumCalled = true
return map[string]string{
assetName: "abc123",
}, nil
})
defer restoreChecksum()
info, err := fetchLatestUpdateInfo()
if err != nil {
t.Fatalf("fetchLatestUpdateInfo returned error: %v", err)
}
if !checksumCalled {
t.Fatal("expected SHA256SUMS fetch when update is available")
}
if !info.HasUpdate {
t.Fatalf("expected HasUpdate=true, got %#v", info)
}
if info.SHA256 != "abc123" || info.AssetName != assetName {
t.Fatalf("unexpected update info: %#v", info)
}
}
func TestCheckForUpdatesLogsFailuresForManualChecks(t *testing.T) {
app := &App{}
restoreRelease := swapUpdateFetchLatestRelease(func() (*githubRelease, error) {
return nil, errors.New("request timed out")
})
defer restoreRelease()
logged := 0
restoreLogger := swapUpdateCheckErrorLogger(func(error) {
logged++
})
defer restoreLogger()
result := app.CheckForUpdates()
if result.Success {
t.Fatalf("expected failure result, got %#v", result)
}
if logged != 1 {
t.Fatalf("expected manual check to log once, got %d", logged)
}
}
func TestCheckForUpdatesSilentlySkipsFailureLogs(t *testing.T) {
app := &App{}
restoreRelease := swapUpdateFetchLatestRelease(func() (*githubRelease, error) {
return nil, errors.New("request timed out")
})
defer restoreRelease()
logged := 0
restoreLogger := swapUpdateCheckErrorLogger(func(error) {
logged++
})
defer restoreLogger()
result := app.CheckForUpdatesSilently()
if result.Success {
t.Fatalf("expected failure result, got %#v", result)
}
if logged != 0 {
t.Fatalf("expected silent check to skip error logging, got %d", logged)
}
}

View File

@@ -0,0 +1,561 @@
package app
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"GoNavi-Wails/internal/ai"
aiservice "GoNavi-Wails/internal/ai/service"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/secretstore"
)
type securityUpdateNormalizedPreview struct {
SourceType SecurityUpdateSourceType `json:"sourceType"`
ConnectionIDs []string `json:"connectionIds"`
HasGlobalProxy bool `json:"hasGlobalProxy"`
AIProviderIDs []string `json:"aiProviderIds"`
AIProvidersNeedingAttention []string `json:"aiProvidersNeedingAttention,omitempty"`
}
func (a *App) GetSecurityUpdateStatus() (SecurityUpdateStatus, error) {
a.updateMu.Lock()
defer a.updateMu.Unlock()
repo := newSecurityUpdateStateRepository(a.configDir)
status, err := repo.LoadMarker()
if err != nil {
if os.IsNotExist(err) {
inspection, inspectErr := aiservice.NewProviderConfigStore(a.configDir, a.secretStore).Inspect()
if inspectErr != nil {
return SecurityUpdateStatus{}, inspectErr
}
if len(inspection.ProvidersNeedingMigration) > 0 {
return buildSecurityUpdatePendingStatusFromInspection(inspection, SecurityUpdateOverallStatusPending), nil
}
return SecurityUpdateStatus{
SchemaVersion: securityUpdateSchemaVersion,
OverallStatus: SecurityUpdateOverallStatusNotDetected,
Summary: SecurityUpdateSummary{},
Issues: []SecurityUpdateIssue{},
}, nil
}
return SecurityUpdateStatus{}, err
}
return status, nil
}
func (a *App) StartSecurityUpdate(request StartSecurityUpdateRequest) (SecurityUpdateStatus, error) {
a.updateMu.Lock()
defer a.updateMu.Unlock()
repo := newSecurityUpdateStateRepository(a.configDir)
status, err := repo.StartRound(request)
if err != nil {
return SecurityUpdateStatus{}, err
}
return a.executeSecurityUpdateRound(repo, status, request.SourceType, request.RawPayload)
}
func (a *App) RetrySecurityUpdateCurrentRound(request RetrySecurityUpdateRequest) (SecurityUpdateStatus, error) {
a.updateMu.Lock()
defer a.updateMu.Unlock()
repo := newSecurityUpdateStateRepository(a.configDir)
status, err := repo.RetryRound(request)
if err != nil {
return SecurityUpdateStatus{}, err
}
previewData, err := os.ReadFile(filepath.Join(status.BackupPath, securityUpdateNormalizedPreviewFileName))
if err != nil {
failed := newSecurityUpdateSystemFailureStatus(status, SecurityUpdateIssueReasonCodeEnvironmentBlocked, err)
_ = repo.WriteResult(failed)
return failed, nil
}
var preview securityUpdateNormalizedPreview
if err := json.Unmarshal(previewData, &preview); err != nil {
failed := newSecurityUpdateSystemFailureStatus(status, SecurityUpdateIssueReasonCodeValidationFailed, err)
_ = repo.WriteResult(failed)
return failed, nil
}
finalStatus, execErr := a.validateSecurityUpdateCurrentAppRound(status, preview)
if execErr != nil {
_ = repo.WriteResult(finalStatus)
return finalStatus, nil
}
if err := repo.WriteResult(finalStatus); err != nil {
return SecurityUpdateStatus{}, err
}
return finalStatus, nil
}
func (a *App) RestartSecurityUpdate(request RestartSecurityUpdateRequest) (SecurityUpdateStatus, error) {
a.updateMu.Lock()
defer a.updateMu.Unlock()
repo := newSecurityUpdateStateRepository(a.configDir)
status, err := repo.RestartRound(request)
if err != nil {
return SecurityUpdateStatus{}, err
}
return a.executeSecurityUpdateRound(repo, status, request.SourceType, request.RawPayload)
}
func (a *App) DismissSecurityUpdateReminder() (SecurityUpdateStatus, error) {
a.updateMu.Lock()
defer a.updateMu.Unlock()
now := nowRFC3339()
repo := newSecurityUpdateStateRepository(a.configDir)
status, err := repo.LoadMarker()
if err != nil {
if !os.IsNotExist(err) {
return SecurityUpdateStatus{}, err
}
inspection, inspectErr := aiservice.NewProviderConfigStore(a.configDir, a.secretStore).Inspect()
if inspectErr != nil {
return SecurityUpdateStatus{}, inspectErr
}
if len(inspection.ProvidersNeedingMigration) > 0 {
status = buildSecurityUpdatePendingStatusFromInspection(inspection, SecurityUpdateOverallStatusPostponed)
} else {
status = SecurityUpdateStatus{
SchemaVersion: securityUpdateSchemaVersion,
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
Summary: SecurityUpdateSummary{},
Issues: []SecurityUpdateIssue{},
}
}
}
status.SchemaVersion = securityUpdateSchemaVersion
if strings.TrimSpace(string(status.SourceType)) == "" {
status.SourceType = SecurityUpdateSourceTypeCurrentAppSavedConfig
}
if status.Issues == nil {
status.Issues = []SecurityUpdateIssue{}
}
if status.OverallStatus == SecurityUpdateOverallStatusCompleted || status.OverallStatus == SecurityUpdateOverallStatusRolledBack {
return status, nil
}
status.OverallStatus = SecurityUpdateOverallStatusPostponed
status.PostponedAt = now
status.UpdatedAt = now
if err := repo.WriteResult(status); err != nil {
return SecurityUpdateStatus{}, err
}
return repo.LoadMarker()
}
func (a *App) executeSecurityUpdateRound(repo *securityUpdateStateRepository, round SecurityUpdateStatus, sourceType SecurityUpdateSourceType, rawPayload string) (SecurityUpdateStatus, error) {
if strings.TrimSpace(string(sourceType)) == "" {
sourceType = SecurityUpdateSourceTypeCurrentAppSavedConfig
}
if sourceType != SecurityUpdateSourceTypeCurrentAppSavedConfig {
failed := newSecurityUpdateSystemFailureStatus(round, SecurityUpdateIssueReasonCodeValidationFailed, fmt.Errorf("unsupported source type: %s", sourceType))
_ = repo.WriteResult(failed)
return failed, nil
}
source, rawParsed, err := parseSecurityUpdateCurrentAppSource(rawPayload)
if err != nil {
failed := newSecurityUpdateSystemFailureStatus(round, SecurityUpdateIssueReasonCodeValidationFailed, err)
_ = repo.WriteResult(failed)
return failed, nil
}
rollbackSnapshot, err := captureSecurityUpdateCurrentAppRollbackSnapshot(a, source)
if err != nil {
failed := newSecurityUpdateSystemFailureStatus(round, securityUpdateFailureReasonForError(err), err)
_ = repo.WriteResult(failed)
return failed, nil
}
if err := securityUpdateWriteJSONFile(filepath.Join(round.BackupPath, securityUpdateSourceCurrentAppFileName), rawParsed); err != nil {
return SecurityUpdateStatus{}, err
}
finalStatus, preview, execErr := a.runSecurityUpdateCurrentAppRound(round, source)
if previewErr := securityUpdateWriteJSONFile(filepath.Join(round.BackupPath, securityUpdateNormalizedPreviewFileName), preview); previewErr != nil {
return a.rollbackSecurityUpdatePersistenceFailure(repo, rollbackSnapshot, finalStatus, previewErr)
}
if execErr != nil {
if rollbackErr := rollbackSnapshot.restore(a); rollbackErr != nil {
failed := newSecurityUpdateSystemFailureStatus(finalStatus, securityUpdateFailureReasonForError(rollbackErr), rollbackErr)
_ = repo.WriteResult(failed)
return failed, nil
}
_ = repo.WriteResult(finalStatus)
return finalStatus, nil
}
if err := repo.WriteResult(finalStatus); err != nil {
return a.rollbackSecurityUpdatePersistenceFailure(repo, rollbackSnapshot, finalStatus, err)
}
return finalStatus, nil
}
func (a *App) rollbackSecurityUpdatePersistenceFailure(
repo *securityUpdateStateRepository,
rollbackSnapshot securityUpdateCurrentAppRollbackSnapshot,
base SecurityUpdateStatus,
cause error,
) (SecurityUpdateStatus, error) {
if rollbackErr := rollbackSnapshot.restore(a); rollbackErr != nil {
failed := newSecurityUpdateSystemFailureStatus(base, securityUpdateFailureReasonForError(rollbackErr), rollbackErr)
_ = repo.WriteResult(failed)
return failed, nil
}
failed := newSecurityUpdateSystemFailureStatus(base, SecurityUpdateIssueReasonCodeEnvironmentBlocked, cause)
_ = repo.WriteResult(failed)
return failed, nil
}
func (a *App) runSecurityUpdateCurrentAppRound(round SecurityUpdateStatus, source securityUpdateCurrentAppSource) (SecurityUpdateStatus, securityUpdateNormalizedPreview, error) {
finalStatus := newSecurityUpdateRoundBaseStatus(round, SecurityUpdateSourceTypeCurrentAppSavedConfig)
preview := securityUpdateNormalizedPreview{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
ConnectionIDs: make([]string, 0, len(source.Connections)),
HasGlobalProxy: source.GlobalProxy != nil,
AIProviderIDs: []string{},
}
connectionRepo := a.savedConnectionRepository()
for _, item := range source.Connections {
finalStatus.Summary.Total++
preview.ConnectionIDs = append(preview.ConnectionIDs, item.ID)
if _, err := connectionRepo.Save(connection.SavedConnectionInput(item)); err != nil {
failed := newSecurityUpdateSystemFailureStatus(finalStatus, SecurityUpdateIssueReasonCodeEnvironmentBlocked, err)
return failed, preview, err
}
finalStatus.Summary.Updated++
}
if source.GlobalProxy != nil {
finalStatus.Summary.Total++
if _, err := a.saveGlobalProxy(connection.SaveGlobalProxyInput(*source.GlobalProxy)); err != nil {
failed := newSecurityUpdateSystemFailureStatus(finalStatus, SecurityUpdateIssueReasonCodeEnvironmentBlocked, err)
return failed, preview, err
}
finalStatus.Summary.Updated++
}
providerSnapshot, err := aiservice.NewProviderConfigStore(a.configDir, a.secretStore).Load()
if err != nil {
failed := newSecurityUpdateSystemFailureStatus(finalStatus, securityUpdateFailureReasonForError(err), err)
return failed, preview, err
}
for _, provider := range providerSnapshot.Providers {
if !providerParticipatesInSecurityUpdate(provider) {
continue
}
preview.AIProviderIDs = append(preview.AIProviderIDs, provider.ID)
finalStatus.Summary.Total++
if provider.HasSecret && strings.TrimSpace(provider.APIKey) == "" {
finalStatus.OverallStatus = SecurityUpdateOverallStatusNeedsAttention
finalStatus.Summary.Pending++
finalStatus.Issues = append(finalStatus.Issues, SecurityUpdateIssue{
ID: "ai-provider-" + provider.ID,
Scope: SecurityUpdateIssueScopeAIProvider,
RefID: provider.ID,
Title: provider.Name,
Severity: SecurityUpdateIssueSeverityMedium,
Status: SecurityUpdateItemStatusNeedsAttention,
ReasonCode: SecurityUpdateIssueReasonCodeSecretMissing,
Action: SecurityUpdateIssueActionOpenAISettings,
Message: "AI 提供商配置需要补充后才能完成安全更新",
})
preview.AIProvidersNeedingAttention = append(preview.AIProvidersNeedingAttention, provider.ID)
continue
}
finalStatus.Summary.Updated++
}
if finalStatus.OverallStatus == SecurityUpdateOverallStatusCompleted {
finalStatus.CompletedAt = finalStatus.UpdatedAt
}
return finalStatus, preview, nil
}
func (a *App) validateSecurityUpdateCurrentAppRound(round SecurityUpdateStatus, preview securityUpdateNormalizedPreview) (SecurityUpdateStatus, error) {
if strings.TrimSpace(string(preview.SourceType)) == "" {
preview.SourceType = SecurityUpdateSourceTypeCurrentAppSavedConfig
}
finalStatus := newSecurityUpdateRoundBaseStatus(round, preview.SourceType)
connectionRepo := a.savedConnectionRepository()
for _, id := range preview.ConnectionIDs {
finalStatus.Summary.Total++
savedConnection, err := connectionRepo.Find(id)
if err != nil {
markSecurityUpdateNeedsAttention(
&finalStatus,
SecurityUpdateIssue{
ID: "connection-" + id,
Scope: SecurityUpdateIssueScopeConnection,
RefID: id,
Title: id,
Severity: SecurityUpdateIssueSeverityMedium,
Status: SecurityUpdateItemStatusNeedsAttention,
ReasonCode: SecurityUpdateIssueReasonCodeValidationFailed,
Action: SecurityUpdateIssueActionOpenConnection,
Message: "连接配置已不存在或仍需重新保存后才能完成安全更新",
},
)
continue
}
if _, err := a.resolveConnectionSecrets(savedConnection.Config); err != nil {
if secretstore.IsUnavailable(err) {
failed := newSecurityUpdateSystemFailureStatus(finalStatus, SecurityUpdateIssueReasonCodeEnvironmentBlocked, err)
return failed, err
}
reason := SecurityUpdateIssueReasonCodeValidationFailed
message := "连接配置仍需补充后才能完成安全更新"
if os.IsNotExist(err) {
reason = SecurityUpdateIssueReasonCodeSecretMissing
message = "连接密码已丢失,请重新保存后再继续"
}
markSecurityUpdateNeedsAttention(
&finalStatus,
SecurityUpdateIssue{
ID: "connection-" + id,
Scope: SecurityUpdateIssueScopeConnection,
RefID: id,
Title: savedConnection.Name,
Severity: SecurityUpdateIssueSeverityMedium,
Status: SecurityUpdateItemStatusNeedsAttention,
ReasonCode: reason,
Action: SecurityUpdateIssueActionOpenConnection,
Message: message,
},
)
continue
}
finalStatus.Summary.Updated++
}
if preview.HasGlobalProxy {
finalStatus.Summary.Total++
proxyView, err := a.loadStoredGlobalProxyView()
if err != nil {
if !os.IsNotExist(err) {
failed := newSecurityUpdateSystemFailureStatus(finalStatus, securityUpdateFailureReasonForError(err), err)
return failed, err
}
markSecurityUpdateNeedsAttention(
&finalStatus,
SecurityUpdateIssue{
ID: "global-proxy-default",
Scope: SecurityUpdateIssueScopeGlobalProxy,
Title: "全局代理",
Severity: SecurityUpdateIssueSeverityMedium,
Status: SecurityUpdateItemStatusNeedsAttention,
ReasonCode: SecurityUpdateIssueReasonCodeValidationFailed,
Action: SecurityUpdateIssueActionOpenProxySettings,
Message: "全局代理配置已不存在或仍需重新保存后才能完成安全更新",
},
)
} else {
if proxyView.HasPassword {
if _, err := a.loadGlobalProxySecretBundle(proxyView); err != nil {
if secretstore.IsUnavailable(err) {
failed := newSecurityUpdateSystemFailureStatus(finalStatus, SecurityUpdateIssueReasonCodeEnvironmentBlocked, err)
return failed, err
}
reason := SecurityUpdateIssueReasonCodeValidationFailed
message := "全局代理密码仍需补充后才能完成安全更新"
if os.IsNotExist(err) {
reason = SecurityUpdateIssueReasonCodeSecretMissing
message = "全局代理密码已丢失,请重新保存后再继续"
}
markSecurityUpdateNeedsAttention(
&finalStatus,
SecurityUpdateIssue{
ID: "global-proxy-default",
Scope: SecurityUpdateIssueScopeGlobalProxy,
Title: "全局代理",
Severity: SecurityUpdateIssueSeverityMedium,
Status: SecurityUpdateItemStatusNeedsAttention,
ReasonCode: reason,
Action: SecurityUpdateIssueActionOpenProxySettings,
Message: message,
},
)
goto validateProviders
}
}
finalStatus.Summary.Updated++
}
}
validateProviders:
providerSnapshot, err := aiservice.NewProviderConfigStore(a.configDir, a.secretStore).Load()
if err != nil {
failed := newSecurityUpdateSystemFailureStatus(finalStatus, securityUpdateFailureReasonForError(err), err)
return failed, err
}
providersByID := make(map[string]ai.ProviderConfig, len(providerSnapshot.Providers))
for _, provider := range providerSnapshot.Providers {
providersByID[provider.ID] = provider
}
for _, providerID := range preview.AIProviderIDs {
finalStatus.Summary.Total++
provider, ok := providersByID[providerID]
if !ok {
markSecurityUpdateNeedsAttention(
&finalStatus,
SecurityUpdateIssue{
ID: "ai-provider-" + providerID,
Scope: SecurityUpdateIssueScopeAIProvider,
RefID: providerID,
Title: providerID,
Severity: SecurityUpdateIssueSeverityMedium,
Status: SecurityUpdateItemStatusNeedsAttention,
ReasonCode: SecurityUpdateIssueReasonCodeValidationFailed,
Action: SecurityUpdateIssueActionOpenAISettings,
Message: "AI 提供商配置已不存在或仍需重新保存后才能完成安全更新",
},
)
continue
}
if provider.HasSecret && strings.TrimSpace(provider.APIKey) == "" {
markSecurityUpdateNeedsAttention(
&finalStatus,
SecurityUpdateIssue{
ID: "ai-provider-" + provider.ID,
Scope: SecurityUpdateIssueScopeAIProvider,
RefID: provider.ID,
Title: provider.Name,
Severity: SecurityUpdateIssueSeverityMedium,
Status: SecurityUpdateItemStatusNeedsAttention,
ReasonCode: SecurityUpdateIssueReasonCodeSecretMissing,
Action: SecurityUpdateIssueActionOpenAISettings,
Message: "AI 提供商配置需要补充后才能完成安全更新",
},
)
continue
}
finalStatus.Summary.Updated++
}
if finalStatus.OverallStatus == SecurityUpdateOverallStatusCompleted {
finalStatus.CompletedAt = finalStatus.UpdatedAt
}
return finalStatus, nil
}
func providerParticipatesInSecurityUpdate(provider ai.ProviderConfig) bool {
return provider.HasSecret || strings.TrimSpace(provider.APIKey) != ""
}
func buildSecurityUpdatePendingStatusFromInspection(
inspection aiservice.ProviderConfigStoreInspection,
overallStatus SecurityUpdateOverallStatus,
) SecurityUpdateStatus {
providersByID := make(map[string]ai.ProviderConfig, len(inspection.Snapshot.Providers))
for _, provider := range inspection.Snapshot.Providers {
providersByID[provider.ID] = provider
}
issues := make([]SecurityUpdateIssue, 0, len(inspection.ProvidersNeedingMigration))
for _, providerID := range inspection.ProvidersNeedingMigration {
provider := providersByID[providerID]
title := strings.TrimSpace(provider.Name)
if title == "" {
title = providerID
}
issues = append(issues, SecurityUpdateIssue{
ID: "ai-provider-" + providerID,
Scope: SecurityUpdateIssueScopeAIProvider,
RefID: providerID,
Title: title,
Severity: SecurityUpdateIssueSeverityMedium,
Status: SecurityUpdateItemStatusPending,
ReasonCode: SecurityUpdateIssueReasonCodeMigrationRequired,
Action: SecurityUpdateIssueActionOpenAISettings,
Message: "AI 提供商配置仍保存在当前应用配置中,完成安全更新后会迁入新的安全存储。",
})
}
return SecurityUpdateStatus{
SchemaVersion: securityUpdateSchemaVersion,
OverallStatus: overallStatus,
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
ReminderVisible: overallStatus == SecurityUpdateOverallStatusPending,
CanStart: overallStatus == SecurityUpdateOverallStatusPending || overallStatus == SecurityUpdateOverallStatusPostponed,
CanPostpone: overallStatus == SecurityUpdateOverallStatusPending || overallStatus == SecurityUpdateOverallStatusPostponed,
Summary: SecurityUpdateSummary{
Total: len(issues),
Pending: len(issues),
},
Issues: issues,
}
}
func newSecurityUpdateRoundBaseStatus(round SecurityUpdateStatus, sourceType SecurityUpdateSourceType) SecurityUpdateStatus {
if strings.TrimSpace(string(sourceType)) == "" {
sourceType = SecurityUpdateSourceTypeCurrentAppSavedConfig
}
return SecurityUpdateStatus{
SchemaVersion: securityUpdateSchemaVersion,
MigrationID: round.MigrationID,
OverallStatus: SecurityUpdateOverallStatusCompleted,
SourceType: sourceType,
BackupAvailable: round.BackupAvailable || strings.TrimSpace(round.BackupPath) != "",
BackupPath: round.BackupPath,
StartedAt: round.StartedAt,
UpdatedAt: nowRFC3339(),
Summary: SecurityUpdateSummary{},
Issues: []SecurityUpdateIssue{},
}
}
func markSecurityUpdateNeedsAttention(status *SecurityUpdateStatus, issue SecurityUpdateIssue) {
status.OverallStatus = SecurityUpdateOverallStatusNeedsAttention
status.Summary.Pending++
status.Issues = append(status.Issues, issue)
}
func securityUpdateFailureReasonForError(err error) SecurityUpdateIssueReasonCode {
if secretstore.IsUnavailable(err) {
return SecurityUpdateIssueReasonCodeEnvironmentBlocked
}
return SecurityUpdateIssueReasonCodeValidationFailed
}
func newSecurityUpdateSystemFailureStatus(base SecurityUpdateStatus, reasonCode SecurityUpdateIssueReasonCode, err error) SecurityUpdateStatus {
status := base
status.SchemaVersion = securityUpdateSchemaVersion
status.OverallStatus = SecurityUpdateOverallStatusRolledBack
status.BackupAvailable = status.BackupAvailable || strings.TrimSpace(status.BackupPath) != ""
status.UpdatedAt = nowRFC3339()
status.CompletedAt = ""
status.LastError = err.Error()
status.Summary.Failed++
status.Issues = []SecurityUpdateIssue{
{
ID: "system-blocked",
Scope: SecurityUpdateIssueScopeSystem,
Title: "安全更新未完成",
Severity: SecurityUpdateIssueSeverityHigh,
Status: SecurityUpdateItemStatusFailed,
ReasonCode: reasonCode,
Action: SecurityUpdateIssueActionViewDetails,
Message: "当前环境无法完成本次安全更新,请稍后重试",
},
}
return status
}

View File

@@ -0,0 +1,942 @@
package app
import (
"errors"
"encoding/json"
"os"
"path/filepath"
"strings"
"testing"
aiservice "GoNavi-Wails/internal/ai/service"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/secretstore"
)
func TestStartSecurityUpdateCreatesBackupAndImportsSavedConfig(t *testing.T) {
app := NewAppWithSecretStore(newFakeAppSecretStore())
app.configDir = t.TempDir()
writeLegacyAIProviderConfig(t, app.configDir, map[string]any{
"providers": []map[string]any{
{
"id": "openai-main",
"type": "openai",
"name": "OpenAI",
"apiKey": "sk-ai-test",
"baseUrl": "https://api.openai.com/v1",
"headers": map[string]any{
"Authorization": "Bearer ai-test",
"X-Team": "platform",
},
},
},
})
status, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
RawPayload: buildLegacySecurityUpdatePayload(),
})
if err != nil {
t.Fatalf("StartSecurityUpdate returned error: %v", err)
}
if status.OverallStatus != SecurityUpdateOverallStatusCompleted {
t.Fatalf("expected completed status, got %q", status.OverallStatus)
}
if status.MigrationID == "" {
t.Fatal("expected migration ID to be created")
}
if status.Summary.Total != 3 || status.Summary.Updated != 3 {
t.Fatalf("expected summary total=3 updated=3, got %#v", status.Summary)
}
savedConnections, err := app.GetSavedConnections()
if err != nil {
t.Fatalf("GetSavedConnections returned error: %v", err)
}
if len(savedConnections) != 1 {
t.Fatalf("expected 1 saved connection, got %d", len(savedConnections))
}
resolvedConnection, err := app.resolveConnectionSecrets(savedConnections[0].Config)
if err != nil {
t.Fatalf("resolveConnectionSecrets returned error: %v", err)
}
if resolvedConnection.Password != "postgres-secret" {
t.Fatalf("expected imported connection password, got %q", resolvedConnection.Password)
}
globalProxyView, err := app.loadStoredGlobalProxyView()
if err != nil {
t.Fatalf("loadStoredGlobalProxyView returned error: %v", err)
}
globalProxyBundle, err := app.loadGlobalProxySecretBundle(globalProxyView)
if err != nil {
t.Fatalf("loadGlobalProxySecretBundle returned error: %v", err)
}
if globalProxyBundle.Password != "proxy-secret" {
t.Fatalf("expected imported proxy password, got %q", globalProxyBundle.Password)
}
providerStore := aiservice.NewProviderConfigStore(app.configDir, app.secretStore)
providerSnapshot, err := providerStore.Load()
if err != nil {
t.Fatalf("provider store Load returned error: %v", err)
}
if len(providerSnapshot.Providers) != 1 {
t.Fatalf("expected 1 AI provider, got %d", len(providerSnapshot.Providers))
}
if providerSnapshot.Providers[0].APIKey != "sk-ai-test" {
t.Fatalf("expected migrated AI provider apiKey, got %q", providerSnapshot.Providers[0].APIKey)
}
for _, name := range []string{
securityUpdateManifestFileName,
securityUpdateSourceCurrentAppFileName,
securityUpdateNormalizedPreviewFileName,
securityUpdateResultFileName,
} {
if _, err := os.Stat(filepath.Join(status.BackupPath, name)); err != nil {
t.Fatalf("expected backup artifact %q: %v", name, err)
}
}
}
func TestGetSecurityUpdateStatusReturnsPendingWhenOnlyAIProviderNeedsSecurityUpdate(t *testing.T) {
app := NewAppWithSecretStore(newFakeAppSecretStore())
app.configDir = t.TempDir()
writeLegacyAIProviderConfig(t, app.configDir, map[string]any{
"providers": []map[string]any{
{
"id": "openai-main",
"type": "openai",
"name": "OpenAI",
"apiKey": "sk-ai-test",
"baseUrl": "https://api.openai.com/v1",
},
},
})
status, err := app.GetSecurityUpdateStatus()
if err != nil {
t.Fatalf("GetSecurityUpdateStatus returned error: %v", err)
}
if status.OverallStatus != SecurityUpdateOverallStatusPending {
t.Fatalf("expected pending status, got %q", status.OverallStatus)
}
if !status.CanStart || !status.ReminderVisible {
t.Fatalf("expected pending status to expose start/reminder flags, got %#v", status)
}
}
func TestGetSecurityUpdateStatusIncludesPendingAIProviderIssuesBeforeStart(t *testing.T) {
app := NewAppWithSecretStore(newFakeAppSecretStore())
app.configDir = t.TempDir()
writeLegacyAIProviderConfig(t, app.configDir, map[string]any{
"providers": []map[string]any{
{
"id": "openai-main",
"type": "openai",
"name": "OpenAI",
"apiKey": "sk-ai-test",
"baseUrl": "https://api.openai.com/v1",
},
},
})
status, err := app.GetSecurityUpdateStatus()
if err != nil {
t.Fatalf("GetSecurityUpdateStatus returned error: %v", err)
}
if len(status.Issues) != 1 {
t.Fatalf("expected 1 pending issue, got %#v", status.Issues)
}
if status.Summary.Total != 1 || status.Summary.Pending != 1 {
t.Fatalf("expected summary total=1 pending=1, got %#v", status.Summary)
}
issue := status.Issues[0]
if issue.Scope != SecurityUpdateIssueScopeAIProvider {
t.Fatalf("expected AI provider issue scope, got %#v", issue)
}
if issue.RefID != "openai-main" || issue.Title != "OpenAI" {
t.Fatalf("expected provider issue to point at openai-main/OpenAI, got %#v", issue)
}
if issue.Status != SecurityUpdateItemStatusPending || issue.Action != SecurityUpdateIssueActionOpenAISettings {
t.Fatalf("expected pending AI settings issue, got %#v", issue)
}
}
func TestRetrySecurityUpdateCurrentRoundReusesMigrationIDAfterPendingIssueIsFixed(t *testing.T) {
store := newFakeAppSecretStore()
app := NewAppWithSecretStore(store)
app.configDir = t.TempDir()
ref, err := secretstore.BuildRef("ai-provider", "openai-main")
if err != nil {
t.Fatalf("BuildRef returned error: %v", err)
}
writeLegacyAIProviderConfig(t, app.configDir, map[string]any{
"providers": []map[string]any{
{
"id": "openai-main",
"type": "openai",
"name": "OpenAI",
"hasSecret": true,
"secretRef": ref,
"baseUrl": "https://api.openai.com/v1",
},
},
})
initial, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
RawPayload: buildLegacySecurityUpdatePayload(),
})
if err != nil {
t.Fatalf("StartSecurityUpdate returned error: %v", err)
}
if initial.OverallStatus != SecurityUpdateOverallStatusNeedsAttention {
t.Fatalf("expected needs_attention status, got %q", initial.OverallStatus)
}
if len(initial.Issues) != 1 || initial.Issues[0].Scope != SecurityUpdateIssueScopeAIProvider {
t.Fatalf("expected AI provider issue, got %#v", initial.Issues)
}
if err := store.Put(ref, []byte(`{"apiKey":"sk-fixed","sensitiveHeaders":{"Authorization":"Bearer fixed"}}`)); err != nil {
t.Fatalf("Put returned error: %v", err)
}
retried, err := app.RetrySecurityUpdateCurrentRound(RetrySecurityUpdateRequest{
MigrationID: initial.MigrationID,
})
if err != nil {
t.Fatalf("RetrySecurityUpdateCurrentRound returned error: %v", err)
}
if retried.MigrationID != initial.MigrationID {
t.Fatalf("expected retry to reuse migration ID %q, got %q", initial.MigrationID, retried.MigrationID)
}
if retried.OverallStatus != SecurityUpdateOverallStatusCompleted {
t.Fatalf("expected completed status after retry, got %q", retried.OverallStatus)
}
}
func TestRetrySecurityUpdateCurrentRoundDoesNotReimportBrokenLegacySourceAfterUserFix(t *testing.T) {
store := newFakeAppSecretStore()
app := NewAppWithSecretStore(store)
app.configDir = t.TempDir()
ref, err := secretstore.BuildRef("ai-provider", "openai-main")
if err != nil {
t.Fatalf("BuildRef returned error: %v", err)
}
writeLegacyAIProviderConfig(t, app.configDir, map[string]any{
"providers": []map[string]any{
{
"id": "openai-main",
"type": "openai",
"name": "OpenAI",
"hasSecret": true,
"secretRef": ref,
"baseUrl": "https://api.openai.com/v1",
},
},
})
initial, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
RawPayload: buildLegacySecurityUpdatePayload(),
})
if err != nil {
t.Fatalf("StartSecurityUpdate returned error: %v", err)
}
if initial.OverallStatus != SecurityUpdateOverallStatusNeedsAttention {
t.Fatalf("expected needs_attention status, got %q", initial.OverallStatus)
}
if _, err := app.SaveConnection(connection.SavedConnectionInput{
ID: "legacy-1",
Name: "Legacy Fixed",
Config: connection.ConnectionConfig{
ID: "legacy-1",
Type: "postgres",
Host: "db-fixed.local",
Port: 5432,
User: "postgres",
Password: "postgres-fixed",
},
}); err != nil {
t.Fatalf("SaveConnection returned error: %v", err)
}
if err := store.Put(ref, []byte(`{"apiKey":"sk-fixed"}`)); err != nil {
t.Fatalf("Put returned error: %v", err)
}
retried, err := app.RetrySecurityUpdateCurrentRound(RetrySecurityUpdateRequest{
MigrationID: initial.MigrationID,
})
if err != nil {
t.Fatalf("RetrySecurityUpdateCurrentRound returned error: %v", err)
}
if retried.OverallStatus != SecurityUpdateOverallStatusCompleted {
t.Fatalf("expected completed status after retry, got %q", retried.OverallStatus)
}
savedConnections, err := app.GetSavedConnections()
if err != nil {
t.Fatalf("GetSavedConnections returned error: %v", err)
}
if len(savedConnections) != 1 {
t.Fatalf("expected 1 saved connection, got %d", len(savedConnections))
}
resolvedConnection, err := app.resolveConnectionSecrets(savedConnections[0].Config)
if err != nil {
t.Fatalf("resolveConnectionSecrets returned error: %v", err)
}
if resolvedConnection.Host != "db-fixed.local" {
t.Fatalf("expected retry to keep user-fixed host, got %q", resolvedConnection.Host)
}
if resolvedConnection.Password != "postgres-fixed" {
t.Fatalf("expected retry to keep user-fixed password, got %q", resolvedConnection.Password)
}
}
func TestRestartSecurityUpdateCreatesNewMigrationID(t *testing.T) {
app := NewAppWithSecretStore(newFakeAppSecretStore())
app.configDir = t.TempDir()
initial, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
RawPayload: buildLegacySecurityUpdatePayload(),
})
if err != nil {
t.Fatalf("StartSecurityUpdate returned error: %v", err)
}
restarted, err := app.RestartSecurityUpdate(RestartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
RawPayload: buildLegacySecurityUpdatePayload(),
})
if err != nil {
t.Fatalf("RestartSecurityUpdate returned error: %v", err)
}
if restarted.MigrationID == initial.MigrationID {
t.Fatal("expected restart to create a new migration ID")
}
}
func TestDismissSecurityUpdateReminderMarksStatusPostponed(t *testing.T) {
app := NewAppWithSecretStore(newFakeAppSecretStore())
app.configDir = t.TempDir()
status, err := app.DismissSecurityUpdateReminder()
if err != nil {
t.Fatalf("DismissSecurityUpdateReminder returned error: %v", err)
}
if status.OverallStatus != SecurityUpdateOverallStatusPostponed {
t.Fatalf("expected postponed status, got %q", status.OverallStatus)
}
if status.PostponedAt == "" {
t.Fatal("expected postponedAt to be recorded")
}
}
func TestDismissSecurityUpdateReminderKeepsCurrentRoundContext(t *testing.T) {
store := newFakeAppSecretStore()
app := NewAppWithSecretStore(store)
app.configDir = t.TempDir()
ref, err := secretstore.BuildRef("ai-provider", "openai-main")
if err != nil {
t.Fatalf("BuildRef returned error: %v", err)
}
writeLegacyAIProviderConfig(t, app.configDir, map[string]any{
"providers": []map[string]any{
{
"id": "openai-main",
"type": "openai",
"name": "OpenAI",
"hasSecret": true,
"secretRef": ref,
"baseUrl": "https://api.openai.com/v1",
},
},
})
initial, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
RawPayload: buildLegacySecurityUpdatePayload(),
})
if err != nil {
t.Fatalf("StartSecurityUpdate returned error: %v", err)
}
if initial.OverallStatus != SecurityUpdateOverallStatusNeedsAttention {
t.Fatalf("expected needs_attention status, got %q", initial.OverallStatus)
}
postponed, err := app.DismissSecurityUpdateReminder()
if err != nil {
t.Fatalf("DismissSecurityUpdateReminder returned error: %v", err)
}
if postponed.OverallStatus != SecurityUpdateOverallStatusPostponed {
t.Fatalf("expected postponed status, got %q", postponed.OverallStatus)
}
if postponed.MigrationID != initial.MigrationID {
t.Fatalf("expected migration ID %q to be preserved, got %q", initial.MigrationID, postponed.MigrationID)
}
if postponed.BackupPath != initial.BackupPath {
t.Fatalf("expected backupPath %q to be preserved, got %q", initial.BackupPath, postponed.BackupPath)
}
if postponed.Summary != initial.Summary {
t.Fatalf("expected summary %#v to be preserved, got %#v", initial.Summary, postponed.Summary)
}
if len(postponed.Issues) != len(initial.Issues) {
t.Fatalf("expected %d issues to be preserved, got %#v", len(initial.Issues), postponed.Issues)
}
if postponed.PostponedAt == "" {
t.Fatal("expected postponedAt to be recorded")
}
}
func TestDismissSecurityUpdateReminderKeepsPendingAIProviderDetailsWithoutCurrentRound(t *testing.T) {
app := NewAppWithSecretStore(newFakeAppSecretStore())
app.configDir = t.TempDir()
writeLegacyAIProviderConfig(t, app.configDir, map[string]any{
"providers": []map[string]any{
{
"id": "openai-main",
"type": "openai",
"name": "OpenAI",
"apiKey": "sk-ai-test",
"baseUrl": "https://api.openai.com/v1",
},
},
})
status, err := app.DismissSecurityUpdateReminder()
if err != nil {
t.Fatalf("DismissSecurityUpdateReminder returned error: %v", err)
}
if status.OverallStatus != SecurityUpdateOverallStatusPostponed {
t.Fatalf("expected postponed status, got %q", status.OverallStatus)
}
if status.Summary.Total != 1 || status.Summary.Pending != 1 {
t.Fatalf("expected summary total=1 pending=1, got %#v", status.Summary)
}
if len(status.Issues) != 1 {
t.Fatalf("expected 1 pending issue, got %#v", status.Issues)
}
if status.Issues[0].RefID != "openai-main" || status.Issues[0].Action != SecurityUpdateIssueActionOpenAISettings {
t.Fatalf("expected postponed issue to keep AI provider repair entry, got %#v", status.Issues[0])
}
}
func TestDismissSecurityUpdateReminderDoesNotOverrideCompletedRound(t *testing.T) {
app := NewAppWithSecretStore(newFakeAppSecretStore())
app.configDir = t.TempDir()
repo := newSecurityUpdateStateRepository(app.configDir)
completed := SecurityUpdateStatus{
SchemaVersion: securityUpdateSchemaVersion,
MigrationID: "migration-1",
OverallStatus: SecurityUpdateOverallStatusCompleted,
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
BackupPath: filepath.Join(app.configDir, securityUpdateBackupRootDirName, "migration-1"),
StartedAt: "2026-04-09T00:00:00Z",
UpdatedAt: "2026-04-09T00:05:00Z",
CompletedAt: "2026-04-09T00:05:00Z",
Summary: SecurityUpdateSummary{
Total: 1,
Updated: 1,
},
Issues: []SecurityUpdateIssue{},
}
if err := repo.WriteResult(completed); err != nil {
t.Fatalf("WriteResult returned error: %v", err)
}
status, err := app.DismissSecurityUpdateReminder()
if err != nil {
t.Fatalf("DismissSecurityUpdateReminder returned error: %v", err)
}
if status.OverallStatus != SecurityUpdateOverallStatusCompleted {
t.Fatalf("expected completed status to be preserved, got %q", status.OverallStatus)
}
if status.MigrationID != completed.MigrationID {
t.Fatalf("expected migration ID %q to be preserved, got %q", completed.MigrationID, status.MigrationID)
}
if status.PostponedAt != "" {
t.Fatalf("expected completed round to keep empty postponedAt, got %q", status.PostponedAt)
}
}
func TestDismissSecurityUpdateReminderDoesNotOverrideRolledBackRound(t *testing.T) {
app := NewAppWithSecretStore(newFakeAppSecretStore())
app.configDir = t.TempDir()
repo := newSecurityUpdateStateRepository(app.configDir)
rolledBack := SecurityUpdateStatus{
SchemaVersion: securityUpdateSchemaVersion,
MigrationID: "migration-1",
OverallStatus: SecurityUpdateOverallStatusRolledBack,
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
BackupPath: filepath.Join(app.configDir, securityUpdateBackupRootDirName, "migration-1"),
StartedAt: "2026-04-09T00:00:00Z",
UpdatedAt: "2026-04-09T00:05:00Z",
Summary: SecurityUpdateSummary{
Total: 1,
Failed: 1,
},
Issues: []SecurityUpdateIssue{
{
ID: "system-blocked",
Scope: SecurityUpdateIssueScopeSystem,
Title: "安全更新未完成",
Severity: SecurityUpdateIssueSeverityHigh,
Status: SecurityUpdateItemStatusFailed,
ReasonCode: SecurityUpdateIssueReasonCodeEnvironmentBlocked,
Action: SecurityUpdateIssueActionViewDetails,
Message: "当前环境无法完成本次安全更新,请稍后重试",
},
},
}
if err := repo.WriteResult(rolledBack); err != nil {
t.Fatalf("WriteResult returned error: %v", err)
}
status, err := app.DismissSecurityUpdateReminder()
if err != nil {
t.Fatalf("DismissSecurityUpdateReminder returned error: %v", err)
}
if status.OverallStatus != SecurityUpdateOverallStatusRolledBack {
t.Fatalf("expected rolled_back status to be preserved, got %q", status.OverallStatus)
}
if status.MigrationID != rolledBack.MigrationID {
t.Fatalf("expected migration ID %q to be preserved, got %q", rolledBack.MigrationID, status.MigrationID)
}
if status.PostponedAt != "" {
t.Fatalf("expected rolled_back round to keep empty postponedAt, got %q", status.PostponedAt)
}
if len(status.Issues) != 1 || status.Issues[0].Scope != SecurityUpdateIssueScopeSystem {
t.Fatalf("expected rolled_back issue details to be preserved, got %#v", status.Issues)
}
}
func TestStartSecurityUpdateRollsBackWhenSecretStoreUnavailable(t *testing.T) {
app := NewAppWithSecretStore(nil)
app.configDir = t.TempDir()
status, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
RawPayload: buildLegacySecurityUpdatePayload(),
})
if err != nil {
t.Fatalf("StartSecurityUpdate returned error: %v", err)
}
if status.OverallStatus != SecurityUpdateOverallStatusRolledBack {
t.Fatalf("expected rolled_back status, got %q", status.OverallStatus)
}
if len(status.Issues) != 1 || status.Issues[0].Scope != SecurityUpdateIssueScopeSystem {
t.Fatalf("expected single system issue, got %#v", status.Issues)
}
}
func TestStartSecurityUpdateRollsBackWhenAIProviderSecretStoreUnavailable(t *testing.T) {
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("blocked"))
app.configDir = t.TempDir()
writeLegacyAIProviderConfig(t, app.configDir, map[string]any{
"providers": []map[string]any{
{
"id": "openai-main",
"type": "openai",
"name": "OpenAI",
"apiKey": "sk-ai-test",
"baseUrl": "https://api.openai.com/v1",
},
},
})
status, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
RawPayload: "",
})
if err != nil {
t.Fatalf("StartSecurityUpdate returned error: %v", err)
}
if status.OverallStatus != SecurityUpdateOverallStatusRolledBack {
t.Fatalf("expected rolled_back status, got %q", status.OverallStatus)
}
if len(status.Issues) != 1 || status.Issues[0].Scope != SecurityUpdateIssueScopeSystem {
t.Fatalf("expected single system issue, got %#v", status.Issues)
}
}
func TestStartSecurityUpdateRollsBackPartialConnectionImportWhenLaterProviderStepFails(t *testing.T) {
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("blocked"))
app.configDir = t.TempDir()
writeLegacyAIProviderConfig(t, app.configDir, map[string]any{
"providers": []map[string]any{
{
"id": "openai-main",
"type": "openai",
"name": "OpenAI",
"apiKey": "sk-ai-test",
"baseUrl": "https://api.openai.com/v1",
},
},
})
payload, err := json.Marshal(map[string]any{
"state": map[string]any{
"connections": []map[string]any{
{
"id": "legacy-1",
"name": "Legacy",
"config": map[string]any{
"id": "legacy-1",
"type": "postgres",
"host": "db.local",
"port": 5432,
"user": "postgres",
},
},
},
},
})
if err != nil {
t.Fatalf("Marshal returned error: %v", err)
}
status, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
RawPayload: string(payload),
})
if err != nil {
t.Fatalf("StartSecurityUpdate returned error: %v", err)
}
if status.OverallStatus != SecurityUpdateOverallStatusRolledBack {
t.Fatalf("expected rolled_back status, got %q", status.OverallStatus)
}
savedConnections, err := app.GetSavedConnections()
if err != nil {
t.Fatalf("GetSavedConnections returned error: %v", err)
}
if len(savedConnections) != 0 {
t.Fatalf("expected rollback to leave no imported connections, got %#v", savedConnections)
}
}
func TestStartSecurityUpdateRollsBackExistingConnectionMetadataAndSecretWhenLaterProviderStepFails(t *testing.T) {
store := newFakeAppSecretStore()
app := NewAppWithSecretStore(store)
app.configDir = t.TempDir()
if _, err := app.SaveConnection(connection.SavedConnectionInput{
ID: "legacy-1",
Name: "Existing",
Config: connection.ConnectionConfig{
ID: "legacy-1",
Type: "postgres",
Host: "db-old.local",
Port: 5432,
User: "postgres",
Password: "old-secret",
},
}); err != nil {
t.Fatalf("SaveConnection returned error: %v", err)
}
if err := os.WriteFile(filepath.Join(app.configDir, "ai_config.json"), []byte("{"), 0o644); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}
payload, err := json.Marshal(map[string]any{
"state": map[string]any{
"connections": []map[string]any{
{
"id": "legacy-1",
"name": "Migrated",
"config": map[string]any{
"id": "legacy-1",
"type": "postgres",
"host": "db-new.local",
"port": 5432,
"user": "postgres",
"password": "new-secret",
},
},
},
},
})
if err != nil {
t.Fatalf("Marshal returned error: %v", err)
}
status, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
RawPayload: string(payload),
})
if err != nil {
t.Fatalf("StartSecurityUpdate returned error: %v", err)
}
if status.OverallStatus != SecurityUpdateOverallStatusRolledBack {
t.Fatalf("expected rolled_back status, got %q", status.OverallStatus)
}
savedConnections, err := app.GetSavedConnections()
if err != nil {
t.Fatalf("GetSavedConnections returned error: %v", err)
}
if len(savedConnections) != 1 {
t.Fatalf("expected existing connection to remain, got %#v", savedConnections)
}
if savedConnections[0].Name != "Existing" || savedConnections[0].Config.Host != "db-old.local" {
t.Fatalf("expected existing connection metadata to be restored, got %#v", savedConnections[0])
}
resolved, err := app.resolveConnectionSecrets(savedConnections[0].Config)
if err != nil {
t.Fatalf("resolveConnectionSecrets returned error: %v", err)
}
if resolved.Password != "old-secret" {
t.Fatalf("expected existing connection secret to be restored, got %q", resolved.Password)
}
}
func TestStartSecurityUpdateRollsBackExistingGlobalProxyWhenLaterProviderStepFails(t *testing.T) {
store := newFakeAppSecretStore()
app := NewAppWithSecretStore(store)
app.configDir = t.TempDir()
if _, err := app.saveGlobalProxy(connection.SaveGlobalProxyInput{
Enabled: true,
Type: "http",
Host: "proxy-old.local",
Port: 8080,
User: "ops",
Password: "old-proxy-secret",
}); err != nil {
t.Fatalf("saveGlobalProxy returned error: %v", err)
}
if err := os.WriteFile(filepath.Join(app.configDir, "ai_config.json"), []byte("{"), 0o644); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}
payload, err := json.Marshal(map[string]any{
"state": map[string]any{
"globalProxy": map[string]any{
"enabled": true,
"type": "http",
"host": "proxy-new.local",
"port": 8081,
"user": "ops-new",
"password": "new-proxy-secret",
},
},
})
if err != nil {
t.Fatalf("Marshal returned error: %v", err)
}
status, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
RawPayload: string(payload),
})
if err != nil {
t.Fatalf("StartSecurityUpdate returned error: %v", err)
}
if status.OverallStatus != SecurityUpdateOverallStatusRolledBack {
t.Fatalf("expected rolled_back status, got %q", status.OverallStatus)
}
view, err := app.loadStoredGlobalProxyView()
if err != nil {
t.Fatalf("loadStoredGlobalProxyView returned error: %v", err)
}
if view.Host != "proxy-old.local" || view.Port != 8080 || view.User != "ops" {
t.Fatalf("expected existing global proxy metadata to be restored, got %#v", view)
}
bundle, err := app.loadGlobalProxySecretBundle(view)
if err != nil {
t.Fatalf("loadGlobalProxySecretBundle returned error: %v", err)
}
if bundle.Password != "old-proxy-secret" {
t.Fatalf("expected existing global proxy secret to be restored, got %q", bundle.Password)
}
}
func TestStartSecurityUpdateRollsBackAllChangesWhenPreviewArtifactWriteFails(t *testing.T) {
store := newFakeAppSecretStore()
app := NewAppWithSecretStore(store)
app.configDir = t.TempDir()
writeLegacyAIProviderConfig(t, app.configDir, map[string]any{
"providers": []map[string]any{
{
"id": "openai-main",
"type": "openai",
"name": "OpenAI",
"apiKey": "sk-ai-test",
"baseUrl": "https://api.openai.com/v1",
"headers": map[string]any{
"Authorization": "Bearer ai-test",
},
},
},
})
restoreWriteJSONFile := swapSecurityUpdateWriteJSONFile(func(path string, payload any) error {
if strings.HasSuffix(filepath.ToSlash(path), "/"+securityUpdateNormalizedPreviewFileName) {
return errors.New("forced preview write failure")
}
return writeJSONFile(path, payload)
})
defer restoreWriteJSONFile()
status, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
RawPayload: buildLegacySecurityUpdatePayload(),
})
if err != nil {
t.Fatalf("StartSecurityUpdate returned error: %v", err)
}
if status.OverallStatus != SecurityUpdateOverallStatusRolledBack {
t.Fatalf("expected rolled_back status, got %q", status.OverallStatus)
}
assertSecurityUpdateRollbackRestoredCurrentAppState(t, app, store)
}
func TestStartSecurityUpdateRollsBackAllChangesWhenFinalResultWriteFails(t *testing.T) {
store := newFakeAppSecretStore()
app := NewAppWithSecretStore(store)
app.configDir = t.TempDir()
writeLegacyAIProviderConfig(t, app.configDir, map[string]any{
"providers": []map[string]any{
{
"id": "openai-main",
"type": "openai",
"name": "OpenAI",
"apiKey": "sk-ai-test",
"baseUrl": "https://api.openai.com/v1",
"headers": map[string]any{
"Authorization": "Bearer ai-test",
},
},
},
})
resultWrites := 0
restoreWriteJSONFile := swapSecurityUpdateWriteJSONFile(func(path string, payload any) error {
if strings.HasSuffix(filepath.ToSlash(path), "/"+securityUpdateResultFileName) {
resultWrites++
if resultWrites == 2 {
return errors.New("forced result write failure")
}
}
return writeJSONFile(path, payload)
})
defer restoreWriteJSONFile()
status, err := app.StartSecurityUpdate(StartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
RawPayload: buildLegacySecurityUpdatePayload(),
})
if err != nil {
t.Fatalf("StartSecurityUpdate returned error: %v", err)
}
if status.OverallStatus != SecurityUpdateOverallStatusRolledBack {
t.Fatalf("expected rolled_back status, got %q", status.OverallStatus)
}
assertSecurityUpdateRollbackRestoredCurrentAppState(t, app, store)
}
func buildLegacySecurityUpdatePayload() string {
payload, _ := json.Marshal(map[string]any{
"state": map[string]any{
"connections": []map[string]any{
{
"id": "legacy-1",
"name": "Legacy",
"config": map[string]any{
"id": "legacy-1",
"type": "postgres",
"host": "db.local",
"port": 5432,
"user": "postgres",
"password": "postgres-secret",
},
},
},
"globalProxy": map[string]any{
"enabled": true,
"type": "http",
"host": "127.0.0.1",
"port": 8080,
"user": "ops",
"password": "proxy-secret",
},
},
})
return string(payload)
}
func writeLegacyAIProviderConfig(t *testing.T, configDir string, payload map[string]any) {
t.Helper()
data, err := json.MarshalIndent(payload, "", " ")
if err != nil {
t.Fatalf("MarshalIndent returned error: %v", err)
}
if err := os.WriteFile(filepath.Join(configDir, "ai_config.json"), data, 0o644); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}
}
func swapSecurityUpdateWriteJSONFile(next func(path string, payload any) error) func() {
original := securityUpdateWriteJSONFile
securityUpdateWriteJSONFile = next
return func() {
securityUpdateWriteJSONFile = original
}
}
func assertSecurityUpdateRollbackRestoredCurrentAppState(t *testing.T, app *App, store *fakeAppSecretStore) {
t.Helper()
savedConnections, err := app.GetSavedConnections()
if err != nil {
t.Fatalf("GetSavedConnections returned error: %v", err)
}
if len(savedConnections) != 0 {
t.Fatalf("expected rollback to leave no imported connections, got %#v", savedConnections)
}
if _, err := app.loadStoredGlobalProxyView(); !os.IsNotExist(err) {
t.Fatalf("expected rollback to remove imported global proxy, got err=%v", err)
}
inspection, err := aiservice.NewProviderConfigStore(app.configDir, app.secretStore).Inspect()
if err != nil {
t.Fatalf("Inspect returned error: %v", err)
}
if len(inspection.ProvidersNeedingMigration) != 1 || inspection.ProvidersNeedingMigration[0] != "openai-main" {
t.Fatalf("expected AI provider migration requirement to be restored, got %#v", inspection.ProvidersNeedingMigration)
}
ref, err := secretstore.BuildRef("ai-provider", "openai-main")
if err != nil {
t.Fatalf("BuildRef returned error: %v", err)
}
if _, err := store.Get(ref); !os.IsNotExist(err) {
t.Fatalf("expected rollback to remove migrated AI provider secret, got err=%v", err)
}
}

View File

@@ -0,0 +1,314 @@
package app
import (
"os"
"path/filepath"
"strings"
aiservice "GoNavi-Wails/internal/ai/service"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/secretstore"
)
const (
securityUpdateAIConfigFileName = "ai_config.json"
securityUpdateAIProviderSecretKind = "ai-provider"
)
type securityUpdateSecretSnapshot struct {
Exists bool
Payload []byte
}
type securityUpdateCurrentAppRollbackSnapshot struct {
connectionsFileExists bool
connectionsFileData []byte
connectionSecrets map[string]securityUpdateSecretSnapshot
connectionCleanupRefs []string
globalProxyFileExists bool
globalProxyFileData []byte
globalProxySecretRef string
globalProxySecret securityUpdateSecretSnapshot
globalProxyCleanupRef string
aiConfigFileExists bool
aiConfigFileData []byte
aiProviderSecrets map[string]securityUpdateSecretSnapshot
aiProviderCleanupRefs []string
}
func captureSecurityUpdateCurrentAppRollbackSnapshot(a *App, source securityUpdateCurrentAppSource) (securityUpdateCurrentAppRollbackSnapshot, error) {
snapshot := securityUpdateCurrentAppRollbackSnapshot{
connectionSecrets: make(map[string]securityUpdateSecretSnapshot),
aiProviderSecrets: make(map[string]securityUpdateSecretSnapshot),
}
configDir := strings.TrimSpace(a.configDir)
if configDir == "" {
configDir = resolveAppConfigDir()
}
connectionRepo := a.savedConnectionRepository()
connectionFileData, connectionFileExists, err := readOptionalFile(connectionRepo.connectionsPath())
if err != nil {
return snapshot, err
}
snapshot.connectionsFileExists = connectionFileExists
snapshot.connectionsFileData = connectionFileData
existingConnections, err := connectionRepo.load()
if err != nil {
return snapshot, err
}
existingConnectionsByID := make(map[string]connection.SavedConnectionView, len(existingConnections))
for _, item := range existingConnections {
existingConnectionsByID[item.ID] = item
}
connectionCleanupSet := make(map[string]struct{})
for _, item := range source.Connections {
connectionID := strings.TrimSpace(item.ID)
if connectionID == "" {
connectionID = strings.TrimSpace(item.Config.ID)
}
if connectionID == "" {
continue
}
defaultRef, refErr := secretstore.BuildRef(savedConnectionSecretKind, connectionID)
if refErr == nil {
connectionCleanupSet[defaultRef] = struct{}{}
}
existing, ok := existingConnectionsByID[connectionID]
if !ok || !savedConnectionViewHasSecrets(existing) {
continue
}
ref := strings.TrimSpace(existing.SecretRef)
if ref == "" {
ref = defaultRef
}
if ref == "" {
continue
}
secretSnapshot, captureErr := captureSecurityUpdateSecretSnapshot(a.secretStore, ref)
if captureErr != nil {
return snapshot, captureErr
}
snapshot.connectionSecrets[ref] = secretSnapshot
connectionCleanupSet[ref] = struct{}{}
}
snapshot.connectionCleanupRefs = make([]string, 0, len(connectionCleanupSet))
for ref := range connectionCleanupSet {
snapshot.connectionCleanupRefs = append(snapshot.connectionCleanupRefs, ref)
}
if source.GlobalProxy != nil {
globalProxyFileData, globalProxyFileExists, err := readOptionalFile(globalProxyMetadataPath(configDir))
if err != nil {
return snapshot, err
}
snapshot.globalProxyFileExists = globalProxyFileExists
snapshot.globalProxyFileData = globalProxyFileData
defaultProxyRef, refErr := secretstore.BuildRef(globalProxySecretKind, globalProxySecretID)
if refErr == nil {
snapshot.globalProxyCleanupRef = defaultProxyRef
}
existingProxy, err := a.loadStoredGlobalProxyView()
if err != nil {
if !os.IsNotExist(err) {
return snapshot, err
}
} else if existingProxy.HasPassword {
ref := strings.TrimSpace(existingProxy.SecretRef)
if ref == "" {
ref = snapshot.globalProxyCleanupRef
}
if ref != "" {
secretSnapshot, captureErr := captureSecurityUpdateSecretSnapshot(a.secretStore, ref)
if captureErr != nil {
return snapshot, captureErr
}
snapshot.globalProxySecretRef = ref
snapshot.globalProxySecret = secretSnapshot
}
}
}
aiConfigPath := filepath.Join(configDir, securityUpdateAIConfigFileName)
aiConfigFileData, aiConfigFileExists, err := readOptionalFile(aiConfigPath)
if err != nil {
return snapshot, err
}
snapshot.aiConfigFileExists = aiConfigFileExists
snapshot.aiConfigFileData = aiConfigFileData
inspection, err := aiservice.NewProviderConfigStore(configDir, a.secretStore).Inspect()
if err != nil {
return snapshot, err
}
aiProviderCleanupSet := make(map[string]struct{})
for _, provider := range inspection.Snapshot.Providers {
providerID := strings.TrimSpace(provider.ID)
if providerID == "" {
continue
}
ref := strings.TrimSpace(provider.SecretRef)
if ref == "" && (provider.HasSecret || strings.TrimSpace(provider.APIKey) != "" || len(provider.Headers) > 0) {
builtRef, refErr := secretstore.BuildRef(securityUpdateAIProviderSecretKind, providerID)
if refErr == nil {
ref = builtRef
}
}
if ref == "" {
continue
}
secretSnapshot, captureErr := captureSecurityUpdateSecretSnapshot(a.secretStore, ref)
if captureErr != nil {
return snapshot, captureErr
}
snapshot.aiProviderSecrets[ref] = secretSnapshot
aiProviderCleanupSet[ref] = struct{}{}
}
snapshot.aiProviderCleanupRefs = make([]string, 0, len(aiProviderCleanupSet))
for ref := range aiProviderCleanupSet {
snapshot.aiProviderCleanupRefs = append(snapshot.aiProviderCleanupRefs, ref)
}
return snapshot, nil
}
func (s securityUpdateCurrentAppRollbackSnapshot) restore(a *App) error {
configDir := strings.TrimSpace(a.configDir)
if configDir == "" {
configDir = resolveAppConfigDir()
}
connectionRepo := a.savedConnectionRepository()
if err := restoreOptionalFile(connectionRepo.connectionsPath(), s.connectionsFileExists, s.connectionsFileData); err != nil {
return err
}
for ref, secretSnapshot := range s.connectionSecrets {
if err := restoreSecurityUpdateSecretSnapshot(a.secretStore, ref, secretSnapshot); err != nil {
return err
}
}
for _, ref := range s.connectionCleanupRefs {
if _, alreadyRestored := s.connectionSecrets[ref]; alreadyRestored {
continue
}
if err := deleteSecurityUpdateSecretRef(a.secretStore, ref); err != nil {
return err
}
}
if err := restoreOptionalFile(globalProxyMetadataPath(configDir), s.globalProxyFileExists, s.globalProxyFileData); err != nil {
return err
}
if s.globalProxySecretRef != "" {
if err := restoreSecurityUpdateSecretSnapshot(a.secretStore, s.globalProxySecretRef, s.globalProxySecret); err != nil {
return err
}
}
if s.globalProxyCleanupRef != "" && s.globalProxyCleanupRef != s.globalProxySecretRef {
if err := deleteSecurityUpdateSecretRef(a.secretStore, s.globalProxyCleanupRef); err != nil {
return err
}
}
if err := restoreOptionalFile(filepath.Join(configDir, securityUpdateAIConfigFileName), s.aiConfigFileExists, s.aiConfigFileData); err != nil {
return err
}
for ref, secretSnapshot := range s.aiProviderSecrets {
if err := restoreSecurityUpdateSecretSnapshot(a.secretStore, ref, secretSnapshot); err != nil {
return err
}
}
for _, ref := range s.aiProviderCleanupRefs {
if _, alreadyRestored := s.aiProviderSecrets[ref]; alreadyRestored {
continue
}
if err := deleteSecurityUpdateSecretRef(a.secretStore, ref); err != nil {
return err
}
}
if s.globalProxyFileExists {
a.loadPersistedGlobalProxy()
return nil
}
_, err := setGlobalProxyConfig(false, connection.ProxyConfig{})
return err
}
func readOptionalFile(path string) ([]byte, bool, error) {
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, false, nil
}
return nil, false, err
}
return append([]byte(nil), data...), true, nil
}
func restoreOptionalFile(path string, exists bool, data []byte) error {
if !exists {
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
return os.WriteFile(path, data, 0o644)
}
func captureSecurityUpdateSecretSnapshot(store secretstore.SecretStore, ref string) (securityUpdateSecretSnapshot, error) {
if store == nil || strings.TrimSpace(ref) == "" {
return securityUpdateSecretSnapshot{}, nil
}
payload, err := store.Get(ref)
if err != nil {
if os.IsNotExist(err) || secretstore.IsUnavailable(err) {
return securityUpdateSecretSnapshot{}, nil
}
return securityUpdateSecretSnapshot{}, err
}
return securityUpdateSecretSnapshot{
Exists: true,
Payload: append([]byte(nil), payload...),
}, nil
}
func restoreSecurityUpdateSecretSnapshot(store secretstore.SecretStore, ref string, snapshot securityUpdateSecretSnapshot) error {
if store == nil || strings.TrimSpace(ref) == "" {
return nil
}
if snapshot.Exists {
if err := store.Put(ref, snapshot.Payload); err != nil {
if secretstore.IsUnavailable(err) {
return nil
}
return err
}
return nil
}
return deleteSecurityUpdateSecretRef(store, ref)
}
func deleteSecurityUpdateSecretRef(store secretstore.SecretStore, ref string) error {
if store == nil || strings.TrimSpace(ref) == "" {
return nil
}
if err := store.Delete(ref); err != nil {
if os.IsNotExist(err) || secretstore.IsUnavailable(err) {
return nil
}
return err
}
return nil
}

View File

@@ -0,0 +1,85 @@
package app
import (
"encoding/json"
"strings"
"GoNavi-Wails/internal/connection"
)
const (
securityUpdateSourceCurrentAppFileName = "source-current-app.json"
securityUpdateNormalizedPreviewFileName = "normalized-preview.json"
)
type securityUpdateCurrentAppEnvelope struct {
State securityUpdateCurrentAppPayload `json:"state"`
Connections []connection.LegacySavedConnection `json:"connections"`
GlobalProxy *connection.LegacyGlobalProxyInput `json:"globalProxy"`
}
type securityUpdateCurrentAppPayload struct {
Connections []connection.LegacySavedConnection `json:"connections"`
GlobalProxy *connection.LegacyGlobalProxyInput `json:"globalProxy"`
}
type securityUpdateCurrentAppSource struct {
Connections []connection.LegacySavedConnection `json:"connections"`
GlobalProxy *connection.LegacyGlobalProxyInput `json:"globalProxy,omitempty"`
}
func parseSecurityUpdateCurrentAppSource(rawPayload string) (securityUpdateCurrentAppSource, any, error) {
trimmed := strings.TrimSpace(rawPayload)
if trimmed == "" {
return securityUpdateCurrentAppSource{Connections: []connection.LegacySavedConnection{}}, map[string]any{}, nil
}
var raw any
if err := json.Unmarshal([]byte(trimmed), &raw); err != nil {
return securityUpdateCurrentAppSource{}, nil, err
}
var envelope securityUpdateCurrentAppEnvelope
if err := json.Unmarshal([]byte(trimmed), &envelope); err != nil {
return securityUpdateCurrentAppSource{}, nil, err
}
connections := envelope.Connections
globalProxy := envelope.GlobalProxy
if len(envelope.State.Connections) > 0 || envelope.State.GlobalProxy != nil {
connections = envelope.State.Connections
globalProxy = envelope.State.GlobalProxy
}
normalizedConnections := make([]connection.LegacySavedConnection, 0, len(connections))
for _, item := range connections {
if strings.TrimSpace(item.ID) == "" && strings.TrimSpace(item.Config.ID) == "" {
continue
}
if strings.TrimSpace(item.ID) == "" {
item.ID = strings.TrimSpace(item.Config.ID)
}
item.Config.ID = item.ID
normalizedConnections = append(normalizedConnections, item)
}
if globalProxy != nil {
normalizedType := strings.ToLower(strings.TrimSpace(globalProxy.Type))
if normalizedType != "http" {
normalizedType = "socks5"
}
globalProxy.Type = normalizedType
if globalProxy.Port <= 0 || globalProxy.Port > 65535 {
if normalizedType == "http" {
globalProxy.Port = 8080
} else {
globalProxy.Port = 1080
}
}
}
return securityUpdateCurrentAppSource{
Connections: normalizedConnections,
GlobalProxy: globalProxy,
}, raw, nil
}

View File

@@ -0,0 +1,293 @@
package app
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
)
const (
securityUpdateSchemaVersion = 1
securityUpdateMarkerDirName = "migrations"
securityUpdateMarkerFileName = "config-security-update.json"
securityUpdateBackupRootDirName = "migration-backups"
securityUpdateManifestFileName = "manifest.json"
securityUpdateResultFileName = "result.json"
)
var securityUpdateWriteJSONFile = writeJSONFile
type securityUpdateStateRepository struct {
configDir string
}
type securityUpdateMarker struct {
SchemaVersion int `json:"schemaVersion"`
MigrationID string `json:"migrationId"`
SourceType SecurityUpdateSourceType `json:"sourceType"`
Status SecurityUpdateOverallStatus `json:"status"`
StartedAt string `json:"startedAt,omitempty"`
UpdatedAt string `json:"updatedAt,omitempty"`
CompletedAt string `json:"completedAt,omitempty"`
PostponedAt string `json:"postponedAt,omitempty"`
BackupPath string `json:"backupPath,omitempty"`
BackupSHA256 string `json:"backupSha256,omitempty"`
Summary SecurityUpdateSummary `json:"summary"`
Issues []SecurityUpdateIssue `json:"issues"`
LastError string `json:"lastError,omitempty"`
}
type securityUpdateBackupManifest struct {
SchemaVersion int `json:"schemaVersion"`
MigrationID string `json:"migrationId"`
SourceType SecurityUpdateSourceType `json:"sourceType"`
CreatedAt string `json:"createdAt"`
StartedAt string `json:"startedAt,omitempty"`
BackupPath string `json:"backupPath"`
}
func newSecurityUpdateStateRepository(configDir string) *securityUpdateStateRepository {
if strings.TrimSpace(configDir) == "" {
configDir = resolveAppConfigDir()
}
return &securityUpdateStateRepository{configDir: configDir}
}
func (r *securityUpdateStateRepository) markerPath() string {
return filepath.Join(r.configDir, securityUpdateMarkerDirName, securityUpdateMarkerFileName)
}
func (r *securityUpdateStateRepository) backupRootPath() string {
return filepath.Join(r.configDir, securityUpdateBackupRootDirName)
}
func (r *securityUpdateStateRepository) backupPath(migrationID string) string {
return filepath.Join(r.backupRootPath(), migrationID)
}
func (r *securityUpdateStateRepository) manifestPath(migrationID string) string {
return filepath.Join(r.backupPath(migrationID), securityUpdateManifestFileName)
}
func (r *securityUpdateStateRepository) resultPath(migrationID string) string {
return filepath.Join(r.backupPath(migrationID), securityUpdateResultFileName)
}
func (r *securityUpdateStateRepository) LoadMarker() (SecurityUpdateStatus, error) {
marker, err := r.readMarker()
if err != nil {
return SecurityUpdateStatus{}, err
}
return buildSecurityUpdateStatus(marker), nil
}
func (r *securityUpdateStateRepository) StartRound(request StartSecurityUpdateRequest) (SecurityUpdateStatus, error) {
marker := r.newRoundMarker(request.SourceType)
if err := r.initializeRoundArtifacts(marker); err != nil {
return SecurityUpdateStatus{}, err
}
status := buildSecurityUpdateStatus(marker)
if err := r.WriteResult(status); err != nil {
return SecurityUpdateStatus{}, err
}
return status, nil
}
func (r *securityUpdateStateRepository) RetryRound(request RetrySecurityUpdateRequest) (SecurityUpdateStatus, error) {
marker, err := r.readMarker()
if err != nil {
return SecurityUpdateStatus{}, err
}
if requestedID := strings.TrimSpace(request.MigrationID); requestedID != "" && requestedID != marker.MigrationID {
return SecurityUpdateStatus{}, fmt.Errorf("migration ID mismatch: current=%s requested=%s", marker.MigrationID, requestedID)
}
if marker.Status != SecurityUpdateOverallStatusNeedsAttention {
return SecurityUpdateStatus{}, fmt.Errorf(
"retry current round requires status %s: current=%s",
SecurityUpdateOverallStatusNeedsAttention,
marker.Status,
)
}
marker.Status = SecurityUpdateOverallStatusInProgress
marker.UpdatedAt = nowRFC3339()
if marker.BackupPath == "" {
marker.BackupPath = r.backupPath(marker.MigrationID)
}
if err := os.MkdirAll(marker.BackupPath, 0o755); err != nil {
return SecurityUpdateStatus{}, err
}
status := buildSecurityUpdateStatus(marker)
if err := r.WriteResult(status); err != nil {
return SecurityUpdateStatus{}, err
}
return status, nil
}
func (r *securityUpdateStateRepository) RestartRound(request RestartSecurityUpdateRequest) (SecurityUpdateStatus, error) {
marker := r.newRoundMarker(request.SourceType)
if err := r.initializeRoundArtifacts(marker); err != nil {
return SecurityUpdateStatus{}, err
}
status := buildSecurityUpdateStatus(marker)
if err := r.WriteResult(status); err != nil {
return SecurityUpdateStatus{}, err
}
return status, nil
}
func (r *securityUpdateStateRepository) WriteResult(status SecurityUpdateStatus) error {
marker := markerFromStatus(status)
if err := r.writeMarker(marker); err != nil {
return err
}
if strings.TrimSpace(marker.BackupPath) == "" {
return nil
}
if err := os.MkdirAll(marker.BackupPath, 0o755); err != nil {
return err
}
return securityUpdateWriteJSONFile(r.resultPath(marker.MigrationID), buildSecurityUpdateStatus(marker))
}
func (r *securityUpdateStateRepository) newRoundMarker(sourceType SecurityUpdateSourceType) securityUpdateMarker {
now := nowRFC3339()
if strings.TrimSpace(string(sourceType)) == "" {
sourceType = SecurityUpdateSourceTypeCurrentAppSavedConfig
}
migrationID := uuid.NewString()
return securityUpdateMarker{
SchemaVersion: securityUpdateSchemaVersion,
MigrationID: migrationID,
SourceType: sourceType,
Status: SecurityUpdateOverallStatusInProgress,
StartedAt: now,
UpdatedAt: now,
BackupPath: r.backupPath(migrationID),
Summary: SecurityUpdateSummary{},
Issues: []SecurityUpdateIssue{},
}
}
func (r *securityUpdateStateRepository) initializeRoundArtifacts(marker securityUpdateMarker) error {
if err := os.MkdirAll(marker.BackupPath, 0o755); err != nil {
return err
}
manifest := securityUpdateBackupManifest{
SchemaVersion: securityUpdateSchemaVersion,
MigrationID: marker.MigrationID,
SourceType: marker.SourceType,
CreatedAt: marker.UpdatedAt,
StartedAt: marker.StartedAt,
BackupPath: marker.BackupPath,
}
if err := securityUpdateWriteJSONFile(r.manifestPath(marker.MigrationID), manifest); err != nil {
return err
}
return r.writeMarker(marker)
}
func (r *securityUpdateStateRepository) readMarker() (securityUpdateMarker, error) {
data, err := os.ReadFile(r.markerPath())
if err != nil {
return securityUpdateMarker{}, err
}
var marker securityUpdateMarker
if err := json.Unmarshal(data, &marker); err != nil {
return securityUpdateMarker{}, err
}
if marker.Issues == nil {
marker.Issues = []SecurityUpdateIssue{}
}
return marker, nil
}
func (r *securityUpdateStateRepository) writeMarker(marker securityUpdateMarker) error {
if err := os.MkdirAll(filepath.Dir(r.markerPath()), 0o755); err != nil {
return err
}
return securityUpdateWriteJSONFile(r.markerPath(), marker)
}
func buildSecurityUpdateStatus(marker securityUpdateMarker) SecurityUpdateStatus {
status := SecurityUpdateStatus{
SchemaVersion: marker.SchemaVersion,
MigrationID: marker.MigrationID,
OverallStatus: marker.Status,
SourceType: marker.SourceType,
BackupAvailable: strings.TrimSpace(marker.BackupPath) != "",
BackupPath: marker.BackupPath,
StartedAt: marker.StartedAt,
UpdatedAt: marker.UpdatedAt,
CompletedAt: marker.CompletedAt,
PostponedAt: marker.PostponedAt,
Summary: marker.Summary,
Issues: marker.Issues,
LastError: marker.LastError,
}
if status.Issues == nil {
status.Issues = []SecurityUpdateIssue{}
}
switch status.OverallStatus {
case SecurityUpdateOverallStatusPending:
status.ReminderVisible = true
status.CanStart = true
status.CanPostpone = true
case SecurityUpdateOverallStatusPostponed:
status.CanStart = true
case SecurityUpdateOverallStatusNeedsAttention:
status.CanRetry = true
status.CanStart = true
case SecurityUpdateOverallStatusRolledBack:
status.CanStart = true
case SecurityUpdateOverallStatusCompleted:
status.BackupAvailable = strings.TrimSpace(status.BackupPath) != ""
}
return status
}
func markerFromStatus(status SecurityUpdateStatus) securityUpdateMarker {
marker := securityUpdateMarker{
SchemaVersion: securityUpdateSchemaVersion,
MigrationID: strings.TrimSpace(status.MigrationID),
SourceType: status.SourceType,
Status: status.OverallStatus,
StartedAt: status.StartedAt,
UpdatedAt: status.UpdatedAt,
CompletedAt: status.CompletedAt,
PostponedAt: status.PostponedAt,
BackupPath: status.BackupPath,
Summary: status.Summary,
Issues: status.Issues,
LastError: status.LastError,
}
if marker.SchemaVersion == 0 {
marker.SchemaVersion = securityUpdateSchemaVersion
}
if marker.Issues == nil {
marker.Issues = []SecurityUpdateIssue{}
}
if marker.BackupPath == "" && marker.MigrationID != "" {
marker.BackupPath = filepath.Join(resolveAppConfigDir(), securityUpdateBackupRootDirName, marker.MigrationID)
}
if marker.UpdatedAt == "" {
marker.UpdatedAt = nowRFC3339()
}
return marker
}
func writeJSONFile(path string, payload any) error {
data, err := json.MarshalIndent(payload, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o644)
}
func nowRFC3339() string {
return time.Now().UTC().Format(time.RFC3339)
}

View File

@@ -0,0 +1,226 @@
package app
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestSecurityUpdateStateStartRoundCreatesMarkerAndManifest(t *testing.T) {
repo := newSecurityUpdateStateRepository(t.TempDir())
status, err := repo.StartRound(StartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
})
if err != nil {
t.Fatalf("StartRound returned error: %v", err)
}
if status.MigrationID == "" {
t.Fatal("expected migration ID to be created")
}
if status.SourceType != SecurityUpdateSourceTypeCurrentAppSavedConfig {
t.Fatalf("expected source type %q, got %q", SecurityUpdateSourceTypeCurrentAppSavedConfig, status.SourceType)
}
if status.OverallStatus != SecurityUpdateOverallStatusInProgress {
t.Fatalf("expected overall status %q, got %q", SecurityUpdateOverallStatusInProgress, status.OverallStatus)
}
if !status.BackupAvailable {
t.Fatal("expected backupAvailable=true")
}
markerPath := filepath.Join(repo.configDir, securityUpdateMarkerDirName, securityUpdateMarkerFileName)
if _, err := os.Stat(markerPath); err != nil {
t.Fatalf("expected marker file at %q: %v", markerPath, err)
}
data, err := os.ReadFile(markerPath)
if err != nil {
t.Fatalf("ReadFile marker failed: %v", err)
}
var marker securityUpdateMarker
if err := json.Unmarshal(data, &marker); err != nil {
t.Fatalf("Unmarshal marker failed: %v", err)
}
if marker.MigrationID != status.MigrationID {
t.Fatalf("expected marker migration ID %q, got %q", status.MigrationID, marker.MigrationID)
}
manifestPath := filepath.Join(repo.configDir, securityUpdateBackupRootDirName, status.MigrationID, securityUpdateManifestFileName)
if _, err := os.Stat(manifestPath); err != nil {
t.Fatalf("expected manifest file at %q: %v", manifestPath, err)
}
}
func TestSecurityUpdateStateRetryRoundReusesCurrentMigrationID(t *testing.T) {
repo := newSecurityUpdateStateRepository(t.TempDir())
initial, err := repo.StartRound(StartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
})
if err != nil {
t.Fatalf("StartRound returned error: %v", err)
}
initial.OverallStatus = SecurityUpdateOverallStatusNeedsAttention
initial.UpdatedAt = nowRFC3339()
initial.Summary = SecurityUpdateSummary{
Total: 1,
Pending: 1,
}
initial.Issues = []SecurityUpdateIssue{
{
ID: "connection-legacy-1",
Scope: SecurityUpdateIssueScopeConnection,
RefID: "legacy-1",
Title: "Legacy",
Severity: SecurityUpdateIssueSeverityMedium,
Status: SecurityUpdateItemStatusNeedsAttention,
ReasonCode: SecurityUpdateIssueReasonCodeSecretMissing,
Action: SecurityUpdateIssueActionOpenConnection,
Message: "连接密码已丢失,请重新保存后再继续",
},
}
if err := repo.WriteResult(initial); err != nil {
t.Fatalf("WriteResult returned error: %v", err)
}
retried, err := repo.RetryRound(RetrySecurityUpdateRequest{
MigrationID: initial.MigrationID,
})
if err != nil {
t.Fatalf("RetryRound returned error: %v", err)
}
if retried.MigrationID != initial.MigrationID {
t.Fatalf("expected retry to reuse migration ID %q, got %q", initial.MigrationID, retried.MigrationID)
}
entries, err := os.ReadDir(filepath.Join(repo.configDir, securityUpdateBackupRootDirName))
if err != nil {
t.Fatalf("ReadDir backup root failed: %v", err)
}
if len(entries) != 1 {
t.Fatalf("expected retry to keep a single backup directory, got %d", len(entries))
}
}
func TestSecurityUpdateStateRetryRoundRejectsRolledBackRound(t *testing.T) {
repo := newSecurityUpdateStateRepository(t.TempDir())
marker := securityUpdateMarker{
SchemaVersion: securityUpdateSchemaVersion,
MigrationID: "migration-1",
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
Status: SecurityUpdateOverallStatusRolledBack,
StartedAt: "2026-04-09T00:00:00Z",
UpdatedAt: "2026-04-09T00:05:00Z",
BackupPath: repo.backupPath("migration-1"),
Summary: SecurityUpdateSummary{
Total: 1,
Failed: 1,
},
Issues: []SecurityUpdateIssue{
{
ID: "system-blocked",
Scope: SecurityUpdateIssueScopeSystem,
Title: "安全更新未完成",
Severity: SecurityUpdateIssueSeverityHigh,
Status: SecurityUpdateItemStatusFailed,
ReasonCode: SecurityUpdateIssueReasonCodeEnvironmentBlocked,
Action: SecurityUpdateIssueActionViewDetails,
Message: "当前环境无法完成本次安全更新,请稍后重试",
},
},
}
if err := repo.writeMarker(marker); err != nil {
t.Fatalf("writeMarker returned error: %v", err)
}
if _, err := repo.RetryRound(RetrySecurityUpdateRequest{MigrationID: marker.MigrationID}); err == nil {
t.Fatal("expected RetryRound to reject rolled_back round")
}
current, err := repo.LoadMarker()
if err != nil {
t.Fatalf("LoadMarker returned error: %v", err)
}
if current.OverallStatus != SecurityUpdateOverallStatusRolledBack {
t.Fatalf("expected marker to remain rolled_back, got %q", current.OverallStatus)
}
}
func TestBuildSecurityUpdateStatusDoesNotAllowRetryAfterRollback(t *testing.T) {
status := buildSecurityUpdateStatus(securityUpdateMarker{
SchemaVersion: securityUpdateSchemaVersion,
MigrationID: "migration-1",
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
Status: SecurityUpdateOverallStatusRolledBack,
StartedAt: "2026-04-09T00:00:00Z",
UpdatedAt: "2026-04-09T00:05:00Z",
BackupPath: filepath.Join("backup", "migration-1"),
Summary: SecurityUpdateSummary{
Total: 1,
Failed: 1,
},
Issues: []SecurityUpdateIssue{
{
ID: "system-blocked",
Scope: SecurityUpdateIssueScopeSystem,
Title: "安全更新未完成",
Severity: SecurityUpdateIssueSeverityHigh,
Status: SecurityUpdateItemStatusFailed,
ReasonCode: SecurityUpdateIssueReasonCodeEnvironmentBlocked,
Action: SecurityUpdateIssueActionViewDetails,
Message: "当前环境无法完成本次安全更新,请稍后重试",
},
},
})
if status.CanRetry {
t.Fatal("expected rolled_back status to require restart instead of retry")
}
if !status.CanStart {
t.Fatal("expected rolled_back status to allow starting a new round")
}
}
func TestSecurityUpdateStateRestartRoundCreatesNewMigrationID(t *testing.T) {
repo := newSecurityUpdateStateRepository(t.TempDir())
initial, err := repo.StartRound(StartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
})
if err != nil {
t.Fatalf("StartRound returned error: %v", err)
}
restarted, err := repo.RestartRound(RestartSecurityUpdateRequest{
SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig,
})
if err != nil {
t.Fatalf("RestartRound returned error: %v", err)
}
if restarted.MigrationID == initial.MigrationID {
t.Fatal("expected restart to create a new migration ID")
}
entries, err := os.ReadDir(filepath.Join(repo.configDir, securityUpdateBackupRootDirName))
if err != nil {
t.Fatalf("ReadDir backup root failed: %v", err)
}
if len(entries) != 2 {
t.Fatalf("expected restart to create a second backup directory, got %d", len(entries))
}
current, err := repo.LoadMarker()
if err != nil {
t.Fatalf("LoadMarker returned error: %v", err)
}
if current.MigrationID != restarted.MigrationID {
t.Fatalf("expected marker to point to latest migration ID %q, got %q", restarted.MigrationID, current.MigrationID)
}
}

View File

@@ -0,0 +1,129 @@
package app
type SecurityUpdateSourceType string
const (
SecurityUpdateSourceTypeCurrentAppSavedConfig SecurityUpdateSourceType = "current_app_saved_config"
)
type SecurityUpdateOverallStatus string
const (
SecurityUpdateOverallStatusNotDetected SecurityUpdateOverallStatus = "not_detected"
SecurityUpdateOverallStatusPending SecurityUpdateOverallStatus = "pending"
SecurityUpdateOverallStatusPostponed SecurityUpdateOverallStatus = "postponed"
SecurityUpdateOverallStatusInProgress SecurityUpdateOverallStatus = "in_progress"
SecurityUpdateOverallStatusNeedsAttention SecurityUpdateOverallStatus = "needs_attention"
SecurityUpdateOverallStatusCompleted SecurityUpdateOverallStatus = "completed"
SecurityUpdateOverallStatusRolledBack SecurityUpdateOverallStatus = "rolled_back"
)
type SecurityUpdateIssueScope string
const (
SecurityUpdateIssueScopeConnection SecurityUpdateIssueScope = "connection"
SecurityUpdateIssueScopeGlobalProxy SecurityUpdateIssueScope = "global_proxy"
SecurityUpdateIssueScopeAIProvider SecurityUpdateIssueScope = "ai_provider"
SecurityUpdateIssueScopeSystem SecurityUpdateIssueScope = "system"
)
type SecurityUpdateIssueSeverity string
const (
SecurityUpdateIssueSeverityHigh SecurityUpdateIssueSeverity = "high"
SecurityUpdateIssueSeverityMedium SecurityUpdateIssueSeverity = "medium"
SecurityUpdateIssueSeverityLow SecurityUpdateIssueSeverity = "low"
)
type SecurityUpdateItemStatus string
const (
SecurityUpdateItemStatusPending SecurityUpdateItemStatus = "pending"
SecurityUpdateItemStatusUpdated SecurityUpdateItemStatus = "updated"
SecurityUpdateItemStatusNeedsAttention SecurityUpdateItemStatus = "needs_attention"
SecurityUpdateItemStatusSkipped SecurityUpdateItemStatus = "skipped"
SecurityUpdateItemStatusFailed SecurityUpdateItemStatus = "failed"
)
type SecurityUpdateIssueReasonCode string
const (
SecurityUpdateIssueReasonCodeMigrationRequired SecurityUpdateIssueReasonCode = "migration_required"
SecurityUpdateIssueReasonCodeSecretMissing SecurityUpdateIssueReasonCode = "secret_missing"
SecurityUpdateIssueReasonCodeFieldInvalid SecurityUpdateIssueReasonCode = "field_invalid"
SecurityUpdateIssueReasonCodeWriteConflict SecurityUpdateIssueReasonCode = "write_conflict"
SecurityUpdateIssueReasonCodeValidationFailed SecurityUpdateIssueReasonCode = "validation_failed"
SecurityUpdateIssueReasonCodeEnvironmentBlocked SecurityUpdateIssueReasonCode = "environment_blocked"
)
type SecurityUpdateIssueAction string
const (
SecurityUpdateIssueActionOpenConnection SecurityUpdateIssueAction = "open_connection"
SecurityUpdateIssueActionOpenProxySettings SecurityUpdateIssueAction = "open_proxy_settings"
SecurityUpdateIssueActionOpenAISettings SecurityUpdateIssueAction = "open_ai_settings"
SecurityUpdateIssueActionRetryUpdate SecurityUpdateIssueAction = "retry_update"
SecurityUpdateIssueActionViewDetails SecurityUpdateIssueAction = "view_details"
)
type SecurityUpdateSummary struct {
Total int `json:"total"`
Updated int `json:"updated"`
Pending int `json:"pending"`
Skipped int `json:"skipped"`
Failed int `json:"failed"`
}
type SecurityUpdateIssue struct {
ID string `json:"id"`
Scope SecurityUpdateIssueScope `json:"scope"`
RefID string `json:"refId,omitempty"`
Title string `json:"title"`
Severity SecurityUpdateIssueSeverity `json:"severity"`
Status SecurityUpdateItemStatus `json:"status"`
ReasonCode SecurityUpdateIssueReasonCode `json:"reasonCode"`
Action SecurityUpdateIssueAction `json:"action"`
Message string `json:"message"`
}
type SecurityUpdateStatus struct {
SchemaVersion int `json:"schemaVersion,omitempty"`
MigrationID string `json:"migrationId,omitempty"`
OverallStatus SecurityUpdateOverallStatus `json:"overallStatus"`
SourceType SecurityUpdateSourceType `json:"sourceType,omitempty"`
ReminderVisible bool `json:"reminderVisible"`
CanStart bool `json:"canStart"`
CanPostpone bool `json:"canPostpone"`
CanRetry bool `json:"canRetry"`
BackupAvailable bool `json:"backupAvailable"`
BackupPath string `json:"backupPath,omitempty"`
StartedAt string `json:"startedAt,omitempty"`
UpdatedAt string `json:"updatedAt,omitempty"`
CompletedAt string `json:"completedAt,omitempty"`
PostponedAt string `json:"postponedAt,omitempty"`
Summary SecurityUpdateSummary `json:"summary"`
Issues []SecurityUpdateIssue `json:"issues"`
LastError string `json:"lastError,omitempty"`
}
type SecurityUpdateOptions struct {
AllowPartial bool `json:"allowPartial,omitempty"`
WriteBackup bool `json:"writeBackup,omitempty"`
}
type StartSecurityUpdateRequest struct {
SourceType SecurityUpdateSourceType `json:"sourceType"`
RawPayload string `json:"rawPayload,omitempty"`
Options *SecurityUpdateOptions `json:"options,omitempty"`
}
type RetrySecurityUpdateRequest struct {
MigrationID string `json:"migrationId,omitempty"`
}
type RestartSecurityUpdateRequest struct {
MigrationID string `json:"migrationId,omitempty"`
SourceType SecurityUpdateSourceType `json:"sourceType"`
RawPayload string `json:"rawPayload,omitempty"`
Options *SecurityUpdateOptions `json:"options,omitempty"`
}

View File

@@ -0,0 +1,14 @@
package app
import "strings"
const macWindowDiagnosticsEnv = "GONAVI_ENABLE_MAC_WINDOW_DIAGNOSTICS"
func shouldInstallMacNativeWindowDiagnostics() bool {
switch strings.ToLower(strings.TrimSpace(getenv(macWindowDiagnosticsEnv))) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}

View File

@@ -0,0 +1,37 @@
package app
import "testing"
func TestShouldInstallMacNativeWindowDiagnosticsDefaultsDisabled(t *testing.T) {
t.Setenv("GONAVI_ENABLE_MAC_WINDOW_DIAGNOSTICS", "")
if shouldInstallMacNativeWindowDiagnostics() {
t.Fatal("expected mac native window diagnostics to stay disabled by default")
}
}
func TestShouldInstallMacNativeWindowDiagnosticsHonorsEnvOptIn(t *testing.T) {
t.Setenv("GONAVI_ENABLE_MAC_WINDOW_DIAGNOSTICS", "1")
if !shouldInstallMacNativeWindowDiagnostics() {
t.Fatal("expected mac native window diagnostics to enable when explicitly opted in")
}
t.Setenv("GONAVI_ENABLE_MAC_WINDOW_DIAGNOSTICS", "true")
if !shouldInstallMacNativeWindowDiagnostics() {
t.Fatal("expected mac native window diagnostics to accept true as opt-in value")
}
t.Setenv("GONAVI_ENABLE_MAC_WINDOW_DIAGNOSTICS", "0")
if shouldInstallMacNativeWindowDiagnostics() {
t.Fatal("expected mac native window diagnostics to stay disabled for non-opt-in values")
}
}
func TestShouldInstallMacNativeWindowDiagnosticsIgnoresCaseAndWhitespace(t *testing.T) {
t.Setenv("GONAVI_ENABLE_MAC_WINDOW_DIAGNOSTICS", " TRUE ")
if !shouldInstallMacNativeWindowDiagnostics() {
t.Fatal("expected helper to trim and lowercase opt-in values")
}
}

View File

@@ -3,7 +3,10 @@ package secretstore
import (
"errors"
"fmt"
"os"
"runtime"
"strings"
"syscall"
"github.com/99designs/keyring"
)
@@ -56,19 +59,32 @@ func (s *keyringStore) Delete(ref string) error {
func (s *keyringStore) HealthCheck() error {
_, err := s.ring.Get(healthCheckRef)
if err == nil || errors.Is(err, keyring.ErrKeyNotFound) {
if err == nil || isKeyringSecretNotFound(err) {
return nil
}
return wrapKeyringError(err)
}
func wrapKeyringError(err error) error {
if err == nil || errors.Is(err, keyring.ErrKeyNotFound) || IsUnavailable(err) {
if err == nil || IsUnavailable(err) {
return err
}
if isKeyringSecretNotFound(err) {
return os.ErrNotExist
}
return &UnavailableError{Reason: err.Error()}
}
func isKeyringSecretNotFound(err error) bool {
if err == nil {
return false
}
if errors.Is(err, keyring.ErrKeyNotFound) || errors.Is(err, syscall.Errno(1168)) {
return true
}
return strings.EqualFold(strings.TrimSpace(err.Error()), keyring.ErrKeyNotFound.Error())
}
func keyringConfigFor(goos string) (keyring.Config, error) {
backends := allowedBackendsFor(goos)
if len(backends) == 0 {

View File

@@ -2,6 +2,9 @@ package secretstore
import (
"errors"
"fmt"
"os"
"syscall"
"testing"
"github.com/99designs/keyring"
@@ -58,6 +61,33 @@ func TestKeyringStoreHealthCheckTreatsMissingProbeItemAsHealthy(t *testing.T) {
}
}
func TestKeyringStoreHealthCheckTreatsWinCredNotFoundMessageAsHealthy(t *testing.T) {
t.Parallel()
store := &keyringStore{ring: fakeKeyringClient{getErr: errors.New("The specified item could not be found in the keyring")}}
if err := store.HealthCheck(); err != nil {
t.Fatalf("HealthCheck should accept WinCred not-found errors, got %v", err)
}
}
func TestKeyringStoreHealthCheckDoesNotTreatWrappedOsErrNotExistAsHealthy(t *testing.T) {
t.Parallel()
store := &keyringStore{ring: fakeKeyringClient{getErr: fmt.Errorf("backend unavailable: %w", os.ErrNotExist)}}
if err := store.HealthCheck(); err == nil {
t.Fatal("HealthCheck should not accept unrelated wrapped os.ErrNotExist errors as healthy")
}
}
func TestKeyringStoreHealthCheckDoesNotTreatPlainOsErrNotExistAsHealthy(t *testing.T) {
t.Parallel()
store := &keyringStore{ring: fakeKeyringClient{getErr: os.ErrNotExist}}
if err := store.HealthCheck(); err == nil {
t.Fatal("HealthCheck should not accept plain os.ErrNotExist errors as healthy")
}
}
func TestKeyringStoreHealthCheckReturnsUnavailableErrorOnBackendFailure(t *testing.T) {
t.Parallel()
@@ -82,6 +112,67 @@ func TestNewKeyringStoreReturnsUnavailableStoreWhenOpenFails(t *testing.T) {
}
}
func TestWrapKeyringErrorNormalizesWinCredNotFoundMessage(t *testing.T) {
t.Parallel()
err := wrapKeyringError(errors.New("The specified item could not be found in the keyring"))
if err == nil {
t.Fatal("wrapKeyringError should preserve missing-secret semantics")
}
if !os.IsNotExist(err) {
t.Fatalf("wrapKeyringError should map WinCred not-found errors to os.ErrNotExist, got %v", err)
}
if IsUnavailable(err) {
t.Fatalf("wrapKeyringError should not treat WinCred not-found errors as unavailable, got %v", err)
}
}
func TestWrapKeyringErrorNormalizesWrappedKeyringErrKeyNotFound(t *testing.T) {
t.Parallel()
err := wrapKeyringError(fmt.Errorf("wrapped: %w", keyring.ErrKeyNotFound))
if err == nil {
t.Fatal("wrapKeyringError should preserve wrapped missing-secret semantics")
}
if !os.IsNotExist(err) {
t.Fatalf("wrapKeyringError should map wrapped ErrKeyNotFound to os.ErrNotExist, got %v", err)
}
if IsUnavailable(err) {
t.Fatalf("wrapKeyringError should not treat wrapped ErrKeyNotFound as unavailable, got %v", err)
}
}
func TestWrapKeyringErrorNormalizesWinCredErrno1168(t *testing.T) {
t.Parallel()
err := wrapKeyringError(syscall.Errno(1168))
if err == nil {
t.Fatal("wrapKeyringError should preserve WinCred errno missing-secret semantics")
}
if !os.IsNotExist(err) {
t.Fatalf("wrapKeyringError should map WinCred errno to os.ErrNotExist, got %v", err)
}
if IsUnavailable(err) {
t.Fatalf("wrapKeyringError should not treat WinCred errno as unavailable, got %v", err)
}
}
func TestWrapKeyringErrorDoesNotSwallowUnrelatedElementNotFoundMessages(t *testing.T) {
t.Parallel()
backendErr := errors.New("database element not found while enumerating providers")
err := wrapKeyringError(backendErr)
if err == nil {
t.Fatal("wrapKeyringError should preserve backend failures")
}
if os.IsNotExist(err) {
t.Fatalf("wrapKeyringError should not map unrelated element-not-found errors to os.ErrNotExist, got %v", err)
}
if !IsUnavailable(err) {
t.Fatalf("wrapKeyringError should keep unrelated backend failures unavailable, got %v", err)
}
}
type fakeKeyringClient struct {
getErr error
item keyring.Item