mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-05-11 18:10:10 +08:00
279 lines
10 KiB
Python
279 lines
10 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
|
|
|
|
from .base import BaseVectorProvider
|
|
|
|
|
|
class MilvusServerProvider(BaseVectorProvider):
|
|
type = "milvus_server"
|
|
label = "Milvus Server"
|
|
description = "Remote Milvus instance accessed via URI."
|
|
enabled = True
|
|
config_schema: List[Dict[str, Any]] = [
|
|
{
|
|
"key": "uri",
|
|
"label": "Server URI",
|
|
"type": "text",
|
|
"required": True,
|
|
"placeholder": "http://localhost:19530",
|
|
},
|
|
{
|
|
"key": "token",
|
|
"label": "Token",
|
|
"type": "password",
|
|
"required": False,
|
|
"placeholder": "user:password",
|
|
},
|
|
]
|
|
|
|
def __init__(self, config: Dict[str, Any] | None = None):
|
|
super().__init__(config)
|
|
self.client: MilvusClient | None = None
|
|
|
|
async def initialize(self) -> None:
|
|
uri = self.config.get("uri")
|
|
if not uri:
|
|
raise RuntimeError("Milvus Server URI is required")
|
|
try:
|
|
self.client = MilvusClient(uri=uri, token=self.config.get("token"))
|
|
except Exception as exc: # pragma: no cover - depends on remote availability
|
|
raise RuntimeError(f"Failed to connect to Milvus Server {uri}: {exc}") from exc
|
|
|
|
def _get_client(self) -> MilvusClient:
|
|
if not self.client:
|
|
raise RuntimeError("Milvus Server client is not initialized")
|
|
return self.client
|
|
|
|
@staticmethod
|
|
def _extract_hit_payload(hit: Any) -> tuple[Any, Any, Dict[str, Any]]:
|
|
hit_id = getattr(hit, "id", None)
|
|
distance = getattr(hit, "distance", None)
|
|
payload: Dict[str, Any] = {}
|
|
|
|
raw: Dict[str, Any] | None = None
|
|
if hasattr(hit, "entity"):
|
|
raw_entity = getattr(hit, "entity")
|
|
if hasattr(raw_entity, "to_dict"):
|
|
raw = dict(raw_entity.to_dict())
|
|
else:
|
|
raw = dict(raw_entity)
|
|
elif isinstance(hit, dict):
|
|
raw = dict(hit)
|
|
|
|
if raw:
|
|
hit_id = hit_id or raw.get("id")
|
|
distance = distance if distance is not None else raw.get("distance")
|
|
inner = raw.get("entity")
|
|
if isinstance(inner, dict):
|
|
payload = dict(inner)
|
|
else:
|
|
payload = {k: v for k, v in raw.items() if k not in {"id", "distance", "entity"}}
|
|
|
|
payload.setdefault("path", payload.get("source_path"))
|
|
payload.setdefault("source_path", payload.get("path"))
|
|
return hit_id, distance, payload
|
|
|
|
@staticmethod
|
|
def _to_int(value: Any) -> int:
|
|
try:
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return 0
|
|
|
|
def ensure_collection(self, collection_name: str, vector: bool, dim: int) -> None:
|
|
client = self._get_client()
|
|
if client.has_collection(collection_name):
|
|
return
|
|
common_fields = [
|
|
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
|
FieldSchema(name="source_path", dtype=DataType.VARCHAR, max_length=512, is_primary=False, auto_id=False),
|
|
]
|
|
if vector:
|
|
vector_dim = dim if isinstance(dim, int) and dim > 0 else 0
|
|
if vector_dim <= 0:
|
|
vector_dim = 4096
|
|
fields = [
|
|
*common_fields,
|
|
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=vector_dim),
|
|
]
|
|
schema = CollectionSchema(fields, description="Vector collection", enable_dynamic_field=True)
|
|
client.create_collection(collection_name, schema=schema)
|
|
index_params = MilvusClient.prepare_index_params()
|
|
index_params.add_index(
|
|
field_name="embedding",
|
|
index_type="IVF_FLAT",
|
|
index_name="vector_index",
|
|
metric_type="COSINE",
|
|
params={"nlist": 64},
|
|
)
|
|
client.create_index(collection_name, index_params=index_params)
|
|
else:
|
|
schema = CollectionSchema(common_fields, description="Simple file index", enable_dynamic_field=True)
|
|
client.create_collection(collection_name, schema=schema)
|
|
|
|
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
|
payload = dict(data)
|
|
payload.setdefault("source_path", payload.get("path"))
|
|
payload.setdefault("vector_id", payload.get("path"))
|
|
self._get_client().upsert(collection_name, data=[payload])
|
|
|
|
def delete_vector(self, collection_name: str, path: str) -> None:
|
|
client = self._get_client()
|
|
escaped = path.replace('"', '\\"')
|
|
client.delete(collection_name, filter=f'source_path == "{escaped}"')
|
|
|
|
def search_vectors(self, collection_name: str, query_embedding, top_k: int):
|
|
search_params = {"metric_type": "COSINE"}
|
|
output_fields = [
|
|
"path",
|
|
"source_path",
|
|
"chunk_id",
|
|
"mime",
|
|
"text",
|
|
"start_offset",
|
|
"end_offset",
|
|
"type",
|
|
"name",
|
|
]
|
|
raw_results = self._get_client().search(
|
|
collection_name,
|
|
data=[query_embedding],
|
|
anns_field="embedding",
|
|
search_params=search_params,
|
|
limit=top_k,
|
|
output_fields=output_fields,
|
|
)
|
|
formatted: List[List[Dict[str, Any]]] = []
|
|
for hits in raw_results:
|
|
bucket: List[Dict[str, Any]] = []
|
|
for hit in hits:
|
|
hit_id, distance, entity = self._extract_hit_payload(hit)
|
|
bucket.append({
|
|
"id": hit_id,
|
|
"distance": distance,
|
|
"entity": entity,
|
|
})
|
|
formatted.append(bucket)
|
|
return formatted
|
|
|
|
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
|
|
if query_path:
|
|
escaped = query_path.replace('"', '\\"')
|
|
filter_expr = f'source_path like "%{escaped}%"'
|
|
else:
|
|
filter_expr = "source_path like '%%'"
|
|
results = self._get_client().query(
|
|
collection_name,
|
|
filter=filter_expr,
|
|
limit=top_k,
|
|
output_fields=[
|
|
"path",
|
|
"source_path",
|
|
"chunk_id",
|
|
"mime",
|
|
"text",
|
|
"start_offset",
|
|
"end_offset",
|
|
"type",
|
|
"name",
|
|
],
|
|
)
|
|
formatted = []
|
|
for row in results:
|
|
entity = dict(row)
|
|
entity.setdefault("path", entity.get("source_path"))
|
|
formatted.append({
|
|
"id": entity.get("path"),
|
|
"distance": 1.0,
|
|
"entity": entity,
|
|
})
|
|
return [formatted]
|
|
|
|
def get_all_stats(self) -> Dict[str, Any]:
|
|
client = self._get_client()
|
|
try:
|
|
collection_names = client.list_collections()
|
|
except Exception as exc:
|
|
raise RuntimeError(f"Failed to list collections: {exc}") from exc
|
|
|
|
collections: List[Dict[str, Any]] = []
|
|
total_vectors = 0
|
|
total_estimated_memory = 0
|
|
|
|
for name in collection_names:
|
|
try:
|
|
stats = client.get_collection_stats(name) or {}
|
|
except Exception:
|
|
stats = {}
|
|
row_count = self._to_int(stats.get("row_count"))
|
|
total_vectors += row_count
|
|
|
|
dimension: Optional[int] = None
|
|
is_vector_collection = False
|
|
try:
|
|
description = client.describe_collection(name)
|
|
except Exception:
|
|
description = None
|
|
|
|
if description:
|
|
for field in description.get("fields", []):
|
|
if field.get("type") == DataType.FLOAT_VECTOR:
|
|
params = field.get("params") or {}
|
|
dimension = self._to_int(params.get("dim")) or 4096
|
|
is_vector_collection = True
|
|
break
|
|
|
|
estimated_memory = 0
|
|
if is_vector_collection and dimension:
|
|
estimated_memory = row_count * dimension * 4
|
|
total_estimated_memory += estimated_memory
|
|
|
|
indexes: List[Dict[str, Any]] = []
|
|
try:
|
|
index_names = client.list_indexes(name) or []
|
|
except Exception:
|
|
index_names = []
|
|
|
|
for index_name in index_names:
|
|
try:
|
|
detail = client.describe_index(name, index_name) or {}
|
|
except Exception:
|
|
detail = {}
|
|
indexes.append(
|
|
{
|
|
"index_name": index_name,
|
|
"index_type": detail.get("index_type"),
|
|
"metric_type": detail.get("metric_type"),
|
|
"indexed_rows": self._to_int(detail.get("indexed_rows")),
|
|
"pending_index_rows": self._to_int(detail.get("pending_index_rows")),
|
|
"state": detail.get("state"),
|
|
}
|
|
)
|
|
|
|
collections.append(
|
|
{
|
|
"name": name,
|
|
"row_count": row_count,
|
|
"dimension": dimension if is_vector_collection else None,
|
|
"estimated_memory_bytes": estimated_memory,
|
|
"is_vector_collection": is_vector_collection,
|
|
"indexes": indexes,
|
|
}
|
|
)
|
|
|
|
return {
|
|
"collections": collections,
|
|
"collection_count": len(collections),
|
|
"total_vectors": total_vectors,
|
|
"estimated_total_memory_bytes": total_estimated_memory,
|
|
"db_file_size_bytes": None,
|
|
}
|
|
|
|
def clear_all_data(self) -> None:
|
|
client = self._get_client()
|
|
for collection_name in client.list_collections():
|
|
client.drop_collection(collection_name)
|