diff --git a/frontend/src/components/ai/aiDatabaseBundleTools.ts b/frontend/src/components/ai/aiDatabaseBundleTools.ts new file mode 100644 index 0000000..7492287 --- /dev/null +++ b/frontend/src/components/ai/aiDatabaseBundleTools.ts @@ -0,0 +1,186 @@ +import type { SavedConnection } from '../../types'; +import { buildRpcConnectionConfig } from '../../utils/connectionRpcConfig'; +import { resolveAITableSchemaToolResult } from '../../utils/aiTableSchemaTool'; +import type { AILocalToolRuntime } from './aiLocalToolRuntime'; +import { + buildPreviewSQLForTable, + normalizeColumns, + normalizeColumnsWithTable, + normalizePerTableColumnLimit, + normalizePreviewLimit, + normalizeTableLimit, + normalizeTableList, +} from './aiDatabaseToolHelpers'; + +interface DatabaseToolExecutionResult { + content: string; + success: boolean; +} + +interface InspectDatabaseBundleOptions { + args: Record; + connection: SavedConnection; + runtime: AILocalToolRuntime; +} + +export const inspectTableBundle = async ({ + args, + connection, + runtime, +}: InspectDatabaseBundleOptions): Promise => { + try { + const safeDbName = args.dbName ? String(args.dbName).trim() : ''; + const safeTable = args.tableName ? String(args.tableName).trim() : ''; + if (!safeTable) return { content: 'tableName 不能为空', success: false }; + const includeSampleRows = args.includeSampleRows === true; + const sampleLimit = normalizePreviewLimit(args.sampleLimit ?? 10); + const rpcConfig = buildRpcConnectionConfig(connection.config) as any; + const results = await Promise.allSettled([ + runtime.getColumns(rpcConfig, safeDbName, safeTable), + runtime.getIndexes(rpcConfig, safeDbName, safeTable), + runtime.getForeignKeys(rpcConfig, safeDbName, safeTable), + runtime.getTriggers(rpcConfig, safeDbName, safeTable), + resolveAITableSchemaToolResult({ + tableName: safeTable, + fetchDDL: () => runtime.showCreateTable(rpcConfig, safeDbName, safeTable), + fetchColumns: () => runtime.getColumns(rpcConfig, safeDbName, safeTable), + }), + includeSampleRows + ? runtime.query(rpcConfig, safeDbName, buildPreviewSQLForTable(connection, safeTable, sampleLimit)) + : Promise.resolve(undefined), + ]); + const warnings: string[] = []; + const payload: Record = { + dbName: safeDbName, + tableName: safeTable, + columns: [], + indexes: [], + foreignKeys: [], + triggers: [], + ddl: '', + }; + const columnsResult = results[0]; + const indexesResult = results[1]; + const foreignKeysResult = results[2]; + const triggersResult = results[3]; + const ddlResult = results[4]; + const sampleRowsResult = results[5]; + if (columnsResult.status === 'fulfilled' && columnsResult.value?.success && Array.isArray(columnsResult.value.data)) { + payload.columns = normalizeColumns(columnsResult.value.data); + } else { + warnings.push(`字段列表获取失败:${columnsResult.status === 'fulfilled' ? (columnsResult.value?.message || '未知错误') : String(columnsResult.reason)}`); + } + if (indexesResult.status === 'fulfilled' && indexesResult.value?.success && Array.isArray(indexesResult.value.data)) { + payload.indexes = indexesResult.value.data; + } else { + warnings.push(`索引定义获取失败:${indexesResult.status === 'fulfilled' ? (indexesResult.value?.message || '未知错误') : String(indexesResult.reason)}`); + } + if (foreignKeysResult.status === 'fulfilled' && foreignKeysResult.value?.success && Array.isArray(foreignKeysResult.value.data)) { + payload.foreignKeys = foreignKeysResult.value.data; + } else { + warnings.push(`外键关系获取失败:${foreignKeysResult.status === 'fulfilled' ? (foreignKeysResult.value?.message || '未知错误') : String(foreignKeysResult.reason)}`); + } + if (triggersResult.status === 'fulfilled' && triggersResult.value?.success && Array.isArray(triggersResult.value.data)) { + payload.triggers = triggersResult.value.data; + } else { + warnings.push(`触发器获取失败:${triggersResult.status === 'fulfilled' ? (triggersResult.value?.message || '未知错误') : String(triggersResult.reason)}`); + } + if (ddlResult.status === 'fulfilled' && ddlResult.value?.success) { + payload.ddl = ddlResult.value.content; + } else { + warnings.push(`DDL 获取失败:${ddlResult.status === 'fulfilled' ? (ddlResult.value?.content || '未知错误') : String(ddlResult.reason)}`); + } + if (includeSampleRows) { + if (sampleRowsResult.status === 'fulfilled' && sampleRowsResult.value?.success) { + const rows = Array.isArray(sampleRowsResult.value.data) ? sampleRowsResult.value.data : []; + payload.sampleRows = { + limit: sampleLimit, + rowCount: rows.length, + rows: rows.slice(0, sampleLimit), + }; + } else { + warnings.push(`样例数据获取失败:${sampleRowsResult.status === 'fulfilled' ? (sampleRowsResult.value?.message || '未知错误') : String(sampleRowsResult.reason)}`); + } + } + if (warnings.length > 0) { + payload.warnings = warnings; + } + return { content: JSON.stringify(payload), success: true }; + } catch (error: any) { + return { content: `获取表结构快照失败: ${error?.message || error}`, success: false }; + } +}; + +export const inspectDatabaseBundle = async ({ + args, + connection, + runtime, +}: InspectDatabaseBundleOptions): Promise => { + try { + const safeDbName = args.dbName ? String(args.dbName).trim() : ''; + if (!safeDbName) return { content: 'dbName 不能为空', success: false }; + const includeColumns = args.includeColumns !== false; + const tableLimit = normalizeTableLimit(args.tableLimit); + const perTableColumnLimit = normalizePerTableColumnLimit(args.perTableColumnLimit); + const rpcConfig = buildRpcConnectionConfig(connection.config) as any; + const results = await Promise.allSettled([ + runtime.getTables(rpcConfig, safeDbName), + includeColumns ? runtime.getAllColumns(rpcConfig, safeDbName) : Promise.resolve(undefined), + ]); + const warnings: string[] = []; + const tablesResult = results[0]; + const allColumnsResult = results[1]; + const allColumns = allColumnsResult.status === 'fulfilled' && allColumnsResult.value?.success && Array.isArray(allColumnsResult.value.data) + ? normalizeColumnsWithTable(allColumnsResult.value.data) + : []; + const tableNamesFromColumns = Array.from(new Set(allColumns.map((column) => column.tableName).filter(Boolean))); + let tableNames: string[] = []; + if (tablesResult.status === 'fulfilled' && tablesResult.value?.success && Array.isArray(tablesResult.value.data)) { + tableNames = normalizeTableList(tablesResult.value.data).filter(Boolean); + } else if (tableNamesFromColumns.length > 0) { + tableNames = tableNamesFromColumns; + warnings.push(`表列表获取失败,已退回字段摘要推断:${tablesResult.status === 'fulfilled' ? (tablesResult.value?.message || '未知错误') : String(tablesResult.reason)}`); + } else { + warnings.push(`表列表获取失败:${tablesResult.status === 'fulfilled' ? (tablesResult.value?.message || '未知错误') : String(tablesResult.reason)}`); + } + if (includeColumns && allColumnsResult.status === 'fulfilled' && (!allColumnsResult.value?.success || !Array.isArray(allColumnsResult.value.data))) { + warnings.push(`字段摘要获取失败:${allColumnsResult.value?.message || '未知错误'}`); + } else if (includeColumns && allColumnsResult.status === 'rejected') { + warnings.push(`字段摘要获取失败:${String(allColumnsResult.reason)}`); + } + const uniqueTableNames = Array.from(new Set(tableNames.filter(Boolean))); + const visibleTableNames = uniqueTableNames.slice(0, tableLimit); + const columnsByTable = new Map>(); + allColumns.forEach((column) => { + const tableName = String(column.tableName || '').trim(); + if (!tableName) return; + const current = columnsByTable.get(tableName) || []; + current.push(column); + columnsByTable.set(tableName, current); + }); + const payload: Record = { + dbName: safeDbName, + tableCount: uniqueTableNames.length, + totalColumns: allColumns.length, + tables: visibleTableNames, + truncatedTables: uniqueTableNames.length > visibleTableNames.length, + }; + if (includeColumns) { + payload.tableSummaries = visibleTableNames.map((tableName) => { + const tableColumns = columnsByTable.get(tableName) || []; + return { + tableName, + columnCount: tableColumns.length, + truncatedColumns: tableColumns.length > perTableColumnLimit, + columns: tableColumns.slice(0, perTableColumnLimit), + }; + }); + } + if (warnings.length > 0) { + payload.warnings = warnings; + } + return { content: JSON.stringify(payload), success: true }; + } catch (error: any) { + return { content: `获取数据库结构总览失败: ${error?.message || error}`, success: false }; + } +}; diff --git a/frontend/src/components/ai/aiDatabaseToolExecutor.ts b/frontend/src/components/ai/aiDatabaseToolExecutor.ts index 5eac7bb..b6dbbd5 100644 --- a/frontend/src/components/ai/aiDatabaseToolExecutor.ts +++ b/frontend/src/components/ai/aiDatabaseToolExecutor.ts @@ -1,9 +1,16 @@ import type { SavedConnection } from '../../types'; import { buildRpcConnectionConfig } from '../../utils/connectionRpcConfig'; import { buildAIReadonlyPreviewSQL } from '../../utils/aiSqlLimit'; -import { buildPaginatedSelectSQL, quoteQualifiedIdent } from '../../utils/sql'; import { resolveAITableSchemaToolResult } from '../../utils/aiTableSchemaTool'; import type { AILocalToolRuntime, AIToolContextEntry } from './aiLocalToolRuntime'; +import { + buildPreviewSQLForTable, + normalizeColumns, + normalizeColumnsWithTable, + normalizePreviewLimit, + normalizeTableList, +} from './aiDatabaseToolHelpers'; +import { inspectDatabaseBundle, inspectTableBundle } from './aiDatabaseBundleTools'; interface ExecuteDatabaseToolCallOptions { toolName: string; @@ -21,64 +28,6 @@ interface ToolExecutionResult { const findConnection = (connections: SavedConnection[], connectionId: string) => connections.find((connection) => connection.id === connectionId); -const normalizeTableList = (rows: any[]): string[] => - rows.map((row) => row.Table || row.table || (Object.values(row)[0] as string)); - -const normalizeColumns = (rows: any[]) => - rows.map((column) => { - const keys = Object.keys(column); - return { - field: column.Field || column.field || column.COLUMN_NAME || column.column_name || column.Name || column.name || (keys.length > 0 ? column[keys[0]] : ''), - type: column.Type || column.type || column.DATA_TYPE || column.data_type || (keys.length > 1 ? column[keys[1]] : ''), - nullable: column.Null || column.null || column.IS_NULLABLE || column.is_nullable || column.Nullable || column.nullable || '', - default: column.Default || column.default || column.COLUMN_DEFAULT || column.column_default || column.DefaultValue || '', - comment: column.Comment || column.comment || column.COLUMN_COMMENT || column.column_comment || column.Description || '', - }; - }); - -const normalizeColumnsWithTable = (rows: any[]) => - rows.map((column) => { - const keys = Object.keys(column); - return { - tableName: column.TableName || column.tableName || column.TABLE_NAME || column.table_name || (keys.length > 0 ? column[keys[0]] : ''), - name: column.Name || column.name || column.COLUMN_NAME || column.column_name || (keys.length > 1 ? column[keys[1]] : ''), - type: column.Type || column.type || column.DATA_TYPE || column.data_type || (keys.length > 2 ? column[keys[2]] : ''), - comment: column.Comment || column.comment || column.COLUMN_COMMENT || column.column_comment || '', - }; - }); - -const normalizePreviewLimit = (input: unknown): number => { - const value = Math.floor(Number(input) || 20); - if (value < 1) return 1; - if (value > 100) return 100; - return value; -}; - -const normalizeTableLimit = (input: unknown): number => { - const value = Math.floor(Number(input) || 80); - if (value < 1) return 1; - if (value > 200) return 200; - return value; -}; - -const normalizePerTableColumnLimit = (input: unknown): number => { - const value = Math.floor(Number(input) || 8); - if (value < 1) return 1; - if (value > 30) return 30; - return value; -}; - -const buildPreviewSQLForTable = (connection: SavedConnection, tableName: string, limit: number): string => { - const dbType = String(connection.config?.type || '').trim(); - return buildPaginatedSelectSQL( - dbType, - `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)}`, - '', - limit, - 0, - ); -}; - const resolveConnectionOrFailure = ( connections: SavedConnection[], connectionId: string, @@ -266,158 +215,12 @@ export async function executeDatabaseToolCall( case 'inspect_table_bundle': { const resolved = resolveConnectionOrFailure(connections, args.connectionId); if (resolved.failure || !resolved.connection) return resolved.failure || null; - try { - const safeDbName = args.dbName ? String(args.dbName).trim() : ''; - const safeTable = args.tableName ? String(args.tableName).trim() : ''; - if (!safeTable) return { content: 'tableName 不能为空', success: false }; - const includeSampleRows = args.includeSampleRows === true; - const sampleLimit = normalizePreviewLimit(args.sampleLimit ?? 10); - const rpcConfig = buildRpcConnectionConfig(resolved.connection.config) as any; - const results = await Promise.allSettled([ - runtime.getColumns(rpcConfig, safeDbName, safeTable), - runtime.getIndexes(rpcConfig, safeDbName, safeTable), - runtime.getForeignKeys(rpcConfig, safeDbName, safeTable), - runtime.getTriggers(rpcConfig, safeDbName, safeTable), - resolveAITableSchemaToolResult({ - tableName: safeTable, - fetchDDL: () => runtime.showCreateTable(rpcConfig, safeDbName, safeTable), - fetchColumns: () => runtime.getColumns(rpcConfig, safeDbName, safeTable), - }), - includeSampleRows - ? runtime.query(rpcConfig, safeDbName, buildPreviewSQLForTable(resolved.connection, safeTable, sampleLimit)) - : Promise.resolve(undefined), - ]); - const warnings: string[] = []; - const payload: Record = { - dbName: safeDbName, - tableName: safeTable, - columns: [], - indexes: [], - foreignKeys: [], - triggers: [], - ddl: '', - }; - const columnsResult = results[0]; - const indexesResult = results[1]; - const foreignKeysResult = results[2]; - const triggersResult = results[3]; - const ddlResult = results[4]; - const sampleRowsResult = results[5]; - if (columnsResult.status === 'fulfilled' && columnsResult.value?.success && Array.isArray(columnsResult.value.data)) { - payload.columns = normalizeColumns(columnsResult.value.data); - } else { - warnings.push(`字段列表获取失败:${columnsResult.status === 'fulfilled' ? (columnsResult.value?.message || '未知错误') : String(columnsResult.reason)}`); - } - if (indexesResult.status === 'fulfilled' && indexesResult.value?.success && Array.isArray(indexesResult.value.data)) { - payload.indexes = indexesResult.value.data; - } else { - warnings.push(`索引定义获取失败:${indexesResult.status === 'fulfilled' ? (indexesResult.value?.message || '未知错误') : String(indexesResult.reason)}`); - } - if (foreignKeysResult.status === 'fulfilled' && foreignKeysResult.value?.success && Array.isArray(foreignKeysResult.value.data)) { - payload.foreignKeys = foreignKeysResult.value.data; - } else { - warnings.push(`外键关系获取失败:${foreignKeysResult.status === 'fulfilled' ? (foreignKeysResult.value?.message || '未知错误') : String(foreignKeysResult.reason)}`); - } - if (triggersResult.status === 'fulfilled' && triggersResult.value?.success && Array.isArray(triggersResult.value.data)) { - payload.triggers = triggersResult.value.data; - } else { - warnings.push(`触发器获取失败:${triggersResult.status === 'fulfilled' ? (triggersResult.value?.message || '未知错误') : String(triggersResult.reason)}`); - } - if (ddlResult.status === 'fulfilled' && ddlResult.value?.success) { - payload.ddl = ddlResult.value.content; - } else { - warnings.push(`DDL 获取失败:${ddlResult.status === 'fulfilled' ? (ddlResult.value?.content || '未知错误') : String(ddlResult.reason)}`); - } - if (includeSampleRows) { - if (sampleRowsResult.status === 'fulfilled' && sampleRowsResult.value?.success) { - const rows = Array.isArray(sampleRowsResult.value.data) ? sampleRowsResult.value.data : []; - payload.sampleRows = { - limit: sampleLimit, - rowCount: rows.length, - rows: rows.slice(0, sampleLimit), - }; - } else { - warnings.push(`样例数据获取失败:${sampleRowsResult.status === 'fulfilled' ? (sampleRowsResult.value?.message || '未知错误') : String(sampleRowsResult.reason)}`); - } - } - if (warnings.length > 0) { - payload.warnings = warnings; - } - return { content: JSON.stringify(payload), success: true }; - } catch (error: any) { - return { content: `获取表结构快照失败: ${error?.message || error}`, success: false }; - } + return inspectTableBundle({ args, connection: resolved.connection, runtime }); } case 'inspect_database_bundle': { const resolved = resolveConnectionOrFailure(connections, args.connectionId); if (resolved.failure || !resolved.connection) return resolved.failure || null; - try { - const safeDbName = args.dbName ? String(args.dbName).trim() : ''; - if (!safeDbName) return { content: 'dbName 不能为空', success: false }; - const includeColumns = args.includeColumns !== false; - const tableLimit = normalizeTableLimit(args.tableLimit); - const perTableColumnLimit = normalizePerTableColumnLimit(args.perTableColumnLimit); - const rpcConfig = buildRpcConnectionConfig(resolved.connection.config) as any; - const results = await Promise.allSettled([ - runtime.getTables(rpcConfig, safeDbName), - includeColumns ? runtime.getAllColumns(rpcConfig, safeDbName) : Promise.resolve(undefined), - ]); - const warnings: string[] = []; - const tablesResult = results[0]; - const allColumnsResult = results[1]; - const allColumns = allColumnsResult.status === 'fulfilled' && allColumnsResult.value?.success && Array.isArray(allColumnsResult.value.data) - ? normalizeColumnsWithTable(allColumnsResult.value.data) - : []; - const tableNamesFromColumns = Array.from(new Set(allColumns.map((column) => column.tableName).filter(Boolean))); - let tableNames: string[] = []; - if (tablesResult.status === 'fulfilled' && tablesResult.value?.success && Array.isArray(tablesResult.value.data)) { - tableNames = normalizeTableList(tablesResult.value.data).filter(Boolean); - } else if (tableNamesFromColumns.length > 0) { - tableNames = tableNamesFromColumns; - warnings.push(`表列表获取失败,已退回字段摘要推断:${tablesResult.status === 'fulfilled' ? (tablesResult.value?.message || '未知错误') : String(tablesResult.reason)}`); - } else { - warnings.push(`表列表获取失败:${tablesResult.status === 'fulfilled' ? (tablesResult.value?.message || '未知错误') : String(tablesResult.reason)}`); - } - if (includeColumns && allColumnsResult.status === 'fulfilled' && (!allColumnsResult.value?.success || !Array.isArray(allColumnsResult.value.data))) { - warnings.push(`字段摘要获取失败:${allColumnsResult.value?.message || '未知错误'}`); - } else if (includeColumns && allColumnsResult.status === 'rejected') { - warnings.push(`字段摘要获取失败:${String(allColumnsResult.reason)}`); - } - const uniqueTableNames = Array.from(new Set(tableNames.filter(Boolean))); - const visibleTableNames = uniqueTableNames.slice(0, tableLimit); - const columnsByTable = new Map>(); - allColumns.forEach((column) => { - const tableName = String(column.tableName || '').trim(); - if (!tableName) return; - const current = columnsByTable.get(tableName) || []; - current.push(column); - columnsByTable.set(tableName, current); - }); - const payload: Record = { - dbName: safeDbName, - tableCount: uniqueTableNames.length, - totalColumns: allColumns.length, - tables: visibleTableNames, - truncatedTables: uniqueTableNames.length > visibleTableNames.length, - }; - if (includeColumns) { - payload.tableSummaries = visibleTableNames.map((tableName) => { - const tableColumns = columnsByTable.get(tableName) || []; - return { - tableName, - columnCount: tableColumns.length, - truncatedColumns: tableColumns.length > perTableColumnLimit, - columns: tableColumns.slice(0, perTableColumnLimit), - }; - }); - } - if (warnings.length > 0) { - payload.warnings = warnings; - } - return { content: JSON.stringify(payload), success: true }; - } catch (error: any) { - return { content: `获取数据库结构总览失败: ${error?.message || error}`, success: false }; - } + return inspectDatabaseBundle({ args, connection: resolved.connection, runtime }); } case 'preview_table_rows': { const resolved = resolveConnectionOrFailure(connections, args.connectionId); diff --git a/frontend/src/components/ai/aiDatabaseToolHelpers.ts b/frontend/src/components/ai/aiDatabaseToolHelpers.ts new file mode 100644 index 0000000..89b2324 --- /dev/null +++ b/frontend/src/components/ai/aiDatabaseToolHelpers.ts @@ -0,0 +1,60 @@ +import type { SavedConnection } from '../../types'; +import { buildPaginatedSelectSQL, quoteQualifiedIdent } from '../../utils/sql'; + +export const normalizeTableList = (rows: any[]): string[] => + rows.map((row) => row.Table || row.table || (Object.values(row)[0] as string)); + +export const normalizeColumns = (rows: any[]) => + rows.map((column) => { + const keys = Object.keys(column); + return { + field: column.Field || column.field || column.COLUMN_NAME || column.column_name || column.Name || column.name || (keys.length > 0 ? column[keys[0]] : ''), + type: column.Type || column.type || column.DATA_TYPE || column.data_type || (keys.length > 1 ? column[keys[1]] : ''), + nullable: column.Null || column.null || column.IS_NULLABLE || column.is_nullable || column.Nullable || column.nullable || '', + default: column.Default || column.default || column.COLUMN_DEFAULT || column.column_default || column.DefaultValue || '', + comment: column.Comment || column.comment || column.COLUMN_COMMENT || column.column_comment || column.Description || '', + }; + }); + +export const normalizeColumnsWithTable = (rows: any[]) => + rows.map((column) => { + const keys = Object.keys(column); + return { + tableName: column.TableName || column.tableName || column.TABLE_NAME || column.table_name || (keys.length > 0 ? column[keys[0]] : ''), + name: column.Name || column.name || column.COLUMN_NAME || column.column_name || (keys.length > 1 ? column[keys[1]] : ''), + type: column.Type || column.type || column.DATA_TYPE || column.data_type || (keys.length > 2 ? column[keys[2]] : ''), + comment: column.Comment || column.comment || column.COLUMN_COMMENT || column.column_comment || '', + }; + }); + +export const normalizePreviewLimit = (input: unknown): number => { + const value = Math.floor(Number(input) || 20); + if (value < 1) return 1; + if (value > 100) return 100; + return value; +}; + +export const normalizeTableLimit = (input: unknown): number => { + const value = Math.floor(Number(input) || 80); + if (value < 1) return 1; + if (value > 200) return 200; + return value; +}; + +export const normalizePerTableColumnLimit = (input: unknown): number => { + const value = Math.floor(Number(input) || 8); + if (value < 1) return 1; + if (value > 30) return 30; + return value; +}; + +export const buildPreviewSQLForTable = (connection: SavedConnection, tableName: string, limit: number): string => { + const dbType = String(connection.config?.type || '').trim(); + return buildPaginatedSelectSQL( + dbType, + `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)}`, + '', + limit, + 0, + ); +};