Files
MyGoNavi/internal/db/qdrant_impl.go
Syngnat c805b16fcd feat(qdrant): 新增 Qdrant 向量库连接支持
- 后端新增 Qdrant REST 连接、collection 元数据、scroll/search 查询与 upsert/delete/payload 更新

- 前端新增 Qdrant 类型、连接配置、图标、方言和能力矩阵

- 测试覆盖 mock REST、真实服务 smoke 和前端配置

Refs #555
2026-06-13 17:03:20 +08:00

1050 lines
32 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package db
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"time"
"GoNavi-Wails/internal/connection"
"GoNavi-Wails/internal/logger"
proxytunnel "GoNavi-Wails/internal/proxy"
"GoNavi-Wails/internal/ssh"
)
const (
defaultQdrantPort = 6333
defaultQdrantDatabase = "default"
defaultQdrantQueryTimeout = 30 * time.Second
)
type QdrantDB struct {
client *http.Client
baseURL string
database string
authHeaders map[string]string
forwarder *ssh.LocalForwarder
}
type qdrantCollectionInfo struct {
Name string `json:"name"`
}
type qdrantListCollectionsResponse struct {
Result struct {
Collections []qdrantCollectionInfo `json:"collections"`
} `json:"result"`
}
type qdrantCollectionResponse struct {
Result map[string]interface{} `json:"result"`
}
type qdrantPoint struct {
ID interface{} `json:"id"`
Payload map[string]interface{} `json:"payload"`
Vector interface{} `json:"vector"`
Score interface{} `json:"score"`
Version interface{} `json:"version"`
}
type qdrantScrollResponse struct {
Result struct {
Points []qdrantPoint `json:"points"`
NextPageOffset interface{} `json:"next_page_offset"`
} `json:"result"`
}
type qdrantSearchResponse struct {
Result []qdrantPoint `json:"result"`
}
type qdrantCountResponse struct {
Result struct {
Count int64 `json:"count"`
} `json:"result"`
}
func (q *QdrantDB) Connect(config connection.ConnectionConfig) error {
if q.forwarder != nil {
_ = q.forwarder.Close()
q.forwarder = nil
}
q.client = nil
runConfig := normalizeQdrantConfig(config)
if runConfig.UseSSH {
forwarder, err := ssh.GetOrCreateLocalForwarder(runConfig.SSH, runConfig.Host, runConfig.Port)
if err != nil {
return fmt.Errorf("创建 SSH 隧道失败:%w", err)
}
q.forwarder = forwarder
host, portText, err := net.SplitHostPort(forwarder.LocalAddr)
if err != nil {
return fmt.Errorf("解析本地转发地址失败:%w", err)
}
port, err := strconv.Atoi(portText)
if err != nil {
return fmt.Errorf("解析本地端口失败:%w", err)
}
runConfig.Host = host
runConfig.Port = port
runConfig.UseSSH = false
logger.Infof("Qdrant 通过本地端口转发连接:%s -> %s:%d", forwarder.LocalAddr, config.Host, config.Port)
}
q.database = qdrantDatabaseFromConfig(runConfig)
q.baseURL = buildQdrantBaseURL(runConfig)
q.authHeaders = qdrantAuthHeaders(runConfig)
q.client = buildQdrantHTTPClient(runConfig)
if err := q.Ping(); err != nil {
_ = q.Close()
return err
}
return nil
}
func (q *QdrantDB) Close() error {
if q.forwarder != nil {
if err := q.forwarder.Close(); err != nil {
logger.Warnf("关闭 Qdrant SSH 端口转发失败:%v", err)
}
q.forwarder = nil
}
q.client = nil
return nil
}
func (q *QdrantDB) Ping() error {
if q.client == nil {
return fmt.Errorf("连接未打开")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
var resp qdrantListCollectionsResponse
return q.doJSON(ctx, http.MethodGet, "/collections", nil, &resp)
}
func (q *QdrantDB) Query(query string) ([]map[string]interface{}, []string, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultQdrantQueryTimeout)
defer cancel()
return q.QueryContext(ctx, query)
}
func (q *QdrantDB) QueryContext(ctx context.Context, query string) ([]map[string]interface{}, []string, error) {
if q.client == nil {
return nil, nil, fmt.Errorf("连接未打开")
}
text := strings.TrimSpace(query)
if text == "" {
return nil, nil, fmt.Errorf("查询语句不能为空")
}
if strings.HasPrefix(text, "{") {
return q.queryJSON(ctx, text)
}
if parsed, ok := parseQdrantSQL(text); ok {
if parsed.Count {
total, err := q.countPoints(ctx, parsed.Collection, nil)
if err != nil {
return nil, nil, err
}
return []map[string]interface{}{{"total": total}}, []string{"total"}, nil
}
return q.scrollPoints(ctx, parsed.Collection, parsed.Limit, parsed.Offset, nil, true, parsed.IncludeVector)
}
return nil, nil, fmt.Errorf("Qdrant 查询仅支持 JSON 命令或简单 SELECT 预览")
}
func (q *QdrantDB) Exec(query string) (int64, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultQdrantQueryTimeout)
defer cancel()
return q.ExecContext(ctx, query)
}
func (q *QdrantDB) ExecContext(ctx context.Context, query string) (int64, error) {
if q.client == nil {
return 0, fmt.Errorf("连接未打开")
}
var cmd map[string]interface{}
if err := decodeJSONWithUseNumber([]byte(strings.TrimSpace(query)), &cmd); err != nil {
return 0, fmt.Errorf("Qdrant 写入命令必须是 JSON%w", err)
}
if name := firstStringValue(cmd, "create_collection", "createCollection", "collection"); name != "" && hasAnyKey(cmd, "create_collection", "createCollection") {
return 1, q.createCollection(ctx, name, cmd)
}
if name := firstStringValue(cmd, "delete_collection", "deleteCollection"); name != "" {
return 1, q.deleteCollection(ctx, name)
}
if name := firstStringValue(cmd, "upsert", "collection"); name != "" && hasAnyKey(cmd, "upsert") {
return q.upsertCommand(ctx, name, cmd)
}
if name := firstStringValue(cmd, "delete", "collection"); name != "" && hasAnyKey(cmd, "delete") {
return q.deleteCommand(ctx, name, cmd)
}
if name := firstStringValue(cmd, "create_payload_index", "createPayloadIndex", "collection"); name != "" && hasAnyKey(cmd, "create_payload_index", "createPayloadIndex") {
return 1, q.createPayloadIndex(ctx, name, cmd)
}
if name := firstStringValue(cmd, "delete_payload_index", "deletePayloadIndex", "collection"); name != "" && hasAnyKey(cmd, "delete_payload_index", "deletePayloadIndex") {
fieldName := firstStringValue(cmd, "field_name", "fieldName", "field")
if fieldName == "" {
return 0, fmt.Errorf("Qdrant 删除 payload index 命令缺少 field_name")
}
return 1, q.deletePayloadIndex(ctx, name, fieldName)
}
return 0, fmt.Errorf("Qdrant JSON 写入命令仅支持 create_collection/delete_collection/upsert/delete/create_payload_index/delete_payload_index")
}
func (q *QdrantDB) GetDatabases() ([]string, error) {
if q.client == nil {
return nil, fmt.Errorf("连接未打开")
}
return []string{q.database}, nil
}
func (q *QdrantDB) GetTables(dbName string) ([]string, error) {
collections, err := q.listCollections(context.Background())
if err != nil {
return nil, err
}
names := make([]string, 0, len(collections))
for _, item := range collections {
if strings.TrimSpace(item.Name) != "" {
names = append(names, item.Name)
}
}
sort.Strings(names)
return names, nil
}
func (q *QdrantDB) GetCreateStatement(dbName, tableName string) (string, error) {
info, err := q.getCollectionInfo(context.Background(), tableNameOrDB(dbName, tableName))
if err != nil {
return "", err
}
payload, _ := json.MarshalIndent(info, "", " ")
return fmt.Sprintf("// Qdrant collection: %s\n%s", tableNameOrDB(dbName, tableName), string(payload)), nil
}
func (q *QdrantDB) GetColumns(dbName, tableName string) ([]connection.ColumnDefinition, error) {
rows, _, err := q.scrollPoints(context.Background(), tableNameOrDB(dbName, tableName), 20, nil, nil, true, true)
if err != nil {
return nil, err
}
cols := []connection.ColumnDefinition{
{Name: "id", Type: "point_id", Nullable: "NO", Key: "PRI", Comment: "Qdrant point id"},
{Name: "vector", Type: "vector<float>", Nullable: "YES", Comment: "Vector or named vectors"},
{Name: "payload", Type: "json", Nullable: "YES", Comment: "Full payload object"},
}
seen := map[string]struct{}{"id": {}, "vector": {}, "payload": {}}
for _, row := range rows {
for key, value := range row {
if _, exists := seen[key]; exists || !strings.HasPrefix(key, "payload.") {
continue
}
seen[key] = struct{}{}
cols = append(cols, connection.ColumnDefinition{
Name: key,
Type: inferChromaValueType(value),
Nullable: "YES",
Comment: "Payload field",
})
}
}
return cols, nil
}
func (q *QdrantDB) GetAllColumns(dbName string) ([]connection.ColumnDefinitionWithTable, error) {
tables, err := q.GetTables(dbName)
if err != nil {
return nil, err
}
var result []connection.ColumnDefinitionWithTable
for _, table := range tables {
cols, err := q.GetColumns(dbName, table)
if err != nil {
continue
}
for _, col := range cols {
result = append(result, connection.ColumnDefinitionWithTable{
TableName: table,
Name: col.Name,
Type: col.Type,
Comment: col.Comment,
})
}
}
return result, nil
}
func (q *QdrantDB) GetIndexes(dbName, tableName string) ([]connection.IndexDefinition, error) {
indexes := []connection.IndexDefinition{
{Name: "PRIMARY", ColumnName: "id", NonUnique: 0, SeqInIndex: 1, IndexType: "PRIMARY"},
}
info, err := q.getCollectionInfo(context.Background(), tableNameOrDB(dbName, tableName))
if err == nil {
indexes = append(indexes, qdrantVectorIndexes(info)...)
indexes = append(indexes, qdrantPayloadIndexes(info)...)
}
if len(indexes) == 1 {
indexes = append(indexes, connection.IndexDefinition{Name: "VECTOR", ColumnName: "vector", NonUnique: 1, SeqInIndex: 1, IndexType: "VECTOR"})
}
return indexes, nil
}
func (q *QdrantDB) GetForeignKeys(dbName, tableName string) ([]connection.ForeignKeyDefinition, error) {
return []connection.ForeignKeyDefinition{}, nil
}
func (q *QdrantDB) GetTriggers(dbName, tableName string) ([]connection.TriggerDefinition, error) {
return []connection.TriggerDefinition{}, nil
}
func (q *QdrantDB) ApplyChanges(tableName string, changes connection.ChangeSet) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultQdrantQueryTimeout)
defer cancel()
if len(changes.Deletes) > 0 {
ids := make([]interface{}, 0, len(changes.Deletes))
for _, row := range changes.Deletes {
if id, ok := qdrantRowID(row); ok {
ids = append(ids, id)
}
}
if len(ids) > 0 {
if _, err := q.deleteCommand(ctx, tableName, map[string]interface{}{"points": ids}); err != nil {
return err
}
}
}
if len(changes.Updates) > 0 {
var upserts []map[string]interface{}
for _, update := range changes.Updates {
row := make(map[string]interface{}, len(update.Keys)+len(update.Values))
for k, v := range update.Keys {
row[k] = v
}
for k, v := range update.Values {
row[k] = v
}
if _, hasVector := qdrantRowVector(row); hasVector {
upserts = append(upserts, row)
continue
}
if err := q.setPayloadFromRow(ctx, tableName, row); err != nil {
return err
}
}
if len(upserts) > 0 {
if err := q.upsertRows(ctx, tableName, upserts); err != nil {
return err
}
}
}
if len(changes.Inserts) > 0 {
if err := q.upsertRows(ctx, tableName, changes.Inserts); err != nil {
return err
}
}
return nil
}
func normalizeQdrantConfig(config connection.ConnectionConfig) connection.ConnectionConfig {
runConfig := applyQdrantURI(config)
if strings.TrimSpace(runConfig.Host) == "" {
runConfig.Host = "localhost"
}
if runConfig.Port <= 0 {
runConfig.Port = defaultQdrantPort
}
if strings.TrimSpace(runConfig.SSLMode) == "" && runConfig.UseSSL {
runConfig.SSLMode = "required"
}
return runConfig
}
func applyQdrantURI(config connection.ConnectionConfig) connection.ConnectionConfig {
uriText := strings.TrimSpace(config.URI)
if uriText == "" {
return config
}
parsed, err := url.Parse(uriText)
if err != nil {
return config
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme != "http" && scheme != "https" && scheme != "qdrant" {
return config
}
if parsed.User != nil {
if strings.TrimSpace(config.User) == "" {
config.User = parsed.User.Username()
}
if pass, ok := parsed.User.Password(); ok && config.Password == "" {
config.Password = pass
}
}
if scheme == "https" {
config.UseSSL = true
}
if host := strings.TrimSpace(parsed.Host); host != "" {
if h, port, ok := parseHostPortWithDefault(host, defaultQdrantPort); ok {
config.Host = h
config.Port = port
}
}
if dbName := strings.Trim(strings.TrimSpace(parsed.Path), "/"); dbName != "" && !strings.HasPrefix(dbName, "collections") && strings.TrimSpace(config.Database) == "" {
config.Database = dbName
}
return config
}
func buildQdrantBaseURL(config connection.ConnectionConfig) string {
scheme := "http"
if config.UseSSL {
scheme = "https"
}
return fmt.Sprintf("%s://%s:%d", scheme, strings.TrimSpace(config.Host), config.Port)
}
func qdrantDatabaseFromConfig(config connection.ConnectionConfig) string {
if dbName := strings.TrimSpace(config.Database); dbName != "" {
return dbName
}
return defaultQdrantDatabase
}
func qdrantConnectionParams(config connection.ConnectionConfig) url.Values {
params := url.Values{}
mergeConnectionParamValues(params, connectionParamsFromURI(config.URI, "http", "https", "qdrant"))
mergeConnectionParamValues(params, connectionParamsFromText(config.ConnectionParams))
return params
}
func qdrantAuthHeaders(config connection.ConnectionConfig) map[string]string {
headers := make(map[string]string)
params := qdrantConnectionParams(config)
apiKey := firstNonEmpty(params.Get("apiKey"), params.Get("apikey"), params.Get("api-key"), params.Get("token"), params.Get("authToken"))
if apiKey == "" && strings.TrimSpace(config.User) == "" {
apiKey = strings.TrimSpace(config.Password)
}
if apiKey != "" {
headers["api-key"] = apiKey
} else if user := strings.TrimSpace(config.User); user != "" {
raw := user + ":" + config.Password
headers["Authorization"] = "Basic " + base64.StdEncoding.EncodeToString([]byte(raw))
}
if headerName := strings.TrimSpace(params.Get("authHeader")); headerName != "" {
if headerValue := strings.TrimSpace(params.Get("authHeaderValue")); headerValue != "" && isSafeConnectionParamKey(headerName) {
headers[headerName] = headerValue
}
}
return headers
}
func buildQdrantHTTPClient(config connection.ConnectionConfig) *http.Client {
transport := http.DefaultTransport.(*http.Transport).Clone()
if tlsConfig, err := resolveGenericTLSConfig(config); err == nil && tlsConfig != nil {
transport.TLSClientConfig = tlsConfig
}
if config.UseProxy {
proxyCfg := config.Proxy
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return proxytunnel.DialContext(ctx, proxyCfg, network, addr)
}
}
return &http.Client{Transport: transport, Timeout: getConnectTimeout(config)}
}
func (q *QdrantDB) doJSON(ctx context.Context, method, path string, body interface{}, out interface{}) error {
if q.client == nil {
return fmt.Errorf("连接未打开")
}
var reader io.Reader
if body != nil {
payload, err := json.Marshal(body)
if err != nil {
return err
}
reader = bytes.NewReader(payload)
}
req, err := http.NewRequestWithContext(ctx, method, strings.TrimRight(q.baseURL, "/")+path, reader)
if err != nil {
return err
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
req.Header.Set("Accept", "application/json")
for key, value := range q.authHeaders {
if strings.TrimSpace(key) != "" && strings.TrimSpace(value) != "" {
req.Header.Set(key, value)
}
}
res, err := q.client.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
resBody, err := io.ReadAll(res.Body)
if err != nil {
return err
}
if res.StatusCode < 200 || res.StatusCode >= 300 {
message := strings.TrimSpace(string(resBody))
if message == "" {
message = res.Status
}
return fmt.Errorf("Qdrant API %s %s 失败:%s", method, path, message)
}
if out == nil || len(bytes.TrimSpace(resBody)) == 0 {
return nil
}
if err := decodeJSONWithUseNumber(resBody, out); err != nil {
return fmt.Errorf("解析 Qdrant 响应失败:%w", err)
}
return nil
}
func (q *QdrantDB) listCollections(ctx context.Context) ([]qdrantCollectionInfo, error) {
var resp qdrantListCollectionsResponse
if err := q.doJSON(ctx, http.MethodGet, "/collections", nil, &resp); err != nil {
return nil, err
}
return resp.Result.Collections, nil
}
func (q *QdrantDB) getCollectionInfo(ctx context.Context, collection string) (map[string]interface{}, error) {
name := strings.TrimSpace(collection)
if name == "" {
return nil, fmt.Errorf("collection 名称不能为空")
}
var resp qdrantCollectionResponse
if err := q.doJSON(ctx, http.MethodGet, fmt.Sprintf("/collections/%s", url.PathEscape(name)), nil, &resp); err != nil {
return nil, err
}
return resp.Result, nil
}
func (q *QdrantDB) scrollPoints(ctx context.Context, collection string, limit int, offset interface{}, filter interface{}, withPayload bool, withVector bool) ([]map[string]interface{}, []string, error) {
name := strings.TrimSpace(collection)
if name == "" {
return nil, nil, fmt.Errorf("collection 名称不能为空")
}
if limit <= 0 {
limit = 200
}
body := map[string]interface{}{
"limit": limit,
"with_payload": withPayload,
"with_vector": withVector,
}
if offset != nil && strings.TrimSpace(fmt.Sprintf("%v", offset)) != "" {
body["offset"] = qdrantNormalizePointID(offset)
}
if filter != nil {
body["filter"] = filter
}
var resp qdrantScrollResponse
if err := q.doJSON(ctx, http.MethodPost, fmt.Sprintf("/collections/%s/points/scroll", url.PathEscape(name)), body, &resp); err != nil {
return nil, nil, err
}
rows := qdrantPointRows(resp.Result.Points)
if resp.Result.NextPageOffset != nil {
for _, row := range rows {
row["next_page_offset"] = resp.Result.NextPageOffset
}
}
return rows, collectColumns(rows), nil
}
func (q *QdrantDB) searchPoints(ctx context.Context, collection string, cmd map[string]interface{}) ([]map[string]interface{}, []string, error) {
name := strings.TrimSpace(collection)
if name == "" {
return nil, nil, fmt.Errorf("collection 名称不能为空")
}
vector := firstExisting(cmd, "vector", "query_vector", "queryVector")
if vector == nil {
return nil, nil, fmt.Errorf("Qdrant search 命令缺少 vector")
}
body := map[string]interface{}{
"vector": normalizeQdrantVector(vector),
"limit": intFromAny(firstExisting(cmd, "limit", "n_results", "nResults"), 10),
"with_payload": qdrantBoolValue(firstExisting(cmd, "with_payload", "withPayload"), true),
"with_vector": qdrantBoolValue(firstExisting(cmd, "with_vector", "withVector"), true),
}
for _, key := range []string{"filter", "params", "score_threshold", "offset"} {
if value, ok := cmd[key]; ok {
body[key] = value
}
}
var resp qdrantSearchResponse
if err := q.doJSON(ctx, http.MethodPost, fmt.Sprintf("/collections/%s/points/search", url.PathEscape(name)), body, &resp); err != nil {
return nil, nil, err
}
rows := qdrantPointRows(resp.Result)
return rows, collectColumns(rows), nil
}
func (q *QdrantDB) countPoints(ctx context.Context, collection string, filter interface{}) (int64, error) {
name := strings.TrimSpace(collection)
if name == "" {
return 0, fmt.Errorf("collection 名称不能为空")
}
body := map[string]interface{}{"exact": true}
if filter != nil {
body["filter"] = filter
}
var resp qdrantCountResponse
if err := q.doJSON(ctx, http.MethodPost, fmt.Sprintf("/collections/%s/points/count", url.PathEscape(name)), body, &resp); err != nil {
return 0, err
}
return resp.Result.Count, nil
}
func (q *QdrantDB) queryJSON(ctx context.Context, text string) ([]map[string]interface{}, []string, error) {
var cmd map[string]interface{}
if err := decodeJSONWithUseNumber([]byte(text), &cmd); err != nil {
return nil, nil, fmt.Errorf("Qdrant JSON 命令解析失败:%w", err)
}
if hasAnyKey(cmd, "list_collections", "listCollections") {
collections, err := q.listCollections(ctx)
if err != nil {
return nil, nil, err
}
rows := make([]map[string]interface{}, 0, len(collections))
for _, collection := range collections {
rows = append(rows, map[string]interface{}{"name": collection.Name})
}
return rows, collectColumns(rows), nil
}
if name := firstStringValue(cmd, "get_collection", "getCollection"); name != "" {
info, err := q.getCollectionInfo(ctx, name)
if err != nil {
return nil, nil, err
}
return []map[string]interface{}{info}, collectColumns([]map[string]interface{}{info}), nil
}
if name := firstStringValue(cmd, "count", "collection"); name != "" && hasAnyKey(cmd, "count") {
total, err := q.countPoints(ctx, name, cmd["filter"])
if err != nil {
return nil, nil, err
}
return []map[string]interface{}{{"total": total}}, []string{"total"}, nil
}
if name := firstStringValue(cmd, "search", "query", "collection"); name != "" && hasAnyKey(cmd, "search", "query", "vector", "query_vector", "queryVector") {
return q.searchPoints(ctx, name, cmd)
}
if name := firstStringValue(cmd, "scroll", "get", "collection"); name != "" {
limit := intFromAny(cmd["limit"], 200)
offset := firstExisting(cmd, "offset", "next_page_offset", "nextPageOffset")
return q.scrollPoints(
ctx,
name,
limit,
offset,
cmd["filter"],
qdrantBoolValue(firstExisting(cmd, "with_payload", "withPayload"), true),
qdrantBoolValue(firstExisting(cmd, "with_vector", "withVector"), true),
)
}
return nil, nil, fmt.Errorf("Qdrant JSON 查询命令仅支持 list_collections/get_collection/count/scroll/search")
}
func (q *QdrantDB) createCollection(ctx context.Context, name string, cmd map[string]interface{}) error {
collection := strings.TrimSpace(name)
if collection == "" {
return fmt.Errorf("collection 名称不能为空")
}
body := make(map[string]interface{})
if vectors, ok := cmd["vectors"]; ok {
body["vectors"] = vectors
} else {
size := intFromAny(firstExisting(cmd, "size", "vector_size", "vectorSize"), 0)
if size <= 0 {
return fmt.Errorf("Qdrant create_collection 命令缺少 vectors 或 size")
}
distance := firstStringValue(cmd, "distance", "metric")
if distance == "" {
distance = "Cosine"
}
body["vectors"] = map[string]interface{}{"size": size, "distance": distance}
}
for _, key := range []string{
"sparse_vectors",
"shard_number",
"replication_factor",
"write_consistency_factor",
"on_disk_payload",
"hnsw_config",
"optimizers_config",
"wal_config",
"quantization_config",
"strict_mode_config",
"init_from",
} {
if value, ok := cmd[key]; ok {
body[key] = value
}
}
return q.doJSON(ctx, http.MethodPut, fmt.Sprintf("/collections/%s", url.PathEscape(collection)), body, nil)
}
func (q *QdrantDB) deleteCollection(ctx context.Context, name string) error {
collection := strings.TrimSpace(name)
if collection == "" {
return fmt.Errorf("collection 名称不能为空")
}
return q.doJSON(ctx, http.MethodDelete, fmt.Sprintf("/collections/%s", url.PathEscape(collection)), nil, nil)
}
func (q *QdrantDB) createPayloadIndex(ctx context.Context, collection string, cmd map[string]interface{}) error {
fieldName := firstStringValue(cmd, "field_name", "fieldName", "field")
if fieldName == "" {
return fmt.Errorf("Qdrant create_payload_index 命令缺少 field_name")
}
fieldSchema := firstExisting(cmd, "field_schema", "fieldSchema", "schema")
if fieldSchema == nil {
fieldSchema = "keyword"
}
body := map[string]interface{}{
"field_name": fieldName,
"field_schema": fieldSchema,
}
return q.doJSON(ctx, http.MethodPut, fmt.Sprintf("/collections/%s/index", url.PathEscape(collection)), body, nil)
}
func (q *QdrantDB) deletePayloadIndex(ctx context.Context, collection, fieldName string) error {
return q.doJSON(ctx, http.MethodDelete, fmt.Sprintf("/collections/%s/index/%s", url.PathEscape(collection), url.PathEscape(fieldName)), nil, nil)
}
func (q *QdrantDB) upsertCommand(ctx context.Context, collection string, cmd map[string]interface{}) (int64, error) {
if rowsValue, ok := cmd["rows"].([]interface{}); ok {
rows := make([]map[string]interface{}, 0, len(rowsValue))
for _, raw := range rowsValue {
if row, ok := raw.(map[string]interface{}); ok {
rows = append(rows, row)
}
}
return int64(len(rows)), q.upsertRows(ctx, collection, rows)
}
if points, ok := cmd["points"]; ok {
body := map[string]interface{}{"points": points}
return int64(len(anySlice(points))), q.doJSON(ctx, http.MethodPut, fmt.Sprintf("/collections/%s/points?wait=true", url.PathEscape(collection)), body, nil)
}
return 0, fmt.Errorf("Qdrant upsert 命令缺少 rows 或 points")
}
func (q *QdrantDB) deleteCommand(ctx context.Context, collection string, cmd map[string]interface{}) (int64, error) {
body := make(map[string]interface{})
if points, ok := cmd["points"]; ok {
body["points"] = qdrantPointIDSlice(points)
} else if ids, ok := cmd["ids"]; ok {
body["points"] = qdrantPointIDSlice(ids)
} else if filter, ok := cmd["filter"]; ok {
body["filter"] = filter
}
if len(body) == 0 {
return 0, fmt.Errorf("Qdrant delete 命令缺少 points/ids/filter")
}
count := int64(len(anySlice(firstExisting(body, "points"))))
return count, q.doJSON(ctx, http.MethodPost, fmt.Sprintf("/collections/%s/points/delete?wait=true", url.PathEscape(collection)), body, nil)
}
func (q *QdrantDB) upsertRows(ctx context.Context, collection string, rows []map[string]interface{}) error {
if len(rows) == 0 {
return nil
}
points := make([]map[string]interface{}, 0, len(rows))
for _, row := range rows {
id, ok := qdrantRowID(row)
if !ok {
return fmt.Errorf("Qdrant 写入行缺少 id")
}
vector, hasVector := qdrantRowVector(row)
if !hasVector {
return fmt.Errorf("Qdrant upsert 行缺少 vector/embedding")
}
points = append(points, map[string]interface{}{
"id": id,
"vector": vector,
"payload": qdrantPayloadFromRow(row),
})
}
body := map[string]interface{}{"points": points}
return q.doJSON(ctx, http.MethodPut, fmt.Sprintf("/collections/%s/points?wait=true", url.PathEscape(collection)), body, nil)
}
func (q *QdrantDB) setPayloadFromRow(ctx context.Context, collection string, row map[string]interface{}) error {
id, ok := qdrantRowID(row)
if !ok {
return fmt.Errorf("Qdrant payload 更新缺少 id")
}
payload := qdrantPayloadFromRow(row)
if len(payload) == 0 {
return nil
}
body := map[string]interface{}{
"points": []interface{}{id},
"payload": payload,
}
return q.doJSON(ctx, http.MethodPost, fmt.Sprintf("/collections/%s/points/payload?wait=true", url.PathEscape(collection)), body, nil)
}
type qdrantParsedSQL struct {
Collection string
Limit int
Offset interface{}
Count bool
IncludeVector bool
}
var qdrantSQLFromRE = regexp.MustCompile(`(?i)\bFROM\s+(?:"([^"]+)"|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z0-9_.\-]+))`)
var qdrantSQLLimitRE = regexp.MustCompile(`(?i)\bLIMIT\s+(\d+)`)
var qdrantSQLOffsetRE = regexp.MustCompile(`(?i)\bOFFSET\s+([a-zA-Z0-9_.\-]+)`)
func parseQdrantSQL(sqlText string) (qdrantParsedSQL, bool) {
text := strings.TrimSpace(sqlText)
if !strings.HasPrefix(strings.ToLower(text), "select") {
return qdrantParsedSQL{}, false
}
matches := qdrantSQLFromRE.FindStringSubmatch(text)
if len(matches) == 0 {
return qdrantParsedSQL{}, false
}
collection := firstNonEmpty(matches[1], matches[2], matches[3])
if collection == "" {
return qdrantParsedSQL{}, false
}
parsed := qdrantParsedSQL{Collection: collection, Limit: 200}
lower := strings.ToLower(text)
parsed.Count = strings.Contains(lower, "count(")
parsed.IncludeVector = strings.Contains(lower, "vector")
if m := qdrantSQLLimitRE.FindStringSubmatch(text); len(m) > 1 {
parsed.Limit, _ = strconv.Atoi(m[1])
}
if m := qdrantSQLOffsetRE.FindStringSubmatch(text); len(m) > 1 {
parsed.Offset = qdrantNormalizePointID(m[1])
}
return parsed, true
}
func qdrantPointRows(points []qdrantPoint) []map[string]interface{} {
rows := make([]map[string]interface{}, 0, len(points))
for _, point := range points {
row := map[string]interface{}{"id": point.ID}
if point.Score != nil {
row["score"] = point.Score
}
if point.Version != nil {
row["version"] = point.Version
}
if point.Vector != nil {
row["vector"] = normalizeJSONLikeValue(point.Vector)
}
if point.Payload != nil {
row["payload"] = point.Payload
for key, value := range point.Payload {
row["payload."+key] = value
}
}
rows = append(rows, row)
}
return rows
}
func qdrantRowID(row map[string]interface{}) (interface{}, bool) {
raw := firstExisting(row, "id", "_id")
if raw == nil {
return nil, false
}
text := strings.TrimSpace(fmt.Sprintf("%v", raw))
if text == "" || text == "<nil>" {
return nil, false
}
return qdrantNormalizePointID(raw), true
}
func qdrantNormalizePointID(value interface{}) interface{} {
switch v := value.(type) {
case json.Number:
if n, err := v.Int64(); err == nil {
return n
}
case float64:
if v == float64(int64(v)) {
return int64(v)
}
case float32:
if v == float32(int64(v)) {
return int64(v)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return v
case string:
text := strings.TrimSpace(v)
if n, err := strconv.ParseInt(text, 10, 64); err == nil {
return n
}
return text
}
return value
}
func qdrantPointIDSlice(value interface{}) []interface{} {
items := anySlice(value)
result := make([]interface{}, 0, len(items))
for _, item := range items {
result = append(result, qdrantNormalizePointID(item))
}
return result
}
func qdrantRowVector(row map[string]interface{}) (interface{}, bool) {
vector := firstExisting(row, "vector", "_vector", "vectors", "embedding", "_embedding")
if vector == nil {
return nil, false
}
return normalizeQdrantVector(vector), true
}
func normalizeQdrantVector(value interface{}) interface{} {
if text, ok := value.(string); ok {
var parsed interface{}
if err := decodeJSONWithUseNumber([]byte(text), &parsed); err == nil {
return parsed
}
}
return value
}
func qdrantPayloadFromRow(row map[string]interface{}) map[string]interface{} {
payload := make(map[string]interface{})
if raw, ok := row["payload"].(map[string]interface{}); ok {
for key, value := range raw {
payload[key] = value
}
}
for key, value := range row {
if isQdrantReservedRowField(key) {
continue
}
if strings.HasPrefix(key, "payload.") {
payload[strings.TrimPrefix(key, "payload.")] = value
continue
}
payload[key] = value
}
return payload
}
func isQdrantReservedRowField(key string) bool {
switch key {
case "id", "_id", "vector", "_vector", "vectors", "embedding", "_embedding", "payload", "score", "version", "next_page_offset":
return true
default:
return false
}
}
func qdrantBoolValue(value interface{}, fallback bool) bool {
if value == nil {
return fallback
}
switch v := value.(type) {
case bool:
return v
case string:
text := strings.TrimSpace(strings.ToLower(v))
if text == "" {
return fallback
}
return text == "1" || text == "true" || text == "yes" || text == "on"
default:
return fallback
}
}
func qdrantVectorIndexes(info map[string]interface{}) []connection.IndexDefinition {
vectors := nestedMapValue(info, "config", "params", "vectors")
if len(vectors) == 0 {
return nil
}
if _, ok := vectors["size"]; ok {
return []connection.IndexDefinition{{Name: "VECTOR", ColumnName: "vector", NonUnique: 1, SeqInIndex: 1, IndexType: "VECTOR"}}
}
var indexes []connection.IndexDefinition
names := make([]string, 0, len(vectors))
for name := range vectors {
names = append(names, name)
}
sort.Strings(names)
for index, name := range names {
indexes = append(indexes, connection.IndexDefinition{
Name: "VECTOR_" + name,
ColumnName: "vector." + name,
NonUnique: 1,
SeqInIndex: index + 1,
IndexType: "VECTOR",
})
}
return indexes
}
func qdrantPayloadIndexes(info map[string]interface{}) []connection.IndexDefinition {
schema := nestedMapValue(info, "payload_schema")
if len(schema) == 0 {
schema = nestedMapValue(info, "payload_schema", "schema")
}
if len(schema) == 0 {
return nil
}
names := make([]string, 0, len(schema))
for name := range schema {
names = append(names, name)
}
sort.Strings(names)
indexes := make([]connection.IndexDefinition, 0, len(names))
for index, name := range names {
indexes = append(indexes, connection.IndexDefinition{
Name: "PAYLOAD_" + name,
ColumnName: "payload." + name,
NonUnique: 1,
SeqInIndex: index + 1,
IndexType: "PAYLOAD",
})
}
return indexes
}
func nestedMapValue(value interface{}, path ...string) map[string]interface{} {
current := value
for _, key := range path {
m, ok := current.(map[string]interface{})
if !ok {
return nil
}
current = m[key]
}
if m, ok := current.(map[string]interface{}); ok {
return m
}
return nil
}