feat: Enhance vector database providers with source path handling and improved search functionality

This commit is contained in:
shiyu
2025-09-27 13:34:18 +08:00
parent ee6e570ccb
commit a4af9475ef
10 changed files with 1082 additions and 353 deletions

View File

@@ -68,3 +68,46 @@ async def get_text_embedding(text: str) -> List[float]:
resp.raise_for_status()
result = resp.json()
return result["data"][0]["embedding"]
async def rerank_texts(query: str, documents: List[str]) -> List[float]:
"""调用重排序模型,为一组文档返回得分。未配置时返回空列表。"""
if not documents:
return []
api_url = await ConfigCenter.get("AI_RERANK_API_URL")
model = await ConfigCenter.get("AI_RERANK_MODEL")
api_key = await ConfigCenter.get("AI_RERANK_API_KEY")
if not api_url or not model or not api_key:
return []
payload = {
"model": model,
"query": query,
"documents": documents,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
async with httpx.AsyncClient() as client:
try:
resp = await client.post(api_url, headers=headers, json=payload)
resp.raise_for_status()
except httpx.HTTPStatusError:
return []
data = resp.json()
if isinstance(data, dict):
results = data.get("results")
if isinstance(results, list):
scores = []
for item in results:
if isinstance(item, dict) and "score" in item:
try:
scores.append(float(item["score"]))
except (TypeError, ValueError):
scores.append(0.0)
return scores
return []

View File

@@ -1,11 +1,95 @@
from typing import Dict, Any
from typing import Dict, Any, List, Tuple
from fastapi.responses import Response
import base64
import mimetypes
import os
from io import BytesIO
from services.ai import describe_image_base64, get_text_embedding
from services.vector_db import VectorDBService, DEFAULT_VECTOR_DIMENSION
from services.logging import LogService
from services.config import ConfigCenter
try: # Pillow is optional but bundled with the project dependencies
from PIL import Image
except ImportError: # pragma: no cover - fallback when pillow missing
Image = None
CHUNK_SIZE = 800
CHUNK_OVERLAP = 200
MAX_IMAGE_EDGE = 1600
JPEG_QUALITY = 85
def _chunk_text(content: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[Tuple[int, str, int, int]]:
"""按固定窗口拆分文本,返回(chunk_id, chunk_text, start, end)。"""
if chunk_size <= 0:
chunk_size = CHUNK_SIZE
if overlap >= chunk_size:
overlap = max(chunk_size // 4, 1)
chunks: List[Tuple[int, str, int, int]] = []
step = chunk_size - overlap
idx = 0
start = 0
length = len(content)
while start < length:
end = min(length, start + chunk_size)
chunk = content[start:end].strip()
if chunk:
chunks.append((idx, chunk, start, end))
idx += 1
if end >= length:
break
start += step
return chunks
def _guess_mime(path: str) -> str:
mime, _ = mimetypes.guess_type(path)
return mime or "application/octet-stream"
def _chunk_key(path: str, chunk_id: str) -> str:
return f"{path}#chunk={chunk_id}"
def _compress_image_for_embedding(input_bytes: bytes) -> Tuple[bytes, Dict[str, Any] | None]:
"""压缩图片,降低发送到视觉模型的体积。"""
if Image is None:
return input_bytes, None
try:
with Image.open(BytesIO(input_bytes)) as img:
img = img.convert("RGB")
width, height = img.size
longest_edge = max(width, height)
scale = 1.0
if longest_edge > MAX_IMAGE_EDGE:
scale = MAX_IMAGE_EDGE / float(longest_edge)
new_size = (max(int(width * scale), 1), max(int(height * scale), 1))
resample_mode = getattr(getattr(Image, "Resampling", Image), "LANCZOS")
img = img.resize(new_size, resample=resample_mode)
buffer = BytesIO()
img.save(buffer, format="JPEG", quality=JPEG_QUALITY, optimize=True)
compressed = buffer.getvalue()
if len(compressed) < len(input_bytes):
return compressed, {
"original_bytes": len(input_bytes),
"compressed_bytes": len(compressed),
"scaled": scale < 1.0,
"width": img.width,
"height": img.height,
}
except Exception: # pragma: no cover - 任意图像处理异常时回退
return input_bytes, None
return input_bytes, None
class VectorIndexProcessor:
name = "向量索引"
@@ -33,6 +117,7 @@ class VectorIndexProcessor:
index_type = config.get("index_type", "vector")
vector_db = VectorDBService()
collection_name = "vector_collection"
if action == "destroy":
await vector_db.delete_vector(collection_name, path)
await LogService.info(
@@ -42,9 +127,19 @@ class VectorIndexProcessor:
)
return Response(content=f"文件 {path}{index_type} 索引已销毁", media_type="text/plain")
if index_type == 'simple':
mime_type = _guess_mime(path)
if index_type == "simple":
await vector_db.ensure_collection(collection_name, vector=False)
await vector_db.upsert_vector(collection_name, {'path': path})
await vector_db.delete_vector(collection_name, path)
await vector_db.upsert_vector(collection_name, {
"path": path,
"source_path": path,
"chunk_id": "filename",
"mime": mime_type,
"type": "filename",
"name": os.path.basename(path),
})
await LogService.info(
"processor:vector_index",
f"Created simple index for {path}",
@@ -53,24 +148,7 @@ class VectorIndexProcessor:
return Response(content=f"文件 {path} 的普通索引已创建", media_type="text/plain")
file_ext = path.split('.')[-1].lower()
description = ""
embedding = None
if file_ext in ["jpg", "jpeg", "png", "bmp"]:
base64_image = base64.b64encode(input_bytes).decode("utf-8")
description = await describe_image_base64(base64_image)
embedding = await get_text_embedding(description)
log_message = f"Indexed image {path}"
response_message = f"图片已索引,描述:{description}"
elif file_ext in ["txt", "md"]:
text = input_bytes.decode("utf-8")
embedding = await get_text_embedding(text)
description = text[:100] + "..." if len(text) > 100 else text
log_message = f"Indexed text file {path}"
response_message = f"文本文件已索引"
if embedding is None:
return Response(content="不支持的文件类型", status_code=400)
details: Dict[str, Any] = {"path": path, "action": "create", "index_type": "vector"}
raw_dim = await ConfigCenter.get('AI_EMBED_DIM', DEFAULT_VECTOR_DIMENSION)
try:
@@ -81,15 +159,103 @@ class VectorIndexProcessor:
vector_dim = DEFAULT_VECTOR_DIMENSION
await vector_db.ensure_collection(collection_name, vector=True, dim=vector_dim)
await vector_db.upsert_vector(
collection_name, {'path': path, 'embedding': embedding})
await vector_db.delete_vector(collection_name, path)
if file_ext in ["jpg", "jpeg", "png", "bmp"]:
processed_bytes, compression = _compress_image_for_embedding(input_bytes)
base64_image = base64.b64encode(processed_bytes).decode("utf-8")
description = await describe_image_base64(base64_image)
embedding = await get_text_embedding(description)
image_mime = "image/jpeg" if compression else mime_type
await vector_db.upsert_vector(collection_name, {
"path": _chunk_key(path, "image"),
"source_path": path,
"chunk_id": "image",
"embedding": embedding,
"text": description,
"mime": image_mime,
"type": "image",
})
details["description"] = description
if compression:
details["image_compression"] = compression
await LogService.info(
"processor:vector_index",
f"Indexed image {path}",
details=details,
)
return Response(content=f"图片已索引,描述:{description}", media_type="text/plain")
if file_ext in ["txt", "md"]:
try:
text = input_bytes.decode("utf-8")
except UnicodeDecodeError:
return Response(content="文本文件解码失败", status_code=400)
chunks = _chunk_text(text)
if not chunks:
await vector_db.upsert_vector(collection_name, {
"path": _chunk_key(path, "0"),
"source_path": path,
"chunk_id": "0",
"embedding": await get_text_embedding(text or path),
"text": text,
"mime": mime_type,
"type": "text",
"start_offset": 0,
"end_offset": len(text),
})
details["chunks"] = 1
await LogService.info(
"processor:vector_index",
f"Indexed text file {path}",
details=details,
)
return Response(content="文本文件已索引", media_type="text/plain")
chunk_count = 0
for chunk_id, chunk_text, start, end in chunks:
embedding = await get_text_embedding(chunk_text)
await vector_db.upsert_vector(collection_name, {
"path": _chunk_key(path, str(chunk_id)),
"source_path": path,
"chunk_id": str(chunk_id),
"embedding": embedding,
"text": chunk_text,
"mime": mime_type,
"type": "text",
"start_offset": start,
"end_offset": end,
})
chunk_count += 1
details["chunks"] = chunk_count
sample = chunks[0][1]
details["sample"] = sample[:120]
await LogService.info(
"processor:vector_index",
f"Indexed text file {path}",
details=details,
)
return Response(content="文本文件已索引", media_type="text/plain")
# 其他类型暂未支持向量索引,回退为文件名索引
await vector_db.delete_vector(collection_name, path)
await vector_db.upsert_vector(collection_name, {
"path": _chunk_key(path, "fallback"),
"source_path": path,
"chunk_id": "filename",
"mime": mime_type,
"type": "filename",
"name": os.path.basename(path),
"embedding": [0.0] * vector_dim,
})
await LogService.info(
"processor:vector_index",
log_message,
details={"path": path, "description": description, "action": "create", "index_type": "vector"},
f"File type fallback to simple index for {path}",
details={"path": path, "action": "create", "index_type": "simple", "original_type": file_ext},
)
return Response(content=response_message, media_type="text/plain")
return Response(content="暂不支持该类型的向量索引,已创建文件名索引", media_type="text/plain")
PROCESSOR_TYPE = "vector_index"

View File

@@ -50,15 +50,20 @@ class MilvusLiteProvider(BaseVectorProvider):
client = self._get_client()
if client.has_collection(collection_name):
return
common_fields = [
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
FieldSchema(name="source_path", dtype=DataType.VARCHAR, max_length=512, is_primary=False, auto_id=False),
]
if vector:
vector_dim = dim if isinstance(dim, int) and dim > 0 else 0
if vector_dim <= 0:
vector_dim = 4096
fields = [
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
*common_fields,
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=vector_dim),
]
schema = CollectionSchema(fields, description="Image vector collection")
schema = CollectionSchema(fields, description="Vector collection", enable_dynamic_field=True)
client.create_collection(collection_name, schema=schema)
index_params = MilvusClient.prepare_index_params()
index_params.add_index(
@@ -70,38 +75,98 @@ class MilvusLiteProvider(BaseVectorProvider):
)
client.create_index(collection_name, index_params=index_params)
else:
fields = [
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
]
schema = CollectionSchema(fields, description="Simple file index")
schema = CollectionSchema(common_fields, description="Simple file index", enable_dynamic_field=True)
client.create_collection(collection_name, schema=schema)
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
self._get_client().upsert(collection_name, data)
payload = dict(data)
payload.setdefault("source_path", payload.get("path"))
payload.setdefault("vector_id", payload.get("path"))
self._get_client().upsert(collection_name, data=[payload])
def delete_vector(self, collection_name: str, path: str) -> None:
self._get_client().delete(collection_name, ids=[path])
client = self._get_client()
escaped = path.replace('"', '\\"')
client.delete(collection_name, filter=f'source_path == "{escaped}"')
def search_vectors(self, collection_name: str, query_embedding, top_k: int):
search_params = {"metric_type": "COSINE"}
return self._get_client().search(
output_fields = [
"path",
"source_path",
"chunk_id",
"mime",
"text",
"start_offset",
"end_offset",
"type",
"name",
]
raw_results = self._get_client().search(
collection_name,
data=[query_embedding],
anns_field="embedding",
search_params=search_params,
limit=top_k,
output_fields=["path"],
output_fields=output_fields,
)
formatted: List[List[Dict[str, Any]]] = []
for hits in raw_results:
bucket: List[Dict[str, Any]] = []
for hit in hits:
if hasattr(hit, "entity"):
entity = dict(getattr(hit, "entity", {}) or {})
hit_id = getattr(hit, "id", None)
distance = getattr(hit, "distance", None)
elif isinstance(hit, dict):
entity = dict((hit.get("entity") or {}))
hit_id = hit.get("id")
distance = hit.get("distance")
else:
entity = {}
hit_id = None
distance = None
entity.setdefault("path", entity.get("source_path"))
bucket.append({
"id": hit_id,
"distance": distance,
"entity": entity,
})
formatted.append(bucket)
return formatted
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
filter_expr = f"path like '%{query_path}%'" if query_path else "path like '%%'"
if query_path:
escaped = query_path.replace('"', '\\"')
filter_expr = f'source_path like "%{escaped}%"'
else:
filter_expr = "source_path like '%%'"
results = self._get_client().query(
collection_name,
filter=filter_expr,
limit=top_k,
output_fields=["path"],
output_fields=[
"path",
"source_path",
"chunk_id",
"mime",
"text",
"start_offset",
"end_offset",
"type",
"name",
],
)
return [[{"id": r["path"], "distance": 1.0, "entity": {"path": r["path"]}} for r in results]]
formatted = []
for row in results:
entity = dict(row)
entity.setdefault("path", entity.get("source_path"))
formatted.append({
"id": entity.get("path"),
"distance": 1.0,
"entity": entity,
})
return [formatted]
def get_all_stats(self) -> Dict[str, Any]:
client = self._get_client()

View File

@@ -58,15 +58,19 @@ class MilvusServerProvider(BaseVectorProvider):
client = self._get_client()
if client.has_collection(collection_name):
return
common_fields = [
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
FieldSchema(name="source_path", dtype=DataType.VARCHAR, max_length=512, is_primary=False, auto_id=False),
]
if vector:
vector_dim = dim if isinstance(dim, int) and dim > 0 else 0
if vector_dim <= 0:
vector_dim = 4096
fields = [
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
*common_fields,
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=vector_dim),
]
schema = CollectionSchema(fields, description="Image vector collection")
schema = CollectionSchema(fields, description="Vector collection", enable_dynamic_field=True)
client.create_collection(collection_name, schema=schema)
index_params = MilvusClient.prepare_index_params()
index_params.add_index(
@@ -78,38 +82,98 @@ class MilvusServerProvider(BaseVectorProvider):
)
client.create_index(collection_name, index_params=index_params)
else:
fields = [
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=512, is_primary=True, auto_id=False),
]
schema = CollectionSchema(fields, description="Simple file index")
schema = CollectionSchema(common_fields, description="Simple file index", enable_dynamic_field=True)
client.create_collection(collection_name, schema=schema)
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
self._get_client().upsert(collection_name, data)
payload = dict(data)
payload.setdefault("source_path", payload.get("path"))
payload.setdefault("vector_id", payload.get("path"))
self._get_client().upsert(collection_name, data=[payload])
def delete_vector(self, collection_name: str, path: str) -> None:
self._get_client().delete(collection_name, ids=[path])
client = self._get_client()
escaped = path.replace('"', '\\"')
client.delete(collection_name, filter=f'source_path == "{escaped}"')
def search_vectors(self, collection_name: str, query_embedding, top_k: int):
search_params = {"metric_type": "COSINE"}
return self._get_client().search(
output_fields = [
"path",
"source_path",
"chunk_id",
"mime",
"text",
"start_offset",
"end_offset",
"type",
"name",
]
raw_results = self._get_client().search(
collection_name,
data=[query_embedding],
anns_field="embedding",
search_params=search_params,
limit=top_k,
output_fields=["path"],
output_fields=output_fields,
)
formatted: List[List[Dict[str, Any]]] = []
for hits in raw_results:
bucket: List[Dict[str, Any]] = []
for hit in hits:
if hasattr(hit, "entity"):
entity = dict(getattr(hit, "entity", {}) or {})
hit_id = getattr(hit, "id", None)
distance = getattr(hit, "distance", None)
elif isinstance(hit, dict):
entity = dict((hit.get("entity") or {}))
hit_id = hit.get("id")
distance = hit.get("distance")
else:
entity = {}
hit_id = None
distance = None
entity.setdefault("path", entity.get("source_path"))
bucket.append({
"id": hit_id,
"distance": distance,
"entity": entity,
})
formatted.append(bucket)
return formatted
def search_by_path(self, collection_name: str, query_path: str, top_k: int):
filter_expr = f"path like '%{query_path}%'" if query_path else "path like '%%'"
if query_path:
escaped = query_path.replace('"', '\\"')
filter_expr = f'source_path like "%{escaped}%"'
else:
filter_expr = "source_path like '%%'"
results = self._get_client().query(
collection_name,
filter=filter_expr,
limit=top_k,
output_fields=["path"],
output_fields=[
"path",
"source_path",
"chunk_id",
"mime",
"text",
"start_offset",
"end_offset",
"type",
"name",
],
)
return [[{"id": r["path"], "distance": 1.0, "entity": {"path": r["path"]}} for r in results]]
formatted = []
for row in results:
entity = dict(row)
entity.setdefault("path", entity.get("source_path"))
formatted.append({
"id": entity.get("path"),
"distance": 1.0,
"entity": entity,
})
return [formatted]
def get_all_stats(self) -> Dict[str, Any]:
client = self._get_client()

View File

@@ -58,29 +58,59 @@ class QdrantProvider(BaseVectorProvider):
size = dim if vector and isinstance(dim, int) and dim > 0 else 1
return qmodels.VectorParams(size=size, distance=qmodels.Distance.COSINE)
def _ensure_payload_indexes(self, client: QdrantClient, collection_name: str) -> None:
for field in ("path", "source_path"):
try:
client.create_payload_index(
collection_name=collection_name,
field_name=field,
field_schema="keyword",
)
except Exception as exc: # pragma: no cover - 依赖外部服务
message = str(exc).lower()
if "already exists" in message or "index exists" in message:
continue
# 旧版本 qdrant 可能返回带状态码的异常,这里容忍重复创建
raise
def ensure_collection(self, collection_name: str, vector: bool, dim: int) -> None:
client = self._get_client()
try:
if client.collection_exists(collection_name):
return
exists = client.collection_exists(collection_name)
except Exception as exc: # pragma: no cover - 依赖外部服务
raise RuntimeError(f"Failed to check Qdrant collection '{collection_name}': {exc}") from exc
if exists:
try:
self._ensure_payload_indexes(client, collection_name)
except Exception:
pass
return
vectors_config = self._vector_params(vector, dim)
try:
client.create_collection(collection_name=collection_name, vectors_config=vectors_config)
except Exception as exc: # pragma: no cover
if "already exists" in str(exc).lower():
try:
self._ensure_payload_indexes(client, collection_name)
except Exception:
pass
return
raise RuntimeError(f"Failed to create Qdrant collection '{collection_name}': {exc}") from exc
try:
self._ensure_payload_indexes(client, collection_name)
except Exception:
pass
@staticmethod
def _point_id(path: str) -> str:
return str(uuid5(NAMESPACE_URL, path))
def _point_id(uid: str) -> str:
return str(uuid5(NAMESPACE_URL, uid))
def _prepare_point(self, data: Dict[str, Any]) -> qmodels.PointStruct:
path = data.get("path")
if not path:
uid = data.get("path")
if not uid:
raise ValueError("Qdrant upsert requires 'path' in data")
embedding = data.get("embedding")
@@ -89,8 +119,11 @@ class QdrantProvider(BaseVectorProvider):
else:
vector = [float(x) for x in embedding]
payload = {"path": path}
return qmodels.PointStruct(id=self._point_id(path), vector=vector, payload=payload)
payload = {k: v for k, v in data.items() if k != "embedding"}
payload.setdefault("vector_id", uid)
source_path = payload.get("source_path") or payload.get("path")
payload["path"] = source_path
return qmodels.PointStruct(id=self._point_id(str(uid)), vector=vector, payload=payload)
def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
client = self._get_client()
@@ -99,7 +132,12 @@ class QdrantProvider(BaseVectorProvider):
def delete_vector(self, collection_name: str, path: str) -> None:
client = self._get_client()
selector = qmodels.PointIdsList(points=[self._point_id(path)])
condition = qmodels.FieldCondition(
key="path",
match=qmodels.MatchValue(value=path),
)
flt = qmodels.Filter(must=[condition])
selector = qmodels.FilterSelector(filter=flt)
client.delete(collection_name=collection_name, points_selector=selector, wait=True)
def _format_search_results(self, points: Sequence[qmodels.ScoredPoint]):
@@ -107,7 +145,7 @@ class QdrantProvider(BaseVectorProvider):
{
"id": point.id,
"distance": point.score,
"entity": {"path": (point.payload or {}).get("path")},
"entity": point.payload or {},
}
for point in points
]
@@ -141,11 +179,11 @@ class QdrantProvider(BaseVectorProvider):
break
for record in records:
path = (record.payload or {}).get("path")
if query_path and path:
if query_path not in path:
continue
results.append({"id": record.id, "distance": 1.0, "entity": {"path": path}})
payload = record.payload or {}
path = payload.get("path")
if query_path and path and query_path not in path:
continue
results.append({"id": record.id, "distance": 1.0, "entity": payload})
if len(results) >= top_k:
break