From f51dbcfb2ce1a48be6339299f5b90c987fa6047d Mon Sep 17 00:00:00 2001 From: Syngnat Date: Wed, 29 Apr 2026 09:41:25 +0800 Subject: [PATCH 1/6] =?UTF-8?q?=F0=9F=90=9B=20fix(oracle):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E6=9F=A5=E8=AF=A2=E7=BB=93=E6=9E=9C=E7=BC=96=E8=BE=91?= =?UTF-8?q?=E6=8F=90=E4=BA=A4=E5=90=8E=E6=95=B0=E6=8D=AE=E8=BF=98=E5=8E=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Oracle GetColumns 未返回主键列标记,前端 pkColumns 为空后退化为 全列 WHERE 条件,Oracle 空字符串即 NULL 语义导致 UPDATE 匹配 0 行。 LEFT JOIN all_constraints + all_cons_columns 检测主键列并赋值 Key="PRI", 与达梦驱动实现方式一致。 --- internal/db/oracle_impl.go | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go index 9efb1b6..6f8947f 100644 --- a/internal/db/oracle_impl.go +++ b/internal/db/oracle_impl.go @@ -263,16 +263,31 @@ func (o *OracleDB) GetCreateStatement(dbName, tableName string) (string, error) } func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { - query := fmt.Sprintf(`SELECT column_name, data_type, nullable, data_default - FROM all_tab_columns - WHERE owner = '%s' AND table_name = '%s' - ORDER BY column_id`, strings.ToUpper(dbName), strings.ToUpper(tableName)) + query := fmt.Sprintf(`SELECT c.column_name, c.data_type, c.nullable, c.data_default, + CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS column_key + FROM all_tab_columns c + LEFT JOIN ( + SELECT cols.owner, cols.table_name, cols.column_name + FROM all_constraints cons + JOIN all_cons_columns cols + ON cons.owner = cols.owner AND cons.constraint_name = cols.constraint_name + WHERE cons.constraint_type = 'P' + ) pk ON c.owner = pk.owner AND c.table_name = pk.table_name AND c.column_name = pk.column_name + WHERE c.owner = '%s' AND c.table_name = '%s' + ORDER BY c.column_id`, strings.ToUpper(dbName), strings.ToUpper(tableName)) if dbName == "" { - query = fmt.Sprintf(`SELECT column_name, data_type, nullable, data_default - FROM user_tab_columns - WHERE table_name = '%s' - ORDER BY column_id`, strings.ToUpper(tableName)) + query = fmt.Sprintf(`SELECT c.column_name, c.data_type, c.nullable, c.data_default, + CASE WHEN pk.column_name IS NOT NULL THEN 'PRI' ELSE '' END AS column_key + FROM user_tab_columns c + LEFT JOIN ( + SELECT cols.table_name, cols.column_name + FROM user_constraints cons + JOIN user_cons_columns cols USING (constraint_name) + WHERE cons.constraint_type = 'P' + ) pk ON c.table_name = pk.table_name AND c.column_name = pk.column_name + WHERE c.table_name = '%s' + ORDER BY c.column_id`, strings.ToUpper(tableName)) } data, _, err := o.Query(query) @@ -286,6 +301,7 @@ func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi Name: fmt.Sprintf("%v", row["COLUMN_NAME"]), Type: fmt.Sprintf("%v", row["DATA_TYPE"]), Nullable: fmt.Sprintf("%v", row["NULLABLE"]), + Key: fmt.Sprintf("%v", row["COLUMN_KEY"]), } if row["DATA_DEFAULT"] != nil { From 05a913ccb20d557fd9bcf1c4ade46f2a7c191621 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Wed, 29 Apr 2026 10:29:19 +0800 Subject: [PATCH 2/6] =?UTF-8?q?=F0=9F=90=9B=20fix(query-editor):=20?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=A4=9A=E6=95=B0=E6=8D=AE=E6=BA=90=E5=A4=A7?= =?UTF-8?q?=E6=9F=A5=E8=AF=A2=E9=99=90=E6=B5=81=E5=A4=B1=E6=95=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - SQL限流:抽取查询自动限流工具,修复 SELECT 判断大小写不一致导致限制未生效 - 方言适配:按 Oracle/Dameng、SQL Server、MySQL/PostgreSQL 等方言分别注入行数限制 - 自定义驱动:支持 custom 连接根据 driver 解析 Oracle、PostgreSQL、SQL Server 等方言 - MongoDB修复:修正 db.collection.find() 解析边界,并对 find/只读 aggregate 下推 limit - Oracle优化:DSN 增加 PREFETCH_ROWS 和 LOB FETCH 参数,减少大结果集拉取开销 - 测试覆盖:补充 SQL 方言矩阵、MongoDB 限流和 Oracle DSN 参数测试 Refs #424 --- frontend/src/components/QueryEditor.tsx | 370 +--------------------- frontend/src/utils/mongodb.test.ts | 105 +++++- frontend/src/utils/mongodb.ts | 43 ++- frontend/src/utils/queryAutoLimit.test.ts | 110 +++++++ frontend/src/utils/queryAutoLimit.ts | 336 ++++++++++++++++++++ internal/db/oracle_dsn_test.go | 32 ++ internal/db/oracle_impl.go | 4 + 7 files changed, 640 insertions(+), 360 deletions(-) create mode 100644 frontend/src/utils/queryAutoLimit.test.ts create mode 100644 frontend/src/utils/queryAutoLimit.ts create mode 100644 internal/db/oracle_dsn_test.go diff --git a/frontend/src/components/QueryEditor.tsx b/frontend/src/components/QueryEditor.tsx index ac7a4f4..b50627c 100644 --- a/frontend/src/components/QueryEditor.tsx +++ b/frontend/src/components/QueryEditor.tsx @@ -9,11 +9,12 @@ import { useStore } from '../store'; import { DBQueryWithCancel, DBQueryMulti, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, CancelQuery, GenerateQueryID, WriteSQLFile } from '../../wailsjs/go/app/App'; import DataGrid, { GONAVI_ROW_KEY } from './DataGrid'; import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities'; -import { convertMongoShellToJsonCommand } from '../utils/mongodb'; +import { applyMongoQueryAutoLimit, convertMongoShellToJsonCommand } from '../utils/mongodb'; import { getShortcutDisplay, isEditableElement, isShortcutMatch } from '../utils/shortcuts'; import { useAutoFetchVisibility } from '../utils/autoFetchVisibility'; import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; import { resolveSqlDialect, resolveSqlFunctions, resolveSqlKeywords } from '../utils/sqlDialect'; +import { applyQueryAutoLimit } from '../utils/queryAutoLimit'; const SQL_KEYWORDS = [ 'SELECT', 'FROM', 'WHERE', 'LIMIT', 'INSERT', 'UPDATE', 'DELETE', 'JOIN', 'LEFT', 'RIGHT', @@ -1184,359 +1185,6 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc return statements; }; - const getLeadingKeyword = (sql: string): string => { - const text = (sql || '').replace(/\r\n/g, '\n'); - const isWS = (ch: string) => ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r'; - const isWord = (ch: string) => /[A-Za-z0-9_]/.test(ch); - - let inSingle = false; - let inDouble = false; - let inBacktick = false; - let escaped = false; - let inLineComment = false; - let inBlockComment = false; - let dollarTag: string | null = null; - - for (let i = 0; i < text.length; i++) { - const ch = text[i]; - const next = i + 1 < text.length ? text[i + 1] : ''; - const prev = i > 0 ? text[i - 1] : ''; - const next2 = i + 2 < text.length ? text[i + 2] : ''; - - if (!inSingle && !inDouble && !inBacktick) { - if (inLineComment) { - if (ch === '\n') inLineComment = false; - continue; - } - if (inBlockComment) { - if (ch === '*' && next === '/') { - i++; - inBlockComment = false; - } - continue; - } - - if (ch === '/' && next === '*') { - i++; - inBlockComment = true; - continue; - } - if (ch === '#') { - inLineComment = true; - continue; - } - if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) { - i++; - inLineComment = true; - continue; - } - - if (dollarTag) { - if (text.startsWith(dollarTag, i)) { - i += dollarTag.length - 1; - dollarTag = null; - } - continue; - } - if (ch === '$') { - const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/); - if (m && m[0]) { - dollarTag = m[0]; - i += dollarTag.length - 1; - continue; - } - } - } - - if (escaped) { - escaped = false; - continue; - } - if ((inSingle || inDouble) && ch === '\\') { - escaped = true; - continue; - } - - if (!inDouble && !inBacktick && ch === '\'') { - inSingle = !inSingle; - continue; - } - if (!inSingle && !inBacktick && ch === '"') { - inDouble = !inDouble; - continue; - } - if (!inSingle && !inDouble && ch === '`') { - inBacktick = !inBacktick; - continue; - } - - if (inSingle || inDouble || inBacktick || dollarTag) continue; - if (isWS(ch)) continue; - - if (isWord(ch)) { - let j = i; - while (j < text.length && isWord(text[j])) j++; - return text.slice(i, j).toLowerCase(); - } - return ''; - } - return ''; - }; - - const splitSqlTail = (sql: string): { main: string; tail: string } => { - const text = (sql || '').replace(/\r\n/g, '\n'); - const isWS = (ch: string) => ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r'; - - let inSingle = false; - let inDouble = false; - let inBacktick = false; - let escaped = false; - let inLineComment = false; - let inBlockComment = false; - let dollarTag: string | null = null; - let lastMeaningful = -1; - - for (let i = 0; i < text.length; i++) { - const ch = text[i]; - const next = i + 1 < text.length ? text[i + 1] : ''; - const prev = i > 0 ? text[i - 1] : ''; - const next2 = i + 2 < text.length ? text[i + 2] : ''; - - if (!inSingle && !inDouble && !inBacktick) { - if (dollarTag) { - if (text.startsWith(dollarTag, i)) { - lastMeaningful = i + dollarTag.length - 1; - i += dollarTag.length - 1; - dollarTag = null; - } else if (!isWS(ch)) { - lastMeaningful = i; - } - continue; - } - if (inLineComment) { - if (ch === '\n') inLineComment = false; - continue; - } - if (inBlockComment) { - if (ch === '*' && next === '/') { - i++; - inBlockComment = false; - } - continue; - } - - // Start comments - if (ch === '/' && next === '*') { - i++; - inBlockComment = true; - continue; - } - if (ch === '#') { - inLineComment = true; - continue; - } - if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) { - i++; - inLineComment = true; - continue; - } - - if (ch === '$') { - const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/); - if (m && m[0]) { - dollarTag = m[0]; - lastMeaningful = i + dollarTag.length - 1; - i += dollarTag.length - 1; - continue; - } - } - } - - if (escaped) { - escaped = false; - } else if ((inSingle || inDouble) && ch === '\\') { - escaped = true; - } else { - if (!inDouble && !inBacktick && ch === '\'') inSingle = !inSingle; - else if (!inSingle && !inBacktick && ch === '"') inDouble = !inDouble; - else if (!inSingle && !inDouble && ch === '`') inBacktick = !inBacktick; - } - - if (!inLineComment && !inBlockComment && !isWS(ch)) { - lastMeaningful = i; - } - } - - if (lastMeaningful < 0) return { main: '', tail: text }; - return { main: text.slice(0, lastMeaningful + 1), tail: text.slice(lastMeaningful + 1) }; - }; - - const findTopLevelKeyword = (sql: string, keyword: string): number => { - const text = sql; - const kw = keyword.toLowerCase(); - const isWS = (ch: string) => ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r'; - const isWord = (ch: string) => /[A-Za-z0-9_]/.test(ch); - - let inSingle = false; - let inDouble = false; - let inBacktick = false; - let escaped = false; - let inLineComment = false; - let inBlockComment = false; - let dollarTag: string | null = null; - let parenDepth = 0; - - for (let i = 0; i < text.length; i++) { - const ch = text[i]; - const next = i + 1 < text.length ? text[i + 1] : ''; - const prev = i > 0 ? text[i - 1] : ''; - const next2 = i + 2 < text.length ? text[i + 2] : ''; - - if (!inSingle && !inDouble && !inBacktick) { - if (inLineComment) { - if (ch === '\n') inLineComment = false; - continue; - } - if (inBlockComment) { - if (ch === '*' && next === '/') { - i++; - inBlockComment = false; - } - continue; - } - - if (ch === '/' && next === '*') { - i++; - inBlockComment = true; - continue; - } - if (ch === '#') { - inLineComment = true; - continue; - } - if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) { - i++; - inLineComment = true; - continue; - } - - if (dollarTag) { - if (text.startsWith(dollarTag, i)) { - i += dollarTag.length - 1; - dollarTag = null; - } - continue; - } - if (ch === '$') { - const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/); - if (m && m[0]) { - dollarTag = m[0]; - i += dollarTag.length - 1; - continue; - } - } - } - - if (escaped) { - escaped = false; - continue; - } - if ((inSingle || inDouble) && ch === '\\') { - escaped = true; - continue; - } - - if (!inDouble && !inBacktick && ch === '\'') { - inSingle = !inSingle; - continue; - } - if (!inSingle && !inBacktick && ch === '"') { - inDouble = !inDouble; - continue; - } - if (!inSingle && !inDouble && ch === '`') { - inBacktick = !inBacktick; - continue; - } - - if (inSingle || inDouble || inBacktick || dollarTag) continue; - - if (ch === '(') { parenDepth++; continue; } - if (ch === ')') { if (parenDepth > 0) parenDepth--; continue; } - if (parenDepth !== 0) continue; - - if (!isWord(ch)) continue; - - if (text.slice(i, i + kw.length).toLowerCase() !== kw) continue; - const before = i - 1 >= 0 ? text[i - 1] : ''; - const after = i + kw.length < text.length ? text[i + kw.length] : ''; - if ((before && isWord(before)) || (after && isWord(after))) continue; - return i; - } - return -1; - }; - - const applyAutoLimit = (sql: string, dbType: string, maxRows: number): { sql: string; applied: boolean; maxRows: number } => { - if (!Number.isFinite(maxRows) || maxRows <= 0) return { sql, applied: false, maxRows }; - const normalizedType = (dbType || 'mysql').toLowerCase(); - - // 只对 SELECT 语句自动加限制 - const keyword = getLeadingKeyword(sql); - if (keyword !== 'SELECT') return { sql, applied: false, maxRows }; - - const { main, tail } = splitSqlTail(sql); - if (!main.trim()) return { sql, applied: false, maxRows }; - - const fromPos = findTopLevelKeyword(main, 'from'); - const limitPos = findTopLevelKeyword(main, 'limit'); - // 已有 LIMIT → 不注入 - if (limitPos >= 0 && (fromPos < 0 || limitPos > fromPos)) return { sql, applied: false, maxRows }; - const fetchPos = findTopLevelKeyword(main, 'fetch'); - // 已有 FETCH → 不注入 - if (fetchPos >= 0 && (fromPos < 0 || fetchPos > fromPos)) return { sql, applied: false, maxRows }; - - // SQL Server / mssql: 检查是否已有 TOP,未有则注入 SELECT TOP N - if (normalizedType === 'sqlserver' || normalizedType === 'mssql') { - const topPos = findTopLevelKeyword(main, 'top'); - if (topPos >= 0) return { sql, applied: false, maxRows }; // 已有 TOP - // 在 SELECT 关键字之后插入 TOP N - const selectPos = findTopLevelKeyword(main, 'select'); - if (selectPos < 0) return { sql, applied: false, maxRows }; - const afterSelect = selectPos + 'SELECT'.length; - // 处理 SELECT DISTINCT 的情况 - const restAfterSelect = main.slice(afterSelect); - const distinctMatch = restAfterSelect.match(/^(\s+DISTINCT\b)/i); - const insertOffset = distinctMatch ? afterSelect + distinctMatch[1].length : afterSelect; - const nextMain = main.slice(0, insertOffset) + ` TOP ${maxRows}` + main.slice(insertOffset); - return { sql: nextMain + tail, applied: true, maxRows }; - } - - // Oracle / Dameng: 使用 FETCH FIRST N ROWS ONLY(Oracle 12c+ 标准语法) - if (normalizedType === 'oracle' || normalizedType === 'dameng') { - // 检查是否已有 ROWNUM 限制 - const rownumPos = findTopLevelKeyword(main, 'rownum'); - if (rownumPos >= 0) return { sql, applied: false, maxRows }; - const offsetPos = findTopLevelKeyword(main, 'offset'); - if (offsetPos >= 0 && (fromPos < 0 || offsetPos > fromPos)) return { sql, applied: false, maxRows }; - const nextMain = main.trimEnd() + ` FETCH FIRST ${maxRows} ROWS ONLY`; - return { sql: nextMain + tail, applied: true, maxRows }; - } - - // 通用 LIMIT 语法(MySQL, PostgreSQL, SQLite, ClickHouse, DuckDB 等) - const offsetPos = findTopLevelKeyword(main, 'offset'); - const forPos = findTopLevelKeyword(main, 'for'); - const lockPos = findTopLevelKeyword(main, 'lock'); - - const candidates = [offsetPos, forPos, lockPos] - .filter(pos => pos >= 0 && (fromPos < 0 || pos > fromPos)); - - const insertAt = candidates.length > 0 ? Math.min(...candidates) : main.length; - const before = main.slice(0, insertAt).trimEnd(); - const after = main.slice(insertAt).trimStart(); - const nextMain = [before, `LIMIT ${maxRows}`, after].filter(Boolean).join(' ').trim(); - return { sql: nextMain + tail, applied: true, maxRows }; - }; - const getSelectedSQL = (): string => { const editor = editorRef.current; if (!editor) return ''; @@ -1662,8 +1310,10 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc try { const rawSQL = getSelectedSQL() || currentQuery; - const dbType = String((buildRpcConnectionConfig(config) as any).type || 'mysql'); - const normalizedDbType = dbType.trim().toLowerCase(); + const rpcConfig = buildRpcConnectionConfig(config) as any; + const dbType = String(rpcConfig.type || 'mysql'); + const driver = String((config as any).driver || ''); + const normalizedDbType = String(resolveSqlDialect(dbType, driver)).trim().toLowerCase(); const normalizedRawSQL = String(rawSQL || '').replace(/;/g, ';'); // MongoDB 仍走逐条执行的旧路径 @@ -1703,6 +1353,12 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc executedSql = shellConvert.command; } } + if (wantsLimitProbe) { + const limitResult = applyMongoQueryAutoLimit(executedSql, maxRows); + if (limitResult.applied) { + executedSql = limitResult.command; + } + } const startTime = Date.now(); let queryId: string; try { @@ -1797,7 +1453,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc if (Number.isFinite(maxRowsForLimit) && maxRowsForLimit > 0) { const stmts = splitSQLStatements(fullSQL); const limitedStmts = stmts.map(s => { - const result = applyAutoLimit(s, normalizedDbType, maxRowsForLimit); + const result = applyQueryAutoLimit(s, normalizedDbType, maxRowsForLimit, driver); if (result.applied) anyLimitApplied = true; return result.sql; }); diff --git a/frontend/src/utils/mongodb.test.ts b/frontend/src/utils/mongodb.test.ts index 2e8ef64..8d0f7ed 100644 --- a/frontend/src/utils/mongodb.test.ts +++ b/frontend/src/utils/mongodb.test.ts @@ -1,6 +1,8 @@ import { describe, expect, it } from 'vitest'; -import { convertMongoShellToJsonCommand } from './mongodb'; +import { applyMongoQueryAutoLimit, convertMongoShellToJsonCommand } from './mongodb'; + +const parseCommand = (command: string | undefined) => JSON.parse(command || '{}'); describe('convertMongoShellToJsonCommand', () => { it('converts show dbs shell shortcut to listDatabases command', () => { @@ -16,4 +18,105 @@ describe('convertMongoShellToJsonCommand', () => { command: JSON.stringify({ listCollections: 1, filter: {}, nameOnly: true }), }); }); + + it('converts find shell commands without adding implicit limit', () => { + const result = convertMongoShellToJsonCommand('db.users.find({ active: true })'); + + expect(result.recognized).toBe(true); + expect(parseCommand(result.command)).toEqual({ + find: 'users', + filter: { active: true }, + }); + }); + + it('keeps explicit find limit values from shell commands', () => { + const result = convertMongoShellToJsonCommand('db.users.find({}).limit(10)'); + + expect(parseCommand(result.command)).toEqual({ + find: 'users', + filter: {}, + limit: 10, + }); + }); + + it('keeps explicit zero limit values from shell commands', () => { + const result = convertMongoShellToJsonCommand('db.users.find({}).limit(0)'); + + expect(parseCommand(result.command)).toEqual({ + find: 'users', + filter: {}, + limit: 0, + }); + }); +}); + +describe('applyMongoQueryAutoLimit', () => { + it('adds limit to raw Mongo find commands', () => { + const result = applyMongoQueryAutoLimit('{"find":"users","filter":{}}', 500); + + expect(result.applied).toBe(true); + expect(parseCommand(result.command)).toEqual({ + find: 'users', + filter: {}, + limit: 500, + }); + }); + + it('adds limit after shell find conversion', () => { + const shell = convertMongoShellToJsonCommand('db.users.find({ active: true })'); + const result = applyMongoQueryAutoLimit(shell.command || '', 500); + + expect(result.applied).toBe(true); + expect(parseCommand(result.command)).toEqual({ + find: 'users', + filter: { active: true }, + limit: 500, + }); + }); + + it('does not replace explicit find limits', () => { + const result = applyMongoQueryAutoLimit('{"find":"users","filter":{},"limit":10}', 500); + + expect(result.applied).toBe(false); + expect(parseCommand(result.command)).toEqual({ + find: 'users', + filter: {}, + limit: 10, + }); + }); + + it('adds $limit to read-only aggregate pipelines', () => { + const result = applyMongoQueryAutoLimit('{"aggregate":"users","pipeline":[{"$match":{"active":true}}],"cursor":{}}', 500); + + expect(result.applied).toBe(true); + expect(parseCommand(result.command)).toEqual({ + aggregate: 'users', + pipeline: [ + { $match: { active: true } }, + { $limit: 500 }, + ], + cursor: {}, + }); + }); + + it('does not add another aggregate $limit', () => { + const command = '{"aggregate":"users","pipeline":[{"$limit":10}],"cursor":{}}'; + const result = applyMongoQueryAutoLimit(command, 500); + + expect(result.applied).toBe(false); + expect(result.command).toBe(command); + }); + + it('does not alter aggregate write pipelines', () => { + const command = '{"aggregate":"users","pipeline":[{"$match":{}},{"$out":"tmp_users"}],"cursor":{}}'; + const result = applyMongoQueryAutoLimit(command, 500); + + expect(result.applied).toBe(false); + expect(result.command).toBe(command); + }); + + it('does not limit non-read or invalid commands', () => { + expect(applyMongoQueryAutoLimit('{"count":"users","query":{}}', 500).applied).toBe(false); + expect(applyMongoQueryAutoLimit('db.users.find({})', 500).applied).toBe(false); + }); }); diff --git a/frontend/src/utils/mongodb.ts b/frontend/src/utils/mongodb.ts index d2b399d..9504a4b 100644 --- a/frontend/src/utils/mongodb.ts +++ b/frontend/src/utils/mongodb.ts @@ -321,7 +321,7 @@ const parseCollectionAndMethod = (raw: string): { pos = nextPos; } else { let end = pos; - while (end < input.length && /[A-Za-z0-9_$.-]/.test(input[end])) end++; + while (end < input.length && /[A-Za-z0-9_$-]/.test(input[end])) end++; collection = input.slice(pos, end).trim(); pos = end; } @@ -662,7 +662,7 @@ export const buildMongoFindCommand = (params: { if (params.sort && Object.keys(params.sort).length > 0) { command.sort = params.sort; } - if (Number.isFinite(params.limit) && Number(params.limit) > 0) { + if (Number.isFinite(params.limit) && Number(params.limit) >= 0) { command.limit = Math.floor(Number(params.limit)); } if (Number.isFinite(params.skip) && Number(params.skip) > 0) { @@ -678,6 +678,45 @@ export const buildMongoCountCommand = (collection: string, filter: Record, key: string) => Object.prototype.hasOwnProperty.call(obj, key); + +const isMongoCommandObject = (value: unknown): value is Record => ( + !!value && typeof value === 'object' && !Array.isArray(value) +); + +export const applyMongoQueryAutoLimit = ( + command: string, + maxRows: number, +): { command: string; applied: boolean; maxRows: number } => { + if (!Number.isFinite(maxRows) || maxRows <= 0) return { command, applied: false, maxRows }; + + let parsed: unknown; + try { + parsed = JSON.parse(String(command || '').trim()); + } catch { + return { command, applied: false, maxRows }; + } + if (!isMongoCommandObject(parsed)) return { command, applied: false, maxRows }; + + const nextMaxRows = Math.floor(Number(maxRows)); + if (hasOwn(parsed, 'find')) { + if (hasOwn(parsed, 'limit')) return { command, applied: false, maxRows }; + parsed.limit = nextMaxRows; + return { command: JSON.stringify(parsed), applied: true, maxRows }; + } + + if (hasOwn(parsed, 'aggregate') && Array.isArray(parsed.pipeline)) { + const pipeline = parsed.pipeline as unknown[]; + const hasExplicitLimit = pipeline.some((stage) => isMongoCommandObject(stage) && hasOwn(stage, '$limit')); + const hasWriteStage = pipeline.some((stage) => isMongoCommandObject(stage) && (hasOwn(stage, '$out') || hasOwn(stage, '$merge'))); + if (hasExplicitLimit || hasWriteStage) return { command, applied: false, maxRows }; + pipeline.push({ $limit: nextMaxRows }); + return { command: JSON.stringify(parsed), applied: true, maxRows }; + } + + return { command, applied: false, maxRows }; +}; + const buildMongoInsertCommand = ( collection: string, documents: Record[], diff --git a/frontend/src/utils/queryAutoLimit.test.ts b/frontend/src/utils/queryAutoLimit.test.ts new file mode 100644 index 0000000..32964a7 --- /dev/null +++ b/frontend/src/utils/queryAutoLimit.test.ts @@ -0,0 +1,110 @@ +import { describe, expect, it } from 'vitest'; + +import { applyQueryAutoLimit } from './queryAutoLimit'; + +describe('applyQueryAutoLimit', () => { + const limitDialects = [ + 'mysql', + 'mariadb', + 'diros', + 'doris', + 'sphinx', + 'postgres', + 'postgresql', + 'kingbase', + 'kingbase8', + 'highgo', + 'vastbase', + 'sqlite', + 'sqlite3', + 'duckdb', + 'clickhouse', + 'tdengine', + ]; + + it.each(limitDialects)('adds generic LIMIT for %s connections', (dbType) => { + expect(applyQueryAutoLimit('SELECT * FROM users', dbType, 500).sql) + .toBe('SELECT * FROM users LIMIT 500'); + }); + + it.each([ + ['oracle'], + ['dameng'], + ['dm'], + ['dm8'], + ])('adds FETCH FIRST limit for %s connections', (dbType) => { + expect(applyQueryAutoLimit('SELECT * FROM MYCIMLED.EDC_LOG', dbType, 500).sql) + .toBe('SELECT * FROM MYCIMLED.EDC_LOG FETCH FIRST 500 ROWS ONLY'); + }); + + it.each([ + ['sqlserver'], + ['mssql'], + ['sql_server'], + ['sql-server'], + ])('adds TOP limit for %s connections', (dbType) => { + expect(applyQueryAutoLimit('SELECT * FROM users', dbType, 500).sql) + .toBe('SELECT TOP 500 * FROM users'); + }); + + it('adds SQL Server TOP after DISTINCT', () => { + expect(applyQueryAutoLimit('SELECT DISTINCT name FROM users', 'sqlserver', 500).sql) + .toBe('SELECT DISTINCT TOP 500 name FROM users'); + }); + + it.each([ + ['oracle', 'SELECT * FROM users FETCH FIRST 500 ROWS ONLY'], + ['dm8', 'SELECT * FROM users FETCH FIRST 500 ROWS ONLY'], + ['mssql', 'SELECT TOP 500 * FROM users'], + ['postgresql', 'SELECT * FROM users LIMIT 500'], + ['doris', 'SELECT * FROM users LIMIT 500'], + ['sqlite3', 'SELECT * FROM users LIMIT 500'], + ])('uses custom driver dialect %s', (driver, expected) => { + expect(applyQueryAutoLimit('SELECT * FROM users', 'custom', 500, driver).sql) + .toBe(expected); + }); + + it('keeps trailing semicolon and comments after injected Oracle limit', () => { + expect(applyQueryAutoLimit('SELECT * FROM MYCIMLED.EDC_LOG; -- preview', 'oracle', 500).sql) + .toBe('SELECT * FROM MYCIMLED.EDC_LOG FETCH FIRST 500 ROWS ONLY; -- preview'); + }); + + it('does not add another generic limit when SQL already limits rows', () => { + expect(applyQueryAutoLimit('SELECT * FROM users LIMIT 10', 'mysql', 500).applied) + .toBe(false); + expect(applyQueryAutoLimit('SELECT * FROM users OFFSET 10 LIMIT 10', 'postgres', 500).applied) + .toBe(false); + }); + + it('does not treat nested LIMIT as the outer query limit', () => { + expect(applyQueryAutoLimit('SELECT * FROM (SELECT * FROM users LIMIT 10) t', 'postgres', 500).sql) + .toBe('SELECT * FROM (SELECT * FROM users LIMIT 10) t LIMIT 500'); + }); + + it('does not add another Oracle limit when Oracle SQL already limits rows', () => { + expect(applyQueryAutoLimit('SELECT * FROM users WHERE ROWNUM <= 10', 'oracle', 500).applied) + .toBe(false); + expect(applyQueryAutoLimit('SELECT * FROM users FETCH FIRST 10 ROWS ONLY', 'oracle', 500).applied) + .toBe(false); + }); + + it('does not add another SQL Server limit when SQL already uses TOP', () => { + expect(applyQueryAutoLimit('SELECT TOP 10 * FROM users', 'sqlserver', 500).applied) + .toBe(false); + }); + + it('adds generic LIMIT before locking clauses', () => { + expect(applyQueryAutoLimit('SELECT * FROM users FOR UPDATE', 'mysql', 500).sql) + .toBe('SELECT * FROM users LIMIT 500 FOR UPDATE'); + }); + + it('adds generic LIMIT before OFFSET clauses', () => { + expect(applyQueryAutoLimit('SELECT * FROM users OFFSET 10', 'postgres', 500).sql) + .toBe('SELECT * FROM users LIMIT 500 OFFSET 10'); + }); + + it('does not limit non-select statements', () => { + expect(applyQueryAutoLimit('UPDATE users SET name = \'a\'', 'mysql', 500).applied) + .toBe(false); + }); +}); diff --git a/frontend/src/utils/queryAutoLimit.ts b/frontend/src/utils/queryAutoLimit.ts new file mode 100644 index 0000000..8b8560c --- /dev/null +++ b/frontend/src/utils/queryAutoLimit.ts @@ -0,0 +1,336 @@ +import { resolveSqlDialect } from './sqlDialect'; + +const isWS = (ch: string) => ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r'; +const isWord = (ch: string) => /[A-Za-z0-9_]/.test(ch); + +const getLeadingKeyword = (sql: string): string => { + const text = (sql || '').replace(/\r\n/g, '\n'); + let inSingle = false; + let inDouble = false; + let inBacktick = false; + let escaped = false; + let inLineComment = false; + let inBlockComment = false; + let dollarTag: string | null = null; + + for (let i = 0; i < text.length; i++) { + const ch = text[i]; + const next = i + 1 < text.length ? text[i + 1] : ''; + const prev = i > 0 ? text[i - 1] : ''; + const next2 = i + 2 < text.length ? text[i + 2] : ''; + + if (!inSingle && !inDouble && !inBacktick) { + if (inLineComment) { + if (ch === '\n') inLineComment = false; + continue; + } + if (inBlockComment) { + if (ch === '*' && next === '/') { + i++; + inBlockComment = false; + } + continue; + } + if (ch === '/' && next === '*') { + i++; + inBlockComment = true; + continue; + } + if (ch === '#') { + inLineComment = true; + continue; + } + if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) { + i++; + inLineComment = true; + continue; + } + if (dollarTag) { + if (text.startsWith(dollarTag, i)) { + i += dollarTag.length - 1; + dollarTag = null; + } + continue; + } + if (ch === '$') { + const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/); + if (m && m[0]) { + dollarTag = m[0]; + i += dollarTag.length - 1; + continue; + } + } + } + + if (escaped) { + escaped = false; + continue; + } + if ((inSingle || inDouble) && ch === '\\') { + escaped = true; + continue; + } + if (!inDouble && !inBacktick && ch === "'") { + inSingle = !inSingle; + continue; + } + if (!inSingle && !inBacktick && ch === '"') { + inDouble = !inDouble; + continue; + } + if (!inSingle && !inDouble && ch === '`') { + inBacktick = !inBacktick; + continue; + } + if (inSingle || inDouble || inBacktick || dollarTag) continue; + if (isWS(ch)) continue; + if (isWord(ch)) { + let j = i; + while (j < text.length && isWord(text[j])) j++; + return text.slice(i, j).toLowerCase(); + } + return ''; + } + return ''; +}; + +const splitSqlTail = (sql: string): { main: string; tail: string } => { + const text = (sql || '').replace(/\r\n/g, '\n'); + let inSingle = false; + let inDouble = false; + let inBacktick = false; + let escaped = false; + let inLineComment = false; + let inBlockComment = false; + let dollarTag: string | null = null; + let lastMeaningful = -1; + + for (let i = 0; i < text.length; i++) { + const ch = text[i]; + const next = i + 1 < text.length ? text[i + 1] : ''; + const prev = i > 0 ? text[i - 1] : ''; + const next2 = i + 2 < text.length ? text[i + 2] : ''; + + if (!inSingle && !inDouble && !inBacktick) { + if (dollarTag) { + if (text.startsWith(dollarTag, i)) { + lastMeaningful = i + dollarTag.length - 1; + i += dollarTag.length - 1; + dollarTag = null; + } else if (!isWS(ch)) { + lastMeaningful = i; + } + continue; + } + if (inLineComment) { + if (ch === '\n') inLineComment = false; + continue; + } + if (inBlockComment) { + if (ch === '*' && next === '/') { + i++; + inBlockComment = false; + } + continue; + } + if (ch === '/' && next === '*') { + i++; + inBlockComment = true; + continue; + } + if (ch === '#') { + inLineComment = true; + continue; + } + if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) { + i++; + inLineComment = true; + continue; + } + if (ch === '$') { + const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/); + if (m && m[0]) { + dollarTag = m[0]; + lastMeaningful = i + dollarTag.length - 1; + i += dollarTag.length - 1; + continue; + } + } + } + + if (escaped) { + escaped = false; + } else if ((inSingle || inDouble) && ch === '\\') { + escaped = true; + } else { + if (!inDouble && !inBacktick && ch === "'") inSingle = !inSingle; + else if (!inSingle && !inBacktick && ch === '"') inDouble = !inDouble; + else if (!inSingle && !inDouble && ch === '`') inBacktick = !inBacktick; + } + + if (!inLineComment && !inBlockComment && !isWS(ch)) { + lastMeaningful = i; + } + } + + if (lastMeaningful < 0) return { main: '', tail: text }; + let mainEnd = lastMeaningful + 1; + while (mainEnd > 0 && (isWS(text[mainEnd - 1]) || text[mainEnd - 1] === ';' || text[mainEnd - 1] === ';')) { + mainEnd--; + } + return { main: text.slice(0, mainEnd), tail: text.slice(mainEnd) }; +}; + +const findTopLevelKeyword = (sql: string, keyword: string): number => { + const text = sql; + const kw = keyword.toLowerCase(); + let inSingle = false; + let inDouble = false; + let inBacktick = false; + let escaped = false; + let inLineComment = false; + let inBlockComment = false; + let dollarTag: string | null = null; + let parenDepth = 0; + + for (let i = 0; i < text.length; i++) { + const ch = text[i]; + const next = i + 1 < text.length ? text[i + 1] : ''; + const prev = i > 0 ? text[i - 1] : ''; + const next2 = i + 2 < text.length ? text[i + 2] : ''; + + if (!inSingle && !inDouble && !inBacktick) { + if (inLineComment) { + if (ch === '\n') inLineComment = false; + continue; + } + if (inBlockComment) { + if (ch === '*' && next === '/') { + i++; + inBlockComment = false; + } + continue; + } + if (ch === '/' && next === '*') { + i++; + inBlockComment = true; + continue; + } + if (ch === '#') { + inLineComment = true; + continue; + } + if (ch === '-' && next === '-' && (i === 0 || isWS(prev)) && (next2 === '' || isWS(next2))) { + i++; + inLineComment = true; + continue; + } + if (dollarTag) { + if (text.startsWith(dollarTag, i)) { + i += dollarTag.length - 1; + dollarTag = null; + } + continue; + } + if (ch === '$') { + const m = text.slice(i).match(/^\$[A-Za-z0-9_]*\$/); + if (m && m[0]) { + dollarTag = m[0]; + i += dollarTag.length - 1; + continue; + } + } + } + + if (escaped) { + escaped = false; + continue; + } + if ((inSingle || inDouble) && ch === '\\') { + escaped = true; + continue; + } + if (!inDouble && !inBacktick && ch === "'") { + inSingle = !inSingle; + continue; + } + if (!inSingle && !inBacktick && ch === '"') { + inDouble = !inDouble; + continue; + } + if (!inSingle && !inDouble && ch === '`') { + inBacktick = !inBacktick; + continue; + } + if (inSingle || inDouble || inBacktick || dollarTag) continue; + if (ch === '(') { + parenDepth++; + continue; + } + if (ch === ')') { + if (parenDepth > 0) parenDepth--; + continue; + } + if (parenDepth !== 0) continue; + if (!isWord(ch)) continue; + if (text.slice(i, i + kw.length).toLowerCase() !== kw) continue; + const before = i - 1 >= 0 ? text[i - 1] : ''; + const after = i + kw.length < text.length ? text[i + kw.length] : ''; + if ((before && isWord(before)) || (after && isWord(after))) continue; + return i; + } + return -1; +}; + +export const applyQueryAutoLimit = ( + sql: string, + dbType: string, + maxRows: number, + driver = '', +): { sql: string; applied: boolean; maxRows: number } => { + if (!Number.isFinite(maxRows) || maxRows <= 0) return { sql, applied: false, maxRows }; + const normalizedType = String(resolveSqlDialect(dbType || 'mysql', driver)).toLowerCase(); + const keyword = getLeadingKeyword(sql); + if (keyword !== 'select') return { sql, applied: false, maxRows }; + + const { main, tail } = splitSqlTail(sql); + if (!main.trim()) return { sql, applied: false, maxRows }; + + const fromPos = findTopLevelKeyword(main, 'from'); + const limitPos = findTopLevelKeyword(main, 'limit'); + if (limitPos >= 0 && (fromPos < 0 || limitPos > fromPos)) return { sql, applied: false, maxRows }; + const fetchPos = findTopLevelKeyword(main, 'fetch'); + if (fetchPos >= 0 && (fromPos < 0 || fetchPos > fromPos)) return { sql, applied: false, maxRows }; + + if (normalizedType === 'sqlserver' || normalizedType === 'mssql') { + const topPos = findTopLevelKeyword(main, 'top'); + if (topPos >= 0) return { sql, applied: false, maxRows }; + const selectPos = findTopLevelKeyword(main, 'select'); + if (selectPos < 0) return { sql, applied: false, maxRows }; + const afterSelect = selectPos + 'SELECT'.length; + const restAfterSelect = main.slice(afterSelect); + const distinctMatch = restAfterSelect.match(/^(\s+DISTINCT\b)/i); + const insertOffset = distinctMatch ? afterSelect + distinctMatch[1].length : afterSelect; + const nextMain = main.slice(0, insertOffset) + ` TOP ${maxRows}` + main.slice(insertOffset); + return { sql: nextMain + tail, applied: true, maxRows }; + } + + if (normalizedType === 'oracle' || normalizedType === 'dameng') { + const rownumPos = findTopLevelKeyword(main, 'rownum'); + if (rownumPos >= 0) return { sql, applied: false, maxRows }; + const offsetPos = findTopLevelKeyword(main, 'offset'); + if (offsetPos >= 0 && (fromPos < 0 || offsetPos > fromPos)) return { sql, applied: false, maxRows }; + return { sql: `${main.trimEnd()} FETCH FIRST ${maxRows} ROWS ONLY${tail}`, applied: true, maxRows }; + } + + const offsetPos = findTopLevelKeyword(main, 'offset'); + const forPos = findTopLevelKeyword(main, 'for'); + const lockPos = findTopLevelKeyword(main, 'lock'); + const candidates = [offsetPos, forPos, lockPos] + .filter(pos => pos >= 0 && (fromPos < 0 || pos > fromPos)); + const insertAt = candidates.length > 0 ? Math.min(...candidates) : main.length; + const before = main.slice(0, insertAt).trimEnd(); + const after = main.slice(insertAt).trimStart(); + const nextMain = [before, `LIMIT ${maxRows}`, after].filter(Boolean).join(' ').trim(); + return { sql: nextMain + tail, applied: true, maxRows }; +}; diff --git a/internal/db/oracle_dsn_test.go b/internal/db/oracle_dsn_test.go new file mode 100644 index 0000000..c4557b6 --- /dev/null +++ b/internal/db/oracle_dsn_test.go @@ -0,0 +1,32 @@ +package db + +import ( + "net/url" + "testing" + + "GoNavi-Wails/internal/connection" +) + +func TestOracleGetDSNIncludesQueryPerformanceOptions(t *testing.T) { + t.Parallel() + + dsn := (&OracleDB{}).getDSN(connection.ConnectionConfig{ + Host: "db.example.com", + Port: 1521, + User: "scott", + Password: "tiger", + Database: "ORCLPDB1", + }) + + parsed, err := url.Parse(dsn) + if err != nil { + t.Fatalf("解析 Oracle DSN 失败: %v", err) + } + query := parsed.Query() + if got := query.Get("PREFETCH_ROWS"); got != "10000" { + t.Fatalf("PREFETCH_ROWS = %q, want 10000", got) + } + if got := query.Get("LOB FETCH"); got != "POST" { + t.Fatalf("LOB FETCH = %q, want POST", got) + } +} diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go index 6f8947f..7d9c3ec 100644 --- a/internal/db/oracle_impl.go +++ b/internal/db/oracle_impl.go @@ -44,6 +44,10 @@ func (o *OracleDB) getDSN(config connection.ConnectionConfig) string { q.Set("SSL", "TRUE") q.Set("SSL VERIFY", "FALSE") } + // 提高 prefetch 行数,减少大结果集的网络往返次数(默认仅 25 行/次) + q.Set("PREFETCH_ROWS", "10000") + // LOB 数据延迟加载,避免大 LOB 列影响普通查询性能 + q.Set("LOB FETCH", "POST") if encoded := q.Encode(); encoded != "" { u.RawQuery = encoded } From b1ef52f62efa49a098db5b99cfa8e00c03858187 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Wed, 29 Apr 2026 12:33:35 +0800 Subject: [PATCH 3/6] =?UTF-8?q?=E2=9C=A8=20feat(data-grid):=20=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E6=97=A0=E4=B8=BB=E9=94=AE=E8=A1=A8=E5=AE=89=E5=85=A8?= =?UTF-8?q?=E7=BC=96=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 定位策略:新增主键、唯一索引和 Oracle ROWID 三类安全行定位能力 - 查询编辑器:简单单表 SELECT 自动补充隐藏定位列,复杂结果保持只读 - 表预览:无主键表可通过唯一索引或 Oracle ROWID 安全编辑 - 提交流程:移除无主键整行 WHERE fallback,隐藏定位列不参与展示和写入 - 后端保护:Oracle、MySQL、PostgreSQL 更新删除必须恰好影响 1 行 - 测试覆盖:补充 QueryEditor、DataViewer、DataGrid 和 ApplyChanges 相关用例 Refs #419 --- frontend/src/components/DataGrid.ddl.test.tsx | 112 ++++- .../src/components/DataGrid.layout.test.tsx | 1 + frontend/src/components/DataGrid.tsx | 274 +++++++++---- .../DataViewer.primary-key.test.tsx | 199 +++++++++ frontend/src/components/DataViewer.tsx | 166 ++++++-- .../QueryEditor.external-sql-save.test.tsx | 173 +++++++- frontend/src/components/QueryEditor.tsx | 388 +++++++++++++++--- frontend/src/utils/queryResultTable.test.ts | 44 ++ frontend/src/utils/queryResultTable.ts | 64 +++ frontend/src/utils/rowLocator.test.ts | 146 +++++++ frontend/src/utils/rowLocator.ts | 133 ++++++ frontend/wailsjs/go/models.ts | 2 + internal/connection/types.go | 7 +- internal/db/database.go | 15 + internal/db/mysql_impl.go | 8 +- internal/db/oracle_applychanges_test.go | 179 +++++++- internal/db/oracle_impl.go | 122 ++++-- internal/db/postgres_impl.go | 18 +- 18 files changed, 1823 insertions(+), 228 deletions(-) create mode 100644 frontend/src/components/DataViewer.primary-key.test.tsx create mode 100644 frontend/src/utils/queryResultTable.test.ts create mode 100644 frontend/src/utils/queryResultTable.ts create mode 100644 frontend/src/utils/rowLocator.test.ts create mode 100644 frontend/src/utils/rowLocator.ts diff --git a/frontend/src/components/DataGrid.ddl.test.tsx b/frontend/src/components/DataGrid.ddl.test.tsx index c1f727b..0ca01fe 100644 --- a/frontend/src/components/DataGrid.ddl.test.tsx +++ b/frontend/src/components/DataGrid.ddl.test.tsx @@ -2,7 +2,8 @@ import React from 'react'; import { act, create, type ReactTestRenderer } from 'react-test-renderer'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; -import DataGrid from './DataGrid'; +import DataGrid, { buildDataGridCommitChangeSet, GONAVI_ROW_KEY } from './DataGrid'; +import { ORACLE_ROWID_LOCATOR_COLUMN } from '../utils/rowLocator'; const storeState = vi.hoisted(() => ({ connections: [ @@ -216,6 +217,115 @@ const waitForEffects = async () => { }); }; +const normalizeValue = (_columnName: string, value: any) => value; +const rowKeyToString = (key: any) => String(key); + +const commitColumnGuard = (columnName: string) => ( + columnName !== GONAVI_ROW_KEY && columnName !== ORACLE_ROWID_LOCATOR_COLUMN +); + +describe('DataGrid commit change set', () => { + it('uses unique locator values instead of falling back to the whole row', () => { + const result = buildDataGridCommitChangeSet({ + addedRows: [], + modifiedRows: { + 'row-1': { [GONAVI_ROW_KEY]: 'row-1', EMAIL: 'a@example.com', NAME: 'new-name', AGE: 42 }, + }, + deletedRowKeys: new Set(), + data: [{ [GONAVI_ROW_KEY]: 'row-1', EMAIL: 'a@example.com', NAME: 'old-name', AGE: 42 }], + editLocator: { + strategy: 'unique-key', + columns: ['EMAIL'], + valueColumns: ['EMAIL'], + readOnly: false, + }, + visibleColumnNames: ['EMAIL', 'NAME', 'AGE'], + rowKeyToString, + normalizeCommitCellValue: normalizeValue, + shouldCommitColumn: commitColumnGuard, + }); + + expect(result).toEqual({ + ok: true, + changes: { + inserts: [], + updates: [{ keys: { EMAIL: 'a@example.com' }, values: { NAME: 'new-name' } }], + deletes: [], + }, + }); + }); + + it('uses hidden Oracle ROWID only as locator and excludes it from update values', () => { + const result = buildDataGridCommitChangeSet({ + addedRows: [], + modifiedRows: { + 'row-1': { [GONAVI_ROW_KEY]: 'row-1', NAME: 'new-name', [ORACLE_ROWID_LOCATOR_COLUMN]: 'BBBB' }, + }, + deletedRowKeys: new Set(), + data: [{ [GONAVI_ROW_KEY]: 'row-1', NAME: 'old-name', [ORACLE_ROWID_LOCATOR_COLUMN]: 'AAAA' }], + editLocator: { + strategy: 'oracle-rowid', + columns: ['ROWID'], + valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN], + hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN], + readOnly: false, + }, + visibleColumnNames: ['NAME'], + rowKeyToString, + normalizeCommitCellValue: normalizeValue, + shouldCommitColumn: commitColumnGuard, + }); + + expect(result).toEqual({ + ok: true, + changes: { + inserts: [], + updates: [{ keys: { ROWID: 'AAAA' }, values: { NAME: 'new-name' } }], + deletes: [], + }, + }); + }); + + it('fails closed when no safe locator is available', () => { + const result = buildDataGridCommitChangeSet({ + addedRows: [], + modifiedRows: { + 'row-1': { [GONAVI_ROW_KEY]: 'row-1', NAME: 'new-name' }, + }, + deletedRowKeys: new Set(), + data: [{ [GONAVI_ROW_KEY]: 'row-1', NAME: 'old-name' }], + editLocator: undefined, + visibleColumnNames: ['NAME'], + rowKeyToString, + normalizeCommitCellValue: normalizeValue, + shouldCommitColumn: commitColumnGuard, + }); + + expect(result).toEqual({ ok: false, error: '当前结果没有可用的安全行定位方式,无法提交修改。' }); + }); + + it('rejects delete rows when unique locator value is null', () => { + const result = buildDataGridCommitChangeSet({ + addedRows: [], + modifiedRows: {}, + deletedRowKeys: new Set(['row-1']), + data: [{ [GONAVI_ROW_KEY]: 'row-1', EMAIL: null, NAME: 'old-name' }], + editLocator: { + strategy: 'unique-key', + columns: ['EMAIL'], + valueColumns: ['EMAIL'], + readOnly: false, + }, + visibleColumnNames: ['EMAIL', 'NAME'], + rowKeyToString, + normalizeCommitCellValue: normalizeValue, + shouldCommitColumn: commitColumnGuard, + }); + + expect(result).toEqual({ ok: false, error: '定位列 EMAIL 的值为空,无法安全提交修改。' }); + }); +}); + describe('DataGrid DDL interactions', () => { beforeEach(() => { backendApp.DBGetColumns.mockResolvedValue({ success: true, data: [] }); diff --git a/frontend/src/components/DataGrid.layout.test.tsx b/frontend/src/components/DataGrid.layout.test.tsx index abdf1b9..f9797d7 100644 --- a/frontend/src/components/DataGrid.layout.test.tsx +++ b/frontend/src/components/DataGrid.layout.test.tsx @@ -159,6 +159,7 @@ describe('DataGrid layout', () => { columnNames={['id', 'name']} loading={false} tableName="users" + pkColumns={['id']} />, ); diff --git a/frontend/src/components/DataGrid.tsx b/frontend/src/components/DataGrid.tsx index 6cd48ff..70f4b69 100644 --- a/frontend/src/components/DataGrid.tsx +++ b/frontend/src/components/DataGrid.tsx @@ -79,6 +79,12 @@ import { type DataGridFindMatch, type DataGridFindNavigationDirection, } from '../utils/dataGridFind'; +import { + filterHiddenLocatorColumns, + isHiddenLocatorColumn, + resolveRowLocatorValues, + type EditRowLocator, +} from '../utils/rowLocator'; // --- Error Boundary --- interface DataGridErrorBoundaryState { @@ -916,6 +922,7 @@ interface DataGridProps { dbName?: string; connectionId?: string; pkColumns?: string[]; + editLocator?: EditRowLocator; readOnly?: boolean; onReload?: () => void; onSort?: (field: string, order: string) => void; @@ -960,12 +967,110 @@ type ColumnMeta = { comment: string; }; +type NormalizeCommitCellValue = (columnName: string, value: any, mode: 'insert' | 'update') => any; + +type DataGridCommitChangeSet = { + inserts: any[]; + updates: any[]; + deletes: any[]; +}; + +export const buildDataGridCommitChangeSet = ({ + addedRows, + modifiedRows, + deletedRowKeys, + data, + editLocator, + visibleColumnNames, + rowKeyToString, + normalizeCommitCellValue, + shouldCommitColumn, +}: { + addedRows: any[]; + modifiedRows: Record; + deletedRowKeys: Set; + data: any[]; + editLocator?: EditRowLocator; + visibleColumnNames: string[]; + rowKeyToString: (key: any) => string; + normalizeCommitCellValue: NormalizeCommitCellValue; + shouldCommitColumn: (columnName: string) => boolean; +}): { ok: true; changes: DataGridCommitChangeSet } | { ok: false; error: string } => { + if (!editLocator || editLocator.readOnly || editLocator.strategy === 'none') { + return { ok: false, error: editLocator?.reason || '当前结果没有可用的安全行定位方式,无法提交修改。' }; + } + + const normalizeValues = (values: Record, mode: 'insert' | 'update') => { + const normalizedValues: Record = {}; + Object.entries(values).forEach(([col, val]) => { + if (!shouldCommitColumn(col)) return; + const normalizedVal = normalizeCommitCellValue(col, val, mode); + if (normalizedVal !== undefined) { + normalizedValues[col] = normalizedVal; + } + }); + return normalizedValues; + }; + + const originalRowsByKey = new Map(); + data.forEach((row) => { + const key = row?.[GONAVI_ROW_KEY]; + if (key === undefined || key === null) return; + originalRowsByKey.set(rowKeyToString(key), row); + }); + + const inserts: any[] = []; + const updates: any[] = []; + const deletes: any[] = []; + + addedRows.forEach(row => { + const key = row?.[GONAVI_ROW_KEY]; + if (key !== undefined && key !== null && deletedRowKeys.has(rowKeyToString(key))) return; + inserts.push(normalizeValues(row, 'insert')); + }); + + for (const keyStr of deletedRowKeys) { + const originalRow = originalRowsByKey.get(keyStr); + if (!originalRow) continue; + const locatorValues = resolveRowLocatorValues(editLocator, originalRow); + if (!locatorValues.ok) return { ok: false, error: locatorValues.error }; + deletes.push(locatorValues.values); + } + + for (const [keyStr, newRow] of Object.entries(modifiedRows)) { + if (deletedRowKeys.has(keyStr)) continue; + const originalRow = originalRowsByKey.get(keyStr); + if (!originalRow) continue; + + const locatorValues = resolveRowLocatorValues(editLocator, originalRow); + if (!locatorValues.ok) return { ok: false, error: locatorValues.error }; + + const hasRowKey = Object.prototype.hasOwnProperty.call(newRow as any, GONAVI_ROW_KEY); + let values: Record = {}; + if (!hasRowKey) { + values = { ...(newRow as any) }; + } else { + visibleColumnNames.forEach((col) => { + const nextVal = (newRow as any)?.[col]; + const prevVal = (originalRow as any)?.[col]; + if (!isCellValueEqualForDiff(prevVal, nextVal)) values[col] = nextVal; + }); + } + + const normalizedValues = normalizeValues(values, 'update'); + if (Object.keys(normalizedValues).length === 0) continue; + updates.push({ keys: locatorValues.values, values: normalizedValues }); + } + + return { ok: true, changes: { inserts, updates, deletes } }; +}; + // P2 性能优化:提取内联 style 对象为模块级常量,避免每次 render 创建新对象 const CELL_ELLIPSIS_STYLE: React.CSSProperties = { overflow: 'hidden', textOverflow: 'ellipsis', whiteSpace: 'nowrap' }; const VIRTUAL_CELL_WRAPPER_STYLE: React.CSSProperties = { margin: -8, padding: '8px 8px 8px 8px' }; const DataGrid: React.FC = ({ - data, columnNames, loading, tableName, exportScope = 'table', resultSql, dbName, connectionId, pkColumns = [], readOnly = false, + data, columnNames, loading, tableName, exportScope = 'table', resultSql, dbName, connectionId, pkColumns = [], editLocator, readOnly = false, onReload, onSort, onPageChange, pagination, onRequestTotalCount, onCancelTotalCount, sortInfoExternal, showFilter, onToggleFilter, exportSqlWithFilter, onApplyFilter, appliedFilterConditions, quickWhereCondition, onApplyQuickWhereCondition, scrollSnapshot, onScrollSnapshotChange @@ -999,7 +1104,25 @@ const DataGrid: React.FC = ({ darkMode, visible: showDataTableVerticalBorders, }); - const canModifyData = !readOnly && !!tableName; + const effectiveEditLocator = useMemo(() => { + if (editLocator) return editLocator; + if (pkColumns.length === 0) return undefined; + return { + strategy: 'primary-key', + columns: pkColumns, + valueColumns: pkColumns, + readOnly: false, + }; + }, [editLocator, pkColumns]); + const visibleColumnNames = useMemo( + () => filterHiddenLocatorColumns(columnNames, effectiveEditLocator), + [columnNames, effectiveEditLocator] + ); + const shouldCommitColumn = useCallback((columnName: string): boolean => { + const normalized = String(columnName || '').trim(); + return normalized !== GONAVI_ROW_KEY && !isHiddenLocatorColumn(normalized, effectiveEditLocator); + }, [effectiveEditLocator]); + const canModifyData = !readOnly && !!tableName && !!effectiveEditLocator && !effectiveEditLocator.readOnly && effectiveEditLocator.strategy !== 'none'; const showColumnComment = queryOptions?.showColumnComment ?? true; const showColumnType = queryOptions?.showColumnType ?? true; @@ -1053,7 +1176,7 @@ const DataGrid: React.FC = ({ // Sync display order from incoming prop and store memory useEffect(() => { - let nextOrder = [...columnNames]; + let nextOrder = [...visibleColumnNames]; if (enableColumnOrderMemory && connectionId && dbName && tableName) { const storedOrder = tableColumnOrders[`${connectionId}-${dbName}-${tableName}`]; if (Array.isArray(storedOrder) && storedOrder.length > 0) { @@ -1066,7 +1189,7 @@ const DataGrid: React.FC = ({ } } setAllOrderedColumnNames(nextOrder); - }, [columnNames, tableColumnOrders, enableColumnOrderMemory, connectionId, dbName, tableName]); + }, [visibleColumnNames, tableColumnOrders, enableColumnOrderMemory, connectionId, dbName, tableName]); // Compute final display columns useEffect(() => { @@ -1378,7 +1501,13 @@ const DataGrid: React.FC = ({ const exportData = async (rows: any[], format: string) => { const hide = message.loading(`正在导出 ${rows.length} 条数据...`, 0); try { - const cleanRows = rows.map(({ [GONAVI_ROW_KEY]: _rowKey, ...rest }) => rest); + const cleanRows = rows.map((row) => { + const next: Record = {}; + displayColumnNames.forEach((columnName) => { + next[columnName] = row?.[columnName]; + }); + return next; + }); // Pass tableName (or 'export') as default filename const res = await ExportData(cleanRows, displayColumnNames, tableName || 'export', format); if (res.success) { @@ -1538,10 +1667,10 @@ const DataGrid: React.FC = ({ return metaColumns; } if (exportScope === 'table') { - return columnNames.filter((columnName) => columnName !== GONAVI_ROW_KEY); + return visibleColumnNames.filter((columnName) => columnName !== GONAVI_ROW_KEY); } return []; - }, [columnMetaMap, exportScope, columnNames]); + }, [columnMetaMap, exportScope, visibleColumnNames]); const normalizeCommitCellValue = useCallback( (columnName: string, value: any, mode: 'insert' | 'update') => { @@ -3298,19 +3427,25 @@ const DataGrid: React.FC = ({ const jsonViewText = useMemo(() => { if (viewMode !== 'json') return ''; const cleanRows = mergedDisplayData.map((row) => { - const { [GONAVI_ROW_KEY]: _rowKey, ...rest } = row || {}; - return normalizeValueForJsonView(rest); + const next: Record = {}; + visibleColumnNames.forEach((columnName) => { + next[columnName] = row?.[columnName]; + }); + return normalizeValueForJsonView(next); }); return JSON.stringify(cleanRows, null, 2); - }, [viewMode, mergedDisplayData]); + }, [viewMode, mergedDisplayData, visibleColumnNames]); const textViewRows = useMemo(() => { if (viewMode !== 'text') return []; return mergedDisplayData.map((row) => { - const { [GONAVI_ROW_KEY]: _rowKey, ...rest } = row || {}; - return rest; + const next: Record = {}; + visibleColumnNames.forEach((columnName) => { + next[columnName] = row?.[columnName]; + }); + return next; }); - }, [viewMode, mergedDisplayData]); + }, [viewMode, mergedDisplayData, visibleColumnNames]); const currentTextRow = useMemo(() => { if (viewMode !== 'text') return null; @@ -3363,7 +3498,7 @@ const DataGrid: React.FC = ({ const formMap: Record = {}; const nullCols = new Set(); - columnNames.forEach((col) => { + visibleColumnNames.forEach((col) => { const baseVal = (baseRow as any)?.[col]; const displayVal = (displayRow as any)?.[col]; baseRawMap[col] = baseVal; @@ -3511,7 +3646,7 @@ const DataGrid: React.FC = ({ const keyStr = rowKeyStr(rowKey); const normalizedNext: Record = {}; let hasAnyVisibleChange = false; - columnNames.forEach((col) => { + visibleColumnNames.forEach((col) => { const currentVal = (currentRow as any)?.[col]; const editedVal = Object.prototype.hasOwnProperty.call(nextItem, col) ? (nextItem as any)[col] : currentVal; if (!isJsonViewValueEqual(currentVal, editedVal)) hasAnyVisibleChange = true; @@ -3530,7 +3665,7 @@ const DataGrid: React.FC = ({ const originalRow = originalMap.get(keyStr); if (!originalRow) continue; const patch: Record = {}; - columnNames.forEach((col) => { + visibleColumnNames.forEach((col) => { const prevVal = (originalRow as any)?.[col]; const nextVal = normalizedNext[col]; if (!isCellValueEqualForDiff(prevVal, nextVal)) patch[col] = nextVal; @@ -3595,7 +3730,7 @@ const DataGrid: React.FC = ({ const baseRawMap = rowEditorBaseRawRef.current || {}; const patch: Record = {}; - columnNames.forEach((col) => { + visibleColumnNames.forEach((col) => { let nextVal = values[col]; // 日期时间类型: 将 dayjs 对象转回格式化字符串 if (nextVal && dayjs.isDayjs(nextVal)) { @@ -3615,7 +3750,7 @@ const DataGrid: React.FC = ({ }); closeRowEditor(); - }, [rowEditorRowKey, rowEditorForm, addedRows, columnNames, rowKeyStr, closeRowEditor]); + }, [rowEditorRowKey, rowEditorForm, addedRows, visibleColumnNames, rowKeyStr, closeRowEditor]); const enableVirtual = viewMode === 'table'; @@ -3761,7 +3896,7 @@ const DataGrid: React.FC = ({ const handleAddRow = () => { const newKey = `new-${Date.now()}`; const newRow: any = { [GONAVI_ROW_KEY]: newKey }; - columnNames.forEach(col => newRow[col] = ''); + visibleColumnNames.forEach(col => newRow[col] = ''); pendingScrollToBottomRef.current = true; setAddedRows(prev => [...prev, newRow]); }; @@ -3775,7 +3910,7 @@ const DataGrid: React.FC = ({ const copiedRows = buildCopiedRowsForPaste({ rows: mergedDisplayData as Array>, selectedRowKeys, - columnNames, + columnNames: visibleColumnNames, rowKeyField: GONAVI_ROW_KEY, rowKeyToString: rowKeyStr, }); @@ -3786,7 +3921,7 @@ const DataGrid: React.FC = ({ setCopiedRowsForPaste(copiedRows); void message.success(`已复制 ${copiedRows.length} 行,可粘贴为新增行`); - }, [selectedRowKeys, mergedDisplayData, columnNames, rowKeyStr]); + }, [selectedRowKeys, mergedDisplayData, visibleColumnNames, rowKeyStr]); const handlePasteCopiedRowsAsNew = useCallback(() => { if (copiedRowsForPaste.length === 0) { @@ -3796,7 +3931,7 @@ const DataGrid: React.FC = ({ const nextRows = buildPastedRowsFromCopiedRows({ rows: copiedRowsForPaste, - columnNames, + columnNames: visibleColumnNames, rowKeyField: GONAVI_ROW_KEY, createRowKey: (index) => { pastedRowSequenceRef.current += 1; @@ -3812,7 +3947,7 @@ const DataGrid: React.FC = ({ setAddedRows(prev => [...prev, ...nextRows]); setSelectedRowKeys(nextRows.map(row => row[GONAVI_ROW_KEY])); void message.success(`已粘贴 ${nextRows.length} 行为新增行,请检查后提交事务`); - }, [copiedRowsForPaste, columnNames]); + }, [copiedRowsForPaste, visibleColumnNames]); const handleDeleteSelected = () => { setDeletedRowKeys(prev => { @@ -3827,66 +3962,23 @@ const DataGrid: React.FC = ({ if (!connectionId || !tableName) return; const conn = connections.find(c => c.id === connectionId); if (!conn) return; - - const inserts: any[] = []; - const updates: any[] = []; - const deletes: any[] = []; - - addedRows.forEach(row => { - const { [GONAVI_ROW_KEY]: _rowKey, ...vals } = row; - const normalizedValues: Record = {}; - Object.entries(vals).forEach(([col, val]) => { - const normalizedVal = normalizeCommitCellValue(col, val, 'insert'); - if (normalizedVal !== undefined) { - normalizedValues[col] = normalizedVal; - } - }); - inserts.push(normalizedValues); - }); - deletedRowKeys.forEach(keyStr => { - // Find original data - const originalRow = data.find(d => rowKeyStr(d?.[GONAVI_ROW_KEY]) === keyStr) || addedRows.find(d => rowKeyStr(d?.[GONAVI_ROW_KEY]) === keyStr); - if (originalRow) { - const pkData: any = {}; - if (pkColumns.length > 0) pkColumns.forEach(k => pkData[k] = originalRow[k]); - else { const { [GONAVI_ROW_KEY]: _rowKey, ...rest } = originalRow; Object.assign(pkData, rest); } - deletes.push(pkData); - } - }); - Object.entries(modifiedRows).forEach(([keyStr, newRow]) => { - if (deletedRowKeys.has(keyStr)) return; - const originalRow = data.find(d => rowKeyStr(d?.[GONAVI_ROW_KEY]) === keyStr); - if (!originalRow) return; // Should not happen for modified rows unless deleted - - const pkData: any = {}; - if (pkColumns.length > 0) pkColumns.forEach(k => pkData[k] = originalRow[k]); - else { const { [GONAVI_ROW_KEY]: _rowKey, ...rest } = originalRow; Object.assign(pkData, rest); } - - const hasRowKey = Object.prototype.hasOwnProperty.call(newRow as any, GONAVI_ROW_KEY); - let values: any = {}; - - if (!hasRowKey) { - values = { ...(newRow as any) }; - } else { - columnNames.forEach((col) => { - const nextVal = (newRow as any)?.[col]; - const prevVal = (originalRow as any)?.[col]; - if (!isCellValueEqualForDiff(prevVal, nextVal)) values[col] = nextVal; - }); - } - - const normalizedValues: Record = {}; - Object.entries(values).forEach(([col, val]) => { - const normalizedVal = normalizeCommitCellValue(col, val, 'update'); - if (normalizedVal !== undefined) { - normalizedValues[col] = normalizedVal; - } - }); - - if (Object.keys(normalizedValues).length === 0) return; - updates.push({ keys: pkData, values: normalizedValues }); + const changeSetResult = buildDataGridCommitChangeSet({ + addedRows, + modifiedRows, + deletedRowKeys, + data, + editLocator: effectiveEditLocator, + visibleColumnNames, + rowKeyToString: rowKeyStr, + normalizeCommitCellValue, + shouldCommitColumn, }); + if (!changeSetResult.ok) { + void message.error(changeSetResult.error); + return; + } + const { inserts, updates, deletes } = changeSetResult.changes; if (inserts.length === 0 && updates.length === 0 && deletes.length === 0) { void message.info("没有可提交的变更"); return; @@ -3902,7 +3994,7 @@ const DataGrid: React.FC = ({ }; const startTime = Date.now(); - const res = await ApplyChanges(buildRpcConnectionConfig(config) as any, dbName || '', tableName, { inserts, updates, deletes } as any); + const res = await ApplyChanges(buildRpcConnectionConfig(config) as any, dbName || '', tableName, { inserts, updates, deletes, locatorStrategy: effectiveEditLocator?.strategy } as any); const duration = Date.now() - startTime; // Construct a pseudo-SQL representation for the log @@ -4051,7 +4143,7 @@ const DataGrid: React.FC = ({ return null; } const records = getTargets(record); - const orderedCols = columnNames.filter(c => c !== GONAVI_ROW_KEY); + const orderedCols = visibleColumnNames.filter(c => c !== GONAVI_ROW_KEY); if (mode === 'insert') { return records.map((row: any) => buildCopyInsertSQL({ dbType, @@ -4100,7 +4192,7 @@ const DataGrid: React.FC = ({ }, [ supportsCopyInsert, getTargets, - columnNames, + visibleColumnNames, dbType, tableName, columnTypeMapByLowerName, @@ -4130,16 +4222,18 @@ const DataGrid: React.FC = ({ const handleCopyJson = useCallback((record: any) => { const records = getTargets(record); const cleanRecords = records.map((r: any) => { - const { [GONAVI_ROW_KEY]: _rowKey, ...rest } = r; - return rest; + const next: Record = {}; + visibleColumnNames.forEach((columnName) => { + next[columnName] = r?.[columnName]; + }); + return next; }); copyToClipboard(JSON.stringify(cleanRecords, null, 2)); - }, [getTargets, copyToClipboard]); + }, [getTargets, visibleColumnNames, copyToClipboard]); const handleCopyCsv = useCallback((record: any) => { const records = getTargets(record); - // 使用 columnNames 保持表定义的字段顺序 - const orderedCols = columnNames.filter(c => c !== GONAVI_ROW_KEY); + const orderedCols = visibleColumnNames.filter(c => c !== GONAVI_ROW_KEY); const header = orderedCols.map(c => `"${c}"`).join(','); const lines = records.map((r: any) => { const values = orderedCols.map(c => { @@ -4152,7 +4246,7 @@ const DataGrid: React.FC = ({ return values.join(','); }); copyToClipboard([header, ...lines].join('\n')); - }, [getTargets, columnNames, copyToClipboard]); + }, [getTargets, visibleColumnNames, copyToClipboard]); const buildConnConfig = useCallback(() => { if (!connectionId) return null; diff --git a/frontend/src/components/DataViewer.primary-key.test.tsx b/frontend/src/components/DataViewer.primary-key.test.tsx new file mode 100644 index 0000000..76f5b1d --- /dev/null +++ b/frontend/src/components/DataViewer.primary-key.test.tsx @@ -0,0 +1,199 @@ +import React from 'react'; +import { act, create, type ReactTestRenderer } from 'react-test-renderer'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import type { TabData } from '../types'; +import { ORACLE_ROWID_LOCATOR_COLUMN } from '../utils/rowLocator'; +import DataViewer from './DataViewer'; + +const storeState = vi.hoisted(() => ({ + connections: [ + { + id: 'conn-1', + name: 'oracle', + config: { + type: 'oracle', + host: '127.0.0.1', + port: 1521, + user: 'scott', + password: '', + database: 'ORCLPDB1', + }, + }, + ], + addSqlLog: vi.fn(), +})); + +const backendApp = vi.hoisted(() => ({ + DBQuery: vi.fn(), + DBGetColumns: vi.fn(), + DBGetIndexes: vi.fn(), +})); + +const messageApi = vi.hoisted(() => ({ + error: vi.fn(), + warning: vi.fn(), +})); + +const dataGridState = vi.hoisted(() => ({ + latestProps: null as any, +})); + +vi.mock('../store', () => { + const useStore = Object.assign( + (selector: (state: typeof storeState) => any) => selector(storeState), + { getState: () => storeState }, + ); + return { useStore }; +}); + +vi.mock('../../wailsjs/go/app/App', () => backendApp); + +vi.mock('antd', () => ({ + message: messageApi, +})); + +vi.mock('./DataGrid', () => ({ + default: (props: any) => { + dataGridState.latestProps = props; + return
; + }, + GONAVI_ROW_KEY: '__gonavi_row_key__', +})); + +const createTab = (overrides: Partial = {}): TabData => ({ + id: 'tab-1', + title: 'EDC_LOG', + type: 'table', + connectionId: 'conn-1', + dbName: 'MYCIMLED', + tableName: 'EDC_LOG', + ...overrides, +}); + +const flushPromises = async () => { + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); +}; + +describe('DataViewer safe editing locator', () => { + const renderAndReload = async (tab: TabData = createTab()) => { + let renderer: ReactTestRenderer; + await act(async () => { + renderer = create(); + }); + + await act(async () => { + await dataGridState.latestProps.onReload(); + }); + await flushPromises(); + return renderer!; + }; + + beforeEach(() => { + vi.clearAllMocks(); + dataGridState.latestProps = null; + storeState.connections[0].config.type = 'oracle'; + storeState.connections[0].config.database = 'ORCLPDB1'; + backendApp.DBQuery.mockResolvedValue({ + success: true, + fields: ['ID', 'NAME'], + data: [{ ID: 7, NAME: 'old-name' }], + }); + backendApp.DBGetIndexes.mockResolvedValue({ success: true, data: [] }); + }); + + it('enables table preview editing after primary keys are loaded', async () => { + backendApp.DBGetColumns.mockResolvedValue({ + success: true, + data: [{ name: 'ID', key: 'PRI' }, { name: 'NAME', key: '' }], + }); + + const renderer = await renderAndReload(); + + expect(dataGridState.latestProps?.pkColumns).toEqual(['ID']); + expect(dataGridState.latestProps?.editLocator).toMatchObject({ + strategy: 'primary-key', + columns: ['ID'], + valueColumns: ['ID'], + readOnly: false, + }); + expect(dataGridState.latestProps?.readOnly).toBe(false); + expect(messageApi.warning).not.toHaveBeenCalled(); + renderer.unmount(); + }); + + it('uses a unique index when the table has no primary key', async () => { + backendApp.DBGetColumns.mockResolvedValue({ + success: true, + data: [{ name: 'EMAIL', key: '' }, { name: 'NAME', key: '' }], + }); + backendApp.DBGetIndexes.mockResolvedValue({ + success: true, + data: [{ name: 'UK_EMAIL', columnName: 'EMAIL', nonUnique: 0, seqInIndex: 1, indexType: 'BTREE' }], + }); + + const renderer = await renderAndReload(); + + expect(dataGridState.latestProps?.pkColumns).toEqual([]); + expect(dataGridState.latestProps?.editLocator).toMatchObject({ + strategy: 'unique-key', + columns: ['EMAIL'], + valueColumns: ['EMAIL'], + readOnly: false, + }); + expect(dataGridState.latestProps?.readOnly).toBe(false); + expect(messageApi.warning).not.toHaveBeenCalled(); + renderer.unmount(); + }); + + it('uses hidden Oracle ROWID when no primary or unique key is available', async () => { + backendApp.DBGetColumns.mockResolvedValue({ + success: true, + data: [{ name: 'ID', key: '' }, { name: 'NAME', key: '' }], + }); + backendApp.DBQuery.mockResolvedValue({ + success: true, + fields: ['ID', 'NAME', ORACLE_ROWID_LOCATOR_COLUMN], + data: [{ ID: 7, NAME: 'old-name', [ORACLE_ROWID_LOCATOR_COLUMN]: 'AAAA' }], + }); + + const renderer = await renderAndReload(); + + expect(dataGridState.latestProps?.pkColumns).toEqual([]); + expect(dataGridState.latestProps?.editLocator).toMatchObject({ + strategy: 'oracle-rowid', + columns: ['ROWID'], + valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN], + hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN], + readOnly: false, + }); + expect(dataGridState.latestProps?.readOnly).toBe(false); + expect(messageApi.warning).not.toHaveBeenCalled(); + expect(backendApp.DBQuery.mock.calls.some((call: any[]) => String(call[2]).includes(`ROWID AS "${ORACLE_ROWID_LOCATOR_COLUMN}"`))).toBe(true); + renderer.unmount(); + }); + + it('keeps non-Oracle table preview read-only when no safe locator exists', async () => { + storeState.connections[0].config.type = 'mysql'; + storeState.connections[0].config.database = 'main'; + backendApp.DBGetColumns.mockResolvedValue({ + success: true, + data: [{ name: 'ID', key: '' }, { name: 'NAME', key: '' }], + }); + + const renderer = await renderAndReload(createTab({ dbName: 'main', tableName: 'users', title: 'users' })); + + expect(dataGridState.latestProps?.pkColumns).toEqual([]); + expect(dataGridState.latestProps?.editLocator).toMatchObject({ + strategy: 'none', + readOnly: true, + reason: '未检测到主键或可用唯一索引,无法安全提交修改。', + }); + expect(dataGridState.latestProps?.readOnly).toBe(true); + expect(messageApi.warning).toHaveBeenCalledWith('表 main.users 保持只读:未检测到主键或可用唯一索引,无法安全提交修改。'); + renderer.unmount(); + }); +}); diff --git a/frontend/src/components/DataViewer.tsx b/frontend/src/components/DataViewer.tsx index cb259c1..b417309 100644 --- a/frontend/src/components/DataViewer.tsx +++ b/frontend/src/components/DataViewer.tsx @@ -1,8 +1,8 @@ import React, { useEffect, useState, useCallback, useRef, useMemo } from 'react'; import { message } from 'antd'; -import { TabData, ColumnDefinition } from '../types'; +import { TabData, ColumnDefinition, IndexDefinition } from '../types'; import { useStore } from '../store'; -import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App'; +import { DBQuery, DBGetColumns, DBGetIndexes } from '../../wailsjs/go/app/App'; import DataGrid, { GONAVI_ROW_KEY } from './DataGrid'; import { buildOrderBySQL, buildPaginatedSelectSQL, buildWhereSQL, hasExplicitSort, quoteIdentPart, quoteQualifiedIdent, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql'; import { buildMongoCountCommand, buildMongoFilter, buildMongoFindCommand, buildMongoSort } from '../utils/mongodb'; @@ -15,6 +15,12 @@ import { normalizeQuickWhereCondition, validateQuickWhereCondition, } from '../utils/dataGridWhereFilter'; +import { + ORACLE_ROWID_LOCATOR_COLUMN, + resolveEditRowLocator, + type EditRowLocator, +} from '../utils/rowLocator'; +import { isOracleLikeDialect } from '../utils/sqlDialect'; type ViewerPaginationState = { current: number; @@ -79,6 +85,47 @@ const parseTotalFromCountRow = (row: any): number | null => { return null; }; +const buildDataViewerReadOnlyLocator = (reason: string): EditRowLocator => ({ + strategy: 'none', + columns: [], + valueColumns: [], + readOnly: true, + reason, +}); + +const formatDataViewerTableName = (dbName: string, tableName: string): string => ( + dbName ? `${dbName}.${tableName}` : tableName +); + +const getTableColumnNames = (columns: ColumnDefinition[] | undefined): string[] => ( + (columns || []) + .map((column) => String(column?.name || '').trim()) + .filter(Boolean) +); + +const resolveDataViewerOrderFallbackColumns = (locator: EditRowLocator | undefined, pkColumns: string[]): string[] => { + if (locator && !locator.readOnly && locator.strategy !== 'oracle-rowid') { + return locator.valueColumns.length > 0 ? locator.valueColumns : locator.columns; + } + return pkColumns; +}; + +const buildDataViewerBaseSelectSQL = ( + dbType: string, + tableName: string, + whereSQL: string, + locator?: EditRowLocator, +): string => { + const quotedTableName = quoteQualifiedIdent(dbType, tableName); + if (locator?.strategy !== 'oracle-rowid') { + return `SELECT * FROM ${quotedTableName} ${whereSQL}`; + } + + const alias = 'gonavi_row_source'; + const rowIDAlias = quoteIdentPart(dbType, ORACLE_ROWID_LOCATOR_COLUMN); + return `SELECT ${alias}.*, ${alias}.ROWID AS ${rowIDAlias} FROM ${quotedTableName} ${alias} ${whereSQL}`; +}; + const normalizeDuckDBIdentifier = (raw: string): string => { const text = String(raw || '').trim(); if (text.length >= 2) { @@ -193,6 +240,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct const [data, setData] = useState([]); const [columnNames, setColumnNames] = useState([]); const [pkColumns, setPkColumns] = useState([]); + const [editLocator, setEditLocator] = useState(undefined); const [loading, setLoading] = useState(false); const connections = useStore(state => state.connections); const addSqlLog = useStore(state => state.addSqlLog); @@ -280,6 +328,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct useEffect(() => { const snapshot = getViewerFilterSnapshot(tab.id); setPkColumns([]); + setEditLocator(undefined); pkKeyRef.current = ''; countKeyRef.current = ''; duckdbApproxKeyRef.current = ''; @@ -435,10 +484,84 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct const whereSQL = isMongoDB ? JSON.stringify(mongoFilter || {}) : buildWhereSQL(dbType, effectiveFilterConditions); + + let pkColumnsForQuery = pkColumns; + let editLocatorForQuery = editLocator; + if (!isMongoDB && !forceReadOnly && tableName) { + const locatorKey = `${tab.connectionId}|${dbTypeLower}|${dbName}|${tableName}`; + if (pkKeyRef.current !== locatorKey || !editLocatorForQuery) { + pkKeyRef.current = locatorKey; + const locatorSeq = ++pkSeqRef.current; + try { + const [resCols, resIndexes] = await Promise.all([ + DBGetColumns(buildRpcConnectionConfig(config) as any, dbName, tableName), + DBGetIndexes(buildRpcConnectionConfig(config) as any, dbName, tableName) + .catch((error: any) => ({ success: false, message: String(error?.message || error || '加载索引失败'), data: [] })), + ]); + if (fetchSeqRef.current !== seq) return; + if (pkSeqRef.current !== locatorSeq) return; + if (pkKeyRef.current !== locatorKey) return; + + if (!resCols?.success || !Array.isArray(resCols.data)) { + const nextLocator = buildDataViewerReadOnlyLocator('无法加载主键/唯一索引元数据,无法安全提交修改。'); + pkColumnsForQuery = []; + editLocatorForQuery = nextLocator; + setPkColumns([]); + setEditLocator(nextLocator); + message.warning(`表 ${formatDataViewerTableName(dbName, tableName)} 保持只读:${nextLocator.reason}`); + } else { + const columnDefs = resCols.data as ColumnDefinition[]; + const primaryKeys = columnDefs + .filter((column: any) => column?.key === 'PRI') + .map((column: any) => String(column?.name || '').trim()) + .filter(Boolean); + const indexes = resIndexes?.success && Array.isArray(resIndexes.data) + ? resIndexes.data as IndexDefinition[] + : []; + const resultColumns = getTableColumnNames(columnDefs); + const locatorColumns = isOracleLikeDialect(dbType) + ? [...resultColumns, ORACLE_ROWID_LOCATOR_COLUMN] + : resultColumns; + let nextLocator = resolveEditRowLocator({ + dbType, + resultColumns: locatorColumns, + primaryKeys, + indexes, + allowOracleRowID: true, + }); + + if (nextLocator.readOnly && primaryKeys.length === 0 && !resIndexes?.success && !isOracleLikeDialect(dbType)) { + nextLocator = buildDataViewerReadOnlyLocator('无法加载唯一索引元数据,无法安全提交修改。'); + } + + pkColumnsForQuery = primaryKeys; + editLocatorForQuery = nextLocator; + setPkColumns(primaryKeys); + setEditLocator(nextLocator); + if (nextLocator.readOnly) { + message.warning(`表 ${formatDataViewerTableName(dbName, tableName)} 保持只读:${nextLocator.reason || '当前结果没有可用的安全行定位方式,无法提交修改。'}`); + } + } + } catch { + if (fetchSeqRef.current !== seq) return; + if (pkSeqRef.current !== locatorSeq) return; + if (pkKeyRef.current !== locatorKey) return; + const nextLocator = buildDataViewerReadOnlyLocator('无法加载主键/唯一索引元数据,无法安全提交修改。'); + pkColumnsForQuery = []; + editLocatorForQuery = nextLocator; + setPkColumns([]); + setEditLocator(nextLocator); + message.warning(`表 ${formatDataViewerTableName(dbName, tableName)} 保持只读:${nextLocator.reason}`); + } + } + } + const countSql = isMongoDB ? buildMongoCountCommand(tableName, mongoFilter || {}) : `SELECT COUNT(*) as total FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`; - const orderBySQL = isMongoDB ? '' : buildOrderBySQL(dbType, sortInfo, pkColumns); + const orderBySQL = isMongoDB + ? '' + : buildOrderBySQL(dbType, sortInfo, resolveDataViewerOrderFallbackColumns(editLocatorForQuery, pkColumnsForQuery)); const totalRows = Number(pagination.total); const hasFiniteTotal = Number.isFinite(totalRows) && totalRows >= 0; const totalKnown = pagination.totalKnown && hasFiniteTotal; @@ -469,7 +592,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct skip: offset, }); } else { - const baseSql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`; + const baseSql = buildDataViewerBaseSelectSQL(dbType, tableName, whereSQL, editLocatorForQuery); sql = `${baseSql}${orderBySQL}`; // ClickHouse 深分页在超大 OFFSET 下容易超时。对于总数已知且存在 ORDER BY 的场景, // 当“尾部偏移”小于“头部偏移”时,改为反向 ORDER BY + 小 OFFSET,并在前端翻转结果。 @@ -557,7 +680,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct if (safeSelect) { let fallbackSql = `SELECT ${safeSelect} FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`; - fallbackSql = buildPaginatedSelectSQL(dbType, fallbackSql, buildOrderBySQL(dbType, sortInfo, pkColumns), size + 1, offset); + fallbackSql = buildPaginatedSelectSQL(dbType, fallbackSql, buildOrderBySQL(dbType, sortInfo, resolveDataViewerOrderFallbackColumns(editLocatorForQuery, pkColumnsForQuery)), size + 1, offset); executedSql = fallbackSql; resData = await executeDataQuery(fallbackSql, '复杂类型降级重试'); } @@ -580,26 +703,6 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct message.warning('已自动提升排序缓冲并重试成功。'); } } - - if (pkColumns.length === 0) { - const pkKey = `${tab.connectionId}|${dbName}|${tableName}`; - if (pkKeyRef.current !== pkKey) { - pkKeyRef.current = pkKey; - const pkSeq = ++pkSeqRef.current; - DBGetColumns(buildRpcConnectionConfig(config) as any, dbName, tableName) - .then((resCols: any) => { - if (pkSeqRef.current !== pkSeq) return; - if (pkKeyRef.current !== pkKey) return; - if (!resCols?.success) return; - const pks = (resCols.data as ColumnDefinition[]).filter((c: any) => c.key === 'PRI').map((c: any) => c.name); - setPkColumns(pks); - }) - .catch(() => { - if (pkSeqRef.current !== pkSeq) return; - if (pkKeyRef.current !== pkKey) return; - }); - } - } if (resData.success) { let resultData = resData.data as any[]; @@ -842,9 +945,9 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct }); } if (fetchSeqRef.current === seq) setLoading(false); - }, [connections, tab, sortInfo, filterConditions, quickWhereCondition, pkColumns, pagination.total, pagination.totalKnown, pagination.totalApprox, pagination.approximateTotal, preferManualTotalCount, supportsApproximateTableCount, supportsApproximateTotalPages]); - // 依赖 pkColumns:在无手动排序时可回退到主键稳定排序。 - // 主键信息只会在首次加载后更新一次,避免循环查询。 + }, [connections, tab, sortInfo, filterConditions, quickWhereCondition, pkColumns, editLocator, forceReadOnly, pagination.total, pagination.totalKnown, pagination.totalApprox, pagination.approximateTotal, preferManualTotalCount, supportsApproximateTableCount, supportsApproximateTotalPages]); + // 依赖定位列:在无手动排序时可回退到安全定位列稳定排序。 + // 定位信息只会在表上下文变化后重新加载,避免循环查询。 // Handlers memoized const handleReload = useCallback(() => { @@ -890,14 +993,14 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct if (!whereSQL) return ''; let sql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} ${whereSQL}`; - sql += buildOrderBySQL(dbType, sortInfo, pkColumns); + sql += buildOrderBySQL(dbType, sortInfo, resolveDataViewerOrderFallbackColumns(editLocator, pkColumns)); const normalizedType = dbType.toLowerCase(); const hasSortForBuffer = hasExplicitSort(sortInfo); if (hasSortForBuffer && (normalizedType === 'mysql' || normalizedType === 'mariadb')) { sql = withSortBufferTuningSQL(normalizedType, sql, 32 * 1024 * 1024); } return sql; - }, [tab.tableName, currentConnConfig?.type, currentConnConfig?.driver, filterConditions, quickWhereCondition, sortInfo, pkColumns]); + }, [tab.tableName, currentConnConfig?.type, currentConnConfig?.driver, filterConditions, quickWhereCondition, sortInfo, editLocator, pkColumns]); useEffect(() => { const action = resolveDataViewerAutoFetchAction({ @@ -927,6 +1030,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct dbName={tab.dbName} connectionId={tab.connectionId} pkColumns={pkColumns} + editLocator={editLocator} onReload={handleReload} onSort={handleSort} onPageChange={handlePageChange} @@ -939,7 +1043,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAct appliedFilterConditions={filterConditions} quickWhereCondition={quickWhereCondition} onApplyQuickWhereCondition={handleApplyQuickWhereCondition} - readOnly={forceReadOnly} + readOnly={forceReadOnly || !editLocator || editLocator.readOnly} sortInfoExternal={sortInfo} exportSqlWithFilter={exportSqlWithFilter || undefined} scrollSnapshot={scrollSnapshotRef.current} diff --git a/frontend/src/components/QueryEditor.external-sql-save.test.tsx b/frontend/src/components/QueryEditor.external-sql-save.test.tsx index 711ea7c..6f64767 100644 --- a/frontend/src/components/QueryEditor.external-sql-save.test.tsx +++ b/frontend/src/components/QueryEditor.external-sql-save.test.tsx @@ -3,6 +3,7 @@ import { act, create, type ReactTestRenderer } from 'react-test-renderer'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import type { SavedQuery, TabData } from '../types'; +import { ORACLE_ROWID_LOCATOR_COLUMN } from '../utils/rowLocator'; import QueryEditor from './QueryEditor'; const storeState = vi.hoisted(() => ({ @@ -44,6 +45,7 @@ const backendApp = vi.hoisted(() => ({ DBGetAllColumns: vi.fn(), DBGetDatabases: vi.fn(), DBGetColumns: vi.fn(), + DBGetIndexes: vi.fn(), CancelQuery: vi.fn(), GenerateQueryID: vi.fn(), WriteSQLFile: vi.fn(), @@ -56,6 +58,10 @@ const messageApi = vi.hoisted(() => ({ warning: vi.fn(), })); +const dataGridState = vi.hoisted(() => ({ + latestProps: null as any, +})); + const editorState = vi.hoisted(() => { const state = { value: '', @@ -114,7 +120,10 @@ vi.mock('@monaco-editor/react', () => ({ })); vi.mock('./DataGrid', () => ({ - default: () => null, + default: (props: any) => { + dataGridState.latestProps = props; + return
; + }, GONAVI_ROW_KEY: '__gonavi_row_key__', })); @@ -152,7 +161,7 @@ vi.mock('antd', () => { Dropdown: ({ children }: any) => <>{children}, Tooltip: ({ children }: any) => <>{children}, Select: () => null, - Tabs: () => null, + Tabs: ({ items }: any) =>
{items?.[0]?.children}
, }; }); @@ -187,7 +196,15 @@ describe('QueryEditor external SQL save', () => { storeState.activeTabId = 'tab-1'; messageApi.success.mockReset(); messageApi.error.mockReset(); + messageApi.warning.mockReset(); backendApp.WriteSQLFile.mockResolvedValue({ success: true }); + backendApp.DBQueryMulti.mockResolvedValue({ success: true, data: [] }); + backendApp.DBGetColumns.mockResolvedValue({ success: true, data: [] }); + backendApp.DBGetIndexes.mockResolvedValue({ success: true, data: [] }); + backendApp.GenerateQueryID.mockResolvedValue('query-1'); + storeState.connections[0].config.type = 'mysql'; + storeState.connections[0].config.database = 'main'; + dataGridState.latestProps = null; editorState.value = ''; editorState.editor.getValue.mockClear(); editorState.editor.setValue.mockClear(); @@ -276,4 +293,156 @@ describe('QueryEditor external SQL save', () => { createdAt: 100, })); }); + + it('automatically appends hidden primary key locator columns for editable query results', async () => { + storeState.connections[0].config.type = 'oracle'; + storeState.connections[0].config.database = 'ORCLPDB1'; + backendApp.DBQueryMulti.mockResolvedValueOnce({ + success: true, + data: [{ columns: ['NAME', '__gonavi_locator_1_ID'], rows: [{ NAME: 'old-name', __gonavi_locator_1_ID: 7 }] }], + }); + backendApp.DBGetColumns.mockResolvedValueOnce({ + success: true, + data: [{ name: 'ID', key: 'PRI' }, { name: 'NAME', key: '' }], + }); + + let renderer: ReactTestRenderer; + await act(async () => { + renderer = create(); + }); + + await act(async () => { + await findButton(renderer!, '运行').props.onClick(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(dataGridState.latestProps?.tableName).toBe('MYCIMLED.EDC_LOG'); + expect(dataGridState.latestProps?.pkColumns).toEqual(['ID']); + expect(dataGridState.latestProps?.editLocator).toMatchObject({ + strategy: 'primary-key', + columns: ['ID'], + valueColumns: ['__gonavi_locator_1_ID'], + hiddenColumns: ['__gonavi_locator_1_ID'], + readOnly: false, + }); + expect(dataGridState.latestProps?.readOnly).toBe(false); + expect(dataGridState.latestProps?.resultSql).toBe('SELECT NAME FROM MYCIMLED.EDC_LOG'); + expect(String(backendApp.DBQueryMulti.mock.calls[0][2])).toContain('"ID" AS "__gonavi_locator_1_ID"'); + expect(messageApi.warning).not.toHaveBeenCalled(); + }); + + it('uses a unique index locator for query results without primary keys', async () => { + storeState.connections[0].config.type = 'oracle'; + storeState.connections[0].config.database = 'ORCLPDB1'; + backendApp.DBQueryMulti.mockResolvedValueOnce({ + success: true, + data: [{ columns: ['NAME', '__gonavi_locator_1_EMAIL'], rows: [{ NAME: 'old-name', __gonavi_locator_1_EMAIL: 'a@example.com' }] }], + }); + backendApp.DBGetColumns.mockResolvedValueOnce({ + success: true, + data: [{ name: 'EMAIL', key: '' }, { name: 'NAME', key: '' }], + }); + backendApp.DBGetIndexes.mockResolvedValueOnce({ + success: true, + data: [{ name: 'UK_EMAIL', columnName: 'EMAIL', nonUnique: 0, seqInIndex: 1, indexType: 'BTREE' }], + }); + + let renderer: ReactTestRenderer; + await act(async () => { + renderer = create(); + }); + + await act(async () => { + await findButton(renderer!, '运行').props.onClick(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(dataGridState.latestProps?.editLocator).toMatchObject({ + strategy: 'unique-key', + columns: ['EMAIL'], + valueColumns: ['__gonavi_locator_1_EMAIL'], + hiddenColumns: ['__gonavi_locator_1_EMAIL'], + readOnly: false, + }); + expect(dataGridState.latestProps?.readOnly).toBe(false); + expect(String(backendApp.DBQueryMulti.mock.calls[0][2])).toContain('"EMAIL" AS "__gonavi_locator_1_EMAIL"'); + expect(messageApi.warning).not.toHaveBeenCalled(); + }); + + it('uses hidden Oracle ROWID for query results without primary or unique keys', async () => { + storeState.connections[0].config.type = 'oracle'; + storeState.connections[0].config.database = 'ORCLPDB1'; + backendApp.DBQueryMulti.mockResolvedValueOnce({ + success: true, + data: [{ columns: ['NAME', ORACLE_ROWID_LOCATOR_COLUMN], rows: [{ NAME: 'old-name', [ORACLE_ROWID_LOCATOR_COLUMN]: 'AAAA' }] }], + }); + backendApp.DBGetColumns.mockResolvedValueOnce({ + success: true, + data: [{ name: 'NAME', key: '' }], + }); + + let renderer: ReactTestRenderer; + await act(async () => { + renderer = create(); + }); + + await act(async () => { + await findButton(renderer!, '运行').props.onClick(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(dataGridState.latestProps?.editLocator).toMatchObject({ + strategy: 'oracle-rowid', + columns: ['ROWID'], + valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN], + hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN], + readOnly: false, + }); + expect(dataGridState.latestProps?.readOnly).toBe(false); + expect(String(backendApp.DBQueryMulti.mock.calls[0][2])).toContain(`ROWID AS "${ORACLE_ROWID_LOCATOR_COLUMN}"`); + expect(messageApi.warning).not.toHaveBeenCalled(); + }); + + it('keeps non-Oracle query results read-only when no safe locator exists', async () => { + backendApp.DBQueryMulti.mockResolvedValueOnce({ + success: true, + data: [{ columns: ['NAME'], rows: [{ NAME: 'old-name' }] }], + }); + backendApp.DBGetColumns.mockResolvedValueOnce({ + success: true, + data: [{ name: 'NAME', key: '' }], + }); + + let renderer: ReactTestRenderer; + await act(async () => { + renderer = create(); + }); + + await act(async () => { + await findButton(renderer!, '运行').props.onClick(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(dataGridState.latestProps?.tableName).toBe('users'); + expect(dataGridState.latestProps?.pkColumns).toEqual([]); + expect(dataGridState.latestProps?.editLocator).toMatchObject({ + strategy: 'none', + readOnly: true, + reason: '未检测到主键或可用唯一索引,无法安全提交修改。', + }); + expect(dataGridState.latestProps?.readOnly).toBe(true); + expect(messageApi.warning).toHaveBeenCalledWith('查询结果保持只读:main.users 未检测到主键或可用唯一索引,无法安全提交修改。'); + }); }); diff --git a/frontend/src/components/QueryEditor.tsx b/frontend/src/components/QueryEditor.tsx index b50627c..c3c7093 100644 --- a/frontend/src/components/QueryEditor.tsx +++ b/frontend/src/components/QueryEditor.tsx @@ -4,17 +4,21 @@ import { Button, message, Modal, Input, Form, Dropdown, MenuProps, Tooltip, Sele import { PlayCircleOutlined, SaveOutlined, FormatPainterOutlined, SettingOutlined, CloseOutlined, StopOutlined, RobotOutlined } from '@ant-design/icons'; import { format } from 'sql-formatter'; import { v4 as uuidv4 } from 'uuid'; -import { TabData, ColumnDefinition } from '../types'; +import { TabData, ColumnDefinition, IndexDefinition } from '../types'; import { useStore } from '../store'; -import { DBQueryWithCancel, DBQueryMulti, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, CancelQuery, GenerateQueryID, WriteSQLFile } from '../../wailsjs/go/app/App'; +import { DBQueryWithCancel, DBQueryMulti, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, DBGetIndexes, CancelQuery, GenerateQueryID, WriteSQLFile } from '../../wailsjs/go/app/App'; import DataGrid, { GONAVI_ROW_KEY } from './DataGrid'; import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities'; import { applyMongoQueryAutoLimit, convertMongoShellToJsonCommand } from '../utils/mongodb'; import { getShortcutDisplay, isEditableElement, isShortcutMatch } from '../utils/shortcuts'; import { useAutoFetchVisibility } from '../utils/autoFetchVisibility'; import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; -import { resolveSqlDialect, resolveSqlFunctions, resolveSqlKeywords } from '../utils/sqlDialect'; +import { isOracleLikeDialect, resolveSqlDialect, resolveSqlFunctions, resolveSqlKeywords } from '../utils/sqlDialect'; import { applyQueryAutoLimit } from '../utils/queryAutoLimit'; +import { extractQueryResultTableRef, type QueryResultTableRef } from '../utils/queryResultTable'; +import { quoteIdentPart } from '../utils/sql'; +import { resolveUniqueKeyGroupsFromIndexes } from './dataGridCopyInsert'; +import { ORACLE_ROWID_LOCATOR_COLUMN, type EditRowLocator } from '../utils/rowLocator'; const SQL_KEYWORDS = [ 'SELECT', 'FROM', 'WHERE', 'LIMIT', 'INSERT', 'UPDATE', 'DELETE', 'JOIN', 'LEFT', 'RIGHT', @@ -187,6 +191,290 @@ let sharedAllColumnsData: {dbName: string, tableName: string, name: string, type let sharedVisibleDbs: string[] = []; let sharedColumnsCacheData: Record = {}; +const QUERY_LOCATOR_ALIAS_PREFIX = '__gonavi_locator_'; + +const buildQueryReadOnlyLocator = (reason: string): EditRowLocator => ({ + strategy: 'none', + columns: [], + valueColumns: [], + readOnly: true, + reason, +}); + +type SimpleSelectInfo = { + selectsAll: boolean; + resultColumns: string[]; +}; + +type QueryStatementPlan = { + originalSql: string; + executedSql: string; + tableRef?: QueryResultTableRef; + pkColumns: string[]; + editLocator?: EditRowLocator; + warning?: string; +}; + +const stripQueryIdentifierQuotes = (part: string): string => { + const text = String(part || '').trim(); + if (!text) return ''; + if ((text.startsWith('`') && text.endsWith('`')) || (text.startsWith('"') && text.endsWith('"'))) { + return text.slice(1, -1).trim(); + } + if (text.startsWith('[') && text.endsWith(']')) { + return text.slice(1, -1).trim(); + } + return text; +}; + +const splitTopLevelComma = (text: string): string[] => { + const parts: string[] = []; + let current = ''; + let parenDepth = 0; + let inSingle = false; + let inDouble = false; + let inBacktick = false; + let escaped = false; + + for (let index = 0; index < text.length; index++) { + const ch = text[index]; + if (escaped) { + current += ch; + escaped = false; + continue; + } + if ((inSingle || inDouble) && ch === '\\') { + current += ch; + escaped = true; + continue; + } + if (!inDouble && !inBacktick && ch === "'") { + inSingle = !inSingle; + current += ch; + continue; + } + if (!inSingle && !inBacktick && ch === '"') { + inDouble = !inDouble; + current += ch; + continue; + } + if (!inSingle && !inDouble && ch === '`') { + inBacktick = !inBacktick; + current += ch; + continue; + } + if (!inSingle && !inDouble && !inBacktick) { + if (ch === '(') parenDepth++; + if (ch === ')' && parenDepth > 0) parenDepth--; + if (ch === ',' && parenDepth === 0) { + parts.push(current.trim()); + current = ''; + continue; + } + } + current += ch; + } + + if (current.trim()) parts.push(current.trim()); + return parts; +}; + +const SIMPLE_IDENTIFIER_PATH_RE = /^(?:[`"\[]?[A-Za-z_][\w$]*[`"\]]?\s*\.\s*){0,2}[`"\[]?[A-Za-z_][\w$]*[`"\]]?$/; +const QUERY_ALIAS_RESERVED = new Set([ + 'where', 'group', 'order', 'having', 'limit', 'fetch', 'offset', 'join', 'left', 'right', 'inner', 'outer', 'on', 'union', +]); + +const getLastIdentifierPart = (path: string): string => { + const parts = String(path || '').split('.').map((part) => stripQueryIdentifierQuotes(part.trim())).filter(Boolean); + return parts[parts.length - 1] || ''; +}; + +const resolveSimpleSelectItemColumn = (item: string): { name: string } | 'all' | undefined => { + const text = String(item || '').trim(); + if (!text) return undefined; + if (text === '*' || /\.\s*\*$/.test(text)) return 'all'; + + let expr = text; + let alias = ''; + const asMatch = text.match(/^(.*?)\s+AS\s+([`"\[]?[A-Za-z_][\w$]*[`"\]]?)$/i); + if (asMatch) { + expr = asMatch[1].trim(); + alias = stripQueryIdentifierQuotes(asMatch[2]); + } else { + const bareAliasMatch = text.match(/^(.*?)\s+([`"\[]?[A-Za-z_][\w$]*[`"\]]?)$/); + if (bareAliasMatch && SIMPLE_IDENTIFIER_PATH_RE.test(bareAliasMatch[1].trim())) { + const candidateAlias = stripQueryIdentifierQuotes(bareAliasMatch[2]); + if (candidateAlias && !QUERY_ALIAS_RESERVED.has(candidateAlias.toLowerCase())) { + expr = bareAliasMatch[1].trim(); + alias = candidateAlias; + } + } + } + + if (!SIMPLE_IDENTIFIER_PATH_RE.test(expr)) return undefined; + const name = alias || getLastIdentifierPart(expr); + return name ? { name } : undefined; +}; + +const parseSimpleSelectInfo = (sql: string): SimpleSelectInfo | undefined => { + const match = String(sql || '').match(/^\s*SELECT\s+([\s\S]+?)\s+FROM\s+/i); + if (!match) return undefined; + const selectList = match[1].trim(); + if (!selectList || /^DISTINCT\b/i.test(selectList)) return undefined; + + const resultColumns: string[] = []; + let selectsAll = false; + for (const item of splitTopLevelComma(selectList)) { + const resolved = resolveSimpleSelectItemColumn(item); + if (!resolved) return undefined; + if (resolved === 'all') { + selectsAll = true; + continue; + } + resultColumns.push(resolved.name); + } + return { selectsAll, resultColumns }; +}; + +const appendQuerySelectExpressions = (sql: string, expressions: string[]): string => { + if (expressions.length === 0) return sql; + return String(sql || '').replace( + /^(\s*SELECT\s+)([\s\S]+?)(\s+FROM\s+[\s\S]*)$/i, + (_match, prefix, selectList, rest) => `${prefix}${String(selectList).trimEnd()}, ${expressions.join(', ')}${rest}`, + ); +}; + +const findQueryResultColumn = (columns: string[], target: string): string | undefined => { + const normalizedTarget = String(target || '').trim().toLowerCase(); + return (columns || []).find((column) => String(column || '').trim().toLowerCase() === normalizedTarget); +}; + +const buildQueryLocatorAlias = (column: string, index: number): string => { + const normalized = String(column || '').trim().replace(/[^A-Za-z0-9_]/g, '_').slice(0, 48) || 'column'; + return `${QUERY_LOCATOR_ALIAS_PREFIX}${index}_${normalized}`; +}; + +const buildQueryLocatorColumnExpression = (dbType: string, column: string, alias: string): string => ( + `${quoteIdentPart(dbType, column)} AS ${quoteIdentPart(dbType, alias)}` +); + +const buildQueryRowIDExpression = (dbType: string): string => ( + `ROWID AS ${quoteIdentPart(dbType, ORACLE_ROWID_LOCATOR_COLUMN)}` +); + +const resolveQueryLocatorPlan = async ({ + statement, + dbType, + currentDb, + config, + forceReadOnly, +}: { + statement: string; + dbType: string; + currentDb: string; + config: any; + forceReadOnly: boolean; +}): Promise => { + const plan: QueryStatementPlan = { + originalSql: statement, + executedSql: statement, + pkColumns: [], + }; + if (forceReadOnly) return plan; + + const tableRef = extractQueryResultTableRef(statement, dbType, currentDb); + if (!tableRef) return plan; + plan.tableRef = tableRef; + + const selectInfo = parseSimpleSelectInfo(statement); + if (!selectInfo) { + const reason = '当前 SELECT 列表不是简单列或 *,无法安全提交修改。'; + plan.editLocator = buildQueryReadOnlyLocator(reason); + plan.warning = `查询结果保持只读:${reason}`; + return plan; + } + + try { + const [resCols, resIndexes] = await Promise.all([ + DBGetColumns(buildRpcConnectionConfig(config) as any, tableRef.metadataDbName, tableRef.metadataTableName), + DBGetIndexes(buildRpcConnectionConfig(config) as any, tableRef.metadataDbName, tableRef.metadataTableName) + .catch((error: any) => ({ success: false, message: String(error?.message || error || '加载索引失败'), data: [] })), + ]); + if (!resCols?.success || !Array.isArray(resCols.data)) { + const reason = `无法加载 ${tableRef.metadataDbName}.${tableRef.metadataTableName} 的主键/唯一索引元数据,无法安全提交修改。`; + plan.editLocator = buildQueryReadOnlyLocator(reason); + plan.warning = `查询结果保持只读:${reason}`; + return plan; + } + + const tableColumns = resCols.data as ColumnDefinition[]; + const primaryKeys = tableColumns + .filter((column: any) => column?.key === 'PRI') + .map((column: any) => String(column?.name || '').trim()) + .filter(Boolean); + const indexes = resIndexes?.success && Array.isArray(resIndexes.data) + ? resIndexes.data as IndexDefinition[] + : []; + const selectedColumns = selectInfo.selectsAll + ? tableColumns.map((column) => String(column?.name || '').trim()).filter(Boolean) + : selectInfo.resultColumns; + const appendExpressions: string[] = []; + const hiddenColumns: string[] = []; + + const buildColumnLocator = (strategy: 'primary-key' | 'unique-key', locatorColumns: string[]): EditRowLocator => { + const valueColumns = locatorColumns.map((column, index) => { + const selectedColumn = findQueryResultColumn(selectedColumns, column); + if (selectedColumn) return selectedColumn; + const alias = buildQueryLocatorAlias(column, index + 1); + appendExpressions.push(buildQueryLocatorColumnExpression(dbType, column, alias)); + hiddenColumns.push(alias); + return alias; + }); + return { + strategy, + columns: locatorColumns, + valueColumns, + hiddenColumns: hiddenColumns.length > 0 ? [...hiddenColumns] : undefined, + readOnly: false, + }; + }; + + if (primaryKeys.length > 0) { + plan.pkColumns = primaryKeys; + plan.editLocator = buildColumnLocator('primary-key', primaryKeys); + } else { + const uniqueKeyGroups = resolveUniqueKeyGroupsFromIndexes(indexes); + const uniqueKeyGroup = uniqueKeyGroups.find((group) => group.length > 0); + if (uniqueKeyGroup) { + plan.editLocator = buildColumnLocator('unique-key', uniqueKeyGroup); + } else if (isOracleLikeDialect(dbType)) { + appendExpressions.push(buildQueryRowIDExpression(dbType)); + plan.editLocator = { + strategy: 'oracle-rowid', + columns: ['ROWID'], + valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN], + hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN], + readOnly: false, + }; + } else { + const reason = !resIndexes?.success + ? '无法加载唯一索引元数据,无法安全提交修改。' + : '未检测到主键或可用唯一索引,无法安全提交修改。'; + plan.editLocator = buildQueryReadOnlyLocator(reason); + plan.warning = `查询结果保持只读:${tableRef.metadataDbName}.${tableRef.metadataTableName} ${reason}`; + } + } + + plan.executedSql = appendQuerySelectExpressions(statement, appendExpressions); + return plan; + } catch { + const reason = `无法加载 ${tableRef.metadataDbName}.${tableRef.metadataTableName} 的主键/唯一索引元数据,无法安全提交修改。`; + plan.editLocator = buildQueryReadOnlyLocator(reason); + plan.warning = `查询结果保持只读:${reason}`; + return plan; + } +}; + const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isActive = true }) => { const [query, setQuery] = useState(tab.query || 'SELECT * FROM '); @@ -198,6 +486,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc columns: string[]; tableName?: string; pkColumns: string[]; + editLocator?: EditRowLocator; readOnly: boolean; truncated?: boolean; pkLoading?: boolean; @@ -1439,26 +1728,36 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc } else { // 非 MongoDB:使用 DBQueryMulti 一次性执行多条 SQL,后端返回多结果集 - let fullSQL = normalizedRawSQL; - if (!fullSQL.trim()) { + const sourceStatements = splitSQLStatements(normalizedRawSQL); + if (sourceStatements.length === 0) { message.info('没有可执行的 SQL。'); setResultSets([]); setActiveResultKey(''); return; } + const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult; + const statementPlans: QueryStatementPlan[] = []; + for (const statement of sourceStatements) { + statementPlans.push(await resolveQueryLocatorPlan({ + statement, + dbType: normalizedDbType, + currentDb, + config, + forceReadOnly: forceReadOnlyResult, + })); + } + // 自动给 SELECT 语句注入行数限制(防止大结果集卡死) const maxRowsForLimit = Number(queryOptions?.maxRows) || 0; let anyLimitApplied = false; - if (Number.isFinite(maxRowsForLimit) && maxRowsForLimit > 0) { - const stmts = splitSQLStatements(fullSQL); - const limitedStmts = stmts.map(s => { - const result = applyQueryAutoLimit(s, normalizedDbType, maxRowsForLimit, driver); - if (result.applied) anyLimitApplied = true; - return result.sql; - }); - fullSQL = limitedStmts.join(';\n'); - } + const executablePlans = statementPlans.map((plan) => { + if (!Number.isFinite(maxRowsForLimit) || maxRowsForLimit <= 0) return plan; + const result = applyQueryAutoLimit(plan.executedSql, normalizedDbType, maxRowsForLimit, driver); + if (result.applied) anyLimitApplied = true; + return { ...plan, executedSql: result.sql }; + }); + const fullSQL = executablePlans.map((plan) => plan.executedSql).join(';\n'); const startTime = Date.now(); let queryId: string; @@ -1515,16 +1814,13 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc const resultSetDataArray = Array.isArray(res.data) ? (res.data as any[]) : []; const nextResultSets: ResultSet[] = []; const maxRows = Number(queryOptions?.maxRows) || 0; - const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult; let anyTruncated = false; - const pendingPk: Array<{ resultKey: string; tableName: string }> = []; - - // 前端也拆分语句用于匹配原始 SQL(展示和表名检测) - const statements = splitSQLStatements(fullSQL); for (let idx = 0; idx < resultSetDataArray.length; idx++) { const rsData = resultSetDataArray[idx]; - const rawStatement = (idx < statements.length) ? statements[idx] : ''; + const plan = executablePlans[idx]; + const originalSql = plan?.originalSql || ''; + const executedSql = plan?.executedSql || originalSql; // 检查是否为 affectedRows 类结果集 const isAffectedResult = Array.isArray(rsData.rows) && rsData.rows.length === 1 @@ -1537,8 +1833,8 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc (row as any)[GONAVI_ROW_KEY] = 0; nextResultSets.push({ key: `result-${idx + 1}`, - sql: rawStatement, - exportSql: rawStatement, + sql: executedSql, + exportSql: originalSql, rows: [row], columns: ['affectedRows'], pkColumns: [], @@ -1561,32 +1857,18 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc if (row && typeof row === 'object') row[GONAVI_ROW_KEY] = i; }); - let simpleTableName: string | undefined = undefined; - if (rawStatement) { - // 支持多行 SQL:SELECT [cols] FROM [schema.]table [WHERE...] [ORDER BY...] [LIMIT...] 等 - // JOIN 查询表名歧义,不提取 - const hasJoin = /\bJOIN\b/i.test(rawStatement); - const tableMatch = !hasJoin - ? rawStatement.match(/^\s*SELECT\s+.+?\s+FROM\s+(?:[\w`"\[\].]+\.)?[`"\[]?(\w+)[`"\]]?\s*(?:$|[\s;])/im) - : null; - if (tableMatch) { - simpleTableName = tableMatch[1]; - if (!forceReadOnlyResult) { - pendingPk.push({ resultKey: `result-${idx + 1}`, tableName: simpleTableName }); - } - } - } - + const tableRef = plan?.tableRef; + const editLocator = plan?.editLocator; nextResultSets.push({ key: `result-${idx + 1}`, - sql: rawStatement, - exportSql: rawStatement, + sql: executedSql, + exportSql: originalSql, rows, columns: cols, - tableName: simpleTableName, - pkColumns: [], - readOnly: true, - pkLoading: !!simpleTableName, + tableName: tableRef?.tableName, + pkColumns: plan?.pkColumns || [], + editLocator, + readOnly: forceReadOnlyResult || !editLocator || editLocator.readOnly, truncated }); } @@ -1595,21 +1877,8 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc setResultSets(nextResultSets); setActiveResultKey(nextResultSets[0]?.key || ''); - pendingPk.forEach(({ resultKey, tableName }) => { - DBGetColumns(buildRpcConnectionConfig(config) as any, currentDb, tableName) - .then((resCols: any) => { - if (runSeqRef.current !== runSeq) return; - if (!resCols?.success) { - setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkLoading: false, readOnly: false } : rs)); - return; - } - const primaryKeys = (resCols.data as ColumnDefinition[]).filter(c => c.key === 'PRI').map(c => c.name); - setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkColumns: primaryKeys, pkLoading: false, readOnly: false } : rs)); - }) - .catch(() => { - if (runSeqRef.current !== runSeq) return; - setResultSets(prev => prev.map(rs => rs.key === resultKey ? { ...rs, pkLoading: false, readOnly: false } : rs)); - }); + executablePlans.forEach((plan) => { + if (plan.warning) message.warning(plan.warning); }); // 后端附带的提示信息(如数据源不支持原生多语句执行的回退提示) @@ -2142,6 +2411,7 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc dbName={currentDb} connectionId={currentConnectionId} pkColumns={rs.pkColumns} + editLocator={rs.editLocator} onReload={() => handleReloadResult(rs.key, rs.sql)} readOnly={rs.readOnly} /> diff --git a/frontend/src/utils/queryResultTable.test.ts b/frontend/src/utils/queryResultTable.test.ts new file mode 100644 index 0000000..d9059d3 --- /dev/null +++ b/frontend/src/utils/queryResultTable.test.ts @@ -0,0 +1,44 @@ +import { describe, expect, it } from 'vitest'; + +import { extractQueryResultTableRef } from './queryResultTable'; + +describe('extractQueryResultTableRef', () => { + it('preserves Oracle schema-qualified table names for editing', () => { + expect(extractQueryResultTableRef('SELECT * FROM MYCIMLED.EDC_LOG FETCH FIRST 500 ROWS ONLY', 'oracle', 'ANONYMOUS')) + .toEqual({ + tableName: 'MYCIMLED.EDC_LOG', + metadataDbName: 'MYCIMLED', + metadataTableName: 'EDC_LOG', + }); + }); + + it('uses current schema for unqualified Oracle tables', () => { + expect(extractQueryResultTableRef('SELECT * FROM EDC_LOG', 'oracle', 'MYCIMLED')) + .toEqual({ + tableName: 'EDC_LOG', + metadataDbName: 'MYCIMLED', + metadataTableName: 'EDC_LOG', + }); + }); + + it('keeps existing simple table behavior for MySQL-style qualified names', () => { + expect(extractQueryResultTableRef('SELECT * FROM app.users LIMIT 500', 'mysql', 'app')) + .toEqual({ + tableName: 'users', + metadataDbName: 'app', + metadataTableName: 'users', + }); + }); + + it('does not mark join results as editable table refs', () => { + expect(extractQueryResultTableRef('SELECT * FROM users u JOIN orders o ON u.id = o.user_id', 'oracle', 'APP')) + .toBeUndefined(); + }); + + it('does not mark grouped or distinct results as editable table refs', () => { + expect(extractQueryResultTableRef('SELECT ID FROM users GROUP BY ID', 'mysql', 'app')) + .toBeUndefined(); + expect(extractQueryResultTableRef('SELECT DISTINCT ID FROM users', 'mysql', 'app')) + .toBeUndefined(); + }); +}); diff --git a/frontend/src/utils/queryResultTable.ts b/frontend/src/utils/queryResultTable.ts new file mode 100644 index 0000000..60bd909 --- /dev/null +++ b/frontend/src/utils/queryResultTable.ts @@ -0,0 +1,64 @@ +export type QueryResultTableRef = { + tableName: string; + metadataDbName: string; + metadataTableName: string; +}; + +const stripIdentifierQuotes = (part: string): string => { + const text = String(part || '').trim(); + if (!text) return ''; + if ((text.startsWith('`') && text.endsWith('`')) || (text.startsWith('"') && text.endsWith('"'))) { + return text.slice(1, -1).trim(); + } + if (text.startsWith('[') && text.endsWith(']')) { + return text.slice(1, -1).trim(); + } + return text; +}; + +const normalizeQualifiedName = (raw: string): string => ( + String(raw || '') + .split('.') + .map((part) => stripIdentifierQuotes(part.trim())) + .filter(Boolean) + .join('.') +); + +const isOracleLikeDialect = (dialect: string): boolean => { + const normalized = String(dialect || '').trim().toLowerCase(); + return normalized === 'oracle' || normalized === 'dameng' || normalized === 'dm' || normalized === 'dm8'; +}; + +export const extractQueryResultTableRef = ( + sql: string, + dialect: string, + currentDb: string, +): QueryResultTableRef | undefined => { + const text = String(sql || '').trim(); + if (!text) return undefined; + if (/\b(JOIN|UNION|INTERSECT|EXCEPT|MINUS)\b/i.test(text)) return undefined; + if (/^\s*SELECT\s+DISTINCT\b/i.test(text)) return undefined; + if (/\bGROUP\s+BY\b|\bHAVING\b/i.test(text)) return undefined; + + const tableMatch = text.match(/^\s*SELECT\s+.+?\s+FROM\s+((?:[`"\[]?\w+[`"\]]?)(?:\s*\.\s*(?:[`"\[]?\w+[`"\]]?)){0,2})\s*(?:$|[\s;])/im); + if (!tableMatch) return undefined; + + const qualifiedName = normalizeQualifiedName(tableMatch[1]); + if (!qualifiedName) return undefined; + + const parts = qualifiedName.split('.').filter(Boolean); + const metadataTableName = parts[parts.length - 1] || ''; + if (!metadataTableName) return undefined; + + const owner = parts.length >= 2 ? parts[parts.length - 2] : ''; + const metadataDbName = owner || currentDb || ''; + const tableName = isOracleLikeDialect(dialect) && owner + ? `${owner}.${metadataTableName}` + : metadataTableName; + + return { + tableName, + metadataDbName, + metadataTableName, + }; +}; diff --git a/frontend/src/utils/rowLocator.test.ts b/frontend/src/utils/rowLocator.test.ts new file mode 100644 index 0000000..5c0916d --- /dev/null +++ b/frontend/src/utils/rowLocator.test.ts @@ -0,0 +1,146 @@ +import { describe, expect, it } from 'vitest'; + +import { + ORACLE_ROWID_LOCATOR_COLUMN, + filterHiddenLocatorColumns, + resolveEditRowLocator, + resolveRowLocatorValues, +} from './rowLocator'; + +const uniqueIndex = (name: string, columnName: string, seqInIndex = 1) => ({ + name, + columnName, + seqInIndex, + nonUnique: 0, + indexType: 'BTREE', +}); + +const normalIndex = (name: string, columnName: string, seqInIndex = 1) => ({ + name, + columnName, + seqInIndex, + nonUnique: 1, + indexType: 'BTREE', +}); + +describe('resolveEditRowLocator', () => { + it('prefers primary keys over unique indexes', () => { + expect(resolveEditRowLocator({ + dbType: 'mysql', + resultColumns: ['ID', 'EMAIL'], + primaryKeys: ['ID'], + indexes: [uniqueIndex('uk_email', 'EMAIL')], + })).toEqual({ + strategy: 'primary-key', + columns: ['ID'], + valueColumns: ['ID'], + readOnly: false, + }); + }); + + it('uses a unique index when there is no primary key', () => { + expect(resolveEditRowLocator({ + dbType: 'mysql', + resultColumns: ['EMAIL', 'NAME'], + indexes: [uniqueIndex('uk_email', 'EMAIL')], + })).toEqual({ + strategy: 'unique-key', + columns: ['EMAIL'], + valueColumns: ['EMAIL'], + readOnly: false, + }); + }); + + it('sorts composite unique index columns by sequence', () => { + expect(resolveEditRowLocator({ + dbType: 'postgres', + resultColumns: ['TENANT_ID', 'CODE', 'NAME'], + indexes: [ + uniqueIndex('uk_tenant_code', 'CODE', 2), + uniqueIndex('uk_tenant_code', 'TENANT_ID', 1), + ], + })).toMatchObject({ + strategy: 'unique-key', + columns: ['TENANT_ID', 'CODE'], + valueColumns: ['TENANT_ID', 'CODE'], + readOnly: false, + }); + }); + + it('ignores non-unique indexes', () => { + expect(resolveEditRowLocator({ + dbType: 'mysql', + resultColumns: ['NAME'], + indexes: [normalIndex('idx_name', 'NAME')], + })).toMatchObject({ + strategy: 'none', + readOnly: true, + }); + }); + + it('keeps results read-only when primary key columns are missing from result columns', () => { + expect(resolveEditRowLocator({ + dbType: 'oracle', + resultColumns: ['NAME'], + primaryKeys: ['ID'], + })).toMatchObject({ + strategy: 'none', + readOnly: true, + reason: '结果集中缺少主键列 ID,无法安全提交修改。', + }); + }); + + it('uses Oracle ROWID when no primary or unique key is available', () => { + expect(resolveEditRowLocator({ + dbType: 'oracle', + resultColumns: ['NAME', ORACLE_ROWID_LOCATOR_COLUMN], + allowOracleRowID: true, + })).toEqual({ + strategy: 'oracle-rowid', + columns: ['ROWID'], + valueColumns: [ORACLE_ROWID_LOCATOR_COLUMN], + hiddenColumns: [ORACLE_ROWID_LOCATOR_COLUMN], + readOnly: false, + }); + }); +}); + +describe('resolveRowLocatorValues', () => { + it('extracts locator values from the original row', () => { + const locator = resolveEditRowLocator({ + dbType: 'mysql', + resultColumns: ['EMAIL', 'NAME'], + indexes: [uniqueIndex('uk_email', 'EMAIL')], + }); + + expect(resolveRowLocatorValues(locator, { EMAIL: 'a@example.com', NAME: 'A' })).toEqual({ + ok: true, + values: { EMAIL: 'a@example.com' }, + }); + }); + + it('rejects nullable unique locator values', () => { + const locator = resolveEditRowLocator({ + dbType: 'mysql', + resultColumns: ['EMAIL', 'NAME'], + indexes: [uniqueIndex('uk_email', 'EMAIL')], + }); + + expect(resolveRowLocatorValues(locator, { EMAIL: null, NAME: 'A' })).toEqual({ + ok: false, + error: '定位列 EMAIL 的值为空,无法安全提交修改。', + }); + }); +}); + +describe('filterHiddenLocatorColumns', () => { + it('removes hidden Oracle ROWID columns from displayed columns', () => { + const locator = resolveEditRowLocator({ + dbType: 'oracle', + resultColumns: ['NAME', ORACLE_ROWID_LOCATOR_COLUMN], + allowOracleRowID: true, + }); + + expect(filterHiddenLocatorColumns(['NAME', ORACLE_ROWID_LOCATOR_COLUMN], locator)).toEqual(['NAME']); + }); +}); diff --git a/frontend/src/utils/rowLocator.ts b/frontend/src/utils/rowLocator.ts new file mode 100644 index 0000000..f4be8a8 --- /dev/null +++ b/frontend/src/utils/rowLocator.ts @@ -0,0 +1,133 @@ +import type { IndexDefinition } from '../types'; +import { resolveUniqueKeyGroupsFromIndexes } from '../components/dataGridCopyInsert'; +import { isOracleLikeDialect } from './sqlDialect'; + +export const ORACLE_ROWID_LOCATOR_COLUMN = '__gonavi_oracle_rowid__'; + +export type RowLocatorStrategy = 'primary-key' | 'unique-key' | 'oracle-rowid' | 'none'; + +export type EditRowLocator = { + strategy: RowLocatorStrategy; + columns: string[]; + valueColumns: string[]; + hiddenColumns?: string[]; + readOnly: boolean; + reason?: string; +}; + +export type ResolveEditRowLocatorParams = { + dbType: string; + resultColumns: string[]; + primaryKeys?: string[]; + indexes?: IndexDefinition[]; + allowOracleRowID?: boolean; +}; + +export type ResolveRowLocatorValuesResult = + | { ok: true; values: Record } + | { ok: false; error: string }; + +const normalizeColumnName = (value: string): string => String(value || '').trim(); + +const hasColumn = (columns: string[], target: string): boolean => { + const normalizedTarget = normalizeColumnName(target).toLowerCase(); + return columns.some((column) => normalizeColumnName(column).toLowerCase() === normalizedTarget); +}; + +const findColumn = (columns: string[], target: string): string => { + const normalizedTarget = normalizeColumnName(target).toLowerCase(); + return columns.find((column) => normalizeColumnName(column).toLowerCase() === normalizedTarget) || target; +}; + +const buildReadOnlyLocator = (reason: string): EditRowLocator => ({ + strategy: 'none', + columns: [], + valueColumns: [], + readOnly: true, + reason, +}); + +export const resolveEditRowLocator = ({ + dbType, + resultColumns, + primaryKeys = [], + indexes, + allowOracleRowID = false, +}: ResolveEditRowLocatorParams): EditRowLocator => { + const columns = (resultColumns || []).map(normalizeColumnName).filter(Boolean); + const primaryKeyColumns = (primaryKeys || []).map(normalizeColumnName).filter(Boolean); + + if (primaryKeyColumns.length > 0) { + const missing = primaryKeyColumns.filter((column) => !hasColumn(columns, column)); + if (missing.length === 0) { + return { + strategy: 'primary-key', + columns: primaryKeyColumns, + valueColumns: primaryKeyColumns.map((column) => findColumn(columns, column)), + readOnly: false, + }; + } + return buildReadOnlyLocator(`结果集中缺少主键列 ${missing.join(', ')},无法安全提交修改。`); + } + + const uniqueKeyGroups = resolveUniqueKeyGroupsFromIndexes(indexes); + const uniqueKeyGroup = uniqueKeyGroups.find((group) => group.length > 0 && group.every((column) => hasColumn(columns, column))); + if (uniqueKeyGroup) { + return { + strategy: 'unique-key', + columns: uniqueKeyGroup, + valueColumns: uniqueKeyGroup.map((column) => findColumn(columns, column)), + readOnly: false, + }; + } + + if (allowOracleRowID && isOracleLikeDialect(dbType) && hasColumn(columns, ORACLE_ROWID_LOCATOR_COLUMN)) { + const rowIDColumn = findColumn(columns, ORACLE_ROWID_LOCATOR_COLUMN); + return { + strategy: 'oracle-rowid', + columns: ['ROWID'], + valueColumns: [rowIDColumn], + hiddenColumns: [rowIDColumn], + readOnly: false, + }; + } + + if (allowOracleRowID && isOracleLikeDialect(dbType)) { + return buildReadOnlyLocator('未检测到主键或可用唯一索引,且结果中缺少 Oracle ROWID,无法安全提交修改。'); + } + + return buildReadOnlyLocator('未检测到主键或可用唯一索引,无法安全提交修改。'); +}; + +export const resolveRowLocatorValues = ( + locator: EditRowLocator | undefined, + row: Record, +): ResolveRowLocatorValuesResult => { + if (!locator || locator.readOnly || locator.strategy === 'none') { + return { ok: false, error: '当前结果没有可用的安全行定位方式,无法提交修改。' }; + } + + const values: Record = {}; + for (let index = 0; index < locator.columns.length; index++) { + const column = locator.columns[index]; + const valueColumn = locator.valueColumns[index] || column; + const value = row?.[valueColumn]; + if (value === null || value === undefined || value === '') { + return { ok: false, error: `定位列 ${column} 的值为空,无法安全提交修改。` }; + } + values[column] = value; + } + + return { ok: true, values }; +}; + +export const filterHiddenLocatorColumns = (columns: string[], locator?: EditRowLocator): string[] => { + const hidden = new Set((locator?.hiddenColumns || []).map((column) => normalizeColumnName(column).toLowerCase())); + if (hidden.size === 0) return columns; + return (columns || []).filter((column) => !hidden.has(normalizeColumnName(column).toLowerCase())); +}; + +export const isHiddenLocatorColumn = (column: string, locator?: EditRowLocator): boolean => { + const normalized = normalizeColumnName(column).toLowerCase(); + return (locator?.hiddenColumns || []).some((hidden) => normalizeColumnName(hidden).toLowerCase() === normalized); +}; diff --git a/frontend/wailsjs/go/models.ts b/frontend/wailsjs/go/models.ts index fff3fea..d2d8608 100755 --- a/frontend/wailsjs/go/models.ts +++ b/frontend/wailsjs/go/models.ts @@ -426,6 +426,7 @@ export namespace connection { inserts: any[]; updates: UpdateRow[]; deletes: any[]; + locatorStrategy?: string; static createFrom(source: any = {}) { return new ChangeSet(source); @@ -436,6 +437,7 @@ export namespace connection { this.inserts = source["inserts"]; this.updates = this.convertValues(source["updates"], UpdateRow); this.deletes = source["deletes"]; + this.locatorStrategy = source["locatorStrategy"]; } convertValues(a: any, classs: any, asMap: boolean = false): any { diff --git a/internal/connection/types.go b/internal/connection/types.go index 9db299a..1927f15 100644 --- a/internal/connection/types.go +++ b/internal/connection/types.go @@ -184,9 +184,10 @@ type UpdateRow struct { // ChangeSet 表示一组批量变更,包含新增、修改和删除操作。 type ChangeSet struct { - Inserts []map[string]interface{} `json:"inserts"` - Updates []UpdateRow `json:"updates"` - Deletes []map[string]interface{} `json:"deletes"` + Inserts []map[string]interface{} `json:"inserts"` + Updates []UpdateRow `json:"updates"` + Deletes []map[string]interface{} `json:"deletes"` + LocatorStrategy string `json:"locatorStrategy,omitempty"` } // MongoMemberInfo 描述 MongoDB 副本集成员的信息。 diff --git a/internal/db/database.go b/internal/db/database.go index 4e6f228..429f7d7 100644 --- a/internal/db/database.go +++ b/internal/db/database.go @@ -3,6 +3,7 @@ package db import ( "GoNavi-Wails/internal/connection" "context" + "database/sql" "fmt" "strings" ) @@ -64,6 +65,20 @@ type BatchApplier interface { ApplyChanges(tableName string, changes connection.ChangeSet) error } +func requireSingleRowAffected(result sql.Result, action string) error { + affected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("%s未生效:无法确认影响行数:%v", action, err) + } + if affected == 0 { + return fmt.Errorf("%s未生效:未匹配到任何行", action) + } + if affected != 1 { + return fmt.Errorf("%s未生效:影响了 %d 行,期望只影响 1 行", action, affected) + } + return nil +} + type databaseFactory func() Database var databaseFactories = map[string]databaseFactory{ diff --git a/internal/db/mysql_impl.go b/internal/db/mysql_impl.go index 0e9071f..744f948 100644 --- a/internal/db/mysql_impl.go +++ b/internal/db/mysql_impl.go @@ -624,8 +624,8 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e if err != nil { return fmt.Errorf("删除失败:%v", err) } - if affected, err := res.RowsAffected(); err == nil && affected == 0 { - return fmt.Errorf("删除未生效:未匹配到任何行") + if err := requireSingleRowAffected(res, "删除"); err != nil { + return err } } @@ -658,8 +658,8 @@ func (m *MySQLDB) ApplyChanges(tableName string, changes connection.ChangeSet) e if err != nil { return fmt.Errorf("更新失败:%v", err) } - if affected, err := res.RowsAffected(); err == nil && affected == 0 { - return fmt.Errorf("更新未生效:未匹配到任何行") + if err := requireSingleRowAffected(res, "更新"); err != nil { + return err } } diff --git a/internal/db/oracle_applychanges_test.go b/internal/db/oracle_applychanges_test.go index 88ff7e5..fcc801a 100644 --- a/internal/db/oracle_applychanges_test.go +++ b/internal/db/oracle_applychanges_test.go @@ -24,8 +24,16 @@ var ( ) type oracleRecordingState struct { - mu sync.Mutex - execArgs [][]driver.NamedValue + mu sync.Mutex + execQueries []string + execArgs [][]driver.NamedValue + rowsAffected int64 +} + +func (s *oracleRecordingState) snapshotExecQueries() []string { + s.mu.Lock() + defer s.mu.Unlock() + return append([]string(nil), s.execQueries...) } func (s *oracleRecordingState) snapshotExecArgs() [][]driver.NamedValue { @@ -63,11 +71,12 @@ func (c *oracleRecordingConn) Close() error { return nil } func (c *oracleRecordingConn) Begin() (driver.Tx, error) { return oracleRecordingTx{}, nil } -func (c *oracleRecordingConn) ExecContext(_ context.Context, _ string, args []driver.NamedValue) (driver.Result, error) { +func (c *oracleRecordingConn) ExecContext(_ context.Context, query string, args []driver.NamedValue) (driver.Result, error) { c.state.mu.Lock() defer c.state.mu.Unlock() + c.state.execQueries = append(c.state.execQueries, query) c.state.execArgs = append(c.state.execArgs, append([]driver.NamedValue(nil), args...)) - return driver.RowsAffected(1), nil + return driver.RowsAffected(c.state.rowsAffected), nil } func (c *oracleRecordingConn) QueryContext(_ context.Context, query string, _ []driver.NamedValue) (driver.Rows, error) { @@ -126,7 +135,7 @@ func openOracleRecordingDB(t *testing.T) (*sql.DB, *oracleRecordingState) { oracleRecordingDriverMu.Lock() oracleRecordingDriverSeq++ dsn := fmt.Sprintf("oracle-recording-%d", oracleRecordingDriverSeq) - state := &oracleRecordingState{} + state := &oracleRecordingState{rowsAffected: 1} oracleRecordingDriverStates[dsn] = state oracleRecordingDriverMu.Unlock() @@ -145,6 +154,82 @@ func openOracleRecordingDB(t *testing.T) (*sql.DB, *oracleRecordingState) { return dbConn, state } +func TestOracleApplyChangesReturnsErrorWhenUpdateMatchesNoRows(t *testing.T) { + t.Parallel() + + dbConn, state := openOracleRecordingDB(t) + state.rowsAffected = 0 + oracleDB := &OracleDB{conn: dbConn} + + changes := connection.ChangeSet{ + Updates: []connection.UpdateRow{{ + Keys: map[string]interface{}{ + "ID": 7, + }, + Values: map[string]interface{}{ + "NAME": "new-name", + }, + }}, + } + + err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes) + if err == nil { + t.Fatal("期望更新未匹配到行时返回错误,实际为 nil") + } + if !strings.Contains(err.Error(), "更新未生效") { + t.Fatalf("错误信息应提示更新未生效,实际=%v", err) + } +} + +func TestOracleApplyChangesReturnsErrorWhenUpdateAffectsMultipleRows(t *testing.T) { + t.Parallel() + + dbConn, state := openOracleRecordingDB(t) + state.rowsAffected = 2 + oracleDB := &OracleDB{conn: dbConn} + + changes := connection.ChangeSet{ + Updates: []connection.UpdateRow{{ + Keys: map[string]interface{}{ + "ID": 7, + }, + Values: map[string]interface{}{ + "NAME": "new-name", + }, + }}, + } + + err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes) + if err == nil { + t.Fatal("期望更新影响多行时返回错误,实际为 nil") + } + if !strings.Contains(err.Error(), "影响了 2 行") { + t.Fatalf("错误信息应提示影响多行,实际=%v", err) + } +} + +func TestOracleApplyChangesReturnsErrorWhenDeleteAffectsMultipleRows(t *testing.T) { + t.Parallel() + + dbConn, state := openOracleRecordingDB(t) + state.rowsAffected = 2 + oracleDB := &OracleDB{conn: dbConn} + + changes := connection.ChangeSet{ + Deletes: []map[string]interface{}{{ + "STATUS": "stale", + }}, + } + + err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes) + if err == nil { + t.Fatal("期望删除影响多行时返回错误,实际为 nil") + } + if !strings.Contains(err.Error(), "影响了 2 行") { + t.Fatalf("错误信息应提示影响多行,实际=%v", err) + } +} + func TestOracleApplyChangesNormalizesTemporalStringsForUpdate(t *testing.T) { t.Parallel() @@ -181,3 +266,87 @@ func TestOracleApplyChangesNormalizesTemporalStringsForUpdate(t *testing.T) { t.Fatalf("日期主键字段应绑定为 time.Time,实际=%#v(%T)", args[1].Value, args[1].Value) } } + +func TestOracleApplyChangesUsesUnquotedRowIDLocator(t *testing.T) { + t.Parallel() + + dbConn, state := openOracleRecordingDB(t) + oracleDB := &OracleDB{conn: dbConn} + + changes := connection.ChangeSet{ + LocatorStrategy: "oracle-rowid", + Updates: []connection.UpdateRow{{ + Keys: map[string]interface{}{ + "ROWID": "AAAA", + }, + Values: map[string]interface{}{ + "NAME": "new-name", + }, + }}, + } + + if err := oracleDB.ApplyChanges("MYCIMLED.EDC_LOG", changes); err != nil { + t.Fatalf("ApplyChanges 返回错误: %v", err) + } + + executions := state.snapshotExecQueries() + if len(executions) != 1 { + t.Fatalf("期望执行 1 条更新,实际 %d 条", len(executions)) + } + query := executions[0] + if !strings.Contains(query, "ROWID = :2") { + t.Fatalf("ROWID 定位条件不正确: %s", query) + } + if strings.Contains(query, "\"ROWID\" =") { + t.Fatalf("ROWID 不应被当作普通列引用: %s", query) + } +} + +func TestMySQLApplyChangesReturnsErrorWhenUpdateAffectsMultipleRows(t *testing.T) { + t.Parallel() + + dbConn, state := openOracleRecordingDB(t) + state.rowsAffected = 2 + mysqlDB := &MySQLDB{conn: dbConn} + + changes := connection.ChangeSet{ + Updates: []connection.UpdateRow{{ + Keys: map[string]interface{}{ + "id": 7, + }, + Values: map[string]interface{}{ + "name": "new-name", + }, + }}, + } + + err := mysqlDB.ApplyChanges("users", changes) + if err == nil { + t.Fatal("期望 MySQL 更新影响多行时返回错误,实际为 nil") + } + if !strings.Contains(err.Error(), "影响了 2 行") { + t.Fatalf("错误信息应提示影响多行,实际=%v", err) + } +} + +func TestPostgresApplyChangesReturnsErrorWhenDeleteAffectsMultipleRows(t *testing.T) { + t.Parallel() + + dbConn, state := openOracleRecordingDB(t) + state.rowsAffected = 2 + postgresDB := &PostgresDB{conn: dbConn} + + changes := connection.ChangeSet{ + Deletes: []map[string]interface{}{{ + "id": 7, + }}, + } + + err := postgresDB.ApplyChanges("public.users", changes) + if err == nil { + t.Fatal("期望 PostgreSQL 删除影响多行时返回错误,实际为 nil") + } + if !strings.Contains(err.Error(), "影响了 2 行") { + t.Fatalf("错误信息应提示影响多行,实际=%v", err) + } +} diff --git a/internal/db/oracle_impl.go b/internal/db/oracle_impl.go index 7d9c3ec..21f53b0 100644 --- a/internal/db/oracle_impl.go +++ b/internal/db/oracle_impl.go @@ -319,17 +319,31 @@ func (o *OracleDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefi } func (o *OracleDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { - query := fmt.Sprintf(`SELECT index_name, column_name, uniqueness - FROM all_ind_columns - JOIN all_indexes USING (index_name, owner) - WHERE table_owner = '%s' AND table_name = '%s'`, - strings.ToUpper(dbName), strings.ToUpper(tableName)) + esc := func(s string) string { return strings.ReplaceAll(strings.ToUpper(strings.TrimSpace(s)), "'", "''") } + table := esc(tableName) + if table == "" { + return nil, fmt.Errorf("表名不能为空") + } - if dbName == "" { - query = fmt.Sprintf(`SELECT index_name, column_name, uniqueness - FROM user_ind_columns - JOIN user_indexes USING (index_name) - WHERE table_name = '%s'`, strings.ToUpper(tableName)) + query := fmt.Sprintf(`SELECT c.index_name, c.column_name, i.uniqueness, c.column_position, i.index_type + FROM all_ind_columns c + JOIN all_indexes i ON i.owner = c.index_owner AND i.index_name = c.index_name + WHERE c.table_owner = '%s' + AND c.table_name = '%s' + AND c.column_name IS NOT NULL + AND c.column_name NOT LIKE 'SYS_NC%%$' + AND i.index_type NOT LIKE 'FUNCTION-BASED%%' + ORDER BY c.index_name, c.column_position`, esc(dbName), table) + + if strings.TrimSpace(dbName) == "" { + query = fmt.Sprintf(`SELECT c.index_name, c.column_name, i.uniqueness, c.column_position, i.index_type + FROM user_ind_columns c + JOIN user_indexes i ON i.index_name = c.index_name + WHERE c.table_name = '%s' + AND c.column_name IS NOT NULL + AND c.column_name NOT LIKE 'SYS_NC%%$' + AND i.index_type NOT LIKE 'FUNCTION-BASED%%' + ORDER BY c.index_name, c.column_position`, table) } data, _, err := o.Query(query) @@ -337,19 +351,46 @@ func (o *OracleDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefin return nil, err } + getValue := func(row map[string]interface{}, names ...string) interface{} { + for _, name := range names { + if value, ok := row[name]; ok { + return value + } + for key, value := range row { + if strings.EqualFold(key, name) { + return value + } + } + } + return nil + } + parseInt := func(value interface{}) int { + var n int + _, _ = fmt.Sscanf(strings.TrimSpace(fmt.Sprintf("%v", value)), "%d", &n) + return n + } + var indexes []connection.IndexDefinition for _, row := range data { - unique := 1 - if val, ok := row["UNIQUENESS"]; ok && val == "UNIQUE" { - unique = 0 + uniqueness := strings.ToUpper(strings.TrimSpace(fmt.Sprintf("%v", getValue(row, "UNIQUENESS")))) + nonUnique := 1 + if uniqueness == "UNIQUE" { + nonUnique = 0 + } + indexType := strings.ToUpper(strings.TrimSpace(fmt.Sprintf("%v", getValue(row, "INDEX_TYPE")))) + if indexType == "" || indexType == "" { + indexType = "BTREE" } idx := connection.IndexDefinition{ - Name: fmt.Sprintf("%v", row["INDEX_NAME"]), - ColumnName: fmt.Sprintf("%v", row["COLUMN_NAME"]), - NonUnique: unique, - // SeqInIndex is harder to get in simple join, omitting or estimating - IndexType: "BTREE", // Default assumption + Name: strings.TrimSpace(fmt.Sprintf("%v", getValue(row, "INDEX_NAME"))), + ColumnName: strings.TrimSpace(fmt.Sprintf("%v", getValue(row, "COLUMN_NAME"))), + NonUnique: nonUnique, + SeqInIndex: parseInt(getValue(row, "COLUMN_POSITION")), + IndexType: indexType, + } + if idx.Name == "" || idx.ColumnName == "" || strings.EqualFold(idx.ColumnName, "") { + continue } indexes = append(indexes, idx) } @@ -551,23 +592,38 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) qualifiedTable = quoteIdent(table) } - // 1. Deletes - for _, pk := range changes.Deletes { + isOracleRowIDLocator := strings.EqualFold(strings.TrimSpace(changes.LocatorStrategy), "oracle-rowid") + buildWhere := func(keys map[string]interface{}, startIndex int) ([]string, []interface{}, int) { var wheres []string var args []interface{} - idx := 0 - for k, v := range pk { + idx := startIndex + for k, v := range keys { idx++ + if isOracleRowIDLocator && strings.EqualFold(strings.TrimSpace(k), "ROWID") { + wheres = append(wheres, fmt.Sprintf("ROWID = :%d", idx)) + args = append(args, v) + continue + } wheres = append(wheres, fmt.Sprintf("%s = :%d", quoteIdent(k), idx)) args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap)) } + return wheres, args, idx + } + + // 1. Deletes + for _, pk := range changes.Deletes { + wheres, args, _ := buildWhere(pk, 0) if len(wheres) == 0 { continue } query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND ")) - if _, err := tx.Exec(query, args...); err != nil { + res, err := tx.Exec(query, args...) + if err != nil { return fmt.Errorf("删除失败:%v", err) } + if err := requireSingleRowAffected(res, "删除"); err != nil { + return err + } } // 2. Updates @@ -586,21 +642,21 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) continue } - var wheres []string - for k, v := range update.Keys { - idx++ - wheres = append(wheres, fmt.Sprintf("%s = :%d", quoteIdent(k), idx)) - args = append(args, normalizeOracleValueForWrite(k, v, columnTypeMap)) - } + wheres, whereArgs, _ := buildWhere(update.Keys, idx) + args = append(args, whereArgs...) if len(wheres) == 0 { return fmt.Errorf("更新操作需要主键条件") } query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND ")) - if _, err := tx.Exec(query, args...); err != nil { + res, err := tx.Exec(query, args...) + if err != nil { return fmt.Errorf("更新失败:%v", err) } + if err := requireSingleRowAffected(res, "更新"); err != nil { + return err + } } // 3. Inserts @@ -622,9 +678,13 @@ func (o *OracleDB) ApplyChanges(tableName string, changes connection.ChangeSet) } query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qualifiedTable, strings.Join(cols, ", "), strings.Join(placeholders, ", ")) - if _, err := tx.Exec(query, args...); err != nil { + res, err := tx.Exec(query, args...) + if err != nil { return fmt.Errorf("插入失败:%v", err) } + if affected, err := res.RowsAffected(); err == nil && affected == 0 { + return fmt.Errorf("插入未生效:未影响任何行") + } } return tx.Commit() diff --git a/internal/db/postgres_impl.go b/internal/db/postgres_impl.go index 5970003..e5cf94f 100644 --- a/internal/db/postgres_impl.go +++ b/internal/db/postgres_impl.go @@ -408,6 +408,12 @@ JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = x.attnum WHERE t.relkind IN ('r', 'p') AND t.relname = '%s' AND n.nspname = '%s' + AND ix.indisvalid + AND ix.indpred IS NULL + AND x.ordinality <= ix.indnkeyatts + AND NOT EXISTS ( + SELECT 1 FROM unnest(ix.indkey) AS expr_key(attnum) WHERE expr_key.attnum <= 0 + ) ORDER BY i.relname, x.ordinality`, esc(table), esc(schema)) data, _, err := p.Query(query) @@ -758,9 +764,13 @@ func (p *PostgresDB) ApplyChanges(tableName string, changes connection.ChangeSet continue } query := fmt.Sprintf("DELETE FROM %s WHERE %s", qualifiedTable, strings.Join(wheres, " AND ")) - if _, err := tx.Exec(query, args...); err != nil { + res, err := tx.Exec(query, args...) + if err != nil { return fmt.Errorf("删除失败:%v", err) } + if err := requireSingleRowAffected(res, "删除"); err != nil { + return err + } } // 2. Updates @@ -791,9 +801,13 @@ func (p *PostgresDB) ApplyChanges(tableName string, changes connection.ChangeSet } query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", qualifiedTable, strings.Join(sets, ", "), strings.Join(wheres, " AND ")) - if _, err := tx.Exec(query, args...); err != nil { + res, err := tx.Exec(query, args...) + if err != nil { return fmt.Errorf("更新失败:%v", err) } + if err := requireSingleRowAffected(res, "更新"); err != nil { + return err + } } // 3. Inserts From 0c1586d7a48312605cece58810f7b8c1019a0aa4 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Wed, 29 Apr 2026 17:16:37 +0800 Subject: [PATCH 4/6] =?UTF-8?q?=F0=9F=90=9B=20fix(clickhouse):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E5=8D=8F=E8=AE=AE=E9=80=89=E6=8B=A9=E4=B8=8E=E8=BF=9E?= =?UTF-8?q?=E6=8E=A5=E9=94=99=E8=AF=AF=E6=8F=90=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 支持 ClickHouse 手动 HTTP/Native 协议优先级,避免 URI scheme 覆盖用户选择 - Auto 模式识别 Native/HTTP 协议误配错误并自动尝试备用协议 - 净化连接失败中的二进制乱码,补充测试连接参数校验和排查日志 - 前端表单增加 ClickHouse 协议选择并同步类型、缓存 key 与持久化兼容 Refs #425 --- frontend/src/components/ConnectionModal.tsx | 185 ++++++++++++++-- frontend/src/store.test.ts | 23 ++ frontend/src/store.ts | 15 ++ frontend/src/types.ts | 1 + .../src/utils/connectionRpcConfig.test.ts | 13 ++ frontend/wailsjs/go/models.ts | 2 + internal/app/app.go | 7 + internal/app/app_cache_key_test.go | 19 ++ internal/app/methods_db.go | 19 ++ internal/app/methods_db_conn_test.go | 20 ++ internal/connection/types.go | 1 + internal/db/clickhouse_impl.go | 199 ++++++++++++++++-- internal/db/clickhouse_impl_test.go | 197 +++++++++++++++++ 13 files changed, 672 insertions(+), 29 deletions(-) diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx index 1784394..a1f154c 100644 --- a/frontend/src/components/ConnectionModal.tsx +++ b/frontend/src/components/ConnectionModal.tsx @@ -95,6 +95,7 @@ type ChoiceCardOption = { label: string; description?: string; }; +type ClickHouseProtocolChoice = "auto" | "http" | "native"; const MAX_URI_LENGTH = 4096; const MAX_URI_HOSTS = 32; const MAX_TIMEOUT_SECONDS = 3600; @@ -102,6 +103,25 @@ const CONNECTION_MODAL_WIDTH = 960; const CONNECTION_MODAL_BODY_HEIGHT = 620; const STEP1_SIDEBAR_DIVIDER_DARK = "rgba(255, 255, 255, 0.16)"; const STEP1_SIDEBAR_DIVIDER_LIGHT = "rgba(0, 0, 0, 0.08)"; +const CLICKHOUSE_PROTOCOL_OPTIONS: Array<{ + value: ClickHouseProtocolChoice; + label: string; +}> = [ + { value: "auto", label: "自动" }, + { value: "http", label: "HTTP" }, + { value: "native", label: "Native" }, +]; + +const normalizeClickHouseProtocolValue = ( + value: unknown, +): ClickHouseProtocolChoice => { + const text = String(value || "") + .trim() + .toLowerCase(); + if (text === "http" || text === "https") return "http"; + if (text === "native" || text === "tcp") return "native"; + return "auto"; +}; type ConnectionSecretKey = | "primaryPassword" | "sshPassword" @@ -848,9 +868,7 @@ const ConnectionModal: React.FC<{ } }; - const resolveDriverUnavailableReason = async ( - type: string, - ): Promise => { + const resolveDriverUnavailableReason = async (type: string): Promise => { const normalized = normalizeDriverType(type); if (!normalized || normalized === "custom") { return ""; @@ -1000,6 +1018,13 @@ const ConnectionModal: React.FC<{ } }; + const normalizeUriBool = (raw: unknown) => { + const text = String(raw ?? "") + .trim() + .toLowerCase(); + return text === "1" || text === "true" || text === "yes" || text === "on"; + }; + const normalizeFileDbPath = (rawPath: string): string => { let pathText = String(rawPath || "").trim(); if (!pathText) { @@ -1117,6 +1142,44 @@ const ConnectionModal: React.FC<{ }; }; + const parseClickHouseHTTPUriToValues = ( + uriText: string, + fallbackPort?: number, + ): Record | null => { + const trimmed = String(uriText || "").trim(); + const lower = trimmed.toLowerCase(); + const isHttps = lower.startsWith("https://"); + const isHttp = lower.startsWith("http://"); + if (!isHttp && !isHttps) { + return null; + } + const defaultPort = + Number.isFinite(Number(fallbackPort)) && Number(fallbackPort) > 0 + ? Number(fallbackPort) + : isHttps + ? 8443 + : 8123; + const parsed = parseSingleHostUri( + trimmed, + [isHttps ? "https" : "http"], + defaultPort, + ); + if (!parsed) { + return null; + } + const skipVerify = normalizeUriBool(parsed.params.get("skip_verify")); + return { + host: parsed.host, + port: parsed.port, + user: parsed.username, + password: parsed.password, + database: parsed.database || "", + clickHouseProtocol: "http", + useSSL: isHttps, + sslMode: isHttps ? (skipVerify ? "skip-verify" : "required") : "disable", + }; + }; + const parseUriToValues = ( uriText: string, type: string, @@ -1337,6 +1400,13 @@ const ConnectionModal: React.FC<{ }; } + if (type === "clickhouse") { + const httpValues = parseClickHouseHTTPUriToValues(trimmedUri); + if (httpValues) { + return httpValues; + } + } + const singleHostSchemes = singleHostUriSchemesByType[type]; if (singleHostSchemes && singleHostSchemes.length > 0) { const parsed = parseSingleHostUri( @@ -1412,6 +1482,9 @@ const ConnectionModal: React.FC<{ parsedValues.sslMode = "disable"; } } else if (type === "clickhouse") { + parsedValues.clickHouseProtocol = normalizeClickHouseProtocolValue( + parsed.params.get("protocol"), + ); const secure = String( parsed.params.get("secure") || parsed.params.get("tls") || "", ) @@ -1707,7 +1780,18 @@ const ConnectionModal: React.FC<{ return `${scheme}://${encodedAuth}${hosts.join(",")}${dbPath}${query ? `?${query}` : ""}`; } - const scheme = type === "postgres" ? "postgresql" : type; + const clickHouseProtocol = + type === "clickhouse" + ? normalizeClickHouseProtocolValue(values.clickHouseProtocol) + : "auto"; + const scheme = + type === "postgres" + ? "postgresql" + : type === "clickhouse" && clickHouseProtocol === "http" + ? values.useSSL + ? "https" + : "http" + : type; const dbPath = database ? `/${encodeURIComponent(database)}` : ""; const params = new URLSearchParams(); if (supportsSSLForType(type) && values.useSSL) { @@ -1728,9 +1812,15 @@ const ConnectionModal: React.FC<{ mode === "skip-verify" || mode === "preferred" ? "true" : "false", ); } else if (type === "clickhouse") { - params.set("secure", "true"); - if (mode === "skip-verify" || mode === "preferred") { - params.set("skip_verify", "true"); + if (clickHouseProtocol === "http") { + if (mode === "skip-verify" || mode === "preferred") { + params.set("skip_verify", "true"); + } + } else { + params.set("secure", "true"); + if (mode === "skip-verify" || mode === "preferred") { + params.set("skip_verify", "true"); + } } } else if (type === "dameng") { const certPath = String(values.sslCertPath || "").trim(); @@ -1761,6 +1851,9 @@ const ConnectionModal: React.FC<{ params.set("protocol", "ws"); } } + if (type === "clickhouse" && clickHouseProtocol !== "auto") { + params.set("protocol", clickHouseProtocol); + } const query = params.toString(); return `${scheme}://${encodedAuth}${toAddress(host, port, defaultPort)}${dbPath}${query ? `?${query}` : ""}`; }; @@ -1967,6 +2060,10 @@ const ConnectionModal: React.FC<{ password: config.password, database: config.database, uri: config.uri || "", + clickHouseProtocol: + configType === "clickhouse" + ? normalizeClickHouseProtocolValue(config.clickHouseProtocol) + : "auto", includeDatabases: initialValues.includeDatabases, includeRedisDatabases: initialValues.includeRedisDatabases, useSSL: !!config.useSSL, @@ -2285,9 +2382,7 @@ const ConnectionModal: React.FC<{ try { await form.validateFields(); const values = form.getFieldsValue(true); - const unavailableReason = await resolveDriverUnavailableReason( - values.type, - ); + const unavailableReason = await resolveDriverUnavailableReason(values.type); if (unavailableReason) { message.warning(unavailableReason); promptInstallDriver(values.type, unavailableReason); @@ -2443,9 +2538,7 @@ const ConnectionModal: React.FC<{ try { await form.validateFields(); const values = form.getFieldsValue(true); - const unavailableReason = await resolveDriverUnavailableReason( - values.type, - ); + const unavailableReason = await resolveDriverUnavailableReason(values.type); if (unavailableReason) { applyTestFailureFeedback( resolveConnectionTestFailureFeedback({ @@ -2740,6 +2833,15 @@ const ConnectionModal: React.FC<{ (Array.isArray(value) && value.length === 0); if (parsedUriValues) { Object.entries(parsedUriValues).forEach(([key, value]) => { + if ( + key === "clickHouseProtocol" && + normalizeClickHouseProtocolValue((mergedValues as any)[key]) === + "auto" && + normalizeClickHouseProtocolValue(value) !== "auto" + ) { + (mergedValues as any)[key] = value; + return; + } if (isEmptyField((mergedValues as any)[key])) { (mergedValues as any)[key] = value; } @@ -2748,6 +2850,35 @@ const ConnectionModal: React.FC<{ const type = String(mergedValues.type || "").toLowerCase(); const defaultPort = getDefaultPortByType(type); + if (type === "clickhouse") { + const requestedProtocol = normalizeClickHouseProtocolValue( + mergedValues.clickHouseProtocol, + ); + const hostSchemeValues = parseClickHouseHTTPUriToValues( + mergedValues.host, + Number(mergedValues.port || defaultPort), + ); + if (hostSchemeValues) { + mergedValues.host = hostSchemeValues.host; + mergedValues.port = hostSchemeValues.port; + if (requestedProtocol !== "native") { + mergedValues.clickHouseProtocol = "http"; + mergedValues.useSSL = hostSchemeValues.useSSL; + mergedValues.sslMode = hostSchemeValues.sslMode; + } else { + mergedValues.clickHouseProtocol = "native"; + } + if (isEmptyField(mergedValues.user)) { + mergedValues.user = hostSchemeValues.user; + } + if (isEmptyField(mergedValues.password)) { + mergedValues.password = hostSchemeValues.password; + } + if (isEmptyField(mergedValues.database)) { + mergedValues.database = hostSchemeValues.database; + } + } + } const isFileDbType = isFileDatabaseType(type); const sslCapableType = supportsSSLForType(type); @@ -2990,6 +3121,10 @@ const ConnectionModal: React.FC<{ ? Math.max(0, Math.min(15, Math.trunc(Number(mergedValues.redisDB)))) : 0, uri: String(mergedValues.uri || "").trim(), + clickHouseProtocol: + type === "clickhouse" + ? normalizeClickHouseProtocolValue(mergedValues.clickHouseProtocol) + : undefined, hosts: hosts, topology: topology, mysqlReplicaUser: mysqlReplicaUser, @@ -3017,7 +3152,10 @@ const ConnectionModal: React.FC<{ } setTypeSelectWarning(null); setDbType(type); - form.setFieldsValue({ type: type }); + form.setFieldsValue({ + type: type, + clickHouseProtocol: type === "clickhouse" ? "auto" : undefined, + }); const defaultPort = getDefaultPortByType(type); if (type === "jvm") { @@ -4294,6 +4432,25 @@ const ConnectionModal: React.FC<{ ), })} + {dbType === "clickhouse" && + renderConfigSectionCard({ + sectionKey: "connectionMode", + icon: , + children: ( + +