mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 20:42:43 +08:00
Allow LLM test to use request payload
This commit is contained in:
@@ -45,6 +45,7 @@ from app.utils.crypto import HashUtils
|
||||
from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
from app.utils.security import SecurityUtils
|
||||
from app.utils.url import UrlUtils
|
||||
from pydantic import BaseModel
|
||||
from version import APP_VERSION
|
||||
|
||||
router = APIRouter()
|
||||
@@ -52,6 +53,14 @@ router = APIRouter()
|
||||
_NETTEST_REDIRECT_STATUS_CODES = {301, 302, 303, 307, 308}
|
||||
|
||||
|
||||
class LlmTestRequest(BaseModel):
|
||||
enabled: Optional[bool] = None
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
|
||||
|
||||
def _match_nettest_prefix(url: str, prefix: str) -> bool:
|
||||
"""
|
||||
判断目标URL是否仍然落在允许的协议、主机、端口和路径前缀内。
|
||||
@@ -276,16 +285,50 @@ def _build_llm_test_data(
|
||||
return data
|
||||
|
||||
|
||||
def _build_llm_test_snapshot() -> dict[str, Any]:
|
||||
def _normalize_llm_test_value(
|
||||
value: Optional[str], *, empty_as_none: bool = False
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
冻结当前 LLM 测试所需配置,避免请求执行过程中被新的保存动作改写。
|
||||
清理来自前端的 LLM 测试字段。
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
stripped = value.strip()
|
||||
if empty_as_none and not stripped:
|
||||
return None
|
||||
return stripped
|
||||
|
||||
|
||||
def _build_llm_test_snapshot(payload: Optional[LlmTestRequest] = None) -> dict[str, Any]:
|
||||
"""
|
||||
冻结当前 LLM 测试所需配置。
|
||||
|
||||
优先使用前端传入的临时参数;未传入时回退到已保存配置,兼容旧调用。
|
||||
"""
|
||||
provider = settings.LLM_PROVIDER
|
||||
model = settings.LLM_MODEL
|
||||
api_key = settings.LLM_API_KEY
|
||||
base_url = settings.LLM_BASE_URL
|
||||
enabled = bool(settings.AI_AGENT_ENABLE)
|
||||
|
||||
if payload:
|
||||
if payload.enabled is not None:
|
||||
enabled = bool(payload.enabled)
|
||||
if payload.provider is not None:
|
||||
provider = _normalize_llm_test_value(payload.provider) or ""
|
||||
if payload.model is not None:
|
||||
model = _normalize_llm_test_value(payload.model) or ""
|
||||
if payload.api_key is not None:
|
||||
api_key = _normalize_llm_test_value(payload.api_key, empty_as_none=True)
|
||||
if payload.base_url is not None:
|
||||
base_url = _normalize_llm_test_value(payload.base_url, empty_as_none=True)
|
||||
|
||||
return {
|
||||
"enabled": bool(settings.AI_AGENT_ENABLE),
|
||||
"provider": settings.LLM_PROVIDER,
|
||||
"model": settings.LLM_MODEL,
|
||||
"api_key": settings.LLM_API_KEY,
|
||||
"base_url": settings.LLM_BASE_URL,
|
||||
"enabled": enabled,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
|
||||
|
||||
@@ -679,11 +722,14 @@ async def get_llm_models(
|
||||
|
||||
|
||||
@router.post("/llm-test", summary="测试LLM调用", response_model=schemas.Response)
|
||||
async def llm_test(_: User = Depends(get_current_active_superuser_async)):
|
||||
async def llm_test(
|
||||
payload: Annotated[Optional[LlmTestRequest], Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
):
|
||||
"""
|
||||
使用当前已保存配置执行一次最小 LLM 调用。
|
||||
使用传入配置或当前已保存配置执行一次最小 LLM 调用。
|
||||
"""
|
||||
snapshot = _build_llm_test_snapshot()
|
||||
snapshot = _build_llm_test_snapshot(payload)
|
||||
data = _build_llm_test_data(
|
||||
provider=snapshot["provider"],
|
||||
model=snapshot["model"],
|
||||
|
||||
@@ -142,6 +142,47 @@ class LlmTestEndpointTest(unittest.TestCase):
|
||||
self.assertEqual(resp.data["duration_ms"], 321)
|
||||
self.assertEqual(resp.data["reply_preview"], "OK")
|
||||
|
||||
def test_llm_test_prefers_request_payload_over_saved_settings(self):
|
||||
llm_test_mock = AsyncMock(
|
||||
return_value={
|
||||
"provider": "openai",
|
||||
"model": "gpt-4.1-mini",
|
||||
"duration_ms": 123,
|
||||
"reply_preview": "OK",
|
||||
}
|
||||
)
|
||||
payload = system_endpoint.LlmTestRequest(
|
||||
enabled=True,
|
||||
provider="openai",
|
||||
model="gpt-4.1-mini",
|
||||
api_key="sk-live",
|
||||
base_url="https://example.com/v1",
|
||||
)
|
||||
|
||||
with patch.object(system_endpoint.settings, "AI_AGENT_ENABLE", False), patch.object(
|
||||
system_endpoint.settings, "LLM_PROVIDER", "deepseek"
|
||||
), patch.object(system_endpoint.settings, "LLM_MODEL", "deepseek-chat"), patch.object(
|
||||
system_endpoint.settings, "LLM_API_KEY", "sk-saved"
|
||||
), patch.object(
|
||||
system_endpoint.settings, "LLM_BASE_URL", "https://api.deepseek.com"
|
||||
), patch.object(
|
||||
system_endpoint.LLMHelper,
|
||||
"test_current_settings",
|
||||
llm_test_mock,
|
||||
create=True,
|
||||
):
|
||||
resp = asyncio.run(system_endpoint.llm_test(payload=payload, _="token"))
|
||||
|
||||
llm_test_mock.assert_awaited_once_with(
|
||||
provider="openai",
|
||||
model="gpt-4.1-mini",
|
||||
api_key="sk-live",
|
||||
base_url="https://example.com/v1",
|
||||
)
|
||||
self.assertTrue(resp.success)
|
||||
self.assertEqual(resp.data["provider"], "openai")
|
||||
self.assertEqual(resp.data["model"], "gpt-4.1-mini")
|
||||
|
||||
def test_llm_test_rejects_empty_reply(self):
|
||||
with patch.object(system_endpoint.settings, "AI_AGENT_ENABLE", True), patch.object(
|
||||
system_endpoint.settings, "LLM_PROVIDER", "deepseek"
|
||||
|
||||
Reference in New Issue
Block a user