diff --git a/frontend/src/components/DataGrid.tsx b/frontend/src/components/DataGrid.tsx index 534cdfe..797a1ab 100644 --- a/frontend/src/components/DataGrid.tsx +++ b/frontend/src/components/DataGrid.tsx @@ -509,7 +509,17 @@ interface DataGridProps { onReload?: () => void; onSort?: (field: string, order: string) => void; onPageChange?: (page: number, size: number) => void; - pagination?: { current: number, pageSize: number, total: number, totalKnown?: boolean }; + pagination?: { + current: number, + pageSize: number, + total: number, + totalKnown?: boolean, + totalApprox?: boolean, + totalCountLoading?: boolean, + totalCountCancelled?: boolean, + }; + onRequestTotalCount?: () => void; + onCancelTotalCount?: () => void; sortInfoExternal?: { columnKey: string, order: string } | null; // Filtering showFilter?: boolean; @@ -534,7 +544,7 @@ type ColumnMeta = { const DataGrid: React.FC = ({ data, columnNames, loading, tableName, dbName, connectionId, pkColumns = [], readOnly = false, - onReload, onSort, onPageChange, pagination, sortInfoExternal, showFilter, onToggleFilter, onApplyFilter + onReload, onSort, onPageChange, pagination, onRequestTotalCount, onCancelTotalCount, sortInfoExternal, showFilter, onToggleFilter, onApplyFilter }) => { const connections = useStore(state => state.connections); const addSqlLog = useStore(state => state.addSqlLog); @@ -2527,6 +2537,26 @@ const DataGrid: React.FC = ({ )} + {isDuckDBConnection && onRequestTotalCount && ( + <> +
+ + + + + )} +
= ({ pageSize={pagination.pageSize} total={pagination.total} showTotal={(total, range) => { + const hasValidRange = Array.isArray(range) && range[0] > 0 && range[1] >= range[0]; + const currentCount = hasValidRange ? Math.max(0, range[1] - range[0] + 1) : 0; + if (pagination.totalKnown === false) { + if (isDuckDBConnection) { + if (pagination.totalCountLoading) return `当前 ${currentCount} 条 / 正在统计精确总数...`; + if (pagination.totalApprox && Number.isFinite(total) && total > 0) return `当前 ${currentCount} 条 / 约 ${total} 条`; + if (pagination.totalCountCancelled) return `当前 ${currentCount} 条 / 已取消统计`; + return `当前 ${currentCount} 条 / 总数未统计`; + } + return `当前 ${currentCount} 条 / 正在统计总数...`; + } if (isDuckDBConnection && (!Number.isFinite(total) || total <= 0)) { - if (pagination.totalKnown === false) return '当前 0 条 / 正在统计总数...'; return '当前 0 条 / 共 0 条'; } - const currentCount = Math.max(0, range[1] - range[0] + 1); - if (pagination.totalKnown === false) return `当前 ${currentCount} 条 / 正在统计总数...`; return `当前 ${currentCount} 条 / 共 ${total} 条`; }} showSizeChanger diff --git a/frontend/src/components/DataViewer.tsx b/frontend/src/components/DataViewer.tsx index 575ee0b..a9af795 100644 --- a/frontend/src/components/DataViewer.tsx +++ b/frontend/src/components/DataViewer.tsx @@ -2,10 +2,20 @@ import React, { useEffect, useState, useCallback, useRef } from 'react'; import { message } from 'antd'; import { TabData, ColumnDefinition } from '../types'; import { useStore } from '../store'; -import { DBQuery, DBGetColumns } from '../../wailsjs/go/app/App'; +import { DBQuery, DBGetColumns, DBQueryIsolated } from '../../wailsjs/go/app/App'; import DataGrid, { GONAVI_ROW_KEY } from './DataGrid'; import { buildOrderBySQL, buildWhereSQL, quoteQualifiedIdent, withSortBufferTuningSQL, type FilterCondition } from '../utils/sql'; +type ViewerPaginationState = { + current: number; + pageSize: number; + total: number; + totalKnown: boolean; + totalApprox: boolean; + totalCountLoading: boolean; + totalCountCancelled: boolean; +}; + const toNonNegativeFiniteNumber = (value: unknown): number | null => { if (typeof value === 'number') { return Number.isFinite(value) && value >= 0 ? value : null; @@ -43,6 +53,61 @@ const parseTotalFromCountRow = (row: any): number | null => { return null; }; +const parseDuckDBApproxTotalRow = (row: any): number | null => { + if (!row || typeof row !== 'object') return null; + const entries = Object.entries(row as Record); + if (entries.length === 0) return null; + + const preferredKeys = ['approx_total', 'estimated_size', 'estimated_rows', 'row_count', 'count', 'total']; + for (const preferred of preferredKeys) { + for (const [key, raw] of entries) { + if (String(key || '').trim().toLowerCase() !== preferred) continue; + const parsed = toNonNegativeFiniteNumber(raw); + if (parsed !== null) return parsed; + } + } + + for (const [key, raw] of entries) { + const normalized = String(key || '').trim().toLowerCase(); + if (normalized.includes('estimate') || normalized.includes('row') || normalized.includes('count') || normalized.includes('total')) { + const parsed = toNonNegativeFiniteNumber(raw); + if (parsed !== null) return parsed; + } + } + return null; +}; + +const normalizeDuckDBIdentifier = (raw: string): string => { + const text = String(raw || '').trim(); + if (text.length >= 2) { + const first = text[0]; + const last = text[text.length - 1]; + if ((first === '"' && last === '"') || (first === '`' && last === '`')) { + return text.slice(1, -1).trim(); + } + } + return text; +}; + +const resolveDuckDBSchemaAndTable = (dbName: string, tableName: string) => { + const rawTable = String(tableName || '').trim(); + if (!rawTable) return { schemaName: 'main', pureTableName: '' }; + + const parts = rawTable.split('.'); + if (parts.length >= 2) { + const pureTableName = normalizeDuckDBIdentifier(parts[parts.length - 1]); + const schemaName = normalizeDuckDBIdentifier(parts[parts.length - 2]); + if (schemaName && pureTableName) { + return { schemaName, pureTableName }; + } + } + + const fallbackSchema = normalizeDuckDBIdentifier(String(dbName || '').trim()) || 'main'; + return { schemaName: fallbackSchema, pureTableName: normalizeDuckDBIdentifier(rawTable) }; +}; + +const escapeSQLLiteral = (value: string): string => String(value || '').replace(/'/g, "''"); + const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => { const [data, setData] = useState([]); const [columnNames, setColumnNames] = useState([]); @@ -53,14 +118,26 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => { const fetchSeqRef = useRef(0); const countSeqRef = useRef(0); const countKeyRef = useRef(''); + const duckdbApproxSeqRef = useRef(0); + const duckdbApproxKeyRef = useRef(''); + const manualCountSeqRef = useRef(0); + const manualCountKeyRef = useRef(''); const pkSeqRef = useRef(0); const pkKeyRef = useRef(''); + const latestConfigRef = useRef(null); + const latestDbTypeRef = useRef(''); + const latestDbNameRef = useRef(''); + const latestCountSqlRef = useRef(''); + const latestCountKeyRef = useRef(''); - const [pagination, setPagination] = useState({ + const [pagination, setPagination] = useState({ current: 1, pageSize: 100, total: 0, - totalKnown: false + totalKnown: false, + totalApprox: false, + totalCountLoading: false, + totalCountCancelled: false, }); const [sortInfo, setSortInfo] = useState<{ columnKey: string, order: string } | null>(null); @@ -70,13 +147,106 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => { const currentConnType = (connections.find(c => c.id === tab.connectionId)?.config?.type || '').toLowerCase(); const forceReadOnly = currentConnType === 'tdengine' || currentConnType === 'clickhouse'; + const runIsolatedQuery = useCallback(async (queryConfig: any, dbName: string, sql: string) => { + return DBQueryIsolated(queryConfig as any, dbName, sql); + }, []); + useEffect(() => { setPkColumns([]); pkKeyRef.current = ''; countKeyRef.current = ''; - setPagination(prev => ({ ...prev, current: 1, total: 0, totalKnown: false })); + duckdbApproxKeyRef.current = ''; + manualCountKeyRef.current = ''; + latestConfigRef.current = null; + latestDbTypeRef.current = ''; + latestDbNameRef.current = ''; + latestCountSqlRef.current = ''; + latestCountKeyRef.current = ''; + setPagination(prev => ({ + ...prev, + current: 1, + total: 0, + totalKnown: false, + totalApprox: false, + totalCountLoading: false, + totalCountCancelled: false, + })); }, [tab.connectionId, tab.dbName, tab.tableName]); + const handleDuckDBManualCount = useCallback(async () => { + if (latestDbTypeRef.current !== 'duckdb') { + return; + } + const config = latestConfigRef.current; + const dbName = latestDbNameRef.current; + const countSql = latestCountSqlRef.current; + const countKey = latestCountKeyRef.current; + + if (!config || !countSql || !countKey) { + message.warning('当前结果集尚未就绪,请先执行一次加载'); + return; + } + + manualCountKeyRef.current = countKey; + const countSeq = ++manualCountSeqRef.current; + const countStart = Date.now(); + setPagination(prev => ({ ...prev, totalCountLoading: true, totalCountCancelled: false })); + const countConfig: any = { ...(config as any), timeout: 120 }; + + try { + const resCount = await runIsolatedQuery(countConfig, dbName, countSql); + const countDuration = Date.now() - countStart; + addSqlLog({ + id: `log-${Date.now()}-duckdb-manual-count`, + timestamp: Date.now(), + sql: countSql, + status: resCount?.success ? 'success' : 'error', + duration: countDuration, + message: resCount?.success ? '' : String(resCount?.message || '统计失败'), + dbName + }); + + if (manualCountSeqRef.current !== countSeq) return; + if (manualCountKeyRef.current !== countKey) return; + + if (!resCount?.success) { + setPagination(prev => ({ ...prev, totalCountLoading: false })); + message.error(String(resCount?.message || '统计总数失败')); + return; + } + if (!Array.isArray(resCount.data) || resCount.data.length === 0) { + setPagination(prev => ({ ...prev, totalCountLoading: false })); + return; + } + + const total = parseTotalFromCountRow(resCount.data[0]); + if (total === null) { + setPagination(prev => ({ ...prev, totalCountLoading: false })); + message.error('统计结果解析失败'); + return; + } + + setPagination(prev => ({ + ...prev, + total, + totalKnown: true, + totalApprox: false, + totalCountLoading: false, + totalCountCancelled: false, + })); + } catch (e: any) { + if (manualCountSeqRef.current !== countSeq) return; + if (manualCountKeyRef.current !== countKey) return; + setPagination(prev => ({ ...prev, totalCountLoading: false })); + message.error(`统计总数失败: ${String(e?.message || e)}`); + } + }, [addSqlLog, runIsolatedQuery]); + + const handleDuckDBCancelManualCount = useCallback(() => { + manualCountSeqRef.current++; + setPagination(prev => ({ ...prev, totalCountLoading: false, totalCountCancelled: true })); + }, []); + const fetchData = useCallback(async (page = pagination.current, size = pagination.pageSize) => { const seq = ++fetchSeqRef.current; setLoading(true); @@ -197,10 +367,24 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => { const isDuckDB = dbTypeLower === 'duckdb'; const minExpectedTotal = hasMore ? offset + resultData.length + 1 : offset + resultData.length; if (derivedTotalKnown) countKeyRef.current = countKey; + latestConfigRef.current = config; + latestDbTypeRef.current = dbTypeLower; + latestDbNameRef.current = dbName; + latestCountSqlRef.current = countSql; + latestCountKeyRef.current = countKey; setPagination(prev => { if (derivedTotalKnown) { - return { ...prev, current: page, pageSize: size, total: derivedTotal, totalKnown: true }; + return { + ...prev, + current: page, + pageSize: size, + total: derivedTotal, + totalKnown: true, + totalApprox: false, + totalCountLoading: false, + totalCountCancelled: false, + }; } if (prev.totalKnown && countKeyRef.current === countKey) { if (!isDuckDB) { @@ -212,16 +396,38 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => { return { ...prev, current: page, pageSize: size }; } } - return { ...prev, current: page, pageSize: size, total: derivedTotal, totalKnown: false }; + const keepManualCounting = prev.totalCountLoading && manualCountKeyRef.current === countKey; + if (isDuckDB && prev.totalApprox && duckdbApproxKeyRef.current === countKey && Number.isFinite(prev.total) && prev.total >= minExpectedTotal) { + return { + ...prev, + current: page, + pageSize: size, + totalKnown: false, + totalApprox: true, + totalCountLoading: keepManualCounting, + totalCountCancelled: false, + }; + } + return { + ...prev, + current: page, + pageSize: size, + total: derivedTotal, + totalKnown: false, + totalApprox: false, + totalCountLoading: keepManualCounting, + totalCountCancelled: keepManualCounting ? false : prev.totalCountCancelled, + }; }); - if (!derivedTotalKnown) { + const shouldRunAsyncCount = !derivedTotalKnown && !isDuckDB; + if (shouldRunAsyncCount) { if (countKeyRef.current !== countKey) { countKeyRef.current = countKey; const countSeq = ++countSeqRef.current; const countStart = Date.now(); // 大表 COUNT(*) 可能非常慢,且在部分运行时环境下会影响后续操作响应; - // 这里为统计请求设置更短的超时,避免“后台统计”长期占用资源。 + // DuckDB 大文件场景下该统计会显著拖慢翻页,已禁用后台 COUNT。 const countConfig: any = { ...(config as any), timeout: 5 }; DBQuery(countConfig, dbName, countSql) @@ -245,17 +451,20 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => { if (!Array.isArray(resCount.data) || resCount.data.length === 0) return; let total: number | null = null; - if (dbTypeLower === 'duckdb') { - total = parseTotalFromCountRow(resCount.data[0]); - } else { - const parsed = Number(resCount.data[0]?.['total']); - if (Number.isFinite(parsed) && parsed >= 0) { - total = parsed; - } + const parsed = Number(resCount.data[0]?.['total']); + if (Number.isFinite(parsed) && parsed >= 0) { + total = parsed; } if (total === null) return; - setPagination(prev => ({ ...prev, total, totalKnown: true })); + setPagination(prev => ({ + ...prev, + total, + totalKnown: true, + totalApprox: false, + totalCountLoading: false, + totalCountCancelled: false, + })); }) .catch(() => { if (countSeqRef.current !== countSeq) return; @@ -264,6 +473,50 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => { }); } } + + if (isDuckDB && !derivedTotalKnown && whereSQL.trim() === '' && duckdbApproxKeyRef.current !== countKey) { + duckdbApproxKeyRef.current = countKey; + const approxSeq = ++duckdbApproxSeqRef.current; + const { schemaName, pureTableName } = resolveDuckDBSchemaAndTable(dbName, tableName); + const escapedSchema = escapeSQLLiteral(schemaName); + const escapedTable = escapeSQLLiteral(pureTableName); + const approxConfig: any = { ...(config as any), timeout: 3 }; + const approxSqlCandidates = [ + `SELECT estimated_size AS approx_total FROM duckdb_tables() WHERE schema_name='${escapedSchema}' AND table_name='${escapedTable}' LIMIT 1`, + `SELECT estimated_size AS approx_total FROM duckdb_tables() WHERE table_name='${escapedTable}' ORDER BY CASE WHEN schema_name='${escapedSchema}' THEN 0 ELSE 1 END LIMIT 1`, + ]; + + (async () => { + for (const approxSql of approxSqlCandidates) { + try { + const approxRes = await runIsolatedQuery(approxConfig, dbName, approxSql); + if (duckdbApproxSeqRef.current !== approxSeq) return; + if (countKeyRef.current !== countKey) return; + if (!approxRes?.success || !Array.isArray(approxRes.data) || approxRes.data.length === 0) continue; + + const approxTotal = parseDuckDBApproxTotalRow(approxRes.data[0]); + if (approxTotal === null) continue; + if (!Number.isFinite(approxTotal) || approxTotal < minExpectedTotal) continue; + + setPagination(prev => { + if (countKeyRef.current !== countKey) return prev; + if (prev.totalKnown) return prev; + return { + ...prev, + total: approxTotal, + totalKnown: false, + totalApprox: true, + totalCountCancelled: false, + }; + }); + return; + } catch { + if (duckdbApproxSeqRef.current !== approxSeq) return; + if (countKeyRef.current !== countKey) return; + } + } + })(); + } } else { message.error(String(resData.message || '查询失败')); } @@ -281,7 +534,7 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => { }); } if (fetchSeqRef.current === seq) setLoading(false); - }, [connections, tab, sortInfo, filterConditions, pkColumns]); + }, [connections, tab, sortInfo, filterConditions, pkColumns, runIsolatedQuery]); // 依赖 pkColumns:在无手动排序时可回退到主键稳定排序。 // 主键信息只会在首次加载后更新一次,避免循环查询。 @@ -320,6 +573,8 @@ const DataViewer: React.FC<{ tab: TabData }> = ({ tab }) => { onSort={handleSort} onPageChange={handlePageChange} pagination={pagination} + onRequestTotalCount={currentConnType === 'duckdb' ? handleDuckDBManualCount : undefined} + onCancelTotalCount={currentConnType === 'duckdb' ? handleDuckDBCancelManualCount : undefined} showFilter={showFilter} onToggleFilter={handleToggleFilter} onApplyFilter={handleApplyFilter} diff --git a/frontend/wailsjs/go/app/App.d.ts b/frontend/wailsjs/go/app/App.d.ts index 0bf094a..72ad6a1 100755 --- a/frontend/wailsjs/go/app/App.d.ts +++ b/frontend/wailsjs/go/app/App.d.ts @@ -34,6 +34,8 @@ export function DBGetTriggers(arg1:connection.ConnectionConfig,arg2:string,arg3: export function DBQuery(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise; +export function DBQueryIsolated(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise; + export function DBShowCreateTable(arg1:connection.ConnectionConfig,arg2:string,arg3:string):Promise; export function DataSync(arg1:sync.SyncConfig):Promise; diff --git a/frontend/wailsjs/go/app/App.js b/frontend/wailsjs/go/app/App.js index 6879fc2..86f801f 100755 --- a/frontend/wailsjs/go/app/App.js +++ b/frontend/wailsjs/go/app/App.js @@ -62,6 +62,10 @@ export function DBQuery(arg1, arg2, arg3) { return window['go']['app']['App']['DBQuery'](arg1, arg2, arg3); } +export function DBQueryIsolated(arg1, arg2, arg3) { + return window['go']['app']['App']['DBQueryIsolated'](arg1, arg2, arg3); +} + export function DBShowCreateTable(arg1, arg2, arg3) { return window['go']['app']['App']['DBShowCreateTable'](arg1, arg2, arg3); } diff --git a/internal/app/app.go b/internal/app/app.go index a46726e..b8dd6a7 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -207,6 +207,32 @@ func (a *App) getDatabase(config connection.ConnectionConfig) (db.Database, erro return a.getDatabaseWithPing(config, false) } +func (a *App) openDatabaseIsolated(config connection.ConnectionConfig) (db.Database, error) { + effectiveConfig := applyGlobalProxyToConnection(config) + if supported, reason := db.DriverRuntimeSupportStatus(effectiveConfig.Type); !supported { + if strings.TrimSpace(reason) == "" { + reason = fmt.Sprintf("%s 驱动未启用,请先在驱动管理中安装启用", strings.TrimSpace(effectiveConfig.Type)) + } + return nil, withLogHint{err: fmt.Errorf("%s", reason), logPath: logger.Path()} + } + + dbInst, err := db.NewDatabase(effectiveConfig.Type) + if err != nil { + return nil, err + } + + connectConfig, proxyErr := resolveDialConfigWithProxy(effectiveConfig) + if proxyErr != nil { + _ = dbInst.Close() + return nil, wrapConnectError(effectiveConfig, proxyErr) + } + if err := dbInst.Connect(connectConfig); err != nil { + _ = dbInst.Close() + return nil, wrapConnectError(effectiveConfig, err) + } + return dbInst, nil +} + func (a *App) getDatabaseWithPing(config connection.ConnectionConfig, forcePing bool) (db.Database, error) { effectiveConfig := applyGlobalProxyToConnection(config) diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index 4c4086b..8263b2b 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -7,6 +7,7 @@ import ( "time" "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/db" "GoNavi-Wails/internal/logger" "GoNavi-Wails/internal/utils" ) @@ -112,16 +113,39 @@ func resolveDDLDBType(config connection.ConnectionConfig) string { driver := strings.ToLower(strings.TrimSpace(config.Driver)) switch driver { - case "postgresql": + case "postgresql", "postgres", "pg", "pq", "pgx": return "postgres" - case "dm": + case "dm", "dameng", "dm8": return "dameng" - case "sqlite3": + case "sqlite3", "sqlite": return "sqlite" case "sphinxql": return "sphinx" case "diros", "doris": return "diros" + case "kingbase", "kingbase8", "kingbasees", "kingbasev8": + return "kingbase" + case "highgo": + return "highgo" + case "vastbase": + return "vastbase" + } + + switch { + case strings.Contains(driver, "postgres"): + return "postgres" + case strings.Contains(driver, "kingbase"): + return "kingbase" + case strings.Contains(driver, "highgo"): + return "highgo" + case strings.Contains(driver, "vastbase"): + return "vastbase" + case strings.Contains(driver, "sqlite"): + return "sqlite" + case strings.Contains(driver, "sphinx"): + return "sphinx" + case strings.Contains(driver, "diros"), strings.Contains(driver, "doris"): + return "diros" default: return driver } @@ -406,6 +430,66 @@ func (a *App) DBQuery(config connection.ConnectionConfig, dbName string, query s } } +func (a *App) DBQueryIsolated(config connection.ConnectionConfig, dbName string, query string) connection.QueryResult { + runConfig := normalizeRunConfig(config, dbName) + + dbInst, err := a.openDatabaseIsolated(runConfig) + if err != nil { + logger.Error(err, "DBQueryIsolated 获取连接失败:%s", formatConnSummary(runConfig)) + return connection.QueryResult{Success: false, Message: err.Error()} + } + defer func() { + if closeErr := dbInst.Close(); closeErr != nil { + logger.Error(closeErr, "DBQueryIsolated 关闭临时连接失败:%s", formatConnSummary(runConfig)) + } + }() + + query = sanitizeSQLForPgLike(runConfig.Type, query) + timeoutSeconds := runConfig.Timeout + if timeoutSeconds <= 0 { + timeoutSeconds = 30 + } + ctx, cancel := utils.ContextWithTimeout(time.Duration(timeoutSeconds) * time.Second) + defer cancel() + + lowerQuery := strings.TrimSpace(strings.ToLower(query)) + isReadQuery := strings.HasPrefix(lowerQuery, "select") || strings.HasPrefix(lowerQuery, "show") || strings.HasPrefix(lowerQuery, "describe") || strings.HasPrefix(lowerQuery, "explain") + if !isReadQuery && strings.ToLower(strings.TrimSpace(runConfig.Type)) == "mongodb" && strings.HasPrefix(strings.TrimSpace(query), "{") { + isReadQuery = true + } + + if isReadQuery { + var data []map[string]interface{} + var columns []string + if q, ok := dbInst.(interface { + QueryContext(context.Context, string) ([]map[string]interface{}, []string, error) + }); ok { + data, columns, err = q.QueryContext(ctx, query) + } else { + data, columns, err = dbInst.Query(query) + } + if err != nil { + logger.Error(err, "DBQueryIsolated 查询失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) + return connection.QueryResult{Success: false, Message: err.Error()} + } + return connection.QueryResult{Success: true, Data: data, Fields: columns} + } + + var affected int64 + if e, ok := dbInst.(interface { + ExecContext(context.Context, string) (int64, error) + }); ok { + affected, err = e.ExecContext(ctx, query) + } else { + affected, err = dbInst.Exec(query) + } + if err != nil { + logger.Error(err, "DBQueryIsolated 执行失败:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) + return connection.QueryResult{Success: false, Message: err.Error()} + } + return connection.QueryResult{Success: true, Data: map[string]int64{"affectedRows": affected}} +} + func sqlSnippet(query string) string { q := strings.TrimSpace(query) const max = 200 @@ -460,8 +544,8 @@ func (a *App) DBGetTables(config connection.ConnectionConfig, dbName string) con } func (a *App) DBShowCreateTable(config connection.ConnectionConfig, dbName string, tableName string) connection.QueryResult { - runConfig := normalizeRunConfig(config, dbName) dbType := resolveDDLDBType(config) + runConfig := buildRunConfigForDDL(config, dbType, dbName) dbInst, err := a.getDatabase(runConfig) if err != nil { @@ -469,35 +553,65 @@ func (a *App) DBShowCreateTable(config connection.ConnectionConfig, dbName strin return connection.QueryResult{Success: false, Message: err.Error()} } - schemaName, pureTableName := normalizeSchemaAndTable(config, dbName, tableName) - sqlStr, err := dbInst.GetCreateStatement(schemaName, pureTableName) + sqlStr, err := resolveCreateStatementWithFallback(dbInst, config, dbName, tableName) if err != nil { logger.Error(err, "DBShowCreateTable 获取建表语句失败:%s 表=%s", formatConnSummary(runConfig), tableName) return connection.QueryResult{Success: false, Message: err.Error()} } - if shouldFallbackCreateStatement(dbType, sqlStr) { - columns, colErr := dbInst.GetColumns(schemaName, pureTableName) - if colErr != nil { - logger.Error(colErr, "DBShowCreateTable 兜底加载字段失败:%s 表=%s", formatConnSummary(runConfig), tableName) - return connection.QueryResult{Success: false, Message: colErr.Error()} - } - fallbackDDL, buildErr := buildFallbackCreateStatement(dbType, schemaName, pureTableName, columns) - if buildErr != nil { - logger.Error(buildErr, "DBShowCreateTable 兜底生成 DDL 失败:%s 表=%s", formatConnSummary(runConfig), tableName) - return connection.QueryResult{Success: false, Message: buildErr.Error()} - } - sqlStr = fallbackDDL - } return connection.QueryResult{Success: true, Data: sqlStr} } -func shouldFallbackCreateStatement(dbType string, ddl string) bool { +func resolveCreateStatementWithFallback(dbInst db.Database, config connection.ConnectionConfig, dbName string, tableName string) (string, error) { + dbType := resolveDDLDBType(config) + schemaName, pureTableName := normalizeSchemaAndTableByType(dbType, dbName, tableName) + if pureTableName == "" { + return "", fmt.Errorf("表名不能为空") + } + + sqlStr, sourceErr := dbInst.GetCreateStatement(schemaName, pureTableName) + if sourceErr == nil && !shouldFallbackCreateStatement(dbType, sqlStr) { + return sqlStr, nil + } + + if !supportsCreateStatementFallback(dbType) { + if sourceErr != nil { + return "", sourceErr + } + return sqlStr, nil + } + + columns, colErr := dbInst.GetColumns(schemaName, pureTableName) + if colErr != nil { + if sourceErr != nil { + return "", sourceErr + } + return "", colErr + } + + fallbackDDL, buildErr := buildFallbackCreateStatement(dbType, schemaName, pureTableName, columns) + if buildErr != nil { + if sourceErr != nil { + return "", sourceErr + } + return "", buildErr + } + return fallbackDDL, nil +} + +func supportsCreateStatementFallback(dbType string) bool { switch dbType { case "postgres", "kingbase", "highgo", "vastbase": + return true default: return false } +} + +func shouldFallbackCreateStatement(dbType string, ddl string) bool { + if !supportsCreateStatementFallback(dbType) { + return false + } trimmed := strings.TrimSpace(ddl) if trimmed == "" { diff --git a/internal/app/methods_db_create_statement_test.go b/internal/app/methods_db_create_statement_test.go new file mode 100644 index 0000000..dfbf1fd --- /dev/null +++ b/internal/app/methods_db_create_statement_test.go @@ -0,0 +1,174 @@ +package app + +import ( + "errors" + "strings" + "testing" + + "GoNavi-Wails/internal/connection" +) + +type fakeCreateStatementDB struct { + createSQL string + createErr error + columns []connection.ColumnDefinition + columnsErr error + + createSchema string + createTable string + colsSchema string + colsTable string +} + +func (f *fakeCreateStatementDB) Connect(config connection.ConnectionConfig) error { return nil } +func (f *fakeCreateStatementDB) Close() error { return nil } +func (f *fakeCreateStatementDB) Ping() error { return nil } +func (f *fakeCreateStatementDB) Query(query string) ([]map[string]interface{}, []string, error) { + return nil, nil, nil +} +func (f *fakeCreateStatementDB) Exec(query string) (int64, error) { return 0, nil } +func (f *fakeCreateStatementDB) GetDatabases() ([]string, error) { return nil, nil } +func (f *fakeCreateStatementDB) GetTables(dbName string) ([]string, error) { return nil, nil } +func (f *fakeCreateStatementDB) GetCreateStatement(dbName, tableName string) (string, error) { + f.createSchema = dbName + f.createTable = tableName + return f.createSQL, f.createErr +} +func (f *fakeCreateStatementDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { + f.colsSchema = dbName + f.colsTable = tableName + return f.columns, f.columnsErr +} +func (f *fakeCreateStatementDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { + return nil, nil +} +func (f *fakeCreateStatementDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { + return nil, nil +} +func (f *fakeCreateStatementDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { + return nil, nil +} +func (f *fakeCreateStatementDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { + return nil, nil +} + +func TestResolveDDLDBType_CustomDriverAlias(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + driver string + want string + }{ + {name: "postgresql alias", driver: "postgresql", want: "postgres"}, + {name: "pgx alias", driver: "pgx", want: "postgres"}, + {name: "kingbase8 alias", driver: "kingbase8", want: "kingbase"}, + {name: "kingbase contains alias", driver: "kingbasees", want: "kingbase"}, + {name: "dm alias", driver: "dm8", want: "dameng"}, + {name: "sqlite alias", driver: "sqlite3", want: "sqlite"}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + cfg := connection.ConnectionConfig{Type: "custom", Driver: tc.driver} + if got := resolveDDLDBType(cfg); got != tc.want { + t.Fatalf("resolveDDLDBType() mismatch, want=%q got=%q", tc.want, got) + } + }) + } +} + +func TestResolveCreateStatementWithFallback_CustomKingbaseUsesPublicSchema(t *testing.T) { + t.Parallel() + + dbInst := &fakeCreateStatementDB{ + createSQL: "SHOW CREATE TABLE not directly supported in Kingbase/Postgres via SQL", + columns: []connection.ColumnDefinition{ + {Name: "id", Type: "bigint", Nullable: "NO", Key: "PRI"}, + }, + } + + ddl, err := resolveCreateStatementWithFallback(dbInst, connection.ConnectionConfig{ + Type: "custom", + Driver: "kingbase8", + }, "demo_db", "orders") + if err != nil { + t.Fatalf("resolveCreateStatementWithFallback() unexpected error: %v", err) + } + if dbInst.createSchema != "public" || dbInst.colsSchema != "public" { + t.Fatalf("expected fallback schema public, got create=%q columns=%q", dbInst.createSchema, dbInst.colsSchema) + } + if !strings.Contains(ddl, `CREATE TABLE "public"."orders"`) { + t.Fatalf("expected fallback DDL with public schema, got: %s", ddl) + } +} + +func TestResolveCreateStatementWithFallback_KeepQualifiedSchema(t *testing.T) { + t.Parallel() + + dbInst := &fakeCreateStatementDB{ + createSQL: "-- SHOW CREATE TABLE not fully supported for PostgreSQL in this MVP.", + columns: []connection.ColumnDefinition{ + {Name: "id", Type: "integer", Nullable: "NO", Key: "PRI"}, + }, + } + + ddl, err := resolveCreateStatementWithFallback(dbInst, connection.ConnectionConfig{ + Type: "custom", + Driver: "postgresql", + }, "demo_db", "sales.orders") + if err != nil { + t.Fatalf("resolveCreateStatementWithFallback() unexpected error: %v", err) + } + if dbInst.createSchema != "sales" || dbInst.colsSchema != "sales" { + t.Fatalf("expected schema sales, got create=%q columns=%q", dbInst.createSchema, dbInst.colsSchema) + } + if !strings.Contains(ddl, `CREATE TABLE "sales"."orders"`) { + t.Fatalf("expected fallback DDL with sales schema, got: %s", ddl) + } +} + +func TestResolveCreateStatementWithFallback_NoFallbackForMySQL(t *testing.T) { + t.Parallel() + + dbInst := &fakeCreateStatementDB{ + createSQL: "SHOW CREATE TABLE not directly supported in Kingbase/Postgres via SQL", + columnsErr: errors.New("should not be called"), + } + + ddl, err := resolveCreateStatementWithFallback(dbInst, connection.ConnectionConfig{ + Type: "mysql", + }, "demo_db", "orders") + if err != nil { + t.Fatalf("resolveCreateStatementWithFallback() unexpected error: %v", err) + } + if ddl != dbInst.createSQL { + t.Fatalf("expected original ddl for mysql, got: %s", ddl) + } + if dbInst.colsTable != "" { + t.Fatalf("mysql path should not call GetColumns, got table=%q", dbInst.colsTable) + } +} + +func TestResolveCreateStatementWithFallback_FallbackWhenCreateStatementError(t *testing.T) { + t.Parallel() + + dbInst := &fakeCreateStatementDB{ + createErr: errors.New("statement unsupported"), + columns: []connection.ColumnDefinition{ + {Name: "id", Type: "bigint", Nullable: "NO", Key: "PRI"}, + }, + } + + ddl, err := resolveCreateStatementWithFallback(dbInst, connection.ConnectionConfig{ + Type: "postgres", + }, "demo_db", "orders") + if err != nil { + t.Fatalf("resolveCreateStatementWithFallback() unexpected error: %v", err) + } + if !strings.Contains(ddl, `CREATE TABLE "public"."orders"`) { + t.Fatalf("expected fallback DDL for postgres error path, got: %s", ddl) + } +} diff --git a/internal/app/methods_file.go b/internal/app/methods_file.go index d80c251..561ef9b 100644 --- a/internal/app/methods_file.go +++ b/internal/app/methods_file.go @@ -1291,7 +1291,7 @@ func dumpTableSQL( createSQL = ddl } } else { - ddl, err := dbInst.GetCreateStatement(schemaName, pureTableName) + ddl, err := resolveCreateStatementWithFallback(dbInst, config, dbName, tableName) if err != nil { if viewDDL, ok := tryGetViewCreateStatement(dbInst, config, dbName, schemaName, pureTableName); ok { createSQL = viewDDL