mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 20:42:43 +08:00
feat: add llm test endpoint
This commit is contained in:
@@ -29,7 +29,7 @@ from app.db.user_oper import (
|
||||
get_current_active_superuser_async,
|
||||
get_current_active_user_async,
|
||||
)
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.llm import LLMHelper, LLMTestError, LLMTestTimeout
|
||||
from app.helper.mediaserver import MediaServerHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.helper.progress import ProgressHelper
|
||||
@@ -259,6 +259,59 @@ def _build_nettest_rules() -> list[dict[str, Any]]:
|
||||
return rules
|
||||
|
||||
|
||||
def _build_llm_test_data(
|
||||
duration_ms: Optional[int] = None,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
构造 LLM 测试接口的基础返回数据。
|
||||
"""
|
||||
data = {
|
||||
"provider": provider if provider is not None else settings.LLM_PROVIDER,
|
||||
"model": model if model is not None else settings.LLM_MODEL,
|
||||
}
|
||||
if duration_ms is not None:
|
||||
data["duration_ms"] = duration_ms
|
||||
return data
|
||||
|
||||
|
||||
def _build_llm_test_snapshot() -> dict[str, Any]:
|
||||
"""
|
||||
冻结当前 LLM 测试所需配置,避免请求执行过程中被新的保存动作改写。
|
||||
"""
|
||||
return {
|
||||
"enabled": bool(settings.AI_AGENT_ENABLE),
|
||||
"provider": settings.LLM_PROVIDER,
|
||||
"model": settings.LLM_MODEL,
|
||||
"api_key": settings.LLM_API_KEY,
|
||||
"base_url": settings.LLM_BASE_URL,
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_llm_test_error(message: str, api_key: Optional[str] = None) -> str:
|
||||
"""
|
||||
清理错误信息中的敏感字段,避免回显密钥。
|
||||
"""
|
||||
if not message:
|
||||
return "LLM 调用失败"
|
||||
|
||||
sanitized = message
|
||||
if api_key:
|
||||
sanitized = sanitized.replace(api_key, "***")
|
||||
sanitized = re.sub(
|
||||
r"(?i)(api[_-]?key\s*[:=]\s*)([^\s,;]+)",
|
||||
r"\1***",
|
||||
sanitized,
|
||||
)
|
||||
sanitized = re.sub(
|
||||
r"(?i)authorization\s*:\s*bearer\s+[^\s,;]+",
|
||||
"Authorization: ***",
|
||||
sanitized,
|
||||
)
|
||||
return sanitized
|
||||
|
||||
|
||||
def _validate_nettest_url(url: str) -> Optional[str]:
|
||||
"""
|
||||
对实际请求地址做基础安全校验。
|
||||
@@ -625,6 +678,73 @@ async def get_llm_models(
|
||||
return schemas.Response(success=False, message=str(e))
|
||||
|
||||
|
||||
@router.post("/llm-test", summary="测试LLM调用", response_model=schemas.Response)
|
||||
async def llm_test(_: User = Depends(get_current_active_superuser_async)):
|
||||
"""
|
||||
使用当前已保存配置执行一次最小 LLM 调用。
|
||||
"""
|
||||
snapshot = _build_llm_test_snapshot()
|
||||
data = _build_llm_test_data(
|
||||
provider=snapshot["provider"],
|
||||
model=snapshot["model"],
|
||||
)
|
||||
if not snapshot["enabled"]:
|
||||
return schemas.Response(success=False, message="请先启用智能助手", data=data)
|
||||
|
||||
if not snapshot["api_key"]:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="请先配置 LLM API Key",
|
||||
data=data,
|
||||
)
|
||||
|
||||
if not (snapshot["model"] or "").strip():
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="请先配置 LLM 模型",
|
||||
data=data,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await LLMHelper.test_current_settings(
|
||||
provider=snapshot["provider"],
|
||||
model=snapshot["model"],
|
||||
api_key=snapshot["api_key"],
|
||||
base_url=snapshot["base_url"],
|
||||
)
|
||||
if not result.get("reply_preview"):
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="模型响应为空",
|
||||
data=_build_llm_test_data(
|
||||
result.get("duration_ms"),
|
||||
provider=snapshot["provider"],
|
||||
model=snapshot["model"],
|
||||
),
|
||||
)
|
||||
return schemas.Response(success=True, data=result)
|
||||
except (LLMTestTimeout, TimeoutError) as err:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="LLM 调用超时",
|
||||
data=_build_llm_test_data(
|
||||
getattr(err, "duration_ms", None),
|
||||
provider=snapshot["provider"],
|
||||
model=snapshot["model"],
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=_sanitize_llm_test_error(str(err), snapshot["api_key"]),
|
||||
data=_build_llm_test_data(
|
||||
getattr(err, "duration_ms", None),
|
||||
provider=snapshot["provider"],
|
||||
model=snapshot["model"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/message", summary="实时消息")
|
||||
async def get_message(
|
||||
request: Request,
|
||||
|
||||
@@ -1,12 +1,30 @@
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class LLMTestError(RuntimeError):
|
||||
"""LLM 测试调用异常,附带请求耗时。"""
|
||||
|
||||
def __init__(self, message: str, duration_ms: int | None = None):
|
||||
super().__init__(message)
|
||||
self.duration_ms = duration_ms
|
||||
|
||||
|
||||
class LLMTestTimeout(TimeoutError):
|
||||
"""LLM 测试调用超时,附带请求耗时。"""
|
||||
|
||||
def __init__(self, message: str, duration_ms: int | None = None):
|
||||
super().__init__(message)
|
||||
self.duration_ms = duration_ms
|
||||
|
||||
|
||||
def _patch_gemini_thought_signature():
|
||||
"""
|
||||
修复 langchain-google-genai 中 Gemini 2.5 思考模型的 thought_signature 兼容问题。
|
||||
@@ -67,19 +85,29 @@ class LLMHelper:
|
||||
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
|
||||
|
||||
@staticmethod
|
||||
def get_llm(streaming: bool = False):
|
||||
def get_llm(
|
||||
streaming: bool = False,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
):
|
||||
"""
|
||||
获取LLM实例
|
||||
:param streaming: 是否启用流式输出
|
||||
:return: LLM实例
|
||||
"""
|
||||
provider = settings.LLM_PROVIDER.lower()
|
||||
api_key = settings.LLM_API_KEY
|
||||
provider_name = str(
|
||||
provider if provider is not None else settings.LLM_PROVIDER
|
||||
).lower()
|
||||
model_name = model if model is not None else settings.LLM_MODEL
|
||||
api_key_value = api_key if api_key is not None else settings.LLM_API_KEY
|
||||
base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL
|
||||
|
||||
if not api_key:
|
||||
if not api_key_value:
|
||||
raise ValueError("未配置LLM API Key")
|
||||
|
||||
if provider == "google":
|
||||
if provider_name == "google":
|
||||
# 修补 Gemini 2.5 思考模型的 thought_signature 兼容性
|
||||
_patch_gemini_thought_signature()
|
||||
|
||||
@@ -94,19 +122,19 @@ class LLMHelper:
|
||||
client_args = {proxy_key: settings.PROXY_HOST}
|
||||
|
||||
model = ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
model=model_name,
|
||||
api_key=api_key_value,
|
||||
retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
client_args=client_args,
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
elif provider_name == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
model = ChatDeepSeek(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
model=model_name,
|
||||
api_key=api_key_value,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
@@ -116,10 +144,10 @@ class LLMHelper:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
model = ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
model=model_name,
|
||||
api_key=api_key_value,
|
||||
max_retries=3,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
base_url=base_url_value,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
stream_usage=True,
|
||||
@@ -137,6 +165,93 @@ class LLMHelper:
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_content(content) -> str:
|
||||
"""
|
||||
从响应内容中提取纯文本,仅保留真实文本块。
|
||||
"""
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
continue
|
||||
|
||||
if isinstance(block, dict) or hasattr(block, "get"):
|
||||
block_type = block.get("type")
|
||||
if block.get("thought") or block_type in (
|
||||
"thinking",
|
||||
"reasoning_content",
|
||||
"reasoning",
|
||||
"thought",
|
||||
):
|
||||
continue
|
||||
if block_type == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
continue
|
||||
if not block_type and isinstance(block.get("text"), str):
|
||||
text_parts.append(block.get("text", ""))
|
||||
return "".join(text_parts)
|
||||
if isinstance(content, dict) or hasattr(content, "get"):
|
||||
if content.get("thought"):
|
||||
return ""
|
||||
if content.get("type") == "text":
|
||||
return content.get("text", "")
|
||||
if not content.get("type") and isinstance(content.get("text"), str):
|
||||
return content.get("text", "")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
async def test_current_settings(
|
||||
prompt: str = "请只回复 OK",
|
||||
timeout: int = 20,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
使用当前已保存配置执行一次最小 LLM 调用。
|
||||
"""
|
||||
provider_name = provider if provider is not None else settings.LLM_PROVIDER
|
||||
model_name = model if model is not None else settings.LLM_MODEL
|
||||
api_key_value = api_key if api_key is not None else settings.LLM_API_KEY
|
||||
base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL
|
||||
start = time.perf_counter()
|
||||
llm = LLMHelper.get_llm(
|
||||
streaming=False,
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
api_key=api_key_value,
|
||||
base_url=base_url_value,
|
||||
)
|
||||
try:
|
||||
response = await asyncio.wait_for(llm.ainvoke(prompt), timeout=timeout)
|
||||
except TimeoutError as err:
|
||||
duration_ms = round((time.perf_counter() - start) * 1000)
|
||||
raise LLMTestTimeout("LLM 调用超时", duration_ms=duration_ms) from err
|
||||
except Exception as err:
|
||||
duration_ms = round((time.perf_counter() - start) * 1000)
|
||||
raise LLMTestError(str(err), duration_ms=duration_ms) from err
|
||||
|
||||
reply_text = LLMHelper._extract_text_content(
|
||||
getattr(response, "content", response)
|
||||
).strip()
|
||||
duration_ms = round((time.perf_counter() - start) * 1000)
|
||||
|
||||
data = {
|
||||
"provider": provider_name,
|
||||
"model": model_name,
|
||||
"duration_ms": duration_ms,
|
||||
}
|
||||
if reply_text:
|
||||
data["reply_preview"] = reply_text[:120]
|
||||
return data
|
||||
|
||||
def get_models(
|
||||
self, provider: str, api_key: str, base_url: str = None
|
||||
) -> List[str]:
|
||||
|
||||
114
tests/test_llm_helper_testcall.py
Normal file
114
tests/test_llm_helper_testcall.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import importlib.util
|
||||
import sys
|
||||
import unittest
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = sys.modules.get(name)
|
||||
if module is None:
|
||||
module = ModuleType(name)
|
||||
sys.modules[name] = module
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
return module
|
||||
|
||||
|
||||
class _DummyLogger:
|
||||
def __getattr__(self, _name):
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
|
||||
class _FakeModel:
|
||||
def __init__(self, content):
|
||||
self._content = content
|
||||
|
||||
async def ainvoke(self, _prompt):
|
||||
return SimpleNamespace(content=self._content)
|
||||
|
||||
|
||||
sys.modules.pop("app.helper.llm", None)
|
||||
_stub_module(
|
||||
"app.core.config",
|
||||
settings=SimpleNamespace(
|
||||
LLM_PROVIDER="global-provider",
|
||||
LLM_MODEL="global-model",
|
||||
LLM_API_KEY="global-key",
|
||||
LLM_BASE_URL="https://global.example.com",
|
||||
LLM_TEMPERATURE=0.1,
|
||||
LLM_MAX_CONTEXT_TOKENS=64,
|
||||
PROXY_HOST=None,
|
||||
),
|
||||
)
|
||||
_stub_module("app.log", logger=_DummyLogger())
|
||||
|
||||
module_path = "/Users/sdongmaker/VScode/MoviePilot/app/helper/llm.py"
|
||||
spec = importlib.util.spec_from_file_location("test_llm_module", module_path)
|
||||
llm_module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(llm_module)
|
||||
|
||||
|
||||
class LlmHelperTestCallTest(unittest.TestCase):
|
||||
def test_extract_text_content_ignores_non_text_blocks(self):
|
||||
content = [
|
||||
{"type": "reasoning", "text": "internal"},
|
||||
{"type": "tool_use", "name": "search"},
|
||||
{"type": "text", "text": "OK"},
|
||||
]
|
||||
|
||||
result = llm_module.LLMHelper._extract_text_content(content)
|
||||
|
||||
self.assertEqual(result, "OK")
|
||||
|
||||
def test_test_current_settings_uses_explicit_snapshot(self):
|
||||
fake_model = _FakeModel("OK")
|
||||
get_llm_mock = Mock(return_value=fake_model)
|
||||
|
||||
with patch.object(llm_module.LLMHelper, "get_llm", get_llm_mock):
|
||||
result = asyncio.run(
|
||||
llm_module.LLMHelper.test_current_settings(
|
||||
provider="deepseek",
|
||||
model="deepseek-chat",
|
||||
api_key="sk-test",
|
||||
base_url="https://api.deepseek.com",
|
||||
)
|
||||
)
|
||||
|
||||
get_llm_mock.assert_called_once_with(
|
||||
streaming=False,
|
||||
provider="deepseek",
|
||||
model="deepseek-chat",
|
||||
api_key="sk-test",
|
||||
base_url="https://api.deepseek.com",
|
||||
)
|
||||
self.assertEqual(result["provider"], "deepseek")
|
||||
self.assertEqual(result["model"], "deepseek-chat")
|
||||
self.assertEqual(result["reply_preview"], "OK")
|
||||
|
||||
def test_test_current_settings_does_not_promote_non_text_blocks(self):
|
||||
fake_model = _FakeModel(
|
||||
[
|
||||
{"type": "tool_use", "name": "lookup"},
|
||||
{"type": "reasoning", "text": "thinking"},
|
||||
]
|
||||
)
|
||||
|
||||
with patch.object(llm_module.LLMHelper, "get_llm", return_value=fake_model):
|
||||
result = asyncio.run(
|
||||
llm_module.LLMHelper.test_current_settings(
|
||||
provider="deepseek",
|
||||
model="deepseek-chat",
|
||||
api_key="sk-test",
|
||||
base_url="https://api.deepseek.com",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertNotIn("reply_preview", result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
203
tests/test_system_llm_test.py
Normal file
203
tests/test_system_llm_test.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import unittest
|
||||
from types import ModuleType
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = sys.modules.get(name)
|
||||
if module is None:
|
||||
module = ModuleType(name)
|
||||
sys.modules[name] = module
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
return module
|
||||
|
||||
|
||||
class _Dummy:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __getattr__(self, _name):
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
|
||||
class _DummyError(Exception):
|
||||
def __init__(self, message="", duration_ms=None):
|
||||
super().__init__(message)
|
||||
self.duration_ms = duration_ms
|
||||
|
||||
|
||||
for _module_name in ("pillow_avif", "aiofiles", "psutil"):
|
||||
_stub_module(_module_name)
|
||||
|
||||
_stub_module("app.helper.sites", SitesHelper=_Dummy)
|
||||
_stub_module("app.chain.mediaserver", MediaServerChain=_Dummy)
|
||||
_stub_module("app.chain.search", SearchChain=_Dummy)
|
||||
_stub_module("app.chain.system", SystemChain=_Dummy)
|
||||
_stub_module("app.core.event", eventmanager=_Dummy())
|
||||
_stub_module("app.core.metainfo", MetaInfo=_Dummy)
|
||||
_stub_module("app.core.module", ModuleManager=_Dummy)
|
||||
_stub_module(
|
||||
"app.core.security",
|
||||
verify_apitoken=_Dummy,
|
||||
verify_resource_token=_Dummy,
|
||||
verify_token=_Dummy,
|
||||
)
|
||||
_stub_module("app.db.models", User=_Dummy)
|
||||
_stub_module("app.db.systemconfig_oper", SystemConfigOper=_Dummy)
|
||||
_stub_module(
|
||||
"app.db.user_oper",
|
||||
get_current_active_superuser=_Dummy,
|
||||
get_current_active_superuser_async=_Dummy,
|
||||
get_current_active_user_async=_Dummy,
|
||||
)
|
||||
_stub_module(
|
||||
"app.helper.llm",
|
||||
LLMHelper=_Dummy,
|
||||
LLMTestError=_DummyError,
|
||||
LLMTestTimeout=_DummyError,
|
||||
)
|
||||
_stub_module("app.helper.mediaserver", MediaServerHelper=_Dummy)
|
||||
_stub_module("app.helper.message", MessageHelper=_Dummy)
|
||||
_stub_module("app.helper.progress", ProgressHelper=_Dummy)
|
||||
_stub_module("app.helper.rule", RuleHelper=_Dummy)
|
||||
_stub_module("app.helper.subscribe", SubscribeHelper=_Dummy)
|
||||
_stub_module("app.helper.system", SystemHelper=_Dummy)
|
||||
_stub_module("app.helper.image", ImageHelper=_Dummy)
|
||||
_stub_module("app.scheduler", Scheduler=_Dummy)
|
||||
_stub_module(
|
||||
"app.log",
|
||||
logger=_Dummy(),
|
||||
log_settings=_Dummy(),
|
||||
LogConfigModel=type("LogConfigModel", (), {}),
|
||||
)
|
||||
_stub_module("app.utils.crypto", HashUtils=_Dummy)
|
||||
_stub_module("app.utils.http", RequestUtils=_Dummy, AsyncRequestUtils=_Dummy)
|
||||
_stub_module("version", APP_VERSION="test")
|
||||
|
||||
from app.api.endpoints import system as system_endpoint
|
||||
|
||||
|
||||
class LlmTestEndpointTest(unittest.TestCase):
|
||||
def test_llm_test_requires_ai_agent_enabled(self):
|
||||
with patch.object(system_endpoint.settings, "AI_AGENT_ENABLE", False):
|
||||
resp = asyncio.run(system_endpoint.llm_test(_="token"))
|
||||
|
||||
self.assertFalse(resp.success)
|
||||
self.assertEqual(resp.message, "请先启用智能助手")
|
||||
|
||||
def test_llm_test_requires_api_key(self):
|
||||
with patch.object(system_endpoint.settings, "AI_AGENT_ENABLE", True), patch.object(
|
||||
system_endpoint.settings, "LLM_API_KEY", None
|
||||
), patch.object(system_endpoint.settings, "LLM_MODEL", "deepseek-chat"):
|
||||
resp = asyncio.run(system_endpoint.llm_test(_="token"))
|
||||
|
||||
self.assertFalse(resp.success)
|
||||
self.assertEqual(resp.message, "请先配置 LLM API Key")
|
||||
self.assertEqual(resp.data["model"], "deepseek-chat")
|
||||
|
||||
def test_llm_test_requires_model(self):
|
||||
with patch.object(system_endpoint.settings, "AI_AGENT_ENABLE", True), patch.object(
|
||||
system_endpoint.settings, "LLM_API_KEY", "sk-test"
|
||||
), patch.object(system_endpoint.settings, "LLM_MODEL", ""):
|
||||
resp = asyncio.run(system_endpoint.llm_test(_="token"))
|
||||
|
||||
self.assertFalse(resp.success)
|
||||
self.assertEqual(resp.message, "请先配置 LLM 模型")
|
||||
|
||||
def test_llm_test_returns_successful_reply_preview(self):
|
||||
llm_test_mock = AsyncMock(
|
||||
return_value={
|
||||
"provider": "deepseek",
|
||||
"model": "deepseek-chat",
|
||||
"duration_ms": 321,
|
||||
"reply_preview": "OK",
|
||||
}
|
||||
)
|
||||
with patch.object(system_endpoint.settings, "AI_AGENT_ENABLE", True), patch.object(
|
||||
system_endpoint.settings, "LLM_PROVIDER", "deepseek"
|
||||
), patch.object(system_endpoint.settings, "LLM_MODEL", "deepseek-chat"), patch.object(
|
||||
system_endpoint.settings, "LLM_API_KEY", "sk-test"
|
||||
), patch.object(
|
||||
system_endpoint.settings, "LLM_BASE_URL", "https://api.deepseek.com"
|
||||
), patch.object(
|
||||
system_endpoint.LLMHelper,
|
||||
"test_current_settings",
|
||||
llm_test_mock,
|
||||
create=True,
|
||||
):
|
||||
resp = asyncio.run(system_endpoint.llm_test(_="token"))
|
||||
|
||||
llm_test_mock.assert_awaited_once_with(
|
||||
provider="deepseek",
|
||||
model="deepseek-chat",
|
||||
api_key="sk-test",
|
||||
base_url="https://api.deepseek.com",
|
||||
)
|
||||
self.assertTrue(resp.success)
|
||||
self.assertEqual(resp.data["provider"], "deepseek")
|
||||
self.assertEqual(resp.data["model"], "deepseek-chat")
|
||||
self.assertEqual(resp.data["duration_ms"], 321)
|
||||
self.assertEqual(resp.data["reply_preview"], "OK")
|
||||
|
||||
def test_llm_test_rejects_empty_reply(self):
|
||||
with patch.object(system_endpoint.settings, "AI_AGENT_ENABLE", True), patch.object(
|
||||
system_endpoint.settings, "LLM_PROVIDER", "deepseek"
|
||||
), patch.object(system_endpoint.settings, "LLM_MODEL", "deepseek-chat"), patch.object(
|
||||
system_endpoint.settings, "LLM_API_KEY", "sk-test"
|
||||
), patch.object(
|
||||
system_endpoint.LLMHelper,
|
||||
"test_current_settings",
|
||||
AsyncMock(return_value={"provider": "deepseek", "model": "deepseek-chat", "duration_ms": 12}),
|
||||
create=True,
|
||||
):
|
||||
resp = asyncio.run(system_endpoint.llm_test(_="token"))
|
||||
|
||||
self.assertFalse(resp.success)
|
||||
self.assertEqual(resp.message, "模型响应为空")
|
||||
self.assertEqual(resp.data["duration_ms"], 12)
|
||||
|
||||
def test_llm_test_maps_timeout_error(self):
|
||||
with patch.object(system_endpoint.settings, "AI_AGENT_ENABLE", True), patch.object(
|
||||
system_endpoint.settings, "LLM_PROVIDER", "deepseek"
|
||||
), patch.object(system_endpoint.settings, "LLM_MODEL", "deepseek-chat"), patch.object(
|
||||
system_endpoint.settings, "LLM_API_KEY", "sk-test"
|
||||
), patch.object(
|
||||
system_endpoint.LLMHelper,
|
||||
"test_current_settings",
|
||||
AsyncMock(side_effect=TimeoutError("request timed out")),
|
||||
create=True,
|
||||
):
|
||||
resp = asyncio.run(system_endpoint.llm_test(_="token"))
|
||||
|
||||
self.assertFalse(resp.success)
|
||||
self.assertEqual(resp.message, "LLM 调用超时")
|
||||
|
||||
def test_llm_test_sanitizes_error_message(self):
|
||||
raw_error = (
|
||||
"request failed api_key=sk-secret "
|
||||
"Authorization: Bearer sk-secret "
|
||||
"base error sk-secret"
|
||||
)
|
||||
with patch.object(system_endpoint.settings, "AI_AGENT_ENABLE", True), patch.object(
|
||||
system_endpoint.settings, "LLM_API_KEY", "sk-secret"
|
||||
), patch.object(system_endpoint.settings, "LLM_PROVIDER", "deepseek"), patch.object(
|
||||
system_endpoint.settings, "LLM_MODEL", "deepseek-chat"
|
||||
), patch.object(
|
||||
system_endpoint.LLMHelper,
|
||||
"test_current_settings",
|
||||
AsyncMock(side_effect=RuntimeError(raw_error)),
|
||||
create=True,
|
||||
):
|
||||
resp = asyncio.run(system_endpoint.llm_test(_="token"))
|
||||
|
||||
self.assertFalse(resp.success)
|
||||
self.assertNotIn("sk-secret", resp.message)
|
||||
self.assertNotIn("Authorization: Bearer", resp.message)
|
||||
self.assertIn("***", resp.message)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -19,6 +19,15 @@ class _Dummy:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __getattr__(self, _name):
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
|
||||
class _DummyError(Exception):
|
||||
def __init__(self, message="", duration_ms=None):
|
||||
super().__init__(message)
|
||||
self.duration_ms = duration_ms
|
||||
|
||||
|
||||
for _module_name in ("pillow_avif", "aiofiles", "psutil"):
|
||||
_stub_module(_module_name)
|
||||
@@ -44,7 +53,12 @@ _stub_module(
|
||||
get_current_active_superuser_async=_Dummy,
|
||||
get_current_active_user_async=_Dummy,
|
||||
)
|
||||
_stub_module("app.helper.llm", LLMHelper=_Dummy)
|
||||
_stub_module(
|
||||
"app.helper.llm",
|
||||
LLMHelper=_Dummy,
|
||||
LLMTestError=_DummyError,
|
||||
LLMTestTimeout=_DummyError,
|
||||
)
|
||||
_stub_module("app.helper.mediaserver", MediaServerHelper=_Dummy)
|
||||
_stub_module("app.helper.message", MessageHelper=_Dummy)
|
||||
_stub_module("app.helper.progress", ProgressHelper=_Dummy)
|
||||
@@ -53,6 +67,12 @@ _stub_module("app.helper.subscribe", SubscribeHelper=_Dummy)
|
||||
_stub_module("app.helper.system", SystemHelper=_Dummy)
|
||||
_stub_module("app.helper.image", ImageHelper=_Dummy)
|
||||
_stub_module("app.scheduler", Scheduler=_Dummy)
|
||||
_stub_module(
|
||||
"app.log",
|
||||
logger=_Dummy(),
|
||||
log_settings=_Dummy(),
|
||||
LogConfigModel=type("LogConfigModel", (), {}),
|
||||
)
|
||||
_stub_module("app.utils.crypto", HashUtils=_Dummy)
|
||||
_stub_module("app.utils.http", RequestUtils=_Dummy, AsyncRequestUtils=_Dummy)
|
||||
_stub_module("version", APP_VERSION="test")
|
||||
|
||||
Reference in New Issue
Block a user