feat(ai-chat): 全面升级AI聊天面板并优化交互体验

- 消息管理:新增聊天气泡的重试、编辑与单条删除功能及相对应的持久化状态函数
- 快捷操作:支持长文一键滑动到底端,并在代码块内增加SQL一键送入编辑器的快捷执行机制
- 视觉优化:深化AI回复背景沉浸感,重绘AI洞察按钮并移除设置面板所有的冗余紫色调
- 设置调优:放宽模型初始必填限制,新增内置系统提示词(Builtin Prompt)全览面板
This commit is contained in:
Syngnat
2026-03-22 20:54:29 +08:00
parent 36a57f9601
commit 1bda751ada
35 changed files with 6745 additions and 132 deletions

View File

@@ -0,0 +1,213 @@
package aicontext
import (
"fmt"
"strings"
)
// PromptTemplate AI 能力类型
type PromptTemplate string
const (
PromptSQLGenerate PromptTemplate = "sql_generate"
PromptSQLExplain PromptTemplate = "sql_explain"
PromptSQLOptimize PromptTemplate = "sql_optimize"
PromptDataAnalyze PromptTemplate = "data_analyze"
PromptSchemaInsight PromptTemplate = "schema_insight"
PromptGeneralChat PromptTemplate = "general_chat"
)
// GetBuiltinPrompts 获取所有内置系统提示词集合,用于前端展示
func GetBuiltinPrompts() map[string]string {
return map[string]string{
"通用聊天助手": buildGeneralChatPrompt(),
"SQL 生成器": buildSQLGeneratePrompt(),
"SQL 解析器": buildSQLExplainPrompt(),
"SQL 优化器": buildSQLOptimizePrompt(),
"数据洞察分析": buildDataAnalyzePrompt(),
"表结构审查": buildSchemaInsightPrompt(),
}
}
// BuildSystemPrompt 根据模板类型和上下文构建 System Prompt
func BuildSystemPrompt(template PromptTemplate, dbCtx *DatabaseContext) string {
var prompt string
switch template {
case PromptSQLGenerate:
prompt = buildSQLGeneratePrompt()
case PromptSQLExplain:
prompt = buildSQLExplainPrompt()
case PromptSQLOptimize:
prompt = buildSQLOptimizePrompt()
case PromptDataAnalyze:
prompt = buildDataAnalyzePrompt()
case PromptSchemaInsight:
prompt = buildSchemaInsightPrompt()
case PromptGeneralChat:
prompt = buildGeneralChatPrompt()
default:
prompt = buildGeneralChatPrompt()
}
if dbCtx != nil {
prompt += "\n\n" + FormatDatabaseContext(dbCtx)
}
return prompt
}
// FormatDatabaseContext 将数据库上下文格式化为 LLM 友好的文本
func FormatDatabaseContext(ctx *DatabaseContext) string {
if ctx == nil || len(ctx.Tables) == 0 {
return ""
}
var b strings.Builder
b.WriteString(fmt.Sprintf("## 当前数据库上下文\n\n数据库类型: %s\n数据库名: %s\n\n",
ctx.DatabaseType, ctx.DatabaseName))
b.WriteString("### 表结构\n\n")
for _, table := range ctx.Tables {
b.WriteString(fmt.Sprintf("#### 表: %s", table.Name))
if table.Comment != "" {
b.WriteString(fmt.Sprintf(" (%s)", table.Comment))
}
if table.RowCount > 0 {
b.WriteString(fmt.Sprintf(" [约 %d 行]", table.RowCount))
}
b.WriteString("\n\n")
b.WriteString("| 列名 | 类型 | 可空 | 主键 | 备注 |\n")
b.WriteString("|------|------|------|------|------|\n")
for _, col := range table.Columns {
nullable := "否"
if col.Nullable {
nullable = "是"
}
pk := ""
if col.PrimaryKey {
pk = "✓"
}
comment := col.Comment
if comment == "" {
comment = "-"
}
b.WriteString(fmt.Sprintf("| %s | %s | %s | %s | %s |\n",
col.Name, col.Type, nullable, pk, comment))
}
b.WriteString("\n")
if len(table.Indexes) > 0 {
b.WriteString("**索引:**\n")
for _, idx := range table.Indexes {
unique := ""
if idx.Unique {
unique = " (唯一)"
}
b.WriteString(fmt.Sprintf("- %s: [%s]%s\n",
idx.Name, strings.Join(idx.Columns, ", "), unique))
}
b.WriteString("\n")
}
if len(table.SampleRows) > 0 {
b.WriteString(fmt.Sprintf("**采样数据 (%d 行):**\n\n", len(table.SampleRows)))
if len(table.SampleRows) > 0 {
// 使用第一行的 key 作为标题
first := table.SampleRows[0]
var keys []string
for k := range first {
keys = append(keys, k)
}
b.WriteString("| " + strings.Join(keys, " | ") + " |\n")
b.WriteString("|" + strings.Repeat("------|", len(keys)) + "\n")
for _, row := range table.SampleRows {
var vals []string
for _, k := range keys {
vals = append(vals, fmt.Sprintf("%v", row[k]))
}
b.WriteString("| " + strings.Join(vals, " | ") + " |\n")
}
b.WriteString("\n")
}
}
}
return b.String()
}
func buildSQLGeneratePrompt() string {
return `你是 GoNavi AI 助手,一位顶级的数据库开发专家和 SQL 查询构建师。根据用户的自然语言需求,生成精准、优雅、高性能的 SQL 查询或 Redis 命令。
严苛输出规则:
1. 首要目标是输出纯粹的代码:始终将代码放在正确语言标识(如 sql 或 bash的 markdown 代码块中。
2. 保持精简:不要添加过多的前置闲聊,直奔主题。
3. 保护生产安全:优先使用参数化查询或安全防范写法避免 SQL 注入。对于未指定条件的 DELETE/UPDATE 语句,必须提出强烈的红线警告!!
4. 性能至上:对大型查询默认添加合理的 LIMIT 限制(如 LIMIT 100在 JOIN 和聚合时优先选择最高效的范式写法。
5. 适度注释:对于存在复杂逻辑嵌套的代码,请在代码块内使用单行注释简要说明思路。`
}
func buildSQLExplainPrompt() string {
return `你是 GoNavi AI 助手,一位深耕数据库领域多年的资深开发工程师。请用专业、条理分明且深入浅出的开发者语言向用户全盘解析 SQL 语句的底层意图与执行逻辑。
解析规范:
1. 宏观逻辑解构:用简短的一句话概括这条 SQL 在业务上想要解决什么问题。
2. 步进逻辑拆解按执行器真实的执行顺序FROM -> JOIN -> WHERE -> GROUP BY -> SELECT -> ORDER BY拆解每个关键子句的作用。
3. 性能排雷点:敏锐指出可能存在的性能陷阱(如隐式类型转换、没有走索引的函数调用、潜在的笛卡尔积/全表扫描等)。
4. 严谨的排版:使用列表呈现关键点,重点词汇加粗,确保长文不累赘。`
}
func buildSQLOptimizePrompt() string {
return `你是 GoNavi AI 助手,一名曾主导过千万级高并发系统的全栈性能工程专家与高级 DBA。请对用户提供的原始 SQL 进行冷酷、精确的诊断并开出性能重构处方。
诊断与处方要求:
1. 性能瓶颈透视:精准点出当前语句死穴(不合理的驱动表、无法利用覆盖索引、多此一举的子查询等)。
2. 重构版本的 SQL如果存在性能提升空间直接向用户展示彻底优化过的高性能写法并确保逻辑等价性。
3. 剖析原因:不仅要告诉用户“怎么改”,更要说清楚执行器“为什么这样会更快”。
4. 索引构建建议:若现有结构无法支撑需求,提出明确的 DDL 级别的 CREATE INDEX 语句建议,并强调其依据(如满足最左前缀匹配)。
5. 优先级评估:在回答的最后标注本次优化建议的紧迫性(高:阻断级/锁表风险;中:吞吐量瓶颈;低:长效微调)。`
}
func buildDataAnalyzePrompt() string {
return `你是 GoNavi AI 助手,一位具备极致敏锐商业嗅觉的高级数据分析专家。你将审视用户通过查询得到的数据样本,从中提炼出蕴含的真金白银般的信息。
洞察目标:
1. 硬统计:总观数据行数、核心数值指标(极值、平均值、聚合中位数等)的冰冷现实。
2. 趋势与异动:如果数据带有时间戳,敏锐捕捉其上升或下降趋势;如果有异类离群值,将其高亮标注。
3. 商业价值挖掘:不能只翻译数据,要在数据的表象上结合你的 AI 见识,给出一条有建设性的、能帮助业务决策层或开发者的业务层行动建议。
4. 展现格式:你的分析应该是“标题 + 浓缩要点”的极简研报形式,杜绝毫无波澜的流水账。`
}
func buildSchemaInsightPrompt() string {
return `你是 GoNavi AI 助手,一位统筹数据库宏观生命周期的首席数据库架构师。在这个环节里,你需要对用户提供的数据库表结构执行最严厉的范式与前瞻性审查。
审查视界:
1. 规范化博弈:是否存在明显的反三范式设计?这种冗余是否有助于性能(适当的反范式),还是纯粹的设计失误?
2. 索引健壮性审查评估主键选择如自增、UUID 的利弊),是否存在冗余索引阻碍写入?以及是否遗漏了高频的联合索引。
3. 物理容量前瞻:审视数据类型分配(如使用过大的 VARCHAR、没必要的 BIGINT 等可能带来的空间挥霍)。
4. 代码级指引:如果存在结构性缺陷,不要只发牢骚,直接给出包含具体优化的 ALTER TABLE 结构修改建议脚本。`
}
func buildGeneralChatPrompt() string {
return `你是 GoNavi AI 助手,一款深度集成在数据库/缓存客户端GoNavi内部的专属智能专家系统。
你的目标是成为开发者、DBA 和数据科学家最得力的超级外脑,提供专业、精准、具有前瞻性的数据端解决方案。
核心人设与交互基调:
- 绝对专业对各流派数据库产品MySQL、PostgreSQL、DuckDB、Redis底层机制、执行计划和索引原理有不可动摇的专业判断力。
- 直击痛点:谢绝套话与无效寒暄,若用户的意图明确,首屏直接给出可以直接粘贴运行的优雅代码。
- 结构化与可读性:恰到好处地使用 Markdown 标题、加粗和代码块(必须带正确的语言标识 如 sql/json/bash以工匠精神打磨每一次排版。
- 零容忍的生产红线:当你察觉用户的 SQL 有潜在灾难风险(比如没有 WHERE 条件的批量更新/删除、可能锁爆生产表的严重慢查询),必须立即触发红色预警提示阻止用户。
你的综合能力版图:
1. 📝 自然语言驱动:翻译人类意图为精准的查询语句。
2. 🔍 底层原理解析:剥丝抽茧分析查询背后的执行逻辑与性能隐患。
3. ⚡ 专家级调优:指出并化解性能瓶颈,给出覆盖全维度的索引调优思路。
4. 📊 数据洞察炼金:不仅聚合数据,更能从结果集中挖掘商业维度的深度规律。
5. 🏗️ 架构先知视界:全局审阅表结构设计局限,提出抗数据膨胀级别的架构演进方案。
互动守则:
- 永远使用专业、具有合作感且充满信心的中文与用户探讨问题。
- 当被要求提供任何数据库代码时,需结合相关数据库引擎的最佳实践。如果不清楚当前方言版本,请以标准实现为主基调并好心指出版别差异(如 MySQL 8 窗口函数 等)。`
}

