diff --git a/app/api/endpoints/system.py b/app/api/endpoints/system.py index 739f01d0..d4f0d254 100644 --- a/app/api/endpoints/system.py +++ b/app/api/endpoints/system.py @@ -45,6 +45,7 @@ from app.utils.crypto import HashUtils from app.utils.http import RequestUtils, AsyncRequestUtils from app.utils.security import SecurityUtils from app.utils.url import UrlUtils +from pydantic import BaseModel from version import APP_VERSION router = APIRouter() @@ -52,6 +53,14 @@ router = APIRouter() _NETTEST_REDIRECT_STATUS_CODES = {301, 302, 303, 307, 308} +class LlmTestRequest(BaseModel): + enabled: Optional[bool] = None + provider: Optional[str] = None + model: Optional[str] = None + api_key: Optional[str] = None + base_url: Optional[str] = None + + def _match_nettest_prefix(url: str, prefix: str) -> bool: """ 判断目标URL是否仍然落在允许的协议、主机、端口和路径前缀内。 @@ -276,16 +285,50 @@ def _build_llm_test_data( return data -def _build_llm_test_snapshot() -> dict[str, Any]: +def _normalize_llm_test_value( + value: Optional[str], *, empty_as_none: bool = False +) -> Optional[str]: """ - 冻结当前 LLM 测试所需配置,避免请求执行过程中被新的保存动作改写。 + 清理来自前端的 LLM 测试字段。 """ + if value is None: + return None + stripped = value.strip() + if empty_as_none and not stripped: + return None + return stripped + + +def _build_llm_test_snapshot(payload: Optional[LlmTestRequest] = None) -> dict[str, Any]: + """ + 冻结当前 LLM 测试所需配置。 + + 优先使用前端传入的临时参数;未传入时回退到已保存配置,兼容旧调用。 + """ + provider = settings.LLM_PROVIDER + model = settings.LLM_MODEL + api_key = settings.LLM_API_KEY + base_url = settings.LLM_BASE_URL + enabled = bool(settings.AI_AGENT_ENABLE) + + if payload: + if payload.enabled is not None: + enabled = bool(payload.enabled) + if payload.provider is not None: + provider = _normalize_llm_test_value(payload.provider) or "" + if payload.model is not None: + model = _normalize_llm_test_value(payload.model) or "" + if payload.api_key is not None: + api_key = _normalize_llm_test_value(payload.api_key, empty_as_none=True) + if payload.base_url is not None: + base_url = _normalize_llm_test_value(payload.base_url, empty_as_none=True) + return { - "enabled": bool(settings.AI_AGENT_ENABLE), - "provider": settings.LLM_PROVIDER, - "model": settings.LLM_MODEL, - "api_key": settings.LLM_API_KEY, - "base_url": settings.LLM_BASE_URL, + "enabled": enabled, + "provider": provider, + "model": model, + "api_key": api_key, + "base_url": base_url, } @@ -679,11 +722,14 @@ async def get_llm_models( @router.post("/llm-test", summary="测试LLM调用", response_model=schemas.Response) -async def llm_test(_: User = Depends(get_current_active_superuser_async)): +async def llm_test( + payload: Annotated[Optional[LlmTestRequest], Body()] = None, + _: User = Depends(get_current_active_superuser_async), +): """ - 使用当前已保存配置执行一次最小 LLM 调用。 + 使用传入配置或当前已保存配置执行一次最小 LLM 调用。 """ - snapshot = _build_llm_test_snapshot() + snapshot = _build_llm_test_snapshot(payload) data = _build_llm_test_data( provider=snapshot["provider"], model=snapshot["model"], diff --git a/tests/test_system_llm_test.py b/tests/test_system_llm_test.py index 792116fb..e185cc34 100644 --- a/tests/test_system_llm_test.py +++ b/tests/test_system_llm_test.py @@ -142,6 +142,47 @@ class LlmTestEndpointTest(unittest.TestCase): self.assertEqual(resp.data["duration_ms"], 321) self.assertEqual(resp.data["reply_preview"], "OK") + def test_llm_test_prefers_request_payload_over_saved_settings(self): + llm_test_mock = AsyncMock( + return_value={ + "provider": "openai", + "model": "gpt-4.1-mini", + "duration_ms": 123, + "reply_preview": "OK", + } + ) + payload = system_endpoint.LlmTestRequest( + enabled=True, + provider="openai", + model="gpt-4.1-mini", + api_key="sk-live", + base_url="https://example.com/v1", + ) + + with patch.object(system_endpoint.settings, "AI_AGENT_ENABLE", False), patch.object( + system_endpoint.settings, "LLM_PROVIDER", "deepseek" + ), patch.object(system_endpoint.settings, "LLM_MODEL", "deepseek-chat"), patch.object( + system_endpoint.settings, "LLM_API_KEY", "sk-saved" + ), patch.object( + system_endpoint.settings, "LLM_BASE_URL", "https://api.deepseek.com" + ), patch.object( + system_endpoint.LLMHelper, + "test_current_settings", + llm_test_mock, + create=True, + ): + resp = asyncio.run(system_endpoint.llm_test(payload=payload, _="token")) + + llm_test_mock.assert_awaited_once_with( + provider="openai", + model="gpt-4.1-mini", + api_key="sk-live", + base_url="https://example.com/v1", + ) + self.assertTrue(resp.success) + self.assertEqual(resp.data["provider"], "openai") + self.assertEqual(resp.data["model"], "gpt-4.1-mini") + def test_llm_test_rejects_empty_reply(self): with patch.object(system_endpoint.settings, "AI_AGENT_ENABLE", True), patch.object( system_endpoint.settings, "LLM_PROVIDER", "deepseek"