♻️ refactor(ai-tools): 拆分 SQL 风险探针执行器

- 将 inspect_sql_risk 执行逻辑从聚合探针执行器中拆出

- 将 SQL 风险工具调用测试迁移到独立测试文件

- 保持本地工具调用行为不变并降低后续扩展成本
This commit is contained in:
Syngnat
2026-06-09 22:31:30 +08:00
parent 48de0b83c4
commit ce06bea744
4 changed files with 137 additions and 71 deletions

View File

@@ -0,0 +1,76 @@
import { describe, expect, it, vi } from 'vitest';
import type { AIToolCall, SavedConnection } from '../../types';
import { executeLocalAIToolCall } from './aiLocalToolExecutor';
const buildConnection = (): SavedConnection => ({
id: 'conn-1',
name: '主库',
config: {
type: 'mysql',
host: '127.0.0.1',
port: 3306,
user: 'root',
},
});
const buildToolCall = (name: string, args: Record<string, unknown>): AIToolCall => ({
id: `call-${name}`,
type: 'function',
function: {
name,
arguments: JSON.stringify(args),
},
});
describe('aiLocalToolExecutor sql risk inspection', () => {
it('inspects SQL risk from the active query tab and applies the AI safety check', async () => {
const checkSQL = vi.fn().mockResolvedValue({
allowed: false,
operationType: 'UPDATE',
});
const result = await executeLocalAIToolCall({
toolCall: buildToolCall('inspect_sql_risk', {}),
connections: [buildConnection()],
tabs: [{
id: 'tab-risk-1',
title: '批量更新',
type: 'query',
connectionId: 'conn-1',
dbName: 'crm',
query: 'UPDATE users SET status = 0',
}],
activeTabId: 'tab-risk-1',
mcpTools: [],
toolContextMap: new Map(),
runtime: {
getDatabases: vi.fn(),
getTables: vi.fn(),
checkSQL,
},
});
const payload = JSON.parse(result.content);
expect(result.success).toBe(true);
expect(checkSQL).toHaveBeenCalledWith('UPDATE users SET status = 0');
expect(payload).toMatchObject({
hasSql: true,
source: 'active_tab',
riskLevel: 'critical',
requiresUserConfirmation: true,
safetyCheck: {
allowed: false,
operationType: 'UPDATE',
},
activeTab: {
id: 'tab-risk-1',
connectionName: '主库',
dbName: 'crm',
},
});
expect(payload.activityKinds).toContain('write');
expect(payload.warnings).toContain('UPDATE 缺少 WHERE 条件,可能更新整表数据');
expect(payload.warnings).toContain('当前 AI 安全策略不允许执行 UPDATE 类型 SQL');
});
});

View File

