From c805b16fcd5e188665afacb18276e19779635ded Mon Sep 17 00:00:00 2001 From: Syngnat Date: Sat, 13 Jun 2026 17:03:20 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(qdrant):=20=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=20Qdrant=20=E5=90=91=E9=87=8F=E5=BA=93=E8=BF=9E=E6=8E=A5?= =?UTF-8?q?=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 后端新增 Qdrant REST 连接、collection 元数据、scroll/search 查询与 upsert/delete/payload 更新 - 前端新增 Qdrant 类型、连接配置、图标、方言和能力矩阵 - 测试覆盖 mock REST、真实服务 smoke 和前端配置 Refs #555 --- .../ConnectionModal.edit-password.test.tsx | 16 +- frontend/src/components/ConnectionModal.tsx | 17 +- .../src/components/DatabaseIcons.test.tsx | 7 + frontend/src/components/DatabaseIcons.tsx | 8 +- .../src/utils/connectionDriverType.test.ts | 2 + frontend/src/utils/connectionDriverType.ts | 1 + .../utils/connectionModalPresentation.test.ts | 9 + .../src/utils/connectionModalPresentation.ts | 2 +- .../utils/connectionTypeCapabilities.test.ts | 5 + .../src/utils/connectionTypeCapabilities.ts | 6 +- .../src/utils/connectionTypeCatalog.test.ts | 3 + frontend/src/utils/connectionTypeCatalog.ts | 5 + .../src/utils/dataSourceCapabilities.test.ts | 18 + frontend/src/utils/dataSourceCapabilities.ts | 3 + frontend/src/utils/sqlDialect.test.ts | 2 + frontend/src/utils/sqlDialect.ts | 6 + internal/app/db_proxy.go | 2 + internal/db/database.go | 5 + internal/db/driver_support.go | 5 + internal/db/qdrant_impl.go | 1049 +++++++++++++++++ internal/db/qdrant_impl_test.go | 274 +++++ 21 files changed, 1434 insertions(+), 11 deletions(-) create mode 100644 internal/db/qdrant_impl.go create mode 100644 internal/db/qdrant_impl_test.go diff --git a/frontend/src/components/ConnectionModal.edit-password.test.tsx b/frontend/src/components/ConnectionModal.edit-password.test.tsx index 2b570f4..125bd70 100644 --- a/frontend/src/components/ConnectionModal.edit-password.test.tsx +++ b/frontend/src/components/ConnectionModal.edit-password.test.tsx @@ -33,10 +33,10 @@ describe('ConnectionModal data source registry', () => { expect(source).toContain('type === "elasticsearch"'); expect(source).toContain("return '支持索引浏览、Mapping 检查、JSON DSL 和 query_string 查询';"); expect(source).toContain( - 'type === "clickhouse" ? "default" : (type === "redis" || type === "elasticsearch" || type === "chroma") ? "" : "root";', + 'type === "clickhouse" ? "default" : (type === "redis" || type === "elasticsearch" || type === "chroma" || type === "qdrant") ? "" : "root";', ); expect(source).toContain( - 'placeholder={(dbType === "elasticsearch" || dbType === "chroma") ? "未开启认证可留空" : undefined}', + 'placeholder={(dbType === "elasticsearch" || dbType === "chroma" || dbType === "qdrant") ? "未开启认证可留空" : undefined}', ); expect(source).toContain('label="显示数据库 (留空显示全部)"'); }); @@ -52,6 +52,18 @@ describe('ConnectionModal data source registry', () => { expect(source).toContain('return "http://127.0.0.1:8000/default_database?tenant=default_tenant";'); expect(source).toContain('return "tenant=default_tenant&apiKey=...";'); }); + + it('exposes Qdrant in the create-connection picker with vector defaults', () => { + expect(source).toContain("case 'qdrant':"); + expect(source).toContain('return 6333;'); + expect(source).toContain('qdrant: ["http", "https", "qdrant"]'); + expect(source).toContain("key: 'qdrant'"); + expect(source).toContain("name: 'Qdrant'"); + expect(source).toContain('type === "qdrant"'); + expect(source).toContain("return 'Collection 浏览、向量搜索和 Payload 过滤';"); + expect(source).toContain('return "http://127.0.0.1:6333";'); + expect(source).toContain('return "apiKey=...";'); + }); }); describe('ConnectionModal Redis Sentinel configuration', () => { diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx index c1f1fc8..876cfe3 100644 --- a/frontend/src/components/ConnectionModal.tsx +++ b/frontend/src/components/ConnectionModal.tsx @@ -1789,7 +1789,7 @@ const ConnectionModal: React.FC<{ parsedValues.useSSL = false; parsedValues.sslMode = "disable"; } - } else if (type === "chroma") { + } else if (type === "chroma" || type === "qdrant") { const tls = String( parsed.params.get("tls") || parsed.params.get("ssl") || @@ -1870,6 +1870,9 @@ const ConnectionModal: React.FC<{ if (dbType === "chroma") { return "http://127.0.0.1:8000/default_database?tenant=default_tenant"; } + if (dbType === "qdrant") { + return "http://127.0.0.1:6333"; + } if (dbType === "redis") { return "redis://:pass@127.0.0.1:6379,127.0.0.2:6379/0?topology=cluster 或 redis://:pass@10.0.0.1:26379,10.0.0.2:26379/0?topology=sentinel&master=mymaster"; } @@ -1913,6 +1916,8 @@ const ConnectionModal: React.FC<{ return "retryWrites=true&readPreference=secondaryPreferred"; case "chroma": return "tenant=default_tenant&apiKey=..."; + case "qdrant": + return "apiKey=..."; case "dameng": return "schema=SYSDBA"; case "tdengine": @@ -2054,7 +2059,7 @@ const ConnectionModal: React.FC<{ const scheme = type === "postgres" ? "postgresql" - : type === "chroma" + : type === "chroma" || type === "qdrant" ? values.useSSL ? "https" : "http" @@ -2108,7 +2113,7 @@ const ConnectionModal: React.FC<{ if (mode === "skip-verify" || mode === "preferred") { params.set("skip_verify", "true"); } - } else if (type === "chroma") { + } else if (type === "chroma" || type === "qdrant") { if (mode === "skip-verify" || mode === "preferred") { params.set("skip_verify", "true"); } @@ -3705,7 +3710,7 @@ const ConnectionModal: React.FC<{ }); } else if (type !== "custom") { const defaultUser = - type === "clickhouse" ? "default" : (type === "redis" || type === "elasticsearch" || type === "chroma") ? "" : "root"; + type === "clickhouse" ? "default" : (type === "redis" || type === "elasticsearch" || type === "chroma" || type === "qdrant") ? "" : "root"; const sslCapableType = supportsSSLForType(type); setUseSSL(false); setUseHttpTunnel(false); @@ -5001,13 +5006,13 @@ const ConnectionModal: React.FC<{ name="user" label="用户名" rules={ - (dbType === "mongodb" || dbType === "elasticsearch" || dbType === "chroma") + (dbType === "mongodb" || dbType === "elasticsearch" || dbType === "chroma" || dbType === "qdrant") ? [] : [createUriAwareRequiredRule("请输入用户名")] } style={{ marginBottom: 0 }} > - + { expect(markup).toContain('>Ch'); }); + it('includes Qdrant in the selectable database icons', () => { + expect(DB_ICON_TYPES).toContain('qdrant'); + expect(getDbIconLabel('qdrant')).toBe('Qdrant'); + const markup = renderToStaticMarkup(<>{getDbIcon('qdrant', undefined, 22)}); + expect(markup).toContain('>Qd'); + }); + it('wraps database icons in a consistent frame for sidebar sizing', () => { const mysqlMarkup = renderToStaticMarkup(<>{getDbIcon('mysql', undefined, 22)}); const jvmMarkup = renderToStaticMarkup(<>{getDbIcon('jvm', undefined, 22)}); diff --git a/frontend/src/components/DatabaseIcons.tsx b/frontend/src/components/DatabaseIcons.tsx index ae7280c..e224cf6 100644 --- a/frontend/src/components/DatabaseIcons.tsx +++ b/frontend/src/components/DatabaseIcons.tsx @@ -50,6 +50,7 @@ const DB_DEFAULT_COLORS: Record = { iris: '#1F6FEB', tdengine: '#2962FF', chroma: '#7C3AED', + qdrant: '#DC244C', diros: '#0050B3', starrocks: '#00A6A6', sphinx: '#2F5D62', @@ -182,6 +183,9 @@ const TDengineIcon: React.FC = ({ size = 16, color }) => ( const ChromaIcon: React.FC = ({ size = 16, color }) => ( ); +const QdrantIcon: React.FC = ({ size = 16, color }) => ( + +); const JVMIcon: React.FC = ({ size = 16, color }) => ( ); @@ -236,6 +240,7 @@ const DB_ICON_MAP: Record> = { iris: IrisIcon, tdengine: TDengineIcon, chroma: ChromaIcon, + qdrant: QdrantIcon, elasticsearch: ElasticsearchIcon, custom: CustomIcon, }; @@ -244,7 +249,7 @@ const DB_ICON_MAP: Record> = { export const DB_ICON_TYPES: string[] = [ 'mysql', 'mariadb', 'oceanbase', 'postgres', 'redis', 'mongodb', 'jvm', 'oracle', 'sqlserver', 'sqlite', 'duckdb', 'clickhouse', 'starrocks', - 'kingbase', 'dameng', 'vastbase', 'opengauss', 'highgo', 'iris', 'tdengine', 'chroma', 'elasticsearch', 'custom', + 'kingbase', 'dameng', 'vastbase', 'opengauss', 'highgo', 'iris', 'tdengine', 'chroma', 'qdrant', 'elasticsearch', 'custom', ]; /** 该类型是否有品牌 SVG 文件 */ @@ -268,6 +273,7 @@ export const getDbIconLabel = (type: string): string => { duckdb: 'DuckDB', kingbase: '金仓', dameng: '达梦', vastbase: 'VastBase', opengauss: 'OpenGauss', highgo: '瀚高', iris: 'InterSystems IRIS', tdengine: 'TDengine', chroma: 'Chroma', + qdrant: 'Qdrant', elasticsearch: 'Elasticsearch', custom: '自定义', }; diff --git a/frontend/src/utils/connectionDriverType.test.ts b/frontend/src/utils/connectionDriverType.test.ts index 2b3da04..e1252dd 100644 --- a/frontend/src/utils/connectionDriverType.test.ts +++ b/frontend/src/utils/connectionDriverType.test.ts @@ -14,6 +14,8 @@ describe('connectionDriverType', () => { expect(normalizeDriverType('elastic')).toBe('elasticsearch'); expect(normalizeDriverType('chromadb')).toBe('chroma'); expect(normalizeDriverType('chroma-db')).toBe('chroma'); + expect(normalizeDriverType('qdrantdb')).toBe('qdrant'); + expect(normalizeDriverType('qdrant-db')).toBe('qdrant'); expect(normalizeDriverType('doris')).toBe('diros'); expect(normalizeDriverType('open-gauss')).toBe('opengauss'); expect(normalizeDriverType('InterSystemsIRIS')).toBe('iris'); diff --git a/frontend/src/utils/connectionDriverType.ts b/frontend/src/utils/connectionDriverType.ts index 7f9a93f..9cb0659 100644 --- a/frontend/src/utils/connectionDriverType.ts +++ b/frontend/src/utils/connectionDriverType.ts @@ -16,6 +16,7 @@ export const normalizeDriverType = (value: string): string => { if (normalized === 'postgresql' || normalized === 'pg' || normalized === 'pq' || normalized === 'pgx') return 'postgres'; if (normalized === 'elastic') return 'elasticsearch'; if (normalized === 'chromadb' || normalized === 'chroma-db') return 'chroma'; + if (normalized === 'qdrantdb' || normalized === 'qdrant-db') return 'qdrant'; if (normalized === 'doris') return 'diros'; if ( normalized === 'open_gauss' || diff --git a/frontend/src/utils/connectionModalPresentation.test.ts b/frontend/src/utils/connectionModalPresentation.test.ts index e80a9be..6a6f424 100644 --- a/frontend/src/utils/connectionModalPresentation.test.ts +++ b/frontend/src/utils/connectionModalPresentation.test.ts @@ -88,6 +88,7 @@ describe('connectionModalPresentation', () => { 'mongodb', 'elasticsearch', 'chroma', + 'qdrant', 'redis', 'tdengine', 'custom', @@ -165,6 +166,14 @@ describe('connectionModalPresentation', () => { 'credentials', 'databaseScope', ]); + expect(resolveConnectionConfigLayout('qdrant').sections).toEqual([ + 'identity', + 'uri', + 'target', + 'service', + 'credentials', + 'databaseScope', + ]); }); it('uses localized labels for layout kinds shown in the modal', () => { diff --git a/frontend/src/utils/connectionModalPresentation.ts b/frontend/src/utils/connectionModalPresentation.ts index 7dd7b3c..5ce552c 100644 --- a/frontend/src/utils/connectionModalPresentation.ts +++ b/frontend/src/utils/connectionModalPresentation.ts @@ -252,7 +252,7 @@ export const resolveConnectionConfigLayout = ( ], }; } - if (type === 'chroma') { + if (type === 'chroma' || type === 'qdrant') { return { kind: 'vector', sections: [ diff --git a/frontend/src/utils/connectionTypeCapabilities.test.ts b/frontend/src/utils/connectionTypeCapabilities.test.ts index 9301efe..14ccd89 100644 --- a/frontend/src/utils/connectionTypeCapabilities.test.ts +++ b/frontend/src/utils/connectionTypeCapabilities.test.ts @@ -18,6 +18,7 @@ describe('connectionTypeCapabilities', () => { expect(singleHostUriSchemesByType.dameng).toEqual(['dameng', 'dm']); expect(singleHostUriSchemesByType.elasticsearch).toEqual(['http', 'https']); expect(singleHostUriSchemesByType.chroma).toEqual(['http', 'https', 'chroma']); + expect(singleHostUriSchemesByType.qdrant).toEqual(['http', 'https', 'qdrant']); expect(singleHostUriSchemesByType.redis).toEqual(['redis']); }); @@ -26,6 +27,7 @@ describe('connectionTypeCapabilities', () => { expect(supportsSSLForType('MongoDB')).toBe(true); expect(supportsSSLForType('elasticsearch')).toBe(true); expect(supportsSSLForType('chroma')).toBe(true); + expect(supportsSSLForType('qdrant')).toBe(true); expect(supportsSSLForType('tdengine')).toBe(true); expect(supportsSSLForType('dameng')).toBe(true); expect(supportsSSLForType('sqlite')).toBe(false); @@ -40,6 +42,8 @@ describe('connectionTypeCapabilities', () => { expect(supportsSSLClientCertificateForType('redis')).toBe(true); expect(supportsSSLCAPathForType('chroma')).toBe(true); expect(supportsSSLClientCertificateForType('chroma')).toBe(false); + expect(supportsSSLCAPathForType('qdrant')).toBe(true); + expect(supportsSSLClientCertificateForType('qdrant')).toBe(false); }); it('detects postgres-compatible SSL parameter dialects', () => { @@ -68,6 +72,7 @@ describe('connectionTypeCapabilities', () => { expect(supportsConnectionParamsForType('tdengine')).toBe(true); expect(supportsConnectionParamsForType('elasticsearch')).toBe(true); expect(supportsConnectionParamsForType('chroma')).toBe(true); + expect(supportsConnectionParamsForType('qdrant')).toBe(true); expect(supportsConnectionParamsForType('redis')).toBe(false); expect(supportsConnectionParamsForType('sqlite')).toBe(false); expect(supportsConnectionParamsForType('jvm')).toBe(false); diff --git a/frontend/src/utils/connectionTypeCapabilities.ts b/frontend/src/utils/connectionTypeCapabilities.ts index e158a13..7f3032d 100644 --- a/frontend/src/utils/connectionTypeCapabilities.ts +++ b/frontend/src/utils/connectionTypeCapabilities.ts @@ -13,6 +13,7 @@ export const singleHostUriSchemesByType: Record = { vastbase: ["vastbase"], elasticsearch: ["http", "https"], chroma: ["http", "https", "chroma"], + qdrant: ["http", "https", "qdrant"], }; const normalizeConnectionType = (type: string) => @@ -42,6 +43,7 @@ const sslSupportedTypes = new Set([ "tdengine", "elasticsearch", "chroma", + "qdrant", ]); export const supportsSSLForType = (type: string) => @@ -65,6 +67,7 @@ const sslCAPathSupportedTypes = new Set([ "redis", "elasticsearch", "chroma", + "qdrant", ]); const sslClientCertificateSupportedTypes = new Set([ @@ -127,4 +130,5 @@ export const supportsConnectionParamsForType = (type: string) => type === "dameng" || type === "tdengine" || type === "elasticsearch" || - type === "chroma"; + type === "chroma" || + type === "qdrant"; diff --git a/frontend/src/utils/connectionTypeCatalog.test.ts b/frontend/src/utils/connectionTypeCatalog.test.ts index 1a841ed..6e357a1 100644 --- a/frontend/src/utils/connectionTypeCatalog.test.ts +++ b/frontend/src/utils/connectionTypeCatalog.test.ts @@ -25,6 +25,7 @@ describe('connectionTypeCatalog', () => { expect(keys).toContain('redis'); expect(keys).toContain('elasticsearch'); expect(keys).toContain('chroma'); + expect(keys).toContain('qdrant'); expect(keys).toContain('jvm'); expect(keys).toContain('custom'); expect(new Set(keys).size).toBe(keys.length); @@ -40,6 +41,7 @@ describe('connectionTypeCatalog', () => { expect(getConnectionTypeDefaultPort('mongodb')).toBe(27017); expect(getConnectionTypeDefaultPort('elasticsearch')).toBe(9200); expect(getConnectionTypeDefaultPort('chroma')).toBe(8000); + expect(getConnectionTypeDefaultPort('qdrant')).toBe(6333); expect(getConnectionTypeDefaultPort('sqlite')).toBe(0); expect(getConnectionTypeDefaultPort('duckdb')).toBe(0); expect(getConnectionTypeDefaultPort('unknown')).toBe(3306); @@ -50,6 +52,7 @@ describe('connectionTypeCatalog', () => { expect(getConnectionTypeHint('mongodb')).toBe('单机 / 副本集'); expect(getConnectionTypeHint('elasticsearch')).toContain('Mapping'); expect(getConnectionTypeHint('chroma')).toContain('向量'); + expect(getConnectionTypeHint('qdrant')).toContain('Payload'); expect(getConnectionTypeHint('oceanbase')).toBe('MySQL / Oracle 租户'); expect(getConnectionTypeHint('duckdb')).toBe('本地文件连接'); expect(getConnectionTypeHint('mysql')).toBe('标准连接配置'); diff --git a/frontend/src/utils/connectionTypeCatalog.ts b/frontend/src/utils/connectionTypeCatalog.ts index cd508a5..c32e0ea 100644 --- a/frontend/src/utils/connectionTypeCatalog.ts +++ b/frontend/src/utils/connectionTypeCatalog.ts @@ -49,6 +49,7 @@ export const CONNECTION_TYPE_GROUPS: ConnectionTypeCatalogGroup[] = [ label: '向量数据库', items: [ { key: 'chroma', name: 'Chroma' }, + { key: 'qdrant', name: 'Qdrant' }, ], }, { @@ -105,6 +106,8 @@ export const getConnectionTypeDefaultPort = (type: string): number => { return 9200; case 'chroma': return 8000; + case 'qdrant': + return 6333; case 'highgo': return 5866; case 'mariadb': @@ -133,6 +136,8 @@ export const getConnectionTypeHint = (type: string): string => { return '支持索引浏览、Mapping 检查、JSON DSL 和 query_string 查询'; case 'chroma': return 'Collection 浏览、向量检索和元数据过滤'; + case 'qdrant': + return 'Collection 浏览、向量搜索和 Payload 过滤'; case 'oceanbase': return 'MySQL / Oracle 租户'; case 'sqlite': diff --git a/frontend/src/utils/dataSourceCapabilities.test.ts b/frontend/src/utils/dataSourceCapabilities.test.ts index 0fed0ed..d7f21f5 100644 --- a/frontend/src/utils/dataSourceCapabilities.test.ts +++ b/frontend/src/utils/dataSourceCapabilities.test.ts @@ -90,6 +90,24 @@ describe('dataSourceCapabilities', () => { }); }); + it('treats Qdrant as a queryable vector datasource without SQL export actions', () => { + expect(getDataSourceCapabilities({ type: 'qdrant' })).toMatchObject({ + type: 'qdrant', + supportsQueryEditor: true, + supportsSqlQueryExport: false, + supportsCopyInsert: false, + supportsCreateDatabase: false, + supportsRenameDatabase: false, + supportsDropDatabase: false, + forceReadOnlyQueryResult: false, + }); + expect(getDataSourceCapabilities({ type: 'custom', driver: 'qdrantdb' })).toMatchObject({ + type: 'qdrant', + supportsQueryEditor: true, + supportsCopyInsert: false, + }); + }); + it('treats OceanBase Oracle protocol as Oracle capabilities', () => { expect(getDataSourceCapabilities({ type: 'oceanbase', diff --git a/frontend/src/utils/dataSourceCapabilities.ts b/frontend/src/utils/dataSourceCapabilities.ts index 2e2ea5f..f92025a 100644 --- a/frontend/src/utils/dataSourceCapabilities.ts +++ b/frontend/src/utils/dataSourceCapabilities.ts @@ -24,6 +24,9 @@ const normalizeDataSourceToken = (raw: string): string => { case 'chromadb': case 'chroma-db': return 'chroma'; + case 'qdrantdb': + case 'qdrant-db': + return 'qdrant'; case 'intersystems': case 'intersystemsiris': case 'inter-systems': diff --git a/frontend/src/utils/sqlDialect.test.ts b/frontend/src/utils/sqlDialect.test.ts index ce4a7a6..ec496a4 100644 --- a/frontend/src/utils/sqlDialect.test.ts +++ b/frontend/src/utils/sqlDialect.test.ts @@ -32,6 +32,8 @@ describe('sqlDialect', () => { expect(resolveSqlDialect('custom', 'elastic')).toBe('elasticsearch'); expect(resolveSqlDialect('ChromaDB')).toBe('chroma'); expect(resolveSqlDialect('custom', 'chroma-db')).toBe('chroma'); + expect(resolveSqlDialect('QdrantDB')).toBe('qdrant'); + expect(resolveSqlDialect('custom', 'qdrant-db')).toBe('qdrant'); expect(resolveSqlDialect('OceanBase', '', { oceanBaseProtocol: 'oracle' })).toBe('oracle'); expect(resolveSqlDialect('custom', 'oceanbase', { oceanBaseProtocol: 'oracle' })).toBe('oracle'); expect(isMysqlFamilyDialect('mariadb')).toBe(true); diff --git a/frontend/src/utils/sqlDialect.ts b/frontend/src/utils/sqlDialect.ts index bdb9b08..c8461a7 100644 --- a/frontend/src/utils/sqlDialect.ts +++ b/frontend/src/utils/sqlDialect.ts @@ -31,6 +31,7 @@ export type SqlDialect = | 'redis' | 'elasticsearch' | 'chroma' + | 'qdrant' | 'unknown' | string; @@ -120,6 +121,10 @@ export const resolveSqlDialect = ( case 'chroma-db': case 'chroma': return 'chroma'; + case 'qdrantdb': + case 'qdrant-db': + case 'qdrant': + return 'qdrant'; default: break; } @@ -146,6 +151,7 @@ export const resolveSqlDialect = ( if (source.includes('iris') || source.includes('intersystems')) return 'iris'; if (source.includes('elastic')) return 'elasticsearch'; if (source.includes('chroma')) return 'chroma'; + if (source.includes('qdrant')) return 'qdrant'; return source; }; diff --git a/internal/app/db_proxy.go b/internal/app/db_proxy.go index 63be9be..4b3120f 100644 --- a/internal/app/db_proxy.go +++ b/internal/app/db_proxy.go @@ -235,6 +235,8 @@ func defaultPortByType(driverType string) int { return 1972 case "chroma": return 8000 + case "qdrant": + return 6333 default: return 0 } diff --git a/internal/db/database.go b/internal/db/database.go index 0305e9e..3ce7dc4 100644 --- a/internal/db/database.go +++ b/internal/db/database.go @@ -483,6 +483,9 @@ var databaseFactories = map[string]databaseFactory{ "chroma": func() Database { return &ChromaDB{} }, + "qdrant": func() Database { + return &QdrantDB{} + }, } func init() { @@ -517,6 +520,8 @@ func normalizeDatabaseType(dbType string) string { return "iris" case "chromadb", "chroma-db": return "chroma" + case "qdrantdb", "qdrant-db": + return "qdrant" default: return normalized } diff --git a/internal/db/driver_support.go b/internal/db/driver_support.go index b70623d..0a9a688 100644 --- a/internal/db/driver_support.go +++ b/internal/db/driver_support.go @@ -17,6 +17,7 @@ var coreBuiltinDrivers = map[string]struct{}{ "oracle": {}, "postgres": {}, "chroma": {}, + "qdrant": {}, } // optionalGoDrivers 表示需要用户“安装启用”后才能使用的纯 Go 驱动。 @@ -69,6 +70,8 @@ func normalizeRuntimeDriverType(driverType string) string { return "elasticsearch" case "chromadb", "chroma-db": return "chroma" + case "qdrantdb", "qdrant-db": + return "qdrant" default: return normalized } @@ -122,6 +125,8 @@ func driverDisplayName(driverType string) string { return "Elasticsearch" case "chroma": return "Chroma" + case "qdrant": + return "Qdrant" default: return strings.ToUpper(strings.TrimSpace(driverType)) } diff --git a/internal/db/qdrant_impl.go b/internal/db/qdrant_impl.go new file mode 100644 index 0000000..adca105 --- /dev/null +++ b/internal/db/qdrant_impl.go @@ -0,0 +1,1049 @@ +package db + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "regexp" + "sort" + "strconv" + "strings" + "time" + + "GoNavi-Wails/internal/connection" + "GoNavi-Wails/internal/logger" + proxytunnel "GoNavi-Wails/internal/proxy" + "GoNavi-Wails/internal/ssh" +) + +const ( + defaultQdrantPort = 6333 + defaultQdrantDatabase = "default" + defaultQdrantQueryTimeout = 30 * time.Second +) + +type QdrantDB struct { + client *http.Client + baseURL string + database string + authHeaders map[string]string + forwarder *ssh.LocalForwarder +} + +type qdrantCollectionInfo struct { + Name string `json:"name"` +} + +type qdrantListCollectionsResponse struct { + Result struct { + Collections []qdrantCollectionInfo `json:"collections"` + } `json:"result"` +} + +type qdrantCollectionResponse struct { + Result map[string]interface{} `json:"result"` +} + +type qdrantPoint struct { + ID interface{} `json:"id"` + Payload map[string]interface{} `json:"payload"` + Vector interface{} `json:"vector"` + Score interface{} `json:"score"` + Version interface{} `json:"version"` +} + +type qdrantScrollResponse struct { + Result struct { + Points []qdrantPoint `json:"points"` + NextPageOffset interface{} `json:"next_page_offset"` + } `json:"result"` +} + +type qdrantSearchResponse struct { + Result []qdrantPoint `json:"result"` +} + +type qdrantCountResponse struct { + Result struct { + Count int64 `json:"count"` + } `json:"result"` +} + +func (q *QdrantDB) Connect(config connection.ConnectionConfig) error { + if q.forwarder != nil { + _ = q.forwarder.Close() + q.forwarder = nil + } + q.client = nil + + runConfig := normalizeQdrantConfig(config) + if runConfig.UseSSH { + forwarder, err := ssh.GetOrCreateLocalForwarder(runConfig.SSH, runConfig.Host, runConfig.Port) + if err != nil { + return fmt.Errorf("创建 SSH 隧道失败:%w", err) + } + q.forwarder = forwarder + + host, portText, err := net.SplitHostPort(forwarder.LocalAddr) + if err != nil { + return fmt.Errorf("解析本地转发地址失败:%w", err) + } + port, err := strconv.Atoi(portText) + if err != nil { + return fmt.Errorf("解析本地端口失败:%w", err) + } + runConfig.Host = host + runConfig.Port = port + runConfig.UseSSH = false + logger.Infof("Qdrant 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port) + } + + q.database = qdrantDatabaseFromConfig(runConfig) + q.baseURL = buildQdrantBaseURL(runConfig) + q.authHeaders = qdrantAuthHeaders(runConfig) + q.client = buildQdrantHTTPClient(runConfig) + + if err := q.Ping(); err != nil { + _ = q.Close() + return err + } + return nil +} + +func (q *QdrantDB) Close() error { + if q.forwarder != nil { + if err := q.forwarder.Close(); err != nil { + logger.Warnf("关闭 Qdrant SSH 端口转发失败:%v", err) + } + q.forwarder = nil + } + q.client = nil + return nil +} + +func (q *QdrantDB) Ping() error { + if q.client == nil { + return fmt.Errorf("连接未打开") + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + var resp qdrantListCollectionsResponse + return q.doJSON(ctx, http.MethodGet, "/collections", nil, &resp) +} + +func (q *QdrantDB) Query(query string) ([]map[string]interface{}, []string, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultQdrantQueryTimeout) + defer cancel() + return q.QueryContext(ctx, query) +} + +func (q *QdrantDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) { + if q.client == nil { + return nil, nil, fmt.Errorf("连接未打开") + } + text := strings.TrimSpace(query) + if text == "" { + return nil, nil, fmt.Errorf("查询语句不能为空") + } + + if strings.HasPrefix(text, "{") { + return q.queryJSON(ctx, text) + } + + if parsed, ok := parseQdrantSQL(text); ok { + if parsed.Count { + total, err := q.countPoints(ctx, parsed.Collection, nil) + if err != nil { + return nil, nil, err + } + return []map[string]interface{}{{"total": total}}, []string{"total"}, nil + } + return q.scrollPoints(ctx, parsed.Collection, parsed.Limit, parsed.Offset, nil, true, parsed.IncludeVector) + } + + return nil, nil, fmt.Errorf("Qdrant 查询仅支持 JSON 命令或简单 SELECT 预览") +} + +func (q *QdrantDB) Exec(query string) (int64, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultQdrantQueryTimeout) + defer cancel() + return q.ExecContext(ctx, query) +} + +func (q *QdrantDB) ExecContext(ctx context.Context, query string) (int64, error) { + if q.client == nil { + return 0, fmt.Errorf("连接未打开") + } + var cmd map[string]interface{} + if err := decodeJSONWithUseNumber([]byte(strings.TrimSpace(query)), &cmd); err != nil { + return 0, fmt.Errorf("Qdrant 写入命令必须是 JSON:%w", err) + } + if name := firstStringValue(cmd, "create_collection", "createCollection", "collection"); name != "" && hasAnyKey(cmd, "create_collection", "createCollection") { + return 1, q.createCollection(ctx, name, cmd) + } + if name := firstStringValue(cmd, "delete_collection", "deleteCollection"); name != "" { + return 1, q.deleteCollection(ctx, name) + } + if name := firstStringValue(cmd, "upsert", "collection"); name != "" && hasAnyKey(cmd, "upsert") { + return q.upsertCommand(ctx, name, cmd) + } + if name := firstStringValue(cmd, "delete", "collection"); name != "" && hasAnyKey(cmd, "delete") { + return q.deleteCommand(ctx, name, cmd) + } + if name := firstStringValue(cmd, "create_payload_index", "createPayloadIndex", "collection"); name != "" && hasAnyKey(cmd, "create_payload_index", "createPayloadIndex") { + return 1, q.createPayloadIndex(ctx, name, cmd) + } + if name := firstStringValue(cmd, "delete_payload_index", "deletePayloadIndex", "collection"); name != "" && hasAnyKey(cmd, "delete_payload_index", "deletePayloadIndex") { + fieldName := firstStringValue(cmd, "field_name", "fieldName", "field") + if fieldName == "" { + return 0, fmt.Errorf("Qdrant 删除 payload index 命令缺少 field_name") + } + return 1, q.deletePayloadIndex(ctx, name, fieldName) + } + return 0, fmt.Errorf("Qdrant JSON 写入命令仅支持 create_collection/delete_collection/upsert/delete/create_payload_index/delete_payload_index") +} + +func (q *QdrantDB) GetDatabases() ([]string, error) { + if q.client == nil { + return nil, fmt.Errorf("连接未打开") + } + return []string{q.database}, nil +} + +func (q *QdrantDB) GetTables(dbName string) ([]string, error) { + collections, err := q.listCollections(context.Background()) + if err != nil { + return nil, err + } + names := make([]string, 0, len(collections)) + for _, item := range collections { + if strings.TrimSpace(item.Name) != "" { + names = append(names, item.Name) + } + } + sort.Strings(names) + return names, nil +} + +func (q *QdrantDB) GetCreateStatement(dbName, tableName string) (string, error) { + info, err := q.getCollectionInfo(context.Background(), tableNameOrDB(dbName, tableName)) + if err != nil { + return "", err + } + payload, _ := json.MarshalIndent(info, "", " ") + return fmt.Sprintf("// Qdrant collection: %s\n%s", tableNameOrDB(dbName, tableName), string(payload)), nil +} + +func (q *QdrantDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) { + rows, _, err := q.scrollPoints(context.Background(), tableNameOrDB(dbName, tableName), 20, nil, nil, true, true) + if err != nil { + return nil, err + } + cols := []connection.ColumnDefinition{ + {Name: "id", Type: "point_id", Nullable: "NO", Key: "PRI", Comment: "Qdrant point id"}, + {Name: "vector", Type: "vector", Nullable: "YES", Comment: "Vector or named vectors"}, + {Name: "payload", Type: "json", Nullable: "YES", Comment: "Full payload object"}, + } + seen := map[string]struct{}{"id": {}, "vector": {}, "payload": {}} + for _, row := range rows { + for key, value := range row { + if _, exists := seen[key]; exists || !strings.HasPrefix(key, "payload.") { + continue + } + seen[key] = struct{}{} + cols = append(cols, connection.ColumnDefinition{ + Name: key, + Type: inferChromaValueType(value), + Nullable: "YES", + Comment: "Payload field", + }) + } + } + return cols, nil +} + +func (q *QdrantDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) { + tables, err := q.GetTables(dbName) + if err != nil { + return nil, err + } + var result []connection.ColumnDefinitionWithTable + for _, table := range tables { + cols, err := q.GetColumns(dbName, table) + if err != nil { + continue + } + for _, col := range cols { + result = append(result, connection.ColumnDefinitionWithTable{ + TableName: table, + Name: col.Name, + Type: col.Type, + Comment: col.Comment, + }) + } + } + return result, nil +} + +func (q *QdrantDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) { + indexes := []connection.IndexDefinition{ + {Name: "PRIMARY", ColumnName: "id", NonUnique: 0, SeqInIndex: 1, IndexType: "PRIMARY"}, + } + info, err := q.getCollectionInfo(context.Background(), tableNameOrDB(dbName, tableName)) + if err == nil { + indexes = append(indexes, qdrantVectorIndexes(info)...) + indexes = append(indexes, qdrantPayloadIndexes(info)...) + } + if len(indexes) == 1 { + indexes = append(indexes, connection.IndexDefinition{Name: "VECTOR", ColumnName: "vector", NonUnique: 1, SeqInIndex: 1, IndexType: "VECTOR"}) + } + return indexes, nil +} + +func (q *QdrantDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) { + return []connection.ForeignKeyDefinition{}, nil +} + +func (q *QdrantDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) { + return []connection.TriggerDefinition{}, nil +} + +func (q *QdrantDB) ApplyChanges(tableName string, changes connection.ChangeSet) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultQdrantQueryTimeout) + defer cancel() + + if len(changes.Deletes) > 0 { + ids := make([]interface{}, 0, len(changes.Deletes)) + for _, row := range changes.Deletes { + if id, ok := qdrantRowID(row); ok { + ids = append(ids, id) + } + } + if len(ids) > 0 { + if _, err := q.deleteCommand(ctx, tableName, map[string]interface{}{"points": ids}); err != nil { + return err + } + } + } + + if len(changes.Updates) > 0 { + var upserts []map[string]interface{} + for _, update := range changes.Updates { + row := make(map[string]interface{}, len(update.Keys)+len(update.Values)) + for k, v := range update.Keys { + row[k] = v + } + for k, v := range update.Values { + row[k] = v + } + if _, hasVector := qdrantRowVector(row); hasVector { + upserts = append(upserts, row) + continue + } + if err := q.setPayloadFromRow(ctx, tableName, row); err != nil { + return err + } + } + if len(upserts) > 0 { + if err := q.upsertRows(ctx, tableName, upserts); err != nil { + return err + } + } + } + + if len(changes.Inserts) > 0 { + if err := q.upsertRows(ctx, tableName, changes.Inserts); err != nil { + return err + } + } + return nil +} + +func normalizeQdrantConfig(config connection.ConnectionConfig) connection.ConnectionConfig { + runConfig := applyQdrantURI(config) + if strings.TrimSpace(runConfig.Host) == "" { + runConfig.Host = "localhost" + } + if runConfig.Port <= 0 { + runConfig.Port = defaultQdrantPort + } + if strings.TrimSpace(runConfig.SSLMode) == "" && runConfig.UseSSL { + runConfig.SSLMode = "required" + } + return runConfig +} + +func applyQdrantURI(config connection.ConnectionConfig) connection.ConnectionConfig { + uriText := strings.TrimSpace(config.URI) + if uriText == "" { + return config + } + parsed, err := url.Parse(uriText) + if err != nil { + return config + } + scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) + if scheme != "http" && scheme != "https" && scheme != "qdrant" { + return config + } + if parsed.User != nil { + if strings.TrimSpace(config.User) == "" { + config.User = parsed.User.Username() + } + if pass, ok := parsed.User.Password(); ok && config.Password == "" { + config.Password = pass + } + } + if scheme == "https" { + config.UseSSL = true + } + if host := strings.TrimSpace(parsed.Host); host != "" { + if h, port, ok := parseHostPortWithDefault(host, defaultQdrantPort); ok { + config.Host = h + config.Port = port + } + } + if dbName := strings.Trim(strings.TrimSpace(parsed.Path), "/"); dbName != "" && !strings.HasPrefix(dbName, "collections") && strings.TrimSpace(config.Database) == "" { + config.Database = dbName + } + return config +} + +func buildQdrantBaseURL(config connection.ConnectionConfig) string { + scheme := "http" + if config.UseSSL { + scheme = "https" + } + return fmt.Sprintf("%s://%s:%d", scheme, strings.TrimSpace(config.Host), config.Port) +} + +func qdrantDatabaseFromConfig(config connection.ConnectionConfig) string { + if dbName := strings.TrimSpace(config.Database); dbName != "" { + return dbName + } + return defaultQdrantDatabase +} + +func qdrantConnectionParams(config connection.ConnectionConfig) url.Values { + params := url.Values{} + mergeConnectionParamValues(params, connectionParamsFromURI(config.URI, "http", "https", "qdrant")) + mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams)) + return params +} + +func qdrantAuthHeaders(config connection.ConnectionConfig) map[string]string { + headers := make(map[string]string) + params := qdrantConnectionParams(config) + apiKey := firstNonEmpty(params.Get("apiKey"), params.Get("apikey"), params.Get("api-key"), params.Get("token"), params.Get("authToken")) + if apiKey == "" && strings.TrimSpace(config.User) == "" { + apiKey = strings.TrimSpace(config.Password) + } + if apiKey != "" { + headers["api-key"] = apiKey + } else if user := strings.TrimSpace(config.User); user != "" { + raw := user + ":" + config.Password + headers["Authorization"] = "Basic " + base64.StdEncoding.EncodeToString([]byte(raw)) + } + if headerName := strings.TrimSpace(params.Get("authHeader")); headerName != "" { + if headerValue := strings.TrimSpace(params.Get("authHeaderValue")); headerValue != "" && isSafeConnectionParamKey(headerName) { + headers[headerName] = headerValue + } + } + return headers +} + +func buildQdrantHTTPClient(config connection.ConnectionConfig) *http.Client { + transport := http.DefaultTransport.(*http.Transport).Clone() + if tlsConfig, err := resolveGenericTLSConfig(config); err == nil && tlsConfig != nil { + transport.TLSClientConfig = tlsConfig + } + if config.UseProxy { + proxyCfg := config.Proxy + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return proxytunnel.DialContext(ctx, proxyCfg, network, addr) + } + } + return &http.Client{Transport: transport, Timeout: getConnectTimeout(config)} +} + +func (q *QdrantDB) doJSON(ctx context.Context, method, path string, body interface{}, out interface{}) error { + if q.client == nil { + return fmt.Errorf("连接未打开") + } + var reader io.Reader + if body != nil { + payload, err := json.Marshal(body) + if err != nil { + return err + } + reader = bytes.NewReader(payload) + } + req, err := http.NewRequestWithContext(ctx, method, strings.TrimRight(q.baseURL, "/")+path, reader) + if err != nil { + return err + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + req.Header.Set("Accept", "application/json") + for key, value := range q.authHeaders { + if strings.TrimSpace(key) != "" && strings.TrimSpace(value) != "" { + req.Header.Set(key, value) + } + } + res, err := q.client.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + resBody, err := io.ReadAll(res.Body) + if err != nil { + return err + } + if res.StatusCode < 200 || res.StatusCode >= 300 { + message := strings.TrimSpace(string(resBody)) + if message == "" { + message = res.Status + } + return fmt.Errorf("Qdrant API %s %s 失败:%s", method, path, message) + } + if out == nil || len(bytes.TrimSpace(resBody)) == 0 { + return nil + } + if err := decodeJSONWithUseNumber(resBody, out); err != nil { + return fmt.Errorf("解析 Qdrant 响应失败:%w", err) + } + return nil +} + +func (q *QdrantDB) listCollections(ctx context.Context) ([]qdrantCollectionInfo, error) { + var resp qdrantListCollectionsResponse + if err := q.doJSON(ctx, http.MethodGet, "/collections", nil, &resp); err != nil { + return nil, err + } + return resp.Result.Collections, nil +} + +func (q *QdrantDB) getCollectionInfo(ctx context.Context, collection string) (map[string]interface{}, error) { + name := strings.TrimSpace(collection) + if name == "" { + return nil, fmt.Errorf("collection 名称不能为空") + } + var resp qdrantCollectionResponse + if err := q.doJSON(ctx, http.MethodGet, fmt.Sprintf("/collections/%s", url.PathEscape(name)), nil, &resp); err != nil { + return nil, err + } + return resp.Result, nil +} + +func (q *QdrantDB) scrollPoints(ctx context.Context, collection string, limit int, offset interface{}, filter interface{}, withPayload bool, withVector bool) ([]map[string]interface{}, []string, error) { + name := strings.TrimSpace(collection) + if name == "" { + return nil, nil, fmt.Errorf("collection 名称不能为空") + } + if limit <= 0 { + limit = 200 + } + body := map[string]interface{}{ + "limit": limit, + "with_payload": withPayload, + "with_vector": withVector, + } + if offset != nil && strings.TrimSpace(fmt.Sprintf("%v", offset)) != "" { + body["offset"] = qdrantNormalizePointID(offset) + } + if filter != nil { + body["filter"] = filter + } + var resp qdrantScrollResponse + if err := q.doJSON(ctx, http.MethodPost, fmt.Sprintf("/collections/%s/points/scroll", url.PathEscape(name)), body, &resp); err != nil { + return nil, nil, err + } + rows := qdrantPointRows(resp.Result.Points) + if resp.Result.NextPageOffset != nil { + for _, row := range rows { + row["next_page_offset"] = resp.Result.NextPageOffset + } + } + return rows, collectColumns(rows), nil +} + +func (q *QdrantDB) searchPoints(ctx context.Context, collection string, cmd map[string]interface{}) ([]map[string]interface{}, []string, error) { + name := strings.TrimSpace(collection) + if name == "" { + return nil, nil, fmt.Errorf("collection 名称不能为空") + } + vector := firstExisting(cmd, "vector", "query_vector", "queryVector") + if vector == nil { + return nil, nil, fmt.Errorf("Qdrant search 命令缺少 vector") + } + body := map[string]interface{}{ + "vector": normalizeQdrantVector(vector), + "limit": intFromAny(firstExisting(cmd, "limit", "n_results", "nResults"), 10), + "with_payload": qdrantBoolValue(firstExisting(cmd, "with_payload", "withPayload"), true), + "with_vector": qdrantBoolValue(firstExisting(cmd, "with_vector", "withVector"), true), + } + for _, key := range []string{"filter", "params", "score_threshold", "offset"} { + if value, ok := cmd[key]; ok { + body[key] = value + } + } + var resp qdrantSearchResponse + if err := q.doJSON(ctx, http.MethodPost, fmt.Sprintf("/collections/%s/points/search", url.PathEscape(name)), body, &resp); err != nil { + return nil, nil, err + } + rows := qdrantPointRows(resp.Result) + return rows, collectColumns(rows), nil +} + +func (q *QdrantDB) countPoints(ctx context.Context, collection string, filter interface{}) (int64, error) { + name := strings.TrimSpace(collection) + if name == "" { + return 0, fmt.Errorf("collection 名称不能为空") + } + body := map[string]interface{}{"exact": true} + if filter != nil { + body["filter"] = filter + } + var resp qdrantCountResponse + if err := q.doJSON(ctx, http.MethodPost, fmt.Sprintf("/collections/%s/points/count", url.PathEscape(name)), body, &resp); err != nil { + return 0, err + } + return resp.Result.Count, nil +} + +func (q *QdrantDB) queryJSON(ctx context.Context, text string) ([]map[string]interface{}, []string, error) { + var cmd map[string]interface{} + if err := decodeJSONWithUseNumber([]byte(text), &cmd); err != nil { + return nil, nil, fmt.Errorf("Qdrant JSON 命令解析失败:%w", err) + } + if hasAnyKey(cmd, "list_collections", "listCollections") { + collections, err := q.listCollections(ctx) + if err != nil { + return nil, nil, err + } + rows := make([]map[string]interface{}, 0, len(collections)) + for _, collection := range collections { + rows = append(rows, map[string]interface{}{"name": collection.Name}) + } + return rows, collectColumns(rows), nil + } + if name := firstStringValue(cmd, "get_collection", "getCollection"); name != "" { + info, err := q.getCollectionInfo(ctx, name) + if err != nil { + return nil, nil, err + } + return []map[string]interface{}{info}, collectColumns([]map[string]interface{}{info}), nil + } + if name := firstStringValue(cmd, "count", "collection"); name != "" && hasAnyKey(cmd, "count") { + total, err := q.countPoints(ctx, name, cmd["filter"]) + if err != nil { + return nil, nil, err + } + return []map[string]interface{}{{"total": total}}, []string{"total"}, nil + } + if name := firstStringValue(cmd, "search", "query", "collection"); name != "" && hasAnyKey(cmd, "search", "query", "vector", "query_vector", "queryVector") { + return q.searchPoints(ctx, name, cmd) + } + if name := firstStringValue(cmd, "scroll", "get", "collection"); name != "" { + limit := intFromAny(cmd["limit"], 200) + offset := firstExisting(cmd, "offset", "next_page_offset", "nextPageOffset") + return q.scrollPoints( + ctx, + name, + limit, + offset, + cmd["filter"], + qdrantBoolValue(firstExisting(cmd, "with_payload", "withPayload"), true), + qdrantBoolValue(firstExisting(cmd, "with_vector", "withVector"), true), + ) + } + return nil, nil, fmt.Errorf("Qdrant JSON 查询命令仅支持 list_collections/get_collection/count/scroll/search") +} + +func (q *QdrantDB) createCollection(ctx context.Context, name string, cmd map[string]interface{}) error { + collection := strings.TrimSpace(name) + if collection == "" { + return fmt.Errorf("collection 名称不能为空") + } + body := make(map[string]interface{}) + if vectors, ok := cmd["vectors"]; ok { + body["vectors"] = vectors + } else { + size := intFromAny(firstExisting(cmd, "size", "vector_size", "vectorSize"), 0) + if size <= 0 { + return fmt.Errorf("Qdrant create_collection 命令缺少 vectors 或 size") + } + distance := firstStringValue(cmd, "distance", "metric") + if distance == "" { + distance = "Cosine" + } + body["vectors"] = map[string]interface{}{"size": size, "distance": distance} + } + for _, key := range []string{ + "sparse_vectors", + "shard_number", + "replication_factor", + "write_consistency_factor", + "on_disk_payload", + "hnsw_config", + "optimizers_config", + "wal_config", + "quantization_config", + "strict_mode_config", + "init_from", + } { + if value, ok := cmd[key]; ok { + body[key] = value + } + } + return q.doJSON(ctx, http.MethodPut, fmt.Sprintf("/collections/%s", url.PathEscape(collection)), body, nil) +} + +func (q *QdrantDB) deleteCollection(ctx context.Context, name string) error { + collection := strings.TrimSpace(name) + if collection == "" { + return fmt.Errorf("collection 名称不能为空") + } + return q.doJSON(ctx, http.MethodDelete, fmt.Sprintf("/collections/%s", url.PathEscape(collection)), nil, nil) +} + +func (q *QdrantDB) createPayloadIndex(ctx context.Context, collection string, cmd map[string]interface{}) error { + fieldName := firstStringValue(cmd, "field_name", "fieldName", "field") + if fieldName == "" { + return fmt.Errorf("Qdrant create_payload_index 命令缺少 field_name") + } + fieldSchema := firstExisting(cmd, "field_schema", "fieldSchema", "schema") + if fieldSchema == nil { + fieldSchema = "keyword" + } + body := map[string]interface{}{ + "field_name": fieldName, + "field_schema": fieldSchema, + } + return q.doJSON(ctx, http.MethodPut, fmt.Sprintf("/collections/%s/index", url.PathEscape(collection)), body, nil) +} + +func (q *QdrantDB) deletePayloadIndex(ctx context.Context, collection, fieldName string) error { + return q.doJSON(ctx, http.MethodDelete, fmt.Sprintf("/collections/%s/index/%s", url.PathEscape(collection), url.PathEscape(fieldName)), nil, nil) +} + +func (q *QdrantDB) upsertCommand(ctx context.Context, collection string, cmd map[string]interface{}) (int64, error) { + if rowsValue, ok := cmd["rows"].([]interface{}); ok { + rows := make([]map[string]interface{}, 0, len(rowsValue)) + for _, raw := range rowsValue { + if row, ok := raw.(map[string]interface{}); ok { + rows = append(rows, row) + } + } + return int64(len(rows)), q.upsertRows(ctx, collection, rows) + } + if points, ok := cmd["points"]; ok { + body := map[string]interface{}{"points": points} + return int64(len(anySlice(points))), q.doJSON(ctx, http.MethodPut, fmt.Sprintf("/collections/%s/points?wait=true", url.PathEscape(collection)), body, nil) + } + return 0, fmt.Errorf("Qdrant upsert 命令缺少 rows 或 points") +} + +func (q *QdrantDB) deleteCommand(ctx context.Context, collection string, cmd map[string]interface{}) (int64, error) { + body := make(map[string]interface{}) + if points, ok := cmd["points"]; ok { + body["points"] = qdrantPointIDSlice(points) + } else if ids, ok := cmd["ids"]; ok { + body["points"] = qdrantPointIDSlice(ids) + } else if filter, ok := cmd["filter"]; ok { + body["filter"] = filter + } + if len(body) == 0 { + return 0, fmt.Errorf("Qdrant delete 命令缺少 points/ids/filter") + } + count := int64(len(anySlice(firstExisting(body, "points")))) + return count, q.doJSON(ctx, http.MethodPost, fmt.Sprintf("/collections/%s/points/delete?wait=true", url.PathEscape(collection)), body, nil) +} + +func (q *QdrantDB) upsertRows(ctx context.Context, collection string, rows []map[string]interface{}) error { + if len(rows) == 0 { + return nil + } + points := make([]map[string]interface{}, 0, len(rows)) + for _, row := range rows { + id, ok := qdrantRowID(row) + if !ok { + return fmt.Errorf("Qdrant 写入行缺少 id") + } + vector, hasVector := qdrantRowVector(row) + if !hasVector { + return fmt.Errorf("Qdrant upsert 行缺少 vector/embedding") + } + points = append(points, map[string]interface{}{ + "id": id, + "vector": vector, + "payload": qdrantPayloadFromRow(row), + }) + } + body := map[string]interface{}{"points": points} + return q.doJSON(ctx, http.MethodPut, fmt.Sprintf("/collections/%s/points?wait=true", url.PathEscape(collection)), body, nil) +} + +func (q *QdrantDB) setPayloadFromRow(ctx context.Context, collection string, row map[string]interface{}) error { + id, ok := qdrantRowID(row) + if !ok { + return fmt.Errorf("Qdrant payload 更新缺少 id") + } + payload := qdrantPayloadFromRow(row) + if len(payload) == 0 { + return nil + } + body := map[string]interface{}{ + "points": []interface{}{id}, + "payload": payload, + } + return q.doJSON(ctx, http.MethodPost, fmt.Sprintf("/collections/%s/points/payload?wait=true", url.PathEscape(collection)), body, nil) +} + +type qdrantParsedSQL struct { + Collection string + Limit int + Offset interface{} + Count bool + IncludeVector bool +} + +var qdrantSQLFromRE = regexp.MustCompile(`(?i)\bFROM\s+(?:"([^"]+)"|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z0-9_.\-]+))`) +var qdrantSQLLimitRE = regexp.MustCompile(`(?i)\bLIMIT\s+(\d+)`) +var qdrantSQLOffsetRE = regexp.MustCompile(`(?i)\bOFFSET\s+([a-zA-Z0-9_.\-]+)`) + +func parseQdrantSQL(sqlText string) (qdrantParsedSQL, bool) { + text := strings.TrimSpace(sqlText) + if !strings.HasPrefix(strings.ToLower(text), "select") { + return qdrantParsedSQL{}, false + } + matches := qdrantSQLFromRE.FindStringSubmatch(text) + if len(matches) == 0 { + return qdrantParsedSQL{}, false + } + collection := firstNonEmpty(matches[1], matches[2], matches[3]) + if collection == "" { + return qdrantParsedSQL{}, false + } + parsed := qdrantParsedSQL{Collection: collection, Limit: 200} + lower := strings.ToLower(text) + parsed.Count = strings.Contains(lower, "count(") + parsed.IncludeVector = strings.Contains(lower, "vector") + if m := qdrantSQLLimitRE.FindStringSubmatch(text); len(m) > 1 { + parsed.Limit, _ = strconv.Atoi(m[1]) + } + if m := qdrantSQLOffsetRE.FindStringSubmatch(text); len(m) > 1 { + parsed.Offset = qdrantNormalizePointID(m[1]) + } + return parsed, true +} + +func qdrantPointRows(points []qdrantPoint) []map[string]interface{} { + rows := make([]map[string]interface{}, 0, len(points)) + for _, point := range points { + row := map[string]interface{}{"id": point.ID} + if point.Score != nil { + row["score"] = point.Score + } + if point.Version != nil { + row["version"] = point.Version + } + if point.Vector != nil { + row["vector"] = normalizeJSONLikeValue(point.Vector) + } + if point.Payload != nil { + row["payload"] = point.Payload + for key, value := range point.Payload { + row["payload."+key] = value + } + } + rows = append(rows, row) + } + return rows +} + +func qdrantRowID(row map[string]interface{}) (interface{}, bool) { + raw := firstExisting(row, "id", "_id") + if raw == nil { + return nil, false + } + text := strings.TrimSpace(fmt.Sprintf("%v", raw)) + if text == "" || text == "" { + return nil, false + } + return qdrantNormalizePointID(raw), true +} + +func qdrantNormalizePointID(value interface{}) interface{} { + switch v := value.(type) { + case json.Number: + if n, err := v.Int64(); err == nil { + return n + } + case float64: + if v == float64(int64(v)) { + return int64(v) + } + case float32: + if v == float32(int64(v)) { + return int64(v) + } + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return v + case string: + text := strings.TrimSpace(v) + if n, err := strconv.ParseInt(text, 10, 64); err == nil { + return n + } + return text + } + return value +} + +func qdrantPointIDSlice(value interface{}) []interface{} { + items := anySlice(value) + result := make([]interface{}, 0, len(items)) + for _, item := range items { + result = append(result, qdrantNormalizePointID(item)) + } + return result +} + +func qdrantRowVector(row map[string]interface{}) (interface{}, bool) { + vector := firstExisting(row, "vector", "_vector", "vectors", "embedding", "_embedding") + if vector == nil { + return nil, false + } + return normalizeQdrantVector(vector), true +} + +func normalizeQdrantVector(value interface{}) interface{} { + if text, ok := value.(string); ok { + var parsed interface{} + if err := decodeJSONWithUseNumber([]byte(text), &parsed); err == nil { + return parsed + } + } + return value +} + +func qdrantPayloadFromRow(row map[string]interface{}) map[string]interface{} { + payload := make(map[string]interface{}) + if raw, ok := row["payload"].(map[string]interface{}); ok { + for key, value := range raw { + payload[key] = value + } + } + for key, value := range row { + if isQdrantReservedRowField(key) { + continue + } + if strings.HasPrefix(key, "payload.") { + payload[strings.TrimPrefix(key, "payload.")] = value + continue + } + payload[key] = value + } + return payload +} + +func isQdrantReservedRowField(key string) bool { + switch key { + case "id", "_id", "vector", "_vector", "vectors", "embedding", "_embedding", "payload", "score", "version", "next_page_offset": + return true + default: + return false + } +} + +func qdrantBoolValue(value interface{}, fallback bool) bool { + if value == nil { + return fallback + } + switch v := value.(type) { + case bool: + return v + case string: + text := strings.TrimSpace(strings.ToLower(v)) + if text == "" { + return fallback + } + return text == "1" || text == "true" || text == "yes" || text == "on" + default: + return fallback + } +} + +func qdrantVectorIndexes(info map[string]interface{}) []connection.IndexDefinition { + vectors := nestedMapValue(info, "config", "params", "vectors") + if len(vectors) == 0 { + return nil + } + if _, ok := vectors["size"]; ok { + return []connection.IndexDefinition{{Name: "VECTOR", ColumnName: "vector", NonUnique: 1, SeqInIndex: 1, IndexType: "VECTOR"}} + } + var indexes []connection.IndexDefinition + names := make([]string, 0, len(vectors)) + for name := range vectors { + names = append(names, name) + } + sort.Strings(names) + for index, name := range names { + indexes = append(indexes, connection.IndexDefinition{ + Name: "VECTOR_" + name, + ColumnName: "vector." + name, + NonUnique: 1, + SeqInIndex: index + 1, + IndexType: "VECTOR", + }) + } + return indexes +} + +func qdrantPayloadIndexes(info map[string]interface{}) []connection.IndexDefinition { + schema := nestedMapValue(info, "payload_schema") + if len(schema) == 0 { + schema = nestedMapValue(info, "payload_schema", "schema") + } + if len(schema) == 0 { + return nil + } + names := make([]string, 0, len(schema)) + for name := range schema { + names = append(names, name) + } + sort.Strings(names) + indexes := make([]connection.IndexDefinition, 0, len(names)) + for index, name := range names { + indexes = append(indexes, connection.IndexDefinition{ + Name: "PAYLOAD_" + name, + ColumnName: "payload." + name, + NonUnique: 1, + SeqInIndex: index + 1, + IndexType: "PAYLOAD", + }) + } + return indexes +} + +func nestedMapValue(value interface{}, path ...string) map[string]interface{} { + current := value + for _, key := range path { + m, ok := current.(map[string]interface{}) + if !ok { + return nil + } + current = m[key] + } + if m, ok := current.(map[string]interface{}); ok { + return m + } + return nil +} diff --git a/internal/db/qdrant_impl_test.go b/internal/db/qdrant_impl_test.go new file mode 100644 index 0000000..8dab118 --- /dev/null +++ b/internal/db/qdrant_impl_test.go @@ -0,0 +1,274 @@ +package db + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" + + "GoNavi-Wails/internal/connection" +) + +func newMockQdrantServer(t *testing.T, handler http.HandlerFunc) *httptest.Server { + t.Helper() + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + return server +} + +func newTestQdrantDB(t *testing.T, serverURL string) *QdrantDB { + t.Helper() + parsed, err := url.Parse(serverURL) + if err != nil { + t.Fatalf("parse server URL: %v", err) + } + host, port, ok := parseHostPortWithDefault(parsed.Host, defaultQdrantPort) + if !ok { + t.Fatalf("parse host port failed: %s", parsed.Host) + } + db := &QdrantDB{} + if err := db.Connect(connection.ConnectionConfig{ + Type: "qdrant", + Host: host, + Port: port, + UseSSL: strings.EqualFold(parsed.Scheme, "https"), + }); err != nil { + t.Fatalf("connect qdrant: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return db +} + +func writeQdrantJSON(w http.ResponseWriter, value interface{}) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(value) +} + +func TestQdrantGetTables(t *testing.T) { + server := newMockQdrantServer(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet && r.URL.Path == "/collections" { + writeQdrantJSON(w, map[string]interface{}{ + "result": map[string]interface{}{ + "collections": []map[string]interface{}{ + {"name": "products"}, + {"name": "logs"}, + }, + }, + }) + return + } + w.WriteHeader(http.StatusNotFound) + }) + + db := newTestQdrantDB(t, server.URL) + tables, err := db.GetTables("") + if err != nil { + t.Fatalf("GetTables failed: %v", err) + } + if strings.Join(tables, ",") != "logs,products" { + t.Fatalf("tables = %v", tables) + } +} + +func TestQdrantCreateCollectionBuildsVectorsBody(t *testing.T) { + var capturedBody map[string]interface{} + server := newMockQdrantServer(t, func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && r.URL.Path == "/collections": + writeQdrantJSON(w, map[string]interface{}{"result": map[string]interface{}{"collections": []interface{}{}}}) + case r.Method == http.MethodPut && r.URL.Path == "/collections/products": + _ = json.NewDecoder(r.Body).Decode(&capturedBody) + writeQdrantJSON(w, map[string]interface{}{"result": true}) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + db := newTestQdrantDB(t, server.URL) + if _, err := db.Exec(`{"create_collection":"products","size":3,"distance":"Cosine","on_disk_payload":true}`); err != nil { + t.Fatalf("create collection failed: %v", err) + } + vectors, _ := capturedBody["vectors"].(map[string]interface{}) + if intFromAny(vectors["size"], 0) != 3 || vectors["distance"] != "Cosine" || capturedBody["on_disk_payload"] != true { + t.Fatalf("captured body = %#v", capturedBody) + } +} + +func TestQdrantSelectConvertsToScroll(t *testing.T) { + var capturedBody map[string]interface{} + server := newMockQdrantServer(t, func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && r.URL.Path == "/collections": + writeQdrantJSON(w, map[string]interface{}{"result": map[string]interface{}{"collections": []interface{}{}}}) + case r.Method == http.MethodPost && r.URL.Path == "/collections/products/points/scroll": + _ = json.NewDecoder(r.Body).Decode(&capturedBody) + writeQdrantJSON(w, map[string]interface{}{ + "result": map[string]interface{}{ + "points": []map[string]interface{}{ + { + "id": 1, + "payload": map[string]interface{}{"category": "book", "price": 19.5}, + "vector": []float64{0.1, 0.2, 0.3}, + }, + }, + }, + }) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + db := newTestQdrantDB(t, server.URL) + rows, columns, err := db.Query(`SELECT id, vector FROM "products" LIMIT 10 OFFSET 5`) + if err != nil { + t.Fatalf("Query failed: %v", err) + } + if intFromAny(capturedBody["limit"], 0) != 10 || capturedBody["offset"] != float64(5) && capturedBody["offset"] != int64(5) { + t.Fatalf("captured body = %#v", capturedBody) + } + if len(rows) != 1 || rows[0]["id"] == nil || rows[0]["payload.category"] != "book" { + t.Fatalf("rows = %#v", rows) + } + if !containsString(columns, "payload.category") || !containsString(columns, "vector") { + t.Fatalf("columns = %v", columns) + } +} + +func TestQdrantJSONSearchFlattensResults(t *testing.T) { + server := newMockQdrantServer(t, func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && r.URL.Path == "/collections": + writeQdrantJSON(w, map[string]interface{}{"result": map[string]interface{}{"collections": []interface{}{}}}) + case r.Method == http.MethodPost && r.URL.Path == "/collections/products/points/search": + writeQdrantJSON(w, map[string]interface{}{ + "result": []map[string]interface{}{ + { + "id": 1, + "score": 0.98, + "payload": map[string]interface{}{"category": "book"}, + "vector": []float64{0.1, 0.2, 0.3}, + }, + }, + }) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + db := newTestQdrantDB(t, server.URL) + rows, columns, err := db.Query(`{"search":"products","vector":[0.1,0.2,0.3],"limit":1}`) + if err != nil { + t.Fatalf("Query failed: %v", err) + } + if len(rows) != 1 || rows[0]["score"] == nil || rows[0]["payload.category"] != "book" { + t.Fatalf("rows = %#v", rows) + } + if !containsString(columns, "score") || !containsString(columns, "payload.category") { + t.Fatalf("columns = %v", columns) + } +} + +func TestQdrantApplyChangesUpsertPayloadAndDelete(t *testing.T) { + var upsertBody map[string]interface{} + var payloadBody map[string]interface{} + var deleteBody map[string]interface{} + server := newMockQdrantServer(t, func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && r.URL.Path == "/collections": + writeQdrantJSON(w, map[string]interface{}{"result": map[string]interface{}{"collections": []interface{}{}}}) + case r.Method == http.MethodPut && r.URL.Path == "/collections/products/points": + _ = json.NewDecoder(r.Body).Decode(&upsertBody) + writeQdrantJSON(w, map[string]interface{}{"result": map[string]interface{}{"operation_id": 1}}) + case r.Method == http.MethodPost && r.URL.Path == "/collections/products/points/payload": + _ = json.NewDecoder(r.Body).Decode(&payloadBody) + writeQdrantJSON(w, map[string]interface{}{"result": map[string]interface{}{"operation_id": 2}}) + case r.Method == http.MethodPost && r.URL.Path == "/collections/products/points/delete": + _ = json.NewDecoder(r.Body).Decode(&deleteBody) + writeQdrantJSON(w, map[string]interface{}{"result": map[string]interface{}{"operation_id": 3}}) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + db := newTestQdrantDB(t, server.URL) + err := db.ApplyChanges("products", connection.ChangeSet{ + Deletes: []map[string]interface{}{{"id": 9}}, + Updates: []connection.UpdateRow{{ + Keys: map[string]interface{}{"id": 1}, + Values: map[string]interface{}{"payload.category": "updated"}, + }}, + Inserts: []map[string]interface{}{ + {"id": 2, "vector": []float64{0.1, 0.2, 0.3}, "payload.kind": "new"}, + }, + }) + if err != nil { + t.Fatalf("ApplyChanges failed: %v", err) + } + if points := anySlice(deleteBody["points"]); len(points) != 1 || intFromAny(points[0], 0) != 9 { + t.Fatalf("delete body = %#v", deleteBody) + } + if points := anySlice(payloadBody["points"]); len(points) != 1 || intFromAny(points[0], 0) != 1 { + t.Fatalf("payload body = %#v", payloadBody) + } + payload, _ := payloadBody["payload"].(map[string]interface{}) + if payload["category"] != "updated" { + t.Fatalf("payload body = %#v", payloadBody) + } + points := anySlice(upsertBody["points"]) + if len(points) != 1 { + t.Fatalf("upsert body = %#v", upsertBody) + } + point, _ := points[0].(map[string]interface{}) + pointPayload, _ := point["payload"].(map[string]interface{}) + if intFromAny(point["id"], 0) != 2 || pointPayload["kind"] != "new" { + t.Fatalf("upsert body = %#v", upsertBody) + } +} + +func TestQdrantLiveSmoke(t *testing.T) { + serverURL := strings.TrimSpace(os.Getenv("GONAVI_QDRANT_TEST_URL")) + if serverURL == "" { + t.Skip("set GONAVI_QDRANT_TEST_URL to run live Qdrant smoke test") + } + + db := newTestQdrantDB(t, serverURL) + collection := "gonavi_smoke_live" + _, _ = db.Exec(fmt.Sprintf(`{"delete_collection":%q}`, collection)) + if _, err := db.Exec(fmt.Sprintf(`{"create_collection":%q,"size":3,"distance":"Cosine"}`, collection)); err != nil { + t.Fatalf("create live collection: %v", err) + } + t.Cleanup(func() { _, _ = db.Exec(fmt.Sprintf(`{"delete_collection":%q}`, collection)) }) + + if err := db.ApplyChanges(collection, connection.ChangeSet{ + Inserts: []map[string]interface{}{{ + "id": 1, + "vector": []float64{0.1, 0.2, 0.3}, + "payload.kind": "smoke", + }}, + }); err != nil { + t.Fatalf("upsert live row: %v", err) + } + + rows, columns, err := db.Query(fmt.Sprintf(`SELECT id, vector FROM "%s" LIMIT 5`, collection)) + if err != nil { + t.Fatalf("select live rows: %v", err) + } + if len(rows) == 0 || intFromAny(rows[0]["id"], 0) != 1 || rows[0]["payload.kind"] != "smoke" { + t.Fatalf("live rows = %#v", rows) + } + if !containsString(columns, "payload.kind") { + t.Fatalf("live columns missing payload.kind: %v", columns) + } + + queryRows, queryColumns, err := db.Query(fmt.Sprintf(`{"search":%q,"vector":[0.1,0.2,0.3],"limit":1}`, collection)) + if err != nil { + t.Fatalf("search live rows: %v", err) + } + if len(queryRows) == 0 || intFromAny(queryRows[0]["id"], 0) != 1 || !containsString(queryColumns, "score") { + t.Fatalf("live query rows = %#v columns = %v", queryRows, queryColumns) + } +}