feat: add hit payload extraction method for Milvus providers

This commit is contained in:
shiyu
2025-09-27 15:09:20 +08:00
parent 5625f2d8bf
commit 8ac3acebb4
2 changed files with 60 additions and 26 deletions

View File

@@ -39,6 +39,35 @@ class MilvusLiteProvider(BaseVectorProvider):
raise RuntimeError("Milvus Lite client is not initialized")
return self.client
@staticmethod
def _extract_hit_payload(hit: Any) -> tuple[Any, Any, Dict[str, Any]]:
hit_id = getattr(hit, "id", None)
distance = getattr(hit, "distance", None)
payload: Dict[str, Any] = {}
raw: Dict[str, Any] | None = None
if hasattr(hit, "entity"):
raw_entity = getattr(hit, "entity")
if hasattr(raw_entity, "to_dict"):
raw = dict(raw_entity.to_dict())
else:
raw = dict(raw_entity)
elif isinstance(hit, dict):
raw = dict(hit)
if raw:
hit_id = hit_id or raw.get("id")
distance = distance if distance is not None else raw.get("distance")
inner = raw.get("entity")
if isinstance(inner, dict):
payload = dict(inner)
else:
payload = {k: v for k, v in raw.items() if k not in {"id", "distance", "entity"}}
payload.setdefault("path", payload.get("source_path"))
payload.setdefault("source_path", payload.get("path"))
return hit_id, distance, payload
@staticmethod
def _to_int(value: Any) -> int:
try:
@@ -114,19 +143,7 @@ class MilvusLiteProvider(BaseVectorProvider):
for hits in raw_results:
bucket: List[Dict[str, Any]] = []
for hit in hits:
if hasattr(hit, "entity"):
entity = dict(getattr(hit, "entity", {}) or {})
hit_id = getattr(hit, "id", None)
distance = getattr(hit, "distance", None)
elif isinstance(hit, dict):
entity = dict((hit.get("entity") or {}))
hit_id = hit.get("id")
distance = hit.get("distance")
else:
entity = {}
hit_id = None
distance = None
entity.setdefault("path", entity.get("source_path"))
hit_id, distance, entity = self._extract_hit_payload(hit)
bucket.append({
"id": hit_id,
"distance": distance,

View File

@@ -47,6 +47,35 @@ class MilvusServerProvider(BaseVectorProvider):
raise RuntimeError("Milvus Server client is not initialized")
return self.client
@staticmethod
def _extract_hit_payload(hit: Any) -> tuple[Any, Any, Dict[str, Any]]:
hit_id = getattr(hit, "id", None)
distance = getattr(hit, "distance", None)
payload: Dict[str, Any] = {}
raw: Dict[str, Any] | None = None
if hasattr(hit, "entity"):
raw_entity = getattr(hit, "entity")
if hasattr(raw_entity, "to_dict"):
raw = dict(raw_entity.to_dict())
else:
raw = dict(raw_entity)
elif isinstance(hit, dict):
raw = dict(hit)
if raw:
hit_id = hit_id or raw.get("id")
distance = distance if distance is not None else raw.get("distance")
inner = raw.get("entity")
if isinstance(inner, dict):
payload = dict(inner)
else:
payload = {k: v for k, v in raw.items() if k not in {"id", "distance", "entity"}}
payload.setdefault("path", payload.get("source_path"))
payload.setdefault("source_path", payload.get("path"))
return hit_id, distance, payload
@staticmethod
def _to_int(value: Any) -> int:
try:
@@ -121,19 +150,7 @@ class MilvusServerProvider(BaseVectorProvider):
for hits in raw_results:
bucket: List[Dict[str, Any]] = []
for hit in hits:
if hasattr(hit, "entity"):
entity = dict(getattr(hit, "entity", {}) or {})
hit_id = getattr(hit, "id", None)
distance = getattr(hit, "distance", None)
elif isinstance(hit, dict):
entity = dict((hit.get("entity") or {}))
hit_id = hit.get("id")
distance = hit.get("distance")
else:
entity = {}
hit_id = None
distance = None
entity.setdefault("path", entity.get("source_path"))
hit_id, distance, entity = self._extract_hit_payload(hit)
bucket.append({
"id": hit_id,
"distance": distance,