Files
MyGoNavi/internal/db/chroma_impl.go
Syngnat 56126e22f2 feat(chroma): 新增 Chroma 向量库连接支持
- 后端新增 Chroma REST 连接、元数据浏览、JSON/SELECT 查询与 upsert/delete 写入

- 前端新增 Chroma 类型、连接配置、图标、方言和能力矩阵

- 测试覆盖 v1/v2 兼容、真实服务 smoke 和前端配置

Refs #560
2026-06-13 16:47:25 +08:00

1197 lines
34 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package db
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
proxytunnel "GoNavi-Wails/internal/proxy"
"GoNavi-Wails/internal/ssh"
)
const (
defaultChromaPort = 8000
defaultChromaTenant = "default_tenant"
defaultChromaDatabase = "default_database"
defaultChromaQueryTimeout = 30 * time.Second
)
type ChromaDB struct {
client *http.Client
baseURL string
tenant string
database string
apiVersion int
authHeaders map[string]string
forwarder *ssh.LocalForwarder
}
type chromaCollection struct {
ID string `json:"id"`
Name string `json:"name"`
Metadata map[string]interface{} `json:"metadata"`
Dimension int `json:"dimension"`
Tenant string `json:"tenant"`
Database string `json:"database"`
}
type chromaGetResponse struct {
IDs []string `json:"ids"`
Documents []interface{} `json:"documents"`
Metadatas []map[string]interface{} `json:"metadatas"`
Embeddings []interface{} `json:"embeddings"`
Included []string `json:"included"`
}
func (c *ChromaDB) Connect(config connection.ConnectionConfig) error {
if c.forwarder != nil {
_ = c.forwarder.Close()
c.forwarder = nil
}
c.client = nil
runConfig := normalizeChromaConfig(config)
if runConfig.UseSSH {
forwarder, err := ssh.GetOrCreateLocalForwarder(runConfig.SSH, runConfig.Host, runConfig.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
c.forwarder = forwarder
host, portText, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portText)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
runConfig.Host = host
runConfig.Port = port
runConfig.UseSSH = false
logger.Infof("Chroma 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
}
c.tenant = chromaTenantFromConfig(runConfig)
c.database = chromaDatabaseFromConfig(runConfig)
c.baseURL = buildChromaBaseURL(runConfig)
c.authHeaders = chromaAuthHeaders(runConfig)
c.client = buildChromaHTTPClient(runConfig)
if err := c.Ping(); err != nil {
_ = c.Close()
return err
}
return nil
}
func (c *ChromaDB) Close() error {
if c.forwarder != nil {
if err := c.forwarder.Close(); err != nil {
logger.Warnf("关闭 Chroma SSH 端口转发失败:%v", err)
}
c.forwarder = nil
}
c.client = nil
return nil
}
func (c *ChromaDB) Ping() error {
if c.client == nil {
return fmt.Errorf("连接未打开")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := c.detectVersion(ctx); err != nil {
return err
}
return nil
}
func (c *ChromaDB) Query(query string) ([]map[string]interface{}, []string, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultChromaQueryTimeout)
defer cancel()
return c.QueryContext(ctx, query)
}
func (c *ChromaDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if c.client == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
text := strings.TrimSpace(query)
if text == "" {
return nil, nil, fmt.Errorf("查询语句不能为空")
}
if strings.HasPrefix(text, "{") {
return c.queryJSON(ctx, text)
}
if parsed, ok := parseChromaSQL(text); ok {
if parsed.Count {
total, err := c.countCollection(ctx, parsed.Collection, parsed.Where)
if err != nil {
return nil, nil, err
}
return []map[string]interface{}{{"total": total}}, []string{"total"}, nil
}
include := []string{"documents", "metadatas"}
if parsed.IncludeEmbeddings {
include = append(include, "embeddings")
}
return c.getCollectionRows(ctx, parsed.Collection, parsed.Limit, parsed.Offset, parsed.Where, include)
}
return nil, nil, fmt.Errorf("Chroma 查询仅支持 JSON 命令或简单 SELECT 预览")
}
func (c *ChromaDB) Exec(query string) (int64, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultChromaQueryTimeout)
defer cancel()
return c.ExecContext(ctx, query)
}
func (c *ChromaDB) ExecContext(ctx context.Context, query string) (int64, error) {
if c.client == nil {
return 0, fmt.Errorf("连接未打开")
}
var cmd map[string]interface{}
if err := decodeJSONWithUseNumber([]byte(strings.TrimSpace(query)), &cmd); err != nil {
return 0, fmt.Errorf("Chroma 写入命令必须是 JSON%w", err)
}
if name := firstStringValue(cmd, "create_collection", "createCollection", "collection"); name != "" && hasAnyKey(cmd, "create_collection", "createCollection") {
body := map[string]interface{}{"name": name}
if metadata, ok := cmd["metadata"]; ok {
body["metadata"] = metadata
}
if getOrBool(cmd, "get_or_create", "getOrCreate") {
body["get_or_create"] = true
}
return 1, c.createCollection(ctx, body)
}
if name := firstStringValue(cmd, "delete_collection", "deleteCollection"); name != "" {
return 1, c.deleteCollection(ctx, name)
}
if name := firstStringValue(cmd, "upsert", "collection"); name != "" && hasAnyKey(cmd, "upsert") {
return c.upsertCommand(ctx, name, cmd)
}
if name := firstStringValue(cmd, "delete", "collection"); name != "" && hasAnyKey(cmd, "delete") {
return c.deleteCommand(ctx, name, cmd)
}
return 0, fmt.Errorf("Chroma JSON 写入命令仅支持 create_collection/delete_collection/upsert/delete")
}
func (c *ChromaDB) GetDatabases() ([]string, error) {
if c.client == nil {
return nil, fmt.Errorf("连接未打开")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := c.ensureVersion(ctx); err != nil {
return nil, err
}
if c.apiVersion != 2 {
return []string{c.database}, nil
}
var raw []map[string]interface{}
err := c.doJSON(ctx, http.MethodGet, fmt.Sprintf("/api/v2/tenants/%s/databases", url.PathEscape(c.tenant)), nil, &raw)
if err != nil {
return []string{c.database}, nil
}
names := make([]string, 0, len(raw))
for _, item := range raw {
if name := mapString(item, "name"); name != "" {
names = append(names, name)
}
}
if len(names) == 0 {
names = append(names, c.database)
}
sort.Strings(names)
return names, nil
}
func (c *ChromaDB) GetTables(dbName string) ([]string, error) {
collections, err := c.listCollections(context.Background(), dbName)
if err != nil {
return nil, err
}
names := make([]string, 0, len(collections))
for _, item := range collections {
if strings.TrimSpace(item.Name) != "" {
names = append(names, item.Name)
}
}
sort.Strings(names)
return names, nil
}
func (c *ChromaDB) GetCreateStatement(dbName, tableName string) (string, error) {
coll, err := c.resolveCollection(context.Background(), dbName, tableName)
if err != nil {
return "", err
}
payload, _ := json.MarshalIndent(coll, "", " ")
return fmt.Sprintf("// Chroma collection: %s\n%s", coll.Name, string(payload)), nil
}
func (c *ChromaDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
rows, _, err := c.getCollectionRows(context.Background(), tableNameOrDB(dbName, tableName), 20, 0, nil, []string{"documents", "metadatas", "embeddings"})
if err != nil {
return nil, err
}
cols := []connection.ColumnDefinition{
{Name: "id", Type: "string", Nullable: "NO", Key: "PRI", Comment: "Chroma document id"},
{Name: "document", Type: "text", Nullable: "YES", Comment: "Document text"},
{Name: "metadata", Type: "json", Nullable: "YES", Comment: "Full metadata object"},
{Name: "embedding", Type: "vector<float>", Nullable: "YES", Comment: "Embedding vector"},
}
seen := map[string]struct{}{"id": {}, "document": {}, "metadata": {}, "embedding": {}}
for _, row := range rows {
for key, value := range row {
if _, exists := seen[key]; exists || !strings.HasPrefix(key, "metadata.") {
continue
}
seen[key] = struct{}{}
cols = append(cols, connection.ColumnDefinition{
Name: key,
Type: inferChromaValueType(value),
Nullable: "YES",
Comment: "Metadata field",
})
}
}
return cols, nil
}
func (c *ChromaDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
tables, err := c.GetTables(dbName)
if err != nil {
return nil, err
}
var result []connection.ColumnDefinitionWithTable
for _, table := range tables {
cols, err := c.GetColumns(dbName, table)
if err != nil {
continue
}
for _, col := range cols {
result = append(result, connection.ColumnDefinitionWithTable{
TableName: table,
Name: col.Name,
Type: col.Type,
Comment: col.Comment,
})
}
}
return result, nil
}
func (c *ChromaDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
return []connection.IndexDefinition{
{Name: "PRIMARY", ColumnName: "id", NonUnique: 0, SeqInIndex: 1, IndexType: "PRIMARY"},
{Name: "HNSW", ColumnName: "embedding", NonUnique: 1, SeqInIndex: 1, IndexType: "VECTOR"},
}, nil
}
func (c *ChromaDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
return []connection.ForeignKeyDefinition{}, nil
}
func (c *ChromaDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return []connection.TriggerDefinition{}, nil
}
func (c *ChromaDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultChromaQueryTimeout)
defer cancel()
if len(changes.Deletes) > 0 {
ids := make([]string, 0, len(changes.Deletes))
for _, row := range changes.Deletes {
if id := chromaRowID(row); id != "" {
ids = append(ids, id)
}
}
if len(ids) > 0 {
if _, err := c.deleteCommand(ctx, tableName, map[string]interface{}{"ids": ids}); err != nil {
return err
}
}
}
if len(changes.Updates) > 0 {
rows := make([]map[string]interface{}, 0, len(changes.Updates))
for _, update := range changes.Updates {
row := make(map[string]interface{}, len(update.Keys)+len(update.Values))
for k, v := range update.Keys {
row[k] = v
}
for k, v := range update.Values {
row[k] = v
}
rows = append(rows, row)
}
if err := c.upsertRows(ctx, tableName, rows); err != nil {
return err
}
}
if len(changes.Inserts) > 0 {
if err := c.upsertRows(ctx, tableName, changes.Inserts); err != nil {
return err
}
}
return nil
}
func normalizeChromaConfig(config connection.ConnectionConfig) connection.ConnectionConfig {
runConfig := applyChromaURI(config)
if strings.TrimSpace(runConfig.Host) == "" {
runConfig.Host = "localhost"
}
if runConfig.Port <= 0 {
runConfig.Port = defaultChromaPort
}
if strings.TrimSpace(runConfig.SSLMode) == "" && runConfig.UseSSL {
runConfig.SSLMode = "required"
}
return runConfig
}
func applyChromaURI(config connection.ConnectionConfig) connection.ConnectionConfig {
uriText := strings.TrimSpace(config.URI)
if uriText == "" {
return config
}
parsed, err := url.Parse(uriText)
if err != nil {
return config
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme != "http" && scheme != "https" && scheme != "chroma" {
return config
}
if parsed.User != nil {
if strings.TrimSpace(config.User) == "" {
config.User = parsed.User.Username()
}
if pass, ok := parsed.User.Password(); ok && config.Password == "" {
config.Password = pass
}
}
if scheme == "https" {
config.UseSSL = true
}
if host := strings.TrimSpace(parsed.Host); host != "" {
if h, port, ok := parseHostPortWithDefault(host, defaultChromaPort); ok {
config.Host = h
config.Port = port
}
}
if dbName := strings.Trim(strings.TrimSpace(parsed.Path), "/"); dbName != "" && !strings.HasPrefix(dbName, "api/") && strings.TrimSpace(config.Database) == "" {
config.Database = dbName
}
return config
}
func buildChromaBaseURL(config connection.ConnectionConfig) string {
scheme := "http"
if config.UseSSL {
scheme = "https"
}
return fmt.Sprintf("%s://%s:%d", scheme, strings.TrimSpace(config.Host), config.Port)
}
func chromaTenantFromConfig(config connection.ConnectionConfig) string {
params := chromaConnectionParams(config)
if tenant := strings.TrimSpace(params.Get("tenant")); tenant != "" {
return tenant
}
return defaultChromaTenant
}
func chromaDatabaseFromConfig(config connection.ConnectionConfig) string {
if dbName := strings.TrimSpace(config.Database); dbName != "" {
return dbName
}
params := chromaConnectionParams(config)
if dbName := strings.TrimSpace(params.Get("database")); dbName != "" {
return dbName
}
return defaultChromaDatabase
}
func chromaConnectionParams(config connection.ConnectionConfig) url.Values {
params := url.Values{}
mergeConnectionParamValues(params, connectionParamsFromURI(config.URI, "http", "https", "chroma"))
mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams))
return params
}
func chromaAuthHeaders(config connection.ConnectionConfig) map[string]string {
headers := make(map[string]string)
params := chromaConnectionParams(config)
token := firstNonEmpty(params.Get("apiKey"), params.Get("apikey"), params.Get("token"), params.Get("authToken"))
if token == "" && strings.TrimSpace(config.User) == "" {
token = strings.TrimSpace(config.Password)
}
if token != "" {
headers["Authorization"] = "Bearer " + token
} else if user := strings.TrimSpace(config.User); user != "" {
raw := user + ":" + config.Password
headers["Authorization"] = "Basic " + base64.StdEncoding.EncodeToString([]byte(raw))
}
if headerName := strings.TrimSpace(params.Get("authHeader")); headerName != "" {
if headerValue := strings.TrimSpace(params.Get("authHeaderValue")); headerValue != "" && isSafeConnectionParamKey(headerName) {
headers[headerName] = headerValue
}
}
return headers
}
func buildChromaHTTPClient(config connection.ConnectionConfig) *http.Client {
transport := http.DefaultTransport.(*http.Transport).Clone()
if tlsConfig, err := resolveGenericTLSConfig(config); err == nil && tlsConfig != nil {
transport.TLSClientConfig = tlsConfig
}
if config.UseProxy {
proxyCfg := config.Proxy
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return proxytunnel.DialContext(ctx, proxyCfg, network, addr)
}
}
return &http.Client{Transport: transport, Timeout: getConnectTimeout(config)}
}
func (c *ChromaDB) detectVersion(ctx context.Context) error {
if c.client == nil {
return fmt.Errorf("连接未打开")
}
if err := c.doJSON(ctx, http.MethodGet, "/api/v2/heartbeat", nil, nil); err == nil {
c.apiVersion = 2
return nil
}
if err := c.doJSON(ctx, http.MethodGet, "/api/v1/heartbeat", nil, nil); err == nil {
c.apiVersion = 1
return nil
}
return fmt.Errorf("Chroma 连接失败:无法访问 /api/v2/heartbeat 或 /api/v1/heartbeat")
}
func (c *ChromaDB) ensureVersion(ctx context.Context) error {
if c.apiVersion == 1 || c.apiVersion == 2 {
return nil
}
return c.detectVersion(ctx)
}
func (c *ChromaDB) doJSON(ctx context.Context, method, path string, body interface{}, out interface{}) error {
if c.client == nil {
return fmt.Errorf("连接未打开")
}
var reader io.Reader
if body != nil {
payload, err := json.Marshal(body)
if err != nil {
return err
}
reader = bytes.NewReader(payload)
}
req, err := http.NewRequestWithContext(ctx, method, strings.TrimRight(c.baseURL, "/")+path, reader)
if err != nil {
return err
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
req.Header.Set("Accept", "application/json")
if strings.TrimSpace(c.authHeaders["Authorization"]) == "" && strings.TrimSpace(req.Header.Get("Authorization")) == "" {
// Basic Auth remains useful for gateways even when Chroma itself uses token auth.
}
for key, value := range c.authHeaders {
if strings.TrimSpace(key) != "" && strings.TrimSpace(value) != "" {
req.Header.Set(key, value)
}
}
res, err := c.client.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
resBody, err := io.ReadAll(res.Body)
if err != nil {
return err
}
if res.StatusCode < 200 || res.StatusCode >= 300 {
message := strings.TrimSpace(string(resBody))
if message == "" {
message = res.Status
}
return fmt.Errorf("Chroma API %s %s 失败:%s", method, path, message)
}
if out == nil || len(bytes.TrimSpace(resBody)) == 0 {
return nil
}
if err := decodeJSONWithUseNumber(resBody, out); err != nil {
return fmt.Errorf("解析 Chroma 响应失败:%w", err)
}
return nil
}
func (c *ChromaDB) v2Path(dbName string, suffix string) string {
database := strings.TrimSpace(dbName)
if database == "" {
database = c.database
}
base := fmt.Sprintf("/api/v2/tenants/%s/databases/%s", url.PathEscape(c.tenant), url.PathEscape(database))
if suffix == "" {
return base
}
return base + suffix
}
func (c *ChromaDB) listCollections(ctx context.Context, dbName string) ([]chromaCollection, error) {
if c.client == nil {
return nil, fmt.Errorf("连接未打开")
}
if err := c.ensureVersion(ctx); err != nil {
return nil, err
}
var collections []chromaCollection
path := "/api/v1/collections"
if c.apiVersion == 2 {
path = c.v2Path(dbName, "/collections")
}
if err := c.doJSON(ctx, http.MethodGet, path, nil, &collections); err != nil {
return nil, err
}
return collections, nil
}
func (c *ChromaDB) resolveCollection(ctx context.Context, dbName, tableName string) (chromaCollection, error) {
name := tableNameOrDB(dbName, tableName)
if name == "" {
return chromaCollection{}, fmt.Errorf("collection 名称不能为空")
}
collections, err := c.listCollections(ctx, dbName)
if err != nil {
return chromaCollection{}, err
}
for _, item := range collections {
if strings.EqualFold(item.Name, name) || strings.EqualFold(item.ID, name) {
return item, nil
}
}
return chromaCollection{}, fmt.Errorf("未找到 Chroma collection%s", name)
}
func (c *ChromaDB) collectionActionPath(ctx context.Context, collectionName, action string) (string, error) {
if err := c.ensureVersion(ctx); err != nil {
return "", err
}
coll, err := c.resolveCollection(ctx, "", collectionName)
if err != nil {
return "", err
}
ident := coll.ID
if strings.TrimSpace(ident) == "" {
ident = coll.Name
}
if ident == "" {
return "", fmt.Errorf("collection 标识为空")
}
if c.apiVersion == 2 {
return c.v2Path(coll.Database, fmt.Sprintf("/collections/%s/%s", url.PathEscape(ident), action)), nil
}
return fmt.Sprintf("/api/v1/collections/%s/%s", url.PathEscape(ident), action), nil
}
func (c *ChromaDB) getCollectionRows(ctx context.Context, collection string, limit int, offset int, where interface{}, include []string) ([]map[string]interface{}, []string, error) {
if limit <= 0 {
limit = 200
}
path, err := c.collectionActionPath(ctx, collection, "get")
if err != nil {
return nil, nil, err
}
body := map[string]interface{}{
"limit": limit,
"offset": offset,
"include": include,
}
if where != nil {
body["where"] = where
}
var resp chromaGetResponse
if err := c.doJSON(ctx, http.MethodPost, path, body, &resp); err != nil {
return nil, nil, err
}
rows, columns := chromaGetResponseRows(resp)
return rows, columns, nil
}
func (c *ChromaDB) countCollection(ctx context.Context, collection string, where interface{}) (int64, error) {
path, err := c.collectionActionPath(ctx, collection, "count")
if err != nil {
return 0, err
}
if where == nil {
var raw interface{}
if err := c.doJSON(ctx, http.MethodGet, path, nil, &raw); err == nil {
return chromaCountValue(raw), nil
}
}
rows, _, err := c.getCollectionRows(ctx, collection, 1_000_000, 0, where, []string{"documents"})
if err != nil {
return 0, err
}
return int64(len(rows)), nil
}
func (c *ChromaDB) queryJSON(ctx context.Context, text string) ([]map[string]interface{}, []string, error) {
var cmd map[string]interface{}
if err := decodeJSONWithUseNumber([]byte(text), &cmd); err != nil {
return nil, nil, fmt.Errorf("Chroma JSON 命令解析失败:%w", err)
}
if hasAnyKey(cmd, "list_collections", "listCollections") {
cols, err := c.listCollections(ctx, "")
if err != nil {
return nil, nil, err
}
rows := make([]map[string]interface{}, 0, len(cols))
for _, col := range cols {
rows = append(rows, chromaStructRow(col))
}
return rows, collectColumns(rows), nil
}
if name := firstStringValue(cmd, "get", "collection"); name != "" && (hasAnyKey(cmd, "get") || !hasAnyKey(cmd, "query", "query_embeddings", "query_texts")) {
limit := intFromAny(cmd["limit"], 200)
offset := intFromAny(cmd["offset"], 0)
include := stringSliceFromAny(cmd["include"], []string{"documents", "metadatas"})
return c.getCollectionRows(ctx, name, limit, offset, cmd["where"], include)
}
if name := firstStringValue(cmd, "query", "collection"); name != "" {
return c.queryCollection(ctx, name, cmd)
}
return nil, nil, fmt.Errorf("Chroma JSON 查询命令仅支持 list_collections/get/query")
}
func (c *ChromaDB) queryCollection(ctx context.Context, collection string, cmd map[string]interface{}) ([]map[string]interface{}, []string, error) {
path, err := c.collectionActionPath(ctx, collection, "query")
if err != nil {
return nil, nil, err
}
body := make(map[string]interface{})
for _, key := range []string{"query_embeddings", "query_texts", "where", "where_document", "include"} {
if value, ok := cmd[key]; ok {
body[key] = value
}
}
body["n_results"] = intFromAny(firstExisting(cmd, "n_results", "limit"), 10)
if _, ok := body["include"]; !ok {
body["include"] = []string{"documents", "metadatas", "distances"}
}
var raw map[string]interface{}
if err := c.doJSON(ctx, http.MethodPost, path, body, &raw); err != nil {
return nil, nil, err
}
rows := chromaQueryResponseRows(raw)
return rows, collectColumns(rows), nil
}
func (c *ChromaDB) createCollection(ctx context.Context, body map[string]interface{}) error {
if err := c.ensureVersion(ctx); err != nil {
return err
}
path := "/api/v1/collections"
if c.apiVersion == 2 {
path = c.v2Path("", "/collections")
}
return c.doJSON(ctx, http.MethodPost, path, body, nil)
}
func (c *ChromaDB) deleteCollection(ctx context.Context, name string) error {
if err := c.ensureVersion(ctx); err != nil {
return err
}
path := fmt.Sprintf("/api/v1/collections/%s", url.PathEscape(name))
if c.apiVersion == 2 {
coll, err := c.resolveCollection(ctx, "", name)
if err != nil {
return err
}
ident := coll.ID
if ident == "" {
ident = coll.Name
}
path = c.v2Path(coll.Database, fmt.Sprintf("/collections/%s", url.PathEscape(ident)))
}
return c.doJSON(ctx, http.MethodDelete, path, nil, nil)
}
func (c *ChromaDB) upsertCommand(ctx context.Context, collection string, cmd map[string]interface{}) (int64, error) {
if rowsValue, ok := cmd["rows"].([]interface{}); ok {
rows := make([]map[string]interface{}, 0, len(rowsValue))
for _, raw := range rowsValue {
if row, ok := raw.(map[string]interface{}); ok {
rows = append(rows, row)
}
}
return int64(len(rows)), c.upsertRows(ctx, collection, rows)
}
body := make(map[string]interface{})
for _, key := range []string{"ids", "documents", "metadatas", "embeddings", "uris"} {
if value, ok := cmd[key]; ok {
body[key] = value
}
}
if _, ok := body["ids"]; !ok {
return 0, fmt.Errorf("Chroma upsert 命令缺少 ids")
}
path, err := c.collectionActionPath(ctx, collection, "upsert")
if err != nil {
return 0, err
}
return int64(len(anySlice(body["ids"]))), c.doJSON(ctx, http.MethodPost, path, body, nil)
}
func (c *ChromaDB) deleteCommand(ctx context.Context, collection string, cmd map[string]interface{}) (int64, error) {
body := make(map[string]interface{})
for _, key := range []string{"ids", "where", "where_document"} {
if value, ok := cmd[key]; ok {
body[key] = value
}
}
if len(body) == 0 {
return 0, fmt.Errorf("Chroma delete 命令缺少 ids/where/where_document")
}
path, err := c.collectionActionPath(ctx, collection, "delete")
if err != nil {
return 0, err
}
return int64(len(anySlice(body["ids"]))), c.doJSON(ctx, http.MethodPost, path, body, nil)
}
func (c *ChromaDB) upsertRows(ctx context.Context, collection string, rows []map[string]interface{}) error {
if len(rows) == 0 {
return nil
}
ids := make([]string, 0, len(rows))
docs := make([]interface{}, 0, len(rows))
metadatas := make([]map[string]interface{}, 0, len(rows))
embeddings := make([]interface{}, 0, len(rows))
hasEmbedding := false
for _, row := range rows {
id := chromaRowID(row)
if id == "" {
return fmt.Errorf("Chroma 写入行缺少 id")
}
ids = append(ids, id)
docs = append(docs, firstExisting(row, "document", "_document", "documents"))
meta := make(map[string]interface{})
if raw, ok := row["metadata"].(map[string]interface{}); ok {
for k, v := range raw {
meta[k] = v
}
}
for k, v := range row {
if isChromaReservedRowField(k) {
continue
}
if strings.HasPrefix(k, "metadata.") {
meta[strings.TrimPrefix(k, "metadata.")] = v
continue
}
meta[k] = v
}
metadatas = append(metadatas, meta)
if embedding := firstExisting(row, "embedding", "_embedding", "embeddings"); embedding != nil {
embeddings = append(embeddings, normalizeChromaEmbedding(embedding))
hasEmbedding = true
}
}
body := map[string]interface{}{
"ids": ids,
"documents": docs,
"metadatas": metadatas,
}
if hasEmbedding {
body["embeddings"] = embeddings
}
path, err := c.collectionActionPath(ctx, collection, "upsert")
if err != nil {
return err
}
return c.doJSON(ctx, http.MethodPost, path, body, nil)
}
type chromaParsedSQL struct {
Collection string
Limit int
Offset int
Where interface{}
Count bool
IncludeEmbeddings bool
}
var chromaSQLFromRE = regexp.MustCompile(`(?i)\bFROM\s+(?:"([^"]+)"|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z0-9_.\-]+))`)
var chromaSQLLimitRE = regexp.MustCompile(`(?i)\bLIMIT\s+(\d+)`)
var chromaSQLOffsetRE = regexp.MustCompile(`(?i)\bOFFSET\s+(\d+)`)
func parseChromaSQL(sqlText string) (chromaParsedSQL, bool) {
text := strings.TrimSpace(sqlText)
if !strings.HasPrefix(strings.ToLower(text), "select") {
return chromaParsedSQL{}, false
}
matches := chromaSQLFromRE.FindStringSubmatch(text)
if len(matches) == 0 {
return chromaParsedSQL{}, false
}
table := firstNonEmpty(matches[1], matches[2], matches[3])
if table == "" {
return chromaParsedSQL{}, false
}
parsed := chromaParsedSQL{Collection: table, Limit: 200}
lower := strings.ToLower(text)
parsed.Count = strings.Contains(lower, "count(")
parsed.IncludeEmbeddings = strings.Contains(lower, "embedding")
if m := chromaSQLLimitRE.FindStringSubmatch(text); len(m) > 1 {
parsed.Limit, _ = strconv.Atoi(m[1])
}
if m := chromaSQLOffsetRE.FindStringSubmatch(text); len(m) > 1 {
parsed.Offset, _ = strconv.Atoi(m[1])
}
return parsed, true
}
func chromaGetResponseRows(resp chromaGetResponse) ([]map[string]interface{}, []string) {
rows := make([]map[string]interface{}, 0, len(resp.IDs))
for index, id := range resp.IDs {
row := map[string]interface{}{"id": id}
if value := sliceValue(resp.Documents, index); value != nil {
row["document"] = value
}
if meta := sliceValueMap(resp.Metadatas, index); meta != nil {
row["metadata"] = meta
for k, v := range meta {
row["metadata."+k] = v
}
}
if value := sliceValue(resp.Embeddings, index); value != nil {
row["embedding"] = normalizeJSONLikeValue(value)
}
rows = append(rows, row)
}
return rows, collectColumns(rows)
}
func chromaQueryResponseRows(raw map[string]interface{}) []map[string]interface{} {
idGroups := nestedAnySlice(raw["ids"])
docGroups := nestedAnySlice(raw["documents"])
metaGroups := nestedAnySlice(raw["metadatas"])
distanceGroups := nestedAnySlice(raw["distances"])
var rows []map[string]interface{}
for groupIndex, group := range idGroups {
for itemIndex, id := range group {
row := map[string]interface{}{
"query_index": groupIndex,
"id": fmt.Sprintf("%v", id),
}
if doc := nestedValue(docGroups, groupIndex, itemIndex); doc != nil {
row["document"] = doc
}
if dist := nestedValue(distanceGroups, groupIndex, itemIndex); dist != nil {
row["distance"] = dist
}
if meta, ok := nestedValue(metaGroups, groupIndex, itemIndex).(map[string]interface{}); ok {
row["metadata"] = meta
for k, v := range meta {
row["metadata."+k] = v
}
}
rows = append(rows, row)
}
}
return rows
}
func collectColumns(rows []map[string]interface{}) []string {
set := make(map[string]struct{})
for _, row := range rows {
for key := range row {
set[key] = struct{}{}
}
}
cols := make([]string, 0, len(set))
for key := range set {
cols = append(cols, key)
}
sort.Strings(cols)
for _, priority := range []string{"id", "query_index", "document", "distance", "metadata", "embedding"} {
for i, col := range cols {
if col == priority && i > 0 {
cols = append(cols[:i], cols[i+1:]...)
cols = append([]string{priority}, cols...)
break
}
}
}
return cols
}
func tableNameOrDB(dbName, tableName string) string {
if name := strings.TrimSpace(tableName); name != "" {
return name
}
return strings.TrimSpace(dbName)
}
func chromaStructRow(col chromaCollection) map[string]interface{} {
row := map[string]interface{}{
"id": col.ID,
"name": col.Name,
"tenant": col.Tenant,
"database": col.Database,
}
if col.Dimension > 0 {
row["dimension"] = col.Dimension
}
if len(col.Metadata) > 0 {
row["metadata"] = col.Metadata
}
return row
}
func chromaCountValue(raw interface{}) int64 {
switch v := raw.(type) {
case json.Number:
n, _ := v.Int64()
return n
case float64:
return int64(v)
case int:
return int64(v)
case int64:
return v
case map[string]interface{}:
return chromaCountValue(firstExisting(v, "count", "total", "value"))
default:
return 0
}
}
func chromaRowID(row map[string]interface{}) string {
return strings.TrimSpace(fmt.Sprintf("%v", firstExisting(row, "id", "_id")))
}
func isChromaReservedRowField(key string) bool {
switch key {
case "id", "_id", "document", "_document", "documents", "metadata", "embedding", "_embedding", "embeddings":
return true
default:
return false
}
}
func normalizeChromaEmbedding(value interface{}) interface{} {
if text, ok := value.(string); ok {
var parsed interface{}
if err := decodeJSONWithUseNumber([]byte(text), &parsed); err == nil {
return parsed
}
}
return value
}
func inferChromaValueType(value interface{}) string {
switch value.(type) {
case bool:
return "bool"
case json.Number, float64, float32, int, int64:
return "number"
case map[string]interface{}:
return "json"
case []interface{}:
return "array"
default:
return "string"
}
}
func firstStringValue(m map[string]interface{}, keys ...string) string {
for _, key := range keys {
if value, ok := m[key]; ok {
text := strings.TrimSpace(fmt.Sprintf("%v", value))
if text != "" && text != "<nil>" {
return text
}
}
}
return ""
}
func firstExisting(m map[string]interface{}, keys ...string) interface{} {
for _, key := range keys {
if value, ok := m[key]; ok {
return value
}
}
return nil
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if text := strings.TrimSpace(value); text != "" {
return text
}
}
return ""
}
func hasAnyKey(m map[string]interface{}, keys ...string) bool {
for _, key := range keys {
if _, ok := m[key]; ok {
return true
}
}
return false
}
func getOrBool(m map[string]interface{}, keys ...string) bool {
for _, key := range keys {
switch v := m[key].(type) {
case bool:
return v
case string:
return strings.EqualFold(strings.TrimSpace(v), "true")
}
}
return false
}
func intFromAny(value interface{}, fallback int) int {
switch v := value.(type) {
case json.Number:
n, err := v.Int64()
if err == nil {
return int(n)
}
case float64:
return int(v)
case int:
return v
case int64:
return int(v)
case string:
n, err := strconv.Atoi(strings.TrimSpace(v))
if err == nil {
return n
}
}
return fallback
}
func mapString(m map[string]interface{}, key string) string {
return strings.TrimSpace(fmt.Sprintf("%v", m[key]))
}
func stringSliceFromAny(value interface{}, fallback []string) []string {
if value == nil {
return fallback
}
switch v := value.(type) {
case []string:
return v
case []interface{}:
result := make([]string, 0, len(v))
for _, item := range v {
if text := strings.TrimSpace(fmt.Sprintf("%v", item)); text != "" {
result = append(result, text)
}
}
if len(result) > 0 {
return result
}
}
return fallback
}
func anySlice(value interface{}) []interface{} {
switch v := value.(type) {
case []interface{}:
return v
case []string:
result := make([]interface{}, len(v))
for i, item := range v {
result[i] = item
}
return result
default:
return nil
}
}
func nestedAnySlice(value interface{}) [][]interface{} {
switch v := value.(type) {
case []interface{}:
result := make([][]interface{}, 0, len(v))
for _, group := range v {
result = append(result, anySlice(group))
}
return result
default:
return nil
}
}
func nestedValue(groups [][]interface{}, groupIndex, itemIndex int) interface{} {
if groupIndex < 0 || groupIndex >= len(groups) {
return nil
}
group := groups[groupIndex]
if itemIndex < 0 || itemIndex >= len(group) {
return nil
}
return group[itemIndex]
}
func sliceValue(items []interface{}, index int) interface{} {
if index < 0 || index >= len(items) {
return nil
}
return normalizeJSONLikeValue(items[index])
}
func sliceValueMap(items []map[string]interface{}, index int) map[string]interface{} {
if index < 0 || index >= len(items) {
return nil
}
return items[index]
}
func normalizeJSONLikeValue(value interface{}) interface{} {
switch value.(type) {
case map[string]interface{}, []interface{}:
payload, err := json.Marshal(value)
if err == nil {
return string(payload)
}
}
return value
}