Refine existing implementation

This commit is contained in:
jxxghp
2026-06-22 18:21:20 +08:00
parent e44a6f41b5
commit 3306d196b7
9 changed files with 832 additions and 466 deletions

View File

@@ -838,6 +838,8 @@ class MoviePilotAgent:
detail = cls._exception_detail_text(error).lower()
if "no endpoints found that support image input" in detail:
return True
if "not a vlm" in detail or "text-only prompts" in detail:
return True
if "unknown variant" in detail and "image_url" in detail:
return True
if "image input" not in detail and "images" not in detail:

View File

@@ -691,7 +691,9 @@ class AgentCapabilityManager:
@staticmethod
def supports_image_input() -> bool:
"""当前 Agent 是否启用图片输入能力。"""
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
from app.agent.llm.helper import LLMHelper
return LLMHelper.supports_image_input()
@staticmethod
def supports_audio_input() -> bool:

View File

@@ -5,7 +5,7 @@ import inspect
import json
import time
from functools import wraps
from typing import Any, List
from typing import Any, List, Optional
from langchain_core.messages import AIMessage, AIMessageChunk
@@ -700,11 +700,85 @@ class LLMHelper:
return {}
@staticmethod
def supports_image_input() -> bool:
def _metadata_supports_image_input(metadata: Any) -> Optional[bool]:
"""从模型元数据中读取图片输入能力,未知时返回 None。"""
if not isinstance(metadata, dict):
return None
modalities = metadata.get("modalities") or {}
input_modalities = modalities.get("input")
if isinstance(input_modalities, str):
input_modalities = [input_modalities]
if isinstance(input_modalities, list):
normalized_modalities = {
str(item or "").strip().lower() for item in input_modalities
}
return "image" in normalized_modalities
return None
@classmethod
def _resolve_catalog_image_input_support(
cls,
provider: Optional[str] = None,
model: Optional[str] = None,
base_url: Optional[str] = None,
base_url_preset: Optional[str] = None,
) -> Optional[bool]:
"""复用 provider 目录缓存解析当前模型是否支持图片输入。"""
provider_name = str(provider if provider is not None else settings.LLM_PROVIDER).strip()
model_name = str(model if model is not None else settings.LLM_MODEL).strip()
if not provider_name or not model_name:
return None
try:
from app.agent.llm.provider import LLMProviderManager
metadata = LLMProviderManager().resolve_cached_model_metadata(
provider_id=provider_name,
model_id=model_name,
base_url=base_url if base_url is not None else settings.LLM_BASE_URL,
base_url_preset_id=(
base_url_preset
if base_url_preset is not None
else settings.LLM_BASE_URL_PRESET
),
)
except Exception as err:
logger.debug(f"解析模型图片能力失败: {err}")
return None
return cls._metadata_supports_image_input(metadata)
@classmethod
def supports_image_input(
cls,
provider: Optional[str] = None,
model: Optional[str] = None,
base_url: Optional[str] = None,
base_url_preset: Optional[str] = None,
) -> bool:
"""
判断当前模型是否启用了图片输入能力。
用户开关为总开关;当内置模型目录明确标注当前模型不支持 image 输入时,
即使总开关开启也降级为纯文本,避免文本模型收到 `image_url` 内容块后
被兼容端点以 400 拒绝。无参调用保持旧版“只读总开关”语义,
未知自定义模型也保持原有开关语义。
"""
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
if not settings.LLM_SUPPORT_IMAGE_INPUT:
return False
if provider is None and model is None:
return True
image_support = cls._resolve_catalog_image_input_support(
provider=provider,
model=model,
base_url=base_url,
base_url_preset=base_url_preset,
)
if image_support is not None:
return image_support
return True
@staticmethod
def _build_legacy_runtime(
@@ -798,6 +872,41 @@ class LLMHelper:
return True
return None
@staticmethod
def _attach_runtime_metadata(model: Any, runtime: dict[str, Any]) -> None:
"""
将 MoviePilot 已解析出的 provider 运行时信息挂到模型实例上。
这些字段只供内部中间件识别协议能力,不参与 LangChain 请求序列化。
"""
runtime_metadata = {
"runtime": runtime.get("runtime"),
"provider_id": runtime.get("provider_id"),
"base_url": runtime.get("base_url"),
}
def _set_metadata_attr(name: str, value: Any) -> None:
try:
setattr(model, name, value)
except Exception:
object.__setattr__(model, name, value)
try:
_set_metadata_attr("_moviepilot_llm_runtime", runtime_metadata["runtime"])
_set_metadata_attr(
"_moviepilot_llm_provider_id",
runtime_metadata["provider_id"],
)
_set_metadata_attr("_moviepilot_llm_base_url", runtime_metadata["base_url"])
except Exception as err:
logger.debug(f"LLM运行时元数据附加失败: {str(err)}")
profile = getattr(model, "profile", None)
if isinstance(profile, dict):
profile["moviepilot_runtime"] = runtime_metadata["runtime"]
profile["moviepilot_provider_id"] = runtime_metadata["provider_id"]
profile["moviepilot_base_url"] = runtime_metadata["base_url"]
@classmethod
def _resolve_thinking_level(
cls,
@@ -1011,6 +1120,7 @@ class LLMHelper:
"max_input_tokens": int(max_input_tokens),
}
cls._attach_runtime_metadata(model, runtime)
return model
@staticmethod

View File

@@ -1564,6 +1564,70 @@ class LLMProviderManager(metaclass=Singleton):
return models[candidate]
return None
def _cached_models_dev_model(
self,
provider_id: str,
model_id: str,
base_url: Optional[str] = None,
base_url_preset_id: Optional[str] = None,
) -> dict[str, Any] | None:
"""从已缓存或内置的 models.dev 数据中同步读取模型元数据。"""
try:
spec = self.get_provider(provider_id)
except LLMProviderError:
return None
models_dev_provider_id = self._resolve_provider_models_dev_provider_id(
spec,
base_url,
base_url_preset_id=base_url_preset_id,
)
if not models_dev_provider_id:
return None
payload = self._cached_models_dev_payload().get(models_dev_provider_id, {}) or {}
models = payload.get("models") if isinstance(payload, dict) else None
if not isinstance(models, dict):
return None
candidates = [model_id]
if model_id.startswith("models/"):
candidates.append(model_id.removeprefix("models/"))
for candidate in candidates:
if candidate in models:
return models[candidate]
return None
def resolve_cached_model_metadata(
self,
provider_id: str,
model_id: Optional[str],
base_url: Optional[str] = None,
base_url_preset_id: Optional[str] = None,
) -> dict[str, Any] | None:
"""同步解析缓存中的模型元数据,不触发远端 models.dev 刷新。"""
if not model_id:
return None
metadata = self._cached_models_dev_model(
provider_id,
model_id,
base_url=base_url,
base_url_preset_id=base_url_preset_id,
)
if metadata:
return metadata
if provider_id == "chatgpt":
return self._cached_models_dev_model("openai", model_id)
if provider_id == "openai":
return (
self._cached_models_dev_payload()
.get("openai", {})
.get("models", {})
.get(model_id)
)
return None
@staticmethod
def _normalize_model_record(
model_id: str,

View File

@@ -20,7 +20,7 @@ from langchain.agents.middleware.tool_selection import (
LLMToolSelectorMiddleware,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from langgraph.runtime import Runtime
@@ -70,17 +70,13 @@ class ToolSelectionStateUpdate(TypedDict):
class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
"""
为 DeepSeek 兼容端点提供更稳妥的工具筛选实现
使用 provider-neutral JSON 提示执行工具筛选。
LangChain 默认会通过 `with_structured_output()` 走 OpenAI
`response_format=json_schema` 路径,但 DeepSeek 官方 OpenAI 兼容端点公开文档
仅保证 `json_object` 模式可用。对于 `deepseek-reasoner`,这会在工具筛选阶段
提前触发 400导致 Agent 还没真正开始执行工具就失败。
因此这里仅在识别到 DeepSeek 模型/端点时,退回到显式 JSON 输出模式:
1. 使用 `response_format={"type": "json_object"}`
2. 在提示词中明确约束返回 JSON 结构;
3. 手动解析 `{"tools": [...]}`,其余模型继续沿用 LangChain 默认实现。
LangChain 默认会通过 `with_structured_output()` 走 provider-specific
结构化输出能力,不同 OpenAI/Anthropic 兼容端点对 `response_format`、
JSON schema 和工具绑定的支持并不一致。工具筛选只是 Agent 执行前的
辅助优化,失败时也会恢复使用全部工具,因此这里统一使用文本提示约束
模型返回 `{"tools": [...]}` 并手动解析,避免在筛选阶段引入额外兼容分支。
另外LangChain 原生工具筛选挂在 `wrap_model_call` 上,会在同一条用户请求
的每次“模型回合”前都重新筛选一次工具。对于会多轮调用工具的复杂任务,
@@ -354,40 +350,13 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
request,
)
@staticmethod
def _is_deepseek_compatible_model(model: BaseChatModel) -> bool:
"""
判断当前模型是否应当走 DeepSeek JSON 兼容分支。
除了官方 `langchain_deepseek`,用户也可能通过 OpenAI-compatible
配置把 DeepSeek 端点接到 `ChatOpenAI`。因此这里同时检查模块名、模型名
和 Base URL避免只靠单一条件漏判。
"""
module_name = type(model).__module__.lower()
model_name = (
str(getattr(model, "model_name", "") or getattr(model, "model", ""))
.strip()
.lower()
)
base_url = (
str(getattr(model, "openai_api_base", "") or getattr(model, "api_base", ""))
.strip()
.lower()
)
return (
"deepseek" in module_name
or model_name.startswith("deepseek-")
or "api.deepseek.com" in base_url
)
@staticmethod
def _parse_json_object(text: str) -> dict[str, Any]:
"""
解析模型返回的 JSON。
DeepSeek 在 JSON 模式下通常会返回纯 JSON但这里仍做一层兜底
兼容模型偶发输出围栏或前后说明文本的情况
不同模型可能偶发输出 Markdown 围栏或前后说明文本,因此这里从
响应中提取第一个 JSON 对象作为兜底
"""
stripped_text = text.strip()
if not stripped_text:
@@ -440,12 +409,12 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
)
return f"Capability groups from tool tags:\n{rendered_groups}\n\n"
def _build_deepseek_selection_prompt(self, selection_request: Any) -> str:
def _build_json_selection_prompt(self, selection_request: Any) -> str:
"""
为 DeepSeek 生成显式 JSON 输出提示。
生成显式 JSON 输出提示。
DeepSeek 官方文档要求在 JSON 输出模式下,提示词中必须明确包含 JSON
约束,否则兼容端点可能返回空内容或无意义输出
使用纯提示约束可覆盖更多兼容端点,避免在工具筛选阶段依赖某个
provider 专属的 `response_format` 或 schema 能力
"""
limit_instruction = ""
if self.max_tools:
@@ -469,7 +438,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
def _normalize_selection_response(self, response: Any) -> dict[str, list[str]]:
"""
解析并标准化 DeepSeek JSON 模式的工具筛选结果。
解析并标准化显式 JSON 模式的工具筛选结果。
"""
content = getattr(response, "content", response)
text = LLMHelper.extract_text_content(content)
@@ -486,22 +455,21 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
logger.debug(f"工具筛选标准化结果: {normalized_tools}")
return {"tools": normalized_tools}
async def _aselect_tools_with_deepseek(
async def _aselect_tools_with_json_prompt(
self, selection_request: Any
) -> dict[str, list[str]]:
"""
使用 DeepSeek 兼容的 JSON 输出模式执行异步工具筛选。
使用 JSON 提示执行异步工具筛选。
:param selection_request: LangChain 工具筛选请求
:return: 标准化后的工具名列表
"""
logger.debug("工具筛选走 DeepSeek JSON 兼容分支")
structured_model = selection_request.model.bind(
response_format={"type": "json_object"}
)
response = await structured_model.ainvoke(
logger.debug("工具筛选走 JSON 提示分支")
response = await selection_request.model.ainvoke(
[
{
"role": "system",
"content": self._build_deepseek_selection_prompt(selection_request),
},
SystemMessage(
content=self._build_json_selection_prompt(selection_request)
),
selection_request.last_user_message,
]
)
@@ -550,26 +518,17 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
if selection_request is None:
return request
if not self._is_deepseek_compatible_model(selection_request.model):
captured_request: ModelRequest[ContextT] = request
async def _capture_handler(
updated_request: ModelRequest[ContextT],
) -> ModelRequest[ContextT]:
nonlocal captured_request
captured_request = updated_request
return updated_request
await super().awrap_model_call(request, _capture_handler)
return captured_request
response = await self._aselect_tools_with_deepseek(selection_request)
return self._process_selection_response(
response,
selection_request.available_tools,
selection_request.valid_tool_names,
request,
)
try:
response = await self._aselect_tools_with_json_prompt(selection_request)
return self._process_selection_response(
response,
selection_request.available_tools,
selection_request.valid_tool_names,
request,
)
except Exception as err:
logger.warning(f"工具筛选失败,将恢复使用所有工具: {str(err)}")
return request
async def abefore_agent( # noqa
self,

View File

@@ -243,7 +243,7 @@ class MessageChain(ChainBase):
processing_status=processing_status,
)
finally:
if continues_async is not True:
if continues_async:
self._mark_message_processing_finished(
channel=channel,
source=source,
@@ -1278,7 +1278,10 @@ class MessageChain(ChainBase):
# 将可直接输入给 LLM 的附件统一转换为 data URL
original_images = images
all_files = list(files or [])
if images and LLMHelper.supports_image_input():
if images and LLMHelper.supports_image_input(
provider=settings.LLM_PROVIDER,
model=settings.LLM_MODEL,
):
images = self._download_attachments_to_data_urls(
images, channel, source
)

View File

@@ -0,0 +1,101 @@
from unittest.mock import AsyncMock, patch
from app.agent import MoviePilotAgent
from app.agent.llm import AgentCapabilityManager, LLMHelper
from app.chain.message import MessageChain
from app.core.config import settings
from app.schemas.types import MessageChannel
def test_llm_supports_image_input_uses_model_catalog_text_only(monkeypatch):
"""内置目录明确为纯文本模型时,应自动关闭图片输入。"""
monkeypatch.setattr(settings, "LLM_SUPPORT_IMAGE_INPUT", True)
assert not LLMHelper.supports_image_input(
provider="minimax",
model="MiniMax-M2.7",
)
def test_llm_supports_image_input_keeps_known_vision_model(monkeypatch):
"""内置目录明确为视觉模型时,应允许图片输入。"""
monkeypatch.setattr(settings, "LLM_SUPPORT_IMAGE_INPUT", True)
assert LLMHelper.supports_image_input(
provider="zhipuai",
model="glm-5v-turbo",
)
def test_llm_supports_image_input_keeps_unknown_model_override(monkeypatch):
"""未知自定义模型保持用户开关语义,避免误伤私有视觉模型。"""
monkeypatch.setattr(settings, "LLM_SUPPORT_IMAGE_INPUT", True)
assert LLMHelper.supports_image_input(
provider="custom-provider",
model="custom-vlm-model",
)
def test_agent_capability_manager_delegates_image_support():
"""Agent 能力管理器应复用统一的模型图片能力判断。"""
with patch.object(LLMHelper, "supports_image_input", return_value=False) as supports:
assert not AgentCapabilityManager.supports_image_input()
supports.assert_called_once_with()
def test_handle_ai_message_routes_text_only_model_images_to_files(monkeypatch):
"""纯文本模型收到图片消息时,应降级为文件附件而非 image_url 内容块。"""
chain = MessageChain()
monkeypatch.setattr(settings, "AI_AGENT_ENABLE", True)
monkeypatch.setattr(settings, "LLM_SUPPORT_IMAGE_INPUT", True)
monkeypatch.setattr(settings, "LLM_PROVIDER", "minimax")
monkeypatch.setattr(settings, "LLM_MODEL", "MiniMax-M2.7")
with patch.object(
chain, "_get_or_create_session_id", return_value="session-1"
), patch.object(
chain, "_download_attachments_to_data_urls"
) as download_images, patch.object(
chain,
"_prepare_agent_files",
return_value=[
{
"name": "image_1.jpg",
"mime_type": "image/jpeg",
"local_path": "/tmp/image_1.jpg",
"status": "ready",
}
],
) as prepare_files, patch(
"app.chain.message.agent_manager.process_message", new_callable=AsyncMock
) as process_message, patch(
"app.chain.message.asyncio.run_coroutine_threadsafe",
side_effect=lambda coro, _loop: coro.close(),
):
chain._handle_ai_message(
text="/ai 帮我看看这张图",
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
images=["tg://file_id/image-1"],
)
download_images.assert_not_called()
prepare_files.assert_called_once()
assert prepare_files.call_args.kwargs["files"][0].ref == "tg://file_id/image-1"
assert process_message.call_args.kwargs["images"] is None
assert process_message.call_args.kwargs["files"][0]["local_path"] == "/tmp/image_1.jpg"
def test_unsupported_image_error_recognizes_vlm_text_only_message():
"""兼容端点返回 not a VLM 时,应识别为图片输入能力错误。"""
error = Exception(
"Error code: 400 - {'code': 20041, 'message': "
"'The model is not a VLM (Vision Language Model). "
"Please use text-only prompts.'}"
)
assert MoviePilotAgent._is_unsupported_image_input_error(error)

View File

@@ -1,9 +1,7 @@
import asyncio
import unittest
from types import SimpleNamespace
from unittest.mock import patch
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from app.agent.middleware import tool_selection as tool_selector_module
from app.agent.tools.tags import ToolTag
@@ -25,14 +23,22 @@ class _FakeBoundModel:
class _FakeModel:
def __init__(
self,
*,
content='{"tools": ["calendar", "search"]}',
model_name="deepseek-reasoner",
base_url="https://api.deepseek.com",
self,
*,
content='{"tools": ["calendar", "search"]}',
model_name="gpt-4o-mini",
base_url="https://api.openai.com/v1",
runtime=None,
):
self.model_name = model_name
self.model = model_name
self.openai_api_base = base_url
self.api_base = base_url
self.base_url = base_url
self._moviepilot_llm_runtime = runtime
self._moviepilot_llm_base_url = base_url
self.messages = None
self.ainvoke_calls = []
self.bind_calls = []
self.bound_model = _FakeBoundModel(content)
@@ -40,6 +46,11 @@ class _FakeModel:
self.bind_calls.append(kwargs)
return self.bound_model
async def ainvoke(self, messages):
self.messages = messages
self.ainvoke_calls.append(messages)
return SimpleNamespace(content=self.bound_model.content)
class _FakeRequest:
def __init__(self, *, tools, messages, model, state=None, runtime=None):
@@ -66,418 +77,479 @@ def _tool(name, description, tags=None):
return SimpleNamespace(name=name, description=description, tags=tags or [])
class ToolSelectorMiddlewareTest(unittest.TestCase):
def test_awrap_model_call_uses_json_mode_for_deepseek(self):
tools = [
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="calendar", description="Manage events"),
SimpleNamespace(name="translate", description="Translate text"),
]
model = _FakeModel()
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
model=model,
)
handled_requests = []
def test_awrap_model_call_uses_json_prompt_for_all_models():
"""工具筛选应统一使用 JSON 提示,不绑定 provider 专属参数。"""
tools = [
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="calendar", description="Manage events"),
SimpleNamespace(name="translate", description="Translate text"),
]
model = _FakeModel()
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
model=model,
)
handled_requests = []
async def handler(updated_request):
handled_requests.append(updated_request)
return updated_request
async def handler(updated_request):
handled_requests.append(updated_request)
return updated_request
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
if state_update:
request.state.update(state_update)
result = asyncio.run(middleware.awrap_model_call(request, handler))
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
if state_update:
request.state.update(state_update)
result = asyncio.run(middleware.awrap_model_call(request, handler))
self.assertEqual(
model.bind_calls,
[{"response_format": {"type": "json_object"}}],
)
self.assertEqual(
[tool.name for tool in result.tools],
["search", "calendar"],
)
prompt = model.bound_model.messages[0]["content"]
self.assertIn("Return the answer in JSON only.", prompt)
self.assertIn('- search: Search for information', prompt)
self.assertIn('- calendar: Manage events', prompt)
self.assertIn("MoviePilot tool-chain hints:", prompt)
self.assertEqual(len(handled_requests), 1)
assert model.bind_calls == []
assert [tool.name for tool in result.tools] == ["search", "calendar"]
system_message = model.messages[0]
assert isinstance(system_message, SystemMessage)
prompt = system_message.content
assert "Return the answer in JSON only." in prompt
assert "- search: Search for information" in prompt
assert "- calendar: Manage events" in prompt
assert "MoviePilot tool-chain hints:" in prompt
assert len(handled_requests) == 1
def test_awrap_model_call_reuses_first_selection_for_later_model_rounds(self):
tools = [
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="calendar", description="Manage events"),
SimpleNamespace(name="translate", description="Translate text"),
]
model = _FakeModel(content='{"tools": ["calendar", "search"]}')
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
model=model,
)
handled_requests = []
async def handler(updated_request):
handled_requests.append(updated_request)
return updated_request
def test_awrap_model_call_uses_same_json_prompt_for_minimax():
"""MiniMax 工具筛选也应复用同一套 JSON 提示路径。"""
tools = [
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="calendar", description="Manage events"),
SimpleNamespace(name="translate", description="Translate text"),
]
model = _FakeModel(
model_name="MiniMax-M2.7",
base_url="https://api.minimaxi.com/anthropic/v1",
runtime="anthropic_compatible",
)
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
model=model,
)
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
if state_update:
request.state.update(state_update)
first_result = asyncio.run(middleware.awrap_model_call(request, handler))
second_result = asyncio.run(middleware.awrap_model_call(request, handler))
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
self.assertEqual(
model.bind_calls,
[{"response_format": {"type": "json_object"}}],
)
self.assertEqual(
[tool.name for tool in first_result.tools],
["search", "calendar"],
)
self.assertEqual(
[tool.name for tool in second_result.tools],
["search", "calendar"],
)
self.assertEqual(len(handled_requests), 2)
assert state_update == {"selected_tool_names": ["search", "calendar"]}
assert model.bind_calls == []
system_message = model.messages[0]
assert isinstance(system_message, SystemMessage)
assert "Return the answer in JSON only." in system_message.content
def test_awrap_model_call_caches_non_deepseek_selection_too(self):
tools = [
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="calendar", description="Manage events"),
SimpleNamespace(name="translate", description="Translate text"),
]
model = _FakeModel(
model_name="gpt-4o-mini",
base_url="https://api.openai.com/v1",
)
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
model=model,
)
async def handler(updated_request):
return updated_request
def test_awrap_model_call_uses_prompt_json_for_anthropic_runtime():
"""Anthropic-compatible runtime 不应触发额外 provider 分支。"""
tools = [
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="calendar", description="Manage events"),
SimpleNamespace(name="translate", description="Translate text"),
]
model = _FakeModel(
model_name="kimi-k2",
base_url="https://example.com/anthropic/v1",
runtime="anthropic_compatible",
)
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
model=model,
)
parent_calls = 0
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
async def _fake_parent_awrap(self, request_arg, handler_arg):
nonlocal parent_calls
parent_calls += 1
selected_request = request_arg.override(
tools=[request_arg.tools[1], request_arg.tools[0]]
)
return await handler_arg(selected_request)
assert state_update == {"selected_tool_names": ["search", "calendar"]}
assert model.bind_calls == []
system_message = model.messages[0]
assert isinstance(system_message, SystemMessage)
assert "Return the answer in JSON only." in system_message.content
with patch.object(
tool_selector_module.LLMToolSelectorMiddleware,
"awrap_model_call",
_fake_parent_awrap,
):
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
if state_update:
request.state.update(state_update)
first_result = asyncio.run(middleware.awrap_model_call(request, handler))
second_result = asyncio.run(middleware.awrap_model_call(request, handler))
self.assertEqual(parent_calls, 1)
self.assertEqual(
[tool.name for tool in first_result.tools],
["calendar", "search"],
)
self.assertEqual(
[tool.name for tool in second_result.tools],
["calendar", "search"],
)
def test_awrap_model_call_reuses_first_selection_for_later_model_rounds():
"""多轮模型回合应复用首轮筛选出的工具集合。"""
tools = [
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="calendar", description="Manage events"),
SimpleNamespace(name="translate", description="Translate text"),
]
model = _FakeModel(content='{"tools": ["calendar", "search"]}')
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
model=model,
)
handled_requests = []
def test_normalize_selection_response_accepts_code_fence_json(self):
middleware = tool_selector_module.ToolSelectorMiddleware()
response = SimpleNamespace(
content=[
{
"type": "text",
"text": '```json\n{"tools": ["search"]}\n```',
}
]
)
async def handler(updated_request):
handled_requests.append(updated_request)
return updated_request
normalized = middleware._normalize_selection_response(response)
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
if state_update:
request.state.update(state_update)
first_result = asyncio.run(middleware.awrap_model_call(request, handler))
second_result = asyncio.run(middleware.awrap_model_call(request, handler))
self.assertEqual(normalized, {"tools": ["search"]})
assert model.bind_calls == []
assert [tool.name for tool in first_result.tools] == ["search", "calendar"]
assert [tool.name for tool in second_result.tools] == ["search", "calendar"]
assert len(handled_requests) == 2
assert len(model.ainvoke_calls) == 1
def test_deepseek_selection_uses_recent_conversation_context(self):
"""多轮追问时工具筛选应看到上一轮用户需求和助手回复。"""
tools = [
_tool(
"query_plugin_config",
"Query plugin config",
[ToolTag.Read, ToolTag.Plugin, ToolTag.Settings],
),
_tool(
"update_plugin_config",
"Update plugin config",
[ToolTag.Write, ToolTag.Plugin, ToolTag.Settings],
),
_tool(
"reload_plugin",
"Reload plugin",
[ToolTag.Write, ToolTag.Plugin],
),
]
model = _FakeModel(content='{"tools": ["query_plugin_config"]}')
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=3,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[
HumanMessage(content="帮我检查插件 DemoPlugin 的配置"),
AIMessage(content="我建议先查询插件配置,然后根据结果决定是否重载插件。"),
HumanMessage(content="按你说的来"),
],
model=model,
)
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
def test_awrap_model_call_caches_plain_json_prompt_selection_too():
"""普通模型也应只调用一次 JSON 提示筛选并缓存结果。"""
tools = [
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="calendar", description="Manage events"),
SimpleNamespace(name="translate", description="Translate text"),
]
model = _FakeModel(
model_name="gpt-4o-mini",
base_url="https://api.openai.com/v1",
)
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
model=model,
)
user_message = model.bound_model.messages[1]
self.assertEqual(
state_update,
{"selected_tool_names": ["query_plugin_config", "update_plugin_config", "reload_plugin"]},
)
self.assertIsInstance(user_message, HumanMessage)
self.assertIn(
"Recent conversation context for tool selection",
user_message.content,
)
self.assertIn("帮我检查插件 DemoPlugin 的配置", user_message.content)
self.assertIn("我建议先查询插件配置", user_message.content)
self.assertIn("按你说的来", user_message.content)
async def handler(updated_request):
return updated_request
def test_single_turn_selection_keeps_original_user_message(self):
"""单轮对话不应额外包裹上下文提示。"""
tools = [
_tool("search", "Search for information", [ToolTag.Read, ToolTag.Web]),
_tool("calendar", "Manage events", [ToolTag.Write]),
]
model = _FakeModel(content='{"tools": ["search"]}')
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
original_message = HumanMessage(content="帮我查一下最近的更新")
request = _FakeRequest(
tools=tools,
messages=[original_message],
model=model,
)
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
if state_update:
request.state.update(state_update)
first_result = asyncio.run(middleware.awrap_model_call(request, handler))
second_result = asyncio.run(middleware.awrap_model_call(request, handler))
asyncio.run(middleware.abefore_agent(request.state, runtime=None, config=None))
assert model.bind_calls == []
assert len(model.ainvoke_calls) == 1
assert [tool.name for tool in first_result.tools] == ["search", "calendar"]
assert [tool.name for tool in second_result.tools] == ["search", "calendar"]
user_message = model.bound_model.messages[1]
self.assertIs(user_message, original_message)
self.assertNotIn(
"Recent conversation context for tool selection",
user_message.content,
)
def test_process_selection_response_completes_low_count_tool_group_by_tags(self):
"""筛选结果过少时应按已命中的工具标签组补齐同组工具"""
tools = [
_tool(
"search_media",
"Search media",
[ToolTag.Read, ToolTag.Media],
),
_tool(
"search_torrents",
"Search torrents",
[ToolTag.Read, ToolTag.Resource, ToolTag.Site, ToolTag.Media],
),
_tool(
"get_search_results",
"Get results",
[ToolTag.Read, ToolTag.Resource],
),
_tool(
"add_download_tasks",
"Add downloads",
[ToolTag.Write, ToolTag.Download, ToolTag.Resource],
),
_tool(
"query_download_tasks",
"Query downloads",
[ToolTag.Read, ToolTag.Download],
),
]
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=4,
selection_tools=tools,
)
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我下载流浪地球")],
model=_FakeModel(),
)
def test_tool_selection_failure_falls_back_to_all_tools():
"""筛选模型返回空响应时不应中断 Agent 请求"""
tools = [
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="calendar", description="Manage events"),
]
model = _FakeModel(content=None)
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
model=model,
)
result = middleware._process_selection_response(
{"tools": ["search_torrents"]},
available_tools=tools,
valid_tool_names=[tool.name for tool in tools],
request=request,
)
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
self.assertEqual(len(result.tools), 4)
self.assertEqual(
{tool.name for tool in result.tools},
assert state_update == {"selected_tool_names": ["search", "calendar"]}
def test_normalize_selection_response_accepts_code_fence_json():
"""工具筛选响应应兼容 Markdown 代码围栏包裹的 JSON。"""
middleware = tool_selector_module.ToolSelectorMiddleware()
response = SimpleNamespace(
content=[
{
"search_media",
"search_torrents",
"get_search_results",
"add_download_tasks",
},
)
def test_process_selection_response_keeps_high_count_selection(self):
"""筛选结果数量足够时不应额外补齐工具。"""
tools = [
SimpleNamespace(name="search_media", description="Search media"),
SimpleNamespace(name="search_torrents", description="Search torrents"),
SimpleNamespace(name="get_search_results", description="Get results"),
SimpleNamespace(name="query_sites", description="Query sites"),
"type": "text",
"text": '```json\n{"tools": ["search"]}\n```',
}
]
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=4,
selection_tools=tools,
)
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我下载流浪地球")],
model=_FakeModel(),
)
)
result = middleware._process_selection_response(
{
"tools": [
"search_media",
"search_torrents",
"get_search_results",
"query_sites",
]
},
available_tools=tools,
valid_tool_names=[tool.name for tool in tools],
request=request,
)
normalized = middleware._normalize_selection_response(response)
self.assertEqual(
[tool.name for tool in result.tools],
[
assert normalized == {"tools": ["search"]}
def test_json_prompt_selection_uses_recent_conversation_context():
"""多轮追问时工具筛选应看到上一轮用户需求和助手回复。"""
tools = [
_tool(
"query_plugin_config",
"Query plugin config",
[ToolTag.Read, ToolTag.Plugin, ToolTag.Settings],
),
_tool(
"update_plugin_config",
"Update plugin config",
[ToolTag.Write, ToolTag.Plugin, ToolTag.Settings],
),
_tool(
"reload_plugin",
"Reload plugin",
[ToolTag.Write, ToolTag.Plugin],
),
]
model = _FakeModel(content='{"tools": ["query_plugin_config"]}')
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=3,
selection_tools=tools,
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[
HumanMessage(content="帮我检查插件 DemoPlugin 的配置"),
AIMessage(content="我建议先查询插件配置,然后根据结果决定是否重载插件。"),
HumanMessage(content="按你说的来"),
],
model=model,
)
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
user_message = model.messages[1]
assert state_update == {
"selected_tool_names": [
"query_plugin_config",
"update_plugin_config",
"reload_plugin",
]
}
assert isinstance(user_message, HumanMessage)
assert "Recent conversation context for tool selection" in user_message.content
assert "帮我检查插件 DemoPlugin 的配置" in user_message.content
assert "我建议先查询插件配置" in user_message.content
assert "按你说的来" in user_message.content
def test_single_turn_selection_keeps_original_user_message():
"""单轮对话不应额外包裹上下文提示。"""
tools = [
_tool("search", "Search for information", [ToolTag.Read, ToolTag.Web]),
_tool("calendar", "Manage events", [ToolTag.Write]),
]
model = _FakeModel(content='{"tools": ["search"]}')
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
middleware.model = model
original_message = HumanMessage(content="帮我查一下最近的更新")
request = _FakeRequest(
tools=tools,
messages=[original_message],
model=model,
)
asyncio.run(middleware.abefore_agent(request.state, runtime=None, config=None))
user_message = model.messages[1]
assert user_message is original_message
assert "Recent conversation context for tool selection" not in user_message.content
def test_process_selection_response_completes_low_count_tool_group_by_tags():
"""筛选结果过少时应按已命中的工具标签组补齐同组工具。"""
tools = [
_tool(
"search_media",
"Search media",
[ToolTag.Read, ToolTag.Media],
),
_tool(
"search_torrents",
"Search torrents",
[ToolTag.Read, ToolTag.Resource, ToolTag.Site, ToolTag.Media],
),
_tool(
"get_search_results",
"Get results",
[ToolTag.Read, ToolTag.Resource],
),
_tool(
"add_download_tasks",
"Add downloads",
[ToolTag.Write, ToolTag.Download, ToolTag.Resource],
),
_tool(
"query_download_tasks",
"Query downloads",
[ToolTag.Read, ToolTag.Download],
),
]
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=4,
selection_tools=tools,
)
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我下载流浪地球")],
model=_FakeModel(),
)
result = middleware._process_selection_response(
{"tools": ["search_torrents"]},
available_tools=tools,
valid_tool_names=[tool.name for tool in tools],
request=request,
)
assert len(result.tools) == 4
assert {tool.name for tool in result.tools} == {
"search_media",
"search_torrents",
"get_search_results",
"add_download_tasks",
}
def test_process_selection_response_keeps_high_count_selection():
"""筛选结果数量足够时不应额外补齐工具。"""
tools = [
SimpleNamespace(name="search_media", description="Search media"),
SimpleNamespace(name="search_torrents", description="Search torrents"),
SimpleNamespace(name="get_search_results", description="Get results"),
SimpleNamespace(name="query_sites", description="Query sites"),
]
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=4,
selection_tools=tools,
)
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我下载流浪地球")],
model=_FakeModel(),
)
result = middleware._process_selection_response(
{
"tools": [
"search_media",
"search_torrents",
"get_search_results",
"query_sites",
],
)
]
},
available_tools=tools,
valid_tool_names=[tool.name for tool in tools],
request=request,
)
def test_process_selection_response_respects_max_tools_when_completing(self):
"""标签组补齐不应突破 max_tools 上限。"""
tools = [
_tool(
"list_directory",
"List directory",
[ToolTag.Read, ToolTag.Directory, ToolTag.File],
),
_tool(
"query_directory_settings",
"Query settings",
[ToolTag.Read, ToolTag.Directory, ToolTag.Settings],
),
_tool(
"recognize_media",
"Recognize media",
[ToolTag.Read, ToolTag.Media],
),
_tool(
"transfer_file",
"Transfer file",
[ToolTag.Write, ToolTag.Transfer, ToolTag.Library, ToolTag.File],
),
]
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我整理这个目录")],
model=_FakeModel(),
)
assert [tool.name for tool in result.tools] == [
"search_media",
"search_torrents",
"get_search_results",
"query_sites",
]
result = middleware._process_selection_response(
{"tools": ["transfer_file"]},
available_tools=tools,
valid_tool_names=[tool.name for tool in tools],
request=request,
)
self.assertEqual(len(result.tools), 2)
self.assertEqual(
{tool.name for tool in result.tools},
{"transfer_file", "list_directory"},
)
def test_process_selection_response_respects_max_tools_when_completing():
"""标签组补齐不应突破 max_tools 上限。"""
tools = [
_tool(
"list_directory",
"List directory",
[ToolTag.Read, ToolTag.Directory, ToolTag.File],
),
_tool(
"query_directory_settings",
"Query settings",
[ToolTag.Read, ToolTag.Directory, ToolTag.Settings],
),
_tool(
"recognize_media",
"Recognize media",
[ToolTag.Read, ToolTag.Media],
),
_tool(
"transfer_file",
"Transfer file",
[ToolTag.Write, ToolTag.Transfer, ToolTag.Library, ToolTag.File],
),
]
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
)
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我整理这个目录")],
model=_FakeModel(),
)
def test_process_selection_response_ignores_generic_tags_when_completing(self):
"""通用权限标签不应被当作工具组使用。"""
tools = [
_tool("read_one", "Read one", [ToolTag.Read]),
_tool("read_two", "Read two", [ToolTag.Read]),
_tool("write_one", "Write one", [ToolTag.Write, ToolTag.Admin]),
]
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=4,
selection_tools=tools,
)
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="查一下信息")],
model=_FakeModel(),
)
result = middleware._process_selection_response(
{"tools": ["transfer_file"]},
available_tools=tools,
valid_tool_names=[tool.name for tool in tools],
request=request,
)
result = middleware._process_selection_response(
{"tools": ["read_one"]},
available_tools=tools,
valid_tool_names=[tool.name for tool in tools],
request=request,
)
assert len(result.tools) == 2
assert {tool.name for tool in result.tools} == {"transfer_file", "list_directory"}
self.assertEqual([tool.name for tool in result.tools], ["read_one"])
def test_process_selection_response_ignores_generic_tags_when_completing():
"""通用权限标签不应被当作工具组使用。"""
tools = [
_tool("read_one", "Read one", [ToolTag.Read]),
_tool("read_two", "Read two", [ToolTag.Read]),
_tool("write_one", "Write one", [ToolTag.Write, ToolTag.Admin]),
]
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=4,
selection_tools=tools,
)
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="查一下信息")],
model=_FakeModel(),
)
result = middleware._process_selection_response(
{"tools": ["read_one"]},
available_tools=tools,
valid_tool_names=[tool.name for tool in tools],
request=request,
)
assert [tool.name for tool in result.tools] == ["read_one"]

View File

@@ -589,6 +589,59 @@ class LlmHelperTestCallTest(unittest.TestCase):
)
self.assertEqual(llm_calls[0].get("default_headers"), {"X-Test": "1"})
def test_get_llm_attaches_runtime_metadata(self):
"""LLM 实例应带上内部 runtime 元数据,供 Agent 中间件判断兼容分支。"""
class _FakeProviderManager:
async def resolve_runtime(self, **kwargs):
return {
"provider_id": kwargs["provider_id"],
"runtime": "anthropic_compatible",
"model_id": kwargs["model"],
"api_key": kwargs["api_key"],
"base_url": kwargs["base_url"],
"default_headers": None,
"use_responses_api": None,
"model_record": None,
"model_metadata": None,
}
class _FakeChatAnthropic:
def __init__(self, **kwargs):
self.model = kwargs["model"]
self.profile = None
provider_module = ModuleType("app.agent.llm.provider")
provider_module.LLMProviderManager = _FakeProviderManager
anthropic_module = ModuleType("langchain_anthropic")
anthropic_module.ChatAnthropic = _FakeChatAnthropic
with patch.dict(
sys.modules,
{
"app.agent.llm.provider": provider_module,
"langchain_anthropic": anthropic_module,
},
):
model = asyncio.run(
llm_module.LLMHelper.get_llm(
provider="minimax",
model="MiniMax-M2.7",
api_key="sk-test",
base_url="https://api.minimaxi.com/anthropic/v1",
)
)
self.assertEqual(
getattr(model, "_moviepilot_llm_runtime"),
"anthropic_compatible",
)
self.assertEqual(getattr(model, "_moviepilot_llm_provider_id"), "minimax")
self.assertEqual(
getattr(model, "_moviepilot_llm_base_url"),
"https://api.minimaxi.com/anthropic/v1",
)
def test_get_llm_applies_proxy_only_when_enabled(self):
"""LLM 构造时应按独立开关决定是否传入系统代理。"""
calls = []