From ce06bea7445bf725423bfa09c0a349836a328ce4 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Tue, 9 Jun 2026 22:31:30 +0800 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(ai-tools):=20?= =?UTF-8?q?=E6=8B=86=E5=88=86=20SQL=20=E9=A3=8E=E9=99=A9=E6=8E=A2=E9=92=88?= =?UTF-8?q?=E6=89=A7=E8=A1=8C=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 inspect_sql_risk 执行逻辑从聚合探针执行器中拆出 - 将 SQL 风险工具调用测试迁移到独立测试文件 - 保持本地工具调用行为不变并降低后续扩展成本 --- ...ocalToolExecutor.sqlRiskInspection.test.ts | 76 +++++++++++++++++++ .../components/ai/aiLocalToolExecutor.test.ts | 50 ------------ ...aiSnapshotInspectionSqlRiskToolExecutor.ts | 48 ++++++++++++ .../ai/aiSnapshotInspectionToolExecutor.ts | 34 ++++----- 4 files changed, 137 insertions(+), 71 deletions(-) create mode 100644 frontend/src/components/ai/aiLocalToolExecutor.sqlRiskInspection.test.ts create mode 100644 frontend/src/components/ai/aiSnapshotInspectionSqlRiskToolExecutor.ts diff --git a/frontend/src/components/ai/aiLocalToolExecutor.sqlRiskInspection.test.ts b/frontend/src/components/ai/aiLocalToolExecutor.sqlRiskInspection.test.ts new file mode 100644 index 0000000..bb2639c --- /dev/null +++ b/frontend/src/components/ai/aiLocalToolExecutor.sqlRiskInspection.test.ts @@ -0,0 +1,76 @@ +import { describe, expect, it, vi } from 'vitest'; + +import type { AIToolCall, SavedConnection } from '../../types'; +import { executeLocalAIToolCall } from './aiLocalToolExecutor'; + +const buildConnection = (): SavedConnection => ({ + id: 'conn-1', + name: '主库', + config: { + type: 'mysql', + host: '127.0.0.1', + port: 3306, + user: 'root', + }, +}); + +const buildToolCall = (name: string, args: Record): AIToolCall => ({ + id: `call-${name}`, + type: 'function', + function: { + name, + arguments: JSON.stringify(args), + }, +}); + +describe('aiLocalToolExecutor sql risk inspection', () => { + it('inspects SQL risk from the active query tab and applies the AI safety check', async () => { + const checkSQL = vi.fn().mockResolvedValue({ + allowed: false, + operationType: 'UPDATE', + }); + + const result = await executeLocalAIToolCall({ + toolCall: buildToolCall('inspect_sql_risk', {}), + connections: [buildConnection()], + tabs: [{ + id: 'tab-risk-1', + title: '批量更新', + type: 'query', + connectionId: 'conn-1', + dbName: 'crm', + query: 'UPDATE users SET status = 0', + }], + activeTabId: 'tab-risk-1', + mcpTools: [], + toolContextMap: new Map(), + runtime: { + getDatabases: vi.fn(), + getTables: vi.fn(), + checkSQL, + }, + }); + + const payload = JSON.parse(result.content); + expect(result.success).toBe(true); + expect(checkSQL).toHaveBeenCalledWith('UPDATE users SET status = 0'); + expect(payload).toMatchObject({ + hasSql: true, + source: 'active_tab', + riskLevel: 'critical', + requiresUserConfirmation: true, + safetyCheck: { + allowed: false, + operationType: 'UPDATE', + }, + activeTab: { + id: 'tab-risk-1', + connectionName: '主库', + dbName: 'crm', + }, + }); + expect(payload.activityKinds).toContain('write'); + expect(payload.warnings).toContain('UPDATE 缺少 WHERE 条件,可能更新整表数据'); + expect(payload.warnings).toContain('当前 AI 安全策略不允许执行 UPDATE 类型 SQL'); + }); +}); diff --git a/frontend/src/components/ai/aiLocalToolExecutor.test.ts b/frontend/src/components/ai/aiLocalToolExecutor.test.ts index b3f7aac..8586be9 100644 --- a/frontend/src/components/ai/aiLocalToolExecutor.test.ts +++ b/frontend/src/components/ai/aiLocalToolExecutor.test.ts @@ -794,56 +794,6 @@ describe('aiLocalToolExecutor', () => { expect(query).not.toHaveBeenCalled(); }); - it('inspects SQL risk from the active query tab and applies the AI safety check', async () => { - const checkSQL = vi.fn().mockResolvedValue({ - allowed: false, - operationType: 'UPDATE', - }); - - const result = await executeLocalAIToolCall({ - toolCall: buildToolCall('inspect_sql_risk', {}), - connections: [buildConnection()], - tabs: [{ - id: 'tab-risk-1', - title: '批量更新', - type: 'query', - connectionId: 'conn-1', - dbName: 'crm', - query: 'UPDATE users SET status = 0', - }], - activeTabId: 'tab-risk-1', - mcpTools: [], - toolContextMap: new Map(), - runtime: { - getDatabases: vi.fn(), - getTables: vi.fn(), - checkSQL, - }, - }); - - const payload = JSON.parse(result.content); - expect(result.success).toBe(true); - expect(checkSQL).toHaveBeenCalledWith('UPDATE users SET status = 0'); - expect(payload).toMatchObject({ - hasSql: true, - source: 'active_tab', - riskLevel: 'critical', - requiresUserConfirmation: true, - safetyCheck: { - allowed: false, - operationType: 'UPDATE', - }, - activeTab: { - id: 'tab-risk-1', - connectionName: '主库', - dbName: 'crm', - }, - }); - expect(payload.activityKinds).toContain('write'); - expect(payload.warnings).toContain('UPDATE 缺少 WHERE 条件,可能更新整表数据'); - expect(payload.warnings).toContain('当前 AI 安全策略不允许执行 UPDATE 类型 SQL'); - }); - it('returns a cross-table column summary for get_all_columns', async () => { const result = await executeLocalAIToolCall({ toolCall: buildToolCall('get_all_columns', { diff --git a/frontend/src/components/ai/aiSnapshotInspectionSqlRiskToolExecutor.ts b/frontend/src/components/ai/aiSnapshotInspectionSqlRiskToolExecutor.ts new file mode 100644 index 0000000..aa34c3c --- /dev/null +++ b/frontend/src/components/ai/aiSnapshotInspectionSqlRiskToolExecutor.ts @@ -0,0 +1,48 @@ +import type { SavedConnection, TabData } from '../../types'; +import { buildSqlRiskSnapshot } from './aiSqlRiskInsights'; +import type { + AISnapshotInspectionRuntime, + SnapshotInspectionResult, +} from './aiSnapshotInspectionToolTypes'; + +interface ExecuteSqlRiskInspectionToolCallOptions { + toolName: string; + args: Record; + connections: SavedConnection[]; + tabs?: TabData[]; + activeTabId?: string | null; + runtime?: AISnapshotInspectionRuntime; +} + +export async function executeSqlRiskInspectionToolCall({ + toolName, + args, + connections, + tabs = [], + activeTabId = null, + runtime, +}: ExecuteSqlRiskInspectionToolCallOptions): Promise { + if (toolName !== 'inspect_sql_risk') { + return null; + } + + const candidateSql = String(args.sql || '').trim(); + const activeTab = tabs.find((tab) => tab.id === activeTabId); + const activeTabSql = activeTab?.type === 'query' ? String(activeTab.query || '').trim() : ''; + const sqlForCheck = candidateSql || activeTabSql; + const safetyCheck = sqlForCheck && typeof runtime?.checkSQL === 'function' + ? await runtime.checkSQL(sqlForCheck) + : undefined; + + return { + content: JSON.stringify(buildSqlRiskSnapshot({ + sql: candidateSql, + previewCharLimit: args.previewCharLimit, + tabs, + activeTabId, + connections, + safetyCheck, + })), + success: true, + }; +} diff --git a/frontend/src/components/ai/aiSnapshotInspectionToolExecutor.ts b/frontend/src/components/ai/aiSnapshotInspectionToolExecutor.ts index b4dc58c..b7bedd0 100644 --- a/frontend/src/components/ai/aiSnapshotInspectionToolExecutor.ts +++ b/frontend/src/components/ai/aiSnapshotInspectionToolExecutor.ts @@ -28,7 +28,6 @@ import { buildRecentSqlActivitySnapshot, buildRecentSqlLogsSnapshot, } from './aiSqlLogInsights'; -import { buildSqlRiskSnapshot } from './aiSqlRiskInsights'; import { buildActiveTabSnapshot, buildWorkspaceTabsSnapshot, @@ -41,6 +40,7 @@ import type { AISnapshotInspectionRuntime, SnapshotInspectionResult, } from './aiSnapshotInspectionToolTypes'; +import { executeSqlRiskInspectionToolCall } from './aiSnapshotInspectionSqlRiskToolExecutor'; interface ExecuteSnapshotInspectionToolCallOptions { toolName: string; @@ -107,6 +107,18 @@ export async function executeSnapshotInspectionToolCall( return aiConfigResult; } + const sqlRiskResult = await executeSqlRiskInspectionToolCall({ + toolName, + args, + connections, + tabs, + activeTabId, + runtime, + }); + if (sqlRiskResult) { + return sqlRiskResult; + } + switch (toolName) { case 'inspect_current_connection': return { @@ -252,26 +264,6 @@ export async function executeSnapshotInspectionToolCall( })), success: true, }; - case 'inspect_sql_risk': { - const candidateSql = String(args.sql || '').trim(); - const activeTab = tabs.find((tab) => tab.id === activeTabId); - const activeTabSql = activeTab?.type === 'query' ? String(activeTab.query || '').trim() : ''; - const sqlForCheck = candidateSql || activeTabSql; - const safetyCheck = sqlForCheck && typeof runtime?.checkSQL === 'function' - ? await runtime.checkSQL(sqlForCheck) - : undefined; - return { - content: JSON.stringify(buildSqlRiskSnapshot({ - sql: candidateSql, - previewCharLimit: args.previewCharLimit, - tabs, - activeTabId, - connections, - safetyCheck, - })), - success: true, - }; - } case 'inspect_app_logs': { const readResult = typeof runtime?.readAppLogTail === 'function' ? await runtime.readAppLogTail(Number(args.lineLimit) || 80, String(args.keyword || ''))