mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-05-11 19:29:44 +08:00
✨ feat(security): 拆分 AI 供应商元数据与密钥存储
This commit is contained in:
231
internal/ai/service/provider_secret.go
Normal file
231
internal/ai/service/provider_secret.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package aiservice
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
"GoNavi-Wails/internal/secretstore"
|
||||
)
|
||||
|
||||
const providerSecretKind = "ai-provider"
|
||||
|
||||
type providerSecretBundle struct {
|
||||
APIKey string `json:"apiKey,omitempty"`
|
||||
SensitiveHeaders map[string]string `json:"sensitiveHeaders,omitempty"`
|
||||
}
|
||||
|
||||
func (b providerSecretBundle) hasAny() bool {
|
||||
return strings.TrimSpace(b.APIKey) != "" || len(b.SensitiveHeaders) > 0
|
||||
}
|
||||
|
||||
func mergeProviderSecretBundles(base, overlay providerSecretBundle) providerSecretBundle {
|
||||
merged := providerSecretBundle{
|
||||
APIKey: base.APIKey,
|
||||
SensitiveHeaders: cloneStringMap(base.SensitiveHeaders),
|
||||
}
|
||||
if strings.TrimSpace(overlay.APIKey) != "" {
|
||||
merged.APIKey = overlay.APIKey
|
||||
}
|
||||
for key, value := range overlay.SensitiveHeaders {
|
||||
if merged.SensitiveHeaders == nil {
|
||||
merged.SensitiveHeaders = make(map[string]string, len(overlay.SensitiveHeaders))
|
||||
}
|
||||
merged.SensitiveHeaders[key] = value
|
||||
}
|
||||
if len(merged.SensitiveHeaders) == 0 {
|
||||
merged.SensitiveHeaders = nil
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func splitProviderSecrets(cfg ai.ProviderConfig) (ai.ProviderConfig, providerSecretBundle) {
|
||||
meta := cfg
|
||||
meta.APIKey = ""
|
||||
|
||||
bundle := providerSecretBundle{}
|
||||
if apiKey := strings.TrimSpace(cfg.APIKey); apiKey != "" {
|
||||
bundle.APIKey = apiKey
|
||||
}
|
||||
|
||||
if len(cfg.Headers) > 0 {
|
||||
safeHeaders := make(map[string]string, len(cfg.Headers))
|
||||
sensitiveHeaders := make(map[string]string)
|
||||
for key, value := range cfg.Headers {
|
||||
if isSensitiveProviderHeader(key) {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
sensitiveHeaders[key] = value
|
||||
}
|
||||
continue
|
||||
}
|
||||
safeHeaders[key] = value
|
||||
}
|
||||
if len(safeHeaders) > 0 {
|
||||
meta.Headers = safeHeaders
|
||||
} else {
|
||||
meta.Headers = nil
|
||||
}
|
||||
if len(sensitiveHeaders) > 0 {
|
||||
bundle.SensitiveHeaders = sensitiveHeaders
|
||||
}
|
||||
} else {
|
||||
meta.Headers = nil
|
||||
}
|
||||
|
||||
meta.HasSecret = cfg.HasSecret || bundle.hasAny()
|
||||
meta.SecretRef = strings.TrimSpace(cfg.SecretRef)
|
||||
if meta.HasSecret && meta.SecretRef == "" && strings.TrimSpace(cfg.ID) != "" {
|
||||
if ref, err := secretstore.BuildRef(providerSecretKind, cfg.ID); err == nil {
|
||||
meta.SecretRef = ref
|
||||
}
|
||||
}
|
||||
if !meta.HasSecret {
|
||||
meta.SecretRef = ""
|
||||
}
|
||||
|
||||
return meta, bundle
|
||||
}
|
||||
|
||||
func mergeProviderSecrets(cfg ai.ProviderConfig, bundle providerSecretBundle) ai.ProviderConfig {
|
||||
merged := cfg
|
||||
merged.APIKey = bundle.APIKey
|
||||
|
||||
headers := cloneStringMap(cfg.Headers)
|
||||
if len(bundle.SensitiveHeaders) > 0 {
|
||||
if headers == nil {
|
||||
headers = make(map[string]string, len(bundle.SensitiveHeaders))
|
||||
}
|
||||
for key, value := range bundle.SensitiveHeaders {
|
||||
headers[key] = value
|
||||
}
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
merged.Headers = headers
|
||||
} else {
|
||||
merged.Headers = nil
|
||||
}
|
||||
|
||||
merged.HasSecret = cfg.HasSecret || bundle.hasAny()
|
||||
if merged.HasSecret && strings.TrimSpace(merged.SecretRef) == "" && strings.TrimSpace(merged.ID) != "" {
|
||||
if ref, err := secretstore.BuildRef(providerSecretKind, merged.ID); err == nil {
|
||||
merged.SecretRef = ref
|
||||
}
|
||||
}
|
||||
if !merged.HasSecret {
|
||||
merged.SecretRef = ""
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
func (s *Service) persistProviderSecretBundle(meta ai.ProviderConfig, bundle providerSecretBundle) (ai.ProviderConfig, error) {
|
||||
meta, _ = splitProviderSecrets(meta)
|
||||
if !bundle.hasAny() {
|
||||
meta.HasSecret = false
|
||||
meta.SecretRef = ""
|
||||
return meta, nil
|
||||
}
|
||||
if s.secretStore == nil {
|
||||
return meta, fmt.Errorf("secret store unavailable")
|
||||
}
|
||||
if err := s.secretStore.HealthCheck(); err != nil {
|
||||
return meta, err
|
||||
}
|
||||
|
||||
ref := strings.TrimSpace(meta.SecretRef)
|
||||
if ref == "" {
|
||||
var err error
|
||||
ref, err = secretstore.BuildRef(providerSecretKind, meta.ID)
|
||||
if err != nil {
|
||||
return meta, err
|
||||
}
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(bundle)
|
||||
if err != nil {
|
||||
return meta, fmt.Errorf("序列化 provider secret bundle 失败: %w", err)
|
||||
}
|
||||
if err := s.secretStore.Put(ref, payload); err != nil {
|
||||
return meta, err
|
||||
}
|
||||
|
||||
meta.SecretRef = ref
|
||||
meta.HasSecret = true
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
func (s *Service) resolveProviderConfigSecrets(cfg ai.ProviderConfig) (ai.ProviderConfig, error) {
|
||||
cfg = normalizeProviderConfig(cfg)
|
||||
meta, bundle := splitProviderSecrets(cfg)
|
||||
if bundle.hasAny() {
|
||||
return mergeProviderSecrets(meta, bundle), nil
|
||||
}
|
||||
if !meta.HasSecret {
|
||||
return meta, nil
|
||||
}
|
||||
if s.secretStore == nil {
|
||||
return meta, fmt.Errorf("secret store unavailable")
|
||||
}
|
||||
|
||||
ref := strings.TrimSpace(meta.SecretRef)
|
||||
if ref == "" {
|
||||
var err error
|
||||
ref, err = secretstore.BuildRef(providerSecretKind, meta.ID)
|
||||
if err != nil {
|
||||
return meta, err
|
||||
}
|
||||
meta.SecretRef = ref
|
||||
}
|
||||
|
||||
payload, err := s.secretStore.Get(ref)
|
||||
if err != nil {
|
||||
return meta, err
|
||||
}
|
||||
|
||||
var stored providerSecretBundle
|
||||
if err := json.Unmarshal(payload, &stored); err != nil {
|
||||
return meta, fmt.Errorf("解析 provider secret bundle 失败: %w", err)
|
||||
}
|
||||
return mergeProviderSecrets(meta, stored), nil
|
||||
}
|
||||
|
||||
func providerMetadataView(cfg ai.ProviderConfig) ai.ProviderConfig {
|
||||
meta, _ := splitProviderSecrets(normalizeProviderConfig(cfg))
|
||||
return meta
|
||||
}
|
||||
|
||||
func isSensitiveProviderHeader(name string) bool {
|
||||
normalized := strings.TrimSpace(strings.ToLower(name))
|
||||
switch normalized {
|
||||
case "authorization", "proxy-authorization", "x-api-key", "api-key":
|
||||
return true
|
||||
}
|
||||
|
||||
for _, token := range providerHeaderTokens(normalized) {
|
||||
switch token {
|
||||
case "auth", "authorization", "token", "secret", "key", "apikey":
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func providerHeaderTokens(name string) []string {
|
||||
return strings.FieldsFunc(name, func(r rune) bool {
|
||||
return !unicode.IsLetter(r) && !unicode.IsDigit(r)
|
||||
})
|
||||
}
|
||||
|
||||
func cloneStringMap(input map[string]string) map[string]string {
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
}
|
||||
cloned := make(map[string]string, len(input))
|
||||
for key, value := range input {
|
||||
cloned[key] = value
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
348
internal/ai/service/provider_secret_test.go
Normal file
348
internal/ai/service/provider_secret_test.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package aiservice
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
"GoNavi-Wails/internal/secretstore"
|
||||
)
|
||||
|
||||
func TestSplitProviderSecretsStripsAPIKeyAndSensitiveHeaders(t *testing.T) {
|
||||
input := ai.ProviderConfig{
|
||||
ID: "openai-main",
|
||||
APIKey: "sk-test",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer test",
|
||||
"X-Team": "db",
|
||||
},
|
||||
}
|
||||
|
||||
meta, bundle := splitProviderSecrets(input)
|
||||
if meta.APIKey != "" {
|
||||
t.Fatal("apiKey should not stay in metadata")
|
||||
}
|
||||
if meta.Headers["Authorization"] != "" {
|
||||
t.Fatal("sensitive header should not stay in metadata")
|
||||
}
|
||||
if meta.Headers["X-Team"] != "db" {
|
||||
t.Fatal("non-sensitive header should stay in metadata")
|
||||
}
|
||||
if bundle.APIKey != "sk-test" {
|
||||
t.Fatal("bundle should keep apiKey")
|
||||
}
|
||||
if bundle.SensitiveHeaders["Authorization"] != "Bearer test" {
|
||||
t.Fatal("bundle should keep sensitive header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveProviderConfigSecretsRestoresStoredSecretBundle(t *testing.T) {
|
||||
store := newFakeProviderSecretStore()
|
||||
service := NewServiceWithSecretStore(store)
|
||||
ref, err := secretstore.BuildRef("ai-provider", "openai-main")
|
||||
if err != nil {
|
||||
t.Fatalf("BuildRef returned error: %v", err)
|
||||
}
|
||||
payload, err := json.Marshal(providerSecretBundle{
|
||||
APIKey: "sk-test",
|
||||
SensitiveHeaders: map[string]string{
|
||||
"Authorization": "Bearer test",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal returned error: %v", err)
|
||||
}
|
||||
if err := store.Put(ref, payload); err != nil {
|
||||
t.Fatalf("Put returned error: %v", err)
|
||||
}
|
||||
|
||||
resolved, err := service.resolveProviderConfigSecrets(ai.ProviderConfig{
|
||||
ID: "openai-main",
|
||||
SecretRef: ref,
|
||||
HasSecret: true,
|
||||
Headers: map[string]string{
|
||||
"X-Team": "db",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("resolveProviderConfigSecrets returned error: %v", err)
|
||||
}
|
||||
if resolved.APIKey != "sk-test" {
|
||||
t.Fatalf("expected restored apiKey, got %q", resolved.APIKey)
|
||||
}
|
||||
if resolved.Headers["Authorization"] != "Bearer test" {
|
||||
t.Fatalf("expected restored sensitive header, got %#v", resolved.Headers)
|
||||
}
|
||||
if resolved.Headers["X-Team"] != "db" {
|
||||
t.Fatalf("expected non-sensitive header to survive, got %#v", resolved.Headers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigMigratesPlaintextProviderSecrets(t *testing.T) {
|
||||
store := newFakeProviderSecretStore()
|
||||
service := NewServiceWithSecretStore(store)
|
||||
service.configDir = t.TempDir()
|
||||
|
||||
legacy := aiConfig{
|
||||
Providers: []ai.ProviderConfig{
|
||||
{
|
||||
ID: "openai-main",
|
||||
Type: "openai",
|
||||
Name: "OpenAI",
|
||||
APIKey: "sk-test",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer test",
|
||||
"X-Team": "db",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := json.MarshalIndent(legacy, "", " ")
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalIndent returned error: %v", err)
|
||||
}
|
||||
configPath := filepath.Join(service.configDir, "ai_config.json")
|
||||
if err := os.WriteFile(configPath, data, 0o644); err != nil {
|
||||
t.Fatalf("WriteFile returned error: %v", err)
|
||||
}
|
||||
|
||||
service.loadConfig()
|
||||
|
||||
providers := service.AIGetProviders()
|
||||
if len(providers) != 1 {
|
||||
t.Fatalf("expected 1 provider, got %d", len(providers))
|
||||
}
|
||||
if providers[0].APIKey != "" {
|
||||
t.Fatalf("expected migrated provider to be secretless, got %q", providers[0].APIKey)
|
||||
}
|
||||
if !providers[0].HasSecret {
|
||||
t.Fatal("expected migrated provider to report HasSecret=true")
|
||||
}
|
||||
stored, err := store.Get(providers[0].SecretRef)
|
||||
if err != nil {
|
||||
t.Fatalf("expected secret bundle in store, got error: %v", err)
|
||||
}
|
||||
var bundle providerSecretBundle
|
||||
if err := json.Unmarshal(stored, &bundle); err != nil {
|
||||
t.Fatalf("Unmarshal returned error: %v", err)
|
||||
}
|
||||
if bundle.APIKey != "sk-test" {
|
||||
t.Fatalf("expected migrated apiKey in store, got %q", bundle.APIKey)
|
||||
}
|
||||
if bundle.SensitiveHeaders["Authorization"] != "Bearer test" {
|
||||
t.Fatalf("expected migrated sensitive header in store, got %#v", bundle.SensitiveHeaders)
|
||||
}
|
||||
|
||||
rewritten, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile returned error: %v", err)
|
||||
}
|
||||
text := string(rewritten)
|
||||
if strings.Contains(text, "sk-test") {
|
||||
t.Fatalf("expected rewritten config to remove api key, got %s", text)
|
||||
}
|
||||
if strings.Contains(text, "Bearer test") {
|
||||
t.Fatalf("expected rewritten config to remove sensitive header, got %s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAISaveProviderPersistsSecretlessConfigAndReturnsSecretlessView(t *testing.T) {
|
||||
store := newFakeProviderSecretStore()
|
||||
service := NewServiceWithSecretStore(store)
|
||||
service.configDir = t.TempDir()
|
||||
|
||||
err := service.AISaveProvider(ai.ProviderConfig{
|
||||
ID: "openai-main",
|
||||
Type: "openai",
|
||||
Name: "OpenAI",
|
||||
APIKey: "sk-test",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer test",
|
||||
"X-Team": "db",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("AISaveProvider returned error: %v", err)
|
||||
}
|
||||
|
||||
providers := service.AIGetProviders()
|
||||
if len(providers) != 1 {
|
||||
t.Fatalf("expected 1 provider, got %d", len(providers))
|
||||
}
|
||||
if providers[0].APIKey != "" {
|
||||
t.Fatalf("expected secretless provider view, got %q", providers[0].APIKey)
|
||||
}
|
||||
if !providers[0].HasSecret {
|
||||
t.Fatal("expected saved provider view to report HasSecret=true")
|
||||
}
|
||||
if providers[0].Headers["Authorization"] != "" {
|
||||
t.Fatalf("expected secretless provider headers, got %#v", providers[0].Headers)
|
||||
}
|
||||
if service.providers[0].APIKey != "sk-test" {
|
||||
t.Fatalf("expected runtime provider to keep apiKey, got %q", service.providers[0].APIKey)
|
||||
}
|
||||
if service.providers[0].Headers["Authorization"] != "Bearer test" {
|
||||
t.Fatalf("expected runtime provider to keep sensitive header, got %#v", service.providers[0].Headers)
|
||||
}
|
||||
|
||||
configPath := filepath.Join(service.configDir, "ai_config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile returned error: %v", err)
|
||||
}
|
||||
text := string(data)
|
||||
if strings.Contains(text, "sk-test") {
|
||||
t.Fatalf("expected config file to be secretless, got %s", text)
|
||||
}
|
||||
if strings.Contains(text, "Bearer test") {
|
||||
t.Fatalf("expected config file to remove sensitive headers, got %s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAISaveProviderKeepsExistingSecretWhenInputOmitsAPIKey(t *testing.T) {
|
||||
store := newFakeProviderSecretStore()
|
||||
service := NewServiceWithSecretStore(store)
|
||||
service.configDir = t.TempDir()
|
||||
|
||||
if err := service.AISaveProvider(ai.ProviderConfig{
|
||||
ID: "openai-main",
|
||||
Type: "openai",
|
||||
Name: "OpenAI",
|
||||
APIKey: "sk-original",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer original",
|
||||
"X-Team": "db",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("initial AISaveProvider returned error: %v", err)
|
||||
}
|
||||
|
||||
if err := service.AISaveProvider(ai.ProviderConfig{
|
||||
ID: "openai-main",
|
||||
Type: "openai",
|
||||
Name: "OpenAI Updated",
|
||||
BaseURL: "https://gateway.openai.com/v1",
|
||||
HasSecret: true,
|
||||
Headers: map[string]string{
|
||||
"X-Team": "platform",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("update AISaveProvider returned error: %v", err)
|
||||
}
|
||||
|
||||
if service.providers[0].APIKey != "sk-original" {
|
||||
t.Fatalf("expected runtime provider to keep original apiKey, got %q", service.providers[0].APIKey)
|
||||
}
|
||||
if service.providers[0].Headers["Authorization"] != "Bearer original" {
|
||||
t.Fatalf("expected runtime provider to keep original sensitive header, got %#v", service.providers[0].Headers)
|
||||
}
|
||||
if service.providers[0].Headers["X-Team"] != "platform" {
|
||||
t.Fatalf("expected runtime provider to update non-sensitive headers, got %#v", service.providers[0].Headers)
|
||||
}
|
||||
if service.providers[0].BaseURL != "https://gateway.openai.com/v1" {
|
||||
t.Fatalf("expected runtime provider to update metadata, got %q", service.providers[0].BaseURL)
|
||||
}
|
||||
|
||||
providers := service.AIGetProviders()
|
||||
if len(providers) != 1 || !providers[0].HasSecret {
|
||||
t.Fatalf("expected provider view to keep HasSecret=true, got %#v", providers)
|
||||
}
|
||||
if providers[0].APIKey != "" {
|
||||
t.Fatalf("expected provider view to stay secretless, got %q", providers[0].APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAISaveProviderMergesStoredSensitiveHeadersWhenUpdatingOnlyAPIKey(t *testing.T) {
|
||||
store := newFakeProviderSecretStore()
|
||||
service := NewServiceWithSecretStore(store)
|
||||
service.configDir = t.TempDir()
|
||||
|
||||
if err := service.AISaveProvider(ai.ProviderConfig{
|
||||
ID: "openai-main",
|
||||
Type: "openai",
|
||||
Name: "OpenAI",
|
||||
APIKey: "sk-original",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer original",
|
||||
"X-Team": "db",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("initial AISaveProvider returned error: %v", err)
|
||||
}
|
||||
|
||||
if err := service.AISaveProvider(ai.ProviderConfig{
|
||||
ID: "openai-main",
|
||||
Type: "openai",
|
||||
Name: "OpenAI",
|
||||
APIKey: "sk-updated",
|
||||
HasSecret: true,
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
Headers: map[string]string{
|
||||
"X-Team": "db",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("update AISaveProvider returned error: %v", err)
|
||||
}
|
||||
|
||||
if service.providers[0].APIKey != "sk-updated" {
|
||||
t.Fatalf("expected updated apiKey, got %q", service.providers[0].APIKey)
|
||||
}
|
||||
if service.providers[0].Headers["Authorization"] != "Bearer original" {
|
||||
t.Fatalf("expected existing sensitive header to be kept, got %#v", service.providers[0].Headers)
|
||||
}
|
||||
|
||||
stored, err := store.Get(service.providers[0].SecretRef)
|
||||
if err != nil {
|
||||
t.Fatalf("expected merged secret bundle in store, got %v", err)
|
||||
}
|
||||
var bundle providerSecretBundle
|
||||
if err := json.Unmarshal(stored, &bundle); err != nil {
|
||||
t.Fatalf("Unmarshal returned error: %v", err)
|
||||
}
|
||||
if bundle.APIKey != "sk-updated" {
|
||||
t.Fatalf("expected store to keep updated apiKey, got %q", bundle.APIKey)
|
||||
}
|
||||
if bundle.SensitiveHeaders["Authorization"] != "Bearer original" {
|
||||
t.Fatalf("expected store to keep existing sensitive header, got %#v", bundle.SensitiveHeaders)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeProviderSecretStore struct {
|
||||
items map[string][]byte
|
||||
}
|
||||
|
||||
func newFakeProviderSecretStore() *fakeProviderSecretStore {
|
||||
return &fakeProviderSecretStore{items: make(map[string][]byte)}
|
||||
}
|
||||
|
||||
func (s *fakeProviderSecretStore) Put(ref string, payload []byte) error {
|
||||
s.items[ref] = append([]byte(nil), payload...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeProviderSecretStore) Get(ref string) ([]byte, error) {
|
||||
payload, ok := s.items[ref]
|
||||
if !ok {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
return append([]byte(nil), payload...), nil
|
||||
}
|
||||
|
||||
func (s *fakeProviderSecretStore) Delete(ref string) error {
|
||||
delete(s.items, ref)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeProviderSecretStore) HealthCheck() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ secretstore.SecretStore = (*fakeProviderSecretStore)(nil)
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"GoNavi-Wails/internal/ai/provider"
|
||||
"GoNavi-Wails/internal/ai/safety"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
"GoNavi-Wails/internal/secretstore"
|
||||
|
||||
"github.com/google/uuid"
|
||||
wailsRuntime "github.com/wailsapp/wails/v2/pkg/runtime"
|
||||
@@ -32,7 +33,8 @@ type Service struct {
|
||||
safetyLevel ai.SQLPermissionLevel
|
||||
contextLevel ai.ContextLevel
|
||||
guard *safety.Guard
|
||||
configDir string // 配置存储目录
|
||||
configDir string // 配置存储目录
|
||||
secretStore secretstore.SecretStore
|
||||
cancelFuncs map[string]context.CancelFunc // 记录每个 session 的 context 取消函数
|
||||
}
|
||||
|
||||
@@ -97,11 +99,19 @@ var claudeCLIHealthCheckFunc = func(config ai.ProviderConfig) error {
|
||||
|
||||
// NewService 创建 AI Service 实例
|
||||
func NewService() *Service {
|
||||
return NewServiceWithSecretStore(secretstore.NewKeyringStore())
|
||||
}
|
||||
|
||||
func NewServiceWithSecretStore(store secretstore.SecretStore) *Service {
|
||||
if store == nil {
|
||||
store = secretstore.NewUnavailableStore("secret store unavailable")
|
||||
}
|
||||
return &Service{
|
||||
providers: make([]ai.ProviderConfig, 0),
|
||||
safetyLevel: ai.PermissionReadOnly,
|
||||
contextLevel: ai.ContextSchemaOnly,
|
||||
guard: safety.NewGuard(ai.PermissionReadOnly),
|
||||
secretStore: store,
|
||||
cancelFuncs: make(map[string]context.CancelFunc),
|
||||
}
|
||||
}
|
||||
@@ -127,35 +137,80 @@ func (s *Service) AIGetProviders() []ai.ProviderConfig {
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
result := make([]ai.ProviderConfig, len(s.providers))
|
||||
copy(result, s.providers)
|
||||
for i := range result {
|
||||
result[i] = normalizeProviderConfig(result[i])
|
||||
for i := range s.providers {
|
||||
result[i] = providerMetadataView(s.providers[i])
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// AISaveProvider 保存/更新 Provider 配置
|
||||
func (s *Service) AISaveProvider(config ai.ProviderConfig) error {
|
||||
fmt.Printf("[AISaveProvider DEBUG] ID: %s, Model: %s\n", config.ID, config.Model)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
config = normalizeProviderConfig(config)
|
||||
|
||||
if strings.TrimSpace(config.ID) == "" {
|
||||
config.ID = "provider-" + uuid.New().String()[:8]
|
||||
}
|
||||
|
||||
var existing ai.ProviderConfig
|
||||
found := false
|
||||
for i, p := range s.providers {
|
||||
if p.ID == config.ID {
|
||||
s.providers[i] = config
|
||||
for _, providerConfig := range s.providers {
|
||||
if providerConfig.ID == config.ID {
|
||||
existing = providerConfig
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
s.providers = append(s.providers, config)
|
||||
|
||||
meta, bundle := splitProviderSecrets(config)
|
||||
var runtimeConfig ai.ProviderConfig
|
||||
switch {
|
||||
case bundle.hasAny():
|
||||
mergedBundle := bundle
|
||||
if found && existing.HasSecret {
|
||||
_, existingBundle := splitProviderSecrets(existing)
|
||||
mergedBundle = mergeProviderSecretBundles(existingBundle, bundle)
|
||||
}
|
||||
if found && strings.TrimSpace(meta.SecretRef) == "" {
|
||||
meta.SecretRef = existing.SecretRef
|
||||
}
|
||||
storedMeta, err := s.persistProviderSecretBundle(meta, mergedBundle)
|
||||
if err != nil {
|
||||
return fmt.Errorf("保存 Provider secret 失败: %w", err)
|
||||
}
|
||||
runtimeConfig = mergeProviderSecrets(storedMeta, mergedBundle)
|
||||
case found && (config.HasSecret || existing.HasSecret):
|
||||
meta.SecretRef = existing.SecretRef
|
||||
meta.HasSecret = config.HasSecret || existing.HasSecret
|
||||
resolved, err := s.resolveProviderConfigSecrets(meta)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取已保存 Provider secret 失败: %w", err)
|
||||
}
|
||||
runtimeConfig = resolved
|
||||
default:
|
||||
runtimeConfig = meta
|
||||
}
|
||||
|
||||
if !runtimeConfig.HasSecret && found && strings.TrimSpace(existing.SecretRef) != "" {
|
||||
if err := s.secretStore.Delete(existing.SecretRef); err != nil {
|
||||
return fmt.Errorf("删除 Provider secret 失败: %w", err)
|
||||
}
|
||||
}
|
||||
if !runtimeConfig.HasSecret {
|
||||
runtimeConfig.SecretRef = ""
|
||||
}
|
||||
|
||||
runtimeConfig = normalizeProviderConfig(runtimeConfig)
|
||||
if found {
|
||||
for i := range s.providers {
|
||||
if s.providers[i].ID == runtimeConfig.ID {
|
||||
s.providers[i] = runtimeConfig
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.providers = append(s.providers, runtimeConfig)
|
||||
}
|
||||
|
||||
return s.saveConfig()
|
||||
@@ -167,9 +222,19 @@ func (s *Service) AIDeleteProvider(id string) error {
|
||||
defer s.mu.Unlock()
|
||||
|
||||
newProviders := make([]ai.ProviderConfig, 0, len(s.providers))
|
||||
for _, p := range s.providers {
|
||||
if p.ID != id {
|
||||
newProviders = append(newProviders, p)
|
||||
var removed ai.ProviderConfig
|
||||
removedFound := false
|
||||
for _, providerConfig := range s.providers {
|
||||
if providerConfig.ID == id {
|
||||
removed = providerConfig
|
||||
removedFound = true
|
||||
continue
|
||||
}
|
||||
newProviders = append(newProviders, providerConfig)
|
||||
}
|
||||
if removedFound && strings.TrimSpace(removed.SecretRef) != "" {
|
||||
if err := s.secretStore.Delete(removed.SecretRef); err != nil {
|
||||
return fmt.Errorf("删除 Provider secret 失败: %w", err)
|
||||
}
|
||||
}
|
||||
s.providers = newProviders
|
||||
@@ -186,17 +251,29 @@ func (s *Service) AIDeleteProvider(id string) error {
|
||||
|
||||
// AITestProvider 测试 Provider 配置是否可用,仅测试端点连通性与密钥,不实际调用对话
|
||||
func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{} {
|
||||
// 如果传入脱敏的 key,使用已保存的 key
|
||||
s.mu.RLock()
|
||||
if isMaskedAPIKey(config.APIKey) {
|
||||
for _, p := range s.providers {
|
||||
if p.ID == config.ID {
|
||||
config.APIKey = p.APIKey
|
||||
break
|
||||
config.APIKey = ""
|
||||
config.HasSecret = true
|
||||
}
|
||||
if strings.TrimSpace(config.APIKey) == "" && (config.HasSecret || strings.TrimSpace(config.SecretRef) != "") {
|
||||
s.mu.RLock()
|
||||
if strings.TrimSpace(config.SecretRef) == "" {
|
||||
for _, providerConfig := range s.providers {
|
||||
if providerConfig.ID == config.ID {
|
||||
config.SecretRef = providerConfig.SecretRef
|
||||
config.HasSecret = config.HasSecret || providerConfig.HasSecret
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
resolved, err := s.resolveProviderConfigSecrets(config)
|
||||
if err != nil {
|
||||
return map[string]interface{}{"success": false, "message": fmt.Sprintf("连接测试失败: %s", err.Error())}
|
||||
}
|
||||
config = resolved
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
config = normalizeProviderConfig(config)
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
|
||||
@@ -842,13 +919,35 @@ func (s *Service) getActiveProvider() (provider.Provider, error) {
|
||||
|
||||
// --- 配置持久化 ---
|
||||
|
||||
const aiConfigSchemaVersion = 2
|
||||
|
||||
type aiConfig struct {
|
||||
SchemaVersion int `json:"schemaVersion,omitempty"`
|
||||
Providers []ai.ProviderConfig `json:"providers"`
|
||||
ActiveProvider string `json:"activeProvider"`
|
||||
SafetyLevel string `json:"safetyLevel"`
|
||||
ContextLevel string `json:"contextLevel"`
|
||||
}
|
||||
|
||||
func (s *Service) loadRuntimeProviderConfig(config ai.ProviderConfig) (ai.ProviderConfig, bool, error) {
|
||||
meta, bundle := splitProviderSecrets(config)
|
||||
if bundle.hasAny() {
|
||||
storedMeta, err := s.persistProviderSecretBundle(meta, bundle)
|
||||
if err != nil {
|
||||
meta.HasSecret = false
|
||||
meta.SecretRef = ""
|
||||
return meta, true, err
|
||||
}
|
||||
return mergeProviderSecrets(storedMeta, bundle), true, nil
|
||||
}
|
||||
|
||||
resolved, err := s.resolveProviderConfigSecrets(meta)
|
||||
if err != nil {
|
||||
return meta, false, err
|
||||
}
|
||||
return resolved, false, nil
|
||||
}
|
||||
|
||||
func (s *Service) loadConfig() {
|
||||
path := filepath.Join(s.configDir, "ai_config.json")
|
||||
data, err := os.ReadFile(path)
|
||||
@@ -862,13 +961,22 @@ func (s *Service) loadConfig() {
|
||||
return
|
||||
}
|
||||
|
||||
s.providers = cfg.Providers
|
||||
if s.providers == nil {
|
||||
s.providers = make([]ai.ProviderConfig, 0)
|
||||
providers := make([]ai.ProviderConfig, 0, len(cfg.Providers))
|
||||
shouldRewrite := cfg.SchemaVersion != aiConfigSchemaVersion
|
||||
for _, providerConfig := range cfg.Providers {
|
||||
runtimeConfig, rewritten, err := s.loadRuntimeProviderConfig(normalizeProviderConfig(providerConfig))
|
||||
if err != nil {
|
||||
logger.Error(err, "加载 AI Provider secret 失败,provider=%s", providerConfig.ID)
|
||||
}
|
||||
if rewritten {
|
||||
shouldRewrite = true
|
||||
}
|
||||
providers = append(providers, runtimeConfig)
|
||||
}
|
||||
for i := range s.providers {
|
||||
s.providers[i] = normalizeProviderConfig(s.providers[i])
|
||||
if providers == nil {
|
||||
providers = make([]ai.ProviderConfig, 0)
|
||||
}
|
||||
s.providers = providers
|
||||
s.activeProvider = cfg.ActiveProvider
|
||||
|
||||
switch ai.SQLPermissionLevel(cfg.SafetyLevel) {
|
||||
@@ -885,11 +993,23 @@ func (s *Service) loadConfig() {
|
||||
default:
|
||||
s.contextLevel = ai.ContextSchemaOnly
|
||||
}
|
||||
|
||||
if shouldRewrite {
|
||||
if err := s.saveConfig(); err != nil {
|
||||
logger.Error(err, "重写 AI 配置失败")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) saveConfig() error {
|
||||
providers := make([]ai.ProviderConfig, len(s.providers))
|
||||
for i := range s.providers {
|
||||
providers[i] = providerMetadataView(s.providers[i])
|
||||
}
|
||||
|
||||
cfg := aiConfig{
|
||||
Providers: s.providers,
|
||||
SchemaVersion: aiConfigSchemaVersion,
|
||||
Providers: providers,
|
||||
ActiveProvider: s.activeProvider,
|
||||
SafetyLevel: string(s.safetyLevel),
|
||||
ContextLevel: string(s.contextLevel),
|
||||
|
||||
@@ -69,6 +69,8 @@ type ProviderConfig struct {
|
||||
Type string `json:"type"` // openai | anthropic | gemini | custom
|
||||
Name string `json:"name"`
|
||||
APIKey string `json:"apiKey"`
|
||||
SecretRef string `json:"secretRef,omitempty"`
|
||||
HasSecret bool `json:"hasSecret,omitempty"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
Model string `json:"model"`
|
||||
Models []string `json:"models,omitempty"`
|
||||
|
||||
Reference in New Issue
Block a user