-
}
- onClick={() => {
- setRenameViewTarget(null); // Create mode
- createTagForm.resetFields();
- setIsCreateTagModalOpen(true);
- }}
- style={{ flex: '1 1 auto' }}
- >
- 新建组
-
-
}
- onClick={() => openBatchOperationModal()}
- style={{ flex: '1 1 auto' }}
- >
- 批量操作表
-
-
}
- onClick={() => openBatchDatabaseModal()}
- style={{ flex: '1 1 auto' }}
- >
- 批量操作库
-
-
}
- onClick={handleOpenSQLFileFromToolbar}
- style={{ flex: '1 1 auto' }}
- >
- 运行外部SQL文件
-
+
+
+ } onClick={() => { setRenameViewTarget(null); createTagForm.resetFields(); setIsCreateTagModalOpen(true); }} style={{ color: darkMode ? 'rgba(255,255,255,0.65)' : 'rgba(0,0,0,0.65)' }} />
+
+
+ } onClick={() => openBatchOperationModal()} style={{ color: darkMode ? 'rgba(255,255,255,0.65)' : 'rgba(0,0,0,0.65)' }} />
+
+
+ } onClick={() => openBatchDatabaseModal()} style={{ color: darkMode ? 'rgba(255,255,255,0.65)' : 'rgba(0,0,0,0.65)' }} />
+
+
+ } onClick={handleOpenSQLFileFromToolbar} style={{ color: darkMode ? 'rgba(255,255,255,0.65)' : 'rgba(0,0,0,0.65)' }} />
+
diff --git a/frontend/src/components/TableOverview.tsx b/frontend/src/components/TableOverview.tsx
index 4a93783..f1d641d 100644
--- a/frontend/src/components/TableOverview.tsx
+++ b/frontend/src/components/TableOverview.tsx
@@ -138,6 +138,7 @@ const TableOverview: React.FC
= ({ tab }) => {
const connections = useStore(state => state.connections);
const theme = useStore(state => state.theme);
const addTab = useStore(state => state.addTab);
+ const setActiveContext = useStore(state => state.setActiveContext);
const darkMode = theme === 'dark';
const [tables, setTables] = useState([]);
@@ -195,6 +196,7 @@ const TableOverview: React.FC = ({ tab }) => {
const openTable = useCallback((tableName: string) => {
if (!connection) return;
+ setActiveContext({ connectionId: connection.id, dbName: tab.dbName || '' });
addTab({
id: `${connection.id}-${tab.dbName}-${tableName}`,
title: tableName,
@@ -203,10 +205,11 @@ const TableOverview: React.FC = ({ tab }) => {
dbName: tab.dbName,
tableName,
});
- }, [connection, tab.dbName, addTab]);
+ }, [connection, tab.dbName, addTab, setActiveContext]);
const openDesign = useCallback((tableName: string) => {
if (!connection) return;
+ setActiveContext({ connectionId: connection.id, dbName: tab.dbName || '' });
addTab({
id: `design-${connection.id}-${tab.dbName}-${tableName}`,
title: `设计表 (${tableName})`,
@@ -217,7 +220,7 @@ const TableOverview: React.FC = ({ tab }) => {
initialTab: 'columns',
readOnly: false,
});
- }, [connection, tab.dbName, addTab]);
+ }, [connection, tab.dbName, addTab, setActiveContext]);
const buildConfig = useCallback(() => {
if (!connection) return null;
@@ -383,6 +386,7 @@ const TableOverview: React.FC = ({ tab }) => {
menu={{
items: [
{ key: 'new-query', label: '新建查询', icon: , onClick: () => {
+ setActiveContext({ connectionId: tab.connectionId, dbName: tab.dbName || '' });
addTab({
id: `query-${Date.now()}`,
title: '新建查询',
diff --git a/frontend/src/store.ts b/frontend/src/store.ts
index 6a87182..e671270 100644
--- a/frontend/src/store.ts
+++ b/frontend/src/store.ts
@@ -1,6 +1,6 @@
import { create } from 'zustand';
import { persist } from 'zustand/middleware';
-import { ConnectionConfig, ProxyConfig, SavedConnection, TabData, SavedQuery, ConnectionTag } from './types';
+import { ConnectionConfig, ProxyConfig, SavedConnection, TabData, SavedQuery, ConnectionTag, AIChatMessage } from './types';
import {
ShortcutAction,
ShortcutBinding,
@@ -424,6 +424,12 @@ interface AppState {
windowState: 'normal' | 'fullscreen' | 'maximized';
sidebarWidth: number;
+ // AI 运行时与持久化状态
+ aiPanelVisible: boolean;
+ aiChatHistory: Record; // sessionId -> messages
+ aiChatSessions: { id: string; title: string; updatedAt: number }[]; // 历史会话列表
+ aiActiveSessionId: string | null;
+
addConnection: (conn: SavedConnection) => void;
updateConnection: (conn: SavedConnection) => void;
removeConnection: (id: string) => void;
@@ -475,6 +481,18 @@ interface AppState {
setWindowBounds: (bounds: { width: number; height: number; x: number; y: number }) => void;
setWindowState: (state: 'normal' | 'fullscreen' | 'maximized') => void;
setSidebarWidth: (width: number) => void;
+
+ // AI actions
+ toggleAIPanel: () => void;
+ setAIPanelVisible: (visible: boolean) => void;
+ addAIChatMessage: (sessionId: string, message: AIChatMessage) => void;
+ updateAIChatMessage: (sessionId: string, messageId: string, updates: Partial) => void;
+ deleteAIChatMessage: (sessionId: string, messageId: string) => void;
+ truncateAIChatMessages: (sessionId: string, upToMessageId: string) => void;
+ clearAIChatHistory: (sessionId: string) => void;
+ deleteAISession: (sessionId: string) => void;
+ createNewAISession: () => void;
+ setAIActiveSessionId: (sessionId: string | null) => void;
}
const sanitizeSavedQueries = (value: unknown): SavedQuery[] => {
@@ -671,6 +689,12 @@ export const useStore = create()(
windowState: 'normal' as const,
sidebarWidth: 330,
+ // AI 运行状态
+ aiPanelVisible: false,
+ aiChatHistory: {},
+ aiChatSessions: [],
+ aiActiveSessionId: null,
+
addConnection: (conn) => set((state) => ({ connections: [...state.connections, conn] })),
updateConnection: (conn) => set((state) => ({
connections: state.connections.map(c => c.id === conn.id ? conn : c)
@@ -950,6 +974,83 @@ export const useStore = create()(
setWindowState: (state) => set({ windowState: state }),
setSidebarWidth: (width) => set({ sidebarWidth: Math.max(200, Math.min(600, Math.trunc(width))) }),
+
+ // AI actions
+ toggleAIPanel: () => set((state) => ({ aiPanelVisible: !state.aiPanelVisible })),
+ setAIPanelVisible: (visible) => set({ aiPanelVisible: visible }),
+ addAIChatMessage: (sessionId, message) => set((state) => {
+ const history = { ...state.aiChatHistory };
+ const messages = history[sessionId] || [];
+ history[sessionId] = [...messages, message];
+
+ let newSessions = [...state.aiChatSessions];
+ const existingSession = newSessions.find(s => s.id === sessionId);
+
+ if (!existingSession) {
+ // 生成标题(首个 user message 内容前 20 字符)
+ let title = message.role === 'user' ? message.content : '新的对话';
+ if (title.length > 20) {
+ title = title.substring(0, 20) + '...';
+ }
+ newSessions.unshift({ id: sessionId, title, updatedAt: Date.now() });
+ } else {
+ // 提至最新
+ newSessions = newSessions.filter(s => s.id !== sessionId);
+ newSessions.unshift({ ...existingSession, updatedAt: Date.now() });
+ }
+
+ return { aiChatHistory: history, aiChatSessions: newSessions };
+ }),
+ updateAIChatMessage: (sessionId, messageId, updates) => set((state) => {
+ const history = { ...state.aiChatHistory };
+ const messages = history[sessionId];
+ if (!messages) return state;
+ history[sessionId] = messages.map(m =>
+ m.id === messageId ? { ...m, ...updates } : m
+ );
+ let newSessions = [...state.aiChatSessions];
+ const existingSession = newSessions.find(s => s.id === sessionId);
+ if (existingSession) {
+ newSessions = newSessions.filter(s => s.id !== sessionId);
+ newSessions.unshift({ ...existingSession, updatedAt: Date.now() });
+ }
+ return { aiChatHistory: history, aiChatSessions: newSessions };
+ }),
+ deleteAIChatMessage: (sessionId, messageId) => set((state) => {
+ const history = { ...state.aiChatHistory };
+ if (history[sessionId]) {
+ history[sessionId] = history[sessionId].filter(m => m.id !== messageId);
+ }
+ return { aiChatHistory: history };
+ }),
+ truncateAIChatMessages: (sessionId, upToMessageId) => set((state) => {
+ const history = { ...state.aiChatHistory };
+ const messages = history[sessionId];
+ if (messages) {
+ const idx = messages.findIndex(m => m.id === upToMessageId);
+ if (idx >= 0) {
+ history[sessionId] = messages.slice(0, idx + 1);
+ }
+ }
+ return { aiChatHistory: history };
+ }),
+ clearAIChatHistory: (sessionId) => set((state) => {
+ const history = { ...state.aiChatHistory };
+ delete history[sessionId];
+ return { aiChatHistory: history };
+ }),
+ deleteAISession: (sessionId) => set((state) => {
+ const history = { ...state.aiChatHistory };
+ delete history[sessionId];
+ const newSessions = state.aiChatSessions.filter(s => s.id !== sessionId);
+ const newActive = state.aiActiveSessionId === sessionId ? null : state.aiActiveSessionId;
+ return { aiChatHistory: history, aiChatSessions: newSessions, aiActiveSessionId: newActive };
+ }),
+ createNewAISession: () => set(() => {
+ const newId = `session-${Date.now()}`;
+ return { aiActiveSessionId: newId };
+ }),
+ setAIActiveSessionId: (sessionId) => set({ aiActiveSessionId: sessionId }),
}),
{
name: 'lite-db-storage', // name of the item in the storage (must be unique)
@@ -985,6 +1086,10 @@ export const useStore = create()(
nextState.windowBounds = sanitizeWindowBounds(state.windowBounds);
nextState.windowState = sanitizeWindowState(state.windowState);
nextState.sidebarWidth = sanitizeSidebarWidth(state.sidebarWidth);
+
+ // 保留原有的 AI 持久化记录,或者为空(版本兼容)
+ nextState.aiChatHistory = (state.aiChatHistory && typeof state.aiChatHistory === 'object') ? state.aiChatHistory : {};
+ nextState.aiChatSessions = Array.isArray(state.aiChatSessions) ? state.aiChatSessions : [];
return nextState as AppState;
},
merge: (persistedState, currentState) => {
@@ -1014,6 +1119,9 @@ export const useStore = create()(
queryOptions: sanitizeQueryOptions(state.queryOptions),
shortcutOptions: sanitizeShortcutOptions(state.shortcutOptions),
tableAccessCount: sanitizeTableAccessCount(state.tableAccessCount),
+
+ aiChatHistory: (state.aiChatHistory && typeof state.aiChatHistory === 'object') ? state.aiChatHistory : {},
+ aiChatSessions: Array.isArray(state.aiChatSessions) ? state.aiChatSessions : [],
};
},
partialize: (state) => ({
@@ -1038,6 +1146,9 @@ export const useStore = create()(
windowBounds: state.windowBounds,
windowState: state.windowState,
sidebarWidth: state.sidebarWidth,
+
+ aiChatHistory: state.aiChatHistory,
+ aiChatSessions: state.aiChatSessions,
}), // Don't persist logs
}
)
diff --git a/frontend/src/types.ts b/frontend/src/types.ts
index ea10867..072d65c 100644
--- a/frontend/src/types.ts
+++ b/frontend/src/types.ts
@@ -183,3 +183,39 @@ export interface StreamEntry {
id: string;
fields: Record;
}
+
+// --- AI Types ---
+
+export type AIProviderType = 'openai' | 'anthropic' | 'gemini' | 'custom';
+export type AISafetyLevel = 'readonly' | 'readwrite' | 'full';
+export type AIContextLevel = 'schema_only' | 'with_samples' | 'with_results';
+
+export interface AIProviderConfig {
+ id: string;
+ type: AIProviderType;
+ name: string;
+ apiKey: string;
+ baseUrl: string;
+ model: string;
+ models?: string[];
+ apiFormat?: string; // custom 专用: openai | anthropic | gemini
+ headers?: Record;
+ maxTokens: number;
+ temperature: number;
+}
+
+export interface AIChatMessage {
+ id: string;
+ role: 'user' | 'assistant' | 'system';
+ content: string;
+ timestamp: number;
+ loading?: boolean;
+}
+
+export interface AISafetyResult {
+ allowed: boolean;
+ operationType: 'query' | 'dml' | 'ddl' | 'other';
+ requiresConfirm: boolean;
+ warningMessage?: string;
+}
+
diff --git a/frontend/wailsjs/go/aiservice/Service.d.ts b/frontend/wailsjs/go/aiservice/Service.d.ts
new file mode 100644
index 0000000..872b5f5
--- /dev/null
+++ b/frontend/wailsjs/go/aiservice/Service.d.ts
@@ -0,0 +1,38 @@
+// Cynhyrchwyd y ffeil hon yn awtomatig. PEIDIWCH Â MODIWL
+// This file is automatically generated. DO NOT EDIT
+import {ai} from '../models';
+import {context} from '../models';
+
+export function AIChatCancel(arg1:string):Promise;
+
+export function AIChatSend(arg1:Array>):Promise>;
+
+export function AIChatStream(arg1:string,arg2:Array>):Promise;
+
+export function AICheckSQL(arg1:string):Promise;
+
+export function AIDeleteProvider(arg1:string):Promise;
+
+export function AIGetActiveProvider():Promise;
+
+export function AIGetBuiltinPrompts():Promise>;
+
+export function AIGetContextLevel():Promise;
+
+export function AIGetProviders():Promise>;
+
+export function AIGetSafetyLevel():Promise;
+
+export function AIListModels():Promise>;
+
+export function AISaveProvider(arg1:ai.ProviderConfig):Promise;
+
+export function AISetActiveProvider(arg1:string):Promise;
+
+export function AISetContextLevel(arg1:string):Promise;
+
+export function AISetSafetyLevel(arg1:string):Promise;
+
+export function AITestProvider(arg1:ai.ProviderConfig):Promise>;
+
+export function Startup(arg1:context.Context):Promise;
diff --git a/frontend/wailsjs/go/aiservice/Service.js b/frontend/wailsjs/go/aiservice/Service.js
new file mode 100644
index 0000000..2e3dcf4
--- /dev/null
+++ b/frontend/wailsjs/go/aiservice/Service.js
@@ -0,0 +1,71 @@
+// @ts-check
+// Cynhyrchwyd y ffeil hon yn awtomatig. PEIDIWCH Â MODIWL
+// This file is automatically generated. DO NOT EDIT
+
+export function AIChatCancel(arg1) {
+ return window['go']['aiservice']['Service']['AIChatCancel'](arg1);
+}
+
+export function AIChatSend(arg1) {
+ return window['go']['aiservice']['Service']['AIChatSend'](arg1);
+}
+
+export function AIChatStream(arg1, arg2) {
+ return window['go']['aiservice']['Service']['AIChatStream'](arg1, arg2);
+}
+
+export function AICheckSQL(arg1) {
+ return window['go']['aiservice']['Service']['AICheckSQL'](arg1);
+}
+
+export function AIDeleteProvider(arg1) {
+ return window['go']['aiservice']['Service']['AIDeleteProvider'](arg1);
+}
+
+export function AIGetActiveProvider() {
+ return window['go']['aiservice']['Service']['AIGetActiveProvider']();
+}
+
+export function AIGetBuiltinPrompts() {
+ return window['go']['aiservice']['Service']['AIGetBuiltinPrompts']();
+}
+
+export function AIGetContextLevel() {
+ return window['go']['aiservice']['Service']['AIGetContextLevel']();
+}
+
+export function AIGetProviders() {
+ return window['go']['aiservice']['Service']['AIGetProviders']();
+}
+
+export function AIGetSafetyLevel() {
+ return window['go']['aiservice']['Service']['AIGetSafetyLevel']();
+}
+
+export function AIListModels() {
+ return window['go']['aiservice']['Service']['AIListModels']();
+}
+
+export function AISaveProvider(arg1) {
+ return window['go']['aiservice']['Service']['AISaveProvider'](arg1);
+}
+
+export function AISetActiveProvider(arg1) {
+ return window['go']['aiservice']['Service']['AISetActiveProvider'](arg1);
+}
+
+export function AISetContextLevel(arg1) {
+ return window['go']['aiservice']['Service']['AISetContextLevel'](arg1);
+}
+
+export function AISetSafetyLevel(arg1) {
+ return window['go']['aiservice']['Service']['AISetSafetyLevel'](arg1);
+}
+
+export function AITestProvider(arg1) {
+ return window['go']['aiservice']['Service']['AITestProvider'](arg1);
+}
+
+export function Startup(arg1) {
+ return window['go']['aiservice']['Service']['Startup'](arg1);
+}
diff --git a/frontend/wailsjs/go/app/App.d.ts b/frontend/wailsjs/go/app/App.d.ts
index d6fb6a4..2baf781 100755
--- a/frontend/wailsjs/go/app/App.d.ts
+++ b/frontend/wailsjs/go/app/App.d.ts
@@ -4,6 +4,7 @@ import {connection} from '../models';
import {time} from '../models';
import {sync} from '../models';
import {redis} from '../models';
+import {context} from '../models';
export function ApplyChanges(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:connection.ChangeSet):Promise;
@@ -197,6 +198,8 @@ export function SetMacNativeWindowControls(arg1:boolean):Promise;
export function SetWindowTranslucency(arg1:number,arg2:number):Promise;
+export function Startup(arg1:context.Context):Promise;
+
export function TestConnection(arg1:connection.ConnectionConfig):Promise;
export function TruncateTables(arg1:connection.ConnectionConfig,arg2:string,arg3:Array):Promise;
diff --git a/frontend/wailsjs/go/app/App.js b/frontend/wailsjs/go/app/App.js
index 495177e..f0e3782 100755
--- a/frontend/wailsjs/go/app/App.js
+++ b/frontend/wailsjs/go/app/App.js
@@ -386,6 +386,10 @@ export function SetWindowTranslucency(arg1, arg2) {
return window['go']['app']['App']['SetWindowTranslucency'](arg1, arg2);
}
+export function Startup(arg1) {
+ return window['go']['app']['App']['Startup'](arg1);
+}
+
export function TestConnection(arg1) {
return window['go']['app']['App']['TestConnection'](arg1);
}
diff --git a/frontend/wailsjs/go/models.ts b/frontend/wailsjs/go/models.ts
index 9a25b59..258b148 100755
--- a/frontend/wailsjs/go/models.ts
+++ b/frontend/wailsjs/go/models.ts
@@ -1,3 +1,58 @@
+export namespace ai {
+
+ export class ProviderConfig {
+ id: string;
+ type: string;
+ name: string;
+ apiKey: string;
+ baseUrl: string;
+ model: string;
+ models?: string[];
+ apiFormat?: string;
+ headers?: Record;
+ maxTokens: number;
+ temperature: number;
+
+ static createFrom(source: any = {}) {
+ return new ProviderConfig(source);
+ }
+
+ constructor(source: any = {}) {
+ if ('string' === typeof source) source = JSON.parse(source);
+ this.id = source["id"];
+ this.type = source["type"];
+ this.name = source["name"];
+ this.apiKey = source["apiKey"];
+ this.baseUrl = source["baseUrl"];
+ this.model = source["model"];
+ this.models = source["models"];
+ this.apiFormat = source["apiFormat"];
+ this.headers = source["headers"];
+ this.maxTokens = source["maxTokens"];
+ this.temperature = source["temperature"];
+ }
+ }
+ export class SafetyResult {
+ allowed: boolean;
+ operationType: string;
+ requiresConfirm: boolean;
+ warningMessage?: string;
+
+ static createFrom(source: any = {}) {
+ return new SafetyResult(source);
+ }
+
+ constructor(source: any = {}) {
+ if ('string' === typeof source) source = JSON.parse(source);
+ this.allowed = source["allowed"];
+ this.operationType = source["operationType"];
+ this.requiresConfirm = source["requiresConfirm"];
+ this.warningMessage = source["warningMessage"];
+ }
+ }
+
+}
+
export namespace connection {
export class UpdateRow {
diff --git a/internal/ai/context/builder.go b/internal/ai/context/builder.go
new file mode 100644
index 0000000..9c4e75a
--- /dev/null
+++ b/internal/ai/context/builder.go
@@ -0,0 +1,213 @@
+package aicontext
+
+import (
+ "fmt"
+ "strings"
+)
+
+// PromptTemplate AI 能力类型
+type PromptTemplate string
+
+const (
+ PromptSQLGenerate PromptTemplate = "sql_generate"
+ PromptSQLExplain PromptTemplate = "sql_explain"
+ PromptSQLOptimize PromptTemplate = "sql_optimize"
+ PromptDataAnalyze PromptTemplate = "data_analyze"
+ PromptSchemaInsight PromptTemplate = "schema_insight"
+ PromptGeneralChat PromptTemplate = "general_chat"
+)
+
+// GetBuiltinPrompts 获取所有内置系统提示词集合,用于前端展示
+func GetBuiltinPrompts() map[string]string {
+ return map[string]string{
+ "通用聊天助手": buildGeneralChatPrompt(),
+ "SQL 生成器": buildSQLGeneratePrompt(),
+ "SQL 解析器": buildSQLExplainPrompt(),
+ "SQL 优化器": buildSQLOptimizePrompt(),
+ "数据洞察分析": buildDataAnalyzePrompt(),
+ "表结构审查": buildSchemaInsightPrompt(),
+ }
+}
+
+// BuildSystemPrompt 根据模板类型和上下文构建 System Prompt
+func BuildSystemPrompt(template PromptTemplate, dbCtx *DatabaseContext) string {
+ var prompt string
+
+ switch template {
+ case PromptSQLGenerate:
+ prompt = buildSQLGeneratePrompt()
+ case PromptSQLExplain:
+ prompt = buildSQLExplainPrompt()
+ case PromptSQLOptimize:
+ prompt = buildSQLOptimizePrompt()
+ case PromptDataAnalyze:
+ prompt = buildDataAnalyzePrompt()
+ case PromptSchemaInsight:
+ prompt = buildSchemaInsightPrompt()
+ case PromptGeneralChat:
+ prompt = buildGeneralChatPrompt()
+ default:
+ prompt = buildGeneralChatPrompt()
+ }
+
+ if dbCtx != nil {
+ prompt += "\n\n" + FormatDatabaseContext(dbCtx)
+ }
+
+ return prompt
+}
+
+// FormatDatabaseContext 将数据库上下文格式化为 LLM 友好的文本
+func FormatDatabaseContext(ctx *DatabaseContext) string {
+ if ctx == nil || len(ctx.Tables) == 0 {
+ return ""
+ }
+
+ var b strings.Builder
+ b.WriteString(fmt.Sprintf("## 当前数据库上下文\n\n数据库类型: %s\n数据库名: %s\n\n",
+ ctx.DatabaseType, ctx.DatabaseName))
+
+ b.WriteString("### 表结构\n\n")
+ for _, table := range ctx.Tables {
+ b.WriteString(fmt.Sprintf("#### 表: %s", table.Name))
+ if table.Comment != "" {
+ b.WriteString(fmt.Sprintf(" (%s)", table.Comment))
+ }
+ if table.RowCount > 0 {
+ b.WriteString(fmt.Sprintf(" [约 %d 行]", table.RowCount))
+ }
+ b.WriteString("\n\n")
+
+ b.WriteString("| 列名 | 类型 | 可空 | 主键 | 备注 |\n")
+ b.WriteString("|------|------|------|------|------|\n")
+ for _, col := range table.Columns {
+ nullable := "否"
+ if col.Nullable {
+ nullable = "是"
+ }
+ pk := ""
+ if col.PrimaryKey {
+ pk = "✓"
+ }
+ comment := col.Comment
+ if comment == "" {
+ comment = "-"
+ }
+ b.WriteString(fmt.Sprintf("| %s | %s | %s | %s | %s |\n",
+ col.Name, col.Type, nullable, pk, comment))
+ }
+ b.WriteString("\n")
+
+ if len(table.Indexes) > 0 {
+ b.WriteString("**索引:**\n")
+ for _, idx := range table.Indexes {
+ unique := ""
+ if idx.Unique {
+ unique = " (唯一)"
+ }
+ b.WriteString(fmt.Sprintf("- %s: [%s]%s\n",
+ idx.Name, strings.Join(idx.Columns, ", "), unique))
+ }
+ b.WriteString("\n")
+ }
+
+ if len(table.SampleRows) > 0 {
+ b.WriteString(fmt.Sprintf("**采样数据 (%d 行):**\n\n", len(table.SampleRows)))
+ if len(table.SampleRows) > 0 {
+ // 使用第一行的 key 作为标题
+ first := table.SampleRows[0]
+ var keys []string
+ for k := range first {
+ keys = append(keys, k)
+ }
+ b.WriteString("| " + strings.Join(keys, " | ") + " |\n")
+ b.WriteString("|" + strings.Repeat("------|", len(keys)) + "\n")
+ for _, row := range table.SampleRows {
+ var vals []string
+ for _, k := range keys {
+ vals = append(vals, fmt.Sprintf("%v", row[k]))
+ }
+ b.WriteString("| " + strings.Join(vals, " | ") + " |\n")
+ }
+ b.WriteString("\n")
+ }
+ }
+ }
+
+ return b.String()
+}
+
+func buildSQLGeneratePrompt() string {
+ return `你是 GoNavi AI 助手,一位顶级的数据库开发专家和 SQL 查询构建师。根据用户的自然语言需求,生成精准、优雅、高性能的 SQL 查询或 Redis 命令。
+
+严苛输出规则:
+1. 首要目标是输出纯粹的代码:始终将代码放在正确语言标识(如 sql 或 bash)的 markdown 代码块中。
+2. 保持精简:不要添加过多的前置闲聊,直奔主题。
+3. 保护生产安全:优先使用参数化查询或安全防范写法避免 SQL 注入。对于未指定条件的 DELETE/UPDATE 语句,必须提出强烈的红线警告!!
+4. 性能至上:对大型查询默认添加合理的 LIMIT 限制(如 LIMIT 100),在 JOIN 和聚合时优先选择最高效的范式写法。
+5. 适度注释:对于存在复杂逻辑嵌套的代码,请在代码块内使用单行注释简要说明思路。`
+}
+
+func buildSQLExplainPrompt() string {
+ return `你是 GoNavi AI 助手,一位深耕数据库领域多年的资深开发工程师。请用专业、条理分明且深入浅出的开发者语言向用户全盘解析 SQL 语句的底层意图与执行逻辑。
+
+解析规范:
+1. 宏观逻辑解构:用简短的一句话概括这条 SQL 在业务上想要解决什么问题。
+2. 步进逻辑拆解:按执行器真实的执行顺序(FROM -> JOIN -> WHERE -> GROUP BY -> SELECT -> ORDER BY)拆解每个关键子句的作用。
+3. 性能排雷点:敏锐指出可能存在的性能陷阱(如隐式类型转换、没有走索引的函数调用、潜在的笛卡尔积/全表扫描等)。
+4. 严谨的排版:使用列表呈现关键点,重点词汇加粗,确保长文不累赘。`
+}
+
+func buildSQLOptimizePrompt() string {
+ return `你是 GoNavi AI 助手,一名曾主导过千万级高并发系统的全栈性能工程专家与高级 DBA。请对用户提供的原始 SQL 进行冷酷、精确的诊断并开出性能重构处方。
+
+诊断与处方要求:
+1. 性能瓶颈透视:精准点出当前语句死穴(不合理的驱动表、无法利用覆盖索引、多此一举的子查询等)。
+2. 重构版本的 SQL:如果存在性能提升空间,直接向用户展示彻底优化过的高性能写法,并确保逻辑等价性。
+3. 剖析原因:不仅要告诉用户“怎么改”,更要说清楚执行器“为什么这样会更快”。
+4. 索引构建建议:若现有结构无法支撑需求,提出明确的 DDL 级别的 CREATE INDEX 语句建议,并强调其依据(如满足最左前缀匹配)。
+5. 优先级评估:在回答的最后标注本次优化建议的紧迫性(高:阻断级/锁表风险;中:吞吐量瓶颈;低:长效微调)。`
+}
+
+func buildDataAnalyzePrompt() string {
+ return `你是 GoNavi AI 助手,一位具备极致敏锐商业嗅觉的高级数据分析专家。你将审视用户通过查询得到的数据样本,从中提炼出蕴含的真金白银般的信息。
+
+洞察目标:
+1. 硬统计:总观数据行数、核心数值指标(极值、平均值、聚合中位数等)的冰冷现实。
+2. 趋势与异动:如果数据带有时间戳,敏锐捕捉其上升或下降趋势;如果有异类离群值,将其高亮标注。
+3. 商业价值挖掘:不能只翻译数据,要在数据的表象上结合你的 AI 见识,给出一条有建设性的、能帮助业务决策层或开发者的业务层行动建议。
+4. 展现格式:你的分析应该是“标题 + 浓缩要点”的极简研报形式,杜绝毫无波澜的流水账。`
+}
+
+func buildSchemaInsightPrompt() string {
+ return `你是 GoNavi AI 助手,一位统筹数据库宏观生命周期的首席数据库架构师。在这个环节里,你需要对用户提供的数据库表结构执行最严厉的范式与前瞻性审查。
+
+审查视界:
+1. 规范化博弈:是否存在明显的反三范式设计?这种冗余是否有助于性能(适当的反范式),还是纯粹的设计失误?
+2. 索引健壮性审查:评估主键选择(如自增、UUID 的利弊),是否存在冗余索引阻碍写入?以及是否遗漏了高频的联合索引。
+3. 物理容量前瞻:审视数据类型分配(如使用过大的 VARCHAR、没必要的 BIGINT 等可能带来的空间挥霍)。
+4. 代码级指引:如果存在结构性缺陷,不要只发牢骚,直接给出包含具体优化的 ALTER TABLE 结构修改建议脚本。`
+}
+
+func buildGeneralChatPrompt() string {
+ return `你是 GoNavi AI 助手,一款深度集成在数据库/缓存客户端(GoNavi)内部的专属智能专家系统。
+你的目标是成为开发者、DBA 和数据科学家最得力的超级外脑,提供专业、精准、具有前瞻性的数据端解决方案。
+
+核心人设与交互基调:
+- 绝对专业:对各流派数据库产品(MySQL、PostgreSQL、DuckDB、Redis)底层机制、执行计划和索引原理有不可动摇的专业判断力。
+- 直击痛点:谢绝套话与无效寒暄,若用户的意图明确,首屏直接给出可以直接粘贴运行的优雅代码。
+- 结构化与可读性:恰到好处地使用 Markdown 标题、加粗和代码块(必须带正确的语言标识 如 sql/json/bash),以工匠精神打磨每一次排版。
+- 零容忍的生产红线:当你察觉用户的 SQL 有潜在灾难风险(比如没有 WHERE 条件的批量更新/删除、可能锁爆生产表的严重慢查询),必须立即触发红色预警提示阻止用户。
+
+你的综合能力版图:
+1. 📝 自然语言驱动:翻译人类意图为精准的查询语句。
+2. 🔍 底层原理解析:剥丝抽茧分析查询背后的执行逻辑与性能隐患。
+3. ⚡ 专家级调优:指出并化解性能瓶颈,给出覆盖全维度的索引调优思路。
+4. 📊 数据洞察炼金:不仅聚合数据,更能从结果集中挖掘商业维度的深度规律。
+5. 🏗️ 架构先知视界:全局审阅表结构设计局限,提出抗数据膨胀级别的架构演进方案。
+
+互动守则:
+- 永远使用专业、具有合作感且充满信心的中文与用户探讨问题。
+- 当被要求提供任何数据库代码时,需结合相关数据库引擎的最佳实践。如果不清楚当前方言版本,请以标准实现为主基调并好心指出版别差异(如 MySQL 8 窗口函数 等)。`
+}
+
diff --git a/internal/ai/context/collector.go b/internal/ai/context/collector.go
new file mode 100644
index 0000000..bfa6c36
--- /dev/null
+++ b/internal/ai/context/collector.go
@@ -0,0 +1,42 @@
+package aicontext
+
+// DatabaseContext 数据库上下文信息,传递给 AI 辅助上下文理解
+type DatabaseContext struct {
+ DatabaseType string `json:"databaseType"` // mysql, postgres 等
+ DatabaseName string `json:"databaseName"`
+ Tables []TableContext `json:"tables"`
+}
+
+// TableContext 表的上下文信息
+type TableContext struct {
+ Name string `json:"name"`
+ Comment string `json:"comment,omitempty"`
+ Columns []ColumnInfo `json:"columns"`
+ Indexes []IndexInfo `json:"indexes,omitempty"`
+ SampleRows []map[string]interface{} `json:"sampleRows,omitempty"`
+ RowCount int64 `json:"rowCount,omitempty"`
+}
+
+// ColumnInfo 列信息
+type ColumnInfo struct {
+ Name string `json:"name"`
+ Type string `json:"type"`
+ Nullable bool `json:"nullable"`
+ PrimaryKey bool `json:"primaryKey"`
+ Comment string `json:"comment,omitempty"`
+}
+
+// IndexInfo 索引信息
+type IndexInfo struct {
+ Name string `json:"name"`
+ Columns []string `json:"columns"`
+ Unique bool `json:"unique"`
+}
+
+// QueryResultContext 查询结果上下文
+type QueryResultContext struct {
+ SQL string `json:"sql"`
+ Columns []string `json:"columns"`
+ Rows []map[string]interface{} `json:"rows"`
+ RowCount int `json:"rowCount"`
+}
diff --git a/internal/ai/provider/anthropic.go b/internal/ai/provider/anthropic.go
new file mode 100644
index 0000000..035d2fc
--- /dev/null
+++ b/internal/ai/provider/anthropic.go
@@ -0,0 +1,293 @@
+package provider
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+
+ "GoNavi-Wails/internal/ai"
+)
+
+const (
+ defaultAnthropicBaseURL = "https://api.anthropic.com"
+ defaultAnthropicModel = "claude-3-5-sonnet-20241022"
+ anthropicAPIVersion = "2023-06-01"
+)
+
+// AnthropicProvider 实现 Anthropic Claude API 的 Provider
+type AnthropicProvider struct {
+ config ai.ProviderConfig
+ baseURL string
+ client *http.Client
+}
+
+// NewAnthropicProvider 创建 Anthropic Provider 实例
+func NewAnthropicProvider(config ai.ProviderConfig) (Provider, error) {
+ baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
+ if baseURL == "" {
+ baseURL = defaultAnthropicBaseURL
+ }
+ model := strings.TrimSpace(config.Model)
+ if model == "" {
+ model = defaultAnthropicModel
+ }
+ maxTokens := config.MaxTokens
+ if maxTokens <= 0 {
+ maxTokens = defaultOpenAIMaxTokens
+ }
+ temperature := config.Temperature
+ if temperature <= 0 {
+ temperature = defaultOpenAITemperature
+ }
+
+ normalized := config
+ normalized.BaseURL = baseURL
+ normalized.Model = model
+ normalized.MaxTokens = maxTokens
+ normalized.Temperature = temperature
+
+ return &AnthropicProvider{
+ config: normalized,
+ baseURL: baseURL,
+ client: &http.Client{Timeout: openAIHTTPTimeout},
+ }, nil
+}
+
+func (p *AnthropicProvider) Name() string {
+ if strings.TrimSpace(p.config.Name) != "" {
+ return p.config.Name
+ }
+ return "Anthropic"
+}
+
+func (p *AnthropicProvider) Validate() error {
+ if strings.TrimSpace(p.config.APIKey) == "" {
+ return fmt.Errorf("API Key 不能为空")
+ }
+ return nil
+}
+
+type anthropicRequest struct {
+ Model string `json:"model"`
+ Messages []anthropicMessage `json:"messages"`
+ System string `json:"system,omitempty"`
+ MaxTokens int `json:"max_tokens"`
+ Temperature float64 `json:"temperature,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+}
+
+type anthropicMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+type anthropicResponse struct {
+ Content []struct {
+ Text string `json:"text"`
+ } `json:"content"`
+ Usage struct {
+ InputTokens int `json:"input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+ } `json:"usage"`
+ Error *struct {
+ Message string `json:"message"`
+ } `json:"error,omitempty"`
+}
+
+type anthropicStreamEvent struct {
+ Type string `json:"type"`
+ Delta *struct {
+ Text string `json:"text"`
+ } `json:"delta,omitempty"`
+}
+
+func (p *AnthropicProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
+ if err := p.Validate(); err != nil {
+ return nil, err
+ }
+
+ systemMsg, messages := extractSystemMessage(req.Messages)
+ anthropicMsgs := make([]anthropicMessage, len(messages))
+ for i, m := range messages {
+ anthropicMsgs[i] = anthropicMessage{Role: m.Role, Content: m.Content}
+ }
+
+ temperature := req.Temperature
+ if temperature <= 0 {
+ temperature = p.config.Temperature
+ }
+ maxTokens := req.MaxTokens
+ if maxTokens <= 0 {
+ maxTokens = p.config.MaxTokens
+ }
+
+ body := anthropicRequest{
+ Model: p.config.Model,
+ Messages: anthropicMsgs,
+ System: systemMsg,
+ MaxTokens: maxTokens,
+ Temperature: temperature,
+ }
+
+ respBody, err := p.doRequest(ctx, body)
+ if err != nil {
+ return nil, err
+ }
+ defer respBody.Close()
+
+ var result anthropicResponse
+ if err := json.NewDecoder(respBody).Decode(&result); err != nil {
+ return nil, fmt.Errorf("解析 Anthropic 响应失败: %w", err)
+ }
+ if result.Error != nil && result.Error.Message != "" {
+ return nil, fmt.Errorf("Anthropic API 错误: %s", result.Error.Message)
+ }
+ if len(result.Content) == 0 {
+ return nil, fmt.Errorf("Anthropic 返回空响应")
+ }
+
+ return &ai.ChatResponse{
+ Content: result.Content[0].Text,
+ TokensUsed: ai.TokenUsage{
+ PromptTokens: result.Usage.InputTokens,
+ CompletionTokens: result.Usage.OutputTokens,
+ TotalTokens: result.Usage.InputTokens + result.Usage.OutputTokens,
+ },
+ }, nil
+}
+
+func (p *AnthropicProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
+ if err := p.Validate(); err != nil {
+ return err
+ }
+
+ systemMsg, messages := extractSystemMessage(req.Messages)
+ anthropicMsgs := make([]anthropicMessage, len(messages))
+ for i, m := range messages {
+ anthropicMsgs[i] = anthropicMessage{Role: m.Role, Content: m.Content}
+ }
+
+ temperature := req.Temperature
+ if temperature <= 0 {
+ temperature = p.config.Temperature
+ }
+ maxTokens := req.MaxTokens
+ if maxTokens <= 0 {
+ maxTokens = p.config.MaxTokens
+ }
+
+ body := anthropicRequest{
+ Model: p.config.Model,
+ Messages: anthropicMsgs,
+ System: systemMsg,
+ MaxTokens: maxTokens,
+ Temperature: temperature,
+ Stream: true,
+ }
+
+ respBody, err := p.doRequest(ctx, body)
+ if err != nil {
+ return err
+ }
+ defer respBody.Close()
+
+ scanner := bufio.NewScanner(respBody)
+ for scanner.Scan() {
+ line := scanner.Text()
+ if !strings.HasPrefix(line, "data: ") {
+ continue
+ }
+ data := strings.TrimPrefix(line, "data: ")
+
+ var event anthropicStreamEvent
+ if err := json.Unmarshal([]byte(data), &event); err != nil {
+ continue
+ }
+
+ switch event.Type {
+ case "content_block_delta":
+ if event.Delta != nil && event.Delta.Text != "" {
+ callback(ai.StreamChunk{Content: event.Delta.Text})
+ }
+ case "message_stop":
+ callback(ai.StreamChunk{Done: true})
+ return nil
+ }
+ }
+
+ callback(ai.StreamChunk{Done: true})
+ return scanner.Err()
+}
+
+func (p *AnthropicProvider) doRequest(ctx context.Context, body interface{}) (io.ReadCloser, error) {
+ jsonBody, err := json.Marshal(body)
+ if err != nil {
+ return nil, fmt.Errorf("序列化请求失败: %w", err)
+ }
+
+ url := p.baseURL + "/v1/messages"
+ if strings.HasSuffix(p.baseURL, "/v1") {
+ url = p.baseURL + "/messages"
+ }
+
+ // 调试日志:打印实际请求信息
+ bodyStr := string(jsonBody)
+ if len(bodyStr) > 500 {
+ bodyStr = bodyStr[:500] + "..."
+ }
+ fmt.Printf("[Anthropic DEBUG] URL: %s\n", url)
+ fmt.Printf("[Anthropic DEBUG] BaseURL: %s\n", p.baseURL)
+ fmt.Printf("[Anthropic DEBUG] Body: %s\n", bodyStr)
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
+ if err != nil {
+ return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
+ }
+
+ httpReq.Header.Set("Content-Type", "application/json")
+ httpReq.Header.Set("x-api-key", p.config.APIKey)
+ httpReq.Header.Set("anthropic-version", anthropicAPIVersion)
+
+ // 仅官方 API 发 beta 特性头(代理不发,避免触发 Claude Code 验证)
+ isOfficialAPI := p.baseURL == defaultAnthropicBaseURL || strings.Contains(p.baseURL, "anthropic.com")
+ if isOfficialAPI {
+ httpReq.Header.Set("anthropic-beta", "interleaved-thinking-2025-05-14,output-128k-2025-02-19,prompt-caching-2024-07-31")
+ }
+
+ // 自定义 headers(用于兼容各类代理服务)
+ for k, v := range p.config.Headers {
+ httpReq.Header.Set(k, v)
+ }
+
+ resp, err := p.client.Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("发送请求到 %s 失败: %w", url, err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ defer resp.Body.Close()
+ bodyBytes, _ := io.ReadAll(resp.Body)
+ return nil, fmt.Errorf("Anthropic API 返回错误 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
+ }
+
+ return resp.Body, nil
+}
+
+// extractSystemMessage 从消息列表中提取 system 消息(Anthropic 要求 system 作为独立字段)
+func extractSystemMessage(messages []ai.Message) (string, []ai.Message) {
+ var systemParts []string
+ var remaining []ai.Message
+ for _, m := range messages {
+ if m.Role == "system" {
+ systemParts = append(systemParts, m.Content)
+ } else {
+ remaining = append(remaining, m)
+ }
+ }
+ return strings.Join(systemParts, "\n\n"), remaining
+}
diff --git a/internal/ai/provider/claude_cli.go b/internal/ai/provider/claude_cli.go
new file mode 100644
index 0000000..824e413
--- /dev/null
+++ b/internal/ai/provider/claude_cli.go
@@ -0,0 +1,227 @@
+package provider
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "os/exec"
+ "strings"
+
+ ai "GoNavi-Wails/internal/ai"
+)
+
+// ClaudeCLIProvider 通过 Claude Code CLI 发送聊天请求
+// 适用于 anyrouter/newapi 等只支持 Claude Code 协议的代理服务
+type ClaudeCLIProvider struct {
+ config ai.ProviderConfig
+}
+
+// NewClaudeCLIProvider 创建 ClaudeCLIProvider 实例
+func NewClaudeCLIProvider(config ai.ProviderConfig) (Provider, error) {
+ return &ClaudeCLIProvider{config: config}, nil
+}
+
+func (p *ClaudeCLIProvider) Name() string {
+ return "ClaudeCLI"
+}
+
+func (p *ClaudeCLIProvider) Validate() error {
+ _, err := exec.LookPath("claude")
+ if err != nil {
+ return fmt.Errorf("未找到 claude 命令,请先安装 Claude Code CLI: npm install -g @anthropic-ai/claude-code")
+ }
+ return nil
+}
+
+// Chat 非流式聊天:调用 claude -p "prompt" --output-format json
+func (p *ClaudeCLIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
+ if err := p.Validate(); err != nil {
+ return nil, err
+ }
+
+ prompt := buildPrompt(req.Messages)
+ args := []string{"-p", prompt, "--output-format", "json", "--no-session-persistence"}
+ if p.config.Model != "" {
+ args = append(args, "--model", p.config.Model)
+ }
+
+ cmd := exec.CommandContext(ctx, "claude", args...)
+ p.setEnv(cmd)
+
+ output, err := cmd.Output()
+ if err != nil {
+ if exitErr, ok := err.(*exec.ExitError); ok {
+ return nil, fmt.Errorf("claude CLI 执行失败: %s", string(exitErr.Stderr))
+ }
+ return nil, fmt.Errorf("claude CLI 执行失败: %w", err)
+ }
+
+ // 解析 JSON 输出
+ var result struct {
+ Result string `json:"result"`
+ }
+ if err := json.Unmarshal(output, &result); err != nil {
+ // 如果 JSON 解析失败,直接返回原始文本
+ return &ai.ChatResponse{Content: strings.TrimSpace(string(output))}, nil
+ }
+
+ return &ai.ChatResponse{Content: result.Result}, nil
+}
+
+// ChatStream 流式聊天:调用 claude -p "prompt" --output-format stream-json
+func (p *ClaudeCLIProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
+ if err := p.Validate(); err != nil {
+ return err
+ }
+
+ prompt := buildPrompt(req.Messages)
+ args := []string{"-p", prompt, "--output-format", "stream-json", "--verbose", "--include-partial-messages", "--no-session-persistence"}
+ if p.config.Model != "" {
+ args = append(args, "--model", p.config.Model)
+ }
+
+ fmt.Printf("[ClaudeCLI DEBUG] Running: claude %v\n", args)
+
+ cmd := exec.CommandContext(ctx, "claude", args...)
+ p.setEnv(cmd)
+
+ // 关闭 stdin,防止 claude CLI 等待输入
+ cmd.Stdin = nil
+
+ stdout, err := cmd.StdoutPipe()
+ if err != nil {
+ return fmt.Errorf("创建 stdout 管道失败: %w", err)
+ }
+
+ // 捕获 stderr
+ var stderrBuf bytes.Buffer
+ cmd.Stderr = &stderrBuf
+
+ if err := cmd.Start(); err != nil {
+ return fmt.Errorf("启动 claude CLI 失败: %w", err)
+ }
+
+ fmt.Printf("[ClaudeCLI DEBUG] Process started, PID: %d\n", cmd.Process.Pid)
+
+ // 立即通知前端:AI 正在思考(避免用户以为卡死)
+ callback(ai.StreamChunk{Content: "💭 *正在思考...*\n\n"})
+
+ // 逐行读取流式 JSON 输出
+ scanner := bufio.NewScanner(stdout)
+ scanner.Buffer(make([]byte, 64*1024), 1024*1024)
+
+ for scanner.Scan() {
+ line := scanner.Text()
+ if strings.TrimSpace(line) == "" {
+ continue
+ }
+
+ fmt.Printf("[ClaudeCLI DEBUG] Line: %s\n", line[:min(len(line), 200)])
+
+ var event cliStreamEvent
+ if err := json.Unmarshal([]byte(line), &event); err != nil {
+ fmt.Printf("[ClaudeCLI DEBUG] Non-JSON line: %s\n", line)
+ continue
+ }
+
+ switch event.Type {
+ case "assistant":
+ // 助手消息开始或文本内容
+ if event.Message.Content != nil {
+ for _, block := range event.Message.Content {
+ if block.Type == "text" && block.Text != "" {
+ callback(ai.StreamChunk{Content: block.Text})
+ }
+ }
+ }
+ case "content_block_delta":
+ // 增量文本
+ if event.Delta.Text != "" {
+ callback(ai.StreamChunk{Content: event.Delta.Text})
+ }
+ case "result":
+ // 最终结果事件 — 不发送 content(assistant 事件已包含),只标记完成
+ callback(ai.StreamChunk{Done: true})
+ _ = cmd.Wait()
+ return nil
+ case "error":
+ callback(ai.StreamChunk{Error: event.Error.Message, Done: true})
+ _ = cmd.Wait()
+ return nil
+ }
+ }
+
+ waitErr := cmd.Wait()
+ stderrStr := strings.TrimSpace(stderrBuf.String())
+ fmt.Printf("[ClaudeCLI DEBUG] Process exited. stderr: %s\n", stderrStr)
+
+ if waitErr != nil {
+ errMsg := fmt.Sprintf("claude CLI 异常退出: %v", waitErr)
+ if stderrStr != "" {
+ errMsg = fmt.Sprintf("claude CLI 异常退出: %s", stderrStr)
+ }
+ callback(ai.StreamChunk{Error: errMsg, Done: true})
+ return nil
+ }
+
+ callback(ai.StreamChunk{Done: true})
+ return nil
+}
+
+// setEnv 设置 Claude CLI 的环境变量
+func (p *ClaudeCLIProvider) setEnv(cmd *exec.Cmd) {
+ env := cmd.Environ()
+ if p.config.BaseURL != "" {
+ baseURL := strings.TrimRight(p.config.BaseURL, "/")
+ env = append(env, "ANTHROPIC_BASE_URL="+baseURL)
+ }
+ if p.config.APIKey != "" {
+ env = append(env, "ANTHROPIC_API_KEY="+p.config.APIKey)
+ }
+ cmd.Env = env
+}
+
+// buildPrompt 将消息列表拼接为适合 claude -p 的提示文本
+func buildPrompt(messages []ai.Message) string {
+ if len(messages) == 1 {
+ return messages[0].Content
+ }
+
+ var sb strings.Builder
+ for _, m := range messages {
+ switch m.Role {
+ case "system":
+ sb.WriteString("[System]\n")
+ sb.WriteString(m.Content)
+ sb.WriteString("\n\n")
+ case "user":
+ sb.WriteString(m.Content)
+ sb.WriteString("\n\n")
+ case "assistant":
+ sb.WriteString("[Previous Assistant Response]\n")
+ sb.WriteString(m.Content)
+ sb.WriteString("\n\n")
+ }
+ }
+ return strings.TrimSpace(sb.String())
+}
+
+// cliStreamEvent Claude CLI stream-json 输出的事件结构
+type cliStreamEvent struct {
+ Type string `json:"type"`
+ Message struct {
+ Content []struct {
+ Type string `json:"type"`
+ Text string `json:"text"`
+ } `json:"content"`
+ } `json:"message,omitempty"`
+ Delta struct {
+ Text string `json:"text"`
+ } `json:"delta,omitempty"`
+ Result string `json:"result,omitempty"`
+ Error struct {
+ Message string `json:"message"`
+ } `json:"error,omitempty"`
+}
diff --git a/internal/ai/provider/custom.go b/internal/ai/provider/custom.go
new file mode 100644
index 0000000..7900ec2
--- /dev/null
+++ b/internal/ai/provider/custom.go
@@ -0,0 +1,74 @@
+package provider
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "GoNavi-Wails/internal/ai"
+)
+
+// CustomProvider 自定义 Provider,根据 apiFormat 选择底层协议
+// 支持 openai / anthropic / gemini 三种 API 格式
+type CustomProvider struct {
+ inner Provider
+ name string
+}
+
+// NewCustomProvider 创建自定义 Provider 实例
+func NewCustomProvider(config ai.ProviderConfig) (Provider, error) {
+ if strings.TrimSpace(config.BaseURL) == "" {
+ return nil, fmt.Errorf("自定义 Provider 必须指定 Base URL")
+ }
+
+ // 根据 apiFormat 决定使用哪个底层协议,默认 openai
+ apiFormat := strings.ToLower(strings.TrimSpace(config.APIFormat))
+ if apiFormat == "" {
+ apiFormat = "openai"
+ }
+
+ var innerProvider Provider
+ var err error
+ switch apiFormat {
+ case "anthropic":
+ innerProvider, err = NewAnthropicProvider(config)
+ case "gemini":
+ innerProvider, err = NewGeminiProvider(config)
+ case "claude-cli":
+ innerProvider, err = NewClaudeCLIProvider(config)
+ default: // "openai" 及其他
+ innerProvider, err = NewOpenAIProvider(config)
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ name := strings.TrimSpace(config.Name)
+ if name == "" {
+ name = "Custom"
+ }
+
+ return &CustomProvider{
+ inner: innerProvider,
+ name: name,
+ }, nil
+}
+
+func (p *CustomProvider) Name() string {
+ return p.name
+}
+
+func (p *CustomProvider) Validate() error {
+ if strings.TrimSpace(p.inner.(interface{ Name() string }).Name()) == "" {
+ // 对自定义 Provider,API Key 可选(部分本地服务不需要)
+ }
+ return nil
+}
+
+func (p *CustomProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
+ return p.inner.Chat(ctx, req)
+}
+
+func (p *CustomProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
+ return p.inner.ChatStream(ctx, req, callback)
+}
diff --git a/internal/ai/provider/gemini.go b/internal/ai/provider/gemini.go
new file mode 100644
index 0000000..0c5eee7
--- /dev/null
+++ b/internal/ai/provider/gemini.go
@@ -0,0 +1,267 @@
+package provider
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+
+ "GoNavi-Wails/internal/ai"
+)
+
+const (
+ defaultGeminiBaseURL = "https://generativelanguage.googleapis.com"
+ defaultGeminiModel = "gemini-2.0-flash"
+)
+
+// GeminiProvider 实现 Google Gemini API 的 Provider
+type GeminiProvider struct {
+ config ai.ProviderConfig
+ baseURL string
+ client *http.Client
+}
+
+// NewGeminiProvider 创建 Gemini Provider 实例
+func NewGeminiProvider(config ai.ProviderConfig) (Provider, error) {
+ baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
+ if baseURL == "" {
+ baseURL = defaultGeminiBaseURL
+ }
+ model := strings.TrimSpace(config.Model)
+ if model == "" {
+ model = defaultGeminiModel
+ }
+ maxTokens := config.MaxTokens
+ if maxTokens <= 0 {
+ maxTokens = defaultOpenAIMaxTokens
+ }
+ temperature := config.Temperature
+ if temperature <= 0 {
+ temperature = defaultOpenAITemperature
+ }
+
+ normalized := config
+ normalized.BaseURL = baseURL
+ normalized.Model = model
+ normalized.MaxTokens = maxTokens
+ normalized.Temperature = temperature
+
+ return &GeminiProvider{
+ config: normalized,
+ baseURL: baseURL,
+ client: &http.Client{Timeout: openAIHTTPTimeout},
+ }, nil
+}
+
+func (p *GeminiProvider) Name() string {
+ if strings.TrimSpace(p.config.Name) != "" {
+ return p.config.Name
+ }
+ return "Gemini"
+}
+
+func (p *GeminiProvider) Validate() error {
+ if strings.TrimSpace(p.config.APIKey) == "" {
+ return fmt.Errorf("API Key 不能为空")
+ }
+ return nil
+}
+
+type geminiRequest struct {
+ Contents []geminiContent `json:"contents"`
+ SystemInstruction *geminiContent `json:"systemInstruction,omitempty"`
+ GenerationConfig geminiGenConfig `json:"generationConfig,omitempty"`
+}
+
+type geminiContent struct {
+ Role string `json:"role,omitempty"`
+ Parts []geminiPart `json:"parts"`
+}
+
+type geminiPart struct {
+ Text string `json:"text"`
+}
+
+type geminiGenConfig struct {
+ Temperature float64 `json:"temperature,omitempty"`
+ MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
+}
+
+type geminiResponse struct {
+ Candidates []struct {
+ Content struct {
+ Parts []struct {
+ Text string `json:"text"`
+ } `json:"parts"`
+ } `json:"content"`
+ } `json:"candidates"`
+ UsageMetadata *struct {
+ PromptTokenCount int `json:"promptTokenCount"`
+ CandidatesTokenCount int `json:"candidatesTokenCount"`
+ TotalTokenCount int `json:"totalTokenCount"`
+ } `json:"usageMetadata"`
+ Error *struct {
+ Message string `json:"message"`
+ } `json:"error,omitempty"`
+}
+
+func (p *GeminiProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
+ if err := p.Validate(); err != nil {
+ return nil, err
+ }
+
+ geminiReq := p.buildRequest(req)
+
+ url := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s",
+ p.baseURL, p.config.Model, p.config.APIKey)
+
+ respBody, err := p.doRequest(ctx, url, geminiReq)
+ if err != nil {
+ return nil, err
+ }
+ defer respBody.Close()
+
+ var result geminiResponse
+ if err := json.NewDecoder(respBody).Decode(&result); err != nil {
+ return nil, fmt.Errorf("解析 Gemini 响应失败: %w", err)
+ }
+ if result.Error != nil && result.Error.Message != "" {
+ return nil, fmt.Errorf("Gemini API 错误: %s", result.Error.Message)
+ }
+ if len(result.Candidates) == 0 || len(result.Candidates[0].Content.Parts) == 0 {
+ return nil, fmt.Errorf("Gemini 返回空响应")
+ }
+
+ var tokens ai.TokenUsage
+ if result.UsageMetadata != nil {
+ tokens = ai.TokenUsage{
+ PromptTokens: result.UsageMetadata.PromptTokenCount,
+ CompletionTokens: result.UsageMetadata.CandidatesTokenCount,
+ TotalTokens: result.UsageMetadata.TotalTokenCount,
+ }
+ }
+
+ var textParts []string
+ for _, part := range result.Candidates[0].Content.Parts {
+ if part.Text != "" {
+ textParts = append(textParts, part.Text)
+ }
+ }
+
+ return &ai.ChatResponse{
+ Content: strings.Join(textParts, ""),
+ TokensUsed: tokens,
+ }, nil
+}
+
+func (p *GeminiProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
+ if err := p.Validate(); err != nil {
+ return err
+ }
+
+ geminiReq := p.buildRequest(req)
+
+ url := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s",
+ p.baseURL, p.config.Model, p.config.APIKey)
+
+ respBody, err := p.doRequest(ctx, url, geminiReq)
+ if err != nil {
+ return err
+ }
+ defer respBody.Close()
+
+ scanner := bufio.NewScanner(respBody)
+ for scanner.Scan() {
+ line := scanner.Text()
+ if !strings.HasPrefix(line, "data: ") {
+ continue
+ }
+ data := strings.TrimPrefix(line, "data: ")
+
+ var chunk geminiResponse
+ if err := json.Unmarshal([]byte(data), &chunk); err != nil {
+ continue
+ }
+
+ if len(chunk.Candidates) > 0 && len(chunk.Candidates[0].Content.Parts) > 0 {
+ for _, part := range chunk.Candidates[0].Content.Parts {
+ if part.Text != "" {
+ callback(ai.StreamChunk{Content: part.Text})
+ }
+ }
+ }
+ }
+
+ callback(ai.StreamChunk{Done: true})
+ return scanner.Err()
+}
+
+func (p *GeminiProvider) buildRequest(req ai.ChatRequest) geminiRequest {
+ temperature := req.Temperature
+ if temperature <= 0 {
+ temperature = p.config.Temperature
+ }
+ maxTokens := req.MaxTokens
+ if maxTokens <= 0 {
+ maxTokens = p.config.MaxTokens
+ }
+
+ var systemInstruction *geminiContent
+ var contents []geminiContent
+
+ for _, m := range req.Messages {
+ if m.Role == "system" {
+ systemInstruction = &geminiContent{
+ Parts: []geminiPart{{Text: m.Content}},
+ }
+ continue
+ }
+ role := m.Role
+ if role == "assistant" {
+ role = "model"
+ }
+ contents = append(contents, geminiContent{
+ Role: role,
+ Parts: []geminiPart{{Text: m.Content}},
+ })
+ }
+
+ return geminiRequest{
+ Contents: contents,
+ SystemInstruction: systemInstruction,
+ GenerationConfig: geminiGenConfig{
+ Temperature: temperature,
+ MaxOutputTokens: maxTokens,
+ },
+ }
+}
+
+func (p *GeminiProvider) doRequest(ctx context.Context, url string, body interface{}) (io.ReadCloser, error) {
+ jsonBody, err := json.Marshal(body)
+ if err != nil {
+ return nil, fmt.Errorf("序列化请求失败: %w", err)
+ }
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
+ if err != nil {
+ return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
+ }
+ httpReq.Header.Set("Content-Type", "application/json")
+
+ resp, err := p.client.Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("发送请求到 Gemini 失败: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ defer resp.Body.Close()
+ bodyBytes, _ := io.ReadAll(resp.Body)
+ return nil, fmt.Errorf("Gemini API 返回错误 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
+ }
+
+ return resp.Body, nil
+}
diff --git a/internal/ai/provider/openai.go b/internal/ai/provider/openai.go
new file mode 100644
index 0000000..ff674a9
--- /dev/null
+++ b/internal/ai/provider/openai.go
@@ -0,0 +1,316 @@
+package provider
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+
+ "GoNavi-Wails/internal/ai"
+)
+
+const (
+ defaultOpenAIBaseURL = "https://api.openai.com/v1"
+ defaultOpenAIModel = "gpt-4o"
+ defaultOpenAIMaxTokens = 4096
+ defaultOpenAITemperature = 0.7
+ openAIHTTPTimeout = 120 * time.Second
+)
+
+// OpenAIProvider 实现 OpenAI / OpenAI 兼容 API 的 Provider
+type OpenAIProvider struct {
+ config ai.ProviderConfig
+ baseURL string
+ client *http.Client
+}
+
+// NewOpenAIProvider 创建 OpenAI Provider 实例
+func NewOpenAIProvider(config ai.ProviderConfig) (Provider, error) {
+ baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
+ if baseURL == "" {
+ baseURL = defaultOpenAIBaseURL
+ }
+ // 确保 baseURL 包含 /v1 路径(兼容用户只填域名的情况,如 https://anyrouter.top)
+ if !strings.HasSuffix(baseURL, "/v1") && !strings.Contains(baseURL, "/v1/") {
+ baseURL = baseURL + "/v1"
+ }
+ model := strings.TrimSpace(config.Model)
+ if model == "" {
+ model = defaultOpenAIModel
+ }
+ maxTokens := config.MaxTokens
+ if maxTokens <= 0 {
+ maxTokens = defaultOpenAIMaxTokens
+ }
+ temperature := config.Temperature
+ if temperature <= 0 {
+ temperature = defaultOpenAITemperature
+ }
+
+ normalized := config
+ normalized.BaseURL = baseURL
+ normalized.Model = model
+ normalized.MaxTokens = maxTokens
+ normalized.Temperature = temperature
+
+ return &OpenAIProvider{
+ config: normalized,
+ baseURL: baseURL,
+ client: &http.Client{
+ Timeout: openAIHTTPTimeout,
+ },
+ }, nil
+}
+
+func (p *OpenAIProvider) Name() string {
+ if strings.TrimSpace(p.config.Name) != "" {
+ return p.config.Name
+ }
+ return "OpenAI"
+}
+
+func (p *OpenAIProvider) Validate() error {
+ if strings.TrimSpace(p.config.APIKey) == "" {
+ return fmt.Errorf("API Key 不能为空")
+ }
+ return nil
+}
+
+// openAIChatRequest OpenAI API 请求体
+type openAIChatRequest struct {
+ Model string `json:"model"`
+ Messages []openAIChatMessage `json:"messages"`
+ Temperature float64 `json:"temperature,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+}
+
+type openAIChatMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+// openAIChatResponse OpenAI API 响应体
+type openAIChatResponse struct {
+ Choices []struct {
+ Message struct {
+ Content string `json:"content"`
+ } `json:"message"`
+ FinishReason string `json:"finish_reason"`
+ } `json:"choices"`
+ Usage struct {
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ TotalTokens int `json:"total_tokens"`
+ } `json:"usage"`
+ Error *struct {
+ Message string `json:"message"`
+ } `json:"error,omitempty"`
+}
+
+// openAIStreamChunk SSE 流式响应片段
+type openAIStreamChunk struct {
+ Choices []struct {
+ Delta struct {
+ Content string `json:"content"`
+ } `json:"delta"`
+ FinishReason *string `json:"finish_reason"`
+ } `json:"choices"`
+ Error *struct {
+ Message string `json:"message"`
+ } `json:"error,omitempty"`
+}
+
+func (p *OpenAIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error) {
+ if err := p.Validate(); err != nil {
+ return nil, err
+ }
+
+ messages := make([]openAIChatMessage, len(req.Messages))
+ for i, m := range req.Messages {
+ messages[i] = openAIChatMessage{Role: m.Role, Content: m.Content}
+ }
+
+ temperature := req.Temperature
+ if temperature <= 0 {
+ temperature = p.config.Temperature
+ }
+ maxTokens := req.MaxTokens
+ if maxTokens <= 0 {
+ maxTokens = p.config.MaxTokens
+ }
+
+ body := openAIChatRequest{
+ Model: p.config.Model,
+ Messages: messages,
+ Temperature: temperature,
+ MaxTokens: maxTokens,
+ Stream: false,
+ }
+
+ respBody, err := p.doRequest(ctx, body)
+ if err != nil {
+ return nil, err
+ }
+ defer respBody.Close()
+
+ var result openAIChatResponse
+ if err := json.NewDecoder(respBody).Decode(&result); err != nil {
+ return nil, fmt.Errorf("解析 OpenAI 响应失败: %w", err)
+ }
+ if result.Error != nil && result.Error.Message != "" {
+ return nil, fmt.Errorf("OpenAI API 错误: %s", result.Error.Message)
+ }
+ if len(result.Choices) == 0 {
+ return nil, fmt.Errorf("OpenAI 返回空响应")
+ }
+
+ return &ai.ChatResponse{
+ Content: result.Choices[0].Message.Content,
+ TokensUsed: ai.TokenUsage{
+ PromptTokens: result.Usage.PromptTokens,
+ CompletionTokens: result.Usage.CompletionTokens,
+ TotalTokens: result.Usage.TotalTokens,
+ },
+ }, nil
+}
+
+func (p *OpenAIProvider) ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error {
+ if err := p.Validate(); err != nil {
+ return err
+ }
+
+ messages := make([]openAIChatMessage, len(req.Messages))
+ for i, m := range req.Messages {
+ messages[i] = openAIChatMessage{Role: m.Role, Content: m.Content}
+ }
+
+ temperature := req.Temperature
+ if temperature <= 0 {
+ temperature = p.config.Temperature
+ }
+ maxTokens := req.MaxTokens
+ if maxTokens <= 0 {
+ maxTokens = p.config.MaxTokens
+ }
+
+ body := openAIChatRequest{
+ Model: p.config.Model,
+ Messages: messages,
+ Temperature: temperature,
+ MaxTokens: maxTokens,
+ Stream: true,
+ }
+
+ respBody, err := p.doRequest(ctx, body)
+ if err != nil {
+ return err
+ }
+ defer respBody.Close()
+
+ receivedContent := false
+ scanner := bufio.NewScanner(respBody)
+ // 增大 scanner buffer,防止长行被截断
+ scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
+ for scanner.Scan() {
+ line := scanner.Text()
+ if line == "" {
+ continue
+ }
+ if !strings.HasPrefix(line, "data: ") {
+ // 非 SSE 数据行,可能是错误信息,记录日志
+ if strings.Contains(line, "error") || strings.Contains(line, "Error") {
+ callback(ai.StreamChunk{Error: fmt.Sprintf("服务端返回异常: %s", line), Done: true})
+ return nil
+ }
+ continue
+ }
+ data := strings.TrimPrefix(line, "data: ")
+ if data == "[DONE]" {
+ callback(ai.StreamChunk{Done: true})
+ return nil
+ }
+
+ var chunk openAIStreamChunk
+ if err := json.Unmarshal([]byte(data), &chunk); err != nil {
+ continue // 跳过格式异常的行
+ }
+ if chunk.Error != nil && chunk.Error.Message != "" {
+ callback(ai.StreamChunk{Error: fmt.Sprintf("API 错误: %s", chunk.Error.Message), Done: true})
+ return nil
+ }
+ if len(chunk.Choices) > 0 {
+ content := chunk.Choices[0].Delta.Content
+ if content != "" {
+ receivedContent = true
+ callback(ai.StreamChunk{Content: content})
+ }
+ if chunk.Choices[0].FinishReason != nil {
+ callback(ai.StreamChunk{Done: true})
+ return nil
+ }
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ return fmt.Errorf("读取 OpenAI 流式响应失败: %w", err)
+ }
+
+ // 如果流正常结束但没有收到任何内容,可能是 API 响应格式不兼容
+ if !receivedContent {
+ callback(ai.StreamChunk{Error: "未收到任何有效响应内容,请检查 API 端点和模型是否正确", Done: true})
+ return nil
+ }
+
+ callback(ai.StreamChunk{Done: true})
+ return nil
+}
+
+func (p *OpenAIProvider) doRequest(ctx context.Context, body interface{}) (io.ReadCloser, error) {
+ jsonBody, err := json.Marshal(body)
+ if err != nil {
+ return nil, fmt.Errorf("序列化请求失败: %w", err)
+ }
+
+ url := p.baseURL + "/chat/completions"
+
+ // 调试日志
+ bodyStr := string(jsonBody)
+ if len(bodyStr) > 500 {
+ bodyStr = bodyStr[:500] + "..."
+ }
+ fmt.Printf("[OpenAI DEBUG] URL: %s\n", url)
+ fmt.Printf("[OpenAI DEBUG] BaseURL: %s\n", p.baseURL)
+ fmt.Printf("[OpenAI DEBUG] Body: %s\n", bodyStr)
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
+ if err != nil {
+ return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
+ }
+
+ httpReq.Header.Set("Content-Type", "application/json")
+ httpReq.Header.Set("Authorization", "Bearer "+p.config.APIKey)
+
+ // 自定义 headers(用于兼容各类 OpenAI 兼容服务)
+ for k, v := range p.config.Headers {
+ httpReq.Header.Set(k, v)
+ }
+
+ resp, err := p.client.Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("发送请求到 %s 失败: %w", url, err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ defer resp.Body.Close()
+ bodyBytes, _ := io.ReadAll(resp.Body)
+ return nil, fmt.Errorf("OpenAI API 返回错误 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
+ }
+
+ return resp.Body, nil
+}
diff --git a/internal/ai/provider/openai_test.go b/internal/ai/provider/openai_test.go
new file mode 100644
index 0000000..94671a4
--- /dev/null
+++ b/internal/ai/provider/openai_test.go
@@ -0,0 +1,86 @@
+package provider
+
+import (
+ "GoNavi-Wails/internal/ai"
+ "testing"
+)
+
+func TestOpenAIProvider_Validate_MissingAPIKey(t *testing.T) {
+ p, err := NewOpenAIProvider(ai.ProviderConfig{Type: "openai", Model: "gpt-4o"})
+ if err != nil {
+ t.Fatalf("unexpected constructor error: %v", err)
+ }
+ if err := p.Validate(); err == nil {
+ t.Fatal("expected validation error for missing API key")
+ }
+}
+
+func TestOpenAIProvider_Validate_Valid(t *testing.T) {
+ p, err := NewOpenAIProvider(ai.ProviderConfig{
+ Type: "openai", APIKey: "sk-test-key", Model: "gpt-4o",
+ })
+ if err != nil {
+ t.Fatalf("unexpected constructor error: %v", err)
+ }
+ if err := p.Validate(); err != nil {
+ t.Fatalf("unexpected validation error: %v", err)
+ }
+}
+
+func TestOpenAIProvider_Name_Custom(t *testing.T) {
+ p, _ := NewOpenAIProvider(ai.ProviderConfig{
+ Type: "openai", Name: "My OpenAI", APIKey: "sk-test",
+ })
+ if p.Name() != "My OpenAI" {
+ t.Fatalf("expected name 'My OpenAI', got '%s'", p.Name())
+ }
+}
+
+func TestOpenAIProvider_Name_Default(t *testing.T) {
+ p, _ := NewOpenAIProvider(ai.ProviderConfig{
+ Type: "openai", APIKey: "sk-test",
+ })
+ if p.Name() != "OpenAI" {
+ t.Fatalf("expected default name 'OpenAI', got '%s'", p.Name())
+ }
+}
+
+func TestOpenAIProvider_DefaultBaseURL(t *testing.T) {
+ p, _ := NewOpenAIProvider(ai.ProviderConfig{
+ Type: "openai", APIKey: "sk-test", Model: "gpt-4o",
+ })
+ op := p.(*OpenAIProvider)
+ if op.baseURL != "https://api.openai.com/v1" {
+ t.Fatalf("expected default base URL, got '%s'", op.baseURL)
+ }
+}
+
+func TestOpenAIProvider_CustomBaseURL(t *testing.T) {
+ p, _ := NewOpenAIProvider(ai.ProviderConfig{
+ Type: "openai", APIKey: "sk-test", BaseURL: "https://my-proxy.com/v1",
+ })
+ op := p.(*OpenAIProvider)
+ if op.baseURL != "https://my-proxy.com/v1" {
+ t.Fatalf("expected custom base URL, got '%s'", op.baseURL)
+ }
+}
+
+func TestOpenAIProvider_DefaultModel(t *testing.T) {
+ p, _ := NewOpenAIProvider(ai.ProviderConfig{
+ Type: "openai", APIKey: "sk-test",
+ })
+ op := p.(*OpenAIProvider)
+ if op.config.Model != "gpt-4o" {
+ t.Fatalf("expected default model 'gpt-4o', got '%s'", op.config.Model)
+ }
+}
+
+func TestOpenAIProvider_DefaultMaxTokens(t *testing.T) {
+ p, _ := NewOpenAIProvider(ai.ProviderConfig{
+ Type: "openai", APIKey: "sk-test",
+ })
+ op := p.(*OpenAIProvider)
+ if op.config.MaxTokens != 4096 {
+ t.Fatalf("expected default max tokens 4096, got %d", op.config.MaxTokens)
+ }
+}
diff --git a/internal/ai/provider/provider.go b/internal/ai/provider/provider.go
new file mode 100644
index 0000000..e9f1d8e
--- /dev/null
+++ b/internal/ai/provider/provider.go
@@ -0,0 +1,19 @@
+package provider
+
+import (
+ "context"
+
+ "GoNavi-Wails/internal/ai"
+)
+
+// Provider AI 模型提供者接口
+type Provider interface {
+ // Chat 发送消息并获取完整响应
+ Chat(ctx context.Context, req ai.ChatRequest) (*ai.ChatResponse, error)
+ // ChatStream 发送消息并以流式返回
+ ChatStream(ctx context.Context, req ai.ChatRequest, callback func(ai.StreamChunk)) error
+ // Name 返回 Provider 名称
+ Name() string
+ // Validate 校验配置是否有效
+ Validate() error
+}
diff --git a/internal/ai/provider/registry.go b/internal/ai/provider/registry.go
new file mode 100644
index 0000000..2cd0c08
--- /dev/null
+++ b/internal/ai/provider/registry.go
@@ -0,0 +1,25 @@
+package provider
+
+import (
+ "fmt"
+ "strings"
+
+ "GoNavi-Wails/internal/ai"
+)
+
+// NewProvider 根据配置创建 Provider 实例
+func NewProvider(config ai.ProviderConfig) (Provider, error) {
+ providerType := strings.ToLower(strings.TrimSpace(config.Type))
+ switch providerType {
+ case "openai":
+ return NewOpenAIProvider(config)
+ case "anthropic":
+ return NewAnthropicProvider(config)
+ case "gemini":
+ return NewGeminiProvider(config)
+ case "custom":
+ return NewCustomProvider(config)
+ default:
+ return nil, fmt.Errorf("不支持的 AI Provider 类型: %s", config.Type)
+ }
+}
diff --git a/internal/ai/safety/classifier.go b/internal/ai/safety/classifier.go
new file mode 100644
index 0000000..dfd9816
--- /dev/null
+++ b/internal/ai/safety/classifier.go
@@ -0,0 +1,101 @@
+package safety
+
+import (
+ "strings"
+ "unicode"
+
+ "GoNavi-Wails/internal/ai"
+)
+
+// ClassifySQL 分类 SQL 语句的操作类型
+func ClassifySQL(sql string) ai.SQLOperationType {
+ keyword := leadingSQLKeyword(sql)
+ switch keyword {
+ case "select", "with", "show", "describe", "desc", "explain", "pragma", "values":
+ return ai.SQLOpQuery
+ case "insert", "update", "delete", "replace", "merge", "upsert":
+ return ai.SQLOpDML
+ case "create", "alter", "drop", "truncate", "rename":
+ return ai.SQLOpDDL
+ default:
+ return ai.SQLOpOther
+ }
+}
+
+// IsHighRiskSQL 判断 SQL 是否为高风险语句
+func IsHighRiskSQL(sql string) (bool, string) {
+ keyword := leadingSQLKeyword(sql)
+ normalized := strings.ToLower(sql)
+
+ switch keyword {
+ case "drop":
+ return true, "⚠️ 高危操作:DROP 语句将永久删除数据库对象"
+ case "truncate":
+ return true, "⚠️ 高危操作:TRUNCATE 将清空表中所有数据"
+ case "delete":
+ if !containsWhereClause(normalized) {
+ return true, "⚠️ 高危操作:DELETE 语句缺少 WHERE 条件,将删除所有数据"
+ }
+ case "update":
+ if !containsWhereClause(normalized) {
+ return true, "⚠️ 高危操作:UPDATE 语句缺少 WHERE 条件,将更新所有记录"
+ }
+ }
+
+ return false, ""
+}
+
+// containsWhereClause 简单判断 SQL 是否包含 WHERE 子句
+func containsWhereClause(normalizedSQL string) bool {
+ return strings.Contains(normalizedSQL, " where ") ||
+ strings.Contains(normalizedSQL, "\nwhere ") ||
+ strings.Contains(normalizedSQL, "\twhere ")
+}
+
+// leadingSQLKeyword 提取 SQL 语句的首个关键字(跳过注释和空白)
+func leadingSQLKeyword(query string) string {
+ text := strings.TrimSpace(query)
+ for len(text) > 0 {
+ trimmed := strings.TrimLeft(text, " \t\r\n")
+ if trimmed == "" {
+ return ""
+ }
+ text = trimmed
+
+ switch {
+ case strings.HasPrefix(text, "--"):
+ if idx := strings.IndexByte(text, '\n'); idx >= 0 {
+ text = text[idx+1:]
+ continue
+ }
+ return ""
+ case strings.HasPrefix(text, "#"):
+ if idx := strings.IndexByte(text, '\n'); idx >= 0 {
+ text = text[idx+1:]
+ continue
+ }
+ return ""
+ case strings.HasPrefix(text, "/*"):
+ if idx := strings.Index(text, "*/"); idx >= 0 {
+ text = text[idx+2:]
+ continue
+ }
+ return ""
+ }
+ break
+ }
+
+ if text == "" {
+ return ""
+ }
+ for i, r := range text {
+ if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' {
+ continue
+ }
+ if i == 0 {
+ return ""
+ }
+ return strings.ToLower(text[:i])
+ }
+ return strings.ToLower(text)
+}
diff --git a/internal/ai/safety/classifier_test.go b/internal/ai/safety/classifier_test.go
new file mode 100644
index 0000000..280b261
--- /dev/null
+++ b/internal/ai/safety/classifier_test.go
@@ -0,0 +1,145 @@
+package safety
+
+import (
+ "GoNavi-Wails/internal/ai"
+ "testing"
+)
+
+func TestClassifySQL(t *testing.T) {
+ tests := []struct {
+ sql string
+ want ai.SQLOperationType
+ }{
+ {"SELECT * FROM users", ai.SQLOpQuery},
+ {" select id from t", ai.SQLOpQuery},
+ {"SHOW TABLES", ai.SQLOpQuery},
+ {"DESCRIBE users", ai.SQLOpQuery},
+ {"DESC users", ai.SQLOpQuery},
+ {"EXPLAIN SELECT 1", ai.SQLOpQuery},
+ {"WITH cte AS (SELECT 1) SELECT * FROM cte", ai.SQLOpQuery},
+ {"PRAGMA table_info(t)", ai.SQLOpQuery},
+ {"VALUES (1, 2)", ai.SQLOpQuery},
+ {"INSERT INTO users VALUES (1)", ai.SQLOpDML},
+ {"UPDATE users SET name='x'", ai.SQLOpDML},
+ {"DELETE FROM users WHERE id=1", ai.SQLOpDML},
+ {"REPLACE INTO users VALUES (1)", ai.SQLOpDML},
+ {"MERGE INTO t USING s ON t.id=s.id", ai.SQLOpDML},
+ {"CREATE TABLE t (id INT)", ai.SQLOpDDL},
+ {"ALTER TABLE t ADD col INT", ai.SQLOpDDL},
+ {"DROP TABLE t", ai.SQLOpDDL},
+ {"TRUNCATE TABLE t", ai.SQLOpDDL},
+ {"RENAME TABLE old TO new", ai.SQLOpDDL},
+ {"/* comment */ SELECT 1", ai.SQLOpQuery},
+ {"-- comment\nDELETE FROM t", ai.SQLOpDML},
+ {"-- line1\n-- line2\nSELECT 1", ai.SQLOpQuery},
+ {"/* block */ -- line\nUPDATE t SET x=1", ai.SQLOpDML},
+ {"", ai.SQLOpOther},
+ {" ", ai.SQLOpOther},
+ {"-- only comment", ai.SQLOpOther},
+ }
+ for _, tt := range tests {
+ got := ClassifySQL(tt.sql)
+ if got != tt.want {
+ t.Errorf("ClassifySQL(%q) = %s, want %s", tt.sql, got, tt.want)
+ }
+ }
+}
+
+func TestIsHighRiskSQL(t *testing.T) {
+ tests := []struct {
+ sql string
+ highRisk bool
+ }{
+ {"DROP TABLE users", true},
+ {"DROP DATABASE test", true},
+ {"TRUNCATE TABLE users", true},
+ {"DELETE FROM users", true}, // 无 WHERE
+ {"DELETE FROM users WHERE id=1", false}, // 有 WHERE
+ {"UPDATE users SET name='x'", true}, // 无 WHERE
+ {"UPDATE users SET name='x' WHERE id=1", false}, // 有 WHERE
+ {"SELECT * FROM users", false},
+ {"INSERT INTO users VALUES (1)", false},
+ }
+ for _, tt := range tests {
+ highRisk, _ := IsHighRiskSQL(tt.sql)
+ if highRisk != tt.highRisk {
+ t.Errorf("IsHighRiskSQL(%q) = %v, want %v", tt.sql, highRisk, tt.highRisk)
+ }
+ }
+}
+
+func TestGuard_ReadOnly(t *testing.T) {
+ g := NewGuard(ai.PermissionReadOnly)
+ tests := []struct {
+ sql string
+ allowed bool
+ }{
+ {"SELECT * FROM t", true},
+ {"INSERT INTO t VALUES (1)", false},
+ {"UPDATE t SET x=1", false},
+ {"DELETE FROM t", false},
+ {"DROP TABLE t", false},
+ {"CREATE TABLE t (id INT)", false},
+ }
+ for _, tt := range tests {
+ result := g.Check(tt.sql)
+ if result.Allowed != tt.allowed {
+ t.Errorf("Guard[readonly].Check(%q).Allowed = %v, want %v", tt.sql, result.Allowed, tt.allowed)
+ }
+ }
+}
+
+func TestGuard_ReadWrite(t *testing.T) {
+ g := NewGuard(ai.PermissionReadWrite)
+ tests := []struct {
+ sql string
+ allowed bool
+ confirm bool
+ }{
+ {"SELECT * FROM t", true, false},
+ {"INSERT INTO t VALUES (1)", true, true},
+ {"UPDATE t SET x=1", true, true}, // 允许但需确认
+ {"DELETE FROM t WHERE id=1", true, true}, // 允许但需确认
+ {"DROP TABLE t", false, true}, // DDL 不允许
+ {"CREATE TABLE t (id INT)", false, true},
+ }
+ for _, tt := range tests {
+ result := g.Check(tt.sql)
+ if result.Allowed != tt.allowed {
+ t.Errorf("Guard[readwrite].Check(%q).Allowed = %v, want %v", tt.sql, result.Allowed, tt.allowed)
+ }
+ if result.RequiresConfirm != tt.confirm {
+ t.Errorf("Guard[readwrite].Check(%q).RequiresConfirm = %v, want %v", tt.sql, result.RequiresConfirm, tt.confirm)
+ }
+ }
+}
+
+func TestGuard_Full(t *testing.T) {
+ g := NewGuard(ai.PermissionFull)
+ tests := []struct {
+ sql string
+ allowed bool
+ }{
+ {"SELECT * FROM t", true},
+ {"INSERT INTO t VALUES (1)", true},
+ {"DROP TABLE t", true},
+ {"CREATE TABLE t (id INT)", true},
+ }
+ for _, tt := range tests {
+ result := g.Check(tt.sql)
+ if result.Allowed != tt.allowed {
+ t.Errorf("Guard[full].Check(%q).Allowed = %v, want %v", tt.sql, result.Allowed, tt.allowed)
+ }
+ }
+}
+
+func TestGuard_HighRiskWarning(t *testing.T) {
+ g := NewGuard(ai.PermissionFull)
+ result := g.Check("DELETE FROM users")
+ if result.WarningMessage == "" {
+ t.Error("expected high-risk warning for DELETE without WHERE")
+ }
+ if !result.RequiresConfirm {
+ t.Error("expected RequiresConfirm for high-risk SQL")
+ }
+}
diff --git a/internal/ai/safety/guard.go b/internal/ai/safety/guard.go
new file mode 100644
index 0000000..ca31bf2
--- /dev/null
+++ b/internal/ai/safety/guard.go
@@ -0,0 +1,71 @@
+package safety
+
+import (
+ "GoNavi-Wails/internal/ai"
+)
+
+// Guard AI SQL 安全策略守卫
+type Guard struct {
+ permissionLevel ai.SQLPermissionLevel
+}
+
+// NewGuard 创建安全策略守卫
+func NewGuard(level ai.SQLPermissionLevel) *Guard {
+ return &Guard{permissionLevel: level}
+}
+
+// SetPermissionLevel 设置权限级别
+func (g *Guard) SetPermissionLevel(level ai.SQLPermissionLevel) {
+ g.permissionLevel = level
+}
+
+// GetPermissionLevel 获取当前权限级别
+func (g *Guard) GetPermissionLevel() ai.SQLPermissionLevel {
+ return g.permissionLevel
+}
+
+// Check 检查 AI 生成的 SQL 是否在允许范围内
+func (g *Guard) Check(sql string) ai.SafetyResult {
+ opType := ClassifySQL(sql)
+ allowed := g.isAllowed(opType)
+ requiresConfirm := g.requiresConfirmation(opType)
+ warningMessage := ""
+
+ if isHighRisk, msg := IsHighRiskSQL(sql); isHighRisk {
+ warningMessage = msg
+ requiresConfirm = true
+ }
+
+ return ai.SafetyResult{
+ Allowed: allowed,
+ OperationType: opType,
+ RequiresConfirm: requiresConfirm,
+ WarningMessage: warningMessage,
+ }
+}
+
+func (g *Guard) isAllowed(opType ai.SQLOperationType) bool {
+ switch g.permissionLevel {
+ case ai.PermissionReadOnly:
+ return opType == ai.SQLOpQuery
+ case ai.PermissionReadWrite:
+ return opType == ai.SQLOpQuery || opType == ai.SQLOpDML
+ case ai.PermissionFull:
+ return opType == ai.SQLOpQuery || opType == ai.SQLOpDML || opType == ai.SQLOpDDL
+ default:
+ return opType == ai.SQLOpQuery
+ }
+}
+
+func (g *Guard) requiresConfirmation(opType ai.SQLOperationType) bool {
+ switch opType {
+ case ai.SQLOpQuery:
+ return false
+ case ai.SQLOpDML:
+ return true // DML 始终需要确认
+ case ai.SQLOpDDL:
+ return true // DDL 始终需要确认
+ default:
+ return true
+ }
+}
diff --git a/internal/ai/service/service.go b/internal/ai/service/service.go
new file mode 100644
index 0000000..52af6ba
--- /dev/null
+++ b/internal/ai/service/service.go
@@ -0,0 +1,573 @@
+package aiservice
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "time"
+
+ "GoNavi-Wails/internal/ai"
+ aicontext "GoNavi-Wails/internal/ai/context"
+ "GoNavi-Wails/internal/ai/provider"
+ "GoNavi-Wails/internal/ai/safety"
+ "GoNavi-Wails/internal/logger"
+
+ "github.com/google/uuid"
+ wailsRuntime "github.com/wailsapp/wails/v2/pkg/runtime"
+)
+
+// Service AI 服务,作为 Wails Binding 暴露给前端
+type Service struct {
+ ctx context.Context
+ mu sync.RWMutex
+ providers []ai.ProviderConfig
+ activeProvider string // active provider ID
+ safetyLevel ai.SQLPermissionLevel
+ contextLevel ai.ContextLevel
+ guard *safety.Guard
+ configDir string // 配置存储目录
+ cancelFuncs map[string]context.CancelFunc // 记录每个 session 的 context 取消函数
+}
+
+// NewService 创建 AI Service 实例
+func NewService() *Service {
+ return &Service{
+ providers: make([]ai.ProviderConfig, 0),
+ safetyLevel: ai.PermissionReadOnly,
+ contextLevel: ai.ContextSchemaOnly,
+ guard: safety.NewGuard(ai.PermissionReadOnly),
+ cancelFuncs: make(map[string]context.CancelFunc),
+ }
+}
+
+// Startup Wails 生命周期回调
+func (s *Service) Startup(ctx context.Context) {
+ s.ctx = ctx
+ s.configDir = resolveConfigDir()
+ s.loadConfig()
+ logger.Infof("AI Service 启动完成,已加载 %d 个 Provider", len(s.providers))
+}
+
+// --- Provider 管理 ---
+
+// AIGetProviders 获取所有 Provider 配置
+func (s *Service) AIGetProviders() []ai.ProviderConfig {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ result := make([]ai.ProviderConfig, len(s.providers))
+ copy(result, s.providers)
+ return result
+}
+
+// AISaveProvider 保存/更新 Provider 配置
+func (s *Service) AISaveProvider(config ai.ProviderConfig) error {
+ fmt.Printf("[AISaveProvider DEBUG] ID: %s, Model: %s\n", config.ID, config.Model)
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if strings.TrimSpace(config.ID) == "" {
+ config.ID = "provider-" + uuid.New().String()[:8]
+ }
+
+ found := false
+ for i, p := range s.providers {
+ if p.ID == config.ID {
+ s.providers[i] = config
+ found = true
+ break
+ }
+ }
+ if !found {
+ s.providers = append(s.providers, config)
+ }
+
+ return s.saveConfig()
+}
+
+// AIDeleteProvider 删除 Provider
+func (s *Service) AIDeleteProvider(id string) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ newProviders := make([]ai.ProviderConfig, 0, len(s.providers))
+ for _, p := range s.providers {
+ if p.ID != id {
+ newProviders = append(newProviders, p)
+ }
+ }
+ s.providers = newProviders
+
+ if s.activeProvider == id {
+ s.activeProvider = ""
+ if len(s.providers) > 0 {
+ s.activeProvider = s.providers[0].ID
+ }
+ }
+
+ return s.saveConfig()
+}
+
+// AITestProvider 测试 Provider 配置是否可用
+func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{} {
+ // 如果传入脱敏的 key,使用已保存的 key
+ s.mu.RLock()
+ if isMaskedAPIKey(config.APIKey) {
+ for _, p := range s.providers {
+ if p.ID == config.ID {
+ config.APIKey = p.APIKey
+ break
+ }
+ }
+ }
+ s.mu.RUnlock()
+
+ p, err := provider.NewProvider(config)
+ if err != nil {
+ return map[string]interface{}{"success": false, "message": err.Error()}
+ }
+ if err := p.Validate(); err != nil {
+ return map[string]interface{}{"success": false, "message": err.Error()}
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*1000*1000*1000) // 30s
+ defer cancel()
+
+ resp, err := p.Chat(ctx, ai.ChatRequest{
+ Messages: []ai.Message{
+ {Role: "user", Content: "Hi, please respond with 'OK' to confirm the connection is working."},
+ },
+ MaxTokens: 10,
+ })
+ if err != nil {
+ return map[string]interface{}{"success": false, "message": fmt.Sprintf("连接测试失败: %s", err.Error())}
+ }
+
+ return map[string]interface{}{
+ "success": true,
+ "message": fmt.Sprintf("连接成功!模型响应: %s", truncateString(resp.Content, 100)),
+ }
+}
+
+// AISetActiveProvider 设置活动 Provider
+func (s *Service) AISetActiveProvider(id string) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.activeProvider = id
+ _ = s.saveConfig()
+}
+
+// AIGetActiveProvider 获取活动 Provider ID
+func (s *Service) AIGetActiveProvider() string {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.activeProvider
+}
+
+// AIGetBuiltinPrompts 返回内部置的各类系统提示词,用于前端展示或查询
+func (s *Service) AIGetBuiltinPrompts() map[string]string {
+ return aicontext.GetBuiltinPrompts()
+}
+
+// AIListModels 获取当前活跃 Provider 的可用模型列表
+func (s *Service) AIListModels() map[string]interface{} {
+ s.mu.RLock()
+ var config ai.ProviderConfig
+ found := false
+ for _, p := range s.providers {
+ if p.ID == s.activeProvider {
+ config = p
+ found = true
+ break
+ }
+ }
+ s.mu.RUnlock()
+
+ if !found {
+ return map[string]interface{}{"success": false, "models": []string{}, "error": "未找到活跃 Provider"}
+ }
+
+ models, err := fetchModels(config)
+ if err != nil {
+ // 回退到配置中的静态模型列表
+ if len(config.Models) > 0 {
+ return map[string]interface{}{"success": true, "models": config.Models, "source": "static"}
+ }
+ return map[string]interface{}{"success": false, "models": []string{}, "error": err.Error()}
+ }
+
+ return map[string]interface{}{"success": true, "models": models, "source": "api"}
+}
+
+// fetchModels 从供应商 API 获取可用模型列表
+func fetchModels(config ai.ProviderConfig) ([]string, error) {
+ providerType := config.Type
+ if providerType == "custom" && config.APIFormat != "" {
+ providerType = config.APIFormat
+ }
+
+ switch providerType {
+ case "openai":
+ return fetchOpenAIModels(config)
+ case "anthropic":
+ // Anthropic 没有公开的 /models 端点,返回硬编码列表
+ return []string{"claude-opus-4-6", "claude-sonnet-4-6"}, nil
+ case "gemini":
+ return fetchGeminiModels(config)
+ default:
+ return fetchOpenAIModels(config)
+ }
+}
+
+// fetchOpenAIModels 获取 OpenAI 兼容 API 的模型列表
+func fetchOpenAIModels(config ai.ProviderConfig) ([]string, error) {
+ baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
+ if baseURL == "" {
+ baseURL = "https://api.openai.com/v1"
+ }
+ // 确保 baseURL 以 /v1 结尾
+ if !strings.HasSuffix(baseURL, "/v1") {
+ baseURL = baseURL + "/v1"
+ }
+
+ req, err := http.NewRequest("GET", baseURL+"/models", nil)
+ if err != nil {
+ return nil, fmt.Errorf("创建请求失败: %w", err)
+ }
+ req.Header.Set("Authorization", "Bearer "+config.APIKey)
+
+ client := &http.Client{Timeout: 15 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("请求模型列表失败: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
+ return nil, fmt.Errorf("获取模型列表失败 (HTTP %d): %s", resp.StatusCode, string(body))
+ }
+
+ var result struct {
+ Data []struct {
+ ID string `json:"id"`
+ } `json:"data"`
+ }
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ return nil, fmt.Errorf("解析模型列表失败: %w", err)
+ }
+
+ models := make([]string, 0, len(result.Data))
+ for _, m := range result.Data {
+ models = append(models, m.ID)
+ }
+ return models, nil
+}
+
+// fetchGeminiModels 获取 Gemini API 的模型列表
+func fetchGeminiModels(config ai.ProviderConfig) ([]string, error) {
+ baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
+ if baseURL == "" {
+ baseURL = "https://generativelanguage.googleapis.com"
+ }
+
+ req, err := http.NewRequest("GET", baseURL+"/v1beta/models?key="+config.APIKey, nil)
+ if err != nil {
+ return nil, fmt.Errorf("创建请求失败: %w", err)
+ }
+
+ client := &http.Client{Timeout: 15 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("请求模型列表失败: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
+ return nil, fmt.Errorf("获取模型列表失败 (HTTP %d): %s", resp.StatusCode, string(body))
+ }
+
+ var result struct {
+ Models []struct {
+ Name string `json:"name"` // e.g. "models/gemini-2.5-flash"
+ } `json:"models"`
+ }
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ return nil, fmt.Errorf("解析模型列表失败: %w", err)
+ }
+
+ models := make([]string, 0, len(result.Models))
+ for _, m := range result.Models {
+ // 去掉 "models/" 前缀
+ name := m.Name
+ if strings.HasPrefix(name, "models/") {
+ name = strings.TrimPrefix(name, "models/")
+ }
+ models = append(models, name)
+ }
+ return models, nil
+}
+
+// --- 安全控制 ---
+
+// AIGetSafetyLevel 获取当前安全级别
+func (s *Service) AIGetSafetyLevel() string {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return string(s.safetyLevel)
+}
+
+// AISetSafetyLevel 设置安全级别
+func (s *Service) AISetSafetyLevel(level string) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ switch ai.SQLPermissionLevel(level) {
+ case ai.PermissionReadOnly, ai.PermissionReadWrite, ai.PermissionFull:
+ s.safetyLevel = ai.SQLPermissionLevel(level)
+ default:
+ s.safetyLevel = ai.PermissionReadOnly
+ }
+ s.guard.SetPermissionLevel(s.safetyLevel)
+ _ = s.saveConfig()
+}
+
+// --- 上下文控制 ---
+
+// AIGetContextLevel 获取上下文传递级别
+func (s *Service) AIGetContextLevel() string {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return string(s.contextLevel)
+}
+
+// AISetContextLevel 设置上下文传递级别
+func (s *Service) AISetContextLevel(level string) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ switch ai.ContextLevel(level) {
+ case ai.ContextSchemaOnly, ai.ContextWithSamples, ai.ContextWithResults:
+ s.contextLevel = ai.ContextLevel(level)
+ default:
+ s.contextLevel = ai.ContextSchemaOnly
+ }
+ _ = s.saveConfig()
+}
+
+// --- AI 对话 ---
+
+// AIChatSend 同步发送 AI 对话(非流式)
+func (s *Service) AIChatSend(messages []map[string]string) map[string]interface{} {
+ p, err := s.getActiveProvider()
+ if err != nil {
+ return map[string]interface{}{"success": false, "error": err.Error()}
+ }
+
+ var aiMessages []ai.Message
+ for _, m := range messages {
+ aiMessages = append(aiMessages, ai.Message{Role: m["role"], Content: m["content"]})
+ }
+
+ resp, err := p.Chat(context.Background(), ai.ChatRequest{Messages: aiMessages})
+ if err != nil {
+ return map[string]interface{}{"success": false, "error": err.Error()}
+ }
+
+ return map[string]interface{}{
+ "success": true,
+ "content": resp.Content,
+ "tokensUsed": map[string]int{
+ "promptTokens": resp.TokensUsed.PromptTokens,
+ "completionTokens": resp.TokensUsed.CompletionTokens,
+ "totalTokens": resp.TokensUsed.TotalTokens,
+ },
+ }
+}
+
+// AIChatStream 流式发送 AI 对话(通过 EventsEmit 推送)
+func (s *Service) AIChatStream(sessionID string, messages []map[string]string) {
+ streamCtx, cancel := context.WithCancel(context.Background())
+ s.mu.Lock()
+ s.cancelFuncs[sessionID] = cancel
+ s.mu.Unlock()
+
+ go func() {
+ defer func() {
+ s.mu.Lock()
+ delete(s.cancelFuncs, sessionID)
+ s.mu.Unlock()
+ cancel() // 确保释放
+ }()
+
+ p, err := s.getActiveProvider()
+ if err != nil {
+ wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
+ "error": err.Error(),
+ "done": true,
+ })
+ return
+ }
+
+ var aiMessages []ai.Message
+ for _, m := range messages {
+ aiMessages = append(aiMessages, ai.Message{Role: m["role"], Content: m["content"]})
+ }
+
+ err = p.ChatStream(streamCtx, ai.ChatRequest{Messages: aiMessages}, func(chunk ai.StreamChunk) {
+ wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
+ "content": chunk.Content,
+ "done": chunk.Done,
+ "error": chunk.Error,
+ })
+ })
+
+ // 当 context 被主动 cancel 的时候,不把这个视为向外抛的 error
+ if err != nil && err != context.Canceled {
+ wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
+ "error": err.Error(),
+ "done": true,
+ })
+ }
+ }()
+}
+
+// AIChatCancel 立即终止某个 Session 的流式对话请求
+func (s *Service) AIChatCancel(sessionID string) {
+ s.mu.RLock()
+ cancel, ok := s.cancelFuncs[sessionID]
+ s.mu.RUnlock()
+ if ok && cancel != nil {
+ cancel()
+ }
+}
+
+// AICheckSQL 检查 SQL 的安全性
+func (s *Service) AICheckSQL(sql string) ai.SafetyResult {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.guard.Check(sql)
+}
+
+// --- 内部方法 ---
+
+func (s *Service) getActiveProvider() (provider.Provider, error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if s.activeProvider == "" && len(s.providers) > 0 {
+ s.activeProvider = s.providers[0].ID
+ }
+
+ for _, cfg := range s.providers {
+ if cfg.ID == s.activeProvider {
+ return provider.NewProvider(cfg)
+ }
+ }
+
+ return nil, fmt.Errorf("未配置 AI Provider,请先在设置中配置")
+}
+
+// --- 配置持久化 ---
+
+type aiConfig struct {
+ Providers []ai.ProviderConfig `json:"providers"`
+ ActiveProvider string `json:"activeProvider"`
+ SafetyLevel string `json:"safetyLevel"`
+ ContextLevel string `json:"contextLevel"`
+}
+
+func (s *Service) loadConfig() {
+ path := filepath.Join(s.configDir, "ai_config.json")
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return // 首次启动,无配置文件
+ }
+
+ var cfg aiConfig
+ if err := json.Unmarshal(data, &cfg); err != nil {
+ logger.Error(err, "加载 AI 配置失败")
+ return
+ }
+
+ s.providers = cfg.Providers
+ if s.providers == nil {
+ s.providers = make([]ai.ProviderConfig, 0)
+ }
+ s.activeProvider = cfg.ActiveProvider
+
+ switch ai.SQLPermissionLevel(cfg.SafetyLevel) {
+ case ai.PermissionReadOnly, ai.PermissionReadWrite, ai.PermissionFull:
+ s.safetyLevel = ai.SQLPermissionLevel(cfg.SafetyLevel)
+ default:
+ s.safetyLevel = ai.PermissionReadOnly
+ }
+ s.guard.SetPermissionLevel(s.safetyLevel)
+
+ switch ai.ContextLevel(cfg.ContextLevel) {
+ case ai.ContextSchemaOnly, ai.ContextWithSamples, ai.ContextWithResults:
+ s.contextLevel = ai.ContextLevel(cfg.ContextLevel)
+ default:
+ s.contextLevel = ai.ContextSchemaOnly
+ }
+}
+
+func (s *Service) saveConfig() error {
+ cfg := aiConfig{
+ Providers: s.providers,
+ ActiveProvider: s.activeProvider,
+ SafetyLevel: string(s.safetyLevel),
+ ContextLevel: string(s.contextLevel),
+ }
+
+ data, err := json.MarshalIndent(cfg, "", " ")
+ if err != nil {
+ return fmt.Errorf("序列化 AI 配置失败: %w", err)
+ }
+
+ if err := os.MkdirAll(s.configDir, 0o755); err != nil {
+ return fmt.Errorf("创建配置目录失败: %w", err)
+ }
+
+ path := filepath.Join(s.configDir, "ai_config.json")
+ if err := os.WriteFile(path, data, 0o644); err != nil {
+ return fmt.Errorf("写入 AI 配置失败: %w", err)
+ }
+
+ return nil
+}
+
+// --- 工具函数 ---
+
+func resolveConfigDir() string {
+ configDir, err := os.UserConfigDir()
+ if err != nil {
+ configDir = "."
+ }
+ return filepath.Join(configDir, "GoNavi")
+}
+
+func maskAPIKey(apiKey string) string {
+ if len(apiKey) <= 8 {
+ return "****"
+ }
+ return apiKey[:4] + "****" + apiKey[len(apiKey)-4:]
+}
+
+func isMaskedAPIKey(apiKey string) bool {
+ return strings.Contains(apiKey, "****")
+}
+
+func truncateString(s string, maxLen int) string {
+ if len(s) <= maxLen {
+ return s
+ }
+ return s[:maxLen] + "..."
+}
diff --git a/internal/ai/types.go b/internal/ai/types.go
new file mode 100644
index 0000000..eb55a6f
--- /dev/null
+++ b/internal/ai/types.go
@@ -0,0 +1,85 @@
+package ai
+
+// Message 表示一条对话消息
+type Message struct {
+ Role string `json:"role"` // "system" | "user" | "assistant"
+ Content string `json:"content"`
+}
+
+// ChatRequest AI 对话请求
+type ChatRequest struct {
+ Messages []Message `json:"messages"`
+ Temperature float64 `json:"temperature"`
+ MaxTokens int `json:"maxTokens"`
+}
+
+// ChatResponse AI 对话响应
+type ChatResponse struct {
+ Content string `json:"content"`
+ TokensUsed TokenUsage `json:"tokensUsed"`
+}
+
+// TokenUsage token 用量统计
+type TokenUsage struct {
+ PromptTokens int `json:"promptTokens"`
+ CompletionTokens int `json:"completionTokens"`
+ TotalTokens int `json:"totalTokens"`
+}
+
+// StreamChunk 流式响应片段
+type StreamChunk struct {
+ Content string `json:"content"`
+ Done bool `json:"done"`
+ Error string `json:"error,omitempty"`
+}
+
+// ProviderConfig AI Provider 配置
+type ProviderConfig struct {
+ ID string `json:"id"`
+ Type string `json:"type"` // openai | anthropic | gemini | custom
+ Name string `json:"name"`
+ APIKey string `json:"apiKey"`
+ BaseURL string `json:"baseUrl"`
+ Model string `json:"model"`
+ Models []string `json:"models,omitempty"`
+ APIFormat string `json:"apiFormat,omitempty"` // custom 专用: openai | anthropic | gemini
+ Headers map[string]string `json:"headers,omitempty"`
+ MaxTokens int `json:"maxTokens"`
+ Temperature float64 `json:"temperature"`
+}
+
+// SQLPermissionLevel AI SQL 执行权限级别
+type SQLPermissionLevel string
+
+const (
+ PermissionReadOnly SQLPermissionLevel = "readonly"
+ PermissionReadWrite SQLPermissionLevel = "readwrite"
+ PermissionFull SQLPermissionLevel = "full"
+)
+
+// ContextLevel AI 上下文传递级别
+type ContextLevel string
+
+const (
+ ContextSchemaOnly ContextLevel = "schema_only"
+ ContextWithSamples ContextLevel = "with_samples"
+ ContextWithResults ContextLevel = "with_results"
+)
+
+// SQLOperationType SQL 操作类型
+type SQLOperationType string
+
+const (
+ SQLOpQuery SQLOperationType = "query" // SELECT, SHOW, DESCRIBE, EXPLAIN
+ SQLOpDML SQLOperationType = "dml" // INSERT, UPDATE, DELETE
+ SQLOpDDL SQLOperationType = "ddl" // CREATE, ALTER, DROP, TRUNCATE
+ SQLOpOther SQLOperationType = "other"
+)
+
+// SafetyResult 安全检查结果
+type SafetyResult struct {
+ Allowed bool `json:"allowed"`
+ OperationType SQLOperationType `json:"operationType"`
+ RequiresConfirm bool `json:"requiresConfirm"`
+ WarningMessage string `json:"warningMessage,omitempty"`
+}
diff --git a/main.go b/main.go
index 02cedcb..4e3bc59 100644
--- a/main.go
+++ b/main.go
@@ -1,8 +1,10 @@
package main
import (
+ "context"
"embed"
+ aiservice "GoNavi-Wails/internal/ai/service"
"GoNavi-Wails/internal/app"
"GoNavi-Wails/internal/logger"
@@ -19,6 +21,7 @@ var assets embed.FS
func main() {
// Create an instance of the app structure
application := app.NewApp()
+ aiService := aiservice.NewService()
// Create application with options
err := wails.Run(&options.App{
@@ -30,10 +33,14 @@ func main() {
Assets: assets,
},
BackgroundColour: &options.RGBA{R: 0, G: 0, B: 0, A: 0},
- OnStartup: application.Startup,
- OnShutdown: application.Shutdown,
+ OnStartup: func(ctx context.Context) {
+ application.Startup(ctx)
+ aiService.Startup(ctx)
+ },
+ OnShutdown: application.Shutdown,
Bind: []interface{}{
application,
+ aiService,
},
Windows: &windows.Options{
WebviewIsTransparent: true,