From 51675f9d0506276bd163f317091771c454609971 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Tue, 28 Apr 2026 14:03:48 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(ai):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E5=A4=9A=E6=96=B9=E8=A8=80=E6=89=A7=E8=A1=8C=E4=B8=8E=20DDL=20?= =?UTF-8?q?=E9=99=8D=E7=BA=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - SQL 执行:移除 AI 工具和代码块预览中硬编码的 LIMIT 50 - 方言适配:按连接类型和自定义驱动别名生成只读 SQL 预览限流语句 - Oracle 兼容:Oracle、自定义 Oracle 和达梦改用 ROWNUM 语法限制行数 - 权限降级:获取表 DDL 失败时自动降级为字段元数据摘要 - 上下文优化:手动添加表结构上下文时复用同一套 DDL 降级逻辑 - 测试覆盖:新增 AI SQL 限流和表结构降级单元测试 Refs #418 --- frontend/src/components/AIChatPanel.tsx | 27 ++++---- frontend/src/components/ai/AIChatInput.tsx | 28 ++++---- .../src/components/ai/AIMessageBubble.tsx | 9 ++- frontend/src/utils/aiSqlLimit.test.ts | 48 +++++++++++++ frontend/src/utils/aiSqlLimit.ts | 31 +++++++++ frontend/src/utils/aiTableSchemaTool.test.ts | 51 ++++++++++++++ frontend/src/utils/aiTableSchemaTool.ts | 69 +++++++++++++++++++ frontend/src/utils/sql.ts | 3 +- 8 files changed, 235 insertions(+), 31 deletions(-) create mode 100644 frontend/src/utils/aiSqlLimit.test.ts create mode 100644 frontend/src/utils/aiSqlLimit.ts create mode 100644 frontend/src/utils/aiTableSchemaTool.test.ts create mode 100644 frontend/src/utils/aiTableSchemaTool.ts diff --git a/frontend/src/components/AIChatPanel.tsx b/frontend/src/components/AIChatPanel.tsx index 8f722b0..43a140b 100644 --- a/frontend/src/components/AIChatPanel.tsx +++ b/frontend/src/components/AIChatPanel.tsx @@ -25,6 +25,8 @@ import { buildMissingProviderNotice, buildModelFetchFailedNotice, } from '../utils/aiComposerNotice'; +import { buildAIReadonlyPreviewSQL } from '../utils/aiSqlLimit'; +import { resolveAITableSchemaToolResult } from '../utils/aiTableSchemaTool'; interface AIChatPanelProps { width?: number; @@ -1145,12 +1147,15 @@ SELECT * FROM users WHERE status = 1; try { const safeDbName = args.dbName ? String(args.dbName).trim() : ''; const safeTable = args.tableName ? String(args.tableName).trim() : ''; - const { DBShowCreateTable } = await import('../../wailsjs/go/app/App'); - const ddlRes = await DBShowCreateTable(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeTable); - if (ddlRes?.success) { - resStr = typeof ddlRes.data === 'string' ? ddlRes.data : JSON.stringify(ddlRes.data); - success = true; - } else { resStr = ddlRes?.message || 'Failed to fetch DDL'; } + const { DBShowCreateTable, DBGetColumns } = await import('../../wailsjs/go/app/App'); + const rpcConfig = buildRpcConnectionConfig(conn.config) as any; + const toolResult = await resolveAITableSchemaToolResult({ + tableName: safeTable, + fetchDDL: () => DBShowCreateTable(rpcConfig, safeDbName, safeTable), + fetchColumns: () => DBGetColumns(rpcConfig, safeDbName, safeTable), + }); + resStr = toolResult.content; + success = toolResult.success; } catch (e: any) { resStr = `获取建表语句失败: ${e?.message || e}`; } @@ -1173,14 +1178,8 @@ SELECT * FROM users WHERE status = 1; } } const { DBQuery } = await import('../../wailsjs/go/app/App'); - // 只对只读查询自动追加 LIMIT,写操作(UPDATE/DELETE/INSERT等)不追加 - const sqlTrimmed = safeSql.replace(/;\s*$/, ''); // 去掉末尾分号防止拼接出 "; LIMIT 50" - const sqlFirstWord = sqlTrimmed.trimStart().split(/\s/)[0]?.toLowerCase() || ''; - const isReadQuery = ['select', 'show', 'describe', 'desc', 'explain', 'with'].includes(sqlFirstWord); - const finalSql = (isReadQuery && !sqlTrimmed.toLowerCase().includes('limit')) - ? sqlTrimmed + ' LIMIT 50' - : sqlTrimmed; - const qRes = await DBQuery(buildRpcConnectionConfig(conn.config) as any, safeDbName, safeSql + (safeSql.toLowerCase().includes('limit') ? '' : ' LIMIT 50')); + const finalSql = buildAIReadonlyPreviewSQL(conn.config?.type || '', safeSql, 50, conn.config?.driver || ''); + const qRes = await DBQuery(buildRpcConnectionConfig(conn.config) as any, safeDbName, finalSql); if (qRes?.success) { const rows = Array.isArray(qRes.data) ? qRes.data : []; const limitedRows = rows.slice(0, 50); diff --git a/frontend/src/components/ai/AIChatInput.tsx b/frontend/src/components/ai/AIChatInput.tsx index 0640cfc..e497005 100644 --- a/frontend/src/components/ai/AIChatInput.tsx +++ b/frontend/src/components/ai/AIChatInput.tsx @@ -2,10 +2,11 @@ import React from 'react'; import { Input, Select, AutoComplete, Tooltip, Modal, Checkbox, Spin, message, Button, Tag } from 'antd'; import { DatabaseOutlined, SendOutlined, TableOutlined, SearchOutlined, PictureOutlined, ExclamationCircleFilled } from '@ant-design/icons'; import { useStore } from '../../store'; -import { DBGetTables, DBShowCreateTable, DBGetDatabases } from '../../../wailsjs/go/app/App'; +import { DBGetTables, DBShowCreateTable, DBGetDatabases, DBGetColumns } from '../../../wailsjs/go/app/App'; import type { OverlayWorkbenchTheme } from '../../utils/overlayWorkbenchTheme'; import type { AIComposerNotice } from '../../utils/aiComposerNotice'; import { buildRpcConnectionConfig } from '../../utils/connectionRpcConfig'; +import { resolveAITableSchemaToolResult } from '../../utils/aiTableSchemaTool'; interface AIChatInputProps { input: string; @@ -202,24 +203,21 @@ export const AIChatInput: React.FC = ({ if (activeContextItems.find(c => c.dbName === dbName && c.tableName === tableName)) { continue; } - const res = await DBShowCreateTable(buildRpcConnectionConfig(conn.config) as any, dbName, tableName); - let createSql = ''; - if (res.success && res.data) { - if (typeof res.data === 'string') { - createSql = res.data; - } else if (Array.isArray(res.data) && res.data.length > 0) { - const row = res.data[0]; - createSql = (Object.values(row).find(v => typeof v === 'string' && (v.toUpperCase().includes('CREATE TABLE') || v.toUpperCase().includes('CREATE'))) || Object.values(row)[1] || Object.values(row)[0]) as string; - } - } else { - message.error(`获取表 ${dbName}.${tableName} 结构失败: ` + (res.message || '未知错误')); + const rpcConfig = buildRpcConnectionConfig(conn.config) as any; + const schemaResult = await resolveAITableSchemaToolResult({ + tableName, + fetchDDL: () => DBShowCreateTable(rpcConfig, dbName, tableName), + fetchColumns: () => DBGetColumns(rpcConfig, dbName, tableName), + }); + if (!schemaResult.success) { + message.error(`获取表 ${dbName}.${tableName} 结构失败: ${schemaResult.content}`); } - - if (createSql) { + + if (schemaResult.success && schemaResult.content) { addAIContext(connectionKey, { dbName: dbName, tableName: tableName, - ddl: createSql + ddl: schemaResult.content }); addedCount++; } diff --git a/frontend/src/components/ai/AIMessageBubble.tsx b/frontend/src/components/ai/AIMessageBubble.tsx index b449142..3cfad53 100644 --- a/frontend/src/components/ai/AIMessageBubble.tsx +++ b/frontend/src/components/ai/AIMessageBubble.tsx @@ -15,6 +15,7 @@ import { parseJVMDiagnosticPlan, resolveJVMDiagnosticPlanTargetTabId, } from '../../utils/jvmDiagnosticPlan'; +import { buildAIReadonlyPreviewSQL } from '../../utils/aiSqlLimit'; // 🔧 性能优化:将 ReactMarkdown 包装为 Memo 组件并提取固定的 plugins const remarkPlugins = [remarkGfm]; @@ -260,7 +261,13 @@ const AIBlockHashRender = ({ match, darkMode, overlayTheme, children, activeConn setPreviewData(null); try { const { DBQuery } = await import('../../../wailsjs/go/app/App'); - const res = await DBQuery(activeConnectionConfig, activeDbName || '', displayText + ' LIMIT 50'); + const previewSql = buildAIReadonlyPreviewSQL( + activeConnectionConfig?.type || '', + displayText, + 50, + activeConnectionConfig?.driver || '', + ); + const res = await DBQuery(activeConnectionConfig, activeDbName || '', previewSql); if (res.success && Array.isArray(res.data)) { const rows = res.data as any[]; const cols = rows.length > 0 ? Object.keys(rows[0]) : []; diff --git a/frontend/src/utils/aiSqlLimit.test.ts b/frontend/src/utils/aiSqlLimit.test.ts new file mode 100644 index 0000000..9e3dfd3 --- /dev/null +++ b/frontend/src/utils/aiSqlLimit.test.ts @@ -0,0 +1,48 @@ +import { describe, expect, it } from 'vitest'; + +import { buildAIReadonlyPreviewSQL } from './aiSqlLimit'; + +describe('buildAIReadonlyPreviewSQL', () => { + it('limits Oracle readonly SQL with ROWNUM instead of MySQL LIMIT', () => { + const sql = buildAIReadonlyPreviewSQL('oracle', 'SELECT 1 FROM DUAL;', 50); + + expect(sql).toBe('SELECT * FROM (SELECT 1 FROM DUAL) WHERE ROWNUM <= 50'); + expect(sql.toLowerCase()).not.toContain('limit'); + }); + + it('does not add another limit when Oracle SQL already limits rows', () => { + expect(buildAIReadonlyPreviewSQL('oracle', 'SELECT * FROM users WHERE ROWNUM <= 10', 50)) + .toBe('SELECT * FROM users WHERE ROWNUM <= 10'); + expect(buildAIReadonlyPreviewSQL('oracle', 'SELECT * FROM users FETCH FIRST 10 ROWS ONLY', 50)) + .toBe('SELECT * FROM users FETCH FIRST 10 ROWS ONLY'); + }); + + it('resolves custom Oracle drivers from the driver alias', () => { + expect(buildAIReadonlyPreviewSQL('custom', 'SELECT 1 FROM DUAL;', 50, 'oracle')) + .toBe('SELECT * FROM (SELECT 1 FROM DUAL) WHERE ROWNUM <= 50'); + }); + + it('keeps MySQL-family SQL on LIMIT syntax', () => { + expect(buildAIReadonlyPreviewSQL('mysql', 'SELECT * FROM users', 50)) + .toBe('SELECT * FROM users LIMIT 50 OFFSET 0'); + }); + + it('keeps PostgreSQL-compatible and ClickHouse SQL on LIMIT syntax', () => { + expect(buildAIReadonlyPreviewSQL('postgres', 'SELECT * FROM users', 50)) + .toBe('SELECT * FROM users LIMIT 50 OFFSET 0'); + expect(buildAIReadonlyPreviewSQL('kingbase', 'SELECT * FROM users', 50)) + .toBe('SELECT * FROM users LIMIT 50 OFFSET 0'); + expect(buildAIReadonlyPreviewSQL('clickhouse', 'SELECT * FROM events', 50)) + .toBe('SELECT * FROM events LIMIT 50 OFFSET 0'); + }); + + it('limits Dameng readonly SQL with Oracle-compatible ROWNUM syntax', () => { + expect(buildAIReadonlyPreviewSQL('dameng', 'SELECT 1 FROM DUAL;', 50)) + .toBe('SELECT * FROM (SELECT 1 FROM DUAL) WHERE ROWNUM <= 50'); + }); + + it('does not limit non-readonly SQL', () => { + expect(buildAIReadonlyPreviewSQL('oracle', 'UPDATE users SET name = \'a\';', 50)) + .toBe('UPDATE users SET name = \'a\''); + }); +}); diff --git a/frontend/src/utils/aiSqlLimit.ts b/frontend/src/utils/aiSqlLimit.ts new file mode 100644 index 0000000..daf884c --- /dev/null +++ b/frontend/src/utils/aiSqlLimit.ts @@ -0,0 +1,31 @@ +import { buildPaginatedSelectSQL } from './sql'; +import { resolveSqlDialect } from './sqlDialect'; + +const AI_READONLY_SQL_KEYWORDS = new Set(['select', 'show', 'describe', 'desc', 'explain', 'with', 'pragma', 'values']); + +const trimSQLStatement = (sql: string): string => String(sql || '').trim().replace(/;\s*$/, '').trim(); + +const isAIReadonlySQL = (sql: string): boolean => { + const firstWord = trimSQLStatement(sql).trimStart().split(/\s+/)[0]?.toLowerCase() || ''; + return AI_READONLY_SQL_KEYWORDS.has(firstWord); +}; + +const hasExistingRowLimit = (dialect: string, sql: string): boolean => { + const text = trimSQLStatement(sql).toLowerCase(); + if (!text) return false; + if (/\blimit\s+\d+\b/.test(text)) return true; + if (/\bfetch\s+(first|next)\s+\d+\s+rows?\b/.test(text)) return true; + if (/\btop\s*\(?\s*\d+\s*\)?\b/.test(text)) return true; + + return (dialect === 'oracle' || dialect === 'dameng') && /\brownum\b/.test(text); +}; + +export const buildAIReadonlyPreviewSQL = (dbType: string, sql: string, limit = 50, driver = ''): string => { + const baseSQL = trimSQLStatement(sql); + const safeLimit = Math.max(0, Math.floor(Number(limit) || 0)); + const dialect = resolveSqlDialect(dbType, driver); + if (!baseSQL || safeLimit <= 0 || !isAIReadonlySQL(baseSQL) || hasExistingRowLimit(dialect, baseSQL)) { + return baseSQL; + } + return buildPaginatedSelectSQL(dialect, baseSQL, '', safeLimit, 0); +}; diff --git a/frontend/src/utils/aiTableSchemaTool.test.ts b/frontend/src/utils/aiTableSchemaTool.test.ts new file mode 100644 index 0000000..419bdec --- /dev/null +++ b/frontend/src/utils/aiTableSchemaTool.test.ts @@ -0,0 +1,51 @@ +import { describe, expect, it, vi } from 'vitest'; + +import { resolveAITableSchemaToolResult } from './aiTableSchemaTool'; + +describe('resolveAITableSchemaToolResult', () => { + it('returns DDL directly when DDL fetch succeeds', async () => { + const fetchColumns = vi.fn(); + + const result = await resolveAITableSchemaToolResult({ + tableName: 'USERS', + fetchDDL: vi.fn().mockResolvedValue({ success: true, data: 'CREATE TABLE USERS (ID NUMBER)' }), + fetchColumns, + }); + + expect(result).toEqual({ success: true, content: 'CREATE TABLE USERS (ID NUMBER)' }); + expect(fetchColumns).not.toHaveBeenCalled(); + }); + + it('falls back to column metadata when DDL fetch fails due to permissions', async () => { + const result = await resolveAITableSchemaToolResult({ + tableName: 'USERS', + fetchDDL: vi.fn().mockResolvedValue({ success: false, message: 'ORA-31603: object not found or insufficient privileges' }), + fetchColumns: vi.fn().mockResolvedValue({ + success: true, + data: [ + { Name: 'ID', Type: 'NUMBER', Nullable: 'NO', Default: null, Comment: '主键' }, + { Name: 'NAME', Type: 'VARCHAR2(64)', Nullable: 'YES' }, + ], + }), + }); + + expect(result.success).toBe(true); + expect(result.content).toContain('DDL 获取失败,已降级为字段元数据摘要'); + expect(result.content).toContain('ORA-31603'); + expect(result.content).toContain('可用字段:ID, NAME'); + expect(result.content).toContain('"field":"ID"'); + expect(result.content).toContain('"type":"NUMBER"'); + }); + + it('returns a combined failure when both DDL and column metadata fail', async () => { + const result = await resolveAITableSchemaToolResult({ + tableName: 'USERS', + fetchDDL: vi.fn().mockResolvedValue({ success: false, message: 'DDL permission denied' }), + fetchColumns: vi.fn().mockResolvedValue({ success: false, message: 'columns permission denied' }), + }); + + expect(result.success).toBe(false); + expect(result.content).toContain('DDL permission denied'); + expect(result.content).toContain('columns permission denied'); + }); +}); diff --git a/frontend/src/utils/aiTableSchemaTool.ts b/frontend/src/utils/aiTableSchemaTool.ts new file mode 100644 index 0000000..598d78e --- /dev/null +++ b/frontend/src/utils/aiTableSchemaTool.ts @@ -0,0 +1,69 @@ +type ToolQueryResult = { + success?: boolean; + data?: unknown; + message?: string; +}; + +type ResolveAITableSchemaToolResultParams = { + tableName: string; + fetchDDL: () => Promise; + fetchColumns: () => Promise; +}; + +const stringifyToolData = (data: unknown): string => ( + typeof data === 'string' ? data : JSON.stringify(data) +); + +const firstStringValue = (row: Record, keys: string[]): string => { + for (const key of keys) { + const value = row[key]; + if (value !== undefined && value !== null) { + return String(value); + } + } + return ''; +}; + +const normalizeAIColumn = (raw: unknown) => { + const row = (raw && typeof raw === 'object') ? raw as Record : {}; + const keys = Object.keys(row); + return { + field: firstStringValue(row, ['Field', 'field', 'COLUMN_NAME', 'column_name', 'Name', 'name']) || (keys.length > 0 ? String(row[keys[0]] ?? '') : ''), + type: firstStringValue(row, ['Type', 'type', 'DATA_TYPE', 'data_type']) || (keys.length > 1 ? String(row[keys[1]] ?? '') : ''), + nullable: firstStringValue(row, ['Null', 'null', 'IS_NULLABLE', 'is_nullable', 'Nullable', 'nullable']), + default: firstStringValue(row, ['Default', 'default', 'COLUMN_DEFAULT', 'column_default', 'DefaultValue']), + comment: firstStringValue(row, ['Comment', 'comment', 'COLUMN_COMMENT', 'column_comment', 'Description']), + }; +}; + +const buildColumnFallbackContent = (tableName: string, ddlError: string, columns: unknown[]): string => { + const normalizedColumns = columns.map(normalizeAIColumn).filter((column) => column.field.trim()); + const fieldNames = normalizedColumns.map((column) => column.field).join(', '); + return [ + `⚠️ 表 ${tableName} 的 DDL 获取失败,已降级为字段元数据摘要。`, + `DDL 错误:${ddlError || '未知错误'}`, + '该结果不包含完整索引、约束、触发器等 DDL 信息;请基于字段列表继续分析,不要因为 DDL 权限失败而停止。', + `可用字段:${fieldNames || '无'}`, + `详细信息:${JSON.stringify(normalizedColumns)}`, + ].join('\n'); +}; + +export const resolveAITableSchemaToolResult = async ({ + tableName, + fetchDDL, + fetchColumns, +}: ResolveAITableSchemaToolResultParams): Promise<{ success: boolean; content: string }> => { + const ddlResult = await fetchDDL(); + if (ddlResult?.success) { + return { success: true, content: stringifyToolData(ddlResult.data) }; + } + + const ddlError = ddlResult?.message || 'Failed to fetch DDL'; + const columnResult = await fetchColumns(); + if (columnResult?.success && Array.isArray(columnResult.data)) { + return { success: true, content: buildColumnFallbackContent(tableName, ddlError, columnResult.data) }; + } + + const columnError = columnResult?.message || 'Failed to fetch columns'; + return { success: false, content: `获取建表语句失败:${ddlError};降级获取字段列表也失败:${columnError}` }; +}; diff --git a/frontend/src/utils/sql.ts b/frontend/src/utils/sql.ts index 0864683..c7b251a 100644 --- a/frontend/src/utils/sql.ts +++ b/frontend/src/utils/sql.ts @@ -192,7 +192,8 @@ export const buildPaginatedSelectSQL = ( } switch (normalizedType) { - case 'oracle': { + case 'oracle': + case 'dameng': { const orderedSql = `${base}${orderBy}`; const upperBound = safeOffset + safeLimit; if (safeOffset <= 0) {