diff --git a/domain/ai/api.py b/domain/ai/api.py index 7da7231..bbd8af7 100644 --- a/domain/ai/api.py +++ b/domain/ai/api.py @@ -267,19 +267,24 @@ async def get_vector_db_config(request: Request, user: User = Depends(get_curren async def update_vector_db_config( request: Request, payload: VectorDBConfigPayload, user: User = Depends(get_current_active_user) ): - entry = get_provider_entry(payload.type) + provider_type = str(payload.type or "").strip() + if not provider_type: + raise HTTPException(status_code=400, detail="向量数据库类型不能为空") + normalized_config = VectorDBConfigManager.normalize_config(payload.config) + + entry = get_provider_entry(provider_type) if not entry: raise HTTPException( - status_code=400, detail=f"未知的向量数据库类型: {payload.type}") + status_code=400, detail=f"未知的向量数据库类型: {provider_type}") if not entry.get("enabled", True): raise HTTPException(status_code=400, detail="该向量数据库类型暂不可用") - provider_cls = get_provider_class(payload.type) + provider_cls = get_provider_class(provider_type) if not provider_cls: raise HTTPException( - status_code=400, detail=f"未找到类型 {payload.type} 对应的实现") + status_code=400, detail=f"未找到类型 {provider_type} 对应的实现") - test_provider = provider_cls(payload.config) + test_provider = provider_cls(normalized_config) try: await test_provider.initialize() except Exception as exc: @@ -293,7 +298,7 @@ async def update_vector_db_config( except Exception: pass - await VectorDBConfigManager.save_config(payload.type, payload.config) + await VectorDBConfigManager.save_config(provider_type, normalized_config) service = VectorDBService() await service.reload() config_data = await service.current_provider() diff --git a/domain/ai/service.py b/domain/ai/service.py index b44d143..afb8d20 100644 --- a/domain/ai/service.py +++ b/domain/ai/service.py @@ -1,7 +1,7 @@ import asyncio import json from collections.abc import Iterable -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, TypeVar import httpx from tortoise.exceptions import DoesNotExist @@ -28,16 +28,37 @@ OPENAI_EMBEDDING_DIMS = { "text-embedding-ada-002": 1536, } +T = TypeVar("T") + class VectorDBConfigManager: TYPE_KEY = "VECTOR_DB_TYPE" CONFIG_KEY = "VECTOR_DB_CONFIG" DEFAULT_TYPE = "milvus_lite" + @classmethod + def normalize_type(cls, provider_type: Any) -> str: + normalized = str(provider_type or cls.DEFAULT_TYPE).strip() + return normalized or cls.DEFAULT_TYPE + + @classmethod + def normalize_config(cls, config: Dict[str, Any] | None) -> Dict[str, Any]: + normalized: Dict[str, Any] = {} + for key, value in (config or {}).items(): + normalized_key = str(key).strip() + if not normalized_key: + continue + if isinstance(value, str): + value = value.strip() + if not value: + continue + normalized[normalized_key] = value + return normalized + @classmethod async def load_config(cls) -> Tuple[str, Dict[str, Any]]: raw_type = await ConfigService.get(cls.TYPE_KEY, cls.DEFAULT_TYPE) - provider_type = str(raw_type or cls.DEFAULT_TYPE) + provider_type = cls.normalize_type(raw_type) raw_config = await ConfigService.get(cls.CONFIG_KEY) config_dict: Dict[str, Any] = {} @@ -48,12 +69,14 @@ class VectorDBConfigManager: config_dict = {} elif isinstance(raw_config, dict): config_dict = raw_config - return provider_type, config_dict + return provider_type, cls.normalize_config(config_dict) @classmethod async def save_config(cls, provider_type: str, config: Dict[str, Any]) -> None: - await ConfigService.set(cls.TYPE_KEY, provider_type) - await ConfigService.set(cls.CONFIG_KEY, json.dumps(config or {})) + normalized_type = cls.normalize_type(provider_type) + normalized_config = cls.normalize_config(config) + await ConfigService.set(cls.TYPE_KEY, normalized_type) + await ConfigService.set(cls.CONFIG_KEY, json.dumps(normalized_config)) @classmethod async def get_type(cls) -> str: @@ -413,6 +436,7 @@ class VectorDBService: self._provider_type: Optional[str] = None self._provider_config: Dict[str, Any] | None = None self._lock = asyncio.Lock() + self._operation_lock = asyncio.Lock() async def _ensure_provider(self) -> BaseVectorProvider: if self._provider is None: @@ -449,33 +473,38 @@ class VectorDBService: self._provider_config = normalized_config return provider + async def _run_provider_call(self, provider: BaseVectorProvider, method_name: str, *args, **kwargs) -> T: + method = getattr(provider, method_name) + async with self._operation_lock: + return await asyncio.to_thread(method, *args, **kwargs) + async def ensure_collection(self, collection_name: str, vector: bool = True, dim: int = DEFAULT_VECTOR_DIMENSION) -> None: provider = await self._ensure_provider() - provider.ensure_collection(collection_name, vector, dim) + await self._run_provider_call(provider, "ensure_collection", collection_name, vector, dim) async def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None: provider = await self._ensure_provider() - provider.upsert_vector(collection_name, data) + await self._run_provider_call(provider, "upsert_vector", collection_name, data) async def delete_vector(self, collection_name: str, path: str) -> None: provider = await self._ensure_provider() - provider.delete_vector(collection_name, path) + await self._run_provider_call(provider, "delete_vector", collection_name, path) async def search_vectors(self, collection_name: str, query_embedding, top_k: int = 5): provider = await self._ensure_provider() - return provider.search_vectors(collection_name, query_embedding, top_k) + return await self._run_provider_call(provider, "search_vectors", collection_name, query_embedding, top_k) async def search_by_path(self, collection_name: str, query_path: str, top_k: int = 20): provider = await self._ensure_provider() - return provider.search_by_path(collection_name, query_path, top_k) + return await self._run_provider_call(provider, "search_by_path", collection_name, query_path, top_k) async def get_all_stats(self) -> Dict[str, Any]: provider = await self._ensure_provider() - return provider.get_all_stats() + return await self._run_provider_call(provider, "get_all_stats") async def clear_all_data(self) -> None: provider = await self._ensure_provider() - provider.clear_all_data() + await self._run_provider_call(provider, "clear_all_data") async def current_provider(self) -> Dict[str, Any]: provider_type, provider_config = await VectorDBConfigManager.load_config() diff --git a/domain/ai/vector_providers/milvus_lite.py b/domain/ai/vector_providers/milvus_lite.py index fd71217..c358b90 100644 --- a/domain/ai/vector_providers/milvus_lite.py +++ b/domain/ai/vector_providers/milvus_lite.py @@ -1,3 +1,4 @@ +import asyncio from pathlib import Path from typing import Any, Dict, List, Optional @@ -23,12 +24,14 @@ class MilvusLiteProvider(BaseVectorProvider): def __init__(self, config: Dict[str, Any] | None = None): super().__init__(config) - self.db_path = Path(self.config.get("db_path") or "data/db/milvus.db") + raw_db_path = self.config.get("db_path") + db_path = str(raw_db_path).strip() if raw_db_path is not None else "" + self.db_path = Path(db_path or "data/db/milvus.db") self.client: MilvusClient | None = None async def initialize(self) -> None: try: - self.client = MilvusClient(str(self.db_path)) + self.client = await asyncio.to_thread(MilvusClient, str(self.db_path)) except Exception as exc: # pragma: no cover - depends on local environment raise RuntimeError(f"Failed to open Milvus Lite at {self.db_path}: {exc}") from exc diff --git a/domain/ai/vector_providers/milvus_server.py b/domain/ai/vector_providers/milvus_server.py index 73be064..06dd3a0 100644 --- a/domain/ai/vector_providers/milvus_server.py +++ b/domain/ai/vector_providers/milvus_server.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Dict, List, Optional from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient @@ -32,11 +33,14 @@ class MilvusServerProvider(BaseVectorProvider): self.client: MilvusClient | None = None async def initialize(self) -> None: - uri = self.config.get("uri") + uri = str(self.config.get("uri") or "").strip() if not uri: raise RuntimeError("Milvus Server URI is required") + token = self.config.get("token") + if isinstance(token, str): + token = token.strip() or None try: - self.client = MilvusClient(uri=uri, token=self.config.get("token")) + self.client = await asyncio.to_thread(MilvusClient, uri=uri, token=token) except Exception as exc: # pragma: no cover - depends on remote availability raise RuntimeError(f"Failed to connect to Milvus Server {uri}: {exc}") from exc diff --git a/domain/ai/vector_providers/qdrant.py b/domain/ai/vector_providers/qdrant.py index cb55abb..f18bb52 100644 --- a/domain/ai/vector_providers/qdrant.py +++ b/domain/ai/vector_providers/qdrant.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Dict, List, Optional, Sequence from uuid import NAMESPACE_URL, uuid5 @@ -40,7 +41,7 @@ class QdrantProvider(BaseVectorProvider): api_key = (self.config.get("api_key") or None) or None try: client = QdrantClient(url=url, api_key=api_key) - client.get_collections() + await asyncio.to_thread(client.get_collections) self.client = client except Exception as exc: # pragma: no cover - 依赖外部服务 raise RuntimeError(f"Failed to connect to Qdrant at {url}: {exc}") from exc