diff --git a/frontend/src/components/ConnectionModal.edit-password.test.tsx b/frontend/src/components/ConnectionModal.edit-password.test.tsx index 7300adf..88fab3f 100644 --- a/frontend/src/components/ConnectionModal.edit-password.test.tsx +++ b/frontend/src/components/ConnectionModal.edit-password.test.tsx @@ -19,14 +19,20 @@ describe('ConnectionModal edit password behavior', () => { describe('ConnectionModal data source registry', () => { it('exposes Elasticsearch in the create-connection picker with HTTP defaults', () => { - expect(source).toContain('case "elasticsearch":\n return 9200;'); + expect(source).toContain('case "elasticsearch":'); + expect(source).toContain('return 9200;'); expect(source).toContain('elasticsearch: ["http", "https"]'); expect(source).toContain('key: "elasticsearch"'); expect(source).toContain('name: "Elasticsearch"'); expect(source).toContain('getDbIcon("elasticsearch", undefined, 36)'); expect(source).toContain('type === "elasticsearch"'); - expect(source).toContain('"http://elastic:pass@127.0.0.1:9200/logs-*"'); - expect(source).toContain('label="默认索引(可选)"'); - expect(source).toContain('"显示索引 (留空显示全部)"'); + expect(source).toContain('return "支持索引浏览、Mapping 检查、JSON DSL 和 query_string 查询";'); + expect(source).toContain( + 'type === "clickhouse" ? "default" : (type === "redis" || type === "elasticsearch") ? "" : "root";', + ); + expect(source).toContain( + 'placeholder={dbType === "elasticsearch" ? "未开启认证可留空" : undefined}', + ); + expect(source).toContain('label="显示数据库 (留空显示全部)"'); }); }); diff --git a/frontend/src/components/DataGrid.ddl.test.tsx b/frontend/src/components/DataGrid.ddl.test.tsx index 967371b..3a96c71 100644 --- a/frontend/src/components/DataGrid.ddl.test.tsx +++ b/frontend/src/components/DataGrid.ddl.test.tsx @@ -234,9 +234,31 @@ vi.mock('antd', () => { ) : null ); - Modal.useModal = () => [{ info: vi.fn(() => ({ destroy: vi.fn() })) }, null]; + Modal.useModal = () => { + const [infoConfig, setInfoConfig] = React.useState(null); + return [{ + info: vi.fn((config: any) => { + setInfoConfig(config); + return { + destroy: vi.fn(() => { + setInfoConfig(null); + }), + }; + }), + }, infoConfig ?
{infoConfig.content}
: null]; + }; const passthrough = ({ children }: any) => <>{children}; + const Dropdown = ({ children, menu, disabled }: any) => ( + <> + {children} + {!disabled && menu?.items?.map((item: any) => ( + item?.type === 'divider' + ? null + : + ))} + + ); const Space = ({ children }: any) =>
{children}
; const Tabs = ({ items = [], activeKey, onChange }: any) => { const resolvedActiveKey = activeKey ?? items[0]?.key; @@ -289,7 +311,7 @@ vi.mock('antd', () => { message: messageApi, Input, Button, - Dropdown: passthrough, + Dropdown, Form, Pagination: () => null, Select: ({ children }: any) =>
{children}
, @@ -806,6 +828,49 @@ describe('DataGrid DDL interactions', () => { renderer!.unmount(); }); + it('exports query-result rows from in-memory data without rerunning ExportQuery', async () => { + backendApp.ExportData.mockResolvedValue({ success: true }); + backendApp.ExportQuery.mockResolvedValue({ success: true }); + + let renderer: ReactTestRenderer; + await act(async () => { + renderer = create( + , + ); + }); + await waitForEffects(); + + await act(async () => { + await findButton(renderer!, 'HTML').props.onClick(); + }); + + const exportAllButton = findButton(renderer!, '全部导出'); + await act(async () => { + await exportAllButton.props.onClick(); + }); + await waitForEffects(); + + expect(backendApp.ExportData).toHaveBeenCalledTimes(1); + expect(backendApp.ExportData).toHaveBeenCalledWith( + [{ owner: 'sa' }, { owner: 'dbo' }], + ['owner'], + 'export', + 'html', + ); + expect(backendApp.ExportQuery).not.toHaveBeenCalled(); + }); + it('copies loaded column data from the v2 column header context menu', async () => { storeState.appearance.uiVersion = 'v2'; diff --git a/frontend/src/components/DatabaseIcons.test.tsx b/frontend/src/components/DatabaseIcons.test.tsx index 2e0071b..fbb5ee4 100644 --- a/frontend/src/components/DatabaseIcons.test.tsx +++ b/frontend/src/components/DatabaseIcons.test.tsx @@ -13,7 +13,9 @@ describe('DatabaseIcons', () => { it('includes Elasticsearch in the selectable database icons', () => { expect(DB_ICON_TYPES).toContain('elasticsearch'); expect(getDbIconLabel('elasticsearch')).toBe('Elasticsearch'); - expect(renderToStaticMarkup(<>{getDbIcon('elasticsearch', undefined, 22)})).toContain('ES'); + const markup = renderToStaticMarkup(<>{getDbIcon('elasticsearch', undefined, 22)}); + expect(markup).toContain('elasticsearch.svg'); + expect(markup).toContain('alt="elasticsearch"'); }); it('wraps database icons in a consistent frame for sidebar sizing', () => { diff --git a/frontend/src/components/JVMDiagnosticConsole.test.tsx b/frontend/src/components/JVMDiagnosticConsole.test.tsx index 2b3b952..4eab41a 100644 --- a/frontend/src/components/JVMDiagnosticConsole.test.tsx +++ b/frontend/src/components/JVMDiagnosticConsole.test.tsx @@ -28,6 +28,13 @@ const baseState = { ], jvmDiagnosticDrafts: {}, jvmDiagnosticOutputs: {}, + fontSize: 14, + appearance: { + uiVersion: "legacy", + dataTableFontSize: 14, + dataTableFontSizeFollowGlobal: true, + customMonoFontFamily: "", + }, setJVMDiagnosticDraft: vi.fn(), appendJVMDiagnosticOutput: vi.fn(), clearJVMDiagnosticOutput: vi.fn(), @@ -62,6 +69,7 @@ const mockMonaco = { KeyMod: { CtrlCmd: 2048 }, KeyCode: { Enter: 3 }, editor: { + defineTheme: vi.fn(), setTheme: vi.fn(), }, languages: { @@ -193,6 +201,7 @@ describe("JVMDiagnosticConsole", () => { removeEventListener: vi.fn(), }; mockMonaco.editor.setTheme.mockClear(); + mockMonaco.editor.defineTheme.mockClear(); mockMonaco.languages.register.mockClear(); mockMonaco.languages.registerCompletionItemProvider.mockClear(); mockEditor.addCommand.mockClear(); diff --git a/frontend/src/components/JVMResourceBrowser.interaction.test.tsx b/frontend/src/components/JVMResourceBrowser.interaction.test.tsx index b4f831d..1e61072 100644 --- a/frontend/src/components/JVMResourceBrowser.interaction.test.tsx +++ b/frontend/src/components/JVMResourceBrowser.interaction.test.tsx @@ -29,6 +29,13 @@ const storeState = vi.hoisted(() => ({ aiPanelVisible: false, setAIPanelVisible: vi.fn(), theme: "light", + fontSize: 14, + appearance: { + uiVersion: "legacy", + dataTableFontSize: 14, + dataTableFontSizeFollowGlobal: true, + customMonoFontFamily: "", + }, })); const backendApp = vi.hoisted(() => ({ diff --git a/frontend/src/components/QueryEditor.external-sql-save.test.tsx b/frontend/src/components/QueryEditor.external-sql-save.test.tsx index e1bbb32..878d84f 100644 --- a/frontend/src/components/QueryEditor.external-sql-save.test.tsx +++ b/frontend/src/components/QueryEditor.external-sql-save.test.tsx @@ -80,6 +80,10 @@ const dataGridState = vi.hoisted(() => ({ latestProps: null as any, })); +const tabsState = vi.hoisted(() => ({ + activeKey: undefined as string | undefined, +})); + const autoFetchState = vi.hoisted(() => ({ visible: false, })); @@ -345,11 +349,24 @@ vi.mock('antd', () => { ), Tooltip: ({ children }: any) => <>{children}, Select: () => null, - Tabs: ({ activeKey, items }: any) => { - const activeItem = items?.find((item: any) => item.key === activeKey) || items?.[0]; + Tabs: ({ activeKey, items, onChange }: any) => { + const resolvedActiveKey = tabsState.activeKey ?? activeKey ?? items?.[0]?.key; + const activeItem = items?.find((item: any) => item.key === resolvedActiveKey) || items?.[0]; return (
-
{items?.map((item: any) => {item.label})}
+
{items?.map((item: any) => ( + + ))}
{activeItem?.children}
); @@ -423,6 +440,7 @@ describe('QueryEditor external SQL save', () => { storeState.appearance.uiVersion = 'legacy'; autoFetchState.visible = false; dataGridState.latestProps = null; + tabsState.activeKey = undefined; editorState.value = ''; editorState.position = { lineNumber: 1, column: 1 }; editorState.selection = null; @@ -1788,6 +1806,198 @@ describe('QueryEditor external SQL save', () => { renderer?.unmount(); }); + it('renders result grid for sqlserver exec statements that return rows', async () => { + storeState.connections[0].config.type = 'sqlserver'; + storeState.connections[0].config.database = 'master'; + backendApp.DBQueryMulti.mockResolvedValueOnce({ + success: true, + data: [{ columns: ['SPID', 'STATUS'], rows: [{ SPID: 52, STATUS: 'RUNNABLE' }] }], + }); + + let renderer!: ReactTestRenderer; + await act(async () => { + renderer = create(); + }); + + await act(async () => { + await findButton(renderer!, '运行').props.onClick(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(textContent(renderer!.toJSON())).toContain('结果 1'); + expect(textContent(renderer!.toJSON())).not.toContain('影响行数:'); + expect(dataGridState.latestProps?.columnNames).toEqual(['SPID', 'STATUS']); + expect(Array.isArray(dataGridState.latestProps?.data)).toBe(true); + expect(dataGridState.latestProps?.data?.[0]).toMatchObject({ SPID: 52, STATUS: 'RUNNABLE' }); + }); + + it('renders standalone message result for sqlserver statistics statements', async () => { + storeState.connections[0].config.type = 'sqlserver'; + storeState.connections[0].config.database = 'master'; + backendApp.DBQueryMulti.mockResolvedValueOnce({ + success: true, + data: [{ + columns: [], + rows: [], + messages: ["Table 'users'. Scan count 1, logical reads 3."], + }], + }); + + let renderer!: ReactTestRenderer; + await act(async () => { + renderer = create(); + }); + + await act(async () => { + await findButton(renderer!, '运行').props.onClick(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(textContent(renderer!.toJSON())).toContain('消息 1'); + expect(textContent(renderer!.toJSON())).toContain("Table 'users'. Scan count 1, logical reads 3."); + expect(dataGridState.latestProps?.columnNames).not.toEqual([]); + }); + + it('keeps multiple result sets from a single sqlserver statement', async () => { + storeState.connections[0].config.type = 'sqlserver'; + storeState.connections[0].config.database = 'master'; + backendApp.DBQueryMulti.mockResolvedValueOnce({ + success: true, + data: [ + { statementIndex: 1, columns: ['name'], rows: [{ name: 'master' }] }, + { statementIndex: 1, columns: ['owner'], rows: [{ owner: 'sa' }] }, + ], + }); + + let renderer!: ReactTestRenderer; + await act(async () => { + renderer = create(); + }); + + await act(async () => { + await findButton(renderer!, '运行').props.onClick(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(textContent(renderer!.toJSON())).toContain('结果 1'); + expect(textContent(renderer!.toJSON())).toContain('结果 2'); + expect(dataGridState.latestProps?.columnNames).toEqual(['name']); + }); + + it('keeps both tabs when rerunning the same single sqlserver statement with multiple result sets', async () => { + storeState.connections[0].config.type = 'sqlserver'; + storeState.connections[0].config.database = 'master'; + backendApp.DBQueryMulti + .mockResolvedValueOnce({ + success: true, + data: [ + { statementIndex: 1, columns: ['name'], rows: [{ name: 'master' }] }, + { statementIndex: 1, columns: ['owner'], rows: [{ owner: 'sa' }] }, + ], + }) + .mockResolvedValueOnce({ + success: true, + data: [ + { statementIndex: 1, columns: ['name'], rows: [{ name: 'tempdb' }] }, + { statementIndex: 1, columns: ['owner'], rows: [{ owner: 'dbo' }] }, + ], + }); + + let renderer!: ReactTestRenderer; + await act(async () => { + renderer = create(); + }); + + await act(async () => { + await findButton(renderer!, '运行').props.onClick(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + await act(async () => { + await findButton(renderer!, '运行').props.onClick(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + const tabLabels = renderer!.root.findAll((node) => { + const className = String(node.props?.className || ''); + return className.includes('query-result-tab-label'); + }); + expect(tabLabels).toHaveLength(2); + expect(dataGridState.latestProps?.columnNames).toEqual(['name']); + expect(dataGridState.latestProps?.data?.[0]).toMatchObject({ name: 'tempdb' }); + }); + + it('reloads the active secondary result set for a single sqlserver statement', async () => { + storeState.connections[0].config.type = 'sqlserver'; + storeState.connections[0].config.database = 'master'; + backendApp.DBQueryMulti + .mockResolvedValueOnce({ + success: true, + data: [ + { statementIndex: 1, columns: ['name'], rows: [{ name: 'master' }] }, + { statementIndex: 1, columns: ['owner'], rows: [{ owner: 'sa' }] }, + ], + }) + .mockResolvedValueOnce({ + success: true, + data: [ + { statementIndex: 1, columns: ['name'], rows: [{ name: 'master' }] }, + { statementIndex: 1, columns: ['owner'], rows: [{ owner: 'dbo' }] }, + ], + }); + + let renderer!: ReactTestRenderer; + await act(async () => { + renderer = create(); + }); + + await act(async () => { + await findButton(renderer!, '运行').props.onClick(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + const resultTabButtons = renderer!.root.findAll((node) => node.type === 'button' && node.props['data-tab-key']); + expect(resultTabButtons).toHaveLength(2); + + await act(async () => { + resultTabButtons[1].props.onClick(); + }); + + expect(dataGridState.latestProps?.columnNames).toEqual(['owner']); + expect(dataGridState.latestProps?.data?.[0]).toMatchObject({ owner: 'sa' }); + + await act(async () => { + await dataGridState.latestProps?.onReload?.(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(backendApp.DBQueryMulti).toHaveBeenCalledTimes(2); + expect(dataGridState.latestProps?.columnNames).toEqual(['owner']); + expect(dataGridState.latestProps?.data?.[0]).toMatchObject({ owner: 'dbo' }); + expect(dataGridState.latestProps?.data).not.toEqual(expect.arrayContaining([expect.objectContaining({ name: 'master' })])); + }); + it('keeps non-Oracle query results read-only when no safe locator exists', async () => { backendApp.DBQueryMulti.mockResolvedValueOnce({ success: true, diff --git a/frontend/src/components/QueryEditor.tsx b/frontend/src/components/QueryEditor.tsx index 3cef7ee..6fa4920 100644 --- a/frontend/src/components/QueryEditor.tsx +++ b/frontend/src/components/QueryEditor.tsx @@ -1833,8 +1833,12 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc key: string; sql: string; exportSql?: string; + sourceStatementIndex?: number; + statementResultIndex?: number; rows: any[]; columns: string[]; + messages?: string[]; + resultType?: 'grid' | 'message'; tableName?: string; pkColumns: string[]; editLocator?: EditRowLocator; @@ -3924,6 +3928,13 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc return selected; }; + const buildResultSetMergeKey = (result: ResultSet): string => { + const sqlKey = normalizeExecutedSqlKey(result.exportSql || result.sql); + const sourceStatementIndex = Number(result.sourceStatementIndex || 1); + const statementResultIndex = Number(result.statementResultIndex || 1); + return `${sqlKey}::${sourceStatementIndex}::${statementResultIndex}`; + }; + const mergeResultSets = (previous: ResultSet[], next: ResultSet[], replaceAll: boolean): ResultSet[] => { if (replaceAll || previous.length === 0) { return next.map((result, index) => ({ ...result, key: `result-${index + 1}` })); @@ -3931,8 +3942,8 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc const merged = [...previous]; next.forEach((result) => { - const incomingKey = normalizeExecutedSqlKey(result.exportSql || result.sql); - const existingIndex = merged.findIndex((item) => normalizeExecutedSqlKey(item.exportSql || item.sql) === incomingKey); + const incomingKey = buildResultSetMergeKey(result); + const existingIndex = merged.findIndex((item) => buildResultSetMergeKey(item) === incomingKey); if (existingIndex >= 0) { merged[existingIndex] = { ...result, key: merged[existingIndex].key }; return; @@ -3947,8 +3958,8 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc if (!firstExecutedResult) { return ''; } - const executedSqlKey = normalizeExecutedSqlKey(firstExecutedResult.exportSql || firstExecutedResult.sql); - return merged.find((item) => normalizeExecutedSqlKey(item.exportSql || item.sql) === executedSqlKey)?.key + const executedSqlKey = buildResultSetMergeKey(firstExecutedResult); + return merged.find((item) => buildResultSetMergeKey(item) === executedSqlKey)?.key || firstExecutedResult.key || merged[0]?.key || ''; @@ -4024,6 +4035,8 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc if (!sql?.trim() || !currentDb) return; const conn = connections.find(c => c.id === currentConnectionId); if (!conn) return; + const currentResult = resultSets.find((item) => item.key === resultKey); + const statementResultIndex = Math.max(1, Number(currentResult?.statementResultIndex || 1)); const config = { ...conn.config, @@ -4049,10 +4062,9 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc return; } - // 取第一个结果集(单条 SQL 只有一个结果集) const resultSetDataArray = Array.isArray(res.data) ? (res.data as any[]) : []; - if (resultSetDataArray.length === 0) return; - const rsData = resultSetDataArray[0]; + const rsData = resultSetDataArray[Math.max(0, statementResultIndex - 1)]; + if (!rsData) return; const isAffectedResult = Array.isArray(rsData.rows) && rsData.rows.length === 1 && rsData.columns && rsData.columns.length === 1 && rsData.columns[0] === 'affectedRows'; @@ -4075,7 +4087,16 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc // 只更新匹配的结果集的 rows 和 columns,保留 tableName/pkColumns/readOnly 等元数据 setResultSets(prev => prev.map(rs => rs.key === resultKey - ? { ...rs, rows, columns: cols, truncated } + ? { + ...rs, + rows, + columns: cols, + messages: Array.isArray(rsData.messages) ? rsData.messages : [], + resultType: ((!Array.isArray(rsData.rows) || rsData.rows.length === 0) && (!Array.isArray(rsData.columns) || rsData.columns.length === 0) && Array.isArray(rsData.messages) && rsData.messages.length > 0) + ? 'message' + : 'grid', + truncated, + } : rs )); } catch (err: any) { @@ -4240,12 +4261,29 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc key: `result-${idx + 1}`, sql: rawStatement, exportSql: rawStatement, + sourceStatementIndex: idx + 1, + statementResultIndex: 1, rows, columns: cols, + messages: Array.isArray(res.messages) ? res.messages : [], pkColumns: [], readOnly: true, truncated }); + } else if (Array.isArray(res.messages) && res.messages.length > 0) { + nextResultSets.push({ + key: `result-${idx + 1}`, + sql: rawStatement, + exportSql: rawStatement, + sourceStatementIndex: idx + 1, + statementResultIndex: 1, + rows: [], + columns: [], + messages: res.messages, + resultType: 'message', + pkColumns: [], + readOnly: true, + }); } else { const affected = Number((res.data as any)?.affectedRows); if (Number.isFinite(affected)) { @@ -4255,8 +4293,11 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc key: `result-${idx + 1}`, sql: rawStatement, exportSql: rawStatement, + sourceStatementIndex: idx + 1, + statementResultIndex: 1, rows: [row], columns: ['affectedRows'], + messages: Array.isArray(res.messages) ? res.messages : [], pkColumns: [], readOnly: true }); @@ -4373,12 +4414,17 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc const nextResultSets: ResultSet[] = []; const maxRows = Number(queryOptions?.maxRows) || 0; let anyTruncated = false; + const statementResultCounts = new Map(); for (let idx = 0; idx < resultSetDataArray.length; idx++) { const rsData = resultSetDataArray[idx]; - const plan = executablePlans[idx]; + const sourceStatementIndex = Number(rsData?.statementIndex || idx + 1); + const statementResultIndex = (statementResultCounts.get(sourceStatementIndex) || 0) + 1; + statementResultCounts.set(sourceStatementIndex, statementResultIndex); + const plan = executablePlans[Math.max(0, sourceStatementIndex - 1)]; const originalSql = plan?.originalSql || ''; const executedSql = plan?.executedSql || originalSql; + const resultMessages = Array.isArray(rsData?.messages) ? rsData.messages : []; // 检查是否为 affectedRows 类结果集 const isAffectedResult = Array.isArray(rsData.rows) && rsData.rows.length === 1 @@ -4393,11 +4439,28 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc key: `result-${idx + 1}`, sql: executedSql, exportSql: originalSql, + sourceStatementIndex, + statementResultIndex, rows: [row], columns: ['affectedRows'], + messages: resultMessages, pkColumns: [], readOnly: true }); + } else if ((!Array.isArray(rsData.rows) || rsData.rows.length === 0) && (!Array.isArray(rsData.columns) || rsData.columns.length === 0) && resultMessages.length > 0) { + nextResultSets.push({ + key: `result-${idx + 1}`, + sql: executedSql, + exportSql: originalSql, + sourceStatementIndex, + statementResultIndex, + rows: [], + columns: [], + messages: resultMessages, + resultType: 'message', + pkColumns: [], + readOnly: true, + }); } else { let rows = Array.isArray(rsData.rows) ? rsData.rows : []; let truncated = false; @@ -4421,8 +4484,11 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc key: `result-${idx + 1}`, sql: executedSql, exportSql: originalSql, + sourceStatementIndex, + statementResultIndex, rows, columns: cols, + messages: resultMessages, tableName: tableRef?.tableName, pkColumns: plan?.pkColumns || [], editLocator, @@ -4432,6 +4498,22 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc } } + if (resultSetDataArray.length === 0 && Array.isArray(res.messages) && res.messages.length > 0) { + nextResultSets.push({ + key: 'result-1', + sql: fullSQL, + exportSql: sourceStatements.join(';\n'), + sourceStatementIndex: 1, + statementResultIndex: 1, + rows: [], + columns: [], + messages: res.messages, + resultType: 'message', + pkColumns: [], + readOnly: true, + }); + } + const shouldReplaceAllResults = didExecuteWholeEditor; setResultSets(prev => { const merged = mergeResultSets(prev, nextResultSets, shouldReplaceAllResults); @@ -5316,9 +5398,12 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc }} > - 结果 {idx + 1} + {rs.resultType === 'message' ? `消息 ${idx + 1}` : `结果 ${idx + 1}`} {(() => { + if (rs.resultType === 'message') { + return i; + } const isAffected = rs.columns.length === 1 && rs.columns[0] === 'affectedRows'; if (isAffected) { return ; @@ -5344,6 +5429,29 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc ), children: (() => { + if (rs.resultType === 'message') { + return ( +
+ 执行消息 +
+ {(rs.messages || []).join('\n')} +
+
+ ); + } // affectedRows 类型结果集(UPDATE/INSERT/DELETE):简洁提示 const isAffectedResult = rs.columns.length === 1 && rs.columns[0] === 'affectedRows'; if (isAffectedResult) { @@ -5356,11 +5464,44 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc 执行成功 影响行数:{affected} + {Array.isArray(rs.messages) && rs.messages.length > 0 && ( +
+ {rs.messages.join('\n')} +
+ )} ); } return (
+ {Array.isArray(rs.messages) && rs.messages.length > 0 && ( +
+ {rs.messages.join('\n')} +
+ )} - : ; +const renderRoot = async () => { + let rootComponent = ; + if (devHarnessMode === 'datagrid-perf') { + const { default: PerfDataGridHarness } = await import('./dev/PerfDataGridHarness'); + rootComponent = ; + } -ReactDOM.createRoot(rootNode).render( - - {rootComponent} - , -) + ReactDOM.createRoot(rootNode).render( + + {rootComponent} + , + ); +}; + +void renderRoot(); diff --git a/frontend/src/utils/dataSourceCapabilities.test.ts b/frontend/src/utils/dataSourceCapabilities.test.ts index d2b005b..db863ce 100644 --- a/frontend/src/utils/dataSourceCapabilities.test.ts +++ b/frontend/src/utils/dataSourceCapabilities.test.ts @@ -63,12 +63,12 @@ describe('dataSourceCapabilities', () => { supportsCreateDatabase: false, supportsRenameDatabase: false, supportsDropDatabase: false, - forceReadOnlyQueryResult: true, + forceReadOnlyQueryResult: false, }); expect(getDataSourceCapabilities({ type: 'custom', driver: 'elastic' })).toMatchObject({ type: 'elasticsearch', supportsQueryEditor: true, - forceReadOnlyQueryResult: true, + forceReadOnlyQueryResult: false, }); }); diff --git a/frontend/wailsjs/go/models.ts b/frontend/wailsjs/go/models.ts index 32611bb..b76630b 100755 --- a/frontend/wailsjs/go/models.ts +++ b/frontend/wailsjs/go/models.ts @@ -806,6 +806,7 @@ export namespace connection { message: string; data: any; fields?: string[]; + messages?: string[]; queryId?: string; static createFrom(source: any = {}) { @@ -818,6 +819,7 @@ export namespace connection { this.message = source["message"]; this.data = source["data"]; this.fields = source["fields"]; + this.messages = source["messages"]; this.queryId = source["queryId"]; } } diff --git a/go.mod b/go.mod index ea49a54..62fdf54 100644 --- a/go.mod +++ b/go.mod @@ -64,7 +64,7 @@ require ( github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect - github.com/golang-sql/sqlexp v0.1.0 // indirect + github.com/golang-sql/sqlexp v0.1.0 github.com/golang/snappy v1.0.0 // indirect github.com/google/flatbuffers v25.12.19+incompatible // indirect github.com/gorilla/websocket v1.5.3 diff --git a/internal/app/methods_data_root.go b/internal/app/methods_data_root.go index c68a1d4..5e46f81 100644 --- a/internal/app/methods_data_root.go +++ b/internal/app/methods_data_root.go @@ -1,6 +1,7 @@ package app import ( + "encoding/json" "fmt" "io" "os" @@ -17,12 +18,8 @@ import ( "github.com/wailsapp/wails/v2/pkg/runtime" ) -var migratableDataRootEntries = []string{ - "connections.json", - "global_proxy.json", - "ai_config.json", - "sessions", - "drivers", +var dataRootMigrationExcludedEntries = map[string]struct{}{ + "storage_root.json": {}, } func (a *App) GetDataRootDirectoryInfo() connection.QueryResult { @@ -122,22 +119,43 @@ func migrateDataRootContents(sourceRoot string, targetRoot string) error { if sourceRoot == "" || targetRoot == "" { return fmt.Errorf("数据目录不能为空") } - if filepath.Clean(sourceRoot) == filepath.Clean(targetRoot) { + sourceAbs, err := filepath.Abs(sourceRoot) + if err != nil { + return fmt.Errorf("解析源数据目录失败:%w", err) + } + targetAbs, err := filepath.Abs(targetRoot) + if err != nil { + return fmt.Errorf("解析目标数据目录失败:%w", err) + } + if filepath.Clean(sourceAbs) == filepath.Clean(targetAbs) { return nil } + if rel, err := filepath.Rel(sourceAbs, targetAbs); err == nil && rel != "." && rel != "" && !strings.HasPrefix(rel, "..") && !filepath.IsAbs(rel) { + return fmt.Errorf("目标数据目录不能位于源目录内部") + } + sourceRoot = sourceAbs + targetRoot = targetAbs if err := os.MkdirAll(targetRoot, 0o755); err != nil { return fmt.Errorf("创建目标数据目录失败:%w", err) } - for _, name := range migratableDataRootEntries { + entries, err := os.ReadDir(sourceRoot) + if err != nil { + return fmt.Errorf("读取源数据目录失败:%w", err) + } + for _, entry := range entries { + name := strings.TrimSpace(entry.Name()) + if name == "" { + continue + } + if _, excluded := dataRootMigrationExcludedEntries[name]; excluded { + continue + } sourcePath := filepath.Join(sourceRoot, name) - info, err := os.Stat(sourcePath) + targetPath := filepath.Join(targetRoot, name) + info, err := entry.Info() if err != nil { - if os.IsNotExist(err) { - continue - } return fmt.Errorf("读取源数据失败(%s):%w", name, err) } - targetPath := filepath.Join(targetRoot, name) if info.IsDir() { if err := copyDir(sourcePath, targetPath); err != nil { return fmt.Errorf("迁移目录失败(%s):%w", name, err) @@ -148,6 +166,75 @@ func migrateDataRootContents(sourceRoot string, targetRoot string) error { return fmt.Errorf("迁移文件失败(%s):%w", name, err) } } + if err := rewriteMigratedDataRootState(targetRoot); err != nil { + return err + } + return nil +} + +func rewriteMigratedDataRootState(targetRoot string) error { + if err := rewriteSecurityUpdateBackupPaths(targetRoot); err != nil { + return err + } + return nil +} + +func rewriteSecurityUpdateBackupPaths(targetRoot string) error { + repo := newSecurityUpdateStateRepository(targetRoot) + marker, err := repo.readMarker() + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("读取迁移后的安全更新状态失败:%w", err) + } + + migrationID := strings.TrimSpace(marker.MigrationID) + if migrationID == "" { + return nil + } + + targetBackupPath := repo.backupPath(migrationID) + marker.BackupPath = targetBackupPath + if err := repo.writeMarker(marker); err != nil { + return fmt.Errorf("写入迁移后的安全更新状态失败:%w", err) + } + + manifestPath := repo.manifestPath(migrationID) + manifestData, err := os.ReadFile(manifestPath) + if err != nil { + if !os.IsNotExist(err) { + return fmt.Errorf("读取迁移后的安全更新备份清单失败:%w", err) + } + } else { + var manifest securityUpdateBackupManifest + if err := json.Unmarshal(manifestData, &manifest); err != nil { + return fmt.Errorf("解析迁移后的安全更新备份清单失败:%w", err) + } + manifest.BackupPath = targetBackupPath + if err := securityUpdateWriteJSONFile(manifestPath, manifest); err != nil { + return fmt.Errorf("写入迁移后的安全更新备份清单失败:%w", err) + } + } + + resultPath := repo.resultPath(migrationID) + resultData, err := os.ReadFile(resultPath) + if err != nil { + if !os.IsNotExist(err) { + return fmt.Errorf("读取迁移后的安全更新结果失败:%w", err) + } + } else { + var result SecurityUpdateStatus + if err := json.Unmarshal(resultData, &result); err != nil { + return fmt.Errorf("解析迁移后的安全更新结果失败:%w", err) + } + result.BackupPath = targetBackupPath + result.BackupAvailable = strings.TrimSpace(targetBackupPath) != "" + if err := securityUpdateWriteJSONFile(resultPath, result); err != nil { + return fmt.Errorf("写入迁移后的安全更新结果失败:%w", err) + } + } + return nil } diff --git a/internal/app/methods_data_root_test.go b/internal/app/methods_data_root_test.go index 75ce8be..34dacac 100644 --- a/internal/app/methods_data_root_test.go +++ b/internal/app/methods_data_root_test.go @@ -105,6 +105,36 @@ func TestMigrateDataRootContentsCopiesSecurityUpdateStateAndRewritesBackupPaths( } } +func TestMigrateDataRootContentsToleratesMissingSecurityUpdateArtifacts(t *testing.T) { + sourceRoot := t.TempDir() + targetRoot := filepath.Join(t.TempDir(), "gonavi-data") + sourceRepo := newSecurityUpdateStateRepository(sourceRoot) + started, err := sourceRepo.StartRound(StartSecurityUpdateRequest{SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig}) + if err != nil { + t.Fatalf("start security update round failed: %v", err) + } + if err := os.Remove(sourceRepo.manifestPath(started.MigrationID)); err != nil { + t.Fatalf("remove source manifest failed: %v", err) + } + if err := os.Remove(sourceRepo.resultPath(started.MigrationID)); err != nil { + t.Fatalf("remove source result failed: %v", err) + } + + if err := migrateDataRootContents(sourceRoot, targetRoot); err != nil { + t.Fatalf("migrateDataRootContents should tolerate missing security update artifacts, got: %v", err) + } + + targetRepo := newSecurityUpdateStateRepository(targetRoot) + targetStatus, err := targetRepo.LoadMarker() + if err != nil { + t.Fatalf("load migrated marker failed: %v", err) + } + expectedBackupPath := filepath.Join(targetRoot, securityUpdateBackupRootDirName, started.MigrationID) + if targetStatus.BackupPath != expectedBackupPath { + t.Fatalf("expected migrated marker backupPath %q, got %q", expectedBackupPath, targetStatus.BackupPath) + } +} + func TestMigrateDataRootContentsCopiesDailySecretsForSavedConnections(t *testing.T) { sourceRoot := t.TempDir() targetRoot := filepath.Join(t.TempDir(), "gonavi-data") diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index 6c960e2..08be907 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -623,6 +623,7 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin }() isReadQuery := isReadOnlySQLQuery(runConfig.Type, query) + tryQueryFirst := shouldTryQueryResultFirst(runConfig.Type, query) runReadQuery := func(inst db.Database) ([]map[string]interface{}, []string, error) { if q, ok := inst.(interface { @@ -633,6 +634,14 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin return inst.Query(query) } + runReadQueryWithMessages := func(inst db.Database) ([]map[string]interface{}, []string, []string, error) { + if q, ok := inst.(db.QueryMessageExecer); ok { + return q.QueryContextWithMessages(ctx, query) + } + data, columns, err := runReadQuery(inst) + return data, columns, nil, err + } + runExecQuery := func(inst db.Database) (int64, error) { if e, ok := inst.(interface { ExecContext(context.Context, string) (int64, error) @@ -642,8 +651,8 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin return inst.Exec(query) } - if isReadQuery { - data, columns, err := runReadQuery(dbInst) + if isReadQuery || tryQueryFirst { + data, columns, messages, err := runReadQueryWithMessages(dbInst) if err != nil && shouldRefreshCachedConnection(err) { if a.invalidateCachedDatabase(runConfig, err) { retryInst, retryErr := a.getDatabaseForcePing(runConfig) @@ -651,32 +660,34 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) return connection.QueryResult{Success: false, Message: retryErr.Error()} } - data, columns, err = runReadQuery(retryInst) + data, columns, messages, err = runReadQueryWithMessages(retryInst) } } - if err != nil { + if err == nil { + return connection.QueryResult{Success: true, Data: data, Fields: columns, Messages: messages, QueryID: queryID} + } + if isReadQuery { logger.Error(err, "DBQuery 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} } - return connection.QueryResult{Success: true, Data: data, Fields: columns, QueryID: queryID} - } else { - affected, err := runExecQuery(dbInst) - if err != nil && shouldRefreshCachedConnection(err) { - if a.invalidateCachedDatabase(runConfig, err) { - retryInst, retryErr := a.getDatabaseForcePing(runConfig) - if retryErr != nil { - logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) - return connection.QueryResult{Success: false, Message: retryErr.Error()} - } - affected, err = runExecQuery(retryInst) - } - } - if err != nil { - logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) - return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} - } - return connection.QueryResult{Success: true, Data: map[string]int64{"affectedRows": affected}, QueryID: queryID} } + + affected, err := runExecQuery(dbInst) + if err != nil && shouldRefreshCachedConnection(err) { + if a.invalidateCachedDatabase(runConfig, err) { + retryInst, retryErr := a.getDatabaseForcePing(runConfig) + if retryErr != nil { + logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) + return connection.QueryResult{Success: false, Message: retryErr.Error()} + } + affected, err = runExecQuery(retryInst) + } + } + if err != nil { + logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) + return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} + } + return connection.QueryResult{Success: true, Data: map[string]int64{"affectedRows": affected}, QueryID: queryID} } // DBQueryMulti 执行可能包含多条 SQL 语句的查询,返回多个结果集。 @@ -727,20 +738,25 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu } } - runMultiQuery := func(inst db.Database) ([]connection.ResultSetData, error) { + runMultiQuery := func(inst db.Database) ([]connection.ResultSetData, []string, error) { if !allReadOnly { - return nil, nil // 包含写操作,走逐条执行路径 + return nil, nil, nil // 包含写操作,走逐条执行路径 + } + if q, ok := inst.(db.MultiResultQueryMessageExecer); ok { + return q.QueryMultiContextWithMessages(ctx, query) } if q, ok := inst.(db.MultiResultQuerierContext); ok { - return q.QueryMultiContext(ctx, query) + results, err := q.QueryMultiContext(ctx, query) + return results, nil, err } if q, ok := inst.(db.MultiResultQuerier); ok { - return q.QueryMulti(query) + results, err := q.QueryMulti(query) + return results, nil, err } - return nil, nil // 返回 nil 表示不支持 + return nil, nil, nil // 返回 nil 表示不支持 } - results, err := runMultiQuery(dbInst) + results, resultMessages, err := runMultiQuery(dbInst) if err != nil && shouldRefreshCachedConnection(err) { if a.invalidateCachedDatabase(runConfig, err) { retryInst, retryErr := a.getDatabaseForcePing(runConfig) @@ -748,7 +764,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu logger.Error(retryErr, "DBQueryMulti 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) return connection.QueryResult{Success: false, Message: retryErr.Error(), QueryID: queryID} } - results, err = runMultiQuery(retryInst) + results, resultMessages, err = runMultiQuery(retryInst) } } if err != nil { @@ -758,7 +774,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu // 驱动支持多结果集,直接返回 if results != nil { - return connection.QueryResult{Success: true, Data: results, QueryID: queryID} + return connection.QueryResult{Success: true, Data: results, Messages: resultMessages, QueryID: queryID} } // 驱动不支持多结果集,回退到逐条执行 @@ -771,13 +787,50 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu } } + var sessionQueryTarget db.StatementQueryExecer + var sessionQueryMessageTarget db.StatementQueryMessageExecer + var sessionMultiQueryTarget db.StatementMultiResultQueryExecer + var sessionMultiQueryMessageTarget db.StatementMultiResultQueryMessageExecer + var sessionExecTarget db.StatementExecer + var sessionBatchTarget db.BatchWriteExecer + closeExecTarget := func() {} + if provider, ok := dbInst.(db.SessionExecerProvider); ok { + sessionExecer, sessionErr := provider.OpenSessionExecer(ctx) + if sessionErr != nil { + logger.Warnf("DBQueryMulti 打开会话级执行器失败,将回退共享连接:%s SQL片段=%q err=%v", formatConnSummary(runConfig), sqlSnippet(query), sessionErr) + } else { + if statementQueryExecer, ok := sessionExecer.(db.StatementQueryExecer); ok { + sessionQueryTarget = statementQueryExecer + } + if statementQueryMessageExecer, ok := sessionExecer.(db.StatementQueryMessageExecer); ok { + sessionQueryMessageTarget = statementQueryMessageExecer + } + if statementMultiResultQueryExecer, ok := sessionExecer.(db.StatementMultiResultQueryExecer); ok { + sessionMultiQueryTarget = statementMultiResultQueryExecer + } + if statementMultiResultQueryMessageExecer, ok := sessionExecer.(db.StatementMultiResultQueryMessageExecer); ok { + sessionMultiQueryMessageTarget = statementMultiResultQueryMessageExecer + } + sessionExecTarget = sessionExecer + if batcher, ok := sessionExecer.(db.BatchWriteExecer); ok { + sessionBatchTarget = batcher + } + closeExecTarget = func() { + if err := sessionExecer.Close(); err != nil { + logger.Warnf("DBQueryMulti 关闭会话级执行器失败:%v", err) + } + } + } + } + defer closeExecTarget() + // 全部为写操作且驱动支持批量 Exec → 一次性发送,大幅减少网络往返 // 适用于 MySQL/MariaDB/Doris/PostgreSQL/SQLite/DuckDB 等支持多语句 Exec 的驱动 if !allReadOnly { allWrite := true containsPLSQLBlock := false for _, stmt := range statements { - if strings.TrimSpace(stmt) != "" && isReadOnlySQLQuery(runConfig.Type, stmt) { + if strings.TrimSpace(stmt) != "" && !isBatchableWriteSQLStatement(runConfig.Type, stmt) { allWrite = false } if isPLSQLBlockStatement(stmt) { @@ -785,7 +838,13 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu } } if allWrite && !containsPLSQLBlock { - if batcher, ok := dbInst.(db.BatchWriteExecer); ok { + batcher := sessionBatchTarget + if batcher == nil { + if fallbackBatcher, ok := dbInst.(db.BatchWriteExecer); ok { + batcher = fallbackBatcher + } + } + if batcher != nil { affected, batchErr := batcher.ExecBatchContext(ctx, query) if batchErr != nil && shouldRefreshCachedConnection(batchErr) { if a.invalidateCachedDatabase(runConfig, batchErr) { @@ -823,17 +882,80 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu continue } - if isReadOnlySQLQuery(runConfig.Type, stmt) { - var data []map[string]interface{} - var columns []string - if q, ok := dbInst.(interface { + isReadStmt := isReadOnlySQLQuery(runConfig.Type, stmt) + tryQueryStmtFirst := shouldTryQueryResultFirst(runConfig.Type, stmt) + if isReadStmt || tryQueryStmtFirst { + var ( + data []map[string]interface{} + columns []string + messages []string + statementResults []connection.ResultSetData + usedMultiResult bool + ) + if sessionMultiQueryMessageTarget != nil { + statementResults, messages, err = sessionMultiQueryMessageTarget.QueryMultiContextWithMessages(ctx, stmt) + usedMultiResult = true + } else if sessionMultiQueryTarget != nil { + statementResults, err = sessionMultiQueryTarget.QueryMultiContext(ctx, stmt) + usedMultiResult = true + } else if q, ok := dbInst.(db.MultiResultQueryMessageExecer); ok { + statementResults, messages, err = q.QueryMultiContextWithMessages(ctx, stmt) + usedMultiResult = true + } else if q, ok := dbInst.(db.MultiResultQuerierContext); ok { + statementResults, err = q.QueryMultiContext(ctx, stmt) + usedMultiResult = true + } else if q, ok := dbInst.(db.MultiResultQuerier); ok { + statementResults, err = q.QueryMulti(stmt) + usedMultiResult = true + } else if sessionQueryMessageTarget != nil { + data, columns, messages, err = sessionQueryMessageTarget.QueryContextWithMessages(ctx, stmt) + } else if sessionQueryTarget != nil { + data, columns, err = sessionQueryTarget.QueryContext(ctx, stmt) + } else if q, ok := dbInst.(db.QueryMessageExecer); ok { + data, columns, messages, err = q.QueryContextWithMessages(ctx, stmt) + } else if q, ok := dbInst.(interface { QueryContext(context.Context, string) ([]map[string]interface{}, []string, error) }); ok { data, columns, err = q.QueryContext(ctx, stmt) } else { data, columns, err = dbInst.Query(stmt) } - if err != nil { + if err == nil { + if usedMultiResult { + if len(statementResults) == 0 && len(messages) > 0 { + statementResults = []connection.ResultSetData{{ + Rows: []map[string]interface{}{}, + Columns: []string{}, + Messages: append([]string(nil), messages...), + }} + } + for _, statementResult := range statementResults { + if statementResult.Rows == nil { + statementResult.Rows = []map[string]interface{}{} + } + if statementResult.Columns == nil { + statementResult.Columns = []string{} + } + statementResult.StatementIndex = idx + 1 + resultSets = append(resultSets, statementResult) + } + continue + } + if data == nil { + data = make([]map[string]interface{}, 0) + } + if columns == nil { + columns = []string{} + } + resultSets = append(resultSets, connection.ResultSetData{ + Rows: data, + Columns: columns, + Messages: messages, + StatementIndex: idx + 1, + }) + continue + } + if isReadStmt { logger.Error(err, "DBQueryMulti 逐条查询失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt)) errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err) if len(resultSets) > 0 { @@ -841,35 +963,30 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu } return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID} } - if data == nil { - data = make([]map[string]interface{}, 0) - } - if columns == nil { - columns = []string{} - } - resultSets = append(resultSets, connection.ResultSetData{Rows: data, Columns: columns}) - } else { - var affected int64 - if e, ok := dbInst.(interface { - ExecContext(context.Context, string) (int64, error) - }); ok { - affected, err = e.ExecContext(ctx, stmt) - } else { - affected, err = dbInst.Exec(stmt) - } - if err != nil { - logger.Error(err, "DBQueryMulti 逐条执行失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt)) - errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err) - if len(resultSets) > 0 { - errMsg += fmt.Sprintf("(前 %d 条已执行成功)", len(resultSets)) - } - return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID} - } - resultSets = append(resultSets, connection.ResultSetData{ - Rows: []map[string]interface{}{{"affectedRows": affected}}, - Columns: []string{"affectedRows"}, - }) } + + var affected int64 + if sessionExecTarget != nil { + affected, err = sessionExecTarget.ExecContext(ctx, stmt) + } else if e, ok := dbInst.(interface { + ExecContext(context.Context, string) (int64, error) + }); ok { + affected, err = e.ExecContext(ctx, stmt) + } else { + affected, err = dbInst.Exec(stmt) + } + if err != nil { + logger.Error(err, "DBQueryMulti 逐条执行失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt)) + errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err) + if len(resultSets) > 0 { + errMsg += fmt.Sprintf("(前 %d 条已执行成功)", len(resultSets)) + } + return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID} + } + resultSets = append(resultSets, connection.ResultSetData{ + Rows: []map[string]interface{}{{"affectedRows": affected}}, + Columns: []string{"affectedRows"}, + }) } if resultSets == nil { @@ -883,6 +1000,24 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu return connection.QueryResult{Success: true, Data: resultSets, QueryID: queryID, Message: fallbackMsg} } +func shouldTryQueryResultFirst(dbType string, query string) bool { + isSQLServer := strings.EqualFold(strings.TrimSpace(dbType), "sqlserver") + keyword := leadingSQLKeyword(query) + switch keyword { + case "exec", "execute", "call": + return true + case "set", "print": + return isSQLServer + case "dbcc": + return isSQLServer + default: + if isSQLServer { + return strings.HasPrefix(keyword, "sp_") || strings.HasPrefix(keyword, "xp_") + } + return false + } +} + func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string, query string) connection.QueryResult { runConfig := normalizeRunConfig(config, dbName) @@ -906,22 +1041,30 @@ func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string, defer cancel() isReadQuery := isReadOnlySQLQuery(runConfig.Type, query) + tryQueryFirst := shouldTryQueryResultFirst(runConfig.Type, query) - if isReadQuery { - var data []map[string]interface{} - var columns []string - if q, ok := dbInst.(interface { + if isReadQuery || tryQueryFirst { + var ( + data []map[string]interface{} + columns []string + messages []string + ) + if q, ok := dbInst.(db.QueryMessageExecer); ok { + data, columns, messages, err = q.QueryContextWithMessages(ctx, query) + } else if q, ok := dbInst.(interface { QueryContext(context.Context, string) ([]map[string]interface{}, []string, error) }); ok { data, columns, err = q.QueryContext(ctx, query) } else { data, columns, err = dbInst.Query(query) } - if err != nil { + if err == nil { + return connection.QueryResult{Success: true, Data: data, Fields: columns, Messages: messages} + } + if isReadQuery { logger.Error(err, "DBQueryIsolated 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) return connection.QueryResult{Success: false, Message: err.Error()} } - return connection.QueryResult{Success: true, Data: data, Fields: columns} } var affected int64 diff --git a/internal/app/methods_db_multi_test.go b/internal/app/methods_db_multi_test.go index ff30389..4250198 100644 --- a/internal/app/methods_db_multi_test.go +++ b/internal/app/methods_db_multi_test.go @@ -10,10 +10,17 @@ import ( ) type fakeBatchWriteDB struct { - batchCalls int - execCalls int + batchCalls int + execCalls int execQueries []string - lastQuery string + lastQuery string + queryCalls int + queryMap map[string][]map[string]interface{} + fieldMap map[string][]string + messageMap map[string][]string + multiResult map[string][]connection.ResultSetData + queryErr map[string]error + session *fakeBatchWriteSession } func (f *fakeBatchWriteDB) Connect(config connection.ConnectionConfig) error { @@ -29,7 +36,16 @@ func (f *fakeBatchWriteDB) Ping() error { } func (f *fakeBatchWriteDB) Query(query string) ([]map[string]interface{}, []string, error) { - return nil, nil, nil + f.queryCalls++ + if err := f.queryErr[query]; err != nil { + return nil, nil, err + } + return f.queryMap[query], f.fieldMap[query], nil +} + +func (f *fakeBatchWriteDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + rows, fields, err := f.Query(query) + return rows, fields, f.messageMap[query], err } func (f *fakeBatchWriteDB) Exec(query string) (int64, error) { @@ -76,12 +92,143 @@ func (f *fakeBatchWriteDB) ExecContext(ctx context.Context, query string) (int64 return 1, nil } +func (f *fakeBatchWriteDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + f.queryCalls++ + if err := f.queryErr[query]; err != nil { + return nil, nil, err + } + return f.queryMap[query], f.fieldMap[query], nil +} + +func (f *fakeBatchWriteDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { + rows, fields, err := f.QueryContext(ctx, query) + return rows, fields, f.messageMap[query], err +} + func (f *fakeBatchWriteDB) ExecBatchContext(ctx context.Context, query string) (int64, error) { f.batchCalls++ f.lastQuery = query return 500, nil } +func (f *fakeBatchWriteDB) OpenSessionExecer(ctx context.Context) (db.StatementExecer, error) { + f.session = &fakeBatchWriteSession{parent: f} + return f.session, nil +} + +type fakeBatchWriteSession struct { + parent *fakeBatchWriteDB + queryCalls int + execCalls int + batchCalls int + closed bool +} + +func (s *fakeBatchWriteSession) Query(query string) ([]map[string]interface{}, []string, error) { + return s.QueryContext(context.Background(), query) +} + +func (s *fakeBatchWriteSession) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + s.queryCalls++ + return s.parent.QueryContext(ctx, query) +} + +func (s *fakeBatchWriteSession) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + return s.QueryContextWithMessages(context.Background(), query) +} + +func (s *fakeBatchWriteSession) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { + s.queryCalls++ + return s.parent.QueryContextWithMessages(ctx, query) +} + +func (s *fakeBatchWriteSession) QueryMulti(query string) ([]connection.ResultSetData, error) { + return s.QueryMultiContext(context.Background(), query) +} + +func (s *fakeBatchWriteSession) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) { + if multi := s.parent.multiResult[query]; len(multi) > 0 { + s.queryCalls++ + return cloneResultSets(multi), nil + } + rows, columns, err := s.QueryContext(ctx, query) + if err != nil { + return nil, err + } + return []connection.ResultSetData{{Rows: rows, Columns: columns}}, nil +} + +func (s *fakeBatchWriteSession) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) { + return s.QueryMultiContextWithMessages(context.Background(), query) +} + +func (s *fakeBatchWriteSession) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) { + if err := s.parent.queryErr[query]; err != nil { + s.queryCalls++ + return nil, nil, err + } + if multi := s.parent.multiResult[query]; len(multi) > 0 { + s.queryCalls++ + return cloneResultSets(multi), append([]string(nil), s.parent.messageMap[query]...), nil + } + rows, columns, messages, err := s.QueryContextWithMessages(ctx, query) + if err != nil { + return nil, nil, err + } + return []connection.ResultSetData{{ + Rows: rows, + Columns: columns, + Messages: append([]string(nil), messages...), + }}, append([]string(nil), messages...), nil +} + +func (s *fakeBatchWriteSession) Exec(query string) (int64, error) { + return s.ExecContext(context.Background(), query) +} + +func (s *fakeBatchWriteSession) ExecContext(ctx context.Context, query string) (int64, error) { + s.execCalls++ + return s.parent.ExecContext(ctx, query) +} + +func (s *fakeBatchWriteSession) ExecBatchContext(ctx context.Context, query string) (int64, error) { + s.batchCalls++ + return s.parent.ExecBatchContext(ctx, query) +} + +func (s *fakeBatchWriteSession) Close() error { + s.closed = true + return nil +} + +func cloneResultSets(input []connection.ResultSetData) []connection.ResultSetData { + if len(input) == 0 { + return nil + } + cloned := make([]connection.ResultSetData, 0, len(input)) + for _, item := range input { + rows := make([]map[string]interface{}, 0, len(item.Rows)) + for _, row := range item.Rows { + if row == nil { + rows = append(rows, nil) + continue + } + rowCopy := make(map[string]interface{}, len(row)) + for key, value := range row { + rowCopy[key] = value + } + rows = append(rows, rowCopy) + } + cloned = append(cloned, connection.ResultSetData{ + Rows: rows, + Columns: append([]string(nil), item.Columns...), + Messages: append([]string(nil), item.Messages...), + StatementIndex: item.StatementIndex, + }) + } + return cloned +} + func TestDBQueryMultiKeepsOracleAnonymousBlockAsSingleStatement(t *testing.T) { originalNewDatabaseFunc := newDatabaseFunc t.Cleanup(func() { @@ -122,6 +269,85 @@ END;` } var _ db.BatchWriteExecer = (*fakeBatchWriteDB)(nil) +var _ db.SessionExecerProvider = (*fakeBatchWriteDB)(nil) +var _ db.QueryMessageExecer = (*fakeBatchWriteDB)(nil) +var _ db.StatementQueryMessageExecer = (*fakeBatchWriteSession)(nil) + +func TestDBQueryWithCancelReturnsResultSetForExecStoredProcedure(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + query := "EXEC sp_who2" + fakeDB := &fakeBatchWriteDB{ + queryMap: map[string][]map[string]interface{}{ + query: { + {"SPID": 52, "STATUS": "RUNNABLE"}, + }, + }, + fieldMap: map[string][]string{ + query: {"SPID", "STATUS"}, + }, + queryErr: map[string]error{}, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"} + + result := app.DBQueryWithCancel(config, "master", query, "sp-who2-test") + if !result.Success { + t.Fatalf("expected DBQueryWithCancel success, got failure: %s", result.Message) + } + rows, ok := result.Data.([]map[string]interface{}) + if !ok { + t.Fatalf("expected []map[string]interface{}, got %T", result.Data) + } + if len(rows) != 1 || rows[0]["SPID"] != 52 { + t.Fatalf("unexpected rows: %#v", rows) + } + if fakeDB.execCalls != 0 { + t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls) + } +} + +func TestDBQueryWithCancelReturnsMessagesForSQLServerQuery(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + query := "SET STATISTICS IO ON" + fakeDB := &fakeBatchWriteDB{ + queryMap: map[string][]map[string]interface{}{ + query: {}, + }, + fieldMap: map[string][]string{ + query: {}, + }, + messageMap: map[string][]string{ + query: {"Table 'users'. Scan count 1, logical reads 3."}, + }, + queryErr: map[string]error{}, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"} + + result := app.DBQueryWithCancel(config, "master", query, "statistics-io-test") + if !result.Success { + t.Fatalf("expected DBQueryWithCancel success, got failure: %s", result.Message) + } + if len(result.Messages) != 1 || result.Messages[0] == "" { + t.Fatalf("expected SQL Server messages to be returned, got %#v", result.Messages) + } +} func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) { originalNewDatabaseFunc := newDatabaseFunc @@ -168,3 +394,206 @@ func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) { t.Fatalf("expected affectedRows=500, got %#v", got) } } + +func TestDBQueryMultiPrefersResultSetForExecStoredProcedure(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + query := "EXEC sp_who2" + fakeDB := &fakeBatchWriteDB{ + queryMap: map[string][]map[string]interface{}{ + query: { + {"SPID": 77, "STATUS": "SUSPENDED"}, + }, + }, + fieldMap: map[string][]string{ + query: {"SPID", "STATUS"}, + }, + queryErr: map[string]error{}, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"} + + result := app.DBQueryMulti(config, "master", query, "sp-who2-multi-test") + if !result.Success { + t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message) + } + resultSets, ok := result.Data.([]connection.ResultSetData) + if !ok { + t.Fatalf("expected []connection.ResultSetData, got %T", result.Data) + } + if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 { + t.Fatalf("unexpected result sets: %#v", resultSets) + } + if got := resultSets[0].Rows[0]["SPID"]; got != 77 { + t.Fatalf("expected SPID=77, got %#v", got) + } + if fakeDB.execCalls != 0 { + t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls) + } +} + +func TestDBQueryMultiDoesNotBatchExecStoredProcedureAsWriteStatement(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + query := "EXEC sp_who2" + fakeDB := &fakeBatchWriteDB{ + queryMap: map[string][]map[string]interface{}{ + query: { + {"SPID": 88, "STATUS": "RUNNING"}, + }, + }, + fieldMap: map[string][]string{ + query: {"SPID", "STATUS"}, + }, + queryErr: map[string]error{}, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"} + + result := app.DBQueryMulti(config, "master", query, "sp-who2-batch-guard-test") + if !result.Success { + t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message) + } + if fakeDB.batchCalls != 0 { + t.Fatalf("expected stored procedure to skip batch write path, got batchCalls=%d", fakeDB.batchCalls) + } + resultSets, ok := result.Data.([]connection.ResultSetData) + if !ok { + t.Fatalf("expected []connection.ResultSetData, got %T", result.Data) + } + if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 { + t.Fatalf("unexpected result sets: %#v", resultSets) + } + if got := resultSets[0].Rows[0]["SPID"]; got != 88 { + t.Fatalf("expected SPID=88, got %#v", got) + } +} + +func TestDBQueryMultiUsesPinnedSessionForSequentialFallback(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + fakeDB := &fakeBatchWriteDB{ + queryMap: map[string][]map[string]interface{}{ + "SELECT 1 AS value": { + {"value": 1}, + }, + }, + fieldMap: map[string][]string{ + "SELECT 1 AS value": {"value"}, + }, + messageMap: map[string][]string{ + "SET NOCOUNT ON": {"NOCOUNT 已开启"}, + }, + queryErr: map[string]error{}, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"} + + result := app.DBQueryMulti(config, "master", "SET NOCOUNT ON;\nSELECT 1 AS value;", "session-fallback-test") + if !result.Success { + t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message) + } + if fakeDB.session == nil { + t.Fatal("expected DBQueryMulti to open a pinned session for sequential fallback") + } + if !fakeDB.session.closed { + t.Fatal("expected DBQueryMulti to close the pinned session") + } + if fakeDB.session.execCalls != 0 { + t.Fatalf("expected SQL Server SET statement to avoid exec-only path, got execCalls=%d", fakeDB.session.execCalls) + } + if fakeDB.session.queryCalls != 2 { + t.Fatalf("expected both statements to query through pinned session, got queryCalls=%d", fakeDB.session.queryCalls) + } + if fakeDB.queryCalls != 2 { + t.Fatalf("expected exactly two underlying query calls, got %d", fakeDB.queryCalls) + } + resultSets, ok := result.Data.([]connection.ResultSetData) + if !ok { + t.Fatalf("expected []connection.ResultSetData, got %T", result.Data) + } + if len(resultSets) != 2 { + t.Fatalf("expected two result sets, got %#v", resultSets) + } + if len(resultSets[0].Messages) != 1 || resultSets[0].Messages[0] != "NOCOUNT 已开启" { + t.Fatalf("expected first result set to keep session message, got %#v", resultSets[0].Messages) + } + if got := resultSets[1].Rows[0]["value"]; got != 1 { + t.Fatalf("expected second result set value=1, got %#v", got) + } +} + +func TestDBQueryMultiKeepsAllResultSetsFromSingleSQLServerStatement(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + query := "EXEC sp_helpdb" + fakeDB := &fakeBatchWriteDB{ + multiResult: map[string][]connection.ResultSetData{ + query: { + { + Rows: []map[string]interface{}{{"name": "master"}}, + Columns: []string{"name"}, + }, + { + Rows: []map[string]interface{}{{"owner": "sa"}}, + Columns: []string{"owner"}, + }, + }, + }, + queryErr: map[string]error{}, + } + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"} + + result := app.DBQueryMulti(config, "master", query, "sp-helpdb-multi-result-test") + if !result.Success { + t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message) + } + resultSets, ok := result.Data.([]connection.ResultSetData) + if !ok { + t.Fatalf("expected []connection.ResultSetData, got %T", result.Data) + } + if len(resultSets) != 2 { + t.Fatalf("expected two result sets, got %#v", resultSets) + } + if got := resultSets[0].Rows[0]["name"]; got != "master" { + t.Fatalf("expected first result set to keep master row, got %#v", got) + } + if got := resultSets[1].Rows[0]["owner"]; got != "sa" { + t.Fatalf("expected second result set to keep owner row, got %#v", got) + } + if resultSets[0].StatementIndex != 1 || resultSets[1].StatementIndex != 1 { + t.Fatalf("expected both result sets to map to the first statement, got %#v", resultSets) + } + if fakeDB.execCalls != 0 { + t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls) + } +} diff --git a/internal/app/sql_sanitize.go b/internal/app/sql_sanitize.go index 00066de..06ae248 100644 --- a/internal/app/sql_sanitize.go +++ b/internal/app/sql_sanitize.go @@ -65,6 +65,19 @@ func isReadOnlySQLQuery(dbType string, query string) bool { } } +func isBatchableWriteSQLStatement(dbType string, query string) bool { + if isReadOnlySQLQuery(dbType, query) { + return false + } + + switch leadingSQLKeyword(query) { + case "insert", "update", "delete", "replace", "merge", "upsert": + return true + default: + return false + } +} + func sanitizeSQLForPgLike(dbType string, query string) string { normalizedType := strings.ToLower(strings.TrimSpace(dbType)) switch normalizedType { diff --git a/internal/app/sql_sanitize_test.go b/internal/app/sql_sanitize_test.go index 85d2d24..8825e89 100644 --- a/internal/app/sql_sanitize_test.go +++ b/internal/app/sql_sanitize_test.go @@ -62,3 +62,42 @@ func TestSanitizeSQLForPgLike_DoesNotModifyOtherDBTypes(t *testing.T) { t.Fatalf("non-PG-like db should not be sanitized:\nIN: %s\nOUT: %s", in, out) } } + +func TestIsReadOnlySQLQuery_DoesNotTreatExecAsReadOnly(t *testing.T) { + if isReadOnlySQLQuery("sqlserver", "EXEC sp_who2") { + t.Fatal("EXEC should not be treated as read-only SQL") + } +} + +func TestIsBatchableWriteSQLStatement_OnlyMatchesRealWriteStatements(t *testing.T) { + if !isBatchableWriteSQLStatement("mysql", "INSERT INTO demo(id) VALUES (1)") { + t.Fatal("expected INSERT to be treated as batchable write") + } + if isBatchableWriteSQLStatement("sqlserver", "EXEC sp_who2") { + t.Fatal("EXEC should not be treated as batchable write") + } + if isBatchableWriteSQLStatement("sqlserver", "SET STATISTICS IO ON") { + t.Fatal("SET STATISTICS should not be treated as batchable write") + } +} + +func TestShouldTryQueryResultFirst_TreatsSQLServerSetAsQueryFirst(t *testing.T) { + if !shouldTryQueryResultFirst("sqlserver", "SET STATISTICS IO ON") { + t.Fatal("expected SQL Server SET STATISTICS to try query-first for notice capture") + } + if shouldTryQueryResultFirst("mysql", "SET sql_mode = ''") { + t.Fatal("non-SQLServer SET should not force query-first") + } +} + +func TestShouldTryQueryResultFirst_TreatsSQLServerSystemCommandsAsQueryFirst(t *testing.T) { + if !shouldTryQueryResultFirst("sqlserver", "sp_who2") { + t.Fatal("expected bare SQL Server system procedure to try query-first") + } + if !shouldTryQueryResultFirst("sqlserver", "DBCC INPUTBUFFER(52)") { + t.Fatal("expected SQL Server DBCC command to try query-first") + } + if shouldTryQueryResultFirst("mysql", "sp_who2") { + t.Fatal("non-SQLServer system procedure name should not force query-first") + } +} diff --git a/internal/app/window_zoom_windows_test.go b/internal/app/window_zoom_windows_test.go index 83cb4b3..6f4abe1 100644 --- a/internal/app/window_zoom_windows_test.go +++ b/internal/app/window_zoom_windows_test.go @@ -45,7 +45,8 @@ func (p *panicChromium) PutZoomFactor(float64) { } type panicFrontend struct { - chromium *panicChromium + chromium *panicChromium + mainWindow *fakeWindow } // 测试必须用 wails 一致的 string key "frontend" 作为 context.WithValue 的 key, @@ -122,7 +123,10 @@ func TestResetWebViewZoomFactorErrorsWhenFrontendMissing(t *testing.T) { } func TestResetWebViewZoomFactorRecoversFromPutZoomFactorPanic(t *testing.T) { - ctx := context.WithValue(context.Background(), stringContextKey("frontend"), &panicFrontend{chromium: &panicChromium{}}) + ctx := context.WithValue(context.Background(), stringContextKey("frontend"), &panicFrontend{ + chromium: &panicChromium{}, + mainWindow: &fakeWindow{}, + }) err := resetWebViewZoomFactor(ctx, 1.0) if err == nil { diff --git a/internal/connection/types.go b/internal/connection/types.go index 8ecc96f..3b13565 100644 --- a/internal/connection/types.go +++ b/internal/connection/types.go @@ -122,17 +122,20 @@ type ConnectionConfig struct { // ResultSetData 表示一个查询结果集(行 + 列名),用于多结果集场景。 type ResultSetData struct { - Rows []map[string]interface{} `json:"rows"` - Columns []string `json:"columns"` + Rows []map[string]interface{} `json:"rows"` + Columns []string `json:"columns"` + Messages []string `json:"messages,omitempty"` + StatementIndex int `json:"statementIndex,omitempty"` } // QueryResult 是 Wails 绑定方法的统一响应格式,前端通过此结构体接收后端结果。 type QueryResult struct { - Success bool `json:"success"` - Message string `json:"message"` - Data interface{} `json:"data"` - Fields []string `json:"fields,omitempty"` - QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation + Success bool `json:"success"` + Message string `json:"message"` + Data interface{} `json:"data"` + Fields []string `json:"fields,omitempty"` + Messages []string `json:"messages,omitempty"` + QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation } // ColumnDefinition 描述表的一个列定义。 diff --git a/internal/db/database.go b/internal/db/database.go index 7472e69..e1cd184 100644 --- a/internal/db/database.go +++ b/internal/db/database.go @@ -67,6 +67,51 @@ type StatementExecer interface { Close() error } +// StatementQueryExecer can run queries on a pinned session/connection. +// Drivers that return sqlConnStatementExecer automatically satisfy it. +type StatementQueryExecer interface { + StatementExecer + Query(query string) ([]map[string]interface{}, []string, error) + QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) +} + +// StatementQueryMessageExecer can run queries on a pinned session and return +// extra server messages/notices alongside rows. +type StatementQueryMessageExecer interface { + StatementQueryExecer + QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) + QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) +} + +// StatementMultiResultQueryExecer can run multi-result queries on a pinned session/connection. +type StatementMultiResultQueryExecer interface { + StatementExecer + QueryMulti(query string) ([]connection.ResultSetData, error) + QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) +} + +// StatementMultiResultQueryMessageExecer can run multi-result queries on a +// pinned session/connection and return server messages/notices. +type StatementMultiResultQueryMessageExecer interface { + StatementMultiResultQueryExecer + QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) + QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) +} + +// QueryMessageExecer is an optional database-level interface for returning +// informational server messages alongside one result set. +type QueryMessageExecer interface { + QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) + QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) +} + +// MultiResultQueryMessageExecer is an optional database-level interface for +// returning informational server messages alongside multi-result queries. +type MultiResultQueryMessageExecer interface { + QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) + QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) +} + // SessionExecerProvider is implemented by database/sql based drivers that can // pin a long-running job to one physical connection. type SessionExecerProvider interface { @@ -96,6 +141,38 @@ func (e *sqlConnStatementExecer) Exec(query string) (int64, error) { return e.ExecContext(context.Background(), query) } +func (e *sqlConnStatementExecer) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + if e == nil || e.conn == nil { + return nil, nil, fmt.Errorf("连接未打开") + } + rows, err := e.conn.QueryContext(ctx, query) + if err != nil { + return nil, nil, err + } + defer rows.Close() + return scanRows(rows) +} + +func (e *sqlConnStatementExecer) Query(query string) ([]map[string]interface{}, []string, error) { + return e.QueryContext(context.Background(), query) +} + +func (e *sqlConnStatementExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) { + if e == nil || e.conn == nil { + return nil, fmt.Errorf("连接未打开") + } + rows, err := e.conn.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + return scanMultiRows(rows) +} + +func (e *sqlConnStatementExecer) QueryMulti(query string) ([]connection.ResultSetData, error) { + return e.QueryMultiContext(context.Background(), query) +} + func (e *sqlConnStatementExecer) ExecBatchContext(ctx context.Context, query string) (int64, error) { return e.ExecContext(ctx, query) } diff --git a/internal/db/sqlserver_impl.go b/internal/db/sqlserver_impl.go index ffa927a..ef4597b 100644 --- a/internal/db/sqlserver_impl.go +++ b/internal/db/sqlserver_impl.go @@ -17,6 +17,7 @@ import ( "GoNavi-Wails/internal/ssh" "GoNavi-Wails/internal/utils" + "github.com/golang-sql/sqlexp" _ "github.com/microsoft/go-mssqldb" ) @@ -26,6 +27,85 @@ type SqlServerDB struct { forwarder *ssh.LocalForwarder } +type sqlServerSessionExecer struct { + conn *sql.Conn +} + +func scanSQLServerRowsWithMessages(ctx context.Context, rows *sql.Rows, retmsg *sqlexp.ReturnMessage) ([]connection.ResultSetData, []string, error) { + if rows == nil { + return []connection.ResultSetData{{Rows: []map[string]interface{}{}, Columns: []string{}}}, nil, nil + } + if ctx == nil { + ctx = context.Background() + } + + var ( + resultSets []connection.ResultSetData + messages []string + allMessages []string + ) + active := true + for active { + raw := retmsg.Message(ctx) + switch msg := raw.(type) { + case sqlexp.MsgNotice: + text := strings.TrimSpace(fmt.Sprint(msg.Message)) + if text != "" { + messages = append(messages, text) + allMessages = append(allMessages, text) + } + case sqlexp.MsgNext: + data, cols, err := scanRows(rows) + if err != nil { + return resultSets, messages, err + } + if data == nil { + data = []map[string]interface{}{} + } + if cols == nil { + cols = []string{} + } + resultSets = append(resultSets, connection.ResultSetData{ + Rows: data, + Columns: cols, + Messages: append([]string(nil), messages...), + }) + messages = nil + case sqlexp.MsgRowsAffected: + resultSets = append(resultSets, connection.ResultSetData{ + Rows: []map[string]interface{}{{"affectedRows": msg.Count}}, + Columns: []string{"affectedRows"}, + Messages: append([]string(nil), messages...), + }) + messages = nil + case sqlexp.MsgNextResultSet: + active = rows.NextResultSet() + case sqlexp.MsgError: + return resultSets, messages, msg.Error + default: + active = false + } + } + + if len(messages) > 0 { + resultSets = append(resultSets, connection.ResultSetData{ + Rows: []map[string]interface{}{}, + Columns: []string{}, + Messages: append([]string(nil), messages...), + }) + } + if len(resultSets) == 0 { + resultSets = []connection.ResultSetData{{ + Rows: []map[string]interface{}{}, + Columns: []string{}, + }} + } + if err := rows.Err(); err != nil { + return resultSets, allMessages, err + } + return resultSets, allMessages, nil +} + // quoteBracket escapes ] in identifiers for safe use in SQL Server [bracket] notation func quoteBracket(name string) string { return strings.ReplaceAll(name, "]", "]]") @@ -133,54 +213,76 @@ func (s *SqlServerDB) Ping() error { } func (s *SqlServerDB) QueryMulti(query string) ([]connection.ResultSetData, error) { + results, _, err := s.QueryMultiWithMessages(query) + return results, err +} + +func (s *SqlServerDB) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) { if s.conn == nil { - return nil, fmt.Errorf("连接未打开") + return nil, nil, fmt.Errorf("连接未打开") } - rows, err := s.conn.Query(query) + ctx := context.Background() + retmsg := &sqlexp.ReturnMessage{} + rows, err := s.conn.QueryContext(ctx, query, retmsg) if err != nil { - return nil, err + return nil, nil, err } defer rows.Close() - return scanMultiRows(rows) + return scanSQLServerRowsWithMessages(ctx, rows, retmsg) } func (s *SqlServerDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) { + results, _, err := s.QueryMultiContextWithMessages(ctx, query) + return results, err +} + +func (s *SqlServerDB) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) { if s.conn == nil { - return nil, fmt.Errorf("连接未打开") + return nil, nil, fmt.Errorf("连接未打开") } - rows, err := s.conn.QueryContext(ctx, query) + retmsg := &sqlexp.ReturnMessage{} + rows, err := s.conn.QueryContext(ctx, query, retmsg) if err != nil { - return nil, err + return nil, nil, err } defer rows.Close() - return scanMultiRows(rows) + return scanSQLServerRowsWithMessages(ctx, rows, retmsg) } func (s *SqlServerDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + rows, columns, _, err := s.QueryContextWithMessages(ctx, query) + return rows, columns, err +} + +func (s *SqlServerDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { if s.conn == nil { - return nil, nil, fmt.Errorf("连接未打开") + return nil, nil, nil, fmt.Errorf("连接未打开") } - rows, err := s.conn.QueryContext(ctx, query) + resultSets, messages, err := s.QueryMultiContextWithMessages(ctx, query) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - defer rows.Close() - - return scanRows(rows) + if len(resultSets) == 0 { + return []map[string]interface{}{}, []string{}, messages, nil + } + first := resultSets[0] + if first.Rows == nil { + first.Rows = []map[string]interface{}{} + } + if first.Columns == nil { + first.Columns = []string{} + } + return first.Rows, first.Columns, messages, nil } func (s *SqlServerDB) Query(query string) ([]map[string]interface{}, []string, error) { - if s.conn == nil { - return nil, nil, fmt.Errorf("连接未打开") - } + rows, columns, _, err := s.QueryWithMessages(query) + return rows, columns, err +} - rows, err := s.conn.Query(query) - if err != nil { - return nil, nil, err - } - defer rows.Close() - return scanRows(rows) +func (s *SqlServerDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + return s.QueryContextWithMessages(context.Background(), query) } func (s *SqlServerDB) ExecContext(ctx context.Context, query string) (int64, error) { @@ -213,7 +315,7 @@ func (s *SqlServerDB) OpenSessionExecer(ctx context.Context) (StatementExecer, e if err != nil { return nil, err } - return NewSQLConnStatementExecer(conn), nil + return &sqlServerSessionExecer{conn: conn}, nil } func (s *SqlServerDB) Exec(query string) (int64, error) { @@ -227,6 +329,87 @@ func (s *SqlServerDB) Exec(query string) (int64, error) { return res.RowsAffected() } +func (e *sqlServerSessionExecer) Exec(query string) (int64, error) { + return e.ExecContext(context.Background(), query) +} + +func (e *sqlServerSessionExecer) ExecContext(ctx context.Context, query string) (int64, error) { + if e == nil || e.conn == nil { + return 0, fmt.Errorf("连接未打开") + } + res, err := e.conn.ExecContext(ctx, query) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (e *sqlServerSessionExecer) Query(query string) ([]map[string]interface{}, []string, error) { + rows, columns, _, err := e.QueryWithMessages(query) + return rows, columns, err +} + +func (e *sqlServerSessionExecer) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + rows, columns, _, err := e.QueryContextWithMessages(ctx, query) + return rows, columns, err +} + +func (e *sqlServerSessionExecer) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + return e.QueryContextWithMessages(context.Background(), query) +} + +func (e *sqlServerSessionExecer) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { + results, messages, err := e.QueryMultiContextWithMessages(ctx, query) + if err != nil { + return nil, nil, nil, err + } + if len(results) == 0 { + return []map[string]interface{}{}, []string{}, messages, nil + } + first := results[0] + if first.Rows == nil { + first.Rows = []map[string]interface{}{} + } + if first.Columns == nil { + first.Columns = []string{} + } + return first.Rows, first.Columns, messages, nil +} + +func (e *sqlServerSessionExecer) QueryMulti(query string) ([]connection.ResultSetData, error) { + results, _, err := e.QueryMultiWithMessages(query) + return results, err +} + +func (e *sqlServerSessionExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) { + results, _, err := e.QueryMultiContextWithMessages(ctx, query) + return results, err +} + +func (e *sqlServerSessionExecer) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) { + return e.QueryMultiContextWithMessages(context.Background(), query) +} + +func (e *sqlServerSessionExecer) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) { + if e == nil || e.conn == nil { + return nil, nil, fmt.Errorf("连接未打开") + } + retmsg := &sqlexp.ReturnMessage{} + rows, err := e.conn.QueryContext(ctx, query, retmsg) + if err != nil { + return nil, nil, err + } + defer rows.Close() + return scanSQLServerRowsWithMessages(ctx, rows, retmsg) +} + +func (e *sqlServerSessionExecer) Close() error { + if e == nil || e.conn == nil { + return nil + } + return e.conn.Close() +} + func (s *SqlServerDB) GetDatabases() ([]string, error) { query := "SELECT name FROM sys.databases WHERE state_desc = 'ONLINE' ORDER BY name" data, _, err := s.Query(query)