mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-14 10:29:52 +08:00
♻️ refactor(ai-tools): 拆分 SQL 风险探针执行器
- 将 inspect_sql_risk 执行逻辑从聚合探针执行器中拆出 - 将 SQL 风险工具调用测试迁移到独立测试文件 - 保持本地工具调用行为不变并降低后续扩展成本
This commit is contained in:
@@ -0,0 +1,76 @@
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import type { AIToolCall, SavedConnection } from '../../types';
|
||||
import { executeLocalAIToolCall } from './aiLocalToolExecutor';
|
||||
|
||||
const buildConnection = (): SavedConnection => ({
|
||||
id: 'conn-1',
|
||||
name: '主库',
|
||||
config: {
|
||||
type: 'mysql',
|
||||
host: '127.0.0.1',
|
||||
port: 3306,
|
||||
user: 'root',
|
||||
},
|
||||
});
|
||||
|
||||
const buildToolCall = (name: string, args: Record<string, unknown>): AIToolCall => ({
|
||||
id: `call-${name}`,
|
||||
type: 'function',
|
||||
function: {
|
||||
name,
|
||||
arguments: JSON.stringify(args),
|
||||
},
|
||||
});
|
||||
|
||||
describe('aiLocalToolExecutor sql risk inspection', () => {
|
||||
it('inspects SQL risk from the active query tab and applies the AI safety check', async () => {
|
||||
const checkSQL = vi.fn().mockResolvedValue({
|
||||
allowed: false,
|
||||
operationType: 'UPDATE',
|
||||
});
|
||||
|
||||
const result = await executeLocalAIToolCall({
|
||||
toolCall: buildToolCall('inspect_sql_risk', {}),
|
||||
connections: [buildConnection()],
|
||||
tabs: [{
|
||||
id: 'tab-risk-1',
|
||||
title: '批量更新',
|
||||
type: 'query',
|
||||
connectionId: 'conn-1',
|
||||
dbName: 'crm',
|
||||
query: 'UPDATE users SET status = 0',
|
||||
}],
|
||||
activeTabId: 'tab-risk-1',
|
||||
mcpTools: [],
|
||||
toolContextMap: new Map(),
|
||||
runtime: {
|
||||
getDatabases: vi.fn(),
|
||||
getTables: vi.fn(),
|
||||
checkSQL,
|
||||
},
|
||||
});
|
||||
|
||||
const payload = JSON.parse(result.content);
|
||||
expect(result.success).toBe(true);
|
||||
expect(checkSQL).toHaveBeenCalledWith('UPDATE users SET status = 0');
|
||||
expect(payload).toMatchObject({
|
||||
hasSql: true,
|
||||
source: 'active_tab',
|
||||
riskLevel: 'critical',
|
||||
requiresUserConfirmation: true,
|
||||
safetyCheck: {
|
||||
allowed: false,
|
||||
operationType: 'UPDATE',
|
||||
},
|
||||
activeTab: {
|
||||
id: 'tab-risk-1',
|
||||
connectionName: '主库',
|
||||
dbName: 'crm',
|
||||
},
|
||||
});
|
||||
expect(payload.activityKinds).toContain('write');
|
||||
expect(payload.warnings).toContain('UPDATE 缺少 WHERE 条件,可能更新整表数据');
|
||||
expect(payload.warnings).toContain('当前 AI 安全策略不允许执行 UPDATE 类型 SQL');
|
||||
});
|
||||
});
|
||||
@@ -794,56 +794,6 @@ describe('aiLocalToolExecutor', () => {
|
||||
expect(query).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('inspects SQL risk from the active query tab and applies the AI safety check', async () => {
|
||||
const checkSQL = vi.fn().mockResolvedValue({
|
||||
allowed: false,
|
||||
operationType: 'UPDATE',
|
||||
});
|
||||
|
||||
const result = await executeLocalAIToolCall({
|
||||
toolCall: buildToolCall('inspect_sql_risk', {}),
|
||||
connections: [buildConnection()],
|
||||
tabs: [{
|
||||
id: 'tab-risk-1',
|
||||
title: '批量更新',
|
||||
type: 'query',
|
||||
connectionId: 'conn-1',
|
||||
dbName: 'crm',
|
||||
query: 'UPDATE users SET status = 0',
|
||||
}],
|
||||
activeTabId: 'tab-risk-1',
|
||||
mcpTools: [],
|
||||
toolContextMap: new Map(),
|
||||
runtime: {
|
||||
getDatabases: vi.fn(),
|
||||
getTables: vi.fn(),
|
||||
checkSQL,
|
||||
},
|
||||
});
|
||||
|
||||
const payload = JSON.parse(result.content);
|
||||
expect(result.success).toBe(true);
|
||||
expect(checkSQL).toHaveBeenCalledWith('UPDATE users SET status = 0');
|
||||
expect(payload).toMatchObject({
|
||||
hasSql: true,
|
||||
source: 'active_tab',
|
||||
riskLevel: 'critical',
|
||||
requiresUserConfirmation: true,
|
||||
safetyCheck: {
|
||||
allowed: false,
|
||||
operationType: 'UPDATE',
|
||||
},
|
||||
activeTab: {
|
||||
id: 'tab-risk-1',
|
||||
connectionName: '主库',
|
||||
dbName: 'crm',
|
||||
},
|
||||
});
|
||||
expect(payload.activityKinds).toContain('write');
|
||||
expect(payload.warnings).toContain('UPDATE 缺少 WHERE 条件,可能更新整表数据');
|
||||
expect(payload.warnings).toContain('当前 AI 安全策略不允许执行 UPDATE 类型 SQL');
|
||||
});
|
||||
|
||||
it('returns a cross-table column summary for get_all_columns', async () => {
|
||||
const result = await executeLocalAIToolCall({
|
||||
toolCall: buildToolCall('get_all_columns', {
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
import type { SavedConnection, TabData } from '../../types';
|
||||
import { buildSqlRiskSnapshot } from './aiSqlRiskInsights';
|
||||
import type {
|
||||
AISnapshotInspectionRuntime,
|
||||
SnapshotInspectionResult,
|
||||
} from './aiSnapshotInspectionToolTypes';
|
||||
|
||||
interface ExecuteSqlRiskInspectionToolCallOptions {
|
||||
toolName: string;
|
||||
args: Record<string, any>;
|
||||
connections: SavedConnection[];
|
||||
tabs?: TabData[];
|
||||
activeTabId?: string | null;
|
||||
runtime?: AISnapshotInspectionRuntime;
|
||||
}
|
||||
|
||||
export async function executeSqlRiskInspectionToolCall({
|
||||
toolName,
|
||||
args,
|
||||
connections,
|
||||
tabs = [],
|
||||
activeTabId = null,
|
||||
runtime,
|
||||
}: ExecuteSqlRiskInspectionToolCallOptions): Promise<SnapshotInspectionResult | null> {
|
||||
if (toolName !== 'inspect_sql_risk') {
|
||||
return null;
|
||||
}
|
||||
|
||||
const candidateSql = String(args.sql || '').trim();
|
||||
const activeTab = tabs.find((tab) => tab.id === activeTabId);
|
||||
const activeTabSql = activeTab?.type === 'query' ? String(activeTab.query || '').trim() : '';
|
||||
const sqlForCheck = candidateSql || activeTabSql;
|
||||
const safetyCheck = sqlForCheck && typeof runtime?.checkSQL === 'function'
|
||||
? await runtime.checkSQL(sqlForCheck)
|
||||
: undefined;
|
||||
|
||||
return {
|
||||
content: JSON.stringify(buildSqlRiskSnapshot({
|
||||
sql: candidateSql,
|
||||
previewCharLimit: args.previewCharLimit,
|
||||
tabs,
|
||||
activeTabId,
|
||||
connections,
|
||||
safetyCheck,
|
||||
})),
|
||||
success: true,
|
||||
};
|
||||
}
|
||||
@@ -28,7 +28,6 @@ import {
|
||||
buildRecentSqlActivitySnapshot,
|
||||
buildRecentSqlLogsSnapshot,
|
||||
} from './aiSqlLogInsights';
|
||||
import { buildSqlRiskSnapshot } from './aiSqlRiskInsights';
|
||||
import {
|
||||
buildActiveTabSnapshot,
|
||||
buildWorkspaceTabsSnapshot,
|
||||
@@ -41,6 +40,7 @@ import type {
|
||||
AISnapshotInspectionRuntime,
|
||||
SnapshotInspectionResult,
|
||||
} from './aiSnapshotInspectionToolTypes';
|
||||
import { executeSqlRiskInspectionToolCall } from './aiSnapshotInspectionSqlRiskToolExecutor';
|
||||
|
||||
interface ExecuteSnapshotInspectionToolCallOptions {
|
||||
toolName: string;
|
||||
@@ -107,6 +107,18 @@ export async function executeSnapshotInspectionToolCall(
|
||||
return aiConfigResult;
|
||||
}
|
||||
|
||||
const sqlRiskResult = await executeSqlRiskInspectionToolCall({
|
||||
toolName,
|
||||
args,
|
||||
connections,
|
||||
tabs,
|
||||
activeTabId,
|
||||
runtime,
|
||||
});
|
||||
if (sqlRiskResult) {
|
||||
return sqlRiskResult;
|
||||
}
|
||||
|
||||
switch (toolName) {
|
||||
case 'inspect_current_connection':
|
||||
return {
|
||||
@@ -252,26 +264,6 @@ export async function executeSnapshotInspectionToolCall(
|
||||
})),
|
||||
success: true,
|
||||
};
|
||||
case 'inspect_sql_risk': {
|
||||
const candidateSql = String(args.sql || '').trim();
|
||||
const activeTab = tabs.find((tab) => tab.id === activeTabId);
|
||||
const activeTabSql = activeTab?.type === 'query' ? String(activeTab.query || '').trim() : '';
|
||||
const sqlForCheck = candidateSql || activeTabSql;
|
||||
const safetyCheck = sqlForCheck && typeof runtime?.checkSQL === 'function'
|
||||
? await runtime.checkSQL(sqlForCheck)
|
||||
: undefined;
|
||||
return {
|
||||
content: JSON.stringify(buildSqlRiskSnapshot({
|
||||
sql: candidateSql,
|
||||
previewCharLimit: args.previewCharLimit,
|
||||
tabs,
|
||||
activeTabId,
|
||||
connections,
|
||||
safetyCheck,
|
||||
})),
|
||||
success: true,
|
||||
};
|
||||
}
|
||||
case 'inspect_app_logs': {
|
||||
const readResult = typeof runtime?.readAppLogTail === 'function'
|
||||
? await runtime.readAppLogTail(Number(args.lineLimit) || 80, String(args.keyword || ''))
|
||||
|
||||
Reference in New Issue
Block a user