From eeaf3c658bcca9dfd4349b3f87cfaa0226dc0ede Mon Sep 17 00:00:00 2001 From: Syngnat Date: Tue, 2 Jun 2026 21:12:59 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(duckdb):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E5=94=AF=E4=B8=80=E7=B4=A2=E5=BC=95=E8=AF=86=E5=88=AB?= =?UTF-8?q?=E4=B8=8E=E5=A4=9A=E5=BA=93=E5=AF=B9=E8=B1=A1=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 合并 DuckDB 约束与索引元数据,恢复唯一索引表的可编辑判定 - 修复 attach 多库场景下 catalog/schema/table 定位混乱问题 - 统一前后端 qualified name 解析,支持带点和带引号对象名 - 补充 DuckDB 元数据与编辑链路回归测试 --- frontend/src/components/DataViewer.tsx | 31 +- frontend/src/components/DefinitionViewer.tsx | 9 +- frontend/src/components/Sidebar.tsx | 40 +- frontend/src/components/TableDesigner.tsx | 22 +- frontend/src/components/TriggerViewer.tsx | 9 +- .../src/components/tableDesignerSchemaSql.ts | 10 +- frontend/src/utils/qualifiedName.ts | 126 +++++ frontend/src/utils/sidebarMetadata.ts | 16 +- frontend/src/utils/sql.test.ts | 14 +- frontend/src/utils/sql.ts | 21 +- internal/app/db_context.go | 4 + internal/app/db_context_test.go | 15 + internal/app/methods_db.go | 4 + internal/db/duckdb_impl.go | 219 ++++---- internal/db/duckdb_metadata.go | 483 ++++++++++++++++++ internal/db/duckdb_metadata_test.go | 165 ++++++ 16 files changed, 969 insertions(+), 219 deletions(-) create mode 100644 frontend/src/utils/qualifiedName.ts create mode 100644 internal/db/duckdb_metadata.go create mode 100644 internal/db/duckdb_metadata_test.go diff --git a/frontend/src/components/DataViewer.tsx b/frontend/src/components/DataViewer.tsx index bcae350..1a82962 100644 --- a/frontend/src/components/DataViewer.tsx +++ b/frontend/src/components/DataViewer.tsx @@ -26,6 +26,7 @@ import { getColumnDefinitionName, getColumnDefinitionType, } from '../utils/columnDefinition'; +import { splitQualifiedNameLast, splitQualifiedNameSegments } from '../utils/qualifiedName'; type ViewerPaginationState = { current: number; @@ -171,33 +172,21 @@ const buildDataViewerBaseSelectSQL = ( return `SELECT ${alias}.*, ${alias}.ROWID AS ${rowIDAlias} FROM ${quotedTableName} ${alias} ${whereSQL}`; }; -const normalizeDuckDBIdentifier = (raw: string): string => { - const text = String(raw || '').trim(); - if (text.length >= 2) { - const first = text[0]; - const last = text[text.length - 1]; - if ((first === '"' && last === '"') || (first === '`' && last === '`')) { - return text.slice(1, -1).trim(); - } - } - return text; -}; - const resolveDuckDBSchemaAndTable = (dbName: string, tableName: string) => { const rawTable = String(tableName || '').trim(); if (!rawTable) return { schemaName: 'main', pureTableName: '' }; - const parts = rawTable.split('.'); - if (parts.length >= 2) { - const pureTableName = normalizeDuckDBIdentifier(parts[parts.length - 1]); - const schemaName = normalizeDuckDBIdentifier(parts[parts.length - 2]); - if (schemaName && pureTableName) { - return { schemaName, pureTableName }; - } + const segments = splitQualifiedNameSegments(rawTable); + if (segments.length >= 2) { + return { + schemaName: segments[segments.length - 2], + pureTableName: segments[segments.length - 1], + }; } - const fallbackSchema = normalizeDuckDBIdentifier(String(dbName || '').trim()) || 'main'; - return { schemaName: fallbackSchema, pureTableName: normalizeDuckDBIdentifier(rawTable) }; + const fallbackParsed = splitQualifiedNameLast(String(dbName || '').trim()); + const fallbackSchema = fallbackParsed.objectName || String(dbName || '').trim() || 'main'; + return { schemaName: fallbackSchema, pureTableName: segments[0] || rawTable }; }; const escapeSQLLiteral = (value: string): string => String(value || '').replace(/'/g, "''"); diff --git a/frontend/src/components/DefinitionViewer.tsx b/frontend/src/components/DefinitionViewer.tsx index e063ee6..13a2b83 100644 --- a/frontend/src/components/DefinitionViewer.tsx +++ b/frontend/src/components/DefinitionViewer.tsx @@ -6,6 +6,7 @@ import { useStore } from '../store'; import { DBQuery } from '../../wailsjs/go/app/App'; import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; import { normalizeOceanBaseProtocol } from '../utils/oceanBaseProtocol'; +import { splitQualifiedNameLast } from '../utils/qualifiedName'; interface DefinitionViewerProps { tab: TabData; @@ -63,12 +64,8 @@ const DefinitionViewer: React.FC = ({ tab }) => { }; const parseSchemaAndName = (fullName: string): { schema: string; name: string } => { - const raw = String(fullName || '').trim(); - const idx = raw.lastIndexOf('.'); - if (idx > 0 && idx < raw.length - 1) { - return { schema: raw.substring(0, idx), name: raw.substring(idx + 1) }; - } - return { schema: '', name: raw }; + const parsed = splitQualifiedNameLast(fullName); + return { schema: parsed.parentPath, name: parsed.objectName }; }; const getCaseInsensitiveRawValue = (row: Record, candidateKeys: string[]): any => { diff --git a/frontend/src/components/Sidebar.tsx b/frontend/src/components/Sidebar.tsx index 49fda2c..d2a7f8b 100644 --- a/frontend/src/components/Sidebar.tsx +++ b/frontend/src/components/Sidebar.tsx @@ -65,6 +65,7 @@ import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; import { getDataSourceCapabilities } from '../utils/dataSourceCapabilities'; import { noAutoCapInputProps } from '../utils/inputAutoCap'; import { normalizeSidebarViewName, resolveSidebarRuntimeDatabase } from '../utils/sidebarMetadata'; +import { splitQualifiedNameLast } from '../utils/qualifiedName'; import { buildStarRocksMaterializedViewPreviewSql } from './tableDesignerSchemaSql'; import { normalizeOceanBaseProtocol } from '../utils/oceanBaseProtocol'; import { resolveConnectionHostSummary, resolveConnectionHostTokens } from '../utils/tabDisplay'; @@ -1796,9 +1797,8 @@ const Sidebar: React.FC<{ const rawName = String(tableName || '').trim(); if (!rawName) return rawName; if (!shouldHideSchemaPrefix(conn)) return rawName; - const lastDotIndex = rawName.lastIndexOf('.'); - if (lastDotIndex <= 0 || lastDotIndex >= rawName.length - 1) return rawName; - return rawName.substring(lastDotIndex + 1); + const parsed = splitQualifiedName(rawName); + return parsed.objectName || rawName; }; const getMetadataDialect = (conn: SavedConnection | undefined): string => { @@ -1984,15 +1984,10 @@ const Sidebar: React.FC<{ }; const splitQualifiedName = (qualifiedName: string): { schemaName: string; objectName: string } => { - const raw = String(qualifiedName || '').trim(); - if (!raw) return { schemaName: '', objectName: '' }; - const idx = raw.lastIndexOf('.'); - if (idx <= 0 || idx >= raw.length - 1) { - return { schemaName: '', objectName: raw }; - } + const parsed = splitQualifiedNameLast(qualifiedName); return { - schemaName: raw.substring(0, idx), - objectName: raw.substring(idx + 1), + schemaName: parsed.parentPath, + objectName: parsed.objectName, }; }; @@ -4785,12 +4780,7 @@ const Sidebar: React.FC<{ }; const extractObjectName = (fullName: string) => { - const raw = String(fullName || '').trim(); - const idx = raw.lastIndexOf('.'); - if (idx >= 0 && idx < raw.length - 1) { - return raw.substring(idx + 1); - } - return raw; + return splitQualifiedName(String(fullName || '').trim()).objectName || String(fullName || '').trim(); }; const handleRenameDatabase = async () => { @@ -5012,9 +5002,9 @@ const Sidebar: React.FC<{ query = `SHOW CREATE VIEW \`${viewName.replace(/`/g, '``')}\``; break; case 'postgres': case 'kingbase': case 'highgo': case 'vastbase': case 'opengauss': { - const parts = viewName.split('.'); - const schema = parts.length > 1 ? parts[0] : 'public'; - const name = parts.length > 1 ? parts[1] : viewName; + const parts = splitQualifiedName(viewName); + const schema = parts.schemaName || 'public'; + const name = parts.objectName || viewName; query = `SELECT pg_get_viewdef('${escapeSQLLiteral(schema)}.${escapeSQLLiteral(name)}'::regclass, true) AS view_definition`; break; } @@ -5133,7 +5123,10 @@ const Sidebar: React.FC<{ const conn = node.dataRef; const { tableName, dbName, id } = conn; const safeTable = String(tableName || 'table_name').trim(); - const quotedTable = safeTable.includes('`') ? safeTable : safeTable.split('.').map(part => `\`${part.replace(/`/g, '``')}\``).join('.'); + const safeTableParts = [splitQualifiedName(safeTable).schemaName, splitQualifiedName(safeTable).objectName].filter(Boolean); + const quotedTable = safeTable.includes('`') + ? safeTable + : (safeTableParts.length > 0 ? safeTableParts : [safeTable]).map(part => `\`${part.replace(/`/g, '``')}\``).join('.'); addTab({ id: `query-create-starrocks-rollup-${Date.now()}`, title: '新增 Rollup', @@ -6504,7 +6497,7 @@ const Sidebar: React.FC<{ const dbName = String(conn?.dbName || '').trim(); const tableName = String(conn?.tableName || node?.title || '').trim(); const objectName = extractObjectName(tableName); - const schemaName = String(conn?.schemaName || (tableName.includes('.') ? tableName.split('.').slice(0, -1).join('.') : '')).trim(); + const schemaName = String(conn?.schemaName || splitQualifiedName(tableName).schemaName || '').trim(); switch (dialect) { case 'mysql': case 'starrocks': @@ -8140,8 +8133,7 @@ const Sidebar: React.FC<{ const rawTableName = String(node?.dataRef?.tableName || node?.dataRef?.viewName || node?.dataRef?.eventName || '').trim(); const conn = node?.dataRef as SavedConnection | undefined; if (rawTableName && shouldHideSchemaPrefix(conn)) { - const lastDotIndex = rawTableName.lastIndexOf('.'); - if (lastDotIndex > 0 && lastDotIndex < rawTableName.length - 1) { + if (splitQualifiedName(rawTableName).schemaName) { hoverTitle = rawTableName; } } diff --git a/frontend/src/components/TableDesigner.tsx b/frontend/src/components/TableDesigner.tsx index 74b7a6a..c15260f 100644 --- a/frontend/src/components/TableDesigner.tsx +++ b/frontend/src/components/TableDesigner.tsx @@ -29,6 +29,7 @@ import { resolveColumnTypeOptions, resolveSqlDialect, } from '../utils/sqlDialect'; +import { splitQualifiedNameLast, stripIdentifierQuotes } from '../utils/qualifiedName'; interface EditableColumn extends ColumnDefinition { _key: string; @@ -1390,26 +1391,11 @@ ${selectedTrigger.statement}`; const escapeDoubleQuoteIdentifier = (name: string) => String(name || '').replace(/"/g, '""'); const escapeSqlString = (value: string) => String(value || '').replace(/'/g, "''"); - const stripIdentifierQuotes = (part: string): string => { - const text = String(part || '').trim(); - if (!text) return ''; - if ((text.startsWith('`') && text.endsWith('`')) || (text.startsWith('"') && text.endsWith('"'))) { - return text.slice(1, -1).trim(); - } - if (text.startsWith('[') && text.endsWith(']')) { - return text.slice(1, -1).trim(); - } - return text; - }; - const splitQualifiedName = (qualifiedName: string): { schemaName: string; objectName: string } => { - const raw = String(qualifiedName || '').trim(); - if (!raw) return { schemaName: '', objectName: '' }; - const idx = raw.lastIndexOf('.'); - if (idx <= 0 || idx >= raw.length - 1) return { schemaName: '', objectName: raw }; + const parsed = splitQualifiedNameLast(qualifiedName); return { - schemaName: stripIdentifierQuotes(raw.substring(0, idx)), - objectName: stripIdentifierQuotes(raw.substring(idx + 1)), + schemaName: parsed.parentPath, + objectName: parsed.objectName, }; }; diff --git a/frontend/src/components/TriggerViewer.tsx b/frontend/src/components/TriggerViewer.tsx index 3029580..8bf9651 100644 --- a/frontend/src/components/TriggerViewer.tsx +++ b/frontend/src/components/TriggerViewer.tsx @@ -6,6 +6,7 @@ import { useStore } from '../store'; import { DBQuery } from '../../wailsjs/go/app/App'; import { buildRpcConnectionConfig } from '../utils/connectionRpcConfig'; import { normalizeOceanBaseProtocol } from '../utils/oceanBaseProtocol'; +import { splitQualifiedNameLast } from '../utils/qualifiedName'; interface TriggerViewerProps { tab: TabData; @@ -25,12 +26,8 @@ const TriggerViewer: React.FC = ({ tab }) => { const escapeSQLLiteral = (raw: string): string => String(raw || '').replace(/'/g, "''"); const quoteSqlServerIdentifier = (raw: string): string => `[${String(raw || '').replace(/]/g, ']]')}]`; const parseSchemaAndName = (fullName: string): { schema: string; name: string } => { - const raw = String(fullName || '').trim(); - const idx = raw.lastIndexOf('.'); - if (idx > 0 && idx < raw.length - 1) { - return { schema: raw.substring(0, idx), name: raw.substring(idx + 1) }; - } - return { schema: '', name: raw }; + const parsed = splitQualifiedNameLast(fullName); + return { schema: parsed.parentPath, name: parsed.objectName }; }; const getMetadataDialect = (conn: any): string => { diff --git a/frontend/src/components/tableDesignerSchemaSql.ts b/frontend/src/components/tableDesignerSchemaSql.ts index c362e9f..0bf690a 100644 --- a/frontend/src/components/tableDesignerSchemaSql.ts +++ b/frontend/src/components/tableDesignerSchemaSql.ts @@ -10,6 +10,7 @@ import { unquoteSqlIdentifierPart, unquoteSqlIdentifierPath, } from '../utils/sqlDialect'; +import { splitQualifiedNameLast } from '../utils/qualifiedName'; export interface EditableColumnSnapshot { _key: string; @@ -83,13 +84,10 @@ const escapeSqlString = (value: string) => String(value || '').replace(/'/g, "'' const stripIdentifierQuotes = unquoteSqlIdentifierPart; const splitQualifiedName = (qualifiedName: string): { schemaName: string; objectName: string } => { - const raw = String(qualifiedName || '').trim(); - if (!raw) return { schemaName: '', objectName: '' }; - const idx = raw.lastIndexOf('.'); - if (idx <= 0 || idx >= raw.length - 1) return { schemaName: '', objectName: raw }; + const parsed = splitQualifiedNameLast(qualifiedName); return { - schemaName: stripIdentifierQuotes(raw.substring(0, idx)), - objectName: stripIdentifierQuotes(raw.substring(idx + 1)), + schemaName: parsed.parentPath, + objectName: parsed.objectName, }; }; diff --git a/frontend/src/utils/qualifiedName.ts b/frontend/src/utils/qualifiedName.ts new file mode 100644 index 0000000..11bb01f --- /dev/null +++ b/frontend/src/utils/qualifiedName.ts @@ -0,0 +1,126 @@ +export type QualifiedNameParts = { + parentPath: string; + objectName: string; +}; + +const normalizeIdentifierEscapes = (raw: string): string => { + let value = String(raw || '').trim(); + for (let i = 0; i < 4; i += 1) { + const next = String(value || '').trim() + .replace(/\\\\"/g, '\\"') + .replace(/\\"/g, '"'); + if (next === value) break; + value = next; + } + return String(value || '').trim(); +}; + +export const stripIdentifierQuotes = (part: string): string => { + const text = normalizeIdentifierEscapes(part); + if (!text) return ''; + if (text.length >= 2) { + const first = text[0]; + const last = text[text.length - 1]; + if (first === '"' && last === '"') { + return text.slice(1, -1).replace(/""/g, '"').trim(); + } + if (first === '`' && last === '`') { + return text.slice(1, -1).replace(/``/g, '`').trim(); + } + if (first === '[' && last === ']') { + return text.slice(1, -1).replace(/]]/g, ']').trim(); + } + } + return text; +}; + +export const splitQualifiedNameSegments = (qualifiedName: string): string[] => { + const text = normalizeIdentifierEscapes(qualifiedName); + if (!text) return []; + + const segments: string[] = []; + let current = ''; + let inDouble = false; + let inBacktick = false; + let inBracket = false; + + const flush = () => { + const value = current.trim(); + current = ''; + if (!value) return; + segments.push(stripIdentifierQuotes(value)); + }; + + for (let i = 0; i < text.length; i += 1) { + const ch = text[i]; + + if (inDouble) { + current += ch; + if (ch === '"' && text[i + 1] === '"') { + current += text[i + 1]; + i += 1; + continue; + } + if (ch === '"') inDouble = false; + continue; + } + + if (inBacktick) { + current += ch; + if (ch === '`' && text[i + 1] === '`') { + current += text[i + 1]; + i += 1; + continue; + } + if (ch === '`') inBacktick = false; + continue; + } + + if (inBracket) { + current += ch; + if (ch === ']' && text[i + 1] === ']') { + current += text[i + 1]; + i += 1; + continue; + } + if (ch === ']') inBracket = false; + continue; + } + + if (ch === '"') { + inDouble = true; + current += ch; + continue; + } + if (ch === '`') { + inBacktick = true; + current += ch; + continue; + } + if (ch === '[') { + inBracket = true; + current += ch; + continue; + } + if (ch === '.') { + flush(); + continue; + } + current += ch; + } + + flush(); + return segments; +}; + +export const splitQualifiedName = (qualifiedName: string): QualifiedNameParts => { + const segments = splitQualifiedNameSegments(qualifiedName); + if (segments.length === 0) return { parentPath: '', objectName: '' }; + if (segments.length === 1) return { parentPath: '', objectName: segments[0] }; + return { + parentPath: segments.slice(0, -1).join('.'), + objectName: segments[segments.length - 1], + }; +}; + +export const splitQualifiedNameLast = splitQualifiedName; diff --git a/frontend/src/utils/sidebarMetadata.ts b/frontend/src/utils/sidebarMetadata.ts index 3bd5d8a..2950749 100644 --- a/frontend/src/utils/sidebarMetadata.ts +++ b/frontend/src/utils/sidebarMetadata.ts @@ -1,17 +1,5 @@ import { normalizeOceanBaseProtocol } from './oceanBaseProtocol'; - -const splitQualifiedName = (qualifiedName: string): { schemaName: string; objectName: string } => { - const raw = String(qualifiedName || '').trim(); - if (!raw) return { schemaName: '', objectName: '' }; - const idx = raw.lastIndexOf('.'); - if (idx <= 0 || idx >= raw.length - 1) { - return { schemaName: '', objectName: raw }; - } - return { - schemaName: raw.substring(0, idx), - objectName: raw.substring(idx + 1), - }; -}; +import { splitQualifiedNameLast } from './qualifiedName'; const normalizeSidebarConnectionDialect = (type: string, driver: string, oceanBaseProtocol?: string): string => { const normalizedType = String(type || '').trim().toLowerCase(); @@ -45,7 +33,7 @@ export const normalizeSidebarViewName = (dialect: string, dbName: string, schema } if (normalizedDialect === 'mysql') { - const parsed = splitQualifiedName(normalizedViewName); + const parsed = splitQualifiedNameLast(normalizedViewName); if (parsed.objectName) { return parsed.objectName; } diff --git a/frontend/src/utils/sql.test.ts b/frontend/src/utils/sql.test.ts index 7f06dff..e435e93 100644 --- a/frontend/src/utils/sql.test.ts +++ b/frontend/src/utils/sql.test.ts @@ -1,6 +1,6 @@ import { describe, expect, it } from 'vitest'; -import { buildOrderBySQL, buildPaginatedSelectSQL, reverseOrderBySQL } from './sql'; +import { buildOrderBySQL, buildPaginatedSelectSQL, quoteQualifiedIdent, reverseOrderBySQL } from './sql'; describe('buildOrderBySQL', () => { it('does not add fallback ORDER BY for DuckDB without explicit sort', () => { @@ -52,3 +52,15 @@ describe('reverseOrderBySQL', () => { .toBe(' ORDER BY COALESCE([a], [b]) DESC, [id] ASC'); }); }); + +describe('quoteQualifiedIdent', () => { + it('does not split dots inside quoted DuckDB identifiers', () => { + expect(quoteQualifiedIdent('duckdb', '"daily.events"."2026.06"')) + .toBe('"daily.events"."2026.06"'); + }); + + it('preserves three-part DuckDB names with quoted dots', () => { + expect(quoteQualifiedIdent('duckdb', '"analytics.catalog"."main.schema"."daily.events"')) + .toBe('"analytics.catalog"."main.schema"."daily.events"'); + }); +}); diff --git a/frontend/src/utils/sql.ts b/frontend/src/utils/sql.ts index d55f80d..4abfd59 100644 --- a/frontend/src/utils/sql.ts +++ b/frontend/src/utils/sql.ts @@ -1,3 +1,5 @@ +import { splitQualifiedNameSegments, stripIdentifierQuotes } from './qualifiedName'; + export type FilterCondition = { id?: number; enabled?: boolean; @@ -8,17 +10,7 @@ export type FilterCondition = { value2?: string; }; -const normalizeIdentPart = (ident: string) => { - let raw = (ident || '').trim(); - if (!raw) return raw; - const first = raw[0]; - const last = raw[raw.length - 1]; - if ((first === '"' && last === '"') || (first === '`' && last === '`')) { - raw = raw.slice(1, -1).trim(); - } - raw = raw.replace(/["`]/g, '').trim(); - return raw; -}; +const normalizeIdentPart = (ident: string) => stripIdentifierQuotes(ident); // 检查标识符是否需要引号(包含特殊字符或是保留字) const needsQuote = (ident: string): boolean => { @@ -62,9 +54,10 @@ export const quoteIdentPart = (dbType: string, ident: string) => { export const quoteQualifiedIdent = (dbType: string, ident: string) => { const raw = (ident || '').trim(); if (!raw) return raw; - const parts = raw.split('.').map(normalizeIdentPart).filter(Boolean); - if (parts.length <= 1) return quoteIdentPart(dbType, raw); - return parts.map(p => quoteIdentPart(dbType, p)).join('.'); + const parts = splitQualifiedNameSegments(raw).filter(Boolean); + if (parts.length === 0) return quoteIdentPart(dbType, raw); + if (parts.length === 1 && parts[0] === normalizeIdentPart(raw)) return quoteIdentPart(dbType, raw); + return parts.map((part) => quoteIdentPart(dbType, part)).join('.'); }; export const escapeLiteral = (val: string) => (val || '').replace(/'/g, "''"); diff --git a/internal/app/db_context.go b/internal/app/db_context.go index 3303e34..f0fd243 100644 --- a/internal/app/db_context.go +++ b/internal/app/db_context.go @@ -58,6 +58,10 @@ func normalizeSchemaAndTable(config connection.ConnectionConfig, dbName string, return targetDB, rawTable } + if dbType == "duckdb" { + return rawDB, rawTable + } + if dbType == "kingbase" { schema, table := db.SplitKingbaseQualifiedName(rawTable) if schema != "" && table != "" { diff --git a/internal/app/db_context_test.go b/internal/app/db_context_test.go index ee02ea7..48a6d19 100644 --- a/internal/app/db_context_test.go +++ b/internal/app/db_context_test.go @@ -140,6 +140,21 @@ func TestNormalizeSchemaAndTable_OceanBaseOracleUsesSchemaFromDatabaseTree(t *te } } +func TestNormalizeSchemaAndTable_DuckDBPreservesQuotedQualifiedName(t *testing.T) { + t.Parallel() + + schemaOrDb, table := normalizeSchemaAndTable(connection.ConnectionConfig{ + Type: "duckdb", + }, `"analytics.catalog"."main.schema"`, `"daily.events"."2026.06"`) + + if schemaOrDb != `"analytics.catalog"."main.schema"` { + t.Fatalf("expected duckdb dbName/catalog path preserved, got %q", schemaOrDb) + } + if table != `"daily.events"."2026.06"` { + t.Fatalf("expected duckdb qualified table preserved, got %q", table) + } +} + func TestQuoteTableIdentByType_KingbaseNormalizesQuotedQualifiedTable(t *testing.T) { t.Parallel() diff --git a/internal/app/methods_db.go b/internal/app/methods_db.go index 2519d9c..1cf99b5 100644 --- a/internal/app/methods_db.go +++ b/internal/app/methods_db.go @@ -320,6 +320,10 @@ func normalizeSchemaAndTableByType(dbType string, dbName string, tableName strin } } + if dbType == "duckdb" { + return rawDB, rawTable + } + if parts := strings.SplitN(rawTable, ".", 2); len(parts) == 2 { schema := strings.TrimSpace(parts[0]) table := strings.TrimSpace(parts[1]) diff --git a/internal/db/duckdb_impl.go b/internal/db/duckdb_impl.go index e0e36b9..82d6e6b 100644 --- a/internal/db/duckdb_impl.go +++ b/internal/db/duckdb_impl.go @@ -160,12 +160,22 @@ func (d *DuckDB) GetDatabases() ([]string, error) { } func (d *DuckDB) GetTables(dbName string) ([]string, error) { + path := normalizeDuckDBObjectPath(dbName, "") query := ` -SELECT table_schema, table_name +SELECT table_catalog, table_schema, table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('information_schema', 'pg_catalog') -ORDER BY table_schema, table_name` +ORDER BY table_catalog, table_schema, table_name` + if path.Catalog != "" { + query = fmt.Sprintf(` +SELECT table_catalog, table_schema, table_name +FROM information_schema.tables +WHERE table_type = 'BASE TABLE' + AND table_schema NOT IN ('information_schema', 'pg_catalog') + AND table_catalog = '%s' +ORDER BY table_catalog, table_schema, table_name`, escapeDuckDBLiteral(path.Catalog)) + } data, _, err := d.Query(query) if err != nil { @@ -175,15 +185,19 @@ ORDER BY table_schema, table_name` seen := map[string]struct{}{} var tables []string for _, row := range data { + catalog := strings.TrimSpace(duckDBRowString(row, "table_catalog", "database_name")) schema := strings.TrimSpace(duckDBRowString(row, "table_schema")) name := strings.TrimSpace(duckDBRowString(row, "table_name")) if name == "" { continue } qualified := name - if schema != "" && !strings.EqualFold(schema, "main") { + if schema != "" { qualified = schema + "." + name } + if catalog != "" && !strings.EqualFold(catalog, "memory") && !strings.EqualFold(catalog, "main") { + qualified = catalog + "." + qualified + } if _, exists := seen[qualified]; exists { continue } @@ -194,18 +208,29 @@ ORDER BY table_schema, table_name` } func (d *DuckDB) GetCreateStatement(dbName, tableName string) (string, error) { - schema, pureTable := normalizeDuckDBSchemaAndTable(dbName, tableName) - if pureTable == "" { + path := normalizeDuckDBObjectPath(dbName, tableName) + if path.Object == "" { return "", fmt.Errorf("表名不能为空") } - escapedTable := escapeDuckDBLiteral(pureTable) - escapedSchema := escapeDuckDBLiteral(schema) + escapedTable := escapeDuckDBLiteral(path.Object) + escapedSchema := escapeDuckDBLiteral(path.Schema) + escapedCatalog := escapeDuckDBLiteral(path.Catalog) - queryCandidates := []string{ + queryCandidates := make([]string, 0, 4) + if path.Catalog != "" { + queryCandidates = append(queryCandidates, fmt.Sprintf("SELECT sql FROM duckdb_tables() WHERE table_name = '%s' AND schema_name = '%s' AND database_name = '%s' LIMIT 1", escapedTable, escapedSchema, escapedCatalog)) + } + queryCandidates = append(queryCandidates, fmt.Sprintf("SELECT sql FROM duckdb_tables() WHERE table_name = '%s' AND schema_name = '%s' LIMIT 1", escapedTable, escapedSchema), fmt.Sprintf("SELECT sql FROM duckdb_tables() WHERE table_name = '%s' LIMIT 1", escapedTable), - fmt.Sprintf("SHOW CREATE TABLE %s", quoteDuckDBQualifiedTable(schema, pureTable)), + fmt.Sprintf("SHOW CREATE TABLE %s", quoteDuckDBQualifiedTable(path.Schema, path.Object)), + ) + + if path.Catalog != "" { + queryCandidates = append([]string{ + fmt.Sprintf("SHOW CREATE TABLE %s.%s", quoteDuckDBIdentifier(path.Catalog), quoteDuckDBQualifiedTable(path.Schema, path.Object)), + }, queryCandidates...) } for _, query := range queryCandidates { @@ -230,8 +255,8 @@ func (d *DuckDB) GetCreateStatement(dbName, tableName string) (string, error) { } func (d *DuckDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { - schema, pureTable := normalizeDuckDBSchemaAndTable(dbName, tableName) - if pureTable == "" { + path := normalizeDuckDBObjectPath(dbName, tableName) + if path.Object == "" { return nil, fmt.Errorf("表名不能为空") } @@ -239,52 +264,62 @@ func (d *DuckDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefini SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_name = '%s' AND table_schema = '%s' -ORDER BY ordinal_position`, escapeDuckDBLiteral(pureTable), escapeDuckDBLiteral(schema)) +ORDER BY ordinal_position`, escapeDuckDBLiteral(path.Object), escapeDuckDBLiteral(path.Schema)) + if path.Catalog != "" { + query = fmt.Sprintf(` +SELECT column_name, data_type, is_nullable, column_default +FROM information_schema.columns +WHERE table_name = '%s' AND table_schema = '%s' AND table_catalog = '%s' +ORDER BY ordinal_position`, escapeDuckDBLiteral(path.Object), escapeDuckDBLiteral(path.Schema), escapeDuckDBLiteral(path.Catalog)) + } data, _, err := d.Query(query) if err != nil { return nil, err } - if len(data) == 0 && schema != "main" { + if len(data) == 0 && path.Schema != "main" { fallbackQuery := fmt.Sprintf(` SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_name = '%s' -ORDER BY ordinal_position`, escapeDuckDBLiteral(pureTable)) +ORDER BY ordinal_position`, escapeDuckDBLiteral(path.Object)) data, _, err = d.Query(fallbackQuery) if err != nil { return nil, err } } - var columns []connection.ColumnDefinition - for _, row := range data { - column := connection.ColumnDefinition{ - Name: duckDBRowString(row, "column_name"), - Type: duckDBRowString(row, "data_type"), - Nullable: strings.ToUpper(strings.TrimSpace(duckDBRowString(row, "is_nullable"))), - Key: "", - Extra: "", - Comment: "", - } - if column.Nullable == "" { - column.Nullable = "YES" - } - if defaultVal := strings.TrimSpace(duckDBRowString(row, "column_default")); defaultVal != "" && defaultVal != "" { - def := defaultVal - column.Default = &def - } - columns = append(columns, column) + constraintQuery := buildDuckDBConstraintMetadataQuery(path, true) + constraintRows, _, constraintErr := d.Query(constraintQuery) + if constraintErr != nil { + return nil, constraintErr } - return columns, nil + if len(constraintRows) == 0 && path.Schema != "main" { + fallbackConstraintQuery := buildDuckDBConstraintMetadataQuery(path, false) + constraintRows, _, constraintErr = d.Query(fallbackConstraintQuery) + if constraintErr != nil { + return nil, constraintErr + } + } + + return buildDuckDBColumnDefinitions(data, constraintRows), nil } func (d *DuckDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { + path := normalizeDuckDBObjectPath(dbName, "") query := ` -SELECT table_schema, table_name, column_name, data_type +SELECT table_catalog, table_schema, table_name, column_name, data_type FROM information_schema.columns WHERE table_schema NOT IN ('information_schema', 'pg_catalog') -ORDER BY table_schema, table_name, ordinal_position` +ORDER BY table_catalog, table_schema, table_name, ordinal_position` + if path.Catalog != "" { + query = fmt.Sprintf(` +SELECT table_catalog, table_schema, table_name, column_name, data_type +FROM information_schema.columns +WHERE table_schema NOT IN ('information_schema', 'pg_catalog') + AND table_catalog = '%s' +ORDER BY table_catalog, table_schema, table_name, ordinal_position`, escapeDuckDBLiteral(path.Catalog)) + } data, _, err := d.Query(query) if err != nil { @@ -293,14 +328,18 @@ ORDER BY table_schema, table_name, ordinal_position` columns := make([]connection.ColumnDefinitionWithTable, 0, len(data)) for _, row := range data { + catalog := strings.TrimSpace(duckDBRowString(row, "table_catalog", "database_name")) schema := strings.TrimSpace(duckDBRowString(row, "table_schema")) tableName := strings.TrimSpace(duckDBRowString(row, "table_name")) if tableName == "" { continue } - if schema != "" && !strings.EqualFold(schema, "main") { + if schema != "" { tableName = schema + "." + tableName } + if catalog != "" && !strings.EqualFold(catalog, "memory") && !strings.EqualFold(catalog, "main") { + tableName = catalog + "." + tableName + } columns = append(columns, connection.ColumnDefinitionWithTable{ TableName: tableName, @@ -312,7 +351,38 @@ ORDER BY table_schema, table_name, ordinal_position` } func (d *DuckDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { - return []connection.IndexDefinition{}, nil + path := normalizeDuckDBObjectPath(dbName, tableName) + if path.Object == "" { + return nil, fmt.Errorf("表名不能为空") + } + + constraintQuery := buildDuckDBConstraintMetadataQuery(path, true) + constraintRows, _, err := d.Query(constraintQuery) + if err != nil { + return nil, err + } + if len(constraintRows) == 0 && path.Schema != "main" { + fallbackQuery := buildDuckDBConstraintMetadataQuery(path, false) + constraintRows, _, err = d.Query(fallbackQuery) + if err != nil { + return nil, err + } + } + + indexQuery := buildDuckDBIndexMetadataQuery(path, true) + indexRows, _, indexErr := d.Query(indexQuery) + if indexErr != nil { + return nil, indexErr + } + if len(indexRows) == 0 && path.Schema != "main" { + fallbackIndexQuery := buildDuckDBIndexMetadataQuery(path, false) + indexRows, _, indexErr = d.Query(fallbackIndexQuery) + if indexErr != nil { + return nil, indexErr + } + } + + return buildDuckDBIndexDefinitions(constraintRows, indexRows), nil } func (d *DuckDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { @@ -344,12 +414,9 @@ func (d *DuckDB) ApplyChanges(tableName string, changes connection.ChangeSet) er return `"` + n + `"` } - schema := "" - table := strings.TrimSpace(tableName) - if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { - schema = strings.TrimSpace(parts[0]) - table = strings.TrimSpace(parts[1]) - } + path := normalizeDuckDBObjectPath("", tableName) + schema := path.Schema + table := path.Object qualifiedTable := quoteIdent(table) if schema != "" { @@ -413,69 +480,3 @@ func (d *DuckDB) ApplyChanges(tableName string, changes connection.ChangeSet) er return tx.Commit() } - -func normalizeDuckDBSchemaAndTable(dbName string, tableName string) (string, string) { - schema := strings.TrimSpace(dbName) - table := strings.TrimSpace(tableName) - if table == "" { - if schema == "" { - schema = "main" - } - return schema, table - } - - if parts := strings.SplitN(table, ".", 2); len(parts) == 2 { - left := strings.TrimSpace(parts[0]) - right := strings.TrimSpace(parts[1]) - if left != "" && right != "" { - return normalizeDuckDBIdentifier(left), normalizeDuckDBIdentifier(right) - } - } - - if schema == "" { - schema = "main" - } - return normalizeDuckDBIdentifier(schema), normalizeDuckDBIdentifier(table) -} - -func normalizeDuckDBIdentifier(raw string) string { - text := strings.TrimSpace(raw) - if len(text) >= 2 { - first := text[0] - last := text[len(text)-1] - if (first == '"' && last == '"') || (first == '`' && last == '`') { - text = strings.TrimSpace(text[1 : len(text)-1]) - } - } - return text -} - -func quoteDuckDBIdentifier(raw string) string { - text := normalizeDuckDBIdentifier(raw) - return `"` + strings.ReplaceAll(text, `"`, `""`) + `"` -} - -func quoteDuckDBQualifiedTable(schema string, table string) string { - s := strings.TrimSpace(schema) - t := strings.TrimSpace(table) - if s == "" { - return quoteDuckDBIdentifier(t) - } - return quoteDuckDBIdentifier(s) + "." + quoteDuckDBIdentifier(t) -} - -func duckDBRowString(row map[string]interface{}, keys ...string) string { - for _, key := range keys { - for rowKey, value := range row { - if !strings.EqualFold(rowKey, key) || value == nil { - continue - } - return fmt.Sprintf("%v", value) - } - } - return "" -} - -func escapeDuckDBLiteral(raw string) string { - return strings.ReplaceAll(raw, "'", "''") -} diff --git a/internal/db/duckdb_metadata.go b/internal/db/duckdb_metadata.go new file mode 100644 index 0000000..2747723 --- /dev/null +++ b/internal/db/duckdb_metadata.go @@ -0,0 +1,483 @@ +package db + +import ( + "fmt" + "strings" + + "GoNavi-Wails/internal/connection" +) + +type duckDBObjectPath struct { + Catalog string + Schema string + Object string +} + +func buildDuckDBConstraintMetadataQuery(path duckDBObjectPath, exact bool) string { + base := ` +SELECT + database_name, + schema_name, + table_name, + constraint_name, + constraint_type, + constraint_column_names +FROM duckdb_constraints() +WHERE table_name = '%s' + AND constraint_type IN ('PRIMARY KEY', 'UNIQUE')` + args := []any{escapeDuckDBLiteral(path.Object)} + if exact && path.Schema != "" { + base += "\n AND schema_name = '%s'" + args = append(args, escapeDuckDBLiteral(path.Schema)) + } + if exact && path.Catalog != "" { + base += "\n AND database_name = '%s'" + args = append(args, escapeDuckDBLiteral(path.Catalog)) + } + base += "\nORDER BY database_name, schema_name, table_name, constraint_type, constraint_name" + return fmt.Sprintf(base, args...) +} + +func buildDuckDBIndexMetadataQuery(path duckDBObjectPath, exact bool) string { + base := ` +SELECT + database_name, + schema_name, + table_name, + index_name, + is_unique, + expressions +FROM duckdb_indexes() +WHERE table_name = '%s'` + args := []any{escapeDuckDBLiteral(path.Object)} + if exact && path.Schema != "" { + base += "\n AND schema_name = '%s'" + args = append(args, escapeDuckDBLiteral(path.Schema)) + } + if exact && path.Catalog != "" { + base += "\n AND database_name = '%s'" + args = append(args, escapeDuckDBLiteral(path.Catalog)) + } + base += "\nORDER BY database_name, schema_name, table_name, index_name" + return fmt.Sprintf(base, args...) +} + +func buildDuckDBColumnDefinitions(rows []map[string]interface{}, constraintRows []map[string]interface{}) []connection.ColumnDefinition { + primaryKeyColumns := make(map[string]struct{}) + uniqueColumns := make(map[string]struct{}) + + for _, row := range constraintRows { + columnNames := parseDuckDBIdentifierList(duckDBRowString(row, "constraint_column_names")) + switch strings.ToUpper(strings.TrimSpace(duckDBRowString(row, "constraint_type"))) { + case "PRIMARY KEY": + for _, columnName := range columnNames { + primaryKeyColumns[strings.ToLower(columnName)] = struct{}{} + } + case "UNIQUE": + for _, columnName := range columnNames { + uniqueColumns[strings.ToLower(columnName)] = struct{}{} + } + } + } + + columns := make([]connection.ColumnDefinition, 0, len(rows)) + for _, row := range rows { + columnName := strings.TrimSpace(duckDBRowString(row, "column_name")) + column := connection.ColumnDefinition{ + Name: columnName, + Type: duckDBRowString(row, "data_type"), + Nullable: strings.ToUpper(strings.TrimSpace(duckDBRowString(row, "is_nullable"))), + Key: "", + Extra: "", + Comment: "", + } + if column.Nullable == "" { + column.Nullable = "YES" + } + if _, ok := primaryKeyColumns[strings.ToLower(columnName)]; ok { + column.Key = "PRI" + } else if _, ok := uniqueColumns[strings.ToLower(columnName)]; ok { + column.Key = "UNI" + } + if defaultVal := strings.TrimSpace(duckDBRowString(row, "column_default")); defaultVal != "" && defaultVal != "" { + def := defaultVal + column.Default = &def + } + columns = append(columns, column) + } + + return columns +} + +func buildDuckDBIndexDefinitions(constraintRows []map[string]interface{}, indexRows []map[string]interface{}) []connection.IndexDefinition { + indexes := make([]connection.IndexDefinition, 0, len(constraintRows)+len(indexRows)) + + for _, row := range constraintRows { + name := strings.TrimSpace(duckDBRowString(row, "constraint_name")) + constraintType := strings.ToUpper(strings.TrimSpace(duckDBRowString(row, "constraint_type"))) + columnNames := parseDuckDBIdentifierList(duckDBRowString(row, "constraint_column_names")) + if name == "" || len(columnNames) == 0 { + continue + } + for idx, columnName := range columnNames { + indexes = append(indexes, connection.IndexDefinition{ + Name: name, + ColumnName: columnName, + NonUnique: 0, + SeqInIndex: idx + 1, + IndexType: constraintType, + }) + } + } + + for _, row := range indexRows { + name := strings.TrimSpace(duckDBRowString(row, "index_name")) + columnNames := parseDuckDBExpressionList(duckDBRowString(row, "expressions")) + if name == "" || len(columnNames) == 0 { + continue + } + nonUnique := 1 + if duckDBRowBool(row, "is_unique") { + nonUnique = 0 + } + for idx, columnName := range columnNames { + indexes = append(indexes, connection.IndexDefinition{ + Name: name, + ColumnName: columnName, + NonUnique: nonUnique, + SeqInIndex: idx + 1, + IndexType: "INDEX", + }) + } + } + + return indexes +} + +func normalizeDuckDBObjectPath(dbName string, tableName string) duckDBObjectPath { + rawDB := strings.TrimSpace(dbName) + rawTable := strings.TrimSpace(tableName) + if rawTable == "" { + if rawDB == "" { + return duckDBObjectPath{Schema: "main"} + } + dbParts := splitDuckDBQualifiedName(rawDB) + switch len(dbParts) { + case 0: + return duckDBObjectPath{Schema: "main"} + case 1: + return duckDBObjectPath{Catalog: normalizeDuckDBIdentifier(dbParts[0])} + default: + return duckDBObjectPath{ + Catalog: normalizeDuckDBIdentifier(dbParts[0]), + Schema: normalizeDuckDBIdentifier(dbParts[len(dbParts)-1]), + } + } + } + + parts := splitDuckDBQualifiedName(rawTable) + switch len(parts) { + case 0: + return duckDBObjectPath{Schema: "main"} + case 1: + schema := "main" + if rawDB != "" { + dbParts := splitDuckDBQualifiedName(rawDB) + if len(dbParts) >= 2 { + return duckDBObjectPath{ + Catalog: normalizeDuckDBIdentifier(dbParts[0]), + Schema: normalizeDuckDBIdentifier(dbParts[len(dbParts)-1]), + Object: normalizeDuckDBIdentifier(parts[0]), + } + } + schema = normalizeDuckDBIdentifier(rawDB) + } + return duckDBObjectPath{ + Schema: schema, + Object: normalizeDuckDBIdentifier(parts[0]), + } + case 2: + if rawDB != "" { + dbParts := splitDuckDBQualifiedName(rawDB) + if len(dbParts) == 1 { + return duckDBObjectPath{ + Catalog: normalizeDuckDBIdentifier(dbParts[0]), + Schema: normalizeDuckDBIdentifier(parts[0]), + Object: normalizeDuckDBIdentifier(parts[1]), + } + } + if len(dbParts) >= 2 { + return duckDBObjectPath{ + Catalog: normalizeDuckDBIdentifier(dbParts[0]), + Schema: normalizeDuckDBIdentifier(parts[0]), + Object: normalizeDuckDBIdentifier(parts[1]), + } + } + } + return duckDBObjectPath{ + Schema: normalizeDuckDBIdentifier(parts[0]), + Object: normalizeDuckDBIdentifier(parts[1]), + } + default: + return duckDBObjectPath{ + Catalog: normalizeDuckDBIdentifier(parts[len(parts)-3]), + Schema: normalizeDuckDBIdentifier(parts[len(parts)-2]), + Object: normalizeDuckDBIdentifier(parts[len(parts)-1]), + } + } +} + +func normalizeDuckDBSchemaAndTable(dbName string, tableName string) (string, string) { + path := normalizeDuckDBObjectPath(dbName, tableName) + schema := path.Schema + if schema == "" { + schema = "main" + } + return schema, path.Object +} + +func normalizeDuckDBIdentifier(raw string) string { + text := strings.TrimSpace(normalizeSQLIdentifierEscapes(raw)) + if len(text) >= 2 { + first := text[0] + last := text[len(text)-1] + if (first == '"' && last == '"') || (first == '`' && last == '`') { + text = strings.TrimSpace(text[1 : len(text)-1]) + } + } + return text +} + +func quoteDuckDBIdentifier(raw string) string { + text := normalizeDuckDBIdentifier(raw) + return `"` + strings.ReplaceAll(text, `"`, `""`) + `"` +} + +func quoteDuckDBQualifiedTable(schema string, table string) string { + s := strings.TrimSpace(schema) + t := strings.TrimSpace(table) + if s == "" { + return quoteDuckDBIdentifier(t) + } + return quoteDuckDBIdentifier(s) + "." + quoteDuckDBIdentifier(t) +} + +func duckDBRowString(row map[string]interface{}, keys ...string) string { + for _, key := range keys { + for rowKey, value := range row { + if !strings.EqualFold(rowKey, key) || value == nil { + continue + } + return fmt.Sprintf("%v", value) + } + } + return "" +} + +func duckDBRowBool(row map[string]interface{}, keys ...string) bool { + value := strings.TrimSpace(strings.ToLower(duckDBRowString(row, keys...))) + return value == "true" || value == "1" || value == "yes" +} + +func duckDBRowInt(row map[string]interface{}, keys ...string) int { + raw := strings.TrimSpace(duckDBRowString(row, keys...)) + if raw == "" { + return 0 + } + var value int + _, _ = fmt.Sscanf(raw, "%d", &value) + return value +} + +func parseDuckDBIdentifierList(raw string) []string { + return parseDuckDBList(raw, true) +} + +func parseDuckDBExpressionList(raw string) []string { + values := parseDuckDBList(raw, false) + if len(values) == 0 { + return values + } + normalized := make([]string, 0, len(values)) + for _, value := range values { + trimmed := strings.TrimSpace(value) + switch { + case trimmed == "": + continue + case isDuckDBSimpleIdentifierExpression(trimmed): + normalized = append(normalized, normalizeDuckDBIdentifier(trimmed)) + default: + normalized = append(normalized, trimmed) + } + } + return normalized +} + +func parseDuckDBList(raw string, normalize bool) []string { + text := strings.TrimSpace(normalizeSQLIdentifierEscapes(raw)) + if text == "" { + return nil + } + if strings.HasPrefix(text, "[") && strings.HasSuffix(text, "]") { + text = text[1 : len(text)-1] + } + values := make([]string, 0) + var current strings.Builder + inDouble := false + inBacktick := false + depth := 0 + + flush := func() { + value := strings.TrimSpace(current.String()) + current.Reset() + if value == "" { + return + } + if normalize { + value = normalizeDuckDBIdentifier(value) + } + values = append(values, value) + } + + for i := 0; i < len(text); i++ { + ch := text[i] + switch ch { + case '"': + current.WriteByte(ch) + if inDouble && i+1 < len(text) && text[i+1] == '"' { + current.WriteByte(text[i+1]) + i++ + continue + } + if !inBacktick { + inDouble = !inDouble + } + case '`': + current.WriteByte(ch) + if inBacktick && i+1 < len(text) && text[i+1] == '`' { + current.WriteByte(text[i+1]) + i++ + continue + } + if !inDouble { + inBacktick = !inBacktick + } + case '(': + current.WriteByte(ch) + if !inDouble && !inBacktick { + depth++ + } + case ')': + current.WriteByte(ch) + if !inDouble && !inBacktick && depth > 0 { + depth-- + } + case ',': + if !inDouble && !inBacktick && depth == 0 { + flush() + continue + } + current.WriteByte(ch) + default: + current.WriteByte(ch) + } + } + flush() + + return values +} + +func splitDuckDBQualifiedName(raw string) []string { + text := strings.TrimSpace(normalizeSQLIdentifierEscapes(raw)) + if text == "" { + return nil + } + parts := make([]string, 0, 3) + var current strings.Builder + inDouble := false + inBacktick := false + inBracket := false + + flush := func() { + value := strings.TrimSpace(current.String()) + current.Reset() + if value == "" { + return + } + parts = append(parts, value) + } + + for i := 0; i < len(text); i++ { + ch := text[i] + if inDouble { + current.WriteByte(ch) + if ch == '"' && i+1 < len(text) && text[i+1] == '"' { + current.WriteByte(text[i+1]) + i++ + continue + } + if ch == '"' { + inDouble = false + } + continue + } + if inBacktick { + current.WriteByte(ch) + if ch == '`' && i+1 < len(text) && text[i+1] == '`' { + current.WriteByte(text[i+1]) + i++ + continue + } + if ch == '`' { + inBacktick = false + } + continue + } + if inBracket { + current.WriteByte(ch) + if ch == ']' && i+1 < len(text) && text[i+1] == ']' { + current.WriteByte(text[i+1]) + i++ + continue + } + if ch == ']' { + inBracket = false + } + continue + } + + switch ch { + case '"': + inDouble = true + current.WriteByte(ch) + case '`': + inBacktick = true + current.WriteByte(ch) + case '[': + inBracket = true + current.WriteByte(ch) + case '.': + flush() + default: + current.WriteByte(ch) + } + } + flush() + + return parts +} + +func isDuckDBSimpleIdentifierExpression(raw string) bool { + text := strings.TrimSpace(raw) + if text == "" { + return false + } + if strings.ContainsAny(text, "() +-/*%") { + return false + } + return true +} + +func escapeDuckDBLiteral(raw string) string { + return strings.ReplaceAll(raw, "'", "''") +} diff --git a/internal/db/duckdb_metadata_test.go b/internal/db/duckdb_metadata_test.go new file mode 100644 index 0000000..3644837 --- /dev/null +++ b/internal/db/duckdb_metadata_test.go @@ -0,0 +1,165 @@ +package db + +import ( + "strings" + "testing" +) + +func TestBuildDuckDBConstraintMetadataQuery_UsesDuckDBConstraints(t *testing.T) { + t.Parallel() + + query := buildDuckDBConstraintMetadataQuery(duckDBObjectPath{ + Catalog: "analytics", + Schema: "main", + Object: "events", + }, true) + + if !containsAll(query, + "FROM duckdb_constraints()", + "constraint_type IN ('PRIMARY KEY', 'UNIQUE')", + "database_name = 'analytics'", + "schema_name = 'main'", + "table_name = 'events'", + ) { + t.Fatalf("DuckDB 约束查询未正确包含 catalog/schema/table 过滤: %s", query) + } +} + +func TestBuildDuckDBIndexMetadataQuery_UsesDuckDBIndexes(t *testing.T) { + t.Parallel() + + query := buildDuckDBIndexMetadataQuery(duckDBObjectPath{ + Catalog: "analytics", + Schema: "main", + Object: "events", + }, true) + + if !containsAll(query, + "FROM duckdb_indexes()", + "database_name = 'analytics'", + "schema_name = 'main'", + "table_name = 'events'", + ) { + t.Fatalf("DuckDB 索引查询未正确包含 catalog/schema/table 过滤: %s", query) + } +} + +func TestBuildDuckDBColumnDefinitions_MarksPrimaryAndUniqueColumns(t *testing.T) { + t.Parallel() + + columns := buildDuckDBColumnDefinitions( + []map[string]interface{}{ + { + "column_name": "id", + "data_type": "BIGINT", + "is_nullable": "NO", + "column_default": nil, + }, + { + "column_name": "email", + "data_type": "VARCHAR", + "is_nullable": "YES", + "column_default": "'guest@example.com'", + }, + }, + []map[string]interface{}{ + { + "constraint_name": "events_pkey", + "constraint_type": "PRIMARY KEY", + "constraint_column_names": "[id]", + }, + { + "constraint_name": "events_email_key", + "constraint_type": "UNIQUE", + "constraint_column_names": "[email]", + }, + }, + ) + + if len(columns) != 2 { + t.Fatalf("unexpected column count: %d", len(columns)) + } + if columns[0].Name != "id" || columns[0].Key != "PRI" { + t.Fatalf("主键列未正确标记: %+v", columns[0]) + } + if columns[1].Name != "email" || columns[1].Key != "UNI" { + t.Fatalf("唯一键列未正确标记: %+v", columns[1]) + } + if columns[1].Default == nil || *columns[1].Default != "'guest@example.com'" { + t.Fatalf("默认值未保留: %+v", columns[1]) + } +} + +func TestBuildDuckDBIndexDefinitions_MergesConstraintsAndUniqueIndexes(t *testing.T) { + t.Parallel() + + indexes := buildDuckDBIndexDefinitions( + []map[string]interface{}{ + { + "constraint_name": "events_pkey", + "constraint_type": "PRIMARY KEY", + "constraint_column_names": "[id]", + }, + { + "constraint_name": "events_business_key", + "constraint_type": "UNIQUE", + "constraint_column_names": "[email, region]", + }, + }, + []map[string]interface{}{ + { + "index_name": "idx_events_slug", + "is_unique": true, + "expressions": "[slug]", + }, + }, + ) + + if len(indexes) != 4 { + t.Fatalf("unexpected index row count: %d", len(indexes)) + } + if indexes[0].Name != "events_pkey" || indexes[0].ColumnName != "id" || indexes[0].NonUnique != 0 { + t.Fatalf("主键索引映射异常: %+v", indexes[0]) + } + if indexes[1].Name != "events_business_key" || indexes[1].ColumnName != "email" || indexes[1].SeqInIndex != 1 { + t.Fatalf("约束唯一索引首列映射异常: %+v", indexes[1]) + } + if indexes[2].Name != "events_business_key" || indexes[2].ColumnName != "region" || indexes[2].SeqInIndex != 2 { + t.Fatalf("约束唯一索引次列映射异常: %+v", indexes[2]) + } + if indexes[3].Name != "idx_events_slug" || indexes[3].ColumnName != "slug" || indexes[3].NonUnique != 0 || indexes[3].IndexType != "INDEX" { + t.Fatalf("显式唯一索引映射异常: %+v", indexes[3]) + } +} + +func TestNormalizeDuckDBObjectPath_PreservesCatalogSchemaAndQuotedDots(t *testing.T) { + t.Parallel() + + path := normalizeDuckDBObjectPath(`"analytics.catalog"."main.schema"`, `"daily.events"."2026.06"`) + if path.Catalog != "analytics.catalog" || path.Schema != "daily.events" || path.Object != "2026.06" { + t.Fatalf("unexpected duckdb path: %+v", path) + } + + qualified := normalizeDuckDBObjectPath(`analytics`, `"main.schema"."daily.events"`) + if qualified.Catalog != "analytics" || qualified.Schema != "main.schema" || qualified.Object != "daily.events" { + t.Fatalf("unexpected duckdb qualified path without catalog: %+v", qualified) + } +} + +func TestParseDuckDBExpressionList_KeepsQuotedExpressionsIntact(t *testing.T) { + t.Parallel() + + parts := parseDuckDBExpressionList(`["slug", lower("name.with.dot")]`) + if len(parts) != 2 || parts[0] != `slug` || parts[1] != `lower("name.with.dot")` { + t.Fatalf("unexpected expression list: %#v", parts) + } +} + +func containsAll(source string, needles ...string) bool { + for _, needle := range needles { + if !strings.Contains(source, needle) { + return false + } + } + return true +}