🐛 fix(ai-safety): 修正完全模式执行口径与本地工具失败判定

- 修正完全模式下 DML 与过程调用的安全提示和限制说明
- 区分连接探针失败与可恢复 SQL 执行错误,避免数据探针被误终止
- 修复本地 execute_sql 写语句结果返回 affectedRows
- 补充 AI 安全、本地工具执行与 SQL 限制回归测试
This commit is contained in:
Syngnat
2026-06-17 09:49:59 +08:00
parent 7ff3e00759
commit 3e140c1bc6
14 changed files with 254 additions and 24 deletions

View File

@@ -7,7 +7,7 @@ import type { OverlayWorkbenchTheme } from '../../utils/overlayWorkbenchTheme';
const SAFETY_OPTIONS: { label: string; value: AISafetyLevel; desc: string; color: string; icon: string }[] = [
{ label: '只读模式', value: 'readonly', desc: 'AI 仅可执行 SELECT 等查询操作,最安全', color: '#22c55e', icon: '🔒' },
{ label: '读写模式', value: 'readwrite', desc: 'AI 可执行 INSERT/UPDATE/DELETE危险操作需二次确认', color: '#f59e0b', icon: '⚠️' },
{ label: '完全模式', value: 'full', desc: 'AI 可执行所有操作(含 DDL),高危操作自动告警', color: '#ef4444', icon: '🔓' },
{ label: '完全模式', value: 'full', desc: 'AI 可执行所有操作(含 DDL/过程调用),高危或未识别操作会告警', color: '#ef4444', icon: '🔓' },
];
interface AISettingsSafetySectionProps {

View File

@@ -23,6 +23,7 @@ interface ExecuteDatabaseToolCallOptions {
interface ToolExecutionResult {
content: string;
success: boolean;
countsAsProbeFailure?: boolean;
}
const findConnection = (connections: SavedConnection[], connectionId: string) =>
@@ -39,12 +40,45 @@ const resolveConnectionOrFailure = (
failure: {
content: 'Connection not found',
success: false,
countsAsProbeFailure: true,
},
};
}
return { connection };
};
const CONNECTION_ERROR_KEYWORDS = [
'connection not found',
'invalid connection',
'bad connection',
'driver: bad connection',
'connection refused',
'connection reset',
'closed network connection',
'server has gone away',
'broken pipe',
'no such host',
'network is unreachable',
'context deadline exceeded',
'i/o timeout',
'timeout',
'eof',
'连接失败',
'连接异常',
'连接超时',
'连接已关闭',
'网络超时',
'网络异常',
];
const countsAsProbeFailure = (message: unknown): boolean => {
const text = String(message || '').trim().toLowerCase();
if (!text) {
return true;
}
return CONNECTION_ERROR_KEYWORDS.some((keyword) => text.includes(keyword));
};
export async function executeDatabaseToolCall(
options: ExecuteDatabaseToolCallOptions,
): Promise<ToolExecutionResult | null> {
@@ -75,9 +109,14 @@ export async function executeDatabaseToolCall(
}
return { content: JSON.stringify(databaseNames), success: true };
}
return { content: result?.message || 'Failed to fetch DBs', success: false };
return {
content: result?.message || 'Failed to fetch DBs',
success: false,
countsAsProbeFailure: countsAsProbeFailure(result?.message),
};
} catch (error: any) {
return { content: `获取数据库列表失败: ${error?.message || error}`, success: false };
const message = `获取数据库列表失败: ${error?.message || error}`;
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
}
}
case 'get_tables': {
@@ -99,9 +138,14 @@ export async function executeDatabaseToolCall(
});
return { content: JSON.stringify(tableNames), success: true };
}
return { content: result?.message || 'Failed to fetch Tables', success: false };
return {
content: result?.message || 'Failed to fetch Tables',
success: false,
countsAsProbeFailure: countsAsProbeFailure(result?.message),
};
} catch (error: any) {
return { content: `获取表列表失败: ${error?.message || error}`, success: false };
const message = `获取表列表失败: ${error?.message || error}`;
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
}
}
case 'get_all_columns': {
@@ -125,9 +169,14 @@ export async function executeDatabaseToolCall(
success: true,
};
}
return { content: result?.message || 'Failed to fetch all columns', success: false };
return {
content: result?.message || 'Failed to fetch all columns',
success: false,
countsAsProbeFailure: countsAsProbeFailure(result?.message),
};
} catch (error: any) {
return { content: `获取全库字段摘要失败: ${error?.message || error}`, success: false };
const message = `获取全库字段摘要失败: ${error?.message || error}`;
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
}
}
case 'get_columns': {
@@ -145,9 +194,14 @@ export async function executeDatabaseToolCall(
success: true,
};
}
return { content: result?.message || 'Failed to fetch columns', success: false };
return {
content: result?.message || 'Failed to fetch columns',
success: false,
countsAsProbeFailure: countsAsProbeFailure(result?.message),
};
} catch (error: any) {
return { content: `获取字段列表失败: ${error?.message || error}`, success: false };
const message = `获取字段列表失败: ${error?.message || error}`;
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
}
}
case 'get_indexes': {
@@ -160,9 +214,11 @@ export async function executeDatabaseToolCall(
return {
content: result?.success && Array.isArray(result.data) ? JSON.stringify(result.data) : (result?.message || 'Failed to fetch indexes'),
success: !!result?.success && Array.isArray(result.data),
countsAsProbeFailure: result?.success ? false : countsAsProbeFailure(result?.message),
};
} catch (error: any) {
return { content: `获取索引定义失败: ${error?.message || error}`, success: false };
const message = `获取索引定义失败: ${error?.message || error}`;
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
}
}
case 'get_foreign_keys': {
@@ -175,9 +231,11 @@ export async function executeDatabaseToolCall(
return {
content: result?.success && Array.isArray(result.data) ? JSON.stringify(result.data) : (result?.message || 'Failed to fetch foreign keys'),
success: !!result?.success && Array.isArray(result.data),
countsAsProbeFailure: result?.success ? false : countsAsProbeFailure(result?.message),
};
} catch (error: any) {
return { content: `获取外键关系失败: ${error?.message || error}`, success: false };
const message = `获取外键关系失败: ${error?.message || error}`;
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
}
}
case 'get_triggers': {
@@ -190,9 +248,11 @@ export async function executeDatabaseToolCall(
return {
content: result?.success && Array.isArray(result.data) ? JSON.stringify(result.data) : (result?.message || 'Failed to fetch triggers'),
success: !!result?.success && Array.isArray(result.data),
countsAsProbeFailure: result?.success ? false : countsAsProbeFailure(result?.message),
};
} catch (error: any) {
return { content: `获取触发器定义失败: ${error?.message || error}`, success: false };
const message = `获取触发器定义失败: ${error?.message || error}`;
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
}
}
case 'get_table_ddl': {
@@ -207,9 +267,14 @@ export async function executeDatabaseToolCall(
fetchDDL: () => runtime.showCreateTable(rpcConfig, safeDbName, safeTable),
fetchColumns: () => runtime.getColumns(rpcConfig, safeDbName, safeTable),
});
return { content: result.content, success: result.success };
return {
content: result.content,
success: result.success,
countsAsProbeFailure: result.success ? false : countsAsProbeFailure(result.content),
};
} catch (error: any) {
return { content: `获取建表语句失败: ${error?.message || error}`, success: false };
const message = `获取建表语句失败: ${error?.message || error}`;
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
}
}
case 'inspect_table_bundle': {
@@ -245,9 +310,14 @@ export async function executeDatabaseToolCall(
success: true,
};
}
return { content: result?.message || 'Failed to preview table rows', success: false };
return {
content: result?.message || 'Failed to preview table rows',
success: false,
countsAsProbeFailure: countsAsProbeFailure(result?.message),
};
} catch (error: any) {
return { content: `预览表样例数据失败: ${error?.message || error}`, success: false };
const message = `预览表样例数据失败: ${error?.message || error}`;
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
}
}
case 'execute_sql': {
@@ -262,6 +332,7 @@ export async function executeDatabaseToolCall(
return {
content: `安全策略拦截:当前安全级别不允许执行 ${checkResult.operationType} 类型的 SQL。请将 SQL 展示给用户,让用户手动执行。`,
success: false,
countsAsProbeFailure: false,
};
}
}
@@ -270,18 +341,31 @@ export async function executeDatabaseToolCall(
safeSql,
50,
resolved.connection.config?.driver || '',
{ oceanBaseProtocol: resolved.connection.config?.oceanBaseProtocol },
);
const result = await runtime.query(buildRpcConnectionConfig(resolved.connection.config) as any, safeDbName, finalSql);
if (result?.success) {
const affectedRows = Number((result.data as Record<string, unknown> | null | undefined)?.affectedRows);
if (Number.isFinite(affectedRows)) {
return {
content: JSON.stringify({ affectedRows }),
success: true,
};
}
const rows = Array.isArray(result.data) ? result.data : [];
return {
content: JSON.stringify({ rowCount: rows.length, data: rows.slice(0, 50) }),
success: true,
};
}
return { content: result?.message || 'SQL 执行失败', success: false };
return {
content: result?.message || 'SQL 执行失败',
success: false,
countsAsProbeFailure: countsAsProbeFailure(result?.message),
};
} catch (error: any) {
return { content: `SQL 执行异常: ${error?.message || error}`, success: false };
const message = `SQL 执行异常: ${error?.message || error}`;
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
}
}
default:

