mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-05 15:39:40 +08:00
feat: add llm test endpoint
This commit is contained in:
@@ -29,7 +29,7 @@ from app.db.user_oper import (
|
||||
get_current_active_superuser_async,
|
||||
get_current_active_user_async,
|
||||
)
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.llm import LLMHelper, LLMTestError, LLMTestTimeout
|
||||
from app.helper.mediaserver import MediaServerHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.helper.progress import ProgressHelper
|
||||
@@ -259,6 +259,59 @@ def _build_nettest_rules() -> list[dict[str, Any]]:
|
||||
return rules
|
||||
|
||||
|
||||
def _build_llm_test_data(
|
||||
duration_ms: Optional[int] = None,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
构造 LLM 测试接口的基础返回数据。
|
||||
"""
|
||||
data = {
|
||||
"provider": provider if provider is not None else settings.LLM_PROVIDER,
|
||||
"model": model if model is not None else settings.LLM_MODEL,
|
||||
}
|
||||
if duration_ms is not None:
|
||||
data["duration_ms"] = duration_ms
|
||||
return data
|
||||
|
||||
|
||||
def _build_llm_test_snapshot() -> dict[str, Any]:
|
||||
"""
|
||||
冻结当前 LLM 测试所需配置,避免请求执行过程中被新的保存动作改写。
|
||||
"""
|
||||
return {
|
||||
"enabled": bool(settings.AI_AGENT_ENABLE),
|
||||
"provider": settings.LLM_PROVIDER,
|
||||
"model": settings.LLM_MODEL,
|
||||
"api_key": settings.LLM_API_KEY,
|
||||
"base_url": settings.LLM_BASE_URL,
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_llm_test_error(message: str, api_key: Optional[str] = None) -> str:
|
||||
"""
|
||||
清理错误信息中的敏感字段,避免回显密钥。
|
||||
"""
|
||||
if not message:
|
||||
return "LLM 调用失败"
|
||||
|
||||
sanitized = message
|
||||
if api_key:
|
||||
sanitized = sanitized.replace(api_key, "***")
|
||||
sanitized = re.sub(
|
||||
r"(?i)(api[_-]?key\s*[:=]\s*)([^\s,;]+)",
|
||||
r"\1***",
|
||||
sanitized,
|
||||
)
|
||||
sanitized = re.sub(
|
||||
r"(?i)authorization\s*:\s*bearer\s+[^\s,;]+",
|
||||
"Authorization: ***",
|
||||
sanitized,
|
||||
)
|
||||
return sanitized
|
||||
|
||||
|
||||
def _validate_nettest_url(url: str) -> Optional[str]:
|
||||
"""
|
||||
对实际请求地址做基础安全校验。
|
||||
@@ -625,6 +678,73 @@ async def get_llm_models(
|
||||
return schemas.Response(success=False, message=str(e))
|
||||
|
||||
|
||||
@router.post("/llm-test", summary="测试LLM调用", response_model=schemas.Response)
|
||||
async def llm_test(_: User = Depends(get_current_active_superuser_async)):
|
||||
"""
|
||||
使用当前已保存配置执行一次最小 LLM 调用。
|
||||
"""
|
||||
snapshot = _build_llm_test_snapshot()
|
||||
data = _build_llm_test_data(
|
||||
provider=snapshot["provider"],
|
||||
model=snapshot["model"],
|
||||
)
|
||||
if not snapshot["enabled"]:
|
||||
return schemas.Response(success=False, message="请先启用智能助手", data=data)
|
||||
|
||||
if not snapshot["api_key"]:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="请先配置 LLM API Key",
|
||||
data=data,
|
||||
)
|
||||
|
||||
if not (snapshot["model"] or "").strip():
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="请先配置 LLM 模型",
|
||||
data=data,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await LLMHelper.test_current_settings(
|
||||
provider=snapshot["provider"],
|
||||
model=snapshot["model"],
|
||||
api_key=snapshot["api_key"],
|
||||
base_url=snapshot["base_url"],
|
||||
)
|
||||
if not result.get("reply_preview"):
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="模型响应为空",
|
||||
data=_build_llm_test_data(
|
||||
result.get("duration_ms"),
|
||||
provider=snapshot["provider"],
|
||||
model=snapshot["model"],
|
||||
),
|
||||
)
|
||||
return schemas.Response(success=True, data=result)
|
||||
except (LLMTestTimeout, TimeoutError) as err:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="LLM 调用超时",
|
||||
data=_build_llm_test_data(
|
||||
getattr(err, "duration_ms", None),
|
||||
provider=snapshot["provider"],
|
||||
model=snapshot["model"],
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=_sanitize_llm_test_error(str(err), snapshot["api_key"]),
|
||||
data=_build_llm_test_data(
|
||||
getattr(err, "duration_ms", None),
|
||||
provider=snapshot["provider"],
|
||||
model=snapshot["model"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/message", summary="实时消息")
|
||||
async def get_message(
|
||||
request: Request,
|
||||
|
||||
Reference in New Issue
Block a user