@@ -794,56 +794,6 @@ describe('aiLocalToolExecutor', () => {
expect(query).not.toHaveBeenCalled();
});
it('inspects SQL risk from the active query tab and applies the AI safety check', async () => {
const checkSQL = vi.fn().mockResolvedValue({
allowed: false,
operationType: 'UPDATE',
});
const result = await executeLocalAIToolCall({
toolCall: buildToolCall('inspect_sql_risk', {}),
connections: [buildConnection()],
tabs: [{
id: 'tab-risk-1',
title: '批量更新',
type: 'query',
connectionId: 'conn-1',
dbName: 'crm',
query: 'UPDATE users SET status = 0',
}],
activeTabId: 'tab-risk-1',
mcpTools: [],
toolContextMap: new Map(),
runtime: {
getDatabases: vi.fn(),
getTables: vi.fn(),
checkSQL,
},
});
const payload = JSON.parse(result.content);
expect(result.success).toBe(true);
expect(checkSQL).toHaveBeenCalledWith('UPDATE users SET status = 0');
expect(payload).toMatchObject({
hasSql: true,
source: 'active_tab',
riskLevel: 'critical',
requiresUserConfirmation: true,
safetyCheck: {
allowed: false,
operationType: 'UPDATE',
},
activeTab: {
id: 'tab-risk-1',
connectionName: '主库',
dbName: 'crm',
},
});
expect(payload.activityKinds).toContain('write');
expect(payload.warnings).toContain('UPDATE 缺少 WHERE 条件,可能更新整表数据');
expect(payload.warnings).toContain('当前 AI 安全策略不允许执行 UPDATE 类型 SQL');
});
it('returns a cross-table column summary for get_all_columns', async () => {
const result = await executeLocalAIToolCall({
toolCall: buildToolCall('get_all_columns', {

View File

@@ -0,0 +1,48 @@
import type { SavedConnection, TabData } from '../../types';
import { buildSqlRiskSnapshot } from './aiSqlRiskInsights';
import type {
AISnapshotInspectionRuntime,
SnapshotInspectionResult,
} from './aiSnapshotInspectionToolTypes';
interface ExecuteSqlRiskInspectionToolCallOptions {
toolName: string;
args: Record<string, any>;
connections: SavedConnection[];
tabs?: TabData[];
activeTabId?: string | null;
runtime?: AISnapshotInspectionRuntime;
}
export async function executeSqlRiskInspectionToolCall({
toolName,
args,
connections,
tabs = [],
activeTabId = null,
runtime,
}: ExecuteSqlRiskInspectionToolCallOptions): Promise<SnapshotInspectionResult | null> {
if (toolName !== 'inspect_sql_risk') {
return null;
}
const candidateSql = String(args.sql || '').trim();
const activeTab = tabs.find((tab) => tab.id === activeTabId);
const activeTabSql = activeTab?.type === 'query' ? String(activeTab.query || '').trim() : '';
const sqlForCheck = candidateSql || activeTabSql;
const safetyCheck = sqlForCheck && typeof runtime?.checkSQL === 'function'
? await runtime.checkSQL(sqlForCheck)
: undefined;
return {
content: JSON.stringify(buildSqlRiskSnapshot({
sql: candidateSql,
previewCharLimit: args.previewCharLimit,
tabs,
activeTabId,
connections,
safetyCheck,
})),
success: true,
};
}

View File

@@ -28,7 +28,6 @@ import {
buildRecentSqlActivitySnapshot,
buildRecentSqlLogsSnapshot,
} from './aiSqlLogInsights';
import { buildSqlRiskSnapshot } from './aiSqlRiskInsights';
import {
buildActiveTabSnapshot,
buildWorkspaceTabsSnapshot,
@@ -41,6 +40,7 @@ import type {
AISnapshotInspectionRuntime,
SnapshotInspectionResult,
} from './aiSnapshotInspectionToolTypes';
import { executeSqlRiskInspectionToolCall } from './aiSnapshotInspectionSqlRiskToolExecutor';
interface ExecuteSnapshotInspectionToolCallOptions {
toolName: string;
@@ -107,6 +107,18 @@ export async function executeSnapshotInspectionToolCall(
return aiConfigResult;
}
const sqlRiskResult = await executeSqlRiskInspectionToolCall({
toolName,
args,
connections,
tabs,
activeTabId,
runtime,
});
if (sqlRiskResult) {
return sqlRiskResult;
}
switch (toolName) {
case 'inspect_current_connection':
return {
@@ -252,26 +264,6 @@ export async function executeSnapshotInspectionToolCall(
})),
success: true,
};
case 'inspect_sql_risk': {
const candidateSql = String(args.sql || '').trim();
const activeTab = tabs.find((tab) => tab.id === activeTabId);
const activeTabSql = activeTab?.type === 'query' ? String(activeTab.query || '').trim() : '';
const sqlForCheck = candidateSql || activeTabSql;
const safetyCheck = sqlForCheck && typeof runtime?.checkSQL === 'function'
? await runtime.checkSQL(sqlForCheck)
: undefined;
return {
content: JSON.stringify(buildSqlRiskSnapshot({
sql: candidateSql,
previewCharLimit: args.previewCharLimit,
tabs,
activeTabId,
connections,
safetyCheck,
})),
success: true,
};
}
case 'inspect_app_logs': {
const readResult = typeof runtime?.readAppLogTail === 'function'
? await runtime.readAppLogTail(Number(args.lineLimit) || 80, String(args.keyword || ''))