feat: enhance vector database configuration handling and improve provider initialization

This commit is contained in:
shiyu
2026-04-10 19:40:41 +08:00
parent 0609cf6971
commit 398dbcf8ae
5 changed files with 65 additions and 23 deletions

View File

@@ -267,19 +267,24 @@ async def get_vector_db_config(request: Request, user: User = Depends(get_curren
async def update_vector_db_config(
request: Request, payload: VectorDBConfigPayload, user: User = Depends(get_current_active_user)
):
entry = get_provider_entry(payload.type)
provider_type = str(payload.type or "").strip()
if not provider_type:
raise HTTPException(status_code=400, detail="向量数据库类型不能为空")
normalized_config = VectorDBConfigManager.normalize_config(payload.config)
entry = get_provider_entry(provider_type)
if not entry:
raise HTTPException(
status_code=400, detail=f"未知的向量数据库类型: {payload.type}")
status_code=400, detail=f"未知的向量数据库类型: {provider_type}")
if not entry.get("enabled", True):
raise HTTPException(status_code=400, detail="该向量数据库类型暂不可用")
provider_cls = get_provider_class(payload.type)
provider_cls = get_provider_class(provider_type)
if not provider_cls:
raise HTTPException(
status_code=400, detail=f"未找到类型 {payload.type} 对应的实现")
status_code=400, detail=f"未找到类型 {provider_type} 对应的实现")
test_provider = provider_cls(payload.config)
test_provider = provider_cls(normalized_config)
try:
await test_provider.initialize()
except Exception as exc:
@@ -293,7 +298,7 @@ async def update_vector_db_config(
except Exception:
pass
await VectorDBConfigManager.save_config(payload.type, payload.config)
await VectorDBConfigManager.save_config(provider_type, normalized_config)
service = VectorDBService()
await service.reload()
config_data = await service.current_provider()

View File

@@ -1,7 +1,7 @@
import asyncio
import json
from collections.abc import Iterable
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, TypeVar
import httpx
from tortoise.exceptions import DoesNotExist
@@ -28,16 +28,37 @@ OPENAI_EMBEDDING_DIMS = {
"text-embedding-ada-002": 1536,
}
T = TypeVar("T")
class VectorDBConfigManager:
TYPE_KEY = "VECTOR_DB_TYPE"
CONFIG_KEY = "VECTOR_DB_CONFIG"
DEFAULT_TYPE = "milvus_lite"
@classmethod
def normalize_type(cls, provider_type: Any) -> str:
normalized = str(provider_type or cls.DEFAULT_TYPE).strip()
return normalized or cls.DEFAULT_TYPE
@classmethod
def normalize_config(cls, config: Dict[str, Any] | None) -> Dict[str, Any]:
normalized: Dict[str, Any] = {}
for key, value in (config or {}).items():
normalized_key = str(key).strip()
if not normalized_key:
continue
if isinstance(value, str):
value = value.strip()
if not value:
continue
normalized[normalized_key] = value
return normalized
@classmethod
async def load_config(cls) -> Tuple[str, Dict[str, Any]]:
raw_type = await ConfigService.get(cls.TYPE_KEY, cls.DEFAULT_TYPE)
provider_type = str(raw_type or cls.DEFAULT_TYPE)
provider_type = cls.normalize_type(raw_type)
raw_config = await ConfigService.get(cls.CONFIG_KEY)
config_dict: Dict[str, Any] = {}
@@ -48,12 +69,14 @@ class VectorDBConfigManager:
config_dict = {}
elif isinstance(raw_config, dict):
config_dict = raw_config
return provider_type, config_dict
return provider_type, cls.normalize_config(config_dict)
@classmethod
async def save_config(cls, provider_type: str, config: Dict[str, Any]) -> None:
await ConfigService.set(cls.TYPE_KEY, provider_type)
await ConfigService.set(cls.CONFIG_KEY, json.dumps(config or {}))
normalized_type = cls.normalize_type(provider_type)
normalized_config = cls.normalize_config(config)
await ConfigService.set(cls.TYPE_KEY, normalized_type)
await ConfigService.set(cls.CONFIG_KEY, json.dumps(normalized_config))
@classmethod
async def get_type(cls) -> str:
@@ -413,6 +436,7 @@ class VectorDBService:
self._provider_type: Optional[str] = None
self._provider_config: Dict[str, Any] | None = None
self._lock = asyncio.Lock()
self._operation_lock = asyncio.Lock()
async def _ensure_provider(self) -> BaseVectorProvider:
if self._provider is None:
@@ -449,33 +473,38 @@ class VectorDBService:
self._provider_config = normalized_config
return provider
async def _run_provider_call(self, provider: BaseVectorProvider, method_name: str, *args, **kwargs) -> T:
method = getattr(provider, method_name)
async with self._operation_lock:
return await asyncio.to_thread(method, *args, **kwargs)
async def ensure_collection(self, collection_name: str, vector: bool = True, dim: int = DEFAULT_VECTOR_DIMENSION) -> None:
provider = await self._ensure_provider()
provider.ensure_collection(collection_name, vector, dim)
await self._run_provider_call(provider, "ensure_collection", collection_name, vector, dim)
async def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
provider = await self._ensure_provider()
provider.upsert_vector(collection_name, data)
await self._run_provider_call(provider, "upsert_vector", collection_name, data)
async def delete_vector(self, collection_name: str, path: str) -> None:
provider = await self._ensure_provider()
provider.delete_vector(collection_name, path)
await self._run_provider_call(provider, "delete_vector", collection_name, path)
async def search_vectors(self, collection_name: str, query_embedding, top_k: int = 5):
provider = await self._ensure_provider()
return provider.search_vectors(collection_name, query_embedding, top_k)
return await self._run_provider_call(provider, "search_vectors", collection_name, query_embedding, top_k)
async def search_by_path(self, collection_name: str, query_path: str, top_k: int = 20):
provider = await self._ensure_provider()
return provider.search_by_path(collection_name, query_path, top_k)
return await self._run_provider_call(provider, "search_by_path", collection_name, query_path, top_k)
async def get_all_stats(self) -> Dict[str, Any]:
provider = await self._ensure_provider()
return provider.get_all_stats()
return await self._run_provider_call(provider, "get_all_stats")
async def clear_all_data(self) -> None:
provider = await self._ensure_provider()
provider.clear_all_data()
await self._run_provider_call(provider, "clear_all_data")
async def current_provider(self) -> Dict[str, Any]:
provider_type, provider_config = await VectorDBConfigManager.load_config()

View File

@@ -1,3 +1,4 @@
import asyncio
from pathlib import Path
from typing import Any, Dict, List, Optional
@@ -23,12 +24,14 @@ class MilvusLiteProvider(BaseVectorProvider):
def __init__(self, config: Dict[str, Any] | None = None):
super().__init__(config)
self.db_path = Path(self.config.get("db_path") or "data/db/milvus.db")
raw_db_path = self.config.get("db_path")
db_path = str(raw_db_path).strip() if raw_db_path is not None else ""
self.db_path = Path(db_path or "data/db/milvus.db")
self.client: MilvusClient | None = None
async def initialize(self) -> None:
try:
self.client = MilvusClient(str(self.db_path))
self.client = await asyncio.to_thread(MilvusClient, str(self.db_path))
except Exception as exc: # pragma: no cover - depends on local environment
raise RuntimeError(f"Failed to open Milvus Lite at {self.db_path}: {exc}") from exc

View File

@@ -1,3 +1,4 @@
import asyncio
from typing import Any, Dict, List, Optional
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
@@ -32,11 +33,14 @@ class MilvusServerProvider(BaseVectorProvider):
self.client: MilvusClient | None = None
async def initialize(self) -> None:
uri = self.config.get("uri")
uri = str(self.config.get("uri") or "").strip()
if not uri:
raise RuntimeError("Milvus Server URI is required")
token = self.config.get("token")
if isinstance(token, str):
token = token.strip() or None
try:
self.client = MilvusClient(uri=uri, token=self.config.get("token"))
self.client = await asyncio.to_thread(MilvusClient, uri=uri, token=token)
except Exception as exc: # pragma: no cover - depends on remote availability
raise RuntimeError(f"Failed to connect to Milvus Server {uri}: {exc}") from exc

View File

@@ -1,3 +1,4 @@
import asyncio
from typing import Any, Dict, List, Optional, Sequence
from uuid import NAMESPACE_URL, uuid5
@@ -40,7 +41,7 @@ class QdrantProvider(BaseVectorProvider):
api_key = (self.config.get("api_key") or None) or None
try:
client = QdrantClient(url=url, api_key=api_key)
client.get_collections()
await asyncio.to_thread(client.get_collections)
self.client = client
except Exception as exc: # pragma: no cover - 依赖外部服务
raise RuntimeError(f"Failed to connect to Qdrant at {url}: {exc}") from exc