🐛 fix(sql-editor): 修复结果消息展示与数据目录迁移稳定性

This commit is contained in:
Syngnat
2026-06-04 07:09:42 +08:00
parent 23ac30086f
commit f5166ac3fc
21 changed files with 1608 additions and 153 deletions

View File

@@ -19,14 +19,20 @@ describe('ConnectionModal edit password behavior', () => {
describe('ConnectionModal data source registry', () => {
it('exposes Elasticsearch in the create-connection picker with HTTP defaults', () => {
expect(source).toContain('case "elasticsearch":\n return 9200;');
expect(source).toContain('case "elasticsearch":');
expect(source).toContain('return 9200;');
expect(source).toContain('elasticsearch: ["http", "https"]');
expect(source).toContain('key: "elasticsearch"');
expect(source).toContain('name: "Elasticsearch"');
expect(source).toContain('getDbIcon("elasticsearch", undefined, 36)');
expect(source).toContain('type === "elasticsearch"');
expect(source).toContain('"http://elastic:pass@127.0.0.1:9200/logs-*"');
expect(source).toContain('label="默认索引(可选)"');
expect(source).toContain('"显示索引 (留空显示全部)"');
expect(source).toContain('return "支持索引浏览、Mapping 检查、JSON DSL 和 query_string 查询";');
expect(source).toContain(
'type === "clickhouse" ? "default" : (type === "redis" || type === "elasticsearch") ? "" : "root";',
);
expect(source).toContain(
'placeholder={dbType === "elasticsearch" ? "未开启认证可留空" : undefined}',
);
expect(source).toContain('label="显示数据库 (留空显示全部)"');
});
});

View File

@@ -234,9 +234,31 @@ vi.mock('antd', () => {
</section>
) : null
);
Modal.useModal = () => [{ info: vi.fn(() => ({ destroy: vi.fn() })) }, null];
Modal.useModal = () => {
const [infoConfig, setInfoConfig] = React.useState<any>(null);
return [{
info: vi.fn((config: any) => {
setInfoConfig(config);
return {
destroy: vi.fn(() => {
setInfoConfig(null);
}),
};
}),
}, infoConfig ? <section data-modal-use-holder="true">{infoConfig.content}</section> : null];
};
const passthrough = ({ children }: any) => <>{children}</>;
const Dropdown = ({ children, menu, disabled }: any) => (
<>
{children}
{!disabled && menu?.items?.map((item: any) => (
item?.type === 'divider'
? null
: <button key={item.key} type="button" disabled={item.disabled} onClick={item.onClick}>{item.label}</button>
))}
</>
);
const Space = ({ children }: any) => <div>{children}</div>;
const Tabs = ({ items = [], activeKey, onChange }: any) => {
const resolvedActiveKey = activeKey ?? items[0]?.key;
@@ -289,7 +311,7 @@ vi.mock('antd', () => {
message: messageApi,
Input,
Button,
Dropdown: passthrough,
Dropdown,
Form,
Pagination: () => null,
Select: ({ children }: any) => <div>{children}</div>,
@@ -806,6 +828,49 @@ describe('DataGrid DDL interactions', () => {
renderer!.unmount();
});
it('exports query-result rows from in-memory data without rerunning ExportQuery', async () => {
backendApp.ExportData.mockResolvedValue({ success: true });
backendApp.ExportQuery.mockResolvedValue({ success: true });
let renderer: ReactTestRenderer;
await act(async () => {
renderer = create(
<DataGrid
data={[
{ __gonavi_row_key__: 'row-1', owner: 'sa' },
{ __gonavi_row_key__: 'row-2', owner: 'dbo' },
]}
columnNames={['owner']}
loading={false}
exportScope="queryResult"
resultSql="EXEC sp_helpdb"
dbName="master"
connectionId="conn-1"
/>,
);
});
await waitForEffects();
await act(async () => {
await findButton(renderer!, 'HTML').props.onClick();
});
const exportAllButton = findButton(renderer!, '全部导出');
await act(async () => {
await exportAllButton.props.onClick();
});
await waitForEffects();
expect(backendApp.ExportData).toHaveBeenCalledTimes(1);
expect(backendApp.ExportData).toHaveBeenCalledWith(
[{ owner: 'sa' }, { owner: 'dbo' }],
['owner'],
'export',
'html',
);
expect(backendApp.ExportQuery).not.toHaveBeenCalled();
});
it('copies loaded column data from the v2 column header context menu', async () => {
storeState.appearance.uiVersion = 'v2';

View File

@@ -13,7 +13,9 @@ describe('DatabaseIcons', () => {
it('includes Elasticsearch in the selectable database icons', () => {
expect(DB_ICON_TYPES).toContain('elasticsearch');
expect(getDbIconLabel('elasticsearch')).toBe('Elasticsearch');
expect(renderToStaticMarkup(<>{getDbIcon('elasticsearch', undefined, 22)}</>)).toContain('ES');
const markup = renderToStaticMarkup(<>{getDbIcon('elasticsearch', undefined, 22)}</>);
expect(markup).toContain('elasticsearch.svg');
expect(markup).toContain('alt="elasticsearch"');
});
it('wraps database icons in a consistent frame for sidebar sizing', () => {

View File

@@ -28,6 +28,13 @@ const baseState = {
],
jvmDiagnosticDrafts: {},
jvmDiagnosticOutputs: {},
fontSize: 14,
appearance: {
uiVersion: "legacy",
dataTableFontSize: 14,
dataTableFontSizeFollowGlobal: true,
customMonoFontFamily: "",
},
setJVMDiagnosticDraft: vi.fn(),
appendJVMDiagnosticOutput: vi.fn(),
clearJVMDiagnosticOutput: vi.fn(),
@@ -62,6 +69,7 @@ const mockMonaco = {
KeyMod: { CtrlCmd: 2048 },
KeyCode: { Enter: 3 },
editor: {
defineTheme: vi.fn(),
setTheme: vi.fn(),
},
languages: {
@@ -193,6 +201,7 @@ describe("JVMDiagnosticConsole", () => {
removeEventListener: vi.fn(),
};
mockMonaco.editor.setTheme.mockClear();
mockMonaco.editor.defineTheme.mockClear();
mockMonaco.languages.register.mockClear();
mockMonaco.languages.registerCompletionItemProvider.mockClear();
mockEditor.addCommand.mockClear();

View File

@@ -29,6 +29,13 @@ const storeState = vi.hoisted(() => ({
aiPanelVisible: false,
setAIPanelVisible: vi.fn(),
theme: "light",
fontSize: 14,
appearance: {
uiVersion: "legacy",
dataTableFontSize: 14,
dataTableFontSizeFollowGlobal: true,
customMonoFontFamily: "",
},
}));
const backendApp = vi.hoisted(() => ({

View File

@@ -80,6 +80,10 @@ const dataGridState = vi.hoisted(() => ({
latestProps: null as any,
}));
const tabsState = vi.hoisted(() => ({
activeKey: undefined as string | undefined,
}));
const autoFetchState = vi.hoisted(() => ({
visible: false,
}));
@@ -345,11 +349,24 @@ vi.mock('antd', () => {
),
Tooltip: ({ children }: any) => <>{children}</>,
Select: () => null,
Tabs: ({ activeKey, items }: any) => {
const activeItem = items?.find((item: any) => item.key === activeKey) || items?.[0];
Tabs: ({ activeKey, items, onChange }: any) => {
const resolvedActiveKey = tabsState.activeKey ?? activeKey ?? items?.[0]?.key;
const activeItem = items?.find((item: any) => item.key === resolvedActiveKey) || items?.[0];
return (
<div>
<div>{items?.map((item: any) => <span key={item.key}>{item.label}</span>)}</div>
<div>{items?.map((item: any) => (
<button
key={item.key}
type="button"
data-tab-key={item.key}
onClick={() => {
tabsState.activeKey = item.key;
onChange?.(item.key);
}}
>
{item.label}
</button>
))}</div>
<div>{activeItem?.children}</div>
</div>
);
@@ -423,6 +440,7 @@ describe('QueryEditor external SQL save', () => {
storeState.appearance.uiVersion = 'legacy';
autoFetchState.visible = false;
dataGridState.latestProps = null;
tabsState.activeKey = undefined;
editorState.value = '';
editorState.position = { lineNumber: 1, column: 1 };
editorState.selection = null;
@@ -1788,6 +1806,198 @@ describe('QueryEditor external SQL save', () => {
renderer?.unmount();
});
it('renders result grid for sqlserver exec statements that return rows', async () => {
storeState.connections[0].config.type = 'sqlserver';
storeState.connections[0].config.database = 'master';
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,
data: [{ columns: ['SPID', 'STATUS'], rows: [{ SPID: 52, STATUS: 'RUNNABLE' }] }],
});
let renderer!: ReactTestRenderer;
await act(async () => {
renderer = create(<QueryEditor tab={createTab({ dbName: 'master', query: 'EXEC sp_who2' })} />);
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
expect(textContent(renderer!.toJSON())).toContain('结果 1');
expect(textContent(renderer!.toJSON())).not.toContain('影响行数:');
expect(dataGridState.latestProps?.columnNames).toEqual(['SPID', 'STATUS']);
expect(Array.isArray(dataGridState.latestProps?.data)).toBe(true);
expect(dataGridState.latestProps?.data?.[0]).toMatchObject({ SPID: 52, STATUS: 'RUNNABLE' });
});
it('renders standalone message result for sqlserver statistics statements', async () => {
storeState.connections[0].config.type = 'sqlserver';
storeState.connections[0].config.database = 'master';
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,
data: [{
columns: [],
rows: [],
messages: ["Table 'users'. Scan count 1, logical reads 3."],
}],
});
let renderer!: ReactTestRenderer;
await act(async () => {
renderer = create(<QueryEditor tab={createTab({ dbName: 'master', query: 'SET STATISTICS IO ON;' })} />);
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
expect(textContent(renderer!.toJSON())).toContain('消息 1');
expect(textContent(renderer!.toJSON())).toContain("Table 'users'. Scan count 1, logical reads 3.");
expect(dataGridState.latestProps?.columnNames).not.toEqual([]);
});
it('keeps multiple result sets from a single sqlserver statement', async () => {
storeState.connections[0].config.type = 'sqlserver';
storeState.connections[0].config.database = 'master';
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,
data: [
{ statementIndex: 1, columns: ['name'], rows: [{ name: 'master' }] },
{ statementIndex: 1, columns: ['owner'], rows: [{ owner: 'sa' }] },
],
});
let renderer!: ReactTestRenderer;
await act(async () => {
renderer = create(<QueryEditor tab={createTab({ dbName: 'master', query: 'EXEC sp_helpdb' })} />);
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
expect(textContent(renderer!.toJSON())).toContain('结果 1');
expect(textContent(renderer!.toJSON())).toContain('结果 2');
expect(dataGridState.latestProps?.columnNames).toEqual(['name']);
});
it('keeps both tabs when rerunning the same single sqlserver statement with multiple result sets', async () => {
storeState.connections[0].config.type = 'sqlserver';
storeState.connections[0].config.database = 'master';
backendApp.DBQueryMulti
.mockResolvedValueOnce({
success: true,
data: [
{ statementIndex: 1, columns: ['name'], rows: [{ name: 'master' }] },
{ statementIndex: 1, columns: ['owner'], rows: [{ owner: 'sa' }] },
],
})
.mockResolvedValueOnce({
success: true,
data: [
{ statementIndex: 1, columns: ['name'], rows: [{ name: 'tempdb' }] },
{ statementIndex: 1, columns: ['owner'], rows: [{ owner: 'dbo' }] },
],
});
let renderer!: ReactTestRenderer;
await act(async () => {
renderer = create(<QueryEditor tab={createTab({ dbName: 'master', query: 'EXEC sp_helpdb' })} />);
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
const tabLabels = renderer!.root.findAll((node) => {
const className = String(node.props?.className || '');
return className.includes('query-result-tab-label');
});
expect(tabLabels).toHaveLength(2);
expect(dataGridState.latestProps?.columnNames).toEqual(['name']);
expect(dataGridState.latestProps?.data?.[0]).toMatchObject({ name: 'tempdb' });
});
it('reloads the active secondary result set for a single sqlserver statement', async () => {
storeState.connections[0].config.type = 'sqlserver';
storeState.connections[0].config.database = 'master';
backendApp.DBQueryMulti
.mockResolvedValueOnce({
success: true,
data: [
{ statementIndex: 1, columns: ['name'], rows: [{ name: 'master' }] },
{ statementIndex: 1, columns: ['owner'], rows: [{ owner: 'sa' }] },
],
})
.mockResolvedValueOnce({
success: true,
data: [
{ statementIndex: 1, columns: ['name'], rows: [{ name: 'master' }] },
{ statementIndex: 1, columns: ['owner'], rows: [{ owner: 'dbo' }] },
],
});
let renderer!: ReactTestRenderer;
await act(async () => {
renderer = create(<QueryEditor tab={createTab({ dbName: 'master', query: 'EXEC sp_helpdb' })} />);
});
await act(async () => {
await findButton(renderer!, '运行').props.onClick();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
const resultTabButtons = renderer!.root.findAll((node) => node.type === 'button' && node.props['data-tab-key']);
expect(resultTabButtons).toHaveLength(2);
await act(async () => {
resultTabButtons[1].props.onClick();
});
expect(dataGridState.latestProps?.columnNames).toEqual(['owner']);
expect(dataGridState.latestProps?.data?.[0]).toMatchObject({ owner: 'sa' });
await act(async () => {
await dataGridState.latestProps?.onReload?.();
});
await act(async () => {
await Promise.resolve();
await Promise.resolve();
});
expect(backendApp.DBQueryMulti).toHaveBeenCalledTimes(2);
expect(dataGridState.latestProps?.columnNames).toEqual(['owner']);
expect(dataGridState.latestProps?.data?.[0]).toMatchObject({ owner: 'dbo' });
expect(dataGridState.latestProps?.data).not.toEqual(expect.arrayContaining([expect.objectContaining({ name: 'master' })]));
});
it('keeps non-Oracle query results read-only when no safe locator exists', async () => {
backendApp.DBQueryMulti.mockResolvedValueOnce({
success: true,

View File

@@ -1833,8 +1833,12 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
key: string;
sql: string;
exportSql?: string;
sourceStatementIndex?: number;
statementResultIndex?: number;
rows: any[];
columns: string[];
messages?: string[];
resultType?: 'grid' | 'message';
tableName?: string;
pkColumns: string[];
editLocator?: EditRowLocator;
@@ -3924,6 +3928,13 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
return selected;
};
const buildResultSetMergeKey = (result: ResultSet): string => {
const sqlKey = normalizeExecutedSqlKey(result.exportSql || result.sql);
const sourceStatementIndex = Number(result.sourceStatementIndex || 1);
const statementResultIndex = Number(result.statementResultIndex || 1);
return `${sqlKey}::${sourceStatementIndex}::${statementResultIndex}`;
};
const mergeResultSets = (previous: ResultSet[], next: ResultSet[], replaceAll: boolean): ResultSet[] => {
if (replaceAll || previous.length === 0) {
return next.map((result, index) => ({ ...result, key: `result-${index + 1}` }));
@@ -3931,8 +3942,8 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
const merged = [...previous];
next.forEach((result) => {
const incomingKey = normalizeExecutedSqlKey(result.exportSql || result.sql);
const existingIndex = merged.findIndex((item) => normalizeExecutedSqlKey(item.exportSql || item.sql) === incomingKey);
const incomingKey = buildResultSetMergeKey(result);
const existingIndex = merged.findIndex((item) => buildResultSetMergeKey(item) === incomingKey);
if (existingIndex >= 0) {
merged[existingIndex] = { ...result, key: merged[existingIndex].key };
return;
@@ -3947,8 +3958,8 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
if (!firstExecutedResult) {
return '';
}
const executedSqlKey = normalizeExecutedSqlKey(firstExecutedResult.exportSql || firstExecutedResult.sql);
return merged.find((item) => normalizeExecutedSqlKey(item.exportSql || item.sql) === executedSqlKey)?.key
const executedSqlKey = buildResultSetMergeKey(firstExecutedResult);
return merged.find((item) => buildResultSetMergeKey(item) === executedSqlKey)?.key
|| firstExecutedResult.key
|| merged[0]?.key
|| '';
@@ -4024,6 +4035,8 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
if (!sql?.trim() || !currentDb) return;
const conn = connections.find(c => c.id === currentConnectionId);
if (!conn) return;
const currentResult = resultSets.find((item) => item.key === resultKey);
const statementResultIndex = Math.max(1, Number(currentResult?.statementResultIndex || 1));
const config = {
...conn.config,
@@ -4049,10 +4062,9 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
return;
}
// 取第一个结果集(单条 SQL 只有一个结果集)
const resultSetDataArray = Array.isArray(res.data) ? (res.data as any[]) : [];
if (resultSetDataArray.length === 0) return;
const rsData = resultSetDataArray[0];
const rsData = resultSetDataArray[Math.max(0, statementResultIndex - 1)];
if (!rsData) return;
const isAffectedResult = Array.isArray(rsData.rows) && rsData.rows.length === 1
&& rsData.columns && rsData.columns.length === 1
&& rsData.columns[0] === 'affectedRows';
@@ -4075,7 +4087,16 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
// 只更新匹配的结果集的 rows 和 columns保留 tableName/pkColumns/readOnly 等元数据
setResultSets(prev => prev.map(rs =>
rs.key === resultKey
? { ...rs, rows, columns: cols, truncated }
? {
...rs,
rows,
columns: cols,
messages: Array.isArray(rsData.messages) ? rsData.messages : [],
resultType: ((!Array.isArray(rsData.rows) || rsData.rows.length === 0) && (!Array.isArray(rsData.columns) || rsData.columns.length === 0) && Array.isArray(rsData.messages) && rsData.messages.length > 0)
? 'message'
: 'grid',
truncated,
}
: rs
));
} catch (err: any) {
@@ -4240,12 +4261,29 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
key: `result-${idx + 1}`,
sql: rawStatement,
exportSql: rawStatement,
sourceStatementIndex: idx + 1,
statementResultIndex: 1,
rows,
columns: cols,
messages: Array.isArray(res.messages) ? res.messages : [],
pkColumns: [],
readOnly: true,
truncated
});
} else if (Array.isArray(res.messages) && res.messages.length > 0) {
nextResultSets.push({
key: `result-${idx + 1}`,
sql: rawStatement,
exportSql: rawStatement,
sourceStatementIndex: idx + 1,
statementResultIndex: 1,
rows: [],
columns: [],
messages: res.messages,
resultType: 'message',
pkColumns: [],
readOnly: true,
});
} else {
const affected = Number((res.data as any)?.affectedRows);
if (Number.isFinite(affected)) {
@@ -4255,8 +4293,11 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
key: `result-${idx + 1}`,
sql: rawStatement,
exportSql: rawStatement,
sourceStatementIndex: idx + 1,
statementResultIndex: 1,
rows: [row],
columns: ['affectedRows'],
messages: Array.isArray(res.messages) ? res.messages : [],
pkColumns: [],
readOnly: true
});
@@ -4373,12 +4414,17 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
const nextResultSets: ResultSet[] = [];
const maxRows = Number(queryOptions?.maxRows) || 0;
let anyTruncated = false;
const statementResultCounts = new Map<number, number>();
for (let idx = 0; idx < resultSetDataArray.length; idx++) {
const rsData = resultSetDataArray[idx];
const plan = executablePlans[idx];
const sourceStatementIndex = Number(rsData?.statementIndex || idx + 1);
const statementResultIndex = (statementResultCounts.get(sourceStatementIndex) || 0) + 1;
statementResultCounts.set(sourceStatementIndex, statementResultIndex);
const plan = executablePlans[Math.max(0, sourceStatementIndex - 1)];
const originalSql = plan?.originalSql || '';
const executedSql = plan?.executedSql || originalSql;
const resultMessages = Array.isArray(rsData?.messages) ? rsData.messages : [];
// 检查是否为 affectedRows 类结果集
const isAffectedResult = Array.isArray(rsData.rows) && rsData.rows.length === 1
@@ -4393,11 +4439,28 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
key: `result-${idx + 1}`,
sql: executedSql,
exportSql: originalSql,
sourceStatementIndex,
statementResultIndex,
rows: [row],
columns: ['affectedRows'],
messages: resultMessages,
pkColumns: [],
readOnly: true
});
} else if ((!Array.isArray(rsData.rows) || rsData.rows.length === 0) && (!Array.isArray(rsData.columns) || rsData.columns.length === 0) && resultMessages.length > 0) {
nextResultSets.push({
key: `result-${idx + 1}`,
sql: executedSql,
exportSql: originalSql,
sourceStatementIndex,
statementResultIndex,
rows: [],
columns: [],
messages: resultMessages,
resultType: 'message',
pkColumns: [],
readOnly: true,
});
} else {
let rows = Array.isArray(rsData.rows) ? rsData.rows : [];
let truncated = false;
@@ -4421,8 +4484,11 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
key: `result-${idx + 1}`,
sql: executedSql,
exportSql: originalSql,
sourceStatementIndex,
statementResultIndex,
rows,
columns: cols,
messages: resultMessages,
tableName: tableRef?.tableName,
pkColumns: plan?.pkColumns || [],
editLocator,
@@ -4432,6 +4498,22 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
}
}
if (resultSetDataArray.length === 0 && Array.isArray(res.messages) && res.messages.length > 0) {
nextResultSets.push({
key: 'result-1',
sql: fullSQL,
exportSql: sourceStatements.join(';\n'),
sourceStatementIndex: 1,
statementResultIndex: 1,
rows: [],
columns: [],
messages: res.messages,
resultType: 'message',
pkColumns: [],
readOnly: true,
});
}
const shouldReplaceAllResults = didExecuteWholeEditor;
setResultSets(prev => {
const merged = mergeResultSets(prev, nextResultSets, shouldReplaceAllResults);
@@ -5316,9 +5398,12 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
}}
>
<Tooltip title={rs.sql}>
<span className="query-result-tab-text"> {idx + 1}</span>
<span className="query-result-tab-text">{rs.resultType === 'message' ? `消息 ${idx + 1}` : `结果 ${idx + 1}`}</span>
</Tooltip>
{(() => {
if (rs.resultType === 'message') {
return <span className="query-result-tab-count">i</span>;
}
const isAffected = rs.columns.length === 1 && rs.columns[0] === 'affectedRows';
if (isAffected) {
return <span className="query-result-tab-count"></span>;
@@ -5344,6 +5429,29 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
</Dropdown>
),
children: (() => {
if (rs.resultType === 'message') {
return (
<div className={isV2Ui ? 'gn-v2-query-success' : undefined} style={{
flex: 1, minHeight: 0, display: 'flex', justifyContent: 'center',
flexDirection: 'column', gap: 12, padding: 24, color: '#666', userSelect: 'text',
overflow: 'auto',
}}>
<span style={{ fontSize: 14, fontWeight: 600 }}></span>
<div style={{
padding: 16,
borderRadius: 8,
border: darkMode ? '1px solid rgba(255,255,255,0.12)' : '1px solid rgba(0,0,0,0.08)',
background: darkMode ? 'rgba(255,255,255,0.03)' : '#fff',
whiteSpace: 'pre-wrap',
wordBreak: 'break-word',
fontFamily: 'var(--gn-font-mono)',
fontSize: 'var(--gn-font-size-mono, 13px)',
}}>
{(rs.messages || []).join('\n')}
</div>
</div>
);
}
// affectedRows 类型结果集UPDATE/INSERT/DELETE简洁提示
const isAffectedResult = rs.columns.length === 1 && rs.columns[0] === 'affectedRows';
if (isAffectedResult) {
@@ -5356,11 +5464,44 @@ const QueryEditor: React.FC<{ tab: TabData; isActive?: boolean }> = ({ tab, isAc
<span style={{ fontSize: 36, color: '#52c41a' }}></span>
<span style={{ fontSize: 14, fontWeight: 500 }}></span>
<span style={{ fontSize: 13, color: '#999' }}>{affected}</span>
{Array.isArray(rs.messages) && rs.messages.length > 0 && (
<div style={{
marginTop: 8,
maxWidth: 720,
padding: 12,
borderRadius: 8,
border: darkMode ? '1px solid rgba(255,255,255,0.12)' : '1px solid rgba(0,0,0,0.08)',
background: darkMode ? 'rgba(255,255,255,0.03)' : '#fff',
whiteSpace: 'pre-wrap',
wordBreak: 'break-word',
fontFamily: 'var(--gn-font-mono)',
fontSize: 'var(--gn-font-size-mono, 12px)',
}}>
{rs.messages.join('\n')}
</div>
)}
</div>
);
}
return (
<div style={{ flex: 1, minHeight: 0, overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
{Array.isArray(rs.messages) && rs.messages.length > 0 && (
<div style={{
flex: '0 0 auto',
margin: '8px 8px 0',
padding: '10px 12px',
borderRadius: 8,
border: darkMode ? '1px solid rgba(255,255,255,0.12)' : '1px solid rgba(0,0,0,0.08)',
background: darkMode ? 'rgba(255,255,255,0.03)' : '#fff',
whiteSpace: 'pre-wrap',
wordBreak: 'break-word',
fontFamily: 'var(--gn-font-mono)',
fontSize: 'var(--gn-font-size-mono, 12px)',
color: darkMode ? '#d4d4d4' : '#666',
}}>
{rs.messages.join('\n')}
</div>
)}
<DataGrid
data={rs.rows}
columnNames={rs.columns}

View File

@@ -1,7 +1,6 @@
import React from 'react'
import ReactDOM from 'react-dom/client'
import App from './App'
import PerfDataGridHarness from './dev/PerfDataGridHarness'
// import './index.css' // Optional global styles
// 全局配置 dayjs 使用中文 locale使 Ant Design 的 DatePicker/TimePicker 等组件
@@ -299,12 +298,18 @@ if (typeof window !== 'undefined' && !(window as any).go) {
}
const rootNode = document.getElementById('root')!;
const devHarnessMode = import.meta.env.DEV ? resolveDevHarnessMode() : '';
const rootComponent = devHarnessMode === 'datagrid-perf'
? <PerfDataGridHarness />
: <App />;
const renderRoot = async () => {
let rootComponent = <App />;
if (devHarnessMode === 'datagrid-perf') {
const { default: PerfDataGridHarness } = await import('./dev/PerfDataGridHarness');
rootComponent = <PerfDataGridHarness />;
}
ReactDOM.createRoot(rootNode).render(
<React.StrictMode>
{rootComponent}
</React.StrictMode>,
)
ReactDOM.createRoot(rootNode).render(
<React.StrictMode>
{rootComponent}
</React.StrictMode>,
);
};
void renderRoot();

View File

@@ -63,12 +63,12 @@ describe('dataSourceCapabilities', () => {
supportsCreateDatabase: false,
supportsRenameDatabase: false,
supportsDropDatabase: false,
forceReadOnlyQueryResult: true,
forceReadOnlyQueryResult: false,
});
expect(getDataSourceCapabilities({ type: 'custom', driver: 'elastic' })).toMatchObject({
type: 'elasticsearch',
supportsQueryEditor: true,
forceReadOnlyQueryResult: true,
forceReadOnlyQueryResult: false,
});
});

View File

@@ -806,6 +806,7 @@ export namespace connection {
message: string;
data: any;
fields?: string[];
messages?: string[];
queryId?: string;
static createFrom(source: any = {}) {
@@ -818,6 +819,7 @@ export namespace connection {
this.message = source["message"];
this.data = source["data"];
this.fields = source["fields"];
this.messages = source["messages"];
this.queryId = source["queryId"];
}
}

2
go.mod
View File

@@ -64,7 +64,7 @@ require (
github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 // indirect
github.com/godbus/dbus/v5 v5.1.0 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/golang-sql/sqlexp v0.1.0
github.com/golang/snappy v1.0.0 // indirect
github.com/google/flatbuffers v25.12.19+incompatible // indirect
github.com/gorilla/websocket v1.5.3

View File

@@ -1,6 +1,7 @@
package app
import (
"encoding/json"
"fmt"
"io"
"os"
@@ -17,12 +18,8 @@ import (
"github.com/wailsapp/wails/v2/pkg/runtime"
)
var migratableDataRootEntries = []string{
"connections.json",
"global_proxy.json",
"ai_config.json",
"sessions",
"drivers",
var dataRootMigrationExcludedEntries = map[string]struct{}{
"storage_root.json": {},
}
func (a *App) GetDataRootDirectoryInfo() connection.QueryResult {
@@ -122,22 +119,43 @@ func migrateDataRootContents(sourceRoot string, targetRoot string) error {
if sourceRoot == "" || targetRoot == "" {
return fmt.Errorf("数据目录不能为空")
}
if filepath.Clean(sourceRoot) == filepath.Clean(targetRoot) {
sourceAbs, err := filepath.Abs(sourceRoot)
if err != nil {
return fmt.Errorf("解析源数据目录失败:%w", err)
}
targetAbs, err := filepath.Abs(targetRoot)
if err != nil {
return fmt.Errorf("解析目标数据目录失败:%w", err)
}
if filepath.Clean(sourceAbs) == filepath.Clean(targetAbs) {
return nil
}
if rel, err := filepath.Rel(sourceAbs, targetAbs); err == nil && rel != "." && rel != "" && !strings.HasPrefix(rel, "..") && !filepath.IsAbs(rel) {
return fmt.Errorf("目标数据目录不能位于源目录内部")
}
sourceRoot = sourceAbs
targetRoot = targetAbs
if err := os.MkdirAll(targetRoot, 0o755); err != nil {
return fmt.Errorf("创建目标数据目录失败:%w", err)
}
for _, name := range migratableDataRootEntries {
entries, err := os.ReadDir(sourceRoot)
if err != nil {
return fmt.Errorf("读取源数据目录失败:%w", err)
}
for _, entry := range entries {
name := strings.TrimSpace(entry.Name())
if name == "" {
continue
}
if _, excluded := dataRootMigrationExcludedEntries[name]; excluded {
continue
}
sourcePath := filepath.Join(sourceRoot, name)
info, err := os.Stat(sourcePath)
targetPath := filepath.Join(targetRoot, name)
info, err := entry.Info()
if err != nil {
if os.IsNotExist(err) {
continue
}
return fmt.Errorf("读取源数据失败(%s%w", name, err)
}
targetPath := filepath.Join(targetRoot, name)
if info.IsDir() {
if err := copyDir(sourcePath, targetPath); err != nil {
return fmt.Errorf("迁移目录失败(%s%w", name, err)
@@ -148,6 +166,75 @@ func migrateDataRootContents(sourceRoot string, targetRoot string) error {
return fmt.Errorf("迁移文件失败(%s%w", name, err)
}
}
if err := rewriteMigratedDataRootState(targetRoot); err != nil {
return err
}
return nil
}
func rewriteMigratedDataRootState(targetRoot string) error {
if err := rewriteSecurityUpdateBackupPaths(targetRoot); err != nil {
return err
}
return nil
}
func rewriteSecurityUpdateBackupPaths(targetRoot string) error {
repo := newSecurityUpdateStateRepository(targetRoot)
marker, err := repo.readMarker()
if err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("读取迁移后的安全更新状态失败:%w", err)
}
migrationID := strings.TrimSpace(marker.MigrationID)
if migrationID == "" {
return nil
}
targetBackupPath := repo.backupPath(migrationID)
marker.BackupPath = targetBackupPath
if err := repo.writeMarker(marker); err != nil {
return fmt.Errorf("写入迁移后的安全更新状态失败:%w", err)
}
manifestPath := repo.manifestPath(migrationID)
manifestData, err := os.ReadFile(manifestPath)
if err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("读取迁移后的安全更新备份清单失败:%w", err)
}
} else {
var manifest securityUpdateBackupManifest
if err := json.Unmarshal(manifestData, &manifest); err != nil {
return fmt.Errorf("解析迁移后的安全更新备份清单失败:%w", err)
}
manifest.BackupPath = targetBackupPath
if err := securityUpdateWriteJSONFile(manifestPath, manifest); err != nil {
return fmt.Errorf("写入迁移后的安全更新备份清单失败:%w", err)
}
}
resultPath := repo.resultPath(migrationID)
resultData, err := os.ReadFile(resultPath)
if err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("读取迁移后的安全更新结果失败:%w", err)
}
} else {
var result SecurityUpdateStatus
if err := json.Unmarshal(resultData, &result); err != nil {
return fmt.Errorf("解析迁移后的安全更新结果失败:%w", err)
}
result.BackupPath = targetBackupPath
result.BackupAvailable = strings.TrimSpace(targetBackupPath) != ""
if err := securityUpdateWriteJSONFile(resultPath, result); err != nil {
return fmt.Errorf("写入迁移后的安全更新结果失败:%w", err)
}
}
return nil
}

View File

@@ -105,6 +105,36 @@ func TestMigrateDataRootContentsCopiesSecurityUpdateStateAndRewritesBackupPaths(
}
}
func TestMigrateDataRootContentsToleratesMissingSecurityUpdateArtifacts(t *testing.T) {
sourceRoot := t.TempDir()
targetRoot := filepath.Join(t.TempDir(), "gonavi-data")
sourceRepo := newSecurityUpdateStateRepository(sourceRoot)
started, err := sourceRepo.StartRound(StartSecurityUpdateRequest{SourceType: SecurityUpdateSourceTypeCurrentAppSavedConfig})
if err != nil {
t.Fatalf("start security update round failed: %v", err)
}
if err := os.Remove(sourceRepo.manifestPath(started.MigrationID)); err != nil {
t.Fatalf("remove source manifest failed: %v", err)
}
if err := os.Remove(sourceRepo.resultPath(started.MigrationID)); err != nil {
t.Fatalf("remove source result failed: %v", err)
}
if err := migrateDataRootContents(sourceRoot, targetRoot); err != nil {
t.Fatalf("migrateDataRootContents should tolerate missing security update artifacts, got: %v", err)
}
targetRepo := newSecurityUpdateStateRepository(targetRoot)
targetStatus, err := targetRepo.LoadMarker()
if err != nil {
t.Fatalf("load migrated marker failed: %v", err)
}
expectedBackupPath := filepath.Join(targetRoot, securityUpdateBackupRootDirName, started.MigrationID)
if targetStatus.BackupPath != expectedBackupPath {
t.Fatalf("expected migrated marker backupPath %q, got %q", expectedBackupPath, targetStatus.BackupPath)
}
}
func TestMigrateDataRootContentsCopiesDailySecretsForSavedConnections(t *testing.T) {
sourceRoot := t.TempDir()
targetRoot := filepath.Join(t.TempDir(), "gonavi-data")

View File

@@ -623,6 +623,7 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
}()
isReadQuery := isReadOnlySQLQuery(runConfig.Type, query)
tryQueryFirst := shouldTryQueryResultFirst(runConfig.Type, query)
runReadQuery := func(inst db.Database) ([]map[string]interface{}, []string, error) {
if q, ok := inst.(interface {
@@ -633,6 +634,14 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
return inst.Query(query)
}
runReadQueryWithMessages := func(inst db.Database) ([]map[string]interface{}, []string, []string, error) {
if q, ok := inst.(db.QueryMessageExecer); ok {
return q.QueryContextWithMessages(ctx, query)
}
data, columns, err := runReadQuery(inst)
return data, columns, nil, err
}
runExecQuery := func(inst db.Database) (int64, error) {
if e, ok := inst.(interface {
ExecContext(context.Context, string) (int64, error)
@@ -642,8 +651,8 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
return inst.Exec(query)
}
if isReadQuery {
data, columns, err := runReadQuery(dbInst)
if isReadQuery || tryQueryFirst {
data, columns, messages, err := runReadQueryWithMessages(dbInst)
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
@@ -651,32 +660,34 @@ func (a *App) DBQueryWithCancel(config connection.ConnectionConfig, dbName strin
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: retryErr.Error()}
}
data, columns, err = runReadQuery(retryInst)
data, columns, messages, err = runReadQueryWithMessages(retryInst)
}
}
if err != nil {
if err == nil {
return connection.QueryResult{Success: true, Data: data, Fields: columns, Messages: messages, QueryID: queryID}
}
if isReadQuery {
logger.Error(err, "DBQuery 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
return connection.QueryResult{Success: true, Data: data, Fields: columns, QueryID: queryID}
} else {
affected, err := runExecQuery(dbInst)
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
if retryErr != nil {
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: retryErr.Error()}
}
affected, err = runExecQuery(retryInst)
}
}
if err != nil {
logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
return connection.QueryResult{Success: true, Data: map[string]int64{"affectedRows": affected}, QueryID: queryID}
}
affected, err := runExecQuery(dbInst)
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
if retryErr != nil {
logger.Error(retryErr, "DBQuery 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: retryErr.Error()}
}
affected, err = runExecQuery(retryInst)
}
}
if err != nil {
logger.Error(err, "DBQuery 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID}
}
return connection.QueryResult{Success: true, Data: map[string]int64{"affectedRows": affected}, QueryID: queryID}
}
// DBQueryMulti 执行可能包含多条 SQL 语句的查询,返回多个结果集。
@@ -727,20 +738,25 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
}
}
runMultiQuery := func(inst db.Database) ([]connection.ResultSetData, error) {
runMultiQuery := func(inst db.Database) ([]connection.ResultSetData, []string, error) {
if !allReadOnly {
return nil, nil // 包含写操作,走逐条执行路径
return nil, nil, nil // 包含写操作,走逐条执行路径
}
if q, ok := inst.(db.MultiResultQueryMessageExecer); ok {
return q.QueryMultiContextWithMessages(ctx, query)
}
if q, ok := inst.(db.MultiResultQuerierContext); ok {
return q.QueryMultiContext(ctx, query)
results, err := q.QueryMultiContext(ctx, query)
return results, nil, err
}
if q, ok := inst.(db.MultiResultQuerier); ok {
return q.QueryMulti(query)
results, err := q.QueryMulti(query)
return results, nil, err
}
return nil, nil // 返回 nil 表示不支持
return nil, nil, nil // 返回 nil 表示不支持
}
results, err := runMultiQuery(dbInst)
results, resultMessages, err := runMultiQuery(dbInst)
if err != nil && shouldRefreshCachedConnection(err) {
if a.invalidateCachedDatabase(runConfig, err) {
retryInst, retryErr := a.getDatabaseForcePing(runConfig)
@@ -748,7 +764,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
logger.Error(retryErr, "DBQueryMulti 重建连接失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: retryErr.Error(), QueryID: queryID}
}
results, err = runMultiQuery(retryInst)
results, resultMessages, err = runMultiQuery(retryInst)
}
}
if err != nil {
@@ -758,7 +774,7 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
// 驱动支持多结果集,直接返回
if results != nil {
return connection.QueryResult{Success: true, Data: results, QueryID: queryID}
return connection.QueryResult{Success: true, Data: results, Messages: resultMessages, QueryID: queryID}
}
// 驱动不支持多结果集,回退到逐条执行
@@ -771,13 +787,50 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
}
}
var sessionQueryTarget db.StatementQueryExecer
var sessionQueryMessageTarget db.StatementQueryMessageExecer
var sessionMultiQueryTarget db.StatementMultiResultQueryExecer
var sessionMultiQueryMessageTarget db.StatementMultiResultQueryMessageExecer
var sessionExecTarget db.StatementExecer
var sessionBatchTarget db.BatchWriteExecer
closeExecTarget := func() {}
if provider, ok := dbInst.(db.SessionExecerProvider); ok {
sessionExecer, sessionErr := provider.OpenSessionExecer(ctx)
if sessionErr != nil {
logger.Warnf("DBQueryMulti 打开会话级执行器失败,将回退共享连接:%s SQL片段=%q err=%v", formatConnSummary(runConfig), sqlSnippet(query), sessionErr)
} else {
if statementQueryExecer, ok := sessionExecer.(db.StatementQueryExecer); ok {
sessionQueryTarget = statementQueryExecer
}
if statementQueryMessageExecer, ok := sessionExecer.(db.StatementQueryMessageExecer); ok {
sessionQueryMessageTarget = statementQueryMessageExecer
}
if statementMultiResultQueryExecer, ok := sessionExecer.(db.StatementMultiResultQueryExecer); ok {
sessionMultiQueryTarget = statementMultiResultQueryExecer
}
if statementMultiResultQueryMessageExecer, ok := sessionExecer.(db.StatementMultiResultQueryMessageExecer); ok {
sessionMultiQueryMessageTarget = statementMultiResultQueryMessageExecer
}
sessionExecTarget = sessionExecer
if batcher, ok := sessionExecer.(db.BatchWriteExecer); ok {
sessionBatchTarget = batcher
}
closeExecTarget = func() {
if err := sessionExecer.Close(); err != nil {
logger.Warnf("DBQueryMulti 关闭会话级执行器失败:%v", err)
}
}
}
}
defer closeExecTarget()
// 全部为写操作且驱动支持批量 Exec → 一次性发送,大幅减少网络往返
// 适用于 MySQL/MariaDB/Doris/PostgreSQL/SQLite/DuckDB 等支持多语句 Exec 的驱动
if !allReadOnly {
allWrite := true
containsPLSQLBlock := false
for _, stmt := range statements {
if strings.TrimSpace(stmt) != "" && isReadOnlySQLQuery(runConfig.Type, stmt) {
if strings.TrimSpace(stmt) != "" && !isBatchableWriteSQLStatement(runConfig.Type, stmt) {
allWrite = false
}
if isPLSQLBlockStatement(stmt) {
@@ -785,7 +838,13 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
}
}
if allWrite && !containsPLSQLBlock {
if batcher, ok := dbInst.(db.BatchWriteExecer); ok {
batcher := sessionBatchTarget
if batcher == nil {
if fallbackBatcher, ok := dbInst.(db.BatchWriteExecer); ok {
batcher = fallbackBatcher
}
}
if batcher != nil {
affected, batchErr := batcher.ExecBatchContext(ctx, query)
if batchErr != nil && shouldRefreshCachedConnection(batchErr) {
if a.invalidateCachedDatabase(runConfig, batchErr) {
@@ -823,17 +882,80 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
continue
}
if isReadOnlySQLQuery(runConfig.Type, stmt) {
var data []map[string]interface{}
var columns []string
if q, ok := dbInst.(interface {
isReadStmt := isReadOnlySQLQuery(runConfig.Type, stmt)
tryQueryStmtFirst := shouldTryQueryResultFirst(runConfig.Type, stmt)
if isReadStmt || tryQueryStmtFirst {
var (
data []map[string]interface{}
columns []string
messages []string
statementResults []connection.ResultSetData
usedMultiResult bool
)
if sessionMultiQueryMessageTarget != nil {
statementResults, messages, err = sessionMultiQueryMessageTarget.QueryMultiContextWithMessages(ctx, stmt)
usedMultiResult = true
} else if sessionMultiQueryTarget != nil {
statementResults, err = sessionMultiQueryTarget.QueryMultiContext(ctx, stmt)
usedMultiResult = true
} else if q, ok := dbInst.(db.MultiResultQueryMessageExecer); ok {
statementResults, messages, err = q.QueryMultiContextWithMessages(ctx, stmt)
usedMultiResult = true
} else if q, ok := dbInst.(db.MultiResultQuerierContext); ok {
statementResults, err = q.QueryMultiContext(ctx, stmt)
usedMultiResult = true
} else if q, ok := dbInst.(db.MultiResultQuerier); ok {
statementResults, err = q.QueryMulti(stmt)
usedMultiResult = true
} else if sessionQueryMessageTarget != nil {
data, columns, messages, err = sessionQueryMessageTarget.QueryContextWithMessages(ctx, stmt)
} else if sessionQueryTarget != nil {
data, columns, err = sessionQueryTarget.QueryContext(ctx, stmt)
} else if q, ok := dbInst.(db.QueryMessageExecer); ok {
data, columns, messages, err = q.QueryContextWithMessages(ctx, stmt)
} else if q, ok := dbInst.(interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}); ok {
data, columns, err = q.QueryContext(ctx, stmt)
} else {
data, columns, err = dbInst.Query(stmt)
}
if err != nil {
if err == nil {
if usedMultiResult {
if len(statementResults) == 0 && len(messages) > 0 {
statementResults = []connection.ResultSetData{{
Rows: []map[string]interface{}{},
Columns: []string{},
Messages: append([]string(nil), messages...),
}}
}
for _, statementResult := range statementResults {
if statementResult.Rows == nil {
statementResult.Rows = []map[string]interface{}{}
}
if statementResult.Columns == nil {
statementResult.Columns = []string{}
}
statementResult.StatementIndex = idx + 1
resultSets = append(resultSets, statementResult)
}
continue
}
if data == nil {
data = make([]map[string]interface{}, 0)
}
if columns == nil {
columns = []string{}
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: data,
Columns: columns,
Messages: messages,
StatementIndex: idx + 1,
})
continue
}
if isReadStmt {
logger.Error(err, "DBQueryMulti 逐条查询失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt))
errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err)
if len(resultSets) > 0 {
@@ -841,35 +963,30 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
}
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
}
if data == nil {
data = make([]map[string]interface{}, 0)
}
if columns == nil {
columns = []string{}
}
resultSets = append(resultSets, connection.ResultSetData{Rows: data, Columns: columns})
} else {
var affected int64
if e, ok := dbInst.(interface {
ExecContext(context.Context, string) (int64, error)
}); ok {
affected, err = e.ExecContext(ctx, stmt)
} else {
affected, err = dbInst.Exec(stmt)
}
if err != nil {
logger.Error(err, "DBQueryMulti 逐条执行失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt))
errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err)
if len(resultSets) > 0 {
errMsg += fmt.Sprintf("(前 %d 条已执行成功)", len(resultSets))
}
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: []map[string]interface{}{{"affectedRows": affected}},
Columns: []string{"affectedRows"},
})
}
var affected int64
if sessionExecTarget != nil {
affected, err = sessionExecTarget.ExecContext(ctx, stmt)
} else if e, ok := dbInst.(interface {
ExecContext(context.Context, string) (int64, error)
}); ok {
affected, err = e.ExecContext(ctx, stmt)
} else {
affected, err = dbInst.Exec(stmt)
}
if err != nil {
logger.Error(err, "DBQueryMulti 逐条执行失败(第 %d/%d 条):%s SQL片段=%q", idx+1, len(statements), formatConnSummary(runConfig), sqlSnippet(stmt))
errMsg := fmt.Sprintf("第 %d 条语句执行失败: %v", idx+1, err)
if len(resultSets) > 0 {
errMsg += fmt.Sprintf("(前 %d 条已执行成功)", len(resultSets))
}
return connection.QueryResult{Success: false, Message: errMsg, QueryID: queryID}
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: []map[string]interface{}{{"affectedRows": affected}},
Columns: []string{"affectedRows"},
})
}
if resultSets == nil {
@@ -883,6 +1000,24 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu
return connection.QueryResult{Success: true, Data: resultSets, QueryID: queryID, Message: fallbackMsg}
}
func shouldTryQueryResultFirst(dbType string, query string) bool {
isSQLServer := strings.EqualFold(strings.TrimSpace(dbType), "sqlserver")
keyword := leadingSQLKeyword(query)
switch keyword {
case "exec", "execute", "call":
return true
case "set", "print":
return isSQLServer
case "dbcc":
return isSQLServer
default:
if isSQLServer {
return strings.HasPrefix(keyword, "sp_") || strings.HasPrefix(keyword, "xp_")
}
return false
}
}
func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string, query string) connection.QueryResult {
runConfig := normalizeRunConfig(config, dbName)
@@ -906,22 +1041,30 @@ func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string,
defer cancel()
isReadQuery := isReadOnlySQLQuery(runConfig.Type, query)
tryQueryFirst := shouldTryQueryResultFirst(runConfig.Type, query)
if isReadQuery {
var data []map[string]interface{}
var columns []string
if q, ok := dbInst.(interface {
if isReadQuery || tryQueryFirst {
var (
data []map[string]interface{}
columns []string
messages []string
)
if q, ok := dbInst.(db.QueryMessageExecer); ok {
data, columns, messages, err = q.QueryContextWithMessages(ctx, query)
} else if q, ok := dbInst.(interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}); ok {
data, columns, err = q.QueryContext(ctx, query)
} else {
data, columns, err = dbInst.Query(query)
}
if err != nil {
if err == nil {
return connection.QueryResult{Success: true, Data: data, Fields: columns, Messages: messages}
}
if isReadQuery {
logger.Error(err, "DBQueryIsolated 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error()}
}
return connection.QueryResult{Success: true, Data: data, Fields: columns}
}
var affected int64

View File

@@ -10,10 +10,17 @@ import (
)
type fakeBatchWriteDB struct {
batchCalls int
execCalls int
batchCalls int
execCalls int
execQueries []string
lastQuery string
lastQuery string
queryCalls int
queryMap map[string][]map[string]interface{}
fieldMap map[string][]string
messageMap map[string][]string
multiResult map[string][]connection.ResultSetData
queryErr map[string]error
session *fakeBatchWriteSession
}
func (f *fakeBatchWriteDB) Connect(config connection.ConnectionConfig) error {
@@ -29,7 +36,16 @@ func (f *fakeBatchWriteDB) Ping() error {
}
func (f *fakeBatchWriteDB) Query(query string) ([]map[string]interface{}, []string, error) {
return nil, nil, nil
f.queryCalls++
if err := f.queryErr[query]; err != nil {
return nil, nil, err
}
return f.queryMap[query], f.fieldMap[query], nil
}
func (f *fakeBatchWriteDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
rows, fields, err := f.Query(query)
return rows, fields, f.messageMap[query], err
}
func (f *fakeBatchWriteDB) Exec(query string) (int64, error) {
@@ -76,12 +92,143 @@ func (f *fakeBatchWriteDB) ExecContext(ctx context.Context, query string) (int64
return 1, nil
}
func (f *fakeBatchWriteDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
f.queryCalls++
if err := f.queryErr[query]; err != nil {
return nil, nil, err
}
return f.queryMap[query], f.fieldMap[query], nil
}
func (f *fakeBatchWriteDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
rows, fields, err := f.QueryContext(ctx, query)
return rows, fields, f.messageMap[query], err
}
func (f *fakeBatchWriteDB) ExecBatchContext(ctx context.Context, query string) (int64, error) {
f.batchCalls++
f.lastQuery = query
return 500, nil
}
func (f *fakeBatchWriteDB) OpenSessionExecer(ctx context.Context) (db.StatementExecer, error) {
f.session = &fakeBatchWriteSession{parent: f}
return f.session, nil
}
type fakeBatchWriteSession struct {
parent *fakeBatchWriteDB
queryCalls int
execCalls int
batchCalls int
closed bool
}
func (s *fakeBatchWriteSession) Query(query string) ([]map[string]interface{}, []string, error) {
return s.QueryContext(context.Background(), query)
}
func (s *fakeBatchWriteSession) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
s.queryCalls++
return s.parent.QueryContext(ctx, query)
}
func (s *fakeBatchWriteSession) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
return s.QueryContextWithMessages(context.Background(), query)
}
func (s *fakeBatchWriteSession) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
s.queryCalls++
return s.parent.QueryContextWithMessages(ctx, query)
}
func (s *fakeBatchWriteSession) QueryMulti(query string) ([]connection.ResultSetData, error) {
return s.QueryMultiContext(context.Background(), query)
}
func (s *fakeBatchWriteSession) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
if multi := s.parent.multiResult[query]; len(multi) > 0 {
s.queryCalls++
return cloneResultSets(multi), nil
}
rows, columns, err := s.QueryContext(ctx, query)
if err != nil {
return nil, err
}
return []connection.ResultSetData{{Rows: rows, Columns: columns}}, nil
}
func (s *fakeBatchWriteSession) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
return s.QueryMultiContextWithMessages(context.Background(), query)
}
func (s *fakeBatchWriteSession) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
if err := s.parent.queryErr[query]; err != nil {
s.queryCalls++
return nil, nil, err
}
if multi := s.parent.multiResult[query]; len(multi) > 0 {
s.queryCalls++
return cloneResultSets(multi), append([]string(nil), s.parent.messageMap[query]...), nil
}
rows, columns, messages, err := s.QueryContextWithMessages(ctx, query)
if err != nil {
return nil, nil, err
}
return []connection.ResultSetData{{
Rows: rows,
Columns: columns,
Messages: append([]string(nil), messages...),
}}, append([]string(nil), messages...), nil
}
func (s *fakeBatchWriteSession) Exec(query string) (int64, error) {
return s.ExecContext(context.Background(), query)
}
func (s *fakeBatchWriteSession) ExecContext(ctx context.Context, query string) (int64, error) {
s.execCalls++
return s.parent.ExecContext(ctx, query)
}
func (s *fakeBatchWriteSession) ExecBatchContext(ctx context.Context, query string) (int64, error) {
s.batchCalls++
return s.parent.ExecBatchContext(ctx, query)
}
func (s *fakeBatchWriteSession) Close() error {
s.closed = true
return nil
}
func cloneResultSets(input []connection.ResultSetData) []connection.ResultSetData {
if len(input) == 0 {
return nil
}
cloned := make([]connection.ResultSetData, 0, len(input))
for _, item := range input {
rows := make([]map[string]interface{}, 0, len(item.Rows))
for _, row := range item.Rows {
if row == nil {
rows = append(rows, nil)
continue
}
rowCopy := make(map[string]interface{}, len(row))
for key, value := range row {
rowCopy[key] = value
}
rows = append(rows, rowCopy)
}
cloned = append(cloned, connection.ResultSetData{
Rows: rows,
Columns: append([]string(nil), item.Columns...),
Messages: append([]string(nil), item.Messages...),
StatementIndex: item.StatementIndex,
})
}
return cloned
}
func TestDBQueryMultiKeepsOracleAnonymousBlockAsSingleStatement(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
@@ -122,6 +269,85 @@ END;`
}
var _ db.BatchWriteExecer = (*fakeBatchWriteDB)(nil)
var _ db.SessionExecerProvider = (*fakeBatchWriteDB)(nil)
var _ db.QueryMessageExecer = (*fakeBatchWriteDB)(nil)
var _ db.StatementQueryMessageExecer = (*fakeBatchWriteSession)(nil)
func TestDBQueryWithCancelReturnsResultSetForExecStoredProcedure(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
query := "EXEC sp_who2"
fakeDB := &fakeBatchWriteDB{
queryMap: map[string][]map[string]interface{}{
query: {
{"SPID": 52, "STATUS": "RUNNABLE"},
},
},
fieldMap: map[string][]string{
query: {"SPID", "STATUS"},
},
queryErr: map[string]error{},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
result := app.DBQueryWithCancel(config, "master", query, "sp-who2-test")
if !result.Success {
t.Fatalf("expected DBQueryWithCancel success, got failure: %s", result.Message)
}
rows, ok := result.Data.([]map[string]interface{})
if !ok {
t.Fatalf("expected []map[string]interface{}, got %T", result.Data)
}
if len(rows) != 1 || rows[0]["SPID"] != 52 {
t.Fatalf("unexpected rows: %#v", rows)
}
if fakeDB.execCalls != 0 {
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls)
}
}
func TestDBQueryWithCancelReturnsMessagesForSQLServerQuery(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
query := "SET STATISTICS IO ON"
fakeDB := &fakeBatchWriteDB{
queryMap: map[string][]map[string]interface{}{
query: {},
},
fieldMap: map[string][]string{
query: {},
},
messageMap: map[string][]string{
query: {"Table 'users'. Scan count 1, logical reads 3."},
},
queryErr: map[string]error{},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
result := app.DBQueryWithCancel(config, "master", query, "statistics-io-test")
if !result.Success {
t.Fatalf("expected DBQueryWithCancel success, got failure: %s", result.Message)
}
if len(result.Messages) != 1 || result.Messages[0] == "" {
t.Fatalf("expected SQL Server messages to be returned, got %#v", result.Messages)
}
}
func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
@@ -168,3 +394,206 @@ func TestDBQueryMultiUsesBatchWriteExecerForAllWriteStatements(t *testing.T) {
t.Fatalf("expected affectedRows=500, got %#v", got)
}
}
func TestDBQueryMultiPrefersResultSetForExecStoredProcedure(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
query := "EXEC sp_who2"
fakeDB := &fakeBatchWriteDB{
queryMap: map[string][]map[string]interface{}{
query: {
{"SPID": 77, "STATUS": "SUSPENDED"},
},
},
fieldMap: map[string][]string{
query: {"SPID", "STATUS"},
},
queryErr: map[string]error{},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
result := app.DBQueryMulti(config, "master", query, "sp-who2-multi-test")
if !result.Success {
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
}
resultSets, ok := result.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
}
if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 {
t.Fatalf("unexpected result sets: %#v", resultSets)
}
if got := resultSets[0].Rows[0]["SPID"]; got != 77 {
t.Fatalf("expected SPID=77, got %#v", got)
}
if fakeDB.execCalls != 0 {
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls)
}
}
func TestDBQueryMultiDoesNotBatchExecStoredProcedureAsWriteStatement(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
query := "EXEC sp_who2"
fakeDB := &fakeBatchWriteDB{
queryMap: map[string][]map[string]interface{}{
query: {
{"SPID": 88, "STATUS": "RUNNING"},
},
},
fieldMap: map[string][]string{
query: {"SPID", "STATUS"},
},
queryErr: map[string]error{},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
result := app.DBQueryMulti(config, "master", query, "sp-who2-batch-guard-test")
if !result.Success {
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
}
if fakeDB.batchCalls != 0 {
t.Fatalf("expected stored procedure to skip batch write path, got batchCalls=%d", fakeDB.batchCalls)
}
resultSets, ok := result.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
}
if len(resultSets) != 1 || len(resultSets[0].Rows) != 1 {
t.Fatalf("unexpected result sets: %#v", resultSets)
}
if got := resultSets[0].Rows[0]["SPID"]; got != 88 {
t.Fatalf("expected SPID=88, got %#v", got)
}
}
func TestDBQueryMultiUsesPinnedSessionForSequentialFallback(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
fakeDB := &fakeBatchWriteDB{
queryMap: map[string][]map[string]interface{}{
"SELECT 1 AS value": {
{"value": 1},
},
},
fieldMap: map[string][]string{
"SELECT 1 AS value": {"value"},
},
messageMap: map[string][]string{
"SET NOCOUNT ON": {"NOCOUNT 已开启"},
},
queryErr: map[string]error{},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
result := app.DBQueryMulti(config, "master", "SET NOCOUNT ON;\nSELECT 1 AS value;", "session-fallback-test")
if !result.Success {
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
}
if fakeDB.session == nil {
t.Fatal("expected DBQueryMulti to open a pinned session for sequential fallback")
}
if !fakeDB.session.closed {
t.Fatal("expected DBQueryMulti to close the pinned session")
}
if fakeDB.session.execCalls != 0 {
t.Fatalf("expected SQL Server SET statement to avoid exec-only path, got execCalls=%d", fakeDB.session.execCalls)
}
if fakeDB.session.queryCalls != 2 {
t.Fatalf("expected both statements to query through pinned session, got queryCalls=%d", fakeDB.session.queryCalls)
}
if fakeDB.queryCalls != 2 {
t.Fatalf("expected exactly two underlying query calls, got %d", fakeDB.queryCalls)
}
resultSets, ok := result.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
}
if len(resultSets) != 2 {
t.Fatalf("expected two result sets, got %#v", resultSets)
}
if len(resultSets[0].Messages) != 1 || resultSets[0].Messages[0] != "NOCOUNT 已开启" {
t.Fatalf("expected first result set to keep session message, got %#v", resultSets[0].Messages)
}
if got := resultSets[1].Rows[0]["value"]; got != 1 {
t.Fatalf("expected second result set value=1, got %#v", got)
}
}
func TestDBQueryMultiKeepsAllResultSetsFromSingleSQLServerStatement(t *testing.T) {
originalNewDatabaseFunc := newDatabaseFunc
t.Cleanup(func() {
newDatabaseFunc = originalNewDatabaseFunc
})
query := "EXEC sp_helpdb"
fakeDB := &fakeBatchWriteDB{
multiResult: map[string][]connection.ResultSetData{
query: {
{
Rows: []map[string]interface{}{{"name": "master"}},
Columns: []string{"name"},
},
{
Rows: []map[string]interface{}{{"owner": "sa"}},
Columns: []string{"owner"},
},
},
},
queryErr: map[string]error{},
}
newDatabaseFunc = func(dbType string) (db.Database, error) {
return fakeDB, nil
}
app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test"))
config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"}
result := app.DBQueryMulti(config, "master", query, "sp-helpdb-multi-result-test")
if !result.Success {
t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message)
}
resultSets, ok := result.Data.([]connection.ResultSetData)
if !ok {
t.Fatalf("expected []connection.ResultSetData, got %T", result.Data)
}
if len(resultSets) != 2 {
t.Fatalf("expected two result sets, got %#v", resultSets)
}
if got := resultSets[0].Rows[0]["name"]; got != "master" {
t.Fatalf("expected first result set to keep master row, got %#v", got)
}
if got := resultSets[1].Rows[0]["owner"]; got != "sa" {
t.Fatalf("expected second result set to keep owner row, got %#v", got)
}
if resultSets[0].StatementIndex != 1 || resultSets[1].StatementIndex != 1 {
t.Fatalf("expected both result sets to map to the first statement, got %#v", resultSets)
}
if fakeDB.execCalls != 0 {
t.Fatalf("expected exec path to be skipped, got execCalls=%d", fakeDB.execCalls)
}
}

View File

@@ -65,6 +65,19 @@ func isReadOnlySQLQuery(dbType string, query string) bool {
}
}
func isBatchableWriteSQLStatement(dbType string, query string) bool {
if isReadOnlySQLQuery(dbType, query) {
return false
}
switch leadingSQLKeyword(query) {
case "insert", "update", "delete", "replace", "merge", "upsert":
return true
default:
return false
}
}
func sanitizeSQLForPgLike(dbType string, query string) string {
normalizedType := strings.ToLower(strings.TrimSpace(dbType))
switch normalizedType {

View File

@@ -62,3 +62,42 @@ func TestSanitizeSQLForPgLike_DoesNotModifyOtherDBTypes(t *testing.T) {
t.Fatalf("non-PG-like db should not be sanitized:\nIN: %s\nOUT: %s", in, out)
}
}
func TestIsReadOnlySQLQuery_DoesNotTreatExecAsReadOnly(t *testing.T) {
if isReadOnlySQLQuery("sqlserver", "EXEC sp_who2") {
t.Fatal("EXEC should not be treated as read-only SQL")
}
}
func TestIsBatchableWriteSQLStatement_OnlyMatchesRealWriteStatements(t *testing.T) {
if !isBatchableWriteSQLStatement("mysql", "INSERT INTO demo(id) VALUES (1)") {
t.Fatal("expected INSERT to be treated as batchable write")
}
if isBatchableWriteSQLStatement("sqlserver", "EXEC sp_who2") {
t.Fatal("EXEC should not be treated as batchable write")
}
if isBatchableWriteSQLStatement("sqlserver", "SET STATISTICS IO ON") {
t.Fatal("SET STATISTICS should not be treated as batchable write")
}
}
func TestShouldTryQueryResultFirst_TreatsSQLServerSetAsQueryFirst(t *testing.T) {
if !shouldTryQueryResultFirst("sqlserver", "SET STATISTICS IO ON") {
t.Fatal("expected SQL Server SET STATISTICS to try query-first for notice capture")
}
if shouldTryQueryResultFirst("mysql", "SET sql_mode = ''") {
t.Fatal("non-SQLServer SET should not force query-first")
}
}
func TestShouldTryQueryResultFirst_TreatsSQLServerSystemCommandsAsQueryFirst(t *testing.T) {
if !shouldTryQueryResultFirst("sqlserver", "sp_who2") {
t.Fatal("expected bare SQL Server system procedure to try query-first")
}
if !shouldTryQueryResultFirst("sqlserver", "DBCC INPUTBUFFER(52)") {
t.Fatal("expected SQL Server DBCC command to try query-first")
}
if shouldTryQueryResultFirst("mysql", "sp_who2") {
t.Fatal("non-SQLServer system procedure name should not force query-first")
}
}

View File

@@ -45,7 +45,8 @@ func (p *panicChromium) PutZoomFactor(float64) {
}
type panicFrontend struct {
chromium *panicChromium
chromium *panicChromium
mainWindow *fakeWindow
}
// 测试必须用 wails 一致的 string key "frontend" 作为 context.WithValue 的 key
@@ -122,7 +123,10 @@ func TestResetWebViewZoomFactorErrorsWhenFrontendMissing(t *testing.T) {
}
func TestResetWebViewZoomFactorRecoversFromPutZoomFactorPanic(t *testing.T) {
ctx := context.WithValue(context.Background(), stringContextKey("frontend"), &panicFrontend{chromium: &panicChromium{}})
ctx := context.WithValue(context.Background(), stringContextKey("frontend"), &panicFrontend{
chromium: &panicChromium{},
mainWindow: &fakeWindow{},
})
err := resetWebViewZoomFactor(ctx, 1.0)
if err == nil {

View File

@@ -122,17 +122,20 @@ type ConnectionConfig struct {
// ResultSetData 表示一个查询结果集(行 + 列名),用于多结果集场景。
type ResultSetData struct {
Rows []map[string]interface{} `json:"rows"`
Columns []string `json:"columns"`
Rows []map[string]interface{} `json:"rows"`
Columns []string `json:"columns"`
Messages []string `json:"messages,omitempty"`
StatementIndex int `json:"statementIndex,omitempty"`
}
// QueryResult 是 Wails 绑定方法的统一响应格式,前端通过此结构体接收后端结果。
type QueryResult struct {
Success bool `json:"success"`
Message string `json:"message"`
Data interface{} `json:"data"`
Fields []string `json:"fields,omitempty"`
QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation
Success bool `json:"success"`
Message string `json:"message"`
Data interface{} `json:"data"`
Fields []string `json:"fields,omitempty"`
Messages []string `json:"messages,omitempty"`
QueryID string `json:"queryId,omitempty"` // Unique ID for query cancellation
}
// ColumnDefinition 描述表的一个列定义。

View File

@@ -67,6 +67,51 @@ type StatementExecer interface {
Close() error
}
// StatementQueryExecer can run queries on a pinned session/connection.
// Drivers that return sqlConnStatementExecer automatically satisfy it.
type StatementQueryExecer interface {
StatementExecer
Query(query string) ([]map[string]interface{}, []string, error)
QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error)
}
// StatementQueryMessageExecer can run queries on a pinned session and return
// extra server messages/notices alongside rows.
type StatementQueryMessageExecer interface {
StatementQueryExecer
QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error)
QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error)
}
// StatementMultiResultQueryExecer can run multi-result queries on a pinned session/connection.
type StatementMultiResultQueryExecer interface {
StatementExecer
QueryMulti(query string) ([]connection.ResultSetData, error)
QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error)
}
// StatementMultiResultQueryMessageExecer can run multi-result queries on a
// pinned session/connection and return server messages/notices.
type StatementMultiResultQueryMessageExecer interface {
StatementMultiResultQueryExecer
QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error)
QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error)
}
// QueryMessageExecer is an optional database-level interface for returning
// informational server messages alongside one result set.
type QueryMessageExecer interface {
QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error)
QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error)
}
// MultiResultQueryMessageExecer is an optional database-level interface for
// returning informational server messages alongside multi-result queries.
type MultiResultQueryMessageExecer interface {
QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error)
QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error)
}
// SessionExecerProvider is implemented by database/sql based drivers that can
// pin a long-running job to one physical connection.
type SessionExecerProvider interface {
@@ -96,6 +141,38 @@ func (e *sqlConnStatementExecer) Exec(query string) (int64, error) {
return e.ExecContext(context.Background(), query)
}
func (e *sqlConnStatementExecer) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if e == nil || e.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := e.conn.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
}
func (e *sqlConnStatementExecer) Query(query string) ([]map[string]interface{}, []string, error) {
return e.QueryContext(context.Background(), query)
}
func (e *sqlConnStatementExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
if e == nil || e.conn == nil {
return nil, fmt.Errorf("连接未打开")
}
rows, err := e.conn.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanMultiRows(rows)
}
func (e *sqlConnStatementExecer) QueryMulti(query string) ([]connection.ResultSetData, error) {
return e.QueryMultiContext(context.Background(), query)
}
func (e *sqlConnStatementExecer) ExecBatchContext(ctx context.Context, query string) (int64, error) {
return e.ExecContext(ctx, query)
}

View File

@@ -17,6 +17,7 @@ import (
"GoNavi-Wails/internal/ssh"
"GoNavi-Wails/internal/utils"
"github.com/golang-sql/sqlexp"
_ "github.com/microsoft/go-mssqldb"
)
@@ -26,6 +27,85 @@ type SqlServerDB struct {
forwarder *ssh.LocalForwarder
}
type sqlServerSessionExecer struct {
conn *sql.Conn
}
func scanSQLServerRowsWithMessages(ctx context.Context, rows *sql.Rows, retmsg *sqlexp.ReturnMessage) ([]connection.ResultSetData, []string, error) {
if rows == nil {
return []connection.ResultSetData{{Rows: []map[string]interface{}{}, Columns: []string{}}}, nil, nil
}
if ctx == nil {
ctx = context.Background()
}
var (
resultSets []connection.ResultSetData
messages []string
allMessages []string
)
active := true
for active {
raw := retmsg.Message(ctx)
switch msg := raw.(type) {
case sqlexp.MsgNotice:
text := strings.TrimSpace(fmt.Sprint(msg.Message))
if text != "" {
messages = append(messages, text)
allMessages = append(allMessages, text)
}
case sqlexp.MsgNext:
data, cols, err := scanRows(rows)
if err != nil {
return resultSets, messages, err
}
if data == nil {
data = []map[string]interface{}{}
}
if cols == nil {
cols = []string{}
}
resultSets = append(resultSets, connection.ResultSetData{
Rows: data,
Columns: cols,
Messages: append([]string(nil), messages...),
})
messages = nil
case sqlexp.MsgRowsAffected:
resultSets = append(resultSets, connection.ResultSetData{
Rows: []map[string]interface{}{{"affectedRows": msg.Count}},
Columns: []string{"affectedRows"},
Messages: append([]string(nil), messages...),
})
messages = nil
case sqlexp.MsgNextResultSet:
active = rows.NextResultSet()
case sqlexp.MsgError:
return resultSets, messages, msg.Error
default:
active = false
}
}
if len(messages) > 0 {
resultSets = append(resultSets, connection.ResultSetData{
Rows: []map[string]interface{}{},
Columns: []string{},
Messages: append([]string(nil), messages...),
})
}
if len(resultSets) == 0 {
resultSets = []connection.ResultSetData{{
Rows: []map[string]interface{}{},
Columns: []string{},
}}
}
if err := rows.Err(); err != nil {
return resultSets, allMessages, err
}
return resultSets, allMessages, nil
}
// quoteBracket escapes ] in identifiers for safe use in SQL Server [bracket] notation
func quoteBracket(name string) string {
return strings.ReplaceAll(name, "]", "]]")
@@ -133,54 +213,76 @@ func (s *SqlServerDB) Ping() error {
}
func (s *SqlServerDB) QueryMulti(query string) ([]connection.ResultSetData, error) {
results, _, err := s.QueryMultiWithMessages(query)
return results, err
}
func (s *SqlServerDB) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
if s.conn == nil {
return nil, fmt.Errorf("连接未打开")
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := s.conn.Query(query)
ctx := context.Background()
retmsg := &sqlexp.ReturnMessage{}
rows, err := s.conn.QueryContext(ctx, query, retmsg)
if err != nil {
return nil, err
return nil, nil, err
}
defer rows.Close()
return scanMultiRows(rows)
return scanSQLServerRowsWithMessages(ctx, rows, retmsg)
}
func (s *SqlServerDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
results, _, err := s.QueryMultiContextWithMessages(ctx, query)
return results, err
}
func (s *SqlServerDB) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
if s.conn == nil {
return nil, fmt.Errorf("连接未打开")
return nil, nil, fmt.Errorf("连接未打开")
}
rows, err := s.conn.QueryContext(ctx, query)
retmsg := &sqlexp.ReturnMessage{}
rows, err := s.conn.QueryContext(ctx, query, retmsg)
if err != nil {
return nil, err
return nil, nil, err
}
defer rows.Close()
return scanMultiRows(rows)
return scanSQLServerRowsWithMessages(ctx, rows, retmsg)
}
func (s *SqlServerDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
rows, columns, _, err := s.QueryContextWithMessages(ctx, query)
return rows, columns, err
}
func (s *SqlServerDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
if s.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
return nil, nil, nil, fmt.Errorf("连接未打开")
}
rows, err := s.conn.QueryContext(ctx, query)
resultSets, messages, err := s.QueryMultiContextWithMessages(ctx, query)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
defer rows.Close()
return scanRows(rows)
if len(resultSets) == 0 {
return []map[string]interface{}{}, []string{}, messages, nil
}
first := resultSets[0]
if first.Rows == nil {
first.Rows = []map[string]interface{}{}
}
if first.Columns == nil {
first.Columns = []string{}
}
return first.Rows, first.Columns, messages, nil
}
func (s *SqlServerDB) Query(query string) ([]map[string]interface{}, []string, error) {
if s.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
rows, columns, _, err := s.QueryWithMessages(query)
return rows, columns, err
}
rows, err := s.conn.Query(query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanRows(rows)
func (s *SqlServerDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
return s.QueryContextWithMessages(context.Background(), query)
}
func (s *SqlServerDB) ExecContext(ctx context.Context, query string) (int64, error) {
@@ -213,7 +315,7 @@ func (s *SqlServerDB) OpenSessionExecer(ctx context.Context) (StatementExecer, e
if err != nil {
return nil, err
}
return NewSQLConnStatementExecer(conn), nil
return &sqlServerSessionExecer{conn: conn}, nil
}
func (s *SqlServerDB) Exec(query string) (int64, error) {
@@ -227,6 +329,87 @@ func (s *SqlServerDB) Exec(query string) (int64, error) {
return res.RowsAffected()
}
func (e *sqlServerSessionExecer) Exec(query string) (int64, error) {
return e.ExecContext(context.Background(), query)
}
func (e *sqlServerSessionExecer) ExecContext(ctx context.Context, query string) (int64, error) {
if e == nil || e.conn == nil {
return 0, fmt.Errorf("连接未打开")
}
res, err := e.conn.ExecContext(ctx, query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (e *sqlServerSessionExecer) Query(query string) ([]map[string]interface{}, []string, error) {
rows, columns, _, err := e.QueryWithMessages(query)
return rows, columns, err
}
func (e *sqlServerSessionExecer) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
rows, columns, _, err := e.QueryContextWithMessages(ctx, query)
return rows, columns, err
}
func (e *sqlServerSessionExecer) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) {
return e.QueryContextWithMessages(context.Background(), query)
}
func (e *sqlServerSessionExecer) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) {
results, messages, err := e.QueryMultiContextWithMessages(ctx, query)
if err != nil {
return nil, nil, nil, err
}
if len(results) == 0 {
return []map[string]interface{}{}, []string{}, messages, nil
}
first := results[0]
if first.Rows == nil {
first.Rows = []map[string]interface{}{}
}
if first.Columns == nil {
first.Columns = []string{}
}
return first.Rows, first.Columns, messages, nil
}
func (e *sqlServerSessionExecer) QueryMulti(query string) ([]connection.ResultSetData, error) {
results, _, err := e.QueryMultiWithMessages(query)
return results, err
}
func (e *sqlServerSessionExecer) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) {
results, _, err := e.QueryMultiContextWithMessages(ctx, query)
return results, err
}
func (e *sqlServerSessionExecer) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) {
return e.QueryMultiContextWithMessages(context.Background(), query)
}
func (e *sqlServerSessionExecer) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) {
if e == nil || e.conn == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
retmsg := &sqlexp.ReturnMessage{}
rows, err := e.conn.QueryContext(ctx, query, retmsg)
if err != nil {
return nil, nil, err
}
defer rows.Close()
return scanSQLServerRowsWithMessages(ctx, rows, retmsg)
}
func (e *sqlServerSessionExecer) Close() error {
if e == nil || e.conn == nil {
return nil
}
return e.conn.Close()
}
func (s *SqlServerDB) GetDatabases() ([]string, error) {
query := "SELECT name FROM sys.databases WHERE state_desc = 'ONLINE' ORDER BY name"
data, _, err := s.Query(query)