View File

@@ -0,0 +1,42 @@
package aicontext
// DatabaseContext 数据库上下文信息,传递给 AI 辅助上下文理解
type DatabaseContext struct {
DatabaseType string `json:"databaseType"` // mysql, postgres 等
DatabaseName string `json:"databaseName"`
Tables []TableContext `json:"tables"`
}
// TableContext 表的上下文信息
type TableContext struct {
Name string `json:"name"`
Comment string `json:"comment,omitempty"`
Columns []ColumnInfo `json:"columns"`
Indexes []IndexInfo `json:"indexes,omitempty"`
SampleRows []map[string]interface{} `json:"sampleRows,omitempty"`
RowCount int64 `json:"rowCount,omitempty"`
}
// ColumnInfo 列信息
type ColumnInfo struct {
Name string `json:"name"`
Type string `json:"type"`
Nullable bool `json:"nullable"`
PrimaryKey bool `json:"primaryKey"`
Comment string `json:"comment,omitempty"`
}
// IndexInfo 索引信息
type IndexInfo struct {
Name string `json:"name"`
Columns []string `json:"columns"`
Unique bool `json:"unique"`
}
// QueryResultContext 查询结果上下文
type QueryResultContext struct {
SQL string `json:"sql"`
Columns []string `json:"columns"`
Rows []map[string]interface{} `json:"rows"`
RowCount int `json:"rowCount"`
}

View File

