From 8ac3acebb434aa970efeab8ab65398374774fc09 Mon Sep 17 00:00:00 2001 From: shiyu Date: Sat, 27 Sep 2025 15:09:20 +0800 Subject: [PATCH] feat: add hit payload extraction method for Milvus providers --- services/vector_db/providers/milvus_lite.py | 43 +++++++++++++------ services/vector_db/providers/milvus_server.py | 43 +++++++++++++------ 2 files changed, 60 insertions(+), 26 deletions(-) diff --git a/services/vector_db/providers/milvus_lite.py b/services/vector_db/providers/milvus_lite.py index e04a20b..1722925 100644 --- a/services/vector_db/providers/milvus_lite.py +++ b/services/vector_db/providers/milvus_lite.py @@ -39,6 +39,35 @@ class MilvusLiteProvider(BaseVectorProvider): raise RuntimeError("Milvus Lite client is not initialized") return self.client + @staticmethod + def _extract_hit_payload(hit: Any) -> tuple[Any, Any, Dict[str, Any]]: + hit_id = getattr(hit, "id", None) + distance = getattr(hit, "distance", None) + payload: Dict[str, Any] = {} + + raw: Dict[str, Any] | None = None + if hasattr(hit, "entity"): + raw_entity = getattr(hit, "entity") + if hasattr(raw_entity, "to_dict"): + raw = dict(raw_entity.to_dict()) + else: + raw = dict(raw_entity) + elif isinstance(hit, dict): + raw = dict(hit) + + if raw: + hit_id = hit_id or raw.get("id") + distance = distance if distance is not None else raw.get("distance") + inner = raw.get("entity") + if isinstance(inner, dict): + payload = dict(inner) + else: + payload = {k: v for k, v in raw.items() if k not in {"id", "distance", "entity"}} + + payload.setdefault("path", payload.get("source_path")) + payload.setdefault("source_path", payload.get("path")) + return hit_id, distance, payload + @staticmethod def _to_int(value: Any) -> int: try: @@ -114,19 +143,7 @@ class MilvusLiteProvider(BaseVectorProvider): for hits in raw_results: bucket: List[Dict[str, Any]] = [] for hit in hits: - if hasattr(hit, "entity"): - entity = dict(getattr(hit, "entity", {}) or {}) - hit_id = getattr(hit, "id", None) - distance = getattr(hit, "distance", None) - elif isinstance(hit, dict): - entity = dict((hit.get("entity") or {})) - hit_id = hit.get("id") - distance = hit.get("distance") - else: - entity = {} - hit_id = None - distance = None - entity.setdefault("path", entity.get("source_path")) + hit_id, distance, entity = self._extract_hit_payload(hit) bucket.append({ "id": hit_id, "distance": distance, diff --git a/services/vector_db/providers/milvus_server.py b/services/vector_db/providers/milvus_server.py index b49e1e0..29244ce 100644 --- a/services/vector_db/providers/milvus_server.py +++ b/services/vector_db/providers/milvus_server.py @@ -47,6 +47,35 @@ class MilvusServerProvider(BaseVectorProvider): raise RuntimeError("Milvus Server client is not initialized") return self.client + @staticmethod + def _extract_hit_payload(hit: Any) -> tuple[Any, Any, Dict[str, Any]]: + hit_id = getattr(hit, "id", None) + distance = getattr(hit, "distance", None) + payload: Dict[str, Any] = {} + + raw: Dict[str, Any] | None = None + if hasattr(hit, "entity"): + raw_entity = getattr(hit, "entity") + if hasattr(raw_entity, "to_dict"): + raw = dict(raw_entity.to_dict()) + else: + raw = dict(raw_entity) + elif isinstance(hit, dict): + raw = dict(hit) + + if raw: + hit_id = hit_id or raw.get("id") + distance = distance if distance is not None else raw.get("distance") + inner = raw.get("entity") + if isinstance(inner, dict): + payload = dict(inner) + else: + payload = {k: v for k, v in raw.items() if k not in {"id", "distance", "entity"}} + + payload.setdefault("path", payload.get("source_path")) + payload.setdefault("source_path", payload.get("path")) + return hit_id, distance, payload + @staticmethod def _to_int(value: Any) -> int: try: @@ -121,19 +150,7 @@ class MilvusServerProvider(BaseVectorProvider): for hits in raw_results: bucket: List[Dict[str, Any]] = [] for hit in hits: - if hasattr(hit, "entity"): - entity = dict(getattr(hit, "entity", {}) or {}) - hit_id = getattr(hit, "id", None) - distance = getattr(hit, "distance", None) - elif isinstance(hit, dict): - entity = dict((hit.get("entity") or {})) - hit_id = hit.get("id") - distance = hit.get("distance") - else: - entity = {} - hit_id = None - distance = None - entity.setdefault("path", entity.get("source_path")) + hit_id, distance, entity = self._extract_hit_payload(hit) bucket.append({ "id": hit_id, "distance": distance,