diff --git a/frontend/src/components/ConnectionModal.edit-password.test.tsx b/frontend/src/components/ConnectionModal.edit-password.test.tsx
index 9dd7edc..2b570f4 100644
--- a/frontend/src/components/ConnectionModal.edit-password.test.tsx
+++ b/frontend/src/components/ConnectionModal.edit-password.test.tsx
@@ -33,13 +33,25 @@ describe('ConnectionModal data source registry', () => {
expect(source).toContain('type === "elasticsearch"');
expect(source).toContain("return '支持索引浏览、Mapping 检查、JSON DSL 和 query_string 查询';");
expect(source).toContain(
- 'type === "clickhouse" ? "default" : (type === "redis" || type === "elasticsearch") ? "" : "root";',
+ 'type === "clickhouse" ? "default" : (type === "redis" || type === "elasticsearch" || type === "chroma") ? "" : "root";',
);
expect(source).toContain(
- 'placeholder={dbType === "elasticsearch" ? "未开启认证可留空" : undefined}',
+ 'placeholder={(dbType === "elasticsearch" || dbType === "chroma") ? "未开启认证可留空" : undefined}',
);
expect(source).toContain('label="显示数据库 (留空显示全部)"');
});
+
+ it('exposes Chroma in the create-connection picker with vector defaults', () => {
+ expect(source).toContain("case 'chroma':");
+ expect(source).toContain('return 8000;');
+ expect(source).toContain('chroma: ["http", "https", "chroma"]');
+ expect(source).toContain("key: 'chroma'");
+ expect(source).toContain("name: 'Chroma'");
+ expect(source).toContain('type === "chroma"');
+ expect(source).toContain("return 'Collection 浏览、向量检索和元数据过滤';");
+ expect(source).toContain('return "http://127.0.0.1:8000/default_database?tenant=default_tenant";');
+ expect(source).toContain('return "tenant=default_tenant&apiKey=...";');
+ });
});
describe('ConnectionModal Redis Sentinel configuration', () => {
diff --git a/frontend/src/components/ConnectionModal.tsx b/frontend/src/components/ConnectionModal.tsx
index 27c60e2..c1f1fc8 100644
--- a/frontend/src/components/ConnectionModal.tsx
+++ b/frontend/src/components/ConnectionModal.tsx
@@ -1789,6 +1789,22 @@ const ConnectionModal: React.FC<{
parsedValues.useSSL = false;
parsedValues.sslMode = "disable";
}
+ } else if (type === "chroma") {
+ const tls = String(
+ parsed.params.get("tls") ||
+ parsed.params.get("ssl") ||
+ parsed.params.get("useSSL") ||
+ parsed.params.get("use_ssl") ||
+ "",
+ )
+ .trim()
+ .toLowerCase();
+ const skipVerify = normalizeBool(
+ parsed.params.get("skip_verify") || parsed.params.get("skipVerify"),
+ );
+ const enabled = tls ? normalizeBool(tls) : trimmedUri.toLowerCase().startsWith("https://");
+ parsedValues.useSSL = enabled;
+ parsedValues.sslMode = enabled ? (skipVerify ? "skip-verify" : "required") : "disable";
}
}
return parsedValues;
@@ -1851,6 +1867,9 @@ const ConnectionModal: React.FC<{
if (dbType === "clickhouse") {
return "clickhouse://default:pass@127.0.0.1:9000/default";
}
+ if (dbType === "chroma") {
+ return "http://127.0.0.1:8000/default_database?tenant=default_tenant";
+ }
if (dbType === "redis") {
return "redis://:pass@127.0.0.1:6379,127.0.0.2:6379/0?topology=cluster 或 redis://:pass@10.0.0.1:26379,10.0.0.2:26379/0?topology=sentinel&master=mymaster";
}
@@ -1892,6 +1911,8 @@ const ConnectionModal: React.FC<{
return "max_execution_time=60&compress=lz4";
case "mongodb":
return "retryWrites=true&readPreference=secondaryPreferred";
+ case "chroma":
+ return "tenant=default_tenant&apiKey=...";
case "dameng":
return "schema=SYSDBA";
case "tdengine":
@@ -2033,6 +2054,10 @@ const ConnectionModal: React.FC<{
const scheme =
type === "postgres"
? "postgresql"
+ : type === "chroma"
+ ? values.useSSL
+ ? "https"
+ : "http"
: type === "clickhouse" && clickHouseProtocol === "http"
? values.useSSL
? "https"
@@ -2083,6 +2108,11 @@ const ConnectionModal: React.FC<{
if (mode === "skip-verify" || mode === "preferred") {
params.set("skip_verify", "true");
}
+ } else if (type === "chroma") {
+ if (mode === "skip-verify" || mode === "preferred") {
+ params.set("skip_verify", "true");
+ }
+ appendSSLPathParamsForUri(params, type, values);
}
} else if (supportsSSLForType(type)) {
if (isPostgresCompatibleSSLType(type)) {
@@ -3675,7 +3705,7 @@ const ConnectionModal: React.FC<{
});
} else if (type !== "custom") {
const defaultUser =
- type === "clickhouse" ? "default" : (type === "redis" || type === "elasticsearch") ? "" : "root";
+ type === "clickhouse" ? "default" : (type === "redis" || type === "elasticsearch" || type === "chroma") ? "" : "root";
const sslCapableType = supportsSSLForType(type);
setUseSSL(false);
setUseHttpTunnel(false);
@@ -4971,13 +5001,13 @@ const ConnectionModal: React.FC<{
name="user"
label="用户名"
rules={
- (dbType === "mongodb" || dbType === "elasticsearch")
+ (dbType === "mongodb" || dbType === "elasticsearch" || dbType === "chroma")
? []
: [createUriAwareRequiredRule("请输入用户名")]
}
style={{ marginBottom: 0 }}
>
-
+
{
expect(markup).toContain('alt="elasticsearch"');
});
+ it('includes Chroma in the selectable database icons', () => {
+ expect(DB_ICON_TYPES).toContain('chroma');
+ expect(getDbIconLabel('chroma')).toBe('Chroma');
+ const markup = renderToStaticMarkup(<>{getDbIcon('chroma', undefined, 22)}>);
+ expect(markup).toContain('>Ch');
+ });
+
it('wraps database icons in a consistent frame for sidebar sizing', () => {
const mysqlMarkup = renderToStaticMarkup(<>{getDbIcon('mysql', undefined, 22)}>);
const jvmMarkup = renderToStaticMarkup(<>{getDbIcon('jvm', undefined, 22)}>);
diff --git a/frontend/src/components/DatabaseIcons.tsx b/frontend/src/components/DatabaseIcons.tsx
index 9d8d97d..ae7280c 100644
--- a/frontend/src/components/DatabaseIcons.tsx
+++ b/frontend/src/components/DatabaseIcons.tsx
@@ -49,6 +49,7 @@ const DB_DEFAULT_COLORS: Record = {
highgo: '#00A86B',
iris: '#1F6FEB',
tdengine: '#2962FF',
+ chroma: '#7C3AED',
diros: '#0050B3',
starrocks: '#00A6A6',
sphinx: '#2F5D62',
@@ -178,6 +179,9 @@ const IrisIcon: React.FC = ({ size = 16, color }) => (
const TDengineIcon: React.FC = ({ size = 16, color }) => (
);
+const ChromaIcon: React.FC = ({ size = 16, color }) => (
+
+);
const JVMIcon: React.FC = ({ size = 16, color }) => (
);
@@ -231,6 +235,7 @@ const DB_ICON_MAP: Record> = {
highgo: HighGoIcon,
iris: IrisIcon,
tdengine: TDengineIcon,
+ chroma: ChromaIcon,
elasticsearch: ElasticsearchIcon,
custom: CustomIcon,
};
@@ -239,7 +244,7 @@ const DB_ICON_MAP: Record> = {
export const DB_ICON_TYPES: string[] = [
'mysql', 'mariadb', 'oceanbase', 'postgres', 'redis', 'mongodb', 'jvm',
'oracle', 'sqlserver', 'sqlite', 'duckdb', 'clickhouse', 'starrocks',
- 'kingbase', 'dameng', 'vastbase', 'opengauss', 'highgo', 'iris', 'tdengine', 'elasticsearch', 'custom',
+ 'kingbase', 'dameng', 'vastbase', 'opengauss', 'highgo', 'iris', 'tdengine', 'chroma', 'elasticsearch', 'custom',
];
/** 该类型是否有品牌 SVG 文件 */
@@ -262,6 +267,7 @@ export const getDbIconLabel = (type: string): string => {
starrocks: 'StarRocks',
duckdb: 'DuckDB', kingbase: '金仓', dameng: '达梦',
vastbase: 'VastBase', opengauss: 'OpenGauss', highgo: '瀚高', iris: 'InterSystems IRIS', tdengine: 'TDengine',
+ chroma: 'Chroma',
elasticsearch: 'Elasticsearch',
custom: '自定义',
};
diff --git a/frontend/src/utils/connectionDriverType.test.ts b/frontend/src/utils/connectionDriverType.test.ts
index 1d9230d..2b3da04 100644
--- a/frontend/src/utils/connectionDriverType.test.ts
+++ b/frontend/src/utils/connectionDriverType.test.ts
@@ -12,6 +12,8 @@ describe('connectionDriverType', () => {
expect(normalizeDriverType('postgresql')).toBe('postgres');
expect(normalizeDriverType('pgx')).toBe('postgres');
expect(normalizeDriverType('elastic')).toBe('elasticsearch');
+ expect(normalizeDriverType('chromadb')).toBe('chroma');
+ expect(normalizeDriverType('chroma-db')).toBe('chroma');
expect(normalizeDriverType('doris')).toBe('diros');
expect(normalizeDriverType('open-gauss')).toBe('opengauss');
expect(normalizeDriverType('InterSystemsIRIS')).toBe('iris');
diff --git a/frontend/src/utils/connectionDriverType.ts b/frontend/src/utils/connectionDriverType.ts
index 0753a2e..7f9a93f 100644
--- a/frontend/src/utils/connectionDriverType.ts
+++ b/frontend/src/utils/connectionDriverType.ts
@@ -15,6 +15,7 @@ export const normalizeDriverType = (value: string): string => {
const normalized = String(value || '').trim().toLowerCase();
if (normalized === 'postgresql' || normalized === 'pg' || normalized === 'pq' || normalized === 'pgx') return 'postgres';
if (normalized === 'elastic') return 'elasticsearch';
+ if (normalized === 'chromadb' || normalized === 'chroma-db') return 'chroma';
if (normalized === 'doris') return 'diros';
if (
normalized === 'open_gauss' ||
diff --git a/frontend/src/utils/connectionModalPresentation.test.ts b/frontend/src/utils/connectionModalPresentation.test.ts
index c45a58c..e80a9be 100644
--- a/frontend/src/utils/connectionModalPresentation.test.ts
+++ b/frontend/src/utils/connectionModalPresentation.test.ts
@@ -87,6 +87,7 @@ describe('connectionModalPresentation', () => {
'iris',
'mongodb',
'elasticsearch',
+ 'chroma',
'redis',
'tdengine',
'custom',
@@ -156,11 +157,20 @@ describe('connectionModalPresentation', () => {
'credentials',
'databaseScope',
]);
+ expect(resolveConnectionConfigLayout('chroma').sections).toEqual([
+ 'identity',
+ 'uri',
+ 'target',
+ 'service',
+ 'credentials',
+ 'databaseScope',
+ ]);
});
it('uses localized labels for layout kinds shown in the modal', () => {
expect(getConnectionConfigLayoutKindLabel('mysql-compatible')).toBe('MySQL 兼容');
expect(getConnectionConfigLayoutKindLabel('file')).toBe('文件型数据库');
expect(getConnectionConfigLayoutKindLabel('search')).toBe('搜索引擎');
+ expect(getConnectionConfigLayoutKindLabel('vector')).toBe('向量数据库');
});
});
diff --git a/frontend/src/utils/connectionModalPresentation.ts b/frontend/src/utils/connectionModalPresentation.ts
index ec6462a..7dd7b3c 100644
--- a/frontend/src/utils/connectionModalPresentation.ts
+++ b/frontend/src/utils/connectionModalPresentation.ts
@@ -40,6 +40,7 @@ export type ConnectionConfigLayoutKind =
| 'oracle'
| 'file'
| 'search'
+ | 'vector'
| 'custom'
| 'jvm'
| 'generic-sql';
@@ -160,6 +161,8 @@ export const getConnectionConfigLayoutKindLabel = (
return '文件型数据库';
case 'search':
return '搜索引擎';
+ case 'vector':
+ return '向量数据库';
case 'custom':
return '自定义连接';
case 'jvm':
@@ -249,6 +252,19 @@ export const resolveConnectionConfigLayout = (
],
};
}
+ if (type === 'chroma') {
+ return {
+ kind: 'vector',
+ sections: [
+ 'identity',
+ 'uri',
+ 'target',
+ 'service',
+ 'credentials',
+ 'databaseScope',
+ ],
+ };
+ }
if (postgresCompatibleTypes.has(type)) {
return {
kind: 'postgres-compatible',
diff --git a/frontend/src/utils/connectionTypeCapabilities.test.ts b/frontend/src/utils/connectionTypeCapabilities.test.ts
index bdad996..9301efe 100644
--- a/frontend/src/utils/connectionTypeCapabilities.test.ts
+++ b/frontend/src/utils/connectionTypeCapabilities.test.ts
@@ -17,6 +17,7 @@ describe('connectionTypeCapabilities', () => {
expect(singleHostUriSchemesByType.opengauss).toContain('jdbc:opengauss');
expect(singleHostUriSchemesByType.dameng).toEqual(['dameng', 'dm']);
expect(singleHostUriSchemesByType.elasticsearch).toEqual(['http', 'https']);
+ expect(singleHostUriSchemesByType.chroma).toEqual(['http', 'https', 'chroma']);
expect(singleHostUriSchemesByType.redis).toEqual(['redis']);
});
@@ -24,6 +25,7 @@ describe('connectionTypeCapabilities', () => {
expect(supportsSSLForType('redis')).toBe(true);
expect(supportsSSLForType('MongoDB')).toBe(true);
expect(supportsSSLForType('elasticsearch')).toBe(true);
+ expect(supportsSSLForType('chroma')).toBe(true);
expect(supportsSSLForType('tdengine')).toBe(true);
expect(supportsSSLForType('dameng')).toBe(true);
expect(supportsSSLForType('sqlite')).toBe(false);
@@ -36,6 +38,8 @@ describe('connectionTypeCapabilities', () => {
expect(supportsSSLClientCertificateForType('sqlserver')).toBe(false);
expect(supportsSSLCAPathForType('redis')).toBe(true);
expect(supportsSSLClientCertificateForType('redis')).toBe(true);
+ expect(supportsSSLCAPathForType('chroma')).toBe(true);
+ expect(supportsSSLClientCertificateForType('chroma')).toBe(false);
});
it('detects postgres-compatible SSL parameter dialects', () => {
@@ -63,6 +67,7 @@ describe('connectionTypeCapabilities', () => {
expect(supportsConnectionParamsForType('dameng')).toBe(true);
expect(supportsConnectionParamsForType('tdengine')).toBe(true);
expect(supportsConnectionParamsForType('elasticsearch')).toBe(true);
+ expect(supportsConnectionParamsForType('chroma')).toBe(true);
expect(supportsConnectionParamsForType('redis')).toBe(false);
expect(supportsConnectionParamsForType('sqlite')).toBe(false);
expect(supportsConnectionParamsForType('jvm')).toBe(false);
diff --git a/frontend/src/utils/connectionTypeCapabilities.ts b/frontend/src/utils/connectionTypeCapabilities.ts
index e6dff16..e158a13 100644
--- a/frontend/src/utils/connectionTypeCapabilities.ts
+++ b/frontend/src/utils/connectionTypeCapabilities.ts
@@ -12,6 +12,7 @@ export const singleHostUriSchemesByType: Record = {
highgo: ["highgo"],
vastbase: ["vastbase"],
elasticsearch: ["http", "https"],
+ chroma: ["http", "https", "chroma"],
};
const normalizeConnectionType = (type: string) =>
@@ -40,6 +41,7 @@ const sslSupportedTypes = new Set([
"redis",
"tdengine",
"elasticsearch",
+ "chroma",
]);
export const supportsSSLForType = (type: string) =>
@@ -62,6 +64,7 @@ const sslCAPathSupportedTypes = new Set([
"mongodb",
"redis",
"elasticsearch",
+ "chroma",
]);
const sslClientCertificateSupportedTypes = new Set([
@@ -123,4 +126,5 @@ export const supportsConnectionParamsForType = (type: string) =>
type === "mongodb" ||
type === "dameng" ||
type === "tdengine" ||
- type === "elasticsearch";
+ type === "elasticsearch" ||
+ type === "chroma";
diff --git a/frontend/src/utils/connectionTypeCatalog.test.ts b/frontend/src/utils/connectionTypeCatalog.test.ts
index c6909d0..1a841ed 100644
--- a/frontend/src/utils/connectionTypeCatalog.test.ts
+++ b/frontend/src/utils/connectionTypeCatalog.test.ts
@@ -13,6 +13,7 @@ describe('connectionTypeCatalog', () => {
'关系型数据库',
'国产数据库',
'NoSQL',
+ '向量数据库',
'时序数据库',
'其他',
]);
@@ -23,6 +24,7 @@ describe('connectionTypeCatalog', () => {
expect(keys).toContain('mongodb');
expect(keys).toContain('redis');
expect(keys).toContain('elasticsearch');
+ expect(keys).toContain('chroma');
expect(keys).toContain('jvm');
expect(keys).toContain('custom');
expect(new Set(keys).size).toBe(keys.length);
@@ -37,6 +39,7 @@ describe('connectionTypeCatalog', () => {
expect(getConnectionTypeDefaultPort('oracle')).toBe(1521);
expect(getConnectionTypeDefaultPort('mongodb')).toBe(27017);
expect(getConnectionTypeDefaultPort('elasticsearch')).toBe(9200);
+ expect(getConnectionTypeDefaultPort('chroma')).toBe(8000);
expect(getConnectionTypeDefaultPort('sqlite')).toBe(0);
expect(getConnectionTypeDefaultPort('duckdb')).toBe(0);
expect(getConnectionTypeDefaultPort('unknown')).toBe(3306);
@@ -46,6 +49,7 @@ describe('connectionTypeCatalog', () => {
expect(getConnectionTypeHint('redis')).toBe('单机 / 哨兵 / 集群');
expect(getConnectionTypeHint('mongodb')).toBe('单机 / 副本集');
expect(getConnectionTypeHint('elasticsearch')).toContain('Mapping');
+ expect(getConnectionTypeHint('chroma')).toContain('向量');
expect(getConnectionTypeHint('oceanbase')).toBe('MySQL / Oracle 租户');
expect(getConnectionTypeHint('duckdb')).toBe('本地文件连接');
expect(getConnectionTypeHint('mysql')).toBe('标准连接配置');
diff --git a/frontend/src/utils/connectionTypeCatalog.ts b/frontend/src/utils/connectionTypeCatalog.ts
index c900c28..cd508a5 100644
--- a/frontend/src/utils/connectionTypeCatalog.ts
+++ b/frontend/src/utils/connectionTypeCatalog.ts
@@ -45,6 +45,12 @@ export const CONNECTION_TYPE_GROUPS: ConnectionTypeCatalogGroup[] = [
{ key: 'elasticsearch', name: 'Elasticsearch' },
],
},
+ {
+ label: '向量数据库',
+ items: [
+ { key: 'chroma', name: 'Chroma' },
+ ],
+ },
{
label: '时序数据库',
items: [
@@ -97,6 +103,8 @@ export const getConnectionTypeDefaultPort = (type: string): number => {
return 27017;
case 'elasticsearch':
return 9200;
+ case 'chroma':
+ return 8000;
case 'highgo':
return 5866;
case 'mariadb':
@@ -123,6 +131,8 @@ export const getConnectionTypeHint = (type: string): string => {
return '单机 / 副本集';
case 'elasticsearch':
return '支持索引浏览、Mapping 检查、JSON DSL 和 query_string 查询';
+ case 'chroma':
+ return 'Collection 浏览、向量检索和元数据过滤';
case 'oceanbase':
return 'MySQL / Oracle 租户';
case 'sqlite':
diff --git a/frontend/src/utils/dataSourceCapabilities.test.ts b/frontend/src/utils/dataSourceCapabilities.test.ts
index db863ce..0fed0ed 100644
--- a/frontend/src/utils/dataSourceCapabilities.test.ts
+++ b/frontend/src/utils/dataSourceCapabilities.test.ts
@@ -72,6 +72,24 @@ describe('dataSourceCapabilities', () => {
});
});
+ it('treats Chroma as a queryable vector datasource without SQL export actions', () => {
+ expect(getDataSourceCapabilities({ type: 'chroma' })).toMatchObject({
+ type: 'chroma',
+ supportsQueryEditor: true,
+ supportsSqlQueryExport: false,
+ supportsCopyInsert: false,
+ supportsCreateDatabase: false,
+ supportsRenameDatabase: false,
+ supportsDropDatabase: false,
+ forceReadOnlyQueryResult: false,
+ });
+ expect(getDataSourceCapabilities({ type: 'custom', driver: 'chromadb' })).toMatchObject({
+ type: 'chroma',
+ supportsQueryEditor: true,
+ supportsCopyInsert: false,
+ });
+ });
+
it('treats OceanBase Oracle protocol as Oracle capabilities', () => {
expect(getDataSourceCapabilities({
type: 'oceanbase',
diff --git a/frontend/src/utils/dataSourceCapabilities.ts b/frontend/src/utils/dataSourceCapabilities.ts
index 7543178..2e2ea5f 100644
--- a/frontend/src/utils/dataSourceCapabilities.ts
+++ b/frontend/src/utils/dataSourceCapabilities.ts
@@ -21,6 +21,9 @@ const normalizeDataSourceToken = (raw: string): string => {
case 'elastic':
case 'elasticsearch':
return 'elasticsearch';
+ case 'chromadb':
+ case 'chroma-db':
+ return 'chroma';
case 'intersystems':
case 'intersystemsiris':
case 'inter-systems':
diff --git a/frontend/src/utils/sqlDialect.test.ts b/frontend/src/utils/sqlDialect.test.ts
index 5c279c0..ce4a7a6 100644
--- a/frontend/src/utils/sqlDialect.test.ts
+++ b/frontend/src/utils/sqlDialect.test.ts
@@ -30,6 +30,8 @@ describe('sqlDialect', () => {
expect(resolveSqlDialect('custom', 'open_gauss')).toBe('opengauss');
expect(resolveSqlDialect('Elasticsearch')).toBe('elasticsearch');
expect(resolveSqlDialect('custom', 'elastic')).toBe('elasticsearch');
+ expect(resolveSqlDialect('ChromaDB')).toBe('chroma');
+ expect(resolveSqlDialect('custom', 'chroma-db')).toBe('chroma');
expect(resolveSqlDialect('OceanBase', '', { oceanBaseProtocol: 'oracle' })).toBe('oracle');
expect(resolveSqlDialect('custom', 'oceanbase', { oceanBaseProtocol: 'oracle' })).toBe('oracle');
expect(isMysqlFamilyDialect('mariadb')).toBe(true);
diff --git a/frontend/src/utils/sqlDialect.ts b/frontend/src/utils/sqlDialect.ts
index ee58138..bdb9b08 100644
--- a/frontend/src/utils/sqlDialect.ts
+++ b/frontend/src/utils/sqlDialect.ts
@@ -30,6 +30,7 @@ export type SqlDialect =
| 'mongodb'
| 'redis'
| 'elasticsearch'
+ | 'chroma'
| 'unknown'
| string;
@@ -115,6 +116,10 @@ export const resolveSqlDialect = (
return source;
case 'elastic':
return 'elasticsearch';
+ case 'chromadb':
+ case 'chroma-db':
+ case 'chroma':
+ return 'chroma';
default:
break;
}
@@ -140,6 +145,7 @@ export const resolveSqlDialect = (
if (source.includes('sqlserver') || source.includes('mssql')) return 'sqlserver';
if (source.includes('iris') || source.includes('intersystems')) return 'iris';
if (source.includes('elastic')) return 'elasticsearch';
+ if (source.includes('chroma')) return 'chroma';
return source;
};
diff --git a/internal/app/db_proxy.go b/internal/app/db_proxy.go
index a5e847b..63be9be 100644
--- a/internal/app/db_proxy.go
+++ b/internal/app/db_proxy.go
@@ -233,6 +233,8 @@ func defaultPortByType(driverType string) int {
return 5866
case "iris":
return 1972
+ case "chroma":
+ return 8000
default:
return 0
}
diff --git a/internal/db/chroma_impl.go b/internal/db/chroma_impl.go
new file mode 100644
index 0000000..9e7fd7f
--- /dev/null
+++ b/internal/db/chroma_impl.go
@@ -0,0 +1,1196 @@
+package db
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/url"
+ "regexp"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+
+ "GoNavi-Wails/internal/connection"
+ "GoNavi-Wails/internal/logger"
+ proxytunnel "GoNavi-Wails/internal/proxy"
+ "GoNavi-Wails/internal/ssh"
+)
+
+const (
+ defaultChromaPort = 8000
+ defaultChromaTenant = "default_tenant"
+ defaultChromaDatabase = "default_database"
+ defaultChromaQueryTimeout = 30 * time.Second
+)
+
+type ChromaDB struct {
+ client *http.Client
+ baseURL string
+ tenant string
+ database string
+ apiVersion int
+ authHeaders map[string]string
+ forwarder *ssh.LocalForwarder
+}
+
+type chromaCollection struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Metadata map[string]interface{} `json:"metadata"`
+ Dimension int `json:"dimension"`
+ Tenant string `json:"tenant"`
+ Database string `json:"database"`
+}
+
+type chromaGetResponse struct {
+ IDs []string `json:"ids"`
+ Documents []interface{} `json:"documents"`
+ Metadatas []map[string]interface{} `json:"metadatas"`
+ Embeddings []interface{} `json:"embeddings"`
+ Included []string `json:"included"`
+}
+
+func (c *ChromaDB) Connect(config connection.ConnectionConfig) error {
+ if c.forwarder != nil {
+ _ = c.forwarder.Close()
+ c.forwarder = nil
+ }
+ c.client = nil
+
+ runConfig := normalizeChromaConfig(config)
+ if runConfig.UseSSH {
+ forwarder, err := ssh.GetOrCreateLocalForwarder(runConfig.SSH, runConfig.Host, runConfig.Port)
+ if err != nil {
+ return fmt.Errorf("创建 SSH 隧道失败:%w", err)
+ }
+ c.forwarder = forwarder
+
+ host, portText, err := net.SplitHostPort(forwarder.LocalAddr)
+ if err != nil {
+ return fmt.Errorf("解析本地转发地址失败:%w", err)
+ }
+ port, err := strconv.Atoi(portText)
+ if err != nil {
+ return fmt.Errorf("解析本地端口失败:%w", err)
+ }
+ runConfig.Host = host
+ runConfig.Port = port
+ runConfig.UseSSH = false
+ logger.Infof("Chroma 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
+ }
+
+ c.tenant = chromaTenantFromConfig(runConfig)
+ c.database = chromaDatabaseFromConfig(runConfig)
+ c.baseURL = buildChromaBaseURL(runConfig)
+ c.authHeaders = chromaAuthHeaders(runConfig)
+ c.client = buildChromaHTTPClient(runConfig)
+
+ if err := c.Ping(); err != nil {
+ _ = c.Close()
+ return err
+ }
+ return nil
+}
+
+func (c *ChromaDB) Close() error {
+ if c.forwarder != nil {
+ if err := c.forwarder.Close(); err != nil {
+ logger.Warnf("关闭 Chroma SSH 端口转发失败:%v", err)
+ }
+ c.forwarder = nil
+ }
+ c.client = nil
+ return nil
+}
+
+func (c *ChromaDB) Ping() error {
+ if c.client == nil {
+ return fmt.Errorf("连接未打开")
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ if err := c.detectVersion(ctx); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (c *ChromaDB) Query(query string) ([]map[string]interface{}, []string, error) {
+ ctx, cancel := context.WithTimeout(context.Background(), defaultChromaQueryTimeout)
+ defer cancel()
+ return c.QueryContext(ctx, query)
+}
+
+func (c *ChromaDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
+ if c.client == nil {
+ return nil, nil, fmt.Errorf("连接未打开")
+ }
+ text := strings.TrimSpace(query)
+ if text == "" {
+ return nil, nil, fmt.Errorf("查询语句不能为空")
+ }
+
+ if strings.HasPrefix(text, "{") {
+ return c.queryJSON(ctx, text)
+ }
+
+ if parsed, ok := parseChromaSQL(text); ok {
+ if parsed.Count {
+ total, err := c.countCollection(ctx, parsed.Collection, parsed.Where)
+ if err != nil {
+ return nil, nil, err
+ }
+ return []map[string]interface{}{{"total": total}}, []string{"total"}, nil
+ }
+ include := []string{"documents", "metadatas"}
+ if parsed.IncludeEmbeddings {
+ include = append(include, "embeddings")
+ }
+ return c.getCollectionRows(ctx, parsed.Collection, parsed.Limit, parsed.Offset, parsed.Where, include)
+ }
+
+ return nil, nil, fmt.Errorf("Chroma 查询仅支持 JSON 命令或简单 SELECT 预览")
+}
+
+func (c *ChromaDB) Exec(query string) (int64, error) {
+ ctx, cancel := context.WithTimeout(context.Background(), defaultChromaQueryTimeout)
+ defer cancel()
+ return c.ExecContext(ctx, query)
+}
+
+func (c *ChromaDB) ExecContext(ctx context.Context, query string) (int64, error) {
+ if c.client == nil {
+ return 0, fmt.Errorf("连接未打开")
+ }
+ var cmd map[string]interface{}
+ if err := decodeJSONWithUseNumber([]byte(strings.TrimSpace(query)), &cmd); err != nil {
+ return 0, fmt.Errorf("Chroma 写入命令必须是 JSON:%w", err)
+ }
+ if name := firstStringValue(cmd, "create_collection", "createCollection", "collection"); name != "" && hasAnyKey(cmd, "create_collection", "createCollection") {
+ body := map[string]interface{}{"name": name}
+ if metadata, ok := cmd["metadata"]; ok {
+ body["metadata"] = metadata
+ }
+ if getOrBool(cmd, "get_or_create", "getOrCreate") {
+ body["get_or_create"] = true
+ }
+ return 1, c.createCollection(ctx, body)
+ }
+ if name := firstStringValue(cmd, "delete_collection", "deleteCollection"); name != "" {
+ return 1, c.deleteCollection(ctx, name)
+ }
+ if name := firstStringValue(cmd, "upsert", "collection"); name != "" && hasAnyKey(cmd, "upsert") {
+ return c.upsertCommand(ctx, name, cmd)
+ }
+ if name := firstStringValue(cmd, "delete", "collection"); name != "" && hasAnyKey(cmd, "delete") {
+ return c.deleteCommand(ctx, name, cmd)
+ }
+ return 0, fmt.Errorf("Chroma JSON 写入命令仅支持 create_collection/delete_collection/upsert/delete")
+}
+
+func (c *ChromaDB) GetDatabases() ([]string, error) {
+ if c.client == nil {
+ return nil, fmt.Errorf("连接未打开")
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ if err := c.ensureVersion(ctx); err != nil {
+ return nil, err
+ }
+ if c.apiVersion != 2 {
+ return []string{c.database}, nil
+ }
+
+ var raw []map[string]interface{}
+ err := c.doJSON(ctx, http.MethodGet, fmt.Sprintf("/api/v2/tenants/%s/databases", url.PathEscape(c.tenant)), nil, &raw)
+ if err != nil {
+ return []string{c.database}, nil
+ }
+ names := make([]string, 0, len(raw))
+ for _, item := range raw {
+ if name := mapString(item, "name"); name != "" {
+ names = append(names, name)
+ }
+ }
+ if len(names) == 0 {
+ names = append(names, c.database)
+ }
+ sort.Strings(names)
+ return names, nil
+}
+
+func (c *ChromaDB) GetTables(dbName string) ([]string, error) {
+ collections, err := c.listCollections(context.Background(), dbName)
+ if err != nil {
+ return nil, err
+ }
+ names := make([]string, 0, len(collections))
+ for _, item := range collections {
+ if strings.TrimSpace(item.Name) != "" {
+ names = append(names, item.Name)
+ }
+ }
+ sort.Strings(names)
+ return names, nil
+}
+
+func (c *ChromaDB) GetCreateStatement(dbName, tableName string) (string, error) {
+ coll, err := c.resolveCollection(context.Background(), dbName, tableName)
+ if err != nil {
+ return "", err
+ }
+ payload, _ := json.MarshalIndent(coll, "", " ")
+ return fmt.Sprintf("// Chroma collection: %s\n%s", coll.Name, string(payload)), nil
+}
+
+func (c *ChromaDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
+ rows, _, err := c.getCollectionRows(context.Background(), tableNameOrDB(dbName, tableName), 20, 0, nil, []string{"documents", "metadatas", "embeddings"})
+ if err != nil {
+ return nil, err
+ }
+ cols := []connection.ColumnDefinition{
+ {Name: "id", Type: "string", Nullable: "NO", Key: "PRI", Comment: "Chroma document id"},
+ {Name: "document", Type: "text", Nullable: "YES", Comment: "Document text"},
+ {Name: "metadata", Type: "json", Nullable: "YES", Comment: "Full metadata object"},
+ {Name: "embedding", Type: "vector", Nullable: "YES", Comment: "Embedding vector"},
+ }
+ seen := map[string]struct{}{"id": {}, "document": {}, "metadata": {}, "embedding": {}}
+ for _, row := range rows {
+ for key, value := range row {
+ if _, exists := seen[key]; exists || !strings.HasPrefix(key, "metadata.") {
+ continue
+ }
+ seen[key] = struct{}{}
+ cols = append(cols, connection.ColumnDefinition{
+ Name: key,
+ Type: inferChromaValueType(value),
+ Nullable: "YES",
+ Comment: "Metadata field",
+ })
+ }
+ }
+ return cols, nil
+}
+
+func (c *ChromaDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
+ tables, err := c.GetTables(dbName)
+ if err != nil {
+ return nil, err
+ }
+ var result []connection.ColumnDefinitionWithTable
+ for _, table := range tables {
+ cols, err := c.GetColumns(dbName, table)
+ if err != nil {
+ continue
+ }
+ for _, col := range cols {
+ result = append(result, connection.ColumnDefinitionWithTable{
+ TableName: table,
+ Name: col.Name,
+ Type: col.Type,
+ Comment: col.Comment,
+ })
+ }
+ }
+ return result, nil
+}
+
+func (c *ChromaDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
+ return []connection.IndexDefinition{
+ {Name: "PRIMARY", ColumnName: "id", NonUnique: 0, SeqInIndex: 1, IndexType: "PRIMARY"},
+ {Name: "HNSW", ColumnName: "embedding", NonUnique: 1, SeqInIndex: 1, IndexType: "VECTOR"},
+ }, nil
+}
+
+func (c *ChromaDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
+ return []connection.ForeignKeyDefinition{}, nil
+}
+
+func (c *ChromaDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
+ return []connection.TriggerDefinition{}, nil
+}
+
+func (c *ChromaDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
+ ctx, cancel := context.WithTimeout(context.Background(), defaultChromaQueryTimeout)
+ defer cancel()
+
+ if len(changes.Deletes) > 0 {
+ ids := make([]string, 0, len(changes.Deletes))
+ for _, row := range changes.Deletes {
+ if id := chromaRowID(row); id != "" {
+ ids = append(ids, id)
+ }
+ }
+ if len(ids) > 0 {
+ if _, err := c.deleteCommand(ctx, tableName, map[string]interface{}{"ids": ids}); err != nil {
+ return err
+ }
+ }
+ }
+
+ if len(changes.Updates) > 0 {
+ rows := make([]map[string]interface{}, 0, len(changes.Updates))
+ for _, update := range changes.Updates {
+ row := make(map[string]interface{}, len(update.Keys)+len(update.Values))
+ for k, v := range update.Keys {
+ row[k] = v
+ }
+ for k, v := range update.Values {
+ row[k] = v
+ }
+ rows = append(rows, row)
+ }
+ if err := c.upsertRows(ctx, tableName, rows); err != nil {
+ return err
+ }
+ }
+ if len(changes.Inserts) > 0 {
+ if err := c.upsertRows(ctx, tableName, changes.Inserts); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func normalizeChromaConfig(config connection.ConnectionConfig) connection.ConnectionConfig {
+ runConfig := applyChromaURI(config)
+ if strings.TrimSpace(runConfig.Host) == "" {
+ runConfig.Host = "localhost"
+ }
+ if runConfig.Port <= 0 {
+ runConfig.Port = defaultChromaPort
+ }
+ if strings.TrimSpace(runConfig.SSLMode) == "" && runConfig.UseSSL {
+ runConfig.SSLMode = "required"
+ }
+ return runConfig
+}
+
+func applyChromaURI(config connection.ConnectionConfig) connection.ConnectionConfig {
+ uriText := strings.TrimSpace(config.URI)
+ if uriText == "" {
+ return config
+ }
+ parsed, err := url.Parse(uriText)
+ if err != nil {
+ return config
+ }
+ scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
+ if scheme != "http" && scheme != "https" && scheme != "chroma" {
+ return config
+ }
+ if parsed.User != nil {
+ if strings.TrimSpace(config.User) == "" {
+ config.User = parsed.User.Username()
+ }
+ if pass, ok := parsed.User.Password(); ok && config.Password == "" {
+ config.Password = pass
+ }
+ }
+ if scheme == "https" {
+ config.UseSSL = true
+ }
+ if host := strings.TrimSpace(parsed.Host); host != "" {
+ if h, port, ok := parseHostPortWithDefault(host, defaultChromaPort); ok {
+ config.Host = h
+ config.Port = port
+ }
+ }
+ if dbName := strings.Trim(strings.TrimSpace(parsed.Path), "/"); dbName != "" && !strings.HasPrefix(dbName, "api/") && strings.TrimSpace(config.Database) == "" {
+ config.Database = dbName
+ }
+ return config
+}
+
+func buildChromaBaseURL(config connection.ConnectionConfig) string {
+ scheme := "http"
+ if config.UseSSL {
+ scheme = "https"
+ }
+ return fmt.Sprintf("%s://%s:%d", scheme, strings.TrimSpace(config.Host), config.Port)
+}
+
+func chromaTenantFromConfig(config connection.ConnectionConfig) string {
+ params := chromaConnectionParams(config)
+ if tenant := strings.TrimSpace(params.Get("tenant")); tenant != "" {
+ return tenant
+ }
+ return defaultChromaTenant
+}
+
+func chromaDatabaseFromConfig(config connection.ConnectionConfig) string {
+ if dbName := strings.TrimSpace(config.Database); dbName != "" {
+ return dbName
+ }
+ params := chromaConnectionParams(config)
+ if dbName := strings.TrimSpace(params.Get("database")); dbName != "" {
+ return dbName
+ }
+ return defaultChromaDatabase
+}
+
+func chromaConnectionParams(config connection.ConnectionConfig) url.Values {
+ params := url.Values{}
+ mergeConnectionParamValues(params, connectionParamsFromURI(config.URI, "http", "https", "chroma"))
+ mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams))
+ return params
+}
+
+func chromaAuthHeaders(config connection.ConnectionConfig) map[string]string {
+ headers := make(map[string]string)
+ params := chromaConnectionParams(config)
+ token := firstNonEmpty(params.Get("apiKey"), params.Get("apikey"), params.Get("token"), params.Get("authToken"))
+ if token == "" && strings.TrimSpace(config.User) == "" {
+ token = strings.TrimSpace(config.Password)
+ }
+ if token != "" {
+ headers["Authorization"] = "Bearer " + token
+ } else if user := strings.TrimSpace(config.User); user != "" {
+ raw := user + ":" + config.Password
+ headers["Authorization"] = "Basic " + base64.StdEncoding.EncodeToString([]byte(raw))
+ }
+ if headerName := strings.TrimSpace(params.Get("authHeader")); headerName != "" {
+ if headerValue := strings.TrimSpace(params.Get("authHeaderValue")); headerValue != "" && isSafeConnectionParamKey(headerName) {
+ headers[headerName] = headerValue
+ }
+ }
+ return headers
+}
+
+func buildChromaHTTPClient(config connection.ConnectionConfig) *http.Client {
+ transport := http.DefaultTransport.(*http.Transport).Clone()
+ if tlsConfig, err := resolveGenericTLSConfig(config); err == nil && tlsConfig != nil {
+ transport.TLSClientConfig = tlsConfig
+ }
+ if config.UseProxy {
+ proxyCfg := config.Proxy
+ transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return proxytunnel.DialContext(ctx, proxyCfg, network, addr)
+ }
+ }
+ return &http.Client{Transport: transport, Timeout: getConnectTimeout(config)}
+}
+
+func (c *ChromaDB) detectVersion(ctx context.Context) error {
+ if c.client == nil {
+ return fmt.Errorf("连接未打开")
+ }
+ if err := c.doJSON(ctx, http.MethodGet, "/api/v2/heartbeat", nil, nil); err == nil {
+ c.apiVersion = 2
+ return nil
+ }
+ if err := c.doJSON(ctx, http.MethodGet, "/api/v1/heartbeat", nil, nil); err == nil {
+ c.apiVersion = 1
+ return nil
+ }
+ return fmt.Errorf("Chroma 连接失败:无法访问 /api/v2/heartbeat 或 /api/v1/heartbeat")
+}
+
+func (c *ChromaDB) ensureVersion(ctx context.Context) error {
+ if c.apiVersion == 1 || c.apiVersion == 2 {
+ return nil
+ }
+ return c.detectVersion(ctx)
+}
+
+func (c *ChromaDB) doJSON(ctx context.Context, method, path string, body interface{}, out interface{}) error {
+ if c.client == nil {
+ return fmt.Errorf("连接未打开")
+ }
+ var reader io.Reader
+ if body != nil {
+ payload, err := json.Marshal(body)
+ if err != nil {
+ return err
+ }
+ reader = bytes.NewReader(payload)
+ }
+ req, err := http.NewRequestWithContext(ctx, method, strings.TrimRight(c.baseURL, "/")+path, reader)
+ if err != nil {
+ return err
+ }
+ if body != nil {
+ req.Header.Set("Content-Type", "application/json")
+ }
+ req.Header.Set("Accept", "application/json")
+ if strings.TrimSpace(c.authHeaders["Authorization"]) == "" && strings.TrimSpace(req.Header.Get("Authorization")) == "" {
+ // Basic Auth remains useful for gateways even when Chroma itself uses token auth.
+ }
+ for key, value := range c.authHeaders {
+ if strings.TrimSpace(key) != "" && strings.TrimSpace(value) != "" {
+ req.Header.Set(key, value)
+ }
+ }
+ res, err := c.client.Do(req)
+ if err != nil {
+ return err
+ }
+ defer res.Body.Close()
+ resBody, err := io.ReadAll(res.Body)
+ if err != nil {
+ return err
+ }
+ if res.StatusCode < 200 || res.StatusCode >= 300 {
+ message := strings.TrimSpace(string(resBody))
+ if message == "" {
+ message = res.Status
+ }
+ return fmt.Errorf("Chroma API %s %s 失败:%s", method, path, message)
+ }
+ if out == nil || len(bytes.TrimSpace(resBody)) == 0 {
+ return nil
+ }
+ if err := decodeJSONWithUseNumber(resBody, out); err != nil {
+ return fmt.Errorf("解析 Chroma 响应失败:%w", err)
+ }
+ return nil
+}
+
+func (c *ChromaDB) v2Path(dbName string, suffix string) string {
+ database := strings.TrimSpace(dbName)
+ if database == "" {
+ database = c.database
+ }
+ base := fmt.Sprintf("/api/v2/tenants/%s/databases/%s", url.PathEscape(c.tenant), url.PathEscape(database))
+ if suffix == "" {
+ return base
+ }
+ return base + suffix
+}
+
+func (c *ChromaDB) listCollections(ctx context.Context, dbName string) ([]chromaCollection, error) {
+ if c.client == nil {
+ return nil, fmt.Errorf("连接未打开")
+ }
+ if err := c.ensureVersion(ctx); err != nil {
+ return nil, err
+ }
+ var collections []chromaCollection
+ path := "/api/v1/collections"
+ if c.apiVersion == 2 {
+ path = c.v2Path(dbName, "/collections")
+ }
+ if err := c.doJSON(ctx, http.MethodGet, path, nil, &collections); err != nil {
+ return nil, err
+ }
+ return collections, nil
+}
+
+func (c *ChromaDB) resolveCollection(ctx context.Context, dbName, tableName string) (chromaCollection, error) {
+ name := tableNameOrDB(dbName, tableName)
+ if name == "" {
+ return chromaCollection{}, fmt.Errorf("collection 名称不能为空")
+ }
+ collections, err := c.listCollections(ctx, dbName)
+ if err != nil {
+ return chromaCollection{}, err
+ }
+ for _, item := range collections {
+ if strings.EqualFold(item.Name, name) || strings.EqualFold(item.ID, name) {
+ return item, nil
+ }
+ }
+ return chromaCollection{}, fmt.Errorf("未找到 Chroma collection:%s", name)
+}
+
+func (c *ChromaDB) collectionActionPath(ctx context.Context, collectionName, action string) (string, error) {
+ if err := c.ensureVersion(ctx); err != nil {
+ return "", err
+ }
+ coll, err := c.resolveCollection(ctx, "", collectionName)
+ if err != nil {
+ return "", err
+ }
+ ident := coll.ID
+ if strings.TrimSpace(ident) == "" {
+ ident = coll.Name
+ }
+ if ident == "" {
+ return "", fmt.Errorf("collection 标识为空")
+ }
+ if c.apiVersion == 2 {
+ return c.v2Path(coll.Database, fmt.Sprintf("/collections/%s/%s", url.PathEscape(ident), action)), nil
+ }
+ return fmt.Sprintf("/api/v1/collections/%s/%s", url.PathEscape(ident), action), nil
+}
+
+func (c *ChromaDB) getCollectionRows(ctx context.Context, collection string, limit int, offset int, where interface{}, include []string) ([]map[string]interface{}, []string, error) {
+ if limit <= 0 {
+ limit = 200
+ }
+ path, err := c.collectionActionPath(ctx, collection, "get")
+ if err != nil {
+ return nil, nil, err
+ }
+ body := map[string]interface{}{
+ "limit": limit,
+ "offset": offset,
+ "include": include,
+ }
+ if where != nil {
+ body["where"] = where
+ }
+ var resp chromaGetResponse
+ if err := c.doJSON(ctx, http.MethodPost, path, body, &resp); err != nil {
+ return nil, nil, err
+ }
+ rows, columns := chromaGetResponseRows(resp)
+ return rows, columns, nil
+}
+
+func (c *ChromaDB) countCollection(ctx context.Context, collection string, where interface{}) (int64, error) {
+ path, err := c.collectionActionPath(ctx, collection, "count")
+ if err != nil {
+ return 0, err
+ }
+ if where == nil {
+ var raw interface{}
+ if err := c.doJSON(ctx, http.MethodGet, path, nil, &raw); err == nil {
+ return chromaCountValue(raw), nil
+ }
+ }
+ rows, _, err := c.getCollectionRows(ctx, collection, 1_000_000, 0, where, []string{"documents"})
+ if err != nil {
+ return 0, err
+ }
+ return int64(len(rows)), nil
+}
+
+func (c *ChromaDB) queryJSON(ctx context.Context, text string) ([]map[string]interface{}, []string, error) {
+ var cmd map[string]interface{}
+ if err := decodeJSONWithUseNumber([]byte(text), &cmd); err != nil {
+ return nil, nil, fmt.Errorf("Chroma JSON 命令解析失败:%w", err)
+ }
+ if hasAnyKey(cmd, "list_collections", "listCollections") {
+ cols, err := c.listCollections(ctx, "")
+ if err != nil {
+ return nil, nil, err
+ }
+ rows := make([]map[string]interface{}, 0, len(cols))
+ for _, col := range cols {
+ rows = append(rows, chromaStructRow(col))
+ }
+ return rows, collectColumns(rows), nil
+ }
+ if name := firstStringValue(cmd, "get", "collection"); name != "" && (hasAnyKey(cmd, "get") || !hasAnyKey(cmd, "query", "query_embeddings", "query_texts")) {
+ limit := intFromAny(cmd["limit"], 200)
+ offset := intFromAny(cmd["offset"], 0)
+ include := stringSliceFromAny(cmd["include"], []string{"documents", "metadatas"})
+ return c.getCollectionRows(ctx, name, limit, offset, cmd["where"], include)
+ }
+ if name := firstStringValue(cmd, "query", "collection"); name != "" {
+ return c.queryCollection(ctx, name, cmd)
+ }
+ return nil, nil, fmt.Errorf("Chroma JSON 查询命令仅支持 list_collections/get/query")
+}
+
+func (c *ChromaDB) queryCollection(ctx context.Context, collection string, cmd map[string]interface{}) ([]map[string]interface{}, []string, error) {
+ path, err := c.collectionActionPath(ctx, collection, "query")
+ if err != nil {
+ return nil, nil, err
+ }
+ body := make(map[string]interface{})
+ for _, key := range []string{"query_embeddings", "query_texts", "where", "where_document", "include"} {
+ if value, ok := cmd[key]; ok {
+ body[key] = value
+ }
+ }
+ body["n_results"] = intFromAny(firstExisting(cmd, "n_results", "limit"), 10)
+ if _, ok := body["include"]; !ok {
+ body["include"] = []string{"documents", "metadatas", "distances"}
+ }
+ var raw map[string]interface{}
+ if err := c.doJSON(ctx, http.MethodPost, path, body, &raw); err != nil {
+ return nil, nil, err
+ }
+ rows := chromaQueryResponseRows(raw)
+ return rows, collectColumns(rows), nil
+}
+
+func (c *ChromaDB) createCollection(ctx context.Context, body map[string]interface{}) error {
+ if err := c.ensureVersion(ctx); err != nil {
+ return err
+ }
+ path := "/api/v1/collections"
+ if c.apiVersion == 2 {
+ path = c.v2Path("", "/collections")
+ }
+ return c.doJSON(ctx, http.MethodPost, path, body, nil)
+}
+
+func (c *ChromaDB) deleteCollection(ctx context.Context, name string) error {
+ if err := c.ensureVersion(ctx); err != nil {
+ return err
+ }
+ path := fmt.Sprintf("/api/v1/collections/%s", url.PathEscape(name))
+ if c.apiVersion == 2 {
+ coll, err := c.resolveCollection(ctx, "", name)
+ if err != nil {
+ return err
+ }
+ ident := coll.ID
+ if ident == "" {
+ ident = coll.Name
+ }
+ path = c.v2Path(coll.Database, fmt.Sprintf("/collections/%s", url.PathEscape(ident)))
+ }
+ return c.doJSON(ctx, http.MethodDelete, path, nil, nil)
+}
+
+func (c *ChromaDB) upsertCommand(ctx context.Context, collection string, cmd map[string]interface{}) (int64, error) {
+ if rowsValue, ok := cmd["rows"].([]interface{}); ok {
+ rows := make([]map[string]interface{}, 0, len(rowsValue))
+ for _, raw := range rowsValue {
+ if row, ok := raw.(map[string]interface{}); ok {
+ rows = append(rows, row)
+ }
+ }
+ return int64(len(rows)), c.upsertRows(ctx, collection, rows)
+ }
+ body := make(map[string]interface{})
+ for _, key := range []string{"ids", "documents", "metadatas", "embeddings", "uris"} {
+ if value, ok := cmd[key]; ok {
+ body[key] = value
+ }
+ }
+ if _, ok := body["ids"]; !ok {
+ return 0, fmt.Errorf("Chroma upsert 命令缺少 ids")
+ }
+ path, err := c.collectionActionPath(ctx, collection, "upsert")
+ if err != nil {
+ return 0, err
+ }
+ return int64(len(anySlice(body["ids"]))), c.doJSON(ctx, http.MethodPost, path, body, nil)
+}
+
+func (c *ChromaDB) deleteCommand(ctx context.Context, collection string, cmd map[string]interface{}) (int64, error) {
+ body := make(map[string]interface{})
+ for _, key := range []string{"ids", "where", "where_document"} {
+ if value, ok := cmd[key]; ok {
+ body[key] = value
+ }
+ }
+ if len(body) == 0 {
+ return 0, fmt.Errorf("Chroma delete 命令缺少 ids/where/where_document")
+ }
+ path, err := c.collectionActionPath(ctx, collection, "delete")
+ if err != nil {
+ return 0, err
+ }
+ return int64(len(anySlice(body["ids"]))), c.doJSON(ctx, http.MethodPost, path, body, nil)
+}
+
+func (c *ChromaDB) upsertRows(ctx context.Context, collection string, rows []map[string]interface{}) error {
+ if len(rows) == 0 {
+ return nil
+ }
+ ids := make([]string, 0, len(rows))
+ docs := make([]interface{}, 0, len(rows))
+ metadatas := make([]map[string]interface{}, 0, len(rows))
+ embeddings := make([]interface{}, 0, len(rows))
+ hasEmbedding := false
+ for _, row := range rows {
+ id := chromaRowID(row)
+ if id == "" {
+ return fmt.Errorf("Chroma 写入行缺少 id")
+ }
+ ids = append(ids, id)
+ docs = append(docs, firstExisting(row, "document", "_document", "documents"))
+ meta := make(map[string]interface{})
+ if raw, ok := row["metadata"].(map[string]interface{}); ok {
+ for k, v := range raw {
+ meta[k] = v
+ }
+ }
+ for k, v := range row {
+ if isChromaReservedRowField(k) {
+ continue
+ }
+ if strings.HasPrefix(k, "metadata.") {
+ meta[strings.TrimPrefix(k, "metadata.")] = v
+ continue
+ }
+ meta[k] = v
+ }
+ metadatas = append(metadatas, meta)
+ if embedding := firstExisting(row, "embedding", "_embedding", "embeddings"); embedding != nil {
+ embeddings = append(embeddings, normalizeChromaEmbedding(embedding))
+ hasEmbedding = true
+ }
+ }
+ body := map[string]interface{}{
+ "ids": ids,
+ "documents": docs,
+ "metadatas": metadatas,
+ }
+ if hasEmbedding {
+ body["embeddings"] = embeddings
+ }
+ path, err := c.collectionActionPath(ctx, collection, "upsert")
+ if err != nil {
+ return err
+ }
+ return c.doJSON(ctx, http.MethodPost, path, body, nil)
+}
+
+type chromaParsedSQL struct {
+ Collection string
+ Limit int
+ Offset int
+ Where interface{}
+ Count bool
+ IncludeEmbeddings bool
+}
+
+var chromaSQLFromRE = regexp.MustCompile(`(?i)\bFROM\s+(?:"([^"]+)"|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z0-9_.\-]+))`)
+var chromaSQLLimitRE = regexp.MustCompile(`(?i)\bLIMIT\s+(\d+)`)
+var chromaSQLOffsetRE = regexp.MustCompile(`(?i)\bOFFSET\s+(\d+)`)
+
+func parseChromaSQL(sqlText string) (chromaParsedSQL, bool) {
+ text := strings.TrimSpace(sqlText)
+ if !strings.HasPrefix(strings.ToLower(text), "select") {
+ return chromaParsedSQL{}, false
+ }
+ matches := chromaSQLFromRE.FindStringSubmatch(text)
+ if len(matches) == 0 {
+ return chromaParsedSQL{}, false
+ }
+ table := firstNonEmpty(matches[1], matches[2], matches[3])
+ if table == "" {
+ return chromaParsedSQL{}, false
+ }
+ parsed := chromaParsedSQL{Collection: table, Limit: 200}
+ lower := strings.ToLower(text)
+ parsed.Count = strings.Contains(lower, "count(")
+ parsed.IncludeEmbeddings = strings.Contains(lower, "embedding")
+ if m := chromaSQLLimitRE.FindStringSubmatch(text); len(m) > 1 {
+ parsed.Limit, _ = strconv.Atoi(m[1])
+ }
+ if m := chromaSQLOffsetRE.FindStringSubmatch(text); len(m) > 1 {
+ parsed.Offset, _ = strconv.Atoi(m[1])
+ }
+ return parsed, true
+}
+
+func chromaGetResponseRows(resp chromaGetResponse) ([]map[string]interface{}, []string) {
+ rows := make([]map[string]interface{}, 0, len(resp.IDs))
+ for index, id := range resp.IDs {
+ row := map[string]interface{}{"id": id}
+ if value := sliceValue(resp.Documents, index); value != nil {
+ row["document"] = value
+ }
+ if meta := sliceValueMap(resp.Metadatas, index); meta != nil {
+ row["metadata"] = meta
+ for k, v := range meta {
+ row["metadata."+k] = v
+ }
+ }
+ if value := sliceValue(resp.Embeddings, index); value != nil {
+ row["embedding"] = normalizeJSONLikeValue(value)
+ }
+ rows = append(rows, row)
+ }
+ return rows, collectColumns(rows)
+}
+
+func chromaQueryResponseRows(raw map[string]interface{}) []map[string]interface{} {
+ idGroups := nestedAnySlice(raw["ids"])
+ docGroups := nestedAnySlice(raw["documents"])
+ metaGroups := nestedAnySlice(raw["metadatas"])
+ distanceGroups := nestedAnySlice(raw["distances"])
+ var rows []map[string]interface{}
+ for groupIndex, group := range idGroups {
+ for itemIndex, id := range group {
+ row := map[string]interface{}{
+ "query_index": groupIndex,
+ "id": fmt.Sprintf("%v", id),
+ }
+ if doc := nestedValue(docGroups, groupIndex, itemIndex); doc != nil {
+ row["document"] = doc
+ }
+ if dist := nestedValue(distanceGroups, groupIndex, itemIndex); dist != nil {
+ row["distance"] = dist
+ }
+ if meta, ok := nestedValue(metaGroups, groupIndex, itemIndex).(map[string]interface{}); ok {
+ row["metadata"] = meta
+ for k, v := range meta {
+ row["metadata."+k] = v
+ }
+ }
+ rows = append(rows, row)
+ }
+ }
+ return rows
+}
+
+func collectColumns(rows []map[string]interface{}) []string {
+ set := make(map[string]struct{})
+ for _, row := range rows {
+ for key := range row {
+ set[key] = struct{}{}
+ }
+ }
+ cols := make([]string, 0, len(set))
+ for key := range set {
+ cols = append(cols, key)
+ }
+ sort.Strings(cols)
+ for _, priority := range []string{"id", "query_index", "document", "distance", "metadata", "embedding"} {
+ for i, col := range cols {
+ if col == priority && i > 0 {
+ cols = append(cols[:i], cols[i+1:]...)
+ cols = append([]string{priority}, cols...)
+ break
+ }
+ }
+ }
+ return cols
+}
+
+func tableNameOrDB(dbName, tableName string) string {
+ if name := strings.TrimSpace(tableName); name != "" {
+ return name
+ }
+ return strings.TrimSpace(dbName)
+}
+
+func chromaStructRow(col chromaCollection) map[string]interface{} {
+ row := map[string]interface{}{
+ "id": col.ID,
+ "name": col.Name,
+ "tenant": col.Tenant,
+ "database": col.Database,
+ }
+ if col.Dimension > 0 {
+ row["dimension"] = col.Dimension
+ }
+ if len(col.Metadata) > 0 {
+ row["metadata"] = col.Metadata
+ }
+ return row
+}
+
+func chromaCountValue(raw interface{}) int64 {
+ switch v := raw.(type) {
+ case json.Number:
+ n, _ := v.Int64()
+ return n
+ case float64:
+ return int64(v)
+ case int:
+ return int64(v)
+ case int64:
+ return v
+ case map[string]interface{}:
+ return chromaCountValue(firstExisting(v, "count", "total", "value"))
+ default:
+ return 0
+ }
+}
+
+func chromaRowID(row map[string]interface{}) string {
+ return strings.TrimSpace(fmt.Sprintf("%v", firstExisting(row, "id", "_id")))
+}
+
+func isChromaReservedRowField(key string) bool {
+ switch key {
+ case "id", "_id", "document", "_document", "documents", "metadata", "embedding", "_embedding", "embeddings":
+ return true
+ default:
+ return false
+ }
+}
+
+func normalizeChromaEmbedding(value interface{}) interface{} {
+ if text, ok := value.(string); ok {
+ var parsed interface{}
+ if err := decodeJSONWithUseNumber([]byte(text), &parsed); err == nil {
+ return parsed
+ }
+ }
+ return value
+}
+
+func inferChromaValueType(value interface{}) string {
+ switch value.(type) {
+ case bool:
+ return "bool"
+ case json.Number, float64, float32, int, int64:
+ return "number"
+ case map[string]interface{}:
+ return "json"
+ case []interface{}:
+ return "array"
+ default:
+ return "string"
+ }
+}
+
+func firstStringValue(m map[string]interface{}, keys ...string) string {
+ for _, key := range keys {
+ if value, ok := m[key]; ok {
+ text := strings.TrimSpace(fmt.Sprintf("%v", value))
+ if text != "" && text != "" {
+ return text
+ }
+ }
+ }
+ return ""
+}
+
+func firstExisting(m map[string]interface{}, keys ...string) interface{} {
+ for _, key := range keys {
+ if value, ok := m[key]; ok {
+ return value
+ }
+ }
+ return nil
+}
+
+func firstNonEmpty(values ...string) string {
+ for _, value := range values {
+ if text := strings.TrimSpace(value); text != "" {
+ return text
+ }
+ }
+ return ""
+}
+
+func hasAnyKey(m map[string]interface{}, keys ...string) bool {
+ for _, key := range keys {
+ if _, ok := m[key]; ok {
+ return true
+ }
+ }
+ return false
+}
+
+func getOrBool(m map[string]interface{}, keys ...string) bool {
+ for _, key := range keys {
+ switch v := m[key].(type) {
+ case bool:
+ return v
+ case string:
+ return strings.EqualFold(strings.TrimSpace(v), "true")
+ }
+ }
+ return false
+}
+
+func intFromAny(value interface{}, fallback int) int {
+ switch v := value.(type) {
+ case json.Number:
+ n, err := v.Int64()
+ if err == nil {
+ return int(n)
+ }
+ case float64:
+ return int(v)
+ case int:
+ return v
+ case int64:
+ return int(v)
+ case string:
+ n, err := strconv.Atoi(strings.TrimSpace(v))
+ if err == nil {
+ return n
+ }
+ }
+ return fallback
+}
+
+func mapString(m map[string]interface{}, key string) string {
+ return strings.TrimSpace(fmt.Sprintf("%v", m[key]))
+}
+
+func stringSliceFromAny(value interface{}, fallback []string) []string {
+ if value == nil {
+ return fallback
+ }
+ switch v := value.(type) {
+ case []string:
+ return v
+ case []interface{}:
+ result := make([]string, 0, len(v))
+ for _, item := range v {
+ if text := strings.TrimSpace(fmt.Sprintf("%v", item)); text != "" {
+ result = append(result, text)
+ }
+ }
+ if len(result) > 0 {
+ return result
+ }
+ }
+ return fallback
+}
+
+func anySlice(value interface{}) []interface{} {
+ switch v := value.(type) {
+ case []interface{}:
+ return v
+ case []string:
+ result := make([]interface{}, len(v))
+ for i, item := range v {
+ result[i] = item
+ }
+ return result
+ default:
+ return nil
+ }
+}
+
+func nestedAnySlice(value interface{}) [][]interface{} {
+ switch v := value.(type) {
+ case []interface{}:
+ result := make([][]interface{}, 0, len(v))
+ for _, group := range v {
+ result = append(result, anySlice(group))
+ }
+ return result
+ default:
+ return nil
+ }
+}
+
+func nestedValue(groups [][]interface{}, groupIndex, itemIndex int) interface{} {
+ if groupIndex < 0 || groupIndex >= len(groups) {
+ return nil
+ }
+ group := groups[groupIndex]
+ if itemIndex < 0 || itemIndex >= len(group) {
+ return nil
+ }
+ return group[itemIndex]
+}
+
+func sliceValue(items []interface{}, index int) interface{} {
+ if index < 0 || index >= len(items) {
+ return nil
+ }
+ return normalizeJSONLikeValue(items[index])
+}
+
+func sliceValueMap(items []map[string]interface{}, index int) map[string]interface{} {
+ if index < 0 || index >= len(items) {
+ return nil
+ }
+ return items[index]
+}
+
+func normalizeJSONLikeValue(value interface{}) interface{} {
+ switch value.(type) {
+ case map[string]interface{}, []interface{}:
+ payload, err := json.Marshal(value)
+ if err == nil {
+ return string(payload)
+ }
+ }
+ return value
+}
diff --git a/internal/db/chroma_impl_test.go b/internal/db/chroma_impl_test.go
new file mode 100644
index 0000000..94f71ce
--- /dev/null
+++ b/internal/db/chroma_impl_test.go
@@ -0,0 +1,292 @@
+package db
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "strings"
+ "testing"
+
+ "GoNavi-Wails/internal/connection"
+)
+
+func newMockChromaServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
+ t.Helper()
+ server := httptest.NewServer(handler)
+ t.Cleanup(server.Close)
+ return server
+}
+
+func newTestChromaDB(t *testing.T, serverURL string) *ChromaDB {
+ t.Helper()
+ parsed, err := url.Parse(serverURL)
+ if err != nil {
+ t.Fatalf("parse server URL: %v", err)
+ }
+ host, port, ok := parseHostPortWithDefault(parsed.Host, defaultChromaPort)
+ if !ok {
+ t.Fatalf("parse host port failed: %s", parsed.Host)
+ }
+ db := &ChromaDB{}
+ if err := db.Connect(connection.ConnectionConfig{
+ Type: "chroma",
+ Host: host,
+ Port: port,
+ }); err != nil {
+ t.Fatalf("connect chroma: %v", err)
+ }
+ t.Cleanup(func() { _ = db.Close() })
+ return db
+}
+
+func writeChromaJSON(w http.ResponseWriter, value interface{}) {
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(value)
+}
+
+func TestChromaConnectDetectsV2(t *testing.T) {
+ server := newMockChromaServer(t, func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == http.MethodGet && r.URL.Path == "/api/v2/heartbeat" {
+ writeChromaJSON(w, map[string]interface{}{"nanosecond heartbeat": 1})
+ return
+ }
+ w.WriteHeader(http.StatusNotFound)
+ })
+
+ db := newTestChromaDB(t, server.URL)
+ if db.apiVersion != 2 {
+ t.Fatalf("apiVersion = %d, want 2", db.apiVersion)
+ }
+}
+
+func TestChromaConnectFallsBackToV1(t *testing.T) {
+ server := newMockChromaServer(t, func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/api/v2/heartbeat" {
+ w.WriteHeader(http.StatusNotFound)
+ return
+ }
+ if r.Method == http.MethodGet && r.URL.Path == "/api/v1/heartbeat" {
+ writeChromaJSON(w, map[string]interface{}{"nanosecond heartbeat": 1})
+ return
+ }
+ w.WriteHeader(http.StatusNotFound)
+ })
+
+ db := newTestChromaDB(t, server.URL)
+ if db.apiVersion != 1 {
+ t.Fatalf("apiVersion = %d, want 1", db.apiVersion)
+ }
+}
+
+func TestChromaGetDatabasesAndTablesV2(t *testing.T) {
+ server := newMockChromaServer(t, func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/api/v2/heartbeat":
+ writeChromaJSON(w, map[string]interface{}{"ok": true})
+ case "/api/v2/tenants/default_tenant/databases":
+ writeChromaJSON(w, []map[string]interface{}{
+ {"name": "analytics"},
+ {"name": "default_database"},
+ })
+ case "/api/v2/tenants/default_tenant/databases/default_database/collections":
+ writeChromaJSON(w, []chromaCollection{
+ {ID: "col-products", Name: "products", Database: "default_database", Tenant: "default_tenant"},
+ {ID: "col-logs", Name: "logs", Database: "default_database", Tenant: "default_tenant"},
+ })
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ })
+
+ db := newTestChromaDB(t, server.URL)
+ dbs, err := db.GetDatabases()
+ if err != nil {
+ t.Fatalf("GetDatabases failed: %v", err)
+ }
+ if strings.Join(dbs, ",") != "analytics,default_database" {
+ t.Fatalf("databases = %v", dbs)
+ }
+ tables, err := db.GetTables("")
+ if err != nil {
+ t.Fatalf("GetTables failed: %v", err)
+ }
+ if strings.Join(tables, ",") != "logs,products" {
+ t.Fatalf("tables = %v", tables)
+ }
+}
+
+func TestChromaSelectConvertsToGetRows(t *testing.T) {
+ var capturedPath string
+ var capturedBody map[string]interface{}
+ server := newMockChromaServer(t, func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case r.URL.Path == "/api/v2/heartbeat":
+ writeChromaJSON(w, map[string]interface{}{"ok": true})
+ case r.URL.Path == "/api/v2/tenants/default_tenant/databases/default_database/collections":
+ writeChromaJSON(w, []chromaCollection{{ID: "col-products", Name: "products", Database: "default_database"}})
+ case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/collections/col-products/get"):
+ capturedPath = r.URL.Path
+ _ = json.NewDecoder(r.Body).Decode(&capturedBody)
+ writeChromaJSON(w, chromaGetResponse{
+ IDs: []string{"p1"},
+ Documents: []interface{}{"first product"},
+ Metadatas: []map[string]interface{}{{"category": "book", "price": json.Number("19.5")}},
+ })
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ })
+
+ db := newTestChromaDB(t, server.URL)
+ rows, columns, err := db.Query(`SELECT * FROM "products" LIMIT 10 OFFSET 5`)
+ if err != nil {
+ t.Fatalf("Query failed: %v", err)
+ }
+ if capturedPath == "" {
+ t.Fatal("expected get endpoint to be called")
+ }
+ if intFromAny(capturedBody["limit"], 0) != 10 || intFromAny(capturedBody["offset"], -1) != 5 {
+ t.Fatalf("captured body = %#v", capturedBody)
+ }
+ if len(rows) != 1 || rows[0]["id"] != "p1" || rows[0]["metadata.category"] != "book" {
+ t.Fatalf("rows = %#v", rows)
+ }
+ if !containsString(columns, "metadata.category") {
+ t.Fatalf("columns missing metadata.category: %v", columns)
+ }
+}
+
+func TestChromaJSONQueryFlattensResults(t *testing.T) {
+ server := newMockChromaServer(t, func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case r.URL.Path == "/api/v2/heartbeat":
+ writeChromaJSON(w, map[string]interface{}{"ok": true})
+ case r.URL.Path == "/api/v2/tenants/default_tenant/databases/default_database/collections":
+ writeChromaJSON(w, []chromaCollection{{ID: "col-products", Name: "products", Database: "default_database"}})
+ case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/collections/col-products/query"):
+ writeChromaJSON(w, map[string]interface{}{
+ "ids": [][]string{{"p1", "p2"}},
+ "documents": [][]string{{"first", "second"}},
+ "distances": [][]float64{{0.1, 0.2}},
+ "metadatas": [][]map[string]interface{}{{{"category": "book"}, {"category": "tool"}}},
+ })
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ })
+
+ db := newTestChromaDB(t, server.URL)
+ rows, columns, err := db.Query(`{"query":"products","query_embeddings":[[0.1,0.2]],"n_results":2}`)
+ if err != nil {
+ t.Fatalf("Query failed: %v", err)
+ }
+ if len(rows) != 2 || rows[1]["id"] != "p2" || rows[1]["distance"] == nil {
+ t.Fatalf("rows = %#v", rows)
+ }
+ if !containsString(columns, "distance") || !containsString(columns, "metadata.category") {
+ t.Fatalf("columns = %v", columns)
+ }
+}
+
+func TestChromaApplyChangesUpsertAndDelete(t *testing.T) {
+ var upsertBody map[string]interface{}
+ var deleteBody map[string]interface{}
+ server := newMockChromaServer(t, func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case r.URL.Path == "/api/v2/heartbeat":
+ writeChromaJSON(w, map[string]interface{}{"ok": true})
+ case r.URL.Path == "/api/v2/tenants/default_tenant/databases/default_database/collections":
+ writeChromaJSON(w, []chromaCollection{{ID: "col-products", Name: "products", Database: "default_database"}})
+ case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/collections/col-products/upsert"):
+ _ = json.NewDecoder(r.Body).Decode(&upsertBody)
+ writeChromaJSON(w, map[string]interface{}{"ok": true})
+ case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/collections/col-products/delete"):
+ _ = json.NewDecoder(r.Body).Decode(&deleteBody)
+ writeChromaJSON(w, map[string]interface{}{"ok": true})
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ })
+
+ db := newTestChromaDB(t, server.URL)
+ err := db.ApplyChanges("products", connection.ChangeSet{
+ Deletes: []map[string]interface{}{{"id": "old"}},
+ Inserts: []map[string]interface{}{
+ {"id": "new", "document": "hello", "metadata.kind": "demo", "score": 9},
+ },
+ })
+ if err != nil {
+ t.Fatalf("ApplyChanges failed: %v", err)
+ }
+ if ids := anySlice(deleteBody["ids"]); len(ids) != 1 || ids[0] != "old" {
+ t.Fatalf("delete body = %#v", deleteBody)
+ }
+ if ids := anySlice(upsertBody["ids"]); len(ids) != 1 || ids[0] != "new" {
+ t.Fatalf("upsert body = %#v", upsertBody)
+ }
+ metas := anySlice(upsertBody["metadatas"])
+ if len(metas) != 1 {
+ t.Fatalf("metadatas = %#v", upsertBody["metadatas"])
+ }
+ meta, _ := metas[0].(map[string]interface{})
+ if meta["kind"] != "demo" || meta["score"] == nil {
+ t.Fatalf("metadata = %#v", meta)
+ }
+}
+
+func TestChromaLiveSmoke(t *testing.T) {
+ serverURL := strings.TrimSpace(os.Getenv("GONAVI_CHROMA_TEST_URL"))
+ if serverURL == "" {
+ t.Skip("set GONAVI_CHROMA_TEST_URL to run live Chroma smoke test")
+ }
+
+ db := newTestChromaDB(t, serverURL)
+ collection := "gonavi_smoke_live"
+ _, _ = db.Exec(fmt.Sprintf(`{"delete_collection":%q}`, collection))
+ if _, err := db.Exec(fmt.Sprintf(`{"create_collection":%q,"get_or_create":true}`, collection)); err != nil {
+ t.Fatalf("create live collection: %v", err)
+ }
+ t.Cleanup(func() { _, _ = db.Exec(fmt.Sprintf(`{"delete_collection":%q}`, collection)) })
+
+ if err := db.ApplyChanges(collection, connection.ChangeSet{
+ Inserts: []map[string]interface{}{{
+ "id": "doc-1",
+ "document": "GoNavi Chroma live smoke",
+ "metadata.kind": "smoke",
+ "embedding": []float64{0.1, 0.2, 0.3},
+ }},
+ }); err != nil {
+ t.Fatalf("upsert live row: %v", err)
+ }
+
+ rows, columns, err := db.Query(fmt.Sprintf(`SELECT * FROM "%s" LIMIT 5`, collection))
+ if err != nil {
+ t.Fatalf("select live rows: %v", err)
+ }
+ if len(rows) == 0 || rows[0]["id"] != "doc-1" || rows[0]["metadata.kind"] != "smoke" {
+ t.Fatalf("live rows = %#v", rows)
+ }
+ if !containsString(columns, "metadata.kind") {
+ t.Fatalf("live columns missing metadata.kind: %v", columns)
+ }
+
+ queryRows, queryColumns, err := db.Query(fmt.Sprintf(`{"query":%q,"query_embeddings":[[0.1,0.2,0.3]],"n_results":1}`, collection))
+ if err != nil {
+ t.Fatalf("query live rows: %v", err)
+ }
+ if len(queryRows) == 0 || queryRows[0]["id"] != "doc-1" || !containsString(queryColumns, "distance") {
+ t.Fatalf("live query rows = %#v columns = %v", queryRows, queryColumns)
+ }
+}
+
+func containsString(items []string, target string) bool {
+ for _, item := range items {
+ if item == target {
+ return true
+ }
+ }
+ return false
+}
diff --git a/internal/db/database.go b/internal/db/database.go
index 1272599..0305e9e 100644
--- a/internal/db/database.go
+++ b/internal/db/database.go
@@ -480,6 +480,9 @@ var databaseFactories = map[string]databaseFactory{
"custom": func() Database {
return &CustomDB{}
},
+ "chroma": func() Database {
+ return &ChromaDB{}
+ },
}
func init() {
@@ -512,6 +515,8 @@ func normalizeDatabaseType(dbType string) string {
return "opengauss"
case "intersystems", "intersystemsiris", "inter-systems-iris", "inter-systems":
return "iris"
+ case "chromadb", "chroma-db":
+ return "chroma"
default:
return normalized
}
diff --git a/internal/db/driver_support.go b/internal/db/driver_support.go
index 87b97f5..b70623d 100644
--- a/internal/db/driver_support.go
+++ b/internal/db/driver_support.go
@@ -16,28 +16,29 @@ var coreBuiltinDrivers = map[string]struct{}{
"redis": {},
"oracle": {},
"postgres": {},
+ "chroma": {},
}
// optionalGoDrivers 表示需要用户“安装启用”后才能使用的纯 Go 驱动。
// 注意:这是一种运行时门控(installed.json 标记),并不减少主二进制体积。
var optionalGoDrivers = map[string]struct{}{
- "mariadb": {},
- "oceanbase": {},
- "diros": {},
- "starrocks": {},
- "sphinx": {},
- "sqlserver": {},
- "sqlite": {},
- "duckdb": {},
- "dameng": {},
- "kingbase": {},
- "highgo": {},
- "vastbase": {},
- "opengauss": {},
- "iris": {},
- "mongodb": {},
- "tdengine": {},
- "clickhouse": {},
+ "mariadb": {},
+ "oceanbase": {},
+ "diros": {},
+ "starrocks": {},
+ "sphinx": {},
+ "sqlserver": {},
+ "sqlite": {},
+ "duckdb": {},
+ "dameng": {},
+ "kingbase": {},
+ "highgo": {},
+ "vastbase": {},
+ "opengauss": {},
+ "iris": {},
+ "mongodb": {},
+ "tdengine": {},
+ "clickhouse": {},
"elasticsearch": {},
}
@@ -66,6 +67,8 @@ func normalizeRuntimeDriverType(driverType string) string {
return "iris"
case "elastic":
return "elasticsearch"
+ case "chromadb", "chroma-db":
+ return "chroma"
default:
return normalized
}
@@ -117,6 +120,8 @@ func driverDisplayName(driverType string) string {
return "ClickHouse"
case "elasticsearch":
return "Elasticsearch"
+ case "chroma":
+ return "Chroma"
default:
return strings.ToUpper(strings.TrimSpace(driverType))
}