♻️ refactor(ai-chat): 抽离本地工具执行器并补齐测试

- 将 AIChatPanel 中本地工具执行 switch 抽成独立 helper,降低主面板耦合度
- 为表结构、SQL 安全拦截和工具结果映射补充独立单测
- 保持 AI 面板交互不变,补做构建与浏览器冒烟验证
This commit is contained in:
Syngnat
2026-06-07 22:40:07 +08:00
parent eff2f7f63a
commit 802385464d
4 changed files with 515 additions and 231 deletions

View File

@@ -23,16 +23,15 @@ describe('AIChatPanel message render isolation', () => {
it('loads MCP tools and skills into the runtime tool chain', () => {
expect(source).toContain('AIListMCPTools');
expect(source).toContain('AIGetSkills');
expect(source).toContain('AICallMCPTool');
expect(source).toContain('executeLocalAIToolCall');
expect(source).toContain('以下是当前启用的 Skill');
expect(source).toContain('buildAvailableAIChatTools');
});
it('teaches the runtime to use deeper schema tools when analyzing structure details', () => {
expect(source).toContain('get_indexes、get_foreign_keys、get_triggers、get_table_ddl');
expect(source).toContain("case 'get_indexes':");
expect(source).toContain("case 'get_foreign_keys':");
expect(source).toContain("case 'get_triggers':");
expect(source).toContain('toolContextMap: toolContextMapRef.current');
expect(source).toContain('buildToolResultMessage');
});
it('keeps the v2 history mode sorted by the latest updated session first', () => {

View File

@@ -2,7 +2,6 @@ import React, { useState, useRef, useEffect, useCallback, useMemo } from 'react'
import { createPortal } from 'react-dom';
import { useStore, loadAISessionsFromBackend, loadAISessionFromBackend } from '../store';
import { EventsOn, EventsOff } from '../../wailsjs/runtime';
import { DBGetDatabases, DBGetTables } from '../../wailsjs/go/app/App';
import type { OverlayWorkbenchTheme } from '../utils/overlayWorkbenchTheme';
import type {
AIChatMessage,
@@ -28,13 +27,16 @@ import {
buildMissingProviderNotice,
buildModelFetchFailedNotice,
} from '../utils/aiComposerNotice';
import { buildAIReadonlyPreviewSQL } from '../utils/aiSqlLimit';
import { resolveAITableSchemaToolResult } from '../utils/aiTableSchemaTool';
import { consumeAIChatSendShortcutOnKeyDown } from '../utils/aiChatSendShortcut';
import { toAIRequestMessage } from '../utils/aiMessagePayload';
import { getShortcutPlatform, resolveShortcutBinding } from '../utils/shortcuts';
import { isMacLikePlatform } from '../utils/appearance';
import { buildAvailableAIChatTools } from '../utils/aiToolRegistry';
import {
buildToolResultMessage,
executeLocalAIToolCall,
type AIToolContextEntry,
} from './ai/aiLocalToolExecutor';
interface AIChatPanelProps {
width?: number;
@@ -1226,7 +1228,7 @@ SELECT * FROM users WHERE status = 1;
}, [availableTools, skills, userPromptSettings]);
// 记录所有成功的 get_tables 调用结果,用于表级精确匹配
const toolContextMapRef = useRef<Map<string, { connectionId: string; dbName: string; tables: string[] }>>(new Map());
const toolContextMapRef = useRef<Map<string, AIToolContextEntry>>(new Map());
const executeLocalTools = useCallback(async (toolCalls: AIToolCall[], currentAsstMsgId: string) => {
const currentAsstMsg = (useStore.getState().aiChatHistory[sid] || []).find(m => m.id === currentAsstMsgId);
@@ -1253,233 +1255,21 @@ SELECT * FROM users WHERE status = 1;
}
const results: AIChatMessage[] = [];
const mcpToolMap = new Map(mcpTools.map((tool) => [tool.alias, tool]));
const currentConnections = useStore.getState().connections;
// 【串行逐条执行 + 实时写入 store】
for (const tc of toolCalls) {
let resStr = '';
let success = false;
try {
const args = JSON.parse(tc.function.arguments || '{}');
const mcpToolDescriptor = mcpToolMap.get(tc.function.name);
switch (tc.function.name) {
case 'get_connections':
const conns = useStore.getState().connections.map(c => ({
id: c.id,
name: c.name,
type: c.config?.type,
host: (c.config as any)?.host || (c.config as any)?.addr || ''
}));
resStr = JSON.stringify(conns);
success = true;
break;
case 'get_databases': {
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
if (conn) {
try {
const dbRes = await DBGetDatabases(buildRpcConnectionConfig(conn.config) as any);
if (dbRes?.success && Array.isArray(dbRes.data)) {
let dNames = dbRes.data.map((r: any) => r.Database || r.database || Object.values(r)[0]);
if (dNames.length > 50) dNames = [...dNames.slice(0, 50), '...(截断)'];
resStr = JSON.stringify(dNames);
success = true;
} else {
resStr = dbRes?.message || 'Failed to fetch DBs';
}
} catch (e: any) {
resStr = `获取数据库列表失败: ${e?.message || e}`;
}
} else { resStr = 'Connection not found'; }
break;
}
case 'get_tables': {
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
if (conn) {
try {
const rawDbName = args.dbName || args.database;
const safeDbName = rawDbName ? String(rawDbName).trim() : '';
const tbRes = await DBGetTables(buildRpcConnectionConfig(conn.config) as any, safeDbName);
if (tbRes?.success && Array.isArray(tbRes.data)) {
let tNames = tbRes.data.map((r: any) => r.Table || r.table || Object.values(r)[0] as string);
if (tNames.length > 150) tNames = [...tNames.slice(0, 150), '...(截断)'];
resStr = JSON.stringify(tNames);
success = true;
// 🔑 记录已验证的上下文参数和表列表(用于后续表级精确匹配)
toolContextMapRef.current.set(`${args.connectionId}:${safeDbName}`, {
connectionId: args.connectionId,
dbName: safeDbName,
tables: tNames.filter((t: string) => t !== '...(截断)')
});
} else { resStr = tbRes?.message || 'Failed to fetch Tables'; }
} catch (e: any) {
resStr = `获取表列表失败: ${e?.message || e}`;
}
} else { resStr = 'Connection not found'; }
break;
}
case 'get_columns': {
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
if (conn) {
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const { DBGetColumns } = await import('../../wailsjs/go/app/App');
const colRes = await DBGetColumns(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable);
if (colRes?.success && Array.isArray(colRes.data)) {
// 只保留关键字段信息,减少 token 占用
const cols = colRes.data.map((c: any) => {
const keys = Object.keys(c);
return {
field: c.Field || c.field || c.COLUMN_NAME || c.column_name || c.Name || c.name || (keys.length > 0 ? c[keys[0]] : ''),
type: c.Type || c.type || c.DATA_TYPE || c.data_type || (keys.length > 1 ? c[keys[1]] : ''),
nullable: c.Null || c.null || c.IS_NULLABLE || c.is_nullable || c.Nullable || c.nullable || '',
default: c.Default || c.default || c.COLUMN_DEFAULT || c.column_default || c.DefaultValue || '',
comment: c.Comment || c.comment || c.COLUMN_COMMENT || c.column_comment || c.Description || '',
};
});
// ⚠️ 在工具返回结果中直接注入强制警告,确保模型使用精确字段名
const fieldNames = cols.map((c: any) => c.field).join(', ');
resStr = `⚠️ 以下为 ${safeTable} 表的真实字段列表。生成 SQL 时只能使用这些 field 值作为列名,必须原样使用,禁止修改、缩写或自行拼凑字段名。\n可用字段${fieldNames}\n详细信息${JSON.stringify(cols)}`;
success = true;
} else { resStr = colRes?.message || 'Failed to fetch columns'; }
} catch (e: any) {
resStr = `获取字段列表失败: ${e?.message || e}`;
}
} else { resStr = 'Connection not found'; }
break;
}
case 'get_indexes': {
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
if (conn) {
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const { DBGetIndexes } = await import('../../wailsjs/go/app/App');
const indexRes = await DBGetIndexes(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable);
if (indexRes?.success && Array.isArray(indexRes.data)) {
resStr = JSON.stringify(indexRes.data);
success = true;
} else { resStr = indexRes?.message || 'Failed to fetch indexes'; }
} catch (e: any) {
resStr = `获取索引定义失败: ${e?.message || e}`;
}
} else { resStr = 'Connection not found'; }
break;
}
case 'get_foreign_keys': {
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
if (conn) {
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const { DBGetForeignKeys } = await import('../../wailsjs/go/app/App');
const foreignKeyRes = await DBGetForeignKeys(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable);
if (foreignKeyRes?.success && Array.isArray(foreignKeyRes.data)) {
resStr = JSON.stringify(foreignKeyRes.data);
success = true;
} else { resStr = foreignKeyRes?.message || 'Failed to fetch foreign keys'; }
} catch (e: any) {
resStr = `获取外键关系失败: ${e?.message || e}`;
}
} else { resStr = 'Connection not found'; }
break;
}
case 'get_triggers': {
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
if (conn) {
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const { DBGetTriggers } = await import('../../wailsjs/go/app/App');
const triggerRes = await DBGetTriggers(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable);
if (triggerRes?.success && Array.isArray(triggerRes.data)) {
resStr = JSON.stringify(triggerRes.data);
success = true;
} else { resStr = triggerRes?.message || 'Failed to fetch triggers'; }
} catch (e: any) {
resStr = `获取触发器定义失败: ${e?.message || e}`;
}
} else { resStr = 'Connection not found'; }
break;
}
case 'get_table_ddl': {
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
if (conn) {
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const { DBShowCreateTable, DBGetColumns } = await import('../../wailsjs/go/app/App');
const rpcConfig = buildRpcConnectionConfig(conn.config) as any;
const toolResult = await resolveAITableSchemaToolResult({
tableName: safeTable,
fetchDDL: () => DBShowCreateTable(rpcConfig, safeDbName, safeTable),
fetchColumns: () => DBGetColumns(rpcConfig, safeDbName, safeTable),
});
resStr = toolResult.content;
success = toolResult.success;
} catch (e: any) {
resStr = `获取建表语句失败: ${e?.message || e}`;
}
} else { resStr = 'Connection not found'; }
break;
}
case 'execute_sql': {
const conn = useStore.getState().connections.find(c => c.id === args.connectionId);
if (conn) {
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeSql = args.sql ? String(args.sql).trim() : '';
// 安全级别检查
const Service = (window as any).go?.aiservice?.Service;
if (Service?.AICheckSQL) {
const check = await Service.AICheckSQL(safeSql);
if (!check.allowed) {
resStr = `安全策略拦截:当前安全级别不允许执行 ${check.operationType} 类型的 SQL。请将 SQL 展示给用户,让用户手动执行。`;
break;
}
}
const { DBQuery } = await import('../../wailsjs/go/app/App');
const finalSql = buildAIReadonlyPreviewSQL(conn.config?.type || '', safeSql, 50, conn.config?.driver || '');
const qRes = await DBQuery(buildRpcConnectionConfig(conn.config) as any, safeDbName, finalSql);
if (qRes?.success) {
const rows = Array.isArray(qRes.data) ? qRes.data : [];
const limitedRows = rows.slice(0, 50);
resStr = JSON.stringify({ rowCount: rows.length, data: limitedRows });
success = true;
} else { resStr = qRes?.message || 'SQL 执行失败'; }
} catch (e: any) {
resStr = `SQL 执行异常: ${e?.message || e}`;
}
} else { resStr = 'Connection not found'; }
break;
}
default:
if (mcpToolDescriptor) {
try {
const Service = (window as any).go?.aiservice?.Service;
const toolResult = await Service?.AICallMCPTool?.(tc.function.name, tc.function.arguments || '{}');
resStr = String(toolResult?.content || (toolResult?.isError ? 'MCP 工具调用失败' : ''));
success = !!toolResult && !toolResult.isError;
} catch (e: any) {
resStr = `MCP 工具调用失败: ${e?.message || e}`;
}
} else {
resStr = `Unknown function: ${tc.function.name}`;
}
}
} catch (e: any) {
resStr = e.message;
}
const resolvedToolDescriptor = mcpToolMap.get(tc.function.name);
const toolResultMsg: AIChatMessage = {
const execution = await executeLocalAIToolCall({
toolCall: tc,
connections: currentConnections,
mcpTools,
toolContextMap: toolContextMapRef.current,
});
const toolResultMsg: AIChatMessage = buildToolResultMessage({
id: genId(),
role: 'tool',
content: resStr,
timestamp: Date.now(),
tool_call_id: tc.id,
tool_name: resolvedToolDescriptor?.title || resolvedToolDescriptor?.originalName || tc.function.name,
success
};
toolCall: tc,
execution,
});
results.push(toolResultMsg);
// 【实时写入】每执行完一条立即写入 store让 UI 能实时看到进度打勾

