Files
Foxel/domain/ai/service.py
shiyu 6b2ada0b42 refactor: imports and reorganize domain structure
- Updated import statements across multiple modules to use relative imports for better encapsulation.
- Consolidated and organized the `__init__.py` files in various domain packages to expose necessary classes and functions.
- Improved code readability and maintainability by grouping related imports and removing unused ones.
- Ensured consistent import patterns across the domain, enhancing the overall structure of the codebase.
2026-01-09 17:28:10 +08:00

502 lines
19 KiB
Python

import asyncio
import json
from collections.abc import Iterable
from typing import Any, Dict, List, Optional, Tuple
import httpx
from tortoise.exceptions import DoesNotExist
from tortoise.transactions import in_transaction
from domain.config import ConfigService
from models.database import AIDefaultModel, AIModel, AIProvider
from .types import ABILITIES, normalize_capabilities
from .vector_providers import (
BaseVectorProvider,
get_provider_class,
get_provider_entry,
list_providers,
)
DEFAULT_VECTOR_DIMENSION = 4096
VECTOR_COLLECTION_NAME = "vector_collection"
FILE_COLLECTION_NAME = "file_collection"
OPENAI_EMBEDDING_DIMS = {
"text-embedding-3-large": 3072,
"text-embedding-3-small": 1536,
"text-embedding-ada-002": 1536,
}
class VectorDBConfigManager:
TYPE_KEY = "VECTOR_DB_TYPE"
CONFIG_KEY = "VECTOR_DB_CONFIG"
DEFAULT_TYPE = "milvus_lite"
@classmethod
async def load_config(cls) -> Tuple[str, Dict[str, Any]]:
raw_type = await ConfigService.get(cls.TYPE_KEY, cls.DEFAULT_TYPE)
provider_type = str(raw_type or cls.DEFAULT_TYPE)
raw_config = await ConfigService.get(cls.CONFIG_KEY)
config_dict: Dict[str, Any] = {}
if isinstance(raw_config, str) and raw_config:
try:
config_dict = json.loads(raw_config)
except json.JSONDecodeError:
config_dict = {}
elif isinstance(raw_config, dict):
config_dict = raw_config
return provider_type, config_dict
@classmethod
async def save_config(cls, provider_type: str, config: Dict[str, Any]) -> None:
await ConfigService.set(cls.TYPE_KEY, provider_type)
await ConfigService.set(cls.CONFIG_KEY, json.dumps(config or {}))
@classmethod
async def get_type(cls) -> str:
provider_type, _ = await cls.load_config()
return provider_type
@classmethod
async def get_config(cls) -> Dict[str, Any]:
_, config = await cls.load_config()
return config
def _normalize_embedding_dim(value: Any) -> Optional[int]:
if value is None:
return None
try:
casted = int(value)
except (TypeError, ValueError):
return None
return casted if casted > 0 else None
def _apply_embedding_dim_to_metadata(
data: Dict[str, Any],
embedding_dim: Optional[int],
base_metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
source = base_metadata if isinstance(base_metadata, dict) else {}
metadata: Dict[str, Any] = dict(source)
override = data.get("metadata")
if isinstance(override, dict) and override:
metadata.update(override)
if embedding_dim is None:
metadata.pop("embedding_dimensions", None)
else:
metadata["embedding_dimensions"] = embedding_dim
data["metadata"] = metadata or None
return data
def infer_openai_capabilities(model_id: str) -> Tuple[List[str], Optional[int]]:
lower = model_id.lower()
caps = set()
if any(keyword in lower for keyword in ["gpt", "chat", "turbo", "o1", "sonnet", "haiku", "thinking"]):
caps.update({"chat", "tools"})
if any(keyword in lower for keyword in ["vision", "gpt-4o", "gpt-4.1", "o1", "vision-preview", "omni"]):
caps.add("vision")
if any(keyword in lower for keyword in ["embed", "embedding"]):
caps.add("embedding")
if "rerank" in lower or "re-rank" in lower:
caps.add("rerank")
if any(keyword in lower for keyword in ["tts", "speech", "audio"]):
caps.add("voice")
embedding_dim = OPENAI_EMBEDDING_DIMS.get(model_id)
return normalize_capabilities(caps), embedding_dim
def infer_gemini_capabilities(methods: Iterable[str]) -> List[str]:
caps = set()
for method in methods:
m = method.lower()
if m in {"generatecontent", "counttokens"}:
caps.update({"chat", "tools", "vision"})
if m == "embedcontent":
caps.add("embedding")
if m in {"generatespeech", "audiogeneration"}:
caps.add("voice")
if m == "rerank":
caps.add("rerank")
return normalize_capabilities(caps)
def serialize_provider(provider: AIProvider) -> Dict[str, Any]:
return {
"id": provider.id,
"name": provider.name,
"identifier": provider.identifier,
"provider_type": provider.provider_type,
"api_format": provider.api_format,
"base_url": provider.base_url,
"api_key": provider.api_key,
"logo_url": provider.logo_url,
"extra_config": provider.extra_config or {},
"created_at": provider.created_at,
"updated_at": provider.updated_at,
}
def model_to_dict(model: AIModel, provider: Optional[AIProvider] = None) -> Dict[str, Any]:
provider_obj = provider or getattr(model, "provider", None)
provider_data = serialize_provider(provider_obj) if provider_obj else None
return {
"id": model.id,
"provider_id": model.provider_id,
"name": model.name,
"display_name": model.display_name,
"description": model.description,
"capabilities": normalize_capabilities(model.capabilities),
"context_window": model.context_window,
"embedding_dimensions": model.embedding_dimensions,
"metadata": model.metadata or {},
"created_at": model.created_at,
"updated_at": model.updated_at,
"provider": provider_data,
}
def provider_to_dict(provider: AIProvider, models: Optional[List[AIModel]] = None) -> Dict[str, Any]:
data = serialize_provider(provider)
if models is not None:
data["models"] = [model_to_dict(m, provider=provider) for m in models]
return data
class AIProviderService:
@classmethod
async def list_providers(cls) -> List[Dict[str, Any]]:
providers = await AIProvider.all().order_by("id").prefetch_related("models")
return [provider_to_dict(p, models=list(p.models)) for p in providers]
@classmethod
async def get_provider(cls, provider_id: int, with_models: bool = False) -> Dict[str, Any]:
if with_models:
provider = await AIProvider.get(id=provider_id)
models = await provider.models.all()
return provider_to_dict(provider, models=models)
provider = await AIProvider.get(id=provider_id)
return provider_to_dict(provider)
@classmethod
async def create_provider(cls, payload: Dict[str, Any]) -> Dict[str, Any]:
data = payload.copy()
data.setdefault("extra_config", {})
provider = await AIProvider.create(**data)
return provider_to_dict(provider)
@classmethod
async def update_provider(cls, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
provider = await AIProvider.get(id=provider_id)
for field, value in payload.items():
setattr(provider, field, value)
await provider.save()
return provider_to_dict(provider)
@classmethod
async def delete_provider(cls, provider_id: int) -> None:
await AIProvider.filter(id=provider_id).delete()
@classmethod
async def list_models(cls, provider_id: int) -> List[Dict[str, Any]]:
models = await AIModel.filter(provider_id=provider_id).order_by("id").prefetch_related("provider")
return [model_to_dict(m) for m in models]
@classmethod
async def create_model(cls, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
data = payload.copy()
data["provider_id"] = provider_id
data["capabilities"] = normalize_capabilities(data.get("capabilities"))
embedding_dim = _normalize_embedding_dim(data.pop("embedding_dimensions", None))
data = _apply_embedding_dim_to_metadata(data, embedding_dim)
model = await AIModel.create(**data)
await model.fetch_related("provider")
return model_to_dict(model)
@classmethod
async def update_model(cls, model_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
model = await AIModel.get(id=model_id)
data = payload.copy()
if "capabilities" in data:
data["capabilities"] = normalize_capabilities(data.get("capabilities"))
embedding_dim = None
if "embedding_dimensions" in data:
embedding_dim = _normalize_embedding_dim(data.pop("embedding_dimensions", None))
_apply_embedding_dim_to_metadata(data, embedding_dim, base_metadata=model.metadata)
for field, value in data.items():
setattr(model, field, value)
if embedding_dim is not None or ("embedding_dimensions" in payload and embedding_dim is None):
model.embedding_dimensions = embedding_dim
await model.save()
await model.fetch_related("provider")
return model_to_dict(model)
@classmethod
async def delete_model(cls, model_id: int) -> None:
await AIModel.filter(id=model_id).delete()
@classmethod
async def fetch_remote_models(cls, provider_id: int) -> List[Dict[str, Any]]:
provider = await AIProvider.get(id=provider_id)
return await cls._get_remote_models(provider)
@classmethod
async def _get_remote_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
if not provider.base_url:
raise ValueError("Provider base_url is required for syncing models")
fmt = (provider.api_format or "").lower()
if fmt not in {"openai", "gemini"}:
raise ValueError(f"Unsupported api_format '{provider.api_format}' for syncing models")
if fmt == "openai":
return await cls._fetch_openai_models(provider)
return await cls._fetch_gemini_models(provider)
@classmethod
async def sync_models(cls, provider_id: int) -> Dict[str, int]:
provider = await AIProvider.get(id=provider_id)
remote_models = await cls._get_remote_models(provider)
created = 0
updated = 0
for entry in remote_models:
defaults = entry.copy()
model_id = defaults.pop("name")
defaults["capabilities"] = normalize_capabilities(defaults.get("capabilities"))
embedding_dim = _normalize_embedding_dim(defaults.pop("embedding_dimensions", None))
defaults = _apply_embedding_dim_to_metadata(defaults, embedding_dim)
obj, is_created = await AIModel.get_or_create(
provider_id=provider.id,
name=model_id,
defaults=defaults,
)
if is_created:
created += 1
continue
for field, value in defaults.items():
setattr(obj, field, value)
if embedding_dim is not None or ("embedding_dimensions" in entry and embedding_dim is None):
obj.embedding_dimensions = embedding_dim
await obj.save()
updated += 1
return {"created": created, "updated": updated}
@classmethod
async def get_default_models(cls) -> Dict[str, Optional[Dict[str, Any]]]:
defaults = await AIDefaultModel.all().prefetch_related("model__provider")
result: Dict[str, Optional[Dict[str, Any]]] = {ability: None for ability in ABILITIES}
for item in defaults:
result[item.ability] = model_to_dict(item.model, provider=item.model.provider) # type: ignore[attr-defined]
return result
@classmethod
async def set_default_models(cls, mapping: Dict[str, Optional[int]]) -> Dict[str, Optional[Dict[str, Any]]]:
normalized = {ability: mapping.get(ability) for ability in ABILITIES}
async with in_transaction() as connection:
for ability, model_id in normalized.items():
record = await AIDefaultModel.get_or_none(ability=ability)
if model_id:
try:
model = await AIModel.get(id=model_id)
except DoesNotExist:
raise ValueError(f"Model {model_id} not found")
if record:
record.model_id = model_id
await record.save(using_db=connection)
else:
await AIDefaultModel.create(ability=ability, model_id=model_id)
elif record:
await record.delete(using_db=connection)
return await cls.get_default_models()
@classmethod
async def get_default_model(cls, ability: str) -> Optional[AIModel]:
ability_key = ability.lower()
if ability_key not in ABILITIES:
return None
record = await AIDefaultModel.get_or_none(ability=ability_key)
if not record:
return None
model = await AIModel.get_or_none(id=record.model_id)
if model:
await model.fetch_related("provider")
return model
@classmethod
async def _fetch_openai_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
base_url = provider.base_url.rstrip("/")
url = f"{base_url}/models"
headers = {}
if provider.api_key:
headers["Authorization"] = f"Bearer {provider.api_key}"
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url, headers=headers)
response.raise_for_status()
payload = response.json()
data = payload.get("data", [])
entries: List[Dict[str, Any]] = []
for item in data:
model_id = item.get("id")
if not model_id:
continue
capabilities, embedding_dim = infer_openai_capabilities(model_id)
entries.append({
"name": model_id,
"display_name": item.get("display_name"),
"description": item.get("description"),
"capabilities": capabilities,
"context_window": item.get("context_window"),
"embedding_dimensions": embedding_dim,
"metadata": item,
})
return entries
@classmethod
async def _fetch_gemini_models(cls, provider: AIProvider) -> List[Dict[str, Any]]:
base_url = provider.base_url.rstrip("/")
suffix = "/models"
if provider.api_key:
suffix += f"?key={provider.api_key}"
url = f"{base_url}{suffix}"
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
payload = response.json()
data = payload.get("models", [])
entries: List[Dict[str, Any]] = []
for item in data:
model_id = item.get("name")
if not model_id:
continue
methods = item.get("supportedGenerationMethods") or []
capabilities = infer_gemini_capabilities(methods)
entries.append({
"name": model_id,
"display_name": item.get("displayName"),
"description": item.get("description"),
"capabilities": capabilities,
"context_window": item.get("inputTokenLimit"),
"embedding_dimensions": item.get("embeddingDimensions"),
"metadata": item,
})
return entries
class VectorDBService:
_instance: Optional["VectorDBService"] = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, "_provider"):
self._provider: Optional[BaseVectorProvider] = None
self._provider_type: Optional[str] = None
self._provider_config: Dict[str, Any] | None = None
self._lock = asyncio.Lock()
async def _ensure_provider(self) -> BaseVectorProvider:
if self._provider is None:
await self.reload()
assert self._provider is not None
return self._provider
async def reload(self) -> BaseVectorProvider:
async with self._lock:
provider_type, provider_config = await VectorDBConfigManager.load_config()
normalized_config = dict(provider_config or {})
if (
self._provider
and self._provider_type == provider_type
and self._provider_config == normalized_config
):
return self._provider
entry = get_provider_entry(provider_type)
if not entry:
raise RuntimeError(f"Unknown vector database provider: {provider_type}")
if not entry.get("enabled", True):
raise RuntimeError(f"Vector database provider '{provider_type}' is disabled")
provider_cls = get_provider_class(provider_type)
if not provider_cls:
raise RuntimeError(f"Provider class not found for '{provider_type}'")
provider = provider_cls(provider_config)
await provider.initialize()
self._provider = provider
self._provider_type = provider_type
self._provider_config = normalized_config
return provider
async def ensure_collection(self, collection_name: str, vector: bool = True, dim: int = DEFAULT_VECTOR_DIMENSION) -> None:
provider = await self._ensure_provider()
provider.ensure_collection(collection_name, vector, dim)
async def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
provider = await self._ensure_provider()
provider.upsert_vector(collection_name, data)
async def delete_vector(self, collection_name: str, path: str) -> None:
provider = await self._ensure_provider()
provider.delete_vector(collection_name, path)
async def search_vectors(self, collection_name: str, query_embedding, top_k: int = 5):
provider = await self._ensure_provider()
return provider.search_vectors(collection_name, query_embedding, top_k)
async def search_by_path(self, collection_name: str, query_path: str, top_k: int = 20):
provider = await self._ensure_provider()
return provider.search_by_path(collection_name, query_path, top_k)
async def get_all_stats(self) -> Dict[str, Any]:
provider = await self._ensure_provider()
return provider.get_all_stats()
async def clear_all_data(self) -> None:
provider = await self._ensure_provider()
provider.clear_all_data()
async def current_provider(self) -> Dict[str, Any]:
provider_type, provider_config = await VectorDBConfigManager.load_config()
entry = get_provider_entry(provider_type) or {}
return {
"type": provider_type,
"config": provider_config,
"label": entry.get("label"),
"enabled": entry.get("enabled", True),
}
__all__ = [
"AIProviderService",
"VectorDBService",
"VectorDBConfigManager",
"DEFAULT_VECTOR_DIMENSION",
"list_providers",
"get_provider_entry",
"get_provider_class",
"normalize_capabilities",
"ABILITIES",
]