+
{title}
@@ -522,13 +523,55 @@ const AISettingsModal: React.FC
= ({ open, onClose, darkMo
);
- const tabItems = [
- { key: 'providers', label:
Provider, children: isEditing ? renderProviderForm() : renderProviderList() },
- { key: 'safety', label:
安全控制, children: renderSafetySettings() },
- { key: 'context', label:
上下文, children: renderContextSettings() },
- { key: 'prompts', label:
内置提示词, children: renderBuiltinPrompts() },
+ const BUILTIN_TOOLS_INFO = [
+ { name: 'get_connections', icon: '🔗', desc: '获取所有可用的数据库连接', detail: '返回连接 ID、名称、类型 (MySQL/PostgreSQL 等) 和 Host 地址。AI 根据返回信息决定优先探索哪个连接。', params: '无参数' },
+ { name: 'get_databases', icon: '🗄️', desc: '获取指定连接下的所有数据库', detail: '传入 connectionId,返回该连接下的数据库/Schema 名称列表。', params: 'connectionId: 连接 ID' },
+ { name: 'get_tables', icon: '📋', desc: '获取指定数据库下的所有表名', detail: '传入 connectionId 和 dbName,返回表名列表。AI 用它来定位用户提到的目标表。', params: 'connectionId, dbName' },
+ { name: 'get_columns', icon: '🔍', desc: '获取指定表的字段结构', detail: '传入 connectionId、dbName 和 tableName,返回每个字段的名称、类型、是否可空、默认值和注释。AI 在生成 SQL 前必须调用此工具确认真实字段名。', params: 'connectionId, dbName, tableName' },
+ { name: 'get_table_ddl', icon: '📝', desc: '获取表的建表语句 (DDL)', detail: '传入 connectionId、dbName 和 tableName,返回完整的 CREATE TABLE 语句,包含字段定义、索引、约束等信息。', params: 'connectionId, dbName, tableName' },
+ { name: 'execute_sql', icon: '▶️', desc: '执行 SQL 查询并返回结果', detail: '传入 connectionId、dbName 和 sql,在目标数据库上执行 SQL 并返回结果(最多 50 行)。受安全级别控制,只读模式下仅允许 SELECT/SHOW/DESCRIBE。', params: 'connectionId, dbName, sql' },
];
+ const renderBuiltinTools = () => (
+
+
+ AI 助手在处理数据库相关问题时,可以自动调用以下内置工具获取真实数据,全程无需人工干预。
+
+
+ 💡 工作流程:get_connections → get_databases → get_tables → get_columns → 生成 SQL
+
+ {BUILTIN_TOOLS_INFO.map(tool => (
+
+
+
{tool.icon}
+
+
+ {tool.name}
+
+
{tool.desc}
+
+
+
+ {tool.detail}
+
+
+
+ 参数:
+
+ {tool.params}
+
+
+
+ ))}
+
+ );
+
const modalShellStyle = {
background: overlayTheme.shellBg, border: overlayTheme.shellBorder,
boxShadow: overlayTheme.shellShadow, backdropFilter: overlayTheme.shellBackdropFilter,
@@ -555,14 +598,64 @@ const AISettingsModal: React.FC
= ({ open, onClose, darkMo
open={open}
onCancel={onClose}
footer={null}
- width={540}
+ width={820}
styles={{
content: modalShellStyle,
- header: { background: 'transparent', borderBottom: 'none', paddingBottom: 4 },
- body: { paddingTop: 0, height: 520, overflowY: 'auto', overflowX: 'hidden' },
+ header: { background: 'transparent', borderBottom: 'none', paddingBottom: 8 },
+ body: { paddingTop: 8, height: 620, overflow: 'hidden' },
}}
>
-
+
+
+
设置导航
+
+ {[
+ { key: 'providers', title: '模型供应商', description: '配置大模型接口与秘钥', icon:
},
+ { key: 'safety', title: '安全控制', description: '限制 AI 操作风险级别', icon:
},
+ { key: 'context', title: '上下文', description: '配置携带的数据架构信息', icon:
},
+ { key: 'tools', title: '内置工具', description: '查看 AI 可调用的数据探针', icon:
},
+ { key: 'prompts', title: '内置提示词', description: '查看系统预设的底层要求', icon:
},
+ ].map((item) => {
+ const active = activeSection === item.key;
+ return (
+
+ );
+ })}
+
+
+
+ {activeSection === 'providers' && (isEditing ? renderProviderForm() : renderProviderList())}
+ {activeSection === 'safety' && renderSafetySettings()}
+ {activeSection === 'context' && renderContextSettings()}
+ {activeSection === 'tools' && renderBuiltinTools()}
+ {activeSection === 'prompts' && renderBuiltinPrompts()}
+
+
);
};
diff --git a/frontend/src/components/QueryEditor.tsx b/frontend/src/components/QueryEditor.tsx
index 3bdb46e..c50d662 100644
--- a/frontend/src/components/QueryEditor.tsx
+++ b/frontend/src/components/QueryEditor.tsx
@@ -480,7 +480,9 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
contextMenuOrder: 1,
run: (ed: any) => {
const selection = ed.getModel()?.getValueInRange(ed.getSelection());
- let prompt = action.prompt;
+ const conn = connectionsRef.current.find(c => c.id === currentConnectionIdRef.current);
+ const ctxText = conn ? `【上下文环境:${conn.config?.type || '数据库'} "${conn.name}", 当前库选定为 "${currentDbRef.current || '默认'}"】\n` : '';
+ let prompt = ctxText + action.prompt;
if (action.useSelection && selection) {
prompt = prompt.replace('{SQL}', selection);
}
@@ -853,7 +855,92 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
return { suggestions };
}
});
+ // 注册 / 斜杠命令 AI 快捷补全
+ const slashCmdDefs = [
+ { cmd: '/query', label: '🔍 自然语言查询', desc: '用中文描述你想查什么', prompt: '帮我写一条 SQL 查询:' },
+ { cmd: '/sql', label: '📝 生成 SQL', desc: '描述需求自动生成语句', prompt: '请根据以下需求生成 SQL:' },
+ { cmd: '/explain', label: '💡 解释 SQL', desc: '解释选中 SQL 的逻辑', prompt: '请解释以下 SQL 的执行逻辑和每一步的作用:\n```sql\n{SQL}\n```', useSelection: true },
+ { cmd: '/optimize', label: '⚡ 优化分析', desc: '分析 SQL 性能瓶颈', prompt: '请分析以下 SQL 的性能问题,并给出优化后的版本:\n```sql\n{SQL}\n```', useSelection: true },
+ { cmd: '/schema', label: '🏗️ 表设计评审', desc: '评审表结构设计质量', prompt: '请全面评审当前关联表的设计,包括字段类型、范式、索引策略等方面的改进建议:' },
+ { cmd: '/index', label: '📊 索引建议', desc: '推荐最优索引方案', prompt: '请基于当前表结构和常见查询场景,推荐最优的索引方案并给出建表语句:' },
+ { cmd: '/diff', label: '🔄 表对比', desc: '对比两表差异生成变更', prompt: '请对比以下两张表的结构差异,并生成从旧版本迁移到新版本的 ALTER 语句:' },
+ { cmd: '/mock', label: '🎲 造测试数据', desc: '生成 INSERT 测试数据', prompt: '请为当前关联的表生成 10 条符合业务语义的测试数据 INSERT 语句:' },
+ ];
+ // 全局变量存储命令定义,供 onDidChangeModelContent 使用
+ (window as any).__gonaviSlashCmdDefs = slashCmdDefs;
+
+ monaco.languages.registerCompletionItemProvider('sql', {
+ triggerCharacters: ['/'],
+ provideCompletionItems: (model: any, position: any) => {
+ const lineContent = model.getLineContent(position.lineNumber);
+ const textBefore = lineContent.substring(0, position.column - 1).trimStart();
+ if (!textBefore.startsWith('/')) {
+ return { suggestions: [] };
+ }
+
+ const range = {
+ startLineNumber: position.lineNumber,
+ endLineNumber: position.lineNumber,
+ startColumn: position.column - textBefore.length,
+ endColumn: position.column,
+ };
+
+ return {
+ suggestions: slashCmdDefs.map((c, i) => ({
+ label: `${c.cmd} ${c.label}`,
+ kind: monaco.languages.CompletionItemKind.Event,
+ detail: c.desc,
+ insertText: `__AI_${c.cmd.slice(1).toUpperCase()}__`,
+ range,
+ sortText: String(i).padStart(2, '0'),
+ })),
+ };
+ },
+ });
+
} // end sqlCompletionRegistered guard
+
+ // 每个编辑器实例都注册内容变化监听(检测斜杠命令标记)
+ let _handlingSlash = false;
+ editor.onDidChangeModelContent(() => {
+ if (_handlingSlash) return;
+ const model = editor.getModel();
+ if (!model) return;
+ const content = model.getValue();
+ const markerMatch = content.match(/__AI_(\w+)__/);
+ if (!markerMatch) return;
+
+ const cmdKey = markerMatch[1].toLowerCase();
+ const defs = (window as any).__gonaviSlashCmdDefs || [];
+ const cmdDef = defs.find((c: any) => c.cmd === `/${cmdKey}`);
+ if (!cmdDef) return;
+
+ // 清除标记文本(带递归保护)
+ _handlingSlash = true;
+ const fullText = model.getValue();
+ const newText = fullText.replace(markerMatch[0], '').replace(/^\s*\n/, '');
+ model.setValue(newText);
+ _handlingSlash = false;
+
+ // 组装 prompt
+ const conn = connectionsRef.current.find(c => c.id === currentConnectionIdRef.current);
+ const ctxText = conn ? `【上下文环境:${conn.config?.type || '数据库'} "${conn.name}", 当前库选定为 "${currentDbRef.current || '默认'}"】\n` : '';
+ let finalPrompt = ctxText + cmdDef.prompt;
+ if (cmdDef.useSelection) {
+ const sel = editor.getSelection();
+ const selText = sel ? model.getValueInRange(sel) : '';
+ finalPrompt = finalPrompt.replace('{SQL}', selText || getCurrentQuery());
+ }
+
+ // 打开 AI 面板并注入 prompt
+ const store = useStore.getState();
+ if (!store.aiPanelVisible) {
+ store.setAIPanelVisible(true);
+ }
+ setTimeout(() => {
+ window.dispatchEvent(new CustomEvent('gonavi:ai:inject-prompt', { detail: { prompt: finalPrompt } }));
+ }, store.aiPanelVisible ? 0 : 350);
+ });
};
const handleFormat = () => {
@@ -870,11 +957,14 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
const selection = editor?.getModel()?.getValueInRange(editor.getSelection()) || '';
const fullSQL = getCurrentQuery();
+ const conn = connections.find(c => c.id === currentConnectionId);
+ const ctxText = conn ? `【上下文环境:${conn.config?.type || '数据库'} "${conn.name}", 当前库选定为 "${currentDb || '默认'}"】\n` : '';
+
const prompts: Record = {
- generate: '请根据当前数据库表结构生成查询语句:',
- explain: `请解释以下 SQL 语句的执行逻辑:\n\`\`\`sql\n${selection || fullSQL}\n\`\`\``,
- optimize: `请分析以下 SQL 语句的性能并给出优化建议:\n\`\`\`sql\n${selection || fullSQL}\n\`\`\``,
- schema: '请分析当前数据库的表结构并给出优化建议。',
+ generate: `${ctxText}请根据当前数据库表结构生成查询语句:`,
+ explain: `${ctxText}请解释以下 SQL 语句的执行逻辑:\n\`\`\`sql\n${selection || fullSQL}\n\`\`\``,
+ optimize: `${ctxText}请分析以下 SQL 语句的性能并给出优化建议:\n\`\`\`sql\n${selection || fullSQL}\n\`\`\``,
+ schema: `${ctxText}请针对当前数据库的表结构进行系统分析,并给出性能和设计上的优化建议。`,
};
const store = useStore.getState();
@@ -1932,41 +2022,73 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
};
}, [activeTabId, tab.id, handleRun]);
- // 监听并处理外部注入的 SQL 代码 (如 AI 面板)
+ // 监听由 TabManager 分发的专用注入事件
useEffect(() => {
- const handleInsertSql = (e: CustomEvent) => {
- if (activeTabId !== tab.id || !e.detail?.sql) return;
- const sqlText = e.detail.sql;
+ const handleInsertSql = (e: any) => {
+ if (e.detail?.tabId !== tab.id || !e.detail?.sql) return;
+ const { sql: sqlText, connectionId, dbName } = e.detail;
+
+ // 同步更新 ref,防止异步 fetchDbs 竞态覆盖正确的 dbName
+ if (connectionId && connectionId !== currentConnectionId) {
+ if (dbName) {
+ currentDbRef.current = dbName;
+ setCurrentDb(dbName);
+ }
+ setCurrentConnectionId(connectionId);
+ } else if (dbName && dbName !== currentDb) {
+ currentDbRef.current = dbName;
+ setCurrentDb(dbName);
+ }
+
const editor = editorRef.current;
- if (editor && (window as any).monaco) {
- const position = editor.getPosition();
+ const monaco = monacoRef.current;
+ if (editor && monaco) {
+ let position = editor.getPosition();
+ const model = editor.getModel();
+ if (!position && model) {
+ const lineCount = model.getLineCount();
+ const maxCol = model.getLineMaxColumn(lineCount);
+ position = new monaco.Position(lineCount, maxCol);
+ }
+
if (position) {
const mText = (sqlText.endsWith('\n') ? sqlText : sqlText + '\n');
- const startRange = new (window as any).monaco.Range(position.lineNumber, position.column, position.lineNumber, position.column);
+ const startRange = new monaco.Range(position.lineNumber, position.column, position.lineNumber, position.column);
editor.executeEdits('ai-insert', [{
range: startRange,
- text: '\n' + mText,
+ text: (position.column > 1 ? '\n' : '') + mText,
forceMoveMarkers: true
}]);
+
+ // 定位并滚动到可见区域
+ const targetLine = position.lineNumber + (position.column > 1 ? 1 : 0);
+ editor.revealLineInCenterIfOutsideViewport(targetLine);
+ editor.setPosition({ lineNumber: targetLine + mText.split('\n').length - 1, column: 1 });
editor.focus();
+
+ if (!e.detail.runImmediately) {
+ message.success('代码已在当前光标处成功插入');
+ }
if (e.detail.runImmediately) {
const endPosition = editor.getPosition();
- editor.setSelection(new (window as any).monaco.Range(
- position.lineNumber + 1, 1,
+ editor.setSelection(new monaco.Range(
+ targetLine, 1,
endPosition.lineNumber, endPosition.column
));
- setTimeout(() => handleRun(), 50);
+ // 🔧 延迟 500ms 等待连接/数据库切换的 setState 生效后再执行
+ setTimeout(() => handleRun(), 500);
}
}
} else {
setQuery((prev: string) => prev ? prev + '\n' + sqlText : sqlText);
+ message.success('代码已追加');
}
};
- window.addEventListener('gonavi:insert-sql', handleInsertSql as EventListener);
- return () => window.removeEventListener('gonavi:insert-sql', handleInsertSql as EventListener);
- }, [activeTabId, tab.id, handleRun]);
+ window.addEventListener('gonavi:insert-sql-to-tab', handleInsertSql as EventListener);
+ return () => window.removeEventListener('gonavi:insert-sql-to-tab', handleInsertSql as EventListener);
+ }, [tab.id, handleRun]);
const resolveDefaultQueryName = () => {
const rawTitle = String(tab.title || '').trim();
diff --git a/frontend/src/components/Sidebar.tsx b/frontend/src/components/Sidebar.tsx
index 09524cd..a15a62f 100644
--- a/frontend/src/components/Sidebar.tsx
+++ b/frontend/src/components/Sidebar.tsx
@@ -1432,7 +1432,7 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
if (type === 'connection') {
setActiveContext({ connectionId: key, dbName: '' });
} else if (type === 'database') {
- setActiveContext({ connectionId: dataRef.id, dbName: title });
+ setActiveContext({ connectionId: dataRef.id, dbName: dataRef.dbName });
} else if (type === 'table') {
setActiveContext({ connectionId: dataRef.id, dbName: dataRef.dbName });
} else if (type === 'view' || type === 'db-trigger' || type === 'routine') {
@@ -1456,9 +1456,9 @@ const Sidebar: React.FC<{ onEditConnection?: (conn: SavedConnection) => void }>
const onDoubleClick = (e: any, node: any) => {
// 保证用户直接双击节点未触发 onClick/onSelect 时也能强行拿到选中状态
- const { type, dataRef, key: nodeKey, title } = node;
+ const { type, dataRef, key: nodeKey } = node;
if (type === 'connection') setActiveContext({ connectionId: nodeKey, dbName: '' });
- else if (type === 'database') setActiveContext({ connectionId: dataRef.id, dbName: title });
+ else if (type === 'database') setActiveContext({ connectionId: dataRef.id, dbName: dataRef.dbName });
else if (type === 'table' || type === 'view' || type === 'db-trigger' || type === 'routine') setActiveContext({ connectionId: dataRef.id, dbName: dataRef.dbName });
else if (type === 'saved-query') setActiveContext({ connectionId: dataRef.connectionId, dbName: dataRef.dbName });
else if (type === 'redis-db') setActiveContext({ connectionId: dataRef.id, dbName: `db${dataRef.redisDB}` });
diff --git a/frontend/src/components/TabManager.tsx b/frontend/src/components/TabManager.tsx
index 8784f4d..9709e83 100644
--- a/frontend/src/components/TabManager.tsx
+++ b/frontend/src/components/TabManager.tsx
@@ -89,6 +89,7 @@ const TabManager: React.FC = () => {
const theme = useStore(state => state.theme);
const activeTabId = useStore(state => state.activeTabId);
const setActiveTab = useStore(state => state.setActiveTab);
+ const addTab = useStore(state => state.addTab);
const closeTab = useStore(state => state.closeTab);
const closeOtherTabs = useStore(state => state.closeOtherTabs);
const closeTabsToLeft = useStore(state => state.closeTabsToLeft);
@@ -134,6 +135,59 @@ const TabManager: React.FC = () => {
setDraggingTabId(null);
};
+ React.useEffect(() => {
+ const handleGlobalInsertSql = (e: any) => {
+ const { sql, runImmediately, connectionId: eventConnId, dbName: eventDbName } = e.detail;
+ if (!sql) return;
+
+ const activeTab = tabs.find(t => t.id === activeTabId);
+
+ // 🔧 runImmediately(点击"执行")始终新建独立 tab,避免追加到已有 tab 导致 SQL 重复
+ if (runImmediately) {
+ const newTabId = 'tab-' + Date.now();
+ const resolvedConnId = eventConnId || activeTab?.connectionId || (connections.length > 0 ? connections[0].id : '');
+ const resolvedDbName = eventConnId ? (eventDbName || '') : (activeTab?.dbName || '');
+ addTab({
+ id: newTabId,
+ type: 'query',
+ title: '新建查询',
+ query: '',
+ connectionId: resolvedConnId,
+ dbName: resolvedDbName
+ });
+ setActiveTab(newTabId);
+ setTimeout(() => {
+ window.dispatchEvent(new CustomEvent('gonavi:insert-sql-to-tab', {
+ detail: { tabId: newTabId, sql, runImmediately: true, connectionId: resolvedConnId, dbName: resolvedDbName }
+ }));
+ }, 300);
+ return;
+ }
+
+ // 插入模式:追加到已有 tab 或新建 tab
+ if (activeTab && activeTab.type === 'query') {
+ window.dispatchEvent(new CustomEvent('gonavi:insert-sql-to-tab', {
+ detail: { tabId: activeTab.id, sql, runImmediately: false, connectionId: eventConnId, dbName: eventDbName }
+ }));
+ } else {
+ const newTabId = 'tab-' + Date.now();
+ const resolvedConnId = eventConnId || activeTab?.connectionId || (connections.length > 0 ? connections[0].id : '');
+ const resolvedDbName = eventConnId ? (eventDbName || '') : (activeTab?.dbName || '');
+ addTab({
+ id: newTabId,
+ type: 'query',
+ title: '新建查询',
+ query: sql,
+ connectionId: resolvedConnId,
+ dbName: resolvedDbName
+ });
+ setActiveTab(newTabId);
+ }
+ };
+ window.addEventListener('gonavi:insert-sql', handleGlobalInsertSql);
+ return () => window.removeEventListener('gonavi:insert-sql', handleGlobalInsertSql);
+ }, [tabs, activeTabId, addTab, setActiveTab, connections]);
+
const tabIds = useMemo(() => tabs.map((tab) => tab.id), [tabs]);
const renderTabBar: TabsProps['renderTabBar'] = (tabBarProps, DefaultTabBar) => (
diff --git a/frontend/src/components/ai/AIChatHeader.tsx b/frontend/src/components/ai/AIChatHeader.tsx
new file mode 100644
index 0000000..1d73d48
--- /dev/null
+++ b/frontend/src/components/ai/AIChatHeader.tsx
@@ -0,0 +1,76 @@
+import React from 'react';
+import { Button, Tooltip } from 'antd';
+import { HistoryOutlined, RobotOutlined, ClearOutlined, SettingOutlined, CloseOutlined, ExportOutlined } from '@ant-design/icons';
+import type { OverlayWorkbenchTheme } from '../../utils/overlayWorkbenchTheme';
+import type { AIChatMessage } from '../../types';
+
+interface AIChatHeaderProps {
+ darkMode: boolean;
+ mutedColor: string;
+ textColor: string;
+ overlayTheme: OverlayWorkbenchTheme;
+ onHistoryClick: () => void;
+ onClear: () => void;
+ onSettingsClick: () => void;
+ onClose: () => void;
+ messages?: AIChatMessage[];
+ sessionTitle?: string;
+}
+
+const exportToMarkdown = (messages: AIChatMessage[], title: string) => {
+ const lines: string[] = [`# ${title}`, '', `> 导出时间:${new Date().toLocaleString()}`, ''];
+ messages.forEach(msg => {
+ const role = msg.role === 'user' ? '👤 You' : '🤖 GoNavi AI';
+ lines.push(`## ${role}`);
+ lines.push('');
+ lines.push(msg.content);
+ lines.push('');
+ lines.push('---');
+ lines.push('');
+ });
+ const blob = new Blob([lines.join('\n')], { type: 'text/markdown;charset=utf-8' });
+ const url = URL.createObjectURL(blob);
+ const a = document.createElement('a');
+ a.href = url;
+ a.download = `${title.replace(/[/\\?%*:|"<>]/g, '-')}.md`;
+ document.body.appendChild(a);
+ a.click();
+ document.body.removeChild(a);
+ URL.revokeObjectURL(url);
+};
+
+export const AIChatHeader: React.FC = ({
+ darkMode, mutedColor, textColor, overlayTheme,
+ onHistoryClick, onClear, onSettingsClick, onClose,
+ messages = [], sessionTitle = '新对话'
+}) => {
+ return (
+
+
+
+ } onClick={onHistoryClick} style={{ color: mutedColor }} />
+
+
+
+
+
GoNavi AI
+
+
+ {messages.length > 0 && (
+
+ } onClick={() => exportToMarkdown(messages, sessionTitle)} style={{ color: mutedColor }} />
+
+ )}
+
+ } onClick={onClear} style={{ color: mutedColor }} />
+
+
+ } onClick={onSettingsClick} style={{ color: mutedColor }} />
+
+
+ } onClick={onClose} style={{ color: mutedColor }} />
+
+
+
+ );
+};
diff --git a/frontend/src/components/ai/AIChatInput.tsx b/frontend/src/components/ai/AIChatInput.tsx
new file mode 100644
index 0000000..ec78375
--- /dev/null
+++ b/frontend/src/components/ai/AIChatInput.tsx
@@ -0,0 +1,574 @@
+import React from 'react';
+import { Input, Select, AutoComplete, Tooltip, Modal, Checkbox, Spin, message, Button, Tag } from 'antd';
+import { DatabaseOutlined, SendOutlined, TableOutlined, SearchOutlined, PictureOutlined } from '@ant-design/icons';
+import { useStore } from '../../store';
+import { DBGetTables, DBShowCreateTable, DBGetDatabases } from '../../../wailsjs/go/app/App';
+import type { OverlayWorkbenchTheme } from '../../utils/overlayWorkbenchTheme';
+
+interface AIChatInputProps {
+ input: string;
+ setInput: (val: string) => void;
+ draftImages: string[];
+ setDraftImages: React.Dispatch>;
+ sending: boolean;
+ onSend: () => void;
+ onStop: () => void;
+ handleKeyDown: (e: React.KeyboardEvent) => void;
+ activeConnName: string;
+ activeContext: any;
+ activeProvider: any;
+ dynamicModels: string[];
+ loadingModels: boolean;
+ onModelChange: (val: string) => void;
+ onFetchModels: () => void;
+ textareaRef: React.RefObject;
+ darkMode: boolean;
+ textColor: string;
+ mutedColor: string;
+ overlayTheme: OverlayWorkbenchTheme;
+ contextUsageChars?: number;
+ maxContextChars?: number;
+}
+
+export const AIChatInput: React.FC = ({
+ input, setInput, draftImages, setDraftImages, sending, onSend, onStop, handleKeyDown,
+ activeConnName, activeContext, activeProvider, dynamicModels, loadingModels,
+ onModelChange, onFetchModels, textareaRef, darkMode, textColor, mutedColor, overlayTheme,
+ contextUsageChars, maxContextChars
+}) => {
+ const [contextOpen, setContextOpen] = React.useState(false);
+ const [contextLoading, setContextLoading] = React.useState(false);
+ const [contextTables, setContextTables] = React.useState<{name: string}[]>([]);
+ const [selectedTableKeys, setSelectedTableKeys] = React.useState([]);
+ const [searchText, setSearchText] = React.useState('');
+ const [appendingContext, setAppendingContext] = React.useState(false);
+
+ const fileInputRef = React.useRef(null);
+ const handleImageUpload = (e: React.ChangeEvent) => {
+ const files = Array.from(e.target.files || []);
+ files.forEach(file => {
+ if (file.type.indexOf('image') !== -1) {
+ const reader = new FileReader();
+ reader.onload = (event) => {
+ if (event.target?.result) {
+ setDraftImages(prev => [...prev, event.target!.result as string]);
+ }
+ };
+ reader.readAsDataURL(file);
+ }
+ });
+ if (fileInputRef.current) {
+ fileInputRef.current.value = '';
+ }
+ };
+
+ const [dbList, setDbList] = React.useState([]);
+ const [selectedDbName, setSelectedDbName] = React.useState('');
+
+ const filteredTables = contextTables.filter(t => t.name.toLowerCase().includes(searchText.toLowerCase()));
+ const [contextExpanded, setContextExpanded] = React.useState(false);
+
+ // Slash commands
+ const [showSlashMenu, setShowSlashMenu] = React.useState(false);
+ const [slashFilter, setSlashFilter] = React.useState('');
+ const slashCommands = React.useMemo(() => [
+ { cmd: '/query', label: '🔍 自然语言查询', desc: '用中文描述你想查什么', prompt: '帮我写一条 SQL 查询:' },
+ { cmd: '/sql', label: '📝 生成 SQL', desc: '描述需求自动生成语句', prompt: '请根据以下需求生成 SQL:' },
+ { cmd: '/explain', label: '💡 解释 SQL', desc: '解释选中 SQL 的逻辑', prompt: '请解释以下 SQL 的执行逻辑和每一步的作用:\n```sql\n\n```' },
+ { cmd: '/optimize', label: '⚡ 优化分析', desc: '分析 SQL 性能瓶颈', prompt: '请分析以下 SQL 的性能问题,并给出优化后的版本:\n```sql\n\n```' },
+ { cmd: '/schema', label: '🏗️ 表设计评审', desc: '评审表结构设计质量', prompt: '请全面评审当前关联表的设计,包括字段类型、范式、索引策略等方面的改进建议:' },
+ { cmd: '/index', label: '📊 索引建议', desc: '推荐最优索引方案', prompt: '请基于当前表结构和常见查询场景,推荐最优的索引方案并给出建表语句:' },
+ { cmd: '/diff', label: '🔄 表对比', desc: '对比两表差异生成变更', prompt: '请对比以下两张表的结构差异,并生成从旧版本迁移到新版本的 ALTER 语句:' },
+ { cmd: '/mock', label: '🎲 造测试数据', desc: '生成 INSERT 测试数据', prompt: '请为当前关联的表生成 10 条符合业务语义的测试数据 INSERT 语句:' },
+ ], []);
+ const filteredSlashCmds = slashCommands.filter(c => c.cmd.startsWith(slashFilter.toLowerCase()));
+
+ const aiContexts = useStore(state => state.aiContexts);
+ const addAIContext = useStore(state => state.addAIContext);
+ const removeAIContext = useStore(state => state.removeAIContext);
+
+ const connectionKey = activeContext?.connectionId ? `${activeContext.connectionId}:${activeContext.dbName || ''}` : 'default';
+ const activeContextItems = aiContexts[connectionKey] || [];
+
+ const fetchTablesForDb = async (dbName: string, connConfig: any) => {
+ setContextLoading(true);
+ setSelectedDbName(dbName);
+ try {
+ const res = await DBGetTables(connConfig, dbName);
+ if (res.success && Array.isArray(res.data)) {
+ setContextTables(res.data.map(r => ({ name: Object.values(r)[0] as string })));
+ } else {
+ message.error('获取表格失败: ' + res.message);
+ setContextTables([]);
+ }
+ } catch (e: any) {
+ message.error(e.message);
+ setContextTables([]);
+ } finally {
+ setContextLoading(false);
+ }
+ };
+
+ const handleOpenContext = async () => {
+ if (!activeContext?.connectionId) {
+ message.warning('请先在左侧选择一个数据库作为所聊上下文');
+ return;
+ }
+ const conn = useStore.getState().connections.find(c => c.id === activeContext.connectionId);
+ if (!conn) return;
+
+ setContextOpen(true);
+ setContextLoading(true);
+ setSearchText('');
+ // Store dbName::tableName composite keys
+ setSelectedTableKeys(activeContextItems.map(c => `${c.dbName}::${c.tableName}`));
+
+ try {
+ // Fetch databases
+ const dbRes = await DBGetDatabases(conn.config as any);
+ if (dbRes.success && Array.isArray(dbRes.data)) {
+ const databases = dbRes.data.map((r: any) => Object.values(r)[0] as string);
+ setDbList(databases);
+ }
+
+ // Fetch tables for the active contextual database
+ const initDbName = activeContext.dbName || '';
+ setSelectedDbName(initDbName);
+ const tablesRes = await DBGetTables(conn.config as any, initDbName);
+ if (tablesRes.success && Array.isArray(tablesRes.data)) {
+ setContextTables(tablesRes.data.map((r: any) => ({ name: Object.values(r)[0] as string })));
+ } else {
+ setContextTables([]);
+ }
+ } catch (e: any) {
+ message.error(e.message);
+ } finally {
+ setContextLoading(false);
+ }
+ };
+
+ const handleAppendContext = async () => {
+ const conn = useStore.getState().connections.find(c => c.id === activeContext.connectionId);
+ if (!conn) return;
+
+ setAppendingContext(true);
+ try {
+ let addedCount = 0;
+ let removedCount = 0;
+
+ for (const cx of activeContextItems) {
+ const key = `${cx.dbName}::${cx.tableName}`;
+ if (!selectedTableKeys.includes(key)) {
+ removeAIContext(connectionKey, cx.dbName, cx.tableName);
+ removedCount++;
+ }
+ }
+
+ for (const key of selectedTableKeys) {
+ const [dbName, tableName] = key.split('::');
+ if (!dbName || !tableName) continue;
+
+ if (activeContextItems.find(c => c.dbName === dbName && c.tableName === tableName)) {
+ continue;
+ }
+ const res = await DBShowCreateTable(conn.config as any, dbName, tableName);
+ let createSql = '';
+ if (res.success && res.data) {
+ if (typeof res.data === 'string') {
+ createSql = res.data;
+ } else if (Array.isArray(res.data) && res.data.length > 0) {
+ const row = res.data[0];
+ createSql = (Object.values(row).find(v => typeof v === 'string' && (v.toUpperCase().includes('CREATE TABLE') || v.toUpperCase().includes('CREATE'))) || Object.values(row)[1] || Object.values(row)[0]) as string;
+ }
+ } else {
+ message.error(`获取表 ${dbName}.${tableName} 结构失败: ` + (res.message || '未知错误'));
+ }
+
+ if (createSql) {
+ addAIContext(connectionKey, {
+ dbName: dbName,
+ tableName: tableName,
+ ddl: createSql
+ });
+ addedCount++;
+ }
+ }
+ if (addedCount > 0 || removedCount > 0) {
+ if (addedCount > 0 && removedCount === 0) {
+ message.success(`已添加 ${addedCount} 张表的结构到上下文`);
+ } else if (removedCount > 0 && addedCount === 0) {
+ message.success(`已从上下文移除 ${removedCount} 张表的结构`);
+ } else {
+ message.success(`上下文已同步更新:新增 ${addedCount},移除 ${removedCount}`);
+ }
+ if (addedCount > 0) setContextExpanded(true);
+ } else {
+ message.info('选中的表未发生变化');
+ }
+ setContextOpen(false);
+ } catch (e: any) {
+ message.error(e.message);
+ } finally {
+ setAppendingContext(false);
+ }
+ };
+
+ return (
+
+
+
+ {activeContextItems.length > 0 && (
+
setContextExpanded(!contextExpanded)}
+ style={{ background: darkMode ? 'rgba(24, 144, 255, 0.15)' : 'rgba(24, 144, 255, 0.08)', border: 'none', color: '#1890ff', borderRadius: 12, padding: '4px 10px', display: 'flex', alignItems: 'center', gap: 4, margin: 0, cursor: 'pointer', transition: 'all 0.3s' }}
+ >
+
+ 关联上下文 ({activeContextItems.length}) {contextExpanded ? '▴' : '▾'}
+
+
+ )}
+
+ {contextExpanded && activeContextItems.map((ctx, idx) => (
+
{ e.preventDefault(); removeAIContext(connectionKey, ctx.dbName, ctx.tableName); }}
+ style={{ background: darkMode ? 'rgba(255,255,255,0.08)' : 'rgba(0,0,0,0.04)', border: 'none', color: textColor, borderRadius: 12, padding: '4px 10px', display: 'flex', alignItems: 'center', gap: 4, margin: 0 }}
+ >
+ 🗄️ {ctx.tableName}
+
+ ))}
+ {draftImages.map((b64, i) => (
+
+

+
setDraftImages(prev => prev.filter((_, idx) => idx !== i))}
+ style={{ position: 'absolute', top: 2, right: 2, background: 'rgba(0,0,0,0.5)', color: '#fff', borderRadius: '50%', width: 16, height: 16, display: 'flex', alignItems: 'center', justifyContent: 'center', cursor: 'pointer', fontSize: 10 }}
+ >
+ ✕
+
+
+ ))}
+
+
+ {showSlashMenu && filteredSlashCmds.length > 0 && (
+
+ {filteredSlashCmds.map(cmd => (
+
e.currentTarget.style.background = darkMode ? 'rgba(255,255,255,0.08)' : 'rgba(0,0,0,0.04)'}
+ onMouseLeave={e => e.currentTarget.style.background = 'transparent'}
+ onClick={() => {
+ setInput(cmd.prompt);
+ setShowSlashMenu(false);
+ setSlashFilter('');
+ textareaRef.current?.focus();
+ }}
+ >
+ {cmd.cmd}
+ {cmd.label}
+ {cmd.desc}
+
+ ))}
+
+ )}
+
{
+ const items = e.clipboardData?.items;
+ if (!items) return;
+ for (let i = 0; i < items.length; i++) {
+ if (items[i].type.indexOf('image') !== -1) {
+ e.preventDefault();
+ const blob = items[i].getAsFile();
+ if (blob) {
+ const reader = new FileReader();
+ reader.onload = (event) => {
+ if (event.target?.result) {
+ setDraftImages(prev => [...prev, event.target!.result as string]);
+ }
+ };
+ reader.readAsDataURL(blob);
+ }
+ }
+ }
+ }}
+ ref={textareaRef as any}
+ value={input}
+ onChange={(e) => {
+ const val = e.target.value;
+ setInput(val);
+ // Slash command detection
+ if (val.startsWith('/')) {
+ setSlashFilter(val.split(/\s/)[0]);
+ setShowSlashMenu(true);
+ } else {
+ setShowSlashMenu(false);
+ setSlashFilter('');
+ }
+ }}
+ onKeyDown={handleKeyDown as any}
+ placeholder="输入消息... (Enter 发送,Shift+Enter 换行,/ 快捷命令)"
+ variant="borderless"
+ autoSize={{ minRows: 1, maxRows: 8 }}
+ style={{ color: textColor, width: '100%', padding: 0, resize: 'none' }}
+ />
+
+
+
+ {activeConnName && (
+
+
+
+
+ {activeConnName}{activeContext?.dbName ? ` / ${activeContext.dbName}` : ''}
+
+
+
+ )}
+
+ {activeProvider && (
+
+
+
+
+
+ }
+ onClick={() => fileInputRef.current?.click()}
+ style={{ color: overlayTheme.mutedText, border: 'none', background: 'transparent', padding: '0 4px', height: 26 }}
+ onMouseEnter={e => e.currentTarget.style.color = textColor}
+ onMouseLeave={e => e.currentTarget.style.color = overlayTheme.mutedText}
+ />
+
+
+ }
+ onClick={handleOpenContext}
+ style={{ color: overlayTheme.mutedText, border: 'none', background: 'transparent', padding: '0 4px', height: 26 }}
+ onMouseEnter={e => e.currentTarget.style.color = textColor}
+ onMouseLeave={e => e.currentTarget.style.color = overlayTheme.mutedText}
+ />
+
+ {sending ? (
+
+ ) : (
+
+ )}
+
+
+
+
+
关联数据库表结构上下文}
+ open={contextOpen}
+ onCancel={() => setContextOpen(false)}
+ onOk={handleAppendContext}
+ confirmLoading={appendingContext}
+ okText="同步所选表至上下文"
+ cancelText="取消"
+ centered
+ styles={{
+ content: { background: darkMode ? '#1e1e1e' : '#ffffff', border: overlayTheme.shellBorder },
+ header: { background: darkMode ? '#1e1e1e' : '#ffffff', borderBottom: overlayTheme.shellBorder },
+ body: { padding: '20px 24px' }
+ }}
+ >
+
+
+ {dbList.length > 0 && (
+
+ {filteredTables.length > 0 ? (
+
+
+ 0 &&
+ filteredTables.some(t => selectedTableKeys.includes(`${selectedDbName}::${t.name}`)) &&
+ !filteredTables.every(t => selectedTableKeys.includes(`${selectedDbName}::${t.name}`))
+ }
+ checked={filteredTables.length > 0 && filteredTables.every(t => selectedTableKeys.includes(`${selectedDbName}::${t.name}`))}
+ onChange={(e) => {
+ if (e.target.checked) {
+ const newSelected = new Set([...selectedTableKeys, ...filteredTables.map(t => `${selectedDbName}::${t.name}`)]);
+ setSelectedTableKeys(Array.from(newSelected));
+ } else {
+ const filteredKeys = filteredTables.map(t => `${selectedDbName}::${t.name}`);
+ setSelectedTableKeys(selectedTableKeys.filter(key => !filteredKeys.includes(key)));
+ }
+ }}
+ style={{ color: textColor, fontWeight: 'bold' }}
+ >
+ 全选匹配的表 ({filteredTables.length})
+
+
+
+
+
+ {filteredTables.map(t => {
+ const key = `${selectedDbName}::${t.name}`;
+ const isSelected = selectedTableKeys.includes(key);
+ return (
+
e.currentTarget.style.background = darkMode ? 'rgba(255,255,255,0.06)' : 'rgba(0,0,0,0.03)'}
+ onMouseLeave={e => e.currentTarget.style.background = 'transparent'}
+ onClick={(e) => {
+ // If click originated from the checkbox input itself, let its onChange handle it to avoid duplicate toggle
+ if ((e.target as HTMLElement).tagName.toLowerCase() === 'input') return;
+ if (isSelected) {
+ setSelectedTableKeys(selectedTableKeys.filter(k => k !== key));
+ } else {
+ setSelectedTableKeys([...selectedTableKeys, key]);
+ }
+ }}
+ >
+ {
+ if (e.target.checked) setSelectedTableKeys([...selectedTableKeys, key]);
+ else setSelectedTableKeys(selectedTableKeys.filter(k => k !== key));
+ }}
+ style={{ color: textColor, width: '100%' }}
+ >
+ {t.name}
+
+
+ );
+ })}
+
+
+
+ ) : (
+
+ 没有找到匹配 '{searchText}' 的表
+
+ )}
+
+
+
+ );
+};
diff --git a/frontend/src/components/ai/AIChatWelcome.tsx b/frontend/src/components/ai/AIChatWelcome.tsx
new file mode 100644
index 0000000..1e6040a
--- /dev/null
+++ b/frontend/src/components/ai/AIChatWelcome.tsx
@@ -0,0 +1,64 @@
+import React from 'react';
+import { RobotOutlined } from '@ant-design/icons';
+import type { OverlayWorkbenchTheme } from '../../utils/overlayWorkbenchTheme';
+
+interface AIChatWelcomeProps {
+ overlayTheme: OverlayWorkbenchTheme;
+ quickActionBg: string;
+ quickActionBorder: string;
+ textColor: string;
+ mutedColor: string;
+ onQuickAction: (prompt: string, autoSend?: boolean) => void;
+ contextTableNames?: string[];
+}
+
+export const AIChatWelcome: React.FC = ({
+ overlayTheme, quickActionBg, quickActionBorder, textColor, mutedColor, onQuickAction, contextTableNames = []
+}) => {
+ const hasContext = contextTableNames.length > 0;
+ const tableList = contextTableNames.join('、');
+
+ const quickActions = hasContext
+ ? [
+ { label: '📝 生成 SQL', prompt: `请根据以下表结构生成一条常用查询语句:${tableList}` },
+ { label: '🔍 解释表结构', prompt: `请详细解释以下表的设计意图和字段含义:${tableList}` },
+ { label: '⚡ 优化建议', prompt: `请分析以下表的结构设计,给出索引优化和查询性能优化建议:${tableList}` },
+ { label: '🏗️ Schema 分析', prompt: `请对以下表进行全面的 Schema 分析,包括数据类型选择、范式评估和改进建议:${tableList}` },
+ ]
+ : [
+ { label: '📝 生成 SQL', prompt: '请根据当前数据库表结构生成一条查询语句:' },
+ { label: '🔍 解释 SQL', prompt: '请解释以下 SQL 语句的执行逻辑:\n```sql\n\n```' },
+ { label: '⚡ 优化建议', prompt: '请分析以下 SQL 语句的性能并给出优化建议:\n```sql\n\n```' },
+ { label: '🏗️ Schema 分析', prompt: '请分析当前数据库的表结构并给出优化建议。' },
+ ];
+
+ return (
+
+
+
+ 你好,我是 GoNavi AI
+
+
+ {hasContext
+ ? `已自动关联 ${contextTableNames.length} 张表结构,点击下方按钮快速开始分析。`
+ : '我是你的智能数据库助手。我可以帮你生成 SQL 查询、分析表结构、解释执行逻辑以及优化数据库性能。'}
+
+
+ {quickActions.map(action => (
+
onQuickAction(action.prompt)}
+ >
+ {action.label}
+
+ ))}
+
+
+ );
+};
diff --git a/frontend/src/components/ai/AIHistoryDrawer.tsx b/frontend/src/components/ai/AIHistoryDrawer.tsx
new file mode 100644
index 0000000..520d288
--- /dev/null
+++ b/frontend/src/components/ai/AIHistoryDrawer.tsx
@@ -0,0 +1,127 @@
+import React, { useState } from 'react';
+import { Drawer, Button, Tooltip, Input } from 'antd';
+import { MenuFoldOutlined, PlusOutlined, DeleteOutlined, SearchOutlined } from '@ant-design/icons';
+import { useStore } from '../../store';
+
+interface AIHistoryDrawerProps {
+ open: boolean;
+ onClose: () => void;
+ bgColor?: string;
+ darkMode: boolean;
+ textColor: string;
+ mutedColor: string;
+ borderColor: string;
+ onCreateNew: () => void;
+ sessionId: string;
+}
+
+export const AIHistoryDrawer: React.FC = ({
+ open, onClose, bgColor, darkMode, textColor, mutedColor, borderColor, onCreateNew, sessionId
+}) => {
+ const aiChatSessions = useStore(state => state.aiChatSessions);
+ const setAIActiveSessionId = useStore(state => state.setAIActiveSessionId);
+ const deleteAISession = useStore(state => state.deleteAISession);
+
+ // 阶段4: 历史记录搜索
+ const [searchText, setSearchText] = useState('');
+
+ const filteredSessions = aiChatSessions.filter(s =>
+ !searchText || (s.title && s.title.toLowerCase().includes(searchText.toLowerCase()))
+ );
+
+ return (
+
+ {/* 侧拉面板头部 */}
+
+ 对话历史
+
+ } onClick={onClose} style={{ color: mutedColor }} />
+
+
+
+ {/* 新建对话按钮 */}
+
+ }
+ onClick={() => { onCreateNew(); onClose(); }}
+ style={{ borderColor: borderColor, color: textColor, background: 'transparent' }}
+ >
+ 开启新对话
+
+
+
+ {/* 列表搜索 */}
+
+ }
+ value={searchText}
+ onChange={e => setSearchText(e.target.value)}
+ variant="filled"
+ size="small"
+ style={{ background: darkMode ? 'rgba(255,255,255,0.04)' : 'transparent', color: textColor }}
+ />
+
+
+ {/* 列表容器 */}
+
+ {filteredSessions.length === 0 ? (
+
暂无匹配的对话记录
+ ) : (
+ filteredSessions.map(session => (
+
{ setAIActiveSessionId(session.id); onClose(); }}
+ style={{
+ padding: '10px 12px',
+ borderRadius: 6,
+ marginBottom: 4,
+ cursor: 'pointer',
+ display: 'flex',
+ justifyContent: 'space-between',
+ alignItems: 'center',
+ background: sessionId === session.id ? (darkMode ? 'rgba(255,255,255,0.06)' : 'rgba(0,0,0,0.04)') : 'transparent',
+ transition: 'background 0.2s',
+ }}
+ >
+
+
+ {session.title || '新对话'}
+
+
+ {new Date(session.updatedAt).toLocaleString(undefined, { month: 'numeric', day: 'numeric', hour: '2-digit', minute: '2-digit' })}
+
+
+
+ }
+ onClick={(e) => {
+ e.stopPropagation();
+ deleteAISession(session.id);
+ }}
+ style={{ display: sessionId === session.id ? 'inline-flex' : undefined }}
+ />
+
+
+ ))
+ )}
+
+
+ );
+};
diff --git a/frontend/src/components/ai/AIMessageBubble.tsx b/frontend/src/components/ai/AIMessageBubble.tsx
new file mode 100644
index 0000000..453d1b6
--- /dev/null
+++ b/frontend/src/components/ai/AIMessageBubble.tsx
@@ -0,0 +1,714 @@
+import React, { useState, useEffect, useRef } from 'react';
+import { Tooltip, message } from 'antd';
+import { UserOutlined, RobotOutlined, EditOutlined, ReloadOutlined, DeleteOutlined, CheckOutlined, CopyOutlined, PlayCircleOutlined, ApiOutlined, LoadingOutlined, CaretRightOutlined, CaretDownOutlined } from '@ant-design/icons';
+import ReactMarkdown from 'react-markdown';
+import remarkGfm from 'remark-gfm';
+import mermaid from 'mermaid';
+import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter';
+import { vscDarkPlus, vs } from 'react-syntax-highlighter/dist/esm/styles/prism';
+import { AIChatMessage, AIToolCall } from '../../types';
+import type { OverlayWorkbenchTheme } from '../../utils/overlayWorkbenchTheme';
+
+// 🔧 性能优化:将 ReactMarkdown 包装为 Memo 组件并提取固定的 plugins
+const remarkPlugins = [remarkGfm];
+
+const MemoizedMarkdown = React.memo(({
+ content,
+ darkMode,
+ overlayTheme,
+ activeConnectionConfig,
+ activeConnectionId,
+ activeDbName
+}: {
+ content: string;
+ darkMode: boolean;
+ overlayTheme: OverlayWorkbenchTheme;
+ activeConnectionConfig?: any;
+ activeConnectionId?: string;
+ activeDbName?: string;
+}) => {
+ // 缓存 components 对象,避免每次渲染都生成新的函数引用击穿内部子组件的 memo
+ const components = React.useMemo(() => ({
+ code({ node, inline, className, children, ...props }: any) {
+ const match = /language-(\w+)/.exec(className || '');
+ if (!inline && match && match[1] === 'mermaid') {
+ return ;
+ }
+ return !inline && match ? (
+
+ ) : (
+
+ {children}
+
+ );
+ }
+ }), [darkMode, overlayTheme, activeConnectionConfig, activeConnectionId, activeDbName]);
+
+ return (
+
+ {content}
+
+ );
+});
+
+interface AIMessageBubbleProps {
+ msg: AIChatMessage;
+ darkMode: boolean;
+ overlayTheme: OverlayWorkbenchTheme;
+ textColor: string;
+ onEdit: (msg: AIChatMessage) => void;
+ onRetry: (msg: AIChatMessage) => void;
+ onDelete: (id: string) => void;
+ activeConnectionId?: string;
+ activeConnectionConfig?: any;
+ activeDbName?: string;
+ allMessages?: AIChatMessage[];
+}
+
+const AIToolResultItem: React.FC<{ resultMsg: AIChatMessage, darkMode: boolean, overlayTheme: OverlayWorkbenchTheme }> = ({ resultMsg, darkMode, overlayTheme }) => {
+ const [toolExpanded, setToolExpanded] = useState(false);
+ const charCount = resultMsg.content ? resultMsg.content.length : 0;
+ return (
+
+
setToolExpanded(!toolExpanded)}
+ >
+ {toolExpanded ?
:
}
+
+
探针执行结果 ({resultMsg.tool_name || 'unknown'})
+
{charCount > 0 ? `${charCount} 个字符` : '无数据'}
+
+ {toolExpanded && (
+
+ {resultMsg.content}
+
+ )}
+
+ );
+};
+
+const MermaidRenderer = ({ chart, darkMode }: { chart: string, darkMode: boolean }) => {
+ const containerRef = React.useRef(null);
+
+ React.useEffect(() => {
+ if (containerRef.current) {
+ try {
+ mermaid.initialize({ startOnLoad: false, theme: darkMode ? 'dark' : 'default' });
+ const id = `mermaid-${Math.random().toString(36).substring(2)}`;
+ (async () => {
+ const result: any = await mermaid.render(id, chart);
+ if (containerRef.current) {
+ containerRef.current.innerHTML = result.svg || result;
+ }
+ })().catch((e: any) => {
+ if (containerRef.current) {
+ containerRef.current.innerHTML = `Mermaid 解析失败: ${e.message}
`;
+ }
+ });
+ } catch (e: any) {
+ if (containerRef.current) {
+ containerRef.current.innerHTML = `Mermaid 渲染异常: ${e.message}
`;
+ }
+ }
+ }
+ }, [chart, darkMode]);
+
+ return ;
+};
+
+const CodeCopyBtn = ({ text }: { text: string }) => {
+ const [copied, setCopied] = useState(false);
+ return (
+ {
+ navigator.clipboard.writeText(text);
+ setCopied(true);
+ setTimeout(() => setCopied(false), 2000);
+ }}
+ style={{
+ cursor: 'pointer',
+ display: 'flex',
+ alignItems: 'center',
+ opacity: copied ? 1 : 0.6,
+ transition: 'opacity 0.2s',
+ }}
+ onMouseEnter={(e) => { e.currentTarget.style.opacity = '1'; }}
+ onMouseLeave={(e) => { e.currentTarget.style.opacity = copied ? '1' : '0.6'; }}
+ >
+ {copied ? : }
+ {copied ? '已复制' : '复制代码'}
+
+ );
+};
+
+const CodeRunBtn = ({ text, connectionId, dbName }: { text: string; connectionId?: string; dbName?: string }) => {
+ // 解析 SQL 顶部的 @context 注释,格式:-- @context connectionId=xxx dbName=yyy
+ const contextMatch = text.match(/^--\s*@context\s+connectionId=(\S+)\s+dbName=(\S+)/m);
+ const resolvedConnId = contextMatch?.[1] || connectionId;
+ const resolvedDbName = contextMatch?.[2] || dbName;
+ // 发送给查询编辑器时去掉 @context 注释行
+ const cleanSql = text.replace(/^--\s*@context\s+.*\n?/gm, '').trim();
+ const sqlDetail = (runImmediately: boolean) => ({ sql: cleanSql, runImmediately, connectionId: resolvedConnId, dbName: resolvedDbName });
+ const handleExecute = async () => {
+ try {
+ const Service = (window as any).go?.aiservice?.Service;
+ if (Service?.AICheckSQL) {
+ const result = await Service.AICheckSQL(text);
+ if (!result.allowed) {
+ message.error(`🔒 安全策略拦截:当前安全级别不允许执行 ${result.operationType} 类型的 SQL。请在 AI 设置中调整安全级别。`);
+ return;
+ }
+ if (result.requiresConfirm) {
+ const { Modal } = await import('antd');
+ Modal.confirm({
+ title: '⚠️ 安全确认',
+ content: result.warningMessage || `此 SQL 为 ${result.operationType} 操作,确定要执行吗?`,
+ okText: '确认执行',
+ cancelText: '取消',
+ okButtonProps: { danger: true },
+ onOk: () => {
+ window.dispatchEvent(new CustomEvent('gonavi:insert-sql', { detail: sqlDetail(true) }));
+ },
+ });
+ return;
+ }
+ }
+ // Safety check passed or not available, execute directly
+ window.dispatchEvent(new CustomEvent('gonavi:insert-sql', { detail: sqlDetail(true) }));
+ } catch (e) {
+ // If safety check fails, still allow manual execution
+ window.dispatchEvent(new CustomEvent('gonavi:insert-sql', { detail: sqlDetail(true) }));
+ }
+ };
+
+ return (
+
+
+ {
+ window.dispatchEvent(new CustomEvent('gonavi:insert-sql', { detail: sqlDetail(false) }));
+ }}
+ style={{
+ cursor: 'pointer', display: 'flex', alignItems: 'center',
+ opacity: 0.6, transition: 'opacity 0.2s', padding: '0 4px', color: '#10b981'
+ }}
+ onMouseEnter={(e) => { e.currentTarget.style.opacity = '1'; }}
+ onMouseLeave={(e) => { e.currentTarget.style.opacity = '0.6'; }}
+ >
+
+ 插入
+
+
+
+ { e.currentTarget.style.opacity = '1'; }}
+ onMouseLeave={(e) => { e.currentTarget.style.opacity = '0.6'; }}
+ >
+
+ 执行
+
+
+
+ );
+};
+
+// 阶段2: 代码块体验升级 (折叠展开、行号显示、内联SQL预览)
+const AIBlockHashRender = ({ match, darkMode, overlayTheme, children, activeConnectionConfig, activeConnectionId, activeDbName }: any) => {
+ const codeText = String(children).replace(/\n$/, '');
+ // 将 @context 注释行从显示文本中剔除,用户无需看到内部元数据
+ const displayText = codeText.replace(/^--\s*@context\s+.*\n?/gm, '').trim();
+ const [expanded, setExpanded] = useState(false);
+ const [previewData, setPreviewData] = useState(null);
+ const [previewCols, setPreviewCols] = useState([]);
+ const [previewLoading, setPreviewLoading] = useState(false);
+ const [previewError, setPreviewError] = useState('');
+ const [previewExpanded, setPreviewExpanded] = useState(false);
+
+ const MAX_HEIGHT = 300;
+ const isLongCode = displayText.split('\n').length > 15;
+ const isSql = match[1] === 'sql';
+ const isSelectQuery = isSql && /^\s*(SELECT|SHOW|DESCRIBE|DESC|EXPLAIN)\b/i.test(displayText.trim());
+
+ const handleInlineExecute = async () => {
+ if (!activeConnectionConfig || previewLoading) return;
+ setPreviewLoading(true);
+ setPreviewError('');
+ setPreviewData(null);
+ try {
+ const { DBQuery } = await import('../../../wailsjs/go/app/App');
+ const res = await DBQuery(activeConnectionConfig, activeDbName || '', displayText + ' LIMIT 50');
+ if (res.success && Array.isArray(res.data)) {
+ const rows = res.data as any[];
+ const cols = rows.length > 0 ? Object.keys(rows[0]) : [];
+ setPreviewCols(cols);
+ setPreviewData(rows.slice(0, 20));
+ setPreviewExpanded(true);
+ } else {
+ setPreviewError(res.message || '查询无结果');
+ }
+ } catch (err: any) {
+ setPreviewError(err?.message || '执行失败');
+ } finally {
+ setPreviewLoading(false);
+ }
+ };
+
+ return (
+
+
+
{match[1]}
+
+ {isSql && }
+ {isSelectQuery && activeConnectionConfig && (
+
+ { if (!previewLoading) e.currentTarget.style.opacity = '1'; }}
+ onMouseLeave={(e) => { if (!previewLoading) e.currentTarget.style.opacity = '0.6'; }}
+ >
+ {previewLoading ? '⏳' : '👁'}
+ {previewLoading ? '执行中...' : '预览'}
+
+
+ )}
+
+
+
+
+
+
+ {displayText}
+
+
+ {!expanded && isLongCode && (
+
setExpanded(true)}
+ >
+
+ 展开全部代码
+
+
+ )}
+ {expanded && isLongCode && (
+
setExpanded(false)}
+ >
+ 收起代码
+
+ )}
+
+
+ {/* Inline SQL Preview Results */}
+ {previewError && (
+
+ ❌ {previewError}
+
+ )}
+ {previewExpanded && previewData && previewData.length > 0 && (
+
+
+ 📊 预览结果({previewData.length} 行 × {previewCols.length} 列)
+ setPreviewExpanded(false)}>收起 ▴
+
+
+
+
+
+ {previewCols.map(col => (
+ |
+ {col}
+ |
+ ))}
+
+
+
+ {previewData.map((row, ri) => (
+
+ {previewCols.map(col => (
+ |
+ {row[col] === null ? NULL : String(row[col])}
+ |
+ ))}
+
+ ))}
+
+
+
+
+ )}
+ {!previewExpanded && previewData && previewData.length > 0 && (
+
setPreviewExpanded(true)}
+ >
+ 📊 查看结果({previewData.length} 行)▾
+
+ )}
+
+ );
+};
+
+// 可折叠思考过程组件
+const ThinkingBlock: React.FC<{ displayThinking: string; totalLen: number; isTyping: boolean; isGlobalLoading: boolean; darkMode: boolean; overlayTheme: any; hasContent: boolean }> = ({ displayThinking, totalLen, isTyping, isGlobalLoading, darkMode, overlayTheme, hasContent }) => {
+ // 如果整体在loading,且尚未吐出content,我们认为真正的思考还在进行;如果吐出content了,思考框就算告一段落
+ const isActivelyThinking = isGlobalLoading && !hasContent;
+ const [expanded, setExpanded] = useState(isActivelyThinking);
+ const contentRef = React.useRef(null);
+
+ React.useEffect(() => { if (isActivelyThinking) setExpanded(true); }, [isActivelyThinking]);
+
+ // 断开连接或思考结束时,若已有内容且不再产生新内容则默认收起
+ React.useEffect(() => {
+ if (!isGlobalLoading) setExpanded(false);
+ }, [isGlobalLoading]);
+
+ // 自动滚动到思考内容底部
+ React.useEffect(() => {
+ if (expanded && isTyping && contentRef.current) {
+ contentRef.current.scrollTop = contentRef.current.scrollHeight;
+ }
+ }, [displayThinking, expanded, isTyping]);
+
+ return (
+
+
setExpanded(e => !e)}
+ style={{
+ display: 'flex', alignItems: 'center', gap: 6,
+ padding: '6px 10px', cursor: 'pointer',
+ background: darkMode ? 'rgba(255,255,255,0.04)' : 'rgba(0,0,0,0.02)',
+ fontSize: 12, color: overlayTheme.mutedText, userSelect: 'none',
+ }}
+ >
+ ▶
+ 💭 思考过程
+ {isActivelyThinking && 思考中...}
+ {!isActivelyThinking && ({displayThinking.length} 字)}
+
+
+
+ {displayThinking}
+ {isTyping && }
+
+
+
+ );
+};
+
+// 工具调用进度面板聚合展示组件
+const AIToolCallingBlock: React.FC<{ tool_calls: AIToolCall[]; loading: boolean; allMessages: AIChatMessage[]; darkMode: boolean; overlayTheme: any; hasContent: boolean }> = ({ tool_calls, loading, allMessages, darkMode, overlayTheme, hasContent }) => {
+ const totalCalls = tool_calls.length;
+ const allDone = tool_calls.every(tc => allMessages?.find(m => m.role === 'tool' && m.tool_call_id === tc.id));
+ const [expanded, setExpanded] = useState(!allDone && loading);
+
+ // 断开连接或执行完毕时,若已完成则默认收起
+ React.useEffect(() => {
+ if (allDone || !loading) setExpanded(false);
+ }, [allDone, loading]);
+
+ // 显示友好的人类可读动作名
+ const getHumanActionName = (fname: string) => {
+ if (fname === 'get_connections') return '获取可用连接信息';
+ if (fname === 'get_databases') return '扫描数据库列表';
+ if (fname === 'get_tables') return '分析表结构信息';
+ return fname;
+ };
+
+ return (
+
+
setExpanded(!expanded)}
+ style={{
+ display: 'flex', alignItems: 'center', justifyContent: 'space-between',
+ padding: '8px 12px', cursor: 'pointer', userSelect: 'none',
+ background: darkMode ? 'rgba(255,255,255,0.02)' : 'rgba(0,0,0,0.01)',
+ }}
+ >
+
+ {!allDone && loading ? (
+
+ ) : (
+
+ )}
+
{!allDone && loading ? '正在执行数据探针...' : `数据探针执行完毕 (${totalCalls} 项)`}
+
+
▶
+
+
+
+ {tool_calls.map((tc, idx) => {
+ const resultMsg = allMessages?.find(m => m.role === 'tool' && m.tool_call_id === tc.id);
+ const isDone = !!resultMsg;
+ const actionName = getHumanActionName(tc.function.name);
+ return (
+
+
+ {isDone
+ ?
+ : (loading ?
:
)
+ }
+
{actionName}
+
+ {resultMsg &&
}
+
+ );
+ })}
+
+
+
+ );
+};
+
+export const AIMessageBubble: React.FC = React.memo(({ msg, darkMode, overlayTheme, textColor, onEdit, onRetry, onDelete, activeConnectionId, activeConnectionConfig, activeDbName, allMessages }) => {
+ const [isCopied, setIsCopied] = useState(false);
+ const isUser = msg.role === 'user';
+
+ const displayContent = msg.content;
+ const isTypingThinking = !!(msg.loading && msg.phase === 'thinking');
+
+ if (msg.role === 'tool') return null;
+
+ // 如果是纯空壳的加载状态(connecting,或还在思考/工具阶段但还没吐出一个字的 content)
+ const isWaitState = msg.phase === 'connecting' ||
+ (msg.loading && !msg.content && (msg.phase === 'thinking' || msg.phase === 'tool_calling'));
+
+ if (isWaitState) {
+ return (
+
+
+
+
+
+
+
{msg.content || '正在建立连接'}...
+
+
+ {/* 即使在波纹过渡态,如果有 thinking / tool_calls 也要显示出来,只是把它们压在波纹下面 */}
+
0) ? 12 : 0 }}>
+ {!isUser && msg.thinking && (
+
+ )}
+ {!isUser && msg.tool_calls && msg.tool_calls.length > 0 && (
+
+ )}
+
+
+
+ );
+ }
+
+ return (
+
+
+
+
+ {isUser
+ ? <> You>
+ : <> GoNavi AI>}
+
+ {/* 气泡操作栏 */}
+
+
+ {isCopied ? (
+
+ ) : (
+ {
+ navigator.clipboard.writeText(msg.content);
+ setIsCopied(true);
+ setTimeout(() => setIsCopied(false), 2000);
+ }} style={{ cursor: 'pointer', color: overlayTheme.mutedText }} onMouseEnter={e => e.currentTarget.style.color = textColor} onMouseLeave={e => e.currentTarget.style.color = overlayTheme.mutedText} />
+ )}
+
+ {isUser ? (
+
+ onEdit(msg)} style={{ cursor: 'pointer', color: overlayTheme.mutedText }} onMouseEnter={e => e.currentTarget.style.color = textColor} onMouseLeave={e => e.currentTarget.style.color = overlayTheme.mutedText} />
+
+ ) : (
+
+ onRetry(msg)} style={{ cursor: 'pointer', color: overlayTheme.mutedText }} onMouseEnter={e => e.currentTarget.style.color = textColor} onMouseLeave={e => e.currentTarget.style.color = overlayTheme.mutedText} />
+
+ )}
+
+ onDelete(msg.id)} style={{ cursor: 'pointer', color: overlayTheme.mutedText }} onMouseEnter={e => e.currentTarget.style.color = '#ef4444'} onMouseLeave={e => e.currentTarget.style.color = overlayTheme.mutedText} />
+
+
+
+
+ {msg.images && msg.images.length > 0 && (
+
+ {msg.images.map((img, i) => (
+

+ ))}
+
+ )}
+ {/* 可折叠思考过程 */}
+ {!isUser && msg.thinking && (
+
+ )}
+ {isUser ? (
+
{msg.content}
+ ) : (
+
+ )}
+ {/* 错误原文复制按钮 */}
+ {!isUser && msg.rawError && (
+
+
+
+ )}
+ {/* 工具调用进度展示 */}
+ {!isUser && msg.tool_calls && msg.tool_calls.length > 0 && (
+
+ )}
+ {msg.loading && msg.phase !== 'tool_calling' && msg.content && (
+
+ )}
+
+
+
+ );
+});
diff --git a/frontend/src/store.ts b/frontend/src/store.ts
index e671270..55fcc2a 100644
--- a/frontend/src/store.ts
+++ b/frontend/src/store.ts
@@ -1,6 +1,6 @@
import { create } from 'zustand';
import { persist } from 'zustand/middleware';
-import { ConnectionConfig, ProxyConfig, SavedConnection, TabData, SavedQuery, ConnectionTag, AIChatMessage } from './types';
+import { ConnectionConfig, ProxyConfig, SavedConnection, TabData, SavedQuery, ConnectionTag, AIChatMessage, AIContextItem } from './types';
import {
ShortcutAction,
ShortcutBinding,
@@ -427,8 +427,15 @@ interface AppState {
// AI 运行时与持久化状态
aiPanelVisible: boolean;
aiChatHistory: Record; // sessionId -> messages
+ replaceAIChatHistory: (sessionId: string, messages: AIChatMessage[]) => void;
aiChatSessions: { id: string; title: string; updatedAt: number }[]; // 历史会话列表
aiActiveSessionId: string | null;
+ updateAISessionTitle: (sessionId: string, title: string) => void;
+
+ aiContexts: Record;
+ addAIContext: (connectionKey: string, context: AIContextItem) => void;
+ removeAIContext: (connectionKey: string, dbName: string, tableName: string) => void;
+ clearAIContexts: (connectionKey: string) => void;
addConnection: (conn: SavedConnection) => void;
updateConnection: (conn: SavedConnection) => void;
@@ -694,6 +701,7 @@ export const useStore = create()(
aiChatHistory: {},
aiChatSessions: [],
aiActiveSessionId: null,
+ aiContexts: {},
addConnection: (conn) => set((state) => ({ connections: [...state.connections, conn] })),
updateConnection: (conn) => set((state) => ({
@@ -1002,19 +1010,26 @@ export const useStore = create()(
return { aiChatHistory: history, aiChatSessions: newSessions };
}),
updateAIChatMessage: (sessionId, messageId, updates) => set((state) => {
- const history = { ...state.aiChatHistory };
- const messages = history[sessionId];
+ const messages = state.aiChatHistory[sessionId];
if (!messages) return state;
- history[sessionId] = messages.map(m =>
- m.id === messageId ? { ...m, ...updates } : m
- );
- let newSessions = [...state.aiChatSessions];
- const existingSession = newSessions.find(s => s.id === sessionId);
- if (existingSession) {
- newSessions = newSessions.filter(s => s.id !== sessionId);
- newSessions.unshift({ ...existingSession, updatedAt: Date.now() });
+ // 🔧 性能优化:用 findIndex + 定点替换代替全量 map,长对话场景下从 O(n) 降至 O(1)
+ const idx = messages.findIndex(m => m.id === messageId);
+ if (idx < 0) return state;
+ const newMessages = [...messages];
+ newMessages[idx] = { ...newMessages[idx], ...updates };
+ const history = { ...state.aiChatHistory, [sessionId]: newMessages };
+ // 仅当非纯 content 追加时才重排 session 顺序(性能优化:流式打字时跳过)
+ const isContentOnlyUpdate = Object.keys(updates).length === 1 && 'content' in updates;
+ if (!isContentOnlyUpdate) {
+ let newSessions = [...state.aiChatSessions];
+ const existingSession = newSessions.find(s => s.id === sessionId);
+ if (existingSession) {
+ newSessions = newSessions.filter(s => s.id !== sessionId);
+ newSessions.unshift({ ...existingSession, updatedAt: Date.now() });
+ }
+ return { aiChatHistory: history, aiChatSessions: newSessions };
}
- return { aiChatHistory: history, aiChatSessions: newSessions };
+ return { aiChatHistory: history };
}),
deleteAIChatMessage: (sessionId, messageId) => set((state) => {
const history = { ...state.aiChatHistory };
@@ -1039,6 +1054,11 @@ export const useStore = create()(
delete history[sessionId];
return { aiChatHistory: history };
}),
+ replaceAIChatHistory: (sessionId, messages) => set((state) => {
+ const history = { ...state.aiChatHistory };
+ history[sessionId] = messages;
+ return { aiChatHistory: history };
+ }),
deleteAISession: (sessionId) => set((state) => {
const history = { ...state.aiChatHistory };
delete history[sessionId];
@@ -1051,6 +1071,39 @@ export const useStore = create()(
return { aiActiveSessionId: newId };
}),
setAIActiveSessionId: (sessionId) => set({ aiActiveSessionId: sessionId }),
+ updateAISessionTitle: (sessionId, title) => set((state) => {
+ const newSessions = [...state.aiChatSessions];
+ const session = newSessions.find(s => s.id === sessionId);
+ if (session) {
+ session.title = title;
+ }
+ return { aiChatSessions: newSessions };
+ }),
+ addAIContext: (connectionKey, context) => set((state) => {
+ const contexts = state.aiContexts[connectionKey] || [];
+ if (contexts.find(c => c.dbName === context.dbName && c.tableName === context.tableName)) {
+ return state;
+ }
+ return {
+ aiContexts: {
+ ...state.aiContexts,
+ [connectionKey]: [...contexts, context]
+ }
+ };
+ }),
+ removeAIContext: (connectionKey, dbName, tableName) => set((state) => {
+ const contexts = state.aiContexts[connectionKey] || [];
+ return {
+ aiContexts: {
+ ...state.aiContexts,
+ [connectionKey]: contexts.filter(c => !(c.dbName === dbName && c.tableName === tableName))
+ }
+ };
+ }),
+ clearAIContexts: (connectionKey) => set((state) => {
+ const { [connectionKey]: _, ...rest } = state.aiContexts;
+ return { aiContexts: rest };
+ }),
}),
{
name: 'lite-db-storage', // name of the item in the storage (must be unique)
@@ -1147,8 +1200,17 @@ export const useStore = create()(
windowState: state.windowState,
sidebarWidth: state.sidebarWidth,
- aiChatHistory: state.aiChatHistory,
- aiChatSessions: state.aiChatSessions,
+ // 只持久化最近 20 个会话的聊天记录,防止 localStorage 膨胀
+ aiChatHistory: (() => {
+ const MAX_PERSIST_SESSIONS = 20;
+ const recentIds = new Set(state.aiChatSessions.slice(0, MAX_PERSIST_SESSIONS).map(s => s.id));
+ const trimmed: Record = {};
+ for (const id of recentIds) {
+ if (state.aiChatHistory[id]) trimmed[id] = state.aiChatHistory[id];
+ }
+ return trimmed;
+ })(),
+ aiChatSessions: state.aiChatSessions.slice(0, 50),
}), // Don't persist logs
}
)
diff --git a/frontend/src/types.ts b/frontend/src/types.ts
index 072d65c..8633c10 100644
--- a/frontend/src/types.ts
+++ b/frontend/src/types.ts
@@ -190,6 +190,12 @@ export type AIProviderType = 'openai' | 'anthropic' | 'gemini' | 'custom';
export type AISafetyLevel = 'readonly' | 'readwrite' | 'full';
export type AIContextLevel = 'schema_only' | 'with_samples' | 'with_results';
+export interface AIContextItem {
+ dbName: string;
+ tableName: string;
+ ddl: string;
+}
+
export interface AIProviderConfig {
id: string;
type: AIProviderType;
@@ -204,12 +210,31 @@ export interface AIProviderConfig {
temperature: number;
}
+export interface AIToolCall {
+ id: string;
+ type: string;
+ function: {
+ name: string;
+ arguments: string;
+ };
+}
+
+export type ChatPhase = 'idle' | 'connecting' | 'thinking' | 'generating' | 'tool_calling';
+
export interface AIChatMessage {
id: string;
- role: 'user' | 'assistant' | 'system';
+ role: 'user' | 'assistant' | 'system' | 'tool';
+ phase?: ChatPhase;
content: string;
+ thinking?: string;
timestamp: number;
loading?: boolean;
+ images?: string[]; // base64 encoded images with data URI prefix
+ tool_calls?: AIToolCall[];
+ tool_call_id?: string;
+ tool_name?: string; // used for UI display
+ rawError?: string; // 存储未清洗的原始错误信息,用于用户复制排查
+ success?: boolean; // 标记探针执行是否成功
}
export interface AISafetyResult {
diff --git a/frontend/wailsjs/go/aiservice/Service.d.ts b/frontend/wailsjs/go/aiservice/Service.d.ts
old mode 100644
new mode 100755
index 872b5f5..6ffc07a
--- a/frontend/wailsjs/go/aiservice/Service.d.ts
+++ b/frontend/wailsjs/go/aiservice/Service.d.ts
@@ -5,9 +5,9 @@ import {context} from '../models';
export function AIChatCancel(arg1:string):Promise;
-export function AIChatSend(arg1:Array>):Promise>;
+export function AIChatSend(arg1:Array,arg2:Array):Promise>;
-export function AIChatStream(arg1:string,arg2:Array>):Promise;
+export function AIChatStream(arg1:string,arg2:Array,arg3:Array):Promise;
export function AICheckSQL(arg1:string):Promise;
diff --git a/frontend/wailsjs/go/aiservice/Service.js b/frontend/wailsjs/go/aiservice/Service.js
old mode 100644
new mode 100755
index 2e3dcf4..7f5de4a
--- a/frontend/wailsjs/go/aiservice/Service.js
+++ b/frontend/wailsjs/go/aiservice/Service.js
@@ -6,12 +6,12 @@ export function AIChatCancel(arg1) {
return window['go']['aiservice']['Service']['AIChatCancel'](arg1);
}
-export function AIChatSend(arg1) {
- return window['go']['aiservice']['Service']['AIChatSend'](arg1);
+export function AIChatSend(arg1, arg2) {
+ return window['go']['aiservice']['Service']['AIChatSend'](arg1, arg2);
}
-export function AIChatStream(arg1, arg2) {
- return window['go']['aiservice']['Service']['AIChatStream'](arg1, arg2);
+export function AIChatStream(arg1, arg2, arg3) {
+ return window['go']['aiservice']['Service']['AIChatStream'](arg1, arg2, arg3);
}
export function AICheckSQL(arg1) {
diff --git a/frontend/wailsjs/go/models.ts b/frontend/wailsjs/go/models.ts
index 258b148..e9558a8 100755
--- a/frontend/wailsjs/go/models.ts
+++ b/frontend/wailsjs/go/models.ts
@@ -1,5 +1,78 @@
export namespace ai {
+ export class ToolCall {
+ id: string;
+ type: string;
+ // Go type: struct { Name string "json:\"name\""; Arguments string "json:\"arguments\"" }
+ function: any;
+
+ static createFrom(source: any = {}) {
+ return new ToolCall(source);
+ }
+
+ constructor(source: any = {}) {
+ if ('string' === typeof source) source = JSON.parse(source);
+ this.id = source["id"];
+ this.type = source["type"];
+ this.function = this.convertValues(source["function"], Object);
+ }
+
+ convertValues(a: any, classs: any, asMap: boolean = false): any {
+ if (!a) {
+ return a;
+ }
+ if (a.slice && a.map) {
+ return (a as any[]).map(elem => this.convertValues(elem, classs));
+ } else if ("object" === typeof a) {
+ if (asMap) {
+ for (const key of Object.keys(a)) {
+ a[key] = new classs(a[key]);
+ }
+ return a;
+ }
+ return new classs(a);
+ }
+ return a;
+ }
+ }
+ export class Message {
+ role: string;
+ content: string;
+ images?: string[];
+ tool_call_id?: string;
+ tool_calls?: ToolCall[];
+
+ static createFrom(source: any = {}) {
+ return new Message(source);
+ }
+
+ constructor(source: any = {}) {
+ if ('string' === typeof source) source = JSON.parse(source);
+ this.role = source["role"];
+ this.content = source["content"];
+ this.images = source["images"];
+ this.tool_call_id = source["tool_call_id"];
+ this.tool_calls = this.convertValues(source["tool_calls"], ToolCall);
+ }
+
+ convertValues(a: any, classs: any, asMap: boolean = false): any {
+ if (!a) {
+ return a;
+ }
+ if (a.slice && a.map) {
+ return (a as any[]).map(elem => this.convertValues(elem, classs));
+ } else if ("object" === typeof a) {
+ if (asMap) {
+ for (const key of Object.keys(a)) {
+ a[key] = new classs(a[key]);
+ }
+ return a;
+ }
+ return new classs(a);
+ }
+ return a;
+ }
+ }
export class ProviderConfig {
id: string;
type: string;
@@ -50,6 +123,55 @@ export namespace ai {
this.warningMessage = source["warningMessage"];
}
}
+ export class ToolFunction {
+ name: string;
+ description: string;
+ parameters: any;
+
+ static createFrom(source: any = {}) {
+ return new ToolFunction(source);
+ }
+
+ constructor(source: any = {}) {
+ if ('string' === typeof source) source = JSON.parse(source);
+ this.name = source["name"];
+ this.description = source["description"];
+ this.parameters = source["parameters"];
+ }
+ }
+ export class Tool {
+ type: string;
+ function: ToolFunction;
+
+ static createFrom(source: any = {}) {
+ return new Tool(source);
+ }
+
+ constructor(source: any = {}) {
+ if ('string' === typeof source) source = JSON.parse(source);
+ this.type = source["type"];
+ this.function = this.convertValues(source["function"], ToolFunction);
+ }
+
+ convertValues(a: any, classs: any, asMap: boolean = false): any {
+ if (!a) {
+ return a;
+ }
+ if (a.slice && a.map) {
+ return (a as any[]).map(elem => this.convertValues(elem, classs));
+ } else if ("object" === typeof a) {
+ if (asMap) {
+ for (const key of Object.keys(a)) {
+ a[key] = new classs(a[key]);
+ }
+ return a;
+ }
+ return new classs(a);
+ }
+ return a;
+ }
+ }
+
}
diff --git a/internal/ai/context/builder.go b/internal/ai/context/builder.go
index 9c4e75a..f61d49a 100644
--- a/internal/ai/context/builder.go
+++ b/internal/ai/context/builder.go
@@ -208,6 +208,7 @@ func buildGeneralChatPrompt() string {
互动守则:
- 永远使用专业、具有合作感且充满信心的中文与用户探讨问题。
-- 当被要求提供任何数据库代码时,需结合相关数据库引擎的最佳实践。如果不清楚当前方言版本,请以标准实现为主基调并好心指出版别差异(如 MySQL 8 窗口函数 等)。`
+- 当被要求提供任何数据库代码时,需结合相关数据库引擎的最佳实践。如果不清楚当前方言版本,请以标准实现为主基调并好心指出版别差异(如 MySQL 8 窗口函数 等)。
+- 绝不轻易拒绝:如果用户要求写 SQL 但并未显式挂载任何表的详细 DDL,请尽最大努力根据对话上下文中带入的【纯表名列表】去推测他要查询哪个表。如果实在无法推断,请温柔且专业地向用户解释目前已知的表有哪些,并询问到底想查哪张表。`
}
diff --git a/internal/ai/provider/anthropic.go b/internal/ai/provider/anthropic.go
index 035d2fc..94312ec 100644
--- a/internal/ai/provider/anthropic.go
+++ b/internal/ai/provider/anthropic.go
@@ -82,8 +82,42 @@ type anthropicRequest struct {
}
type anthropicMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
+ Role string `json:"role"`
+ Content interface{} `json:"content"`
+}
+
+func buildAnthropicMessages(reqMessages []ai.Message) []anthropicMessage {
+ messages := make([]anthropicMessage, 0, len(reqMessages))
+ for _, m := range reqMessages {
+ if len(m.Images) > 0 {
+ var contentParts []map[string]interface{}
+ for _, img := range m.Images {
+ mimeType, rawBase64, err := ParseDataURI(img)
+ if err == nil {
+ contentParts = append(contentParts, map[string]interface{}{
+ "type": "image",
+ "source": map[string]interface{}{
+ "type": "base64",
+ "media_type": mimeType,
+ "data": rawBase64,
+ },
+ })
+ }
+ }
+ text := m.Content
+ if text == "" {
+ text = "请描述和分析这张图片。" // 防止强 System Prompt 下模型仅看到空文本且忽略图片直接回复打招呼
+ }
+ contentParts = append(contentParts, map[string]interface{}{
+ "type": "text",
+ "text": text,
+ })
+ messages = append(messages, anthropicMessage{Role: m.Role, Content: contentParts})
+ } else {
+ messages = append(messages, anthropicMessage{Role: m.Role, Content: m.Content})
+ }
+ }
+ return messages
}
type anthropicResponse struct {
@@ -112,10 +146,7 @@ func (p *AnthropicProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.C
}
systemMsg, messages := extractSystemMessage(req.Messages)
- anthropicMsgs := make([]anthropicMessage, len(messages))
- for i, m := range messages {
- anthropicMsgs[i] = anthropicMessage{Role: m.Role, Content: m.Content}
- }
+ anthropicMsgs := buildAnthropicMessages(messages)
temperature := req.Temperature
if temperature <= 0 {
@@ -167,10 +198,7 @@ func (p *AnthropicProvider) ChatStream(ctx context.Context, req ai.ChatRequest,
}
systemMsg, messages := extractSystemMessage(req.Messages)
- anthropicMsgs := make([]anthropicMessage, len(messages))
- for i, m := range messages {
- anthropicMsgs[i] = anthropicMessage{Role: m.Role, Content: m.Content}
- }
+ anthropicMsgs := buildAnthropicMessages(messages)
temperature := req.Temperature
if temperature <= 0 {
@@ -253,6 +281,12 @@ func (p *AnthropicProvider) doRequest(ctx context.Context, body interface{}) (io
httpReq.Header.Set("x-api-key", p.config.APIKey)
httpReq.Header.Set("anthropic-version", anthropicAPIVersion)
+ if strings.Contains(string(jsonBody), `"stream":true`) || strings.Contains(string(jsonBody), `"stream": true`) {
+ httpReq.Header.Set("Accept", "text/event-stream")
+ httpReq.Header.Set("Cache-Control", "no-cache")
+ httpReq.Header.Set("Connection", "keep-alive")
+ }
+
// 仅官方 API 发 beta 特性头(代理不发,避免触发 Claude Code 验证)
isOfficialAPI := p.baseURL == defaultAnthropicBaseURL || strings.Contains(p.baseURL, "anthropic.com")
if isOfficialAPI {
diff --git a/internal/ai/provider/claude_cli.go b/internal/ai/provider/claude_cli.go
index 824e413..4cd9b00 100644
--- a/internal/ai/provider/claude_cli.go
+++ b/internal/ai/provider/claude_cli.go
@@ -105,8 +105,7 @@ func (p *ClaudeCLIProvider) ChatStream(ctx context.Context, req ai.ChatRequest,
fmt.Printf("[ClaudeCLI DEBUG] Process started, PID: %d\n", cmd.Process.Pid)
- // 立即通知前端:AI 正在思考(避免用户以为卡死)
- callback(ai.StreamChunk{Content: "💭 *正在思考...*\n\n"})
+ // 前端已有 loading 动画,无需在 content 中注入"正在思考"
// 逐行读取流式 JSON 输出
scanner := bufio.NewScanner(stdout)
@@ -131,14 +130,18 @@ func (p *ClaudeCLIProvider) ChatStream(ctx context.Context, req ai.ChatRequest,
// 助手消息开始或文本内容
if event.Message.Content != nil {
for _, block := range event.Message.Content {
- if block.Type == "text" && block.Text != "" {
+ if block.Type == "thinking" && block.Thinking != "" {
+ callback(ai.StreamChunk{Thinking: block.Thinking})
+ } else if block.Type == "text" && block.Text != "" {
callback(ai.StreamChunk{Content: block.Text})
}
}
}
case "content_block_delta":
- // 增量文本
- if event.Delta.Text != "" {
+ // 增量文本或增量思考
+ if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
+ callback(ai.StreamChunk{Thinking: event.Delta.Thinking})
+ } else if event.Delta.Text != "" {
callback(ai.StreamChunk{Content: event.Delta.Text})
}
case "result":
@@ -213,12 +216,15 @@ type cliStreamEvent struct {
Type string `json:"type"`
Message struct {
Content []struct {
- Type string `json:"type"`
- Text string `json:"text"`
+ Type string `json:"type"`
+ Text string `json:"text"`
+ Thinking string `json:"thinking"`
} `json:"content"`
} `json:"message,omitempty"`
Delta struct {
- Text string `json:"text"`
+ Type string `json:"type"`
+ Text string `json:"text"`
+ Thinking string `json:"thinking"`
} `json:"delta,omitempty"`
Result string `json:"result,omitempty"`
Error struct {
diff --git a/internal/ai/provider/gemini.go b/internal/ai/provider/gemini.go
index 0c5eee7..b4cf910 100644
--- a/internal/ai/provider/gemini.go
+++ b/internal/ai/provider/gemini.go
@@ -83,7 +83,13 @@ type geminiContent struct {
}
type geminiPart struct {
- Text string `json:"text"`
+ Text string `json:"text,omitempty"`
+ InlineData *geminiBlob `json:"inlineData,omitempty"`
+}
+
+type geminiBlob struct {
+ MimeType string `json:"mimeType"`
+ Data string `json:"data"`
}
type geminiGenConfig struct {
@@ -205,10 +211,6 @@ func (p *GeminiProvider) buildRequest(req ai.ChatRequest) geminiRequest {
if temperature <= 0 {
temperature = p.config.Temperature
}
- maxTokens := req.MaxTokens
- if maxTokens <= 0 {
- maxTokens = p.config.MaxTokens
- }
var systemInstruction *geminiContent
var contents []geminiContent
@@ -224,9 +226,29 @@ func (p *GeminiProvider) buildRequest(req ai.ChatRequest) geminiRequest {
if role == "assistant" {
role = "model"
}
+ var parts []geminiPart
+ text := m.Content
+ if text == "" && len(m.Images) > 0 {
+ text = "请描述和分析这张图片。" // 同样避免 Gemini 认为意图不明确
+ }
+ if text != "" {
+ parts = append(parts, geminiPart{Text: text})
+ }
+ for _, img := range m.Images {
+ mimeType, rawBase64, err := ParseDataURI(img)
+ if err == nil {
+ parts = append(parts, geminiPart{
+ InlineData: &geminiBlob{
+ MimeType: mimeType,
+ Data: rawBase64,
+ },
+ })
+ }
+ }
+
contents = append(contents, geminiContent{
Role: role,
- Parts: []geminiPart{{Text: m.Content}},
+ Parts: parts,
})
}
@@ -235,7 +257,6 @@ func (p *GeminiProvider) buildRequest(req ai.ChatRequest) geminiRequest {
SystemInstruction: systemInstruction,
GenerationConfig: geminiGenConfig{
Temperature: temperature,
- MaxOutputTokens: maxTokens,
},
}
}
@@ -252,6 +273,12 @@ func (p *GeminiProvider) doRequest(ctx context.Context, url string, body interfa
}
httpReq.Header.Set("Content-Type", "application/json")
+ if strings.Contains(url, "alt=sse") {
+ httpReq.Header.Set("Accept", "text/event-stream")
+ httpReq.Header.Set("Cache-Control", "no-cache")
+ httpReq.Header.Set("Connection", "keep-alive")
+ }
+
resp, err := p.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("发送请求到 Gemini 失败: %w", err)
diff --git a/internal/ai/provider/helper.go b/internal/ai/provider/helper.go
new file mode 100644
index 0000000..e6bb0d3
--- /dev/null
+++ b/internal/ai/provider/helper.go
@@ -0,0 +1,26 @@
+package provider
+
+import (
+ "fmt"
+ "strings"
+)
+
+// ParseDataURI 解析前端传递的 Data URI,返回 mimeType 和去掉前缀的 rawBase64
+func ParseDataURI(dataURI string) (mimeType, rawBase64 string, err error) {
+ if !strings.HasPrefix(dataURI, "data:") {
+ // 如果前端漏了前缀,默认容错当做 jpeg 处理
+ return "image/jpeg", dataURI, nil
+ }
+ parts := strings.SplitN(dataURI, ",", 2)
+ if len(parts) != 2 {
+ return "", "", fmt.Errorf("invalid data URI format")
+ }
+ meta := strings.TrimPrefix(parts[0], "data:")
+ metaParts := strings.Split(meta, ";")
+ mimeType = metaParts[0]
+ if mimeType == "" {
+ mimeType = "image/jpeg" // fallback
+ }
+ rawBase64 = parts[1]
+ return mimeType, rawBase64, nil
+}
diff --git a/internal/ai/provider/openai.go b/internal/ai/provider/openai.go
index ff674a9..87d5214 100644
--- a/internal/ai/provider/openai.go
+++ b/internal/ai/provider/openai.go
@@ -88,18 +88,67 @@ type openAIChatRequest struct {
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream,omitempty"`
+ Tools []ai.Tool `json:"tools,omitempty"`
}
type openAIChatMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
+ Role string `json:"role"`
+ Content interface{} `json:"content,omitempty"`
+ ToolCalls []ai.ToolCall `json:"tool_calls,omitempty"`
+ ToolCallID string `json:"tool_call_id,omitempty"`
+}
+
+func buildOpenAIMessages(reqMessages []ai.Message, modelName string, baseURL string) []openAIChatMessage {
+ messages := make([]openAIChatMessage, len(reqMessages))
+ for i, m := range reqMessages {
+ if m.Role == "tool" {
+ messages[i] = openAIChatMessage{Role: m.Role, Content: m.Content, ToolCallID: m.ToolCallID}
+ continue
+ }
+ if len(m.ToolCalls) > 0 {
+ messages[i] = openAIChatMessage{Role: m.Role, Content: m.Content, ToolCalls: m.ToolCalls}
+ continue
+ }
+
+ if len(m.Images) > 0 {
+ var contentParts []map[string]interface{}
+ text := m.Content
+ if text == "" {
+ text = "请描述和分析这张图片。" // 兼容部分模型(如 ZhipuAI/GLM-4V)强制要求图片必须伴随有效文本块,同时防止强 System Prompt 下模型当成空消息处理
+ }
+ contentParts = append(contentParts, map[string]interface{}{
+ "type": "text",
+ "text": text,
+ })
+ for _, img := range m.Images {
+ imgURL := img
+ // 仅当直接请求智谱官方 API 域名时(它原生不接受 data 协议前缀),才截取裸 Base64
+ if strings.Contains(strings.ToLower(baseURL), "bigmodel") {
+ if _, raw, err := ParseDataURI(img); err == nil {
+ imgURL = raw
+ }
+ }
+ contentParts = append(contentParts, map[string]interface{}{
+ "type": "image_url",
+ "image_url": map[string]interface{}{
+ "url": imgURL,
+ },
+ })
+ }
+ messages[i] = openAIChatMessage{Role: m.Role, Content: contentParts}
+ } else {
+ messages[i] = openAIChatMessage{Role: m.Role, Content: m.Content}
+ }
+ }
+ return messages
}
// openAIChatResponse OpenAI API 响应体
type openAIChatResponse struct {
Choices []struct {
Message struct {
- Content string `json:"content"`
+ Content string `json:"content"`
+ ToolCalls []ai.ToolCall `json:"tool_calls,omitempty"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
@@ -114,10 +163,22 @@ type openAIChatResponse struct {
}
// openAIStreamChunk SSE 流式响应片段
+type openAIToolCallDelta struct {
+ Index int `json:"index"`
+ ID string `json:"id,omitempty"`
+ Type string `json:"type,omitempty"`
+ Function *struct {
+ Name string `json:"name,omitempty"`
+ Arguments string `json:"arguments,omitempty"`
+ } `json:"function,omitempty"`
+}
+
type openAIStreamChunk struct {
Choices []struct {
Delta struct {
- Content string `json:"content"`
+ Content string `json:"content"`
+ ReasoningContent string `json:"reasoning_content"`
+ ToolCalls []openAIToolCallDelta `json:"tool_calls,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
} `json:"choices"`
@@ -131,26 +192,19 @@ func (p *OpenAIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.Chat
return nil, err
}
- messages := make([]openAIChatMessage, len(req.Messages))
- for i, m := range req.Messages {
- messages[i] = openAIChatMessage{Role: m.Role, Content: m.Content}
- }
+ messages := buildOpenAIMessages(req.Messages, p.config.Model, p.baseURL)
temperature := req.Temperature
if temperature <= 0 {
temperature = p.config.Temperature
}
- maxTokens := req.MaxTokens
- if maxTokens <= 0 {
- maxTokens = p.config.MaxTokens
- }
body := openAIChatRequest{
Model: p.config.Model,
Messages: messages,
Temperature: temperature,
- MaxTokens: maxTokens,
Stream: false,
+ Tools: req.Tools,
}
respBody, err := p.doRequest(ctx, body)
@@ -177,6 +231,7 @@ func (p *OpenAIProvider) Chat(ctx context.Context, req ai.ChatRequest) (*ai.Chat
CompletionTokens: result.Usage.CompletionTokens,
TotalTokens: result.Usage.TotalTokens,
},
+ ToolCalls: result.Choices[0].Message.ToolCalls,
}, nil
}
@@ -185,26 +240,19 @@ func (p *OpenAIProvider) ChatStream(ctx context.Context, req ai.ChatRequest, cal
return err
}
- messages := make([]openAIChatMessage, len(req.Messages))
- for i, m := range req.Messages {
- messages[i] = openAIChatMessage{Role: m.Role, Content: m.Content}
- }
+ messages := buildOpenAIMessages(req.Messages, p.config.Model, p.baseURL)
temperature := req.Temperature
if temperature <= 0 {
temperature = p.config.Temperature
}
- maxTokens := req.MaxTokens
- if maxTokens <= 0 {
- maxTokens = p.config.MaxTokens
- }
body := openAIChatRequest{
Model: p.config.Model,
Messages: messages,
Temperature: temperature,
- MaxTokens: maxTokens,
Stream: true,
+ Tools: req.Tools,
}
respBody, err := p.doRequest(ctx, body)
@@ -214,6 +262,8 @@ func (p *OpenAIProvider) ChatStream(ctx context.Context, req ai.ChatRequest, cal
defer respBody.Close()
receivedContent := false
+ var activeToolCalls []ai.ToolCall
+
scanner := bufio.NewScanner(respBody)
// 增大 scanner buffer,防止长行被截断
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
@@ -245,12 +295,49 @@ func (p *OpenAIProvider) ChatStream(ctx context.Context, req ai.ChatRequest, cal
return nil
}
if len(chunk.Choices) > 0 {
- content := chunk.Choices[0].Delta.Content
+ choice := chunk.Choices[0]
+
+ // Handle ToolCalls delta
+ if len(choice.Delta.ToolCalls) > 0 {
+ receivedContent = true
+ for _, tcDelta := range choice.Delta.ToolCalls {
+ // Expand activeToolCalls slice if index is larger
+ for len(activeToolCalls) <= tcDelta.Index {
+ activeToolCalls = append(activeToolCalls, ai.ToolCall{Type: "function"})
+ }
+ if tcDelta.ID != "" {
+ activeToolCalls[tcDelta.Index].ID = tcDelta.ID
+ }
+ if tcDelta.Function != nil {
+ if tcDelta.Function.Name != "" {
+ activeToolCalls[tcDelta.Index].Function.Name += tcDelta.Function.Name
+ }
+ if tcDelta.Function.Arguments != "" {
+ activeToolCalls[tcDelta.Index].Function.Arguments += tcDelta.Function.Arguments
+ }
+ }
+ }
+ // 实时推送目前已解析的 ToolCalls 状态
+ callback(ai.StreamChunk{ToolCalls: activeToolCalls})
+ }
+
+ content := choice.Delta.Content
if content != "" {
receivedContent = true
callback(ai.StreamChunk{Content: content})
}
- if chunk.Choices[0].FinishReason != nil {
+
+ // 支持 DeepSeek/千问等模型的 reasoning_content 字段
+ if choice.Delta.ReasoningContent != "" {
+ receivedContent = true
+ callback(ai.StreamChunk{Thinking: choice.Delta.ReasoningContent})
+ }
+
+ if choice.FinishReason != nil {
+ if *choice.FinishReason == "tool_calls" {
+ callback(ai.StreamChunk{ToolCalls: activeToolCalls, Done: true})
+ return nil
+ }
callback(ai.StreamChunk{Done: true})
return nil
}
@@ -296,6 +383,13 @@ func (p *OpenAIProvider) doRequest(ctx context.Context, body interface{}) (io.Re
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+p.config.APIKey)
+ // 仅在流式请求时明确声明 SSE,防止代理缓冲
+ if strings.Contains(string(jsonBody), `"stream":true`) || strings.Contains(string(jsonBody), `"stream": true`) {
+ httpReq.Header.Set("Accept", "text/event-stream")
+ httpReq.Header.Set("Cache-Control", "no-cache")
+ httpReq.Header.Set("Connection", "keep-alive")
+ }
+
// 自定义 headers(用于兼容各类 OpenAI 兼容服务)
for k, v := range p.config.Headers {
httpReq.Header.Set(k, v)
diff --git a/internal/ai/service/service.go b/internal/ai/service/service.go
index 52af6ba..addab29 100644
--- a/internal/ai/service/service.go
+++ b/internal/ai/service/service.go
@@ -114,7 +114,7 @@ func (s *Service) AIDeleteProvider(id string) error {
return s.saveConfig()
}
-// AITestProvider 测试 Provider 配置是否可用
+// AITestProvider 测试 Provider 配置是否可用,仅测试端点连通性与密钥,不实际调用对话
func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{} {
// 如果传入脱敏的 key,使用已保存的 key
s.mu.RLock()
@@ -128,30 +128,84 @@ func (s *Service) AITestProvider(config ai.ProviderConfig) map[string]interface{
}
s.mu.RUnlock()
- p, err := provider.NewProvider(config)
- if err != nil {
- return map[string]interface{}{"success": false, "message": err.Error()}
- }
- if err := p.Validate(); err != nil {
- return map[string]interface{}{"success": false, "message": err.Error()}
+ baseURL := strings.TrimRight(strings.TrimSpace(config.BaseURL), "/")
+ providerType := config.Type
+ if providerType == "custom" && config.APIFormat != "" {
+ providerType = config.APIFormat
}
- ctx, cancel := context.WithTimeout(context.Background(), 30*1000*1000*1000) // 30s
- defer cancel()
+ client := &http.Client{Timeout: 10 * time.Second}
+ var err error
+
+ switch providerType {
+ case "openai":
+ if baseURL == "" {
+ baseURL = "https://api.openai.com/v1"
+ }
+ if !strings.HasSuffix(baseURL, "/v1") && !strings.Contains(baseURL, "/v1/") {
+ baseURL = baseURL + "/v1"
+ }
+ // 使用 /models 端点验证连通性和鉴权
+ req, _ := http.NewRequest("GET", baseURL+"/models", nil)
+ req.Header.Set("Authorization", "Bearer "+config.APIKey)
+ for k, v := range config.Headers {
+ req.Header.Set(k, v)
+ }
+ resp, reqErr := client.Do(req)
+ if reqErr != nil {
+ err = reqErr
+ } else {
+ defer resp.Body.Close()
+ if resp.StatusCode == http.StatusUnauthorized {
+ err = fmt.Errorf("API Key 验证失败 (HTTP %d)", resp.StatusCode)
+ } else if resp.StatusCode >= 500 {
+ err = fmt.Errorf("上游服务器内部错误 (HTTP %d)", resp.StatusCode)
+ }
+ }
+ case "anthropic":
+ if baseURL == "" {
+ baseURL = "https://api.anthropic.com"
+ }
+ req, _ := http.NewRequest("GET", baseURL, nil)
+ resp, reqErr := client.Do(req)
+ if reqErr != nil {
+ err = reqErr
+ } else {
+ resp.Body.Close()
+ }
+ case "gemini":
+ if baseURL == "" {
+ baseURL = "https://generativelanguage.googleapis.com"
+ }
+ req, _ := http.NewRequest("GET", baseURL+"/v1beta/models?key="+config.APIKey, nil)
+ resp, reqErr := client.Do(req)
+ if reqErr != nil {
+ err = reqErr
+ } else {
+ defer resp.Body.Close()
+ if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusBadRequest {
+ err = fmt.Errorf("API Key 无效或请求错误 (HTTP %d)", resp.StatusCode)
+ }
+ }
+ default:
+ if baseURL != "" {
+ req, _ := http.NewRequest("GET", baseURL, nil)
+ resp, reqErr := client.Do(req)
+ if reqErr != nil {
+ err = reqErr
+ } else {
+ resp.Body.Close()
+ }
+ }
+ }
- resp, err := p.Chat(ctx, ai.ChatRequest{
- Messages: []ai.Message{
- {Role: "user", Content: "Hi, please respond with 'OK' to confirm the connection is working."},
- },
- MaxTokens: 10,
- })
if err != nil {
return map[string]interface{}{"success": false, "message": fmt.Sprintf("连接测试失败: %s", err.Error())}
}
return map[string]interface{}{
"success": true,
- "message": fmt.Sprintf("连接成功!模型响应: %s", truncateString(resp.Content, 100)),
+ "message": "端点连通性测试成功!",
}
}
@@ -364,19 +418,14 @@ func (s *Service) AISetContextLevel(level string) {
// --- AI 对话 ---
-// AIChatSend 同步发送 AI 对话(非流式)
-func (s *Service) AIChatSend(messages []map[string]string) map[string]interface{} {
+// AIChatSend 非流式发送 AI 对话
+func (s *Service) AIChatSend(messages []ai.Message, tools []ai.Tool) map[string]interface{} {
p, err := s.getActiveProvider()
if err != nil {
return map[string]interface{}{"success": false, "error": err.Error()}
}
- var aiMessages []ai.Message
- for _, m := range messages {
- aiMessages = append(aiMessages, ai.Message{Role: m["role"], Content: m["content"]})
- }
-
- resp, err := p.Chat(context.Background(), ai.ChatRequest{Messages: aiMessages})
+ resp, err := p.Chat(context.Background(), ai.ChatRequest{Messages: messages, Tools: tools})
if err != nil {
return map[string]interface{}{"success": false, "error": err.Error()}
}
@@ -384,6 +433,7 @@ func (s *Service) AIChatSend(messages []map[string]string) map[string]interface{
return map[string]interface{}{
"success": true,
"content": resp.Content,
+ "tool_calls": resp.ToolCalls,
"tokensUsed": map[string]int{
"promptTokens": resp.TokensUsed.PromptTokens,
"completionTokens": resp.TokensUsed.CompletionTokens,
@@ -393,7 +443,7 @@ func (s *Service) AIChatSend(messages []map[string]string) map[string]interface{
}
// AIChatStream 流式发送 AI 对话(通过 EventsEmit 推送)
-func (s *Service) AIChatStream(sessionID string, messages []map[string]string) {
+func (s *Service) AIChatStream(sessionID string, messages []ai.Message, tools []ai.Tool) {
streamCtx, cancel := context.WithCancel(context.Background())
s.mu.Lock()
s.cancelFuncs[sessionID] = cancel
@@ -416,16 +466,13 @@ func (s *Service) AIChatStream(sessionID string, messages []map[string]string) {
return
}
- var aiMessages []ai.Message
- for _, m := range messages {
- aiMessages = append(aiMessages, ai.Message{Role: m["role"], Content: m["content"]})
- }
-
- err = p.ChatStream(streamCtx, ai.ChatRequest{Messages: aiMessages}, func(chunk ai.StreamChunk) {
+ err = p.ChatStream(streamCtx, ai.ChatRequest{Messages: messages, Tools: tools}, func(chunk ai.StreamChunk) {
wailsRuntime.EventsEmit(s.ctx, "ai:stream:"+sessionID, map[string]interface{}{
- "content": chunk.Content,
- "done": chunk.Done,
- "error": chunk.Error,
+ "content": chunk.Content,
+ "thinking": chunk.Thinking,
+ "tool_calls": chunk.ToolCalls,
+ "done": chunk.Done,
+ "error": chunk.Error,
})
})
diff --git a/internal/ai/types.go b/internal/ai/types.go
index eb55a6f..0c83c6d 100644
--- a/internal/ai/types.go
+++ b/internal/ai/types.go
@@ -1,9 +1,35 @@
package ai
+// ToolCall 表示 AI 发出的工具调用
+type ToolCall struct {
+ ID string `json:"id"`
+ Type string `json:"type"` // "function"
+ Function struct {
+ Name string `json:"name"`
+ Arguments string `json:"arguments"`
+ } `json:"function"`
+}
+
+// ToolFunction 表示可使用的函数定义
+type ToolFunction struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Parameters any `json:"parameters"` // JSON Schema definitions
+}
+
+// Tool 工具申明
+type Tool struct {
+ Type string `json:"type"` // "function"
+ Function ToolFunction `json:"function"`
+}
+
// Message 表示一条对话消息
type Message struct {
- Role string `json:"role"` // "system" | "user" | "assistant"
- Content string `json:"content"`
+ Role string `json:"role"` // "system" | "user" | "assistant" | "tool"
+ Content string `json:"content"`
+ Images []string `json:"images,omitempty"` // base64 encoded images with data:image/png;base64,... prefix
+ ToolCallID string `json:"tool_call_id,omitempty"` // 当 role 为 "tool" 时必须传递
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"` // 当 role 为 "assistant" 并试图调工具时传递
}
// ChatRequest AI 对话请求
@@ -11,12 +37,14 @@ type ChatRequest struct {
Messages []Message `json:"messages"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"maxTokens"`
+ Tools []Tool `json:"tools,omitempty"`
}
// ChatResponse AI 对话响应
type ChatResponse struct {
Content string `json:"content"`
TokensUsed TokenUsage `json:"tokensUsed"`
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
// TokenUsage token 用量统计
@@ -28,9 +56,11 @@ type TokenUsage struct {
// StreamChunk 流式响应片段
type StreamChunk struct {
- Content string `json:"content"`
- Done bool `json:"done"`
- Error string `json:"error,omitempty"`
+ Content string `json:"content"`
+ Thinking string `json:"thinking,omitempty"`
+ Done bool `json:"done"`
+ Error string `json:"error,omitempty"`
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
// ProviderConfig AI Provider 配置