View File

@@ -0,0 +1,131 @@
import { describe, expect, it, vi } from 'vitest';
import type { AIMCPToolDescriptor, AIToolCall, SavedConnection } from '../../types';
import { buildToolResultMessage, executeLocalAIToolCall } from './aiLocalToolExecutor';
const buildConnection = (): SavedConnection => ({
id: 'conn-1',
name: '主库',
config: {
type: 'mysql',
host: '127.0.0.1',
port: 3306,
user: 'root',
},
});
const buildToolCall = (name: string, args: Record<string, unknown>): AIToolCall => ({
id: `call-${name}`,
type: 'function',
function: {
name,
arguments: JSON.stringify(args),
},
});
describe('aiLocalToolExecutor', () => {
it('caches validated table context after get_tables succeeds', async () => {
const toolContextMap = new Map();
const result = await executeLocalAIToolCall({
toolCall: buildToolCall('get_tables', { connectionId: 'conn-1', dbName: 'crm' }),
connections: [buildConnection()],
mcpTools: [],
toolContextMap,
runtime: {
getDatabases: vi.fn(),
getTables: vi.fn().mockResolvedValue({
success: true,
data: [{ Table: 'users' }, { Table: 'orders' }],
}),
},
});
expect(result.success).toBe(true);
expect(result.content).toContain('users');
expect(toolContextMap.get('conn-1:crm')).toEqual({
connectionId: 'conn-1',
dbName: 'crm',
tables: ['users', 'orders'],
});
});
it('blocks execute_sql when the AI safety check rejects the statement', async () => {
const query = vi.fn();
const result = await executeLocalAIToolCall({
toolCall: buildToolCall('execute_sql', {
connectionId: 'conn-1',
dbName: 'crm',
sql: 'DELETE FROM users',
}),
connections: [buildConnection()],
mcpTools: [],
toolContextMap: new Map(),
runtime: {
getDatabases: vi.fn(),
getTables: vi.fn(),
getColumns: vi.fn(),
getIndexes: vi.fn(),
getForeignKeys: vi.fn(),
getTriggers: vi.fn(),
showCreateTable: vi.fn(),
query,
checkSQL: vi.fn().mockResolvedValue({
allowed: false,
operationType: 'DELETE',
}),
},
});
expect(result.success).toBe(false);
expect(result.content).toContain('安全策略拦截');
expect(query).not.toHaveBeenCalled();
});
it('returns index definitions and resolves the tool label for MCP descriptors', async () => {
const mcpTools: AIMCPToolDescriptor[] = [{
alias: 'custom_tool',
originalName: 'custom_tool',
serverId: 'server-1',
serverName: 'demo',
title: '自定义探针',
description: '',
}];
const indexResult = await executeLocalAIToolCall({
toolCall: buildToolCall('get_indexes', {
connectionId: 'conn-1',
dbName: 'crm',
tableName: 'users',
}),
connections: [buildConnection()],
mcpTools,
toolContextMap: new Map(),
runtime: {
getDatabases: vi.fn(),
getTables: vi.fn(),
getColumns: vi.fn(),
getIndexes: vi.fn().mockResolvedValue({
success: true,
data: [{ keyName: 'idx_users_email', nonUnique: 0 }],
}),
getForeignKeys: vi.fn(),
getTriggers: vi.fn(),
showCreateTable: vi.fn(),
query: vi.fn(),
},
});
const message = buildToolResultMessage({
id: 'msg-1',
timestamp: 1,
toolCall: buildToolCall('custom_tool', {}),
execution: {
content: 'ok',
success: true,
toolName: '自定义探针',
},
});
expect(indexResult.success).toBe(true);
expect(indexResult.content).toContain('idx_users_email');
expect(message.tool_name).toBe('自定义探针');
});
});

