♻️ refactor(ai-tools): 拆分本地工具执行链路

- 抽离 AI 本地工具默认运行时封装

- 拆分数据库工具执行器并保留现有返回协议

- 收敛 aiLocalToolExecutor 为工具路由与结果包装
This commit is contained in:
Syngnat
2026-06-09 00:16:57 +08:00
parent 20e817b203
commit f7ed6f8e61
3 changed files with 638 additions and 648 deletions

View File

@@ -0,0 +1,487 @@
import type { SavedConnection } from '../../types';
import { buildRpcConnectionConfig } from '../../utils/connectionRpcConfig';
import { buildAIReadonlyPreviewSQL } from '../../utils/aiSqlLimit';
import { buildPaginatedSelectSQL, quoteQualifiedIdent } from '../../utils/sql';
import { resolveAITableSchemaToolResult } from '../../utils/aiTableSchemaTool';
import type { AILocalToolRuntime, AIToolContextEntry } from './aiLocalToolRuntime';
interface ExecuteDatabaseToolCallOptions {
toolName: string;
args: Record<string, any>;
connections: SavedConnection[];
toolContextMap: Map<string, AIToolContextEntry>;
runtime: AILocalToolRuntime;
}
interface ToolExecutionResult {
content: string;
success: boolean;
}
const findConnection = (connections: SavedConnection[], connectionId: string) =>
connections.find((connection) => connection.id === connectionId);
const normalizeTableList = (rows: any[]): string[] =>
rows.map((row) => row.Table || row.table || (Object.values(row)[0] as string));
const normalizeColumns = (rows: any[]) =>
rows.map((column) => {
const keys = Object.keys(column);
return {
field: column.Field || column.field || column.COLUMN_NAME || column.column_name || column.Name || column.name || (keys.length > 0 ? column[keys[0]] : ''),
type: column.Type || column.type || column.DATA_TYPE || column.data_type || (keys.length > 1 ? column[keys[1]] : ''),
nullable: column.Null || column.null || column.IS_NULLABLE || column.is_nullable || column.Nullable || column.nullable || '',
default: column.Default || column.default || column.COLUMN_DEFAULT || column.column_default || column.DefaultValue || '',
comment: column.Comment || column.comment || column.COLUMN_COMMENT || column.column_comment || column.Description || '',
};
});
const normalizeColumnsWithTable = (rows: any[]) =>
rows.map((column) => {
const keys = Object.keys(column);
return {
tableName: column.TableName || column.tableName || column.TABLE_NAME || column.table_name || (keys.length > 0 ? column[keys[0]] : ''),
name: column.Name || column.name || column.COLUMN_NAME || column.column_name || (keys.length > 1 ? column[keys[1]] : ''),
type: column.Type || column.type || column.DATA_TYPE || column.data_type || (keys.length > 2 ? column[keys[2]] : ''),
comment: column.Comment || column.comment || column.COLUMN_COMMENT || column.column_comment || '',
};
});
const normalizePreviewLimit = (input: unknown): number => {
const value = Math.floor(Number(input) || 20);
if (value < 1) return 1;
if (value > 100) return 100;
return value;
};
const normalizeTableLimit = (input: unknown): number => {
const value = Math.floor(Number(input) || 80);
if (value < 1) return 1;
if (value > 200) return 200;
return value;
};
const normalizePerTableColumnLimit = (input: unknown): number => {
const value = Math.floor(Number(input) || 8);
if (value < 1) return 1;
if (value > 30) return 30;
return value;
};
const buildPreviewSQLForTable = (connection: SavedConnection, tableName: string, limit: number): string => {
const dbType = String(connection.config?.type || '').trim();
return buildPaginatedSelectSQL(
dbType,
`SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)}`,
'',
limit,
0,
);
};
const resolveConnectionOrFailure = (
connections: SavedConnection[],
connectionId: string,
): { connection: SavedConnection | null; failure?: ToolExecutionResult } => {
const connection = findConnection(connections, connectionId);
if (!connection) {
return {
connection: null,
failure: {
content: 'Connection not found',
success: false,
},
};
}
return { connection };
};
export async function executeDatabaseToolCall(
options: ExecuteDatabaseToolCallOptions,
): Promise<ToolExecutionResult | null> {
const { toolName, args, connections, toolContextMap, runtime } = options;
switch (toolName) {
case 'get_connections': {
const availableConnections = connections.map((connection) => ({
id: connection.id,
name: connection.name,
type: connection.config?.type,
host: (connection.config as any)?.host || (connection.config as any)?.addr || '',
}));
return {
content: JSON.stringify(availableConnections),
success: true,
};
}
case 'get_databases': {
const resolved = resolveConnectionOrFailure(connections, args.connectionId);
if (resolved.failure || !resolved.connection) return resolved.failure || null;
try {
const result = await runtime.getDatabases(buildRpcConnectionConfig(resolved.connection.config) as any);
if (result?.success && Array.isArray(result.data)) {
let databaseNames = result.data.map((row: any) => row.Database || row.database || Object.values(row)[0]);
if (databaseNames.length > 50) {
databaseNames = [...databaseNames.slice(0, 50), '...(截断)'];
}
return { content: JSON.stringify(databaseNames), success: true };
}
return { content: result?.message || 'Failed to fetch DBs', success: false };
} catch (error: any) {
return { content: `获取数据库列表失败: ${error?.message || error}`, success: false };
}
}
case 'get_tables': {
const resolved = resolveConnectionOrFailure(connections, args.connectionId);
if (resolved.failure || !resolved.connection) return resolved.failure || null;
try {
const rawDbName = args.dbName || args.database;
const safeDbName = rawDbName ? String(rawDbName).trim() : '';
const result = await runtime.getTables(buildRpcConnectionConfig(resolved.connection.config) as any, safeDbName);
if (result?.success && Array.isArray(result.data)) {
let tableNames = normalizeTableList(result.data);
if (tableNames.length > 150) {
tableNames = [...tableNames.slice(0, 150), '...(截断)'];
}
toolContextMap.set(`${args.connectionId}:${safeDbName}`, {
connectionId: args.connectionId,
dbName: safeDbName,
tables: tableNames.filter((tableName) => tableName !== '...(截断)'),
});
return { content: JSON.stringify(tableNames), success: true };
}
return { content: result?.message || 'Failed to fetch Tables', success: false };
} catch (error: any) {
return { content: `获取表列表失败: ${error?.message || error}`, success: false };
}
}
case 'get_all_columns': {
const resolved = resolveConnectionOrFailure(connections, args.connectionId);
if (resolved.failure || !resolved.connection) return resolved.failure || null;
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const result = await runtime.getAllColumns(buildRpcConnectionConfig(resolved.connection.config) as any, safeDbName);
if (result?.success && Array.isArray(result.data)) {
const allColumns = normalizeColumnsWithTable(result.data);
const tableNames = Array.from(new Set(allColumns.map((column) => column.tableName).filter(Boolean)));
const limitedColumns = allColumns.slice(0, 400);
return {
content: JSON.stringify({
dbName: safeDbName,
tableCount: tableNames.length,
totalColumns: allColumns.length,
truncated: allColumns.length > limitedColumns.length,
columns: limitedColumns,
}),
success: true,
};
}
return { content: result?.message || 'Failed to fetch all columns', success: false };
} catch (error: any) {
return { content: `获取全库字段摘要失败: ${error?.message || error}`, success: false };
}
}
case 'get_columns': {
const resolved = resolveConnectionOrFailure(connections, args.connectionId);
if (resolved.failure || !resolved.connection) return resolved.failure || null;
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const result = await runtime.getColumns(buildRpcConnectionConfig(resolved.connection.config) as any, safeDbName, safeTable);
if (result?.success && Array.isArray(result.data)) {
const columns = normalizeColumns(result.data);
const fieldNames = columns.map((column) => column.field).join(', ');
return {
content: `⚠️ 以下为 ${safeTable} 表的真实字段列表。生成 SQL 时只能使用这些 field 值作为列名,必须原样使用,禁止修改、缩写或自行拼凑字段名。\n可用字段${fieldNames}\n详细信息${JSON.stringify(columns)}`,
success: true,
};
}
return { content: result?.message || 'Failed to fetch columns', success: false };
} catch (error: any) {
return { content: `获取字段列表失败: ${error?.message || error}`, success: false };
}
}
case 'get_indexes': {
const resolved = resolveConnectionOrFailure(connections, args.connectionId);
if (resolved.failure || !resolved.connection) return resolved.failure || null;
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const result = await runtime.getIndexes(buildRpcConnectionConfig(resolved.connection.config) as any, safeDbName, safeTable);
return {
content: result?.success && Array.isArray(result.data) ? JSON.stringify(result.data) : (result?.message || 'Failed to fetch indexes'),
success: !!result?.success && Array.isArray(result.data),
};
} catch (error: any) {
return { content: `获取索引定义失败: ${error?.message || error}`, success: false };
}
}
case 'get_foreign_keys': {
const resolved = resolveConnectionOrFailure(connections, args.connectionId);
if (resolved.failure || !resolved.connection) return resolved.failure || null;
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const result = await runtime.getForeignKeys(buildRpcConnectionConfig(resolved.connection.config) as any, safeDbName, safeTable);
return {
content: result?.success && Array.isArray(result.data) ? JSON.stringify(result.data) : (result?.message || 'Failed to fetch foreign keys'),
success: !!result?.success && Array.isArray(result.data),
};
} catch (error: any) {
return { content: `获取外键关系失败: ${error?.message || error}`, success: false };
}
}
case 'get_triggers': {
const resolved = resolveConnectionOrFailure(connections, args.connectionId);
if (resolved.failure || !resolved.connection) return resolved.failure || null;
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const result = await runtime.getTriggers(buildRpcConnectionConfig(resolved.connection.config) as any, safeDbName, safeTable);
return {
content: result?.success && Array.isArray(result.data) ? JSON.stringify(result.data) : (result?.message || 'Failed to fetch triggers'),
success: !!result?.success && Array.isArray(result.data),
};
} catch (error: any) {
return { content: `获取触发器定义失败: ${error?.message || error}`, success: false };
}
}
case 'get_table_ddl': {
const resolved = resolveConnectionOrFailure(connections, args.connectionId);
if (resolved.failure || !resolved.connection) return resolved.failure || null;
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const rpcConfig = buildRpcConnectionConfig(resolved.connection.config) as any;
const result = await resolveAITableSchemaToolResult({
tableName: safeTable,
fetchDDL: () => runtime.showCreateTable(rpcConfig, safeDbName, safeTable),
fetchColumns: () => runtime.getColumns(rpcConfig, safeDbName, safeTable),
});
return { content: result.content, success: result.success };
} catch (error: any) {
return { content: `获取建表语句失败: ${error?.message || error}`, success: false };
}
}
case 'inspect_table_bundle': {
const resolved = resolveConnectionOrFailure(connections, args.connectionId);
if (resolved.failure || !resolved.connection) return resolved.failure || null;
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
if (!safeTable) return { content: 'tableName 不能为空', success: false };
const includeSampleRows = args.includeSampleRows === true;
const sampleLimit = normalizePreviewLimit(args.sampleLimit ?? 10);
const rpcConfig = buildRpcConnectionConfig(resolved.connection.config) as any;
const results = await Promise.allSettled([
runtime.getColumns(rpcConfig, safeDbName, safeTable),
runtime.getIndexes(rpcConfig, safeDbName, safeTable),
runtime.getForeignKeys(rpcConfig, safeDbName, safeTable),
runtime.getTriggers(rpcConfig, safeDbName, safeTable),
resolveAITableSchemaToolResult({
tableName: safeTable,
fetchDDL: () => runtime.showCreateTable(rpcConfig, safeDbName, safeTable),
fetchColumns: () => runtime.getColumns(rpcConfig, safeDbName, safeTable),
}),
includeSampleRows
? runtime.query(rpcConfig, safeDbName, buildPreviewSQLForTable(resolved.connection, safeTable, sampleLimit))
: Promise.resolve(undefined),
]);
const warnings: string[] = [];
const payload: Record<string, unknown> = {
dbName: safeDbName,
tableName: safeTable,
columns: [],
indexes: [],
foreignKeys: [],
triggers: [],
ddl: '',
};
const columnsResult = results[0];
const indexesResult = results[1];
const foreignKeysResult = results[2];
const triggersResult = results[3];
const ddlResult = results[4];
const sampleRowsResult = results[5];
if (columnsResult.status === 'fulfilled' && columnsResult.value?.success && Array.isArray(columnsResult.value.data)) {
payload.columns = normalizeColumns(columnsResult.value.data);
} else {
warnings.push(`字段列表获取失败:${columnsResult.status === 'fulfilled' ? (columnsResult.value?.message || '未知错误') : String(columnsResult.reason)}`);
}
if (indexesResult.status === 'fulfilled' && indexesResult.value?.success && Array.isArray(indexesResult.value.data)) {
payload.indexes = indexesResult.value.data;
} else {
warnings.push(`索引定义获取失败:${indexesResult.status === 'fulfilled' ? (indexesResult.value?.message || '未知错误') : String(indexesResult.reason)}`);
}
if (foreignKeysResult.status === 'fulfilled' && foreignKeysResult.value?.success && Array.isArray(foreignKeysResult.value.data)) {
payload.foreignKeys = foreignKeysResult.value.data;
} else {
warnings.push(`外键关系获取失败:${foreignKeysResult.status === 'fulfilled' ? (foreignKeysResult.value?.message || '未知错误') : String(foreignKeysResult.reason)}`);
}
if (triggersResult.status === 'fulfilled' && triggersResult.value?.success && Array.isArray(triggersResult.value.data)) {
payload.triggers = triggersResult.value.data;
} else {
warnings.push(`触发器获取失败:${triggersResult.status === 'fulfilled' ? (triggersResult.value?.message || '未知错误') : String(triggersResult.reason)}`);
}
if (ddlResult.status === 'fulfilled' && ddlResult.value?.success) {
payload.ddl = ddlResult.value.content;
} else {
warnings.push(`DDL 获取失败:${ddlResult.status === 'fulfilled' ? (ddlResult.value?.content || '未知错误') : String(ddlResult.reason)}`);
}
if (includeSampleRows) {
if (sampleRowsResult.status === 'fulfilled' && sampleRowsResult.value?.success) {
const rows = Array.isArray(sampleRowsResult.value.data) ? sampleRowsResult.value.data : [];
payload.sampleRows = {
limit: sampleLimit,
rowCount: rows.length,
rows: rows.slice(0, sampleLimit),
};
} else {
warnings.push(`样例数据获取失败:${sampleRowsResult.status === 'fulfilled' ? (sampleRowsResult.value?.message || '未知错误') : String(sampleRowsResult.reason)}`);
}
}
if (warnings.length > 0) {
payload.warnings = warnings;
}
return { content: JSON.stringify(payload), success: true };
} catch (error: any) {
return { content: `获取表结构快照失败: ${error?.message || error}`, success: false };
}
}
case 'inspect_database_bundle': {
const resolved = resolveConnectionOrFailure(connections, args.connectionId);
if (resolved.failure || !resolved.connection) return resolved.failure || null;
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
if (!safeDbName) return { content: 'dbName 不能为空', success: false };
const includeColumns = args.includeColumns !== false;
const tableLimit = normalizeTableLimit(args.tableLimit);
const perTableColumnLimit = normalizePerTableColumnLimit(args.perTableColumnLimit);
const rpcConfig = buildRpcConnectionConfig(resolved.connection.config) as any;
const results = await Promise.allSettled([
runtime.getTables(rpcConfig, safeDbName),
includeColumns ? runtime.getAllColumns(rpcConfig, safeDbName) : Promise.resolve(undefined),
]);
const warnings: string[] = [];
const tablesResult = results[0];
const allColumnsResult = results[1];
const allColumns = allColumnsResult.status === 'fulfilled' && allColumnsResult.value?.success && Array.isArray(allColumnsResult.value.data)
? normalizeColumnsWithTable(allColumnsResult.value.data)
: [];
const tableNamesFromColumns = Array.from(new Set(allColumns.map((column) => column.tableName).filter(Boolean)));
let tableNames: string[] = [];
if (tablesResult.status === 'fulfilled' && tablesResult.value?.success && Array.isArray(tablesResult.value.data)) {
tableNames = normalizeTableList(tablesResult.value.data).filter(Boolean);
} else if (tableNamesFromColumns.length > 0) {
tableNames = tableNamesFromColumns;
warnings.push(`表列表获取失败,已退回字段摘要推断:${tablesResult.status === 'fulfilled' ? (tablesResult.value?.message || '未知错误') : String(tablesResult.reason)}`);
} else {
warnings.push(`表列表获取失败:${tablesResult.status === 'fulfilled' ? (tablesResult.value?.message || '未知错误') : String(tablesResult.reason)}`);
}
if (includeColumns && allColumnsResult.status === 'fulfilled' && (!allColumnsResult.value?.success || !Array.isArray(allColumnsResult.value.data))) {
warnings.push(`字段摘要获取失败:${allColumnsResult.value?.message || '未知错误'}`);
} else if (includeColumns && allColumnsResult.status === 'rejected') {
warnings.push(`字段摘要获取失败:${String(allColumnsResult.reason)}`);
}
const uniqueTableNames = Array.from(new Set(tableNames.filter(Boolean)));
const visibleTableNames = uniqueTableNames.slice(0, tableLimit);
const columnsByTable = new Map<string, ReturnType<typeof normalizeColumnsWithTable>>();
allColumns.forEach((column) => {
const tableName = String(column.tableName || '').trim();
if (!tableName) return;
const current = columnsByTable.get(tableName) || [];
current.push(column);
columnsByTable.set(tableName, current);
});
const payload: Record<string, unknown> = {
dbName: safeDbName,
tableCount: uniqueTableNames.length,
totalColumns: allColumns.length,
tables: visibleTableNames,
truncatedTables: uniqueTableNames.length > visibleTableNames.length,
};
if (includeColumns) {
payload.tableSummaries = visibleTableNames.map((tableName) => {
const tableColumns = columnsByTable.get(tableName) || [];
return {
tableName,
columnCount: tableColumns.length,
truncatedColumns: tableColumns.length > perTableColumnLimit,
columns: tableColumns.slice(0, perTableColumnLimit),
};
});
}
if (warnings.length > 0) {
payload.warnings = warnings;
}
return { content: JSON.stringify(payload), success: true };
} catch (error: any) {
return { content: `获取数据库结构总览失败: ${error?.message || error}`, success: false };
}
}
case 'preview_table_rows': {
const resolved = resolveConnectionOrFailure(connections, args.connectionId);
if (resolved.failure || !resolved.connection) return resolved.failure || null;
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
if (!safeTable) return { content: 'tableName 不能为空', success: false };
const safeLimit = normalizePreviewLimit(args.limit);
const previewSQL = buildPreviewSQLForTable(resolved.connection, safeTable, safeLimit);
const result = await runtime.query(buildRpcConnectionConfig(resolved.connection.config) as any, safeDbName, previewSQL);
if (result?.success) {
const rows = Array.isArray(result.data) ? result.data : [];
return {
content: JSON.stringify({
dbName: safeDbName,
tableName: safeTable,
limit: safeLimit,
rowCount: rows.length,
rows: rows.slice(0, safeLimit),
}),
success: true,
};
}
return { content: result?.message || 'Failed to preview table rows', success: false };
} catch (error: any) {
return { content: `预览表样例数据失败: ${error?.message || error}`, success: false };
}
}
case 'execute_sql': {
const resolved = resolveConnectionOrFailure(connections, args.connectionId);
if (resolved.failure || !resolved.connection) return resolved.failure || null;
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeSql = args.sql ? String(args.sql).trim() : '';
if (typeof runtime.checkSQL === 'function') {
const checkResult = await runtime.checkSQL(safeSql);
if (checkResult && checkResult.allowed === false) {
return {
content: `安全策略拦截:当前安全级别不允许执行 ${checkResult.operationType} 类型的 SQL。请将 SQL 展示给用户,让用户手动执行。`,
success: false,
};
}
}
const finalSql = buildAIReadonlyPreviewSQL(
resolved.connection.config?.type || '',
safeSql,
50,
resolved.connection.config?.driver || '',
);
const result = await runtime.query(buildRpcConnectionConfig(resolved.connection.config) as any, safeDbName, finalSql);
if (result?.success) {
const rows = Array.isArray(result.data) ? result.data : [];
return {
content: JSON.stringify({ rowCount: rows.length, data: rows.slice(0, 50) }),
success: true,
};
}
return { content: result?.message || 'SQL 执行失败', success: false };
} catch (error: any) {
return { content: `SQL 执行异常: ${error?.message || error}`, success: false };
}
}
default:
return null;
}
}

