mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-05-10 17:43:35 +08:00
feat: add AI Settings tab for managing providers and models
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .routes import adapters, virtual_fs, auth, config, processors, tasks, logs, share, backup, search, vector_db, offline_downloads
|
||||
from .routes import adapters, virtual_fs, auth, config, processors, tasks, logs, share, backup, search, vector_db, offline_downloads, ai_providers
|
||||
from .routes import webdav
|
||||
from .routes import plugins
|
||||
|
||||
@@ -18,6 +18,7 @@ def include_routers(app: FastAPI):
|
||||
app.include_router(share.public_router)
|
||||
app.include_router(backup.router)
|
||||
app.include_router(vector_db.router)
|
||||
app.include_router(ai_providers.router)
|
||||
app.include_router(plugins.router)
|
||||
app.include_router(webdav.router)
|
||||
app.include_router(offline_downloads.router)
|
||||
|
||||
177
api/routes/ai_providers.py
Normal file
177
api/routes/ai_providers.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from typing import Annotated, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path
|
||||
|
||||
from api.response import success
|
||||
from schemas.ai import (
|
||||
AIDefaultsUpdate,
|
||||
AIModelCreate,
|
||||
AIModelUpdate,
|
||||
AIProviderCreate,
|
||||
AIProviderUpdate,
|
||||
)
|
||||
from services.ai_providers import AIProviderService
|
||||
from services.auth import User, get_current_active_user
|
||||
from services.vector_db import VectorDBService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/ai", tags=["ai"])
|
||||
service = AIProviderService()
|
||||
|
||||
|
||||
@router.get("/providers")
|
||||
async def list_providers(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
providers = await service.list_providers()
|
||||
return success({"providers": providers})
|
||||
|
||||
|
||||
@router.post("/providers")
|
||||
async def create_provider(
|
||||
payload: AIProviderCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
provider = await service.create_provider(payload.dict())
|
||||
return success(provider)
|
||||
|
||||
|
||||
@router.get("/providers/{provider_id}")
|
||||
async def get_provider(
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
provider = await service.get_provider(provider_id, with_models=True)
|
||||
return success(provider)
|
||||
|
||||
|
||||
@router.put("/providers/{provider_id}")
|
||||
async def update_provider(
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIProviderUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = {k: v for k, v in payload.dict().items() if v is not None}
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="No fields to update")
|
||||
provider = await service.update_provider(provider_id, data)
|
||||
return success(provider)
|
||||
|
||||
|
||||
@router.delete("/providers/{provider_id}")
|
||||
async def delete_provider(
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
await service.delete_provider(provider_id)
|
||||
return success({"id": provider_id})
|
||||
|
||||
|
||||
@router.post("/providers/{provider_id}/sync-models")
|
||||
async def sync_models(
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
try:
|
||||
result = await service.sync_models(provider_id)
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Failed to synchronize models: {exc}") from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return success(result)
|
||||
|
||||
|
||||
@router.get("/providers/{provider_id}/remote-models")
|
||||
async def fetch_remote_models(
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
try:
|
||||
models = await service.fetch_remote_models(provider_id)
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Failed to pull models: {exc}") from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return success({"models": models})
|
||||
|
||||
|
||||
@router.get("/providers/{provider_id}/models")
|
||||
async def list_models(
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
models = await service.list_models(provider_id)
|
||||
return success({"models": models})
|
||||
|
||||
|
||||
@router.post("/providers/{provider_id}/models")
|
||||
async def create_model(
|
||||
provider_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIModelCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
model = await service.create_model(provider_id, payload.dict())
|
||||
return success(model)
|
||||
|
||||
|
||||
@router.put("/models/{model_id}")
|
||||
async def update_model(
|
||||
model_id: Annotated[int, Path(..., gt=0)],
|
||||
payload: AIModelUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
data = {k: v for k, v in payload.dict().items() if v is not None}
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="No fields to update")
|
||||
model = await service.update_model(model_id, data)
|
||||
return success(model)
|
||||
|
||||
|
||||
@router.delete("/models/{model_id}")
|
||||
async def delete_model(
|
||||
model_id: Annotated[int, Path(..., gt=0)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
await service.delete_model(model_id)
|
||||
return success({"id": model_id})
|
||||
|
||||
|
||||
def _get_embedding_dimension(entry: Optional[Dict]) -> Optional[int]:
|
||||
if not entry:
|
||||
return None
|
||||
value = entry.get("embedding_dimensions")
|
||||
return int(value) if value is not None else None
|
||||
|
||||
|
||||
@router.get("/defaults")
|
||||
async def get_defaults(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
defaults = await service.get_default_models()
|
||||
return success(defaults)
|
||||
|
||||
|
||||
@router.put("/defaults")
|
||||
async def update_defaults(
|
||||
payload: AIDefaultsUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
):
|
||||
previous = await service.get_default_models()
|
||||
try:
|
||||
updated = await service.set_default_models(payload.as_mapping())
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
prev_dim = _get_embedding_dimension(previous.get("embedding"))
|
||||
next_dim = _get_embedding_dimension(updated.get("embedding"))
|
||||
|
||||
if prev_dim and next_dim and prev_dim != next_dim:
|
||||
try:
|
||||
await VectorDBService().clear_all_data()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail=f"Failed to clear vector database: {exc}") from exc
|
||||
|
||||
return success(updated)
|
||||
@@ -1,11 +1,10 @@
|
||||
import httpx
|
||||
import time
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||
from fastapi import APIRouter, Depends, Form
|
||||
from typing import Annotated
|
||||
from services.config import ConfigCenter, VERSION
|
||||
from services.auth import get_current_active_user, User, has_users
|
||||
from api.response import success
|
||||
from services.vector_db import VectorDBService
|
||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||
|
||||
|
||||
@@ -24,27 +23,8 @@ async def set_config(
|
||||
key: str = Form(...),
|
||||
value: str = Form(...)
|
||||
):
|
||||
original_value = await ConfigCenter.get(key)
|
||||
value_to_save = value
|
||||
if key == "AI_EMBED_DIM":
|
||||
try:
|
||||
parsed_value = int(value)
|
||||
except (TypeError, ValueError):
|
||||
raise HTTPException(status_code=400, detail="AI_EMBED_DIM must be an integer")
|
||||
if parsed_value <= 0:
|
||||
raise HTTPException(status_code=400, detail="AI_EMBED_DIM must be greater than zero")
|
||||
value_to_save = str(parsed_value)
|
||||
|
||||
await ConfigCenter.set(key, value_to_save)
|
||||
|
||||
if key == "AI_EMBED_DIM" and str(original_value) != value_to_save:
|
||||
try:
|
||||
service = VectorDBService()
|
||||
await service.clear_all_data()
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to clear vector database: {exc}")
|
||||
|
||||
return success({"key": key, "value": value_to_save})
|
||||
await ConfigCenter.set(key, value)
|
||||
return success({"key": key, "value": value})
|
||||
|
||||
|
||||
@router.get("/all")
|
||||
|
||||
Reference in New Issue
Block a user