View File

@@ -0,0 +1,364 @@
import { DBGetDatabases, DBGetTables } from '../../../wailsjs/go/app/App';
import type { AIChatMessage, AIMCPToolDescriptor, AIToolCall, SavedConnection } from '../../types';
import { buildRpcConnectionConfig } from '../../utils/connectionRpcConfig';
import { buildAIReadonlyPreviewSQL } from '../../utils/aiSqlLimit';
import { resolveAITableSchemaToolResult } from '../../utils/aiTableSchemaTool';
export interface AIToolContextEntry {
connectionId: string;
dbName: string;
tables: string[];
}
interface AILocalToolRuntime {
getDatabases: (config: any) => Promise<any>;
getTables: (config: any, dbName: string) => Promise<any>;
getColumns: (config: any, dbName: string, tableName: string) => Promise<any>;
getIndexes: (config: any, dbName: string, tableName: string) => Promise<any>;
getForeignKeys: (config: any, dbName: string, tableName: string) => Promise<any>;
getTriggers: (config: any, dbName: string, tableName: string) => Promise<any>;
showCreateTable: (config: any, dbName: string, tableName: string) => Promise<any>;
query: (config: any, dbName: string, sql: string) => Promise<any>;
checkSQL?: (sql: string) => Promise<{ allowed?: boolean; operationType?: string } | undefined>;
callMCPTool?: (name: string, args: string) => Promise<{ content?: string; isError?: boolean } | undefined>;
}
export interface ExecuteLocalAIToolCallOptions {
toolCall: AIToolCall;
connections: SavedConnection[];
mcpTools: AIMCPToolDescriptor[];
toolContextMap: Map<string, AIToolContextEntry>;
runtime?: Partial<AILocalToolRuntime>;
}
export interface ExecuteLocalAIToolCallResult {
content: string;
success: boolean;
toolName: string;
}
const buildDefaultRuntime = (): AILocalToolRuntime => ({
getDatabases: DBGetDatabases,
getTables: DBGetTables,
getColumns: async (config, dbName, tableName) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBGetColumns(config, dbName, tableName);
},
getIndexes: async (config, dbName, tableName) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBGetIndexes(config, dbName, tableName);
},
getForeignKeys: async (config, dbName, tableName) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBGetForeignKeys(config, dbName, tableName);
},
getTriggers: async (config, dbName, tableName) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBGetTriggers(config, dbName, tableName);
},
showCreateTable: async (config, dbName, tableName) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBShowCreateTable(config, dbName, tableName);
},
query: async (config, dbName, sql) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBQuery(config, dbName, sql);
},
checkSQL: async (sql) => {
const service = (window as any).go?.aiservice?.Service;
if (typeof service?.AICheckSQL !== 'function') {
return undefined;
}
return service.AICheckSQL(sql);
},
callMCPTool: async (name, args) => {
const service = (window as any).go?.aiservice?.Service;
if (typeof service?.AICallMCPTool !== 'function') {
return undefined;
}
return service.AICallMCPTool(name, args);
},
});
const normalizeTableList = (rows: any[]): string[] =>
rows.map((row) => row.Table || row.table || (Object.values(row)[0] as string));
const normalizeColumns = (rows: any[]) =>
rows.map((column) => {
const keys = Object.keys(column);
return {
field: column.Field || column.field || column.COLUMN_NAME || column.column_name || column.Name || column.name || (keys.length > 0 ? column[keys[0]] : ''),
type: column.Type || column.type || column.DATA_TYPE || column.data_type || (keys.length > 1 ? column[keys[1]] : ''),
nullable: column.Null || column.null || column.IS_NULLABLE || column.is_nullable || column.Nullable || column.nullable || '',
default: column.Default || column.default || column.COLUMN_DEFAULT || column.column_default || column.DefaultValue || '',
comment: column.Comment || column.comment || column.COLUMN_COMMENT || column.column_comment || column.Description || '',
};
});
const buildToolName = (toolCall: AIToolCall, descriptor?: AIMCPToolDescriptor) =>
descriptor?.title || descriptor?.originalName || toolCall.function.name;
const findConnection = (connections: SavedConnection[], connectionId: string) =>
connections.find((connection) => connection.id === connectionId);
export async function executeLocalAIToolCall({
toolCall,
connections,
mcpTools,
toolContextMap,
runtime,
}: ExecuteLocalAIToolCallOptions): Promise<ExecuteLocalAIToolCallResult> {
const mergedRuntime = { ...buildDefaultRuntime(), ...(runtime || {}) };
const descriptor = mcpTools.find((tool) => tool.alias === toolCall.function.name);
let content = '';
let success = false;
try {
const args = JSON.parse(toolCall.function.arguments || '{}');
switch (toolCall.function.name) {
case 'get_connections': {
const availableConnections = connections.map((connection) => ({
id: connection.id,
name: connection.name,
type: connection.config?.type,
host: (connection.config as any)?.host || (connection.config as any)?.addr || '',
}));
content = JSON.stringify(availableConnections);
success = true;
break;
}
case 'get_databases': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const result = await mergedRuntime.getDatabases(buildRpcConnectionConfig(connection.config) as any);
if (result?.success && Array.isArray(result.data)) {
let databaseNames = result.data.map((row: any) => row.Database || row.database || Object.values(row)[0]);
if (databaseNames.length > 50) {
databaseNames = [...databaseNames.slice(0, 50), '...(截断)'];
}
content = JSON.stringify(databaseNames);
success = true;
} else {
content = result?.message || 'Failed to fetch DBs';
}
} catch (error: any) {
content = `获取数据库列表失败: ${error?.message || error}`;
}
break;
}
case 'get_tables': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const rawDbName = args.dbName || args.database;
const safeDbName = rawDbName ? String(rawDbName).trim() : '';
const result = await mergedRuntime.getTables(buildRpcConnectionConfig(connection.config) as any, safeDbName);
if (result?.success && Array.isArray(result.data)) {
let tableNames = normalizeTableList(result.data);
if (tableNames.length > 150) {
tableNames = [...tableNames.slice(0, 150), '...(截断)'];
}
content = JSON.stringify(tableNames);
success = true;
toolContextMap.set(`${args.connectionId}:${safeDbName}`, {
connectionId: args.connectionId,
dbName: safeDbName,
tables: tableNames.filter((tableName) => tableName !== '...(截断)'),
});
} else {
content = result?.message || 'Failed to fetch Tables';
}
} catch (error: any) {
content = `获取表列表失败: ${error?.message || error}`;
}
break;
}
case 'get_columns': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const result = await mergedRuntime.getColumns(buildRpcConnectionConfig(connection.config) as any, safeDbName, safeTable);
if (result?.success && Array.isArray(result.data)) {
const columns = normalizeColumns(result.data);
const fieldNames = columns.map((column) => column.field).join(', ');
content = `⚠️ 以下为 ${safeTable} 表的真实字段列表。生成 SQL 时只能使用这些 field 值作为列名,必须原样使用,禁止修改、缩写或自行拼凑字段名。\n可用字段${fieldNames}\n详细信息${JSON.stringify(columns)}`;
success = true;
} else {
content = result?.message || 'Failed to fetch columns';
}
} catch (error: any) {
content = `获取字段列表失败: ${error?.message || error}`;
}
break;
}
case 'get_indexes': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const result = await mergedRuntime.getIndexes(buildRpcConnectionConfig(connection.config) as any, safeDbName, safeTable);
if (result?.success && Array.isArray(result.data)) {
content = JSON.stringify(result.data);
success = true;
} else {
content = result?.message || 'Failed to fetch indexes';
}
} catch (error: any) {
content = `获取索引定义失败: ${error?.message || error}`;
}
break;
}
case 'get_foreign_keys': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const result = await mergedRuntime.getForeignKeys(buildRpcConnectionConfig(connection.config) as any, safeDbName, safeTable);
if (result?.success && Array.isArray(result.data)) {
content = JSON.stringify(result.data);
success = true;
} else {
content = result?.message || 'Failed to fetch foreign keys';
}
} catch (error: any) {
content = `获取外键关系失败: ${error?.message || error}`;
}
break;
}
case 'get_triggers': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const result = await mergedRuntime.getTriggers(buildRpcConnectionConfig(connection.config) as any, safeDbName, safeTable);
if (result?.success && Array.isArray(result.data)) {
content = JSON.stringify(result.data);
success = true;
} else {
content = result?.message || 'Failed to fetch triggers';
}
} catch (error: any) {
content = `获取触发器定义失败: ${error?.message || error}`;
}
break;
}
case 'get_table_ddl': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const rpcConfig = buildRpcConnectionConfig(connection.config) as any;
const result = await resolveAITableSchemaToolResult({
tableName: safeTable,
fetchDDL: () => mergedRuntime.showCreateTable(rpcConfig, safeDbName, safeTable),
fetchColumns: () => mergedRuntime.getColumns(rpcConfig, safeDbName, safeTable),
});
content = result.content;
success = result.success;
} catch (error: any) {
content = `获取建表语句失败: ${error?.message || error}`;
}
break;
}
case 'execute_sql': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeSql = args.sql ? String(args.sql).trim() : '';
if (typeof mergedRuntime.checkSQL === 'function') {
const checkResult = await mergedRuntime.checkSQL(safeSql);
if (checkResult && checkResult.allowed === false) {
content = `安全策略拦截:当前安全级别不允许执行 ${checkResult.operationType} 类型的 SQL。请将 SQL 展示给用户,让用户手动执行。`;
break;
}
}
const finalSql = buildAIReadonlyPreviewSQL(connection.config?.type || '', safeSql, 50, connection.config?.driver || '');
const result = await mergedRuntime.query(buildRpcConnectionConfig(connection.config) as any, safeDbName, finalSql);
if (result?.success) {
const rows = Array.isArray(result.data) ? result.data : [];
content = JSON.stringify({ rowCount: rows.length, data: rows.slice(0, 50) });
success = true;
} else {
content = result?.message || 'SQL 执行失败';
}
} catch (error: any) {
content = `SQL 执行异常: ${error?.message || error}`;
}
break;
}
default: {
if (!descriptor) {
content = `Unknown function: ${toolCall.function.name}`;
break;
}
try {
const result = await mergedRuntime.callMCPTool?.(toolCall.function.name, toolCall.function.arguments || '{}');
content = String(result?.content || (result?.isError ? 'MCP 工具调用失败' : ''));
success = !!result && !result.isError;
} catch (error: any) {
content = `MCP 工具调用失败: ${error?.message || error}`;
}
break;
}
}
} catch (error: any) {
content = error?.message || String(error);
}
return {
content,
success,
toolName: buildToolName(toolCall, descriptor),
};
}
export function buildToolResultMessage(params: {
id: string;
timestamp: number;
toolCall: AIToolCall;
execution: ExecuteLocalAIToolCallResult;
}): AIChatMessage {
const { id, timestamp, toolCall, execution } = params;
return {
id,
role: 'tool',
content: execution.content,
timestamp,
tool_call_id: toolCall.id,
tool_name: execution.toolName,
success: execution.success,
};
}