View File

@@ -1,5 +1,3 @@
import { DBGetAllColumns, DBGetDatabases, DBGetTables } from '../../../wailsjs/go/app/App';
import type { SqlLog } from '../../store';
import type {
AIChatMessage,
@@ -13,34 +11,15 @@ import type {
SqlSnippet,
TabData,
} from '../../types';
import { buildRpcConnectionConfig } from '../../utils/connectionRpcConfig';
import { buildAIReadonlyPreviewSQL } from '../../utils/aiSqlLimit';
import { buildPaginatedSelectSQL, quoteQualifiedIdent } from '../../utils/sql';
import { resolveAITableSchemaToolResult } from '../../utils/aiTableSchemaTool';
import { executeDatabaseToolCall } from './aiDatabaseToolExecutor';
import {
executeSnapshotInspectionToolCall,
type AISnapshotInspectionRuntime,
} from './aiSnapshotInspectionToolExecutor';
buildDefaultLocalToolRuntime,
type AILocalToolRuntime,
type AIToolContextEntry,
} from './aiLocalToolRuntime';
import { executeSnapshotInspectionToolCall } from './aiSnapshotInspectionToolExecutor';
export interface AIToolContextEntry {
connectionId: string;
dbName: string;
tables: string[];
}
export interface AILocalToolRuntime extends AISnapshotInspectionRuntime {
getDatabases: (config: any) => Promise<any>;
getTables: (config: any, dbName: string) => Promise<any>;
getAllColumns: (config: any, dbName: string) => Promise<any>;
getColumns: (config: any, dbName: string, tableName: string) => Promise<any>;
getIndexes: (config: any, dbName: string, tableName: string) => Promise<any>;
getForeignKeys: (config: any, dbName: string, tableName: string) => Promise<any>;
getTriggers: (config: any, dbName: string, tableName: string) => Promise<any>;
showCreateTable: (config: any, dbName: string, tableName: string) => Promise<any>;
query: (config: any, dbName: string, sql: string) => Promise<any>;
checkSQL?: (sql: string) => Promise<{ allowed?: boolean; operationType?: string } | undefined>;
callMCPTool?: (name: string, args: string) => Promise<{ content?: string; isError?: boolean } | undefined>;
}
export type { AILocalToolRuntime, AIToolContextEntry } from './aiLocalToolRuntime';
export interface ExecuteLocalAIToolCallOptions {
toolCall: AIToolCall;
@@ -66,146 +45,9 @@ export interface ExecuteLocalAIToolCallResult {
toolName: string;
}
const buildDefaultRuntime = (): AILocalToolRuntime => ({
getDatabases: DBGetDatabases,
getTables: DBGetTables,
getAllColumns: DBGetAllColumns,
getColumns: async (config, dbName, tableName) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBGetColumns(config, dbName, tableName);
},
getIndexes: async (config, dbName, tableName) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBGetIndexes(config, dbName, tableName);
},
getForeignKeys: async (config, dbName, tableName) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBGetForeignKeys(config, dbName, tableName);
},
getTriggers: async (config, dbName, tableName) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBGetTriggers(config, dbName, tableName);
},
showCreateTable: async (config, dbName, tableName) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBShowCreateTable(config, dbName, tableName);
},
query: async (config, dbName, sql) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBQuery(config, dbName, sql);
},
checkSQL: async (sql) => {
const service = (window as any).go?.aiservice?.Service;
if (typeof service?.AICheckSQL !== 'function') {
return undefined;
}
return service.AICheckSQL(sql);
},
callMCPTool: async (name, args) => {
const service = (window as any).go?.aiservice?.Service;
if (typeof service?.AICallMCPTool !== 'function') {
return undefined;
}
return service.AICallMCPTool(name, args);
},
getAIRuntimeState: async () => {
const service = (window as any).go?.aiservice?.Service;
if (!service) {
return undefined;
}
const [providers, activeProviderId, safetyLevel, contextLevel] = await Promise.all([
typeof service.AIGetProviders === 'function' ? service.AIGetProviders() : Promise.resolve([]),
typeof service.AIGetActiveProvider === 'function' ? service.AIGetActiveProvider() : Promise.resolve(''),
typeof service.AIGetSafetyLevel === 'function' ? service.AIGetSafetyLevel() : Promise.resolve(''),
typeof service.AIGetContextLevel === 'function' ? service.AIGetContextLevel() : Promise.resolve(''),
]);
return {
providers: Array.isArray(providers) ? providers : [],
activeProviderId: String(activeProviderId || '').trim(),
safetyLevel: String(safetyLevel || '').trim(),
contextLevel: String(contextLevel || '').trim(),
};
},
getMCPServers: async () => {
const service = (window as any).go?.aiservice?.Service;
if (typeof service?.AIGetMCPServers !== 'function') {
return undefined;
}
return service.AIGetMCPServers();
},
getMCPClientInstallStatuses: async () => {
const service = (window as any).go?.aiservice?.Service;
if (typeof service?.AIGetMCPClientInstallStatuses !== 'function') {
return undefined;
}
return service.AIGetMCPClientInstallStatuses();
},
});
const normalizeTableList = (rows: any[]): string[] =>
rows.map((row) => row.Table || row.table || (Object.values(row)[0] as string));
const normalizeColumns = (rows: any[]) =>
rows.map((column) => {
const keys = Object.keys(column);
return {
field: column.Field || column.field || column.COLUMN_NAME || column.column_name || column.Name || column.name || (keys.length > 0 ? column[keys[0]] : ''),
type: column.Type || column.type || column.DATA_TYPE || column.data_type || (keys.length > 1 ? column[keys[1]] : ''),
nullable: column.Null || column.null || column.IS_NULLABLE || column.is_nullable || column.Nullable || column.nullable || '',
default: column.Default || column.default || column.COLUMN_DEFAULT || column.column_default || column.DefaultValue || '',
comment: column.Comment || column.comment || column.COLUMN_COMMENT || column.column_comment || column.Description || '',
};
});
const normalizeColumnsWithTable = (rows: any[]) =>
rows.map((column) => {
const keys = Object.keys(column);
return {
tableName: column.TableName || column.tableName || column.TABLE_NAME || column.table_name || (keys.length > 0 ? column[keys[0]] : ''),
name: column.Name || column.name || column.COLUMN_NAME || column.column_name || (keys.length > 1 ? column[keys[1]] : ''),
type: column.Type || column.type || column.DATA_TYPE || column.data_type || (keys.length > 2 ? column[keys[2]] : ''),
comment: column.Comment || column.comment || column.COLUMN_COMMENT || column.column_comment || '',
};
});
const buildToolName = (toolCall: AIToolCall, descriptor?: AIMCPToolDescriptor) =>
descriptor?.title || descriptor?.originalName || toolCall.function.name;
const findConnection = (connections: SavedConnection[], connectionId: string) =>
connections.find((connection) => connection.id === connectionId);
const normalizePreviewLimit = (input: unknown): number => {
const value = Math.floor(Number(input) || 20);
if (value < 1) return 1;
if (value > 100) return 100;
return value;
};
const normalizeTableLimit = (input: unknown): number => {
const value = Math.floor(Number(input) || 80);
if (value < 1) return 1;
if (value > 200) return 200;
return value;
};
const normalizePerTableColumnLimit = (input: unknown): number => {
const value = Math.floor(Number(input) || 8);
if (value < 1) return 1;
if (value > 30) return 30;
return value;
};
const buildPreviewSQLForTable = (connection: SavedConnection, tableName: string, limit: number): string => {
const dbType = String(connection.config?.type || '').trim();
return buildPaginatedSelectSQL(
dbType,
`SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)}`,
'',
limit,
0,
);
};
export async function executeLocalAIToolCall({
toolCall,
connections,
@@ -223,13 +65,12 @@ export async function executeLocalAIToolCall({
dynamicModels = [],
runtime,
}: ExecuteLocalAIToolCallOptions): Promise<ExecuteLocalAIToolCallResult> {
const mergedRuntime = { ...buildDefaultRuntime(), ...(runtime || {}) };
const mergedRuntime: AILocalToolRuntime = { ...buildDefaultLocalToolRuntime(), ...(runtime || {}) };
const descriptor = mcpTools.find((tool) => tool.alias === toolCall.function.name);
let content = '';
let success = false;
try {
const args = JSON.parse(toolCall.function.arguments || '{}');
const snapshotInspectionResult = await executeSnapshotInspectionToolCall({
toolName: toolCall.function.name,
args,
@@ -248,496 +89,57 @@ export async function executeLocalAIToolCall({
runtime: mergedRuntime,
});
if (snapshotInspectionResult) {
content = snapshotInspectionResult.content;
success = snapshotInspectionResult.success;
return {
content,
success,
content: snapshotInspectionResult.content,
success: snapshotInspectionResult.success,
toolName: buildToolName(toolCall, descriptor),
};
}
switch (toolCall.function.name) {
case 'get_connections': {
const availableConnections = connections.map((connection) => ({
id: connection.id,
name: connection.name,
type: connection.config?.type,
host: (connection.config as any)?.host || (connection.config as any)?.addr || '',
}));
content = JSON.stringify(availableConnections);
success = true;
break;
}
case 'get_databases': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const result = await mergedRuntime.getDatabases(buildRpcConnectionConfig(connection.config) as any);
if (result?.success && Array.isArray(result.data)) {
let databaseNames = result.data.map((row: any) => row.Database || row.database || Object.values(row)[0]);
if (databaseNames.length > 50) {
databaseNames = [...databaseNames.slice(0, 50), '...(截断)'];
}
content = JSON.stringify(databaseNames);
success = true;
} else {
content = result?.message || 'Failed to fetch DBs';
}
} catch (error: any) {
content = `获取数据库列表失败: ${error?.message || error}`;
}
break;
}
case 'get_tables': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const rawDbName = args.dbName || args.database;
const safeDbName = rawDbName ? String(rawDbName).trim() : '';
const result = await mergedRuntime.getTables(buildRpcConnectionConfig(connection.config) as any, safeDbName);
if (result?.success && Array.isArray(result.data)) {
let tableNames = normalizeTableList(result.data);
if (tableNames.length > 150) {
tableNames = [...tableNames.slice(0, 150), '...(截断)'];
}
content = JSON.stringify(tableNames);
success = true;
toolContextMap.set(`${args.connectionId}:${safeDbName}`, {
connectionId: args.connectionId,
dbName: safeDbName,
tables: tableNames.filter((tableName) => tableName !== '...(截断)'),
});
} else {
content = result?.message || 'Failed to fetch Tables';
}
} catch (error: any) {
content = `获取表列表失败: ${error?.message || error}`;
}
break;
}
case 'get_all_columns': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const result = await mergedRuntime.getAllColumns(buildRpcConnectionConfig(connection.config) as any, safeDbName);
if (result?.success && Array.isArray(result.data)) {
const allColumns = normalizeColumnsWithTable(result.data);
const tableNames = Array.from(new Set(allColumns.map((column) => column.tableName).filter(Boolean)));
const limitedColumns = allColumns.slice(0, 400);
content = JSON.stringify({
dbName: safeDbName,
tableCount: tableNames.length,
totalColumns: allColumns.length,
truncated: allColumns.length > limitedColumns.length,
columns: limitedColumns,
});
success = true;
} else {
content = result?.message || 'Failed to fetch all columns';
}
} catch (error: any) {
content = `获取全库字段摘要失败: ${error?.message || error}`;
}
break;
}
case 'get_columns': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const result = await mergedRuntime.getColumns(buildRpcConnectionConfig(connection.config) as any, safeDbName, safeTable);
if (result?.success && Array.isArray(result.data)) {
const columns = normalizeColumns(result.data);
const fieldNames = columns.map((column) => column.field).join(', ');
content = `⚠️ 以下为 ${safeTable} 表的真实字段列表。生成 SQL 时只能使用这些 field 值作为列名,必须原样使用,禁止修改、缩写或自行拼凑字段名。\n可用字段${fieldNames}\n详细信息${JSON.stringify(columns)}`;
success = true;
} else {
content = result?.message || 'Failed to fetch columns';
}
} catch (error: any) {
content = `获取字段列表失败: ${error?.message || error}`;
}
break;
}
case 'get_indexes': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const result = await mergedRuntime.getIndexes(buildRpcConnectionConfig(connection.config) as any, safeDbName, safeTable);
if (result?.success && Array.isArray(result.data)) {
content = JSON.stringify(result.data);
success = true;
} else {
content = result?.message || 'Failed to fetch indexes';
}
} catch (error: any) {
content = `获取索引定义失败: ${error?.message || error}`;
}
break;
}
case 'get_foreign_keys': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const result = await mergedRuntime.getForeignKeys(buildRpcConnectionConfig(connection.config) as any, safeDbName, safeTable);
if (result?.success && Array.isArray(result.data)) {
content = JSON.stringify(result.data);
success = true;
} else {
content = result?.message || 'Failed to fetch foreign keys';
}
} catch (error: any) {
content = `获取外键关系失败: ${error?.message || error}`;
}
break;
}
case 'get_triggers': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const result = await mergedRuntime.getTriggers(buildRpcConnectionConfig(connection.config) as any, safeDbName, safeTable);
if (result?.success && Array.isArray(result.data)) {
content = JSON.stringify(result.data);
success = true;
} else {
content = result?.message || 'Failed to fetch triggers';
}
} catch (error: any) {
content = `获取触发器定义失败: ${error?.message || error}`;
}
break;
}
case 'get_table_ddl': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
const rpcConfig = buildRpcConnectionConfig(connection.config) as any;
const result = await resolveAITableSchemaToolResult({
tableName: safeTable,
fetchDDL: () => mergedRuntime.showCreateTable(rpcConfig, safeDbName, safeTable),
fetchColumns: () => mergedRuntime.getColumns(rpcConfig, safeDbName, safeTable),
});
content = result.content;
success = result.success;
} catch (error: any) {
content = `获取建表语句失败: ${error?.message || error}`;
}
break;
}
case 'inspect_table_bundle': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
if (!safeTable) {
content = 'tableName 不能为空';
break;
}
const includeSampleRows = args.includeSampleRows === true;
const sampleLimit = normalizePreviewLimit(args.sampleLimit ?? 10);
const rpcConfig = buildRpcConnectionConfig(connection.config) as any;
const results = await Promise.allSettled([
mergedRuntime.getColumns(rpcConfig, safeDbName, safeTable),
mergedRuntime.getIndexes(rpcConfig, safeDbName, safeTable),
mergedRuntime.getForeignKeys(rpcConfig, safeDbName, safeTable),
mergedRuntime.getTriggers(rpcConfig, safeDbName, safeTable),
resolveAITableSchemaToolResult({
tableName: safeTable,
fetchDDL: () => mergedRuntime.showCreateTable(rpcConfig, safeDbName, safeTable),
fetchColumns: () => mergedRuntime.getColumns(rpcConfig, safeDbName, safeTable),
}),
includeSampleRows
? mergedRuntime.query(rpcConfig, safeDbName, buildPreviewSQLForTable(connection, safeTable, sampleLimit))
: Promise.resolve(undefined),
]);
const warnings: string[] = [];
const columnsResult = results[0];
const indexesResult = results[1];
const foreignKeysResult = results[2];
const triggersResult = results[3];
const ddlResult = results[4];
const sampleRowsResult = results[5];
const databaseToolResult = await executeDatabaseToolCall({
toolName: toolCall.function.name,
args,
connections,
toolContextMap,
runtime: mergedRuntime,
});
if (databaseToolResult) {
return {
content: databaseToolResult.content,
success: databaseToolResult.success,
toolName: buildToolName(toolCall, descriptor),
};
}
const payload: Record<string, unknown> = {
dbName: safeDbName,
tableName: safeTable,
columns: [],
indexes: [],
foreignKeys: [],
triggers: [],
ddl: '',
};
if (!descriptor) {
return {
content: `Unknown function: ${toolCall.function.name}`,
success: false,
toolName: buildToolName(toolCall),
};
}
if (columnsResult.status === 'fulfilled' && columnsResult.value?.success && Array.isArray(columnsResult.value.data)) {
payload.columns = normalizeColumns(columnsResult.value.data);
} else {
warnings.push(`字段列表获取失败:${columnsResult.status === 'fulfilled' ? (columnsResult.value?.message || '未知错误') : String(columnsResult.reason)}`);
}
if (indexesResult.status === 'fulfilled' && indexesResult.value?.success && Array.isArray(indexesResult.value.data)) {
payload.indexes = indexesResult.value.data;
} else {
warnings.push(`索引定义获取失败:${indexesResult.status === 'fulfilled' ? (indexesResult.value?.message || '未知错误') : String(indexesResult.reason)}`);
}
if (foreignKeysResult.status === 'fulfilled' && foreignKeysResult.value?.success && Array.isArray(foreignKeysResult.value.data)) {
payload.foreignKeys = foreignKeysResult.value.data;
} else {
warnings.push(`外键关系获取失败:${foreignKeysResult.status === 'fulfilled' ? (foreignKeysResult.value?.message || '未知错误') : String(foreignKeysResult.reason)}`);
}
if (triggersResult.status === 'fulfilled' && triggersResult.value?.success && Array.isArray(triggersResult.value.data)) {
payload.triggers = triggersResult.value.data;
} else {
warnings.push(`触发器获取失败:${triggersResult.status === 'fulfilled' ? (triggersResult.value?.message || '未知错误') : String(triggersResult.reason)}`);
}
if (ddlResult.status === 'fulfilled' && ddlResult.value?.success) {
payload.ddl = ddlResult.value.content;
} else {
warnings.push(`DDL 获取失败:${ddlResult.status === 'fulfilled' ? (ddlResult.value?.content || '未知错误') : String(ddlResult.reason)}`);
}
if (includeSampleRows) {
if (sampleRowsResult.status === 'fulfilled' && sampleRowsResult.value?.success) {
const rows = Array.isArray(sampleRowsResult.value.data) ? sampleRowsResult.value.data : [];
payload.sampleRows = {
limit: sampleLimit,
rowCount: rows.length,
rows: rows.slice(0, sampleLimit),
};
} else {
warnings.push(`样例数据获取失败:${sampleRowsResult.status === 'fulfilled' ? (sampleRowsResult.value?.message || '未知错误') : String(sampleRowsResult.reason)}`);
}
}
if (warnings.length > 0) {
payload.warnings = warnings;
}
content = JSON.stringify(payload);
success = true;
} catch (error: any) {
content = `获取表结构快照失败: ${error?.message || error}`;
}
break;
}
case 'inspect_database_bundle': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
if (!safeDbName) {
content = 'dbName 不能为空';
break;
}
const includeColumns = args.includeColumns !== false;
const tableLimit = normalizeTableLimit(args.tableLimit);
const perTableColumnLimit = normalizePerTableColumnLimit(args.perTableColumnLimit);
const rpcConfig = buildRpcConnectionConfig(connection.config) as any;
const results = await Promise.allSettled([
mergedRuntime.getTables(rpcConfig, safeDbName),
includeColumns ? mergedRuntime.getAllColumns(rpcConfig, safeDbName) : Promise.resolve(undefined),
]);
const warnings: string[] = [];
const tablesResult = results[0];
const allColumnsResult = results[1];
const allColumns = allColumnsResult.status === 'fulfilled' && allColumnsResult.value?.success && Array.isArray(allColumnsResult.value.data)
? normalizeColumnsWithTable(allColumnsResult.value.data)
: [];
const tableNamesFromColumns = Array.from(new Set(allColumns.map((column) => column.tableName).filter(Boolean)));
let tableNames: string[] = [];
if (tablesResult.status === 'fulfilled' && tablesResult.value?.success && Array.isArray(tablesResult.value.data)) {
tableNames = normalizeTableList(tablesResult.value.data).filter(Boolean);
} else if (tableNamesFromColumns.length > 0) {
tableNames = tableNamesFromColumns;
warnings.push(`表列表获取失败,已退回字段摘要推断:${tablesResult.status === 'fulfilled' ? (tablesResult.value?.message || '未知错误') : String(tablesResult.reason)}`);
} else {
warnings.push(`表列表获取失败:${tablesResult.status === 'fulfilled' ? (tablesResult.value?.message || '未知错误') : String(tablesResult.reason)}`);
}
if (includeColumns && allColumnsResult.status === 'fulfilled' && (!allColumnsResult.value?.success || !Array.isArray(allColumnsResult.value.data))) {
warnings.push(`字段摘要获取失败:${allColumnsResult.value?.message || '未知错误'}`);
} else if (includeColumns && allColumnsResult.status === 'rejected') {
warnings.push(`字段摘要获取失败:${String(allColumnsResult.reason)}`);
}
const uniqueTableNames = Array.from(new Set(tableNames.filter(Boolean)));
const visibleTableNames = uniqueTableNames.slice(0, tableLimit);
const columnsByTable = new Map<string, ReturnType<typeof normalizeColumnsWithTable>>();
allColumns.forEach((column) => {
const tableName = String(column.tableName || '').trim();
if (!tableName) {
return;
}
const current = columnsByTable.get(tableName) || [];
current.push(column);
columnsByTable.set(tableName, current);
});
const payload: Record<string, unknown> = {
dbName: safeDbName,
tableCount: uniqueTableNames.length,
totalColumns: allColumns.length,
tables: visibleTableNames,
truncatedTables: uniqueTableNames.length > visibleTableNames.length,
};
if (includeColumns) {
payload.tableSummaries = visibleTableNames.map((tableName) => {
const tableColumns = columnsByTable.get(tableName) || [];
return {
tableName,
columnCount: tableColumns.length,
truncatedColumns: tableColumns.length > perTableColumnLimit,
columns: tableColumns.slice(0, perTableColumnLimit),
};
});
}
if (warnings.length > 0) {
payload.warnings = warnings;
}
content = JSON.stringify(payload);
success = true;
} catch (error: any) {
content = `获取数据库结构总览失败: ${error?.message || error}`;
}
break;
}
case 'preview_table_rows': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
if (!safeTable) {
content = 'tableName 不能为空';
break;
}
const safeLimit = normalizePreviewLimit(args.limit);
const previewSQL = buildPreviewSQLForTable(connection, safeTable, safeLimit);
const result = await mergedRuntime.query(buildRpcConnectionConfig(connection.config) as any, safeDbName, previewSQL);
if (result?.success) {
const rows = Array.isArray(result.data) ? result.data : [];
content = JSON.stringify({
dbName: safeDbName,
tableName: safeTable,
limit: safeLimit,
rowCount: rows.length,
rows: rows.slice(0, safeLimit),
});
success = true;
} else {
content = result?.message || 'Failed to preview table rows';
}
} catch (error: any) {
content = `预览表样例数据失败: ${error?.message || error}`;
}
break;
}
case 'execute_sql': {
const connection = findConnection(connections, args.connectionId);
if (!connection) {
content = 'Connection not found';
break;
}
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeSql = args.sql ? String(args.sql).trim() : '';
if (typeof mergedRuntime.checkSQL === 'function') {
const checkResult = await mergedRuntime.checkSQL(safeSql);
if (checkResult && checkResult.allowed === false) {
content = `安全策略拦截:当前安全级别不允许执行 ${checkResult.operationType} 类型的 SQL。请将 SQL 展示给用户,让用户手动执行。`;
break;
}
}
const finalSql = buildAIReadonlyPreviewSQL(connection.config?.type || '', safeSql, 50, connection.config?.driver || '');
const result = await mergedRuntime.query(buildRpcConnectionConfig(connection.config) as any, safeDbName, finalSql);
if (result?.success) {
const rows = Array.isArray(result.data) ? result.data : [];
content = JSON.stringify({ rowCount: rows.length, data: rows.slice(0, 50) });
success = true;
} else {
content = result?.message || 'SQL 执行失败';
}
} catch (error: any) {
content = `SQL 执行异常: ${error?.message || error}`;
}
break;
}
default: {
if (!descriptor) {
content = `Unknown function: ${toolCall.function.name}`;
break;
}
try {
const result = await mergedRuntime.callMCPTool?.(toolCall.function.name, toolCall.function.arguments || '{}');
content = String(result?.content || (result?.isError ? 'MCP 工具调用失败' : ''));
success = !!result && !result.isError;
} catch (error: any) {
content = `MCP 工具调用失败: ${error?.message || error}`;
}
break;
}
try {
const result = await mergedRuntime.callMCPTool?.(toolCall.function.name, toolCall.function.arguments || '{}');
return {
content: String(result?.content || (result?.isError ? 'MCP 工具调用失败' : '')),
success: !!result && !result.isError,
toolName: buildToolName(toolCall, descriptor),
};
} catch (error: any) {
return {
content: `MCP 工具调用失败: ${error?.message || error}`,
success: false,
toolName: buildToolName(toolCall, descriptor),
};
}
} catch (error: any) {
content = error?.message || String(error);
return {
content: error?.message || String(error),
success: false,
toolName: buildToolName(toolCall, descriptor),
};
}
return {
content,
success,
toolName: buildToolName(toolCall, descriptor),
};
}
export function buildToolResultMessage(params: {

View File

@@ -0,0 +1,101 @@
import { DBGetAllColumns, DBGetDatabases, DBGetTables } from '../../../wailsjs/go/app/App';
import type { AISnapshotInspectionRuntime } from './aiSnapshotInspectionToolExecutor';
export interface AIToolContextEntry {
connectionId: string;
dbName: string;
tables: string[];
}
export interface AILocalToolRuntime extends AISnapshotInspectionRuntime {
getDatabases: (config: any) => Promise<any>;
getTables: (config: any, dbName: string) => Promise<any>;
getAllColumns: (config: any, dbName: string) => Promise<any>;
getColumns: (config: any, dbName: string, tableName: string) => Promise<any>;
getIndexes: (config: any, dbName: string, tableName: string) => Promise<any>;
getForeignKeys: (config: any, dbName: string, tableName: string) => Promise<any>;
getTriggers: (config: any, dbName: string, tableName: string) => Promise<any>;
showCreateTable: (config: any, dbName: string, tableName: string) => Promise<any>;
query: (config: any, dbName: string, sql: string) => Promise<any>;
checkSQL?: (sql: string) => Promise<{ allowed?: boolean; operationType?: string } | undefined>;
callMCPTool?: (name: string, args: string) => Promise<{ content?: string; isError?: boolean } | undefined>;
}
const getAIService = () => (window as any).go?.aiservice?.Service;
export const buildDefaultLocalToolRuntime = (): AILocalToolRuntime => ({
getDatabases: DBGetDatabases,
getTables: DBGetTables,
getAllColumns: DBGetAllColumns,
getColumns: async (config, dbName, tableName) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBGetColumns(config, dbName, tableName);
},
getIndexes: async (config, dbName, tableName) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBGetIndexes(config, dbName, tableName);
},
getForeignKeys: async (config, dbName, tableName) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBGetForeignKeys(config, dbName, tableName);
},
getTriggers: async (config, dbName, tableName) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBGetTriggers(config, dbName, tableName);
},
showCreateTable: async (config, dbName, tableName) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBShowCreateTable(config, dbName, tableName);
},
query: async (config, dbName, sql) => {
const mod = await import('../../../wailsjs/go/app/App');
return mod.DBQuery(config, dbName, sql);
},
checkSQL: async (sql) => {
const service = getAIService();
if (typeof service?.AICheckSQL !== 'function') {
return undefined;
}
return service.AICheckSQL(sql);
},
callMCPTool: async (name, args) => {
const service = getAIService();
if (typeof service?.AICallMCPTool !== 'function') {
return undefined;
}
return service.AICallMCPTool(name, args);
},
getAIRuntimeState: async () => {
const service = getAIService();
if (!service) {
return undefined;
}
const [providers, activeProviderId, safetyLevel, contextLevel] = await Promise.all([
typeof service.AIGetProviders === 'function' ? service.AIGetProviders() : Promise.resolve([]),
typeof service.AIGetActiveProvider === 'function' ? service.AIGetActiveProvider() : Promise.resolve(''),
typeof service.AIGetSafetyLevel === 'function' ? service.AIGetSafetyLevel() : Promise.resolve(''),
typeof service.AIGetContextLevel === 'function' ? service.AIGetContextLevel() : Promise.resolve(''),
]);
return {
providers: Array.isArray(providers) ? providers : [],
activeProviderId: String(activeProviderId || '').trim(),
safetyLevel: String(safetyLevel || '').trim(),
contextLevel: String(contextLevel || '').trim(),
};
},
getMCPServers: async () => {
const service = getAIService();
if (typeof service?.AIGetMCPServers !== 'function') {
return undefined;
}
return service.AIGetMCPServers();
},
getMCPClientInstallStatuses: async () => {
const service = getAIService();
if (typeof service?.AIGetMCPClientInstallStatuses !== 'function') {
return undefined;
}
return service.AIGetMCPClientInstallStatuses();
},
});