mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-06-01 13:40:22 +08:00
feat: add AI Settings tab for managing providers and models
This commit is contained in:
310
services/ai.py
310
services/ai.py
@@ -1,113 +1,247 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
from typing import List
|
||||
from services.config import ConfigCenter
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
from models.database import AIModel, AIProvider
|
||||
from services.ai_providers import AIProviderService
|
||||
|
||||
|
||||
provider_service = AIProviderService()
|
||||
|
||||
|
||||
class MissingModelError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
async def describe_image_base64(base64_image: str, detail: str = "high") -> str:
|
||||
"""
|
||||
传入base64图片和文本提示,返回图片描述文本。
|
||||
传入 base64 图片并返回描述文本。缺省时返回错误提示。
|
||||
"""
|
||||
OAI_API_URL = await ConfigCenter.get("AI_VISION_API_URL")
|
||||
VISION_MODEL = await ConfigCenter.get("AI_VISION_MODEL")
|
||||
API_KEY = await ConfigCenter.get("AI_VISION_API_KEY")
|
||||
payload = {
|
||||
"model": VISION_MODEL,
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
"detail": detail
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "描述这个图片"
|
||||
}
|
||||
]}
|
||||
]
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.post(OAI_API_URL, headers=headers, json=payload)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
return result["choices"][0]["message"]["content"]
|
||||
model, provider = await _require_model("vision")
|
||||
if provider.api_format == "openai":
|
||||
return await _describe_with_openai(provider, model, base64_image, detail)
|
||||
return await _describe_with_gemini(provider, model, base64_image, detail)
|
||||
except MissingModelError as exc:
|
||||
return str(exc)
|
||||
except httpx.ReadTimeout:
|
||||
return "请求超时,请稍后重试。"
|
||||
except Exception as e:
|
||||
return f"请求失败: {str(e)}"
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return f"请求失败: {exc}"
|
||||
|
||||
|
||||
async def get_text_embedding(text: str) -> List[float]:
|
||||
"""
|
||||
传入文本,返回嵌入向量。
|
||||
传入文本,返回嵌入向量。若未配置模型则抛出异常。
|
||||
"""
|
||||
OAI_API_URL = await ConfigCenter.get("AI_EMBED_API_URL")
|
||||
EMBED_MODEL = await ConfigCenter.get("AI_EMBED_MODEL")
|
||||
API_KEY = await ConfigCenter.get("AI_EMBED_API_KEY")
|
||||
payload = {
|
||||
"model": EMBED_MODEL,
|
||||
"input": text
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
if OAI_API_URL.endswith("chat/completions"):
|
||||
url = OAI_API_URL.replace("chat/completions", "embeddings")
|
||||
else:
|
||||
url = OAI_API_URL
|
||||
resp = await client.post(url, headers=headers, json=payload)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
return result["data"][0]["embedding"]
|
||||
model, provider = await _require_model("embedding")
|
||||
if provider.api_format == "openai":
|
||||
return await _embedding_with_openai(provider, model, text)
|
||||
return await _embedding_with_gemini(provider, model, text)
|
||||
|
||||
|
||||
async def rerank_texts(query: str, documents: List[str]) -> List[float]:
|
||||
async def rerank_texts(query: str, documents: Sequence[str]) -> List[float]:
|
||||
"""调用重排序模型,为一组文档返回得分。未配置时返回空列表。"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
api_url = await ConfigCenter.get("AI_RERANK_API_URL")
|
||||
model = await ConfigCenter.get("AI_RERANK_MODEL")
|
||||
api_key = await ConfigCenter.get("AI_RERANK_API_KEY")
|
||||
|
||||
if not api_url or not model or not api_key:
|
||||
try:
|
||||
model, provider = await _require_model("rerank")
|
||||
except MissingModelError:
|
||||
return []
|
||||
|
||||
try:
|
||||
if provider.api_format == "openai":
|
||||
return await _rerank_with_openai(provider, model, query, documents)
|
||||
return await _rerank_with_gemini(provider, model, query, documents)
|
||||
except Exception: # noqa: BLE001
|
||||
return []
|
||||
|
||||
|
||||
async def _require_model(ability: str) -> Tuple[AIModel, AIProvider]:
|
||||
model = await provider_service.get_default_model(ability)
|
||||
if not model:
|
||||
raise MissingModelError(f"未配置默认 {ability} 模型,请前往系统设置完成配置。")
|
||||
provider = getattr(model, "provider", None)
|
||||
if provider is None:
|
||||
await model.fetch_related("provider")
|
||||
provider = model.provider
|
||||
if provider is None:
|
||||
raise MissingModelError("模型缺少关联的提供商配置。")
|
||||
if not provider.base_url:
|
||||
raise MissingModelError("该提供商未设置 API 地址。")
|
||||
return model, provider
|
||||
|
||||
|
||||
def _openai_endpoint(provider: AIProvider, path: str) -> str:
|
||||
base = (provider.base_url or "").rstrip("/")
|
||||
if not base:
|
||||
raise MissingModelError("提供商 API 地址未配置。")
|
||||
return f"{base}/{path.lstrip('/')}"
|
||||
|
||||
|
||||
def _openai_headers(provider: AIProvider) -> dict:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if provider.api_key:
|
||||
headers["Authorization"] = f"Bearer {provider.api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
def _gemini_endpoint(provider: AIProvider, path: str) -> str:
|
||||
base = (provider.base_url or "").rstrip("/")
|
||||
if not base:
|
||||
raise MissingModelError("提供商 API 地址未配置。")
|
||||
url = f"{base}/{path.lstrip('/')}"
|
||||
if provider.api_key:
|
||||
connector = "&" if "?" in url else "?"
|
||||
url = f"{url}{connector}key={provider.api_key}"
|
||||
return url
|
||||
|
||||
|
||||
async def _describe_with_openai(provider: AIProvider, model: AIModel, base64_image: str, detail: str) -> str:
|
||||
url = _openai_endpoint(provider, "/chat/completions")
|
||||
payload = {
|
||||
"model": model,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"model": model.name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
"detail": detail,
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "描述这个图片"},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(url, headers=_openai_headers(provider), json=payload)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
return body["choices"][0]["message"]["content"]
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
|
||||
async def _describe_with_gemini(provider: AIProvider, model: AIModel, base64_image: str, detail: str) -> str:
|
||||
detail_text = f"描述这个图片,细节等级:{detail}"
|
||||
model_name = model.name if model.name.startswith("models/") else f"models/{model.name}"
|
||||
url = _gemini_endpoint(provider, f"{model_name}:generateContent")
|
||||
payload = {
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": "image/jpeg",
|
||||
"data": base64_image,
|
||||
}
|
||||
},
|
||||
{"text": detail_text},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
candidates = body.get("candidates") or []
|
||||
if not candidates:
|
||||
return ""
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
text_parts = [part.get("text") for part in parts if isinstance(part, dict) and part.get("text")]
|
||||
return "\n".join(text_parts)
|
||||
|
||||
|
||||
async def _embedding_with_openai(provider: AIProvider, model: AIModel, text: str) -> List[float]:
|
||||
url = _openai_endpoint(provider, "/embeddings")
|
||||
payload = {
|
||||
"model": model.name,
|
||||
"input": text,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(url, headers=_openai_headers(provider), json=payload)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
return body["data"][0]["embedding"]
|
||||
|
||||
|
||||
async def _embedding_with_gemini(provider: AIProvider, model: AIModel, text: str) -> List[float]:
|
||||
model_name = model.name if model.name.startswith("models/") else f"models/{model.name}"
|
||||
url = _gemini_endpoint(provider, f"{model_name}:embedContent")
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"content": {
|
||||
"parts": [{"text": text}],
|
||||
},
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
embedding = body.get("embedding") or {}
|
||||
return embedding.get("values") or []
|
||||
|
||||
|
||||
async def _rerank_with_openai(
|
||||
provider: AIProvider,
|
||||
model: AIModel,
|
||||
query: str,
|
||||
documents: Sequence[str],
|
||||
) -> List[float]:
|
||||
url = _openai_endpoint(provider, "/rerank")
|
||||
payload = {
|
||||
"model": model.name,
|
||||
"query": query,
|
||||
"documents": [
|
||||
{"id": str(idx), "text": content}
|
||||
for idx, content in enumerate(documents)
|
||||
],
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(url, headers=_openai_headers(provider), json=payload)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
results = body.get("results") or body.get("data") or []
|
||||
scores: List[float] = []
|
||||
for item in results:
|
||||
try:
|
||||
scores.append(float(item.get("score", 0.0)))
|
||||
except (TypeError, ValueError):
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
|
||||
|
||||
async def _rerank_with_gemini(
|
||||
provider: AIProvider,
|
||||
model: AIModel,
|
||||
query: str,
|
||||
documents: Sequence[str],
|
||||
) -> List[float]:
|
||||
model_name = model.name if model.name.startswith("models/") else f"models/{model.name}"
|
||||
url = _gemini_endpoint(provider, f"{model_name}:rankContent")
|
||||
payload = {
|
||||
"query": {"text": query},
|
||||
"documents": [
|
||||
{"id": str(idx), "content": {"parts": [{"text": content}]}}
|
||||
for idx, content in enumerate(documents)
|
||||
],
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
|
||||
scores: List[float] = []
|
||||
ranked = body.get("rankedDocuments") or body.get("results") or []
|
||||
for item in ranked:
|
||||
raw_score = item.get("relevanceScore") or item.get("score") or item.get("confidenceScore")
|
||||
try:
|
||||
resp = await client.post(api_url, headers=headers, json=payload)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError:
|
||||
return []
|
||||
data = resp.json()
|
||||
if isinstance(data, dict):
|
||||
results = data.get("results")
|
||||
if isinstance(results, list):
|
||||
scores = []
|
||||
for item in results:
|
||||
if isinstance(item, dict) and "score" in item:
|
||||
try:
|
||||
scores.append(float(item["score"]))
|
||||
except (TypeError, ValueError):
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
return []
|
||||
scores.append(float(raw_score))
|
||||
except (TypeError, ValueError):
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
|
||||
347
services/ai_providers.py
Normal file
347
services/ai_providers.py
Normal file
@@ -0,0 +1,347 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from tortoise.exceptions import DoesNotExist
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
from models.database import AIDefaultModel, AIModel, AIProvider
|
||||
|
||||
|
||||
ABILITIES = ["chat", "vision", "embedding", "rerank", "voice", "tools"]
|
||||
|
||||
OPENAI_EMBEDDING_DIMS = {
|
||||
"text-embedding-3-large": 3072,
|
||||
"text-embedding-3-small": 1536,
|
||||
"text-embedding-ada-002": 1536,
|
||||
}
|
||||
|
||||
|
||||
def _normalize_embedding_dim(value: Any) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
casted = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return casted if casted > 0 else None
|
||||
|
||||
|
||||
def _apply_embedding_dim_to_metadata(
|
||||
data: Dict[str, Any],
|
||||
embedding_dim: Optional[int],
|
||||
base_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
source = base_metadata if isinstance(base_metadata, dict) else {}
|
||||
metadata: Dict[str, Any] = dict(source)
|
||||
override = data.get("metadata")
|
||||
if isinstance(override, dict) and override:
|
||||
metadata.update(override)
|
||||
if embedding_dim is None:
|
||||
metadata.pop("embedding_dimensions", None)
|
||||
else:
|
||||
metadata["embedding_dimensions"] = embedding_dim
|
||||
data["metadata"] = metadata or None
|
||||
return data
|
||||
|
||||
|
||||
def normalize_capabilities(items: Optional[Iterable[str]]) -> List[str]:
|
||||
if not items:
|
||||
return []
|
||||
normalized = []
|
||||
for cap in items:
|
||||
key = str(cap).strip().lower()
|
||||
if key in ABILITIES and key not in normalized:
|
||||
normalized.append(key)
|
||||
return normalized
|
||||
|
||||
|
||||
def infer_openai_capabilities(model_id: str) -> Tuple[List[str], Optional[int]]:
|
||||
lower = model_id.lower()
|
||||
caps = set()
|
||||
|
||||
if any(keyword in lower for keyword in ["gpt", "chat", "turbo", "o1", "sonnet", "haiku", "thinking"]):
|
||||
caps.update({"chat", "tools"})
|
||||
|
||||
if any(keyword in lower for keyword in ["vision", "gpt-4o", "gpt-4.1", "o1", "vision-preview", "omni"]):
|
||||
caps.add("vision")
|
||||
|
||||
if any(keyword in lower for keyword in ["embed", "embedding"]):
|
||||
caps.add("embedding")
|
||||
|
||||
if "rerank" in lower or "re-rank" in lower:
|
||||
caps.add("rerank")
|
||||
|
||||
if any(keyword in lower for keyword in ["tts", "speech", "audio"]):
|
||||
caps.add("voice")
|
||||
|
||||
embedding_dim = OPENAI_EMBEDDING_DIMS.get(model_id)
|
||||
return normalize_capabilities(caps), embedding_dim
|
||||
|
||||
|
||||
def infer_gemini_capabilities(methods: Iterable[str]) -> List[str]:
|
||||
caps = set()
|
||||
for method in methods:
|
||||
m = method.lower()
|
||||
if m in {"generatecontent", "counttokens"}:
|
||||
caps.update({"chat", "tools", "vision"})
|
||||
if m == "embedcontent":
|
||||
caps.add("embedding")
|
||||
if m in {"generatespeech", "audiogeneration"}:
|
||||
caps.add("voice")
|
||||
if m == "rerank":
|
||||
caps.add("rerank")
|
||||
return normalize_capabilities(caps)
|
||||
|
||||
|
||||
def serialize_provider(provider: AIProvider) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"identifier": provider.identifier,
|
||||
"provider_type": provider.provider_type,
|
||||
"api_format": provider.api_format,
|
||||
"base_url": provider.base_url,
|
||||
"api_key": provider.api_key,
|
||||
"logo_url": provider.logo_url,
|
||||
"extra_config": provider.extra_config or {},
|
||||
"created_at": provider.created_at,
|
||||
"updated_at": provider.updated_at,
|
||||
}
|
||||
|
||||
|
||||
def model_to_dict(model: AIModel, provider: Optional[AIProvider] = None) -> Dict[str, Any]:
|
||||
provider_obj = provider or getattr(model, "provider", None)
|
||||
provider_data = serialize_provider(provider_obj) if provider_obj else None
|
||||
return {
|
||||
"id": model.id,
|
||||
"provider_id": model.provider_id,
|
||||
"name": model.name,
|
||||
"display_name": model.display_name,
|
||||
"description": model.description,
|
||||
"capabilities": normalize_capabilities(model.capabilities),
|
||||
"context_window": model.context_window,
|
||||
"embedding_dimensions": model.embedding_dimensions,
|
||||
"metadata": model.metadata or {},
|
||||
"created_at": model.created_at,
|
||||
"updated_at": model.updated_at,
|
||||
"provider": provider_data,
|
||||
}
|
||||
|
||||
|
||||
def provider_to_dict(provider: AIProvider, models: Optional[List[AIModel]] = None) -> Dict[str, Any]:
|
||||
data = serialize_provider(provider)
|
||||
if models is not None:
|
||||
data["models"] = [model_to_dict(m, provider=provider) for m in models]
|
||||
return data
|
||||
|
||||
|
||||
class AIProviderService:
|
||||
async def list_providers(self) -> List[Dict[str, Any]]:
|
||||
providers = await AIProvider.all().order_by("id").prefetch_related("models")
|
||||
return [provider_to_dict(p, models=list(p.models)) for p in providers]
|
||||
|
||||
async def get_provider(self, provider_id: int, with_models: bool = False) -> Dict[str, Any]:
|
||||
if with_models:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
models = await provider.models.all()
|
||||
return provider_to_dict(provider, models=models)
|
||||
else:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
return provider_to_dict(provider)
|
||||
|
||||
async def create_provider(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
data = payload.copy()
|
||||
data.setdefault("extra_config", {})
|
||||
provider = await AIProvider.create(**data)
|
||||
return provider_to_dict(provider)
|
||||
|
||||
async def update_provider(self, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
for field, value in payload.items():
|
||||
setattr(provider, field, value)
|
||||
await provider.save()
|
||||
return provider_to_dict(provider)
|
||||
|
||||
async def delete_provider(self, provider_id: int) -> None:
|
||||
await AIProvider.filter(id=provider_id).delete()
|
||||
|
||||
async def list_models(self, provider_id: int) -> List[Dict[str, Any]]:
|
||||
models = await AIModel.filter(provider_id=provider_id).order_by("id").prefetch_related("provider")
|
||||
return [model_to_dict(m) for m in models]
|
||||
|
||||
async def create_model(self, provider_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
data = payload.copy()
|
||||
data["provider_id"] = provider_id
|
||||
data["capabilities"] = normalize_capabilities(data.get("capabilities"))
|
||||
embedding_dim = _normalize_embedding_dim(data.pop("embedding_dimensions", None))
|
||||
data = _apply_embedding_dim_to_metadata(data, embedding_dim)
|
||||
model = await AIModel.create(**data)
|
||||
await model.fetch_related("provider")
|
||||
return model_to_dict(model)
|
||||
|
||||
async def update_model(self, model_id: int, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
model = await AIModel.get(id=model_id)
|
||||
data = payload.copy()
|
||||
if "capabilities" in data:
|
||||
data["capabilities"] = normalize_capabilities(data.get("capabilities"))
|
||||
embedding_dim = None
|
||||
if "embedding_dimensions" in data:
|
||||
embedding_dim = _normalize_embedding_dim(data.pop("embedding_dimensions", None))
|
||||
_apply_embedding_dim_to_metadata(data, embedding_dim, base_metadata=model.metadata)
|
||||
for field, value in data.items():
|
||||
setattr(model, field, value)
|
||||
if embedding_dim is not None or ("embedding_dimensions" in payload and embedding_dim is None):
|
||||
model.embedding_dimensions = embedding_dim
|
||||
await model.save()
|
||||
await model.fetch_related("provider")
|
||||
return model_to_dict(model)
|
||||
|
||||
async def delete_model(self, model_id: int) -> None:
|
||||
await AIModel.filter(id=model_id).delete()
|
||||
|
||||
async def fetch_remote_models(self, provider_id: int) -> List[Dict[str, Any]]:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
return await self._get_remote_models(provider)
|
||||
|
||||
async def _get_remote_models(self, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
if not provider.base_url:
|
||||
raise ValueError("Provider base_url is required for syncing models")
|
||||
|
||||
fmt = (provider.api_format or "").lower()
|
||||
if fmt not in {"openai", "gemini"}:
|
||||
raise ValueError(f"Unsupported api_format '{provider.api_format}' for syncing models")
|
||||
|
||||
if fmt == "openai":
|
||||
return await self._fetch_openai_models(provider)
|
||||
return await self._fetch_gemini_models(provider)
|
||||
|
||||
async def sync_models(self, provider_id: int) -> Dict[str, int]:
|
||||
provider = await AIProvider.get(id=provider_id)
|
||||
remote_models = await self._get_remote_models(provider)
|
||||
|
||||
created = 0
|
||||
updated = 0
|
||||
for entry in remote_models:
|
||||
defaults = entry.copy()
|
||||
model_id = defaults.pop("name")
|
||||
defaults["capabilities"] = normalize_capabilities(defaults.get("capabilities"))
|
||||
embedding_dim = _normalize_embedding_dim(defaults.pop("embedding_dimensions", None))
|
||||
defaults = _apply_embedding_dim_to_metadata(defaults, embedding_dim)
|
||||
obj, is_created = await AIModel.get_or_create(
|
||||
provider_id=provider.id,
|
||||
name=model_id,
|
||||
defaults=defaults,
|
||||
)
|
||||
if is_created:
|
||||
created += 1
|
||||
continue
|
||||
for field, value in defaults.items():
|
||||
setattr(obj, field, value)
|
||||
if embedding_dim is not None or ("embedding_dimensions" in entry and embedding_dim is None):
|
||||
obj.embedding_dimensions = embedding_dim
|
||||
await obj.save()
|
||||
updated += 1
|
||||
|
||||
return {"created": created, "updated": updated}
|
||||
|
||||
async def get_default_models(self) -> Dict[str, Optional[Dict[str, Any]]]:
|
||||
defaults = await AIDefaultModel.all().prefetch_related("model__provider")
|
||||
result: Dict[str, Optional[Dict[str, Any]]] = {ability: None for ability in ABILITIES}
|
||||
for item in defaults:
|
||||
result[item.ability] = model_to_dict(item.model, provider=item.model.provider) # type: ignore[attr-defined]
|
||||
return result
|
||||
|
||||
async def set_default_models(self, mapping: Dict[str, Optional[int]]) -> Dict[str, Optional[Dict[str, Any]]]:
|
||||
normalized = {ability: mapping.get(ability) for ability in ABILITIES}
|
||||
async with in_transaction() as connection:
|
||||
for ability, model_id in normalized.items():
|
||||
record = await AIDefaultModel.get_or_none(ability=ability)
|
||||
if model_id:
|
||||
try:
|
||||
model = await AIModel.get(id=model_id)
|
||||
except DoesNotExist:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
if record:
|
||||
record.model_id = model_id
|
||||
await record.save(using_db=connection)
|
||||
else:
|
||||
await AIDefaultModel.create(ability=ability, model_id=model_id)
|
||||
elif record:
|
||||
await record.delete(using_db=connection)
|
||||
return await self.get_default_models()
|
||||
|
||||
async def get_default_model(self, ability: str) -> Optional[AIModel]:
|
||||
ability_key = ability.lower()
|
||||
if ability_key not in ABILITIES:
|
||||
return None
|
||||
record = await AIDefaultModel.get_or_none(ability=ability_key)
|
||||
if not record:
|
||||
return None
|
||||
model = await AIModel.get_or_none(id=record.model_id)
|
||||
if model:
|
||||
await model.fetch_related("provider")
|
||||
return model
|
||||
|
||||
async def _fetch_openai_models(self, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
base_url = provider.base_url.rstrip("/")
|
||||
url = f"{base_url}/models"
|
||||
headers = {}
|
||||
if provider.api_key:
|
||||
headers["Authorization"] = f"Bearer {provider.api_key}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
data = payload.get("data", [])
|
||||
entries: List[Dict[str, Any]] = []
|
||||
for item in data:
|
||||
model_id = item.get("id")
|
||||
if not model_id:
|
||||
continue
|
||||
capabilities, embedding_dim = infer_openai_capabilities(model_id)
|
||||
entries.append({
|
||||
"name": model_id,
|
||||
"display_name": item.get("display_name"),
|
||||
"description": item.get("description"),
|
||||
"capabilities": capabilities,
|
||||
"context_window": item.get("context_window"),
|
||||
"embedding_dimensions": embedding_dim,
|
||||
"metadata": item,
|
||||
})
|
||||
return entries
|
||||
|
||||
async def _fetch_gemini_models(self, provider: AIProvider) -> List[Dict[str, Any]]:
|
||||
base_url = provider.base_url.rstrip("/")
|
||||
suffix = "/models"
|
||||
if provider.api_key:
|
||||
suffix += f"?key={provider.api_key}"
|
||||
url = f"{base_url}{suffix}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
data = payload.get("models", [])
|
||||
entries: List[Dict[str, Any]] = []
|
||||
for item in data:
|
||||
model_id = item.get("name")
|
||||
if not model_id:
|
||||
continue
|
||||
methods = item.get("supportedGenerationMethods") or []
|
||||
capabilities = infer_gemini_capabilities(methods)
|
||||
entries.append({
|
||||
"name": model_id,
|
||||
"display_name": item.get("displayName"),
|
||||
"description": item.get("description"),
|
||||
"capabilities": capabilities,
|
||||
"context_window": item.get("inputTokenLimit"),
|
||||
"embedding_dimensions": item.get("embeddingDimensions"),
|
||||
"metadata": item,
|
||||
})
|
||||
return entries
|
||||
@@ -5,15 +5,11 @@ import mimetypes
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
from services.ai import describe_image_base64, get_text_embedding
|
||||
from services.ai import describe_image_base64, get_text_embedding, provider_service
|
||||
from services.vector_db import VectorDBService, DEFAULT_VECTOR_DIMENSION
|
||||
from services.logging import LogService
|
||||
from services.config import ConfigCenter
|
||||
from PIL import Image
|
||||
|
||||
try: # Pillow is optional but bundled with the project dependencies
|
||||
from PIL import Image
|
||||
except ImportError: # pragma: no cover - fallback when pillow missing
|
||||
Image = None
|
||||
|
||||
|
||||
CHUNK_SIZE = 800
|
||||
@@ -150,13 +146,15 @@ class VectorIndexProcessor:
|
||||
file_ext = path.split('.')[-1].lower()
|
||||
details: Dict[str, Any] = {"path": path, "action": "create", "index_type": "vector"}
|
||||
|
||||
raw_dim = await ConfigCenter.get('AI_EMBED_DIM', DEFAULT_VECTOR_DIMENSION)
|
||||
try:
|
||||
vector_dim = int(raw_dim)
|
||||
except (TypeError, ValueError):
|
||||
vector_dim = DEFAULT_VECTOR_DIMENSION
|
||||
if vector_dim <= 0:
|
||||
vector_dim = DEFAULT_VECTOR_DIMENSION
|
||||
embedding_model = await provider_service.get_default_model("embedding")
|
||||
vector_dim = DEFAULT_VECTOR_DIMENSION
|
||||
if embedding_model and getattr(embedding_model, "embedding_dimensions", None):
|
||||
try:
|
||||
vector_dim = int(embedding_model.embedding_dimensions)
|
||||
except (TypeError, ValueError):
|
||||
vector_dim = DEFAULT_VECTOR_DIMENSION
|
||||
if vector_dim <= 0:
|
||||
vector_dim = DEFAULT_VECTOR_DIMENSION
|
||||
|
||||
await vector_db.ensure_collection(collection_name, vector=True, dim=vector_dim)
|
||||
await vector_db.delete_vector(collection_name, path)
|
||||
|
||||
Reference in New Issue
Block a user