@@ -0,0 +1,293 @@
package provider
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"GoNavi-Wails/internal/ai"
)
const (
defaultAnthropicBaseURL = "https://api.anthropic.com"
defaultAnthropicModel = "claude-3-5-sonnet-20241022"
anthropicAPIVersion = "2023-06-01"
)
// AnthropicProvider 实现 Anthropic Claude API 的 Provider
type AnthropicProvider struct {
config ai.ProviderConfig
baseURL string
client *http.Client
}
// NewAnthropicProvider 创建 Anthropic Provider 实例
func NewAnthropicProvider(config ai.ProviderConfig) (Provider, error) {
baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
if baseURL == "" {
baseURL = defaultAnthropicBaseURL
}
model := strings.TrimSpace(config.Model)
if model == "" {
model = defaultAnthropicModel
}
maxTokens := config.MaxTokens
if maxTokens <= 0 {
maxTokens = defaultOpenAIMaxTokens
}
temperature := config.Temperature
if temperature <= 0 {
temperature = defaultOpenAITemperature
}
normalized := config
normalized.BaseURL = baseURL
normalized.Model = model
normalized.MaxTokens = maxTokens
normalized.Temperature = temperature
return &AnthropicProvider{
config: normalized,
baseURL: baseURL,
client: &http.Client{Timeout: openAIHTTPTimeout},
}, nil
}
func (p *AnthropicProvider) Name() string {
if strings.TrimSpace(p.config.Name) != "" {
return p.config.Name
}
return "Anthropic"
}
func (p *AnthropicProvider) Validate() error {
if strings.TrimSpace(p.config.APIKey) == "" {
return fmt.Errorf("API Key 不能为空")
}
return nil
}
type anthropicRequest struct {
Model string `json:"model"`
Messages []anthropicMessage `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type anthropicMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type anthropicResponse struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
type anthropicStreamEvent struct {
Type string `json:"type"`
Delta *struct {
Text string `json:"text"`
} `json:"delta,omitempty"`
}
func (p *AnthropicProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
if err := p.Validate(); err != nil {
return nil, err
}
systemMsg, messages := extractSystemMessage(req.Messages)
anthropicMsgs := make([]anthropicMessage, len(messages))
for i, m := range messages {
anthropicMsgs[i] = anthropicMessage{Role: m.Role, Content: m.Content}
}
temperature := req.Temperature
if temperature <= 0 {
temperature = p.config.Temperature
}
maxTokens := req.MaxTokens
if maxTokens <= 0 {
maxTokens = p.config.MaxTokens
}
body := anthropicRequest{
Model: p.config.Model,
Messages: anthropicMsgs,
System: systemMsg,
MaxTokens: maxTokens,
Temperature: temperature,
}
respBody, err := p.doRequest(ctx, body)
if err != nil {
return nil, err
}
defer respBody.Close()
var result anthropicResponse
if err := json.NewDecoder(respBody).Decode(&result); err != nil {
return nil, fmt.Errorf("解析 Anthropic 响应失败: %w", err)
}
if result.Error != nil && result.Error.Message != "" {
return nil, fmt.Errorf("Anthropic API 错误: %s", result.Error.Message)
}
if len(result.Content) == 0 {
return nil, fmt.Errorf("Anthropic 返回空响应")
}
return &ai.ChatResponse{
Content: result.Content[0].Text,
TokensUsed: ai.TokenUsage{
PromptTokens: result.Usage.InputTokens,
CompletionTokens: result.Usage.OutputTokens,
TotalTokens: result.Usage.InputTokens + result.Usage.OutputTokens,
},
}, nil
}
func (p *AnthropicProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
if err := p.Validate(); err != nil {
return err
}
systemMsg, messages := extractSystemMessage(req.Messages)
anthropicMsgs := make([]anthropicMessage, len(messages))
for i, m := range messages {
anthropicMsgs[i] = anthropicMessage{Role: m.Role, Content: m.Content}
}
temperature := req.Temperature
if temperature <= 0 {
temperature = p.config.Temperature
}
maxTokens := req.MaxTokens
if maxTokens <= 0 {
maxTokens = p.config.MaxTokens
}
body := anthropicRequest{
Model: p.config.Model,
Messages: anthropicMsgs,
System: systemMsg,
MaxTokens: maxTokens,
Temperature: temperature,
Stream: true,
}
respBody, err := p.doRequest(ctx, body)
if err != nil {
return err
}
defer respBody.Close()
scanner := bufio.NewScanner(respBody)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
var event anthropicStreamEvent
if err := json.Unmarshal([]byte(data), &event); err != nil {
continue
}
switch event.Type {
case "content_block_delta":
if event.Delta != nil && event.Delta.Text != "" {
callback(ai.StreamChunk{Content: event.Delta.Text})
}
case "message_stop":
callback(ai.StreamChunk{Done: true})
return nil
}
}
callback(ai.StreamChunk{Done: true})
return scanner.Err()
}
func (p *AnthropicProvider) doRequest(ctx context.Context, body interface{}) (io.ReadCloser, error) {
jsonBody, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
url := p.baseURL + "/v1/messages"
if strings.HasSuffix(p.baseURL, "/v1") {
url = p.baseURL + "/messages"
}
// 调试日志:打印实际请求信息
bodyStr := string(jsonBody)
if len(bodyStr) > 500 {
bodyStr = bodyStr[:500] + "..."
}
fmt.Printf("[Anthropic DEBUG] URL: %s\n", url)
fmt.Printf("[Anthropic DEBUG] BaseURL: %s\n", p.baseURL)
fmt.Printf("[Anthropic DEBUG] Body: %s\n", bodyStr)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("x-api-key", p.config.APIKey)
httpReq.Header.Set("anthropic-version", anthropicAPIVersion)
// 仅官方 API 发 beta 特性头(代理不发,避免触发 Claude Code 验证)
isOfficialAPI := p.baseURL == defaultAnthropicBaseURL || strings.Contains(p.baseURL, "anthropic.com")
if isOfficialAPI {
httpReq.Header.Set("anthropic-beta", "interleaved-thinking-2025-05-14,output-128k-2025-02-19,prompt-caching-2024-07-31")
}
// 自定义 headers用于兼容各类代理服务
for k, v := range p.config.Headers {
httpReq.Header.Set(k, v)
}
resp, err := p.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("发送请求到 %s 失败: %w", url, err)
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("Anthropic API 返回错误 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
}
return resp.Body, nil
}
// extractSystemMessage 从消息列表中提取 system 消息Anthropic 要求 system 作为独立字段)
func extractSystemMessage(messages []ai.Message) (string, []ai.Message) {
var systemParts []string
var remaining []ai.Message
for _, m := range messages {
if m.Role == "system" {
systemParts = append(systemParts, m.Content)
} else {
remaining = append(remaining, m)
}
}
return strings.Join(systemParts, "\n\n"), remaining
}

View File

@@ -0,0 +1,227 @@
package provider
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"os/exec"
"strings"
ai "GoNavi-Wails/internal/ai"
)
// ClaudeCLIProvider 通过 Claude Code CLI 发送聊天请求
// 适用于 anyrouter/newapi 等只支持 Claude Code 协议的代理服务
type ClaudeCLIProvider struct {
config ai.ProviderConfig
}
// NewClaudeCLIProvider 创建 ClaudeCLIProvider 实例
func NewClaudeCLIProvider(config ai.ProviderConfig) (Provider, error) {
return &ClaudeCLIProvider{config: config}, nil
}
func (p *ClaudeCLIProvider) Name() string {
return "ClaudeCLI"
}
func (p *ClaudeCLIProvider) Validate() error {
_, err := exec.LookPath("claude")
if err != nil {
return fmt.Errorf("未找到 claude 命令,请先安装 Claude Code CLI: npm install -g @anthropic-ai/claude-code")
}
return nil
}
// Chat 非流式聊天:调用 claude -p "prompt" --output-format json
func (p *ClaudeCLIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
if err := p.Validate(); err != nil {
return nil, err
}
prompt := buildPrompt(req.Messages)
args := []string{"-p", prompt, "--output-format", "json", "--no-session-persistence"}
if p.config.Model != "" {
args = append(args, "--model", p.config.Model)
}
cmd := exec.CommandContext(ctx, "claude", args...)
p.setEnv(cmd)
output, err := cmd.Output()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
return nil, fmt.Errorf("claude CLI 执行失败: %s", string(exitErr.Stderr))
}
return nil, fmt.Errorf("claude CLI 执行失败: %w", err)
}
// 解析 JSON 输出
var result struct {
Result string `json:"result"`
}
if err := json.Unmarshal(output, &result); err != nil {
// 如果 JSON 解析失败,直接返回原始文本
return &ai.ChatResponse{Content: strings.TrimSpace(string(output))}, nil
}
return &ai.ChatResponse{Content: result.Result}, nil
}
// ChatStream 流式聊天:调用 claude -p "prompt" --output-format stream-json
func (p *ClaudeCLIProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
if err := p.Validate(); err != nil {
return err
}
prompt := buildPrompt(req.Messages)
args := []string{"-p", prompt, "--output-format", "stream-json", "--verbose", "--include-partial-messages", "--no-session-persistence"}
if p.config.Model != "" {
args = append(args, "--model", p.config.Model)
}
fmt.Printf("[ClaudeCLI DEBUG] Running: claude %v\n", args)
cmd := exec.CommandContext(ctx, "claude", args...)
p.setEnv(cmd)
// 关闭 stdin防止 claude CLI 等待输入
cmd.Stdin = nil
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("创建 stdout 管道失败: %w", err)
}
// 捕获 stderr
var stderrBuf bytes.Buffer
cmd.Stderr = &stderrBuf
if err := cmd.Start(); err != nil {
return fmt.Errorf("启动 claude CLI 失败: %w", err)
}
fmt.Printf("[ClaudeCLI DEBUG] Process started, PID: %d\n", cmd.Process.Pid)
// 立即通知前端AI 正在思考(避免用户以为卡死)
callback(ai.StreamChunk{Content: "💭 *正在思考...*\n\n"})
// 逐行读取流式 JSON 输出
scanner := bufio.NewScanner(stdout)
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) == "" {
continue
}
fmt.Printf("[ClaudeCLI DEBUG] Line: %s\n", line[:min(len(line), 200)])
var event cliStreamEvent
if err := json.Unmarshal([]byte(line), &event); err != nil {
fmt.Printf("[ClaudeCLI DEBUG] Non-JSON line: %s\n", line)
continue
}
switch event.Type {
case "assistant":
// 助手消息开始或文本内容
if event.Message.Content != nil {
for _, block := range event.Message.Content {
if block.Type == "text" && block.Text != "" {
callback(ai.StreamChunk{Content: block.Text})
}
}
}
case "content_block_delta":
// 增量文本
if event.Delta.Text != "" {
callback(ai.StreamChunk{Content: event.Delta.Text})
}
case "result":
// 最终结果事件 — 不发送 contentassistant 事件已包含),只标记完成
callback(ai.StreamChunk{Done: true})
_ = cmd.Wait()
return nil
case "error":
callback(ai.StreamChunk{Error: event.Error.Message, Done: true})
_ = cmd.Wait()
return nil
}
}
waitErr := cmd.Wait()
stderrStr := strings.TrimSpace(stderrBuf.String())
fmt.Printf("[ClaudeCLI DEBUG] Process exited. stderr: %s\n", stderrStr)
if waitErr != nil {
errMsg := fmt.Sprintf("claude CLI 异常退出: %v", waitErr)
if stderrStr != "" {
errMsg = fmt.Sprintf("claude CLI 异常退出: %s", stderrStr)
}
callback(ai.StreamChunk{Error: errMsg, Done: true})
return nil
}
callback(ai.StreamChunk{Done: true})
return nil
}
// setEnv 设置 Claude CLI 的环境变量
func (p *ClaudeCLIProvider) setEnv(cmd *exec.Cmd) {
env := cmd.Environ()
if p.config.BaseURL != "" {
baseURL := strings.TrimRight(p.config.BaseURL, "/")
env = append(env, "ANTHROPIC_BASE_URL="+baseURL)
}
if p.config.APIKey != "" {
env = append(env, "ANTHROPIC_API_KEY="+p.config.APIKey)
}
cmd.Env = env
}
// buildPrompt 将消息列表拼接为适合 claude -p 的提示文本
func buildPrompt(messages []ai.Message) string {
if len(messages) == 1 {
return messages[0].Content
}
var sb strings.Builder
for _, m := range messages {
switch m.Role {
case "system":
sb.WriteString("[System]\n")
sb.WriteString(m.Content)
sb.WriteString("\n\n")
case "user":
sb.WriteString(m.Content)
sb.WriteString("\n\n")
case "assistant":
sb.WriteString("[Previous Assistant Response]\n")
sb.WriteString(m.Content)
sb.WriteString("\n\n")
}
}
return strings.TrimSpace(sb.String())
}
// cliStreamEvent Claude CLI stream-json 输出的事件结构
type cliStreamEvent struct {
Type string `json:"type"`
Message struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
} `json:"message,omitempty"`
Delta struct {
Text string `json:"text"`
} `json:"delta,omitempty"`
Result string `json:"result,omitempty"`
Error struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}

View File

@@ -0,0 +1,74 @@
package provider
import (
"context"
"fmt"
"strings"
"GoNavi-Wails/internal/ai"
)
// CustomProvider 自定义 Provider根据 apiFormat 选择底层协议
// 支持 openai / anthropic / gemini 三种 API 格式
type CustomProvider struct {
inner Provider
name string
}
// NewCustomProvider 创建自定义 Provider 实例
func NewCustomProvider(config ai.ProviderConfig) (Provider, error) {
if strings.TrimSpace(config.BaseURL) == "" {
return nil, fmt.Errorf("自定义 Provider 必须指定 Base URL")
}
// 根据 apiFormat 决定使用哪个底层协议,默认 openai
apiFormat := strings.ToLower(strings.TrimSpace(config.APIFormat))
if apiFormat == "" {
apiFormat = "openai"
}
var innerProvider Provider
var err error
switch apiFormat {
case "anthropic":
innerProvider, err = NewAnthropicProvider(config)
case "gemini":
innerProvider, err = NewGeminiProvider(config)
case "claude-cli":
innerProvider, err = NewClaudeCLIProvider(config)
default: // "openai" 及其他
innerProvider, err = NewOpenAIProvider(config)
}
if err != nil {
return nil, err
}
name := strings.TrimSpace(config.Name)
if name == "" {
name = "Custom"
}
return &CustomProvider{
inner: innerProvider,
name: name,
}, nil
}
func (p *CustomProvider) Name() string {
return p.name
}
func (p *CustomProvider) Validate() error {
if strings.TrimSpace(p.inner.(interface{ Name() string }).Name()) == "" {
// 对自定义 ProviderAPI Key 可选(部分本地服务不需要)
}
return nil
}
func (p *CustomProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
return p.inner.Chat(ctx, req)
}
func (p *CustomProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
return p.inner.ChatStream(ctx, req, callback)
}

View File

@@ -0,0 +1,267 @@
package provider
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"GoNavi-Wails/internal/ai"
)
const (
defaultGeminiBaseURL = "https://generativelanguage.googleapis.com"
defaultGeminiModel = "gemini-2.0-flash"
)
// GeminiProvider 实现 Google Gemini API 的 Provider
type GeminiProvider struct {
config ai.ProviderConfig
baseURL string
client *http.Client
}
// NewGeminiProvider 创建 Gemini Provider 实例
func NewGeminiProvider(config ai.ProviderConfig) (Provider, error) {
baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
if baseURL == "" {
baseURL = defaultGeminiBaseURL
}
model := strings.TrimSpace(config.Model)
if model == "" {
model = defaultGeminiModel
}
maxTokens := config.MaxTokens
if maxTokens <= 0 {
maxTokens = defaultOpenAIMaxTokens
}
temperature := config.Temperature
if temperature <= 0 {
temperature = defaultOpenAITemperature
}
normalized := config
normalized.BaseURL = baseURL
normalized.Model = model
normalized.MaxTokens = maxTokens
normalized.Temperature = temperature
return &GeminiProvider{
config: normalized,
baseURL: baseURL,
client: &http.Client{Timeout: openAIHTTPTimeout},
}, nil
}
func (p *GeminiProvider) Name() string {
if strings.TrimSpace(p.config.Name) != "" {
return p.config.Name
}
return "Gemini"
}
func (p *GeminiProvider) Validate() error {
if strings.TrimSpace(p.config.APIKey) == "" {
return fmt.Errorf("API Key 不能为空")
}
return nil
}
type geminiRequest struct {
Contents []geminiContent `json:"contents"`
SystemInstruction *geminiContent `json:"systemInstruction,omitempty"`
GenerationConfig geminiGenConfig `json:"generationConfig,omitempty"`
}
type geminiContent struct {
Role string `json:"role,omitempty"`
Parts []geminiPart `json:"parts"`
}
type geminiPart struct {
Text string `json:"text"`
}
type geminiGenConfig struct {
Temperature float64 `json:"temperature,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
}
type geminiResponse struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"content"`
} `json:"candidates"`
UsageMetadata *struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
} `json:"usageMetadata"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
func (p *GeminiProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
if err := p.Validate(); err != nil {
return nil, err
}
geminiReq := p.buildRequest(req)
url := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s",
p.baseURL, p.config.Model, p.config.APIKey)
respBody, err := p.doRequest(ctx, url, geminiReq)
if err != nil {
return nil, err
}
defer respBody.Close()
var result geminiResponse
if err := json.NewDecoder(respBody).Decode(&result); err != nil {
return nil, fmt.Errorf("解析 Gemini 响应失败: %w", err)
}
if result.Error != nil && result.Error.Message != "" {
return nil, fmt.Errorf("Gemini API 错误: %s", result.Error.Message)
}
if len(result.Candidates) == 0 || len(result.Candidates[0].Content.Parts) == 0 {
return nil, fmt.Errorf("Gemini 返回空响应")
}
var tokens ai.TokenUsage
if result.UsageMetadata != nil {
tokens = ai.TokenUsage{
PromptTokens: result.UsageMetadata.PromptTokenCount,
CompletionTokens: result.UsageMetadata.CandidatesTokenCount,
TotalTokens: result.UsageMetadata.TotalTokenCount,
}
}
var textParts []string
for _, part := range result.Candidates[0].Content.Parts {
if part.Text != "" {
textParts = append(textParts, part.Text)
}
}
return &ai.ChatResponse{
Content: strings.Join(textParts, ""),
TokensUsed: tokens,
}, nil
}
func (p *GeminiProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
if err := p.Validate(); err != nil {
return err
}
geminiReq := p.buildRequest(req)
url := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s",
p.baseURL, p.config.Model, p.config.APIKey)
respBody, err := p.doRequest(ctx, url, geminiReq)
if err != nil {
return err
}
defer respBody.Close()
scanner := bufio.NewScanner(respBody)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
var chunk geminiResponse
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
if len(chunk.Candidates) > 0 && len(chunk.Candidates[0].Content.Parts) > 0 {
for _, part := range chunk.Candidates[0].Content.Parts {
if part.Text != "" {
callback(ai.StreamChunk{Content: part.Text})
}
}
}
}
callback(ai.StreamChunk{Done: true})
return scanner.Err()
}
func (p *GeminiProvider) buildRequest(req ai.ChatRequest) geminiRequest {
temperature := req.Temperature
if temperature <= 0 {
temperature = p.config.Temperature
}
maxTokens := req.MaxTokens
if maxTokens <= 0 {
maxTokens = p.config.MaxTokens
}
var systemInstruction *geminiContent
var contents []geminiContent
for _, m := range req.Messages {
if m.Role == "system" {
systemInstruction = &geminiContent{
Parts: []geminiPart{{Text: m.Content}},
}
continue
}
role := m.Role
if role == "assistant" {
role = "model"
}
contents = append(contents, geminiContent{
Role: role,
Parts: []geminiPart{{Text: m.Content}},
})
}
return geminiRequest{
Contents: contents,
SystemInstruction: systemInstruction,
GenerationConfig: geminiGenConfig{
Temperature: temperature,
MaxOutputTokens: maxTokens,
},
}
}
func (p *GeminiProvider) doRequest(ctx context.Context, url string, body interface{}) (io.ReadCloser, error) {
jsonBody, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := p.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("发送请求到 Gemini 失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("Gemini API 返回错误 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
}
return resp.Body, nil
}

