From 8f1e6cf379671773b080bff2e732dcc2b10bb0ef Mon Sep 17 00:00:00 2001 From: Syngnat Date: Mon, 22 Jun 2026 22:36:39 +0800 Subject: [PATCH 1/6] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20perf(frontend):=20?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E9=95=BF=E6=97=B6=E8=BF=90=E8=A1=8C=E4=B8=8B?= =?UTF-8?q?=E7=9A=84=E6=90=9C=E7=B4=A2=E4=B8=8E=E7=BC=93=E5=AD=98=E5=8D=A0?= =?UTF-8?q?=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 为 V2 cmd+k 搜索预建索引并限制初始/宽泛结果数量 - 清理冷数据库树和 DataViewer 长生命周期快照缓存 - 收紧运行时 SQL 日志预算并在 hydration 时压缩旧缓存 --- .../DataViewer.primary-key.test.tsx | 8 + frontend/src/components/DataViewer.tsx | 26 +++- .../Sidebar.locate-toolbar.test.tsx | 9 +- frontend/src/components/Sidebar.tsx | 62 +++++++- .../sidebar/useSidebarSearchModel.tsx | 24 ++- .../sidebar/useSidebarTreeLoaders.tsx | 46 +----- .../sidebarV2Utils.command-search.test.ts | 116 ++++++++++++++ frontend/src/components/sidebarV2Utils.ts | 145 +++++++++++++++--- frontend/src/store.test.ts | 75 +++++++-- frontend/src/store.ts | 129 +++++++++++----- 10 files changed, 506 insertions(+), 134 deletions(-) create mode 100644 frontend/src/components/sidebarV2Utils.command-search.test.ts diff --git a/frontend/src/components/DataViewer.primary-key.test.tsx b/frontend/src/components/DataViewer.primary-key.test.tsx index 4cb1d30..8f257b9 100644 --- a/frontend/src/components/DataViewer.primary-key.test.tsx +++ b/frontend/src/components/DataViewer.primary-key.test.tsx @@ -164,6 +164,14 @@ describe('DataViewer safe editing locator', () => { expect(source).toContain('data_viewer.sql_log.phase.sort_buffer_retry'); }); + it('caps viewer filter snapshots so long-running sessions do not retain unbounded table state', () => { + const source = readFileSync(new URL('./DataViewer.tsx', import.meta.url), 'utf8'); + + expect(source).toContain('const MAX_VIEWER_FILTER_SNAPSHOTS = 64;'); + expect(source).toContain('const trimViewerFilterSnapshots = () => {'); + expect(source).toContain('setViewerFilterSnapshot(normalizedTabId, {'); + }); + it('enables table preview editing after primary keys are loaded', async () => { backendApp.DBGetColumns.mockResolvedValue({ success: true, diff --git a/frontend/src/components/DataViewer.tsx b/frontend/src/components/DataViewer.tsx index 7453fe9..746b42e 100644 --- a/frontend/src/components/DataViewer.tsx +++ b/frontend/src/components/DataViewer.tsx @@ -280,8 +280,32 @@ type ViewerScrollSnapshot = { }; const viewerFilterSnapshotsByTab = new Map(); +const MAX_VIEWER_FILTER_SNAPSHOTS = 64; const VIEWER_SCROLL_SNAPSHOT_PERSIST_DELAY_MS = 160; +const trimViewerFilterSnapshots = () => { + while (viewerFilterSnapshotsByTab.size > MAX_VIEWER_FILTER_SNAPSHOTS) { + const oldestKey = viewerFilterSnapshotsByTab.keys().next().value; + if (!oldestKey) { + break; + } + viewerFilterSnapshotsByTab.delete(oldestKey); + } +}; + +const setViewerFilterSnapshot = ( + tabId: string, + snapshot: ViewerFilterSnapshot, +) => { + const normalizedTabId = String(tabId || '').trim(); + if (!normalizedTabId) return; + if (viewerFilterSnapshotsByTab.has(normalizedTabId)) { + viewerFilterSnapshotsByTab.delete(normalizedTabId); + } + viewerFilterSnapshotsByTab.set(normalizedTabId, snapshot); + trimViewerFilterSnapshots(); +}; + const normalizeViewerFilterConditions = (conditions: FilterCondition[] | undefined): FilterCondition[] => { if (!Array.isArray(conditions)) return []; return conditions.map((cond) => ({ @@ -380,7 +404,7 @@ const DataViewer: React.FC<{ tab: TabData; isActive?: boolean }> = React.memo(({ const persistViewerSnapshot = useCallback((tabId: string, overrides?: Partial) => { const normalizedTabId = String(tabId || '').trim(); if (!normalizedTabId) return; - viewerFilterSnapshotsByTab.set(normalizedTabId, { + setViewerFilterSnapshot(normalizedTabId, { showFilter, conditions: normalizeViewerFilterConditions(filterConditions), quickWhereCondition: normalizeQuickWhereCondition(quickWhereCondition), diff --git a/frontend/src/components/Sidebar.locate-toolbar.test.tsx b/frontend/src/components/Sidebar.locate-toolbar.test.tsx index ba63e6a..31dd385 100644 --- a/frontend/src/components/Sidebar.locate-toolbar.test.tsx +++ b/frontend/src/components/Sidebar.locate-toolbar.test.tsx @@ -2344,9 +2344,6 @@ describe('Sidebar locate toolbar', () => { const loadTablesStart = source.indexOf('const loadTables = async (node: any) => {'); const loadTablesEnd = source.indexOf('const config = {', loadTablesStart); const loadTablesSource = source.slice(loadTablesStart, loadTablesEnd); - const externalSqlReadStart = source.indexOf('const externalSQLDirectoryResults = await Promise.all(', loadTablesStart); - const externalSqlReadEnd = source.indexOf('const externalSQLTrees = externalSQLDirectoryResults.reduce', externalSqlReadStart); - const externalSqlReadSource = source.slice(externalSqlReadStart, externalSqlReadEnd); const externalSqlFlowStart = source.indexOf('const handleAddExternalSQLDirectory = async (node: any) => {'); const externalSqlFlowEnd = source.indexOf('const cancelSQLFileExecution = () => {', externalSqlFlowStart); const externalSqlFlowSource = source.slice(externalSqlFlowStart, externalSqlFlowEnd); @@ -2369,8 +2366,6 @@ describe('Sidebar locate toolbar', () => { [ loadTablesStart, loadTablesEnd, - externalSqlReadStart, - externalSqlReadEnd, externalSqlFlowStart, externalSqlFlowEnd, treeTitleStart, @@ -2387,9 +2382,7 @@ describe('Sidebar locate toolbar', () => { expect(loadTablesSource).toContain("title: t('sidebar.tree.saved_queries')"); expect(loadTablesSource).not.toContain("title: '已存查询'"); - - expect(externalSqlReadSource).toContain("t('sidebar.message.external_sql_directory_read_failed'"); - expect(externalSqlReadSource).not.toContain('SQL 目录读取失败'); + expect(source).not.toContain('const externalSQLDirectoryResults = await Promise.all('); expect(loadTablesSource).not.toContain('SQL 目录读取失败'); expect(loadTablesSource).not.toContain("'SQL目录'"); diff --git a/frontend/src/components/Sidebar.tsx b/frontend/src/components/Sidebar.tsx index 29c3c5f..76c2381 100644 --- a/frontend/src/components/Sidebar.tsx +++ b/frontend/src/components/Sidebar.tsx @@ -164,6 +164,7 @@ import { resolveSidebarDropInsertBefore, resolveSidebarDropNodeFromDomEvent, resolveSidebarDropTargetMetricsFromDomEvent, + resolveSidebarDatabaseTreePruneKeys, resolveSidebarNodeConnectionId, resolveSidebarTagDropInsertBefore, resolveV2ActiveConnectionId, @@ -190,6 +191,7 @@ export { resolveSidebarDropInsertBefore, resolveSidebarDropNodeFromDomEvent, resolveSidebarDropTargetMetricsFromDomEvent, + resolveSidebarDatabaseTreePruneKeys, resolveSidebarNodeConnectionId, resolveSidebarTagDropInsertBefore, resolveV2ActiveConnectionId, @@ -205,6 +207,7 @@ export type { V2CommandSearchItem, V2RailConnectionGroup } from './sidebarV2Util const { Search } = Input; const SIDEBAR_LOCATE_LOAD_WAIT_INTERVAL_MS = 50; const SIDEBAR_LOCATE_LOAD_WAIT_ATTEMPTS = 160; +const SIDEBAR_CACHED_DATABASE_TREE_LIMIT = 12; // resolveV2ObjectGroupTitle 已迁移到 ./sidebar/sidebarHelpers @@ -506,6 +509,8 @@ const Sidebar: React.FC<{ const [selectedKeys, setSelectedKeys] = useState([]); const selectedNodesRef = useRef([]); const loadingNodesRef = useRef>(new Set()); + const databaseTreeTouchedAtRef = useRef>({}); + const pruneLoadedDatabaseTreesRef = useRef<() => void>(() => {}); const clickTimerRef = useRef | null>(null); const treeDragSelectSuppressUntilRef = useRef(0); const treeDragSelectionSnapshotRef = useRef<{ @@ -544,6 +549,7 @@ const Sidebar: React.FC<{ }, [setActiveContext]); const openV2CommandSearch = useCallback(() => { + pruneLoadedDatabaseTreesRef.current(); setIsV2CommandSearchOpen(true); setV2CommandActiveIndex(0); }, []); @@ -984,6 +990,55 @@ const Sidebar: React.FC<{ return nextTreeData; }; + const clearTreeNodeChildrenByKeys = useCallback((keysToClear: string[]) => { + const keysToClearSet = new Set(keysToClear.map((key) => String(key || '').trim()).filter(Boolean)); + if (keysToClearSet.size === 0) { + return; + } + + const clearChildren = (nodes: TreeNode[]): TreeNode[] => ( + nodes.map((node) => { + const nodeKey = String(node.key || '').trim(); + if (keysToClearSet.has(nodeKey)) { + return { ...node, children: undefined }; + } + if (node.children?.length) { + return { ...node, children: clearChildren(node.children) }; + } + return node; + }) + ); + + setTreeData((prev) => { + const nextTreeData = clearChildren(prev); + treeDataRef.current = nextTreeData; + return nextTreeData; + }); + setLoadedKeys((prev) => prev.filter((key) => !keysToClearSet.has(String(key)))); + keysToClearSet.forEach((key) => { + delete databaseTreeTouchedAtRef.current[key]; + }); + }, []); + + const pruneLoadedDatabaseTrees = useCallback(() => { + const activeDatabaseKey = activeContext?.connectionId && activeContext?.dbName + ? `${activeContext.connectionId}-${activeContext.dbName}` + : ''; + const keysToClear = resolveSidebarDatabaseTreePruneKeys({ + treeData: treeDataRef.current, + expandedKeys, + selectedKeys, + activeDatabaseKey, + touchedAtByDatabaseKey: databaseTreeTouchedAtRef.current, + maxLoadedDatabases: SIDEBAR_CACHED_DATABASE_TREE_LIMIT, + }); + if (keysToClear.length === 0) { + return; + } + clearTreeNodeChildrenByKeys(keysToClear); + }, [activeContext?.connectionId, activeContext?.dbName, clearTreeNodeChildrenByKeys, expandedKeys, selectedKeys]); + pruneLoadedDatabaseTreesRef.current = pruneLoadedDatabaseTrees; + const mergeExpandedTreeKeys = (requiredKeys: React.Key[]) => { setExpandedKeys(prev => { const merged = [...prev]; @@ -1727,7 +1782,6 @@ const Sidebar: React.FC<{ loadTables, } = useSidebarTreeLoaders({ savedQueries, - externalSQLDirectories, tableSortPreference, tableAccessCount, pinnedSidebarTables, @@ -1740,7 +1794,10 @@ const Sidebar: React.FC<{ buildJVMRuntimeConfig, buildJVMDiagnosticTreeNodes, resolveSavedQueryDisplayName, - decorateExternalSQLTreeNode, + onDatabaseTreeLoaded: (databaseKey: string) => { + databaseTreeTouchedAtRef.current[databaseKey] = Date.now(); + pruneLoadedDatabaseTrees(); + }, }); const { @@ -1950,6 +2007,7 @@ const Sidebar: React.FC<{ treeViewportWidth, treeHeight, isV2Ui, + isV2CommandSearchOpen, connections, connectionIds, selectedKeys, diff --git a/frontend/src/components/sidebar/useSidebarSearchModel.tsx b/frontend/src/components/sidebar/useSidebarSearchModel.tsx index c36c636..4ebdb57 100644 --- a/frontend/src/components/sidebar/useSidebarSearchModel.tsx +++ b/frontend/src/components/sidebar/useSidebarSearchModel.tsx @@ -30,6 +30,7 @@ import { } from './sidebarHelpers'; import type { SearchScope } from '../sidebarCoreUtils'; import { + buildV2CommandSearchTreeIndex, V2_TREE_HORIZONTAL_SCROLL_BOTTOM_RESERVE, estimateV2TreeHorizontalScrollWidth, filterV2CommandSearchTreeItems, @@ -74,6 +75,7 @@ type SidebarSearchModelArgs = { treeViewportWidth: number; treeHeight: number; isV2Ui: boolean; + isV2CommandSearchOpen: boolean; connections: SavedConnection[]; connectionIds: string[]; selectedKeys: React.Key[]; @@ -111,6 +113,7 @@ export const useSidebarSearchModel = ({ treeViewportWidth, treeHeight, isV2Ui, + isV2CommandSearchOpen, connections, connectionIds, selectedKeys, @@ -179,6 +182,10 @@ export const useSidebarSearchModel = ({ }; const currentLanguage = getCurrentLanguage(); + const connectionById = useMemo( + () => new Map(connections.map((connection) => [connection.id, connection])), + [connections], + ); const searchScopeSummary = useMemo(() => { if (searchScopes.includes('smart')) { @@ -360,6 +367,9 @@ export const useSidebarSearchModel = ({ }, [deferredSearchValue, searchScopes, treeData]); const commandSearchTreeItems = useMemo(() => { + if (!isV2CommandSearchOpen) { + return []; + } const result: V2CommandSearchItem[] = []; const visit = (nodes: TreeNode[]) => { nodes.forEach((node) => { @@ -375,7 +385,7 @@ export const useSidebarSearchModel = ({ node, }); } else if (node.type === 'database') { - const conn = connections.find((item) => item.id === dataRef.id); + const conn = connectionById.get(String(dataRef.id || '')); result.push({ key: `node-${node.key}`, kind: 'node', @@ -392,7 +402,7 @@ export const useSidebarSearchModel = ({ || node.type === 'db-event' || node.type === 'routine' ) { - const conn = connections.find((item) => item.id === dataRef.id); + const conn = connectionById.get(String(dataRef.id || '')); const objectName = String(dataRef.tableName || dataRef.viewName || dataRef.triggerName || dataRef.eventName || dataRef.routineName || node.title || '').trim(); const displayName = String(node.title || extractObjectName(objectName) || objectName).trim(); result.push({ @@ -412,7 +422,11 @@ export const useSidebarSearchModel = ({ visit(treeData); return result; - }, [connections, treeData]); + }, [connectionById, extractObjectName, isV2CommandSearchOpen, treeData]); + const commandSearchTreeIndex = useMemo( + () => buildV2CommandSearchTreeIndex(commandSearchTreeItems), + [commandSearchTreeItems], + ); const commandSearchRecentItems = useMemo(() => { return sqlLogs.slice(0, 5).map((log) => ({ @@ -473,8 +487,8 @@ export const useSidebarSearchModel = ({ const v2CommandSearchObjectMode = v2CommandSearchQuery.mode === 'object'; const v2CommandSearchAiMode = v2CommandSearchQuery.mode === 'ai'; const filteredCommandSearchTreeItems = useMemo(() => { - return filterV2CommandSearchTreeItems(commandSearchTreeItems, v2CommandSearchQuery); - }, [commandSearchTreeItems, v2CommandSearchQuery]); + return filterV2CommandSearchTreeItems(commandSearchTreeIndex, v2CommandSearchQuery); + }, [commandSearchTreeIndex, v2CommandSearchQuery]); const filteredCommandSearchActionItems = useMemo(() => { if (v2CommandSearchObjectMode || v2CommandSearchAiMode) return []; diff --git a/frontend/src/components/sidebar/useSidebarTreeLoaders.tsx b/frontend/src/components/sidebar/useSidebarTreeLoaders.tsx index e392c9c..72b0afa 100644 --- a/frontend/src/components/sidebar/useSidebarTreeLoaders.tsx +++ b/frontend/src/components/sidebar/useSidebarTreeLoaders.tsx @@ -13,14 +13,13 @@ import { TableOutlined, ThunderboltOutlined, } from '@ant-design/icons'; -import type { SavedConnection, SavedQuery, ExternalSQLDirectory, ExternalSQLTreeEntry, JVMCapability, JVMResourceSummary } from '../../types'; +import type { SavedConnection, SavedQuery, JVMCapability, JVMResourceSummary } from '../../types'; import { useStore } from '../../store'; import { t } from '../../i18n'; import { buildRpcConnectionConfig } from '../../utils/connectionRpcConfig'; import { buildRedisDbNodeLabel, getRedisDbAlias } from '../../utils/redisDbAlias'; import { buildJVMMonitoringActionDescriptors } from '../../utils/jvmSidebarActions'; import { type SidebarViewMetadataEntry } from '../../utils/sidebarMetadata'; -import { buildExternalSQLRootNode, type ExternalSQLTreeNode } from '../../utils/externalSqlTree'; import { buildQualifiedName, buildSidebarObjectKeyName, @@ -47,7 +46,7 @@ import { sortSidebarTableEntries, type SidebarTreeNode as TreeNode, } from '../sidebarV2Utils'; -import { DBGetDatabases, DBGetTables, DBQuery, GetDriverStatusList, JVMProbeCapabilities, ListSQLDirectory } from '../../../wailsjs/go/app/App'; +import { DBGetDatabases, DBGetTables, DBQuery, GetDriverStatusList, JVMProbeCapabilities } from '../../../wailsjs/go/app/App'; type DriverStatusSnapshot = { type: string; @@ -119,7 +118,6 @@ const resolveSavedConnectionDriverType = (conn: SavedConnection | undefined): st type UseSidebarTreeLoadersOptions = { savedQueries: SavedQuery[]; - externalSQLDirectories: ExternalSQLDirectory[]; tableSortPreference: Record; tableAccessCount: Record; pinnedSidebarTables: any[]; @@ -132,12 +130,11 @@ type UseSidebarTreeLoadersOptions = { buildJVMRuntimeConfig: (conn: SavedConnection & { dbName?: string }, providerMode: string) => any; buildJVMDiagnosticTreeNodes: (conn: SavedConnection) => TreeNode[]; resolveSavedQueryDisplayName: (name: string | null | undefined) => string; - decorateExternalSQLTreeNode: (node: ExternalSQLTreeNode) => TreeNode; + onDatabaseTreeLoaded?: (databaseKey: string) => void; }; export const useSidebarTreeLoaders = ({ savedQueries, - externalSQLDirectories, tableSortPreference, tableAccessCount, pinnedSidebarTables, @@ -150,7 +147,7 @@ export const useSidebarTreeLoaders = ({ buildJVMRuntimeConfig, buildJVMDiagnosticTreeNodes, resolveSavedQueryDisplayName, - decorateExternalSQLTreeNode, + onDatabaseTreeLoaded, }: UseSidebarTreeLoadersOptions) => { const driverStatusCacheRef = useRef<{ fetchedAt: number; @@ -516,40 +513,6 @@ export const useSidebarTreeLoaders = ({ loadFunctions(conn, conn.dbName), loadDatabaseEvents(conn, conn.dbName), ]); - const externalSQLDirectoryResults = await Promise.all( - externalSQLDirectories.map(async (directory: ExternalSQLDirectory) => { - const directoryRes = await ListSQLDirectory(directory.path); - if (!directoryRes.success) { - message.warning({ - key: `external-sql-${directory.id}`, - content: t('sidebar.message.external_sql_directory_read_failed', { - name: directory.name, - error: directoryRes.message, - }), - }); - return { id: directory.id, entries: [] as ExternalSQLTreeEntry[] }; - } - return { - id: directory.id, - entries: Array.isArray(directoryRes.data) ? directoryRes.data as ExternalSQLTreeEntry[] : [], - }; - }), - ); - const externalSQLTrees = externalSQLDirectoryResults.reduce>((accumulator, item) => { - accumulator[item.id] = item.entries; - return accumulator; - }, {}); - const externalSQLRootNode = decorateExternalSQLTreeNode(buildExternalSQLRootNode({ - dbNodeKey: String(key), - connectionId: String(conn.id), - dbName: String(conn.dbName), - directories: externalSQLDirectories, - directoryTrees: externalSQLTrees, - labels: { - root: t('sidebar.external_sql.root'), - directoryFallback: t('sidebar.external_sql.directory_fallback'), - }, - })); const viewRows: SidebarViewMetadataEntry[] = Array.isArray(viewsResult.views) ? viewsResult.views : []; const materializedViewRows: SidebarViewMetadataEntry[] = Array.isArray(materializedViewsResult.views) ? materializedViewsResult.views : []; const triggerRows: any[] = Array.isArray(triggersResult.triggers) ? triggersResult.triggers : []; @@ -855,6 +818,7 @@ export const useSidebarTreeLoaders = ({ replaceTreeNodeChildren(key, [queriesNode, ...groupedNodes]); } + onDatabaseTreeLoaded?.(String(key)); } else { setConnectionStates(prev => ({ ...prev, [key as string]: 'error' })); message.error({ content: res.message, key: `db-${key}-tables` }); diff --git a/frontend/src/components/sidebarV2Utils.command-search.test.ts b/frontend/src/components/sidebarV2Utils.command-search.test.ts new file mode 100644 index 0000000..1f7219c --- /dev/null +++ b/frontend/src/components/sidebarV2Utils.command-search.test.ts @@ -0,0 +1,116 @@ +import { describe, expect, it } from 'vitest'; + +import { + V2_COMMAND_SEARCH_INITIAL_TREE_LIMIT, + V2_COMMAND_SEARCH_MAX_TREE_RESULTS, + buildV2CommandSearchTreeIndex, + filterV2CommandSearchTreeItems, + parseV2CommandSearchQuery, + resolveSidebarDatabaseTreePruneKeys, + type V2CommandSearchItem, +} from './sidebarV2Utils'; + +const buildNodeItems = (count: number): V2CommandSearchItem[] => { + return Array.from({ length: count }, (_, index) => ({ + key: `node-table-${index}`, + kind: 'node' as const, + title: `fs_order_${index}`, + meta: `开发240 · front_end_sys_${index % 4}`, + icon: null, + node: { + type: index % 6 === 0 ? 'view' : 'table', + key: `table-${index}`, + title: `fs_order_${index}`, + dataRef: { + tableName: `fs_order_${index}`, + viewName: index % 6 === 0 ? `v_order_${index}` : undefined, + dbName: `front_end_sys_${index % 4}`, + name: `obj_${index}`, + config: { + host: `10.0.0.${index % 16}`, + }, + }, + }, + })); +}; + +describe('sidebarV2 command search performance helpers', () => { + it('keeps the initial tree result limit when the query is empty', () => { + const items = buildNodeItems(V2_COMMAND_SEARCH_INITIAL_TREE_LIMIT + 80); + + expect( + filterV2CommandSearchTreeItems(items, parseV2CommandSearchQuery('')), + ).toHaveLength(V2_COMMAND_SEARCH_INITIAL_TREE_LIMIT); + }); + + it('caps broad keyword matches to avoid rendering the full loaded tree', () => { + const items = buildNodeItems(V2_COMMAND_SEARCH_MAX_TREE_RESULTS + 160); + + const result = filterV2CommandSearchTreeItems( + items, + parseV2CommandSearchQuery('fs_order'), + ); + + expect(result).toHaveLength(V2_COMMAND_SEARCH_MAX_TREE_RESULTS); + expect(result[0]?.key).toBe('node-table-0'); + expect(result[result.length - 1]?.key).toBe(`node-table-${V2_COMMAND_SEARCH_MAX_TREE_RESULTS - 1}`); + }); + + it('returns the same matches when filtering with a prebuilt search index', () => { + const items = buildNodeItems(200); + const index = buildV2CommandSearchTreeIndex(items); + const query = parseV2CommandSearchQuery('@fs_order_1'); + + expect(filterV2CommandSearchTreeItems(index, query)).toEqual( + filterV2CommandSearchTreeItems(items, query), + ); + }); + + it('prunes only cold collapsed database trees when too many object trees stay loaded', () => { + expect(resolveSidebarDatabaseTreePruneKeys({ + treeData: [ + { + key: 'conn-1', + title: 'conn-1', + type: 'connection', + children: [ + { + key: 'conn-1-db-a', + title: 'db-a', + type: 'database', + children: [{ key: 'a-tables', title: '表', type: 'object-group' }], + }, + { + key: 'conn-1-db-b', + title: 'db-b', + type: 'database', + children: [{ key: 'b-tables', title: '表', type: 'object-group' }], + }, + { + key: 'conn-1-db-c', + title: 'db-c', + type: 'database', + children: [{ key: 'c-tables', title: '表', type: 'object-group' }], + }, + { + key: 'conn-1-db-d', + title: 'db-d', + type: 'database', + children: [{ key: 'd-tables', title: '表', type: 'object-group' }], + }, + ], + }, + ], + expandedKeys: ['conn-1-db-c'], + selectedKeys: [], + activeDatabaseKey: 'conn-1-db-d', + touchedAtByDatabaseKey: { + 'conn-1-db-a': 10, + 'conn-1-db-b': 20, + 'conn-1-db-c': 30, + 'conn-1-db-d': 40, + }, + maxLoadedDatabases: 2, + })).toEqual(['conn-1-db-a', 'conn-1-db-b']); + }); +}); diff --git a/frontend/src/components/sidebarV2Utils.ts b/frontend/src/components/sidebarV2Utils.ts index cb24057..c2fab47 100644 --- a/frontend/src/components/sidebarV2Utils.ts +++ b/frontend/src/components/sidebarV2Utils.ts @@ -415,6 +415,13 @@ export type V2CommandSearchItem = dbName?: string; }; +export interface V2CommandSearchTreeIndexEntry { + item: Extract; + normalizedSearchText: string; + normalizedObjectText: string; + objectNode: boolean; +} + export type V2CommandSearchMode = 'default' | 'object' | 'ai'; export interface V2CommandSearchQuery { @@ -467,40 +474,69 @@ const isV2CommandSearchObjectNode = (node: SidebarTreeNode): boolean => { || node.type === 'materialized-view'; }; -const V2_COMMAND_SEARCH_INITIAL_TREE_LIMIT = 24; +export const V2_COMMAND_SEARCH_INITIAL_TREE_LIMIT = 24; +export const V2_COMMAND_SEARCH_MAX_TREE_RESULTS = 120; + +export const buildV2CommandSearchTreeIndex = ( + items: V2CommandSearchItem[], +): V2CommandSearchTreeIndexEntry[] => { + return items.flatMap((item) => { + if (item.kind !== 'node') { + return []; + } + const dataRef = item.node.dataRef || {}; + const normalizedTitle = String(item.title || '').toLowerCase(); + const normalizedPrimaryObjectText = String( + dataRef.tableName || dataRef.viewName || item.title || '', + ).toLowerCase(); + + return [{ + item, + normalizedSearchText: [ + item.title, + item.meta, + dataRef.tableName, + dataRef.viewName, + dataRef.dbName, + dataRef.name, + dataRef.config?.host, + ].filter(Boolean).join(' ').toLowerCase(), + normalizedObjectText: `${normalizedPrimaryObjectText} ${normalizedTitle}`.trim(), + objectNode: isV2CommandSearchObjectNode(item.node), + }]; + }); +}; export const filterV2CommandSearchTreeItems = ( - items: V2CommandSearchItem[], + items: V2CommandSearchItem[] | V2CommandSearchTreeIndexEntry[], query: V2CommandSearchQuery, ): V2CommandSearchItem[] => { if (query.mode === 'ai') return []; + const index = items.length > 0 && 'item' in items[0] + ? items as V2CommandSearchTreeIndexEntry[] + : buildV2CommandSearchTreeIndex(items as V2CommandSearchItem[]); const normalizedKeyword = query.normalizedKeyword; const objectMode = query.mode === 'object'; - const matchedItems = items.filter((item) => { - if (item.kind !== 'node') return false; - const node = item.node; - const dataRef = node.dataRef || {}; - if (objectMode && !isV2CommandSearchObjectNode(node)) { - return false; + const result: V2CommandSearchItem[] = []; + const maxResults = normalizedKeyword + ? V2_COMMAND_SEARCH_MAX_TREE_RESULTS + : V2_COMMAND_SEARCH_INITIAL_TREE_LIMIT; + + for (const entry of index) { + if (objectMode && !entry.objectNode) { + continue; } - if (!normalizedKeyword) return true; - const objectName = String(dataRef.tableName || dataRef.viewName || item.title || '').toLowerCase(); - if (objectMode) { - return objectName.includes(normalizedKeyword) - || String(item.title || '').toLowerCase().includes(normalizedKeyword); + if (!normalizedKeyword) { + result.push(entry.item); + } else if (objectMode ? entry.normalizedObjectText.includes(normalizedKeyword) : entry.normalizedSearchText.includes(normalizedKeyword)) { + result.push(entry.item); } - const haystack = [ - item.title, - item.meta, - dataRef.tableName, - dataRef.viewName, - dataRef.dbName, - dataRef.name, - dataRef.config?.host, - ].filter(Boolean).join(' ').toLowerCase(); - return haystack.includes(normalizedKeyword); - }); - return normalizedKeyword ? matchedItems : matchedItems.slice(0, V2_COMMAND_SEARCH_INITIAL_TREE_LIMIT); + if (result.length >= maxResults) { + break; + } + } + + return result; }; export interface V2CommandSearchEnterState { @@ -765,4 +801,63 @@ export const resolveV2ActiveConnectionId = ({ || ''; }; +export const resolveSidebarDatabaseTreePruneKeys = ({ + treeData, + expandedKeys, + selectedKeys, + activeDatabaseKey, + touchedAtByDatabaseKey, + maxLoadedDatabases, +}: { + treeData: SidebarTreeNode[]; + expandedKeys: React.Key[]; + selectedKeys: React.Key[]; + activeDatabaseKey?: string; + touchedAtByDatabaseKey?: Record; + maxLoadedDatabases: number; +}): string[] => { + if (!Number.isFinite(maxLoadedDatabases) || maxLoadedDatabases <= 0) { + return []; + } + + const loadedDatabaseKeys: string[] = []; + const visit = (nodes: SidebarTreeNode[]) => { + nodes.forEach((node) => { + if (node.type === 'database' && Array.isArray(node.children) && node.children.length > 0) { + loadedDatabaseKeys.push(String(node.key || '').trim()); + return; + } + if (node.children?.length) { + visit(node.children); + } + }); + }; + visit(treeData); + + if (loadedDatabaseKeys.length <= maxLoadedDatabases) { + return []; + } + + const expandedKeySet = new Set(expandedKeys.map((key) => String(key || '').trim()).filter(Boolean)); + const selectedKeySet = new Set(selectedKeys.map((key) => String(key || '').trim()).filter(Boolean)); + const protectedDatabaseKeys = new Set(); + if (activeDatabaseKey) { + protectedDatabaseKeys.add(String(activeDatabaseKey).trim()); + } + + const candidates = loadedDatabaseKeys + .filter((key) => key && !expandedKeySet.has(key) && !selectedKeySet.has(key) && !protectedDatabaseKeys.has(key)) + .sort((left, right) => { + const leftTouchedAt = Number(touchedAtByDatabaseKey?.[left] || 0); + const rightTouchedAt = Number(touchedAtByDatabaseKey?.[right] || 0); + if (leftTouchedAt !== rightTouchedAt) { + return leftTouchedAt - rightTouchedAt; + } + return left.localeCompare(right); + }); + + const pruneCount = loadedDatabaseKeys.length - maxLoadedDatabases; + return candidates.slice(0, pruneCount); +}; + export const shouldClearSidebarActiveContextOnEmptySelect = (isV2Ui: boolean): boolean => !isV2Ui; diff --git a/frontend/src/store.test.ts b/frontend/src/store.test.ts index 6d0d8bc..684a1e4 100644 --- a/frontend/src/store.test.ts +++ b/frontend/src/store.test.ts @@ -1253,33 +1253,80 @@ describe('store appearance persistence', () => { expect(useStore.getState().activeTabId).toBe('query-1'); }); - it('persists recent SQL execution logs and trims oversized entries', async () => { + it('keeps only the most recent runtime SQL logs and trims oversized entries', async () => { const { useStore } = await importStore(); - const longSql = `select '${'x'.repeat(120 * 1024)}'`; + const longSql = `select '${'x'.repeat(20 * 1024)}'`; - useStore.getState().addSqlLog({ - id: 'log-1', - timestamp: 100, - sql: longSql, - status: 'success', - duration: 12, + for (let i = 0; i < 140; i += 1) { + useStore.getState().addSqlLog({ + id: `log-${i}`, + timestamp: 100 + i, + sql: longSql, + status: 'success', + duration: 12 + i, + dbName: 'main', + }); + } + + expect(useStore.getState().sqlLogs).toHaveLength(120); + expect(useStore.getState().sqlLogs[0]).toEqual(expect.objectContaining({ + id: 'log-139', dbName: 'main', - }); + })); + expect(useStore.getState().sqlLogs[119]).toEqual(expect.objectContaining({ + id: 'log-20', + })); + expect(useStore.getState().sqlLogs[0]?.sql.length).toBe(12 * 1024); const persisted = JSON.parse(storage.getItem('lite-db-storage') || '{}'); - expect(persisted.state.sqlLogs).toHaveLength(1); - expect(persisted.state.sqlLogs[0].sql.length).toBe(100 * 1024); + expect(persisted.state.sqlLogs).toHaveLength(120); + expect(persisted.state.sqlLogs[0].sql.length).toBe(12 * 1024); expect(persisted.state.sqlLogs[0].dbName).toBe('main'); vi.resetModules(); const reloaded = await importStore(); expect(reloaded.useStore.getState().sqlLogs[0]).toEqual(expect.objectContaining({ - id: 'log-1', + id: 'log-139', status: 'success', - duration: 12, + duration: 151, dbName: 'main', })); - expect(reloaded.useStore.getState().sqlLogs[0]?.sql.length).toBe(100 * 1024); + expect(reloaded.useStore.getState().sqlLogs).toHaveLength(120); + expect(reloaded.useStore.getState().sqlLogs[119]).toEqual(expect.objectContaining({ + id: 'log-20', + })); + expect(reloaded.useStore.getState().sqlLogs[0]?.sql.length).toBe(12 * 1024); + }); + + it('shrinks oversized SQL logs from older persisted snapshots during hydration', async () => { + storage.setItem('lite-db-storage', JSON.stringify({ + state: { + sqlLogs: Array.from({ length: 200 }, (_, index) => ({ + id: `legacy-log-${index}`, + timestamp: 500 + index, + sql: `select '${'x'.repeat(18 * 1024)}'`, + status: index % 2 === 0 ? 'success' : 'error', + duration: index, + dbName: 'legacy', + message: 'm'.repeat(3 * 1024), + })), + }, + version: 12, + })); + + const { useStore } = await importStore(); + const sqlLogs = useStore.getState().sqlLogs; + + expect(sqlLogs).toHaveLength(120); + expect(sqlLogs[0]).toEqual(expect.objectContaining({ + id: 'legacy-log-0', + dbName: 'legacy', + })); + expect(sqlLogs[119]).toEqual(expect.objectContaining({ + id: 'legacy-log-119', + })); + expect(sqlLogs[0]?.sql.length).toBe(12 * 1024); + expect(sqlLogs[0]?.message?.length).toBe(1024); }); it('defaults AI chat send shortcut to Enter in shared shortcut options', async () => { diff --git a/frontend/src/store.ts b/frontend/src/store.ts index 0d5dee3..b9a3b39 100644 --- a/frontend/src/store.ts +++ b/frontend/src/store.ts @@ -137,14 +137,16 @@ const MIN_KEEPALIVE_INTERVAL_MINUTES = 1; const MAX_KEEPALIVE_INTERVAL_MINUTES = 1440; const DEFAULT_DIAGNOSTIC_TIMEOUT_SECONDS = 15; const MAX_DIAGNOSTIC_TIMEOUT_SECONDS = 300; -const PERSIST_VERSION = 12; +const PERSIST_VERSION = 13; const PERSIST_STORAGE_KEY = "lite-db-storage"; const PERSIST_WRITE_DEBOUNCE_MS = 160; const MAX_PERSISTED_QUERY_TABS = 20; const MAX_PERSISTED_QUERY_LENGTH = 1024 * 1024; -const MAX_SQL_LOGS = 1000; +const MAX_RUNTIME_SQL_LOGS = 120; +const MAX_RUNTIME_SQL_LOG_LENGTH = 12 * 1024; +const MAX_RUNTIME_SQL_LOG_MESSAGE_LENGTH = 1024; const MAX_PERSISTED_SQL_LOGS = 200; -const MAX_PERSISTED_SQL_LOG_LENGTH = 100 * 1024; +const MAX_PERSISTED_SQL_LOG_LENGTH = 24 * 1024; const MAX_PERSISTED_SQL_LOG_MESSAGE_LENGTH = 2 * 1024; const MAX_TABLE_EXPORT_HISTORY_PER_TARGET = 20; const MAX_TABLE_EXPORT_HISTORY_TARGETS = 200; @@ -1708,50 +1710,101 @@ const resolveActiveContextForTabId = ( return fallbackContext; }; -const sanitizeSqlLogs = (value: unknown, limit = MAX_PERSISTED_SQL_LOGS): SqlLog[] => { +type SqlLogSanitizeOptions = { + limit: number; + sqlLength: number; + messageLength: number; +}; + +const RUNTIME_SQL_LOG_SANITIZE_OPTIONS: SqlLogSanitizeOptions = { + limit: MAX_RUNTIME_SQL_LOGS, + sqlLength: MAX_RUNTIME_SQL_LOG_LENGTH, + messageLength: MAX_RUNTIME_SQL_LOG_MESSAGE_LENGTH, +}; + +const PERSISTED_SQL_LOG_SANITIZE_OPTIONS: SqlLogSanitizeOptions = { + limit: MAX_PERSISTED_SQL_LOGS, + sqlLength: MAX_PERSISTED_SQL_LOG_LENGTH, + messageLength: MAX_PERSISTED_SQL_LOG_MESSAGE_LENGTH, +}; + +const sanitizeSqlLogEntry = ( + entry: unknown, + index: number, + options: SqlLogSanitizeOptions, +): SqlLog | null => { + if (!entry || typeof entry !== "object") return null; + const raw = entry as Record; + const sql = typeof raw.sql === "string" ? raw.sql.slice(0, options.sqlLength) : ""; + if (!sql.trim()) return null; + + const status = raw.status === "error" ? "error" : "success"; + const timestamp = Number(raw.timestamp); + const duration = Number(raw.duration); + const affectedRows = Number(raw.affectedRows); + const message = typeof raw.message === "string" + ? raw.message.slice(0, options.messageLength) + : ""; + + const log: SqlLog = { + id: toTrimmedString(raw.id, `log-${index + 1}`) || `log-${index + 1}`, + timestamp: Number.isFinite(timestamp) && timestamp > 0 ? timestamp : Date.now(), + sql, + status, + duration: Number.isFinite(duration) && duration >= 0 ? duration : 0, + dbName: toTrimmedString(raw.dbName) || undefined, + }; + + if (message) { + log.message = message; + } + if (Number.isFinite(affectedRows)) { + log.affectedRows = affectedRows; + } + + return log; +}; + +const sanitizeSqlLogs = ( + value: unknown, + options: SqlLogSanitizeOptions = PERSISTED_SQL_LOG_SANITIZE_OPTIONS, +): SqlLog[] => { if (!Array.isArray(value)) return []; const result: SqlLog[] = []; const seenIds = new Set(); value.forEach((entry, index) => { - if (!entry || typeof entry !== "object") return; - const raw = entry as Record; - const sql = typeof raw.sql === "string" ? raw.sql.slice(0, MAX_PERSISTED_SQL_LOG_LENGTH) : ""; - if (!sql.trim()) return; + const log = sanitizeSqlLogEntry(entry, index, options); + if (!log) return; - let id = toTrimmedString(raw.id, `log-${index + 1}`) || `log-${index + 1}`; + let id = log.id; if (seenIds.has(id)) { id = `${id}-${index + 1}`; } seenIds.add(id); - const status = raw.status === "error" ? "error" : "success"; - const timestamp = Number(raw.timestamp); - const duration = Number(raw.duration); - const affectedRows = Number(raw.affectedRows); - const log: SqlLog = { - id, - timestamp: Number.isFinite(timestamp) && timestamp > 0 ? timestamp : Date.now(), - sql, - status, - duration: Number.isFinite(duration) && duration >= 0 ? duration : 0, - dbName: toTrimmedString(raw.dbName) || undefined, - }; - - const message = typeof raw.message === "string" - ? raw.message.slice(0, MAX_PERSISTED_SQL_LOG_MESSAGE_LENGTH) - : ""; - if (message) { - log.message = message; - } - if (Number.isFinite(affectedRows)) { - log.affectedRows = affectedRows; - } - - result.push(log); + result.push(id === log.id ? log : { ...log, id }); }); - return result.slice(0, limit); + return result.slice(0, options.limit); +}; + +const sanitizeRuntimeSqlLogs = (value: unknown) => + sanitizeSqlLogs(value, RUNTIME_SQL_LOG_SANITIZE_OPTIONS); + +const sanitizePersistedSqlLogs = (value: unknown) => + sanitizeSqlLogs(value, PERSISTED_SQL_LOG_SANITIZE_OPTIONS); + +const appendRuntimeSqlLog = (existing: SqlLog[], entry: SqlLog): SqlLog[] => { + const nextEntry = sanitizeSqlLogEntry(entry, 0, RUNTIME_SQL_LOG_SANITIZE_OPTIONS); + if (!nextEntry) { + return existing; + } + + const nextLogs = [nextEntry, ...existing.slice(0, MAX_RUNTIME_SQL_LOGS - 1)]; + return existing.some((item) => item.id === nextEntry.id) + ? sanitizeRuntimeSqlLogs(nextLogs) + : nextLogs; }; const hasLegacyConnectionSecrets = ( @@ -3155,7 +3208,7 @@ export const useStore = create()( }), addSqlLog: (log) => - set((state) => ({ sqlLogs: sanitizeSqlLogs([log, ...state.sqlLogs], MAX_SQL_LOGS) })), + set((state) => ({ sqlLogs: appendRuntimeSqlLog(state.sqlLogs, log) })), clearSqlLogs: () => set({ sqlLogs: [] }), upsertTableExportHistory: (historyKey, entry) => set((state) => { @@ -3552,7 +3605,7 @@ export const useStore = create()( nextState.shortcutOptions = sanitizeShortcutOptions( state.shortcutOptions, ); - nextState.sqlLogs = sanitizeSqlLogs(state.sqlLogs); + nextState.sqlLogs = sanitizeRuntimeSqlLogs(state.sqlLogs); nextState.tableExportHistories = sanitizeTableExportHistories( state.tableExportHistories, ); @@ -3665,7 +3718,7 @@ export const useStore = create()( state.sqlEditorTransactionOptions, ), shortcutOptions: sanitizeShortcutOptions(state.shortcutOptions), - sqlLogs: sanitizeSqlLogs(state.sqlLogs), + sqlLogs: sanitizeRuntimeSqlLogs(state.sqlLogs), sqlSnippets: sanitizeSqlSnippets(state.sqlSnippets), tableAccessCount: sanitizeTableAccessCount(state.tableAccessCount), @@ -3697,7 +3750,7 @@ export const useStore = create()( dataEditTransactionOptions: state.dataEditTransactionOptions, sqlEditorTransactionOptions: state.sqlEditorTransactionOptions, shortcutOptions: resolveShortcutOptionsForPersistence(state.shortcutOptions), - sqlLogs: sanitizeSqlLogs(state.sqlLogs), + sqlLogs: sanitizePersistedSqlLogs(state.sqlLogs), tableExportHistories: sanitizeTableExportHistories( state.tableExportHistories, ), From 495a985ae1b13649d6bf967bcc410099a9f6cfe3 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Tue, 23 Jun 2026 08:48:42 +0800 Subject: [PATCH 2/6] =?UTF-8?q?=F0=9F=90=9B=20fix(sqlserver):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E5=8F=AF=E9=80=89=E9=A9=B1=E5=8A=A8=E6=9F=A5=E8=AF=A2?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E9=80=8F=E4=BC=A0=E7=BC=BA=E5=A4=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 为 optional-driver-agent 的 query 和 queryMulti 响应补充 messages 字段 - 在可选驱动 DB 客户端透传 SQL Server 查询提示信息与多结果集 - 补充 agent 与数据库层回归测试并更新 driver agent revision --- cmd/optional-driver-agent/main.go | 153 +++++++++++++- cmd/optional-driver-agent/main_test.go | 103 ++++++++++ internal/db/driver_agent_revisions_gen.go | 42 ++-- internal/db/optional_driver_agent_impl.go | 186 +++++++++++++----- .../db/optional_driver_agent_impl_test.go | 65 ++++++ 5 files changed, 472 insertions(+), 77 deletions(-) diff --git a/cmd/optional-driver-agent/main.go b/cmd/optional-driver-agent/main.go index 063e6ca..03d45e6 100644 --- a/cmd/optional-driver-agent/main.go +++ b/cmd/optional-driver-agent/main.go @@ -36,6 +36,7 @@ type agentResponse struct { Error string `json:"error,omitempty"` Data interface{} `json:"data,omitempty"` Fields []string `json:"fields,omitempty"` + Messages []string `json:"messages,omitempty"` ChunkType string `json:"chunkType,omitempty"` RowsAffected int64 `json:"rowsAffected,omitempty"` } @@ -48,6 +49,7 @@ const ( agentMethodOpenSession = "openSession" agentMethodCloseSession = "closeSession" agentMethodQuery = "query" + agentMethodQueryMulti = "queryMulti" agentMethodStreamQuery = "streamQuery" agentMethodExec = "exec" agentMethodGetDatabases = "getDatabases" @@ -64,9 +66,9 @@ const ( const legacyClickHouseDefaultTimeout = 2 * time.Hour const ( - agentChunkColumns = "columns" - agentChunkRows = "rows" - agentChunkDone = "done" + agentChunkColumns = "columns" + agentChunkRows = "rows" + agentChunkDone = "done" // agentStreamBatchSize 控制 driver-agent 向主进程发送 row chunk 的批次大小。 // 调小到 64:单批 JSON 编码 + 主进程解码的瞬时内存峰值降为原来的 1/4, // 代价是 IPC 次数变为 4 倍,但每批仅一次 stdin/stdout 行读写,整体影响可忽略。 @@ -236,12 +238,23 @@ func handleRequest(runtimeState *agentRuntime, req agentRequest) agentResponse { } else if ok { switch method { case agentMethodQuery: - data, fields, err := queryStatementWithOptionalTimeout(session, req.Query, req.TimeoutMs) + data, fields, messages, err := queryStatementWithMessagesOptionalTimeout(session, req.Query, req.TimeoutMs) if err != nil { return fail(resp, err.Error()) } resp.Data = data resp.Fields = fields + resp.Messages = messages + case agentMethodQueryMulti: + data, messages, supported, err := queryMultiStatementWithMessagesOptionalTimeout(session, req.Query, req.TimeoutMs) + if err != nil { + return fail(resp, err.Error()) + } + if !supported { + return fail(resp, "当前事务会话不支持多结果集查询") + } + resp.Data = data + resp.Messages = messages case agentMethodExec: affected, err := execStatementWithOptionalTimeout(session, req.Query, req.TimeoutMs) if err != nil { @@ -260,12 +273,23 @@ func handleRequest(runtimeState *agentRuntime, req agentRequest) agentResponse { return fail(resp, err.Error()) } case agentMethodQuery: - data, fields, err := queryWithOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs) + data, fields, messages, err := queryWithMessagesOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs) if err != nil { return fail(resp, err.Error()) } resp.Data = data resp.Fields = fields + resp.Messages = messages + case agentMethodQueryMulti: + data, messages, supported, err := queryMultiWithMessagesOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs) + if err != nil { + return fail(resp, err.Error()) + } + if !supported { + return fail(resp, "当前驱动不支持原生多结果集查询") + } + resp.Data = data + resp.Messages = messages case agentMethodExec: affected, err := execWithOptionalTimeout(runtimeState.inst, req.Query, req.TimeoutMs) if err != nil { @@ -581,6 +605,30 @@ type agentQueryContextRunner interface { QueryContext(context.Context, string) ([]map[string]interface{}, []string, error) } +type agentQueryMessageRunner interface { + QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) +} + +type agentQueryMessageContextRunner interface { + QueryContextWithMessages(context.Context, string) ([]map[string]interface{}, []string, []string, error) +} + +type agentMultiResultMessageRunner interface { + QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) +} + +type agentMultiResultMessageContextRunner interface { + QueryMultiContextWithMessages(context.Context, string) ([]connection.ResultSetData, []string, error) +} + +type agentMultiResultRunner interface { + QueryMulti(query string) ([]connection.ResultSetData, error) +} + +type agentMultiResultContextRunner interface { + QueryMultiContext(context.Context, string) ([]connection.ResultSetData, error) +} + type agentExecRunner interface { Exec(string) (int64, error) } @@ -589,20 +637,39 @@ type agentExecContextRunner interface { ExecContext(context.Context, string) (int64, error) } -func queryWithOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) { +func queryWithMessagesOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, []string, error) { effectiveTimeoutMs := timeoutMs if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") { effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond) } if effectiveTimeoutMs <= 0 { - return inst.Query(query) + if q, ok := inst.(agentQueryMessageRunner); ok { + return q.QueryWithMessages(query) + } + data, fields, err := inst.Query(query) + return data, fields, nil, err + } + if q, ok := inst.(agentQueryMessageContextRunner); ok { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond) + defer cancel() + return q.QueryContextWithMessages(ctx, query) } if q, ok := inst.(agentQueryContextRunner); ok { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond) defer cancel() - return q.QueryContext(ctx, query) + data, fields, err := q.QueryContext(ctx, query) + return data, fields, nil, err } - return inst.Query(query) + if q, ok := inst.(agentQueryMessageRunner); ok { + return q.QueryWithMessages(query) + } + data, fields, err := inst.Query(query) + return data, fields, nil, err +} + +func queryWithOptionalTimeout(inst agentQueryRunner, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) { + data, fields, _, err := queryWithMessagesOptionalTimeout(inst, query, timeoutMs) + return data, fields, err } func queryStatementWithOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]map[string]interface{}, []string, error) { @@ -613,6 +680,74 @@ func queryStatementWithOptionalTimeout(inst db.StatementExecer, query string, ti return queryWithOptionalTimeout(queryRunner, query, timeoutMs) } +func queryStatementWithMessagesOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]map[string]interface{}, []string, []string, error) { + queryRunner, ok := inst.(agentQueryRunner) + if !ok { + return nil, nil, nil, fmt.Errorf("当前事务会话不支持查询语句") + } + return queryWithMessagesOptionalTimeout(queryRunner, query, timeoutMs) +} + +func queryMultiWithMessagesOptionalTimeout(inst db.Database, query string, timeoutMs int64) ([]connection.ResultSetData, []string, bool, error) { + effectiveTimeoutMs := timeoutMs + if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") { + effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond) + } + if effectiveTimeoutMs > 0 { + if q, ok := inst.(agentMultiResultMessageContextRunner); ok { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond) + defer cancel() + data, messages, err := q.QueryMultiContextWithMessages(ctx, query) + return data, messages, true, err + } + if q, ok := inst.(agentMultiResultContextRunner); ok { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond) + defer cancel() + data, err := q.QueryMultiContext(ctx, query) + return data, nil, true, err + } + } + if q, ok := inst.(agentMultiResultMessageRunner); ok { + data, messages, err := q.QueryMultiWithMessages(query) + return data, messages, true, err + } + if q, ok := inst.(agentMultiResultRunner); ok { + data, err := q.QueryMulti(query) + return data, nil, true, err + } + return nil, nil, false, nil +} + +func queryMultiStatementWithMessagesOptionalTimeout(inst db.StatementExecer, query string, timeoutMs int64) ([]connection.ResultSetData, []string, bool, error) { + effectiveTimeoutMs := timeoutMs + if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") { + effectiveTimeoutMs = int64(legacyClickHouseDefaultTimeout / time.Millisecond) + } + if effectiveTimeoutMs > 0 { + if q, ok := inst.(agentMultiResultMessageContextRunner); ok { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond) + defer cancel() + data, messages, err := q.QueryMultiContextWithMessages(ctx, query) + return data, messages, true, err + } + if q, ok := inst.(agentMultiResultContextRunner); ok { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(effectiveTimeoutMs)*time.Millisecond) + defer cancel() + data, err := q.QueryMultiContext(ctx, query) + return data, nil, true, err + } + } + if q, ok := inst.(agentMultiResultMessageRunner); ok { + data, messages, err := q.QueryMultiWithMessages(query) + return data, messages, true, err + } + if q, ok := inst.(agentMultiResultRunner); ok { + data, err := q.QueryMulti(query) + return data, nil, true, err + } + return nil, nil, false, nil +} + func streamWithOptionalTimeout(inst db.StreamQueryExecer, query string, timeoutMs int64, consumer db.QueryStreamConsumer) error { effectiveTimeoutMs := timeoutMs if effectiveTimeoutMs <= 0 && strings.EqualFold(strings.TrimSpace(agentDriverType), "clickhouse") { diff --git a/cmd/optional-driver-agent/main_test.go b/cmd/optional-driver-agent/main_test.go index eac28d6..5db34e9 100644 --- a/cmd/optional-driver-agent/main_test.go +++ b/cmd/optional-driver-agent/main_test.go @@ -101,6 +101,9 @@ type fakeAgentTimeoutDB struct { execCalled bool execContextCalled bool deadlineSet bool + queryMessages []string + multiResults []connection.ResultSetData + multiMessages []string } func (f *fakeAgentTimeoutDB) Connect(config connection.ConnectionConfig) error { return nil } @@ -117,6 +120,14 @@ func (f *fakeAgentTimeoutDB) QueryContext(ctx context.Context, query string) ([] } return []map[string]interface{}{{"ok": 1}}, []string{"ok"}, nil } +func (f *fakeAgentTimeoutDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + data, fields, err := f.QueryContext(context.Background(), query) + return data, fields, append([]string(nil), f.queryMessages...), err +} +func (f *fakeAgentTimeoutDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { + data, fields, err := f.QueryContext(ctx, query) + return data, fields, append([]string(nil), f.queryMessages...), err +} func (f *fakeAgentTimeoutDB) Exec(query string) (int64, error) { f.execCalled = true return 0, errors.New("exec should not be called") @@ -150,6 +161,15 @@ func (f *fakeAgentTimeoutDB) GetForeignKeys(dbName, tableName string) ([]connect func (f *fakeAgentTimeoutDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { return nil, nil } +func (f *fakeAgentTimeoutDB) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) { + return append([]connection.ResultSetData(nil), f.multiResults...), append([]string(nil), f.multiMessages...), nil +} +func (f *fakeAgentTimeoutDB) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) { + if _, ok := ctx.Deadline(); ok { + f.deadlineSet = true + } + return f.QueryMultiWithMessages(query) +} type fakeAgentSessionDB struct { fakeAgentTimeoutDB @@ -165,6 +185,7 @@ type fakeAgentStatementSession struct { queryCalls int execCalls int closed bool + messages []string } func (f *fakeAgentStatementSession) Query(query string) ([]map[string]interface{}, []string, error) { @@ -175,6 +196,14 @@ func (f *fakeAgentStatementSession) QueryContext(ctx context.Context, query stri f.queryCalls++ return []map[string]interface{}{{"session_ok": 1}}, []string{"session_ok"}, nil } +func (f *fakeAgentStatementSession) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + data, fields, err := f.QueryContext(context.Background(), query) + return data, fields, append([]string(nil), f.messages...), err +} +func (f *fakeAgentStatementSession) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { + data, fields, err := f.QueryContext(ctx, query) + return data, fields, append([]string(nil), f.messages...), err +} func (f *fakeAgentStatementSession) Exec(query string) (int64, error) { return f.ExecContext(context.Background(), query) @@ -297,6 +326,77 @@ func TestQueryWithOptionalTimeout_ClickHouseLegacyModeUsesQueryContext(t *testin } } +func TestHandleRequest_QueryIncludesServerMessages(t *testing.T) { + old := agentDriverType + defer func() { agentDriverType = old }() + agentDriverType = "sqlserver" + + fake := &fakeAgentTimeoutDB{ + queryMessages: []string{"PRINT sql line 1", "PRINT sql line 2"}, + } + runtimeState := &agentRuntime{inst: fake, sessions: make(map[string]db.StatementExecer)} + + resp := handleRequest(runtimeState, agentRequest{ + ID: 11, + Method: agentMethodQuery, + Query: "exec dbo.p_get_select", + TimeoutMs: int64((2 * time.Second).Milliseconds()), + }) + if !resp.Success { + t.Fatalf("query request failed: %s", resp.Error) + } + if len(resp.Messages) != 2 || resp.Messages[0] != "PRINT sql line 1" { + t.Fatalf("expected query messages to be preserved, got %#v", resp.Messages) + } +} + +func TestHandleRequest_QueryMultiIncludesResultSetsAndMessages(t *testing.T) { + old := agentDriverType + defer func() { agentDriverType = old }() + agentDriverType = "sqlserver" + + fake := &fakeAgentTimeoutDB{ + multiResults: []connection.ResultSetData{ + { + StatementIndex: 1, + Rows: []map[string]interface{}{{"name": "master"}}, + Columns: []string{"name"}, + }, + { + StatementIndex: 1, + Rows: []map[string]interface{}{}, + Columns: []string{}, + Messages: []string{"PRINT generated sql"}, + }, + }, + multiMessages: []string{"batch top-level message"}, + } + runtimeState := &agentRuntime{inst: fake, sessions: make(map[string]db.StatementExecer)} + + resp := handleRequest(runtimeState, agentRequest{ + ID: 12, + Method: agentMethodQueryMulti, + Query: "exec dbo.p_get_select", + TimeoutMs: int64((2 * time.Second).Milliseconds()), + }) + if !resp.Success { + t.Fatalf("queryMulti request failed: %s", resp.Error) + } + if len(resp.Messages) != 1 || resp.Messages[0] != "batch top-level message" { + t.Fatalf("expected top-level messages to be preserved, got %#v", resp.Messages) + } + resultSets, ok := resp.Data.([]connection.ResultSetData) + if !ok { + t.Fatalf("expected []connection.ResultSetData, got %T", resp.Data) + } + if len(resultSets) != 2 { + t.Fatalf("expected 2 result sets, got %#v", resultSets) + } + if len(resultSets[1].Messages) != 1 || resultSets[1].Messages[0] != "PRINT generated sql" { + t.Fatalf("expected message-only result set to be preserved, got %#v", resultSets[1]) + } +} + func TestHandleRequest_UsesPinnedSessionForSessionScopedQueryAndExec(t *testing.T) { old := agentDriverType defer func() { agentDriverType = old }() @@ -329,6 +429,9 @@ func TestHandleRequest_UsesPinnedSessionForSessionScopedQueryAndExec(t *testing. if !queryResp.Success { t.Fatalf("session query failed: %s", queryResp.Error) } + if len(queryResp.Messages) != 0 { + t.Fatalf("expected empty default session messages, got %#v", queryResp.Messages) + } if fake.queryCalled || fake.queryContextCalled { t.Fatalf("expected session query to bypass database-level query path, got Query=%v QueryContext=%v", fake.queryCalled, fake.queryContextCalled) } diff --git a/internal/db/driver_agent_revisions_gen.go b/internal/db/driver_agent_revisions_gen.go index e05b8c9..12e7049 100644 --- a/internal/db/driver_agent_revisions_gen.go +++ b/internal/db/driver_agent_revisions_gen.go @@ -4,26 +4,26 @@ package db func init() { optionalDriverAgentRevisions = map[string]string{ - "mariadb": "src-b23e2ce1581a5064", - "oceanbase": "src-5067dbdf0ca7b9c4", - "diros": "src-db43faca6bf15d9b", - "starrocks": "src-01e9f06c0fab09d5", - "sphinx": "src-38ee5cae952cc809", - "sqlserver": "src-7a87f6deb816f110", - "sqlite": "src-d3d439cd788880e2", - "duckdb": "src-b11506b8706bfb73", - "dameng": "src-1638124bfd7fce09", - "kingbase": "src-fb3a404cf4eb1bd9", - "highgo": "src-72fe51afa884f6bc", - "vastbase": "src-3d48607603bfd8b7", - "opengauss": "src-709acf442f016e30", - "gaussdb": "src-f6beccc924d71031", - "iris": "src-9ebf5b970a73b341", - "mongodb": "src-367d11cd04e982c1", - "tdengine": "src-3c13c42f18ba01e1", - "iotdb": "src-5ba9da13c6a272f9", - "clickhouse": "src-99c8babfefdf142c", - "elasticsearch": "src-36b2e2b5f49db9d1", - "trino": "src-d264ceca132c185c", + "mariadb": "src-cc133d2524ceb634", + "oceanbase": "src-ac17327184366ff0", + "diros": "src-7d4fe439271d0c56", + "starrocks": "src-ce9ee22641a32f46", + "sphinx": "src-08f5ae54efb3d9df", + "sqlserver": "src-33b3b2c6dad5b3e6", + "sqlite": "src-96dfa25b3042b2d5", + "duckdb": "src-8804eb2cdbc89433", + "dameng": "src-016e77082aea6718", + "kingbase": "src-17728b2ebda94dc9", + "highgo": "src-da2e8a9d2e661d3b", + "vastbase": "src-da186ac367206c16", + "opengauss": "src-54dc852e4c502947", + "gaussdb": "src-3bbbffc6991dc8ae", + "iris": "src-e798713e492e9a09", + "mongodb": "src-2610395b35c2e708", + "tdengine": "src-779b9b537f08856f", + "iotdb": "src-7edea4aba8d4869e", + "clickhouse": "src-0197342ca5afa8b5", + "elasticsearch": "src-08e8e80cb17a409a", + "trino": "src-ba947f211ce7b19f", } } diff --git a/internal/db/optional_driver_agent_impl.go b/internal/db/optional_driver_agent_impl.go index 3ec52b8..07b0762 100644 --- a/internal/db/optional_driver_agent_impl.go +++ b/internal/db/optional_driver_agent_impl.go @@ -29,6 +29,7 @@ const ( optionalAgentMethodOpenSession = "openSession" optionalAgentMethodCloseSession = "closeSession" optionalAgentMethodQuery = "query" + optionalAgentMethodQueryMulti = "queryMulti" optionalAgentMethodStreamQuery = "streamQuery" optionalAgentMethodExec = "exec" optionalAgentMethodGetDatabases = "getDatabases" @@ -75,6 +76,7 @@ type optionalAgentResponse struct { Error string `json:"error,omitempty"` Data json.RawMessage `json:"data,omitempty"` Fields []string `json:"fields,omitempty"` + Messages []string `json:"messages,omitempty"` ChunkType string `json:"chunkType,omitempty"` RowsAffected int64 `json:"rowsAffected,omitempty"` } @@ -106,7 +108,7 @@ func ProbeOptionalDriverAgentMetadata(driverType string, executablePath string) }() var metadata OptionalDriverAgentMetadata - if err := client.callWithTimeout(optionalAgentRequest{Method: optionalAgentMethodMetadata}, &metadata, nil, nil, optionalAgentMetadataProbeTimeout); err != nil { + if err := client.callWithTimeout(optionalAgentRequest{Method: optionalAgentMethodMetadata}, &metadata, nil, nil, nil, optionalAgentMetadataProbeTimeout); err != nil { return OptionalDriverAgentMetadata{}, err } metadata.DriverType = normalizeRuntimeDriverType(metadata.DriverType) @@ -208,7 +210,7 @@ func (c *optionalDriverAgentClient) stderrText() string { return strings.TrimSpace(c.stderr.String()) } -func (c *optionalDriverAgentClient) call(req optionalAgentRequest, out interface{}, fields *[]string, rowsAffected *int64) error { +func (c *optionalDriverAgentClient) call(req optionalAgentRequest, out interface{}, fields *[]string, messages *[]string, rowsAffected *int64) error { c.mu.Lock() defer c.mu.Unlock() @@ -252,6 +254,9 @@ func (c *optionalDriverAgentClient) call(req optionalAgentRequest, out interface if fields != nil { *fields = resp.Fields } + if messages != nil { + *messages = append((*messages)[:0], resp.Messages...) + } if rowsAffected != nil { *rowsAffected = resp.RowsAffected } @@ -263,14 +268,14 @@ func (c *optionalDriverAgentClient) call(req optionalAgentRequest, out interface return nil } -func (c *optionalDriverAgentClient) callWithTimeout(req optionalAgentRequest, out interface{}, fields *[]string, rowsAffected *int64, timeout time.Duration) error { +func (c *optionalDriverAgentClient) callWithTimeout(req optionalAgentRequest, out interface{}, fields *[]string, messages *[]string, rowsAffected *int64, timeout time.Duration) error { if timeout <= 0 { - return c.call(req, out, fields, rowsAffected) + return c.call(req, out, fields, messages, rowsAffected) } errCh := make(chan error, 1) go func() { - errCh <- c.call(req, out, fields, rowsAffected) + errCh <- c.call(req, out, fields, messages, rowsAffected) }() timer := time.NewTimer(timeout) @@ -469,7 +474,7 @@ func (d *OptionalDriverAgentDB) Connect(config connection.ConnectionConfig) erro if err := client.call(optionalAgentRequest{ Method: optionalAgentMethodConnect, Config: &config, - }, nil, nil, nil); err != nil { + }, nil, nil, nil, nil); err != nil { _ = client.close() return err } @@ -482,7 +487,7 @@ func (d *OptionalDriverAgentDB) Close() error { if d.client == nil { return nil } - _ = d.client.call(optionalAgentRequest{Method: optionalAgentMethodClose}, nil, nil, nil) + _ = d.client.call(optionalAgentRequest{Method: optionalAgentMethodClose}, nil, nil, nil, nil) err := d.client.close() d.client = nil return err @@ -493,10 +498,87 @@ func (d *OptionalDriverAgentDB) Ping() error { if err != nil { return err } - return client.call(optionalAgentRequest{Method: optionalAgentMethodPing}, nil, nil, nil) + return client.call(optionalAgentRequest{Method: optionalAgentMethodPing}, nil, nil, nil, nil) } func (d *OptionalDriverAgentDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + data, fields, _, err := d.QueryContextWithMessages(ctx, query) + return data, fields, err +} + +func (d *OptionalDriverAgentDB) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { + if err := ctx.Err(); err != nil { + return nil, nil, nil, err + } + client, err := d.requireClient() + if err != nil { + return nil, nil, nil, err + } + var data []map[string]interface{} + var fields []string + var messages []string + if err := client.call(optionalAgentRequest{ + Method: optionalAgentMethodQuery, + Query: query, + TimeoutMs: timeoutMsFromContext(ctx), + }, &data, &fields, &messages, nil); err != nil { + return nil, nil, nil, err + } + return data, fields, messages, nil +} + +func (d *OptionalDriverAgentDB) Query(query string) ([]map[string]interface{}, []string, error) { + data, fields, _, err := d.QueryWithMessages(query) + return data, fields, err +} + +func (d *OptionalDriverAgentDB) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + client, err := d.requireClient() + if err != nil { + return nil, nil, nil, err + } + var data []map[string]interface{} + var fields []string + var messages []string + if err := client.call(optionalAgentRequest{ + Method: optionalAgentMethodQuery, + Query: query, + }, &data, &fields, &messages, nil); err != nil { + return nil, nil, nil, err + } + return data, fields, messages, nil +} + +func (d *OptionalDriverAgentDB) QueryMulti(query string) ([]connection.ResultSetData, error) { + results, _, err := d.QueryMultiWithMessages(query) + return results, err +} + +func (d *OptionalDriverAgentDB) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) { + client, err := d.requireClient() + if err != nil { + return nil, nil, err + } + var results []connection.ResultSetData + var messages []string + if err := client.call(optionalAgentRequest{ + Method: optionalAgentMethodQueryMulti, + Query: query, + }, &results, nil, &messages, nil); err != nil { + if isOptionalAgentMultiResultUnsupportedError(err) { + return nil, nil, nil + } + return nil, nil, err + } + return results, messages, nil +} + +func (d *OptionalDriverAgentDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) { + results, _, err := d.QueryMultiContextWithMessages(ctx, query) + return results, err +} + +func (d *OptionalDriverAgentDB) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) { if err := ctx.Err(); err != nil { return nil, nil, err } @@ -504,32 +586,19 @@ func (d *OptionalDriverAgentDB) QueryContext(ctx context.Context, query string) if err != nil { return nil, nil, err } - var data []map[string]interface{} - var fields []string + var results []connection.ResultSetData + var messages []string if err := client.call(optionalAgentRequest{ - Method: optionalAgentMethodQuery, + Method: optionalAgentMethodQueryMulti, Query: query, TimeoutMs: timeoutMsFromContext(ctx), - }, &data, &fields, nil); err != nil { + }, &results, nil, &messages, nil); err != nil { + if isOptionalAgentMultiResultUnsupportedError(err) { + return nil, nil, nil + } return nil, nil, err } - return data, fields, nil -} - -func (d *OptionalDriverAgentDB) Query(query string) ([]map[string]interface{}, []string, error) { - client, err := d.requireClient() - if err != nil { - return nil, nil, err - } - var data []map[string]interface{} - var fields []string - if err := client.call(optionalAgentRequest{ - Method: optionalAgentMethodQuery, - Query: query, - }, &data, &fields, nil); err != nil { - return nil, nil, err - } - return data, fields, nil + return results, messages, nil } func (d *OptionalDriverAgentDB) StreamQuery(query string, consumer QueryStreamConsumer) error { @@ -581,7 +650,7 @@ func (d *OptionalDriverAgentDB) ExecContext(ctx context.Context, query string) ( Method: optionalAgentMethodExec, Query: query, TimeoutMs: timeoutMsFromContext(ctx), - }, nil, nil, &affected); err != nil { + }, nil, nil, nil, &affected); err != nil { return 0, err } return affected, nil @@ -596,7 +665,7 @@ func (d *OptionalDriverAgentDB) Exec(query string) (int64, error) { if err := client.call(optionalAgentRequest{ Method: optionalAgentMethodExec, Query: query, - }, nil, nil, &affected); err != nil { + }, nil, nil, nil, &affected); err != nil { return 0, err } return affected, nil @@ -611,7 +680,7 @@ func (d *OptionalDriverAgentDB) OpenSessionExecer(ctx context.Context) (Statemen if err := client.call(optionalAgentRequest{ Method: optionalAgentMethodOpenSession, TimeoutMs: timeoutMsFromContext(ctx), - }, &sessionID, nil, nil); err != nil { + }, &sessionID, nil, nil, nil); err != nil { return nil, err } sessionID = strings.TrimSpace(sessionID) @@ -629,6 +698,10 @@ func (s *optionalDriverAgentSession) Query(query string) ([]map[string]interface return s.QueryContext(context.Background(), query) } +func (s *optionalDriverAgentSession) QueryWithMessages(query string) ([]map[string]interface{}, []string, []string, error) { + return s.QueryContextWithMessages(context.Background(), query) +} + func (s *optionalDriverAgentSession) StreamQuery(query string, consumer QueryStreamConsumer) error { return s.StreamQueryContext(context.Background(), query, consumer) } @@ -663,20 +736,26 @@ func (s *optionalDriverAgentSession) StreamQueryContext(ctx context.Context, que } func (s *optionalDriverAgentSession) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + data, fields, _, err := s.QueryContextWithMessages(ctx, query) + return data, fields, err +} + +func (s *optionalDriverAgentSession) QueryContextWithMessages(ctx context.Context, query string) ([]map[string]interface{}, []string, []string, error) { if err := s.ensureOpen(); err != nil { - return nil, nil, err + return nil, nil, nil, err } var data []map[string]interface{} var fields []string + var messages []string if err := s.client.call(optionalAgentRequest{ Method: optionalAgentMethodQuery, SessionID: s.sessionID, Query: query, TimeoutMs: timeoutMsFromContext(ctx), - }, &data, &fields, nil); err != nil { - return nil, nil, err + }, &data, &fields, &messages, nil); err != nil { + return nil, nil, nil, err } - return data, fields, nil + return data, fields, messages, nil } func (s *optionalDriverAgentSession) Exec(query string) (int64, error) { @@ -693,7 +772,7 @@ func (s *optionalDriverAgentSession) ExecContext(ctx context.Context, query stri SessionID: s.sessionID, Query: query, TimeoutMs: timeoutMsFromContext(ctx), - }, nil, nil, &affected); err != nil { + }, nil, nil, nil, &affected); err != nil { return 0, err } return affected, nil @@ -714,7 +793,7 @@ func (s *optionalDriverAgentSession) Close() error { return s.client.call(optionalAgentRequest{ Method: optionalAgentMethodCloseSession, SessionID: sessionID, - }, nil, nil, nil) + }, nil, nil, nil, nil) } func (s *optionalDriverAgentSession) ensureOpen() error { @@ -740,6 +819,19 @@ func isOptionalAgentStreamUnsupportedError(err error) bool { return strings.Contains(text, "不支持的方法") || strings.Contains(text, "不支持流式查询") } +func isOptionalAgentMultiResultUnsupportedError(err error) bool { + if err == nil { + return false + } + text := strings.TrimSpace(err.Error()) + if text == "" { + return false + } + return strings.Contains(text, "不支持的方法") || + strings.Contains(text, "不支持原生多结果集查询") || + strings.Contains(text, "不支持多结果集查询") +} + func (d *OptionalDriverAgentDB) GetDatabases() ([]string, error) { client, err := d.requireClient() if err != nil { @@ -748,7 +840,7 @@ func (d *OptionalDriverAgentDB) GetDatabases() ([]string, error) { var dbs []string if err := client.call(optionalAgentRequest{ Method: optionalAgentMethodGetDatabases, - }, &dbs, nil, nil); err != nil { + }, &dbs, nil, nil, nil); err != nil { return nil, err } return dbs, nil @@ -763,7 +855,7 @@ func (d *OptionalDriverAgentDB) GetTables(dbName string) ([]string, error) { if err := client.call(optionalAgentRequest{ Method: optionalAgentMethodGetTables, DBName: dbName, - }, &tables, nil, nil); err != nil { + }, &tables, nil, nil, nil); err != nil { return nil, err } return tables, nil @@ -779,7 +871,7 @@ func (d *OptionalDriverAgentDB) GetCreateStatement(dbName, tableName string) (st Method: optionalAgentMethodGetCreateStmt, DBName: dbName, TableName: tableName, - }, &sqlText, nil, nil); err != nil { + }, &sqlText, nil, nil, nil); err != nil { return "", err } return sqlText, nil @@ -795,7 +887,7 @@ func (d *OptionalDriverAgentDB) GetColumns(dbName, tableName string) ([]connecti Method: optionalAgentMethodGetColumns, DBName: dbName, TableName: tableName, - }, &columns, nil, nil); err != nil { + }, &columns, nil, nil, nil); err != nil { return nil, err } return columns, nil @@ -810,7 +902,7 @@ func (d *OptionalDriverAgentDB) GetAllColumns(dbName string) ([]connection.Colum if err := client.call(optionalAgentRequest{ Method: optionalAgentMethodGetAllColumns, DBName: dbName, - }, &columns, nil, nil); err != nil { + }, &columns, nil, nil, nil); err != nil { return nil, err } return columns, nil @@ -826,7 +918,7 @@ func (d *OptionalDriverAgentDB) GetIndexes(dbName, tableName string) ([]connecti Method: optionalAgentMethodGetIndexes, DBName: dbName, TableName: tableName, - }, &indexes, nil, nil); err != nil { + }, &indexes, nil, nil, nil); err != nil { return nil, err } return indexes, nil @@ -842,7 +934,7 @@ func (d *OptionalDriverAgentDB) GetForeignKeys(dbName, tableName string) ([]conn Method: optionalAgentMethodGetForeignKeys, DBName: dbName, TableName: tableName, - }, &keys, nil, nil); err != nil { + }, &keys, nil, nil, nil); err != nil { return nil, err } return keys, nil @@ -858,7 +950,7 @@ func (d *OptionalDriverAgentDB) GetTriggers(dbName, tableName string) ([]connect Method: optionalAgentMethodGetTriggers, DBName: dbName, TableName: tableName, - }, &triggers, nil, nil); err != nil { + }, &triggers, nil, nil, nil); err != nil { return nil, err } return triggers, nil @@ -883,7 +975,7 @@ func (d *OptionalDriverAgentDB) ApplyChanges(tableName string, changes connectio Method: optionalAgentMethodApplyChanges, TableName: tableName, Changes: &changes, - }, nil, nil, nil) + }, nil, nil, nil, nil) } func (d *OptionalDriverAgentDB) requireClient() (*optionalDriverAgentClient, error) { diff --git a/internal/db/optional_driver_agent_impl_test.go b/internal/db/optional_driver_agent_impl_test.go index 90600db..e1612ac 100644 --- a/internal/db/optional_driver_agent_impl_test.go +++ b/internal/db/optional_driver_agent_impl_test.go @@ -136,3 +136,68 @@ func TestOptionalDriverAgentClientCallStreamQueryConsumesChunks(t *testing.T) { t.Fatalf("请求未使用 streamQuery 方法: %s", stdin.String()) } } + +func TestOptionalDriverAgentDBQueryWithMessagesParsesAgentMessages(t *testing.T) { + var stdin optionalAgentTestWriteCloser + stdout := `{"id":1,"success":true,"data":[{"sql_text":"select 1"}],"fields":["sql_text"],"messages":["PRINT sql line 1","PRINT sql line 2"]}` + "\n" + + dbInst := &OptionalDriverAgentDB{ + driverType: "sqlserver", + client: &optionalDriverAgentClient{ + stdin: &stdin, + reader: bufio.NewReader(strings.NewReader(stdout)), + driver: "sqlserver", + }, + } + + rows, fields, messages, err := dbInst.QueryWithMessages("exec dbo.p_get_select") + if err != nil { + t.Fatalf("QueryWithMessages 返回错误: %v", err) + } + if len(rows) != 1 || rows[0]["sql_text"] != "select 1" { + t.Fatalf("查询结果异常: %#v", rows) + } + if len(fields) != 1 || fields[0] != "sql_text" { + t.Fatalf("字段异常: %#v", fields) + } + if len(messages) != 2 || messages[0] != "PRINT sql line 1" { + t.Fatalf("消息异常: %#v", messages) + } + if !strings.Contains(stdin.String(), `"method":"query"`) { + t.Fatalf("请求未使用 query 方法: %s", stdin.String()) + } +} + +func TestOptionalDriverAgentDBQueryMultiWithMessagesParsesResultSets(t *testing.T) { + var stdin optionalAgentTestWriteCloser + stdout := `{"id":1,"success":true,"data":[{"statementIndex":1,"rows":[{"name":"master"}],"columns":["name"]},{"statementIndex":1,"rows":[],"columns":[],"messages":["PRINT generated sql"]}],"messages":["batch top-level message"]}` + "\n" + + dbInst := &OptionalDriverAgentDB{ + driverType: "sqlserver", + client: &optionalDriverAgentClient{ + stdin: &stdin, + reader: bufio.NewReader(strings.NewReader(stdout)), + driver: "sqlserver", + }, + } + + resultSets, messages, err := dbInst.QueryMultiWithMessages("exec dbo.p_get_select") + if err != nil { + t.Fatalf("QueryMultiWithMessages 返回错误: %v", err) + } + if len(resultSets) != 2 { + t.Fatalf("结果集数量异常: %#v", resultSets) + } + if got := resultSets[0].Rows[0]["name"]; got != "master" { + t.Fatalf("首个结果集异常,got=%v", got) + } + if len(resultSets[1].Messages) != 1 || resultSets[1].Messages[0] != "PRINT generated sql" { + t.Fatalf("消息结果集异常: %#v", resultSets[1]) + } + if len(messages) != 1 || messages[0] != "batch top-level message" { + t.Fatalf("顶层消息异常: %#v", messages) + } + if !strings.Contains(stdin.String(), `"method":"queryMulti"`) { + t.Fatalf("请求未使用 queryMulti 方法: %s", stdin.String()) + } +} From e8cad189be8ed55713633a7b42dee0b6a049052d Mon Sep 17 00:00:00 2001 From: Syngnat Date: Tue, 23 Jun 2026 09:46:44 +0800 Subject: [PATCH 3/6] =?UTF-8?q?=F0=9F=90=9B=20fix(sqlserver):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E6=99=AE=E9=80=9A=E6=9F=A5=E8=AF=A2=E7=BB=93=E6=9E=9C?= =?UTF-8?q?=E8=A2=AB=E5=8E=9F=E7=94=9F=E5=A4=9A=E7=BB=93=E6=9E=9C=E9=9B=86?= =?UTF-8?q?=E5=90=83=E7=A9=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 对只读 SQL 的原生多结果集空返回增加顺序回退兜底 - 避免 optional driver-agent 成功返回空结果时前端只剩日志无结果集 - 补充 SQLServer 读查询空结果回退回归测试 --- internal/app/methods_db.go | 7 +++ internal/app/methods_db_multi_test.go | 78 +++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index 6c79970..ebb27d5 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -1000,6 +1000,13 @@ func (a *App) DBQueryMulti(config connection.ConnectionConfig, dbName string, qu return connection.QueryResult{Success: false, Message: err.Error(), QueryID: queryID} } + // 某些 optional driver-agent 的原生多结果集路径会异常返回“成功但无任何结果集”。 + // 对只读查询这是不可信信号,回退到逐条执行可以避免普通 SELECT 在结果面板中被吃空。 + if useNativeMultiResult && allReadOnly && results != nil && len(results) == 0 && len(resultMessages) == 0 { + logger.Warnf("DBQueryMulti 原生多结果集返回空结果,将回退逐条执行:%s SQL片段=%q", formatConnSummary(runConfig), sqlSnippet(query)) + results = nil + } + // 驱动支持多结果集,直接返回 if results != nil { return connection.QueryResult{Success: true, Data: results, Messages: resultMessages, QueryID: queryID} diff --git a/internal/app/methods_db_multi_test.go b/internal/app/methods_db_multi_test.go index 7bc13a5..0b4a382 100644 --- a/internal/app/methods_db_multi_test.go +++ b/internal/app/methods_db_multi_test.go @@ -34,6 +34,11 @@ type fakeNativeMultiResultDB struct { multiCalls int } +type fakeEmptyNativeMultiResultDB struct { + *fakeBatchWriteDB + multiCalls int +} + func (f *fakeNativeMultiResultDB) QueryMulti(query string) ([]connection.ResultSetData, error) { results, _, err := f.QueryMultiWithMessages(query) return results, err @@ -67,6 +72,28 @@ func (f *fakeNativeMultiResultDB) QueryMultiContextWithMessages(ctx context.Cont }}, append([]string(nil), messages...), nil } +func (f *fakeEmptyNativeMultiResultDB) QueryMulti(query string) ([]connection.ResultSetData, error) { + results, _, err := f.QueryMultiWithMessages(query) + return results, err +} + +func (f *fakeEmptyNativeMultiResultDB) QueryMultiWithMessages(query string) ([]connection.ResultSetData, []string, error) { + return f.QueryMultiContextWithMessages(context.Background(), query) +} + +func (f *fakeEmptyNativeMultiResultDB) QueryMultiContext(ctx context.Context, query string) ([]connection.ResultSetData, error) { + results, _, err := f.QueryMultiContextWithMessages(ctx, query) + return results, err +} + +func (f *fakeEmptyNativeMultiResultDB) QueryMultiContextWithMessages(ctx context.Context, query string) ([]connection.ResultSetData, []string, error) { + f.multiCalls++ + if err := f.queryErr[query]; err != nil { + return nil, nil, err + } + return []connection.ResultSetData{}, nil, nil +} + func (f *fakeBatchWriteDB) Connect(config connection.ConnectionConfig) error { return nil } @@ -1332,6 +1359,57 @@ func TestDBQueryMultiRunsSQLServerStatisticsBatchNatively(t *testing.T) { } } +func TestDBQueryMultiFallsBackWhenNativeReadOnlyBatchReturnsEmptyResults(t *testing.T) { + originalNewDatabaseFunc := newDatabaseFunc + t.Cleanup(func() { + newDatabaseFunc = originalNewDatabaseFunc + }) + + query := "SELECT 1 AS value" + baseDB := &fakeBatchWriteDB{ + queryMap: map[string][]map[string]interface{}{ + query: { + {"value": 1}, + }, + }, + fieldMap: map[string][]string{ + query: {"value"}, + }, + queryErr: map[string]error{}, + } + fakeDB := &fakeEmptyNativeMultiResultDB{fakeBatchWriteDB: baseDB} + newDatabaseFunc = func(dbType string) (db.Database, error) { + return fakeDB, nil + } + + app := NewAppWithSecretStore(secretstore.NewUnavailableStore("test")) + config := connection.ConnectionConfig{Type: "sqlserver", Host: "127.0.0.1", Port: 1433, User: "sa"} + + result := app.DBQueryMulti(config, "master", query, "sqlserver-empty-native-read-fallback-test") + if !result.Success { + t.Fatalf("expected DBQueryMulti success, got failure: %s", result.Message) + } + if fakeDB.multiCalls != 1 { + t.Fatalf("expected one native multi-result attempt, got %d", fakeDB.multiCalls) + } + if baseDB.session == nil { + t.Fatal("expected empty native result to fall back to pinned session query") + } + if baseDB.session.queryCalls != 1 { + t.Fatalf("expected fallback to query through pinned session once, got %d", baseDB.session.queryCalls) + } + resultSets, ok := result.Data.([]connection.ResultSetData) + if !ok { + t.Fatalf("expected []connection.ResultSetData, got %T", result.Data) + } + if len(resultSets) != 1 { + t.Fatalf("expected one fallback result set, got %#v", resultSets) + } + if got := resultSets[0].Rows[0]["value"]; got != 1 { + t.Fatalf("expected fallback SELECT result value=1, got %#v", got) + } +} + func TestDBQueryMultiUsesPinnedSessionForSequentialFallback(t *testing.T) { originalNewDatabaseFunc := newDatabaseFunc t.Cleanup(func() { From 3a00ae1f441f64d7ffdd8dff4610ca24f3259628 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Tue, 23 Jun 2026 09:55:41 +0800 Subject: [PATCH 4/6] =?UTF-8?q?=F0=9F=94=A7=20chore(wails):=20=E5=90=8C?= =?UTF-8?q?=E6=AD=A5=E6=95=B0=E6=8D=AE=E5=90=8C=E6=AD=A5=20targetSchema=20?= =?UTF-8?q?TS=20=E7=BB=91=E5=AE=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 为 sync 请求模型补充 targetSchema 字段映射 - 同步自动生成的 frontend wailsjs models.ts 绑定 - 更新 package.json.md5 生成校验文件 --- frontend/package.json.md5 | 2 +- frontend/wailsjs/go/models.ts | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/frontend/package.json.md5 b/frontend/package.json.md5 index 72cd156..50f303f 100755 --- a/frontend/package.json.md5 +++ b/frontend/package.json.md5 @@ -1 +1 @@ -1d8f9adbde8018f90d013cc740e0405b \ No newline at end of file +84ec3a6d42105c92f224232f0d83a33b \ No newline at end of file diff --git a/frontend/wailsjs/go/models.ts b/frontend/wailsjs/go/models.ts index b870e6a..113c599 100755 --- a/frontend/wailsjs/go/models.ts +++ b/frontend/wailsjs/go/models.ts @@ -1401,6 +1401,7 @@ export namespace sync { targetConfig: connection.ConnectionConfig; sourceDatabase?: string; targetDatabase?: string; + targetSchema?: string; tables: string[]; sourceQuery?: string; content?: string; @@ -1422,6 +1423,7 @@ export namespace sync { this.targetConfig = this.convertValues(source["targetConfig"], connection.ConnectionConfig); this.sourceDatabase = source["sourceDatabase"]; this.targetDatabase = source["targetDatabase"]; + this.targetSchema = source["targetSchema"]; this.tables = source["tables"]; this.sourceQuery = source["sourceQuery"]; this.content = source["content"]; From bc63311003a4f51161a7e0eae180d558b1123793 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Tue, 23 Jun 2026 10:46:35 +0800 Subject: [PATCH 5/6] =?UTF-8?q?=F0=9F=90=9B=20fix(oracle):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E6=99=AE=E9=80=9A=E6=9F=A5=E8=AF=A2=E9=87=8D=E5=A4=8D?= =?UTF-8?q?=E5=88=97=E8=87=AA=E5=8A=A8=E5=88=AB=E5=90=8D=E7=BC=BA=E5=A4=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 QueryEditor 查询计划阶段识别显式列与 alias.* 的重复列冲突 - Oracle 执行前自动为冲突显式列补充 _1 风格唯一别名 - 让 locator 与后续追加表达式复用改写后的可执行 SQL - 补充普通查询重复列自动别名的 Oracle 回归测试 --- .../QueryEditor.results-and-drop.test.tsx | 56 +++++++++++++ .../queryEditor/QueryEditorHelpers.ts | 80 +++++++++++++++++-- 2 files changed, 130 insertions(+), 6 deletions(-) diff --git a/frontend/src/components/QueryEditor.results-and-drop.test.tsx b/frontend/src/components/QueryEditor.results-and-drop.test.tsx index 313d9f2..fadc7aa 100644 --- a/frontend/src/components/QueryEditor.results-and-drop.test.tsx +++ b/frontend/src/components/QueryEditor.results-and-drop.test.tsx @@ -2878,6 +2878,62 @@ describe('QueryEditor external SQL save', () => { expect(messageApi.warning).not.toHaveBeenCalled(); }); + it('auto aliases Oracle duplicate explicit columns before alias star expansion', async () => { + storeState.connections[0].config.type = 'oracle'; + storeState.connections[0].config.database = 'APP'; + backendApp.DBQueryMulti.mockResolvedValueOnce({ + success: true, + data: [{ + columns: ['EHR_USERID_1', 'USERID', 'EHR_USERID', 'USERNAME'], + rows: [{ + EHR_USERID_1: 'emp-1', + USERID: 7, + EHR_USERID: 'emp-1', + USERNAME: 'alice', + }], + }], + }); + backendApp.DBGetColumns.mockResolvedValueOnce({ + success: true, + data: [ + { name: 'USERID', key: 'PRI' }, + { name: 'EHR_USERID', key: '' }, + { name: 'USERNAME', key: '' }, + ], + }); + + let renderer: ReactTestRenderer; + await act(async () => { + renderer = create(); + }); + + await act(async () => { + await findButton(renderer!, '运行').props.onClick(); + }); + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + + expect(dataGridState.latestProps?.readOnly).toBe(false); + expect(dataGridState.latestProps?.editLocator).toMatchObject({ + strategy: 'primary-key', + columns: ['USERID'], + valueColumns: ['USERID'], + writableColumns: { + USERID: 'USERID', + EHR_USERID: 'EHR_USERID', + USERNAME: 'USERNAME', + }, + readOnly: false, + }); + expect(String(backendApp.DBQueryMulti.mock.calls[0][2])).toContain('EHR_USERID AS EHR_USERID_1, a.*'); + expect(messageApi.warning).not.toHaveBeenCalled(); + }); + it.each([ 'mysql', 'mariadb', diff --git a/frontend/src/components/queryEditor/QueryEditorHelpers.ts b/frontend/src/components/queryEditor/QueryEditorHelpers.ts index a0c4444..817aa57 100644 --- a/frontend/src/components/queryEditor/QueryEditorHelpers.ts +++ b/frontend/src/components/queryEditor/QueryEditorHelpers.ts @@ -210,7 +210,13 @@ export const getLastIdentifierPart = (path: string): string => { return parts[parts.length - 1] || ''; }; -export const resolveSimpleSelectItemColumn = (item: string): { resultName: string; sourceName: string } | 'all' | undefined => { +export type SelectItemInfo = { + expression: string; + resultName: string; + sourceName?: string; +}; + +export const resolveSelectItemInfo = (item: string): SelectItemInfo | 'all' | undefined => { const text = String(item || '').trim(); if (!text) return undefined; if (text === '*' || /\.\s*\*$/.test(text)) return 'all'; @@ -232,10 +238,16 @@ export const resolveSimpleSelectItemColumn = (item: string): { resultName: strin } } - if (!SIMPLE_IDENTIFIER_PATH_RE.test(expr)) return undefined; - const sourceName = getLastIdentifierPart(expr); + if (!alias && !SIMPLE_IDENTIFIER_PATH_RE.test(expr)) return undefined; + const sourceName = SIMPLE_IDENTIFIER_PATH_RE.test(expr) ? getLastIdentifierPart(expr) : ''; const resultName = alias || sourceName; - return sourceName && resultName ? { resultName, sourceName } : undefined; + return resultName ? { expression: expr, resultName, sourceName: sourceName || undefined } : undefined; +}; + +export const resolveSimpleSelectItemColumn = (item: string): { resultName: string; sourceName: string } | 'all' | undefined => { + const resolved = resolveSelectItemInfo(item); + if (!resolved || resolved === 'all' || !resolved.sourceName) return resolved === 'all' ? 'all' : undefined; + return { resultName: resolved.resultName, sourceName: resolved.sourceName }; }; export const parseSimpleSelectInfo = (sql: string): SimpleSelectInfo | undefined => { @@ -354,6 +366,57 @@ export const rewriteOracleSelectAllWithExpressions = (sql: string, expressions: return `${prefix}${finalSelectItems.join(', ')}${fromKeyword}${tableText}${aliasClause}${parsedAlias.remainder}`; }; +export const rewriteOracleDuplicateSelectColumns = (sql: string, tableColumnNames: string[]): string | undefined => { + const metadataNames = new Set( + tableColumnNames + .map((name) => String(name || '').trim().toLowerCase()) + .filter(Boolean), + ); + if (metadataNames.size === 0) return undefined; + + const match = String(sql || '').match(/^(\s*SELECT\s+)([\s\S]+?)(\s+FROM\s+[\s\S]*)$/i); + if (!match) return undefined; + + const prefix = match[1]; + const selectList = match[2].trim(); + const rest = match[3]; + const selectItems = splitTopLevelComma(selectList); + if (selectItems.length === 0) return undefined; + + const parsedItems = selectItems.map((item) => ({ + raw: String(item || '').trimEnd(), + info: resolveSelectItemInfo(item), + })); + const hasWildcard = parsedItems.some(({ info }) => info === 'all'); + if (!hasWildcard) return undefined; + + const usedResultNames = new Set(metadataNames); + parsedItems.forEach(({ info }) => { + if (!info || info === 'all') return; + const normalizedResult = String(info.resultName || '').trim().toLowerCase(); + if (normalizedResult) usedResultNames.add(normalizedResult); + }); + + let changed = false; + const rewrittenItems = parsedItems.map(({ raw, info }) => { + if (!info || info === 'all') return raw; + const normalizedResult = String(info.resultName || '').trim().toLowerCase(); + if (!metadataNames.has(normalizedResult)) return raw; + + let nextIndex = 1; + let alias = `${info.resultName}_${nextIndex}`; + while (usedResultNames.has(alias.toLowerCase())) { + nextIndex++; + alias = `${info.resultName}_${nextIndex}`; + } + usedResultNames.add(alias.toLowerCase()); + changed = true; + return `${info.expression} AS ${alias}`; + }); + + return changed ? `${prefix}${rewrittenItems.join(', ')}${rest}` : undefined; +}; + export const findWritableResultColumnForSource = (writableColumns: Record, target: string): string | undefined => { const normalizedTarget = String(target || '').trim().toLowerCase(); return Object.entries(writableColumns || {}).find(([, sourceColumn]) => ( @@ -1968,6 +2031,11 @@ export const resolveQueryLocatorPlan = async ({ const tableColumns = resCols.data as ColumnDefinition[]; const tableColumnNames = tableColumns.map(getColumnDefinitionName).filter(Boolean); + let executableStatement = statement; + if (isOracleLikeDialect(dbType) && selectInfo.selectsAll) { + const rewritten = rewriteOracleDuplicateSelectColumns(executableStatement, tableColumnNames); + if (rewritten) executableStatement = rewritten; + } const primaryKeys = tableColumns .filter((column: any) => getColumnDefinitionKey(column) === 'PRI') .map(getColumnDefinitionName) @@ -2058,7 +2126,7 @@ export const resolveQueryLocatorPlan = async ({ ]; if (executableAppendExpressions.length > 0 && isOracleLikeDialect(dbType) && selectInfo.selectsBareAll) { - const rewritten = rewriteOracleSelectAllWithExpressions(statement, executableAppendExpressions); + const rewritten = rewriteOracleSelectAllWithExpressions(executableStatement, executableAppendExpressions); if (rewritten) { plan.executedSql = rewritten; return plan; @@ -2070,7 +2138,7 @@ export const resolveQueryLocatorPlan = async ({ return plan; } - plan.executedSql = appendQuerySelectExpressions(statement, executableAppendExpressions); + plan.executedSql = appendQuerySelectExpressions(executableStatement, executableAppendExpressions); return plan; } catch { const reason = translate('query_editor.message.read_only_table_locator_metadata_unavailable', { From 8da8cc7f9146f7e750b31f11ec5e0d6d676a1b10 Mon Sep 17 00:00:00 2001 From: Syngnat Date: Tue, 23 Jun 2026 12:14:27 +0800 Subject: [PATCH 6/6] =?UTF-8?q?=F0=9F=90=9B=20fix(mongodb):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=20DataGrid=20=E7=BC=96=E8=BE=91=E5=90=8E=20BSON=20?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E4=B8=A2=E5=A4=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 为 MongoDB 结果展示、单元格编辑和行编辑接入类型感知格式化与解析 - 支持 ObjectId、日期、Int32、Int64、Double、Decimal128、UUID 等常见类型保真 - 统一 v1/v2 驱动查询结果的 Extended JSON 输出与 ApplyChanges BSON 恢复 - 补充前端提交链路与后端类型转换回归测试 --- frontend/src/components/DataGrid.ddl.test.tsx | 42 +++ frontend/src/components/DataGrid.tsx | 75 +++- frontend/src/components/DataGridCore.tsx | 8 + frontend/src/utils/mongodb.test.ts | 42 ++- frontend/src/utils/mongodb.ts | 328 +++++++++++++++++- internal/db/mongodb_impl.go | 118 ++++--- internal/db/mongodb_impl_uri_test.go | 135 +++++++ internal/db/mongodb_impl_v1.go | 118 ++++--- internal/db/mongodb_impl_v1_uri_test.go | 135 +++++++ 9 files changed, 899 insertions(+), 102 deletions(-) diff --git a/frontend/src/components/DataGrid.ddl.test.tsx b/frontend/src/components/DataGrid.ddl.test.tsx index 5f7f937..1d1961b 100644 --- a/frontend/src/components/DataGrid.ddl.test.tsx +++ b/frontend/src/components/DataGrid.ddl.test.tsx @@ -12,6 +12,7 @@ import DataGrid, { import DataGridToolbarFrame from './DataGridToolbarFrame'; import { V2CellContextMenuView, V2ColumnHeaderContextMenuView, V2TableGroupContextMenuView } from './V2TableContextMenu'; import { setCurrentLanguage, t } from '../i18n'; +import { parseMongoEditedValue } from '../utils/mongodb'; import { DUCKDB_ROWID_LOCATOR_COLUMN, ORACLE_ROWID_LOCATOR_COLUMN } from '../utils/rowLocator'; const storeState = vi.hoisted(() => ({ @@ -648,6 +649,47 @@ describe('DataGrid commit change set', () => { }); }); + it('keeps MongoDB explicit typed edit values in the final commit payload', () => { + const result = buildDataGridCommitChangeSet({ + addedRows: [{ + [GONAVI_ROW_KEY]: 'new-1', + _id: '507f1f77bcf86cd799439013', + age: '{"$numberLong":"12"}', + ratio: '1.5', + }], + modifiedRows: {}, + deletedRowKeys: new Set(), + data: [], + editLocator: { + strategy: 'primary-key', + columns: ['_id'], + valueColumns: ['_id'], + readOnly: false, + }, + visibleColumnNames: ['_id', 'age', 'ratio'], + rowKeyToString, + normalizeCommitCellValue: (columnName, value) => parseMongoEditedValue( + columnName, + value, + columnName === 'ratio' ? { $numberDouble: '0.5' } : undefined, + ), + shouldCommitColumn: commitColumnGuard, + }); + + expect(result).toEqual({ + ok: true, + changes: { + inserts: [{ + _id: { $oid: '507f1f77bcf86cd799439013' }, + age: { $numberLong: '12' }, + ratio: { $numberDouble: '1.5' }, + }], + updates: [], + deletes: [], + }, + }); + }); + it('fails closed when no safe locator is available', () => { const result = buildDataGridCommitChangeSet({ addedRows: [], diff --git a/frontend/src/components/DataGrid.tsx b/frontend/src/components/DataGrid.tsx index 38a1f04..d78dbc9 100644 --- a/frontend/src/components/DataGrid.tsx +++ b/frontend/src/components/DataGrid.tsx @@ -158,6 +158,7 @@ import { useDataGridColumnResize } from './useDataGridColumnResize'; import { useDataGridPreviewPanel } from './useDataGridPreviewPanel'; import { buildTableExportTab } from '../utils/tableExportTab'; import { buildDataGridCssText } from './dataGridStyles'; +import { formatMongoEditableValue, parseMongoEditedValue } from '../utils/mongodb'; // --- Error Boundary --- import { @@ -533,6 +534,7 @@ const DataGrid: React.FC = ({ const supportsApproximateTableCount = dataSourceCaps.supportsApproximateTableCount; const supportsApproximateTotalPages = dataSourceCaps.supportsApproximateTotalPages; const dbType = dataSourceCaps.type; + const isMongoDBConnection = dbType === 'mongodb'; const isDuckDBConnection = dataSourceCaps.type === 'duckdb'; const supportsCopyInsert = dataSourceCaps.supportsCopyInsert; const supportsSqlQueryExport = dataSourceCaps.supportsSqlQueryExport; @@ -544,6 +546,33 @@ const DataGrid: React.FC = ({ const filteredExportSql = useMemo(() => String(exportSqlWithFilter || '').trim(), [exportSqlWithFilter]); const hasFilteredExportSql = exportScope === 'table' && filteredExportSql.length > 0; + const mongoAwareEditableText = useCallback((value: any): string => ( + isMongoDBConnection ? formatMongoEditableValue(value) : toEditableText(value) + ), [isMongoDBConnection]); + + const mongoAwareFormText = useCallback((value: any): string => ( + isMongoDBConnection ? formatMongoEditableValue(value) : toFormText(value) + ), [isMongoDBConnection]); + + const normalizeMongoEditedCellValue = useCallback((columnName: string, value: any, currentValue?: any) => ( + isMongoDBConnection ? parseMongoEditedValue(columnName, value, currentValue) : value + ), [isMongoDBConnection]); + + const normalizeMongoEditedRow = useCallback((row: any, currentRow?: any) => { + if (!isMongoDBConnection || !row || typeof row !== 'object') return row; + let changed = false; + const nextRow: any = { ...row }; + Object.keys(row).forEach((columnName) => { + if (columnName === GONAVI_ROW_KEY) return; + const normalizedValue = normalizeMongoEditedCellValue(columnName, row[columnName], currentRow?.[columnName]); + if (normalizedValue !== row[columnName]) { + nextRow[columnName] = normalizedValue; + changed = true; + } + }); + return changed ? nextRow : row; + }, [isMongoDBConnection, normalizeMongoEditedCellValue]); + // --- 主题样式变量(仅在 darkMode / opacity / blur 变化时重算) --- const themeStyles = useMemo(() => { const _getBg = (darkHex: string) => { @@ -679,7 +708,7 @@ const DataGrid: React.FC = ({ openBatchEditModal, closeBatchEditModal, } = useDataGridModalEditors({ - toEditableText, + toEditableText: mongoAwareEditableText, looksLikeJsonText, }); const [virtualEditingCell, setVirtualEditingCell] = useState(null); @@ -699,7 +728,7 @@ const DataGrid: React.FC = ({ updateFocusedCell, handleDataPanelFormatJson, } = useDataGridPreviewPanel({ - toEditableText, + toEditableText: mongoAwareEditableText, looksLikeJsonText, normalizeDateTimeString, }); @@ -954,6 +983,9 @@ const DataGrid: React.FC = ({ const normalizeCommitCellValue = useCallback( (columnName: string, value: any, mode: 'insert' | 'update') => { if (value === undefined) return undefined; + if (isMongoDBConnection) { + return parseMongoEditedValue(columnName, value, undefined); + } const normalizedName = String(columnName || '').trim(); const meta = columnMetaMap[normalizedName] || columnMetaMapByLowerName[normalizedName.toLowerCase()]; const temporal = isTemporalColumnType(meta?.type, dbType); @@ -977,7 +1009,7 @@ const DataGrid: React.FC = ({ return value; }, - [columnMetaMap, columnMetaMapByLowerName, dbType] + [columnMetaMap, columnMetaMapByLowerName, dbType, isMongoDBConnection] ); const openTableByName = useCallback((nextTableName: string) => { @@ -1567,19 +1599,23 @@ const DataGrid: React.FC = ({ const keyStr = rowKeyStr(rowKey); const isAdded = addedRows.some(r => r?.[GONAVI_ROW_KEY] === rowKey); if (isAdded) { - setAddedRows(prev => prev.map(r => r?.[GONAVI_ROW_KEY] === rowKey ? { ...r, ...row } : r)); + const currentAddedRow = addedRows.find(r => r?.[GONAVI_ROW_KEY] === rowKey); + const normalizedRow = normalizeMongoEditedRow(row, currentAddedRow); + setAddedRows(prev => prev.map(r => r?.[GONAVI_ROW_KEY] === rowKey ? { ...r, ...normalizedRow } : r)); return; } if (deletedRowKeys.has(keyStr)) return; // 查找原始行数据,对比是否真正有值变更 const originalRow = data.find(r => r?.[GONAVI_ROW_KEY] === rowKey); if (originalRow) { + const currentRow = modifiedRows[keyStr] ? { ...originalRow, ...modifiedRows[keyStr] } : originalRow; + const normalizedRow = normalizeMongoEditedRow(row, currentRow); const changedFields: Record = {}; - for (const col of Object.keys(row)) { + for (const col of Object.keys(normalizedRow)) { if (col === GONAVI_ROW_KEY) continue; if (!isWritableResultColumn(col, effectiveEditLocator)) continue; - if (!isCellValueEqualForDiff(originalRow[col], row[col])) { - changedFields[col] = row[col]; + if (!isCellValueEqualForDiff(originalRow[col], normalizedRow[col])) { + changedFields[col] = normalizedRow[col]; } } if (Object.keys(changedFields).length === 0) { @@ -1609,9 +1645,9 @@ const DataGrid: React.FC = ({ } return { ...prev, [keyStr]: newCols }; }); - setModifiedRows(prev => ({ ...prev, [keyStr]: row })); + setModifiedRows(prev => ({ ...prev, [keyStr]: normalizedRow })); } - }, [addedRows, data, rowKeyStr, deletedRowKeys, effectiveEditLocator]); + }, [addedRows, data, rowKeyStr, deletedRowKeys, effectiveEditLocator, modifiedRows, normalizeMongoEditedRow]); const handleDataPanelSave = useCallback(() => { if (!focusedCellInfo) return; @@ -1729,7 +1765,9 @@ const DataGrid: React.FC = ({ if (isDateTimeField) { setCellFieldValue(form, fieldName, parseToDayjs(raw, pickerType)); } else { - const initialValue = typeof raw === 'string' ? normalizeDateTimeString(raw) : raw; + const initialValue = isMongoDBConnection + ? mongoAwareEditableText(raw) + : (typeof raw === 'string' ? normalizeDateTimeString(raw) : raw); setCellFieldValue(form, fieldName, initialValue); } setVirtualEditingCell({ @@ -1738,7 +1776,7 @@ const DataGrid: React.FC = ({ title, columnType, }); - }, [canModifyData, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, dbType, form, openCellEditor, rowKeyStr]); + }, [canModifyData, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, dbType, form, isMongoDBConnection, mongoAwareEditableText, openCellEditor, rowKeyStr]); const handleVirtualCellActivate = useCallback((record: Item, dataIndex: string, title: React.ReactNode) => { if (!canModifyData) return; @@ -2014,7 +2052,7 @@ const DataGrid: React.FC = ({ const baseVal = (baseRow as any)?.[col]; const displayVal = (displayRow as any)?.[col]; baseRawMap[col] = baseVal; - displayMap[col] = toFormText(displayVal); + displayMap[col] = mongoAwareFormText(displayVal); // 日期时间类型: 将字符串值转为 dayjs 对象供 DatePicker 使用 const colMeta = columnMetaMap[col] || columnMetaMapByLowerName[col.toLowerCase()]; const rowPickerType = getTemporalPickerType(colMeta?.type, dbType, currentConnConfig); @@ -2022,7 +2060,7 @@ const DataGrid: React.FC = ({ const dVal = parseToDayjs(displayVal, rowPickerType); formMap[col] = dVal; } else { - formMap[col] = displayVal === null || displayVal === undefined ? undefined : toFormText(displayVal); + formMap[col] = displayVal === null || displayVal === undefined ? undefined : mongoAwareFormText(displayVal); } if (baseVal === null || baseVal === undefined) nullCols.add(col); }); @@ -2034,7 +2072,7 @@ const DataGrid: React.FC = ({ nullCols, formValues: formMap, }); - }, [addedRows, canModifyData, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, data, dbType, mergedDisplayData, openRowEditor, rowKeyStr, translateDataGrid, visibleColumnNames]); + }, [addedRows, canModifyData, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, data, dbType, mergedDisplayData, mongoAwareFormText, openRowEditor, rowKeyStr, translateDataGrid, visibleColumnNames]); const openCurrentViewRowEditor = useCallback(() => { if (!canModifyData) return; @@ -2192,6 +2230,7 @@ const DataGrid: React.FC = ({ const keyStr = rowEditorRowKey; if (!keyStr) return; const values = rowEditorForm.getFieldsValue(true) || {}; + const baseRawMap = rowEditorBaseRawRef.current || {}; const isAdded = addedRows.some(r => rowKeyStr(r?.[GONAVI_ROW_KEY]) === keyStr); if (isAdded) { @@ -2199,12 +2238,13 @@ const DataGrid: React.FC = ({ const convertedValues: Record = {}; Object.entries(values).forEach(([col, val]) => { if (!isWritableResultColumn(col, effectiveEditLocator)) return; + const baseVal = baseRawMap[col]; if (val && dayjs.isDayjs(val)) { const colMeta = columnMetaMap[col] || columnMetaMapByLowerName[col.toLowerCase()]; const rowPickerType = getTemporalPickerType(colMeta?.type, dbType, currentConnConfig); convertedValues[col] = formatFromDayjs(val as dayjs.Dayjs, rowPickerType); } else { - convertedValues[col] = val; + convertedValues[col] = normalizeMongoEditedCellValue(col, val, baseVal); } }); setAddedRows(prev => prev.map(r => rowKeyStr(r?.[GONAVI_ROW_KEY]) === keyStr ? { ...r, ...convertedValues } : r)); @@ -2212,7 +2252,6 @@ const DataGrid: React.FC = ({ return; } - const baseRawMap = rowEditorBaseRawRef.current || {}; const patch: Record = {}; visibleColumnNames.forEach((col) => { if (!isWritableResultColumn(col, effectiveEditLocator)) return; @@ -2222,6 +2261,8 @@ const DataGrid: React.FC = ({ const colMeta = columnMetaMap[col] || columnMetaMapByLowerName[col.toLowerCase()]; const rowPickerType = getTemporalPickerType(colMeta?.type, dbType, currentConnConfig); nextVal = formatFromDayjs(nextVal as dayjs.Dayjs, rowPickerType); + } else { + nextVal = normalizeMongoEditedCellValue(col, nextVal, baseRawMap[col]); } const baseVal = baseRawMap[col]; if (!isCellValueEqualForDiff(baseVal, nextVal)) patch[col] = nextVal; @@ -2235,7 +2276,7 @@ const DataGrid: React.FC = ({ }); closeRowEditor(); - }, [addedRows, closeRowEditor, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, dbType, effectiveEditLocator, rowEditorForm, rowEditorRowKey, rowKeyStr, visibleColumnNames]); + }, [addedRows, closeRowEditor, columnMetaMap, columnMetaMapByLowerName, currentConnConfig, dbType, effectiveEditLocator, normalizeMongoEditedCellValue, rowEditorForm, rowEditorRowKey, rowKeyStr, visibleColumnNames]); const enableVirtual = isTableSurfaceActive; diff --git a/frontend/src/components/DataGridCore.tsx b/frontend/src/components/DataGridCore.tsx index 047a3fa..b341b23 100644 --- a/frontend/src/components/DataGridCore.tsx +++ b/frontend/src/components/DataGridCore.tsx @@ -73,6 +73,7 @@ import { } from './dataGridClipboardExport'; import { applyNoAutoCapAttributesWithin, noAutoCapInputProps } from '../utils/inputAutoCap'; import { DEFAULT_SHORTCUT_OPTIONS, getShortcutPlatform, resolveShortcutDisplay } from '../utils/shortcuts'; +import { formatMongoValueForDisplay } from '../utils/mongodb'; import { TEMPORAL_FORMATS, formatFromDayjs, @@ -355,6 +356,10 @@ export const formatCellDisplayText = (val: any, columnType?: string, connectionC if (val === null) return 'NULL'; const bitText = normalizeBitHexDisplayText(val, columnType); if (bitText !== null) return bitText; + if (String(connectionConfig?.type || '').trim().toLowerCase() === 'mongodb') { + const mongoText = formatMongoValueForDisplay(val); + return mongoText.length > TABLE_CELL_PREVIEW_MAX_CHARS ? `${mongoText.slice(0, TABLE_CELL_PREVIEW_MAX_CHARS)}…` : mongoText; + } if (typeof val === 'object') { if (!Array.isArray(val) && !isPlainObject(val)) { return String(val); @@ -398,6 +403,9 @@ const formatClipboardCellText = (val: any, columnType?: string, connectionConfig if (val === null || val === undefined) return 'NULL'; const bitText = normalizeBitHexDisplayText(val, columnType); if (bitText !== null) return bitText; + if (String(connectionConfig?.type || '').trim().toLowerCase() === 'mongodb') { + return formatMongoValueForDisplay(val); + } if (typeof val === 'string') { const oceanBaseDateOnly = normalizeOceanBaseOracleDateDisplayText(val, columnType, connectionConfig); if (oceanBaseDateOnly !== null) return oceanBaseDateOnly; diff --git a/frontend/src/utils/mongodb.test.ts b/frontend/src/utils/mongodb.test.ts index 52dcbf7..25d9a3d 100644 --- a/frontend/src/utils/mongodb.test.ts +++ b/frontend/src/utils/mongodb.test.ts @@ -1,6 +1,12 @@ import { describe, expect, it } from 'vitest'; -import { applyMongoQueryAutoLimit, buildMongoFindCommand, convertMongoShellToJsonCommand } from './mongodb'; +import { + applyMongoQueryAutoLimit, + buildMongoFindCommand, + convertMongoShellToJsonCommand, + formatMongoEditableValue, + parseMongoEditedValue, +} from './mongodb'; const parseCommand = (command: string | undefined) => JSON.parse(command || '{}'); @@ -134,3 +140,37 @@ describe('buildMongoFindCommand', () => { }); }); }); + +describe('Mongo edit value helpers', () => { + it('formats common extended JSON wrappers to editable literals', () => { + expect(formatMongoEditableValue({ $oid: '507f1f77bcf86cd799439011' })).toBe('ObjectId("507f1f77bcf86cd799439011")'); + expect(formatMongoEditableValue({ $date: { $numberLong: '1719100800000' } })).toBe('ISODate("2024-06-23T00:00:00.000Z")'); + expect(formatMongoEditableValue({ $numberInt: '7' })).toBe('NumberInt(7)'); + expect(formatMongoEditableValue({ $numberLong: '8' })).toBe('NumberLong("8")'); + expect(formatMongoEditableValue({ $numberDouble: '1.5' })).toBe('1.5'); + expect(formatMongoEditableValue({ $numberDecimal: '9.99' })).toBe('NumberDecimal("9.99")'); + expect(formatMongoEditableValue({ + $binary: { + base64: 'EjRWeBI0RniSNFZ4EjRWeA==', + subType: '04', + }, + })).toBe('UUID("12345678-1234-4678-9234-567812345678")'); + }); + + it('parses typed Mongo edit text back to extended JSON wrappers', () => { + expect(parseMongoEditedValue('_id', '507f1f77bcf86cd799439011')).toEqual({ $oid: '507f1f77bcf86cd799439011' }); + expect(parseMongoEditedValue('createdAt', '2024-06-23T00:00:00.000Z', { $date: { $numberLong: '1719100800000' } })).toEqual({ + $date: { $numberLong: '1719100800000' }, + }); + expect(parseMongoEditedValue('count32', '7', { $numberInt: '1' })).toEqual({ $numberInt: '7' }); + expect(parseMongoEditedValue('count64', '8', { $numberLong: '1' })).toEqual({ $numberLong: '8' }); + expect(parseMongoEditedValue('ratio', '1.5', { $numberDouble: '0.5' })).toEqual({ $numberDouble: '1.5' }); + expect(parseMongoEditedValue('price', '9.99', { $numberDecimal: '1.23' })).toEqual({ $numberDecimal: '9.99' }); + expect(parseMongoEditedValue('uid', 'UUID("12345678-1234-4678-9234-567812345678")')).toEqual({ + $binary: { + base64: 'EjRWeBI0RniSNFZ4EjRWeA==', + subType: '04', + }, + }); + }); +}); diff --git a/frontend/src/utils/mongodb.ts b/frontend/src/utils/mongodb.ts index 08bd14e..db30bf6 100644 --- a/frontend/src/utils/mongodb.ts +++ b/frontend/src/utils/mongodb.ts @@ -16,8 +16,168 @@ type ShellConvertResult = { }; const HEX24_RE = /^[0-9a-fA-F]{24}$/; +const UUID_RE = /^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$/; const INTEGER_RE = /^[+-]?\d+$/; const FLOAT_RE = /^[+-]?(?:\d+\.\d+|\d+\.|\.\d+)$/; +const SCIENTIFIC_RE = /^[+-]?(?:\d+(?:\.\d+)?|\.\d+)[eE][+-]?\d+$/; + +const isPlainMongoObject = (value: unknown): value is Record => ( + !!value && typeof value === 'object' && !Array.isArray(value) +); + +const getSingleMongoOperatorEntry = (value: unknown): [string, unknown] | null => { + if (!isPlainMongoObject(value)) return null; + const entries = Object.entries(value); + if (entries.length !== 1) return null; + return entries[0] || null; +}; + +const byteArrayToBase64 = (bytes: Uint8Array): string => { + const BufferCtor = (globalThis as any)?.Buffer; + if (BufferCtor) { + return BufferCtor.from(bytes).toString('base64'); + } + let binary = ''; + bytes.forEach((byte) => { + binary += String.fromCharCode(byte); + }); + return globalThis.btoa(binary); +}; + +const base64ToByteArray = (base64: string): Uint8Array => { + const BufferCtor = (globalThis as any)?.Buffer; + if (BufferCtor) { + return Uint8Array.from(BufferCtor.from(base64, 'base64')); + } + const binary = globalThis.atob(base64); + const bytes = new Uint8Array(binary.length); + for (let index = 0; index < binary.length; index += 1) { + bytes[index] = binary.charCodeAt(index); + } + return bytes; +}; + +const uuidToBytes = (uuid: string): Uint8Array => { + const hex = String(uuid || '').trim().replace(/-/g, '').toLowerCase(); + const bytes = new Uint8Array(16); + for (let index = 0; index < 16; index += 1) { + bytes[index] = Number.parseInt(hex.slice(index * 2, index * 2 + 2), 16); + } + return bytes; +}; + +const bytesToUuid = (bytes: Uint8Array): string => { + const hex = Array.from(bytes).map((byte) => byte.toString(16).padStart(2, '0')).join(''); + if (hex.length !== 32) return ''; + return [ + hex.slice(0, 8), + hex.slice(8, 12), + hex.slice(12, 16), + hex.slice(16, 20), + hex.slice(20, 32), + ].join('-'); +}; + +const buildMongoBinaryUUID = (uuidText: string): { $binary: { base64: string; subType: string } } => ({ + $binary: { + base64: byteArrayToBase64(uuidToBytes(uuidText)), + subType: '04', + }, +}); + +const buildMongoDateLiteralText = (raw?: unknown): string => { + const millis = typeof raw === 'object' && raw && !Array.isArray(raw) + ? parseMongoDateToMillis((raw as Record)?.$numberLong ?? raw) + : parseMongoDateToMillis(raw); + if (millis !== null) { + return new Date(millis).toISOString(); + } + return String(raw ?? ''); +}; + +const buildMongoBinaryLiteralText = (raw: unknown): string | null => { + if (!isPlainMongoObject(raw)) return null; + const binary = raw.$binary; + if (!isPlainMongoObject(binary)) return null; + const subType = String(binary.subType ?? '').trim().toLowerCase(); + const base64 = String(binary.base64 ?? '').trim(); + if (subType !== '04' || !base64) return null; + try { + const uuidText = bytesToUuid(base64ToByteArray(base64)); + return UUID_RE.test(uuidText) ? `UUID("${uuidText}")` : null; + } catch { + return null; + } +}; + +const looksLikeExplicitMongoTypedLiteral = (raw: string): boolean => ( + /^(?:ObjectId|ISODate|NumberInt|NumberLong|NumberDouble|NumberDecimal|UUID|MaxKey|MinKey)\s*\(/i.test(String(raw || '').trim()) +); + +const looksLikeMongoStructuredLiteral = (raw: string): boolean => { + const text = String(raw || '').trim(); + if (!text) return false; + const first = text[0]; + const last = text[text.length - 1]; + return (first === '{' && last === '}') || (first === '[' && last === ']'); +}; + +type MongoValueKind = + | 'nullish' + | 'string' + | 'boolean' + | 'number' + | 'object' + | 'array' + | 'objectId' + | 'date' + | 'int32' + | 'int64' + | 'double' + | 'decimal128' + | 'uuid' + | 'binary' + | 'maxKey' + | 'minKey'; + +const resolveMongoValueKind = (value: unknown): MongoValueKind => { + if (value === null || typeof value === 'undefined') return 'nullish'; + if (Array.isArray(value)) return 'array'; + if (typeof value === 'string') return 'string'; + if (typeof value === 'boolean') return 'boolean'; + if (typeof value === 'number') return 'number'; + const singleEntry = getSingleMongoOperatorEntry(value); + if (singleEntry) { + switch (singleEntry[0]) { + case '$oid': + return 'objectId'; + case '$date': + return 'date'; + case '$numberInt': + return 'int32'; + case '$numberLong': + return 'int64'; + case '$numberDouble': + return 'double'; + case '$numberDecimal': + return 'decimal128'; + case '$binary': { + const binary = singleEntry[1]; + if (isPlainMongoObject(binary) && String(binary.subType ?? '').trim().toLowerCase() === '04') { + return 'uuid'; + } + return 'binary'; + } + case '$maxKey': + return 'maxKey'; + case '$minKey': + return 'minKey'; + default: + break; + } + } + return typeof value === 'object' ? 'object' : 'string'; +}; const escapeRegex = (raw: string) => raw.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); @@ -69,13 +229,31 @@ const parseBooleanLiteral = (raw: string): boolean | null => { return null; }; +const normalizeMongoDoubleLiteral = (raw: string): string | null => { + const text = String(raw || '').trim(); + if (!text) return null; + const lower = text.toLowerCase(); + if (lower === 'nan') return 'NaN'; + if (lower === 'infinity' || lower === '+infinity') return 'Infinity'; + if (lower === '-infinity') return '-Infinity'; + if (INTEGER_RE.test(text) || FLOAT_RE.test(text) || SCIENTIFIC_RE.test(text)) { + const parsed = Number(text); + return Number.isFinite(parsed) ? String(parsed) : null; + } + return null; +}; + const normalizeExtendedJSON = (raw: string): string => { let text = String(raw || ''); text = text.replace(/ObjectId\s*\(\s*["']([0-9a-fA-F]{24})["']\s*\)/g, (_m, oid: string) => JSON.stringify({ $oid: oid })); text = text.replace(/ISODate\s*\(\s*["']([^"']+)["']\s*\)/g, (_m, dateText: string) => JSON.stringify(buildMongoExtendedDate(dateText))); text = text.replace(/NumberLong\s*\(\s*["']?([+-]?\d+)["']?\s*\)/g, '{"$numberLong":"$1"}'); text = text.replace(/NumberInt\s*\(\s*["']?([+-]?\d+)["']?\s*\)/g, '{"$numberInt":"$1"}'); + text = text.replace(/NumberDouble\s*\(\s*["']?([^"')]+)["']?\s*\)/g, '{"$numberDouble":"$1"}'); text = text.replace(/NumberDecimal\s*\(\s*["']?([+-]?(?:\d+(?:\.\d+)?|\.\d+))["']?\s*\)/g, '{"$numberDecimal":"$1"}'); + text = text.replace(/UUID\s*\(\s*["']([0-9a-fA-F-]{36})["']\s*\)/g, (_m, uuidText: string) => JSON.stringify(buildMongoBinaryUUID(uuidText))); + text = text.replace(/MaxKey\s*\(\s*\)/g, '{"$maxKey":1}'); + text = text.replace(/MinKey\s*\(\s*\)/g, '{"$minKey":1}'); return text; }; @@ -130,21 +308,39 @@ const evalMongoLikeLiteral = (raw: string): unknown => { if (!INTEGER_RE.test(text)) throw new Error(`NumberLong invalid value: ${text}`); return { $numberLong: text }; }; + const NumberDouble = (value: unknown) => { + const normalized = normalizeMongoDoubleLiteral(String(value ?? '').trim()); + if (!normalized) throw new Error(`NumberDouble invalid value: ${String(value)}`); + return { $numberDouble: normalized }; + }; const NumberDecimal = (value: unknown) => { const text = String(value ?? '').trim(); if (!text) throw new Error('NumberDecimal invalid value'); return { $numberDecimal: text }; }; + const UUID = (value: unknown) => { + const text = String(value ?? '').trim().replace(/^['"]|['"]$/g, ''); + if (!UUID_RE.test(text)) { + throw new Error(`UUID invalid value: ${text}`); + } + return buildMongoBinaryUUID(text.toLowerCase()); + }; + const MaxKey = () => ({ $maxKey: 1 }); + const MinKey = () => ({ $minKey: 1 }); const parser = new Function( 'ObjectId', 'ISODate', 'NumberInt', 'NumberLong', + 'NumberDouble', 'NumberDecimal', + 'UUID', + 'MaxKey', + 'MinKey', '"use strict"; return (' + expression + ');', ); - const evaluated = parser(ObjectId, ISODate, NumberInt, NumberLong, NumberDecimal); + const evaluated = parser(ObjectId, ISODate, NumberInt, NumberLong, NumberDouble, NumberDecimal, UUID, MaxKey, MinKey); return normalizeEvaluatedMongoValue(evaluated); }; @@ -183,6 +379,135 @@ const parseMongoJSONValue = (raw: string): unknown => { } }; +export const formatMongoValueForDisplay = (value: unknown): string => { + if (value === null) return 'NULL'; + if (typeof value === 'undefined') return ''; + const singleEntry = getSingleMongoOperatorEntry(value); + if (singleEntry) { + switch (singleEntry[0]) { + case '$oid': + return `ObjectId("${String(singleEntry[1] ?? '')}")`; + case '$date': + return `ISODate("${buildMongoDateLiteralText(singleEntry[1])}")`; + case '$numberInt': + return `NumberInt(${String(singleEntry[1] ?? '')})`; + case '$numberLong': + return `NumberLong("${String(singleEntry[1] ?? '')}")`; + case '$numberDouble': + return String(singleEntry[1] ?? ''); + case '$numberDecimal': + return `NumberDecimal("${String(singleEntry[1] ?? '')}")`; + case '$binary': { + const binaryText = buildMongoBinaryLiteralText(value); + if (binaryText) return binaryText; + break; + } + case '$maxKey': + return 'MaxKey()'; + case '$minKey': + return 'MinKey()'; + default: + break; + } + } + if (Array.isArray(value) || isPlainMongoObject(value)) { + try { + return JSON.stringify(value); + } catch { + return String(value); + } + } + return String(value); +}; + +export const formatMongoEditableValue = (value: unknown): string => { + if (value === null || typeof value === 'undefined') return ''; + const singleEntry = getSingleMongoOperatorEntry(value); + if (singleEntry) { + return formatMongoValueForDisplay(value); + } + if (Array.isArray(value) || isPlainMongoObject(value)) { + try { + return JSON.stringify(value, null, 2); + } catch { + return String(value); + } + } + return String(value); +}; + +export const parseMongoEditedValue = ( + columnName: string, + rawValue: unknown, + currentValue?: unknown, +): unknown => { + if (typeof rawValue !== 'string') return rawValue; + + const currentKind = resolveMongoValueKind(currentValue); + const text = rawValue.trim(); + const structuredLiteral = looksLikeMongoStructuredLiteral(rawValue); + const explicitLiteral = looksLikeExplicitMongoTypedLiteral(rawValue); + + if (structuredLiteral || explicitLiteral) { + return parseMongoJSONValue(rawValue); + } + + switch (currentKind) { + case 'objectId': + if (HEX24_RE.test(text)) return { $oid: text.toLowerCase() }; + return rawValue; + case 'date': + if (!text) return rawValue; + return buildMongoExtendedDate(text); + case 'int32': + if (INTEGER_RE.test(text)) return { $numberInt: String(Number.parseInt(text, 10)) }; + if (text.toLowerCase() === 'null') return null; + return rawValue; + case 'int64': + if (INTEGER_RE.test(text)) return { $numberLong: text }; + if (text.toLowerCase() === 'null') return null; + return rawValue; + case 'double': { + const normalized = normalizeMongoDoubleLiteral(text); + if (normalized !== null) return { $numberDouble: normalized }; + if (text.toLowerCase() === 'null') return null; + return rawValue; + } + case 'decimal128': + if (INTEGER_RE.test(text) || FLOAT_RE.test(text)) return { $numberDecimal: text }; + if (text.toLowerCase() === 'null') return null; + return rawValue; + case 'boolean': { + const boolValue = parseBooleanLiteral(text); + if (boolValue !== null) return boolValue; + if (text.toLowerCase() === 'null') return null; + return rawValue; + } + case 'number': + if (INTEGER_RE.test(text) || FLOAT_RE.test(text)) { + const parsed = Number(text); + return Number.isFinite(parsed) ? parsed : rawValue; + } + if (text.toLowerCase() === 'null') return null; + return rawValue; + case 'array': + case 'object': + case 'uuid': + case 'binary': + case 'maxKey': + case 'minKey': + if (text.toLowerCase() === 'null') return null; + return rawValue; + case 'string': + case 'nullish': + default: + if (String(columnName || '').trim() === '_id' && HEX24_RE.test(text)) { + return { $oid: text.toLowerCase() }; + } + return rawValue; + } +}; + const splitTopLevelComma = (raw: string): string[] => { const text = String(raw || ''); const result: string[] = []; @@ -1098,4 +1423,3 @@ export const convertMongoShellToJsonCommand = (raw: string): ShellConvertResult }; } }; - diff --git a/internal/db/mongodb_impl.go b/internal/db/mongodb_impl.go index e6e2e19..5d7f4c3 100644 --- a/internal/db/mongodb_impl.go +++ b/internal/db/mongodb_impl.go @@ -4,6 +4,7 @@ package db import ( "context" + "encoding/json" "fmt" "net" "net/url" @@ -1058,7 +1059,16 @@ func (m *MongoDB) execCount(ctx context.Context, cmd bson.D) ([]map[string]inter // convertBsonValue 将 BSON 特殊类型转换为前端可读的 JSON 友好值 func convertBsonValue(v interface{}) interface{} { switch val := v.(type) { + case map[string]interface{}: + result := make(map[string]interface{}, len(val)) + for k, v2 := range val { + result[k] = convertBsonValue(v2) + } + return result case bson.ObjectID: + if converted, ok := encodeMongoExtendedJSONFieldValue(val); ok { + return converted + } return val.Hex() case bson.M: result := make(map[string]interface{}, len(val)) @@ -1078,11 +1088,75 @@ func convertBsonValue(v interface{}) interface{} { result[i] = convertBsonValue(v2) } return result + case []interface{}: + result := make([]interface{}, len(val)) + for i, v2 := range val { + result[i] = convertBsonValue(v2) + } + return result default: + if !shouldEncodeMongoExtendedJSONFieldValue(v) { + return v + } + if converted, ok := encodeMongoExtendedJSONFieldValue(v); ok { + return converted + } return v } } +func shouldEncodeMongoExtendedJSONFieldValue(v interface{}) bool { + switch v.(type) { + case bson.DateTime, + bson.Decimal128, + bson.Binary, + bson.Regex, + bson.Timestamp, + bson.MaxKey, + bson.MinKey, + bson.Undefined, + int32, + int64, + []byte, + time.Time: + return true + default: + return false + } +} + +func encodeMongoExtendedJSONFieldValue(v interface{}) (interface{}, bool) { + payload, err := bson.MarshalExtJSON(bson.M{"v": v}, true, false) + if err != nil { + return nil, false + } + + var wrapped map[string]interface{} + if err := json.Unmarshal(payload, &wrapped); err != nil { + return nil, false + } + + converted, ok := wrapped["v"] + return converted, ok +} + +func decodeMongoExtendedJSONFieldValue(v interface{}) interface{} { + payload, err := json.Marshal(map[string]interface{}{"v": v}) + if err != nil { + return v + } + + var wrapped bson.M + if err := bson.UnmarshalExtJSON(payload, false, &wrapped); err != nil { + return v + } + + if converted, ok := wrapped["v"]; ok { + return converted + } + return v +} + func (m *MongoDB) Exec(query string) (int64, error) { _, _, err := m.Query(query) if err != nil { @@ -1220,7 +1294,7 @@ func (m *MongoDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDef func copyMongoChangeDocument(row map[string]interface{}) bson.M { doc := bson.M{} for k, v := range row { - doc[k] = v + doc[k] = decodeMongoExtendedJSONFieldValue(v) } return doc } @@ -1228,46 +1302,11 @@ func copyMongoChangeDocument(row map[string]interface{}) bson.M { func buildMongoChangeFilter(row map[string]interface{}) bson.M { filter := bson.M{} for k, v := range row { - filter[k] = normalizeMongoChangeFilterValue(k, v) + filter[k] = decodeMongoExtendedJSONFieldValue(v) } return filter } -func normalizeMongoChangeFilterValue(key string, value interface{}) interface{} { - if strings.TrimSpace(key) != "_id" { - return value - } - - switch val := value.(type) { - case map[string]interface{}: - if raw, ok := val["$oid"]; ok { - if oid, parsed := parseMongoObjectIDHex(fmt.Sprintf("%v", raw)); parsed { - return oid - } - } - case bson.M: - if raw, ok := val["$oid"]; ok { - if oid, parsed := parseMongoObjectIDHex(fmt.Sprintf("%v", raw)); parsed { - return oid - } - } - } - return value -} - -func parseMongoObjectIDHex(value string) (bson.ObjectID, bool) { - text := strings.TrimSpace(value) - var zero bson.ObjectID - if len(text) != 24 { - return zero, false - } - oid, err := bson.ObjectIDFromHex(text) - if err != nil { - return zero, false - } - return oid, true -} - // ApplyChanges implements batch changes for MongoDB func (m *MongoDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { if m.client == nil { @@ -1300,10 +1339,7 @@ func (m *MongoDB) ApplyChanges(tableName string, changes connection.ChangeSet) e return fmt.Errorf("更新操作需要主键条件") } - updateDoc := bson.M{"$set": bson.M{}} - for k, v := range update.Values { - updateDoc["$set"].(bson.M)[k] = v - } + updateDoc := bson.M{"$set": copyMongoChangeDocument(update.Values)} result, err := collection.UpdateOne(ctx, filter, updateDoc) if err != nil { diff --git a/internal/db/mongodb_impl_uri_test.go b/internal/db/mongodb_impl_uri_test.go index d69292e..ed20d01 100644 --- a/internal/db/mongodb_impl_uri_test.go +++ b/internal/db/mongodb_impl_uri_test.go @@ -128,3 +128,138 @@ func TestCopyMongoChangeDocument_LeavesInsertIDStringUntouched(t *testing.T) { t.Fatalf("insert _id string should stay string, got %T %v", doc["_id"], doc["_id"]) } } + +func TestConvertBsonValue_EncodesMongoTypedValues(t *testing.T) { + const oidHex = "507f1f77bcf86cd799439011" + oid, err := bson.ObjectIDFromHex(oidHex) + if err != nil { + t.Fatal(err) + } + decimalValue, err := bson.ParseDecimal128("12.34") + if err != nil { + t.Fatal(err) + } + + converted, ok := convertBsonValue(bson.M{ + "_id": oid, + "createdAt": bson.DateTime(1719100800000), + "count32": int32(7), + "count64": int64(8), + "ratio": 1.5, + "price": decimalValue, + "uid": bson.Binary{ + Subtype: 0x04, + Data: []byte{0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78}, + }, + "nested": bson.M{ + "innerId": oid, + }, + "items": bson.A{int32(1), int64(2)}, + }).(map[string]interface{}) + if !ok { + t.Fatalf("expected converted document map, got %T", converted) + } + + if converted["_id"].(map[string]interface{})["$oid"] != oidHex { + t.Fatalf("unexpected ObjectID wrapper: %#v", converted["_id"]) + } + if converted["createdAt"].(map[string]interface{})["$date"].(map[string]interface{})["$numberLong"] != "1719100800000" { + t.Fatalf("unexpected date wrapper: %#v", converted["createdAt"]) + } + if converted["count32"].(map[string]interface{})["$numberInt"] != "7" { + t.Fatalf("unexpected int32 wrapper: %#v", converted["count32"]) + } + if converted["count64"].(map[string]interface{})["$numberLong"] != "8" { + t.Fatalf("unexpected int64 wrapper: %#v", converted["count64"]) + } + if converted["ratio"] != 1.5 { + t.Fatalf("plain double should stay float64, got %T %#v", converted["ratio"], converted["ratio"]) + } + if converted["price"].(map[string]interface{})["$numberDecimal"] != "12.34" { + t.Fatalf("unexpected decimal wrapper: %#v", converted["price"]) + } + if converted["uid"].(map[string]interface{})["$binary"].(map[string]interface{})["base64"] != "EjRWeBI0VngSNFZ4EjRWeA==" { + t.Fatalf("unexpected binary wrapper: %#v", converted["uid"]) + } + + nestedDoc, ok := converted["nested"].(map[string]interface{}) + if !ok { + t.Fatalf("expected nested map, got %T", converted["nested"]) + } + if nestedDoc["innerId"].(map[string]interface{})["$oid"] != oidHex { + t.Fatalf("unexpected nested ObjectID wrapper: %#v", nestedDoc["innerId"]) + } + + items, ok := converted["items"].([]interface{}) + if !ok || len(items) != 2 { + t.Fatalf("unexpected items wrapper: %#v", converted["items"]) + } + if items[0].(map[string]interface{})["$numberInt"] != "1" || items[1].(map[string]interface{})["$numberLong"] != "2" { + t.Fatalf("unexpected numeric array wrappers: %#v", items) + } +} + +func TestCopyMongoChangeDocument_DecodesExtendedJSONWrappers(t *testing.T) { + doc := copyMongoChangeDocument(map[string]interface{}{ + "_id": map[string]interface{}{"$oid": "507f1f77bcf86cd799439011"}, + "createdAt": map[string]interface{}{"$date": map[string]interface{}{"$numberLong": "1719100800000"}}, + "count32": map[string]interface{}{"$numberInt": "7"}, + "count64": map[string]interface{}{"$numberLong": "8"}, + "ratio": map[string]interface{}{"$numberDouble": "1.5"}, + "price": map[string]interface{}{"$numberDecimal": "12.34"}, + "uid": map[string]interface{}{ + "$binary": map[string]interface{}{ + "base64": "EjRWeBI0VngSNFZ4EjRWeA==", + "subType": "04", + }, + }, + "nested": map[string]interface{}{ + "innerId": map[string]interface{}{"$oid": "507f1f77bcf86cd799439012"}, + }, + "items": []interface{}{ + map[string]interface{}{"$numberInt": "1"}, + map[string]interface{}{"$numberLong": "2"}, + }, + }) + + if _, ok := doc["_id"].(bson.ObjectID); !ok { + t.Fatalf("expected _id to decode to bson.ObjectID, got %T", doc["_id"]) + } + if got, ok := doc["createdAt"].(bson.DateTime); !ok || got != bson.DateTime(1719100800000) { + t.Fatalf("expected createdAt bson.DateTime, got %T %#v", doc["createdAt"], doc["createdAt"]) + } + if got, ok := doc["count32"].(int32); !ok || got != 7 { + t.Fatalf("expected count32 int32, got %T %#v", doc["count32"], doc["count32"]) + } + if got, ok := doc["count64"].(int64); !ok || got != 8 { + t.Fatalf("expected count64 int64, got %T %#v", doc["count64"], doc["count64"]) + } + if got, ok := doc["ratio"].(float64); !ok || got != 1.5 { + t.Fatalf("expected ratio float64, got %T %#v", doc["ratio"], doc["ratio"]) + } + if _, ok := doc["price"].(bson.Decimal128); !ok { + t.Fatalf("expected price bson.Decimal128, got %T", doc["price"]) + } + if binaryValue, ok := doc["uid"].(bson.Binary); !ok || binaryValue.Subtype != 0x04 || len(binaryValue.Data) != 16 { + t.Fatalf("expected uid bson.Binary UUID, got %T %#v", doc["uid"], doc["uid"]) + } + + nestedDoc, ok := doc["nested"].(bson.D) + if !ok || len(nestedDoc) != 1 || nestedDoc[0].Key != "innerId" { + t.Fatalf("expected nested bson.D, got %T %#v", doc["nested"], doc["nested"]) + } + if _, ok := nestedDoc[0].Value.(bson.ObjectID); !ok { + t.Fatalf("expected nested innerId ObjectID, got %T", nestedDoc[0].Value) + } + + items, ok := doc["items"].(bson.A) + if !ok || len(items) != 2 { + t.Fatalf("expected items bson.A, got %T %#v", doc["items"], doc["items"]) + } + if got, ok := items[0].(int32); !ok || got != 1 { + t.Fatalf("expected items[0] int32, got %T %#v", items[0], items[0]) + } + if got, ok := items[1].(int64); !ok || got != 2 { + t.Fatalf("expected items[1] int64, got %T %#v", items[1], items[1]) + } +} diff --git a/internal/db/mongodb_impl_v1.go b/internal/db/mongodb_impl_v1.go index ed49ed6..accf9b5 100644 --- a/internal/db/mongodb_impl_v1.go +++ b/internal/db/mongodb_impl_v1.go @@ -4,6 +4,7 @@ package db import ( "context" + "encoding/json" "fmt" "net" "net/url" @@ -1061,7 +1062,16 @@ func (m *MongoDBV1) execCount(ctx context.Context, cmd bson.D) ([]map[string]int // convertBsonValue 将 BSON 特殊类型转换为前端可读的 JSON 友好值 func convertBsonValue(v interface{}) interface{} { switch val := v.(type) { + case map[string]interface{}: + result := make(map[string]interface{}, len(val)) + for k, v2 := range val { + result[k] = convertBsonValue(v2) + } + return result case primitive.ObjectID: + if converted, ok := encodeMongoExtendedJSONFieldValue(val); ok { + return converted + } return val.Hex() case bson.M: result := make(map[string]interface{}, len(val)) @@ -1081,11 +1091,75 @@ func convertBsonValue(v interface{}) interface{} { result[i] = convertBsonValue(v2) } return result + case []interface{}: + result := make([]interface{}, len(val)) + for i, v2 := range val { + result[i] = convertBsonValue(v2) + } + return result default: + if !shouldEncodeMongoExtendedJSONFieldValue(v) { + return v + } + if converted, ok := encodeMongoExtendedJSONFieldValue(v); ok { + return converted + } return v } } +func shouldEncodeMongoExtendedJSONFieldValue(v interface{}) bool { + switch v.(type) { + case primitive.DateTime, + primitive.Decimal128, + primitive.Binary, + primitive.Regex, + primitive.Timestamp, + primitive.MaxKey, + primitive.MinKey, + primitive.Undefined, + int32, + int64, + []byte, + time.Time: + return true + default: + return false + } +} + +func encodeMongoExtendedJSONFieldValue(v interface{}) (interface{}, bool) { + payload, err := bson.MarshalExtJSON(bson.M{"v": v}, true, false) + if err != nil { + return nil, false + } + + var wrapped map[string]interface{} + if err := json.Unmarshal(payload, &wrapped); err != nil { + return nil, false + } + + converted, ok := wrapped["v"] + return converted, ok +} + +func decodeMongoExtendedJSONFieldValue(v interface{}) interface{} { + payload, err := json.Marshal(map[string]interface{}{"v": v}) + if err != nil { + return v + } + + var wrapped bson.M + if err := bson.UnmarshalExtJSON(payload, false, &wrapped); err != nil { + return v + } + + if converted, ok := wrapped["v"]; ok { + return converted + } + return v +} + func (m *MongoDBV1) Exec(query string) (int64, error) { _, _, err := m.Query(query) if err != nil { @@ -1223,7 +1297,7 @@ func (m *MongoDBV1) GetTriggers(dbName, tableName string) ([]connection.TriggerD func copyMongoChangeDocument(row map[string]interface{}) bson.M { doc := bson.M{} for k, v := range row { - doc[k] = v + doc[k] = decodeMongoExtendedJSONFieldValue(v) } return doc } @@ -1231,46 +1305,11 @@ func copyMongoChangeDocument(row map[string]interface{}) bson.M { func buildMongoChangeFilter(row map[string]interface{}) bson.M { filter := bson.M{} for k, v := range row { - filter[k] = normalizeMongoChangeFilterValue(k, v) + filter[k] = decodeMongoExtendedJSONFieldValue(v) } return filter } -func normalizeMongoChangeFilterValue(key string, value interface{}) interface{} { - if strings.TrimSpace(key) != "_id" { - return value - } - - switch val := value.(type) { - case map[string]interface{}: - if raw, ok := val["$oid"]; ok { - if oid, parsed := parseMongoObjectIDHex(fmt.Sprintf("%v", raw)); parsed { - return oid - } - } - case bson.M: - if raw, ok := val["$oid"]; ok { - if oid, parsed := parseMongoObjectIDHex(fmt.Sprintf("%v", raw)); parsed { - return oid - } - } - } - return value -} - -func parseMongoObjectIDHex(value string) (primitive.ObjectID, bool) { - text := strings.TrimSpace(value) - var zero primitive.ObjectID - if len(text) != 24 { - return zero, false - } - oid, err := primitive.ObjectIDFromHex(text) - if err != nil { - return zero, false - } - return oid, true -} - // ApplyChanges implements batch changes for MongoDB func (m *MongoDBV1) ApplyChanges(tableName string, changes connection.ChangeSet) error { if m.client == nil { @@ -1303,10 +1342,7 @@ func (m *MongoDBV1) ApplyChanges(tableName string, changes connection.ChangeSet) return fmt.Errorf("更新操作需要主键条件") } - updateDoc := bson.M{"$set": bson.M{}} - for k, v := range update.Values { - updateDoc["$set"].(bson.M)[k] = v - } + updateDoc := bson.M{"$set": copyMongoChangeDocument(update.Values)} result, err := collection.UpdateOne(ctx, filter, updateDoc) if err != nil { diff --git a/internal/db/mongodb_impl_v1_uri_test.go b/internal/db/mongodb_impl_v1_uri_test.go index b90deac..3a95726 100644 --- a/internal/db/mongodb_impl_v1_uri_test.go +++ b/internal/db/mongodb_impl_v1_uri_test.go @@ -87,3 +87,138 @@ func TestCopyMongoChangeDocumentV1_LeavesInsertIDStringUntouched(t *testing.T) { t.Fatalf("insert _id string should stay string, got %T %v", doc["_id"], doc["_id"]) } } + +func TestConvertBsonValueV1_EncodesMongoTypedValues(t *testing.T) { + const oidHex = "507f1f77bcf86cd799439011" + oid, err := primitive.ObjectIDFromHex(oidHex) + if err != nil { + t.Fatal(err) + } + decimalValue, err := primitive.ParseDecimal128("12.34") + if err != nil { + t.Fatal(err) + } + + converted, ok := convertBsonValue(bson.M{ + "_id": oid, + "createdAt": primitive.DateTime(1719100800000), + "count32": int32(7), + "count64": int64(8), + "ratio": 1.5, + "price": decimalValue, + "uid": primitive.Binary{ + Subtype: 0x04, + Data: []byte{0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78}, + }, + "nested": bson.M{ + "innerId": oid, + }, + "items": bson.A{int32(1), int64(2)}, + }).(map[string]interface{}) + if !ok { + t.Fatalf("expected converted document map, got %T", converted) + } + + if converted["_id"].(map[string]interface{})["$oid"] != oidHex { + t.Fatalf("unexpected ObjectID wrapper: %#v", converted["_id"]) + } + if converted["createdAt"].(map[string]interface{})["$date"].(map[string]interface{})["$numberLong"] != "1719100800000" { + t.Fatalf("unexpected date wrapper: %#v", converted["createdAt"]) + } + if converted["count32"].(map[string]interface{})["$numberInt"] != "7" { + t.Fatalf("unexpected int32 wrapper: %#v", converted["count32"]) + } + if converted["count64"].(map[string]interface{})["$numberLong"] != "8" { + t.Fatalf("unexpected int64 wrapper: %#v", converted["count64"]) + } + if converted["ratio"] != 1.5 { + t.Fatalf("plain double should stay float64, got %T %#v", converted["ratio"], converted["ratio"]) + } + if converted["price"].(map[string]interface{})["$numberDecimal"] != "12.34" { + t.Fatalf("unexpected decimal wrapper: %#v", converted["price"]) + } + if converted["uid"].(map[string]interface{})["$binary"].(map[string]interface{})["base64"] != "EjRWeBI0VngSNFZ4EjRWeA==" { + t.Fatalf("unexpected binary wrapper: %#v", converted["uid"]) + } + + nestedDoc, ok := converted["nested"].(map[string]interface{}) + if !ok { + t.Fatalf("expected nested map, got %T", converted["nested"]) + } + if nestedDoc["innerId"].(map[string]interface{})["$oid"] != oidHex { + t.Fatalf("unexpected nested ObjectID wrapper: %#v", nestedDoc["innerId"]) + } + + items, ok := converted["items"].([]interface{}) + if !ok || len(items) != 2 { + t.Fatalf("unexpected items wrapper: %#v", converted["items"]) + } + if items[0].(map[string]interface{})["$numberInt"] != "1" || items[1].(map[string]interface{})["$numberLong"] != "2" { + t.Fatalf("unexpected numeric array wrappers: %#v", items) + } +} + +func TestCopyMongoChangeDocumentV1_DecodesExtendedJSONWrappers(t *testing.T) { + doc := copyMongoChangeDocument(map[string]interface{}{ + "_id": map[string]interface{}{"$oid": "507f1f77bcf86cd799439011"}, + "createdAt": map[string]interface{}{"$date": map[string]interface{}{"$numberLong": "1719100800000"}}, + "count32": map[string]interface{}{"$numberInt": "7"}, + "count64": map[string]interface{}{"$numberLong": "8"}, + "ratio": map[string]interface{}{"$numberDouble": "1.5"}, + "price": map[string]interface{}{"$numberDecimal": "12.34"}, + "uid": map[string]interface{}{ + "$binary": map[string]interface{}{ + "base64": "EjRWeBI0VngSNFZ4EjRWeA==", + "subType": "04", + }, + }, + "nested": map[string]interface{}{ + "innerId": map[string]interface{}{"$oid": "507f1f77bcf86cd799439012"}, + }, + "items": []interface{}{ + map[string]interface{}{"$numberInt": "1"}, + map[string]interface{}{"$numberLong": "2"}, + }, + }) + + if _, ok := doc["_id"].(primitive.ObjectID); !ok { + t.Fatalf("expected _id to decode to primitive.ObjectID, got %T", doc["_id"]) + } + if got, ok := doc["createdAt"].(primitive.DateTime); !ok || got != primitive.DateTime(1719100800000) { + t.Fatalf("expected createdAt primitive.DateTime, got %T %#v", doc["createdAt"], doc["createdAt"]) + } + if got, ok := doc["count32"].(int32); !ok || got != 7 { + t.Fatalf("expected count32 int32, got %T %#v", doc["count32"], doc["count32"]) + } + if got, ok := doc["count64"].(int64); !ok || got != 8 { + t.Fatalf("expected count64 int64, got %T %#v", doc["count64"], doc["count64"]) + } + if got, ok := doc["ratio"].(float64); !ok || got != 1.5 { + t.Fatalf("expected ratio float64, got %T %#v", doc["ratio"], doc["ratio"]) + } + if _, ok := doc["price"].(primitive.Decimal128); !ok { + t.Fatalf("expected price primitive.Decimal128, got %T", doc["price"]) + } + if binaryValue, ok := doc["uid"].(primitive.Binary); !ok || binaryValue.Subtype != 0x04 || len(binaryValue.Data) != 16 { + t.Fatalf("expected uid primitive.Binary UUID, got %T %#v", doc["uid"], doc["uid"]) + } + + nestedDoc, ok := doc["nested"].(primitive.M) + if !ok { + t.Fatalf("expected nested primitive.M, got %T %#v", doc["nested"], doc["nested"]) + } + if _, ok := nestedDoc["innerId"].(primitive.ObjectID); !ok { + t.Fatalf("expected nested innerId ObjectID, got %T", nestedDoc["innerId"]) + } + + items, ok := doc["items"].(bson.A) + if !ok || len(items) != 2 { + t.Fatalf("expected items bson.A, got %T %#v", doc["items"], doc["items"]) + } + if got, ok := items[0].(int32); !ok || got != 1 { + t.Fatalf("expected items[0] int32, got %T %#v", items[0], items[0]) + } + if got, ok := items[1].(int64); !ok || got != 2 { + t.Fatalf("expected items[1] int64, got %T %#v", items[1], items[1]) + } +}