mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-20 21:43:56 +08:00
🐛 fix(ai-safety): 修正完全模式执行口径与本地工具失败判定
- 修正完全模式下 DML 与过程调用的安全提示和限制说明 - 区分连接探针失败与可恢复 SQL 执行错误,避免数据探针被误终止 - 修复本地 execute_sql 写语句结果返回 affectedRows - 补充 AI 安全、本地工具执行与 SQL 限制回归测试
This commit is contained in:
@@ -7,7 +7,7 @@ import type { OverlayWorkbenchTheme } from '../../utils/overlayWorkbenchTheme';
|
||||
const SAFETY_OPTIONS: { label: string; value: AISafetyLevel; desc: string; color: string; icon: string }[] = [
|
||||
{ label: '只读模式', value: 'readonly', desc: 'AI 仅可执行 SELECT 等查询操作,最安全', color: '#22c55e', icon: '🔒' },
|
||||
{ label: '读写模式', value: 'readwrite', desc: 'AI 可执行 INSERT/UPDATE/DELETE,危险操作需二次确认', color: '#f59e0b', icon: '⚠️' },
|
||||
{ label: '完全模式', value: 'full', desc: 'AI 可执行所有操作(含 DDL),高危操作自动告警', color: '#ef4444', icon: '🔓' },
|
||||
{ label: '完全模式', value: 'full', desc: 'AI 可执行所有操作(含 DDL/过程调用),高危或未识别操作会告警', color: '#ef4444', icon: '🔓' },
|
||||
];
|
||||
|
||||
interface AISettingsSafetySectionProps {
|
||||
|
||||
@@ -23,6 +23,7 @@ interface ExecuteDatabaseToolCallOptions {
|
||||
interface ToolExecutionResult {
|
||||
content: string;
|
||||
success: boolean;
|
||||
countsAsProbeFailure?: boolean;
|
||||
}
|
||||
|
||||
const findConnection = (connections: SavedConnection[], connectionId: string) =>
|
||||
@@ -39,12 +40,45 @@ const resolveConnectionOrFailure = (
|
||||
failure: {
|
||||
content: 'Connection not found',
|
||||
success: false,
|
||||
countsAsProbeFailure: true,
|
||||
},
|
||||
};
|
||||
}
|
||||
return { connection };
|
||||
};
|
||||
|
||||
const CONNECTION_ERROR_KEYWORDS = [
|
||||
'connection not found',
|
||||
'invalid connection',
|
||||
'bad connection',
|
||||
'driver: bad connection',
|
||||
'connection refused',
|
||||
'connection reset',
|
||||
'closed network connection',
|
||||
'server has gone away',
|
||||
'broken pipe',
|
||||
'no such host',
|
||||
'network is unreachable',
|
||||
'context deadline exceeded',
|
||||
'i/o timeout',
|
||||
'timeout',
|
||||
'eof',
|
||||
'连接失败',
|
||||
'连接异常',
|
||||
'连接超时',
|
||||
'连接已关闭',
|
||||
'网络超时',
|
||||
'网络异常',
|
||||
];
|
||||
|
||||
const countsAsProbeFailure = (message: unknown): boolean => {
|
||||
const text = String(message || '').trim().toLowerCase();
|
||||
if (!text) {
|
||||
return true;
|
||||
}
|
||||
return CONNECTION_ERROR_KEYWORDS.some((keyword) => text.includes(keyword));
|
||||
};
|
||||
|
||||
export async function executeDatabaseToolCall(
|
||||
options: ExecuteDatabaseToolCallOptions,
|
||||
): Promise<ToolExecutionResult | null> {
|
||||
@@ -75,9 +109,14 @@ export async function executeDatabaseToolCall(
|
||||
}
|
||||
return { content: JSON.stringify(databaseNames), success: true };
|
||||
}
|
||||
return { content: result?.message || 'Failed to fetch DBs', success: false };
|
||||
return {
|
||||
content: result?.message || 'Failed to fetch DBs',
|
||||
success: false,
|
||||
countsAsProbeFailure: countsAsProbeFailure(result?.message),
|
||||
};
|
||||
} catch (error: any) {
|
||||
return { content: `获取数据库列表失败: ${error?.message || error}`, success: false };
|
||||
const message = `获取数据库列表失败: ${error?.message || error}`;
|
||||
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
|
||||
}
|
||||
}
|
||||
case 'get_tables': {
|
||||
@@ -99,9 +138,14 @@ export async function executeDatabaseToolCall(
|
||||
});
|
||||
return { content: JSON.stringify(tableNames), success: true };
|
||||
}
|
||||
return { content: result?.message || 'Failed to fetch Tables', success: false };
|
||||
return {
|
||||
content: result?.message || 'Failed to fetch Tables',
|
||||
success: false,
|
||||
countsAsProbeFailure: countsAsProbeFailure(result?.message),
|
||||
};
|
||||
} catch (error: any) {
|
||||
return { content: `获取表列表失败: ${error?.message || error}`, success: false };
|
||||
const message = `获取表列表失败: ${error?.message || error}`;
|
||||
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
|
||||
}
|
||||
}
|
||||
case 'get_all_columns': {
|
||||
@@ -125,9 +169,14 @@ export async function executeDatabaseToolCall(
|
||||
success: true,
|
||||
};
|
||||
}
|
||||
return { content: result?.message || 'Failed to fetch all columns', success: false };
|
||||
return {
|
||||
content: result?.message || 'Failed to fetch all columns',
|
||||
success: false,
|
||||
countsAsProbeFailure: countsAsProbeFailure(result?.message),
|
||||
};
|
||||
} catch (error: any) {
|
||||
return { content: `获取全库字段摘要失败: ${error?.message || error}`, success: false };
|
||||
const message = `获取全库字段摘要失败: ${error?.message || error}`;
|
||||
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
|
||||
}
|
||||
}
|
||||
case 'get_columns': {
|
||||
@@ -145,9 +194,14 @@ export async function executeDatabaseToolCall(
|
||||
success: true,
|
||||
};
|
||||
}
|
||||
return { content: result?.message || 'Failed to fetch columns', success: false };
|
||||
return {
|
||||
content: result?.message || 'Failed to fetch columns',
|
||||
success: false,
|
||||
countsAsProbeFailure: countsAsProbeFailure(result?.message),
|
||||
};
|
||||
} catch (error: any) {
|
||||
return { content: `获取字段列表失败: ${error?.message || error}`, success: false };
|
||||
const message = `获取字段列表失败: ${error?.message || error}`;
|
||||
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
|
||||
}
|
||||
}
|
||||
case 'get_indexes': {
|
||||
@@ -160,9 +214,11 @@ export async function executeDatabaseToolCall(
|
||||
return {
|
||||
content: result?.success && Array.isArray(result.data) ? JSON.stringify(result.data) : (result?.message || 'Failed to fetch indexes'),
|
||||
success: !!result?.success && Array.isArray(result.data),
|
||||
countsAsProbeFailure: result?.success ? false : countsAsProbeFailure(result?.message),
|
||||
};
|
||||
} catch (error: any) {
|
||||
return { content: `获取索引定义失败: ${error?.message || error}`, success: false };
|
||||
const message = `获取索引定义失败: ${error?.message || error}`;
|
||||
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
|
||||
}
|
||||
}
|
||||
case 'get_foreign_keys': {
|
||||
@@ -175,9 +231,11 @@ export async function executeDatabaseToolCall(
|
||||
return {
|
||||
content: result?.success && Array.isArray(result.data) ? JSON.stringify(result.data) : (result?.message || 'Failed to fetch foreign keys'),
|
||||
success: !!result?.success && Array.isArray(result.data),
|
||||
countsAsProbeFailure: result?.success ? false : countsAsProbeFailure(result?.message),
|
||||
};
|
||||
} catch (error: any) {
|
||||
return { content: `获取外键关系失败: ${error?.message || error}`, success: false };
|
||||
const message = `获取外键关系失败: ${error?.message || error}`;
|
||||
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
|
||||
}
|
||||
}
|
||||
case 'get_triggers': {
|
||||
@@ -190,9 +248,11 @@ export async function executeDatabaseToolCall(
|
||||
return {
|
||||
content: result?.success && Array.isArray(result.data) ? JSON.stringify(result.data) : (result?.message || 'Failed to fetch triggers'),
|
||||
success: !!result?.success && Array.isArray(result.data),
|
||||
countsAsProbeFailure: result?.success ? false : countsAsProbeFailure(result?.message),
|
||||
};
|
||||
} catch (error: any) {
|
||||
return { content: `获取触发器定义失败: ${error?.message || error}`, success: false };
|
||||
const message = `获取触发器定义失败: ${error?.message || error}`;
|
||||
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
|
||||
}
|
||||
}
|
||||
case 'get_table_ddl': {
|
||||
@@ -207,9 +267,14 @@ export async function executeDatabaseToolCall(
|
||||
fetchDDL: () => runtime.showCreateTable(rpcConfig, safeDbName, safeTable),
|
||||
fetchColumns: () => runtime.getColumns(rpcConfig, safeDbName, safeTable),
|
||||
});
|
||||
return { content: result.content, success: result.success };
|
||||
return {
|
||||
content: result.content,
|
||||
success: result.success,
|
||||
countsAsProbeFailure: result.success ? false : countsAsProbeFailure(result.content),
|
||||
};
|
||||
} catch (error: any) {
|
||||
return { content: `获取建表语句失败: ${error?.message || error}`, success: false };
|
||||
const message = `获取建表语句失败: ${error?.message || error}`;
|
||||
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
|
||||
}
|
||||
}
|
||||
case 'inspect_table_bundle': {
|
||||
@@ -245,9 +310,14 @@ export async function executeDatabaseToolCall(
|
||||
success: true,
|
||||
};
|
||||
}
|
||||
return { content: result?.message || 'Failed to preview table rows', success: false };
|
||||
return {
|
||||
content: result?.message || 'Failed to preview table rows',
|
||||
success: false,
|
||||
countsAsProbeFailure: countsAsProbeFailure(result?.message),
|
||||
};
|
||||
} catch (error: any) {
|
||||
return { content: `预览表样例数据失败: ${error?.message || error}`, success: false };
|
||||
const message = `预览表样例数据失败: ${error?.message || error}`;
|
||||
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
|
||||
}
|
||||
}
|
||||
case 'execute_sql': {
|
||||
@@ -262,6 +332,7 @@ export async function executeDatabaseToolCall(
|
||||
return {
|
||||
content: `安全策略拦截:当前安全级别不允许执行 ${checkResult.operationType} 类型的 SQL。请将 SQL 展示给用户,让用户手动执行。`,
|
||||
success: false,
|
||||
countsAsProbeFailure: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -270,18 +341,31 @@ export async function executeDatabaseToolCall(
|
||||
safeSql,
|
||||
50,
|
||||
resolved.connection.config?.driver || '',
|
||||
{ oceanBaseProtocol: resolved.connection.config?.oceanBaseProtocol },
|
||||
);
|
||||
const result = await runtime.query(buildRpcConnectionConfig(resolved.connection.config) as any, safeDbName, finalSql);
|
||||
if (result?.success) {
|
||||
const affectedRows = Number((result.data as Record<string, unknown> | null | undefined)?.affectedRows);
|
||||
if (Number.isFinite(affectedRows)) {
|
||||
return {
|
||||
content: JSON.stringify({ affectedRows }),
|
||||
success: true,
|
||||
};
|
||||
}
|
||||
const rows = Array.isArray(result.data) ? result.data : [];
|
||||
return {
|
||||
content: JSON.stringify({ rowCount: rows.length, data: rows.slice(0, 50) }),
|
||||
success: true,
|
||||
};
|
||||
}
|
||||
return { content: result?.message || 'SQL 执行失败', success: false };
|
||||
return {
|
||||
content: result?.message || 'SQL 执行失败',
|
||||
success: false,
|
||||
countsAsProbeFailure: countsAsProbeFailure(result?.message),
|
||||
};
|
||||
} catch (error: any) {
|
||||
return { content: `SQL 执行异常: ${error?.message || error}`, success: false };
|
||||
const message = `SQL 执行异常: ${error?.message || error}`;
|
||||
return { content: message, success: false, countsAsProbeFailure: countsAsProbeFailure(message) };
|
||||
}
|
||||
}
|
||||
default:
|
||||
|
||||
@@ -201,6 +201,86 @@ describe('aiLocalToolExecutor', () => {
|
||||
expect(query).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('treats OceanBase Oracle SQL execution errors as recoverable and uses Oracle readonly preview SQL', async () => {
|
||||
const query = vi.fn().mockResolvedValue({
|
||||
success: false,
|
||||
message: "oceanbase: error 900 (42000): ORA-00900 near '50 OFFSET 0'",
|
||||
});
|
||||
const result = await executeLocalAIToolCall({
|
||||
toolCall: buildToolCall('execute_sql', {
|
||||
connectionId: 'conn-1',
|
||||
dbName: 'SYS',
|
||||
sql: 'SELECT 1 FROM DUAL',
|
||||
}),
|
||||
connections: [{
|
||||
...buildConnection(),
|
||||
config: {
|
||||
type: 'oceanbase',
|
||||
host: '127.0.0.1',
|
||||
port: 2881,
|
||||
user: 'sys',
|
||||
driver: 'oceanbase',
|
||||
oceanBaseProtocol: 'oracle',
|
||||
},
|
||||
}],
|
||||
mcpTools: [],
|
||||
toolContextMap: new Map(),
|
||||
runtime: {
|
||||
getDatabases: vi.fn(),
|
||||
getTables: vi.fn(),
|
||||
getColumns: vi.fn(),
|
||||
getIndexes: vi.fn(),
|
||||
getForeignKeys: vi.fn(),
|
||||
getTriggers: vi.fn(),
|
||||
showCreateTable: vi.fn(),
|
||||
query,
|
||||
checkSQL: vi.fn().mockResolvedValue({ allowed: true, operationType: 'query' }),
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.success).toBe(false);
|
||||
expect(result.countsAsProbeFailure).toBe(false);
|
||||
expect(query).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
'SYS',
|
||||
'SELECT * FROM (SELECT 1 FROM DUAL) WHERE ROWNUM <= 50',
|
||||
);
|
||||
});
|
||||
|
||||
it('returns affectedRows for execute_sql write statements instead of pretending rowCount is zero', async () => {
|
||||
const query = vi.fn().mockResolvedValue({
|
||||
success: true,
|
||||
data: {
|
||||
affectedRows: 100000,
|
||||
},
|
||||
});
|
||||
const result = await executeLocalAIToolCall({
|
||||
toolCall: buildToolCall('execute_sql', {
|
||||
connectionId: 'conn-1',
|
||||
dbName: 'crm',
|
||||
sql: 'INSERT INTO orders_archive SELECT * FROM orders',
|
||||
}),
|
||||
connections: [buildConnection()],
|
||||
mcpTools: [],
|
||||
toolContextMap: new Map(),
|
||||
runtime: {
|
||||
getDatabases: vi.fn(),
|
||||
getTables: vi.fn(),
|
||||
getColumns: vi.fn(),
|
||||
getIndexes: vi.fn(),
|
||||
getForeignKeys: vi.fn(),
|
||||
getTriggers: vi.fn(),
|
||||
showCreateTable: vi.fn(),
|
||||
query,
|
||||
checkSQL: vi.fn().mockResolvedValue({ allowed: true, operationType: 'INSERT' }),
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.success).toBe(true);
|
||||
expect(query).toHaveBeenCalledWith(expect.anything(), 'crm', 'INSERT INTO orders_archive SELECT * FROM orders');
|
||||
expect(JSON.parse(result.content)).toEqual({ affectedRows: 100000 });
|
||||
});
|
||||
|
||||
it('returns a cross-table column summary for get_all_columns', async () => {
|
||||
const result = await executeLocalAIToolCall({
|
||||
toolCall: buildToolCall('get_all_columns', {
|
||||
|
||||
@@ -48,6 +48,7 @@ export interface ExecuteLocalAIToolCallResult {
|
||||
content: string;
|
||||
success: boolean;
|
||||
toolName: string;
|
||||
countsAsProbeFailure?: boolean;
|
||||
}
|
||||
|
||||
const buildToolName = (toolCall: AIToolCall, descriptor?: AIMCPToolDescriptor) =>
|
||||
@@ -106,6 +107,7 @@ export async function executeLocalAIToolCall({
|
||||
content: snapshotInspectionResult.content,
|
||||
success: snapshotInspectionResult.success,
|
||||
toolName: buildToolName(toolCall, descriptor),
|
||||
countsAsProbeFailure: snapshotInspectionResult.countsAsProbeFailure,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -121,6 +123,7 @@ export async function executeLocalAIToolCall({
|
||||
content: databaseToolResult.content,
|
||||
success: databaseToolResult.success,
|
||||
toolName: buildToolName(toolCall, descriptor),
|
||||
countsAsProbeFailure: databaseToolResult.countsAsProbeFailure,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -62,4 +62,17 @@ describe('buildAISafetySnapshot', () => {
|
||||
expect(snapshot.effectiveRestrictions.join('\n')).toContain('allowMutating=true');
|
||||
expect(snapshot.effectiveRestrictions.join('\n')).toContain('当前 JVM 诊断明确禁止 mutating 命令');
|
||||
});
|
||||
|
||||
it('describes full safety mode as allowing other statements with confirmation', () => {
|
||||
const snapshot = buildAISafetySnapshot({
|
||||
safetyLevel: 'full',
|
||||
connections: [],
|
||||
});
|
||||
|
||||
expect(snapshot.safetyLevel).toBe('full');
|
||||
expect(snapshot.permissionMatrix.allowDML).toBe(true);
|
||||
expect(snapshot.permissionMatrix.allowDDL).toBe(true);
|
||||
expect(snapshot.sqlRuleText).toContain('允许所有 SQL 操作');
|
||||
expect(snapshot.effectiveRestrictions.join('\n')).toContain('高风险或未识别语句仍会要求确认');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -9,7 +9,7 @@ const SAFETY_LEVEL_LABELS: Record<string, string> = {
|
||||
const SAFETY_RULE_TEXTS: Record<string, string> = {
|
||||
readonly: '只读模式仅允许查询语句。',
|
||||
readwrite: '读写模式允许查询和 DML,DDL 仍会被阻止。',
|
||||
full: '完全开放模式允许查询、DML 和 DDL;未识别操作仍会被阻止。',
|
||||
full: '完全开放模式允许所有 SQL 操作;高风险或未识别语句仍会要求确认。',
|
||||
};
|
||||
|
||||
const normalizeSafetyLevel = (value: AISafetyLevel | string | undefined): string => {
|
||||
|
||||
@@ -46,4 +46,5 @@ export interface AISnapshotInspectionRuntime {
|
||||
export interface SnapshotInspectionResult {
|
||||
content: string;
|
||||
success: boolean;
|
||||
countsAsProbeFailure?: boolean;
|
||||
}
|
||||
|
||||
@@ -211,6 +211,7 @@ const HighlightedCodeBlock: React.FC<HighlightedCodeBlockProps> = ({
|
||||
displayText,
|
||||
50,
|
||||
activeConnectionConfig?.driver || '',
|
||||
{ oceanBaseProtocol: activeConnectionConfig?.oceanBaseProtocol },
|
||||
);
|
||||
const response = await DBQuery(activeConnectionConfig, activeDbName || '', previewSql);
|
||||
if (response.success && Array.isArray(response.data)) {
|
||||
|
||||
@@ -11,6 +11,7 @@ const executeLocalAIToolCallMock = vi.hoisted(() => vi.fn(async ({ toolCall }: {
|
||||
content: `result:${toolCall.function.name}`,
|
||||
success: true,
|
||||
toolName: toolCall.function.name,
|
||||
countsAsProbeFailure: true,
|
||||
})));
|
||||
|
||||
vi.mock('./aiChatPayloadDispatch', () => ({
|
||||
@@ -191,4 +192,35 @@ describe('useAIChatLocalTools', () => {
|
||||
renderer?.unmount();
|
||||
});
|
||||
});
|
||||
|
||||
it('does not auto-stop the probe after three recoverable SQL execution errors', async () => {
|
||||
executeLocalAIToolCallMock.mockResolvedValue({
|
||||
content: "oceanbase: error 900 (42000): ORA-00900 near '50 OFFSET 0'",
|
||||
success: false,
|
||||
toolName: 'execute_sql',
|
||||
countsAsProbeFailure: false,
|
||||
});
|
||||
|
||||
let renderer: ReactTestRenderer | undefined;
|
||||
await act(async () => {
|
||||
renderer = create(<LocalToolsHarness />);
|
||||
});
|
||||
|
||||
expect(latestHook).toBeDefined();
|
||||
for (let i = 0; i < 3; i += 1) {
|
||||
const run = latestHook!.executeLocalTools([buildToolCall('execute_sql')], 'assistant-1');
|
||||
await act(async () => {
|
||||
await vi.advanceTimersByTimeAsync(150);
|
||||
await run;
|
||||
});
|
||||
}
|
||||
|
||||
const messages = useStore.getState().aiChatHistory[SESSION_ID] || [];
|
||||
expect(messages.some((message) => message.content.includes('探针连续 3 轮执行失败'))).toBe(false);
|
||||
expect(dispatchAIChatPayloadMock).toHaveBeenCalledTimes(3);
|
||||
|
||||
await act(async () => {
|
||||
renderer?.unmount();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -18,6 +18,7 @@ import { dispatchAIChatPayload } from './aiChatPayloadDispatch';
|
||||
import {
|
||||
buildToolResultMessage,
|
||||
executeLocalAIToolCall,
|
||||
type ExecuteLocalAIToolCallResult,
|
||||
type AIToolContextEntry,
|
||||
} from './aiLocalToolExecutor';
|
||||
|
||||
@@ -96,6 +97,7 @@ export const useAIChatLocalTools = ({
|
||||
}
|
||||
|
||||
const results: AIChatMessage[] = [];
|
||||
const executions: ExecuteLocalAIToolCallResult[] = [];
|
||||
const currentConnections = useStore.getState().connections;
|
||||
for (const toolCall of toolCalls) {
|
||||
const currentState = useStore.getState();
|
||||
@@ -119,6 +121,7 @@ export const useAIChatLocalTools = ({
|
||||
userPromptSettings,
|
||||
dynamicModels,
|
||||
});
|
||||
executions.push(execution);
|
||||
const toolResultMsg: AIChatMessage = buildToolResultMessage({
|
||||
id: nextMessageId(),
|
||||
timestamp: Date.now(),
|
||||
@@ -130,8 +133,9 @@ export const useAIChatLocalTools = ({
|
||||
await new Promise((resolve) => setTimeout(resolve, 150));
|
||||
}
|
||||
|
||||
const anySuccess = results.some((message) => message.success === true);
|
||||
if (anySuccess) {
|
||||
const roundCountsAsFailure = executions.length > 0
|
||||
&& executions.every((execution) => execution.success !== true && execution.countsAsProbeFailure !== false);
|
||||
if (!roundCountsAsFailure) {
|
||||
toolCallRoundRef.current = 0;
|
||||
} else {
|
||||
toolCallRoundRef.current += 1;
|
||||
|
||||
@@ -22,6 +22,11 @@ describe('buildAIReadonlyPreviewSQL', () => {
|
||||
.toBe('SELECT * FROM (SELECT 1 FROM DUAL) WHERE ROWNUM <= 50');
|
||||
});
|
||||
|
||||
it('treats OceanBase Oracle as Oracle dialect when building readonly preview SQL', () => {
|
||||
expect(buildAIReadonlyPreviewSQL('oceanbase', 'SELECT 1 FROM DUAL;', 50, 'oceanbase', { oceanBaseProtocol: 'oracle' }))
|
||||
.toBe('SELECT * FROM (SELECT 1 FROM DUAL) WHERE ROWNUM <= 50');
|
||||
});
|
||||
|
||||
it('keeps MySQL-family SQL on LIMIT syntax', () => {
|
||||
expect(buildAIReadonlyPreviewSQL('mysql', 'SELECT * FROM users', 50))
|
||||
.toBe('SELECT * FROM users LIMIT 50 OFFSET 0');
|
||||
|
||||
@@ -20,10 +20,16 @@ const hasExistingRowLimit = (dialect: string, sql: string): boolean => {
|
||||
return (dialect === 'oracle' || dialect === 'dameng') && /\brownum\b/.test(text);
|
||||
};
|
||||
|
||||
export const buildAIReadonlyPreviewSQL = (dbType: string, sql: string, limit = 50, driver = ''): string => {
|
||||
export const buildAIReadonlyPreviewSQL = (
|
||||
dbType: string,
|
||||
sql: string,
|
||||
limit = 50,
|
||||
driver = '',
|
||||
options?: { oceanBaseProtocol?: unknown },
|
||||
): string => {
|
||||
const baseSQL = trimSQLStatement(sql);
|
||||
const safeLimit = Math.max(0, Math.floor(Number(limit) || 0));
|
||||
const dialect = resolveSqlDialect(dbType, driver);
|
||||
const dialect = resolveSqlDialect(dbType, driver, options);
|
||||
if (!baseSQL || safeLimit <= 0 || !isAIReadonlySQL(baseSQL) || hasExistingRowLimit(dialect, baseSQL)) {
|
||||
return baseSQL;
|
||||
}
|
||||
|
||||
@@ -124,6 +124,7 @@ func TestGuard_Full(t *testing.T) {
|
||||
{"INSERT INTO t VALUES (1)", true},
|
||||
{"DROP TABLE t", true},
|
||||
{"CREATE TABLE t (id INT)", true},
|
||||
{"CALL bulk_insert_users(100000)", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
result := g.Check(tt.sql)
|
||||
|
||||
@@ -51,7 +51,7 @@ func (g *Guard) isAllowed(opType ai.SQLOperationType) bool {
|
||||
case ai.PermissionReadWrite:
|
||||
return opType == ai.SQLOpQuery || opType == ai.SQLOpDML
|
||||
case ai.PermissionFull:
|
||||
return opType == ai.SQLOpQuery || opType == ai.SQLOpDML || opType == ai.SQLOpDDL
|
||||
return true
|
||||
default:
|
||||
return opType == ai.SQLOpQuery
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user