View File

@@ -0,0 +1,316 @@
package provider
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"GoNavi-Wails/internal/ai"
)
const (
defaultOpenAIBaseURL = "https://api.openai.com/v1"
defaultOpenAIModel = "gpt-4o"
defaultOpenAIMaxTokens = 4096
defaultOpenAITemperature = 0.7
openAIHTTPTimeout = 120 * time.Second
)
// OpenAIProvider 实现 OpenAI / OpenAI 兼容 API 的 Provider
type OpenAIProvider struct {
config ai.ProviderConfig
baseURL string
client *http.Client
}
// NewOpenAIProvider 创建 OpenAI Provider 实例
func NewOpenAIProvider(config ai.ProviderConfig) (Provider, error) {
baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
if baseURL == "" {
baseURL = defaultOpenAIBaseURL
}
// 确保 baseURL 包含 /v1 路径(兼容用户只填域名的情况,如 https://anyrouter.top
if !strings.HasSuffix(baseURL, "/v1") && !strings.Contains(baseURL, "/v1/") {
baseURL = baseURL + "/v1"
}
model := strings.TrimSpace(config.Model)
if model == "" {
model = defaultOpenAIModel
}
maxTokens := config.MaxTokens
if maxTokens <= 0 {
maxTokens = defaultOpenAIMaxTokens
}
temperature := config.Temperature
if temperature <= 0 {
temperature = defaultOpenAITemperature
}
normalized := config
normalized.BaseURL = baseURL
normalized.Model = model
normalized.MaxTokens = maxTokens
normalized.Temperature = temperature
return &OpenAIProvider{
config: normalized,
baseURL: baseURL,
client: &http.Client{
Timeout: openAIHTTPTimeout,
},
}, nil
}
func (p *OpenAIProvider) Name() string {
if strings.TrimSpace(p.config.Name) != "" {
return p.config.Name
}
return "OpenAI"
}
func (p *OpenAIProvider) Validate() error {
if strings.TrimSpace(p.config.APIKey) == "" {
return fmt.Errorf("API Key 不能为空")
}
return nil
}
// openAIChatRequest OpenAI API 请求体
type openAIChatRequest struct {
Model string `json:"model"`
Messages []openAIChatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type openAIChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// openAIChatResponse OpenAI API 响应体
type openAIChatResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
// openAIStreamChunk SSE 流式响应片段
type openAIStreamChunk struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
} `json:"choices"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
func (p *OpenAIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
if err := p.Validate(); err != nil {
return nil, err
}
messages := make([]openAIChatMessage, len(req.Messages))
for i, m := range req.Messages {
messages[i] = openAIChatMessage{Role: m.Role, Content: m.Content}
}
temperature := req.Temperature
if temperature <= 0 {
temperature = p.config.Temperature
}
maxTokens := req.MaxTokens
if maxTokens <= 0 {
maxTokens = p.config.MaxTokens
}
body := openAIChatRequest{
Model: p.config.Model,
Messages: messages,
Temperature: temperature,
MaxTokens: maxTokens,
Stream: false,
}
respBody, err := p.doRequest(ctx, body)
if err != nil {
return nil, err
}
defer respBody.Close()
var result openAIChatResponse
if err := json.NewDecoder(respBody).Decode(&result); err != nil {
return nil, fmt.Errorf("解析 OpenAI 响应失败: %w", err)
}
if result.Error != nil && result.Error.Message != "" {
return nil, fmt.Errorf("OpenAI API 错误: %s", result.Error.Message)
}
if len(result.Choices) == 0 {
return nil, fmt.Errorf("OpenAI 返回空响应")
}
return &ai.ChatResponse{
Content: result.Choices[0].Message.Content,
TokensUsed: ai.TokenUsage{
PromptTokens: result.Usage.PromptTokens,
CompletionTokens: result.Usage.CompletionTokens,
TotalTokens: result.Usage.TotalTokens,
},
}, nil
}
func (p *OpenAIProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
if err := p.Validate(); err != nil {
return err
}
messages := make([]openAIChatMessage, len(req.Messages))
for i, m := range req.Messages {
messages[i] = openAIChatMessage{Role: m.Role, Content: m.Content}
}
temperature := req.Temperature
if temperature <= 0 {
temperature = p.config.Temperature
}
maxTokens := req.MaxTokens
if maxTokens <= 0 {
maxTokens = p.config.MaxTokens
}
body := openAIChatRequest{
Model: p.config.Model,
Messages: messages,
Temperature: temperature,
MaxTokens: maxTokens,
Stream: true,
}
respBody, err := p.doRequest(ctx, body)
if err != nil {
return err
}
defer respBody.Close()
receivedContent := false
scanner := bufio.NewScanner(respBody)
// 增大 scanner buffer防止长行被截断
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
continue
}
if !strings.HasPrefix(line, "data: ") {
// 非 SSE 数据行,可能是错误信息,记录日志
if strings.Contains(line, "error") || strings.Contains(line, "Error") {
callback(ai.StreamChunk{Error: fmt.Sprintf("服务端返回异常: %s", line), Done: true})
return nil
}
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
callback(ai.StreamChunk{Done: true})
return nil
}
var chunk openAIStreamChunk
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue // 跳过格式异常的行
}
if chunk.Error != nil && chunk.Error.Message != "" {
callback(ai.StreamChunk{Error: fmt.Sprintf("API 错误: %s", chunk.Error.Message), Done: true})
return nil
}
if len(chunk.Choices) > 0 {
content := chunk.Choices[0].Delta.Content
if content != "" {
receivedContent = true
callback(ai.StreamChunk{Content: content})
}
if chunk.Choices[0].FinishReason != nil {
callback(ai.StreamChunk{Done: true})
return nil
}
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("读取 OpenAI 流式响应失败: %w", err)
}
// 如果流正常结束但没有收到任何内容,可能是 API 响应格式不兼容
if !receivedContent {
callback(ai.StreamChunk{Error: "未收到任何有效响应内容,请检查 API 端点和模型是否正确", Done: true})
return nil
}
callback(ai.StreamChunk{Done: true})
return nil
}
func (p *OpenAIProvider) doRequest(ctx context.Context, body interface{}) (io.ReadCloser, error) {
jsonBody, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
url := p.baseURL + "/chat/completions"
// 调试日志
bodyStr := string(jsonBody)
if len(bodyStr) > 500 {
bodyStr = bodyStr[:500] + "..."
}
fmt.Printf("[OpenAI DEBUG] URL: %s\n", url)
fmt.Printf("[OpenAI DEBUG] BaseURL: %s\n", p.baseURL)
fmt.Printf("[OpenAI DEBUG] Body: %s\n", bodyStr)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+p.config.APIKey)
// 自定义 headers用于兼容各类 OpenAI 兼容服务)
for k, v := range p.config.Headers {
httpReq.Header.Set(k, v)
}
resp, err := p.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("发送请求到 %s 失败: %w", url, err)
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("OpenAI API 返回错误 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
}
return resp.Body, nil
}

View File

@@ -0,0 +1,86 @@
package provider
import (
"GoNavi-Wails/internal/ai"
"testing"
)
func TestOpenAIProvider_Validate_MissingAPIKey(t *testing.T) {
p, err := NewOpenAIProvider(ai.ProviderConfig{Type: "openai", Model: "gpt-4o"})
if err != nil {
t.Fatalf("unexpected constructor error: %v", err)
}
if err := p.Validate(); err == nil {
t.Fatal("expected validation error for missing API key")
}
}
func TestOpenAIProvider_Validate_Valid(t *testing.T) {
p, err := NewOpenAIProvider(ai.ProviderConfig{
Type: "openai", APIKey: "sk-test-key", Model: "gpt-4o",
})
if err != nil {
t.Fatalf("unexpected constructor error: %v", err)
}
if err := p.Validate(); err != nil {
t.Fatalf("unexpected validation error: %v", err)
}
}
func TestOpenAIProvider_Name_Custom(t *testing.T) {
p, _ := NewOpenAIProvider(ai.ProviderConfig{
Type: "openai", Name: "My OpenAI", APIKey: "sk-test",
})
if p.Name() != "My OpenAI" {
t.Fatalf("expected name 'My OpenAI', got '%s'", p.Name())
}
}
func TestOpenAIProvider_Name_Default(t *testing.T) {
p, _ := NewOpenAIProvider(ai.ProviderConfig{
Type: "openai", APIKey: "sk-test",
})
if p.Name() != "OpenAI" {
t.Fatalf("expected default name 'OpenAI', got '%s'", p.Name())
}
}
func TestOpenAIProvider_DefaultBaseURL(t *testing.T) {
p, _ := NewOpenAIProvider(ai.ProviderConfig{
Type: "openai", APIKey: "sk-test", Model: "gpt-4o",
})
op := p.(*OpenAIProvider)
if op.baseURL != "https://api.openai.com/v1" {
t.Fatalf("expected default base URL, got '%s'", op.baseURL)
}
}
func TestOpenAIProvider_CustomBaseURL(t *testing.T) {
p, _ := NewOpenAIProvider(ai.ProviderConfig{
Type: "openai", APIKey: "sk-test", BaseURL: "https://my-proxy.com/v1",
})
op := p.(*OpenAIProvider)
if op.baseURL != "https://my-proxy.com/v1" {
t.Fatalf("expected custom base URL, got '%s'", op.baseURL)
}
}
func TestOpenAIProvider_DefaultModel(t *testing.T) {
p, _ := NewOpenAIProvider(ai.ProviderConfig{
Type: "openai", APIKey: "sk-test",
})
op := p.(*OpenAIProvider)
if op.config.Model != "gpt-4o" {
t.Fatalf("expected default model 'gpt-4o', got '%s'", op.config.Model)
}
}
func TestOpenAIProvider_DefaultMaxTokens(t *testing.T) {
p, _ := NewOpenAIProvider(ai.ProviderConfig{
Type: "openai", APIKey: "sk-test",
})
op := p.(*OpenAIProvider)
if op.config.MaxTokens != 4096 {
t.Fatalf("expected default max tokens 4096, got %d", op.config.MaxTokens)
}
}

View File

@@ -0,0 +1,19 @@
package provider
import (
"context"
"GoNavi-Wails/internal/ai"
)
// Provider AI 模型提供者接口
type Provider interface {
// Chat 发送消息并获取完整响应
Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error)
// ChatStream 发送消息并以流式返回
ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error
// Name 返回 Provider 名称
Name() string
// Validate 校验配置是否有效
Validate() error
}

View File

@@ -0,0 +1,25 @@
package provider
import (
"fmt"
"strings"
"GoNavi-Wails/internal/ai"
)
// NewProvider 根据配置创建 Provider 实例
func NewProvider(config ai.ProviderConfig) (Provider, error) {
providerType := strings.ToLower(strings.TrimSpace(config.Type))
switch providerType {
case "openai":
return NewOpenAIProvider(config)
case "anthropic":
return NewAnthropicProvider(config)
case "gemini":
return NewGeminiProvider(config)
case "custom":
return NewCustomProvider(config)
default:
return nil, fmt.Errorf("不支持的 AI Provider 类型: %s", config.Type)
}
}

View File

@@ -0,0 +1,101 @@
package safety
import (
"strings"
"unicode"
"GoNavi-Wails/internal/ai"
)
// ClassifySQL 分类 SQL 语句的操作类型
func ClassifySQL(sql string) ai.SQLOperationType {
keyword := leadingSQLKeyword(sql)
switch keyword {
case "select", "with", "show", "describe", "desc", "explain", "pragma", "values":
return ai.SQLOpQuery
case "insert", "update", "delete", "replace", "merge", "upsert":
return ai.SQLOpDML
case "create", "alter", "drop", "truncate", "rename":
return ai.SQLOpDDL
default:
return ai.SQLOpOther
}
}
// IsHighRiskSQL 判断 SQL 是否为高风险语句
func IsHighRiskSQL(sql string) (bool, string) {
keyword := leadingSQLKeyword(sql)
normalized := strings.ToLower(sql)
switch keyword {
case "drop":
return true, "⚠️ 高危操作DROP 语句将永久删除数据库对象"
case "truncate":
return true, "⚠️ 高危操作TRUNCATE 将清空表中所有数据"
case "delete":
if !containsWhereClause(normalized) {
return true, "⚠️ 高危操作DELETE 语句缺少 WHERE 条件,将删除所有数据"
}
case "update":
if !containsWhereClause(normalized) {
return true, "⚠️ 高危操作UPDATE 语句缺少 WHERE 条件,将更新所有记录"
}
}
return false, ""
}
// containsWhereClause 简单判断 SQL 是否包含 WHERE 子句
func containsWhereClause(normalizedSQL string) bool {
return strings.Contains(normalizedSQL, " where ") ||
strings.Contains(normalizedSQL, "\nwhere ") ||
strings.Contains(normalizedSQL, "\twhere ")
}
// leadingSQLKeyword 提取 SQL 语句的首个关键字(跳过注释和空白)
func leadingSQLKeyword(query string) string {
text := strings.TrimSpace(query)
for len(text) > 0 {
trimmed := strings.TrimLeft(text, " \t\r\n")
if trimmed == "" {
return ""
}
text = trimmed
switch {
case strings.HasPrefix(text, "--"):
if idx := strings.IndexByte(text, '\n'); idx >= 0 {
text = text[idx+1:]
continue
}
return ""
case strings.HasPrefix(text, "#"):
if idx := strings.IndexByte(text, '\n'); idx >= 0 {
text = text[idx+1:]
continue
}
return ""
case strings.HasPrefix(text, "/*"):
if idx := strings.Index(text, "*/"); idx >= 0 {
text = text[idx+2:]
continue
}
return ""
}
break
}
if text == "" {
return ""
}
for i, r := range text {
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' {
continue
}
if i == 0 {
return ""
}
return strings.ToLower(text[:i])
}
return strings.ToLower(text)
}

View File

@@ -0,0 +1,145 @@
package safety
import (
"GoNavi-Wails/internal/ai"
"testing"
)
func TestClassifySQL(t *testing.T) {
tests := []struct {
sql string
want ai.SQLOperationType
}{
{"SELECT * FROM users", ai.SQLOpQuery},
{" select id from t", ai.SQLOpQuery},
{"SHOW TABLES", ai.SQLOpQuery},
{"DESCRIBE users", ai.SQLOpQuery},
{"DESC users", ai.SQLOpQuery},
{"EXPLAIN SELECT 1", ai.SQLOpQuery},
{"WITH cte AS (SELECT 1) SELECT * FROM cte", ai.SQLOpQuery},
{"PRAGMA table_info(t)", ai.SQLOpQuery},
{"VALUES (1, 2)", ai.SQLOpQuery},
{"INSERT INTO users VALUES (1)", ai.SQLOpDML},
{"UPDATE users SET name='x'", ai.SQLOpDML},
{"DELETE FROM users WHERE id=1", ai.SQLOpDML},
{"REPLACE INTO users VALUES (1)", ai.SQLOpDML},
{"MERGE INTO t USING s ON t.id=s.id", ai.SQLOpDML},
{"CREATE TABLE t (id INT)", ai.SQLOpDDL},
{"ALTER TABLE t ADD col INT", ai.SQLOpDDL},
{"DROP TABLE t", ai.SQLOpDDL},
{"TRUNCATE TABLE t", ai.SQLOpDDL},
{"RENAME TABLE old TO new", ai.SQLOpDDL},
{"/* comment */ SELECT 1", ai.SQLOpQuery},
{"-- comment\nDELETE FROM t", ai.SQLOpDML},
{"-- line1\n-- line2\nSELECT 1", ai.SQLOpQuery},
{"/* block */ -- line\nUPDATE t SET x=1", ai.SQLOpDML},
{"", ai.SQLOpOther},
{" ", ai.SQLOpOther},
{"-- only comment", ai.SQLOpOther},
}
for _, tt := range tests {
got := ClassifySQL(tt.sql)
if got != tt.want {
t.Errorf("ClassifySQL(%q) = %s, want %s", tt.sql, got, tt.want)
}
}
}
func TestIsHighRiskSQL(t *testing.T) {
tests := []struct {
sql string
highRisk bool
}{
{"DROP TABLE users", true},
{"DROP DATABASE test", true},
{"TRUNCATE TABLE users", true},
{"DELETE FROM users", true}, // 无 WHERE
{"DELETE FROM users WHERE id=1", false}, // 有 WHERE
{"UPDATE users SET name='x'", true}, // 无 WHERE
{"UPDATE users SET name='x' WHERE id=1", false}, // 有 WHERE
{"SELECT * FROM users", false},
{"INSERT INTO users VALUES (1)", false},
}
for _, tt := range tests {
highRisk, _ := IsHighRiskSQL(tt.sql)
if highRisk != tt.highRisk {
t.Errorf("IsHighRiskSQL(%q) = %v, want %v", tt.sql, highRisk, tt.highRisk)
}
}
}
func TestGuard_ReadOnly(t *testing.T) {
g := NewGuard(ai.PermissionReadOnly)
tests := []struct {
sql string
allowed bool
}{
{"SELECT * FROM t", true},
{"INSERT INTO t VALUES (1)", false},
{"UPDATE t SET x=1", false},
{"DELETE FROM t", false},
{"DROP TABLE t", false},
{"CREATE TABLE t (id INT)", false},
}
for _, tt := range tests {
result := g.Check(tt.sql)
if result.Allowed != tt.allowed {
t.Errorf("Guard[readonly].Check(%q).Allowed = %v, want %v", tt.sql, result.Allowed, tt.allowed)
}
}
}
func TestGuard_ReadWrite(t *testing.T) {
g := NewGuard(ai.PermissionReadWrite)
tests := []struct {
sql string
allowed bool
confirm bool
}{
{"SELECT * FROM t", true, false},
{"INSERT INTO t VALUES (1)", true, true},
{"UPDATE t SET x=1", true, true}, // 允许但需确认
{"DELETE FROM t WHERE id=1", true, true}, // 允许但需确认
{"DROP TABLE t", false, true}, // DDL 不允许
{"CREATE TABLE t (id INT)", false, true},
}
for _, tt := range tests {
result := g.Check(tt.sql)
if result.Allowed != tt.allowed {
t.Errorf("Guard[readwrite].Check(%q).Allowed = %v, want %v", tt.sql, result.Allowed, tt.allowed)
}
if result.RequiresConfirm != tt.confirm {
t.Errorf("Guard[readwrite].Check(%q).RequiresConfirm = %v, want %v", tt.sql, result.RequiresConfirm, tt.confirm)
}
}
}
func TestGuard_Full(t *testing.T) {
g := NewGuard(ai.PermissionFull)
tests := []struct {
sql string
allowed bool
}{
{"SELECT * FROM t", true},
{"INSERT INTO t VALUES (1)", true},
{"DROP TABLE t", true},
{"CREATE TABLE t (id INT)", true},
}
for _, tt := range tests {
result := g.Check(tt.sql)
if result.Allowed != tt.allowed {
t.Errorf("Guard[full].Check(%q).Allowed = %v, want %v", tt.sql, result.Allowed, tt.allowed)
}
}
}
func TestGuard_HighRiskWarning(t *testing.T) {
g := NewGuard(ai.PermissionFull)
result := g.Check("DELETE FROM users")
if result.WarningMessage == "" {
t.Error("expected high-risk warning for DELETE without WHERE")
}
if !result.RequiresConfirm {
t.Error("expected RequiresConfirm for high-risk SQL")
}
}

View File

@@ -0,0 +1,71 @@
package safety
import (
"GoNavi-Wails/internal/ai"
)
// Guard AI SQL 安全策略守卫
type Guard struct {
permissionLevel ai.SQLPermissionLevel
}
// NewGuard 创建安全策略守卫
func NewGuard(level ai.SQLPermissionLevel) *Guard {
return &Guard{permissionLevel: level}
}
// SetPermissionLevel 设置权限级别
func (g *Guard) SetPermissionLevel(level ai.SQLPermissionLevel) {
g.permissionLevel = level
}
// GetPermissionLevel 获取当前权限级别
func (g *Guard) GetPermissionLevel() ai.SQLPermissionLevel {
return g.permissionLevel
}
// Check 检查 AI 生成的 SQL 是否在允许范围内
func (g *Guard) Check(sql string) ai.SafetyResult {
opType := ClassifySQL(sql)
allowed := g.isAllowed(opType)
requiresConfirm := g.requiresConfirmation(opType)
warningMessage := ""
if isHighRisk, msg := IsHighRiskSQL(sql); isHighRisk {
warningMessage = msg
requiresConfirm = true
}
return ai.SafetyResult{
Allowed: allowed,
OperationType: opType,
RequiresConfirm: requiresConfirm,
WarningMessage: warningMessage,
}
}
func (g *Guard) isAllowed(opType ai.SQLOperationType) bool {
switch g.permissionLevel {
case ai.PermissionReadOnly:
return opType == ai.SQLOpQuery
case ai.PermissionReadWrite:
return opType == ai.SQLOpQuery || opType == ai.SQLOpDML
case ai.PermissionFull:
return opType == ai.SQLOpQuery || opType == ai.SQLOpDML || opType == ai.SQLOpDDL
default:
return opType == ai.SQLOpQuery
}
}
func (g *Guard) requiresConfirmation(opType ai.SQLOperationType) bool {
switch opType {
case ai.SQLOpQuery:
return false
case ai.SQLOpDML:
return true // DML 始终需要确认
case ai.SQLOpDDL:
return true // DDL 始终需要确认
default:
return true
}
}

View File

@@ -0,0 +1,573 @@
package aiservice
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"GoNavi-Wails/internal/ai"
aicontext "GoNavi-Wails/internal/ai/context"
"GoNavi-Wails/internal/ai/provider"
"GoNavi-Wails/internal/ai/safety"
"GoNavi-Wails/internal/logger"
"github.com/google/uuid"
wailsRuntime "github.com/wailsapp/wails/v2/pkg/runtime"
)
// Service AI 服务,作为 Wails Binding 暴露给前端
type Service struct {
ctx context.Context
mu sync.RWMutex
providers []ai.ProviderConfig
activeProvider string // active provider ID
safetyLevel ai.SQLPermissionLevel
contextLevel ai.ContextLevel
guard *safety.Guard
configDir string // 配置存储目录
cancelFuncs map[string]context.CancelFunc // 记录每个 session 的 context 取消函数
}
// NewService 创建 AI Service 实例
func NewService() *Service {
return &Service{
providers: make([]ai.ProviderConfig, 0),
safetyLevel: ai.PermissionReadOnly,
contextLevel: ai.ContextSchemaOnly,
guard: safety.NewGuard(ai.PermissionReadOnly),
cancelFuncs: make(map[string]context.CancelFunc),
}
}
// Startup Wails 生命周期回调
func (s *Service) Startup(ctx context.Context) {
s.ctx = ctx
s.configDir = resolveConfigDir()
s.loadConfig()
logger.Infof("AI Service 启动完成,已加载 %d 个 Provider", len(s.providers))
}
// --- Provider 管理 ---
// AIGetProviders 获取所有 Provider 配置
func (s *Service) AIGetProviders() []ai.ProviderConfig {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]ai.ProviderConfig, len(s.providers))
copy(result, s.providers)
return result
}
// AISaveProvider 保存/更新 Provider 配置
func (s *Service) AISaveProvider(config ai.ProviderConfig) error {
fmt.Printf("[AISaveProvider DEBUG] ID: %s, Model: %s\n", config.ID, config.Model)
s.mu.Lock()
defer s.mu.Unlock()
if strings.TrimSpace(config.ID) == "" {
config.ID = "provider-" + uuid.New().String()[:8]
}
found := false
for i, p := range s.providers {
if p.ID == config.ID {
s.providers[i] = config
found = true
break
}
}
if !found {
s.providers = append(s.providers, config)
}
return s.saveConfig()
}
// AIDeleteProvider 删除 Provider
func (s *Service) AIDeleteProvider(id string) error {
s.mu.Lock()
defer s.mu.Unlock()
newProviders := make([]ai.ProviderConfig, 0, len(s.providers))
for _, p := range s.providers {
if p.ID != id {
newProviders = append(newProviders, p)
}
}
s.providers = newProviders
if s.activeProvider == id {
s.activeProvider = ""
if len(s.providers) > 0 {
s.activeProvider = s.providers[0].ID
}
}
return s.saveConfig()
}
// AITestProvider 测试 Provider 配置是否可用
func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{} {
// 如果传入脱敏的 key使用已保存的 key
s.mu.RLock()
if isMaskedAPIKey(config.APIKey) {
for _, p := range s.providers {
if p.ID == config.ID {
config.APIKey = p.APIKey
break
}
}
}
s.mu.RUnlock()
p, err := provider.NewProvider(config)
if err != nil {
return map[string]interface{}{"success": false, "message": err.Error()}
}
if err := p.Validate(); err != nil {
return map[string]interface{}{"success": false, "message": err.Error()}
}
ctx, cancel := context.WithTimeout(context.Background(), 30*1000*1000*1000) // 30s
defer cancel()
resp, err := p.Chat(ctx, ai.ChatRequest{
Messages: []ai.Message{
{Role: "user", Content: "Hi, please respond with 'OK' to confirm the connection is working."},
},
MaxTokens: 10,
})
if err != nil {
return map[string]interface{}{"success": false, "message": fmt.Sprintf("连接测试失败: %s", err.Error())}
}
return map[string]interface{}{
"success": true,
"message": fmt.Sprintf("连接成功!模型响应: %s", truncateString(resp.Content, 100)),
}
}
// AISetActiveProvider 设置活动 Provider
func (s *Service) AISetActiveProvider(id string) {
s.mu.Lock()
defer s.mu.Unlock()
s.activeProvider = id
_ = s.saveConfig()
}
// AIGetActiveProvider 获取活动 Provider ID
func (s *Service) AIGetActiveProvider() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.activeProvider
}
// AIGetBuiltinPrompts 返回内部置的各类系统提示词,用于前端展示或查询
func (s *Service) AIGetBuiltinPrompts() map[string]string {
return aicontext.GetBuiltinPrompts()
}
// AIListModels 获取当前活跃 Provider 的可用模型列表
func (s *Service) AIListModels() map[string]interface{} {
s.mu.RLock()
var config ai.ProviderConfig
found := false
for _, p := range s.providers {
if p.ID == s.activeProvider {
config = p
found = true
break
}
}
s.mu.RUnlock()
if !found {
return map[string]interface{}{"success": false, "models": []string{}, "error": "未找到活跃 Provider"}
}
models, err := fetchModels(config)
if err != nil {
// 回退到配置中的静态模型列表
if len(config.Models) > 0 {
return map[string]interface{}{"success": true, "models": config.Models, "source": "static"}
}
return map[string]interface{}{"success": false, "models": []string{}, "error": err.Error()}
}
return map[string]interface{}{"success": true, "models": models, "source": "api"}
}
// fetchModels 从供应商 API 获取可用模型列表
func fetchModels(config ai.ProviderConfig) ([]string, error) {
providerType := config.Type
if providerType == "custom" && config.APIFormat != "" {
providerType = config.APIFormat
}
switch providerType {
case "openai":
return fetchOpenAIModels(config)
case "anthropic":
// Anthropic 没有公开的 /models 端点,返回硬编码列表
return []string{"claude-opus-4-6", "claude-sonnet-4-6"}, nil
case "gemini":
return fetchGeminiModels(config)
default:
return fetchOpenAIModels(config)
}
}
// fetchOpenAIModels 获取 OpenAI 兼容 API 的模型列表
func fetchOpenAIModels(config ai.ProviderConfig) ([]string, error) {
baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
// 确保 baseURL 以 /v1 结尾
if !strings.HasSuffix(baseURL, "/v1") {
baseURL = baseURL + "/v1"
}
req, err := http.NewRequest("GET", baseURL+"/models", nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+config.APIKey)
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求模型列表失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return nil, fmt.Errorf("获取模型列表失败 (HTTP %d): %s", resp.StatusCode, string(body))
}
var result struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("解析模型列表失败: %w", err)
}
models := make([]string, 0, len(result.Data))
for _, m := range result.Data {
models = append(models, m.ID)
}
return models, nil
}
// fetchGeminiModels 获取 Gemini API 的模型列表
func fetchGeminiModels(config ai.ProviderConfig) ([]string, error) {
baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
req, err := http.NewRequest("GET", baseURL+"/v1beta/models?key="+config.APIKey, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求模型列表失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return nil, fmt.Errorf("获取模型列表失败 (HTTP %d): %s", resp.StatusCode, string(body))
}
var result struct {
Models []struct {
Name string `json:"name"` // e.g. "models/gemini-2.5-flash"
} `json:"models"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("解析模型列表失败: %w", err)
}
models := make([]string, 0, len(result.Models))
for _, m := range result.Models {
// 去掉 "models/" 前缀
name := m.Name
if strings.HasPrefix(name, "models/") {
name = strings.TrimPrefix(name, "models/")
}
models = append(models, name)
}
return models, nil
}
// --- 安全控制 ---
// AIGetSafetyLevel 获取当前安全级别
func (s *Service) AIGetSafetyLevel() string {
s.mu.RLock()
defer s.mu.RUnlock()
return string(s.safetyLevel)
}
// AISetSafetyLevel 设置安全级别
func (s *Service) AISetSafetyLevel(level string) {
s.mu.Lock()
defer s.mu.Unlock()
switch ai.SQLPermissionLevel(level) {
case ai.PermissionReadOnly, ai.PermissionReadWrite, ai.PermissionFull:
s.safetyLevel = ai.SQLPermissionLevel(level)
default:
s.safetyLevel = ai.PermissionReadOnly
}
s.guard.SetPermissionLevel(s.safetyLevel)
_ = s.saveConfig()
}
// --- 上下文控制 ---
// AIGetContextLevel 获取上下文传递级别
func (s *Service) AIGetContextLevel() string {
s.mu.RLock()
defer s.mu.RUnlock()
return string(s.contextLevel)
}
// AISetContextLevel 设置上下文传递级别
func (s *Service) AISetContextLevel(level string) {
s.mu.Lock()
defer s.mu.Unlock()
switch ai.ContextLevel(level) {
case ai.ContextSchemaOnly, ai.ContextWithSamples, ai.ContextWithResults:
s.contextLevel = ai.ContextLevel(level)
default:
s.contextLevel = ai.ContextSchemaOnly
}
_ = s.saveConfig()
}
// --- AI 对话 ---
// AIChatSend 同步发送 AI 对话(非流式)
func (s *Service) AIChatSend(messages []map[string]string) map[string]interface{} {
p, err := s.getActiveProvider()
if err != nil {
return map[string]interface{}{"success": false, "error": err.Error()}
}
var aiMessages []ai.Message
for _, m := range messages {
aiMessages = append(aiMessages, ai.Message{Role: m["role"], Content: m["content"]})
}
resp, err := p.Chat(context.Background(), ai.ChatRequest{Messages: aiMessages})
if err != nil {
return map[string]interface{}{"success": false, "error": err.Error()}
}
return map[string]interface{}{
"success": true,
"content": resp.Content,
"tokensUsed": map[string]int{
"promptTokens": resp.TokensUsed.PromptTokens,
"completionTokens": resp.TokensUsed.CompletionTokens,
"totalTokens": resp.TokensUsed.TotalTokens,
},
}
}
// AIChatStream 流式发送 AI 对话(通过 EventsEmit 推送)
func (s *Service) AIChatStream(sessionID string, messages []map[string]string) {
streamCtx, cancel := context.WithCancel(context.Background())
s.mu.Lock()
s.cancelFuncs[sessionID] = cancel
s.mu.Unlock()
go func() {
defer func() {
s.mu.Lock()
delete(s.cancelFuncs, sessionID)
s.mu.Unlock()
cancel() // 确保释放
}()
p, err := s.getActiveProvider()
if err != nil {
wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
"error": err.Error(),
"done": true,
})
return
}
var aiMessages []ai.Message
for _, m := range messages {
aiMessages = append(aiMessages, ai.Message{Role: m["role"], Content: m["content"]})
}
err = p.ChatStream(streamCtx, ai.ChatRequest{Messages: aiMessages}, func(chunk ai.StreamChunk) {
wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
"content": chunk.Content,
"done": chunk.Done,
"error": chunk.Error,
})
})
// 当 context 被主动 cancel 的时候,不把这个视为向外抛的 error
if err != nil && err != context.Canceled {
wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
"error": err.Error(),
"done": true,
})
}
}()
}
// AIChatCancel 立即终止某个 Session 的流式对话请求
func (s *Service) AIChatCancel(sessionID string) {
s.mu.RLock()
cancel, ok := s.cancelFuncs[sessionID]
s.mu.RUnlock()
if ok && cancel != nil {
cancel()
}
}
// AICheckSQL 检查 SQL 的安全性
func (s *Service) AICheckSQL(sql string) ai.SafetyResult {
s.mu.RLock()
defer s.mu.RUnlock()
return s.guard.Check(sql)
}
// --- 内部方法 ---
func (s *Service) getActiveProvider() (provider.Provider, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.activeProvider == "" && len(s.providers) > 0 {
s.activeProvider = s.providers[0].ID
}
for _, cfg := range s.providers {
if cfg.ID == s.activeProvider {
return provider.NewProvider(cfg)
}
}
return nil, fmt.Errorf("未配置 AI Provider请先在设置中配置")
}
// --- 配置持久化 ---
type aiConfig struct {
Providers []ai.ProviderConfig `json:"providers"`
ActiveProvider string `json:"activeProvider"`
SafetyLevel string `json:"safetyLevel"`
ContextLevel string `json:"contextLevel"`
}
func (s *Service) loadConfig() {
path := filepath.Join(s.configDir, "ai_config.json")
data, err := os.ReadFile(path)
if err != nil {
return // 首次启动,无配置文件
}
var cfg aiConfig
if err := json.Unmarshal(data, &cfg); err != nil {
logger.Error(err, "加载 AI 配置失败")
return
}
s.providers = cfg.Providers
if s.providers == nil {
s.providers = make([]ai.ProviderConfig, 0)
}
s.activeProvider = cfg.ActiveProvider
switch ai.SQLPermissionLevel(cfg.SafetyLevel) {
case ai.PermissionReadOnly, ai.PermissionReadWrite, ai.PermissionFull:
s.safetyLevel = ai.SQLPermissionLevel(cfg.SafetyLevel)
default:
s.safetyLevel = ai.PermissionReadOnly
}
s.guard.SetPermissionLevel(s.safetyLevel)
switch ai.ContextLevel(cfg.ContextLevel) {
case ai.ContextSchemaOnly, ai.ContextWithSamples, ai.ContextWithResults:
s.contextLevel = ai.ContextLevel(cfg.ContextLevel)
default:
s.contextLevel = ai.ContextSchemaOnly
}
}
func (s *Service) saveConfig() error {
cfg := aiConfig{
Providers: s.providers,
ActiveProvider: s.activeProvider,
SafetyLevel: string(s.safetyLevel),
ContextLevel: string(s.contextLevel),
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("序列化 AI 配置失败: %w", err)
}
if err := os.MkdirAll(s.configDir, 0o755); err != nil {
return fmt.Errorf("创建配置目录失败: %w", err)
}
path := filepath.Join(s.configDir, "ai_config.json")
if err := os.WriteFile(path, data, 0o644); err != nil {
return fmt.Errorf("写入 AI 配置失败: %w", err)
}
return nil
}
// --- 工具函数 ---
func resolveConfigDir() string {
configDir, err := os.UserConfigDir()
if err != nil {
configDir = "."
}
return filepath.Join(configDir, "GoNavi")
}
func maskAPIKey(apiKey string) string {
if len(apiKey) <= 8 {
return "****"
}
return apiKey[:4] + "****" + apiKey[len(apiKey)-4:]
}
func isMaskedAPIKey(apiKey string) bool {
return strings.Contains(apiKey, "****")
}
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}

