mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-06 22:49:35 +08:00
✨ feat(ai-chat): 全面升级AI聊天面板并优化交互体验
- 消息管理:新增聊天气泡的重试、编辑与单条删除功能及相对应的持久化状态函数 - 快捷操作:支持长文一键滑动到底端,并在代码块内增加SQL一键送入编辑器的快捷执行机制 - 视觉优化:深化AI回复背景沉浸感,重绘AI洞察按钮并移除设置面板所有的冗余紫色调 - 设置调优:放宽模型初始必填限制,新增内置系统提示词(Builtin Prompt)全览面板
This commit is contained in:
213
internal/ai/context/builder.go
Normal file
213
internal/ai/context/builder.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package aicontext
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PromptTemplate AI 能力类型
|
||||
type PromptTemplate string
|
||||
|
||||
const (
|
||||
PromptSQLGenerate PromptTemplate = "sql_generate"
|
||||
PromptSQLExplain PromptTemplate = "sql_explain"
|
||||
PromptSQLOptimize PromptTemplate = "sql_optimize"
|
||||
PromptDataAnalyze PromptTemplate = "data_analyze"
|
||||
PromptSchemaInsight PromptTemplate = "schema_insight"
|
||||
PromptGeneralChat PromptTemplate = "general_chat"
|
||||
)
|
||||
|
||||
// GetBuiltinPrompts 获取所有内置系统提示词集合,用于前端展示
|
||||
func GetBuiltinPrompts() map[string]string {
|
||||
return map[string]string{
|
||||
"通用聊天助手": buildGeneralChatPrompt(),
|
||||
"SQL 生成器": buildSQLGeneratePrompt(),
|
||||
"SQL 解析器": buildSQLExplainPrompt(),
|
||||
"SQL 优化器": buildSQLOptimizePrompt(),
|
||||
"数据洞察分析": buildDataAnalyzePrompt(),
|
||||
"表结构审查": buildSchemaInsightPrompt(),
|
||||
}
|
||||
}
|
||||
|
||||
// BuildSystemPrompt 根据模板类型和上下文构建 System Prompt
|
||||
func BuildSystemPrompt(template PromptTemplate, dbCtx *DatabaseContext) string {
|
||||
var prompt string
|
||||
|
||||
switch template {
|
||||
case PromptSQLGenerate:
|
||||
prompt = buildSQLGeneratePrompt()
|
||||
case PromptSQLExplain:
|
||||
prompt = buildSQLExplainPrompt()
|
||||
case PromptSQLOptimize:
|
||||
prompt = buildSQLOptimizePrompt()
|
||||
case PromptDataAnalyze:
|
||||
prompt = buildDataAnalyzePrompt()
|
||||
case PromptSchemaInsight:
|
||||
prompt = buildSchemaInsightPrompt()
|
||||
case PromptGeneralChat:
|
||||
prompt = buildGeneralChatPrompt()
|
||||
default:
|
||||
prompt = buildGeneralChatPrompt()
|
||||
}
|
||||
|
||||
if dbCtx != nil {
|
||||
prompt += "\n\n" + FormatDatabaseContext(dbCtx)
|
||||
}
|
||||
|
||||
return prompt
|
||||
}
|
||||
|
||||
// FormatDatabaseContext 将数据库上下文格式化为 LLM 友好的文本
|
||||
func FormatDatabaseContext(ctx *DatabaseContext) string {
|
||||
if ctx == nil || len(ctx.Tables) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("## 当前数据库上下文\n\n数据库类型: %s\n数据库名: %s\n\n",
|
||||
ctx.DatabaseType, ctx.DatabaseName))
|
||||
|
||||
b.WriteString("### 表结构\n\n")
|
||||
for _, table := range ctx.Tables {
|
||||
b.WriteString(fmt.Sprintf("#### 表: %s", table.Name))
|
||||
if table.Comment != "" {
|
||||
b.WriteString(fmt.Sprintf(" (%s)", table.Comment))
|
||||
}
|
||||
if table.RowCount > 0 {
|
||||
b.WriteString(fmt.Sprintf(" [约 %d 行]", table.RowCount))
|
||||
}
|
||||
b.WriteString("\n\n")
|
||||
|
||||
b.WriteString("| 列名 | 类型 | 可空 | 主键 | 备注 |\n")
|
||||
b.WriteString("|------|------|------|------|------|\n")
|
||||
for _, col := range table.Columns {
|
||||
nullable := "否"
|
||||
if col.Nullable {
|
||||
nullable = "是"
|
||||
}
|
||||
pk := ""
|
||||
if col.PrimaryKey {
|
||||
pk = "✓"
|
||||
}
|
||||
comment := col.Comment
|
||||
if comment == "" {
|
||||
comment = "-"
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("| %s | %s | %s | %s | %s |\n",
|
||||
col.Name, col.Type, nullable, pk, comment))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
|
||||
if len(table.Indexes) > 0 {
|
||||
b.WriteString("**索引:**\n")
|
||||
for _, idx := range table.Indexes {
|
||||
unique := ""
|
||||
if idx.Unique {
|
||||
unique = " (唯一)"
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("- %s: [%s]%s\n",
|
||||
idx.Name, strings.Join(idx.Columns, ", "), unique))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
if len(table.SampleRows) > 0 {
|
||||
b.WriteString(fmt.Sprintf("**采样数据 (%d 行):**\n\n", len(table.SampleRows)))
|
||||
if len(table.SampleRows) > 0 {
|
||||
// 使用第一行的 key 作为标题
|
||||
first := table.SampleRows[0]
|
||||
var keys []string
|
||||
for k := range first {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
b.WriteString("| " + strings.Join(keys, " | ") + " |\n")
|
||||
b.WriteString("|" + strings.Repeat("------|", len(keys)) + "\n")
|
||||
for _, row := range table.SampleRows {
|
||||
var vals []string
|
||||
for _, k := range keys {
|
||||
vals = append(vals, fmt.Sprintf("%v", row[k]))
|
||||
}
|
||||
b.WriteString("| " + strings.Join(vals, " | ") + " |\n")
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func buildSQLGeneratePrompt() string {
|
||||
return `你是 GoNavi AI 助手,一位顶级的数据库开发专家和 SQL 查询构建师。根据用户的自然语言需求,生成精准、优雅、高性能的 SQL 查询或 Redis 命令。
|
||||
|
||||
严苛输出规则:
|
||||
1. 首要目标是输出纯粹的代码:始终将代码放在正确语言标识(如 sql 或 bash)的 markdown 代码块中。
|
||||
2. 保持精简:不要添加过多的前置闲聊,直奔主题。
|
||||
3. 保护生产安全:优先使用参数化查询或安全防范写法避免 SQL 注入。对于未指定条件的 DELETE/UPDATE 语句,必须提出强烈的红线警告!!
|
||||
4. 性能至上:对大型查询默认添加合理的 LIMIT 限制(如 LIMIT 100),在 JOIN 和聚合时优先选择最高效的范式写法。
|
||||
5. 适度注释:对于存在复杂逻辑嵌套的代码,请在代码块内使用单行注释简要说明思路。`
|
||||
}
|
||||
|
||||
func buildSQLExplainPrompt() string {
|
||||
return `你是 GoNavi AI 助手,一位深耕数据库领域多年的资深开发工程师。请用专业、条理分明且深入浅出的开发者语言向用户全盘解析 SQL 语句的底层意图与执行逻辑。
|
||||
|
||||
解析规范:
|
||||
1. 宏观逻辑解构:用简短的一句话概括这条 SQL 在业务上想要解决什么问题。
|
||||
2. 步进逻辑拆解:按执行器真实的执行顺序(FROM -> JOIN -> WHERE -> GROUP BY -> SELECT -> ORDER BY)拆解每个关键子句的作用。
|
||||
3. 性能排雷点:敏锐指出可能存在的性能陷阱(如隐式类型转换、没有走索引的函数调用、潜在的笛卡尔积/全表扫描等)。
|
||||
4. 严谨的排版:使用列表呈现关键点,重点词汇加粗,确保长文不累赘。`
|
||||
}
|
||||
|
||||
func buildSQLOptimizePrompt() string {
|
||||
return `你是 GoNavi AI 助手,一名曾主导过千万级高并发系统的全栈性能工程专家与高级 DBA。请对用户提供的原始 SQL 进行冷酷、精确的诊断并开出性能重构处方。
|
||||
|
||||
诊断与处方要求:
|
||||
1. 性能瓶颈透视:精准点出当前语句死穴(不合理的驱动表、无法利用覆盖索引、多此一举的子查询等)。
|
||||
2. 重构版本的 SQL:如果存在性能提升空间,直接向用户展示彻底优化过的高性能写法,并确保逻辑等价性。
|
||||
3. 剖析原因:不仅要告诉用户“怎么改”,更要说清楚执行器“为什么这样会更快”。
|
||||
4. 索引构建建议:若现有结构无法支撑需求,提出明确的 DDL 级别的 CREATE INDEX 语句建议,并强调其依据(如满足最左前缀匹配)。
|
||||
5. 优先级评估:在回答的最后标注本次优化建议的紧迫性(高:阻断级/锁表风险;中:吞吐量瓶颈;低:长效微调)。`
|
||||
}
|
||||
|
||||
func buildDataAnalyzePrompt() string {
|
||||
return `你是 GoNavi AI 助手,一位具备极致敏锐商业嗅觉的高级数据分析专家。你将审视用户通过查询得到的数据样本,从中提炼出蕴含的真金白银般的信息。
|
||||
|
||||
洞察目标:
|
||||
1. 硬统计:总观数据行数、核心数值指标(极值、平均值、聚合中位数等)的冰冷现实。
|
||||
2. 趋势与异动:如果数据带有时间戳,敏锐捕捉其上升或下降趋势;如果有异类离群值,将其高亮标注。
|
||||
3. 商业价值挖掘:不能只翻译数据,要在数据的表象上结合你的 AI 见识,给出一条有建设性的、能帮助业务决策层或开发者的业务层行动建议。
|
||||
4. 展现格式:你的分析应该是“标题 + 浓缩要点”的极简研报形式,杜绝毫无波澜的流水账。`
|
||||
}
|
||||
|
||||
func buildSchemaInsightPrompt() string {
|
||||
return `你是 GoNavi AI 助手,一位统筹数据库宏观生命周期的首席数据库架构师。在这个环节里,你需要对用户提供的数据库表结构执行最严厉的范式与前瞻性审查。
|
||||
|
||||
审查视界:
|
||||
1. 规范化博弈:是否存在明显的反三范式设计?这种冗余是否有助于性能(适当的反范式),还是纯粹的设计失误?
|
||||
2. 索引健壮性审查:评估主键选择(如自增、UUID 的利弊),是否存在冗余索引阻碍写入?以及是否遗漏了高频的联合索引。
|
||||
3. 物理容量前瞻:审视数据类型分配(如使用过大的 VARCHAR、没必要的 BIGINT 等可能带来的空间挥霍)。
|
||||
4. 代码级指引:如果存在结构性缺陷,不要只发牢骚,直接给出包含具体优化的 ALTER TABLE 结构修改建议脚本。`
|
||||
}
|
||||
|
||||
func buildGeneralChatPrompt() string {
|
||||
return `你是 GoNavi AI 助手,一款深度集成在数据库/缓存客户端(GoNavi)内部的专属智能专家系统。
|
||||
你的目标是成为开发者、DBA 和数据科学家最得力的超级外脑,提供专业、精准、具有前瞻性的数据端解决方案。
|
||||
|
||||
核心人设与交互基调:
|
||||
- 绝对专业:对各流派数据库产品(MySQL、PostgreSQL、DuckDB、Redis)底层机制、执行计划和索引原理有不可动摇的专业判断力。
|
||||
- 直击痛点:谢绝套话与无效寒暄,若用户的意图明确,首屏直接给出可以直接粘贴运行的优雅代码。
|
||||
- 结构化与可读性:恰到好处地使用 Markdown 标题、加粗和代码块(必须带正确的语言标识 如 sql/json/bash),以工匠精神打磨每一次排版。
|
||||
- 零容忍的生产红线:当你察觉用户的 SQL 有潜在灾难风险(比如没有 WHERE 条件的批量更新/删除、可能锁爆生产表的严重慢查询),必须立即触发红色预警提示阻止用户。
|
||||
|
||||
你的综合能力版图:
|
||||
1. 📝 自然语言驱动:翻译人类意图为精准的查询语句。
|
||||
2. 🔍 底层原理解析:剥丝抽茧分析查询背后的执行逻辑与性能隐患。
|
||||
3. ⚡ 专家级调优:指出并化解性能瓶颈,给出覆盖全维度的索引调优思路。
|
||||
4. 📊 数据洞察炼金:不仅聚合数据,更能从结果集中挖掘商业维度的深度规律。
|
||||
5. 🏗️ 架构先知视界:全局审阅表结构设计局限,提出抗数据膨胀级别的架构演进方案。
|
||||
|
||||
互动守则:
|
||||
- 永远使用专业、具有合作感且充满信心的中文与用户探讨问题。
|
||||
- 当被要求提供任何数据库代码时,需结合相关数据库引擎的最佳实践。如果不清楚当前方言版本,请以标准实现为主基调并好心指出版别差异(如 MySQL 8 窗口函数 等)。`
|
||||
}
|
||||
|
||||
42
internal/ai/context/collector.go
Normal file
42
internal/ai/context/collector.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package aicontext
|
||||
|
||||
// DatabaseContext 数据库上下文信息,传递给 AI 辅助上下文理解
|
||||
type DatabaseContext struct {
|
||||
DatabaseType string `json:"databaseType"` // mysql, postgres 等
|
||||
DatabaseName string `json:"databaseName"`
|
||||
Tables []TableContext `json:"tables"`
|
||||
}
|
||||
|
||||
// TableContext 表的上下文信息
|
||||
type TableContext struct {
|
||||
Name string `json:"name"`
|
||||
Comment string `json:"comment,omitempty"`
|
||||
Columns []ColumnInfo `json:"columns"`
|
||||
Indexes []IndexInfo `json:"indexes,omitempty"`
|
||||
SampleRows []map[string]interface{} `json:"sampleRows,omitempty"`
|
||||
RowCount int64 `json:"rowCount,omitempty"`
|
||||
}
|
||||
|
||||
// ColumnInfo 列信息
|
||||
type ColumnInfo struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Nullable bool `json:"nullable"`
|
||||
PrimaryKey bool `json:"primaryKey"`
|
||||
Comment string `json:"comment,omitempty"`
|
||||
}
|
||||
|
||||
// IndexInfo 索引信息
|
||||
type IndexInfo struct {
|
||||
Name string `json:"name"`
|
||||
Columns []string `json:"columns"`
|
||||
Unique bool `json:"unique"`
|
||||
}
|
||||
|
||||
// QueryResultContext 查询结果上下文
|
||||
type QueryResultContext struct {
|
||||
SQL string `json:"sql"`
|
||||
Columns []string `json:"columns"`
|
||||
Rows []map[string]interface{} `json:"rows"`
|
||||
RowCount int `json:"rowCount"`
|
||||
}
|
||||
293
internal/ai/provider/anthropic.go
Normal file
293
internal/ai/provider/anthropic.go
Normal file
@@ -0,0 +1,293 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAnthropicBaseURL = "https://api.anthropic.com"
|
||||
defaultAnthropicModel = "claude-3-5-sonnet-20241022"
|
||||
anthropicAPIVersion = "2023-06-01"
|
||||
)
|
||||
|
||||
// AnthropicProvider 实现 Anthropic Claude API 的 Provider
|
||||
type AnthropicProvider struct {
|
||||
config ai.ProviderConfig
|
||||
baseURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewAnthropicProvider 创建 Anthropic Provider 实例
|
||||
func NewAnthropicProvider(config ai.ProviderConfig) (Provider, error) {
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
|
||||
if baseURL == "" {
|
||||
baseURL = defaultAnthropicBaseURL
|
||||
}
|
||||
model := strings.TrimSpace(config.Model)
|
||||
if model == "" {
|
||||
model = defaultAnthropicModel
|
||||
}
|
||||
maxTokens := config.MaxTokens
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = defaultOpenAIMaxTokens
|
||||
}
|
||||
temperature := config.Temperature
|
||||
if temperature <= 0 {
|
||||
temperature = defaultOpenAITemperature
|
||||
}
|
||||
|
||||
normalized := config
|
||||
normalized.BaseURL = baseURL
|
||||
normalized.Model = model
|
||||
normalized.MaxTokens = maxTokens
|
||||
normalized.Temperature = temperature
|
||||
|
||||
return &AnthropicProvider{
|
||||
config: normalized,
|
||||
baseURL: baseURL,
|
||||
client: &http.Client{Timeout: openAIHTTPTimeout},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *AnthropicProvider) Name() string {
|
||||
if strings.TrimSpace(p.config.Name) != "" {
|
||||
return p.config.Name
|
||||
}
|
||||
return "Anthropic"
|
||||
}
|
||||
|
||||
func (p *AnthropicProvider) Validate() error {
|
||||
if strings.TrimSpace(p.config.APIKey) == "" {
|
||||
return fmt.Errorf("API Key 不能为空")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type anthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []anthropicMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type anthropicResponse struct {
|
||||
Content []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Delta *struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"delta,omitempty"`
|
||||
}
|
||||
|
||||
func (p *AnthropicProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
|
||||
if err := p.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
systemMsg, messages := extractSystemMessage(req.Messages)
|
||||
anthropicMsgs := make([]anthropicMessage, len(messages))
|
||||
for i, m := range messages {
|
||||
anthropicMsgs[i] = anthropicMessage{Role: m.Role, Content: m.Content}
|
||||
}
|
||||
|
||||
temperature := req.Temperature
|
||||
if temperature <= 0 {
|
||||
temperature = p.config.Temperature
|
||||
}
|
||||
maxTokens := req.MaxTokens
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = p.config.MaxTokens
|
||||
}
|
||||
|
||||
body := anthropicRequest{
|
||||
Model: p.config.Model,
|
||||
Messages: anthropicMsgs,
|
||||
System: systemMsg,
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: temperature,
|
||||
}
|
||||
|
||||
respBody, err := p.doRequest(ctx, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer respBody.Close()
|
||||
|
||||
var result anthropicResponse
|
||||
if err := json.NewDecoder(respBody).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析 Anthropic 响应失败: %w", err)
|
||||
}
|
||||
if result.Error != nil && result.Error.Message != "" {
|
||||
return nil, fmt.Errorf("Anthropic API 错误: %s", result.Error.Message)
|
||||
}
|
||||
if len(result.Content) == 0 {
|
||||
return nil, fmt.Errorf("Anthropic 返回空响应")
|
||||
}
|
||||
|
||||
return &ai.ChatResponse{
|
||||
Content: result.Content[0].Text,
|
||||
TokensUsed: ai.TokenUsage{
|
||||
PromptTokens: result.Usage.InputTokens,
|
||||
CompletionTokens: result.Usage.OutputTokens,
|
||||
TotalTokens: result.Usage.InputTokens + result.Usage.OutputTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *AnthropicProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
systemMsg, messages := extractSystemMessage(req.Messages)
|
||||
anthropicMsgs := make([]anthropicMessage, len(messages))
|
||||
for i, m := range messages {
|
||||
anthropicMsgs[i] = anthropicMessage{Role: m.Role, Content: m.Content}
|
||||
}
|
||||
|
||||
temperature := req.Temperature
|
||||
if temperature <= 0 {
|
||||
temperature = p.config.Temperature
|
||||
}
|
||||
maxTokens := req.MaxTokens
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = p.config.MaxTokens
|
||||
}
|
||||
|
||||
body := anthropicRequest{
|
||||
Model: p.config.Model,
|
||||
Messages: anthropicMsgs,
|
||||
System: systemMsg,
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: temperature,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
respBody, err := p.doRequest(ctx, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer respBody.Close()
|
||||
|
||||
scanner := bufio.NewScanner(respBody)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
var event anthropicStreamEvent
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case "content_block_delta":
|
||||
if event.Delta != nil && event.Delta.Text != "" {
|
||||
callback(ai.StreamChunk{Content: event.Delta.Text})
|
||||
}
|
||||
case "message_stop":
|
||||
callback(ai.StreamChunk{Done: true})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
callback(ai.StreamChunk{Done: true})
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func (p *AnthropicProvider) doRequest(ctx context.Context, body interface{}) (io.ReadCloser, error) {
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
url := p.baseURL + "/v1/messages"
|
||||
if strings.HasSuffix(p.baseURL, "/v1") {
|
||||
url = p.baseURL + "/messages"
|
||||
}
|
||||
|
||||
// 调试日志:打印实际请求信息
|
||||
bodyStr := string(jsonBody)
|
||||
if len(bodyStr) > 500 {
|
||||
bodyStr = bodyStr[:500] + "..."
|
||||
}
|
||||
fmt.Printf("[Anthropic DEBUG] URL: %s\n", url)
|
||||
fmt.Printf("[Anthropic DEBUG] BaseURL: %s\n", p.baseURL)
|
||||
fmt.Printf("[Anthropic DEBUG] Body: %s\n", bodyStr)
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("x-api-key", p.config.APIKey)
|
||||
httpReq.Header.Set("anthropic-version", anthropicAPIVersion)
|
||||
|
||||
// 仅官方 API 发 beta 特性头(代理不发,避免触发 Claude Code 验证)
|
||||
isOfficialAPI := p.baseURL == defaultAnthropicBaseURL || strings.Contains(p.baseURL, "anthropic.com")
|
||||
if isOfficialAPI {
|
||||
httpReq.Header.Set("anthropic-beta", "interleaved-thinking-2025-05-14,output-128k-2025-02-19,prompt-caching-2024-07-31")
|
||||
}
|
||||
|
||||
// 自定义 headers(用于兼容各类代理服务)
|
||||
for k, v := range p.config.Headers {
|
||||
httpReq.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := p.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("发送请求到 %s 失败: %w", url, err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("Anthropic API 返回错误 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
// extractSystemMessage 从消息列表中提取 system 消息(Anthropic 要求 system 作为独立字段)
|
||||
func extractSystemMessage(messages []ai.Message) (string, []ai.Message) {
|
||||
var systemParts []string
|
||||
var remaining []ai.Message
|
||||
for _, m := range messages {
|
||||
if m.Role == "system" {
|
||||
systemParts = append(systemParts, m.Content)
|
||||
} else {
|
||||
remaining = append(remaining, m)
|
||||
}
|
||||
}
|
||||
return strings.Join(systemParts, "\n\n"), remaining
|
||||
}
|
||||
227
internal/ai/provider/claude_cli.go
Normal file
227
internal/ai/provider/claude_cli.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
ai "GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
// ClaudeCLIProvider 通过 Claude Code CLI 发送聊天请求
|
||||
// 适用于 anyrouter/newapi 等只支持 Claude Code 协议的代理服务
|
||||
type ClaudeCLIProvider struct {
|
||||
config ai.ProviderConfig
|
||||
}
|
||||
|
||||
// NewClaudeCLIProvider 创建 ClaudeCLIProvider 实例
|
||||
func NewClaudeCLIProvider(config ai.ProviderConfig) (Provider, error) {
|
||||
return &ClaudeCLIProvider{config: config}, nil
|
||||
}
|
||||
|
||||
func (p *ClaudeCLIProvider) Name() string {
|
||||
return "ClaudeCLI"
|
||||
}
|
||||
|
||||
func (p *ClaudeCLIProvider) Validate() error {
|
||||
_, err := exec.LookPath("claude")
|
||||
if err != nil {
|
||||
return fmt.Errorf("未找到 claude 命令,请先安装 Claude Code CLI: npm install -g @anthropic-ai/claude-code")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Chat 非流式聊天:调用 claude -p "prompt" --output-format json
|
||||
func (p *ClaudeCLIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
|
||||
if err := p.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prompt := buildPrompt(req.Messages)
|
||||
args := []string{"-p", prompt, "--output-format", "json", "--no-session-persistence"}
|
||||
if p.config.Model != "" {
|
||||
args = append(args, "--model", p.config.Model)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "claude", args...)
|
||||
p.setEnv(cmd)
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return nil, fmt.Errorf("claude CLI 执行失败: %s", string(exitErr.Stderr))
|
||||
}
|
||||
return nil, fmt.Errorf("claude CLI 执行失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析 JSON 输出
|
||||
var result struct {
|
||||
Result string `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(output, &result); err != nil {
|
||||
// 如果 JSON 解析失败,直接返回原始文本
|
||||
return &ai.ChatResponse{Content: strings.TrimSpace(string(output))}, nil
|
||||
}
|
||||
|
||||
return &ai.ChatResponse{Content: result.Result}, nil
|
||||
}
|
||||
|
||||
// ChatStream 流式聊天:调用 claude -p "prompt" --output-format stream-json
|
||||
func (p *ClaudeCLIProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prompt := buildPrompt(req.Messages)
|
||||
args := []string{"-p", prompt, "--output-format", "stream-json", "--verbose", "--include-partial-messages", "--no-session-persistence"}
|
||||
if p.config.Model != "" {
|
||||
args = append(args, "--model", p.config.Model)
|
||||
}
|
||||
|
||||
fmt.Printf("[ClaudeCLI DEBUG] Running: claude %v\n", args)
|
||||
|
||||
cmd := exec.CommandContext(ctx, "claude", args...)
|
||||
p.setEnv(cmd)
|
||||
|
||||
// 关闭 stdin,防止 claude CLI 等待输入
|
||||
cmd.Stdin = nil
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 stdout 管道失败: %w", err)
|
||||
}
|
||||
|
||||
// 捕获 stderr
|
||||
var stderrBuf bytes.Buffer
|
||||
cmd.Stderr = &stderrBuf
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("启动 claude CLI 失败: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("[ClaudeCLI DEBUG] Process started, PID: %d\n", cmd.Process.Pid)
|
||||
|
||||
// 立即通知前端:AI 正在思考(避免用户以为卡死)
|
||||
callback(ai.StreamChunk{Content: "💭 *正在思考...*\n\n"})
|
||||
|
||||
// 逐行读取流式 JSON 输出
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("[ClaudeCLI DEBUG] Line: %s\n", line[:min(len(line), 200)])
|
||||
|
||||
var event cliStreamEvent
|
||||
if err := json.Unmarshal([]byte(line), &event); err != nil {
|
||||
fmt.Printf("[ClaudeCLI DEBUG] Non-JSON line: %s\n", line)
|
||||
continue
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case "assistant":
|
||||
// 助手消息开始或文本内容
|
||||
if event.Message.Content != nil {
|
||||
for _, block := range event.Message.Content {
|
||||
if block.Type == "text" && block.Text != "" {
|
||||
callback(ai.StreamChunk{Content: block.Text})
|
||||
}
|
||||
}
|
||||
}
|
||||
case "content_block_delta":
|
||||
// 增量文本
|
||||
if event.Delta.Text != "" {
|
||||
callback(ai.StreamChunk{Content: event.Delta.Text})
|
||||
}
|
||||
case "result":
|
||||
// 最终结果事件 — 不发送 content(assistant 事件已包含),只标记完成
|
||||
callback(ai.StreamChunk{Done: true})
|
||||
_ = cmd.Wait()
|
||||
return nil
|
||||
case "error":
|
||||
callback(ai.StreamChunk{Error: event.Error.Message, Done: true})
|
||||
_ = cmd.Wait()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
waitErr := cmd.Wait()
|
||||
stderrStr := strings.TrimSpace(stderrBuf.String())
|
||||
fmt.Printf("[ClaudeCLI DEBUG] Process exited. stderr: %s\n", stderrStr)
|
||||
|
||||
if waitErr != nil {
|
||||
errMsg := fmt.Sprintf("claude CLI 异常退出: %v", waitErr)
|
||||
if stderrStr != "" {
|
||||
errMsg = fmt.Sprintf("claude CLI 异常退出: %s", stderrStr)
|
||||
}
|
||||
callback(ai.StreamChunk{Error: errMsg, Done: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
callback(ai.StreamChunk{Done: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
// setEnv 设置 Claude CLI 的环境变量
|
||||
func (p *ClaudeCLIProvider) setEnv(cmd *exec.Cmd) {
|
||||
env := cmd.Environ()
|
||||
if p.config.BaseURL != "" {
|
||||
baseURL := strings.TrimRight(p.config.BaseURL, "/")
|
||||
env = append(env, "ANTHROPIC_BASE_URL="+baseURL)
|
||||
}
|
||||
if p.config.APIKey != "" {
|
||||
env = append(env, "ANTHROPIC_API_KEY="+p.config.APIKey)
|
||||
}
|
||||
cmd.Env = env
|
||||
}
|
||||
|
||||
// buildPrompt 将消息列表拼接为适合 claude -p 的提示文本
|
||||
func buildPrompt(messages []ai.Message) string {
|
||||
if len(messages) == 1 {
|
||||
return messages[0].Content
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for _, m := range messages {
|
||||
switch m.Role {
|
||||
case "system":
|
||||
sb.WriteString("[System]\n")
|
||||
sb.WriteString(m.Content)
|
||||
sb.WriteString("\n\n")
|
||||
case "user":
|
||||
sb.WriteString(m.Content)
|
||||
sb.WriteString("\n\n")
|
||||
case "assistant":
|
||||
sb.WriteString("[Previous Assistant Response]\n")
|
||||
sb.WriteString(m.Content)
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(sb.String())
|
||||
}
|
||||
|
||||
// cliStreamEvent Claude CLI stream-json 输出的事件结构
|
||||
type cliStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Message struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
} `json:"message,omitempty"`
|
||||
Delta struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"delta,omitempty"`
|
||||
Result string `json:"result,omitempty"`
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
74
internal/ai/provider/custom.go
Normal file
74
internal/ai/provider/custom.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
// CustomProvider 自定义 Provider,根据 apiFormat 选择底层协议
|
||||
// 支持 openai / anthropic / gemini 三种 API 格式
|
||||
type CustomProvider struct {
|
||||
inner Provider
|
||||
name string
|
||||
}
|
||||
|
||||
// NewCustomProvider 创建自定义 Provider 实例
|
||||
func NewCustomProvider(config ai.ProviderConfig) (Provider, error) {
|
||||
if strings.TrimSpace(config.BaseURL) == "" {
|
||||
return nil, fmt.Errorf("自定义 Provider 必须指定 Base URL")
|
||||
}
|
||||
|
||||
// 根据 apiFormat 决定使用哪个底层协议,默认 openai
|
||||
apiFormat := strings.ToLower(strings.TrimSpace(config.APIFormat))
|
||||
if apiFormat == "" {
|
||||
apiFormat = "openai"
|
||||
}
|
||||
|
||||
var innerProvider Provider
|
||||
var err error
|
||||
switch apiFormat {
|
||||
case "anthropic":
|
||||
innerProvider, err = NewAnthropicProvider(config)
|
||||
case "gemini":
|
||||
innerProvider, err = NewGeminiProvider(config)
|
||||
case "claude-cli":
|
||||
innerProvider, err = NewClaudeCLIProvider(config)
|
||||
default: // "openai" 及其他
|
||||
innerProvider, err = NewOpenAIProvider(config)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(config.Name)
|
||||
if name == "" {
|
||||
name = "Custom"
|
||||
}
|
||||
|
||||
return &CustomProvider{
|
||||
inner: innerProvider,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *CustomProvider) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
func (p *CustomProvider) Validate() error {
|
||||
if strings.TrimSpace(p.inner.(interface{ Name() string }).Name()) == "" {
|
||||
// 对自定义 Provider,API Key 可选(部分本地服务不需要)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *CustomProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
|
||||
return p.inner.Chat(ctx, req)
|
||||
}
|
||||
|
||||
func (p *CustomProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
|
||||
return p.inner.ChatStream(ctx, req, callback)
|
||||
}
|
||||
267
internal/ai/provider/gemini.go
Normal file
267
internal/ai/provider/gemini.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultGeminiBaseURL = "https://generativelanguage.googleapis.com"
|
||||
defaultGeminiModel = "gemini-2.0-flash"
|
||||
)
|
||||
|
||||
// GeminiProvider 实现 Google Gemini API 的 Provider
|
||||
type GeminiProvider struct {
|
||||
config ai.ProviderConfig
|
||||
baseURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewGeminiProvider 创建 Gemini Provider 实例
|
||||
func NewGeminiProvider(config ai.ProviderConfig) (Provider, error) {
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
|
||||
if baseURL == "" {
|
||||
baseURL = defaultGeminiBaseURL
|
||||
}
|
||||
model := strings.TrimSpace(config.Model)
|
||||
if model == "" {
|
||||
model = defaultGeminiModel
|
||||
}
|
||||
maxTokens := config.MaxTokens
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = defaultOpenAIMaxTokens
|
||||
}
|
||||
temperature := config.Temperature
|
||||
if temperature <= 0 {
|
||||
temperature = defaultOpenAITemperature
|
||||
}
|
||||
|
||||
normalized := config
|
||||
normalized.BaseURL = baseURL
|
||||
normalized.Model = model
|
||||
normalized.MaxTokens = maxTokens
|
||||
normalized.Temperature = temperature
|
||||
|
||||
return &GeminiProvider{
|
||||
config: normalized,
|
||||
baseURL: baseURL,
|
||||
client: &http.Client{Timeout: openAIHTTPTimeout},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) Name() string {
|
||||
if strings.TrimSpace(p.config.Name) != "" {
|
||||
return p.config.Name
|
||||
}
|
||||
return "Gemini"
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) Validate() error {
|
||||
if strings.TrimSpace(p.config.APIKey) == "" {
|
||||
return fmt.Errorf("API Key 不能为空")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type geminiRequest struct {
|
||||
Contents []geminiContent `json:"contents"`
|
||||
SystemInstruction *geminiContent `json:"systemInstruction,omitempty"`
|
||||
GenerationConfig geminiGenConfig `json:"generationConfig,omitempty"`
|
||||
}
|
||||
|
||||
type geminiContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []geminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type geminiPart struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type geminiGenConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
}
|
||||
|
||||
type geminiResponse struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata *struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
|
||||
if err := p.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
geminiReq := p.buildRequest(req)
|
||||
|
||||
url := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s",
|
||||
p.baseURL, p.config.Model, p.config.APIKey)
|
||||
|
||||
respBody, err := p.doRequest(ctx, url, geminiReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer respBody.Close()
|
||||
|
||||
var result geminiResponse
|
||||
if err := json.NewDecoder(respBody).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析 Gemini 响应失败: %w", err)
|
||||
}
|
||||
if result.Error != nil && result.Error.Message != "" {
|
||||
return nil, fmt.Errorf("Gemini API 错误: %s", result.Error.Message)
|
||||
}
|
||||
if len(result.Candidates) == 0 || len(result.Candidates[0].Content.Parts) == 0 {
|
||||
return nil, fmt.Errorf("Gemini 返回空响应")
|
||||
}
|
||||
|
||||
var tokens ai.TokenUsage
|
||||
if result.UsageMetadata != nil {
|
||||
tokens = ai.TokenUsage{
|
||||
PromptTokens: result.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: result.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: result.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
}
|
||||
|
||||
var textParts []string
|
||||
for _, part := range result.Candidates[0].Content.Parts {
|
||||
if part.Text != "" {
|
||||
textParts = append(textParts, part.Text)
|
||||
}
|
||||
}
|
||||
|
||||
return &ai.ChatResponse{
|
||||
Content: strings.Join(textParts, ""),
|
||||
TokensUsed: tokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
geminiReq := p.buildRequest(req)
|
||||
|
||||
url := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s",
|
||||
p.baseURL, p.config.Model, p.config.APIKey)
|
||||
|
||||
respBody, err := p.doRequest(ctx, url, geminiReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer respBody.Close()
|
||||
|
||||
scanner := bufio.NewScanner(respBody)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
var chunk geminiResponse
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(chunk.Candidates) > 0 && len(chunk.Candidates[0].Content.Parts) > 0 {
|
||||
for _, part := range chunk.Candidates[0].Content.Parts {
|
||||
if part.Text != "" {
|
||||
callback(ai.StreamChunk{Content: part.Text})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
callback(ai.StreamChunk{Done: true})
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) buildRequest(req ai.ChatRequest) geminiRequest {
|
||||
temperature := req.Temperature
|
||||
if temperature <= 0 {
|
||||
temperature = p.config.Temperature
|
||||
}
|
||||
maxTokens := req.MaxTokens
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = p.config.MaxTokens
|
||||
}
|
||||
|
||||
var systemInstruction *geminiContent
|
||||
var contents []geminiContent
|
||||
|
||||
for _, m := range req.Messages {
|
||||
if m.Role == "system" {
|
||||
systemInstruction = &geminiContent{
|
||||
Parts: []geminiPart{{Text: m.Content}},
|
||||
}
|
||||
continue
|
||||
}
|
||||
role := m.Role
|
||||
if role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
contents = append(contents, geminiContent{
|
||||
Role: role,
|
||||
Parts: []geminiPart{{Text: m.Content}},
|
||||
})
|
||||
}
|
||||
|
||||
return geminiRequest{
|
||||
Contents: contents,
|
||||
SystemInstruction: systemInstruction,
|
||||
GenerationConfig: geminiGenConfig{
|
||||
Temperature: temperature,
|
||||
MaxOutputTokens: maxTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) doRequest(ctx context.Context, url string, body interface{}) (io.ReadCloser, error) {
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := p.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("发送请求到 Gemini 失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("Gemini API 返回错误 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
316
internal/ai/provider/openai.go
Normal file
316
internal/ai/provider/openai.go
Normal file
@@ -0,0 +1,316 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultOpenAIBaseURL = "https://api.openai.com/v1"
|
||||
defaultOpenAIModel = "gpt-4o"
|
||||
defaultOpenAIMaxTokens = 4096
|
||||
defaultOpenAITemperature = 0.7
|
||||
openAIHTTPTimeout = 120 * time.Second
|
||||
)
|
||||
|
||||
// OpenAIProvider 实现 OpenAI / OpenAI 兼容 API 的 Provider
|
||||
type OpenAIProvider struct {
|
||||
config ai.ProviderConfig
|
||||
baseURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewOpenAIProvider 创建 OpenAI Provider 实例
|
||||
func NewOpenAIProvider(config ai.ProviderConfig) (Provider, error) {
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
|
||||
if baseURL == "" {
|
||||
baseURL = defaultOpenAIBaseURL
|
||||
}
|
||||
// 确保 baseURL 包含 /v1 路径(兼容用户只填域名的情况,如 https://anyrouter.top)
|
||||
if !strings.HasSuffix(baseURL, "/v1") && !strings.Contains(baseURL, "/v1/") {
|
||||
baseURL = baseURL + "/v1"
|
||||
}
|
||||
model := strings.TrimSpace(config.Model)
|
||||
if model == "" {
|
||||
model = defaultOpenAIModel
|
||||
}
|
||||
maxTokens := config.MaxTokens
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = defaultOpenAIMaxTokens
|
||||
}
|
||||
temperature := config.Temperature
|
||||
if temperature <= 0 {
|
||||
temperature = defaultOpenAITemperature
|
||||
}
|
||||
|
||||
normalized := config
|
||||
normalized.BaseURL = baseURL
|
||||
normalized.Model = model
|
||||
normalized.MaxTokens = maxTokens
|
||||
normalized.Temperature = temperature
|
||||
|
||||
return &OpenAIProvider{
|
||||
config: normalized,
|
||||
baseURL: baseURL,
|
||||
client: &http.Client{
|
||||
Timeout: openAIHTTPTimeout,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) Name() string {
|
||||
if strings.TrimSpace(p.config.Name) != "" {
|
||||
return p.config.Name
|
||||
}
|
||||
return "OpenAI"
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) Validate() error {
|
||||
if strings.TrimSpace(p.config.APIKey) == "" {
|
||||
return fmt.Errorf("API Key 不能为空")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// openAIChatRequest OpenAI API 请求体
|
||||
type openAIChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openAIChatMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type openAIChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// openAIChatResponse OpenAI API 响应体
|
||||
type openAIChatResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// openAIStreamChunk SSE 流式响应片段
|
||||
type openAIStreamChunk struct {
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
|
||||
if err := p.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messages := make([]openAIChatMessage, len(req.Messages))
|
||||
for i, m := range req.Messages {
|
||||
messages[i] = openAIChatMessage{Role: m.Role, Content: m.Content}
|
||||
}
|
||||
|
||||
temperature := req.Temperature
|
||||
if temperature <= 0 {
|
||||
temperature = p.config.Temperature
|
||||
}
|
||||
maxTokens := req.MaxTokens
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = p.config.MaxTokens
|
||||
}
|
||||
|
||||
body := openAIChatRequest{
|
||||
Model: p.config.Model,
|
||||
Messages: messages,
|
||||
Temperature: temperature,
|
||||
MaxTokens: maxTokens,
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
respBody, err := p.doRequest(ctx, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer respBody.Close()
|
||||
|
||||
var result openAIChatResponse
|
||||
if err := json.NewDecoder(respBody).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析 OpenAI 响应失败: %w", err)
|
||||
}
|
||||
if result.Error != nil && result.Error.Message != "" {
|
||||
return nil, fmt.Errorf("OpenAI API 错误: %s", result.Error.Message)
|
||||
}
|
||||
if len(result.Choices) == 0 {
|
||||
return nil, fmt.Errorf("OpenAI 返回空响应")
|
||||
}
|
||||
|
||||
return &ai.ChatResponse{
|
||||
Content: result.Choices[0].Message.Content,
|
||||
TokensUsed: ai.TokenUsage{
|
||||
PromptTokens: result.Usage.PromptTokens,
|
||||
CompletionTokens: result.Usage.CompletionTokens,
|
||||
TotalTokens: result.Usage.TotalTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
messages := make([]openAIChatMessage, len(req.Messages))
|
||||
for i, m := range req.Messages {
|
||||
messages[i] = openAIChatMessage{Role: m.Role, Content: m.Content}
|
||||
}
|
||||
|
||||
temperature := req.Temperature
|
||||
if temperature <= 0 {
|
||||
temperature = p.config.Temperature
|
||||
}
|
||||
maxTokens := req.MaxTokens
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = p.config.MaxTokens
|
||||
}
|
||||
|
||||
body := openAIChatRequest{
|
||||
Model: p.config.Model,
|
||||
Messages: messages,
|
||||
Temperature: temperature,
|
||||
MaxTokens: maxTokens,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
respBody, err := p.doRequest(ctx, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer respBody.Close()
|
||||
|
||||
receivedContent := false
|
||||
scanner := bufio.NewScanner(respBody)
|
||||
// 增大 scanner buffer,防止长行被截断
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
// 非 SSE 数据行,可能是错误信息,记录日志
|
||||
if strings.Contains(line, "error") || strings.Contains(line, "Error") {
|
||||
callback(ai.StreamChunk{Error: fmt.Sprintf("服务端返回异常: %s", line), Done: true})
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
callback(ai.StreamChunk{Done: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
var chunk openAIStreamChunk
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue // 跳过格式异常的行
|
||||
}
|
||||
if chunk.Error != nil && chunk.Error.Message != "" {
|
||||
callback(ai.StreamChunk{Error: fmt.Sprintf("API 错误: %s", chunk.Error.Message), Done: true})
|
||||
return nil
|
||||
}
|
||||
if len(chunk.Choices) > 0 {
|
||||
content := chunk.Choices[0].Delta.Content
|
||||
if content != "" {
|
||||
receivedContent = true
|
||||
callback(ai.StreamChunk{Content: content})
|
||||
}
|
||||
if chunk.Choices[0].FinishReason != nil {
|
||||
callback(ai.StreamChunk{Done: true})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("读取 OpenAI 流式响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 如果流正常结束但没有收到任何内容,可能是 API 响应格式不兼容
|
||||
if !receivedContent {
|
||||
callback(ai.StreamChunk{Error: "未收到任何有效响应内容,请检查 API 端点和模型是否正确", Done: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
callback(ai.StreamChunk{Done: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) doRequest(ctx context.Context, body interface{}) (io.ReadCloser, error) {
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
url := p.baseURL + "/chat/completions"
|
||||
|
||||
// 调试日志
|
||||
bodyStr := string(jsonBody)
|
||||
if len(bodyStr) > 500 {
|
||||
bodyStr = bodyStr[:500] + "..."
|
||||
}
|
||||
fmt.Printf("[OpenAI DEBUG] URL: %s\n", url)
|
||||
fmt.Printf("[OpenAI DEBUG] BaseURL: %s\n", p.baseURL)
|
||||
fmt.Printf("[OpenAI DEBUG] Body: %s\n", bodyStr)
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+p.config.APIKey)
|
||||
|
||||
// 自定义 headers(用于兼容各类 OpenAI 兼容服务)
|
||||
for k, v := range p.config.Headers {
|
||||
httpReq.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := p.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("发送请求到 %s 失败: %w", url, err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("OpenAI API 返回错误 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
86
internal/ai/provider/openai_test.go
Normal file
86
internal/ai/provider/openai_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"GoNavi-Wails/internal/ai"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOpenAIProvider_Validate_MissingAPIKey(t *testing.T) {
|
||||
p, err := NewOpenAIProvider(ai.ProviderConfig{Type: "openai", Model: "gpt-4o"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected constructor error: %v", err)
|
||||
}
|
||||
if err := p.Validate(); err == nil {
|
||||
t.Fatal("expected validation error for missing API key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIProvider_Validate_Valid(t *testing.T) {
|
||||
p, err := NewOpenAIProvider(ai.ProviderConfig{
|
||||
Type: "openai", APIKey: "sk-test-key", Model: "gpt-4o",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected constructor error: %v", err)
|
||||
}
|
||||
if err := p.Validate(); err != nil {
|
||||
t.Fatalf("unexpected validation error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIProvider_Name_Custom(t *testing.T) {
|
||||
p, _ := NewOpenAIProvider(ai.ProviderConfig{
|
||||
Type: "openai", Name: "My OpenAI", APIKey: "sk-test",
|
||||
})
|
||||
if p.Name() != "My OpenAI" {
|
||||
t.Fatalf("expected name 'My OpenAI', got '%s'", p.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIProvider_Name_Default(t *testing.T) {
|
||||
p, _ := NewOpenAIProvider(ai.ProviderConfig{
|
||||
Type: "openai", APIKey: "sk-test",
|
||||
})
|
||||
if p.Name() != "OpenAI" {
|
||||
t.Fatalf("expected default name 'OpenAI', got '%s'", p.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIProvider_DefaultBaseURL(t *testing.T) {
|
||||
p, _ := NewOpenAIProvider(ai.ProviderConfig{
|
||||
Type: "openai", APIKey: "sk-test", Model: "gpt-4o",
|
||||
})
|
||||
op := p.(*OpenAIProvider)
|
||||
if op.baseURL != "https://api.openai.com/v1" {
|
||||
t.Fatalf("expected default base URL, got '%s'", op.baseURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIProvider_CustomBaseURL(t *testing.T) {
|
||||
p, _ := NewOpenAIProvider(ai.ProviderConfig{
|
||||
Type: "openai", APIKey: "sk-test", BaseURL: "https://my-proxy.com/v1",
|
||||
})
|
||||
op := p.(*OpenAIProvider)
|
||||
if op.baseURL != "https://my-proxy.com/v1" {
|
||||
t.Fatalf("expected custom base URL, got '%s'", op.baseURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIProvider_DefaultModel(t *testing.T) {
|
||||
p, _ := NewOpenAIProvider(ai.ProviderConfig{
|
||||
Type: "openai", APIKey: "sk-test",
|
||||
})
|
||||
op := p.(*OpenAIProvider)
|
||||
if op.config.Model != "gpt-4o" {
|
||||
t.Fatalf("expected default model 'gpt-4o', got '%s'", op.config.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIProvider_DefaultMaxTokens(t *testing.T) {
|
||||
p, _ := NewOpenAIProvider(ai.ProviderConfig{
|
||||
Type: "openai", APIKey: "sk-test",
|
||||
})
|
||||
op := p.(*OpenAIProvider)
|
||||
if op.config.MaxTokens != 4096 {
|
||||
t.Fatalf("expected default max tokens 4096, got %d", op.config.MaxTokens)
|
||||
}
|
||||
}
|
||||
19
internal/ai/provider/provider.go
Normal file
19
internal/ai/provider/provider.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
// Provider AI 模型提供者接口
|
||||
type Provider interface {
|
||||
// Chat 发送消息并获取完整响应
|
||||
Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error)
|
||||
// ChatStream 发送消息并以流式返回
|
||||
ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error
|
||||
// Name 返回 Provider 名称
|
||||
Name() string
|
||||
// Validate 校验配置是否有效
|
||||
Validate() error
|
||||
}
|
||||
25
internal/ai/provider/registry.go
Normal file
25
internal/ai/provider/registry.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
// NewProvider 根据配置创建 Provider 实例
|
||||
func NewProvider(config ai.ProviderConfig) (Provider, error) {
|
||||
providerType := strings.ToLower(strings.TrimSpace(config.Type))
|
||||
switch providerType {
|
||||
case "openai":
|
||||
return NewOpenAIProvider(config)
|
||||
case "anthropic":
|
||||
return NewAnthropicProvider(config)
|
||||
case "gemini":
|
||||
return NewGeminiProvider(config)
|
||||
case "custom":
|
||||
return NewCustomProvider(config)
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的 AI Provider 类型: %s", config.Type)
|
||||
}
|
||||
}
|
||||
101
internal/ai/safety/classifier.go
Normal file
101
internal/ai/safety/classifier.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package safety
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
// ClassifySQL 分类 SQL 语句的操作类型
|
||||
func ClassifySQL(sql string) ai.SQLOperationType {
|
||||
keyword := leadingSQLKeyword(sql)
|
||||
switch keyword {
|
||||
case "select", "with", "show", "describe", "desc", "explain", "pragma", "values":
|
||||
return ai.SQLOpQuery
|
||||
case "insert", "update", "delete", "replace", "merge", "upsert":
|
||||
return ai.SQLOpDML
|
||||
case "create", "alter", "drop", "truncate", "rename":
|
||||
return ai.SQLOpDDL
|
||||
default:
|
||||
return ai.SQLOpOther
|
||||
}
|
||||
}
|
||||
|
||||
// IsHighRiskSQL 判断 SQL 是否为高风险语句
|
||||
func IsHighRiskSQL(sql string) (bool, string) {
|
||||
keyword := leadingSQLKeyword(sql)
|
||||
normalized := strings.ToLower(sql)
|
||||
|
||||
switch keyword {
|
||||
case "drop":
|
||||
return true, "⚠️ 高危操作:DROP 语句将永久删除数据库对象"
|
||||
case "truncate":
|
||||
return true, "⚠️ 高危操作:TRUNCATE 将清空表中所有数据"
|
||||
case "delete":
|
||||
if !containsWhereClause(normalized) {
|
||||
return true, "⚠️ 高危操作:DELETE 语句缺少 WHERE 条件,将删除所有数据"
|
||||
}
|
||||
case "update":
|
||||
if !containsWhereClause(normalized) {
|
||||
return true, "⚠️ 高危操作:UPDATE 语句缺少 WHERE 条件,将更新所有记录"
|
||||
}
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// containsWhereClause 简单判断 SQL 是否包含 WHERE 子句
|
||||
func containsWhereClause(normalizedSQL string) bool {
|
||||
return strings.Contains(normalizedSQL, " where ") ||
|
||||
strings.Contains(normalizedSQL, "\nwhere ") ||
|
||||
strings.Contains(normalizedSQL, "\twhere ")
|
||||
}
|
||||
|
||||
// leadingSQLKeyword 提取 SQL 语句的首个关键字(跳过注释和空白)
|
||||
func leadingSQLKeyword(query string) string {
|
||||
text := strings.TrimSpace(query)
|
||||
for len(text) > 0 {
|
||||
trimmed := strings.TrimLeft(text, " \t\r\n")
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
text = trimmed
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(text, "--"):
|
||||
if idx := strings.IndexByte(text, '\n'); idx >= 0 {
|
||||
text = text[idx+1:]
|
||||
continue
|
||||
}
|
||||
return ""
|
||||
case strings.HasPrefix(text, "#"):
|
||||
if idx := strings.IndexByte(text, '\n'); idx >= 0 {
|
||||
text = text[idx+1:]
|
||||
continue
|
||||
}
|
||||
return ""
|
||||
case strings.HasPrefix(text, "/*"):
|
||||
if idx := strings.Index(text, "*/"); idx >= 0 {
|
||||
text = text[idx+2:]
|
||||
continue
|
||||
}
|
||||
return ""
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
for i, r := range text {
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' {
|
||||
continue
|
||||
}
|
||||
if i == 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(text[:i])
|
||||
}
|
||||
return strings.ToLower(text)
|
||||
}
|
||||
145
internal/ai/safety/classifier_test.go
Normal file
145
internal/ai/safety/classifier_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package safety
|
||||
|
||||
import (
|
||||
"GoNavi-Wails/internal/ai"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClassifySQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
sql string
|
||||
want ai.SQLOperationType
|
||||
}{
|
||||
{"SELECT * FROM users", ai.SQLOpQuery},
|
||||
{" select id from t", ai.SQLOpQuery},
|
||||
{"SHOW TABLES", ai.SQLOpQuery},
|
||||
{"DESCRIBE users", ai.SQLOpQuery},
|
||||
{"DESC users", ai.SQLOpQuery},
|
||||
{"EXPLAIN SELECT 1", ai.SQLOpQuery},
|
||||
{"WITH cte AS (SELECT 1) SELECT * FROM cte", ai.SQLOpQuery},
|
||||
{"PRAGMA table_info(t)", ai.SQLOpQuery},
|
||||
{"VALUES (1, 2)", ai.SQLOpQuery},
|
||||
{"INSERT INTO users VALUES (1)", ai.SQLOpDML},
|
||||
{"UPDATE users SET name='x'", ai.SQLOpDML},
|
||||
{"DELETE FROM users WHERE id=1", ai.SQLOpDML},
|
||||
{"REPLACE INTO users VALUES (1)", ai.SQLOpDML},
|
||||
{"MERGE INTO t USING s ON t.id=s.id", ai.SQLOpDML},
|
||||
{"CREATE TABLE t (id INT)", ai.SQLOpDDL},
|
||||
{"ALTER TABLE t ADD col INT", ai.SQLOpDDL},
|
||||
{"DROP TABLE t", ai.SQLOpDDL},
|
||||
{"TRUNCATE TABLE t", ai.SQLOpDDL},
|
||||
{"RENAME TABLE old TO new", ai.SQLOpDDL},
|
||||
{"/* comment */ SELECT 1", ai.SQLOpQuery},
|
||||
{"-- comment\nDELETE FROM t", ai.SQLOpDML},
|
||||
{"-- line1\n-- line2\nSELECT 1", ai.SQLOpQuery},
|
||||
{"/* block */ -- line\nUPDATE t SET x=1", ai.SQLOpDML},
|
||||
{"", ai.SQLOpOther},
|
||||
{" ", ai.SQLOpOther},
|
||||
{"-- only comment", ai.SQLOpOther},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := ClassifySQL(tt.sql)
|
||||
if got != tt.want {
|
||||
t.Errorf("ClassifySQL(%q) = %s, want %s", tt.sql, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsHighRiskSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
sql string
|
||||
highRisk bool
|
||||
}{
|
||||
{"DROP TABLE users", true},
|
||||
{"DROP DATABASE test", true},
|
||||
{"TRUNCATE TABLE users", true},
|
||||
{"DELETE FROM users", true}, // 无 WHERE
|
||||
{"DELETE FROM users WHERE id=1", false}, // 有 WHERE
|
||||
{"UPDATE users SET name='x'", true}, // 无 WHERE
|
||||
{"UPDATE users SET name='x' WHERE id=1", false}, // 有 WHERE
|
||||
{"SELECT * FROM users", false},
|
||||
{"INSERT INTO users VALUES (1)", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
highRisk, _ := IsHighRiskSQL(tt.sql)
|
||||
if highRisk != tt.highRisk {
|
||||
t.Errorf("IsHighRiskSQL(%q) = %v, want %v", tt.sql, highRisk, tt.highRisk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuard_ReadOnly(t *testing.T) {
|
||||
g := NewGuard(ai.PermissionReadOnly)
|
||||
tests := []struct {
|
||||
sql string
|
||||
allowed bool
|
||||
}{
|
||||
{"SELECT * FROM t", true},
|
||||
{"INSERT INTO t VALUES (1)", false},
|
||||
{"UPDATE t SET x=1", false},
|
||||
{"DELETE FROM t", false},
|
||||
{"DROP TABLE t", false},
|
||||
{"CREATE TABLE t (id INT)", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
result := g.Check(tt.sql)
|
||||
if result.Allowed != tt.allowed {
|
||||
t.Errorf("Guard[readonly].Check(%q).Allowed = %v, want %v", tt.sql, result.Allowed, tt.allowed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuard_ReadWrite(t *testing.T) {
|
||||
g := NewGuard(ai.PermissionReadWrite)
|
||||
tests := []struct {
|
||||
sql string
|
||||
allowed bool
|
||||
confirm bool
|
||||
}{
|
||||
{"SELECT * FROM t", true, false},
|
||||
{"INSERT INTO t VALUES (1)", true, true},
|
||||
{"UPDATE t SET x=1", true, true}, // 允许但需确认
|
||||
{"DELETE FROM t WHERE id=1", true, true}, // 允许但需确认
|
||||
{"DROP TABLE t", false, true}, // DDL 不允许
|
||||
{"CREATE TABLE t (id INT)", false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
result := g.Check(tt.sql)
|
||||
if result.Allowed != tt.allowed {
|
||||
t.Errorf("Guard[readwrite].Check(%q).Allowed = %v, want %v", tt.sql, result.Allowed, tt.allowed)
|
||||
}
|
||||
if result.RequiresConfirm != tt.confirm {
|
||||
t.Errorf("Guard[readwrite].Check(%q).RequiresConfirm = %v, want %v", tt.sql, result.RequiresConfirm, tt.confirm)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuard_Full(t *testing.T) {
|
||||
g := NewGuard(ai.PermissionFull)
|
||||
tests := []struct {
|
||||
sql string
|
||||
allowed bool
|
||||
}{
|
||||
{"SELECT * FROM t", true},
|
||||
{"INSERT INTO t VALUES (1)", true},
|
||||
{"DROP TABLE t", true},
|
||||
{"CREATE TABLE t (id INT)", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
result := g.Check(tt.sql)
|
||||
if result.Allowed != tt.allowed {
|
||||
t.Errorf("Guard[full].Check(%q).Allowed = %v, want %v", tt.sql, result.Allowed, tt.allowed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGuard_HighRiskWarning(t *testing.T) {
|
||||
g := NewGuard(ai.PermissionFull)
|
||||
result := g.Check("DELETE FROM users")
|
||||
if result.WarningMessage == "" {
|
||||
t.Error("expected high-risk warning for DELETE without WHERE")
|
||||
}
|
||||
if !result.RequiresConfirm {
|
||||
t.Error("expected RequiresConfirm for high-risk SQL")
|
||||
}
|
||||
}
|
||||
71
internal/ai/safety/guard.go
Normal file
71
internal/ai/safety/guard.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package safety
|
||||
|
||||
import (
|
||||
"GoNavi-Wails/internal/ai"
|
||||
)
|
||||
|
||||
// Guard AI SQL 安全策略守卫
|
||||
type Guard struct {
|
||||
permissionLevel ai.SQLPermissionLevel
|
||||
}
|
||||
|
||||
// NewGuard 创建安全策略守卫
|
||||
func NewGuard(level ai.SQLPermissionLevel) *Guard {
|
||||
return &Guard{permissionLevel: level}
|
||||
}
|
||||
|
||||
// SetPermissionLevel 设置权限级别
|
||||
func (g *Guard) SetPermissionLevel(level ai.SQLPermissionLevel) {
|
||||
g.permissionLevel = level
|
||||
}
|
||||
|
||||
// GetPermissionLevel 获取当前权限级别
|
||||
func (g *Guard) GetPermissionLevel() ai.SQLPermissionLevel {
|
||||
return g.permissionLevel
|
||||
}
|
||||
|
||||
// Check 检查 AI 生成的 SQL 是否在允许范围内
|
||||
func (g *Guard) Check(sql string) ai.SafetyResult {
|
||||
opType := ClassifySQL(sql)
|
||||
allowed := g.isAllowed(opType)
|
||||
requiresConfirm := g.requiresConfirmation(opType)
|
||||
warningMessage := ""
|
||||
|
||||
if isHighRisk, msg := IsHighRiskSQL(sql); isHighRisk {
|
||||
warningMessage = msg
|
||||
requiresConfirm = true
|
||||
}
|
||||
|
||||
return ai.SafetyResult{
|
||||
Allowed: allowed,
|
||||
OperationType: opType,
|
||||
RequiresConfirm: requiresConfirm,
|
||||
WarningMessage: warningMessage,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) isAllowed(opType ai.SQLOperationType) bool {
|
||||
switch g.permissionLevel {
|
||||
case ai.PermissionReadOnly:
|
||||
return opType == ai.SQLOpQuery
|
||||
case ai.PermissionReadWrite:
|
||||
return opType == ai.SQLOpQuery || opType == ai.SQLOpDML
|
||||
case ai.PermissionFull:
|
||||
return opType == ai.SQLOpQuery || opType == ai.SQLOpDML || opType == ai.SQLOpDDL
|
||||
default:
|
||||
return opType == ai.SQLOpQuery
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) requiresConfirmation(opType ai.SQLOperationType) bool {
|
||||
switch opType {
|
||||
case ai.SQLOpQuery:
|
||||
return false
|
||||
case ai.SQLOpDML:
|
||||
return true // DML 始终需要确认
|
||||
case ai.SQLOpDDL:
|
||||
return true // DDL 始终需要确认
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
573
internal/ai/service/service.go
Normal file
573
internal/ai/service/service.go
Normal file
@@ -0,0 +1,573 @@
|
||||
package aiservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"GoNavi-Wails/internal/ai"
|
||||
aicontext "GoNavi-Wails/internal/ai/context"
|
||||
"GoNavi-Wails/internal/ai/provider"
|
||||
"GoNavi-Wails/internal/ai/safety"
|
||||
"GoNavi-Wails/internal/logger"
|
||||
|
||||
"github.com/google/uuid"
|
||||
wailsRuntime "github.com/wailsapp/wails/v2/pkg/runtime"
|
||||
)
|
||||
|
||||
// Service AI 服务,作为 Wails Binding 暴露给前端
|
||||
type Service struct {
|
||||
ctx context.Context
|
||||
mu sync.RWMutex
|
||||
providers []ai.ProviderConfig
|
||||
activeProvider string // active provider ID
|
||||
safetyLevel ai.SQLPermissionLevel
|
||||
contextLevel ai.ContextLevel
|
||||
guard *safety.Guard
|
||||
configDir string // 配置存储目录
|
||||
cancelFuncs map[string]context.CancelFunc // 记录每个 session 的 context 取消函数
|
||||
}
|
||||
|
||||
// NewService 创建 AI Service 实例
|
||||
func NewService() *Service {
|
||||
return &Service{
|
||||
providers: make([]ai.ProviderConfig, 0),
|
||||
safetyLevel: ai.PermissionReadOnly,
|
||||
contextLevel: ai.ContextSchemaOnly,
|
||||
guard: safety.NewGuard(ai.PermissionReadOnly),
|
||||
cancelFuncs: make(map[string]context.CancelFunc),
|
||||
}
|
||||
}
|
||||
|
||||
// Startup Wails 生命周期回调
|
||||
func (s *Service) Startup(ctx context.Context) {
|
||||
s.ctx = ctx
|
||||
s.configDir = resolveConfigDir()
|
||||
s.loadConfig()
|
||||
logger.Infof("AI Service 启动完成,已加载 %d 个 Provider", len(s.providers))
|
||||
}
|
||||
|
||||
// --- Provider 管理 ---
|
||||
|
||||
// AIGetProviders 获取所有 Provider 配置
|
||||
func (s *Service) AIGetProviders() []ai.ProviderConfig {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
result := make([]ai.ProviderConfig, len(s.providers))
|
||||
copy(result, s.providers)
|
||||
return result
|
||||
}
|
||||
|
||||
// AISaveProvider 保存/更新 Provider 配置
|
||||
func (s *Service) AISaveProvider(config ai.ProviderConfig) error {
|
||||
fmt.Printf("[AISaveProvider DEBUG] ID: %s, Model: %s\n", config.ID, config.Model)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if strings.TrimSpace(config.ID) == "" {
|
||||
config.ID = "provider-" + uuid.New().String()[:8]
|
||||
}
|
||||
|
||||
found := false
|
||||
for i, p := range s.providers {
|
||||
if p.ID == config.ID {
|
||||
s.providers[i] = config
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
s.providers = append(s.providers, config)
|
||||
}
|
||||
|
||||
return s.saveConfig()
|
||||
}
|
||||
|
||||
// AIDeleteProvider 删除 Provider
|
||||
func (s *Service) AIDeleteProvider(id string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
newProviders := make([]ai.ProviderConfig, 0, len(s.providers))
|
||||
for _, p := range s.providers {
|
||||
if p.ID != id {
|
||||
newProviders = append(newProviders, p)
|
||||
}
|
||||
}
|
||||
s.providers = newProviders
|
||||
|
||||
if s.activeProvider == id {
|
||||
s.activeProvider = ""
|
||||
if len(s.providers) > 0 {
|
||||
s.activeProvider = s.providers[0].ID
|
||||
}
|
||||
}
|
||||
|
||||
return s.saveConfig()
|
||||
}
|
||||
|
||||
// AITestProvider 测试 Provider 配置是否可用
|
||||
func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{} {
|
||||
// 如果传入脱敏的 key,使用已保存的 key
|
||||
s.mu.RLock()
|
||||
if isMaskedAPIKey(config.APIKey) {
|
||||
for _, p := range s.providers {
|
||||
if p.ID == config.ID {
|
||||
config.APIKey = p.APIKey
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
p, err := provider.NewProvider(config)
|
||||
if err != nil {
|
||||
return map[string]interface{}{"success": false, "message": err.Error()}
|
||||
}
|
||||
if err := p.Validate(); err != nil {
|
||||
return map[string]interface{}{"success": false, "message": err.Error()}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*1000*1000*1000) // 30s
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, ai.ChatRequest{
|
||||
Messages: []ai.Message{
|
||||
{Role: "user", Content: "Hi, please respond with 'OK' to confirm the connection is working."},
|
||||
},
|
||||
MaxTokens: 10,
|
||||
})
|
||||
if err != nil {
|
||||
return map[string]interface{}{"success": false, "message": fmt.Sprintf("连接测试失败: %s", err.Error())}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("连接成功!模型响应: %s", truncateString(resp.Content, 100)),
|
||||
}
|
||||
}
|
||||
|
||||
// AISetActiveProvider 设置活动 Provider
|
||||
func (s *Service) AISetActiveProvider(id string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.activeProvider = id
|
||||
_ = s.saveConfig()
|
||||
}
|
||||
|
||||
// AIGetActiveProvider 获取活动 Provider ID
|
||||
func (s *Service) AIGetActiveProvider() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.activeProvider
|
||||
}
|
||||
|
||||
// AIGetBuiltinPrompts 返回内部置的各类系统提示词,用于前端展示或查询
|
||||
func (s *Service) AIGetBuiltinPrompts() map[string]string {
|
||||
return aicontext.GetBuiltinPrompts()
|
||||
}
|
||||
|
||||
// AIListModels 获取当前活跃 Provider 的可用模型列表
|
||||
func (s *Service) AIListModels() map[string]interface{} {
|
||||
s.mu.RLock()
|
||||
var config ai.ProviderConfig
|
||||
found := false
|
||||
for _, p := range s.providers {
|
||||
if p.ID == s.activeProvider {
|
||||
config = p
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !found {
|
||||
return map[string]interface{}{"success": false, "models": []string{}, "error": "未找到活跃 Provider"}
|
||||
}
|
||||
|
||||
models, err := fetchModels(config)
|
||||
if err != nil {
|
||||
// 回退到配置中的静态模型列表
|
||||
if len(config.Models) > 0 {
|
||||
return map[string]interface{}{"success": true, "models": config.Models, "source": "static"}
|
||||
}
|
||||
return map[string]interface{}{"success": false, "models": []string{}, "error": err.Error()}
|
||||
}
|
||||
|
||||
return map[string]interface{}{"success": true, "models": models, "source": "api"}
|
||||
}
|
||||
|
||||
// fetchModels 从供应商 API 获取可用模型列表
|
||||
func fetchModels(config ai.ProviderConfig) ([]string, error) {
|
||||
providerType := config.Type
|
||||
if providerType == "custom" && config.APIFormat != "" {
|
||||
providerType = config.APIFormat
|
||||
}
|
||||
|
||||
switch providerType {
|
||||
case "openai":
|
||||
return fetchOpenAIModels(config)
|
||||
case "anthropic":
|
||||
// Anthropic 没有公开的 /models 端点,返回硬编码列表
|
||||
return []string{"claude-opus-4-6", "claude-sonnet-4-6"}, nil
|
||||
case "gemini":
|
||||
return fetchGeminiModels(config)
|
||||
default:
|
||||
return fetchOpenAIModels(config)
|
||||
}
|
||||
}
|
||||
|
||||
// fetchOpenAIModels 获取 OpenAI 兼容 API 的模型列表
|
||||
func fetchOpenAIModels(config ai.ProviderConfig) ([]string, error) {
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
// 确保 baseURL 以 /v1 结尾
|
||||
if !strings.HasSuffix(baseURL, "/v1") {
|
||||
baseURL = baseURL + "/v1"
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", baseURL+"/models", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+config.APIKey)
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求模型列表失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return nil, fmt.Errorf("获取模型列表失败 (HTTP %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析模型列表失败: %w", err)
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(result.Data))
|
||||
for _, m := range result.Data {
|
||||
models = append(models, m.ID)
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// fetchGeminiModels 获取 Gemini API 的模型列表
|
||||
func fetchGeminiModels(config ai.ProviderConfig) ([]string, error) {
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", baseURL+"/v1beta/models?key="+config.APIKey, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求模型列表失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return nil, fmt.Errorf("获取模型列表失败 (HTTP %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Models []struct {
|
||||
Name string `json:"name"` // e.g. "models/gemini-2.5-flash"
|
||||
} `json:"models"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析模型列表失败: %w", err)
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(result.Models))
|
||||
for _, m := range result.Models {
|
||||
// 去掉 "models/" 前缀
|
||||
name := m.Name
|
||||
if strings.HasPrefix(name, "models/") {
|
||||
name = strings.TrimPrefix(name, "models/")
|
||||
}
|
||||
models = append(models, name)
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// --- 安全控制 ---
|
||||
|
||||
// AIGetSafetyLevel 获取当前安全级别
|
||||
func (s *Service) AIGetSafetyLevel() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return string(s.safetyLevel)
|
||||
}
|
||||
|
||||
// AISetSafetyLevel 设置安全级别
|
||||
func (s *Service) AISetSafetyLevel(level string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
switch ai.SQLPermissionLevel(level) {
|
||||
case ai.PermissionReadOnly, ai.PermissionReadWrite, ai.PermissionFull:
|
||||
s.safetyLevel = ai.SQLPermissionLevel(level)
|
||||
default:
|
||||
s.safetyLevel = ai.PermissionReadOnly
|
||||
}
|
||||
s.guard.SetPermissionLevel(s.safetyLevel)
|
||||
_ = s.saveConfig()
|
||||
}
|
||||
|
||||
// --- 上下文控制 ---
|
||||
|
||||
// AIGetContextLevel 获取上下文传递级别
|
||||
func (s *Service) AIGetContextLevel() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return string(s.contextLevel)
|
||||
}
|
||||
|
||||
// AISetContextLevel 设置上下文传递级别
|
||||
func (s *Service) AISetContextLevel(level string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
switch ai.ContextLevel(level) {
|
||||
case ai.ContextSchemaOnly, ai.ContextWithSamples, ai.ContextWithResults:
|
||||
s.contextLevel = ai.ContextLevel(level)
|
||||
default:
|
||||
s.contextLevel = ai.ContextSchemaOnly
|
||||
}
|
||||
_ = s.saveConfig()
|
||||
}
|
||||
|
||||
// --- AI 对话 ---
|
||||
|
||||
// AIChatSend 同步发送 AI 对话(非流式)
|
||||
func (s *Service) AIChatSend(messages []map[string]string) map[string]interface{} {
|
||||
p, err := s.getActiveProvider()
|
||||
if err != nil {
|
||||
return map[string]interface{}{"success": false, "error": err.Error()}
|
||||
}
|
||||
|
||||
var aiMessages []ai.Message
|
||||
for _, m := range messages {
|
||||
aiMessages = append(aiMessages, ai.Message{Role: m["role"], Content: m["content"]})
|
||||
}
|
||||
|
||||
resp, err := p.Chat(context.Background(), ai.ChatRequest{Messages: aiMessages})
|
||||
if err != nil {
|
||||
return map[string]interface{}{"success": false, "error": err.Error()}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"success": true,
|
||||
"content": resp.Content,
|
||||
"tokensUsed": map[string]int{
|
||||
"promptTokens": resp.TokensUsed.PromptTokens,
|
||||
"completionTokens": resp.TokensUsed.CompletionTokens,
|
||||
"totalTokens": resp.TokensUsed.TotalTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// AIChatStream 流式发送 AI 对话(通过 EventsEmit 推送)
|
||||
func (s *Service) AIChatStream(sessionID string, messages []map[string]string) {
|
||||
streamCtx, cancel := context.WithCancel(context.Background())
|
||||
s.mu.Lock()
|
||||
s.cancelFuncs[sessionID] = cancel
|
||||
s.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
delete(s.cancelFuncs, sessionID)
|
||||
s.mu.Unlock()
|
||||
cancel() // 确保释放
|
||||
}()
|
||||
|
||||
p, err := s.getActiveProvider()
|
||||
if err != nil {
|
||||
wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"done": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var aiMessages []ai.Message
|
||||
for _, m := range messages {
|
||||
aiMessages = append(aiMessages, ai.Message{Role: m["role"], Content: m["content"]})
|
||||
}
|
||||
|
||||
err = p.ChatStream(streamCtx, ai.ChatRequest{Messages: aiMessages}, func(chunk ai.StreamChunk) {
|
||||
wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
|
||||
"content": chunk.Content,
|
||||
"done": chunk.Done,
|
||||
"error": chunk.Error,
|
||||
})
|
||||
})
|
||||
|
||||
// 当 context 被主动 cancel 的时候,不把这个视为向外抛的 error
|
||||
if err != nil && err != context.Canceled {
|
||||
wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"done": true,
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// AIChatCancel 立即终止某个 Session 的流式对话请求
|
||||
func (s *Service) AIChatCancel(sessionID string) {
|
||||
s.mu.RLock()
|
||||
cancel, ok := s.cancelFuncs[sessionID]
|
||||
s.mu.RUnlock()
|
||||
if ok && cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// AICheckSQL 检查 SQL 的安全性
|
||||
func (s *Service) AICheckSQL(sql string) ai.SafetyResult {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.guard.Check(sql)
|
||||
}
|
||||
|
||||
// --- 内部方法 ---
|
||||
|
||||
func (s *Service) getActiveProvider() (provider.Provider, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.activeProvider == "" && len(s.providers) > 0 {
|
||||
s.activeProvider = s.providers[0].ID
|
||||
}
|
||||
|
||||
for _, cfg := range s.providers {
|
||||
if cfg.ID == s.activeProvider {
|
||||
return provider.NewProvider(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("未配置 AI Provider,请先在设置中配置")
|
||||
}
|
||||
|
||||
// --- 配置持久化 ---
|
||||
|
||||
type aiConfig struct {
|
||||
Providers []ai.ProviderConfig `json:"providers"`
|
||||
ActiveProvider string `json:"activeProvider"`
|
||||
SafetyLevel string `json:"safetyLevel"`
|
||||
ContextLevel string `json:"contextLevel"`
|
||||
}
|
||||
|
||||
func (s *Service) loadConfig() {
|
||||
path := filepath.Join(s.configDir, "ai_config.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return // 首次启动,无配置文件
|
||||
}
|
||||
|
||||
var cfg aiConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
logger.Error(err, "加载 AI 配置失败")
|
||||
return
|
||||
}
|
||||
|
||||
s.providers = cfg.Providers
|
||||
if s.providers == nil {
|
||||
s.providers = make([]ai.ProviderConfig, 0)
|
||||
}
|
||||
s.activeProvider = cfg.ActiveProvider
|
||||
|
||||
switch ai.SQLPermissionLevel(cfg.SafetyLevel) {
|
||||
case ai.PermissionReadOnly, ai.PermissionReadWrite, ai.PermissionFull:
|
||||
s.safetyLevel = ai.SQLPermissionLevel(cfg.SafetyLevel)
|
||||
default:
|
||||
s.safetyLevel = ai.PermissionReadOnly
|
||||
}
|
||||
s.guard.SetPermissionLevel(s.safetyLevel)
|
||||
|
||||
switch ai.ContextLevel(cfg.ContextLevel) {
|
||||
case ai.ContextSchemaOnly, ai.ContextWithSamples, ai.ContextWithResults:
|
||||
s.contextLevel = ai.ContextLevel(cfg.ContextLevel)
|
||||
default:
|
||||
s.contextLevel = ai.ContextSchemaOnly
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) saveConfig() error {
|
||||
cfg := aiConfig{
|
||||
Providers: s.providers,
|
||||
ActiveProvider: s.activeProvider,
|
||||
SafetyLevel: string(s.safetyLevel),
|
||||
ContextLevel: string(s.contextLevel),
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化 AI 配置失败: %w", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(s.configDir, 0o755); err != nil {
|
||||
return fmt.Errorf("创建配置目录失败: %w", err)
|
||||
}
|
||||
|
||||
path := filepath.Join(s.configDir, "ai_config.json")
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
return fmt.Errorf("写入 AI 配置失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- 工具函数 ---
|
||||
|
||||
func resolveConfigDir() string {
|
||||
configDir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
configDir = "."
|
||||
}
|
||||
return filepath.Join(configDir, "GoNavi")
|
||||
}
|
||||
|
||||
func maskAPIKey(apiKey string) string {
|
||||
if len(apiKey) <= 8 {
|
||||
return "****"
|
||||
}
|
||||
return apiKey[:4] + "****" + apiKey[len(apiKey)-4:]
|
||||
}
|
||||
|
||||
func isMaskedAPIKey(apiKey string) bool {
|
||||
return strings.Contains(apiKey, "****")
|
||||
}
|
||||
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
85
internal/ai/types.go
Normal file
85
internal/ai/types.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package ai
|
||||
|
||||
// Message 表示一条对话消息
|
||||
type Message struct {
|
||||
Role string `json:"role"` // "system" | "user" | "assistant"
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ChatRequest AI 对话请求
|
||||
type ChatRequest struct {
|
||||
Messages []Message `json:"messages"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
MaxTokens int `json:"maxTokens"`
|
||||
}
|
||||
|
||||
// ChatResponse AI 对话响应
|
||||
type ChatResponse struct {
|
||||
Content string `json:"content"`
|
||||
TokensUsed TokenUsage `json:"tokensUsed"`
|
||||
}
|
||||
|
||||
// TokenUsage token 用量统计
|
||||
type TokenUsage struct {
|
||||
PromptTokens int `json:"promptTokens"`
|
||||
CompletionTokens int `json:"completionTokens"`
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
}
|
||||
|
||||
// StreamChunk 流式响应片段
|
||||
type StreamChunk struct {
|
||||
Content string `json:"content"`
|
||||
Done bool `json:"done"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ProviderConfig AI Provider 配置
|
||||
type ProviderConfig struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // openai | anthropic | gemini | custom
|
||||
Name string `json:"name"`
|
||||
APIKey string `json:"apiKey"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
Model string `json:"model"`
|
||||
Models []string `json:"models,omitempty"`
|
||||
APIFormat string `json:"apiFormat,omitempty"` // custom 专用: openai | anthropic | gemini
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
MaxTokens int `json:"maxTokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
}
|
||||
|
||||
// SQLPermissionLevel AI SQL 执行权限级别
|
||||
type SQLPermissionLevel string
|
||||
|
||||
const (
|
||||
PermissionReadOnly SQLPermissionLevel = "readonly"
|
||||
PermissionReadWrite SQLPermissionLevel = "readwrite"
|
||||
PermissionFull SQLPermissionLevel = "full"
|
||||
)
|
||||
|
||||
// ContextLevel AI 上下文传递级别
|
||||
type ContextLevel string
|
||||
|
||||
const (
|
||||
ContextSchemaOnly ContextLevel = "schema_only"
|
||||
ContextWithSamples ContextLevel = "with_samples"
|
||||
ContextWithResults ContextLevel = "with_results"
|
||||
)
|
||||
|
||||
// SQLOperationType SQL 操作类型
|
||||
type SQLOperationType string
|
||||
|
||||
const (
|
||||
SQLOpQuery SQLOperationType = "query" // SELECT, SHOW, DESCRIBE, EXPLAIN
|
||||
SQLOpDML SQLOperationType = "dml" // INSERT, UPDATE, DELETE
|
||||
SQLOpDDL SQLOperationType = "ddl" // CREATE, ALTER, DROP, TRUNCATE
|
||||
SQLOpOther SQLOperationType = "other"
|
||||
)
|
||||
|
||||
// SafetyResult 安全检查结果
|
||||
type SafetyResult struct {
|
||||
Allowed bool `json:"allowed"`
|
||||
OperationType SQLOperationType `json:"operationType"`
|
||||
RequiresConfirm bool `json:"requiresConfirm"`
|
||||
WarningMessage string `json:"warningMessage,omitempty"`
|
||||
}
|
||||
Reference in New Issue
Block a user