🐛 fix(query-export): 修复查询结果导出卡住并统一按数据源能力控制导出路径

- 查询结果页导出增加稳定兜底,异常时确保 loading 关闭避免持续转圈
- DataGrid 导出逻辑按数据源能力分流,优先走后端 ExportQuery 并保留结果集导出降级
- QueryEditor 传递结果导出 SQL,保证查询结果导出范围与当前结果一致
- 后端补充 ExportData/ExportQuery 关键日志,提升导出链路可观测性
This commit is contained in:
Syngnat
2026-03-02 14:18:44 +08:00
parent 84688e995a
commit 3ca898a950
15 changed files with 672 additions and 71 deletions

View File

@@ -379,7 +379,7 @@ jobs:
- name: List Assets
run: ls -R release-assets
- name: Verify DuckDB Driver Assets
- name: Verify Optional Driver Assets
shell: bash
run: |
set -euo pipefail
@@ -390,20 +390,24 @@ jobs:
"drivers/MacOS/duckdb-driver-agent-darwin-amd64"
"drivers/MacOS/duckdb-driver-agent-darwin-arm64"
"drivers/Linux/duckdb-driver-agent-linux-amd64"
"drivers/Windows/clickhouse-driver-agent-windows-amd64.exe"
"drivers/MacOS/clickhouse-driver-agent-darwin-amd64"
"drivers/MacOS/clickhouse-driver-agent-darwin-arm64"
"drivers/Linux/clickhouse-driver-agent-linux-amd64"
)
missing=0
for file in "${REQUIRED_FILES[@]}"; do
if [ ! -f "$file" ]; then
echo "❌ 缺少 DuckDB 驱动资产:$file"
echo "❌ 缺少驱动资产:$file"
missing=1
else
echo "✅ 已找到 DuckDB 驱动资产:$file"
echo "✅ 已找到驱动资产:$file"
fi
done
if [ "$missing" -ne 0 ]; then
echo "❌ DuckDB 驱动资产不完整,终止发布"
echo "❌ 可选驱动资产不完整,终止发布"
exit 1
fi

View File

@@ -2,11 +2,13 @@ package main
import (
"bufio"
"context"
"encoding/json"
"fmt"
"os"
"reflect"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/db"
@@ -17,6 +19,7 @@ type agentRequest struct {
Method string `json:"method"`
Config *connection.ConnectionConfig `json:"config,omitempty"`
Query string `json:"query,omitempty"`
TimeoutMs int64 `json:"timeoutMs,omitempty"`
DBName string `json:"dbName,omitempty"`
TableName string `json:"tableName,omitempty"`
Changes *connection.ChangeSet `json:"changes,omitempty"`
@@ -48,6 +51,8 @@ const (
agentMethodApplyChanges = "applyChanges"
)
const legacyClickHouseDefaultTimeout = 2 * time.Hour
var (
agentDriverType string
agentDatabaseFactory func() db.Database
@@ -138,14 +143,14 @@ func handleRequest(inst *db.Database, req agentRequest) agentResponse {
return fail(resp, err.Error())
}
case agentMethodQuery:
data, fields, err := (*inst).Query(req.Query)
data, fields, err := queryWithOptionalTimeout(*inst, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
resp.Data = data
resp.Fields = fields
case agentMethodExec:
affected, err := (*inst).Exec(req.Query)
affected, err := execWithOptionalTimeout(*inst, req.Query, req.TimeoutMs)
if err != nil {
return fail(resp, err.Error())
}
@@ -287,3 +292,39 @@ func normalizeAgentResponseData(v interface{}) interface{} {
return v
}
}
func queryWithOptionalTimeout(inst db.Database, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) {
effectiveTimeoutMs := timeoutMs
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
}
if effectiveTimeoutMs <= 0 {
return inst.Query(query)
}
if q, ok := inst.(interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
return q.QueryContext(ctx, query)
}
return inst.Query(query)
}
func execWithOptionalTimeout(inst db.Database, query string, timeoutMs int64) (int64, error) {
effectiveTimeoutMs := timeoutMs
if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") {
effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond)
}
if effectiveTimeoutMs <= 0 {
return inst.Exec(query)
}
if e, ok := inst.(interface {
ExecContext(context.Context, string) (int64, error)
}); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond)
defer cancel()
return e.ExecContext(ctx, query)
}
return inst.Exec(query)
}

View File