85
internal/ai/types.go Normal file
View File

@@ -0,0 +1,85 @@
package ai
// Message 表示一条对话消息
type Message struct {
Role string `json:"role"` // "system" | "user" | "assistant"
Content string `json:"content"`
}
// ChatRequest AI 对话请求
type ChatRequest struct {
Messages []Message `json:"messages"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"maxTokens"`
}
// ChatResponse AI 对话响应
type ChatResponse struct {
Content string `json:"content"`
TokensUsed TokenUsage `json:"tokensUsed"`
}
// TokenUsage token 用量统计
type TokenUsage struct {
PromptTokens int `json:"promptTokens"`
CompletionTokens int `json:"completionTokens"`
TotalTokens int `json:"totalTokens"`
}
// StreamChunk 流式响应片段
type StreamChunk struct {
Content string `json:"content"`
Done bool `json:"done"`
Error string `json:"error,omitempty"`
}
// ProviderConfig AI Provider 配置
type ProviderConfig struct {
ID string `json:"id"`
Type string `json:"type"` // openai | anthropic | gemini | custom
Name string `json:"name"`
APIKey string `json:"apiKey"`
BaseURL string `json:"baseUrl"`
Model string `json:"model"`
Models []string `json:"models,omitempty"`
APIFormat string `json:"apiFormat,omitempty"` // custom 专用: openai | anthropic | gemini
Headers map[string]string `json:"headers,omitempty"`
MaxTokens int `json:"maxTokens"`
Temperature float64 `json:"temperature"`
}
// SQLPermissionLevel AI SQL 执行权限级别
type SQLPermissionLevel string
const (
PermissionReadOnly SQLPermissionLevel = "readonly"
PermissionReadWrite SQLPermissionLevel = "readwrite"
PermissionFull SQLPermissionLevel = "full"
)
// ContextLevel AI 上下文传递级别
type ContextLevel string
const (
ContextSchemaOnly ContextLevel = "schema_only"
ContextWithSamples ContextLevel = "with_samples"
ContextWithResults ContextLevel = "with_results"
)
// SQLOperationType SQL 操作类型
type SQLOperationType string
const (
SQLOpQuery SQLOperationType = "query" // SELECT, SHOW, DESCRIBE, EXPLAIN
SQLOpDML SQLOperationType = "dml" // INSERT, UPDATE, DELETE
SQLOpDDL SQLOperationType = "ddl" // CREATE, ALTER, DROP, TRUNCATE
SQLOpOther SQLOperationType = "other"
)
// SafetyResult 安全检查结果
type SafetyResult struct {
Allowed bool `json:"allowed"`
OperationType SQLOperationType `json:"operationType"`
RequiresConfirm bool `json:"requiresConfirm"`
WarningMessage string `json:"warningMessage,omitempty"`
}