View File

@@ -201,6 +201,86 @@ describe('aiLocalToolExecutor', () => {
expect(query).not.toHaveBeenCalled();
});
it('treats OceanBase Oracle SQL execution errors as recoverable and uses Oracle readonly preview SQL', async () => {
const query = vi.fn().mockResolvedValue({
success: false,
message: "oceanbase: error 900 (42000): ORA-00900 near '50 OFFSET 0'",
});
const result = await executeLocalAIToolCall({
toolCall: buildToolCall('execute_sql', {
connectionId: 'conn-1',
dbName: 'SYS',
sql: 'SELECT 1 FROM DUAL',
}),
connections: [{
...buildConnection(),
config: {
type: 'oceanbase',
host: '127.0.0.1',
port: 2881,
user: 'sys',
driver: 'oceanbase',
oceanBaseProtocol: 'oracle',
},
}],
mcpTools: [],
toolContextMap: new Map(),
runtime: {
getDatabases: vi.fn(),
getTables: vi.fn(),
getColumns: vi.fn(),
getIndexes: vi.fn(),
getForeignKeys: vi.fn(),
getTriggers: vi.fn(),
showCreateTable: vi.fn(),
query,
checkSQL: vi.fn().mockResolvedValue({ allowed: true, operationType: 'query' }),
},
});
expect(result.success).toBe(false);
expect(result.countsAsProbeFailure).toBe(false);
expect(query).toHaveBeenCalledWith(
expect.anything(),
'SYS',
'SELECT * FROM (SELECT 1 FROM DUAL) WHERE ROWNUM <= 50',
);
});
it('returns affectedRows for execute_sql write statements instead of pretending rowCount is zero', async () => {
const query = vi.fn().mockResolvedValue({
success: true,
data: {
affectedRows: 100000,
},
});
const result = await executeLocalAIToolCall({
toolCall: buildToolCall('execute_sql', {
connectionId: 'conn-1',
dbName: 'crm',
sql: 'INSERT INTO orders_archive SELECT * FROM orders',
}),
connections: [buildConnection()],
mcpTools: [],
toolContextMap: new Map(),
runtime: {
getDatabases: vi.fn(),
getTables: vi.fn(),
getColumns: vi.fn(),
getIndexes: vi.fn(),
getForeignKeys: vi.fn(),
getTriggers: vi.fn(),
showCreateTable: vi.fn(),
query,
checkSQL: vi.fn().mockResolvedValue({ allowed: true, operationType: 'INSERT' }),
},
});
expect(result.success).toBe(true);
expect(query).toHaveBeenCalledWith(expect.anything(), 'crm', 'INSERT INTO orders_archive SELECT * FROM orders');
expect(JSON.parse(result.content)).toEqual({ affectedRows: 100000 });
});
it('returns a cross-table column summary for get_all_columns', async () => {
const result = await executeLocalAIToolCall({
toolCall: buildToolCall('get_all_columns', {

View File

@@ -48,6 +48,7 @@ export interface ExecuteLocalAIToolCallResult {
content: string;
success: boolean;
toolName: string;
countsAsProbeFailure?: boolean;
}
const buildToolName = (toolCall: AIToolCall, descriptor?: AIMCPToolDescriptor) =>
@@ -106,6 +107,7 @@ export async function executeLocalAIToolCall({
content: snapshotInspectionResult.content,
success: snapshotInspectionResult.success,
toolName: buildToolName(toolCall, descriptor),
countsAsProbeFailure: snapshotInspectionResult.countsAsProbeFailure,
};
}
@@ -121,6 +123,7 @@ export async function executeLocalAIToolCall({
content: databaseToolResult.content,
success: databaseToolResult.success,
toolName: buildToolName(toolCall, descriptor),
countsAsProbeFailure: databaseToolResult.countsAsProbeFailure,
};
}

View File

@@ -62,4 +62,17 @@ describe('buildAISafetySnapshot', () => {
expect(snapshot.effectiveRestrictions.join('\n')).toContain('allowMutating=true');
expect(snapshot.effectiveRestrictions.join('\n')).toContain('当前 JVM 诊断明确禁止 mutating 命令');
});
it('describes full safety mode as allowing other statements with confirmation', () => {
const snapshot = buildAISafetySnapshot({
safetyLevel: 'full',
connections: [],
});
expect(snapshot.safetyLevel).toBe('full');
expect(snapshot.permissionMatrix.allowDML).toBe(true);
expect(snapshot.permissionMatrix.allowDDL).toBe(true);
expect(snapshot.sqlRuleText).toContain('允许所有 SQL 操作');
expect(snapshot.effectiveRestrictions.join('\n')).toContain('高风险或未识别语句仍会要求确认');
});
});

View File

@@ -9,7 +9,7 @@ const SAFETY_LEVEL_LABELS: Record<string, string> = {
const SAFETY_RULE_TEXTS: Record<string, string> = {
readonly: '只读模式仅允许查询语句。',
readwrite: '读写模式允许查询和 DMLDDL 仍会被阻止。',
full: '完全开放模式允许查询、DML 和 DDL未识别操作仍会被阻止。',
full: '完全开放模式允许所有 SQL 操作;高风险或未识别语句仍会要求确认。',
};
const normalizeSafetyLevel = (value: AISafetyLevel | string | undefined): string => {

View File

@@ -46,4 +46,5 @@ export interface AISnapshotInspectionRuntime {
export interface SnapshotInspectionResult {
content: string;
success: boolean;
countsAsProbeFailure?: boolean;
}

View File

@@ -211,6 +211,7 @@ const HighlightedCodeBlock: React.FC<HighlightedCodeBlockProps> = ({
displayText,
50,
activeConnectionConfig?.driver || '',
{ oceanBaseProtocol: activeConnectionConfig?.oceanBaseProtocol },
);
const response = await DBQuery(activeConnectionConfig, activeDbName || '', previewSql);
if (response.success && Array.isArray(response.data)) {

View File

@@ -11,6 +11,7 @@ const executeLocalAIToolCallMock = vi.hoisted(() => vi.fn(async ({ toolCall }: {
content: `result:${toolCall.function.name}`,
success: true,
toolName: toolCall.function.name,
countsAsProbeFailure: true,
})));
vi.mock('./aiChatPayloadDispatch', () => ({
@@ -191,4 +192,35 @@ describe('useAIChatLocalTools', () => {
renderer?.unmount();
});
});
it('does not auto-stop the probe after three recoverable SQL execution errors', async () => {
executeLocalAIToolCallMock.mockResolvedValue({
content: "oceanbase: error 900 (42000): ORA-00900 near '50 OFFSET 0'",
success: false,
toolName: 'execute_sql',
countsAsProbeFailure: false,
});
let renderer: ReactTestRenderer | undefined;
await act(async () => {
renderer = create(<LocalToolsHarness />);
});
expect(latestHook).toBeDefined();
for (let i = 0; i < 3; i += 1) {
const run = latestHook!.executeLocalTools([buildToolCall('execute_sql')], 'assistant-1');
await act(async () => {
await vi.advanceTimersByTimeAsync(150);
await run;
});
}
const messages = useStore.getState().aiChatHistory[SESSION_ID] || [];
expect(messages.some((message) => message.content.includes('探针连续 3 轮执行失败'))).toBe(false);
expect(dispatchAIChatPayloadMock).toHaveBeenCalledTimes(3);
await act(async () => {
renderer?.unmount();
});
});
});

View File

@@ -18,6 +18,7 @@ import { dispatchAIChatPayload } from './aiChatPayloadDispatch';
import {
buildToolResultMessage,
executeLocalAIToolCall,
type ExecuteLocalAIToolCallResult,
type AIToolContextEntry,
} from './aiLocalToolExecutor';
@@ -96,6 +97,7 @@ export const useAIChatLocalTools = ({
}
const results: AIChatMessage[] = [];
const executions: ExecuteLocalAIToolCallResult[] = [];
const currentConnections = useStore.getState().connections;
for (const toolCall of toolCalls) {
const currentState = useStore.getState();
@@ -119,6 +121,7 @@ export const useAIChatLocalTools = ({
userPromptSettings,
dynamicModels,
});
executions.push(execution);
const toolResultMsg: AIChatMessage = buildToolResultMessage({
id: nextMessageId(),
timestamp: Date.now(),
@@ -130,8 +133,9 @@ export const useAIChatLocalTools = ({
await new Promise((resolve) => setTimeout(resolve, 150));
}
const anySuccess = results.some((message) => message.success === true);
if (anySuccess) {
const roundCountsAsFailure = executions.length > 0
&& executions.every((execution) => execution.success !== true && execution.countsAsProbeFailure !== false);
if (!roundCountsAsFailure) {
toolCallRoundRef.current = 0;
} else {
toolCallRoundRef.current += 1;

View File

@@ -22,6 +22,11 @@ describe('buildAIReadonlyPreviewSQL', () => {
.toBe('SELECT * FROM (SELECT 1 FROM DUAL) WHERE ROWNUM <= 50');
});
it('treats OceanBase Oracle as Oracle dialect when building readonly preview SQL', () => {
expect(buildAIReadonlyPreviewSQL('oceanbase', 'SELECT 1 FROM DUAL;', 50, 'oceanbase', { oceanBaseProtocol: 'oracle' }))
.toBe('SELECT * FROM (SELECT 1 FROM DUAL) WHERE ROWNUM <= 50');
});
it('keeps MySQL-family SQL on LIMIT syntax', () => {
expect(buildAIReadonlyPreviewSQL('mysql', 'SELECT * FROM users', 50))
.toBe('SELECT * FROM users LIMIT 50 OFFSET 0');

View File

@@ -20,10 +20,16 @@ const hasExistingRowLimit = (dialect: string, sql: string): boolean => {
return (dialect === 'oracle' || dialect === 'dameng') && /\brownum\b/.test(text);
};
export const buildAIReadonlyPreviewSQL = (dbType: string, sql: string, limit = 50, driver = ''): string => {
export const buildAIReadonlyPreviewSQL = (
dbType: string,
sql: string,
limit = 50,
driver = '',
options?: { oceanBaseProtocol?: unknown },
): string => {
const baseSQL = trimSQLStatement(sql);
const safeLimit = Math.max(0, Math.floor(Number(limit) || 0));
const dialect = resolveSqlDialect(dbType, driver);
const dialect = resolveSqlDialect(dbType, driver, options);
if (!baseSQL || safeLimit <= 0 || !isAIReadonlySQL(baseSQL) || hasExistingRowLimit(dialect, baseSQL)) {
return baseSQL;
}

View File

@@ -124,6 +124,7 @@ func TestGuard_Full(t *testing.T) {
{"INSERT INTO t VALUES (1)", true},
{"DROP TABLE t", true},
{"CREATE TABLE t (id INT)", true},
{"CALL bulk_insert_users(100000)", true},
}
for _, tt := range tests {
result := g.Check(tt.sql)

View File

@@ -51,7 +51,7 @@ func (g *Guard) isAllowed(opType ai.SQLOperationType) bool {
case ai.PermissionReadWrite:
return opType == ai.SQLOpQuery || opType == ai.SQLOpDML
case ai.PermissionFull:
return opType == ai.SQLOpQuery || opType == ai.SQLOpDML || opType == ai.SQLOpDDL
return true
default:
return opType == ai.SQLOpQuery
}