mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-05-07 04:42:50 +08:00
feat: add hit payload extraction method for Milvus providers
This commit is contained in:
@@ -39,6 +39,35 @@ class MilvusLiteProvider(BaseVectorProvider):
|
||||
raise RuntimeError("Milvus Lite client is not initialized")
|
||||
return self.client
|
||||
|
||||
@staticmethod
|
||||
def _extract_hit_payload(hit: Any) -> tuple[Any, Any, Dict[str, Any]]:
|
||||
hit_id = getattr(hit, "id", None)
|
||||
distance = getattr(hit, "distance", None)
|
||||
payload: Dict[str, Any] = {}
|
||||
|
||||
raw: Dict[str, Any] | None = None
|
||||
if hasattr(hit, "entity"):
|
||||
raw_entity = getattr(hit, "entity")
|
||||
if hasattr(raw_entity, "to_dict"):
|
||||
raw = dict(raw_entity.to_dict())
|
||||
else:
|
||||
raw = dict(raw_entity)
|
||||
elif isinstance(hit, dict):
|
||||
raw = dict(hit)
|
||||
|
||||
if raw:
|
||||
hit_id = hit_id or raw.get("id")
|
||||
distance = distance if distance is not None else raw.get("distance")
|
||||
inner = raw.get("entity")
|
||||
if isinstance(inner, dict):
|
||||
payload = dict(inner)
|
||||
else:
|
||||
payload = {k: v for k, v in raw.items() if k not in {"id", "distance", "entity"}}
|
||||
|
||||
payload.setdefault("path", payload.get("source_path"))
|
||||
payload.setdefault("source_path", payload.get("path"))
|
||||
return hit_id, distance, payload
|
||||
|
||||
@staticmethod
|
||||
def _to_int(value: Any) -> int:
|
||||
try:
|
||||
@@ -114,19 +143,7 @@ class MilvusLiteProvider(BaseVectorProvider):
|
||||
for hits in raw_results:
|
||||
bucket: List[Dict[str, Any]] = []
|
||||
for hit in hits:
|
||||
if hasattr(hit, "entity"):
|
||||
entity = dict(getattr(hit, "entity", {}) or {})
|
||||
hit_id = getattr(hit, "id", None)
|
||||
distance = getattr(hit, "distance", None)
|
||||
elif isinstance(hit, dict):
|
||||
entity = dict((hit.get("entity") or {}))
|
||||
hit_id = hit.get("id")
|
||||
distance = hit.get("distance")
|
||||
else:
|
||||
entity = {}
|
||||
hit_id = None
|
||||
distance = None
|
||||
entity.setdefault("path", entity.get("source_path"))
|
||||
hit_id, distance, entity = self._extract_hit_payload(hit)
|
||||
bucket.append({
|
||||
"id": hit_id,
|
||||
"distance": distance,
|
||||
|
||||
@@ -47,6 +47,35 @@ class MilvusServerProvider(BaseVectorProvider):
|
||||
raise RuntimeError("Milvus Server client is not initialized")
|
||||
return self.client
|
||||
|
||||
@staticmethod
|
||||
def _extract_hit_payload(hit: Any) -> tuple[Any, Any, Dict[str, Any]]:
|
||||
hit_id = getattr(hit, "id", None)
|
||||
distance = getattr(hit, "distance", None)
|
||||
payload: Dict[str, Any] = {}
|
||||
|
||||
raw: Dict[str, Any] | None = None
|
||||
if hasattr(hit, "entity"):
|
||||
raw_entity = getattr(hit, "entity")
|
||||
if hasattr(raw_entity, "to_dict"):
|
||||
raw = dict(raw_entity.to_dict())
|
||||
else:
|
||||
raw = dict(raw_entity)
|
||||
elif isinstance(hit, dict):
|
||||
raw = dict(hit)
|
||||
|
||||
if raw:
|
||||
hit_id = hit_id or raw.get("id")
|
||||
distance = distance if distance is not None else raw.get("distance")
|
||||
inner = raw.get("entity")
|
||||
if isinstance(inner, dict):
|
||||
payload = dict(inner)
|
||||
else:
|
||||
payload = {k: v for k, v in raw.items() if k not in {"id", "distance", "entity"}}
|
||||
|
||||
payload.setdefault("path", payload.get("source_path"))
|
||||
payload.setdefault("source_path", payload.get("path"))
|
||||
return hit_id, distance, payload
|
||||
|
||||
@staticmethod
|
||||
def _to_int(value: Any) -> int:
|
||||
try:
|
||||
@@ -121,19 +150,7 @@ class MilvusServerProvider(BaseVectorProvider):
|
||||
for hits in raw_results:
|
||||
bucket: List[Dict[str, Any]] = []
|
||||
for hit in hits:
|
||||
if hasattr(hit, "entity"):
|
||||
entity = dict(getattr(hit, "entity", {}) or {})
|
||||
hit_id = getattr(hit, "id", None)
|
||||
distance = getattr(hit, "distance", None)
|
||||
elif isinstance(hit, dict):
|
||||
entity = dict((hit.get("entity") or {}))
|
||||
hit_id = hit.get("id")
|
||||
distance = hit.get("distance")
|
||||
else:
|
||||
entity = {}
|
||||
hit_id = None
|
||||
distance = None
|
||||
entity.setdefault("path", entity.get("source_path"))
|
||||
hit_id, distance, entity = self._extract_hit_payload(hit)
|
||||
bucket.append({
|
||||
"id": hit_id,
|
||||
"distance": distance,
|
||||
|
||||
Reference in New Issue
Block a user