@@ -3,8 +3,13 @@ package main
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"testing"
"time"
"GoNavi-Wails/internal/connection"
)
type duckMapLike map[any]any
@@ -60,3 +65,108 @@ func TestNormalizeAgentResponseData_KeepByteSlice(t *testing.T) {
t.Fatalf("[]byte 内容被意外改写: %v", out)
}
}
type fakeAgentTimeoutDB struct {
queryCalled bool
queryContextCalled bool
execCalled bool
execContextCalled bool
deadlineSet bool
}
func (f *fakeAgentTimeoutDB) Connect(config connection.ConnectionConfig) error { return nil }
func (f *fakeAgentTimeoutDB) Close() error { return nil }
func (f *fakeAgentTimeoutDB) Ping() error { return nil }
func (f *fakeAgentTimeoutDB) Query(query string) ([]map[string]interface{}, []string, error) {
f.queryCalled = true
return nil, nil, errors.New("query should not be called")
}
func (f *fakeAgentTimeoutDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
f.queryContextCalled = true
if _, ok := ctx.Deadline(); ok {
f.deadlineSet = true
}
return []map[string]interface{}{{"ok": 1}}, []string{"ok"}, nil
}
func (f *fakeAgentTimeoutDB) Exec(query string) (int64, error) {
f.execCalled = true
return 0, errors.New("exec should not be called")
}
func (f *fakeAgentTimeoutDB) ExecContext(ctx context.Context, query string) (int64, error) {
f.execContextCalled = true
if _, ok := ctx.Deadline(); ok {
f.deadlineSet = true
}
return 3, nil
}
func (f *fakeAgentTimeoutDB) GetDatabases() ([]string, error) { return nil, nil }
func (f *fakeAgentTimeoutDB) GetTables(dbName string) ([]string, error) {
return nil, nil
}
func (f *fakeAgentTimeoutDB) GetCreateStatement(dbName, tableName string) (string, error) {
return "", nil
}
func (f *fakeAgentTimeoutDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
return nil, nil
}
func (f *fakeAgentTimeoutDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
return nil, nil
}
func (f *fakeAgentTimeoutDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
return nil, nil
}
func (f *fakeAgentTimeoutDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
return nil, nil
}
func (f *fakeAgentTimeoutDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return nil, nil
}
func TestQueryWithOptionalTimeout_UsesQueryContext(t *testing.T) {
fake := &fakeAgentTimeoutDB{}
data, fields, err := queryWithOptionalTimeout(fake, "SELECT 1", int64((2 * time.Second).Milliseconds()))
if err != nil {
t.Fatalf("queryWithOptionalTimeout 返回错误: %v", err)
}
if !fake.queryContextCalled || fake.queryCalled {
t.Fatalf("query 调用路径异常QueryContext=%v Query=%v", fake.queryContextCalled, fake.queryCalled)
}
if !fake.deadlineSet {
t.Fatal("queryWithOptionalTimeout 未设置 deadline")
}
if len(data) != 1 || len(fields) != 1 || fields[0] != "ok" {
t.Fatalf("queryWithOptionalTimeout 返回数据异常: data=%v fields=%v", data, fields)
}
}
func TestExecWithOptionalTimeout_UsesExecContext(t *testing.T) {
fake := &fakeAgentTimeoutDB{}
affected, err := execWithOptionalTimeout(fake, "DELETE FROM t", int64((2 * time.Second).Milliseconds()))
if err != nil {
t.Fatalf("execWithOptionalTimeout 返回错误: %v", err)
}
if !fake.execContextCalled || fake.execCalled {
t.Fatalf("exec 调用路径异常ExecContext=%v Exec=%v", fake.execContextCalled, fake.execCalled)
}
if !fake.deadlineSet {
t.Fatal("execWithOptionalTimeout 未设置 deadline")
}
if affected != 3 {
t.Fatalf("受影响行数异常want=3 got=%d", affected)
}
}
func TestQueryWithOptionalTimeout_ClickHouseLegacyModeUsesQueryContext(t *testing.T) {
old := agentDriverType
agentDriverType = "clickhouse"
defer func() { agentDriverType = old }()
fake := &fakeAgentTimeoutDB{}
_, _, err := queryWithOptionalTimeout(fake, "SELECT 1", 0)
if err != nil {
t.Fatalf("queryWithOptionalTimeout 返回错误: %v", err)
}
if !fake.queryContextCalled || fake.queryCalled {
t.Fatalf("clickhouse legacy query 调用路径异常QueryContext=%v Query=%v", fake.queryContextCalled, fake.queryCalled)
}
}

View File

@@ -75,7 +75,7 @@
},
"clickhouse": {
"engine": "go",
"version": "2.43.0",
"version": "2.43.1",
"checksumPolicy": "off",
"downloadUrl": "builtin://activate/clickhouse"
},

View File

