♻️ refactor(ai-tools): 拆分数据库工具执行器逻辑

This commit is contained in:
Syngnat
2026-06-09 11:19:15 +08:00
parent 83972d29b7
commit b5ba49ff8f
3 changed files with 256 additions and 207 deletions

View File

@@ -0,0 +1,186 @@
import type { SavedConnection } from '../../types';
import { buildRpcConnectionConfig } from '../../utils/connectionRpcConfig';
import { resolveAITableSchemaToolResult } from '../../utils/aiTableSchemaTool';
import type { AILocalToolRuntime } from './aiLocalToolRuntime';
import {
buildPreviewSQLForTable,
normalizeColumns,
normalizeColumnsWithTable,
normalizePerTableColumnLimit,
normalizePreviewLimit,
normalizeTableLimit,
normalizeTableList,
} from './aiDatabaseToolHelpers';
interface DatabaseToolExecutionResult {
content: string;
success: boolean;
}
interface InspectDatabaseBundleOptions {
args: Record<string, any>;
connection: SavedConnection;
runtime: AILocalToolRuntime;
}
export const inspectTableBundle = async ({
args,
connection,
runtime,
}: InspectDatabaseBundleOptions): Promise<DatabaseToolExecutionResult> => {
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
if (!safeTable) return { content: 'tableName 不能为空', success: false };
const includeSampleRows = args.includeSampleRows === true;
const sampleLimit = normalizePreviewLimit(args.sampleLimit ?? 10);
const rpcConfig = buildRpcConnectionConfig(connection.config) as any;
const results = await Promise.allSettled([
runtime.getColumns(rpcConfig, safeDbName, safeTable),
runtime.getIndexes(rpcConfig, safeDbName, safeTable),
runtime.getForeignKeys(rpcConfig, safeDbName, safeTable),
runtime.getTriggers(rpcConfig, safeDbName, safeTable),
resolveAITableSchemaToolResult({
tableName: safeTable,
fetchDDL: () => runtime.showCreateTable(rpcConfig, safeDbName, safeTable),
fetchColumns: () => runtime.getColumns(rpcConfig, safeDbName, safeTable),
}),
includeSampleRows
? runtime.query(rpcConfig, safeDbName, buildPreviewSQLForTable(connection, safeTable, sampleLimit))
: Promise.resolve(undefined),
]);
const warnings: string[] = [];
const payload: Record<string, unknown> = {
dbName: safeDbName,
tableName: safeTable,
columns: [],
indexes: [],
foreignKeys: [],
triggers: [],
ddl: '',
};
const columnsResult = results[0];
const indexesResult = results[1];
const foreignKeysResult = results[2];
const triggersResult = results[3];
const ddlResult = results[4];
const sampleRowsResult = results[5];
if (columnsResult.status === 'fulfilled' && columnsResult.value?.success && Array.isArray(columnsResult.value.data)) {
payload.columns = normalizeColumns(columnsResult.value.data);
} else {
warnings.push(`字段列表获取失败:${columnsResult.status === 'fulfilled' ? (columnsResult.value?.message || '未知错误') : String(columnsResult.reason)}`);
}
if (indexesResult.status === 'fulfilled' && indexesResult.value?.success && Array.isArray(indexesResult.value.data)) {
payload.indexes = indexesResult.value.data;
} else {
warnings.push(`索引定义获取失败:${indexesResult.status === 'fulfilled' ? (indexesResult.value?.message || '未知错误') : String(indexesResult.reason)}`);
}
if (foreignKeysResult.status === 'fulfilled' && foreignKeysResult.value?.success && Array.isArray(foreignKeysResult.value.data)) {
payload.foreignKeys = foreignKeysResult.value.data;
} else {
warnings.push(`外键关系获取失败:${foreignKeysResult.status === 'fulfilled' ? (foreignKeysResult.value?.message || '未知错误') : String(foreignKeysResult.reason)}`);
}
if (triggersResult.status === 'fulfilled' && triggersResult.value?.success && Array.isArray(triggersResult.value.data)) {
payload.triggers = triggersResult.value.data;
} else {
warnings.push(`触发器获取失败:${triggersResult.status === 'fulfilled' ? (triggersResult.value?.message || '未知错误') : String(triggersResult.reason)}`);
}
if (ddlResult.status === 'fulfilled' && ddlResult.value?.success) {
payload.ddl = ddlResult.value.content;
} else {
warnings.push(`DDL 获取失败:${ddlResult.status === 'fulfilled' ? (ddlResult.value?.content || '未知错误') : String(ddlResult.reason)}`);
}
if (includeSampleRows) {
if (sampleRowsResult.status === 'fulfilled' && sampleRowsResult.value?.success) {
const rows = Array.isArray(sampleRowsResult.value.data) ? sampleRowsResult.value.data : [];
payload.sampleRows = {
limit: sampleLimit,
rowCount: rows.length,
rows: rows.slice(0, sampleLimit),
};
} else {
warnings.push(`样例数据获取失败:${sampleRowsResult.status === 'fulfilled' ? (sampleRowsResult.value?.message || '未知错误') : String(sampleRowsResult.reason)}`);
}
}
if (warnings.length > 0) {
payload.warnings = warnings;
}
return { content: JSON.stringify(payload), success: true };
} catch (error: any) {
return { content: `获取表结构快照失败: ${error?.message || error}`, success: false };
}
};
export const inspectDatabaseBundle = async ({
args,
connection,
runtime,
}: InspectDatabaseBundleOptions): Promise<DatabaseToolExecutionResult> => {
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
if (!safeDbName) return { content: 'dbName 不能为空', success: false };
const includeColumns = args.includeColumns !== false;
const tableLimit = normalizeTableLimit(args.tableLimit);
const perTableColumnLimit = normalizePerTableColumnLimit(args.perTableColumnLimit);
const rpcConfig = buildRpcConnectionConfig(connection.config) as any;
const results = await Promise.allSettled([
runtime.getTables(rpcConfig, safeDbName),
includeColumns ? runtime.getAllColumns(rpcConfig, safeDbName) : Promise.resolve(undefined),
]);
const warnings: string[] = [];
const tablesResult = results[0];
const allColumnsResult = results[1];
const allColumns = allColumnsResult.status === 'fulfilled' && allColumnsResult.value?.success && Array.isArray(allColumnsResult.value.data)
? normalizeColumnsWithTable(allColumnsResult.value.data)
: [];
const tableNamesFromColumns = Array.from(new Set(allColumns.map((column) => column.tableName).filter(Boolean)));
let tableNames: string[] = [];
if (tablesResult.status === 'fulfilled' && tablesResult.value?.success && Array.isArray(tablesResult.value.data)) {
tableNames = normalizeTableList(tablesResult.value.data).filter(Boolean);
} else if (tableNamesFromColumns.length > 0) {
tableNames = tableNamesFromColumns;
warnings.push(`表列表获取失败,已退回字段摘要推断:${tablesResult.status === 'fulfilled' ? (tablesResult.value?.message || '未知错误') : String(tablesResult.reason)}`);
} else {
warnings.push(`表列表获取失败:${tablesResult.status === 'fulfilled' ? (tablesResult.value?.message || '未知错误') : String(tablesResult.reason)}`);
}
if (includeColumns && allColumnsResult.status === 'fulfilled' && (!allColumnsResult.value?.success || !Array.isArray(allColumnsResult.value.data))) {
warnings.push(`字段摘要获取失败:${allColumnsResult.value?.message || '未知错误'}`);
} else if (includeColumns && allColumnsResult.status === 'rejected') {
warnings.push(`字段摘要获取失败:${String(allColumnsResult.reason)}`);
}
const uniqueTableNames = Array.from(new Set(tableNames.filter(Boolean)));
const visibleTableNames = uniqueTableNames.slice(0, tableLimit);
const columnsByTable = new Map<string, ReturnType<typeof normalizeColumnsWithTable>>();
allColumns.forEach((column) => {
const tableName = String(column.tableName || '').trim();
if (!tableName) return;
const current = columnsByTable.get(tableName) || [];
current.push(column);
columnsByTable.set(tableName, current);
});
const payload: Record<string, unknown> = {
dbName: safeDbName,
tableCount: uniqueTableNames.length,
totalColumns: allColumns.length,
tables: visibleTableNames,
truncatedTables: uniqueTableNames.length > visibleTableNames.length,
};
if (includeColumns) {
payload.tableSummaries = visibleTableNames.map((tableName) => {
const tableColumns = columnsByTable.get(tableName) || [];
return {
tableName,
columnCount: tableColumns.length,
truncatedColumns: tableColumns.length > perTableColumnLimit,
columns: tableColumns.slice(0, perTableColumnLimit),
};
});
}
if (warnings.length > 0) {
payload.warnings = warnings;
}
return { content: JSON.stringify(payload), success: true };
} catch (error: any) {
return { content: `获取数据库结构总览失败: ${error?.message || error}`, success: false };
}
};

