Files
Foxel/domain/ai/vector_providers/milvus_server.py
2025-12-08 17:46:45 +08:00

279 lines
10 KiB
Python

from __future__ import annotations
from typing import Any, Dict, List, Optional
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
from .base import BaseVectorProvider
class MilvusServerProvider(BaseVectorProvider):
type = "milvus_server"
label = "Milvus Server"
description = "Remote Milvus instance accessed via URI."
enabled = True
config_schema: List[Dict[str, Any]] = [
{
"key": "uri",
"label": "Server URI",
"type": "text",
"required": True,
"placeholder": "http://localhost:19530",
},
{
"key": "token",
"label": "Token",
"type": "password",
"required": False,
"placeholder": "user:password",
},
]
def __init__(self, config: Dict[str, Any] | None = None):
super().__init__(config)
self.client: MilvusClient | None = None
async def initialize(self) -> None:
uri = self.config.get("uri")
if not uri:
raise RuntimeError("Milvus Server URI is required")
try:
self.client = MilvusClient(uri=uri, token=self.config.get("token"))
except Exception as exc: # pragma: no cover - depends on remote availability
raise RuntimeError(f"Failed to connect to Milvus Server {uri}: {exc}") from exc
def _get_client(self) -> MilvusClient:
if not self.client:
raise RuntimeError("Milvus Server client is not initialized")
return self.client
@staticmethod
def _extract_hit_payload(hit: Any) -> tuple[Any, Any, Dict[str, Any]]:
hit_id = getattr(hit, "id", None)
distance = getattr(hit, "distance", None)
payload: Dict[str, Any] = {}
raw: Dict[str, Any] | None = None
if hasattr(hit, "entity"):
raw_entity = getattr(hit, "entity")
if hasattr(raw_entity, "to_dict"):
raw = dict(raw_entity.to_dict())
else:
raw = dict(raw_entity)
elif isinstance(hit, dict):
raw = dict(hit)
if raw:
hit_id = hit_id or raw.get("id")
distance = distance if distance is not None else raw.get("distance")
inner = raw.get("entity")
if isinstance(inner, dict):
payload = dict(inner)
else:
payload = {k: v for k, v in raw.items() if k not in {"id", "distance", "entity"}}
payload.setdefault("path", payload.get("source_path"))
payload.setdefault("source_path", payload.get("path"))
return hit_id, distance, payload
@staticmethod
def _to_int(value: Any) -> int:
try:
return int(value)
except (TypeError, ValueError):
return 0
def ensure_collection(self, collection_name: str, vector: bool, dim: int) -> None:
client = self._get_client()
if client.has_collection(collection_name):
return
common_fields = [
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
FieldSchema(name="source_path", dtype=DataType.VARCHAR, max_length=512, is_primary=False, auto_id=False),
]
if vector:
vector_dim = dim if isinstance(dim, int) and dim > 0 else 0
if vector_dim <= 0:
vector_dim = 4096
fields = [
*common_fields,
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=vector_dim),
]
schema = CollectionSchema(fields, description="Vector collection", enable_dynamic_field=True)
client.create_collection(collection_name, schema=schema)
index_params = MilvusClient.prepare_index_params()
index_params.add_index(
field_name="embedding",
index_type="IVF_FLAT",
index_name="vector_index",
metric_type="COSINE",
params={"nlist": 64},
)
client.create_index(collection_name, index_params=index_params)
else:
schema = CollectionSchema(common_fields, description="Simple file index", enable_dynamic_field=True)
client.create_collection(collection_name, schema=schema)
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
payload = dict(data)
payload.setdefault("source_path", payload.get("path"))
payload.setdefault("vector_id", payload.get("path"))
self._get_client().upsert(collection_name, data=[payload])
def delete_vector(self, collection_name: str, path: str) -> None:
client = self._get_client()
escaped = path.replace('"', '\\"')
client.delete(collection_name, filter=f'source_path == "{escaped}"')
def search_vectors(self, collection_name: str, query_embedding, top_k: int):
search_params = {"metric_type": "COSINE"}
output_fields = [
"path",
"source_path",
"chunk_id",
"mime",
"text",
"start_offset",
"end_offset",
"type",
"name",
]
raw_results = self._get_client().search(
collection_name,
data=[query_embedding],
anns_field="embedding",
search_params=search_params,
limit=top_k,
output_fields=output_fields,
)
formatted: List[List[Dict[str, Any]]] = []
for hits in raw_results:
bucket: List[Dict[str, Any]] = []
for hit in hits:
hit_id, distance, entity = self._extract_hit_payload(hit)
bucket.append({
"id": hit_id,
"distance": distance,
"entity": entity,
})
formatted.append(bucket)
return formatted
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
if query_path:
escaped = query_path.replace('"', '\\"')
filter_expr = f'source_path like \"%{escaped}%\"'
else:
filter_expr = "source_path like '%%'"
results = self._get_client().query(
collection_name,
filter=filter_expr,
limit=top_k,
output_fields=[
"path",
"source_path",
"chunk_id",
"mime",
"text",
"start_offset",
"end_offset",
"type",
"name",
],
)
formatted = []
for row in results:
entity = dict(row)
entity.setdefault("path", entity.get("source_path"))
formatted.append({
"id": entity.get("path"),
"distance": 1.0,
"entity": entity,
})
return [formatted]
def get_all_stats(self) -> Dict[str, Any]:
client = self._get_client()
try:
collection_names = client.list_collections()
except Exception as exc:
raise RuntimeError(f"Failed to list collections: {exc}") from exc
collections: List[Dict[str, Any]] = []
total_vectors = 0
total_estimated_memory = 0
for name in collection_names:
try:
stats = client.get_collection_stats(name) or {}
except Exception:
stats = {}
row_count = self._to_int(stats.get("row_count"))
total_vectors += row_count
dimension: Optional[int] = None
is_vector_collection = False
try:
description = client.describe_collection(name)
except Exception:
description = None
if description:
for field in description.get("fields", []):
if field.get("type") == DataType.FLOAT_VECTOR:
params = field.get("params") or {}
dimension = self._to_int(params.get("dim")) or 4096
is_vector_collection = True
break
estimated_memory = 0
if is_vector_collection and dimension:
estimated_memory = row_count * dimension * 4
total_estimated_memory += estimated_memory
indexes: List[Dict[str, Any]] = []
try:
index_names = client.list_indexes(name) or []
except Exception:
index_names = []
for index_name in index_names:
try:
detail = client.describe_index(name) or {}
except Exception:
detail = {}
indexes.append(
{
"index_name": index_name,
"index_type": detail.get("index_type"),
"metric_type": detail.get("metric_type"),
"indexed_rows": self._to_int(detail.get("indexed_rows")),
"pending_index_rows": self._to_int(detail.get("pending_index_rows")),
"state": detail.get("state"),
}
)
collections.append(
{
"name": name,
"row_count": row_count,
"dimension": dimension if is_vector_collection else None,
"estimated_memory_bytes": estimated_memory,
"is_vector_collection": is_vector_collection,
"indexes": indexes,
}
)
return {
"collections": collections,
"collection_count": len(collections),
"total_vectors": total_vectors,
"estimated_total_memory_bytes": total_estimated_memory,
"db_file_size_bytes": None,
}
def clear_all_data(self) -> None:
client = self._get_client()
for collection_name in client.list_collections():
client.drop_collection(collection_name)