@@ -12,6 +12,7 @@ import { v4 as uuidv4 } from 'uuid';
import 'react-resizable/css/styles.css';
import { buildOrderBySQL, buildWhereSQL, escapeLiteral, quoteIdentPart, quoteQualifiedIdent, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql';
import { isMacLikePlatform, normalizeOpacityForPlatform } from '../utils/appearance';
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
// --- Error Boundary ---
interface DataGridErrorBoundaryState {
@@ -302,6 +303,7 @@ const DataContext = React.createContext<{
copyToClipboard: (t: string) => void;
tableName?: string;
enableRowContextMenu: boolean;
supportsCopyInsert: boolean;
} | null>(null);
interface Item {
@@ -444,7 +446,7 @@ const ContextMenuRow = React.memo(({ children, record, ...props }: any) => {
if (!record || !context) return <tr {...props}>{children}</tr>;
const { selectedRowKeysRef, displayDataRef, handleCopyInsert, handleCopyJson, handleCopyCsv, handleExportSelected, copyToClipboard, enableRowContextMenu } = context;
const { selectedRowKeysRef, displayDataRef, handleCopyInsert, handleCopyJson, handleCopyCsv, handleExportSelected, copyToClipboard, enableRowContextMenu, supportsCopyInsert } = context;
if (!enableRowContextMenu) {
return <tr {...props}>{children}</tr>;
@@ -460,12 +462,12 @@ const ContextMenuRow = React.memo(({ children, record, ...props }: any) => {
};
const menuItems: MenuProps['items'] = [
{
key: 'insert',
label: `复制为 INSERT`,
icon: <ConsoleSqlOutlined />,
onClick: () => handleCopyInsert(record)
},
...(supportsCopyInsert ? [{
key: 'insert',
label: '复制为 INSERT',
icon: <ConsoleSqlOutlined />,
onClick: () => handleCopyInsert(record),
}] : []),
{ key: 'json', label: '复制为 JSON', icon: <FileTextOutlined />, onClick: () => handleCopyJson(record) },
{ key: 'csv', label: '复制为 CSV', icon: <FileTextOutlined />, onClick: () => handleCopyCsv(record) },
{ key: 'copy', label: '复制为 Markdown', icon: <CopyOutlined />, onClick: () => {
@@ -502,6 +504,8 @@ interface DataGridProps {
columnNames: string[];
loading: boolean;
tableName?: string;
exportScope?: 'table' | 'queryResult';
resultSql?: string;
dbName?: string;
connectionId?: string;
pkColumns?: string[];
@@ -543,7 +547,7 @@ type ColumnMeta = {
};
const DataGrid: React.FC<DataGridProps> = ({
data, columnNames, loading, tableName, dbName, connectionId, pkColumns = [], readOnly = false,
data, columnNames, loading, tableName, exportScope = 'table', resultSql, dbName, connectionId, pkColumns = [], readOnly = false,
onReload, onSort, onPageChange, pagination, onRequestTotalCount, onCancelTotalCount, sortInfoExternal, showFilter, onToggleFilter, onApplyFilter
}) => {
const connections = useStore(state => state.connections);
@@ -559,8 +563,14 @@ const DataGrid: React.FC<DataGridProps> = ({
const showColumnComment = queryOptions?.showColumnComment !== false;
const showColumnType = queryOptions?.showColumnType !== false;
const selectionColumnWidth = 46;
const connTypeLower = String(connections.find(c => c.id === connectionId)?.config?.type || '').trim().toLowerCase();
const isDuckDBConnection = connTypeLower === 'duckdb';
const currentConnConfig = connections.find(c => c.id === connectionId)?.config;
const dataSourceCaps = getDataSourceCapabilities(currentConnConfig);
const isDuckDBConnection = dataSourceCaps.type === 'duckdb';
const supportsCopyInsert = dataSourceCaps.supportsCopyInsert;
const supportsSqlQueryExport = dataSourceCaps.supportsSqlQueryExport;
const isQueryResultExport = exportScope === 'queryResult';
const canImport = exportScope === 'table' && !!tableName;
const canExport = !!connectionId && (isQueryResultExport || !!tableName);
// Background Helper
const getBg = (darkHex: string) => {
@@ -687,11 +697,20 @@ const DataGrid: React.FC<DataGridProps> = ({
// Helper to export specific data
const exportData = async (rows: any[], format: string) => {
const hide = message.loading(`正在导出 ${rows.length} 条数据...`, 0);
const cleanRows = rows.map(({ [GONAVI_ROW_KEY]: _rowKey, ...rest }) => rest);
// Pass tableName (or 'export') as default filename
const res = await ExportData(cleanRows, columnNames, tableName || 'export', format);
hide();
if (res.success) { message.success("导出成功"); } else if (res.message !== "Cancelled") { message.error("导出失败: " + res.message); }
try {
const cleanRows = rows.map(({ [GONAVI_ROW_KEY]: _rowKey, ...rest }) => rest);
// Pass tableName (or 'export') as default filename
const res = await ExportData(cleanRows, columnNames, tableName || 'export', format);
if (res.success) {
message.success("导出成功");
} else if (res.message !== "Cancelled") {
message.error("导出失败: " + res.message);
}
} catch (e: any) {
message.error("导出失败: " + (e?.message || String(e)));
} finally {
hide();
}
};
const [sortInfo, setSortInfo] = useState<{ columnKey: string, order: string } | null>(null);
@@ -2101,6 +2120,10 @@ const DataGrid: React.FC<DataGridProps> = ({
}, []);
const handleCopyInsert = useCallback((record: any) => {
if (!supportsCopyInsert) {
message.warning("当前数据源不支持复制为 INSERT请使用 JSON/CSV/Markdown 复制。");
return;
}
const records = getTargets(record);
const sqls = records.map((r: any) => {
const { [GONAVI_ROW_KEY]: _rowKey, ...vals } = r;
@@ -2110,7 +2133,7 @@ const DataGrid: React.FC<DataGridProps> = ({
return `INSERT INTO \`${targetTable}\` (${cols.map(c => `\`${c}\``).join(', ')}) VALUES (${values.join(', ')});`;
});
copyToClipboard(sqls.join('\n'));
}, [tableName, getTargets, copyToClipboard]);
}, [supportsCopyInsert, tableName, getTargets, copyToClipboard]);
const handleCopyJson = useCallback((record: any) => {
const records = getTargets(record);
@@ -2149,12 +2172,17 @@ const DataGrid: React.FC<DataGridProps> = ({
const config = buildConnConfig();
if (!config) return;
const hide = message.loading(`正在导出...`, 0);
const res = await ExportQuery(config as any, dbName || '', sql, defaultName || 'export', format);
hide();
if (res.success) {
message.success("导出成功");
} else if (res.message !== "Cancelled") {
message.error("导出失败: " + res.message);
try {
const res = await ExportQuery(config as any, dbName || '', sql, defaultName || 'export', format);
if (res.success) {
message.success("导出成功");
} else if (res.message !== "Cancelled") {
message.error("导出失败: " + res.message);
}
} catch (e: any) {
message.error("导出失败: " + (e?.message || String(e)));
} finally {
hide();
}
}, [buildConnConfig, dbName]);
@@ -2198,6 +2226,10 @@ const DataGrid: React.FC<DataGridProps> = ({
// Context Menu Export
const handleExportSelected = useCallback(async (format: string, record: any) => {
const records = getTargets(record);
if (isQueryResultExport) {
await exportData(records, format);
return;
}
if (!connectionId || !tableName) {
await exportData(records, format);
return;
@@ -2225,11 +2257,11 @@ const DataGrid: React.FC<DataGridProps> = ({
const sql = `SELECT * FROM ${quoteQualifiedIdent(dbType, tableName)} WHERE ${pkWhere}`;
await exportByQuery(sql, format, tableName || 'export');
}, [getTargets, connectionId, tableName, hasChanges, exportData, buildConnConfig, buildPkWhereSql, exportByQuery]);
}, [getTargets, isQueryResultExport, connectionId, tableName, hasChanges, exportData, buildConnConfig, buildPkWhereSql, exportByQuery]);
// Export
const handleExport = async (format: string) => {
if (!connectionId || !tableName) return;
if (!connectionId) return;
// 1. Export Selected
if (selectedRowKeys.length > 0) {
@@ -2238,17 +2270,38 @@ const DataGrid: React.FC<DataGridProps> = ({
return;
}
// 查询结果页导出统一按当前结果集(已加载数据)导出,避免再次执行原 SQL 造成大数据导出或长时间阻塞。
if (isQueryResultExport) {
const sql = String(resultSql || '').trim();
if (!hasChanges && supportsSqlQueryExport && sql) {
await exportByQuery(sql, format, tableName || 'query_result');
} else {
await exportData(mergedDisplayData, format);
}
return;
}
// 2. Prompt for Current vs All
// Using a custom modal content with buttons to handle 3 states
let instance: any;
const handleAll = async () => {
instance.destroy();
if (!tableName) return;
const config = buildConnConfig();
if (!config) return;
const hide = message.loading(`正在导出全部数据...`, 0);
const res = await ExportTable(config as any, dbName || '', tableName, format);
hide();
if (res.success) { message.success("导出成功"); } else if (res.message !== "Cancelled") { message.error("导出失败: " + res.message); }
try {
const res = await ExportTable(config as any, dbName || '', tableName, format);
if (res.success) {
message.success("导出成功");
} else if (res.message !== "Cancelled") {
message.error("导出失败: " + res.message);
}
} catch (e: any) {
message.error("导出失败: " + (e?.message || String(e)));
} finally {
hide();
}
};
const handlePage = async () => {
instance.destroy();
@@ -2411,7 +2464,8 @@ const DataGrid: React.FC<DataGridProps> = ({
copyToClipboard,
tableName,
enableRowContextMenu: !canModifyData,
}), [handleCopyCsv, handleCopyInsert, handleCopyJson, handleExportSelected, copyToClipboard, tableName, canModifyData]);
supportsCopyInsert,
}), [handleCopyCsv, handleCopyInsert, handleCopyJson, handleExportSelected, copyToClipboard, tableName, canModifyData, supportsCopyInsert]);
const cellContextMenuValue = useMemo(() => ({
showMenu: showCellContextMenu,
@@ -2456,8 +2510,8 @@ const DataGrid: React.FC<DataGridProps> = ({
setSelectedRowKeys([]);
onReload();
}}></Button>}
{tableName && <Button icon={<ImportOutlined />} onClick={handleImport}></Button>}
{tableName && <Dropdown menu={{ items: exportMenu }}><Button icon={<ExportOutlined />}> <DownOutlined /></Button></Dropdown>}
{canImport && <Button icon={<ImportOutlined />} onClick={handleImport}></Button>}
{canExport && <Dropdown menu={{ items: exportMenu }}><Button icon={<ExportOutlined />}> <DownOutlined /></Button></Dropdown>}
{canModifyData && (
<>
@@ -2996,21 +3050,23 @@ const DataGrid: React.FC<DataGridProps> = ({
({selectedRowKeys.length})
</div>
<div style={{ height: 1, background: darkMode ? '#303030' : '#f0f0f0', margin: '4px 0' }} />
<div
style={{
padding: '8px 12px',
cursor: 'pointer',
transition: 'background 0.2s',
}}
onMouseEnter={(e) => e.currentTarget.style.background = darkMode ? '#303030' : '#f5f5f5'}
onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'}
onClick={() => {
if (cellContextMenu.record) handleCopyInsert(cellContextMenu.record);
setCellContextMenu(prev => ({ ...prev, visible: false }));
}}
>
INSERT
</div>
{supportsCopyInsert && (
<div
style={{
padding: '8px 12px',
cursor: 'pointer',
transition: 'background 0.2s',
}}
onMouseEnter={(e) => e.currentTarget.style.background = darkMode ? '#303030' : '#f5f5f5'}
onMouseLeave={(e) => e.currentTarget.style.background = 'transparent'}
onClick={() => {
if (cellContextMenu.record) handleCopyInsert(cellContextMenu.record);
setCellContextMenu(prev => ({ ...prev, visible: false }));
}}
>
INSERT
</div>
)}
<div
style={{
padding: '8px 12px',

View File

@@ -5,6 +5,7 @@ import { useStore } from '../store';
import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App';
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
import { buildOrderBySQL, buildWhereSQL, quoteIdentPart, quoteQualifiedIdent, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql';
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
type ViewerPaginationState = {
current: number;
@@ -172,8 +173,10 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
const [showFilter, setShowFilter] = useState(false);
const [filterConditions, setFilterConditions] = useState<FilterCondition[]>([]);
const duckdbSafeSelectCacheRef = useRef<Record<string, string>>({});
const currentConnType = (connections.find(c => c.id === tab.connectionId)?.config?.type || '').toLowerCase();
const forceReadOnly = currentConnType === 'tdengine' || currentConnType === 'clickhouse';
const currentConnConfig = connections.find(c => c.id === tab.connectionId)?.config;
const currentConnCaps = getDataSourceCapabilities(currentConnConfig);
const currentConnType = currentConnCaps.type;
const forceReadOnly = currentConnCaps.forceReadOnlyQueryResult;
useEffect(() => {
setPkColumns([]);
@@ -673,6 +676,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => {
columnNames={columnNames}
loading={loading}
tableName={tab.tableName}
exportScope="table"
dbName={tab.dbName}
connectionId={tab.connectionId}
pkColumns={pkColumns}

View File

@@ -1,4 +1,4 @@
import React, { useState, useEffect, useRef } from 'react';
import React, { useState, useEffect, useRef, useMemo } from 'react';
import Editor, { OnMount } from '@monaco-editor/react';
import { Button, message, Modal, Input, Form, Dropdown, MenuProps, Tooltip, Select, Tabs } from 'antd';
import { PlayCircleOutlined, SaveOutlined, FormatPainterOutlined, SettingOutlined, CloseOutlined } from '@ant-design/icons';
@@ -7,6 +7,7 @@ import { TabData, ColumnDefinition } from '../types';
import { useStore } from '../store';
import { DBQuery, DBGetTables, DBGetAllColumns, DBGetDatabases, DBGetColumns } from '../../wailsjs/go/app/App';
import DataGrid, { GONAVI_ROW_KEY } from './DataGrid';
import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities';
const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
const [query, setQuery] = useState(tab.query || 'SELECT * FROM ');
@@ -14,6 +15,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
type ResultSet = {
key: string;
sql: string;
exportSql?: string;
rows: any[];
columns: string[];
tableName?: string;
@@ -47,6 +49,10 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
const visibleDbsRef = useRef<string[]>([]); // Store visible databases for cross-db intellisense
const connections = useStore(state => state.connections);
const queryCapableConnections = useMemo(
() => connections.filter(c => getDataSourceCapabilities(c.config).supportsQueryEditor),
[connections]
);
const addSqlLog = useStore(state => state.addSqlLog);
const currentConnectionIdRef = useRef(currentConnectionId);
const currentDbRef = useRef(currentDb);
@@ -64,6 +70,16 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
currentConnectionIdRef.current = currentConnectionId;
}, [currentConnectionId]);
useEffect(() => {
if (!queryCapableConnections.some(c => c.id === currentConnectionId)) {
const fallback = queryCapableConnections[0]?.id || '';
if (fallback && fallback !== currentConnectionId) {
setCurrentConnectionId(fallback);
setCurrentDb('');
}
}
}, [queryCapableConnections, currentConnectionId]);
useEffect(() => {
currentDbRef.current = currentDb;
}, [currentDb]);
@@ -977,6 +993,12 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
if (runSeqRef.current === runSeq) setLoading(false);
return;
}
const connCaps = getDataSourceCapabilities(conn.config);
if (!connCaps.supportsQueryEditor) {
message.error("当前数据源不支持 SQL 查询编辑器,请使用对应专用页面。");
if (runSeqRef.current === runSeq) setLoading(false);
return;
}
const config = {
...conn.config,
@@ -1000,8 +1022,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
const nextResultSets: ResultSet[] = [];
const maxRows = Number(queryOptions?.maxRows) || 0;
const dbType = String((config as any).type || 'mysql');
const normalizedDbType = dbType.toLowerCase();
const forceReadOnlyResult = normalizedDbType === 'tdengine' || normalizedDbType === 'clickhouse';
const forceReadOnlyResult = connCaps.forceReadOnlyQueryResult;
const wantsLimitProbe = Number.isFinite(maxRows) && maxRows > 0;
const probeLimit = wantsLimitProbe ? (maxRows + 1) : 0;
let anyTruncated = false;
@@ -1066,6 +1087,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
nextResultSets.push({
key: `result-${idx + 1}`,
sql: rawStatement,
exportSql: limited.applied ? applyAutoLimit(rawStatement, dbType, Math.max(1, Number(maxRows) || 1)).sql : rawStatement,
rows,
columns: cols,
tableName: simpleTableName,
@@ -1082,6 +1104,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
nextResultSets.push({
key: `result-${idx + 1}`,
sql: rawStatement,
exportSql: rawStatement,
rows: [row],
columns: ['affectedRows'],
pkColumns: [],
@@ -1223,7 +1246,7 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
setCurrentConnectionId(val);
setCurrentDb('');
}}
options={connections.map(c => ({ label: c.name, value: c.id }))}
options={queryCapableConnections.map(c => ({ label: c.name, value: c.id }))}
showSearch
/>
<Select
@@ -1333,6 +1356,8 @@ const QueryEditor: React.FC<{ tab: TabData }> = ({ tab }) => {
columnNames={rs.columns}
loading={loading}
tableName={rs.tableName}
exportScope="queryResult"
resultSql={rs.exportSql || rs.sql}
dbName={currentDb}
connectionId={currentConnectionId}
pkColumns={rs.pkColumns}

View File

@@ -0,0 +1,86 @@
import type { ConnectionConfig } from '../types';
type ConnectionLike = Pick<ConnectionConfig, 'type' | 'driver'> | null | undefined;
const normalizeDataSourceToken = (raw: string): string => {
const normalized = String(raw || '').trim().toLowerCase();
switch (normalized) {
case 'doris':
return 'diros';
case 'postgresql':
return 'postgres';
case 'dm':
return 'dameng';
default:
return normalized;
}
};
export const resolveDataSourceType = (config: ConnectionLike): string => {
if (!config) return '';
const type = normalizeDataSourceToken(String(config.type || ''));
if (type === 'custom') {
const driver = normalizeDataSourceToken(String(config.driver || ''));
return driver || 'custom';
}
return type;
};
const SQL_QUERY_EXPORT_TYPES = new Set([
'mysql',
'mariadb',
'diros',
'sphinx',
'postgres',
'kingbase',
'highgo',
'vastbase',
'sqlserver',
'sqlite',
'duckdb',
'oracle',
'dameng',
'tdengine',
'clickhouse',
]);
const COPY_INSERT_TYPES = new Set([
'mysql',
'mariadb',
'diros',
'sphinx',
'postgres',
'kingbase',
'highgo',
'vastbase',
'sqlserver',
'sqlite',
'duckdb',
'oracle',
'dameng',
'tdengine',
'clickhouse',
]);
const QUERY_EDITOR_DISABLED_TYPES = new Set(['redis']);
const FORCE_READ_ONLY_QUERY_TYPES = new Set(['tdengine', 'clickhouse']);
export type DataSourceCapabilities = {
type: string;
supportsQueryEditor: boolean;
supportsSqlQueryExport: boolean;
supportsCopyInsert: boolean;
forceReadOnlyQueryResult: boolean;
};
export const getDataSourceCapabilities = (config: ConnectionLike): DataSourceCapabilities => {
const type = resolveDataSourceType(config);
return {
type,
supportsQueryEditor: !QUERY_EDITOR_DISABLED_TYPES.has(type),
supportsSqlQueryExport: SQL_QUERY_EXPORT_TYPES.has(type),
supportsCopyInsert: COPY_INSERT_TYPES.has(type),
forceReadOnlyQueryResult: FORCE_READ_ONLY_QUERY_TYPES.has(type),
};
};

View File

@@ -225,7 +225,7 @@ const builtinDriverManifestJSON = `{
"vastbase": { "engine": "go", "version": "1.11.1", "checksumPolicy": "off", "downloadUrl": "builtin://activate/vastbase" },
"mongodb": { "engine": "go", "version": "2.5.0", "checksumPolicy": "off", "downloadUrl": "builtin://activate/mongodb" },
"tdengine": { "engine": "go", "version": "3.7.8", "checksumPolicy": "off", "downloadUrl": "builtin://activate/tdengine" },
"clickhouse": { "engine": "go", "version": "2.43.0", "checksumPolicy": "off", "downloadUrl": "builtin://activate/clickhouse" }
"clickhouse": { "engine": "go", "version": "2.43.1", "checksumPolicy": "off", "downloadUrl": "builtin://activate/clickhouse" }
}
}`
@@ -278,7 +278,7 @@ var latestDriverVersionMap = map[string]string{
"vastbase": "1.11.2",
"mongodb": "2.5.0",
"tdengine": "3.7.8",
"clickhouse": "2.43.0",
"clickhouse": "2.43.1",
"oracle": "2.9.0",
"postgres": "1.11.2",
"redis": "9.17.3",

View File

@@ -2,6 +2,7 @@ package app
import (
"bufio"
"context"
"encoding/csv"
"encoding/json"
"fmt"
@@ -16,11 +17,16 @@ import (
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/db"
"GoNavi-Wails/internal/logger"
"GoNavi-Wails/internal/utils"
"github.com/wailsapp/wails/v2/pkg/runtime"
"github.com/xuri/excelize/v2"
)
const minExportQueryTimeout = 5 * time.Minute
const minClickHouseExportQueryTimeout = 2 * time.Hour
func (a *App) OpenSQLFile() connection.QueryResult {
selection, err := runtime.OpenFileDialog(a.ctx, runtime.OpenDialogOptions{
Title: "Select SQL File",
@@ -614,7 +620,7 @@ func (a *App) ExportTable(config connection.ConnectionConfig, dbName string, tab
query := fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(runConfig.Type, tableName))
data, columns, err := dbInst.Query(query)
data, columns, err := queryDataForExport(dbInst, runConfig, query)
if err != nil {
return connection.QueryResult{Success: false, Message: err.Error()}
}
@@ -945,7 +951,7 @@ func listViewNameLookup(dbInst db.Database, config connection.ConnectionConfig,
if strings.TrimSpace(query) == "" {
continue
}
rows, _, err := dbInst.Query(query)
rows, _, err := queryDataForExport(dbInst, config, query)
if err != nil {
continue
}
@@ -1056,7 +1062,7 @@ func tryGetViewCreateStatement(
if strings.TrimSpace(query) == "" {
continue
}
rows, _, err := dbInst.Query(query)
rows, _, err := queryDataForExport(dbInst, config, query)
if err != nil || len(rows) == 0 {
continue
}
@@ -1421,7 +1427,7 @@ func dumpTableSQL(
qualified := qualifyTable(schemaName, pureTableName)
selectSQL := fmt.Sprintf("SELECT * FROM %s", quoteQualifiedIdentByType(config.Type, qualified))
data, columns, err := dbInst.Query(selectSQL)
data, columns, err := queryDataForExport(dbInst, config, selectSQL)
if err != nil {
return err
}
@@ -1456,14 +1462,17 @@ func (a *App) ExportData(data []map[string]interface{}, columns []string, defaul
if defaultName == "" {
defaultName = "export"
}
logger.Infof("ExportData 开始rows=%d cols=%d format=%s defaultName=%s", len(data), len(columns), strings.ToLower(strings.TrimSpace(format)), strings.TrimSpace(defaultName))
filename, err := runtime.SaveFileDialog(a.ctx, runtime.SaveDialogOptions{
Title: "Export Data",
DefaultFilename: fmt.Sprintf("%s.%s", defaultName, strings.ToLower(format)),
})
if err != nil || filename == "" {
logger.Infof("ExportData 已取消或未选择文件err=%v", err)
return connection.QueryResult{Success: false, Message: "Cancelled"}
}
logger.Infof("ExportData 选定文件:%s", filename)
f, err := os.Create(filename)
if err != nil {
@@ -1471,9 +1480,11 @@ func (a *App) ExportData(data []map[string]interface{}, columns []string, defaul
}
defer f.Close()
if err := writeRowsToFile(f, data, columns, format); err != nil {
logger.Warnf("ExportData 写入失败file=%s err=%v", filename, err)
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
}
logger.Infof("ExportData 完成file=%s rows=%d", filename, len(data))
return connection.QueryResult{Success: true, Message: "Export successful"}
}
@@ -1494,8 +1505,10 @@ func (a *App) ExportQuery(config connection.ConnectionConfig, dbName string, que
DefaultFilename: fmt.Sprintf("%s.%s", defaultName, strings.ToLower(format)),
})
if err != nil || filename == "" {
logger.Infof("ExportQuery 已取消或未选择文件err=%v", err)
return connection.QueryResult{Success: false, Message: "Cancelled"}
}
logger.Infof("ExportQuery 开始type=%s db=%s format=%s file=%s sql=%q", strings.TrimSpace(config.Type), strings.TrimSpace(dbName), strings.ToLower(strings.TrimSpace(format)), filename, sqlSnippet(query))
runConfig := normalizeRunConfig(config, dbName)
dbInst, err := a.getDatabase(runConfig)
@@ -1509,8 +1522,9 @@ func (a *App) ExportQuery(config connection.ConnectionConfig, dbName string, que
return connection.QueryResult{Success: false, Message: "Only SELECT/WITH queries are supported"}
}
data, columns, err := dbInst.Query(query)
data, columns, err := queryDataForExport(dbInst, runConfig, query)
if err != nil {
logger.Warnf("ExportQuery 查询失败type=%s db=%s err=%v sql=%q", strings.TrimSpace(config.Type), strings.TrimSpace(dbName), err, sqlSnippet(query))
return connection.QueryResult{Success: false, Message: err.Error()}
}
@@ -1521,12 +1535,55 @@ func (a *App) ExportQuery(config connection.ConnectionConfig, dbName string, que
defer f.Close()
if err := writeRowsToFile(f, data, columns, format); err != nil {
logger.Warnf("ExportQuery 写入失败file=%s err=%v", filename, err)
return connection.QueryResult{Success: false, Message: "Write error: " + err.Error()}
}
logger.Infof("ExportQuery 完成file=%s rows=%d cols=%d", filename, len(data), len(columns))
return connection.QueryResult{Success: true, Message: "Export successful"}
}
func queryDataForExport(dbInst db.Database, config connection.ConnectionConfig, query string) ([]map[string]interface{}, []string, error) {
timeout := getExportQueryTimeout(config)
dbType := resolveDDLDBType(config)
if dbType == "clickhouse" {
logger.Infof("ClickHouse 导出查询开始timeout=%s SQL片段=%q", timeout, sqlSnippet(query))
}
if q, ok := dbInst.(interface {
QueryContext(context.Context, string) ([]map[string]interface{}, []string, error)
}); ok {
ctx, cancel := utils.ContextWithTimeout(timeout)
defer cancel()
data, columns, err := q.QueryContext(ctx, query)
if err != nil && dbType == "clickhouse" {
logger.Warnf("ClickHouse 导出查询失败timeout=%s SQL片段=%q err=%v", timeout, sqlSnippet(query), err)
}
return data, columns, err
}
data, columns, err := dbInst.Query(query)
if err != nil && dbType == "clickhouse" {
logger.Warnf("ClickHouse 导出查询失败(无 QueryContexttimeout=%s SQL片段=%q err=%v", timeout, sqlSnippet(query), err)
}
return data, columns, err
}
func getExportQueryTimeout(config connection.ConnectionConfig) time.Duration {
timeout := time.Duration(config.Timeout) * time.Second
if timeout <= 0 {
timeout = minExportQueryTimeout
}
if resolveDDLDBType(config) == "clickhouse" {
if timeout < minClickHouseExportQueryTimeout {
timeout = minClickHouseExportQueryTimeout
}
return timeout
}
if timeout < minExportQueryTimeout {
timeout = minExportQueryTimeout
}
return timeout
}
func writeRowsToFile(f *os.File, data []map[string]interface{}, columns []string, format string) error {
format = strings.ToLower(strings.TrimSpace(format))
if f == nil {

View File

@@ -2,12 +2,65 @@ package app
import (
"bytes"
"context"
"encoding/json"
"os"
"strings"
"testing"
"time"
"GoNavi-Wails/internal/connection"
)
type fakeExportQueryDB struct {
data []map[string]interface{}
cols []string
err error
lastQuery string
lastContextTimeout time.Duration
hasContextDeadline bool
}
func (f *fakeExportQueryDB) Connect(config connection.ConnectionConfig) error { return nil }
func (f *fakeExportQueryDB) Close() error { return nil }
func (f *fakeExportQueryDB) Ping() error { return nil }
func (f *fakeExportQueryDB) Query(query string) ([]map[string]interface{}, []string, error) {
f.lastQuery = query
return f.data, f.cols, f.err
}
func (f *fakeExportQueryDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
f.lastQuery = query
if deadline, ok := ctx.Deadline(); ok {
f.hasContextDeadline = true
f.lastContextTimeout = time.Until(deadline)
}
return f.data, f.cols, f.err
}
func (f *fakeExportQueryDB) Exec(query string) (int64, error) { return 0, nil }
func (f *fakeExportQueryDB) GetDatabases() ([]string, error) { return nil, nil }
func (f *fakeExportQueryDB) GetTables(dbName string) ([]string, error) {
return nil, nil
}
func (f *fakeExportQueryDB) GetCreateStatement(dbName, tableName string) (string, error) {
return "", nil
}
func (f *fakeExportQueryDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
return nil, nil
}
func (f *fakeExportQueryDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
return nil, nil
}
func (f *fakeExportQueryDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
return nil, nil
}
func (f *fakeExportQueryDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
return nil, nil
}
func (f *fakeExportQueryDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return nil, nil
}
func TestFormatExportCellText_FloatNoScientificNotation(t *testing.T) {
got := formatExportCellText(1.445663e+06)
if strings.Contains(strings.ToLower(got), "e+") || strings.Contains(strings.ToLower(got), "e-") {
@@ -87,3 +140,66 @@ func TestWriteRowsToFile_JSON_NumberKeepPlainText(t *testing.T) {
t.Fatalf("json 数值格式异常want=1445663 got=%s", decoded[0]["id"].String())
}
}
func TestQueryDataForExport_UsesMinimumTimeout(t *testing.T) {
fake := &fakeExportQueryDB{
data: []map[string]interface{}{{"v": 1}},
cols: []string{"v"},
}
_, _, err := queryDataForExport(fake, connection.ConnectionConfig{Timeout: 10}, "SELECT 1")
if err != nil {
t.Fatalf("queryDataForExport 返回错误: %v", err)
}
if !fake.hasContextDeadline {
t.Fatal("queryDataForExport 应设置 context deadline")
}
if fake.lastQuery != "SELECT 1" {
t.Fatalf("queryDataForExport 查询语句异常want=%q got=%q", "SELECT 1", fake.lastQuery)
}
lowerBound := minExportQueryTimeout - 5*time.Second
upperBound := minExportQueryTimeout + 5*time.Second
if fake.lastContextTimeout < lowerBound || fake.lastContextTimeout > upperBound {
t.Fatalf("导出最小超时异常want≈%s got=%s", minExportQueryTimeout, fake.lastContextTimeout)
}
}
func TestQueryDataForExport_UsesLargerConfiguredTimeout(t *testing.T) {
fake := &fakeExportQueryDB{
data: []map[string]interface{}{{"v": 1}},
cols: []string{"v"},
}
_, _, err := queryDataForExport(fake, connection.ConnectionConfig{Timeout: 900}, "SELECT 1")
if err != nil {
t.Fatalf("queryDataForExport 返回错误: %v", err)
}
if !fake.hasContextDeadline {
t.Fatal("queryDataForExport 应设置 context deadline")
}
expected := 900 * time.Second
lowerBound := expected - 5*time.Second
upperBound := expected + 5*time.Second
if fake.lastContextTimeout < lowerBound || fake.lastContextTimeout > upperBound {
t.Fatalf("导出配置超时异常want≈%s got=%s", expected, fake.lastContextTimeout)
}
}
func TestGetExportQueryTimeout_ClickHouseUsesLongerMinimum(t *testing.T) {
timeout := getExportQueryTimeout(connection.ConnectionConfig{
Type: "clickhouse",
Timeout: 30,
})
if timeout != minClickHouseExportQueryTimeout {
t.Fatalf("clickhouse 导出超时下限异常want=%s got=%s", minClickHouseExportQueryTimeout, timeout)
}
}
func TestGetExportQueryTimeout_CustomClickHouseUsesLongerMinimum(t *testing.T) {
timeout := getExportQueryTimeout(connection.ConnectionConfig{
Type: "custom",
Driver: "clickhouse",
Timeout: 30,
})
if timeout != minClickHouseExportQueryTimeout {
t.Fatalf("custom clickhouse 导出超时下限异常want=%s got=%s", minClickHouseExportQueryTimeout, timeout)
}
}

View File

@@ -24,6 +24,7 @@ const (
defaultClickHousePort = 9000
defaultClickHouseUser = "default"
defaultClickHouseDatabase = "default"
minClickHouseReadTimeout = 5 * time.Minute
)
type ClickHouseDB struct {
@@ -101,7 +102,11 @@ func applyClickHouseURI(config connection.ConnectionConfig) connection.Connectio
}
func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig) *clickhouse.Options {
timeout := getConnectTimeout(config)
connectTimeout := getConnectTimeout(config)
readTimeout := connectTimeout
if readTimeout < minClickHouseReadTimeout {
readTimeout = minClickHouseReadTimeout
}
return &clickhouse.Options{
Addr: []string{
net.JoinHostPort(config.Host, strconv.Itoa(config.Port)),
@@ -111,8 +116,8 @@ func (c *ClickHouseDB) buildClickHouseOptions(config connection.ConnectionConfig
Username: strings.TrimSpace(config.User),
Password: config.Password,
},
DialTimeout: timeout,
ReadTimeout: timeout,
DialTimeout: connectTimeout,
ReadTimeout: readTimeout,
}
}

View File

@@ -147,7 +147,7 @@ func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) {
if opts.DialTimeout != 15*time.Second {
t.Fatalf("dial timeout 不符合预期:%s", opts.DialTimeout)
}
if opts.ReadTimeout != 15*time.Second {
if opts.ReadTimeout != minClickHouseReadTimeout {
t.Fatalf("read timeout 不符合预期:%s", opts.ReadTimeout)
}
if _, ok := opts.Settings["write_timeout"]; ok {
@@ -160,3 +160,27 @@ func TestClickHouseOptions_UsesStructuredTimeoutAndAuth(t *testing.T) {
t.Fatalf("options 不应通过 settings 传递 dial_timeout%v", opts.Settings)
}
}
func TestClickHouseOptions_ReadTimeoutUsesLargerConfiguredTimeout(t *testing.T) {
c := &ClickHouseDB{}
cfg := normalizeClickHouseConfig(connection.ConnectionConfig{
Type: "clickhouse",
Host: "127.0.0.1",
Port: 9000,
User: "default",
Password: "secret",
Database: "analytics",
Timeout: 900,
})
opts := c.buildClickHouseOptions(cfg)
if opts == nil {
t.Fatal("options 为空")
}
if opts.DialTimeout != 900*time.Second {
t.Fatalf("dial timeout 不符合预期:%s", opts.DialTimeout)
}
if opts.ReadTimeout != 900*time.Second {
t.Fatalf("read timeout 不符合预期:%s", opts.ReadTimeout)
}
}

View File

@@ -11,8 +11,10 @@ import (
"os/exec"
"strings"
"sync"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
)
const (
@@ -38,6 +40,7 @@ type optionalAgentRequest struct {
Method string `json:"method"`
Config *connection.ConnectionConfig `json:"config,omitempty"`
Query string `json:"query,omitempty"`
TimeoutMs int64 `json:"timeoutMs,omitempty"`
DBName string `json:"dbName,omitempty"`
TableName string `json:"tableName,omitempty"`
Changes *connection.ChangeSet `json:"changes,omitempty"`
@@ -223,6 +226,7 @@ func (d *OptionalDriverAgentDB) Connect(config connection.ConnectionConfig) erro
if err != nil {
return err
}
logger.Infof("%s 驱动代理路径:%s", driverDisplayName(d.driverType), executablePath)
client, err := newOptionalDriverAgentClient(d.driverType, executablePath)
if err != nil {
return err
@@ -260,7 +264,20 @@ func (d *OptionalDriverAgentDB) QueryContext(ctx context.Context, query string)
if err := ctx.Err(); err != nil {
return nil, nil, err
}
return d.Query(query)
client, err := d.requireClient()
if err != nil {
return nil, nil, err
}
var data []map[string]interface{}
var fields []string
if err := client.call(optionalAgentRequest{
Method: optionalAgentMethodQuery,
Query: query,
TimeoutMs: timeoutMsFromContext(ctx),
}, &data, &fields, nil); err != nil {
return nil, nil, err
}
return data, fields, nil
}
func (d *OptionalDriverAgentDB) Query(query string) ([]map[string]interface{}, []string, error) {
@@ -283,7 +300,19 @@ func (d *OptionalDriverAgentDB) ExecContext(ctx context.Context, query string) (
if err := ctx.Err(); err != nil {
return 0, err
}
return d.Exec(query)
client, err := d.requireClient()
if err != nil {
return 0, err
}
var affected int64
if err := client.call(optionalAgentRequest{
Method: optionalAgentMethodExec,
Query: query,
TimeoutMs: timeoutMsFromContext(ctx),
}, nil, nil, &affected); err != nil {
return 0, err
}
return affected, nil
}
func (d *OptionalDriverAgentDB) Exec(query string) (int64, error) {
@@ -443,3 +472,15 @@ func (d *OptionalDriverAgentDB) requireClient() (*optionalDriverAgentClient, err
}
return d.client, nil
}
func timeoutMsFromContext(ctx context.Context) int64 {
deadline, ok := ctx.Deadline()
if !ok {
return 0
}
remaining := time.Until(deadline).Milliseconds()
if remaining <= 0 {
return 1
}
return remaining
}

View File

@@ -0,0 +1,32 @@
package db
import (
"context"
"testing"
"time"
)
func TestTimeoutMsFromContext_NoDeadline(t *testing.T) {
if got := timeoutMsFromContext(context.Background()); got != 0 {
t.Fatalf("无 deadline 时应返回 0got=%d", got)
}
}
func TestTimeoutMsFromContext_WithDeadline(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
got := timeoutMsFromContext(ctx)
if got <= 0 {
t.Fatalf("有 deadline 时应返回正值got=%d", got)
}
}
func TestTimeoutMsFromContext_ExpiredDeadline(t *testing.T) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
defer cancel()
if got := timeoutMsFromContext(ctx); got != 1 {
t.Fatalf("过期 deadline 应返回 1got=%d", got)
}
}