mirror of
https://github.com/Syngnat/GoNavi.git
synced 2026-06-14 18:39:54 +08:00
🐛 fix(sql-editor): 修复结果消息展示与数据目录迁移稳定性
This commit is contained in:
@@ -19,14 +19,20 @@ describe('ConnectionModal edit password behavior', () => {
|
||||
|
||||
describe('ConnectionModal data source registry', () => {
|
||||
it('exposes Elasticsearch in the create-connection picker with HTTP defaults', () => {
|
||||
expect(source).toContain('case "elasticsearch":\n return 9200;');
|
||||
expect(source).toContain('case "elasticsearch":');
|
||||
expect(source).toContain('return 9200;');
|
||||
expect(source).toContain('elasticsearch: ["http", "https"]');
|
||||
expect(source).toContain('key: "elasticsearch"');
|
||||
expect(source).toContain('name: "Elasticsearch"');
|
||||
expect(source).toContain('getDbIcon("elasticsearch", undefined, 36)');
|
||||
expect(source).toContain('type === "elasticsearch"');
|
||||
expect(source).toContain('"http://elastic:pass@127.0.0.1:9200/logs-*"');
|
||||
expect(source).toContain('label="默认索引(可选)"');
|
||||
expect(source).toContain('"显示索引 (留空显示全部)"');
|
||||
expect(source).toContain('return "支持索引浏览、Mapping 检查、JSON DSL 和 query_string 查询";');
|
||||
expect(source).toContain(
|
||||
'type === "clickhouse" ? "default" : (type === "redis" || type === "elasticsearch") ? "" : "root";',
|
||||
);
|
||||
expect(source).toContain(
|
||||
'placeholder={dbType === "elasticsearch" ? "未开启认证可留空" : undefined}',
|
||||
);
|
||||
expect(source).toContain('label="显示数据库 (留空显示全部)"');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -234,9 +234,31 @@ vi.mock('antd', () => {
|
||||
</section>
|
||||
) : null
|
||||
);
|
||||
Modal.useModal = () => [{ info: vi.fn(() => ({ destroy: vi.fn() })) }, null];
|
||||
Modal.useModal = () => {
|
||||
const [infoConfig, setInfoConfig] = React.useState<any>(null);
|
||||
return [{
|
||||
info: vi.fn((config: any) => {
|
||||
setInfoConfig(config);
|
||||
return {
|
||||
destroy: vi.fn(() => {
|
||||
setInfoConfig(null);
|
||||
}),
|
||||
};
|
||||
}),
|
||||
}, infoConfig ? <section data-modal-use-holder="true">{infoConfig.content}</section> : null];
|
||||
};
|
||||
|
||||
const passthrough = ({ children }: any) => <>{children}</>;
|
||||
const Dropdown = ({ children, menu, disabled }: any) => (
|
||||
<>
|
||||
{children}
|
||||
{!disabled && menu?.items?.map((item: any) => (
|
||||
item?.type === 'divider'
|
||||
? null
|
||||
: <button key={item.key} type="button" disabled={item.disabled} onClick={item.onClick}>{item.label}</button>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
const Space = ({ children }: any) => <div>{children}</div>;
|
||||
const Tabs = ({ items = [], activeKey, onChange }: any) => {
|
||||
const resolvedActiveKey = activeKey ?? items[0]?.key;
|
||||
@@ -289,7 +311,7 @@ vi.mock('antd', () => {
|
||||
message: messageApi,
|
||||
Input,
|
||||
Button,
|
||||
Dropdown: passthrough,
|
||||
Dropdown,
|
||||
Form,
|
||||
Pagination: () => null,
|
||||
Select: ({ children }: any) => <div>{children}</div>,
|
||||
@@ -806,6 +828,49 @@ describe('DataGrid DDL interactions', () => {
|
||||
renderer!.unmount();
|
||||
});
|
||||
|
||||
it('exports query-result rows from in-memory data without rerunning ExportQuery', async () => {
|
||||
backendApp.ExportData.mockResolvedValue({ success: true });
|
||||
backendApp.ExportQuery.mockResolvedValue({ success: true });
|
||||
|
||||
let renderer: ReactTestRenderer;
|
||||
await act(async () => {
|
||||
renderer = create(
|
||||
<DataGrid
|
||||
data={[
|
||||
{ __gonavi_row_key__: 'row-1', owner: 'sa' },
|
||||
{ __gonavi_row_key__: 'row-2', owner: 'dbo' },
|
||||
]}
|
||||
columnNames={['owner']}
|
||||
loading={false}
|
||||
exportScope="queryResult"
|
||||
resultSql="EXEC sp_helpdb"
|
||||
dbName="master"
|
||||
connectionId="conn-1"
|
||||
/>,
|
||||
);
|
||||
});
|
||||
await waitForEffects();
|
||||
|
||||
await act(async () => {
|
||||
await findButton(renderer!, 'HTML').props.onClick();
|
||||
});
|
||||
|
||||
const exportAllButton = findButton(renderer!, '全部导出');
|
||||
await act(async () => {
|
||||
await exportAllButton.props.onClick();
|
||||
});
|
||||
await waitForEffects();
|
||||
|
||||
expect(backendApp.ExportData).toHaveBeenCalledTimes(1);
|
||||
expect(backendApp.ExportData).toHaveBeenCalledWith(
|
||||
[{ owner: 'sa' }, { owner: 'dbo' }],
|
||||
['owner'],
|
||||
'export',
|
||||
'html',
|
||||
);
|
||||
expect(backendApp.ExportQuery).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('copies loaded column data from the v2 column header context menu', async () => {
|
||||
storeState.appearance.uiVersion = 'v2';
|
||||
|
||||
|
||||
@@ -13,7 +13,9 @@ describe('DatabaseIcons', () => {
|
||||
it('includes Elasticsearch in the selectable database icons', () => {
|
||||
expect(DB_ICON_TYPES).toContain('elasticsearch');
|
||||
expect(getDbIconLabel('elasticsearch')).toBe('Elasticsearch');
|
||||
expect(renderToStaticMarkup(<>{getDbIcon('elasticsearch', undefined, 22)}</>)).toContain('ES');
|
||||
const markup = renderToStaticMarkup(<>{getDbIcon('elasticsearch', undefined, 22)}</>);
|
||||
expect(markup).toContain('elasticsearch.svg');
|
||||
expect(markup).toContain('alt="elasticsearch"');
|
||||
});
|
||||
|
||||
it('wraps database icons in a consistent frame for sidebar sizing', () => {
|
||||
|
||||
@@ -28,6 +28,13 @@ const baseState = {
|
||||
],
|
||||
jvmDiagnosticDrafts: {},
|
||||
jvmDiagnosticOutputs: {},
|
||||
fontSize: 14,
|
||||
appearance: {
|
||||
uiVersion: "legacy",
|
||||
dataTableFontSize: 14,
|
||||
dataTableFontSizeFollowGlobal: true,
|
||||
customMonoFontFamily: "",
|
||||
},
|
||||
setJVMDiagnosticDraft: vi.fn(),
|
||||
appendJVMDiagnosticOutput: vi.fn(),
|
||||
clearJVMDiagnosticOutput: vi.fn(),
|
||||
@@ -62,6 +69,7 @@ const mockMonaco = {
|
||||
KeyMod: { CtrlCmd: 2048 },
|
||||
KeyCode: { Enter: 3 },
|
||||
editor: {
|
||||
defineTheme: vi.fn(),
|
||||
setTheme: vi.fn(),
|
||||
},
|
||||
languages: {
|
||||
@@ -193,6 +201,7 @@ describe("JVMDiagnosticConsole", () => {
|
||||
removeEventListener: vi.fn(),
|
||||
};
|
||||
mockMonaco.editor.setTheme.mockClear();
|
||||
mockMonaco.editor.defineTheme.mockClear();
|
||||
mockMonaco.languages.register.mockClear();
|
||||
mockMonaco.languages.registerCompletionItemProvider.mockClear();
|
||||
mockEditor.addCommand.mockClear();
|
||||
|
||||
@@ -29,6 +29,13 @@ const storeState = vi.hoisted(() => ({
|
||||
aiPanelVisible: false,
|
||||
setAIPanelVisible: vi.fn(),
|
||||
theme: "light",
|
||||
fontSize: 14,
|
||||
appearance: {
|
||||
uiVersion: "legacy",
|
||||
dataTableFontSize: 14,
|
||||
dataTableFontSizeFollowGlobal: true,
|
||||
customMonoFontFamily: "",
|
||||
},
|
||||
}));
|
||||
|
||||
const backendApp = vi.hoisted(() => ({
|
||||
|
||||
@@ -80,6 +80,10 @@ const dataGridState = vi.hoisted(() => ({
|
||||
latestProps: null as any,
|
||||
}));
|
||||
|
||||
const tabsState = vi.hoisted(() => ({
|
||||
activeKey: undefined as string | undefined,
|
||||
}));
|
||||
|
||||
const autoFetchState = vi.hoisted(() => ({
|
||||
visible: false,
|
||||
}));
|
||||
@@ -345,11 +349,24 @@ vi.mock('antd', () => {
|
||||
),
|
||||
Tooltip: ({ children }: any) => <>{children}</>,
|
||||
Select: () => null,
|
||||
Tabs: ({ activeKey, items }: any) => {
|
||||
const activeItem = items?.find((item: any) => item.key === activeKey) || items?.[0];
|
||||
Tabs: ({ activeKey, items, onChange }: any) => {
|
||||
const resolvedActiveKey = tabsState.activeKey ?? activeKey ?? items?.[0]?.key;
|
||||
const activeItem = items?.find((item: any) => item.key === resolvedActiveKey) || items?.[0];
|
||||
return (
|
||||
<div>
|
||||
<div>{items?.map((item: any) => <span key={item.key}>{item.label}</span>)}</div>
|
||||
<div>{items?.map((item: any) => (
|
||||
<button
|
||||
key={item.key}
|
||||
type="button"
|
||||
data-tab-key={item.key}
|
||||
onClick={() => {
|
||||
tabsState.activeKey = item.key;
|
||||
onChange?.(item.key);
|
||||
}}
|
||||
>
|
||||
{item.label}
|
||||
</button>
|
||||
))}</div>
|
||||
<div>{activeItem?.children}</div>
|
||||
</div>
|
||||
);
|
||||
@@ -423,6 +440,7 @@ describe('QueryEditor external SQL save', () => {
|
||||
storeState.appearance.uiVersion = 'legacy';
|
||||
autoFetchState.visible = false;
|
||||
dataGridState.latestProps = null;
|
||||
tabsState.activeKey = undefined;
|
||||
editorState.value = '';
|
||||
editorState.position = { lineNumber: 1, column: 1 };
|
||||
editorState.selection = null;
|
||||
@@ -1788,6 +1806,198 @@ describe('QueryEditor external SQL save', () => {
|
||||
renderer?.unmount();
|
||||
});
|
||||
|
||||
it('renders result grid for sqlserver exec statements that return rows', async () => {
|
||||
storeState.connections[0].config.type = 'sqlserver';
|
||||
storeState.connections[0].config.database = 'master';
|
||||
backendApp.DBQueryMulti.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [{ columns: ['SPID', 'STATUS'], rows: [{ SPID: 52, STATUS: 'RUNNABLE' }] }],
|
||||
});
|
||||
|
||||
let renderer!: ReactTestRenderer;
|
||||
await act(async () => {
|
||||
renderer = create(<QueryEditor tab={createTab({ dbName: 'master', query: 'EXEC sp_who2' })} />);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await findButton(renderer!, '运行').props.onClick();
|
||||
});
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(textContent(renderer!.toJSON())).toContain('结果 1');
|
||||
expect(textContent(renderer!.toJSON())).not.toContain('影响行数:');
|
||||
expect(dataGridState.latestProps?.columnNames).toEqual(['SPID', 'STATUS']);
|
||||
expect(Array.isArray(dataGridState.latestProps?.data)).toBe(true);
|
||||
expect(dataGridState.latestProps?.data?.[0]).toMatchObject({ SPID: 52, STATUS: 'RUNNABLE' });
|
||||
});
|
||||
|
||||
it('renders standalone message result for sqlserver statistics statements', async () => {
|
||||
storeState.connections[0].config.type = 'sqlserver';
|
||||
storeState.connections[0].config.database = 'master';
|
||||
backendApp.DBQueryMulti.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [{
|
||||
columns: [],
|
||||
rows: [],
|
||||
messages: ["Table 'users'. Scan count 1, logical reads 3."],
|
||||
}],
|
||||
});
|
||||
|
||||
let renderer!: ReactTestRenderer;
|
||||
await act(async () => {
|
||||
renderer = create(<QueryEditor tab={createTab({ dbName: 'master', query: 'SET STATISTICS IO ON;' })} />);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await findButton(renderer!, '运行').props.onClick();
|
||||
});
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(textContent(renderer!.toJSON())).toContain('消息 1');
|
||||
expect(textContent(renderer!.toJSON())).toContain("Table 'users'. Scan count 1, logical reads 3.");
|
||||
expect(dataGridState.latestProps?.columnNames).not.toEqual([]);
|
||||
});
|
||||
|
||||
it('keeps multiple result sets from a single sqlserver statement', async () => {
|
||||
storeState.connections[0].config.type = 'sqlserver';
|
||||
storeState.connections[0].config.database = 'master';
|
||||
backendApp.DBQueryMulti.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [
|
||||
{ statementIndex: 1, columns: ['name'], rows: [{ name: 'master' }] },
|
||||
{ statementIndex: 1, columns: ['owner'], rows: [{ owner: 'sa' }] },
|
||||
],
|
||||
});
|
||||
|
||||
let renderer!: ReactTestRenderer;
|
||||
await act(async () => {
|
||||
renderer = create(<QueryEditor tab={createTab({ dbName: 'master', query: 'EXEC sp_helpdb' })} />);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await findButton(renderer!, '运行').props.onClick();
|
||||
});
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(textContent(renderer!.toJSON())).toContain('结果 1');
|
||||
expect(textContent(renderer!.toJSON())).toContain('结果 2');
|
||||
expect(dataGridState.latestProps?.columnNames).toEqual(['name']);
|
||||
});
|
||||
|
||||
it('keeps both tabs when rerunning the same single sqlserver statement with multiple result sets', async () => {
|
||||
storeState.connections[0].config.type = 'sqlserver';
|
||||
storeState.connections[0].config.database = 'master';
|
||||
backendApp.DBQueryMulti
|
||||
.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [
|
||||
{ statementIndex: 1, columns: ['name'], rows: [{ name: 'master' }] },
|
||||
{ statementIndex: 1, columns: ['owner'], rows: [{ owner: 'sa' }] },
|
||||
],
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [
|
||||
{ statementIndex: 1, columns: ['name'], rows: [{ name: 'tempdb' }] },
|
||||
{ statementIndex: 1, columns: ['owner'], rows: [{ owner: 'dbo' }] },
|
||||
],
|
||||
});
|
||||
|
||||
let renderer!: ReactTestRenderer;
|
||||
await act(async () => {
|
||||
renderer = create(<QueryEditor tab={createTab({ dbName: 'master', query: 'EXEC sp_helpdb' })} />);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await findButton(renderer!, '运行').props.onClick();
|
||||
});
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await findButton(renderer!, '运行').props.onClick();
|
||||
});
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
const tabLabels = renderer!.root.findAll((node) => {
|
||||
const className = String(node.props?.className || '');
|
||||
return className.includes('query-result-tab-label');
|
||||
});
|
||||
expect(tabLabels).toHaveLength(2);
|
||||
expect(dataGridState.latestProps?.columnNames).toEqual(['name']);
|
||||
expect(dataGridState.latestProps?.data?.[0]).toMatchObject({ name: 'tempdb' });
|
||||
});
|
||||
|
||||
it('reloads the active secondary result set for a single sqlserver statement', async () => {
|
||||
storeState.connections[0].config.type = 'sqlserver';
|
||||
storeState.connections[0].config.database = 'master';
|
||||
backendApp.DBQueryMulti
|
||||
.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [
|
||||
{ statementIndex: 1, columns: ['name'], rows: [{ name: 'master' }] },
|
||||
{ statementIndex: 1, columns: ['owner'], rows: [{ owner: 'sa' }] },
|
||||
],
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
success: true,
|
||||
data: [
|
||||
{ statementIndex: 1, columns: ['name'], rows: [{ name: 'master' }] },
|
||||
{ statementIndex: 1, columns: ['owner'], rows: [{ owner: 'dbo' }] },
|
||||
],
|
||||
});
|
||||
|
||||
let renderer!: ReactTestRenderer;
|
||||
await act(async () => {
|
||||
renderer = create(<QueryEditor tab={createTab({ dbName: 'master', query: 'EXEC sp_helpdb' })} />);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await findButton(renderer!, '运行').props.onClick();
|
||||
});
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
const resultTabButtons = renderer!.root.findAll((node) => node.type === 'button' && node.props['data-tab-key']);
|
||||
expect(resultTabButtons).toHaveLength(2);
|
||||
|
||||
await act(async () => {
|
||||
resultTabButtons[1].props.onClick();
|
||||
});
|
||||
|
||||
expect(dataGridState.latestProps?.columnNames).toEqual(['owner']);
|
||||
expect(dataGridState.latestProps?.data?.[0]).toMatchObject({ owner: 'sa' });
|
||||
|
||||
await act(async () => {
|
||||
await dataGridState.latestProps?.onReload?.();
|
||||
});
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(backendApp.DBQueryMulti).toHaveBeenCalledTimes(2);
|
||||
expect(dataGridState.latestProps?.columnNames).toEqual(['owner']);
|
||||
expect(dataGridState.latestProps?.data?.[0]).toMatchObject({ owner: 'dbo' });
|
||||
expect(dataGridState.latestProps?.data).not.toEqual(expect.arrayContaining([expect.objectContaining({ name: 'master' })]));
|
||||
});
|
||||
|
||||
it('keeps non-Oracle query results read-only when no safe locator exists', async () => {
|
||||
backendApp.DBQueryMulti.mockResolvedValueOnce({
|
||||
success: true,
|
||||
|
||||
@@ -1833,8 +1833,12 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
key: string;
|
||||
sql: string;
|
||||
exportSql?: string;
|
||||
sourceStatementIndex?: number;
|
||||
statementResultIndex?: number;
|
||||
rows: any[];
|
||||
columns: string[];
|
||||
messages?: string[];
|
||||
resultType?: 'grid' | 'message';
|
||||
tableName?: string;
|
||||
pkColumns: string[];
|
||||
editLocator?: EditRowLocator;
|
||||
@@ -3924,6 +3928,13 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
return selected;
|
||||
};
|
||||
|
||||
const buildResultSetMergeKey = (result: ResultSet): string => {
|
||||
const sqlKey = normalizeExecutedSqlKey(result.exportSql || result.sql);
|
||||
const sourceStatementIndex = Number(result.sourceStatementIndex || 1);
|
||||
const statementResultIndex = Number(result.statementResultIndex || 1);
|
||||
return `${sqlKey}::${sourceStatementIndex}::${statementResultIndex}`;
|
||||
};
|
||||
|
||||
const mergeResultSets = (previous: ResultSet[], next: ResultSet[], replaceAll: boolean): ResultSet[] => {
|
||||
if (replaceAll || previous.length === 0) {
|
||||
return next.map((result, index) => ({ ...result, key: `result-${index + 1}` }));
|
||||
@@ -3931,8 +3942,8 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
|
||||
const merged = [...previous];
|
||||
next.forEach((result) => {
|
||||
const incomingKey = normalizeExecutedSqlKey(result.exportSql || result.sql);
|
||||
const existingIndex = merged.findIndex((item) => normalizeExecutedSqlKey(item.exportSql || item.sql) === incomingKey);
|
||||
const incomingKey = buildResultSetMergeKey(result);
|
||||
const existingIndex = merged.findIndex((item) => buildResultSetMergeKey(item) === incomingKey);
|
||||
if (existingIndex >= 0) {
|
||||
merged[existingIndex] = { ...result, key: merged[existingIndex].key };
|
||||
return;
|
||||
@@ -3947,8 +3958,8 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
if (!firstExecutedResult) {
|
||||
return '';
|
||||
}
|
||||
const executedSqlKey = normalizeExecutedSqlKey(firstExecutedResult.exportSql || firstExecutedResult.sql);
|
||||
return merged.find((item) => normalizeExecutedSqlKey(item.exportSql || item.sql) === executedSqlKey)?.key
|
||||
const executedSqlKey = buildResultSetMergeKey(firstExecutedResult);
|
||||
return merged.find((item) => buildResultSetMergeKey(item) === executedSqlKey)?.key
|
||||
|| firstExecutedResult.key
|
||||
|| merged[0]?.key
|
||||
|| '';
|
||||
@@ -4024,6 +4035,8 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
if (!sql?.trim() || !currentDb) return;
|
||||
const conn = connections.find(c => c.id === currentConnectionId);
|
||||
if (!conn) return;
|
||||
const currentResult = resultSets.find((item) => item.key === resultKey);
|
||||
const statementResultIndex = Math.max(1, Number(currentResult?.statementResultIndex || 1));
|
||||
|
||||
const config = {
|
||||
...conn.config,
|
||||
@@ -4049,10 +4062,9 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
return;
|
||||
}
|
||||
|
||||
// 取第一个结果集(单条 SQL 只有一个结果集)
|
||||
const resultSetDataArray = Array.isArray(res.data) ? (res.data as any[]) : [];
|
||||
if (resultSetDataArray.length === 0) return;
|
||||
const rsData = resultSetDataArray[0];
|
||||
const rsData = resultSetDataArray[Math.max(0, statementResultIndex - 1)];
|
||||
if (!rsData) return;
|
||||
const isAffectedResult = Array.isArray(rsData.rows) && rsData.rows.length === 1
|
||||
&& rsData.columns && rsData.columns.length === 1
|
||||
&& rsData.columns[0] === 'affectedRows';
|
||||
@@ -4075,7 +4087,16 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
// 只更新匹配的结果集的 rows 和 columns,保留 tableName/pkColumns/readOnly 等元数据
|
||||
setResultSets(prev => prev.map(rs =>
|
||||
rs.key === resultKey
|
||||
? { ...rs, rows, columns: cols, truncated }
|
||||
? {
|
||||
...rs,
|
||||
rows,
|
||||
columns: cols,
|
||||
messages: Array.isArray(rsData.messages) ? rsData.messages : [],
|
||||
resultType: ((!Array.isArray(rsData.rows) || rsData.rows.length === 0) && (!Array.isArray(rsData.columns) || rsData.columns.length === 0) && Array.isArray(rsData.messages) && rsData.messages.length > 0)
|
||||
? 'message'
|
||||
: 'grid',
|
||||
truncated,
|
||||
}
|
||||
: rs
|
||||
));
|
||||
} catch (err: any) {
|
||||
@@ -4240,12 +4261,29 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
key: `result-${idx + 1}`,
|
||||
sql: rawStatement,
|
||||
exportSql: rawStatement,
|
||||
sourceStatementIndex: idx + 1,
|
||||
statementResultIndex: 1,
|
||||
rows,
|
||||
columns: cols,
|
||||
messages: Array.isArray(res.messages) ? res.messages : [],
|
||||
pkColumns: [],
|
||||
readOnly: true,
|
||||
truncated
|
||||
});
|
||||
} else if (Array.isArray(res.messages) && res.messages.length > 0) {
|
||||
nextResultSets.push({
|
||||
key: `result-${idx + 1}`,
|
||||
sql: rawStatement,
|
||||
exportSql: rawStatement,
|
||||
sourceStatementIndex: idx + 1,
|
||||
statementResultIndex: 1,
|
||||
rows: [],
|
||||
columns: [],
|
||||
messages: res.messages,
|
||||
resultType: 'message',
|
||||
pkColumns: [],
|
||||
readOnly: true,
|
||||
});
|
||||
} else {
|
||||
const affected = Number((res.data as any)?.affectedRows);
|
||||
if (Number.isFinite(affected)) {
|
||||
@@ -4255,8 +4293,11 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
key: `result-${idx + 1}`,
|
||||
sql: rawStatement,
|
||||
exportSql: rawStatement,
|
||||
sourceStatementIndex: idx + 1,
|
||||
statementResultIndex: 1,
|
||||
rows: [row],
|
||||
columns: ['affectedRows'],
|
||||
messages: Array.isArray(res.messages) ? res.messages : [],
|
||||
pkColumns: [],
|
||||
readOnly: true
|
||||
});
|
||||
@@ -4373,12 +4414,17 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
const nextResultSets: ResultSet[] = [];
|
||||
const maxRows = Number(queryOptions?.maxRows) || 0;
|
||||
let anyTruncated = false;
|
||||
const statementResultCounts = new Map<number, number>();
|
||||
|
||||
for (let idx = 0; idx < resultSetDataArray.length; idx++) {
|
||||
const rsData = resultSetDataArray[idx];
|
||||
const plan = executablePlans[idx];
|
||||
const sourceStatementIndex = Number(rsData?.statementIndex || idx + 1);
|
||||
const statementResultIndex = (statementResultCounts.get(sourceStatementIndex) || 0) + 1;
|
||||
statementResultCounts.set(sourceStatementIndex, statementResultIndex);
|
||||
const plan = executablePlans[Math.max(0, sourceStatementIndex - 1)];
|
||||
const originalSql = plan?.originalSql || '';
|
||||
const executedSql = plan?.executedSql || originalSql;
|
||||
const resultMessages = Array.isArray(rsData?.messages) ? rsData.messages : [];
|
||||
|
||||
// 检查是否为 affectedRows 类结果集
|
||||
const isAffectedResult = Array.isArray(rsData.rows) && rsData.rows.length === 1
|
||||
@@ -4393,11 +4439,28 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
key: `result-${idx + 1}`,
|
||||
sql: executedSql,
|
||||
exportSql: originalSql,
|
||||
sourceStatementIndex,
|
||||
statementResultIndex,
|
||||
rows: [row],
|
||||
columns: ['affectedRows'],
|
||||
messages: resultMessages,
|
||||
pkColumns: [],
|
||||
readOnly: true
|
||||
});
|
||||
} else if ((!Array.isArray(rsData.rows) || rsData.rows.length === 0) && (!Array.isArray(rsData.columns) || rsData.columns.length === 0) && resultMessages.length > 0) {
|
||||
nextResultSets.push({
|
||||
key: `result-${idx + 1}`,
|
||||
sql: executedSql,
|
||||
exportSql: originalSql,
|
||||
sourceStatementIndex,
|
||||
statementResultIndex,
|
||||
rows: [],
|
||||
columns: [],
|
||||
messages: resultMessages,
|
||||
resultType: 'message',
|
||||
pkColumns: [],
|
||||
readOnly: true,
|
||||
});
|
||||
} else {
|
||||
let rows = Array.isArray(rsData.rows) ? rsData.rows : [];
|
||||
let truncated = false;
|
||||
@@ -4421,8 +4484,11 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
key: `result-${idx + 1}`,
|
||||
sql: executedSql,
|
||||
exportSql: originalSql,
|
||||
sourceStatementIndex,
|
||||
statementResultIndex,
|
||||
rows,
|
||||
columns: cols,
|
||||
messages: resultMessages,
|
||||
tableName: tableRef?.tableName,
|
||||
pkColumns: plan?.pkColumns || [],
|
||||
editLocator,
|
||||
@@ -4432,6 +4498,22 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
}
|
||||
}
|
||||
|
||||
if (resultSetDataArray.length === 0 && Array.isArray(res.messages) && res.messages.length > 0) {
|
||||
nextResultSets.push({
|
||||
key: 'result-1',
|
||||
sql: fullSQL,
|
||||
exportSql: sourceStatements.join(';\n'),
|
||||
sourceStatementIndex: 1,
|
||||
statementResultIndex: 1,
|
||||
rows: [],
|
||||
columns: [],
|
||||
messages: res.messages,
|
||||
resultType: 'message',
|
||||
pkColumns: [],
|
||||
readOnly: true,
|
||||
});
|
||||
}
|
||||
|
||||
const shouldReplaceAllResults = didExecuteWholeEditor;
|
||||
setResultSets(prev => {
|
||||
const merged = mergeResultSets(prev, nextResultSets, shouldReplaceAllResults);
|
||||
@@ -5316,9 +5398,12 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
}}
|
||||
>
|
||||
<Tooltip title={rs.sql}>
|
||||
<span className="query-result-tab-text">结果 {idx + 1}</span>
|
||||
<span className="query-result-tab-text">{rs.resultType === 'message' ? `消息 ${idx + 1}` : `结果 ${idx + 1}`}</span>
|
||||
</Tooltip>
|
||||
{(() => {
|
||||
if (rs.resultType === 'message') {
|
||||
return <span className="query-result-tab-count">i</span>;
|
||||
}
|
||||
const isAffected = rs.columns.length === 1 && rs.columns[0] === 'affectedRows';
|
||||
if (isAffected) {
|
||||
return <span className="query-result-tab-count">✓</span>;
|
||||
@@ -5344,6 +5429,29 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
</Dropdown>
|
||||
),
|
||||
children: (() => {
|
||||
if (rs.resultType === 'message') {
|
||||
return (
|
||||
<div className={isV2Ui ? 'gn-v2-query-success' : undefined} style={{
|
||||
flex: 1, minHeight: 0, display: 'flex', justifyContent: 'center',
|
||||
flexDirection: 'column', gap: 12, padding: 24, color: '#666', userSelect: 'text',
|
||||
overflow: 'auto',
|
||||
}}>
|
||||
<span style={{ fontSize: 14, fontWeight: 600 }}>执行消息</span>
|
||||
<div style={{
|
||||
padding: 16,
|
||||
borderRadius: 8,
|
||||
border: darkMode ? '1px solid rgba(255,255,255,0.12)' : '1px solid rgba(0,0,0,0.08)',
|
||||
background: darkMode ? 'rgba(255,255,255,0.03)' : '#fff',
|
||||
whiteSpace: 'pre-wrap',
|
||||
wordBreak: 'break-word',
|
||||
fontFamily: 'var(--gn-font-mono)',
|
||||
fontSize: 'var(--gn-font-size-mono, 13px)',
|
||||
}}>
|
||||
{(rs.messages || []).join('\n')}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
// affectedRows 类型结果集(UPDATE/INSERT/DELETE):简洁提示
|
||||
const isAffectedResult = rs.columns.length === 1 && rs.columns[0] === 'affectedRows';
|
||||
if (isAffectedResult) {
|
||||
@@ -5356,11 +5464,44 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
|
||||
<span style={{ fontSize: 36, color: '#52c41a' }}>✓</span>
|
||||
<span style={{ fontSize: 14, fontWeight: 500 }}>执行成功</span>
|
||||
<span style={{ fontSize: 13, color: '#999' }}>影响行数:{affected}</span>
|
||||
{Array.isArray(rs.messages) && rs.messages.length > 0 && (
|
||||
<div style={{
|
||||
marginTop: 8,
|
||||
maxWidth: 720,
|
||||
padding: 12,
|
||||
borderRadius: 8,
|
||||
border: darkMode ? '1px solid rgba(255,255,255,0.12)' : '1px solid rgba(0,0,0,0.08)',
|
||||
background: darkMode ? 'rgba(255,255,255,0.03)' : '#fff',
|
||||
whiteSpace: 'pre-wrap',
|
||||
wordBreak: 'break-word',
|
||||
fontFamily: 'var(--gn-font-mono)',
|
||||
fontSize: 'var(--gn-font-size-mono, 12px)',
|
||||
}}>
|
||||
{rs.messages.join('\n')}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div style={{ flex: 1, minHeight: 0, overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
|
||||
{Array.isArray(rs.messages) && rs.messages.length > 0 && (
|
||||
<div style={{
|
||||
flex: '0 0 auto',
|
||||
margin: '8px 8px 0',
|
||||
padding: '10px 12px',
|
||||
borderRadius: 8,
|
||||
border: darkMode ? '1px solid rgba(255,255,255,0.12)' : '1px solid rgba(0,0,0,0.08)',
|
||||
background: darkMode ? 'rgba(255,255,255,0.03)' : '#fff',
|
||||
whiteSpace: 'pre-wrap',
|
||||
wordBreak: 'break-word',
|
||||
fontFamily: 'var(--gn-font-mono)',
|
||||
fontSize: 'var(--gn-font-size-mono, 12px)',
|
||||
color: darkMode ? '#d4d4d4' : '#666',
|
||||
}}>
|
||||
{rs.messages.join('\n')}
|
||||
</div>
|
||||
)}
|
||||
<DataGrid
|
||||
data={rs.rows}
|
||||
columnNames={rs.columns}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import React from 'react'
|
||||
import ReactDOM from 'react-dom/client'
|
||||
import App from './App'
|
||||
import PerfDataGridHarness from './dev/PerfDataGridHarness'
|
||||
// import './index.css' // Optional global styles
|
||||
|
||||
// 全局配置 dayjs 使用中文 locale,使 Ant Design 的 DatePicker/TimePicker 等组件
|
||||
@@ -299,12 +298,18 @@ if (typeof window !== 'undefined' && !(window as any).go) {
|
||||
}
|
||||
const rootNode = document.getElementById('root')!;
|
||||
const devHarnessMode = import.meta.env.DEV ? resolveDevHarnessMode() : '';
|
||||
const rootComponent = devHarnessMode === 'datagrid-perf'
|
||||
? <PerfDataGridHarness />
|
||||
: <App />;
|
||||
const renderRoot = async () => {
|
||||
let rootComponent = <App />;
|
||||
if (devHarnessMode === 'datagrid-perf') {
|
||||
const { default: PerfDataGridHarness } = await import('./dev/PerfDataGridHarness');
|
||||
rootComponent = <PerfDataGridHarness />;
|
||||
}
|
||||
|
||||
ReactDOM.createRoot(rootNode).render(
|
||||
<React.StrictMode>
|
||||
{rootComponent}
|
||||
</React.StrictMode>,
|
||||
)
|
||||
ReactDOM.createRoot(rootNode).render(
|
||||
<React.StrictMode>
|
||||
{rootComponent}
|
||||
</React.StrictMode>,
|
||||
);
|
||||
};
|
||||
|
||||
void renderRoot();
|
||||
|
||||
@@ -63,12 +63,12 @@ describe('dataSourceCapabilities', () => {
|
||||
supportsCreateDatabase: false,
|
||||
supportsRenameDatabase: false,
|
||||
supportsDropDatabase: false,
|
||||
forceReadOnlyQueryResult: true,
|
||||
forceReadOnlyQueryResult: false,
|
||||
});
|
||||
expect(getDataSourceCapabilities({ type: 'custom', driver: 'elastic' })).toMatchObject({
|
||||
type: 'elasticsearch',
|
||||
supportsQueryEditor: true,
|
||||
forceReadOnlyQueryResult: true,
|
||||
forceReadOnlyQueryResult: false,
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -806,6 +806,7 @@ export namespace connection {
|
||||
message: string;
|
||||
data: any;
|
||||
fields?: string[];
|
||||
messages?: string[];
|
||||
queryId?: string;
|
||||
|
||||
static createFrom(source: any = {}) {
|
||||
@@ -818,6 +819,7 @@ export namespace connection {
|
||||
this.message = source["message"];
|
||||
this.data = source["data"];
|
||||
this.fields = source["fields"];
|
||||
this.messages = source["messages"];
|
||||
this.queryId = source["queryId"];
|
||||
}
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -64,7 +64,7 @@ require (
|
||||
github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 // indirect
|
||||
github.com/godbus/dbus/v5 v5.1.0 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0
|
||||
github.com/golang/snappy v1.0.0 // indirect
|
||||
github.com/google/flatbuffers v25.12.19+incompatible // indirect
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@@ -17,12 +18,8 @@ import (
|
||||
"github.com/wailsapp/wails/v2/pkg/runtime"
|
||||
)
|
||||
|
||||
var migratableDataRootEntries = []string{
|
||||
"connections.json",
|
||||
"global_proxy.json",
|
||||
"ai_config.json",
|
||||
"sessions",
|
||||
"drivers",
|
||||
var dataRootMigrationExcludedEntries = map[string]struct{}{
|
||||
"storage_root.json": {},
|
||||
}
|
||||
|
||||
func (a *App) GetDataRootDirectoryInfo() connection.QueryResult {
|
||||
@@ -122,22 +119,43 @@ func migrateDataRootContents(sourceRoot string, targetRoot string) error {
|
||||
if sourceRoot == "" || targetRoot == "" {
|
||||
return fmt.Errorf("数据目录不能为空")
|
||||
}
|
||||
if filepath.Clean(sourceRoot) == filepath.Clean(targetRoot) {
|
||||
sourceAbs, err := filepath.Abs(sourceRoot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析源数据目录失败:%w", err)
|
||||
}
|
||||
targetAbs, err := filepath.Abs(targetRoot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析目标数据目录失败:%w", err)
|
||||
}
|
||||
if filepath.Clean(sourceAbs) == filepath.Clean(targetAbs) {
|
||||
return nil
|
||||
}
|
||||
if rel, err := filepath.Rel(sourceAbs, targetAbs); err == nil && rel != "." && rel != "" && !strings.HasPrefix(rel, "..") && !filepath.IsAbs(rel) {
|
||||
return fmt.Errorf("目标数据目录不能位于源目录内部")
|
||||
}
|
||||
sourceRoot = sourceAbs
|
||||
targetRoot = targetAbs
|
||||
if err := os.MkdirAll(targetRoot, 0o755); err != nil {
|
||||
return fmt.Errorf("创建目标数据目录失败:%w", err)
|
||||
}
|
||||
for _, name := range migratableDataRootEntries {
|
||||
entries, err := os.ReadDir(sourceRoot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取源数据目录失败:%w", err)
|
||||
}
|
||||
for _, entry := range entries {
|
||||
name := strings.TrimSpace(entry.Name())
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, excluded := dataRootMigrationExcludedEntries[name]; excluded {
|
||||
continue
|
||||
}
|
||||
sourcePath := filepath.Join(sourceRoot, name)
|
||||
info, err := os.Stat(sourcePath)
|
||||
targetPath := filepath.Join(targetRoot, name)
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("读取源数据失败(%s):%w", name, err)
|
||||
}
|
||||
targetPath := filepath.Join(targetRoot, name)
|
||||
if info.IsDir() {
|
||||
if err := copyDir(sourcePath, targetPath); err != nil {
|
||||
return fmt.Errorf("迁移目录失败(%s):%w", name, err)
|
||||
@@ -148,6 +166,75 @@ func migrateDataRootContents(sourceRoot string, targetRoot string) error {
|
||||
return fmt.Errorf("迁移文件失败(%s):%w", name, err)
|
||||
}
|
||||
}
|
||||
if err := rewriteMigratedDataRootState(targetRoot); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func rewriteMigratedDataRootState(targetRoot string) error {
|
||||
if err := rewriteSecurityUpdateBackupPaths(targetRoot); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func rewriteSecurityUpdateBackupPaths(targetRoot string) error {
|
||||
repo := newSecurityUpdateStateRepository(targetRoot)
|
||||
marker, err := repo.readMarker()
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("读取迁移后的安全更新状态失败:%w", err)
|
||||
}
|
||||
|
||||
migrationID := strings.TrimSpace(marker.MigrationID)
|
||||
if migrationID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
targetBackupPath := repo.backupPath(migrationID)
|
||||
marker.BackupPath = targetBackupPath
|
||||
if err := repo.writeMarker(marker); err != nil {
|
||||
return fmt.Errorf("写入迁移后的安全更新状态失败:%w", err)
|
||||
}
|
||||
|
||||
manifestPath := repo.manifestPath(migrationID)
|
||||
manifestData, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("读取迁移后的安全更新备份清单失败:%w", err)
|
||||
}
|
||||
} else {
|
||||
var manifest securityUpdateBackupManifest
|
||||
if err := json.Unmarshal(manifestData, &manifest); err != nil {
|
||||
return fmt.Errorf("解析迁移后的安全更新备份清单失败:%w", err)
|
||||
}
|
||||
manifest.BackupPath = targetBackupPath
|
||||
if err := securityUpdateWriteJSONFile(manifestPath, manifest); err != nil {
|
||||
return fmt.Errorf("写入迁移后的安全更新备份清单失败:%w", err)
|
||||
}
|
||||
}
|
||||
|
||||
resultPath := repo.resultPath(migrationID)
|
||||
resultData, err := os.ReadFile(resultPath)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("读取迁移后的安全更新结果失败:%w", err)
|
||||
}
|
||||
} else {
|
||||
var result SecurityUpdateStatus
|
||||
if err := json.Unmarshal(resultData, &result); err != nil {
|
||||
return fmt.Errorf("解析迁移后的安全更新结果失败:%w", err)
|
||||
}
|
||||
result.BackupPath = targetBackupPath
|
||||
result.BackupAvailable = strings.TrimSpace(targetBackupPath) != ""
|
||||
if err := securityUpdateWriteJSONFile(resultPath, result); err != nil {
|
||||
return fmt.Errorf("写入迁移后的安全更新结果失败:%w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -105,6 +105,36 @@ func TestMigrateDataRootContentsCopiesSecurityUpdateStateAndRewritesBackupPaths(
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateDataRootContentsToleratesMissingSecurityUpdateArtifacts(t *testing.T) {
|
||||
sourceRoot := t.TempDir()
|
||||
targetRoot := filepath.Join(t.TempDir(), "gonavi-data")
|
||||
sourceRepo := newSecurityUpdateStateRepository(sourceRoot)
|
||||
started, err := sourceRepo.StartRound(StartSecurityUpdateRequest{SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig})
|
||||
if err != nil {
|
||||
t.Fatalf("start security update round failed: %v", err)
|
||||
}
|
||||
if err := os.Remove(sourceRepo.manifestPath(started.MigrationID)); err != nil {
|
||||
t.Fatalf("remove source manifest failed: %v", err)
|
||||
}
|
||||
if err := os.Remove(sourceRepo.resultPath(started.MigrationID)); err != nil {
|
||||
t.Fatalf("remove source result failed: %v", err)
|
||||
}
|
||||
|
||||
if err := migrateDataRootContents(sourceRoot, targetRoot); err != nil {
|
||||
t.Fatalf("migrateDataRootContents should tolerate missing security update artifacts, got: %v", err)
|
||||
}
|
||||
|
||||
targetRepo := newSecurityUpdateStateRepository(targetRoot)
|
||||
targetStatus, err := targetRepo.LoadMarker()
|
||||
if err != nil {
|
||||
t.Fatalf("load migrated marker failed: %v", err)
|
||||
}
|
||||
expectedBackupPath := filepath.Join(targetRoot, securityUpdateBackupRootDirName, started.MigrationID)
|
||||
if targetStatus.BackupPath != expectedBackupPath {
|
||||
t.Fatalf("expected migrated marker backupPath %q, got %q", expectedBackupPath, targetStatus.BackupPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateDataRootContentsCopiesDailySecretsForSavedConnections(t *testing.T) {
|
||||
sourceRoot := t.TempDir()
|
||||
targetRoot := filepath.Join(t.TempDir(), "gonavi-data")
|
||||
|
||||
@@ -623,6 +623,7 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
|
||||
}()
|
||||
|
||||
isReadQuery := isReadOnlySQLQuery(runConfig.Type, query)
|
||||
tryQueryFirst := shouldTryQueryResultFirst(runConfig.Type, query)
|
||||
|
||||
runReadQuery := func(inst db.Database) ([]map[string]interface{}, []string, error) {
|
||||
if q, ok := inst.(interface {
|
||||
@@ -633,6 +634,14 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
|
||||
return inst.Query(query)
|
||||
}
|
||||
|
||||
runReadQueryWithMessages := func(inst db.Database) ([]map[string]interface{}, []string, []string, error) {
|
||||
if q, ok := inst.(db.QueryMessageExecer); ok {
|
||||
return q.QueryContextWithMessages(ctx, query)
|
||||
}
|
||||
data, columns, err := runReadQuery(inst)
|
||||
return data, columns, nil, err
|
||||
}
|
||||
|
||||
runExecQuery := func(inst db.Database) (int64, error) {
|
||||
if e, ok := inst.(interface {
|
||||
ExecContext(context.Context, string) (int64, error)
|
||||
@@ -642,8 +651,8 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
|
||||
return inst.Exec(query)
|
||||
}
|
||||
|
||||
if isReadQuery {
|
||||
data, columns, err := runReadQuery(dbInst)
|
||||
if isReadQuery || tryQueryFirst {
|
||||
data, columns, messages, err := runReadQueryWithMessages(dbInst)
|
||||
if err != nil && shouldRefreshCachedConnection(err) {
|
||||
if a.invalidateCachedDatabase(runConfig, err) {
|
||||
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
|
||||
@@ -651,32 +660,34 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
|
||||
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: retryErr.Error()}
|
||||
}
|
||||
data, columns, err = runReadQuery(retryInst)
|
||||
data, columns, messages, err = runReadQueryWithMessages(retryInst)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == nil {
|
||||
return connection.QueryResult{Success: true, Data: data, Fields: columns, Messages: messages, QueryID: queryID}
|
||||
}
|
||||
if isReadQuery {
|
||||
logger.Error(err, "DBQuery 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: data, Fields: columns, QueryID: queryID}
|
||||
} else {
|
||||
affected, err := runExecQuery(dbInst)
|
||||
if err != nil && shouldRefreshCachedConnection(err) {
|
||||
if a.invalidateCachedDatabase(runConfig, err) {
|
||||
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
|
||||
if retryErr != nil {
|
||||
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: retryErr.Error()}
|
||||
}
|
||||
affected, err = runExecQuery(retryInst)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: map[string]int64{"affectedRows": affected}, QueryID: queryID}
|
||||
}
|
||||
|
||||
affected, err := runExecQuery(dbInst)
|
||||
if err != nil && shouldRefreshCachedConnection(err) {
|
||||
if a.invalidateCachedDatabase(runConfig, err) {
|
||||
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
|
||||
if retryErr != nil {
|
||||
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: retryErr.Error()}
|
||||
}
|
||||
affected, err = runExecQuery(retryInst)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: map[string]int64{"affectedRows": affected}, QueryID: queryID}
|
||||
}
|
||||
|
||||
// DBQueryMulti 执行可能包含多条 SQL 语句的查询,返回多个结果集。
|
||||
@@ -727,20 +738,25 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
}
|
||||
}
|
||||
|
||||
runMultiQuery := func(inst db.Database) ([]connection.ResultSetData, error) {
|
||||
runMultiQuery := func(inst db.Database) ([]connection.ResultSetData, []string, error) {
|
||||
if !allReadOnly {
|
||||
return nil, nil // 包含写操作,走逐条执行路径
|
||||
return nil, nil, nil // 包含写操作,走逐条执行路径
|
||||
}
|
||||
if q, ok := inst.(db.MultiResultQueryMessageExecer); ok {
|
||||
return q.QueryMultiContextWithMessages(ctx, query)
|
||||
}
|
||||
if q, ok := inst.(db.MultiResultQuerierContext); ok {
|
||||
return q.QueryMultiContext(ctx, query)
|
||||
results, err := q.QueryMultiContext(ctx, query)
|
||||
return results, nil, err
|
||||
}
|
||||
if q, ok := inst.(db.MultiResultQuerier); ok {
|
||||
return q.QueryMulti(query)
|
||||
results, err := q.QueryMulti(query)
|
||||
return results, nil, err
|
||||
}
|
||||
return nil, nil // 返回 nil 表示不支持
|
||||
return nil, nil, nil // 返回 nil 表示不支持
|
||||
}
|
||||
|
||||
results, err := runMultiQuery(dbInst)
|
||||
results, resultMessages, err := runMultiQuery(dbInst)
|
||||
if err != nil && shouldRefreshCachedConnection(err) {
|
||||
if a.invalidateCachedDatabase(runConfig, err) {
|
||||
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
|
||||
@@ -748,7 +764,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
logger.Error(retryErr, "DBQueryMulti 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: retryErr.Error(), QueryID: queryID}
|
||||
}
|
||||
results, err = runMultiQuery(retryInst)
|
||||
results, resultMessages, err = runMultiQuery(retryInst)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
@@ -758,7 +774,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
|
||||
// 驱动支持多结果集,直接返回
|
||||
if results != nil {
|
||||
return connection.QueryResult{Success: true, Data: results, QueryID: queryID}
|
||||
return connection.QueryResult{Success: true, Data: results, Messages: resultMessages, QueryID: queryID}
|
||||
}
|
||||
|
||||
// 驱动不支持多结果集,回退到逐条执行
|
||||
@@ -771,13 +787,50 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
}
|
||||
}
|
||||
|
||||
var sessionQueryTarget db.StatementQueryExecer
|
||||
var sessionQueryMessageTarget db.StatementQueryMessageExecer
|
||||
var sessionMultiQueryTarget db.StatementMultiResultQueryExecer
|
||||
var sessionMultiQueryMessageTarget db.StatementMultiResultQueryMessageExecer
|
||||
var sessionExecTarget db.StatementExecer
|
||||
var sessionBatchTarget db.BatchWriteExecer
|
||||
closeExecTarget := func() {}
|
||||
if provider, ok := dbInst.(db.SessionExecerProvider); ok {
|
||||
sessionExecer, sessionErr := provider.OpenSessionExecer(ctx)
|
||||
if sessionErr != nil {
|
||||
logger.Warnf("DBQueryMulti 打开会话级执行器失败,将回退共享连接:%s SQL片段=%q err=%v", formatConnSummary(runConfig), sqlSnippet(query), sessionErr)
|
||||
} else {
|
||||
if statementQueryExecer, ok := sessionExecer.(db.StatementQueryExecer); ok {
|
||||
sessionQueryTarget = statementQueryExecer
|
||||
}
|
||||
if statementQueryMessageExecer, ok := sessionExecer.(db.StatementQueryMessageExecer); ok {
|
||||
sessionQueryMessageTarget = statementQueryMessageExecer
|
||||
}
|
||||
if statementMultiResultQueryExecer, ok := sessionExecer.(db.StatementMultiResultQueryExecer); ok {
|
||||
sessionMultiQueryTarget = statementMultiResultQueryExecer
|
||||
}
|
||||
if statementMultiResultQueryMessageExecer, ok := sessionExecer.(db.StatementMultiResultQueryMessageExecer); ok {
|
||||
sessionMultiQueryMessageTarget = statementMultiResultQueryMessageExecer
|
||||
}
|
||||
sessionExecTarget = sessionExecer
|
||||
if batcher, ok := sessionExecer.(db.BatchWriteExecer); ok {
|
||||
sessionBatchTarget = batcher
|
||||
}
|
||||
closeExecTarget = func() {
|
||||
if err := sessionExecer.Close(); err != nil {
|
||||
logger.Warnf("DBQueryMulti 关闭会话级执行器失败:%v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
defer closeExecTarget()
|
||||
|
||||
// 全部为写操作且驱动支持批量 Exec → 一次性发送,大幅减少网络往返
|
||||
// 适用于 MySQL/MariaDB/Doris/PostgreSQL/SQLite/DuckDB 等支持多语句 Exec 的驱动
|
||||
if !allReadOnly {
|
||||
allWrite := true
|
||||
containsPLSQLBlock := false
|
||||
for _, stmt := range statements {
|
||||
if strings.TrimSpace(stmt) != "" && isReadOnlySQLQuery(runConfig.Type, stmt) {
|
||||
if strings.TrimSpace(stmt) != "" && !isBatchableWriteSQLStatement(runConfig.Type, stmt) {
|
||||
allWrite = false
|
||||
}
|
||||
if isPLSQLBlockStatement(stmt) {
|
||||
@@ -785,7 +838,13 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
}
|
||||
}
|
||||
if allWrite && !containsPLSQLBlock {
|
||||
if batcher, ok := dbInst.(db.BatchWriteExecer); ok {
|
||||
batcher := sessionBatchTarget
|
||||
if batcher == nil {
|
||||
if fallbackBatcher, ok := dbInst.(db.BatchWriteExecer); ok {
|
||||
batcher = fallbackBatcher
|
||||
}
|
||||
}
|
||||
if batcher != nil {
|
||||
affected, batchErr := batcher.ExecBatchContext(ctx, query)
|
||||
if batchErr != nil && shouldRefreshCachedConnection(batchErr) {
|
||||
if a.invalidateCachedDatabase(runConfig, batchErr) {
|
||||
@@ -823,17 +882,80 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
continue
|
||||
}
|
||||
|
||||
if isReadOnlySQLQuery(runConfig.Type, stmt) {
|
||||
var data []map[string]interface{}
|
||||
var columns []string
|
||||
if q, ok := dbInst.(interface {
|
||||
isReadStmt := isReadOnlySQLQuery(runConfig.Type, stmt)
|
||||
tryQueryStmtFirst := shouldTryQueryResultFirst(runConfig.Type, stmt)
|
||||
if isReadStmt || tryQueryStmtFirst {
|
||||
var (
|
||||
data []map[string]interface{}
|
||||
columns []string
|
||||
messages []string
|
||||
statementResults []connection.ResultSetData
|
||||
usedMultiResult bool
|
||||
)
|
||||
if sessionMultiQueryMessageTarget != nil {
|
||||
statementResults, messages, err = sessionMultiQueryMessageTarget.QueryMultiContextWithMessages(ctx, stmt)
|
||||
usedMultiResult = true
|
||||
} else if sessionMultiQueryTarget != nil {
|
||||
statementResults, err = sessionMultiQueryTarget.QueryMultiContext(ctx, stmt)
|
||||
usedMultiResult = true
|
||||
} else if q, ok := dbInst.(db.MultiResultQueryMessageExecer); ok {
|
||||
statementResults, messages, err = q.QueryMultiContextWithMessages(ctx, stmt)
|
||||
usedMultiResult = true
|
||||
} else if q, ok := dbInst.(db.MultiResultQuerierContext); ok {
|
||||
statementResults, err = q.QueryMultiContext(ctx, stmt)
|
||||
usedMultiResult = true
|
||||
} else if q, ok := dbInst.(db.MultiResultQuerier); ok {
|
||||
statementResults, err = q.QueryMulti(stmt)
|
||||
usedMultiResult = true
|
||||
} else if sessionQueryMessageTarget != nil {
|
||||
data, columns, messages, err = sessionQueryMessageTarget.QueryContextWithMessages(ctx, stmt)
|
||||
} else if sessionQueryTarget != nil {
|
||||
data, columns, err = sessionQueryTarget.QueryContext(ctx, stmt)
|
||||
} else if q, ok := dbInst.(db.QueryMessageExecer); ok {
|
||||
data, columns, messages, err = q.QueryContextWithMessages(ctx, stmt)
|
||||
} else if q, ok := dbInst.(interface {
|
||||
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
|
||||
}); ok {
|
||||
data, columns, err = q.QueryContext(ctx, stmt)
|
||||
} else {
|
||||
data, columns, err = dbInst.Query(stmt)
|
||||
}
|
||||
if err != nil {
|
||||
if err == nil {
|
||||
if usedMultiResult {
|
||||
if len(statementResults) == 0 && len(messages) > 0 {
|
||||
statementResults = []connection.ResultSetData{{
|
||||
Rows: []map[string]interface{}{},
|
||||
Columns: []string{},
|
||||
Messages: append([]string(nil), messages...),
|
||||
}}
|
||||
}
|
||||
for _, statementResult := range statementResults {
|
||||
if statementResult.Rows == nil {
|
||||
statementResult.Rows = []map[string]interface{}{}
|
||||
}
|
||||
if statementResult.Columns == nil {
|
||||
statementResult.Columns = []string{}
|
||||
}
|
||||
statementResult.StatementIndex = idx + 1
|
||||
resultSets = append(resultSets, statementResult)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if data == nil {
|
||||
data = make([]map[string]interface{}, 0)
|
||||
}
|
||||
if columns == nil {
|
||||
columns = []string{}
|
||||
}
|
||||
resultSets = append(resultSets, connection.ResultSetData{
|
||||
Rows: data,
|
||||
Columns: columns,
|
||||
Messages: messages,
|
||||
StatementIndex: idx + 1,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if isReadStmt {
|
||||
logger.Error(err, "DBQueryMulti 逐条查询失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt))
|
||||
errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err)
|
||||
if len(resultSets) > 0 {
|
||||
@@ -841,35 +963,30 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
}
|
||||
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
|
||||
}
|
||||
if data == nil {
|
||||
data = make([]map[string]interface{}, 0)
|
||||
}
|
||||
if columns == nil {
|
||||
columns = []string{}
|
||||
}
|
||||
resultSets = append(resultSets, connection.ResultSetData{Rows: data, Columns: columns})
|
||||
} else {
|
||||
var affected int64
|
||||
if e, ok := dbInst.(interface {
|
||||
ExecContext(context.Context, string) (int64, error)
|
||||
}); ok {
|
||||
affected, err = e.ExecContext(ctx, stmt)
|
||||
} else {
|
||||
affected, err = dbInst.Exec(stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQueryMulti 逐条执行失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt))
|
||||
errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err)
|
||||
if len(resultSets) > 0 {
|
||||
errMsg += fmt.Sprintf("(前 %d 条已执行成功)", len(resultSets))
|
||||
}
|
||||
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
|
||||
}
|
||||
resultSets = append(resultSets, connection.ResultSetData{
|
||||
Rows: []map[string]interface{}{{"affectedRows": affected}},
|
||||
Columns: []string{"affectedRows"},
|
||||
})
|
||||
}
|
||||
|
||||
var affected int64
|
||||
if sessionExecTarget != nil {
|
||||
affected, err = sessionExecTarget.ExecContext(ctx, stmt)
|
||||
} else if e, ok := dbInst.(interface {
|
||||
ExecContext(context.Context, string) (int64, error)
|
||||
}); ok {
|
||||
affected, err = e.ExecContext(ctx, stmt)
|
||||
} else {
|
||||
affected, err = dbInst.Exec(stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(err, "DBQueryMulti 逐条执行失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt))
|
||||
errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err)
|
||||
if len(resultSets) > 0 {
|
||||
errMsg += fmt.Sprintf("(前 %d 条已执行成功)", len(resultSets))
|
||||
}
|
||||
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
|
||||
}
|
||||
resultSets = append(resultSets, connection.ResultSetData{
|
||||
Rows: []map[string]interface{}{{"affectedRows": affected}},
|
||||
Columns: []string{"affectedRows"},
|
||||
})
|
||||
}
|
||||
|
||||
if resultSets == nil {
|
||||
@@ -883,6 +1000,24 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
|
||||
return connection.QueryResult{Success: true, Data: resultSets, QueryID: queryID, Message: fallbackMsg}
|
||||
}
|
||||
|
||||
func shouldTryQueryResultFirst(dbType string, query string) bool {
|
||||
isSQLServer := strings.EqualFold(strings.TrimSpace(dbType), "sqlserver")
|
||||
keyword := leadingSQLKeyword(query)
|
||||
switch keyword {
|
||||
case "exec", "execute", "call":
|
||||
return true
|
||||
case "set", "print":
|
||||
return isSQLServer
|
||||
case "dbcc":
|
||||
return isSQLServer
|
||||
default:
|
||||
if isSQLServer {
|
||||
return strings.HasPrefix(keyword, "sp_") || strings.HasPrefix(keyword, "xp_")
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string, query string) connection.QueryResult {
|
||||
runConfig := normalizeRunConfig(config, dbName)
|
||||
|
||||
@@ -906,22 +1041,30 @@ func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string,
|
||||
defer cancel()
|
||||
|
||||
isReadQuery := isReadOnlySQLQuery(runConfig.Type, query)
|
||||
tryQueryFirst := shouldTryQueryResultFirst(runConfig.Type, query)
|
||||
|
||||
if isReadQuery {
|
||||
var data []map[string]interface{}
|
||||
var columns []string
|
||||
if q, ok := dbInst.(interface {
|
||||
if isReadQuery || tryQueryFirst {
|
||||
var (
|
||||
data []map[string]interface{}
|
||||
columns []string
|
||||
messages []string
|
||||
)
|
||||
if q, ok := dbInst.(db.QueryMessageExecer); ok {
|
||||
data, columns, messages, err = q.QueryContextWithMessages(ctx, query)
|
||||
} else if q, ok := dbInst.(interface {
|
||||
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
|
||||
}); ok {
|
||||
data, columns, err = q.QueryContext(ctx, query)
|
||||
} else {
|
||||
data, columns, err = dbInst.Query(query)
|
||||
}
|
||||
if err != nil {
|
||||
if err == nil {
|
||||
return connection.QueryResult{Success: true, Data: data, Fields: columns, Messages: messages}
|
||||
}
|
||||
if isReadQuery {
|
||||
logger.Error(err, "DBQueryIsolated 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
|
||||
return connection.QueryResult{Success: false, Message: err.Error()}
|
||||
}
|
||||
return connection.QueryResult{Success: true, Data: data, Fields: columns}
|
||||
}
|
||||
|
||||
var affected int64
|
||||
|
||||
@@ -10,10 +10,17 @@ import (
|
||||
)
|
||||
|
||||
type fakeBatchWriteDB struct {
|
||||
batchCalls int
|
||||
execCalls int
|
||||
batchCalls int
|
||||
execCalls int
|
||||
execQueries []string
|
||||
lastQuery string
|
||||
lastQuery string
|
||||
queryCalls int
|
||||
queryMap map[string][]map[string]interface{}
|
||||
fieldMap map[string][]string
|
||||
messageMap map[string][]string
|
||||
multiResult map[string][]connection.ResultSetData
|
||||
queryErr map[string]error
|
||||
session *fakeBatchWriteSession
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) Connect(config connection.ConnectionConfig) error {
|
||||
@@ -29,7 +36,16 @@ func (f *fakeBatchWriteDB) Ping() error {
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
return nil, nil, nil
|
||||
f.queryCalls++
|
||||
if err := f.queryErr[query]; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return f.queryMap[query], f.fieldMap[query], nil
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
rows, fields, err := f.Query(query)
|
||||
return rows, fields, f.messageMap[query], err
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) Exec(query string) (int64, error) {
|
||||
@@ -76,12 +92,143 @@ func (f *fakeBatchWriteDB) ExecContext(ctx context.Context, query string) (int64
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
f.queryCalls++
|
||||
if err := f.queryErr[query]; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return f.queryMap[query], f.fieldMap[query], nil
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
rows, fields, err := f.QueryContext(ctx, query)
|
||||
return rows, fields, f.messageMap[query], err
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) ExecBatchContext(ctx context.Context, query string) (int64, error) {
|
||||
f.batchCalls++
|
||||
f.lastQuery = query
|
||||
return 500, nil
|
||||
}
|
||||
|
||||
func (f *fakeBatchWriteDB) OpenSessionExecer(ctx context.Context) (db.StatementExecer, error) {
|
||||
f.session = &fakeBatchWriteSession{parent: f}
|
||||
return f.session, nil
|
||||
}
|
||||
|
||||
type fakeBatchWriteSession struct {
|
||||
parent *fakeBatchWriteDB
|
||||
queryCalls int
|
||||
execCalls int
|
||||
batchCalls int
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
return s.QueryContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
s.queryCalls++
|
||||
return s.parent.QueryContext(ctx, query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
return s.QueryContextWithMessages(context.Background(), query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
s.queryCalls++
|
||||
return s.parent.QueryContextWithMessages(ctx, query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) QueryMulti(query string) ([]connection.ResultSetData, error) {
|
||||
return s.QueryMultiContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
|
||||
if multi := s.parent.multiResult[query]; len(multi) > 0 {
|
||||
s.queryCalls++
|
||||
return cloneResultSets(multi), nil
|
||||
}
|
||||
rows, columns, err := s.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []connection.ResultSetData{{Rows: rows, Columns: columns}}, nil
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
|
||||
return s.QueryMultiContextWithMessages(context.Background(), query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
|
||||
if err := s.parent.queryErr[query]; err != nil {
|
||||
s.queryCalls++
|
||||
return nil, nil, err
|
||||
}
|
||||
if multi := s.parent.multiResult[query]; len(multi) > 0 {
|
||||
s.queryCalls++
|
||||
return cloneResultSets(multi), append([]string(nil), s.parent.messageMap[query]...), nil
|
||||
}
|
||||
rows, columns, messages, err := s.QueryContextWithMessages(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return []connection.ResultSetData{{
|
||||
Rows: rows,
|
||||
Columns: columns,
|
||||
Messages: append([]string(nil), messages...),
|
||||
}}, append([]string(nil), messages...), nil
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) Exec(query string) (int64, error) {
|
||||
return s.ExecContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
s.execCalls++
|
||||
return s.parent.ExecContext(ctx, query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) ExecBatchContext(ctx context.Context, query string) (int64, error) {
|
||||
s.batchCalls++
|
||||
return s.parent.ExecBatchContext(ctx, query)
|
||||
}
|
||||
|
||||
func (s *fakeBatchWriteSession) Close() error {
|
||||
s.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func cloneResultSets(input []connection.ResultSetData) []connection.ResultSetData {
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
}
|
||||
cloned := make([]connection.ResultSetData, 0, len(input))
|
||||
for _, item := range input {
|
||||
rows := make([]map[string]interface{}, 0, len(item.Rows))
|
||||
for _, row := range item.Rows {
|
||||
if row == nil {
|
||||
rows = append(rows, nil)
|
||||
continue
|
||||
}
|
||||
rowCopy := make(map[string]interface{}, len(row))
|
||||
for key, value := range row {
|
||||
rowCopy[key] = value
|
||||
}
|
||||
rows = append(rows, rowCopy)
|
||||
}
|
||||
cloned = append(cloned, connection.ResultSetData{
|
||||
Rows: rows,
|
||||
Columns: append([]string(nil), item.Columns...),
|
||||
Messages: append([]string(nil), item.Messages...),
|
||||
StatementIndex: item.StatementIndex,
|
||||
})
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func TestDBQueryMultiKeepsOracleAnonymousBlockAsSingleStatement(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
@@ -122,6 +269,85 @@ END;`
|
||||
}
|
||||
|
||||
var _ db.BatchWriteExecer = (*fakeBatchWriteDB)(nil)
|
||||
var _ db.SessionExecerProvider = (*fakeBatchWriteDB)(nil)
|
||||
var _ db.QueryMessageExecer = (*fakeBatchWriteDB)(nil)
|
||||
var _ db.StatementQueryMessageExecer = (*fakeBatchWriteSession)(nil)
|
||||
|
||||
func TestDBQueryWithCancelReturnsResultSetForExecStoredProcedure(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
query := "EXEC sp_who2"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
queryMap: map[string][]map[string]interface{}{
|
||||
query: {
|
||||
{"SPID": 52, "STATUS": "RUNNABLE"},
|
||||
},
|
||||
},
|
||||
fieldMap: map[string][]string{
|
||||
query: {"SPID", "STATUS"},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
|
||||
|
||||
result := app.DBQueryWithCancel(config, "master", query, "sp-who2-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryWithCancel success, got failure: %s", result.Message)
|
||||
}
|
||||
rows, ok := result.Data.([]map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected []map[string]interface{}, got %T", result.Data)
|
||||
}
|
||||
if len(rows) != 1 || rows[0]["SPID"] != 52 {
|
||||
t.Fatalf("unexpected rows: %#v", rows)
|
||||
}
|
||||
if fakeDB.execCalls != 0 {
|
||||
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryWithCancelReturnsMessagesForSQLServerQuery(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
query := "SET STATISTICS IO ON"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
queryMap: map[string][]map[string]interface{}{
|
||||
query: {},
|
||||
},
|
||||
fieldMap: map[string][]string{
|
||||
query: {},
|
||||
},
|
||||
messageMap: map[string][]string{
|
||||
query: {"Table 'users'. Scan count 1, logical reads 3."},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
|
||||
|
||||
result := app.DBQueryWithCancel(config, "master", query, "statistics-io-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryWithCancel success, got failure: %s", result.Message)
|
||||
}
|
||||
if len(result.Messages) != 1 || result.Messages[0] == "" {
|
||||
t.Fatalf("expected SQL Server messages to be returned, got %#v", result.Messages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
@@ -168,3 +394,206 @@ func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) {
|
||||
t.Fatalf("expected affectedRows=500, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiPrefersResultSetForExecStoredProcedure(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
query := "EXEC sp_who2"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
queryMap: map[string][]map[string]interface{}{
|
||||
query: {
|
||||
{"SPID": 77, "STATUS": "SUSPENDED"},
|
||||
},
|
||||
},
|
||||
fieldMap: map[string][]string{
|
||||
query: {"SPID", "STATUS"},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
|
||||
|
||||
result := app.DBQueryMulti(config, "master", query, "sp-who2-multi-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
|
||||
}
|
||||
resultSets, ok := result.Data.([]connection.ResultSetData)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
|
||||
}
|
||||
if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 {
|
||||
t.Fatalf("unexpected result sets: %#v", resultSets)
|
||||
}
|
||||
if got := resultSets[0].Rows[0]["SPID"]; got != 77 {
|
||||
t.Fatalf("expected SPID=77, got %#v", got)
|
||||
}
|
||||
if fakeDB.execCalls != 0 {
|
||||
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiDoesNotBatchExecStoredProcedureAsWriteStatement(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
query := "EXEC sp_who2"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
queryMap: map[string][]map[string]interface{}{
|
||||
query: {
|
||||
{"SPID": 88, "STATUS": "RUNNING"},
|
||||
},
|
||||
},
|
||||
fieldMap: map[string][]string{
|
||||
query: {"SPID", "STATUS"},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
|
||||
|
||||
result := app.DBQueryMulti(config, "master", query, "sp-who2-batch-guard-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
|
||||
}
|
||||
if fakeDB.batchCalls != 0 {
|
||||
t.Fatalf("expected stored procedure to skip batch write path, got batchCalls=%d", fakeDB.batchCalls)
|
||||
}
|
||||
resultSets, ok := result.Data.([]connection.ResultSetData)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
|
||||
}
|
||||
if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 {
|
||||
t.Fatalf("unexpected result sets: %#v", resultSets)
|
||||
}
|
||||
if got := resultSets[0].Rows[0]["SPID"]; got != 88 {
|
||||
t.Fatalf("expected SPID=88, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiUsesPinnedSessionForSequentialFallback(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
queryMap: map[string][]map[string]interface{}{
|
||||
"SELECT 1 AS value": {
|
||||
{"value": 1},
|
||||
},
|
||||
},
|
||||
fieldMap: map[string][]string{
|
||||
"SELECT 1 AS value": {"value"},
|
||||
},
|
||||
messageMap: map[string][]string{
|
||||
"SET NOCOUNT ON": {"NOCOUNT 已开启"},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
|
||||
|
||||
result := app.DBQueryMulti(config, "master", "SET NOCOUNT ON;\nSELECT 1 AS value;", "session-fallback-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
|
||||
}
|
||||
if fakeDB.session == nil {
|
||||
t.Fatal("expected DBQueryMulti to open a pinned session for sequential fallback")
|
||||
}
|
||||
if !fakeDB.session.closed {
|
||||
t.Fatal("expected DBQueryMulti to close the pinned session")
|
||||
}
|
||||
if fakeDB.session.execCalls != 0 {
|
||||
t.Fatalf("expected SQL Server SET statement to avoid exec-only path, got execCalls=%d", fakeDB.session.execCalls)
|
||||
}
|
||||
if fakeDB.session.queryCalls != 2 {
|
||||
t.Fatalf("expected both statements to query through pinned session, got queryCalls=%d", fakeDB.session.queryCalls)
|
||||
}
|
||||
if fakeDB.queryCalls != 2 {
|
||||
t.Fatalf("expected exactly two underlying query calls, got %d", fakeDB.queryCalls)
|
||||
}
|
||||
resultSets, ok := result.Data.([]connection.ResultSetData)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
|
||||
}
|
||||
if len(resultSets) != 2 {
|
||||
t.Fatalf("expected two result sets, got %#v", resultSets)
|
||||
}
|
||||
if len(resultSets[0].Messages) != 1 || resultSets[0].Messages[0] != "NOCOUNT 已开启" {
|
||||
t.Fatalf("expected first result set to keep session message, got %#v", resultSets[0].Messages)
|
||||
}
|
||||
if got := resultSets[1].Rows[0]["value"]; got != 1 {
|
||||
t.Fatalf("expected second result set value=1, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBQueryMultiKeepsAllResultSetsFromSingleSQLServerStatement(t *testing.T) {
|
||||
originalNewDatabaseFunc := newDatabaseFunc
|
||||
t.Cleanup(func() {
|
||||
newDatabaseFunc = originalNewDatabaseFunc
|
||||
})
|
||||
|
||||
query := "EXEC sp_helpdb"
|
||||
fakeDB := &fakeBatchWriteDB{
|
||||
multiResult: map[string][]connection.ResultSetData{
|
||||
query: {
|
||||
{
|
||||
Rows: []map[string]interface{}{{"name": "master"}},
|
||||
Columns: []string{"name"},
|
||||
},
|
||||
{
|
||||
Rows: []map[string]interface{}{{"owner": "sa"}},
|
||||
Columns: []string{"owner"},
|
||||
},
|
||||
},
|
||||
},
|
||||
queryErr: map[string]error{},
|
||||
}
|
||||
newDatabaseFunc = func(dbType string) (db.Database, error) {
|
||||
return fakeDB, nil
|
||||
}
|
||||
|
||||
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
|
||||
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
|
||||
|
||||
result := app.DBQueryMulti(config, "master", query, "sp-helpdb-multi-result-test")
|
||||
if !result.Success {
|
||||
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
|
||||
}
|
||||
resultSets, ok := result.Data.([]connection.ResultSetData)
|
||||
if !ok {
|
||||
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
|
||||
}
|
||||
if len(resultSets) != 2 {
|
||||
t.Fatalf("expected two result sets, got %#v", resultSets)
|
||||
}
|
||||
if got := resultSets[0].Rows[0]["name"]; got != "master" {
|
||||
t.Fatalf("expected first result set to keep master row, got %#v", got)
|
||||
}
|
||||
if got := resultSets[1].Rows[0]["owner"]; got != "sa" {
|
||||
t.Fatalf("expected second result set to keep owner row, got %#v", got)
|
||||
}
|
||||
if resultSets[0].StatementIndex != 1 || resultSets[1].StatementIndex != 1 {
|
||||
t.Fatalf("expected both result sets to map to the first statement, got %#v", resultSets)
|
||||
}
|
||||
if fakeDB.execCalls != 0 {
|
||||
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +65,19 @@ func isReadOnlySQLQuery(dbType string, query string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func isBatchableWriteSQLStatement(dbType string, query string) bool {
|
||||
if isReadOnlySQLQuery(dbType, query) {
|
||||
return false
|
||||
}
|
||||
|
||||
switch leadingSQLKeyword(query) {
|
||||
case "insert", "update", "delete", "replace", "merge", "upsert":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeSQLForPgLike(dbType string, query string) string {
|
||||
normalizedType := strings.ToLower(strings.TrimSpace(dbType))
|
||||
switch normalizedType {
|
||||
|
||||
@@ -62,3 +62,42 @@ func TestSanitizeSQLForPgLike_DoesNotModifyOtherDBTypes(t *testing.T) {
|
||||
t.Fatalf("non-PG-like db should not be sanitized:\nIN: %s\nOUT: %s", in, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsReadOnlySQLQuery_DoesNotTreatExecAsReadOnly(t *testing.T) {
|
||||
if isReadOnlySQLQuery("sqlserver", "EXEC sp_who2") {
|
||||
t.Fatal("EXEC should not be treated as read-only SQL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBatchableWriteSQLStatement_OnlyMatchesRealWriteStatements(t *testing.T) {
|
||||
if !isBatchableWriteSQLStatement("mysql", "INSERT INTO demo(id) VALUES (1)") {
|
||||
t.Fatal("expected INSERT to be treated as batchable write")
|
||||
}
|
||||
if isBatchableWriteSQLStatement("sqlserver", "EXEC sp_who2") {
|
||||
t.Fatal("EXEC should not be treated as batchable write")
|
||||
}
|
||||
if isBatchableWriteSQLStatement("sqlserver", "SET STATISTICS IO ON") {
|
||||
t.Fatal("SET STATISTICS should not be treated as batchable write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldTryQueryResultFirst_TreatsSQLServerSetAsQueryFirst(t *testing.T) {
|
||||
if !shouldTryQueryResultFirst("sqlserver", "SET STATISTICS IO ON") {
|
||||
t.Fatal("expected SQL Server SET STATISTICS to try query-first for notice capture")
|
||||
}
|
||||
if shouldTryQueryResultFirst("mysql", "SET sql_mode = ''") {
|
||||
t.Fatal("non-SQLServer SET should not force query-first")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldTryQueryResultFirst_TreatsSQLServerSystemCommandsAsQueryFirst(t *testing.T) {
|
||||
if !shouldTryQueryResultFirst("sqlserver", "sp_who2") {
|
||||
t.Fatal("expected bare SQL Server system procedure to try query-first")
|
||||
}
|
||||
if !shouldTryQueryResultFirst("sqlserver", "DBCC INPUTBUFFER(52)") {
|
||||
t.Fatal("expected SQL Server DBCC command to try query-first")
|
||||
}
|
||||
if shouldTryQueryResultFirst("mysql", "sp_who2") {
|
||||
t.Fatal("non-SQLServer system procedure name should not force query-first")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,8 @@ func (p *panicChromium) PutZoomFactor(float64) {
|
||||
}
|
||||
|
||||
type panicFrontend struct {
|
||||
chromium *panicChromium
|
||||
chromium *panicChromium
|
||||
mainWindow *fakeWindow
|
||||
}
|
||||
|
||||
// 测试必须用 wails 一致的 string key "frontend" 作为 context.WithValue 的 key,
|
||||
@@ -122,7 +123,10 @@ func TestResetWebViewZoomFactorErrorsWhenFrontendMissing(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResetWebViewZoomFactorRecoversFromPutZoomFactorPanic(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), stringContextKey("frontend"), &panicFrontend{chromium: &panicChromium{}})
|
||||
ctx := context.WithValue(context.Background(), stringContextKey("frontend"), &panicFrontend{
|
||||
chromium: &panicChromium{},
|
||||
mainWindow: &fakeWindow{},
|
||||
})
|
||||
|
||||
err := resetWebViewZoomFactor(ctx, 1.0)
|
||||
if err == nil {
|
||||
|
||||
@@ -122,17 +122,20 @@ type ConnectionConfig struct {
|
||||
|
||||
// ResultSetData 表示一个查询结果集(行 + 列名),用于多结果集场景。
|
||||
type ResultSetData struct {
|
||||
Rows []map[string]interface{} `json:"rows"`
|
||||
Columns []string `json:"columns"`
|
||||
Rows []map[string]interface{} `json:"rows"`
|
||||
Columns []string `json:"columns"`
|
||||
Messages []string `json:"messages,omitempty"`
|
||||
StatementIndex int `json:"statementIndex,omitempty"`
|
||||
}
|
||||
|
||||
// QueryResult 是 Wails 绑定方法的统一响应格式,前端通过此结构体接收后端结果。
|
||||
type QueryResult struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data"`
|
||||
Fields []string `json:"fields,omitempty"`
|
||||
QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data"`
|
||||
Fields []string `json:"fields,omitempty"`
|
||||
Messages []string `json:"messages,omitempty"`
|
||||
QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation
|
||||
}
|
||||
|
||||
// ColumnDefinition 描述表的一个列定义。
|
||||
|
||||
@@ -67,6 +67,51 @@ type StatementExecer interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
// StatementQueryExecer can run queries on a pinned session/connection.
|
||||
// Drivers that return sqlConnStatementExecer automatically satisfy it.
|
||||
type StatementQueryExecer interface {
|
||||
StatementExecer
|
||||
Query(query string) ([]map[string]interface{}, []string, error)
|
||||
QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error)
|
||||
}
|
||||
|
||||
// StatementQueryMessageExecer can run queries on a pinned session and return
|
||||
// extra server messages/notices alongside rows.
|
||||
type StatementQueryMessageExecer interface {
|
||||
StatementQueryExecer
|
||||
QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error)
|
||||
QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error)
|
||||
}
|
||||
|
||||
// StatementMultiResultQueryExecer can run multi-result queries on a pinned session/connection.
|
||||
type StatementMultiResultQueryExecer interface {
|
||||
StatementExecer
|
||||
QueryMulti(query string) ([]connection.ResultSetData, error)
|
||||
QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error)
|
||||
}
|
||||
|
||||
// StatementMultiResultQueryMessageExecer can run multi-result queries on a
|
||||
// pinned session/connection and return server messages/notices.
|
||||
type StatementMultiResultQueryMessageExecer interface {
|
||||
StatementMultiResultQueryExecer
|
||||
QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error)
|
||||
QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error)
|
||||
}
|
||||
|
||||
// QueryMessageExecer is an optional database-level interface for returning
|
||||
// informational server messages alongside one result set.
|
||||
type QueryMessageExecer interface {
|
||||
QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error)
|
||||
QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error)
|
||||
}
|
||||
|
||||
// MultiResultQueryMessageExecer is an optional database-level interface for
|
||||
// returning informational server messages alongside multi-result queries.
|
||||
type MultiResultQueryMessageExecer interface {
|
||||
QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error)
|
||||
QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error)
|
||||
}
|
||||
|
||||
// SessionExecerProvider is implemented by database/sql based drivers that can
|
||||
// pin a long-running job to one physical connection.
|
||||
type SessionExecerProvider interface {
|
||||
@@ -96,6 +141,38 @@ func (e *sqlConnStatementExecer) Exec(query string) (int64, error) {
|
||||
return e.ExecContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlConnStatementExecer) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
if e == nil || e.conn == nil {
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := e.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
}
|
||||
|
||||
func (e *sqlConnStatementExecer) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
return e.QueryContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlConnStatementExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
|
||||
if e == nil || e.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := e.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMultiRows(rows)
|
||||
}
|
||||
|
||||
func (e *sqlConnStatementExecer) QueryMulti(query string) ([]connection.ResultSetData, error) {
|
||||
return e.QueryMultiContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlConnStatementExecer) ExecBatchContext(ctx context.Context, query string) (int64, error) {
|
||||
return e.ExecContext(ctx, query)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"GoNavi-Wails/internal/ssh"
|
||||
"GoNavi-Wails/internal/utils"
|
||||
|
||||
"github.com/golang-sql/sqlexp"
|
||||
_ "github.com/microsoft/go-mssqldb"
|
||||
)
|
||||
|
||||
@@ -26,6 +27,85 @@ type SqlServerDB struct {
|
||||
forwarder *ssh.LocalForwarder
|
||||
}
|
||||
|
||||
type sqlServerSessionExecer struct {
|
||||
conn *sql.Conn
|
||||
}
|
||||
|
||||
func scanSQLServerRowsWithMessages(ctx context.Context, rows *sql.Rows, retmsg *sqlexp.ReturnMessage) ([]connection.ResultSetData, []string, error) {
|
||||
if rows == nil {
|
||||
return []connection.ResultSetData{{Rows: []map[string]interface{}{}, Columns: []string{}}}, nil, nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
var (
|
||||
resultSets []connection.ResultSetData
|
||||
messages []string
|
||||
allMessages []string
|
||||
)
|
||||
active := true
|
||||
for active {
|
||||
raw := retmsg.Message(ctx)
|
||||
switch msg := raw.(type) {
|
||||
case sqlexp.MsgNotice:
|
||||
text := strings.TrimSpace(fmt.Sprint(msg.Message))
|
||||
if text != "" {
|
||||
messages = append(messages, text)
|
||||
allMessages = append(allMessages, text)
|
||||
}
|
||||
case sqlexp.MsgNext:
|
||||
data, cols, err := scanRows(rows)
|
||||
if err != nil {
|
||||
return resultSets, messages, err
|
||||
}
|
||||
if data == nil {
|
||||
data = []map[string]interface{}{}
|
||||
}
|
||||
if cols == nil {
|
||||
cols = []string{}
|
||||
}
|
||||
resultSets = append(resultSets, connection.ResultSetData{
|
||||
Rows: data,
|
||||
Columns: cols,
|
||||
Messages: append([]string(nil), messages...),
|
||||
})
|
||||
messages = nil
|
||||
case sqlexp.MsgRowsAffected:
|
||||
resultSets = append(resultSets, connection.ResultSetData{
|
||||
Rows: []map[string]interface{}{{"affectedRows": msg.Count}},
|
||||
Columns: []string{"affectedRows"},
|
||||
Messages: append([]string(nil), messages...),
|
||||
})
|
||||
messages = nil
|
||||
case sqlexp.MsgNextResultSet:
|
||||
active = rows.NextResultSet()
|
||||
case sqlexp.MsgError:
|
||||
return resultSets, messages, msg.Error
|
||||
default:
|
||||
active = false
|
||||
}
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
resultSets = append(resultSets, connection.ResultSetData{
|
||||
Rows: []map[string]interface{}{},
|
||||
Columns: []string{},
|
||||
Messages: append([]string(nil), messages...),
|
||||
})
|
||||
}
|
||||
if len(resultSets) == 0 {
|
||||
resultSets = []connection.ResultSetData{{
|
||||
Rows: []map[string]interface{}{},
|
||||
Columns: []string{},
|
||||
}}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return resultSets, allMessages, err
|
||||
}
|
||||
return resultSets, allMessages, nil
|
||||
}
|
||||
|
||||
// quoteBracket escapes ] in identifiers for safe use in SQL Server [bracket] notation
|
||||
func quoteBracket(name string) string {
|
||||
return strings.ReplaceAll(name, "]", "]]")
|
||||
@@ -133,54 +213,76 @@ func (s *SqlServerDB) Ping() error {
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) QueryMulti(query string) ([]connection.ResultSetData, error) {
|
||||
results, _, err := s.QueryMultiWithMessages(query)
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := s.conn.Query(query)
|
||||
ctx := context.Background()
|
||||
retmsg := &sqlexp.ReturnMessage{}
|
||||
rows, err := s.conn.QueryContext(ctx, query, retmsg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMultiRows(rows)
|
||||
return scanSQLServerRowsWithMessages(ctx, rows, retmsg)
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
|
||||
results, _, err := s.QueryMultiContextWithMessages(ctx, query)
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, fmt.Errorf("连接未打开")
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, err := s.conn.QueryContext(ctx, query)
|
||||
retmsg := &sqlexp.ReturnMessage{}
|
||||
rows, err := s.conn.QueryContext(ctx, query, retmsg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMultiRows(rows)
|
||||
return scanSQLServerRowsWithMessages(ctx, rows, retmsg)
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
rows, columns, _, err := s.QueryContextWithMessages(ctx, query)
|
||||
return rows, columns, err
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
return nil, nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
|
||||
rows, err := s.conn.QueryContext(ctx, query)
|
||||
resultSets, messages, err := s.QueryMultiContextWithMessages(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows)
|
||||
if len(resultSets) == 0 {
|
||||
return []map[string]interface{}{}, []string{}, messages, nil
|
||||
}
|
||||
first := resultSets[0]
|
||||
if first.Rows == nil {
|
||||
first.Rows = []map[string]interface{}{}
|
||||
}
|
||||
if first.Columns == nil {
|
||||
first.Columns = []string{}
|
||||
}
|
||||
return first.Rows, first.Columns, messages, nil
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
if s.conn == nil {
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
rows, columns, _, err := s.QueryWithMessages(query)
|
||||
return rows, columns, err
|
||||
}
|
||||
|
||||
rows, err := s.conn.Query(query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanRows(rows)
|
||||
func (s *SqlServerDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
return s.QueryContextWithMessages(context.Background(), query)
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
@@ -213,7 +315,7 @@ func (s *SqlServerDB) OpenSessionExecer(ctx context.Context) (StatementExecer, e
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewSQLConnStatementExecer(conn), nil
|
||||
return &sqlServerSessionExecer{conn: conn}, nil
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) Exec(query string) (int64, error) {
|
||||
@@ -227,6 +329,87 @@ func (s *SqlServerDB) Exec(query string) (int64, error) {
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) Exec(query string) (int64, error) {
|
||||
return e.ExecContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) ExecContext(ctx context.Context, query string) (int64, error) {
|
||||
if e == nil || e.conn == nil {
|
||||
return 0, fmt.Errorf("连接未打开")
|
||||
}
|
||||
res, err := e.conn.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) Query(query string) ([]map[string]interface{}, []string, error) {
|
||||
rows, columns, _, err := e.QueryWithMessages(query)
|
||||
return rows, columns, err
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
|
||||
rows, columns, _, err := e.QueryContextWithMessages(ctx, query)
|
||||
return rows, columns, err
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
return e.QueryContextWithMessages(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
|
||||
results, messages, err := e.QueryMultiContextWithMessages(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return []map[string]interface{}{}, []string{}, messages, nil
|
||||
}
|
||||
first := results[0]
|
||||
if first.Rows == nil {
|
||||
first.Rows = []map[string]interface{}{}
|
||||
}
|
||||
if first.Columns == nil {
|
||||
first.Columns = []string{}
|
||||
}
|
||||
return first.Rows, first.Columns, messages, nil
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) QueryMulti(query string) ([]connection.ResultSetData, error) {
|
||||
results, _, err := e.QueryMultiWithMessages(query)
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
|
||||
results, _, err := e.QueryMultiContextWithMessages(ctx, query)
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
|
||||
return e.QueryMultiContextWithMessages(context.Background(), query)
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
|
||||
if e == nil || e.conn == nil {
|
||||
return nil, nil, fmt.Errorf("连接未打开")
|
||||
}
|
||||
retmsg := &sqlexp.ReturnMessage{}
|
||||
rows, err := e.conn.QueryContext(ctx, query, retmsg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanSQLServerRowsWithMessages(ctx, rows, retmsg)
|
||||
}
|
||||
|
||||
func (e *sqlServerSessionExecer) Close() error {
|
||||
if e == nil || e.conn == nil {
|
||||
return nil
|
||||
}
|
||||
return e.conn.Close()
|
||||
}
|
||||
|
||||
func (s *SqlServerDB) GetDatabases() ([]string, error) {
|
||||
query := "SELECT name FROM sys.databases WHERE state_desc = 'ONLINE' ORDER BY name"
|
||||
data, _, err := s.Query(query)
|
||||
|
||||
Reference in New Issue
Block a user