From 3e140c1bc6cfec81f39d1f37e4e3adbe992e5024 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Wed, 17 Jun 2026 09:49:59 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(ai-safety):=20=E4=BF=AE?= =?UTF-8?q?=E6=AD=A3=E5=AE=8C=E5=85=A8=E6=A8=A1=E5=BC=8F=E6=89=A7=E8=A1=8C?= =?UTF-8?q?=E5=8F=A3=E5=BE=84=E4=B8=8E=E6=9C=AC=E5=9C=B0=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E5=A4=B1=E8=B4=A5=E5=88=A4=E5=AE=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修正完全模式下 DML 与过程调用的安全提示和限制说明 - 区分连接探针失败与可恢复 SQL 执行错误,避免数据探针被误终止 - 修复本地 execute_sql 写语句结果返回 affectedRows - 补充 AI 安全、本地工具执行与 SQL 限制回归测试 --- .../components/ai/AISettingsSafetySection.tsx | 2 +- .../components/ai/aiDatabaseToolExecutor.ts | 118 +++++++++++++++--- .../components/ai/aiLocalToolExecutor.test.ts | 80 ++++++++++++ .../src/components/ai/aiLocalToolExecutor.ts | 3 + .../components/ai/aiSafetyInsights.test.ts | 13 ++ .../src/components/ai/aiSafetyInsights.ts | 2 +- .../ai/aiSnapshotInspectionToolTypes.ts | 1 + .../ai/messageBubble/AIMessageCodeBlock.tsx | 1 + .../ai/useAIChatLocalTools.test.tsx | 32 +++++ .../src/components/ai/useAIChatLocalTools.ts | 8 +- frontend/src/utils/aiSqlLimit.test.ts | 5 + frontend/src/utils/aiSqlLimit.ts | 10 +- internal/ai/safety/classifier_test.go | 1 + internal/ai/safety/guard.go | 2 +- 14 files changed, 254 insertions(+), 24 deletions(-) diff --git a/frontend/src/components/ai/AISettingsSafetySection.tsx b/frontend/src/components/ai/AISettingsSafetySection.tsx index 861f188..9da29fe 100644 --- a/frontend/src/components/ai/AISettingsSafetySection.tsx +++ b/frontend/src/components/ai/AISettingsSafetySection.tsx @@ -7,7 +7,7 @@ import type { OverlayWorkbenchTheme } from '../../utils/overlayWorkbenchTheme'; const SAFETY_OPTIONS: { label: string; value: AISafetyLevel; desc: string; color: string; icon: string }[] = [ { label: '只读模式', value: 'readonly', desc: 'AI 仅可执行 SELECT 等查询操作,最安全', color: '#22c55e', icon: '🔒' }, { label: '读写模式', value: 'readwrite', desc: 'AI 可执行 INSERT/UPDATE/DELETE,危险操作需二次确认', color: '#f59e0b', icon: '⚠️' }, - { label: '完全模式', value: 'full', desc: 'AI 可执行所有操作(含 DDL),高危操作自动告警', color: '#ef4444', icon: '🔓' }, + { label: '完全模式', value: 'full', desc: 'AI 可执行所有操作(含 DDL/过程调用),高危或未识别操作会告警', color: '#ef4444', icon: '🔓' }, ]; interface AISettingsSafetySectionProps { diff --git a/frontend/src/components/ai/aiDatabaseToolExecutor.ts b/frontend/src/components/ai/aiDatabaseToolExecutor.ts index b6dbbd5..492864b 100644 --- a/frontend/src/components/ai/aiDatabaseToolExecutor.ts +++ b/frontend/src/components/ai/aiDatabaseToolExecutor.ts @@ -23,6 +23,7 @@ interface ExecuteDatabaseToolCallOptions { interface ToolExecutionResult { content: string; success: boolean; + countsAsProbeFailure?: boolean; } const findConnection = (connections: SavedConnection[], connectionId: string) => @@ -39,12 +40,45 @@ const resolveConnectionOrFailure = ( failure: { content: 'Connection not found', success: false, + countsAsProbeFailure: true, }, }; } return { connection }; }; +const CONNECTION_ERROR_KEYWORDS = [ + 'connection not found', + 'invalid connection', + 'bad connection', + 'driver: bad connection', + 'connection refused', + 'connection reset', + 'closed network connection', + 'server has gone away', + 'broken pipe', + 'no such host', + 'network is unreachable', + 'context deadline exceeded', + 'i/o timeout', + 'timeout', + 'eof', + '连接失败', + '连接异常', + '连接超时', + '连接已关闭', + '网络超时', + '网络异常', +]; + +const countsAsProbeFailure = (message: unknown): boolean => { + const text = String(message || '').trim().toLowerCase(); + if (!text) { + return true; + } + return CONNECTION_ERROR_KEYWORDS.some((keyword) => text.includes(keyword)); +}; + export async function executeDatabaseToolCall( options: ExecuteDatabaseToolCallOptions, ): Promise { @@ -75,9 +109,14 @@ export async function executeDatabaseToolCall( } return { content: JSON.stringify(databaseNames), success: true }; } - return { content: result?.message || 'Failed to fetch DBs', success: false }; + return { + content: result?.message || 'Failed to fetch DBs', + success: false, + countsAsProbeFailure: countsAsProbeFailure(result?.message), + }; } catch (error: any) { - return { content: `获取数据库列表失败: ${error?.message || error}`, success: false }; + const message = `获取数据库列表失败: ${error?.message || error}`; + return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) }; } } case 'get_tables': { @@ -99,9 +138,14 @@ export async function executeDatabaseToolCall( }); return { content: JSON.stringify(tableNames), success: true }; } - return { content: result?.message || 'Failed to fetch Tables', success: false }; + return { + content: result?.message || 'Failed to fetch Tables', + success: false, + countsAsProbeFailure: countsAsProbeFailure(result?.message), + }; } catch (error: any) { - return { content: `获取表列表失败: ${error?.message || error}`, success: false }; + const message = `获取表列表失败: ${error?.message || error}`; + return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) }; } } case 'get_all_columns': { @@ -125,9 +169,14 @@ export async function executeDatabaseToolCall( success: true, }; } - return { content: result?.message || 'Failed to fetch all columns', success: false }; + return { + content: result?.message || 'Failed to fetch all columns', + success: false, + countsAsProbeFailure: countsAsProbeFailure(result?.message), + }; } catch (error: any) { - return { content: `获取全库字段摘要失败: ${error?.message || error}`, success: false }; + const message = `获取全库字段摘要失败: ${error?.message || error}`; + return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) }; } } case 'get_columns': { @@ -145,9 +194,14 @@ export async function executeDatabaseToolCall( success: true, }; } - return { content: result?.message || 'Failed to fetch columns', success: false }; + return { + content: result?.message || 'Failed to fetch columns', + success: false, + countsAsProbeFailure: countsAsProbeFailure(result?.message), + }; } catch (error: any) { - return { content: `获取字段列表失败: ${error?.message || error}`, success: false }; + const message = `获取字段列表失败: ${error?.message || error}`; + return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) }; } } case 'get_indexes': { @@ -160,9 +214,11 @@ export async function executeDatabaseToolCall( return { content: result?.success && Array.isArray(result.data) ? JSON.stringify(result.data) : (result?.message || 'Failed to fetch indexes'), success: !!result?.success && Array.isArray(result.data), + countsAsProbeFailure: result?.success ? false : countsAsProbeFailure(result?.message), }; } catch (error: any) { - return { content: `获取索引定义失败: ${error?.message || error}`, success: false }; + const message = `获取索引定义失败: ${error?.message || error}`; + return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) }; } } case 'get_foreign_keys': { @@ -175,9 +231,11 @@ export async function executeDatabaseToolCall( return { content: result?.success && Array.isArray(result.data) ? JSON.stringify(result.data) : (result?.message || 'Failed to fetch foreign keys'), success: !!result?.success && Array.isArray(result.data), + countsAsProbeFailure: result?.success ? false : countsAsProbeFailure(result?.message), }; } catch (error: any) { - return { content: `获取外键关系失败: ${error?.message || error}`, success: false }; + const message = `获取外键关系失败: ${error?.message || error}`; + return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) }; } } case 'get_triggers': { @@ -190,9 +248,11 @@ export async function executeDatabaseToolCall( return { content: result?.success && Array.isArray(result.data) ? JSON.stringify(result.data) : (result?.message || 'Failed to fetch triggers'), success: !!result?.success && Array.isArray(result.data), + countsAsProbeFailure: result?.success ? false : countsAsProbeFailure(result?.message), }; } catch (error: any) { - return { content: `获取触发器定义失败: ${error?.message || error}`, success: false }; + const message = `获取触发器定义失败: ${error?.message || error}`; + return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) }; } } case 'get_table_ddl': { @@ -207,9 +267,14 @@ export async function executeDatabaseToolCall( fetchDDL: () => runtime.showCreateTable(rpcConfig, safeDbName, safeTable), fetchColumns: () => runtime.getColumns(rpcConfig, safeDbName, safeTable), }); - return { content: result.content, success: result.success }; + return { + content: result.content, + success: result.success, + countsAsProbeFailure: result.success ? false : countsAsProbeFailure(result.content), + }; } catch (error: any) { - return { content: `获取建表语句失败: ${error?.message || error}`, success: false }; + const message = `获取建表语句失败: ${error?.message || error}`; + return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) }; } } case 'inspect_table_bundle': { @@ -245,9 +310,14 @@ export async function executeDatabaseToolCall( success: true, }; } - return { content: result?.message || 'Failed to preview table rows', success: false }; + return { + content: result?.message || 'Failed to preview table rows', + success: false, + countsAsProbeFailure: countsAsProbeFailure(result?.message), + }; } catch (error: any) { - return { content: `预览表样例数据失败: ${error?.message || error}`, success: false }; + const message = `预览表样例数据失败: ${error?.message || error}`; + return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) }; } } case 'execute_sql': { @@ -262,6 +332,7 @@ export async function executeDatabaseToolCall( return { content: `安全策略拦截:当前安全级别不允许执行 ${checkResult.operationType} 类型的 SQL。请将 SQL 展示给用户,让用户手动执行。`, success: false, + countsAsProbeFailure: false, }; } } @@ -270,18 +341,31 @@ export async function executeDatabaseToolCall( safeSql, 50, resolved.connection.config?.driver || '', + { oceanBaseProtocol: resolved.connection.config?.oceanBaseProtocol }, ); const result = await runtime.query(buildRpcConnectionConfig(resolved.connection.config) as any, safeDbName, finalSql); if (result?.success) { + const affectedRows = Number((result.data as Record | null | undefined)?.affectedRows); + if (Number.isFinite(affectedRows)) { + return { + content: JSON.stringify({ affectedRows }), + success: true, + }; + } const rows = Array.isArray(result.data) ? result.data : []; return { content: JSON.stringify({ rowCount: rows.length, data: rows.slice(0, 50) }), success: true, }; } - return { content: result?.message || 'SQL 执行失败', success: false }; + return { + content: result?.message || 'SQL 执行失败', + success: false, + countsAsProbeFailure: countsAsProbeFailure(result?.message), + }; } catch (error: any) { - return { content: `SQL 执行异常: ${error?.message || error}`, success: false }; + const message = `SQL 执行异常: ${error?.message || error}`; + return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) }; } } default: diff --git a/frontend/src/components/ai/aiLocalToolExecutor.test.ts b/frontend/src/components/ai/aiLocalToolExecutor.test.ts index ef39f0b..1f50ca5 100644 --- a/frontend/src/components/ai/aiLocalToolExecutor.test.ts +++ b/frontend/src/components/ai/aiLocalToolExecutor.test.ts @@ -201,6 +201,86 @@ describe('aiLocalToolExecutor', () => { expect(query).not.toHaveBeenCalled(); }); + it('treats OceanBase Oracle SQL execution errors as recoverable and uses Oracle readonly preview SQL', async () => { + const query = vi.fn().mockResolvedValue({ + success: false, + message: "oceanbase: error 900 (42000): ORA-00900 near '50 OFFSET 0'", + }); + const result = await executeLocalAIToolCall({ + toolCall: buildToolCall('execute_sql', { + connectionId: 'conn-1', + dbName: 'SYS', + sql: 'SELECT 1 FROM DUAL', + }), + connections: [{ + ...buildConnection(), + config: { + type: 'oceanbase', + host: '127.0.0.1', + port: 2881, + user: 'sys', + driver: 'oceanbase', + oceanBaseProtocol: 'oracle', + }, + }], + mcpTools: [], + toolContextMap: new Map(), + runtime: { + getDatabases: vi.fn(), + getTables: vi.fn(), + getColumns: vi.fn(), + getIndexes: vi.fn(), + getForeignKeys: vi.fn(), + getTriggers: vi.fn(), + showCreateTable: vi.fn(), + query, + checkSQL: vi.fn().mockResolvedValue({ allowed: true, operationType: 'query' }), + }, + }); + + expect(result.success).toBe(false); + expect(result.countsAsProbeFailure).toBe(false); + expect(query).toHaveBeenCalledWith( + expect.anything(), + 'SYS', + 'SELECT * FROM (SELECT 1 FROM DUAL) WHERE ROWNUM <= 50', + ); + }); + + it('returns affectedRows for execute_sql write statements instead of pretending rowCount is zero', async () => { + const query = vi.fn().mockResolvedValue({ + success: true, + data: { + affectedRows: 100000, + }, + }); + const result = await executeLocalAIToolCall({ + toolCall: buildToolCall('execute_sql', { + connectionId: 'conn-1', + dbName: 'crm', + sql: 'INSERT INTO orders_archive SELECT * FROM orders', + }), + connections: [buildConnection()], + mcpTools: [], + toolContextMap: new Map(), + runtime: { + getDatabases: vi.fn(), + getTables: vi.fn(), + getColumns: vi.fn(), + getIndexes: vi.fn(), + getForeignKeys: vi.fn(), + getTriggers: vi.fn(), + showCreateTable: vi.fn(), + query, + checkSQL: vi.fn().mockResolvedValue({ allowed: true, operationType: 'INSERT' }), + }, + }); + + expect(result.success).toBe(true); + expect(query).toHaveBeenCalledWith(expect.anything(), 'crm', 'INSERT INTO orders_archive SELECT * FROM orders'); + expect(JSON.parse(result.content)).toEqual({ affectedRows: 100000 }); + }); + it('returns a cross-table column summary for get_all_columns', async () => { const result = await executeLocalAIToolCall({ toolCall: buildToolCall('get_all_columns', { diff --git a/frontend/src/components/ai/aiLocalToolExecutor.ts b/frontend/src/components/ai/aiLocalToolExecutor.ts index a3a3dba..131cdb4 100644 --- a/frontend/src/components/ai/aiLocalToolExecutor.ts +++ b/frontend/src/components/ai/aiLocalToolExecutor.ts @@ -48,6 +48,7 @@ export interface ExecuteLocalAIToolCallResult { content: string; success: boolean; toolName: string; + countsAsProbeFailure?: boolean; } const buildToolName = (toolCall: AIToolCall, descriptor?: AIMCPToolDescriptor) => @@ -106,6 +107,7 @@ export async function executeLocalAIToolCall({ content: snapshotInspectionResult.content, success: snapshotInspectionResult.success, toolName: buildToolName(toolCall, descriptor), + countsAsProbeFailure: snapshotInspectionResult.countsAsProbeFailure, }; } @@ -121,6 +123,7 @@ export async function executeLocalAIToolCall({ content: databaseToolResult.content, success: databaseToolResult.success, toolName: buildToolName(toolCall, descriptor), + countsAsProbeFailure: databaseToolResult.countsAsProbeFailure, }; } diff --git a/frontend/src/components/ai/aiSafetyInsights.test.ts b/frontend/src/components/ai/aiSafetyInsights.test.ts index b437943..d954c07 100644 --- a/frontend/src/components/ai/aiSafetyInsights.test.ts +++ b/frontend/src/components/ai/aiSafetyInsights.test.ts @@ -62,4 +62,17 @@ describe('buildAISafetySnapshot', () => { expect(snapshot.effectiveRestrictions.join('\n')).toContain('allowMutating=true'); expect(snapshot.effectiveRestrictions.join('\n')).toContain('当前 JVM 诊断明确禁止 mutating 命令'); }); + + it('describes full safety mode as allowing other statements with confirmation', () => { + const snapshot = buildAISafetySnapshot({ + safetyLevel: 'full', + connections: [], + }); + + expect(snapshot.safetyLevel).toBe('full'); + expect(snapshot.permissionMatrix.allowDML).toBe(true); + expect(snapshot.permissionMatrix.allowDDL).toBe(true); + expect(snapshot.sqlRuleText).toContain('允许所有 SQL 操作'); + expect(snapshot.effectiveRestrictions.join('\n')).toContain('高风险或未识别语句仍会要求确认'); + }); }); diff --git a/frontend/src/components/ai/aiSafetyInsights.ts b/frontend/src/components/ai/aiSafetyInsights.ts index b9d5390..6f1a958 100644 --- a/frontend/src/components/ai/aiSafetyInsights.ts +++ b/frontend/src/components/ai/aiSafetyInsights.ts @@ -9,7 +9,7 @@ const SAFETY_LEVEL_LABELS: Record = { const SAFETY_RULE_TEXTS: Record = { readonly: '只读模式仅允许查询语句。', readwrite: '读写模式允许查询和 DML,DDL 仍会被阻止。', - full: '完全开放模式允许查询、DML 和 DDL;未识别操作仍会被阻止。', + full: '完全开放模式允许所有 SQL 操作;高风险或未识别语句仍会要求确认。', }; const normalizeSafetyLevel = (value: AISafetyLevel | string | undefined): string => { diff --git a/frontend/src/components/ai/aiSnapshotInspectionToolTypes.ts b/frontend/src/components/ai/aiSnapshotInspectionToolTypes.ts index b04c3c3..15577e4 100644 --- a/frontend/src/components/ai/aiSnapshotInspectionToolTypes.ts +++ b/frontend/src/components/ai/aiSnapshotInspectionToolTypes.ts @@ -46,4 +46,5 @@ export interface AISnapshotInspectionRuntime { export interface SnapshotInspectionResult { content: string; success: boolean; + countsAsProbeFailure?: boolean; } diff --git a/frontend/src/components/ai/messageBubble/AIMessageCodeBlock.tsx b/frontend/src/components/ai/messageBubble/AIMessageCodeBlock.tsx index 8ba0f3d..19e475e 100644 --- a/frontend/src/components/ai/messageBubble/AIMessageCodeBlock.tsx +++ b/frontend/src/components/ai/messageBubble/AIMessageCodeBlock.tsx @@ -211,6 +211,7 @@ const HighlightedCodeBlock: React.FC = ({ displayText, 50, activeConnectionConfig?.driver || '', + { oceanBaseProtocol: activeConnectionConfig?.oceanBaseProtocol }, ); const response = await DBQuery(activeConnectionConfig, activeDbName || '', previewSql); if (response.success && Array.isArray(response.data)) { diff --git a/frontend/src/components/ai/useAIChatLocalTools.test.tsx b/frontend/src/components/ai/useAIChatLocalTools.test.tsx index ef69374..627977b 100644 --- a/frontend/src/components/ai/useAIChatLocalTools.test.tsx +++ b/frontend/src/components/ai/useAIChatLocalTools.test.tsx @@ -11,6 +11,7 @@ const executeLocalAIToolCallMock = vi.hoisted(() => vi.fn(async ({ toolCall }: { content: `result:${toolCall.function.name}`, success: true, toolName: toolCall.function.name, + countsAsProbeFailure: true, }))); vi.mock('./aiChatPayloadDispatch', () => ({ @@ -191,4 +192,35 @@ describe('useAIChatLocalTools', () => { renderer?.unmount(); }); }); + + it('does not auto-stop the probe after three recoverable SQL execution errors', async () => { + executeLocalAIToolCallMock.mockResolvedValue({ + content: "oceanbase: error 900 (42000): ORA-00900 near '50 OFFSET 0'", + success: false, + toolName: 'execute_sql', + countsAsProbeFailure: false, + }); + + let renderer: ReactTestRenderer | undefined; + await act(async () => { + renderer = create(); + }); + + expect(latestHook).toBeDefined(); + for (let i = 0; i < 3; i += 1) { + const run = latestHook!.executeLocalTools([buildToolCall('execute_sql')], 'assistant-1'); + await act(async () => { + await vi.advanceTimersByTimeAsync(150); + await run; + }); + } + + const messages = useStore.getState().aiChatHistory[SESSION_ID] || []; + expect(messages.some((message) => message.content.includes('探针连续 3 轮执行失败'))).toBe(false); + expect(dispatchAIChatPayloadMock).toHaveBeenCalledTimes(3); + + await act(async () => { + renderer?.unmount(); + }); + }); }); diff --git a/frontend/src/components/ai/useAIChatLocalTools.ts b/frontend/src/components/ai/useAIChatLocalTools.ts index aaa22b7..8f2caf2 100644 --- a/frontend/src/components/ai/useAIChatLocalTools.ts +++ b/frontend/src/components/ai/useAIChatLocalTools.ts @@ -18,6 +18,7 @@ import { dispatchAIChatPayload } from './aiChatPayloadDispatch'; import { buildToolResultMessage, executeLocalAIToolCall, + type ExecuteLocalAIToolCallResult, type AIToolContextEntry, } from './aiLocalToolExecutor'; @@ -96,6 +97,7 @@ export const useAIChatLocalTools = ({ } const results: AIChatMessage[] = []; + const executions: ExecuteLocalAIToolCallResult[] = []; const currentConnections = useStore.getState().connections; for (const toolCall of toolCalls) { const currentState = useStore.getState(); @@ -119,6 +121,7 @@ export const useAIChatLocalTools = ({ userPromptSettings, dynamicModels, }); + executions.push(execution); const toolResultMsg: AIChatMessage = buildToolResultMessage({ id: nextMessageId(), timestamp: Date.now(), @@ -130,8 +133,9 @@ export const useAIChatLocalTools = ({ await new Promise((resolve) => setTimeout(resolve, 150)); } - const anySuccess = results.some((message) => message.success === true); - if (anySuccess) { + const roundCountsAsFailure = executions.length > 0 + && executions.every((execution) => execution.success !== true && execution.countsAsProbeFailure !== false); + if (!roundCountsAsFailure) { toolCallRoundRef.current = 0; } else { toolCallRoundRef.current += 1; diff --git a/frontend/src/utils/aiSqlLimit.test.ts b/frontend/src/utils/aiSqlLimit.test.ts index 179fd66..58ae68b 100644 --- a/frontend/src/utils/aiSqlLimit.test.ts +++ b/frontend/src/utils/aiSqlLimit.test.ts @@ -22,6 +22,11 @@ describe('buildAIReadonlyPreviewSQL', () => { .toBe('SELECT * FROM (SELECT 1 FROM DUAL) WHERE ROWNUM <= 50'); }); + it('treats OceanBase Oracle as Oracle dialect when building readonly preview SQL', () => { + expect(buildAIReadonlyPreviewSQL('oceanbase', 'SELECT 1 FROM DUAL;', 50, 'oceanbase', { oceanBaseProtocol: 'oracle' })) + .toBe('SELECT * FROM (SELECT 1 FROM DUAL) WHERE ROWNUM <= 50'); + }); + it('keeps MySQL-family SQL on LIMIT syntax', () => { expect(buildAIReadonlyPreviewSQL('mysql', 'SELECT * FROM users', 50)) .toBe('SELECT * FROM users LIMIT 50 OFFSET 0'); diff --git a/frontend/src/utils/aiSqlLimit.ts b/frontend/src/utils/aiSqlLimit.ts index daf884c..35410a1 100644 --- a/frontend/src/utils/aiSqlLimit.ts +++ b/frontend/src/utils/aiSqlLimit.ts @@ -20,10 +20,16 @@ const hasExistingRowLimit = (dialect: string, sql: string): boolean => { return (dialect === 'oracle' || dialect === 'dameng') && /\brownum\b/.test(text); }; -export const buildAIReadonlyPreviewSQL = (dbType: string, sql: string, limit = 50, driver = ''): string => { +export const buildAIReadonlyPreviewSQL = ( + dbType: string, + sql: string, + limit = 50, + driver = '', + options?: { oceanBaseProtocol?: unknown }, +): string => { const baseSQL = trimSQLStatement(sql); const safeLimit = Math.max(0, Math.floor(Number(limit) || 0)); - const dialect = resolveSqlDialect(dbType, driver); + const dialect = resolveSqlDialect(dbType, driver, options); if (!baseSQL || safeLimit <= 0 || !isAIReadonlySQL(baseSQL) || hasExistingRowLimit(dialect, baseSQL)) { return baseSQL; } diff --git a/internal/ai/safety/classifier_test.go b/internal/ai/safety/classifier_test.go index 280b261..5a71c4c 100644 --- a/internal/ai/safety/classifier_test.go +++ b/internal/ai/safety/classifier_test.go @@ -124,6 +124,7 @@ func TestGuard_Full(t *testing.T) { {"INSERT INTO t VALUES (1)", true}, {"DROP TABLE t", true}, {"CREATE TABLE t (id INT)", true}, + {"CALL bulk_insert_users(100000)", true}, } for _, tt := range tests { result := g.Check(tt.sql) diff --git a/internal/ai/safety/guard.go b/internal/ai/safety/guard.go index ca31bf2..3a958ab 100644 --- a/internal/ai/safety/guard.go +++ b/internal/ai/safety/guard.go @@ -51,7 +51,7 @@ func (g *Guard) isAllowed(opType ai.SQLOperationType) bool { case ai.PermissionReadWrite: return opType == ai.SQLOpQuery || opType == ai.SQLOpDML case ai.PermissionFull: - return opType == ai.SQLOpQuery || opType == ai.SQLOpDML || opType == ai.SQLOpDDL + return true default: return opType == ai.SQLOpQuery }