import re from typing import Annotated, Optional from fastapi import APIRouter, Body, Depends, Request from fastapi.responses import HTMLResponse from pydantic import BaseModel from app import schemas from app.agent.llm import ( LLMHelper, LLMProviderManager, LLMTestTimeout, render_auth_result_html, ) from app.core.config import settings from app.db.models import User from app.db.user_oper import ( get_current_active_superuser_async, get_current_active_user_async, ) from app.log import logger router = APIRouter() class LlmTestRequest(BaseModel): enabled: Optional[bool] = None provider: Optional[str] = None model: Optional[str] = None thinking_level: Optional[str] = None api_key: Optional[str] = None base_url: Optional[str] = None base_url_preset: Optional[str] = None class LlmProviderAuthStartRequest(BaseModel): provider: str method: str def _sanitize_llm_test_error(message: str, api_key: Optional[str] = None) -> str: """ 清理错误信息中的敏感字段,避免回显密钥。 """ if not message: return "LLM 调用失败" sanitized = message if api_key: sanitized = sanitized.replace(api_key, "***") sanitized = re.sub( r"(?i)(api[_-]?key\s*[:=]\s*)([^\s,;]+)", r"\1***", sanitized, ) sanitized = re.sub( r"(?i)authorization\s*:\s*bearer\s+[^\s,;]+", "Authorization: ***", sanitized, ) return sanitized @router.get("/models", summary="获取LLM模型列表", response_model=schemas.Response) async def get_llm_models( provider: str, api_key: Optional[str] = None, base_url: Optional[str] = None, base_url_preset: Optional[str] = None, force_refresh: Optional[bool] = False, _: User = Depends(get_current_active_user_async), ): """ 获取指定 provider 的模型目录。 """ try: provider_manager = LLMProviderManager() models = await LLMHelper().get_models( provider=provider, api_key=api_key, base_url=base_url, base_url_preset=base_url_preset, force_refresh=bool(force_refresh), ) return schemas.Response( success=True, data={ "provider": provider, "models": models, "auth_status": provider_manager.get_auth_status(provider), }, ) except Exception as err: return schemas.Response(success=False, message=str(err)) @router.get("/providers", summary="获取LLM提供商目录", response_model=schemas.Response) async def get_llm_providers( _: User = Depends(get_current_active_user_async), ): """ 返回前端可直接渲染的 provider 目录。 """ try: providers = await LLMProviderManager().list_providers_async() return schemas.Response(success=True, data=providers) except Exception as err: return schemas.Response(success=False, message=str(err)) @router.post( "/provider-auth/start", summary="启动LLM提供商授权", response_model=schemas.Response, ) async def start_llm_provider_auth( payload: LlmProviderAuthStartRequest, request: Request, _: User = Depends(get_current_active_superuser_async), ): """ 启动 provider 授权会话。 """ try: callback_url = None if payload.provider == "chatgpt" and payload.method == "browser_oauth": callback_url = str( request.url_for("llm_provider_auth_callback", provider_id=payload.provider) ) result = await LLMProviderManager().start_auth( payload.provider, payload.method, callback_url, ) return schemas.Response(success=True, data=result) except Exception as err: return schemas.Response(success=False, message=str(err)) @router.get( "/provider-auth/{session_id}", summary="获取LLM提供商授权会话状态", response_model=schemas.Response, ) async def get_llm_provider_auth_session( session_id: str, _: User = Depends(get_current_active_superuser_async), ): """ 查询授权会话状态。 """ try: result = LLMProviderManager().get_session_status(session_id) return schemas.Response(success=True, data=result) except Exception as err: return schemas.Response(success=False, message=str(err)) @router.post( "/provider-auth/{session_id}/poll", summary="轮询LLM提供商授权会话", response_model=schemas.Response, ) async def poll_llm_provider_auth_session( session_id: str, _: User = Depends(get_current_active_superuser_async), ): """ 轮询 device code / OAuth 会话状态。 """ try: result = await LLMProviderManager().poll_auth_session(session_id) return schemas.Response(success=True, data=result) except Exception as err: return schemas.Response(success=False, message=str(err)) @router.delete( "/provider-auth/{provider_id}", summary="断开LLM提供商授权", response_model=schemas.Response, ) async def delete_llm_provider_auth( provider_id: str, _: User = Depends(get_current_active_superuser_async), ): """ 删除已保存的 provider 授权信息。 """ try: await LLMProviderManager().clear_auth(provider_id) return schemas.Response(success=True) except Exception as err: return schemas.Response(success=False, message=str(err)) @router.get( "/provider-auth/callback/{provider_id}", summary="LLM提供商OAuth回调", response_class=HTMLResponse, name="llm_provider_auth_callback", ) async def llm_provider_auth_callback( provider_id: str, code: Optional[str] = None, state: Optional[str] = None, error: Optional[str] = None, error_description: Optional[str] = None, ): """ 处理需要浏览器回跳的 OAuth provider。 """ success, message = await LLMProviderManager().handle_chatgpt_callback( provider_id, code, state, error, error_description, ) return HTMLResponse(content=render_auth_result_html(success, message)) @router.post("/test", summary="测试LLM调用", response_model=schemas.Response) async def llm_test( payload: Annotated[Optional[LlmTestRequest], Body()] = None, _: User = Depends(get_current_active_superuser_async), ): """ 使用传入配置或当前已保存配置执行一次最小 LLM 调用。 """ payload = payload or LlmTestRequest( enabled=settings.AI_AGENT_ENABLE, provider=settings.LLM_PROVIDER, model=settings.LLM_MODEL, thinking_level=settings.LLM_THINKING_LEVEL, api_key=settings.LLM_API_KEY, base_url=settings.LLM_BASE_URL, base_url_preset=settings.LLM_BASE_URL_PRESET, ) if not payload.provider: return schemas.Response(success=False, message="请配置LLM提供商和模型") if not payload.model or not payload.model.strip(): return schemas.Response(success=False, message="请先配置 LLM 模型") data = { "provider": payload.provider, "model": payload.model, } if not payload.enabled: return schemas.Response(success=False, message="请先启用智能助手", data=data) if ( payload.provider not in {"chatgpt", "github-copilot"} and (not payload.api_key or not payload.api_key.strip()) ): return schemas.Response( success=False, message="请先配置 LLM API Key", data=data, ) try: result = await LLMHelper.test_current_settings( provider=payload.provider, model=payload.model, thinking_level=payload.thinking_level, api_key=payload.api_key, base_url=payload.base_url, base_url_preset=payload.base_url_preset, ) if not result.get("reply_preview"): return schemas.Response( success=False, message="模型响应为空", data=result, ) return schemas.Response(success=True, data=result) except (LLMTestTimeout, TimeoutError) as err: logger.warning(err) return schemas.Response( success=False, message="LLM 调用超时", ) except Exception as err: return schemas.Response( success=False, message=_sanitize_llm_test_error(str(err), payload.api_key), )