From b62d22395b5031e49dd40ac45487e758119ca2b3 Mon Sep 17 00:00:00 2001 From: tianqijiuyun-latiao <69459608+tianqijiuyun-latiao@users.noreply.github.com> Date: Fri, 3 Apr 2026 00:18:06 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(security):=20=E6=8B=86?= =?UTF-8?q?=E5=88=86=20AI=20=E4=BE=9B=E5=BA=94=E5=95=86=E5=85=83=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E4=B8=8E=E5=AF=86=E9=92=A5=E5=AD=98=E5=82=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/ai/service/provider_secret.go | 231 +++++++++++++ internal/ai/service/provider_secret_test.go | 348 ++++++++++++++++++++ internal/ai/service/service.go | 174 ++++++++-- internal/ai/types.go | 2 + 4 files changed, 728 insertions(+), 27 deletions(-) create mode 100644 internal/ai/service/provider_secret.go create mode 100644 internal/ai/service/provider_secret_test.go diff --git a/internal/ai/service/provider_secret.go b/internal/ai/service/provider_secret.go new file mode 100644 index 0000000..6fe22bc --- /dev/null +++ b/internal/ai/service/provider_secret.go @@ -0,0 +1,231 @@ +package aiservice + +import ( + "encoding/json" + "fmt" + "strings" + "unicode" + + "GoNavi-Wails/internal/ai" + "GoNavi-Wails/internal/secretstore" +) + +const providerSecretKind = "ai-provider" + +type providerSecretBundle struct { + APIKey string `json:"apiKey,omitempty"` + SensitiveHeaders map[string]string `json:"sensitiveHeaders,omitempty"` +} + +func (b providerSecretBundle) hasAny() bool { + return strings.TrimSpace(b.APIKey) != "" || len(b.SensitiveHeaders) > 0 +} + +func mergeProviderSecretBundles(base, overlay providerSecretBundle) providerSecretBundle { + merged := providerSecretBundle{ + APIKey: base.APIKey, + SensitiveHeaders: cloneStringMap(base.SensitiveHeaders), + } + if strings.TrimSpace(overlay.APIKey) != "" { + merged.APIKey = overlay.APIKey + } + for key, value := range overlay.SensitiveHeaders { + if merged.SensitiveHeaders == nil { + merged.SensitiveHeaders = make(map[string]string, len(overlay.SensitiveHeaders)) + } + merged.SensitiveHeaders[key] = value + } + if len(merged.SensitiveHeaders) == 0 { + merged.SensitiveHeaders = nil + } + return merged +} + +func splitProviderSecrets(cfg ai.ProviderConfig) (ai.ProviderConfig, providerSecretBundle) { + meta := cfg + meta.APIKey = "" + + bundle := providerSecretBundle{} + if apiKey := strings.TrimSpace(cfg.APIKey); apiKey != "" { + bundle.APIKey = apiKey + } + + if len(cfg.Headers) > 0 { + safeHeaders := make(map[string]string, len(cfg.Headers)) + sensitiveHeaders := make(map[string]string) + for key, value := range cfg.Headers { + if isSensitiveProviderHeader(key) { + if strings.TrimSpace(value) != "" { + sensitiveHeaders[key] = value + } + continue + } + safeHeaders[key] = value + } + if len(safeHeaders) > 0 { + meta.Headers = safeHeaders + } else { + meta.Headers = nil + } + if len(sensitiveHeaders) > 0 { + bundle.SensitiveHeaders = sensitiveHeaders + } + } else { + meta.Headers = nil + } + + meta.HasSecret = cfg.HasSecret || bundle.hasAny() + meta.SecretRef = strings.TrimSpace(cfg.SecretRef) + if meta.HasSecret && meta.SecretRef == "" && strings.TrimSpace(cfg.ID) != "" { + if ref, err := secretstore.BuildRef(providerSecretKind, cfg.ID); err == nil { + meta.SecretRef = ref + } + } + if !meta.HasSecret { + meta.SecretRef = "" + } + + return meta, bundle +} + +func mergeProviderSecrets(cfg ai.ProviderConfig, bundle providerSecretBundle) ai.ProviderConfig { + merged := cfg + merged.APIKey = bundle.APIKey + + headers := cloneStringMap(cfg.Headers) + if len(bundle.SensitiveHeaders) > 0 { + if headers == nil { + headers = make(map[string]string, len(bundle.SensitiveHeaders)) + } + for key, value := range bundle.SensitiveHeaders { + headers[key] = value + } + } + if len(headers) > 0 { + merged.Headers = headers + } else { + merged.Headers = nil + } + + merged.HasSecret = cfg.HasSecret || bundle.hasAny() + if merged.HasSecret && strings.TrimSpace(merged.SecretRef) == "" && strings.TrimSpace(merged.ID) != "" { + if ref, err := secretstore.BuildRef(providerSecretKind, merged.ID); err == nil { + merged.SecretRef = ref + } + } + if !merged.HasSecret { + merged.SecretRef = "" + } + + return merged +} + +func (s *Service) persistProviderSecretBundle(meta ai.ProviderConfig, bundle providerSecretBundle) (ai.ProviderConfig, error) { + meta, _ = splitProviderSecrets(meta) + if !bundle.hasAny() { + meta.HasSecret = false + meta.SecretRef = "" + return meta, nil + } + if s.secretStore == nil { + return meta, fmt.Errorf("secret store unavailable") + } + if err := s.secretStore.HealthCheck(); err != nil { + return meta, err + } + + ref := strings.TrimSpace(meta.SecretRef) + if ref == "" { + var err error + ref, err = secretstore.BuildRef(providerSecretKind, meta.ID) + if err != nil { + return meta, err + } + } + + payload, err := json.Marshal(bundle) + if err != nil { + return meta, fmt.Errorf("序列化 provider secret bundle 失败: %w", err) + } + if err := s.secretStore.Put(ref, payload); err != nil { + return meta, err + } + + meta.SecretRef = ref + meta.HasSecret = true + return meta, nil +} + +func (s *Service) resolveProviderConfigSecrets(cfg ai.ProviderConfig) (ai.ProviderConfig, error) { + cfg = normalizeProviderConfig(cfg) + meta, bundle := splitProviderSecrets(cfg) + if bundle.hasAny() { + return mergeProviderSecrets(meta, bundle), nil + } + if !meta.HasSecret { + return meta, nil + } + if s.secretStore == nil { + return meta, fmt.Errorf("secret store unavailable") + } + + ref := strings.TrimSpace(meta.SecretRef) + if ref == "" { + var err error + ref, err = secretstore.BuildRef(providerSecretKind, meta.ID) + if err != nil { + return meta, err + } + meta.SecretRef = ref + } + + payload, err := s.secretStore.Get(ref) + if err != nil { + return meta, err + } + + var stored providerSecretBundle + if err := json.Unmarshal(payload, &stored); err != nil { + return meta, fmt.Errorf("解析 provider secret bundle 失败: %w", err) + } + return mergeProviderSecrets(meta, stored), nil +} + +func providerMetadataView(cfg ai.ProviderConfig) ai.ProviderConfig { + meta, _ := splitProviderSecrets(normalizeProviderConfig(cfg)) + return meta +} + +func isSensitiveProviderHeader(name string) bool { + normalized := strings.TrimSpace(strings.ToLower(name)) + switch normalized { + case "authorization", "proxy-authorization", "x-api-key", "api-key": + return true + } + + for _, token := range providerHeaderTokens(normalized) { + switch token { + case "auth", "authorization", "token", "secret", "key", "apikey": + return true + } + } + + return false +} + +func providerHeaderTokens(name string) []string { + return strings.FieldsFunc(name, func(r rune) bool { + return !unicode.IsLetter(r) && !unicode.IsDigit(r) + }) +} + +func cloneStringMap(input map[string]string) map[string]string { + if len(input) == 0 { + return nil + } + cloned := make(map[string]string, len(input)) + for key, value := range input { + cloned[key] = value + } + return cloned +} diff --git a/internal/ai/service/provider_secret_test.go b/internal/ai/service/provider_secret_test.go new file mode 100644 index 0000000..033b24f --- /dev/null +++ b/internal/ai/service/provider_secret_test.go @@ -0,0 +1,348 @@ +package aiservice + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "GoNavi-Wails/internal/ai" + "GoNavi-Wails/internal/secretstore" +) + +func TestSplitProviderSecretsStripsAPIKeyAndSensitiveHeaders(t *testing.T) { + input := ai.ProviderConfig{ + ID: "openai-main", + APIKey: "sk-test", + BaseURL: "https://api.openai.com/v1", + Headers: map[string]string{ + "Authorization": "Bearer test", + "X-Team": "db", + }, + } + + meta, bundle := splitProviderSecrets(input) + if meta.APIKey != "" { + t.Fatal("apiKey should not stay in metadata") + } + if meta.Headers["Authorization"] != "" { + t.Fatal("sensitive header should not stay in metadata") + } + if meta.Headers["X-Team"] != "db" { + t.Fatal("non-sensitive header should stay in metadata") + } + if bundle.APIKey != "sk-test" { + t.Fatal("bundle should keep apiKey") + } + if bundle.SensitiveHeaders["Authorization"] != "Bearer test" { + t.Fatal("bundle should keep sensitive header") + } +} + +func TestResolveProviderConfigSecretsRestoresStoredSecretBundle(t *testing.T) { + store := newFakeProviderSecretStore() + service := NewServiceWithSecretStore(store) + ref, err := secretstore.BuildRef("ai-provider", "openai-main") + if err != nil { + t.Fatalf("BuildRef returned error: %v", err) + } + payload, err := json.Marshal(providerSecretBundle{ + APIKey: "sk-test", + SensitiveHeaders: map[string]string{ + "Authorization": "Bearer test", + }, + }) + if err != nil { + t.Fatalf("Marshal returned error: %v", err) + } + if err := store.Put(ref, payload); err != nil { + t.Fatalf("Put returned error: %v", err) + } + + resolved, err := service.resolveProviderConfigSecrets(ai.ProviderConfig{ + ID: "openai-main", + SecretRef: ref, + HasSecret: true, + Headers: map[string]string{ + "X-Team": "db", + }, + }) + if err != nil { + t.Fatalf("resolveProviderConfigSecrets returned error: %v", err) + } + if resolved.APIKey != "sk-test" { + t.Fatalf("expected restored apiKey, got %q", resolved.APIKey) + } + if resolved.Headers["Authorization"] != "Bearer test" { + t.Fatalf("expected restored sensitive header, got %#v", resolved.Headers) + } + if resolved.Headers["X-Team"] != "db" { + t.Fatalf("expected non-sensitive header to survive, got %#v", resolved.Headers) + } +} + +func TestLoadConfigMigratesPlaintextProviderSecrets(t *testing.T) { + store := newFakeProviderSecretStore() + service := NewServiceWithSecretStore(store) + service.configDir = t.TempDir() + + legacy := aiConfig{ + Providers: []ai.ProviderConfig{ + { + ID: "openai-main", + Type: "openai", + Name: "OpenAI", + APIKey: "sk-test", + BaseURL: "https://api.openai.com/v1", + Headers: map[string]string{ + "Authorization": "Bearer test", + "X-Team": "db", + }, + }, + }, + } + data, err := json.MarshalIndent(legacy, "", " ") + if err != nil { + t.Fatalf("MarshalIndent returned error: %v", err) + } + configPath := filepath.Join(service.configDir, "ai_config.json") + if err := os.WriteFile(configPath, data, 0o644); err != nil { + t.Fatalf("WriteFile returned error: %v", err) + } + + service.loadConfig() + + providers := service.AIGetProviders() + if len(providers) != 1 { + t.Fatalf("expected 1 provider, got %d", len(providers)) + } + if providers[0].APIKey != "" { + t.Fatalf("expected migrated provider to be secretless, got %q", providers[0].APIKey) + } + if !providers[0].HasSecret { + t.Fatal("expected migrated provider to report HasSecret=true") + } + stored, err := store.Get(providers[0].SecretRef) + if err != nil { + t.Fatalf("expected secret bundle in store, got error: %v", err) + } + var bundle providerSecretBundle + if err := json.Unmarshal(stored, &bundle); err != nil { + t.Fatalf("Unmarshal returned error: %v", err) + } + if bundle.APIKey != "sk-test" { + t.Fatalf("expected migrated apiKey in store, got %q", bundle.APIKey) + } + if bundle.SensitiveHeaders["Authorization"] != "Bearer test" { + t.Fatalf("expected migrated sensitive header in store, got %#v", bundle.SensitiveHeaders) + } + + rewritten, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("ReadFile returned error: %v", err) + } + text := string(rewritten) + if strings.Contains(text, "sk-test") { + t.Fatalf("expected rewritten config to remove api key, got %s", text) + } + if strings.Contains(text, "Bearer test") { + t.Fatalf("expected rewritten config to remove sensitive header, got %s", text) + } +} + +func TestAISaveProviderPersistsSecretlessConfigAndReturnsSecretlessView(t *testing.T) { + store := newFakeProviderSecretStore() + service := NewServiceWithSecretStore(store) + service.configDir = t.TempDir() + + err := service.AISaveProvider(ai.ProviderConfig{ + ID: "openai-main", + Type: "openai", + Name: "OpenAI", + APIKey: "sk-test", + BaseURL: "https://api.openai.com/v1", + Headers: map[string]string{ + "Authorization": "Bearer test", + "X-Team": "db", + }, + }) + if err != nil { + t.Fatalf("AISaveProvider returned error: %v", err) + } + + providers := service.AIGetProviders() + if len(providers) != 1 { + t.Fatalf("expected 1 provider, got %d", len(providers)) + } + if providers[0].APIKey != "" { + t.Fatalf("expected secretless provider view, got %q", providers[0].APIKey) + } + if !providers[0].HasSecret { + t.Fatal("expected saved provider view to report HasSecret=true") + } + if providers[0].Headers["Authorization"] != "" { + t.Fatalf("expected secretless provider headers, got %#v", providers[0].Headers) + } + if service.providers[0].APIKey != "sk-test" { + t.Fatalf("expected runtime provider to keep apiKey, got %q", service.providers[0].APIKey) + } + if service.providers[0].Headers["Authorization"] != "Bearer test" { + t.Fatalf("expected runtime provider to keep sensitive header, got %#v", service.providers[0].Headers) + } + + configPath := filepath.Join(service.configDir, "ai_config.json") + data, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("ReadFile returned error: %v", err) + } + text := string(data) + if strings.Contains(text, "sk-test") { + t.Fatalf("expected config file to be secretless, got %s", text) + } + if strings.Contains(text, "Bearer test") { + t.Fatalf("expected config file to remove sensitive headers, got %s", text) + } +} + +func TestAISaveProviderKeepsExistingSecretWhenInputOmitsAPIKey(t *testing.T) { + store := newFakeProviderSecretStore() + service := NewServiceWithSecretStore(store) + service.configDir = t.TempDir() + + if err := service.AISaveProvider(ai.ProviderConfig{ + ID: "openai-main", + Type: "openai", + Name: "OpenAI", + APIKey: "sk-original", + BaseURL: "https://api.openai.com/v1", + Headers: map[string]string{ + "Authorization": "Bearer original", + "X-Team": "db", + }, + }); err != nil { + t.Fatalf("initial AISaveProvider returned error: %v", err) + } + + if err := service.AISaveProvider(ai.ProviderConfig{ + ID: "openai-main", + Type: "openai", + Name: "OpenAI Updated", + BaseURL: "https://gateway.openai.com/v1", + HasSecret: true, + Headers: map[string]string{ + "X-Team": "platform", + }, + }); err != nil { + t.Fatalf("update AISaveProvider returned error: %v", err) + } + + if service.providers[0].APIKey != "sk-original" { + t.Fatalf("expected runtime provider to keep original apiKey, got %q", service.providers[0].APIKey) + } + if service.providers[0].Headers["Authorization"] != "Bearer original" { + t.Fatalf("expected runtime provider to keep original sensitive header, got %#v", service.providers[0].Headers) + } + if service.providers[0].Headers["X-Team"] != "platform" { + t.Fatalf("expected runtime provider to update non-sensitive headers, got %#v", service.providers[0].Headers) + } + if service.providers[0].BaseURL != "https://gateway.openai.com/v1" { + t.Fatalf("expected runtime provider to update metadata, got %q", service.providers[0].BaseURL) + } + + providers := service.AIGetProviders() + if len(providers) != 1 || !providers[0].HasSecret { + t.Fatalf("expected provider view to keep HasSecret=true, got %#v", providers) + } + if providers[0].APIKey != "" { + t.Fatalf("expected provider view to stay secretless, got %q", providers[0].APIKey) + } +} + +func TestAISaveProviderMergesStoredSensitiveHeadersWhenUpdatingOnlyAPIKey(t *testing.T) { + store := newFakeProviderSecretStore() + service := NewServiceWithSecretStore(store) + service.configDir = t.TempDir() + + if err := service.AISaveProvider(ai.ProviderConfig{ + ID: "openai-main", + Type: "openai", + Name: "OpenAI", + APIKey: "sk-original", + BaseURL: "https://api.openai.com/v1", + Headers: map[string]string{ + "Authorization": "Bearer original", + "X-Team": "db", + }, + }); err != nil { + t.Fatalf("initial AISaveProvider returned error: %v", err) + } + + if err := service.AISaveProvider(ai.ProviderConfig{ + ID: "openai-main", + Type: "openai", + Name: "OpenAI", + APIKey: "sk-updated", + HasSecret: true, + BaseURL: "https://api.openai.com/v1", + Headers: map[string]string{ + "X-Team": "db", + }, + }); err != nil { + t.Fatalf("update AISaveProvider returned error: %v", err) + } + + if service.providers[0].APIKey != "sk-updated" { + t.Fatalf("expected updated apiKey, got %q", service.providers[0].APIKey) + } + if service.providers[0].Headers["Authorization"] != "Bearer original" { + t.Fatalf("expected existing sensitive header to be kept, got %#v", service.providers[0].Headers) + } + + stored, err := store.Get(service.providers[0].SecretRef) + if err != nil { + t.Fatalf("expected merged secret bundle in store, got %v", err) + } + var bundle providerSecretBundle + if err := json.Unmarshal(stored, &bundle); err != nil { + t.Fatalf("Unmarshal returned error: %v", err) + } + if bundle.APIKey != "sk-updated" { + t.Fatalf("expected store to keep updated apiKey, got %q", bundle.APIKey) + } + if bundle.SensitiveHeaders["Authorization"] != "Bearer original" { + t.Fatalf("expected store to keep existing sensitive header, got %#v", bundle.SensitiveHeaders) + } +} + +type fakeProviderSecretStore struct { + items map[string][]byte +} + +func newFakeProviderSecretStore() *fakeProviderSecretStore { + return &fakeProviderSecretStore{items: make(map[string][]byte)} +} + +func (s *fakeProviderSecretStore) Put(ref string, payload []byte) error { + s.items[ref] = append([]byte(nil), payload...) + return nil +} + +func (s *fakeProviderSecretStore) Get(ref string) ([]byte, error) { + payload, ok := s.items[ref] + if !ok { + return nil, os.ErrNotExist + } + return append([]byte(nil), payload...), nil +} + +func (s *fakeProviderSecretStore) Delete(ref string) error { + delete(s.items, ref) + return nil +} + +func (s *fakeProviderSecretStore) HealthCheck() error { + return nil +} + +var _ secretstore.SecretStore = (*fakeProviderSecretStore)(nil) diff --git a/internal/ai/service/service.go b/internal/ai/service/service.go index 6897820..a5bd44e 100644 --- a/internal/ai/service/service.go +++ b/internal/ai/service/service.go @@ -18,6 +18,7 @@ import ( "GoNavi-Wails/internal/ai/provider" "GoNavi-Wails/internal/ai/safety" "GoNavi-Wails/internal/logger" + "GoNavi-Wails/internal/secretstore" "github.com/google/uuid" wailsRuntime "github.com/wailsapp/wails/v2/pkg/runtime" @@ -32,7 +33,8 @@ type Service struct { safetyLevel ai.SQLPermissionLevel contextLevel ai.ContextLevel guard *safety.Guard - configDir string // 配置存储目录 + configDir string // 配置存储目录 + secretStore secretstore.SecretStore cancelFuncs map[string]context.CancelFunc // 记录每个 session 的 context 取消函数 } @@ -97,11 +99,19 @@ var claudeCLIHealthCheckFunc = func(config ai.ProviderConfig) error { // NewService 创建 AI Service 实例 func NewService() *Service { + return NewServiceWithSecretStore(secretstore.NewKeyringStore()) +} + +func NewServiceWithSecretStore(store secretstore.SecretStore) *Service { + if store == nil { + store = secretstore.NewUnavailableStore("secret store unavailable") + } return &Service{ providers: make([]ai.ProviderConfig, 0), safetyLevel: ai.PermissionReadOnly, contextLevel: ai.ContextSchemaOnly, guard: safety.NewGuard(ai.PermissionReadOnly), + secretStore: store, cancelFuncs: make(map[string]context.CancelFunc), } } @@ -127,35 +137,80 @@ func (s *Service) AIGetProviders() []ai.ProviderConfig { defer s.mu.RUnlock() result := make([]ai.ProviderConfig, len(s.providers)) - copy(result, s.providers) - for i := range result { - result[i] = normalizeProviderConfig(result[i]) + for i := range s.providers { + result[i] = providerMetadataView(s.providers[i]) } return result } // AISaveProvider 保存/更新 Provider 配置 func (s *Service) AISaveProvider(config ai.ProviderConfig) error { - fmt.Printf("[AISaveProvider DEBUG] ID: %s, Model: %s\n", config.ID, config.Model) s.mu.Lock() defer s.mu.Unlock() config = normalizeProviderConfig(config) - if strings.TrimSpace(config.ID) == "" { config.ID = "provider-" + uuid.New().String()[:8] } + var existing ai.ProviderConfig found := false - for i, p := range s.providers { - if p.ID == config.ID { - s.providers[i] = config + for _, providerConfig := range s.providers { + if providerConfig.ID == config.ID { + existing = providerConfig found = true break } } - if !found { - s.providers = append(s.providers, config) + + meta, bundle := splitProviderSecrets(config) + var runtimeConfig ai.ProviderConfig + switch { + case bundle.hasAny(): + mergedBundle := bundle + if found && existing.HasSecret { + _, existingBundle := splitProviderSecrets(existing) + mergedBundle = mergeProviderSecretBundles(existingBundle, bundle) + } + if found && strings.TrimSpace(meta.SecretRef) == "" { + meta.SecretRef = existing.SecretRef + } + storedMeta, err := s.persistProviderSecretBundle(meta, mergedBundle) + if err != nil { + return fmt.Errorf("保存 Provider secret 失败: %w", err) + } + runtimeConfig = mergeProviderSecrets(storedMeta, mergedBundle) + case found && (config.HasSecret || existing.HasSecret): + meta.SecretRef = existing.SecretRef + meta.HasSecret = config.HasSecret || existing.HasSecret + resolved, err := s.resolveProviderConfigSecrets(meta) + if err != nil { + return fmt.Errorf("读取已保存 Provider secret 失败: %w", err) + } + runtimeConfig = resolved + default: + runtimeConfig = meta + } + + if !runtimeConfig.HasSecret && found && strings.TrimSpace(existing.SecretRef) != "" { + if err := s.secretStore.Delete(existing.SecretRef); err != nil { + return fmt.Errorf("删除 Provider secret 失败: %w", err) + } + } + if !runtimeConfig.HasSecret { + runtimeConfig.SecretRef = "" + } + + runtimeConfig = normalizeProviderConfig(runtimeConfig) + if found { + for i := range s.providers { + if s.providers[i].ID == runtimeConfig.ID { + s.providers[i] = runtimeConfig + break + } + } + } else { + s.providers = append(s.providers, runtimeConfig) } return s.saveConfig() @@ -167,9 +222,19 @@ func (s *Service) AIDeleteProvider(id string) error { defer s.mu.Unlock() newProviders := make([]ai.ProviderConfig, 0, len(s.providers)) - for _, p := range s.providers { - if p.ID != id { - newProviders = append(newProviders, p) + var removed ai.ProviderConfig + removedFound := false + for _, providerConfig := range s.providers { + if providerConfig.ID == id { + removed = providerConfig + removedFound = true + continue + } + newProviders = append(newProviders, providerConfig) + } + if removedFound && strings.TrimSpace(removed.SecretRef) != "" { + if err := s.secretStore.Delete(removed.SecretRef); err != nil { + return fmt.Errorf("删除 Provider secret 失败: %w", err) } } s.providers = newProviders @@ -186,17 +251,29 @@ func (s *Service) AIDeleteProvider(id string) error { // AITestProvider 测试 Provider 配置是否可用,仅测试端点连通性与密钥,不实际调用对话 func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{} { - // 如果传入脱敏的 key,使用已保存的 key - s.mu.RLock() if isMaskedAPIKey(config.APIKey) { - for _, p := range s.providers { - if p.ID == config.ID { - config.APIKey = p.APIKey - break + config.APIKey = "" + config.HasSecret = true + } + if strings.TrimSpace(config.APIKey) == "" && (config.HasSecret || strings.TrimSpace(config.SecretRef) != "") { + s.mu.RLock() + if strings.TrimSpace(config.SecretRef) == "" { + for _, providerConfig := range s.providers { + if providerConfig.ID == config.ID { + config.SecretRef = providerConfig.SecretRef + config.HasSecret = config.HasSecret || providerConfig.HasSecret + break + } } } + s.mu.RUnlock() + + resolved, err := s.resolveProviderConfigSecrets(config) + if err != nil { + return map[string]interface{}{"success": false, "message": fmt.Sprintf("连接测试失败: %s", err.Error())} + } + config = resolved } - s.mu.RUnlock() config = normalizeProviderConfig(config) baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/") @@ -842,13 +919,35 @@ func (s *Service) getActiveProvider() (provider.Provider, error) { // --- 配置持久化 --- +const aiConfigSchemaVersion = 2 + type aiConfig struct { + SchemaVersion int `json:"schemaVersion,omitempty"` Providers []ai.ProviderConfig `json:"providers"` ActiveProvider string `json:"activeProvider"` SafetyLevel string `json:"safetyLevel"` ContextLevel string `json:"contextLevel"` } +func (s *Service) loadRuntimeProviderConfig(config ai.ProviderConfig) (ai.ProviderConfig, bool, error) { + meta, bundle := splitProviderSecrets(config) + if bundle.hasAny() { + storedMeta, err := s.persistProviderSecretBundle(meta, bundle) + if err != nil { + meta.HasSecret = false + meta.SecretRef = "" + return meta, true, err + } + return mergeProviderSecrets(storedMeta, bundle), true, nil + } + + resolved, err := s.resolveProviderConfigSecrets(meta) + if err != nil { + return meta, false, err + } + return resolved, false, nil +} + func (s *Service) loadConfig() { path := filepath.Join(s.configDir, "ai_config.json") data, err := os.ReadFile(path) @@ -862,13 +961,22 @@ func (s *Service) loadConfig() { return } - s.providers = cfg.Providers - if s.providers == nil { - s.providers = make([]ai.ProviderConfig, 0) + providers := make([]ai.ProviderConfig, 0, len(cfg.Providers)) + shouldRewrite := cfg.SchemaVersion != aiConfigSchemaVersion + for _, providerConfig := range cfg.Providers { + runtimeConfig, rewritten, err := s.loadRuntimeProviderConfig(normalizeProviderConfig(providerConfig)) + if err != nil { + logger.Error(err, "加载 AI Provider secret 失败,provider=%s", providerConfig.ID) + } + if rewritten { + shouldRewrite = true + } + providers = append(providers, runtimeConfig) } - for i := range s.providers { - s.providers[i] = normalizeProviderConfig(s.providers[i]) + if providers == nil { + providers = make([]ai.ProviderConfig, 0) } + s.providers = providers s.activeProvider = cfg.ActiveProvider switch ai.SQLPermissionLevel(cfg.SafetyLevel) { @@ -885,11 +993,23 @@ func (s *Service) loadConfig() { default: s.contextLevel = ai.ContextSchemaOnly } + + if shouldRewrite { + if err := s.saveConfig(); err != nil { + logger.Error(err, "重写 AI 配置失败") + } + } } func (s *Service) saveConfig() error { + providers := make([]ai.ProviderConfig, len(s.providers)) + for i := range s.providers { + providers[i] = providerMetadataView(s.providers[i]) + } + cfg := aiConfig{ - Providers: s.providers, + SchemaVersion: aiConfigSchemaVersion, + Providers: providers, ActiveProvider: s.activeProvider, SafetyLevel: string(s.safetyLevel), ContextLevel: string(s.contextLevel), diff --git a/internal/ai/types.go b/internal/ai/types.go index 5f4ddae..790b023 100644 --- a/internal/ai/types.go +++ b/internal/ai/types.go @@ -69,6 +69,8 @@ type ProviderConfig struct { Type string `json:"type"` // openai | anthropic | gemini | custom Name string `json:"name"` APIKey string `json:"apiKey"` + SecretRef string `json:"secretRef,omitempty"` + HasSecret bool `json:"hasSecret,omitempty"` BaseURL string `json:"baseUrl"` Model string `json:"model"` Models []string `json:"models,omitempty"`