diff --git a/frontend/src/components/QueryEditor.external-sql-save.test.tsx b/frontend/src/components/QueryEditor.external-sql-save.test.tsx index 3ad5db5..5547a5f 100644 --- a/frontend/src/components/QueryEditor.external-sql-save.test.tsx +++ b/frontend/src/components/QueryEditor.external-sql-save.test.tsx @@ -43,6 +43,11 @@ const storeState = vi.hoisted(() => ({ showQueryResultsPanel: false, }, setQueryOptions: vi.fn(), + sqlEditorTransactionOptions: { + commitMode: 'manual' as 'manual' | 'auto', + autoCommitDelayMs: 5000, + }, + setSqlEditorTransactionOptions: vi.fn(), shortcutOptions: { runQuery: { mac: { enabled: false, combo: '' }, @@ -70,6 +75,9 @@ const backendApp = vi.hoisted(() => ({ DBQuery: vi.fn(), DBQueryWithCancel: vi.fn(), DBQueryMulti: vi.fn(), + DBQueryMultiTransactional: vi.fn(), + DBCommitTransaction: vi.fn(), + DBRollbackTransaction: vi.fn(), DBGetTables: vi.fn(), DBGetAllColumns: vi.fn(), DBGetDatabases: vi.fn(), @@ -449,6 +457,10 @@ describe('QueryEditor external SQL save', () => { showColumnType: true, showQueryResultsPanel: false, }; + storeState.sqlEditorTransactionOptions = { + commitMode: 'manual', + autoCommitDelayMs: 5000, + }; storeState.shortcutOptions = { runQuery: { mac: { enabled: false, combo: '' }, @@ -471,13 +483,21 @@ describe('QueryEditor external SQL save', () => { storeState.setQueryOptions.mockImplementation((options: Record) => { storeState.queryOptions = { ...storeState.queryOptions, ...options }; }); + storeState.setSqlEditorTransactionOptions.mockReset(); + storeState.setSqlEditorTransactionOptions.mockImplementation((options: Record) => { + storeState.sqlEditorTransactionOptions = { ...storeState.sqlEditorTransactionOptions, ...options }; + }); messageApi.success.mockReset(); messageApi.error.mockReset(); messageApi.warning.mockReset(); backendApp.DBQuery.mockResolvedValue({ success: true, data: [] }); backendApp.WriteSQLFile.mockResolvedValue({ success: true }); backendApp.ExportSQLFile.mockResolvedValue({ success: true }); + backendApp.DBQueryWithCancel.mockResolvedValue({ success: true, data: [] }); backendApp.DBQueryMulti.mockResolvedValue({ success: true, data: [] }); + backendApp.DBQueryMultiTransactional.mockResolvedValue({ success: true, data: [] }); + backendApp.DBCommitTransaction.mockResolvedValue({ success: true, message: '事务已提交' }); + backendApp.DBRollbackTransaction.mockResolvedValue({ success: true, message: '事务已回滚' }); backendApp.DBGetColumns.mockResolvedValue({ success: true, data: [] }); backendApp.DBGetIndexes.mockResolvedValue({ success: true, data: [] }); backendApp.DBGetAllColumns.mockResolvedValue({ success: true, data: [] }); @@ -2117,6 +2137,96 @@ describe('QueryEditor external SQL save', () => { expect(pageText).toContain('原始错误:pq: syntax error at or near "from"'); }); + it('runs SQL editor DML through a pending managed transaction and commits manually', async () => { + backendApp.DBQueryMultiTransactional.mockResolvedValueOnce({ + success: true, + transactionId: 'tx-1', + transactionPending: true, + data: [ + { columns: ['affectedRows'], rows: [{ affectedRows: 2 }], statementIndex: 1 }, + ], + }); + + let renderer!: ReactTestRenderer; + await act(async () => { + renderer = create(); + }); + + await act(async () => { + await findButton(renderer!, '运行').props.onClick(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(backendApp.DBQueryMultiTransactional).toHaveBeenCalledWith( + expect.anything(), + 'main', + expect.stringContaining('UPDATE users SET name'), + 'query-1', + ); + expect(backendApp.DBQueryMulti).not.toHaveBeenCalled(); + expect(textContent(renderer!.root)).toContain('事务待提交'); + expect(textContent(renderer!.root)).toContain('影响行数:2'); + + await act(async () => { + await findExactButton(renderer!, '提交').props.onClick(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(backendApp.DBCommitTransaction).toHaveBeenCalledWith('tx-1'); + expect(textContent(renderer!.root)).not.toContain('事务待提交'); + }); + + it('auto commits SQL editor DML transactions after the configured delay', async () => { + vi.useFakeTimers(); + storeState.sqlEditorTransactionOptions = { + commitMode: 'auto', + autoCommitDelayMs: 3000, + }; + backendApp.DBQueryMultiTransactional.mockResolvedValueOnce({ + success: true, + transactionId: 'tx-auto', + transactionPending: true, + data: [ + { columns: ['affectedRows'], rows: [{ affectedRows: 1 }], statementIndex: 1 }, + ], + }); + + try { + let renderer!: ReactTestRenderer; + await act(async () => { + renderer = create(); + }); + + await act(async () => { + await findButton(renderer!, '运行').props.onClick(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(textContent(renderer!.root)).toContain('事务待提交'); + expect(backendApp.DBCommitTransaction).not.toHaveBeenCalled(); + + await act(async () => { + vi.advanceTimersByTime(3000); + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(backendApp.DBCommitTransaction).toHaveBeenCalledWith('tx-auto'); + expect(backendApp.DBQueryMulti).not.toHaveBeenCalled(); + } finally { + vi.useRealTimers(); + } + }); + it('automatically appends hidden primary key locator columns for editable query results', async () => { storeState.connections[0].config.type = 'oracle'; storeState.connections[0].config.database = 'ORCLPDB1'; @@ -3334,6 +3444,8 @@ describe('QueryEditor external SQL save', () => { expect(source).toContain('gn-v2-query-toolbar-connection-select'); expect(source).toContain('gn-v2-query-toolbar-database-select'); expect(source).toContain('gn-v2-query-toolbar-max-rows-select'); + expect(source).toContain('gn-v2-query-toolbar-transaction-mode-select'); + expect(source).toContain('gn-v2-query-toolbar-transaction-delay-select'); expect(source).toContain('gn-v2-query-toolbar-action-group'); expect(source).toContain('style={isV2Ui ? undefined : { width: 150 }}'); expect(source).toContain('style={isV2Ui ? undefined : { width: 200 }}'); @@ -3348,10 +3460,12 @@ describe('QueryEditor external SQL save', () => { expect(css).toContain('display: inline-flex !important;'); expect(css).toContain('gap: 6px;'); expect(css).toContain('margin-left: 0 !important;'); - expect(css).toContain('max-width: 520px;'); + expect(css).toContain('max-width: 720px;'); expect(css).toContain('width: 140px !important;'); expect(css).toContain('width: 166px !important;'); expect(css).toContain('width: 132px !important;'); + expect(css).toContain('width: 112px !important;'); + expect(css).toContain('width: 82px !important;'); expect(css).toContain('width: 34px !important;'); expect(css).toContain('@media (max-width: 900px)'); diff --git a/frontend/src/components/QueryEditor.tsx b/frontend/src/components/QueryEditor.tsx index e04a850..236f9e1 100644 --- a/frontend/src/components/QueryEditor.tsx +++ b/frontend/src/components/QueryEditor.tsx @@ -6,7 +6,7 @@ import { format } from 'sql-formatter'; import { v4 as uuidv4 } from 'uuid'; import { TabData, ColumnDefinition, IndexDefinition } from '../types'; import { useStore } from '../store'; -import { DBQuery, DBQueryWithCancel, DBQueryMulti, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, DBGetIndexes, CancelQuery, GenerateQueryID, WriteSQLFile, ExportSQLFile } from '../../wailsjs/go/app/App'; +import { DBQuery, DBQueryWithCancel, DBQueryMulti, DBQueryMultiTransactional, DBCommitTransaction, DBRollbackTransaction, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns, DBGetIndexes, CancelQuery, GenerateQueryID, WriteSQLFile, ExportSQLFile } from '../../wailsjs/go/app/App'; import { GONAVI_ROW_KEY } from './DataGrid'; import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities'; import { applyMongoQueryAutoLimit, convertMongoShellToJsonCommand } from "../utils/mongodb"; @@ -750,6 +750,67 @@ const areSqlStatementListsEqual = (left: string[], right: string[]): boolean => && left.every((statement, index) => normalizeExecutedSqlKey(statement) === normalizeExecutedSqlKey(right[index])) ); +const SQL_EDITOR_DML_KEYWORDS = new Set(['insert', 'update', 'delete', 'replace', 'merge', 'upsert']); +const SQL_EDITOR_READ_KEYWORDS = new Set(['select', 'with', 'show', 'describe', 'desc', 'explain', 'pragma', 'values']); +const SQL_EDITOR_TRANSACTION_CONTROL_KEYWORDS = new Set(['begin', 'commit', 'rollback', 'savepoint', 'release']); +const SQL_EDITOR_AUTO_COMMIT_DELAY_OPTIONS = [ + { value: 3000, label: '3 秒' }, + { value: 5000, label: '5 秒' }, + { value: 10000, label: '10 秒' }, + { value: 30000, label: '30 秒' }, +]; + +const resolveLeadingSqlKeyword = (statement: string): string => { + let text = String(statement || '').trim(); + while (text) { + if (text.startsWith('--') || text.startsWith('#')) { + const lineBreak = text.indexOf('\n'); + if (lineBreak < 0) return ''; + text = text.slice(lineBreak + 1).trimStart(); + continue; + } + if (text.startsWith('/*')) { + const blockEnd = text.indexOf('*/'); + if (blockEnd < 0) return ''; + text = text.slice(blockEnd + 2).trimStart(); + continue; + } + break; + } + const match = text.match(/^([A-Za-z0-9_]+)/); + return match?.[1]?.toLowerCase() || ''; +}; + +const isSqlEditorTransactionControlStatement = (statement: string): boolean => { + const keyword = resolveLeadingSqlKeyword(statement); + if (SQL_EDITOR_TRANSACTION_CONTROL_KEYWORDS.has(keyword)) return true; + return keyword === 'start' && /\btransaction\b/i.test(statement); +}; + +const shouldUseSqlEditorManagedTransaction = (statements: string[]): boolean => { + let hasManagedWrite = false; + for (const statement of statements) { + const trimmed = String(statement || '').trim(); + if (!trimmed) continue; + if (isSqlEditorTransactionControlStatement(trimmed)) return false; + const keyword = resolveLeadingSqlKeyword(trimmed); + if (SQL_EDITOR_READ_KEYWORDS.has(keyword)) continue; + if (SQL_EDITOR_DML_KEYWORDS.has(keyword)) { + hasManagedWrite = true; + continue; + } + return false; + } + return hasManagedWrite; +}; + +type PendingSqlEditorTransaction = { + id: string; + commitMode: 'manual' | 'auto'; + autoCommitDelayMs: number; + createdAt: number; +}; + const normalizeEditorPosition = (position: any): { lineNumber: number; column: number } | null => { if (!position) return null; const lineNumber = Number(position.positionLineNumber ?? position.lineNumber ?? position.endLineNumber ?? position.startLineNumber ?? position.selectionStartLineNumber); @@ -2031,9 +2092,16 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc const setSqlFormatOptions = useStore(state => state.setSqlFormatOptions); const queryOptions = useStore(state => state.queryOptions); const setQueryOptions = useStore(state => state.setQueryOptions); + const sqlEditorTransactionOptions = useStore(state => state.sqlEditorTransactionOptions); + const setSqlEditorTransactionOptions = useStore(state => state.setSqlEditorTransactionOptions); const [isResultPanelVisible, setIsResultPanelVisible] = useState( () => tab.resultPanelVisible === true ); + const [pendingSqlTransaction, setPendingSqlTransaction] = useState(null); + const pendingSqlTransactionRef = useRef(null); + const sqlEditorAutoCommitTimerRef = useRef | null>(null); + const sqlEditorAutoCommitCountdownRef = useRef | null>(null); + const [sqlEditorAutoCommitRemainingSeconds, setSqlEditorAutoCommitRemainingSeconds] = useState(null); const shortcutOptions = useStore(state => state.shortcutOptions); const activeShortcutPlatform = getShortcutPlatform(isMacLikePlatform()); const runQueryShortcutBinding = useMemo( @@ -2070,6 +2138,89 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc return nextVisible; }); }, [tab.id, updateQueryTabDraft]); + const sqlEditorCommitMode = sqlEditorTransactionOptions?.commitMode === 'auto' ? 'auto' : 'manual'; + const sqlEditorAutoCommitDelayMs = SQL_EDITOR_AUTO_COMMIT_DELAY_OPTIONS.some((item) => item.value === sqlEditorTransactionOptions?.autoCommitDelayMs) + ? Number(sqlEditorTransactionOptions?.autoCommitDelayMs) + : 5000; + const clearSqlEditorAutoCommitTimer = useCallback(() => { + if (sqlEditorAutoCommitTimerRef.current) { + clearTimeout(sqlEditorAutoCommitTimerRef.current); + sqlEditorAutoCommitTimerRef.current = null; + } + if (sqlEditorAutoCommitCountdownRef.current) { + clearInterval(sqlEditorAutoCommitCountdownRef.current); + sqlEditorAutoCommitCountdownRef.current = null; + } + setSqlEditorAutoCommitRemainingSeconds(null); + }, []); + const updatePendingSqlTransaction = useCallback((transaction: PendingSqlEditorTransaction | null) => { + pendingSqlTransactionRef.current = transaction; + setPendingSqlTransaction(transaction); + }, []); + const finishPendingSqlTransaction = useCallback(async ( + action: 'commit' | 'rollback', + source: 'manual' | 'auto' = 'manual', + transactionId?: string, + ) => { + const transaction = pendingSqlTransactionRef.current; + if (!transaction || (transactionId && transaction.id !== transactionId)) { + return; + } + clearSqlEditorAutoCommitTimer(); + try { + const res = action === 'commit' + ? await DBCommitTransaction(transaction.id) + : await DBRollbackTransaction(transaction.id); + if (res?.success) { + updatePendingSqlTransaction(null); + if (action === 'commit') { + message.success(source === 'auto' ? 'SQL 事务已自动提交' : 'SQL 事务已提交'); + } else { + message.success('SQL 事务已回滚'); + } + return; + } + updatePendingSqlTransaction(null); + const fallback = action === 'commit' ? '提交失败' : '回滚失败'; + message.error(`${source === 'auto' ? '自动提交失败' : fallback}: ${formatSqlExecutionError(res?.message || '未知错误')}`); + } catch (err: any) { + updatePendingSqlTransaction(null); + const fallback = action === 'commit' ? '提交失败' : '回滚失败'; + message.error(`${source === 'auto' ? '自动提交失败' : fallback}: ${formatSqlExecutionError(err?.message || err || '未知错误')}`); + } + }, [clearSqlEditorAutoCommitTimer, updatePendingSqlTransaction]); + const activatePendingSqlTransaction = useCallback((transaction: PendingSqlEditorTransaction) => { + clearSqlEditorAutoCommitTimer(); + updatePendingSqlTransaction(transaction); + if (transaction.commitMode !== 'auto') { + return; + } + const dueAt = Date.now() + transaction.autoCommitDelayMs; + const updateRemaining = () => { + setSqlEditorAutoCommitRemainingSeconds(Math.max(1, Math.ceil((dueAt - Date.now()) / 1000))); + }; + updateRemaining(); + sqlEditorAutoCommitCountdownRef.current = setInterval(updateRemaining, 250); + sqlEditorAutoCommitTimerRef.current = setTimeout(() => { + sqlEditorAutoCommitTimerRef.current = null; + if (sqlEditorAutoCommitCountdownRef.current) { + clearInterval(sqlEditorAutoCommitCountdownRef.current); + sqlEditorAutoCommitCountdownRef.current = null; + } + setSqlEditorAutoCommitRemainingSeconds(null); + void finishPendingSqlTransaction('commit', 'auto', transaction.id); + }, transaction.autoCommitDelayMs); + }, [clearSqlEditorAutoCommitTimer, finishPendingSqlTransaction, updatePendingSqlTransaction]); + useEffect(() => { + return () => { + clearSqlEditorAutoCommitTimer(); + const transaction = pendingSqlTransactionRef.current; + if (transaction?.id) { + pendingSqlTransactionRef.current = null; + void DBRollbackTransaction(transaction.id); + } + }; + }, [clearSqlEditorAutoCommitTimer]); const autoFetchVisible = useAutoFetchVisibility(); const currentSavedQuery = useMemo(() => { @@ -4297,6 +4448,11 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc message.info('没有可执行的 SQL。'); return; } + const useManagedTransaction = shouldUseSqlEditorManagedTransaction(sourceStatements); + if (useManagedTransaction && pendingSqlTransactionRef.current) { + message.warning('当前 SQL 编辑器已有未提交事务,请先提交或回滚后再执行新的增删改语句。'); + return; + } const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult; const statementPlans: QueryStatementPlan[] = []; @@ -4331,7 +4487,8 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc } setQueryId(queryId); - const res = await DBQueryMulti(buildRpcConnectionConfig(config) as any, currentDb, fullSQL, queryId); + const queryExecutor = useManagedTransaction ? DBQueryMultiTransactional : DBQueryMulti; + const res = await queryExecutor(buildRpcConnectionConfig(config) as any, currentDb, fullSQL, queryId); const duration = Date.now() - startTime; addSqlLog({ @@ -4373,6 +4530,15 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc return; } + if (useManagedTransaction && res.transactionPending && res.transactionId) { + activatePendingSqlTransaction({ + id: String(res.transactionId), + commitMode: sqlEditorCommitMode, + autoCommitDelayMs: sqlEditorAutoCommitDelayMs, + createdAt: Date.now(), + }); + } + // res.data 是 ResultSetData[] 数组 const resultSetDataArray = Array.isArray(res.data) ? (res.data as any[]) : []; const nextResultSets: ResultSet[] = []; @@ -5122,6 +5288,39 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc }, wasClosed ? 350 : 0); }; + const sqlEditorTransactionToolbar = pendingSqlTransaction ? ( +
+ + {pendingSqlTransaction.commitMode === 'auto' && sqlEditorAutoCommitRemainingSeconds !== null + ? `事务待提交,${sqlEditorAutoCommitRemainingSeconds}s 后自动提交` + : '事务待提交'} + + + +
+ ) : null; + return (
= ({ tab, isAc ]} /> + + setSqlEditorTransactionOptions({ autoCommitDelayMs: Number(delayMs) })} + options={SQL_EDITOR_AUTO_COMMIT_DELAY_OPTIONS} + /> + )} + {pendingSqlTransaction && sqlEditorTransactionToolbar}
= ({ tab, isAc currentDb={currentDb} currentConnectionId={currentConnectionId} toggleShortcutLabel={toggleQueryResultsPanelShortcutLabel} + transactionToolbar={sqlEditorTransactionToolbar} onActiveResultKeyChange={setActiveResultKey} onHide={() => updateResultPanelVisibility(false)} onCloseResult={handleCloseResult} diff --git a/frontend/src/components/QueryEditorResultsPanel.tsx b/frontend/src/components/QueryEditorResultsPanel.tsx index 33877f4..7a710d9 100644 --- a/frontend/src/components/QueryEditorResultsPanel.tsx +++ b/frontend/src/components/QueryEditorResultsPanel.tsx @@ -33,6 +33,7 @@ interface QueryEditorResultsPanelProps { currentDb: string; currentConnectionId: string; toggleShortcutLabel: string; + transactionToolbar?: React.ReactNode; onActiveResultKeyChange: (key: string) => void; onHide: () => void; onCloseResult: (key: string) => void; @@ -57,6 +58,7 @@ const QueryEditorResultsPanel: React.FC = ({ currentDb, currentConnectionId, toggleShortcutLabel, + transactionToolbar, onActiveResultKeyChange, onHide, onCloseResult, @@ -132,6 +134,16 @@ const QueryEditorResultsPanel: React.FC = ({ /> ); + const tabsExtraContent = transactionToolbar || !activeResultUsesDataGrid + ? { + right: ( +
+ {transactionToolbar} + {!activeResultUsesDataGrid ? tabsHideButton : null} +
+ ), + } + : undefined; const toolbarHideButton = ( @@ -321,7 +333,7 @@ const QueryEditorResultsPanel: React.FC = ({ onChange={onActiveResultKeyChange} animated={false} style={{ flex: 1, minHeight: 0 }} - tabBarExtraContent={!activeResultUsesDataGrid ? { right: tabsHideButton } : undefined} + tabBarExtraContent={tabsExtraContent} items={resultSets.map((rs, idx) => ({ key: rs.key, label: ( diff --git a/frontend/src/store.ts b/frontend/src/store.ts index 38b3ac0..06426fa 100644 --- a/frontend/src/store.ts +++ b/frontend/src/store.ts @@ -1119,6 +1119,11 @@ export interface DataEditTransactionOptions { autoCommitDelayMs: number; } +export interface SqlEditorTransactionOptions { + commitMode: "manual" | "auto"; + autoCommitDelayMs: number; +} + interface AppState { connections: SavedConnection[]; connectionTags: ConnectionTag[]; @@ -1137,6 +1142,7 @@ interface AppState { sqlFormatOptions: { keywordCase: "upper" | "lower" }; queryOptions: QueryOptions; dataEditTransactionOptions: DataEditTransactionOptions; + sqlEditorTransactionOptions: SqlEditorTransactionOptions; shortcutOptions: ShortcutOptions; sqlSnippets: SqlSnippet[]; sqlLogs: SqlLog[]; @@ -1245,6 +1251,9 @@ interface AppState { setDataEditTransactionOptions: ( options: Partial, ) => void; + setSqlEditorTransactionOptions: ( + options: Partial, + ) => void; updateShortcut: ( action: ShortcutAction, binding: Partial, @@ -1614,6 +1623,7 @@ const sanitizeQueryOptions = (value: unknown): QueryOptions => { }; const DATA_EDIT_AUTO_COMMIT_DELAY_OPTIONS = new Set([3000, 5000, 10000, 30000]); +const SQL_EDITOR_AUTO_COMMIT_DELAY_OPTIONS = new Set([3000, 5000, 10000, 30000]); const sanitizeDataEditTransactionOptions = ( value: unknown, @@ -1631,6 +1641,22 @@ const sanitizeDataEditTransactionOptions = ( }; }; +const sanitizeSqlEditorTransactionOptions = ( + value: unknown, +): SqlEditorTransactionOptions => { + const raw = + value && typeof value === "object" + ? (value as Record) + : {}; + const autoCommitDelayMs = Number(raw.autoCommitDelayMs); + return { + commitMode: raw.commitMode === "auto" ? "auto" : "manual", + autoCommitDelayMs: SQL_EDITOR_AUTO_COMMIT_DELAY_OPTIONS.has(autoCommitDelayMs) + ? autoCommitDelayMs + : 5000, + }; +}; + const sanitizeTableAccessCount = (value: unknown): Record => { const raw = value && typeof value === "object" @@ -2021,6 +2047,10 @@ export const useStore = create()( commitMode: "manual", autoCommitDelayMs: 5000, }, + sqlEditorTransactionOptions: { + commitMode: "manual", + autoCommitDelayMs: 5000, + }, shortcutOptions: cloneShortcutOptions(DEFAULT_SHORTCUT_OPTIONS), sqlSnippets: DEFAULT_SQL_SNIPPETS, sqlLogs: [], @@ -2772,6 +2802,13 @@ export const useStore = create()( ...options, }), })), + setSqlEditorTransactionOptions: (options) => + set((state) => ({ + sqlEditorTransactionOptions: sanitizeSqlEditorTransactionOptions({ + ...state.sqlEditorTransactionOptions, + ...options, + }), + })), updateShortcut: (action, binding, platform) => { runWithExplicitShortcutPersistence(() => { const targetPlatform = platform ?? getShortcutPlatform(); @@ -3180,6 +3217,8 @@ export const useStore = create()( nextState.queryOptions = sanitizeQueryOptions(state.queryOptions); nextState.dataEditTransactionOptions = sanitizeDataEditTransactionOptions(state.dataEditTransactionOptions); + nextState.sqlEditorTransactionOptions = + sanitizeSqlEditorTransactionOptions(state.sqlEditorTransactionOptions); nextState.shortcutOptions = sanitizeShortcutOptions( state.shortcutOptions, ); @@ -3285,6 +3324,9 @@ export const useStore = create()( dataEditTransactionOptions: sanitizeDataEditTransactionOptions( state.dataEditTransactionOptions, ), + sqlEditorTransactionOptions: sanitizeSqlEditorTransactionOptions( + state.sqlEditorTransactionOptions, + ), shortcutOptions: sanitizeShortcutOptions(state.shortcutOptions), sqlLogs: sanitizeSqlLogs(state.sqlLogs), sqlSnippets: sanitizeSqlSnippets(state.sqlSnippets), @@ -3316,6 +3358,7 @@ export const useStore = create()( sqlFormatOptions: state.sqlFormatOptions, queryOptions: state.queryOptions, dataEditTransactionOptions: state.dataEditTransactionOptions, + sqlEditorTransactionOptions: state.sqlEditorTransactionOptions, shortcutOptions: resolveShortcutOptionsForPersistence(state.shortcutOptions), sqlLogs: sanitizeSqlLogs(state.sqlLogs), sqlSnippets: state.sqlSnippets, diff --git a/frontend/src/v2-theme.css b/frontend/src/v2-theme.css index f64a9df..ebc94b0 100644 --- a/frontend/src/v2-theme.css +++ b/frontend/src/v2-theme.css @@ -4810,7 +4810,7 @@ body[data-ui-version="v2"] .gn-v2-query-toolbar-actions { body[data-ui-version="v2"] .gn-v2-query-toolbar-selects { flex: 0 1 auto !important; flex-wrap: nowrap; - max-width: 520px; + max-width: 720px; } body[data-ui-version="v2"] .gn-v2-query-toolbar-actions { @@ -4839,6 +4839,16 @@ body[data-ui-version="v2"] .gn-v2-query-toolbar-max-rows-select { flex: 0 0 132px !important; } +body[data-ui-version="v2"] .gn-v2-query-toolbar-transaction-mode-select { + width: 112px !important; + flex: 0 0 112px !important; +} + +body[data-ui-version="v2"] .gn-v2-query-toolbar-transaction-delay-select { + width: 82px !important; + flex: 0 0 82px !important; +} + body[data-ui-version="v2"] .gn-v2-query-toolbar .ant-select-selector { height: 32px !important; padding: 0 10px !important; diff --git a/frontend/wailsjs/go/app/App.d.ts b/frontend/wailsjs/go/app/App.d.ts index 9dc835f..8596db0 100755 --- a/frontend/wailsjs/go/app/App.d.ts +++ b/frontend/wailsjs/go/app/App.d.ts @@ -52,14 +52,20 @@ export function DBGetTriggers(arg1:connection.ConnectionConfig,arg2:string,arg3: export function DBQuery(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise; +export function DBCommitTransaction(arg1:string):Promise; + export function DBQueryIsolated(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise; export function DBQueryMulti(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise; +export function DBQueryMultiTransactional(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise; + export function DBQueryWithCancel(arg1:connection.ConnectionConfig,arg2:string,arg3:string,arg4:string):Promise; export function DBShowCreateTable(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise; +export function DBRollbackTransaction(arg1:string):Promise; + export function DataSync(arg1:sync.SyncConfig):Promise; export function DataSyncAnalyze(arg1:sync.SyncConfig):Promise; diff --git a/frontend/wailsjs/go/app/App.js b/frontend/wailsjs/go/app/App.js index 5bf55fe..637e28a 100755 --- a/frontend/wailsjs/go/app/App.js +++ b/frontend/wailsjs/go/app/App.js @@ -94,6 +94,10 @@ export function DBQuery(arg1, arg2, arg3) { return window['go']['app']['App']['DBQuery'](arg1, arg2, arg3); } +export function DBCommitTransaction(arg1) { + return window['go']['app']['App']['DBCommitTransaction'](arg1); +} + export function DBQueryIsolated(arg1, arg2, arg3) { return window['go']['app']['App']['DBQueryIsolated'](arg1, arg2, arg3); } @@ -102,6 +106,10 @@ export function DBQueryMulti(arg1, arg2, arg3, arg4) { return window['go']['app']['App']['DBQueryMulti'](arg1, arg2, arg3, arg4); } +export function DBQueryMultiTransactional(arg1, arg2, arg3, arg4) { + return window['go']['app']['App']['DBQueryMultiTransactional'](arg1, arg2, arg3, arg4); +} + export function DBQueryWithCancel(arg1, arg2, arg3, arg4) { return window['go']['app']['App']['DBQueryWithCancel'](arg1, arg2, arg3, arg4); } @@ -110,6 +118,10 @@ export function DBShowCreateTable(arg1, arg2, arg3) { return window['go']['app']['App']['DBShowCreateTable'](arg1, arg2, arg3); } +export function DBRollbackTransaction(arg1) { + return window['go']['app']['App']['DBRollbackTransaction'](arg1); +} + export function DataSync(arg1) { return window['go']['app']['App']['DataSync'](arg1); } diff --git a/frontend/wailsjs/go/models.ts b/frontend/wailsjs/go/models.ts index 7cf42f9..5d135ee 100755 --- a/frontend/wailsjs/go/models.ts +++ b/frontend/wailsjs/go/models.ts @@ -981,6 +981,8 @@ export namespace connection { fields?: string[]; messages?: string[]; queryId?: string; + transactionId?: string; + transactionPending?: boolean; static createFrom(source: any = {}) { return new QueryResult(source); @@ -994,6 +996,8 @@ export namespace connection { this.fields = source["fields"]; this.messages = source["messages"]; this.queryId = source["queryId"]; + this.transactionId = source["transactionId"]; + this.transactionPending = source["transactionPending"]; } } diff --git a/internal/app/app.go b/internal/app/app.go index d394f7d..886d0c1 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -55,6 +55,15 @@ type queryContext struct { started time.Time } +type managedSQLTransaction struct { + id string + execer db.StatementExecer + dbType string + commitSQL string + rollbackSQL string + createdAt time.Time +} + // App struct type App struct { ctx context.Context @@ -68,6 +77,8 @@ type App struct { configDir string secretStore secretstore.SecretStore runningQueries map[string]queryContext // queryID -> cancelFunc and start time + sqlTransactionMu sync.Mutex + sqlTransactions map[string]*managedSQLTransaction jvmPreviewTokenMu sync.Mutex jvmPreviewTokens map[string]jvmPreviewConfirmationToken jvmPreviewTokenTTL time.Duration @@ -86,6 +97,7 @@ func NewAppWithSecretStore(store secretstore.SecretStore) *App { dbCache: make(map[string]cachedDatabase), connectFailures: make(map[string]cachedConnectFailure), runningQueries: make(map[string]queryContext), + sqlTransactions: make(map[string]*managedSQLTransaction), configDir: resolveAppConfigDir(), secretStore: store, jvmPreviewTokens: make(map[string]jvmPreviewConfirmationToken), @@ -167,6 +179,7 @@ func (a *App) LogWindowDiagnostic(stage string, payload string) { // Shutdown is called when the app terminates func (a *App) Shutdown(ctx context.Context) { logger.Infof("应用开始关闭,准备释放资源") + a.rollbackPendingSQLTransactionsOnShutdown() a.mu.Lock() defer a.mu.Unlock() for _, dbInst := range a.dbCache { diff --git a/internal/app/methods_db_multi_test.go b/internal/app/methods_db_multi_test.go index 012ab2c..37f9957 100644 --- a/internal/app/methods_db_multi_test.go +++ b/internal/app/methods_db_multi_test.go @@ -2,6 +2,7 @@ package app import ( "context" + "errors" "testing" "GoNavi-Wails/internal/connection" @@ -21,6 +22,7 @@ type fakeBatchWriteDB struct { messageMap map[string][]string multiResult map[string][]connection.ResultSetData queryErr map[string]error + execErr map[string]error execAffected map[string]int64 session *fakeBatchWriteSession } @@ -53,6 +55,9 @@ func (f *fakeBatchWriteDB) QueryWithMessages(query string) ([]map[string]interfa func (f *fakeBatchWriteDB) Exec(query string) (int64, error) { f.execCalls++ f.execQueries = append(f.execQueries, query) + if err := f.execErr[query]; err != nil { + return 0, err + } if affected, ok := f.execAffected[query]; ok { return affected, nil } @@ -95,6 +100,9 @@ func (f *fakeBatchWriteDB) ExecContext(ctx context.Context, query string) (int64 f.lastCtx = ctx f.execCalls++ f.execQueries = append(f.execQueries, query) + if err := f.execErr[query]; err != nil { + return 0, err + } if affected, ok := f.execAffected[query]; ok { return affected, nil } @@ -506,6 +514,185 @@ func TestDBQueryMultiPreservesPerStatementResultsForMultipleWriteStatements(t *t } } +func TestDBQueryMultiTransactionalKeepsDMLTransactionOpenUntilCommit(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + firstStmt := "UPDATE users SET name = 'new' WHERE id = 1" + secondStmt := "DELETE FROM audit_logs WHERE user_id = 1" + fakeDB := &fakeBatchWriteDB{ + execAffected: map[string]int64{ + firstStmt: 1, + secondStmt: 3, + }, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "mysql", Host: "127.0.0.1", Port: 3306, User: "root"} + + result := app.DBQueryMultiTransactional(config, "main", firstStmt+";\n"+secondStmt+";", "tx-query") + if !result.Success { + t.Fatalf("expected transactional query success, got failure: %s", result.Message) + } + if result.TransactionID == "" || !result.TransactionPending { + t.Fatalf("expected pending transaction metadata, got id=%q pending=%v", result.TransactionID, result.TransactionPending) + } + if fakeDB.session == nil { + t.Fatal("expected transactional query to open a pinned session") + } + if fakeDB.session.closed { + t.Fatal("expected transaction session to stay open before commit") + } + wantExecs := []string{"START TRANSACTION", firstStmt, secondStmt} + if len(fakeDB.execQueries) != len(wantExecs) { + t.Fatalf("expected exec queries %#v, got %#v", wantExecs, fakeDB.execQueries) + } + for i, want := range wantExecs { + if fakeDB.execQueries[i] != want { + t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i]) + } + } + + resultSets, ok := result.Data.([]connection.ResultSetData) + if !ok { + t.Fatalf("expected []connection.ResultSetData, got %T", result.Data) + } + if len(resultSets) != 2 { + t.Fatalf("expected one affectedRows result per DML statement, got %#v", resultSets) + } + if got := resultSets[0].Rows[0]["affectedRows"]; got != int64(1) { + t.Fatalf("expected first affectedRows=1, got %#v", got) + } + if got := resultSets[1].Rows[0]["affectedRows"]; got != int64(3) { + t.Fatalf("expected second affectedRows=3, got %#v", got) + } + + commitResult := app.DBCommitTransaction(result.TransactionID) + if !commitResult.Success { + t.Fatalf("expected commit success, got failure: %s", commitResult.Message) + } + if !fakeDB.session.closed { + t.Fatal("expected transaction session to close after commit") + } + if got := fakeDB.execQueries[len(fakeDB.execQueries)-1]; got != "COMMIT" { + t.Fatalf("expected final exec to be COMMIT, got %q", got) + } +} + +func TestDBQueryMultiTransactionalRollsBackAndClosesOnDMLFailure(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + firstStmt := "UPDATE users SET name = 'new' WHERE id = 1" + secondStmt := "DELETE FROM audit_logs WHERE user_id = 1" + fakeDB := &fakeBatchWriteDB{ + execErr: map[string]error{ + secondStmt: errors.New("delete failed"), + }, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "mysql", Host: "127.0.0.1", Port: 3306, User: "root"} + + result := app.DBQueryMultiTransactional(config, "main", firstStmt+";\n"+secondStmt+";", "tx-query") + if result.Success { + t.Fatal("expected transactional query failure") + } + if result.TransactionID != "" || result.TransactionPending { + t.Fatalf("expected failed transaction not to be exposed, got id=%q pending=%v", result.TransactionID, result.TransactionPending) + } + if fakeDB.session == nil || !fakeDB.session.closed { + t.Fatal("expected failed transaction session to close") + } + wantExecs := []string{"START TRANSACTION", firstStmt, secondStmt, "ROLLBACK"} + if len(fakeDB.execQueries) != len(wantExecs) { + t.Fatalf("expected exec queries %#v, got %#v", wantExecs, fakeDB.execQueries) + } + for i, want := range wantExecs { + if fakeDB.execQueries[i] != want { + t.Fatalf("expected exec query %d = %q, got %q", i, want, fakeDB.execQueries[i]) + } + } +} + +func TestDBQueryMultiTransactionalSkipsManagedTransactionForReadOnlySQL(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + query := "SELECT 1 AS value" + fakeDB := &fakeBatchWriteDB{ + queryMap: map[string][]map[string]interface{}{ + query: {{"value": 1}}, + }, + fieldMap: map[string][]string{ + query: {"value"}, + }, + queryErr: map[string]error{}, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "mysql", Host: "127.0.0.1", Port: 3306, User: "root"} + + result := app.DBQueryMultiTransactional(config, "main", query, "read-query") + if !result.Success { + t.Fatalf("expected read-only query success, got failure: %s", result.Message) + } + if result.TransactionID != "" || result.TransactionPending { + t.Fatalf("expected read-only query not to start managed transaction, got id=%q pending=%v", result.TransactionID, result.TransactionPending) + } + if len(fakeDB.execQueries) != 0 { + t.Fatalf("expected no transaction wrapper execs for read-only query, got %#v", fakeDB.execQueries) + } +} + +func TestDBQueryMultiTransactionalSkipsManagedTransactionForExplicitTransactionSQL(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + stmt := "UPDATE users SET name = 'new' WHERE id = 1" + fakeDB := &fakeBatchWriteDB{} + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "mysql", Host: "127.0.0.1", Port: 3306, User: "root"} + + result := app.DBQueryMultiTransactional(config, "main", "BEGIN;\n"+stmt+";\nCOMMIT;", "explicit-tx-query") + if !result.Success { + t.Fatalf("expected explicit transaction SQL success, got failure: %s", result.Message) + } + if result.TransactionID != "" || result.TransactionPending { + t.Fatalf("expected explicit transaction SQL not to be managed, got id=%q pending=%v", result.TransactionID, result.TransactionPending) + } + if len(fakeDB.execQueries) != 3 { + t.Fatalf("expected explicit transaction statements only, got %#v", fakeDB.execQueries) + } + if fakeDB.execQueries[0] != "BEGIN" || fakeDB.execQueries[1] != stmt || fakeDB.execQueries[2] != "COMMIT" { + t.Fatalf("expected explicit transaction statements unchanged, got %#v", fakeDB.execQueries) + } + if fakeDB.session == nil || !fakeDB.session.closed { + t.Fatal("expected normal DBQueryMulti session to close after explicit transaction SQL") + } +} + func TestDBQueryMultiPrefersResultSetForExecStoredProcedure(t *testing.T) { originalNewDatabaseFunc := newDatabaseFunc t.Cleanup(func() { diff --git a/internal/app/methods_db_transaction.go b/internal/app/methods_db_transaction.go new file mode 100644 index 0000000..2c6765a --- /dev/null +++ b/internal/app/methods_db_transaction.go @@ -0,0 +1,334 @@ +package app + +import ( + "context" + "fmt" + "strings" + "time" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/db" + "GoNavi-Wails/internal/logger" + "github.com/google/uuid" +) + +const sqlEditorTransactionFinishTimeout = 30 * time.Second + +// DBQueryMultiTransactional executes SQL editor DML in a managed transaction. +// The transaction stays open until DBCommitTransaction or DBRollbackTransaction +// is called by the SQL editor UI. +func (a *App) DBQueryMultiTransactional(config connection.ConnectionConfig, dbName string, query string, queryID string) connection.QueryResult { + runConfig := normalizeRunConfig(config, dbName) + + if queryID == "" { + queryID = generateQueryID() + } + + query = sanitizeSQLForPgLike(resolveDDLDBType(config), query) + if !shouldUseManagedSQLTransaction(runConfig.Type, query) { + return a.DBQueryMulti(config, dbName, query, queryID) + } + + beginSQL, commitSQL, rollbackSQL, ok := sqlFileBatchTransactionSQL(runConfig.Type) + if !ok { + return connection.QueryResult{ + Success: false, + Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type), + QueryID: queryID, + } + } + + dbInst, err := a.getDatabase(runConfig) + if err != nil { + logger.Error(err, "DBQueryMultiTransactional 获取连接失败:%s", formatConnSummary(runConfig)) + return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} + } + + provider, ok := dbInst.(db.SessionExecerProvider) + if !ok { + return connection.QueryResult{ + Success: false, + Message: fmt.Sprintf("当前数据源(%s)不支持 SQL 编辑器托管事务", runConfig.Type), + QueryID: queryID, + } + } + + ctx, cancel := newQueryExecutionContext(runConfig) + defer cancel() + + a.queryMu.Lock() + a.runningQueries[queryID] = queryContext{ + cancel: cancel, + started: time.Now(), + } + a.queryMu.Unlock() + defer func() { + a.queryMu.Lock() + delete(a.runningQueries, queryID) + a.queryMu.Unlock() + }() + + sessionExecer, err := provider.OpenSessionExecer(ctx) + if err != nil { + logger.Error(err, "DBQueryMultiTransactional 打开事务会话失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) + return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} + } + + closeSession := true + defer func() { + if closeSession { + if err := sessionExecer.Close(); err != nil { + logger.Warnf("DBQueryMultiTransactional 关闭事务会话失败:%v", err) + } + } + }() + + if _, err := sessionExecer.ExecContext(ctx, beginSQL); err != nil { + logger.Error(err, "DBQueryMultiTransactional 开启事务失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) + return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} + } + + statements := splitSQLStatements(query) + resultSets, err := executeManagedSQLTransactionStatements(ctx, sessionExecer, runConfig, statements) + if err != nil { + if _, rollbackErr := sessionExecer.ExecContext(context.Background(), rollbackSQL); rollbackErr != nil { + logger.Error(rollbackErr, "DBQueryMultiTransactional 执行失败后回滚失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) + err = fmt.Errorf("%v;回滚失败: %w", err, rollbackErr) + } + logger.Error(err, "DBQueryMultiTransactional 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) + return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} + } + + transactionID := "sql-editor-" + uuid.NewString() + a.sqlTransactionMu.Lock() + if a.sqlTransactions == nil { + a.sqlTransactions = make(map[string]*managedSQLTransaction) + } + a.sqlTransactions[transactionID] = &managedSQLTransaction{ + id: transactionID, + execer: sessionExecer, + dbType: runConfig.Type, + commitSQL: commitSQL, + rollbackSQL: rollbackSQL, + createdAt: time.Now(), + } + a.sqlTransactionMu.Unlock() + + closeSession = false + return connection.QueryResult{ + Success: true, + Data: resultSets, + QueryID: queryID, + TransactionID: transactionID, + TransactionPending: true, + } +} + +func executeManagedSQLTransactionStatements(ctx context.Context, session db.StatementExecer, runConfig connection.ConnectionConfig, statements []string) ([]connection.ResultSetData, error) { + var resultSets []connection.ResultSetData + sessionQueryTarget, _ := session.(db.StatementQueryExecer) + sessionQueryMessageTarget, _ := session.(db.StatementQueryMessageExecer) + sessionMultiQueryTarget, _ := session.(db.StatementMultiResultQueryExecer) + sessionMultiQueryMessageTarget, _ := session.(db.StatementMultiResultQueryMessageExecer) + + for idx, stmt := range statements { + stmt = strings.TrimSpace(stmt) + if stmt == "" { + continue + } + + isReadStmt := isReadOnlySQLQuery(runConfig.Type, stmt) + tryQueryStmtFirst := shouldTryQueryResultFirst(runConfig.Type, stmt) + if isReadStmt || tryQueryStmtFirst { + var ( + data []map[string]interface{} + columns []string + messages []string + statementResults []connection.ResultSetData + usedMultiResult bool + err error + ) + if sessionMultiQueryMessageTarget != nil { + statementResults, messages, err = sessionMultiQueryMessageTarget.QueryMultiContextWithMessages(ctx, stmt) + usedMultiResult = true + } else if sessionMultiQueryTarget != nil { + statementResults, err = sessionMultiQueryTarget.QueryMultiContext(ctx, stmt) + usedMultiResult = true + } else if sessionQueryMessageTarget != nil { + data, columns, messages, err = sessionQueryMessageTarget.QueryContextWithMessages(ctx, stmt) + } else if sessionQueryTarget != nil { + data, columns, err = sessionQueryTarget.QueryContext(ctx, stmt) + } else { + err = fmt.Errorf("当前事务会话不支持查询语句") + } + if err == nil { + if usedMultiResult { + if len(statementResults) == 0 && len(messages) > 0 { + statementResults = []connection.ResultSetData{{ + Rows: []map[string]interface{}{}, + Columns: []string{}, + Messages: append([]string(nil), messages...), + }} + } + for _, statementResult := range statementResults { + if statementResult.Rows == nil { + statementResult.Rows = []map[string]interface{}{} + } + if statementResult.Columns == nil { + statementResult.Columns = []string{} + } + statementResult.StatementIndex = idx + 1 + resultSets = append(resultSets, statementResult) + } + continue + } + if data == nil { + data = make([]map[string]interface{}, 0) + } + if columns == nil { + columns = []string{} + } + resultSets = append(resultSets, connection.ResultSetData{ + Rows: data, + Columns: columns, + Messages: messages, + StatementIndex: idx + 1, + }) + continue + } + if isReadStmt { + return nil, fmt.Errorf("第 %d 条语句执行失败: %w", idx+1, err) + } + } + + affected, err := session.ExecContext(ctx, stmt) + if err != nil { + return nil, fmt.Errorf("第 %d 条语句执行失败: %w", idx+1, err) + } + resultSets = append(resultSets, connection.ResultSetData{ + Rows: []map[string]interface{}{{"affectedRows": affected}}, + Columns: []string{"affectedRows"}, + StatementIndex: idx + 1, + }) + } + + if resultSets == nil { + resultSets = []connection.ResultSetData{} + } + return resultSets, nil +} + +func shouldUseManagedSQLTransaction(dbType string, query string) bool { + statements := splitSQLStatements(query) + hasManagedWrite := false + for _, stmt := range statements { + stmt = strings.TrimSpace(stmt) + if stmt == "" { + continue + } + if isSQLTransactionControlStatement(stmt) { + return false + } + if isReadOnlySQLQuery(dbType, stmt) { + continue + } + if isBatchableWriteSQLStatement(dbType, stmt) { + hasManagedWrite = true + continue + } + return false + } + return hasManagedWrite +} + +func isSQLTransactionControlStatement(stmt string) bool { + switch leadingSQLKeyword(stmt) { + case "begin", "commit", "rollback", "savepoint", "release": + return true + case "start": + return strings.Contains(strings.ToLower(stmt), "transaction") + default: + return false + } +} + +func (a *App) DBCommitTransaction(transactionID string) connection.QueryResult { + return a.finishManagedSQLTransaction(transactionID, true) +} + +func (a *App) DBRollbackTransaction(transactionID string) connection.QueryResult { + return a.finishManagedSQLTransaction(transactionID, false) +} + +func (a *App) finishManagedSQLTransaction(transactionID string, commit bool) connection.QueryResult { + transactionID = strings.TrimSpace(transactionID) + if transactionID == "" { + return connection.QueryResult{Success: false, Message: "事务 ID 不能为空"} + } + + a.sqlTransactionMu.Lock() + tx, ok := a.sqlTransactions[transactionID] + if ok { + delete(a.sqlTransactions, transactionID) + } + a.sqlTransactionMu.Unlock() + if !ok || tx == nil || tx.execer == nil { + return connection.QueryResult{Success: false, Message: "事务不存在或已结束"} + } + + action := "回滚" + sqlText := tx.rollbackSQL + if commit { + action = "提交" + sqlText = tx.commitSQL + } + + ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout) + defer cancel() + + var execErr error + if strings.TrimSpace(sqlText) != "" { + _, execErr = tx.execer.ExecContext(ctx, sqlText) + } + closeErr := tx.execer.Close() + if execErr != nil { + logger.Error(execErr, "SQL 编辑器事务%s失败:id=%s dbType=%s", action, transactionID, tx.dbType) + return connection.QueryResult{Success: false, Message: fmt.Sprintf("事务%s失败: %v", action, execErr)} + } + if closeErr != nil { + logger.Error(closeErr, "SQL 编辑器事务%s后关闭会话失败:id=%s dbType=%s", action, transactionID, tx.dbType) + return connection.QueryResult{Success: false, Message: fmt.Sprintf("事务%s成功,但关闭会话失败: %v", action, closeErr)} + } + + if commit { + return connection.QueryResult{Success: true, Message: "事务已提交"} + } + return connection.QueryResult{Success: true, Message: "事务已回滚"} +} + +func (a *App) rollbackPendingSQLTransactionsOnShutdown() { + a.sqlTransactionMu.Lock() + pending := make([]*managedSQLTransaction, 0, len(a.sqlTransactions)) + for id, tx := range a.sqlTransactions { + if tx != nil { + pending = append(pending, tx) + } + delete(a.sqlTransactions, id) + } + a.sqlTransactionMu.Unlock() + + for _, tx := range pending { + ctx, cancel := context.WithTimeout(context.Background(), sqlEditorTransactionFinishTimeout) + if strings.TrimSpace(tx.rollbackSQL) != "" && tx.execer != nil { + if _, err := tx.execer.ExecContext(ctx, tx.rollbackSQL); err != nil { + logger.Warnf("关闭应用时回滚 SQL 编辑器事务失败:id=%s dbType=%s err=%v", tx.id, tx.dbType, err) + } + } + cancel() + if tx.execer != nil { + if err := tx.execer.Close(); err != nil { + logger.Warnf("关闭应用时关闭 SQL 编辑器事务会话失败:id=%s dbType=%s err=%v", tx.id, tx.dbType, err) + } + } + } +} diff --git a/internal/connection/types.go b/internal/connection/types.go index 3b13565..31243fb 100644 --- a/internal/connection/types.go +++ b/internal/connection/types.go @@ -130,12 +130,14 @@ type ResultSetData struct { // QueryResult 是 Wails 绑定方法的统一响应格式,前端通过此结构体接收后端结果。 type QueryResult struct { - Success bool `json:"success"` - Message string `json:"message"` - Data interface{} `json:"data"` - Fields []string `json:"fields,omitempty"` - Messages []string `json:"messages,omitempty"` - QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation + Success bool `json:"success"` + Message string `json:"message"` + Data interface{} `json:"data"` + Fields []string `json:"fields,omitempty"` + Messages []string `json:"messages,omitempty"` + QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation + TransactionID string `json:"transactionId,omitempty"` + TransactionPending bool `json:"transactionPending,omitempty"` } // ColumnDefinition 描述表的一个列定义。