mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-05-30 12:39:52 +08:00
feat: Enhance vector database providers with source path handling and improved search functionality
This commit is contained in:
@@ -68,3 +68,46 @@ async def get_text_embedding(text: str) -> List[float]:
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
return result["data"][0]["embedding"]
|
||||
|
||||
|
||||
async def rerank_texts(query: str, documents: List[str]) -> List[float]:
|
||||
"""调用重排序模型,为一组文档返回得分。未配置时返回空列表。"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
api_url = await ConfigCenter.get("AI_RERANK_API_URL")
|
||||
model = await ConfigCenter.get("AI_RERANK_MODEL")
|
||||
api_key = await ConfigCenter.get("AI_RERANK_API_KEY")
|
||||
|
||||
if not api_url or not model or not api_key:
|
||||
return []
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
resp = await client.post(api_url, headers=headers, json=payload)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError:
|
||||
return []
|
||||
data = resp.json()
|
||||
if isinstance(data, dict):
|
||||
results = data.get("results")
|
||||
if isinstance(results, list):
|
||||
scores = []
|
||||
for item in results:
|
||||
if isinstance(item, dict) and "score" in item:
|
||||
try:
|
||||
scores.append(float(item["score"]))
|
||||
except (TypeError, ValueError):
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
return []
|
||||
|
||||
@@ -1,11 +1,95 @@
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from fastapi.responses import Response
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
from services.ai import describe_image_base64, get_text_embedding
|
||||
from services.vector_db import VectorDBService, DEFAULT_VECTOR_DIMENSION
|
||||
from services.logging import LogService
|
||||
from services.config import ConfigCenter
|
||||
|
||||
try: # Pillow is optional but bundled with the project dependencies
|
||||
from PIL import Image
|
||||
except ImportError: # pragma: no cover - fallback when pillow missing
|
||||
Image = None
|
||||
|
||||
|
||||
CHUNK_SIZE = 800
|
||||
CHUNK_OVERLAP = 200
|
||||
MAX_IMAGE_EDGE = 1600
|
||||
JPEG_QUALITY = 85
|
||||
|
||||
|
||||
def _chunk_text(content: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[Tuple[int, str, int, int]]:
|
||||
"""按固定窗口拆分文本,返回(chunk_id, chunk_text, start, end)。"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = CHUNK_SIZE
|
||||
if overlap >= chunk_size:
|
||||
overlap = max(chunk_size // 4, 1)
|
||||
|
||||
chunks: List[Tuple[int, str, int, int]] = []
|
||||
step = chunk_size - overlap
|
||||
idx = 0
|
||||
start = 0
|
||||
length = len(content)
|
||||
|
||||
while start < length:
|
||||
end = min(length, start + chunk_size)
|
||||
chunk = content[start:end].strip()
|
||||
if chunk:
|
||||
chunks.append((idx, chunk, start, end))
|
||||
idx += 1
|
||||
if end >= length:
|
||||
break
|
||||
start += step
|
||||
return chunks
|
||||
|
||||
|
||||
def _guess_mime(path: str) -> str:
|
||||
mime, _ = mimetypes.guess_type(path)
|
||||
return mime or "application/octet-stream"
|
||||
|
||||
|
||||
def _chunk_key(path: str, chunk_id: str) -> str:
|
||||
return f"{path}#chunk={chunk_id}"
|
||||
|
||||
|
||||
def _compress_image_for_embedding(input_bytes: bytes) -> Tuple[bytes, Dict[str, Any] | None]:
|
||||
"""压缩图片,降低发送到视觉模型的体积。"""
|
||||
if Image is None:
|
||||
return input_bytes, None
|
||||
|
||||
try:
|
||||
with Image.open(BytesIO(input_bytes)) as img:
|
||||
img = img.convert("RGB")
|
||||
width, height = img.size
|
||||
longest_edge = max(width, height)
|
||||
scale = 1.0
|
||||
if longest_edge > MAX_IMAGE_EDGE:
|
||||
scale = MAX_IMAGE_EDGE / float(longest_edge)
|
||||
new_size = (max(int(width * scale), 1), max(int(height * scale), 1))
|
||||
resample_mode = getattr(getattr(Image, "Resampling", Image), "LANCZOS")
|
||||
img = img.resize(new_size, resample=resample_mode)
|
||||
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="JPEG", quality=JPEG_QUALITY, optimize=True)
|
||||
compressed = buffer.getvalue()
|
||||
|
||||
if len(compressed) < len(input_bytes):
|
||||
return compressed, {
|
||||
"original_bytes": len(input_bytes),
|
||||
"compressed_bytes": len(compressed),
|
||||
"scaled": scale < 1.0,
|
||||
"width": img.width,
|
||||
"height": img.height,
|
||||
}
|
||||
except Exception: # pragma: no cover - 任意图像处理异常时回退
|
||||
return input_bytes, None
|
||||
|
||||
return input_bytes, None
|
||||
|
||||
|
||||
class VectorIndexProcessor:
|
||||
name = "向量索引"
|
||||
@@ -33,6 +117,7 @@ class VectorIndexProcessor:
|
||||
index_type = config.get("index_type", "vector")
|
||||
vector_db = VectorDBService()
|
||||
collection_name = "vector_collection"
|
||||
|
||||
if action == "destroy":
|
||||
await vector_db.delete_vector(collection_name, path)
|
||||
await LogService.info(
|
||||
@@ -42,9 +127,19 @@ class VectorIndexProcessor:
|
||||
)
|
||||
return Response(content=f"文件 {path} 的 {index_type} 索引已销毁", media_type="text/plain")
|
||||
|
||||
if index_type == 'simple':
|
||||
mime_type = _guess_mime(path)
|
||||
|
||||
if index_type == "simple":
|
||||
await vector_db.ensure_collection(collection_name, vector=False)
|
||||
await vector_db.upsert_vector(collection_name, {'path': path})
|
||||
await vector_db.delete_vector(collection_name, path)
|
||||
await vector_db.upsert_vector(collection_name, {
|
||||
"path": path,
|
||||
"source_path": path,
|
||||
"chunk_id": "filename",
|
||||
"mime": mime_type,
|
||||
"type": "filename",
|
||||
"name": os.path.basename(path),
|
||||
})
|
||||
await LogService.info(
|
||||
"processor:vector_index",
|
||||
f"Created simple index for {path}",
|
||||
@@ -53,24 +148,7 @@ class VectorIndexProcessor:
|
||||
return Response(content=f"文件 {path} 的普通索引已创建", media_type="text/plain")
|
||||
|
||||
file_ext = path.split('.')[-1].lower()
|
||||
description = ""
|
||||
embedding = None
|
||||
|
||||
if file_ext in ["jpg", "jpeg", "png", "bmp"]:
|
||||
base64_image = base64.b64encode(input_bytes).decode("utf-8")
|
||||
description = await describe_image_base64(base64_image)
|
||||
embedding = await get_text_embedding(description)
|
||||
log_message = f"Indexed image {path}"
|
||||
response_message = f"图片已索引,描述:{description}"
|
||||
elif file_ext in ["txt", "md"]:
|
||||
text = input_bytes.decode("utf-8")
|
||||
embedding = await get_text_embedding(text)
|
||||
description = text[:100] + "..." if len(text) > 100 else text
|
||||
log_message = f"Indexed text file {path}"
|
||||
response_message = f"文本文件已索引"
|
||||
|
||||
if embedding is None:
|
||||
return Response(content="不支持的文件类型", status_code=400)
|
||||
details: Dict[str, Any] = {"path": path, "action": "create", "index_type": "vector"}
|
||||
|
||||
raw_dim = await ConfigCenter.get('AI_EMBED_DIM', DEFAULT_VECTOR_DIMENSION)
|
||||
try:
|
||||
@@ -81,15 +159,103 @@ class VectorIndexProcessor:
|
||||
vector_dim = DEFAULT_VECTOR_DIMENSION
|
||||
|
||||
await vector_db.ensure_collection(collection_name, vector=True, dim=vector_dim)
|
||||
await vector_db.upsert_vector(
|
||||
collection_name, {'path': path, 'embedding': embedding})
|
||||
|
||||
await vector_db.delete_vector(collection_name, path)
|
||||
|
||||
if file_ext in ["jpg", "jpeg", "png", "bmp"]:
|
||||
processed_bytes, compression = _compress_image_for_embedding(input_bytes)
|
||||
base64_image = base64.b64encode(processed_bytes).decode("utf-8")
|
||||
description = await describe_image_base64(base64_image)
|
||||
embedding = await get_text_embedding(description)
|
||||
image_mime = "image/jpeg" if compression else mime_type
|
||||
await vector_db.upsert_vector(collection_name, {
|
||||
"path": _chunk_key(path, "image"),
|
||||
"source_path": path,
|
||||
"chunk_id": "image",
|
||||
"embedding": embedding,
|
||||
"text": description,
|
||||
"mime": image_mime,
|
||||
"type": "image",
|
||||
})
|
||||
details["description"] = description
|
||||
if compression:
|
||||
details["image_compression"] = compression
|
||||
await LogService.info(
|
||||
"processor:vector_index",
|
||||
f"Indexed image {path}",
|
||||
details=details,
|
||||
)
|
||||
return Response(content=f"图片已索引,描述:{description}", media_type="text/plain")
|
||||
|
||||
if file_ext in ["txt", "md"]:
|
||||
try:
|
||||
text = input_bytes.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return Response(content="文本文件解码失败", status_code=400)
|
||||
|
||||
chunks = _chunk_text(text)
|
||||
if not chunks:
|
||||
await vector_db.upsert_vector(collection_name, {
|
||||
"path": _chunk_key(path, "0"),
|
||||
"source_path": path,
|
||||
"chunk_id": "0",
|
||||
"embedding": await get_text_embedding(text or path),
|
||||
"text": text,
|
||||
"mime": mime_type,
|
||||
"type": "text",
|
||||
"start_offset": 0,
|
||||
"end_offset": len(text),
|
||||
})
|
||||
details["chunks"] = 1
|
||||
await LogService.info(
|
||||
"processor:vector_index",
|
||||
f"Indexed text file {path}",
|
||||
details=details,
|
||||
)
|
||||
return Response(content="文本文件已索引", media_type="text/plain")
|
||||
|
||||
chunk_count = 0
|
||||
for chunk_id, chunk_text, start, end in chunks:
|
||||
embedding = await get_text_embedding(chunk_text)
|
||||
await vector_db.upsert_vector(collection_name, {
|
||||
"path": _chunk_key(path, str(chunk_id)),
|
||||
"source_path": path,
|
||||
"chunk_id": str(chunk_id),
|
||||
"embedding": embedding,
|
||||
"text": chunk_text,
|
||||
"mime": mime_type,
|
||||
"type": "text",
|
||||
"start_offset": start,
|
||||
"end_offset": end,
|
||||
})
|
||||
chunk_count += 1
|
||||
|
||||
details["chunks"] = chunk_count
|
||||
sample = chunks[0][1]
|
||||
details["sample"] = sample[:120]
|
||||
await LogService.info(
|
||||
"processor:vector_index",
|
||||
f"Indexed text file {path}",
|
||||
details=details,
|
||||
)
|
||||
return Response(content="文本文件已索引", media_type="text/plain")
|
||||
|
||||
# 其他类型暂未支持向量索引,回退为文件名索引
|
||||
await vector_db.delete_vector(collection_name, path)
|
||||
await vector_db.upsert_vector(collection_name, {
|
||||
"path": _chunk_key(path, "fallback"),
|
||||
"source_path": path,
|
||||
"chunk_id": "filename",
|
||||
"mime": mime_type,
|
||||
"type": "filename",
|
||||
"name": os.path.basename(path),
|
||||
"embedding": [0.0] * vector_dim,
|
||||
})
|
||||
await LogService.info(
|
||||
"processor:vector_index",
|
||||
log_message,
|
||||
details={"path": path, "description": description, "action": "create", "index_type": "vector"},
|
||||
f"File type fallback to simple index for {path}",
|
||||
details={"path": path, "action": "create", "index_type": "simple", "original_type": file_ext},
|
||||
)
|
||||
return Response(content=response_message, media_type="text/plain")
|
||||
return Response(content="暂不支持该类型的向量索引,已创建文件名索引", media_type="text/plain")
|
||||
|
||||
|
||||
PROCESSOR_TYPE = "vector_index"
|
||||
|
||||
@@ -50,15 +50,20 @@ class MilvusLiteProvider(BaseVectorProvider):
|
||||
client = self._get_client()
|
||||
if client.has_collection(collection_name):
|
||||
return
|
||||
common_fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
||||
FieldSchema(name="source_path", dtype=DataType.VARCHAR, max_length=512, is_primary=False, auto_id=False),
|
||||
]
|
||||
|
||||
if vector:
|
||||
vector_dim = dim if isinstance(dim, int) and dim > 0 else 0
|
||||
if vector_dim <= 0:
|
||||
vector_dim = 4096
|
||||
fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
||||
*common_fields,
|
||||
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=vector_dim),
|
||||
]
|
||||
schema = CollectionSchema(fields, description="Image vector collection")
|
||||
schema = CollectionSchema(fields, description="Vector collection", enable_dynamic_field=True)
|
||||
client.create_collection(collection_name, schema=schema)
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
index_params.add_index(
|
||||
@@ -70,38 +75,98 @@ class MilvusLiteProvider(BaseVectorProvider):
|
||||
)
|
||||
client.create_index(collection_name, index_params=index_params)
|
||||
else:
|
||||
fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
||||
]
|
||||
schema = CollectionSchema(fields, description="Simple file index")
|
||||
schema = CollectionSchema(common_fields, description="Simple file index", enable_dynamic_field=True)
|
||||
client.create_collection(collection_name, schema=schema)
|
||||
|
||||
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
self._get_client().upsert(collection_name, data)
|
||||
payload = dict(data)
|
||||
payload.setdefault("source_path", payload.get("path"))
|
||||
payload.setdefault("vector_id", payload.get("path"))
|
||||
self._get_client().upsert(collection_name, data=[payload])
|
||||
|
||||
def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
self._get_client().delete(collection_name, ids=[path])
|
||||
client = self._get_client()
|
||||
escaped = path.replace('"', '\\"')
|
||||
client.delete(collection_name, filter=f'source_path == "{escaped}"')
|
||||
|
||||
def search_vectors(self, collection_name: str, query_embedding, top_k: int):
|
||||
search_params = {"metric_type": "COSINE"}
|
||||
return self._get_client().search(
|
||||
output_fields = [
|
||||
"path",
|
||||
"source_path",
|
||||
"chunk_id",
|
||||
"mime",
|
||||
"text",
|
||||
"start_offset",
|
||||
"end_offset",
|
||||
"type",
|
||||
"name",
|
||||
]
|
||||
raw_results = self._get_client().search(
|
||||
collection_name,
|
||||
data=[query_embedding],
|
||||
anns_field="embedding",
|
||||
search_params=search_params,
|
||||
limit=top_k,
|
||||
output_fields=["path"],
|
||||
output_fields=output_fields,
|
||||
)
|
||||
formatted: List[List[Dict[str, Any]]] = []
|
||||
for hits in raw_results:
|
||||
bucket: List[Dict[str, Any]] = []
|
||||
for hit in hits:
|
||||
if hasattr(hit, "entity"):
|
||||
entity = dict(getattr(hit, "entity", {}) or {})
|
||||
hit_id = getattr(hit, "id", None)
|
||||
distance = getattr(hit, "distance", None)
|
||||
elif isinstance(hit, dict):
|
||||
entity = dict((hit.get("entity") or {}))
|
||||
hit_id = hit.get("id")
|
||||
distance = hit.get("distance")
|
||||
else:
|
||||
entity = {}
|
||||
hit_id = None
|
||||
distance = None
|
||||
entity.setdefault("path", entity.get("source_path"))
|
||||
bucket.append({
|
||||
"id": hit_id,
|
||||
"distance": distance,
|
||||
"entity": entity,
|
||||
})
|
||||
formatted.append(bucket)
|
||||
return formatted
|
||||
|
||||
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
|
||||
filter_expr = f"path like '%{query_path}%'" if query_path else "path like '%%'"
|
||||
if query_path:
|
||||
escaped = query_path.replace('"', '\\"')
|
||||
filter_expr = f'source_path like "%{escaped}%"'
|
||||
else:
|
||||
filter_expr = "source_path like '%%'"
|
||||
results = self._get_client().query(
|
||||
collection_name,
|
||||
filter=filter_expr,
|
||||
limit=top_k,
|
||||
output_fields=["path"],
|
||||
output_fields=[
|
||||
"path",
|
||||
"source_path",
|
||||
"chunk_id",
|
||||
"mime",
|
||||
"text",
|
||||
"start_offset",
|
||||
"end_offset",
|
||||
"type",
|
||||
"name",
|
||||
],
|
||||
)
|
||||
return [[{"id": r["path"], "distance": 1.0, "entity": {"path": r["path"]}} for r in results]]
|
||||
formatted = []
|
||||
for row in results:
|
||||
entity = dict(row)
|
||||
entity.setdefault("path", entity.get("source_path"))
|
||||
formatted.append({
|
||||
"id": entity.get("path"),
|
||||
"distance": 1.0,
|
||||
"entity": entity,
|
||||
})
|
||||
return [formatted]
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Any]:
|
||||
client = self._get_client()
|
||||
|
||||
@@ -58,15 +58,19 @@ class MilvusServerProvider(BaseVectorProvider):
|
||||
client = self._get_client()
|
||||
if client.has_collection(collection_name):
|
||||
return
|
||||
common_fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
||||
FieldSchema(name="source_path", dtype=DataType.VARCHAR, max_length=512, is_primary=False, auto_id=False),
|
||||
]
|
||||
if vector:
|
||||
vector_dim = dim if isinstance(dim, int) and dim > 0 else 0
|
||||
if vector_dim <= 0:
|
||||
vector_dim = 4096
|
||||
fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
||||
*common_fields,
|
||||
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=vector_dim),
|
||||
]
|
||||
schema = CollectionSchema(fields, description="Image vector collection")
|
||||
schema = CollectionSchema(fields, description="Vector collection", enable_dynamic_field=True)
|
||||
client.create_collection(collection_name, schema=schema)
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
index_params.add_index(
|
||||
@@ -78,38 +82,98 @@ class MilvusServerProvider(BaseVectorProvider):
|
||||
)
|
||||
client.create_index(collection_name, index_params=index_params)
|
||||
else:
|
||||
fields = [
|
||||
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
|
||||
]
|
||||
schema = CollectionSchema(fields, description="Simple file index")
|
||||
schema = CollectionSchema(common_fields, description="Simple file index", enable_dynamic_field=True)
|
||||
client.create_collection(collection_name, schema=schema)
|
||||
|
||||
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
self._get_client().upsert(collection_name, data)
|
||||
payload = dict(data)
|
||||
payload.setdefault("source_path", payload.get("path"))
|
||||
payload.setdefault("vector_id", payload.get("path"))
|
||||
self._get_client().upsert(collection_name, data=[payload])
|
||||
|
||||
def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
self._get_client().delete(collection_name, ids=[path])
|
||||
client = self._get_client()
|
||||
escaped = path.replace('"', '\\"')
|
||||
client.delete(collection_name, filter=f'source_path == "{escaped}"')
|
||||
|
||||
def search_vectors(self, collection_name: str, query_embedding, top_k: int):
|
||||
search_params = {"metric_type": "COSINE"}
|
||||
return self._get_client().search(
|
||||
output_fields = [
|
||||
"path",
|
||||
"source_path",
|
||||
"chunk_id",
|
||||
"mime",
|
||||
"text",
|
||||
"start_offset",
|
||||
"end_offset",
|
||||
"type",
|
||||
"name",
|
||||
]
|
||||
raw_results = self._get_client().search(
|
||||
collection_name,
|
||||
data=[query_embedding],
|
||||
anns_field="embedding",
|
||||
search_params=search_params,
|
||||
limit=top_k,
|
||||
output_fields=["path"],
|
||||
output_fields=output_fields,
|
||||
)
|
||||
formatted: List[List[Dict[str, Any]]] = []
|
||||
for hits in raw_results:
|
||||
bucket: List[Dict[str, Any]] = []
|
||||
for hit in hits:
|
||||
if hasattr(hit, "entity"):
|
||||
entity = dict(getattr(hit, "entity", {}) or {})
|
||||
hit_id = getattr(hit, "id", None)
|
||||
distance = getattr(hit, "distance", None)
|
||||
elif isinstance(hit, dict):
|
||||
entity = dict((hit.get("entity") or {}))
|
||||
hit_id = hit.get("id")
|
||||
distance = hit.get("distance")
|
||||
else:
|
||||
entity = {}
|
||||
hit_id = None
|
||||
distance = None
|
||||
entity.setdefault("path", entity.get("source_path"))
|
||||
bucket.append({
|
||||
"id": hit_id,
|
||||
"distance": distance,
|
||||
"entity": entity,
|
||||
})
|
||||
formatted.append(bucket)
|
||||
return formatted
|
||||
|
||||
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
|
||||
filter_expr = f"path like '%{query_path}%'" if query_path else "path like '%%'"
|
||||
if query_path:
|
||||
escaped = query_path.replace('"', '\\"')
|
||||
filter_expr = f'source_path like "%{escaped}%"'
|
||||
else:
|
||||
filter_expr = "source_path like '%%'"
|
||||
results = self._get_client().query(
|
||||
collection_name,
|
||||
filter=filter_expr,
|
||||
limit=top_k,
|
||||
output_fields=["path"],
|
||||
output_fields=[
|
||||
"path",
|
||||
"source_path",
|
||||
"chunk_id",
|
||||
"mime",
|
||||
"text",
|
||||
"start_offset",
|
||||
"end_offset",
|
||||
"type",
|
||||
"name",
|
||||
],
|
||||
)
|
||||
return [[{"id": r["path"], "distance": 1.0, "entity": {"path": r["path"]}} for r in results]]
|
||||
formatted = []
|
||||
for row in results:
|
||||
entity = dict(row)
|
||||
entity.setdefault("path", entity.get("source_path"))
|
||||
formatted.append({
|
||||
"id": entity.get("path"),
|
||||
"distance": 1.0,
|
||||
"entity": entity,
|
||||
})
|
||||
return [formatted]
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Any]:
|
||||
client = self._get_client()
|
||||
|
||||
@@ -58,29 +58,59 @@ class QdrantProvider(BaseVectorProvider):
|
||||
size = dim if vector and isinstance(dim, int) and dim > 0 else 1
|
||||
return qmodels.VectorParams(size=size, distance=qmodels.Distance.COSINE)
|
||||
|
||||
def _ensure_payload_indexes(self, client: QdrantClient, collection_name: str) -> None:
|
||||
for field in ("path", "source_path"):
|
||||
try:
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=field,
|
||||
field_schema="keyword",
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - 依赖外部服务
|
||||
message = str(exc).lower()
|
||||
if "already exists" in message or "index exists" in message:
|
||||
continue
|
||||
# 旧版本 qdrant 可能返回带状态码的异常,这里容忍重复创建
|
||||
raise
|
||||
|
||||
def ensure_collection(self, collection_name: str, vector: bool, dim: int) -> None:
|
||||
client = self._get_client()
|
||||
try:
|
||||
if client.collection_exists(collection_name):
|
||||
return
|
||||
exists = client.collection_exists(collection_name)
|
||||
except Exception as exc: # pragma: no cover - 依赖外部服务
|
||||
raise RuntimeError(f"Failed to check Qdrant collection '{collection_name}': {exc}") from exc
|
||||
|
||||
if exists:
|
||||
try:
|
||||
self._ensure_payload_indexes(client, collection_name)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
vectors_config = self._vector_params(vector, dim)
|
||||
try:
|
||||
client.create_collection(collection_name=collection_name, vectors_config=vectors_config)
|
||||
except Exception as exc: # pragma: no cover
|
||||
if "already exists" in str(exc).lower():
|
||||
try:
|
||||
self._ensure_payload_indexes(client, collection_name)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
raise RuntimeError(f"Failed to create Qdrant collection '{collection_name}': {exc}") from exc
|
||||
|
||||
try:
|
||||
self._ensure_payload_indexes(client, collection_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _point_id(path: str) -> str:
|
||||
return str(uuid5(NAMESPACE_URL, path))
|
||||
def _point_id(uid: str) -> str:
|
||||
return str(uuid5(NAMESPACE_URL, uid))
|
||||
|
||||
def _prepare_point(self, data: Dict[str, Any]) -> qmodels.PointStruct:
|
||||
path = data.get("path")
|
||||
if not path:
|
||||
uid = data.get("path")
|
||||
if not uid:
|
||||
raise ValueError("Qdrant upsert requires 'path' in data")
|
||||
|
||||
embedding = data.get("embedding")
|
||||
@@ -89,8 +119,11 @@ class QdrantProvider(BaseVectorProvider):
|
||||
else:
|
||||
vector = [float(x) for x in embedding]
|
||||
|
||||
payload = {"path": path}
|
||||
return qmodels.PointStruct(id=self._point_id(path), vector=vector, payload=payload)
|
||||
payload = {k: v for k, v in data.items() if k != "embedding"}
|
||||
payload.setdefault("vector_id", uid)
|
||||
source_path = payload.get("source_path") or payload.get("path")
|
||||
payload["path"] = source_path
|
||||
return qmodels.PointStruct(id=self._point_id(str(uid)), vector=vector, payload=payload)
|
||||
|
||||
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
client = self._get_client()
|
||||
@@ -99,7 +132,12 @@ class QdrantProvider(BaseVectorProvider):
|
||||
|
||||
def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
client = self._get_client()
|
||||
selector = qmodels.PointIdsList(points=[self._point_id(path)])
|
||||
condition = qmodels.FieldCondition(
|
||||
key="path",
|
||||
match=qmodels.MatchValue(value=path),
|
||||
)
|
||||
flt = qmodels.Filter(must=[condition])
|
||||
selector = qmodels.FilterSelector(filter=flt)
|
||||
client.delete(collection_name=collection_name, points_selector=selector, wait=True)
|
||||
|
||||
def _format_search_results(self, points: Sequence[qmodels.ScoredPoint]):
|
||||
@@ -107,7 +145,7 @@ class QdrantProvider(BaseVectorProvider):
|
||||
{
|
||||
"id": point.id,
|
||||
"distance": point.score,
|
||||
"entity": {"path": (point.payload or {}).get("path")},
|
||||
"entity": point.payload or {},
|
||||
}
|
||||
for point in points
|
||||
]
|
||||
@@ -141,11 +179,11 @@ class QdrantProvider(BaseVectorProvider):
|
||||
break
|
||||
|
||||
for record in records:
|
||||
path = (record.payload or {}).get("path")
|
||||
if query_path and path:
|
||||
if query_path not in path:
|
||||
continue
|
||||
results.append({"id": record.id, "distance": 1.0, "entity": {"path": path}})
|
||||
payload = record.payload or {}
|
||||
path = payload.get("path")
|
||||
if query_path and path and query_path not in path:
|
||||
continue
|
||||
results.append({"id": record.id, "distance": 1.0, "entity": payload})
|
||||
if len(results) >= top_k:
|
||||
break
|
||||
|
||||
|
||||
Reference in New Issue
Block a user