mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-14 18:39:54 +08:00
♻️ refactor(ai-chat): 抽离本地工具执行器并补齐测试
- 将 AIChatPanel 中本地工具执行 switch 抽成独立 helper,降低主面板耦合度 - 为表结构、SQL 安全拦截和工具结果映射补充独立单测 - 保持 AI 面板交互不变,补做构建与浏览器冒烟验证
This commit is contained in:
@@ -23,16 +23,15 @@ describe('AIChatPanel message render isolation', () => {
|
||||
it('loads MCP tools and skills into the runtime tool chain', () => {
|
||||
expect(source).toContain('AIListMCPTools');
|
||||
expect(source).toContain('AIGetSkills');
|
||||
expect(source).toContain('AICallMCPTool');
|
||||
expect(source).toContain('executeLocalAIToolCall');
|
||||
expect(source).toContain('以下是当前启用的 Skill');
|
||||
expect(source).toContain('buildAvailableAIChatTools');
|
||||
});
|
||||
|
||||
it('teaches the runtime to use deeper schema tools when analyzing structure details', () => {
|
||||
expect(source).toContain('get_indexes、get_foreign_keys、get_triggers、get_table_ddl');
|
||||
expect(source).toContain("case 'get_indexes':");
|
||||
expect(source).toContain("case 'get_foreign_keys':");
|
||||
expect(source).toContain("case 'get_triggers':");
|
||||
expect(source).toContain('toolContextMap: toolContextMapRef.current');
|
||||
expect(source).toContain('buildToolResultMessage');
|
||||
});
|
||||
|
||||
it('keeps the v2 history mode sorted by the latest updated session first', () => {
|
||||
|
||||
@@ -2,7 +2,6 @@ import React, { useState, useRef, useEffect, useCallback, useMemo } from 'react'
|
||||
import { createPortal } from 'react-dom';
|
||||
import { useStore, loadAISessionsFromBackend, loadAISessionFromBackend } from '../store';
|
||||
import { EventsOn, EventsOff } from '../../wailsjs/runtime';
|
||||
import { DBGetDatabases, DBGetTables } from '../../wailsjs/go/app/App';
|
||||
import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme';
|
||||
import type {
|
||||
AIChatMessage,
|
||||
@@ -28,13 +27,16 @@ import {
|
||||
buildMissingProviderNotice,
|
||||
buildModelFetchFailedNotice,
|
||||
} from '../utils/aiComposerNotice';
|
||||
import { buildAIReadonlyPreviewSQL } from '../utils/aiSqlLimit';
|
||||
import { resolveAITableSchemaToolResult } from '../utils/aiTableSchemaTool';
|
||||
import { consumeAIChatSendShortcutOnKeyDown } from '../utils/aiChatSendShortcut';
|
||||
import { toAIRequestMessage } from '../utils/aiMessagePayload';
|
||||
import { getShortcutPlatform, resolveShortcutBinding } from '../utils/shortcuts';
|
||||
import { isMacLikePlatform } from '../utils/appearance';
|
||||
import { buildAvailableAIChatTools } from '../utils/aiToolRegistry';
|
||||
import {
|
||||
buildToolResultMessage,
|
||||
executeLocalAIToolCall,
|
||||
type AIToolContextEntry,
|
||||
} from './ai/aiLocalToolExecutor';
|
||||
|
||||
interface AIChatPanelProps {
|
||||
width?: number;
|
||||
@@ -1226,7 +1228,7 @@ SELECT * FROM users WHERE status = 1;
|
||||
}, [availableTools, skills, userPromptSettings]);
|
||||
|
||||
// 记录所有成功的 get_tables 调用结果,用于表级精确匹配
|
||||
const toolContextMapRef = useRef<Map<string, { connectionId: string; dbName: string; tables: string[] }>>(new Map());
|
||||
const toolContextMapRef = useRef<Map<string, AIToolContextEntry>>(new Map());
|
||||
|
||||
const executeLocalTools = useCallback(async (toolCalls: AIToolCall[], currentAsstMsgId: string) => {
|
||||
const currentAsstMsg = (useStore.getState().aiChatHistory[sid] || []).find(m => m.id === currentAsstMsgId);
|
||||
@@ -1253,233 +1255,21 @@ SELECT * FROM users WHERE status = 1;
|
||||
}
|
||||
|
||||
const results: AIChatMessage[] = [];
|
||||
const mcpToolMap = new Map(mcpTools.map((tool) => [tool.alias, tool]));
|
||||
const currentConnections = useStore.getState().connections;
|
||||
// 【串行逐条执行 + 实时写入 store】
|
||||
for (const tc of toolCalls) {
|
||||
let resStr = '';
|
||||
let success = false;
|
||||
try {
|
||||
const args = JSON.parse(tc.function.arguments || '{}');
|
||||
const mcpToolDescriptor = mcpToolMap.get(tc.function.name);
|
||||
switch (tc.function.name) {
|
||||
case 'get_connections':
|
||||
const conns = useStore.getState().connections.map(c => ({
|
||||
id: c.id,
|
||||
name: c.name,
|
||||
type: c.config?.type,
|
||||
host: (c.config as any)?.host || (c.config as any)?.addr || ''
|
||||
}));
|
||||
resStr = JSON.stringify(conns);
|
||||
success = true;
|
||||
break;
|
||||
case 'get_databases': {
|
||||
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
|
||||
if (conn) {
|
||||
try {
|
||||
const dbRes = await DBGetDatabases(buildRpcConnectionConfig(conn.config) as any);
|
||||
if (dbRes?.success && Array.isArray(dbRes.data)) {
|
||||
let dNames = dbRes.data.map((r: any) => r.Database || r.database || Object.values(r)[0]);
|
||||
if (dNames.length > 50) dNames = [...dNames.slice(0, 50), '...(截断)'];
|
||||
resStr = JSON.stringify(dNames);
|
||||
success = true;
|
||||
} else {
|
||||
resStr = dbRes?.message || 'Failed to fetch DBs';
|
||||
}
|
||||
} catch (e: any) {
|
||||
resStr = `获取数据库列表失败: ${e?.message || e}`;
|
||||
}
|
||||
} else { resStr = 'Connection not found'; }
|
||||
break;
|
||||
}
|
||||
case 'get_tables': {
|
||||
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
|
||||
if (conn) {
|
||||
try {
|
||||
const rawDbName = args.dbName || args.database;
|
||||
const safeDbName = rawDbName ? String(rawDbName).trim() : '';
|
||||
const tbRes = await DBGetTables(buildRpcConnectionConfig(conn.config) as any, safeDbName);
|
||||
if (tbRes?.success && Array.isArray(tbRes.data)) {
|
||||
let tNames = tbRes.data.map((r: any) => r.Table || r.table || Object.values(r)[0] as string);
|
||||
if (tNames.length > 150) tNames = [...tNames.slice(0, 150), '...(截断)'];
|
||||
resStr = JSON.stringify(tNames);
|
||||
success = true;
|
||||
// 🔑 记录已验证的上下文参数和表列表(用于后续表级精确匹配)
|
||||
toolContextMapRef.current.set(`${args.connectionId}:${safeDbName}`, {
|
||||
connectionId: args.connectionId,
|
||||
dbName: safeDbName,
|
||||
tables: tNames.filter((t: string) => t !== '...(截断)')
|
||||
});
|
||||
} else { resStr = tbRes?.message || 'Failed to fetch Tables'; }
|
||||
} catch (e: any) {
|
||||
resStr = `获取表列表失败: ${e?.message || e}`;
|
||||
}
|
||||
} else { resStr = 'Connection not found'; }
|
||||
break;
|
||||
}
|
||||
case 'get_columns': {
|
||||
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
|
||||
if (conn) {
|
||||
try {
|
||||
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
|
||||
const safeTable = args.tableName ? String(args.tableName).trim() : '';
|
||||
const { DBGetColumns } = await import('../../wailsjs/go/app/App');
|
||||
const colRes = await DBGetColumns(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable);
|
||||
if (colRes?.success && Array.isArray(colRes.data)) {
|
||||
// 只保留关键字段信息,减少 token 占用
|
||||
const cols = colRes.data.map((c: any) => {
|
||||
const keys = Object.keys(c);
|
||||
return {
|
||||
field: c.Field || c.field || c.COLUMN_NAME || c.column_name || c.Name || c.name || (keys.length > 0 ? c[keys[0]] : ''),
|
||||
type: c.Type || c.type || c.DATA_TYPE || c.data_type || (keys.length > 1 ? c[keys[1]] : ''),
|
||||
nullable: c.Null || c.null || c.IS_NULLABLE || c.is_nullable || c.Nullable || c.nullable || '',
|
||||
default: c.Default || c.default || c.COLUMN_DEFAULT || c.column_default || c.DefaultValue || '',
|
||||
comment: c.Comment || c.comment || c.COLUMN_COMMENT || c.column_comment || c.Description || '',
|
||||
};
|
||||
});
|
||||
// ⚠️ 在工具返回结果中直接注入强制警告,确保模型使用精确字段名
|
||||
const fieldNames = cols.map((c: any) => c.field).join(', ');
|
||||
resStr = `⚠️ 以下为 ${safeTable} 表的真实字段列表。生成 SQL 时只能使用这些 field 值作为列名,必须原样使用,禁止修改、缩写或自行拼凑字段名。\n可用字段:${fieldNames}\n详细信息:${JSON.stringify(cols)}`;
|
||||
success = true;
|
||||
} else { resStr = colRes?.message || 'Failed to fetch columns'; }
|
||||
} catch (e: any) {
|
||||
resStr = `获取字段列表失败: ${e?.message || e}`;
|
||||
}
|
||||
} else { resStr = 'Connection not found'; }
|
||||
break;
|
||||
}
|
||||
case 'get_indexes': {
|
||||
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
|
||||
if (conn) {
|
||||
try {
|
||||
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
|
||||
const safeTable = args.tableName ? String(args.tableName).trim() : '';
|
||||
const { DBGetIndexes } = await import('../../wailsjs/go/app/App');
|
||||
const indexRes = await DBGetIndexes(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable);
|
||||
if (indexRes?.success && Array.isArray(indexRes.data)) {
|
||||
resStr = JSON.stringify(indexRes.data);
|
||||
success = true;
|
||||
} else { resStr = indexRes?.message || 'Failed to fetch indexes'; }
|
||||
} catch (e: any) {
|
||||
resStr = `获取索引定义失败: ${e?.message || e}`;
|
||||
}
|
||||
} else { resStr = 'Connection not found'; }
|
||||
break;
|
||||
}
|
||||
case 'get_foreign_keys': {
|
||||
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
|
||||
if (conn) {
|
||||
try {
|
||||
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
|
||||
const safeTable = args.tableName ? String(args.tableName).trim() : '';
|
||||
const { DBGetForeignKeys } = await import('../../wailsjs/go/app/App');
|
||||
const foreignKeyRes = await DBGetForeignKeys(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable);
|
||||
if (foreignKeyRes?.success && Array.isArray(foreignKeyRes.data)) {
|
||||
resStr = JSON.stringify(foreignKeyRes.data);
|
||||
success = true;
|
||||
} else { resStr = foreignKeyRes?.message || 'Failed to fetch foreign keys'; }
|
||||
} catch (e: any) {
|
||||
resStr = `获取外键关系失败: ${e?.message || e}`;
|
||||
}
|
||||
} else { resStr = 'Connection not found'; }
|
||||
break;
|
||||
}
|
||||
case 'get_triggers': {
|
||||
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
|
||||
if (conn) {
|
||||
try {
|
||||
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
|
||||
const safeTable = args.tableName ? String(args.tableName).trim() : '';
|
||||
const { DBGetTriggers } = await import('../../wailsjs/go/app/App');
|
||||
const triggerRes = await DBGetTriggers(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable);
|
||||
if (triggerRes?.success && Array.isArray(triggerRes.data)) {
|
||||
resStr = JSON.stringify(triggerRes.data);
|
||||
success = true;
|
||||
} else { resStr = triggerRes?.message || 'Failed to fetch triggers'; }
|
||||
} catch (e: any) {
|
||||
resStr = `获取触发器定义失败: ${e?.message || e}`;
|
||||
}
|
||||
} else { resStr = 'Connection not found'; }
|
||||
break;
|
||||
}
|
||||
case 'get_table_ddl': {
|
||||
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
|
||||
if (conn) {
|
||||
try {
|
||||
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
|
||||
const safeTable = args.tableName ? String(args.tableName).trim() : '';
|
||||
const { DBShowCreateTable, DBGetColumns } = await import('../../wailsjs/go/app/App');
|
||||
const rpcConfig = buildRpcConnectionConfig(conn.config) as any;
|
||||
const toolResult = await resolveAITableSchemaToolResult({
|
||||
tableName: safeTable,
|
||||
fetchDDL: () => DBShowCreateTable(rpcConfig, safeDbName, safeTable),
|
||||
fetchColumns: () => DBGetColumns(rpcConfig, safeDbName, safeTable),
|
||||
});
|
||||
resStr = toolResult.content;
|
||||
success = toolResult.success;
|
||||
} catch (e: any) {
|
||||
resStr = `获取建表语句失败: ${e?.message || e}`;
|
||||
}
|
||||
} else { resStr = 'Connection not found'; }
|
||||
break;
|
||||
}
|
||||
case 'execute_sql': {
|
||||
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
|
||||
if (conn) {
|
||||
try {
|
||||
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
|
||||
const safeSql = args.sql ? String(args.sql).trim() : '';
|
||||
// 安全级别检查
|
||||
const Service = (window as any).go?.aiservice?.Service;
|
||||
if (Service?.AICheckSQL) {
|
||||
const check = await Service.AICheckSQL(safeSql);
|
||||
if (!check.allowed) {
|
||||
resStr = `安全策略拦截:当前安全级别不允许执行 ${check.operationType} 类型的 SQL。请将 SQL 展示给用户,让用户手动执行。`;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const { DBQuery } = await import('../../wailsjs/go/app/App');
|
||||
const finalSql = buildAIReadonlyPreviewSQL(conn.config?.type || '', safeSql, 50, conn.config?.driver || '');
|
||||
const qRes = await DBQuery(buildRpcConnectionConfig(conn.config) as any, safeDbName, finalSql);
|
||||
if (qRes?.success) {
|
||||
const rows = Array.isArray(qRes.data) ? qRes.data : [];
|
||||
const limitedRows = rows.slice(0, 50);
|
||||
resStr = JSON.stringify({ rowCount: rows.length, data: limitedRows });
|
||||
success = true;
|
||||
} else { resStr = qRes?.message || 'SQL 执行失败'; }
|
||||
} catch (e: any) {
|
||||
resStr = `SQL 执行异常: ${e?.message || e}`;
|
||||
}
|
||||
} else { resStr = 'Connection not found'; }
|
||||
break;
|
||||
}
|
||||
default:
|
||||
if (mcpToolDescriptor) {
|
||||
try {
|
||||
const Service = (window as any).go?.aiservice?.Service;
|
||||
const toolResult = await Service?.AICallMCPTool?.(tc.function.name, tc.function.arguments || '{}');
|
||||
resStr = String(toolResult?.content || (toolResult?.isError ? 'MCP 工具调用失败' : ''));
|
||||
success = !!toolResult && !toolResult.isError;
|
||||
} catch (e: any) {
|
||||
resStr = `MCP 工具调用失败: ${e?.message || e}`;
|
||||
}
|
||||
} else {
|
||||
resStr = `Unknown function: ${tc.function.name}`;
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
resStr = e.message;
|
||||
}
|
||||
|
||||
const resolvedToolDescriptor = mcpToolMap.get(tc.function.name);
|
||||
const toolResultMsg: AIChatMessage = {
|
||||
const execution = await executeLocalAIToolCall({
|
||||
toolCall: tc,
|
||||
connections: currentConnections,
|
||||
mcpTools,
|
||||
toolContextMap: toolContextMapRef.current,
|
||||
});
|
||||
const toolResultMsg: AIChatMessage = buildToolResultMessage({
|
||||
id: genId(),
|
||||
role: 'tool',
|
||||
content: resStr,
|
||||
timestamp: Date.now(),
|
||||
tool_call_id: tc.id,
|
||||
tool_name: resolvedToolDescriptor?.title || resolvedToolDescriptor?.originalName || tc.function.name,
|
||||
success
|
||||
};
|
||||
toolCall: tc,
|
||||
execution,
|
||||
});
|
||||
results.push(toolResultMsg);
|
||||
|
||||
// 【实时写入】每执行完一条立即写入 store,让 UI 能实时看到进度打勾
|
||||
|
||||
131
frontend/src/components/ai/aiLocalToolExecutor.test.ts
Normal file
131
frontend/src/components/ai/aiLocalToolExecutor.test.ts
Normal file
@@ -0,0 +1,131 @@
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import type { AIMCPToolDescriptor, AIToolCall, SavedConnection } from '../../types';
|
||||
import { buildToolResultMessage, executeLocalAIToolCall } from './aiLocalToolExecutor';
|
||||
|
||||
const buildConnection = (): SavedConnection => ({
|
||||
id: 'conn-1',
|
||||
name: '主库',
|
||||
config: {
|
||||
type: 'mysql',
|
||||
host: '127.0.0.1',
|
||||
port: 3306,
|
||||
user: 'root',
|
||||
},
|
||||
});
|
||||
|
||||
const buildToolCall = (name: string, args: Record<string, unknown>): AIToolCall => ({
|
||||
id: `call-${name}`,
|
||||
type: 'function',
|
||||
function: {
|
||||
name,
|
||||
arguments: JSON.stringify(args),
|
||||
},
|
||||
});
|
||||
|
||||
describe('aiLocalToolExecutor', () => {
|
||||
it('caches validated table context after get_tables succeeds', async () => {
|
||||
const toolContextMap = new Map();
|
||||
const result = await executeLocalAIToolCall({
|
||||
toolCall: buildToolCall('get_tables', { connectionId: 'conn-1', dbName: 'crm' }),
|
||||
connections: [buildConnection()],
|
||||
mcpTools: [],
|
||||
toolContextMap,
|
||||
runtime: {
|
||||
getDatabases: vi.fn(),
|
||||
getTables: vi.fn().mockResolvedValue({
|
||||
success: true,
|
||||
data: [{ Table: 'users' }, { Table: 'orders' }],
|
||||
}),
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.content).toContain('users');
|
||||
expect(toolContextMap.get('conn-1:crm')).toEqual({
|
||||
connectionId: 'conn-1',
|
||||
dbName: 'crm',
|
||||
tables: ['users', 'orders'],
|
||||
});
|
||||
});
|
||||
|
||||
it('blocks execute_sql when the AI safety check rejects the statement', async () => {
|
||||
const query = vi.fn();
|
||||
const result = await executeLocalAIToolCall({
|
||||
toolCall: buildToolCall('execute_sql', {
|
||||
connectionId: 'conn-1',
|
||||
dbName: 'crm',
|
||||
sql: 'DELETE FROM users',
|
||||
}),
|
||||
connections: [buildConnection()],
|
||||
mcpTools: [],
|
||||
toolContextMap: new Map(),
|
||||
runtime: {
|
||||
getDatabases: vi.fn(),
|
||||
getTables: vi.fn(),
|
||||
getColumns: vi.fn(),
|
||||
getIndexes: vi.fn(),
|
||||
getForeignKeys: vi.fn(),
|
||||
getTriggers: vi.fn(),
|
||||
showCreateTable: vi.fn(),
|
||||
query,
|
||||
checkSQL: vi.fn().mockResolvedValue({
|
||||
allowed: false,
|
||||
operationType: 'DELETE',
|
||||
}),
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.success).toBe(false);
|
||||
expect(result.content).toContain('安全策略拦截');
|
||||
expect(query).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('returns index definitions and resolves the tool label for MCP descriptors', async () => {
|
||||
const mcpTools: AIMCPToolDescriptor[] = [{
|
||||
alias: 'custom_tool',
|
||||
originalName: 'custom_tool',
|
||||
serverId: 'server-1',
|
||||
serverName: 'demo',
|
||||
title: '自定义探针',
|
||||
description: '',
|
||||
}];
|
||||
const indexResult = await executeLocalAIToolCall({
|
||||
toolCall: buildToolCall('get_indexes', {
|
||||
connectionId: 'conn-1',
|
||||
dbName: 'crm',
|
||||
tableName: 'users',
|
||||
}),
|
||||
connections: [buildConnection()],
|
||||
mcpTools,
|
||||
toolContextMap: new Map(),
|
||||
runtime: {
|
||||
getDatabases: vi.fn(),
|
||||
getTables: vi.fn(),
|
||||
getColumns: vi.fn(),
|
||||
getIndexes: vi.fn().mockResolvedValue({
|
||||
success: true,
|
||||
data: [{ keyName: 'idx_users_email', nonUnique: 0 }],
|
||||
}),
|
||||
getForeignKeys: vi.fn(),
|
||||
getTriggers: vi.fn(),
|
||||
showCreateTable: vi.fn(),
|
||||
query: vi.fn(),
|
||||
},
|
||||
});
|
||||
const message = buildToolResultMessage({
|
||||
id: 'msg-1',
|
||||
timestamp: 1,
|
||||
toolCall: buildToolCall('custom_tool', {}),
|
||||
execution: {
|
||||
content: 'ok',
|
||||
success: true,
|
||||
toolName: '自定义探针',
|
||||
},
|
||||
});
|
||||
|
||||
expect(indexResult.success).toBe(true);
|
||||
expect(indexResult.content).toContain('idx_users_email');
|
||||
expect(message.tool_name).toBe('自定义探针');
|
||||
});
|
||||
});
|
||||
364
frontend/src/components/ai/aiLocalToolExecutor.ts
Normal file
364
frontend/src/components/ai/aiLocalToolExecutor.ts
Normal file
@@ -0,0 +1,364 @@
|
||||
import { DBGetDatabases, DBGetTables } from '../../../wailsjs/go/app/App';
|
||||
|
||||
import type { AIChatMessage, AIMCPToolDescriptor, AIToolCall, SavedConnection } from '../../types';
|
||||
import { buildRpcConnectionConfig } from '../../utils/connectionRpcConfig';
|
||||
import { buildAIReadonlyPreviewSQL } from '../../utils/aiSqlLimit';
|
||||
import { resolveAITableSchemaToolResult } from '../../utils/aiTableSchemaTool';
|
||||
|
||||
export interface AIToolContextEntry {
|
||||
connectionId: string;
|
||||
dbName: string;
|
||||
tables: string[];
|
||||
}
|
||||
|
||||
interface AILocalToolRuntime {
|
||||
getDatabases: (config: any) => Promise<any>;
|
||||
getTables: (config: any, dbName: string) => Promise<any>;
|
||||
getColumns: (config: any, dbName: string, tableName: string) => Promise<any>;
|
||||
getIndexes: (config: any, dbName: string, tableName: string) => Promise<any>;
|
||||
getForeignKeys: (config: any, dbName: string, tableName: string) => Promise<any>;
|
||||
getTriggers: (config: any, dbName: string, tableName: string) => Promise<any>;
|
||||
showCreateTable: (config: any, dbName: string, tableName: string) => Promise<any>;
|
||||
query: (config: any, dbName: string, sql: string) => Promise<any>;
|
||||
checkSQL?: (sql: string) => Promise<{ allowed?: boolean; operationType?: string } | undefined>;
|
||||
callMCPTool?: (name: string, args: string) => Promise<{ content?: string; isError?: boolean } | undefined>;
|
||||
}
|
||||
|
||||
export interface ExecuteLocalAIToolCallOptions {
|
||||
toolCall: AIToolCall;
|
||||
connections: SavedConnection[];
|
||||
mcpTools: AIMCPToolDescriptor[];
|
||||
toolContextMap: Map<string, AIToolContextEntry>;
|
||||
runtime?: Partial<AILocalToolRuntime>;
|
||||
}
|
||||
|
||||
export interface ExecuteLocalAIToolCallResult {
|
||||
content: string;
|
||||
success: boolean;
|
||||
toolName: string;
|
||||
}
|
||||
|
||||
const buildDefaultRuntime = (): AILocalToolRuntime => ({
|
||||
getDatabases: DBGetDatabases,
|
||||
getTables: DBGetTables,
|
||||
getColumns: async (config, dbName, tableName) => {
|
||||
const mod = await import('../../../wailsjs/go/app/App');
|
||||
return mod.DBGetColumns(config, dbName, tableName);
|
||||
},
|
||||
getIndexes: async (config, dbName, tableName) => {
|
||||
const mod = await import('../../../wailsjs/go/app/App');
|
||||
return mod.DBGetIndexes(config, dbName, tableName);
|
||||
},
|
||||
getForeignKeys: async (config, dbName, tableName) => {
|
||||
const mod = await import('../../../wailsjs/go/app/App');
|
||||
return mod.DBGetForeignKeys(config, dbName, tableName);
|
||||
},
|
||||
getTriggers: async (config, dbName, tableName) => {
|
||||
const mod = await import('../../../wailsjs/go/app/App');
|
||||
return mod.DBGetTriggers(config, dbName, tableName);
|
||||
},
|
||||
showCreateTable: async (config, dbName, tableName) => {
|
||||
const mod = await import('../../../wailsjs/go/app/App');
|
||||
return mod.DBShowCreateTable(config, dbName, tableName);
|
||||
},
|
||||
query: async (config, dbName, sql) => {
|
||||
const mod = await import('../../../wailsjs/go/app/App');
|
||||
return mod.DBQuery(config, dbName, sql);
|
||||
},
|
||||
checkSQL: async (sql) => {
|
||||
const service = (window as any).go?.aiservice?.Service;
|
||||
if (typeof service?.AICheckSQL !== 'function') {
|
||||
return undefined;
|
||||
}
|
||||
return service.AICheckSQL(sql);
|
||||
},
|
||||
callMCPTool: async (name, args) => {
|
||||
const service = (window as any).go?.aiservice?.Service;
|
||||
if (typeof service?.AICallMCPTool !== 'function') {
|
||||
return undefined;
|
||||
}
|
||||
return service.AICallMCPTool(name, args);
|
||||
},
|
||||
});
|
||||
|
||||
const normalizeTableList = (rows: any[]): string[] =>
|
||||
rows.map((row) => row.Table || row.table || (Object.values(row)[0] as string));
|
||||
|
||||
const normalizeColumns = (rows: any[]) =>
|
||||
rows.map((column) => {
|
||||
const keys = Object.keys(column);
|
||||
return {
|
||||
field: column.Field || column.field || column.COLUMN_NAME || column.column_name || column.Name || column.name || (keys.length > 0 ? column[keys[0]] : ''),
|
||||
type: column.Type || column.type || column.DATA_TYPE || column.data_type || (keys.length > 1 ? column[keys[1]] : ''),
|
||||
nullable: column.Null || column.null || column.IS_NULLABLE || column.is_nullable || column.Nullable || column.nullable || '',
|
||||
default: column.Default || column.default || column.COLUMN_DEFAULT || column.column_default || column.DefaultValue || '',
|
||||
comment: column.Comment || column.comment || column.COLUMN_COMMENT || column.column_comment || column.Description || '',
|
||||
};
|
||||
});
|
||||
|
||||
const buildToolName = (toolCall: AIToolCall, descriptor?: AIMCPToolDescriptor) =>
|
||||
descriptor?.title || descriptor?.originalName || toolCall.function.name;
|
||||
|
||||
const findConnection = (connections: SavedConnection[], connectionId: string) =>
|
||||
connections.find((connection) => connection.id === connectionId);
|
||||
|
||||
export async function executeLocalAIToolCall({
|
||||
toolCall,
|
||||
connections,
|
||||
mcpTools,
|
||||
toolContextMap,
|
||||
runtime,
|
||||
}: ExecuteLocalAIToolCallOptions): Promise<ExecuteLocalAIToolCallResult> {
|
||||
const mergedRuntime = { ...buildDefaultRuntime(), ...(runtime || {}) };
|
||||
const descriptor = mcpTools.find((tool) => tool.alias === toolCall.function.name);
|
||||
let content = '';
|
||||
let success = false;
|
||||
|
||||
try {
|
||||
const args = JSON.parse(toolCall.function.arguments || '{}');
|
||||
switch (toolCall.function.name) {
|
||||
case 'get_connections': {
|
||||
const availableConnections = connections.map((connection) => ({
|
||||
id: connection.id,
|
||||
name: connection.name,
|
||||
type: connection.config?.type,
|
||||
host: (connection.config as any)?.host || (connection.config as any)?.addr || '',
|
||||
}));
|
||||
content = JSON.stringify(availableConnections);
|
||||
success = true;
|
||||
break;
|
||||
}
|
||||
case 'get_databases': {
|
||||
const connection = findConnection(connections, args.connectionId);
|
||||
if (!connection) {
|
||||
content = 'Connection not found';
|
||||
break;
|
||||
}
|
||||
try {
|
||||
const result = await mergedRuntime.getDatabases(buildRpcConnectionConfig(connection.config) as any);
|
||||
if (result?.success && Array.isArray(result.data)) {
|
||||
let databaseNames = result.data.map((row: any) => row.Database || row.database || Object.values(row)[0]);
|
||||
if (databaseNames.length > 50) {
|
||||
databaseNames = [...databaseNames.slice(0, 50), '...(截断)'];
|
||||
}
|
||||
content = JSON.stringify(databaseNames);
|
||||
success = true;
|
||||
} else {
|
||||
content = result?.message || 'Failed to fetch DBs';
|
||||
}
|
||||
} catch (error: any) {
|
||||
content = `获取数据库列表失败: ${error?.message || error}`;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'get_tables': {
|
||||
const connection = findConnection(connections, args.connectionId);
|
||||
if (!connection) {
|
||||
content = 'Connection not found';
|
||||
break;
|
||||
}
|
||||
try {
|
||||
const rawDbName = args.dbName || args.database;
|
||||
const safeDbName = rawDbName ? String(rawDbName).trim() : '';
|
||||
const result = await mergedRuntime.getTables(buildRpcConnectionConfig(connection.config) as any, safeDbName);
|
||||
if (result?.success && Array.isArray(result.data)) {
|
||||
let tableNames = normalizeTableList(result.data);
|
||||
if (tableNames.length > 150) {
|
||||
tableNames = [...tableNames.slice(0, 150), '...(截断)'];
|
||||
}
|
||||
content = JSON.stringify(tableNames);
|
||||
success = true;
|
||||
toolContextMap.set(`${args.connectionId}:${safeDbName}`, {
|
||||
connectionId: args.connectionId,
|
||||
dbName: safeDbName,
|
||||
tables: tableNames.filter((tableName) => tableName !== '...(截断)'),
|
||||
});
|
||||
} else {
|
||||
content = result?.message || 'Failed to fetch Tables';
|
||||
}
|
||||
} catch (error: any) {
|
||||
content = `获取表列表失败: ${error?.message || error}`;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'get_columns': {
|
||||
const connection = findConnection(connections, args.connectionId);
|
||||
if (!connection) {
|
||||
content = 'Connection not found';
|
||||
break;
|
||||
}
|
||||
try {
|
||||
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
|
||||
const safeTable = args.tableName ? String(args.tableName).trim() : '';
|
||||
const result = await mergedRuntime.getColumns(buildRpcConnectionConfig(connection.config) as any, safeDbName, safeTable);
|
||||
if (result?.success && Array.isArray(result.data)) {
|
||||
const columns = normalizeColumns(result.data);
|
||||
const fieldNames = columns.map((column) => column.field).join(', ');
|
||||
content = `⚠️ 以下为 ${safeTable} 表的真实字段列表。生成 SQL 时只能使用这些 field 值作为列名,必须原样使用,禁止修改、缩写或自行拼凑字段名。\n可用字段:${fieldNames}\n详细信息:${JSON.stringify(columns)}`;
|
||||
success = true;
|
||||
} else {
|
||||
content = result?.message || 'Failed to fetch columns';
|
||||
}
|
||||
} catch (error: any) {
|
||||
content = `获取字段列表失败: ${error?.message || error}`;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'get_indexes': {
|
||||
const connection = findConnection(connections, args.connectionId);
|
||||
if (!connection) {
|
||||
content = 'Connection not found';
|
||||
break;
|
||||
}
|
||||
try {
|
||||
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
|
||||
const safeTable = args.tableName ? String(args.tableName).trim() : '';
|
||||
const result = await mergedRuntime.getIndexes(buildRpcConnectionConfig(connection.config) as any, safeDbName, safeTable);
|
||||
if (result?.success && Array.isArray(result.data)) {
|
||||
content = JSON.stringify(result.data);
|
||||
success = true;
|
||||
} else {
|
||||
content = result?.message || 'Failed to fetch indexes';
|
||||
}
|
||||
} catch (error: any) {
|
||||
content = `获取索引定义失败: ${error?.message || error}`;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'get_foreign_keys': {
|
||||
const connection = findConnection(connections, args.connectionId);
|
||||
if (!connection) {
|
||||
content = 'Connection not found';
|
||||
break;
|
||||
}
|
||||
try {
|
||||
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
|
||||
const safeTable = args.tableName ? String(args.tableName).trim() : '';
|
||||
const result = await mergedRuntime.getForeignKeys(buildRpcConnectionConfig(connection.config) as any, safeDbName, safeTable);
|
||||
if (result?.success && Array.isArray(result.data)) {
|
||||
content = JSON.stringify(result.data);
|
||||
success = true;
|
||||
} else {
|
||||
content = result?.message || 'Failed to fetch foreign keys';
|
||||
}
|
||||
} catch (error: any) {
|
||||
content = `获取外键关系失败: ${error?.message || error}`;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'get_triggers': {
|
||||
const connection = findConnection(connections, args.connectionId);
|
||||
if (!connection) {
|
||||
content = 'Connection not found';
|
||||
break;
|
||||
}
|
||||
try {
|
||||
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
|
||||
const safeTable = args.tableName ? String(args.tableName).trim() : '';
|
||||
const result = await mergedRuntime.getTriggers(buildRpcConnectionConfig(connection.config) as any, safeDbName, safeTable);
|
||||
if (result?.success && Array.isArray(result.data)) {
|
||||
content = JSON.stringify(result.data);
|
||||
success = true;
|
||||
} else {
|
||||
content = result?.message || 'Failed to fetch triggers';
|
||||
}
|
||||
} catch (error: any) {
|
||||
content = `获取触发器定义失败: ${error?.message || error}`;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'get_table_ddl': {
|
||||
const connection = findConnection(connections, args.connectionId);
|
||||
if (!connection) {
|
||||
content = 'Connection not found';
|
||||
break;
|
||||
}
|
||||
try {
|
||||
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
|
||||
const safeTable = args.tableName ? String(args.tableName).trim() : '';
|
||||
const rpcConfig = buildRpcConnectionConfig(connection.config) as any;
|
||||
const result = await resolveAITableSchemaToolResult({
|
||||
tableName: safeTable,
|
||||
fetchDDL: () => mergedRuntime.showCreateTable(rpcConfig, safeDbName, safeTable),
|
||||
fetchColumns: () => mergedRuntime.getColumns(rpcConfig, safeDbName, safeTable),
|
||||
});
|
||||
content = result.content;
|
||||
success = result.success;
|
||||
} catch (error: any) {
|
||||
content = `获取建表语句失败: ${error?.message || error}`;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'execute_sql': {
|
||||
const connection = findConnection(connections, args.connectionId);
|
||||
if (!connection) {
|
||||
content = 'Connection not found';
|
||||
break;
|
||||
}
|
||||
try {
|
||||
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
|
||||
const safeSql = args.sql ? String(args.sql).trim() : '';
|
||||
if (typeof mergedRuntime.checkSQL === 'function') {
|
||||
const checkResult = await mergedRuntime.checkSQL(safeSql);
|
||||
if (checkResult && checkResult.allowed === false) {
|
||||
content = `安全策略拦截:当前安全级别不允许执行 ${checkResult.operationType} 类型的 SQL。请将 SQL 展示给用户,让用户手动执行。`;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const finalSql = buildAIReadonlyPreviewSQL(connection.config?.type || '', safeSql, 50, connection.config?.driver || '');
|
||||
const result = await mergedRuntime.query(buildRpcConnectionConfig(connection.config) as any, safeDbName, finalSql);
|
||||
if (result?.success) {
|
||||
const rows = Array.isArray(result.data) ? result.data : [];
|
||||
content = JSON.stringify({ rowCount: rows.length, data: rows.slice(0, 50) });
|
||||
success = true;
|
||||
} else {
|
||||
content = result?.message || 'SQL 执行失败';
|
||||
}
|
||||
} catch (error: any) {
|
||||
content = `SQL 执行异常: ${error?.message || error}`;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
if (!descriptor) {
|
||||
content = `Unknown function: ${toolCall.function.name}`;
|
||||
break;
|
||||
}
|
||||
try {
|
||||
const result = await mergedRuntime.callMCPTool?.(toolCall.function.name, toolCall.function.arguments || '{}');
|
||||
content = String(result?.content || (result?.isError ? 'MCP 工具调用失败' : ''));
|
||||
success = !!result && !result.isError;
|
||||
} catch (error: any) {
|
||||
content = `MCP 工具调用失败: ${error?.message || error}`;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (error: any) {
|
||||
content = error?.message || String(error);
|
||||
}
|
||||
|
||||
return {
|
||||
content,
|
||||
success,
|
||||
toolName: buildToolName(toolCall, descriptor),
|
||||
};
|
||||
}
|
||||
|
||||
export function buildToolResultMessage(params: {
|
||||
id: string;
|
||||
timestamp: number;
|
||||
toolCall: AIToolCall;
|
||||
execution: ExecuteLocalAIToolCallResult;
|
||||
}): AIChatMessage {
|
||||
const { id, timestamp, toolCall, execution } = params;
|
||||
return {
|
||||
id,
|
||||
role: 'tool',
|
||||
content: execution.content,
|
||||
timestamp,
|
||||
tool_call_id: toolCall.id,
|
||||
tool_name: execution.toolName,
|
||||
success: execution.success,
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user