mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-05-10 17:43:35 +08:00
feat: enhance vector database configuration handling and improve provider initialization
This commit is contained in:
@@ -267,19 +267,24 @@ async def get_vector_db_config(request: Request, user: User = Depends(get_curren
|
||||
async def update_vector_db_config(
|
||||
request: Request, payload: VectorDBConfigPayload, user: User = Depends(get_current_active_user)
|
||||
):
|
||||
entry = get_provider_entry(payload.type)
|
||||
provider_type = str(payload.type or "").strip()
|
||||
if not provider_type:
|
||||
raise HTTPException(status_code=400, detail="向量数据库类型不能为空")
|
||||
normalized_config = VectorDBConfigManager.normalize_config(payload.config)
|
||||
|
||||
entry = get_provider_entry(provider_type)
|
||||
if not entry:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"未知的向量数据库类型: {payload.type}")
|
||||
status_code=400, detail=f"未知的向量数据库类型: {provider_type}")
|
||||
if not entry.get("enabled", True):
|
||||
raise HTTPException(status_code=400, detail="该向量数据库类型暂不可用")
|
||||
|
||||
provider_cls = get_provider_class(payload.type)
|
||||
provider_cls = get_provider_class(provider_type)
|
||||
if not provider_cls:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"未找到类型 {payload.type} 对应的实现")
|
||||
status_code=400, detail=f"未找到类型 {provider_type} 对应的实现")
|
||||
|
||||
test_provider = provider_cls(payload.config)
|
||||
test_provider = provider_cls(normalized_config)
|
||||
try:
|
||||
await test_provider.initialize()
|
||||
except Exception as exc:
|
||||
@@ -293,7 +298,7 @@ async def update_vector_db_config(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await VectorDBConfigManager.save_config(payload.type, payload.config)
|
||||
await VectorDBConfigManager.save_config(provider_type, normalized_config)
|
||||
service = VectorDBService()
|
||||
await service.reload()
|
||||
config_data = await service.current_provider()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypeVar
|
||||
|
||||
import httpx
|
||||
from tortoise.exceptions import DoesNotExist
|
||||
@@ -28,16 +28,37 @@ OPENAI_EMBEDDING_DIMS = {
|
||||
"text-embedding-ada-002": 1536,
|
||||
}
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class VectorDBConfigManager:
|
||||
TYPE_KEY = "VECTOR_DB_TYPE"
|
||||
CONFIG_KEY = "VECTOR_DB_CONFIG"
|
||||
DEFAULT_TYPE = "milvus_lite"
|
||||
|
||||
@classmethod
|
||||
def normalize_type(cls, provider_type: Any) -> str:
|
||||
normalized = str(provider_type or cls.DEFAULT_TYPE).strip()
|
||||
return normalized or cls.DEFAULT_TYPE
|
||||
|
||||
@classmethod
|
||||
def normalize_config(cls, config: Dict[str, Any] | None) -> Dict[str, Any]:
|
||||
normalized: Dict[str, Any] = {}
|
||||
for key, value in (config or {}).items():
|
||||
normalized_key = str(key).strip()
|
||||
if not normalized_key:
|
||||
continue
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
if not value:
|
||||
continue
|
||||
normalized[normalized_key] = value
|
||||
return normalized
|
||||
|
||||
@classmethod
|
||||
async def load_config(cls) -> Tuple[str, Dict[str, Any]]:
|
||||
raw_type = await ConfigService.get(cls.TYPE_KEY, cls.DEFAULT_TYPE)
|
||||
provider_type = str(raw_type or cls.DEFAULT_TYPE)
|
||||
provider_type = cls.normalize_type(raw_type)
|
||||
|
||||
raw_config = await ConfigService.get(cls.CONFIG_KEY)
|
||||
config_dict: Dict[str, Any] = {}
|
||||
@@ -48,12 +69,14 @@ class VectorDBConfigManager:
|
||||
config_dict = {}
|
||||
elif isinstance(raw_config, dict):
|
||||
config_dict = raw_config
|
||||
return provider_type, config_dict
|
||||
return provider_type, cls.normalize_config(config_dict)
|
||||
|
||||
@classmethod
|
||||
async def save_config(cls, provider_type: str, config: Dict[str, Any]) -> None:
|
||||
await ConfigService.set(cls.TYPE_KEY, provider_type)
|
||||
await ConfigService.set(cls.CONFIG_KEY, json.dumps(config or {}))
|
||||
normalized_type = cls.normalize_type(provider_type)
|
||||
normalized_config = cls.normalize_config(config)
|
||||
await ConfigService.set(cls.TYPE_KEY, normalized_type)
|
||||
await ConfigService.set(cls.CONFIG_KEY, json.dumps(normalized_config))
|
||||
|
||||
@classmethod
|
||||
async def get_type(cls) -> str:
|
||||
@@ -413,6 +436,7 @@ class VectorDBService:
|
||||
self._provider_type: Optional[str] = None
|
||||
self._provider_config: Dict[str, Any] | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._operation_lock = asyncio.Lock()
|
||||
|
||||
async def _ensure_provider(self) -> BaseVectorProvider:
|
||||
if self._provider is None:
|
||||
@@ -449,33 +473,38 @@ class VectorDBService:
|
||||
self._provider_config = normalized_config
|
||||
return provider
|
||||
|
||||
async def _run_provider_call(self, provider: BaseVectorProvider, method_name: str, *args, **kwargs) -> T:
|
||||
method = getattr(provider, method_name)
|
||||
async with self._operation_lock:
|
||||
return await asyncio.to_thread(method, *args, **kwargs)
|
||||
|
||||
async def ensure_collection(self, collection_name: str, vector: bool = True, dim: int = DEFAULT_VECTOR_DIMENSION) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.ensure_collection(collection_name, vector, dim)
|
||||
await self._run_provider_call(provider, "ensure_collection", collection_name, vector, dim)
|
||||
|
||||
async def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.upsert_vector(collection_name, data)
|
||||
await self._run_provider_call(provider, "upsert_vector", collection_name, data)
|
||||
|
||||
async def delete_vector(self, collection_name: str, path: str) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.delete_vector(collection_name, path)
|
||||
await self._run_provider_call(provider, "delete_vector", collection_name, path)
|
||||
|
||||
async def search_vectors(self, collection_name: str, query_embedding, top_k: int = 5):
|
||||
provider = await self._ensure_provider()
|
||||
return provider.search_vectors(collection_name, query_embedding, top_k)
|
||||
return await self._run_provider_call(provider, "search_vectors", collection_name, query_embedding, top_k)
|
||||
|
||||
async def search_by_path(self, collection_name: str, query_path: str, top_k: int = 20):
|
||||
provider = await self._ensure_provider()
|
||||
return provider.search_by_path(collection_name, query_path, top_k)
|
||||
return await self._run_provider_call(provider, "search_by_path", collection_name, query_path, top_k)
|
||||
|
||||
async def get_all_stats(self) -> Dict[str, Any]:
|
||||
provider = await self._ensure_provider()
|
||||
return provider.get_all_stats()
|
||||
return await self._run_provider_call(provider, "get_all_stats")
|
||||
|
||||
async def clear_all_data(self) -> None:
|
||||
provider = await self._ensure_provider()
|
||||
provider.clear_all_data()
|
||||
await self._run_provider_call(provider, "clear_all_data")
|
||||
|
||||
async def current_provider(self) -> Dict[str, Any]:
|
||||
provider_type, provider_config = await VectorDBConfigManager.load_config()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -23,12 +24,14 @@ class MilvusLiteProvider(BaseVectorProvider):
|
||||
|
||||
def __init__(self, config: Dict[str, Any] | None = None):
|
||||
super().__init__(config)
|
||||
self.db_path = Path(self.config.get("db_path") or "data/db/milvus.db")
|
||||
raw_db_path = self.config.get("db_path")
|
||||
db_path = str(raw_db_path).strip() if raw_db_path is not None else ""
|
||||
self.db_path = Path(db_path or "data/db/milvus.db")
|
||||
self.client: MilvusClient | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
self.client = MilvusClient(str(self.db_path))
|
||||
self.client = await asyncio.to_thread(MilvusClient, str(self.db_path))
|
||||
except Exception as exc: # pragma: no cover - depends on local environment
|
||||
raise RuntimeError(f"Failed to open Milvus Lite at {self.db_path}: {exc}") from exc
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
|
||||
@@ -32,11 +33,14 @@ class MilvusServerProvider(BaseVectorProvider):
|
||||
self.client: MilvusClient | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
uri = self.config.get("uri")
|
||||
uri = str(self.config.get("uri") or "").strip()
|
||||
if not uri:
|
||||
raise RuntimeError("Milvus Server URI is required")
|
||||
token = self.config.get("token")
|
||||
if isinstance(token, str):
|
||||
token = token.strip() or None
|
||||
try:
|
||||
self.client = MilvusClient(uri=uri, token=self.config.get("token"))
|
||||
self.client = await asyncio.to_thread(MilvusClient, uri=uri, token=token)
|
||||
except Exception as exc: # pragma: no cover - depends on remote availability
|
||||
raise RuntimeError(f"Failed to connect to Milvus Server {uri}: {exc}") from exc
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from uuid import NAMESPACE_URL, uuid5
|
||||
|
||||
@@ -40,7 +41,7 @@ class QdrantProvider(BaseVectorProvider):
|
||||
api_key = (self.config.get("api_key") or None) or None
|
||||
try:
|
||||
client = QdrantClient(url=url, api_key=api_key)
|
||||
client.get_collections()
|
||||
await asyncio.to_thread(client.get_collections)
|
||||
self.client = client
|
||||
except Exception as exc: # pragma: no cover - 依赖外部服务
|
||||
raise RuntimeError(f"Failed to connect to Qdrant at {url}: {exc}") from exc
|
||||
|
||||
Reference in New Issue
Block a user