mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-05-11 18:10:10 +08:00
- Updated import statements across multiple modules to use relative imports for better encapsulation. - Consolidated and organized the `__init__.py` files in various domain packages to expose necessary classes and functions. - Improved code readability and maintainability by grouping related imports and removing unused ones. - Ensured consistent import patterns across the domain, enhancing the overall structure of the codebase.
305 lines
11 KiB
Python
305 lines
11 KiB
Python
from typing import Annotated, Dict, Optional
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, Depends, HTTPException, Path, Request
|
|
|
|
from api.response import success
|
|
from domain.audit import AuditAction, audit
|
|
from domain.auth import User, get_current_active_user
|
|
from .service import AIProviderService, VectorDBConfigManager, VectorDBService
|
|
from .types import (
|
|
AIDefaultsUpdate,
|
|
AIModelCreate,
|
|
AIModelUpdate,
|
|
AIProviderCreate,
|
|
AIProviderUpdate,
|
|
VectorDBConfigPayload,
|
|
)
|
|
from .vector_providers import get_provider_class, get_provider_entry, list_providers
|
|
|
|
router_ai = APIRouter(prefix="/api/ai", tags=["ai"])
|
|
router_vector_db = APIRouter(prefix="/api/vector-db", tags=["vector-db"])
|
|
|
|
|
|
@audit(action=AuditAction.READ, description="获取 AI 提供商列表")
|
|
@router_ai.get("/providers")
|
|
async def list_providers_endpoint(
|
|
request: Request,
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
):
|
|
providers = await AIProviderService.list_providers()
|
|
return success({"providers": providers})
|
|
|
|
|
|
@audit(
|
|
action=AuditAction.CREATE,
|
|
description="创建 AI 提供商",
|
|
body_fields=["name", "identifier", "provider_type", "api_format", "base_url", "logo_url"],
|
|
redact_fields=["api_key"],
|
|
)
|
|
@router_ai.post("/providers")
|
|
async def create_provider(
|
|
request: Request,
|
|
payload: AIProviderCreate,
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
):
|
|
provider = await AIProviderService.create_provider(payload.dict())
|
|
return success(provider)
|
|
|
|
|
|
@audit(action=AuditAction.READ, description="获取 AI 提供商详情")
|
|
@router_ai.get("/providers/{provider_id}")
|
|
async def get_provider(
|
|
request: Request,
|
|
provider_id: Annotated[int, Path(..., gt=0)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
):
|
|
provider = await AIProviderService.get_provider(provider_id, with_models=True)
|
|
return success(provider)
|
|
|
|
|
|
@audit(
|
|
action=AuditAction.UPDATE,
|
|
description="更新 AI 提供商",
|
|
body_fields=["name", "provider_type", "api_format", "base_url", "logo_url", "api_key"],
|
|
redact_fields=["api_key"],
|
|
)
|
|
@router_ai.put("/providers/{provider_id}")
|
|
async def update_provider(
|
|
request: Request,
|
|
provider_id: Annotated[int, Path(..., gt=0)],
|
|
payload: AIProviderUpdate,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
):
|
|
data = {k: v for k, v in payload.dict().items() if v is not None}
|
|
if not data:
|
|
raise HTTPException(status_code=400, detail="No fields to update")
|
|
provider = await AIProviderService.update_provider(provider_id, data)
|
|
return success(provider)
|
|
|
|
|
|
@audit(action=AuditAction.DELETE, description="删除 AI 提供商")
|
|
@router_ai.delete("/providers/{provider_id}")
|
|
async def delete_provider(
|
|
request: Request,
|
|
provider_id: Annotated[int, Path(..., gt=0)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
):
|
|
await AIProviderService.delete_provider(provider_id)
|
|
return success({"id": provider_id})
|
|
|
|
|
|
@audit(action=AuditAction.UPDATE, description="同步模型列表")
|
|
@router_ai.post("/providers/{provider_id}/sync-models")
|
|
async def sync_models(
|
|
request: Request,
|
|
provider_id: Annotated[int, Path(..., gt=0)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
):
|
|
try:
|
|
result = await AIProviderService.sync_models(provider_id)
|
|
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
|
|
raise HTTPException(status_code=502, detail=f"Failed to synchronize models: {exc}") from exc
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
|
|
return success(result)
|
|
|
|
|
|
@audit(action=AuditAction.READ, description="获取远程模型列表")
|
|
@router_ai.get("/providers/{provider_id}/remote-models")
|
|
async def fetch_remote_models(
|
|
request: Request,
|
|
provider_id: Annotated[int, Path(..., gt=0)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
):
|
|
try:
|
|
models = await AIProviderService.fetch_remote_models(provider_id)
|
|
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
|
|
raise HTTPException(status_code=502, detail=f"Failed to pull models: {exc}") from exc
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
|
|
return success({"models": models})
|
|
|
|
|
|
@audit(action=AuditAction.READ, description="获取模型列表")
|
|
@router_ai.get("/providers/{provider_id}/models")
|
|
async def list_models(
|
|
request: Request,
|
|
provider_id: Annotated[int, Path(..., gt=0)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
):
|
|
models = await AIProviderService.list_models(provider_id)
|
|
return success({"models": models})
|
|
|
|
|
|
@audit(
|
|
action=AuditAction.CREATE,
|
|
description="创建模型",
|
|
body_fields=["name", "display_name", "capabilities", "context_window", "embedding_dimensions"],
|
|
)
|
|
@router_ai.post("/providers/{provider_id}/models")
|
|
async def create_model(
|
|
request: Request,
|
|
provider_id: Annotated[int, Path(..., gt=0)],
|
|
payload: AIModelCreate,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
):
|
|
model = await AIProviderService.create_model(provider_id, payload.dict())
|
|
return success(model)
|
|
|
|
|
|
@audit(
|
|
action=AuditAction.UPDATE,
|
|
description="更新模型",
|
|
body_fields=["display_name", "description", "capabilities", "context_window", "embedding_dimensions"],
|
|
)
|
|
@router_ai.put("/models/{model_id}")
|
|
async def update_model(
|
|
request: Request,
|
|
model_id: Annotated[int, Path(..., gt=0)],
|
|
payload: AIModelUpdate,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
):
|
|
data = {k: v for k, v in payload.dict().items() if v is not None}
|
|
if not data:
|
|
raise HTTPException(status_code=400, detail="No fields to update")
|
|
model = await AIProviderService.update_model(model_id, data)
|
|
return success(model)
|
|
|
|
|
|
@audit(action=AuditAction.DELETE, description="删除模型")
|
|
@router_ai.delete("/models/{model_id}")
|
|
async def delete_model(
|
|
request: Request,
|
|
model_id: Annotated[int, Path(..., gt=0)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
):
|
|
await AIProviderService.delete_model(model_id)
|
|
return success({"id": model_id})
|
|
|
|
|
|
def _get_embedding_dimension(entry: Optional[Dict]) -> Optional[int]:
|
|
if not entry:
|
|
return None
|
|
value = entry.get("embedding_dimensions")
|
|
return int(value) if value is not None else None
|
|
|
|
|
|
@audit(action=AuditAction.READ, description="获取默认模型")
|
|
@router_ai.get("/defaults")
|
|
async def get_defaults(
|
|
request: Request,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
):
|
|
defaults = await AIProviderService.get_default_models()
|
|
return success(defaults)
|
|
|
|
|
|
@audit(
|
|
action=AuditAction.UPDATE,
|
|
description="更新默认模型",
|
|
body_fields=["chat", "vision", "embedding", "rerank", "voice", "tools"],
|
|
)
|
|
@router_ai.put("/defaults")
|
|
async def update_defaults(
|
|
request: Request,
|
|
payload: AIDefaultsUpdate,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
):
|
|
previous = await AIProviderService.get_default_models()
|
|
try:
|
|
updated = await AIProviderService.set_default_models(payload.as_mapping())
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
|
|
prev_dim = _get_embedding_dimension(previous.get("embedding"))
|
|
next_dim = _get_embedding_dimension(updated.get("embedding"))
|
|
|
|
if prev_dim and next_dim and prev_dim != next_dim:
|
|
try:
|
|
await VectorDBService().clear_all_data()
|
|
except Exception as exc: # noqa: BLE001
|
|
raise HTTPException(status_code=500, detail=f"Failed to clear vector database: {exc}") from exc
|
|
|
|
return success(updated)
|
|
|
|
|
|
@audit(action=AuditAction.UPDATE, description="清空向量数据库")
|
|
@router_vector_db.post("/clear-all", summary="清空向量数据库")
|
|
async def clear_vector_db(request: Request, user: User = Depends(get_current_active_user)):
|
|
try:
|
|
service = VectorDBService()
|
|
await service.clear_all_data()
|
|
return success(msg="向量数据库已清空")
|
|
except Exception as e: # noqa: BLE001
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@audit(action=AuditAction.READ, description="获取向量数据库统计")
|
|
@router_vector_db.get("/stats", summary="获取向量数据库统计")
|
|
async def get_vector_db_stats(request: Request, user: User = Depends(get_current_active_user)):
|
|
try:
|
|
service = VectorDBService()
|
|
data = await service.get_all_stats()
|
|
return success(data=data)
|
|
except Exception as e: # noqa: BLE001
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@audit(action=AuditAction.READ, description="获取向量数据库提供者列表")
|
|
@router_vector_db.get("/providers", summary="列出可用向量数据库提供者")
|
|
async def list_vector_providers(request: Request, user: User = Depends(get_current_active_user)):
|
|
return success(list_providers())
|
|
|
|
|
|
@audit(action=AuditAction.READ, description="获取向量数据库配置")
|
|
@router_vector_db.get("/config", summary="获取当前向量数据库配置")
|
|
async def get_vector_db_config(request: Request, user: User = Depends(get_current_active_user)):
|
|
service = VectorDBService()
|
|
data = await service.current_provider()
|
|
return success(data)
|
|
|
|
|
|
@audit(action=AuditAction.UPDATE, description="更新向量数据库配置", body_fields=["type"])
|
|
@router_vector_db.post("/config", summary="更新向量数据库配置")
|
|
async def update_vector_db_config(
|
|
request: Request, payload: VectorDBConfigPayload, user: User = Depends(get_current_active_user)
|
|
):
|
|
entry = get_provider_entry(payload.type)
|
|
if not entry:
|
|
raise HTTPException(
|
|
status_code=400, detail=f"未知的向量数据库类型: {payload.type}")
|
|
if not entry.get("enabled", True):
|
|
raise HTTPException(status_code=400, detail="该向量数据库类型暂不可用")
|
|
|
|
provider_cls = get_provider_class(payload.type)
|
|
if not provider_cls:
|
|
raise HTTPException(
|
|
status_code=400, detail=f"未找到类型 {payload.type} 对应的实现")
|
|
|
|
test_provider = provider_cls(payload.config)
|
|
try:
|
|
await test_provider.initialize()
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=400, detail=str(exc))
|
|
finally:
|
|
client = getattr(test_provider, "client", None)
|
|
close_fn = getattr(client, "close", None)
|
|
if callable(close_fn):
|
|
try:
|
|
close_fn()
|
|
except Exception:
|
|
pass
|
|
|
|
await VectorDBConfigManager.save_config(payload.type, payload.config)
|
|
service = VectorDBService()
|
|
await service.reload()
|
|
config_data = await service.current_provider()
|
|
stats = await service.get_all_stats()
|
|
return success({"config": config_data, "stats": stats})
|
|
|
|
|
|
__all__ = ["router_ai", "router_vector_db"]
|