From 802385464d12d71fe096bf2c1e4087bed203fc4c Mon Sep 17 00:00:00 2001 From: Syngnat Date: Sun, 7 Jun 2026 22:40:07 +0800 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(ai-chat):=20?= =?UTF-8?q?=E6=8A=BD=E7=A6=BB=E6=9C=AC=E5=9C=B0=E5=B7=A5=E5=85=B7=E6=89=A7?= =?UTF-8?q?=E8=A1=8C=E5=99=A8=E5=B9=B6=E8=A1=A5=E9=BD=90=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 AIChatPanel 中本地工具执行 switch 抽成独立 helper,降低主面板耦合度 - 为表结构、SQL 安全拦截和工具结果映射补充独立单测 - 保持 AI 面板交互不变,补做构建与浏览器冒烟验证 --- .../AIChatPanel.message-boundary.test.tsx | 7 +- frontend/src/components/AIChatPanel.tsx | 244 +----------- .../components/ai/aiLocalToolExecutor.test.ts | 131 +++++++ .../src/components/ai/aiLocalToolExecutor.ts | 364 ++++++++++++++++++ 4 files changed, 515 insertions(+), 231 deletions(-) create mode 100644 frontend/src/components/ai/aiLocalToolExecutor.test.ts create mode 100644 frontend/src/components/ai/aiLocalToolExecutor.ts diff --git a/frontend/src/components/AIChatPanel.message-boundary.test.tsx b/frontend/src/components/AIChatPanel.message-boundary.test.tsx index 556cfcd..3144b13 100644 --- a/frontend/src/components/AIChatPanel.message-boundary.test.tsx +++ b/frontend/src/components/AIChatPanel.message-boundary.test.tsx @@ -23,16 +23,15 @@ describe('AIChatPanel message render isolation', () => { it('loads MCP tools and skills into the runtime tool chain', () => { expect(source).toContain('AIListMCPTools'); expect(source).toContain('AIGetSkills'); - expect(source).toContain('AICallMCPTool'); + expect(source).toContain('executeLocalAIToolCall'); expect(source).toContain('以下是当前启用的 Skill'); expect(source).toContain('buildAvailableAIChatTools'); }); it('teaches the runtime to use deeper schema tools when analyzing structure details', () => { expect(source).toContain('get_indexes、get_foreign_keys、get_triggers、get_table_ddl'); - expect(source).toContain("case 'get_indexes':"); - expect(source).toContain("case 'get_foreign_keys':"); - expect(source).toContain("case 'get_triggers':"); + expect(source).toContain('toolContextMap: toolContextMapRef.current'); + expect(source).toContain('buildToolResultMessage'); }); it('keeps the v2 history mode sorted by the latest updated session first', () => { diff --git a/frontend/src/components/AIChatPanel.tsx b/frontend/src/components/AIChatPanel.tsx index 5a9bd72..a54b22c 100644 --- a/frontend/src/components/AIChatPanel.tsx +++ b/frontend/src/components/AIChatPanel.tsx @@ -2,7 +2,6 @@ import React, { useState, useRef, useEffect, useCallback, useMemo } from 'react' import { createPortal } from 'react-dom'; import { useStore, loadAISessionsFromBackend, loadAISessionFromBackend } from '../store'; import { EventsOn, EventsOff } from '../../wailsjs/runtime'; -import { DBGetDatabases, DBGetTables } from '../../wailsjs/go/app/App'; import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme'; import type { AIChatMessage, @@ -28,13 +27,16 @@ import { buildMissingProviderNotice, buildModelFetchFailedNotice, } from '../utils/aiComposerNotice'; -import { buildAIReadonlyPreviewSQL } from '../utils/aiSqlLimit'; -import { resolveAITableSchemaToolResult } from '../utils/aiTableSchemaTool'; import { consumeAIChatSendShortcutOnKeyDown } from '../utils/aiChatSendShortcut'; import { toAIRequestMessage } from '../utils/aiMessagePayload'; import { getShortcutPlatform, resolveShortcutBinding } from '../utils/shortcuts'; import { isMacLikePlatform } from '../utils/appearance'; import { buildAvailableAIChatTools } from '../utils/aiToolRegistry'; +import { + buildToolResultMessage, + executeLocalAIToolCall, + type AIToolContextEntry, +} from './ai/aiLocalToolExecutor'; interface AIChatPanelProps { width?: number; @@ -1226,7 +1228,7 @@ SELECT * FROM users WHERE status = 1; }, [availableTools, skills, userPromptSettings]); // 记录所有成功的 get_tables 调用结果,用于表级精确匹配 - const toolContextMapRef = useRef>(new Map()); + const toolContextMapRef = useRef>(new Map()); const executeLocalTools = useCallback(async (toolCalls: AIToolCall[], currentAsstMsgId: string) => { const currentAsstMsg = (useStore.getState().aiChatHistory[sid] || []).find(m => m.id === currentAsstMsgId); @@ -1253,233 +1255,21 @@ SELECT * FROM users WHERE status = 1; } const results: AIChatMessage[] = []; - const mcpToolMap = new Map(mcpTools.map((tool) => [tool.alias, tool])); + const currentConnections = useStore.getState().connections; // 【串行逐条执行 + 实时写入 store】 for (const tc of toolCalls) { - let resStr = ''; - let success = false; - try { - const args = JSON.parse(tc.function.arguments || '{}'); - const mcpToolDescriptor = mcpToolMap.get(tc.function.name); - switch (tc.function.name) { - case 'get_connections': - const conns = useStore.getState().connections.map(c => ({ - id: c.id, - name: c.name, - type: c.config?.type, - host: (c.config as any)?.host || (c.config as any)?.addr || '' - })); - resStr = JSON.stringify(conns); - success = true; - break; - case 'get_databases': { - const conn = useStore.getState().connections.find(c => c.id === args.connectionId); - if (conn) { - try { - const dbRes = await DBGetDatabases(buildRpcConnectionConfig(conn.config) as any); - if (dbRes?.success && Array.isArray(dbRes.data)) { - let dNames = dbRes.data.map((r: any) => r.Database || r.database || Object.values(r)[0]); - if (dNames.length > 50) dNames = [...dNames.slice(0, 50), '...(截断)']; - resStr = JSON.stringify(dNames); - success = true; - } else { - resStr = dbRes?.message || 'Failed to fetch DBs'; - } - } catch (e: any) { - resStr = `获取数据库列表失败: ${e?.message || e}`; - } - } else { resStr = 'Connection not found'; } - break; - } - case 'get_tables': { - const conn = useStore.getState().connections.find(c => c.id === args.connectionId); - if (conn) { - try { - const rawDbName = args.dbName || args.database; - const safeDbName = rawDbName ? String(rawDbName).trim() : ''; - const tbRes = await DBGetTables(buildRpcConnectionConfig(conn.config) as any, safeDbName); - if (tbRes?.success && Array.isArray(tbRes.data)) { - let tNames = tbRes.data.map((r: any) => r.Table || r.table || Object.values(r)[0] as string); - if (tNames.length > 150) tNames = [...tNames.slice(0, 150), '...(截断)']; - resStr = JSON.stringify(tNames); - success = true; - // 🔑 记录已验证的上下文参数和表列表(用于后续表级精确匹配) - toolContextMapRef.current.set(`${args.connectionId}:${safeDbName}`, { - connectionId: args.connectionId, - dbName: safeDbName, - tables: tNames.filter((t: string) => t !== '...(截断)') - }); - } else { resStr = tbRes?.message || 'Failed to fetch Tables'; } - } catch (e: any) { - resStr = `获取表列表失败: ${e?.message || e}`; - } - } else { resStr = 'Connection not found'; } - break; - } - case 'get_columns': { - const conn = useStore.getState().connections.find(c => c.id === args.connectionId); - if (conn) { - try { - const safeDbName = args.dbName ? String(args.dbName).trim() : ''; - const safeTable = args.tableName ? String(args.tableName).trim() : ''; - const { DBGetColumns } = await import('../../wailsjs/go/app/App'); - const colRes = await DBGetColumns(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable); - if (colRes?.success && Array.isArray(colRes.data)) { - // 只保留关键字段信息,减少 token 占用 - const cols = colRes.data.map((c: any) => { - const keys = Object.keys(c); - return { - field: c.Field || c.field || c.COLUMN_NAME || c.column_name || c.Name || c.name || (keys.length > 0 ? c[keys[0]] : ''), - type: c.Type || c.type || c.DATA_TYPE || c.data_type || (keys.length > 1 ? c[keys[1]] : ''), - nullable: c.Null || c.null || c.IS_NULLABLE || c.is_nullable || c.Nullable || c.nullable || '', - default: c.Default || c.default || c.COLUMN_DEFAULT || c.column_default || c.DefaultValue || '', - comment: c.Comment || c.comment || c.COLUMN_COMMENT || c.column_comment || c.Description || '', - }; - }); - // ⚠️ 在工具返回结果中直接注入强制警告,确保模型使用精确字段名 - const fieldNames = cols.map((c: any) => c.field).join(', '); - resStr = `⚠️ 以下为 ${safeTable} 表的真实字段列表。生成 SQL 时只能使用这些 field 值作为列名,必须原样使用,禁止修改、缩写或自行拼凑字段名。\n可用字段:${fieldNames}\n详细信息:${JSON.stringify(cols)}`; - success = true; - } else { resStr = colRes?.message || 'Failed to fetch columns'; } - } catch (e: any) { - resStr = `获取字段列表失败: ${e?.message || e}`; - } - } else { resStr = 'Connection not found'; } - break; - } - case 'get_indexes': { - const conn = useStore.getState().connections.find(c => c.id === args.connectionId); - if (conn) { - try { - const safeDbName = args.dbName ? String(args.dbName).trim() : ''; - const safeTable = args.tableName ? String(args.tableName).trim() : ''; - const { DBGetIndexes } = await import('../../wailsjs/go/app/App'); - const indexRes = await DBGetIndexes(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable); - if (indexRes?.success && Array.isArray(indexRes.data)) { - resStr = JSON.stringify(indexRes.data); - success = true; - } else { resStr = indexRes?.message || 'Failed to fetch indexes'; } - } catch (e: any) { - resStr = `获取索引定义失败: ${e?.message || e}`; - } - } else { resStr = 'Connection not found'; } - break; - } - case 'get_foreign_keys': { - const conn = useStore.getState().connections.find(c => c.id === args.connectionId); - if (conn) { - try { - const safeDbName = args.dbName ? String(args.dbName).trim() : ''; - const safeTable = args.tableName ? String(args.tableName).trim() : ''; - const { DBGetForeignKeys } = await import('../../wailsjs/go/app/App'); - const foreignKeyRes = await DBGetForeignKeys(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable); - if (foreignKeyRes?.success && Array.isArray(foreignKeyRes.data)) { - resStr = JSON.stringify(foreignKeyRes.data); - success = true; - } else { resStr = foreignKeyRes?.message || 'Failed to fetch foreign keys'; } - } catch (e: any) { - resStr = `获取外键关系失败: ${e?.message || e}`; - } - } else { resStr = 'Connection not found'; } - break; - } - case 'get_triggers': { - const conn = useStore.getState().connections.find(c => c.id === args.connectionId); - if (conn) { - try { - const safeDbName = args.dbName ? String(args.dbName).trim() : ''; - const safeTable = args.tableName ? String(args.tableName).trim() : ''; - const { DBGetTriggers } = await import('../../wailsjs/go/app/App'); - const triggerRes = await DBGetTriggers(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable); - if (triggerRes?.success && Array.isArray(triggerRes.data)) { - resStr = JSON.stringify(triggerRes.data); - success = true; - } else { resStr = triggerRes?.message || 'Failed to fetch triggers'; } - } catch (e: any) { - resStr = `获取触发器定义失败: ${e?.message || e}`; - } - } else { resStr = 'Connection not found'; } - break; - } - case 'get_table_ddl': { - const conn = useStore.getState().connections.find(c => c.id === args.connectionId); - if (conn) { - try { - const safeDbName = args.dbName ? String(args.dbName).trim() : ''; - const safeTable = args.tableName ? String(args.tableName).trim() : ''; - const { DBShowCreateTable, DBGetColumns } = await import('../../wailsjs/go/app/App'); - const rpcConfig = buildRpcConnectionConfig(conn.config) as any; - const toolResult = await resolveAITableSchemaToolResult({ - tableName: safeTable, - fetchDDL: () => DBShowCreateTable(rpcConfig, safeDbName, safeTable), - fetchColumns: () => DBGetColumns(rpcConfig, safeDbName, safeTable), - }); - resStr = toolResult.content; - success = toolResult.success; - } catch (e: any) { - resStr = `获取建表语句失败: ${e?.message || e}`; - } - } else { resStr = 'Connection not found'; } - break; - } - case 'execute_sql': { - const conn = useStore.getState().connections.find(c => c.id === args.connectionId); - if (conn) { - try { - const safeDbName = args.dbName ? String(args.dbName).trim() : ''; - const safeSql = args.sql ? String(args.sql).trim() : ''; - // 安全级别检查 - const Service = (window as any).go?.aiservice?.Service; - if (Service?.AICheckSQL) { - const check = await Service.AICheckSQL(safeSql); - if (!check.allowed) { - resStr = `安全策略拦截:当前安全级别不允许执行 ${check.operationType} 类型的 SQL。请将 SQL 展示给用户,让用户手动执行。`; - break; - } - } - const { DBQuery } = await import('../../wailsjs/go/app/App'); - const finalSql = buildAIReadonlyPreviewSQL(conn.config?.type || '', safeSql, 50, conn.config?.driver || ''); - const qRes = await DBQuery(buildRpcConnectionConfig(conn.config) as any, safeDbName, finalSql); - if (qRes?.success) { - const rows = Array.isArray(qRes.data) ? qRes.data : []; - const limitedRows = rows.slice(0, 50); - resStr = JSON.stringify({ rowCount: rows.length, data: limitedRows }); - success = true; - } else { resStr = qRes?.message || 'SQL 执行失败'; } - } catch (e: any) { - resStr = `SQL 执行异常: ${e?.message || e}`; - } - } else { resStr = 'Connection not found'; } - break; - } - default: - if (mcpToolDescriptor) { - try { - const Service = (window as any).go?.aiservice?.Service; - const toolResult = await Service?.AICallMCPTool?.(tc.function.name, tc.function.arguments || '{}'); - resStr = String(toolResult?.content || (toolResult?.isError ? 'MCP 工具调用失败' : '')); - success = !!toolResult && !toolResult.isError; - } catch (e: any) { - resStr = `MCP 工具调用失败: ${e?.message || e}`; - } - } else { - resStr = `Unknown function: ${tc.function.name}`; - } - } - } catch (e: any) { - resStr = e.message; - } - - const resolvedToolDescriptor = mcpToolMap.get(tc.function.name); - const toolResultMsg: AIChatMessage = { + const execution = await executeLocalAIToolCall({ + toolCall: tc, + connections: currentConnections, + mcpTools, + toolContextMap: toolContextMapRef.current, + }); + const toolResultMsg: AIChatMessage = buildToolResultMessage({ id: genId(), - role: 'tool', - content: resStr, timestamp: Date.now(), - tool_call_id: tc.id, - tool_name: resolvedToolDescriptor?.title || resolvedToolDescriptor?.originalName || tc.function.name, - success - }; + toolCall: tc, + execution, + }); results.push(toolResultMsg); // 【实时写入】每执行完一条立即写入 store,让 UI 能实时看到进度打勾 diff --git a/frontend/src/components/ai/aiLocalToolExecutor.test.ts b/frontend/src/components/ai/aiLocalToolExecutor.test.ts new file mode 100644 index 0000000..31cbe94 --- /dev/null +++ b/frontend/src/components/ai/aiLocalToolExecutor.test.ts @@ -0,0 +1,131 @@ +import { describe, expect, it, vi } from 'vitest'; + +import type { AIMCPToolDescriptor, AIToolCall, SavedConnection } from '../../types'; +import { buildToolResultMessage, executeLocalAIToolCall } from './aiLocalToolExecutor'; + +const buildConnection = (): SavedConnection => ({ + id: 'conn-1', + name: '主库', + config: { + type: 'mysql', + host: '127.0.0.1', + port: 3306, + user: 'root', + }, +}); + +const buildToolCall = (name: string, args: Record): AIToolCall => ({ + id: `call-${name}`, + type: 'function', + function: { + name, + arguments: JSON.stringify(args), + }, +}); + +describe('aiLocalToolExecutor', () => { + it('caches validated table context after get_tables succeeds', async () => { + const toolContextMap = new Map(); + const result = await executeLocalAIToolCall({ + toolCall: buildToolCall('get_tables', { connectionId: 'conn-1', dbName: 'crm' }), + connections: [buildConnection()], + mcpTools: [], + toolContextMap, + runtime: { + getDatabases: vi.fn(), + getTables: vi.fn().mockResolvedValue({ + success: true, + data: [{ Table: 'users' }, { Table: 'orders' }], + }), + }, + }); + + expect(result.success).toBe(true); + expect(result.content).toContain('users'); + expect(toolContextMap.get('conn-1:crm')).toEqual({ + connectionId: 'conn-1', + dbName: 'crm', + tables: ['users', 'orders'], + }); + }); + + it('blocks execute_sql when the AI safety check rejects the statement', async () => { + const query = vi.fn(); + const result = await executeLocalAIToolCall({ + toolCall: buildToolCall('execute_sql', { + connectionId: 'conn-1', + dbName: 'crm', + sql: 'DELETE FROM users', + }), + connections: [buildConnection()], + mcpTools: [], + toolContextMap: new Map(), + runtime: { + getDatabases: vi.fn(), + getTables: vi.fn(), + getColumns: vi.fn(), + getIndexes: vi.fn(), + getForeignKeys: vi.fn(), + getTriggers: vi.fn(), + showCreateTable: vi.fn(), + query, + checkSQL: vi.fn().mockResolvedValue({ + allowed: false, + operationType: 'DELETE', + }), + }, + }); + + expect(result.success).toBe(false); + expect(result.content).toContain('安全策略拦截'); + expect(query).not.toHaveBeenCalled(); + }); + + it('returns index definitions and resolves the tool label for MCP descriptors', async () => { + const mcpTools: AIMCPToolDescriptor[] = [{ + alias: 'custom_tool', + originalName: 'custom_tool', + serverId: 'server-1', + serverName: 'demo', + title: '自定义探针', + description: '', + }]; + const indexResult = await executeLocalAIToolCall({ + toolCall: buildToolCall('get_indexes', { + connectionId: 'conn-1', + dbName: 'crm', + tableName: 'users', + }), + connections: [buildConnection()], + mcpTools, + toolContextMap: new Map(), + runtime: { + getDatabases: vi.fn(), + getTables: vi.fn(), + getColumns: vi.fn(), + getIndexes: vi.fn().mockResolvedValue({ + success: true, + data: [{ keyName: 'idx_users_email', nonUnique: 0 }], + }), + getForeignKeys: vi.fn(), + getTriggers: vi.fn(), + showCreateTable: vi.fn(), + query: vi.fn(), + }, + }); + const message = buildToolResultMessage({ + id: 'msg-1', + timestamp: 1, + toolCall: buildToolCall('custom_tool', {}), + execution: { + content: 'ok', + success: true, + toolName: '自定义探针', + }, + }); + + expect(indexResult.success).toBe(true); + expect(indexResult.content).toContain('idx_users_email'); + expect(message.tool_name).toBe('自定义探针'); + }); +}); diff --git a/frontend/src/components/ai/aiLocalToolExecutor.ts b/frontend/src/components/ai/aiLocalToolExecutor.ts new file mode 100644 index 0000000..629145e --- /dev/null +++ b/frontend/src/components/ai/aiLocalToolExecutor.ts @@ -0,0 +1,364 @@ +import { DBGetDatabases, DBGetTables } from '../../../wailsjs/go/app/App'; + +import type { AIChatMessage, AIMCPToolDescriptor, AIToolCall, SavedConnection } from '../../types'; +import { buildRpcConnectionConfig } from '../../utils/connectionRpcConfig'; +import { buildAIReadonlyPreviewSQL } from '../../utils/aiSqlLimit'; +import { resolveAITableSchemaToolResult } from '../../utils/aiTableSchemaTool'; + +export interface AIToolContextEntry { + connectionId: string; + dbName: string; + tables: string[]; +} + +interface AILocalToolRuntime { + getDatabases: (config: any) => Promise; + getTables: (config: any, dbName: string) => Promise; + getColumns: (config: any, dbName: string, tableName: string) => Promise; + getIndexes: (config: any, dbName: string, tableName: string) => Promise; + getForeignKeys: (config: any, dbName: string, tableName: string) => Promise; + getTriggers: (config: any, dbName: string, tableName: string) => Promise; + showCreateTable: (config: any, dbName: string, tableName: string) => Promise; + query: (config: any, dbName: string, sql: string) => Promise; + checkSQL?: (sql: string) => Promise<{ allowed?: boolean; operationType?: string } | undefined>; + callMCPTool?: (name: string, args: string) => Promise<{ content?: string; isError?: boolean } | undefined>; +} + +export interface ExecuteLocalAIToolCallOptions { + toolCall: AIToolCall; + connections: SavedConnection[]; + mcpTools: AIMCPToolDescriptor[]; + toolContextMap: Map; + runtime?: Partial; +} + +export interface ExecuteLocalAIToolCallResult { + content: string; + success: boolean; + toolName: string; +} + +const buildDefaultRuntime = (): AILocalToolRuntime => ({ + getDatabases: DBGetDatabases, + getTables: DBGetTables, + getColumns: async (config, dbName, tableName) => { + const mod = await import('../../../wailsjs/go/app/App'); + return mod.DBGetColumns(config, dbName, tableName); + }, + getIndexes: async (config, dbName, tableName) => { + const mod = await import('../../../wailsjs/go/app/App'); + return mod.DBGetIndexes(config, dbName, tableName); + }, + getForeignKeys: async (config, dbName, tableName) => { + const mod = await import('../../../wailsjs/go/app/App'); + return mod.DBGetForeignKeys(config, dbName, tableName); + }, + getTriggers: async (config, dbName, tableName) => { + const mod = await import('../../../wailsjs/go/app/App'); + return mod.DBGetTriggers(config, dbName, tableName); + }, + showCreateTable: async (config, dbName, tableName) => { + const mod = await import('../../../wailsjs/go/app/App'); + return mod.DBShowCreateTable(config, dbName, tableName); + }, + query: async (config, dbName, sql) => { + const mod = await import('../../../wailsjs/go/app/App'); + return mod.DBQuery(config, dbName, sql); + }, + checkSQL: async (sql) => { + const service = (window as any).go?.aiservice?.Service; + if (typeof service?.AICheckSQL !== 'function') { + return undefined; + } + return service.AICheckSQL(sql); + }, + callMCPTool: async (name, args) => { + const service = (window as any).go?.aiservice?.Service; + if (typeof service?.AICallMCPTool !== 'function') { + return undefined; + } + return service.AICallMCPTool(name, args); + }, +}); + +const normalizeTableList = (rows: any[]): string[] => + rows.map((row) => row.Table || row.table || (Object.values(row)[0] as string)); + +const normalizeColumns = (rows: any[]) => + rows.map((column) => { + const keys = Object.keys(column); + return { + field: column.Field || column.field || column.COLUMN_NAME || column.column_name || column.Name || column.name || (keys.length > 0 ? column[keys[0]] : ''), + type: column.Type || column.type || column.DATA_TYPE || column.data_type || (keys.length > 1 ? column[keys[1]] : ''), + nullable: column.Null || column.null || column.IS_NULLABLE || column.is_nullable || column.Nullable || column.nullable || '', + default: column.Default || column.default || column.COLUMN_DEFAULT || column.column_default || column.DefaultValue || '', + comment: column.Comment || column.comment || column.COLUMN_COMMENT || column.column_comment || column.Description || '', + }; + }); + +const buildToolName = (toolCall: AIToolCall, descriptor?: AIMCPToolDescriptor) => + descriptor?.title || descriptor?.originalName || toolCall.function.name; + +const findConnection = (connections: SavedConnection[], connectionId: string) => + connections.find((connection) => connection.id === connectionId); + +export async function executeLocalAIToolCall({ + toolCall, + connections, + mcpTools, + toolContextMap, + runtime, +}: ExecuteLocalAIToolCallOptions): Promise { + const mergedRuntime = { ...buildDefaultRuntime(), ...(runtime || {}) }; + const descriptor = mcpTools.find((tool) => tool.alias === toolCall.function.name); + let content = ''; + let success = false; + + try { + const args = JSON.parse(toolCall.function.arguments || '{}'); + switch (toolCall.function.name) { + case 'get_connections': { + const availableConnections = connections.map((connection) => ({ + id: connection.id, + name: connection.name, + type: connection.config?.type, + host: (connection.config as any)?.host || (connection.config as any)?.addr || '', + })); + content = JSON.stringify(availableConnections); + success = true; + break; + } + case 'get_databases': { + const connection = findConnection(connections, args.connectionId); + if (!connection) { + content = 'Connection not found'; + break; + } + try { + const result = await mergedRuntime.getDatabases(buildRpcConnectionConfig(connection.config) as any); + if (result?.success && Array.isArray(result.data)) { + let databaseNames = result.data.map((row: any) => row.Database || row.database || Object.values(row)[0]); + if (databaseNames.length > 50) { + databaseNames = [...databaseNames.slice(0, 50), '...(截断)']; + } + content = JSON.stringify(databaseNames); + success = true; + } else { + content = result?.message || 'Failed to fetch DBs'; + } + } catch (error: any) { + content = `获取数据库列表失败: ${error?.message || error}`; + } + break; + } + case 'get_tables': { + const connection = findConnection(connections, args.connectionId); + if (!connection) { + content = 'Connection not found'; + break; + } + try { + const rawDbName = args.dbName || args.database; + const safeDbName = rawDbName ? String(rawDbName).trim() : ''; + const result = await mergedRuntime.getTables(buildRpcConnectionConfig(connection.config) as any, safeDbName); + if (result?.success && Array.isArray(result.data)) { + let tableNames = normalizeTableList(result.data); + if (tableNames.length > 150) { + tableNames = [...tableNames.slice(0, 150), '...(截断)']; + } + content = JSON.stringify(tableNames); + success = true; + toolContextMap.set(`${args.connectionId}:${safeDbName}`, { + connectionId: args.connectionId, + dbName: safeDbName, + tables: tableNames.filter((tableName) => tableName !== '...(截断)'), + }); + } else { + content = result?.message || 'Failed to fetch Tables'; + } + } catch (error: any) { + content = `获取表列表失败: ${error?.message || error}`; + } + break; + } + case 'get_columns': { + const connection = findConnection(connections, args.connectionId); + if (!connection) { + content = 'Connection not found'; + break; + } + try { + const safeDbName = args.dbName ? String(args.dbName).trim() : ''; + const safeTable = args.tableName ? String(args.tableName).trim() : ''; + const result = await mergedRuntime.getColumns(buildRpcConnectionConfig(connection.config) as any, safeDbName, safeTable); + if (result?.success && Array.isArray(result.data)) { + const columns = normalizeColumns(result.data); + const fieldNames = columns.map((column) => column.field).join(', '); + content = `⚠️ 以下为 ${safeTable} 表的真实字段列表。生成 SQL 时只能使用这些 field 值作为列名,必须原样使用,禁止修改、缩写或自行拼凑字段名。\n可用字段:${fieldNames}\n详细信息:${JSON.stringify(columns)}`; + success = true; + } else { + content = result?.message || 'Failed to fetch columns'; + } + } catch (error: any) { + content = `获取字段列表失败: ${error?.message || error}`; + } + break; + } + case 'get_indexes': { + const connection = findConnection(connections, args.connectionId); + if (!connection) { + content = 'Connection not found'; + break; + } + try { + const safeDbName = args.dbName ? String(args.dbName).trim() : ''; + const safeTable = args.tableName ? String(args.tableName).trim() : ''; + const result = await mergedRuntime.getIndexes(buildRpcConnectionConfig(connection.config) as any, safeDbName, safeTable); + if (result?.success && Array.isArray(result.data)) { + content = JSON.stringify(result.data); + success = true; + } else { + content = result?.message || 'Failed to fetch indexes'; + } + } catch (error: any) { + content = `获取索引定义失败: ${error?.message || error}`; + } + break; + } + case 'get_foreign_keys': { + const connection = findConnection(connections, args.connectionId); + if (!connection) { + content = 'Connection not found'; + break; + } + try { + const safeDbName = args.dbName ? String(args.dbName).trim() : ''; + const safeTable = args.tableName ? String(args.tableName).trim() : ''; + const result = await mergedRuntime.getForeignKeys(buildRpcConnectionConfig(connection.config) as any, safeDbName, safeTable); + if (result?.success && Array.isArray(result.data)) { + content = JSON.stringify(result.data); + success = true; + } else { + content = result?.message || 'Failed to fetch foreign keys'; + } + } catch (error: any) { + content = `获取外键关系失败: ${error?.message || error}`; + } + break; + } + case 'get_triggers': { + const connection = findConnection(connections, args.connectionId); + if (!connection) { + content = 'Connection not found'; + break; + } + try { + const safeDbName = args.dbName ? String(args.dbName).trim() : ''; + const safeTable = args.tableName ? String(args.tableName).trim() : ''; + const result = await mergedRuntime.getTriggers(buildRpcConnectionConfig(connection.config) as any, safeDbName, safeTable); + if (result?.success && Array.isArray(result.data)) { + content = JSON.stringify(result.data); + success = true; + } else { + content = result?.message || 'Failed to fetch triggers'; + } + } catch (error: any) { + content = `获取触发器定义失败: ${error?.message || error}`; + } + break; + } + case 'get_table_ddl': { + const connection = findConnection(connections, args.connectionId); + if (!connection) { + content = 'Connection not found'; + break; + } + try { + const safeDbName = args.dbName ? String(args.dbName).trim() : ''; + const safeTable = args.tableName ? String(args.tableName).trim() : ''; + const rpcConfig = buildRpcConnectionConfig(connection.config) as any; + const result = await resolveAITableSchemaToolResult({ + tableName: safeTable, + fetchDDL: () => mergedRuntime.showCreateTable(rpcConfig, safeDbName, safeTable), + fetchColumns: () => mergedRuntime.getColumns(rpcConfig, safeDbName, safeTable), + }); + content = result.content; + success = result.success; + } catch (error: any) { + content = `获取建表语句失败: ${error?.message || error}`; + } + break; + } + case 'execute_sql': { + const connection = findConnection(connections, args.connectionId); + if (!connection) { + content = 'Connection not found'; + break; + } + try { + const safeDbName = args.dbName ? String(args.dbName).trim() : ''; + const safeSql = args.sql ? String(args.sql).trim() : ''; + if (typeof mergedRuntime.checkSQL === 'function') { + const checkResult = await mergedRuntime.checkSQL(safeSql); + if (checkResult && checkResult.allowed === false) { + content = `安全策略拦截:当前安全级别不允许执行 ${checkResult.operationType} 类型的 SQL。请将 SQL 展示给用户,让用户手动执行。`; + break; + } + } + const finalSql = buildAIReadonlyPreviewSQL(connection.config?.type || '', safeSql, 50, connection.config?.driver || ''); + const result = await mergedRuntime.query(buildRpcConnectionConfig(connection.config) as any, safeDbName, finalSql); + if (result?.success) { + const rows = Array.isArray(result.data) ? result.data : []; + content = JSON.stringify({ rowCount: rows.length, data: rows.slice(0, 50) }); + success = true; + } else { + content = result?.message || 'SQL 执行失败'; + } + } catch (error: any) { + content = `SQL 执行异常: ${error?.message || error}`; + } + break; + } + default: { + if (!descriptor) { + content = `Unknown function: ${toolCall.function.name}`; + break; + } + try { + const result = await mergedRuntime.callMCPTool?.(toolCall.function.name, toolCall.function.arguments || '{}'); + content = String(result?.content || (result?.isError ? 'MCP 工具调用失败' : '')); + success = !!result && !result.isError; + } catch (error: any) { + content = `MCP 工具调用失败: ${error?.message || error}`; + } + break; + } + } + } catch (error: any) { + content = error?.message || String(error); + } + + return { + content, + success, + toolName: buildToolName(toolCall, descriptor), + }; +} + +export function buildToolResultMessage(params: { + id: string; + timestamp: number; + toolCall: AIToolCall; + execution: ExecuteLocalAIToolCallResult; +}): AIChatMessage { + const { id, timestamp, toolCall, execution } = params; + return { + id, + role: 'tool', + content: execution.content, + timestamp, + tool_call_id: toolCall.id, + tool_name: execution.toolName, + success: execution.success, + }; +}