View File

@@ -1,9 +1,16 @@
import type { SavedConnection } from '../../types';
import { buildRpcConnectionConfig } from '../../utils/connectionRpcConfig';
import { buildAIReadonlyPreviewSQL } from '../../utils/aiSqlLimit';
import { buildPaginatedSelectSQL, quoteQualifiedIdent } from '../../utils/sql';
import { resolveAITableSchemaToolResult } from '../../utils/aiTableSchemaTool';
import type { AILocalToolRuntime, AIToolContextEntry } from './aiLocalToolRuntime';
import {
buildPreviewSQLForTable,
normalizeColumns,
normalizeColumnsWithTable,
normalizePreviewLimit,
normalizeTableList,
} from './aiDatabaseToolHelpers';
import { inspectDatabaseBundle, inspectTableBundle } from './aiDatabaseBundleTools';
interface ExecuteDatabaseToolCallOptions {
toolName: string;
@@ -21,64 +28,6 @@ interface ToolExecutionResult {
const findConnection = (connections: SavedConnection[], connectionId: string) =>
connections.find((connection) => connection.id === connectionId);
const normalizeTableList = (rows: any[]): string[] =>
rows.map((row) => row.Table || row.table || (Object.values(row)[0] as string));
const normalizeColumns = (rows: any[]) =>
rows.map((column) => {
const keys = Object.keys(column);
return {
field: column.Field || column.field || column.COLUMN_NAME || column.column_name || column.Name || column.name || (keys.length > 0 ? column[keys[0]] : ''),
type: column.Type || column.type || column.DATA_TYPE || column.data_type || (keys.length > 1 ? column[keys[1]] : ''),
nullable: column.Null || column.null || column.IS_NULLABLE || column.is_nullable || column.Nullable || column.nullable || '',
default: column.Default || column.default || column.COLUMN_DEFAULT || column.column_default || column.DefaultValue || '',
comment: column.Comment || column.comment || column.COLUMN_COMMENT || column.column_comment || column.Description || '',
};
});
const normalizeColumnsWithTable = (rows: any[]) =>
rows.map((column) => {
const keys = Object.keys(column);
return {
tableName: column.TableName || column.tableName || column.TABLE_NAME || column.table_name || (keys.length > 0 ? column[keys[0]] : ''),
name: column.Name || column.name || column.COLUMN_NAME || column.column_name || (keys.length > 1 ? column[keys[1]] : ''),
type: column.Type || column.type || column.DATA_TYPE || column.data_type || (keys.length > 2 ? column[keys[2]] : ''),
comment: column.Comment || column.comment || column.COLUMN_COMMENT || column.column_comment || '',
};
});
const normalizePreviewLimit = (input: unknown): number => {
const value = Math.floor(Number(input) || 20);
if (value < 1) return 1;
if (value > 100) return 100;
return value;
};
const normalizeTableLimit = (input: unknown): number => {
const value = Math.floor(Number(input) || 80);
if (value < 1) return 1;
if (value > 200) return 200;
return value;
};
const normalizePerTableColumnLimit = (input: unknown): number => {
const value = Math.floor(Number(input) || 8);
if (value < 1) return 1;
if (value > 30) return 30;
return value;
};
const buildPreviewSQLForTable = (connection: SavedConnection, tableName: string, limit: number): string => {
const dbType = String(connection.config?.type || '').trim();
return buildPaginatedSelectSQL(
dbType,
`SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)}`,
'',
limit,
0,
);
};
const resolveConnectionOrFailure = (
connections: SavedConnection[],
connectionId: string,
@@ -266,158 +215,12 @@ export async function executeDatabaseToolCall(
case 'inspect_table_bundle': {
const resolved = resolveConnectionOrFailure(connections, args.connectionId);
if (resolved.failure || !resolved.connection) return resolved.failure || null;
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
const safeTable = args.tableName ? String(args.tableName).trim() : '';
if (!safeTable) return { content: 'tableName 不能为空', success: false };
const includeSampleRows = args.includeSampleRows === true;
const sampleLimit = normalizePreviewLimit(args.sampleLimit ?? 10);
const rpcConfig = buildRpcConnectionConfig(resolved.connection.config) as any;
const results = await Promise.allSettled([
runtime.getColumns(rpcConfig, safeDbName, safeTable),
runtime.getIndexes(rpcConfig, safeDbName, safeTable),
runtime.getForeignKeys(rpcConfig, safeDbName, safeTable),
runtime.getTriggers(rpcConfig, safeDbName, safeTable),
resolveAITableSchemaToolResult({
tableName: safeTable,
fetchDDL: () => runtime.showCreateTable(rpcConfig, safeDbName, safeTable),
fetchColumns: () => runtime.getColumns(rpcConfig, safeDbName, safeTable),
}),
includeSampleRows
? runtime.query(rpcConfig, safeDbName, buildPreviewSQLForTable(resolved.connection, safeTable, sampleLimit))
: Promise.resolve(undefined),
]);
const warnings: string[] = [];
const payload: Record<string, unknown> = {
dbName: safeDbName,
tableName: safeTable,
columns: [],
indexes: [],
foreignKeys: [],
triggers: [],
ddl: '',
};
const columnsResult = results[0];
const indexesResult = results[1];
const foreignKeysResult = results[2];
const triggersResult = results[3];
const ddlResult = results[4];
const sampleRowsResult = results[5];
if (columnsResult.status === 'fulfilled' && columnsResult.value?.success && Array.isArray(columnsResult.value.data)) {
payload.columns = normalizeColumns(columnsResult.value.data);
} else {
warnings.push(`字段列表获取失败:${columnsResult.status === 'fulfilled' ? (columnsResult.value?.message || '未知错误') : String(columnsResult.reason)}`);
}
if (indexesResult.status === 'fulfilled' && indexesResult.value?.success && Array.isArray(indexesResult.value.data)) {
payload.indexes = indexesResult.value.data;
} else {
warnings.push(`索引定义获取失败:${indexesResult.status === 'fulfilled' ? (indexesResult.value?.message || '未知错误') : String(indexesResult.reason)}`);
}
if (foreignKeysResult.status === 'fulfilled' && foreignKeysResult.value?.success && Array.isArray(foreignKeysResult.value.data)) {
payload.foreignKeys = foreignKeysResult.value.data;
} else {
warnings.push(`外键关系获取失败:${foreignKeysResult.status === 'fulfilled' ? (foreignKeysResult.value?.message || '未知错误') : String(foreignKeysResult.reason)}`);
}
if (triggersResult.status === 'fulfilled' && triggersResult.value?.success && Array.isArray(triggersResult.value.data)) {
payload.triggers = triggersResult.value.data;
} else {
warnings.push(`触发器获取失败:${triggersResult.status === 'fulfilled' ? (triggersResult.value?.message || '未知错误') : String(triggersResult.reason)}`);
}
if (ddlResult.status === 'fulfilled' && ddlResult.value?.success) {
payload.ddl = ddlResult.value.content;
} else {
warnings.push(`DDL 获取失败:${ddlResult.status === 'fulfilled' ? (ddlResult.value?.content || '未知错误') : String(ddlResult.reason)}`);
}
if (includeSampleRows) {
if (sampleRowsResult.status === 'fulfilled' && sampleRowsResult.value?.success) {
const rows = Array.isArray(sampleRowsResult.value.data) ? sampleRowsResult.value.data : [];
payload.sampleRows = {
limit: sampleLimit,
rowCount: rows.length,
rows: rows.slice(0, sampleLimit),
};
} else {
warnings.push(`样例数据获取失败:${sampleRowsResult.status === 'fulfilled' ? (sampleRowsResult.value?.message || '未知错误') : String(sampleRowsResult.reason)}`);
}
}
if (warnings.length > 0) {
payload.warnings = warnings;
}
return { content: JSON.stringify(payload), success: true };
} catch (error: any) {
return { content: `获取表结构快照失败: ${error?.message || error}`, success: false };
}
return inspectTableBundle({ args, connection: resolved.connection, runtime });
}
case 'inspect_database_bundle': {
const resolved = resolveConnectionOrFailure(connections, args.connectionId);
if (resolved.failure || !resolved.connection) return resolved.failure || null;
try {
const safeDbName = args.dbName ? String(args.dbName).trim() : '';
if (!safeDbName) return { content: 'dbName 不能为空', success: false };
const includeColumns = args.includeColumns !== false;
const tableLimit = normalizeTableLimit(args.tableLimit);
const perTableColumnLimit = normalizePerTableColumnLimit(args.perTableColumnLimit);
const rpcConfig = buildRpcConnectionConfig(resolved.connection.config) as any;
const results = await Promise.allSettled([
runtime.getTables(rpcConfig, safeDbName),
includeColumns ? runtime.getAllColumns(rpcConfig, safeDbName) : Promise.resolve(undefined),
]);
const warnings: string[] = [];
const tablesResult = results[0];
const allColumnsResult = results[1];
const allColumns = allColumnsResult.status === 'fulfilled' && allColumnsResult.value?.success && Array.isArray(allColumnsResult.value.data)
? normalizeColumnsWithTable(allColumnsResult.value.data)
: [];
const tableNamesFromColumns = Array.from(new Set(allColumns.map((column) => column.tableName).filter(Boolean)));
let tableNames: string[] = [];
if (tablesResult.status === 'fulfilled' && tablesResult.value?.success && Array.isArray(tablesResult.value.data)) {
tableNames = normalizeTableList(tablesResult.value.data).filter(Boolean);
} else if (tableNamesFromColumns.length > 0) {
tableNames = tableNamesFromColumns;
warnings.push(`表列表获取失败,已退回字段摘要推断:${tablesResult.status === 'fulfilled' ? (tablesResult.value?.message || '未知错误') : String(tablesResult.reason)}`);
} else {
warnings.push(`表列表获取失败:${tablesResult.status === 'fulfilled' ? (tablesResult.value?.message || '未知错误') : String(tablesResult.reason)}`);
}
if (includeColumns && allColumnsResult.status === 'fulfilled' && (!allColumnsResult.value?.success || !Array.isArray(allColumnsResult.value.data))) {
warnings.push(`字段摘要获取失败:${allColumnsResult.value?.message || '未知错误'}`);
} else if (includeColumns && allColumnsResult.status === 'rejected') {
warnings.push(`字段摘要获取失败:${String(allColumnsResult.reason)}`);
}
const uniqueTableNames = Array.from(new Set(tableNames.filter(Boolean)));
const visibleTableNames = uniqueTableNames.slice(0, tableLimit);
const columnsByTable = new Map<string, ReturnType<typeof normalizeColumnsWithTable>>();
allColumns.forEach((column) => {
const tableName = String(column.tableName || '').trim();
if (!tableName) return;
const current = columnsByTable.get(tableName) || [];
current.push(column);
columnsByTable.set(tableName, current);
});
const payload: Record<string, unknown> = {
dbName: safeDbName,
tableCount: uniqueTableNames.length,
totalColumns: allColumns.length,
tables: visibleTableNames,
truncatedTables: uniqueTableNames.length > visibleTableNames.length,
};
if (includeColumns) {
payload.tableSummaries = visibleTableNames.map((tableName) => {
const tableColumns = columnsByTable.get(tableName) || [];
return {
tableName,
columnCount: tableColumns.length,
truncatedColumns: tableColumns.length > perTableColumnLimit,
columns: tableColumns.slice(0, perTableColumnLimit),
};
});
}
if (warnings.length > 0) {
payload.warnings = warnings;
}
return { content: JSON.stringify(payload), success: true };
} catch (error: any) {
return { content: `获取数据库结构总览失败: ${error?.message || error}`, success: false };
}
return inspectDatabaseBundle({ args, connection: resolved.connection, runtime });
}
case 'preview_table_rows': {
const resolved = resolveConnectionOrFailure(connections, args.connectionId);

View File

@@ -0,0 +1,60 @@
import type { SavedConnection } from '../../types';
import { buildPaginatedSelectSQL, quoteQualifiedIdent } from '../../utils/sql';
export const normalizeTableList = (rows: any[]): string[] =>
rows.map((row) => row.Table || row.table || (Object.values(row)[0] as string));
export const normalizeColumns = (rows: any[]) =>
rows.map((column) => {
const keys = Object.keys(column);
return {
field: column.Field || column.field || column.COLUMN_NAME || column.column_name || column.Name || column.name || (keys.length > 0 ? column[keys[0]] : ''),
type: column.Type || column.type || column.DATA_TYPE || column.data_type || (keys.length > 1 ? column[keys[1]] : ''),
nullable: column.Null || column.null || column.IS_NULLABLE || column.is_nullable || column.Nullable || column.nullable || '',
default: column.Default || column.default || column.COLUMN_DEFAULT || column.column_default || column.DefaultValue || '',
comment: column.Comment || column.comment || column.COLUMN_COMMENT || column.column_comment || column.Description || '',
};
});
export const normalizeColumnsWithTable = (rows: any[]) =>
rows.map((column) => {
const keys = Object.keys(column);
return {
tableName: column.TableName || column.tableName || column.TABLE_NAME || column.table_name || (keys.length > 0 ? column[keys[0]] : ''),
name: column.Name || column.name || column.COLUMN_NAME || column.column_name || (keys.length > 1 ? column[keys[1]] : ''),
type: column.Type || column.type || column.DATA_TYPE || column.data_type || (keys.length > 2 ? column[keys[2]] : ''),
comment: column.Comment || column.comment || column.COLUMN_COMMENT || column.column_comment || '',
};
});
export const normalizePreviewLimit = (input: unknown): number => {
const value = Math.floor(Number(input) || 20);
if (value < 1) return 1;
if (value > 100) return 100;
return value;
};
export const normalizeTableLimit = (input: unknown): number => {
const value = Math.floor(Number(input) || 80);
if (value < 1) return 1;
if (value > 200) return 200;
return value;
};
export const normalizePerTableColumnLimit = (input: unknown): number => {
const value = Math.floor(Number(input) || 8);
if (value < 1) return 1;
if (value > 30) return 30;
return value;
};
export const buildPreviewSQLForTable = (connection: SavedConnection, tableName: string, limit: number): string => {
const dbType = String(connection.config?.type || '').trim();
return buildPaginatedSelectSQL(
dbType,
`SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)}`,
'',
limit,
0,
);
};