mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-23 08:34:28 +08:00
Refine existing implementation
This commit is contained in:
@@ -838,6 +838,8 @@ class MoviePilotAgent:
|
||||
detail = cls._exception_detail_text(error).lower()
|
||||
if "no endpoints found that support image input" in detail:
|
||||
return True
|
||||
if "not a vlm" in detail or "text-only prompts" in detail:
|
||||
return True
|
||||
if "unknown variant" in detail and "image_url" in detail:
|
||||
return True
|
||||
if "image input" not in detail and "images" not in detail:
|
||||
|
||||
@@ -691,7 +691,9 @@ class AgentCapabilityManager:
|
||||
@staticmethod
|
||||
def supports_image_input() -> bool:
|
||||
"""当前 Agent 是否启用图片输入能力。"""
|
||||
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
|
||||
from app.agent.llm.helper import LLMHelper
|
||||
|
||||
return LLMHelper.supports_image_input()
|
||||
|
||||
@staticmethod
|
||||
def supports_audio_input() -> bool:
|
||||
|
||||
@@ -5,7 +5,7 @@ import inspect
|
||||
import json
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
|
||||
@@ -700,11 +700,85 @@ class LLMHelper:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def supports_image_input() -> bool:
|
||||
def _metadata_supports_image_input(metadata: Any) -> Optional[bool]:
|
||||
"""从模型元数据中读取图片输入能力,未知时返回 None。"""
|
||||
if not isinstance(metadata, dict):
|
||||
return None
|
||||
|
||||
modalities = metadata.get("modalities") or {}
|
||||
input_modalities = modalities.get("input")
|
||||
if isinstance(input_modalities, str):
|
||||
input_modalities = [input_modalities]
|
||||
if isinstance(input_modalities, list):
|
||||
normalized_modalities = {
|
||||
str(item or "").strip().lower() for item in input_modalities
|
||||
}
|
||||
return "image" in normalized_modalities
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _resolve_catalog_image_input_support(
|
||||
cls,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
base_url_preset: Optional[str] = None,
|
||||
) -> Optional[bool]:
|
||||
"""复用 provider 目录缓存解析当前模型是否支持图片输入。"""
|
||||
provider_name = str(provider if provider is not None else settings.LLM_PROVIDER).strip()
|
||||
model_name = str(model if model is not None else settings.LLM_MODEL).strip()
|
||||
if not provider_name or not model_name:
|
||||
return None
|
||||
|
||||
try:
|
||||
from app.agent.llm.provider import LLMProviderManager
|
||||
|
||||
metadata = LLMProviderManager().resolve_cached_model_metadata(
|
||||
provider_id=provider_name,
|
||||
model_id=model_name,
|
||||
base_url=base_url if base_url is not None else settings.LLM_BASE_URL,
|
||||
base_url_preset_id=(
|
||||
base_url_preset
|
||||
if base_url_preset is not None
|
||||
else settings.LLM_BASE_URL_PRESET
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"解析模型图片能力失败: {err}")
|
||||
return None
|
||||
|
||||
return cls._metadata_supports_image_input(metadata)
|
||||
|
||||
@classmethod
|
||||
def supports_image_input(
|
||||
cls,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
base_url_preset: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
判断当前模型是否启用了图片输入能力。
|
||||
|
||||
用户开关为总开关;当内置模型目录明确标注当前模型不支持 image 输入时,
|
||||
即使总开关开启也降级为纯文本,避免文本模型收到 `image_url` 内容块后
|
||||
被兼容端点以 400 拒绝。无参调用保持旧版“只读总开关”语义,
|
||||
未知自定义模型也保持原有开关语义。
|
||||
"""
|
||||
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
|
||||
if not settings.LLM_SUPPORT_IMAGE_INPUT:
|
||||
return False
|
||||
if provider is None and model is None:
|
||||
return True
|
||||
|
||||
image_support = cls._resolve_catalog_image_input_support(
|
||||
provider=provider,
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
base_url_preset=base_url_preset,
|
||||
)
|
||||
if image_support is not None:
|
||||
return image_support
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _build_legacy_runtime(
|
||||
@@ -798,6 +872,41 @@ class LLMHelper:
|
||||
return True
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _attach_runtime_metadata(model: Any, runtime: dict[str, Any]) -> None:
|
||||
"""
|
||||
将 MoviePilot 已解析出的 provider 运行时信息挂到模型实例上。
|
||||
|
||||
这些字段只供内部中间件识别协议能力,不参与 LangChain 请求序列化。
|
||||
"""
|
||||
runtime_metadata = {
|
||||
"runtime": runtime.get("runtime"),
|
||||
"provider_id": runtime.get("provider_id"),
|
||||
"base_url": runtime.get("base_url"),
|
||||
}
|
||||
|
||||
def _set_metadata_attr(name: str, value: Any) -> None:
|
||||
try:
|
||||
setattr(model, name, value)
|
||||
except Exception:
|
||||
object.__setattr__(model, name, value)
|
||||
|
||||
try:
|
||||
_set_metadata_attr("_moviepilot_llm_runtime", runtime_metadata["runtime"])
|
||||
_set_metadata_attr(
|
||||
"_moviepilot_llm_provider_id",
|
||||
runtime_metadata["provider_id"],
|
||||
)
|
||||
_set_metadata_attr("_moviepilot_llm_base_url", runtime_metadata["base_url"])
|
||||
except Exception as err:
|
||||
logger.debug(f"LLM运行时元数据附加失败: {str(err)}")
|
||||
|
||||
profile = getattr(model, "profile", None)
|
||||
if isinstance(profile, dict):
|
||||
profile["moviepilot_runtime"] = runtime_metadata["runtime"]
|
||||
profile["moviepilot_provider_id"] = runtime_metadata["provider_id"]
|
||||
profile["moviepilot_base_url"] = runtime_metadata["base_url"]
|
||||
|
||||
@classmethod
|
||||
def _resolve_thinking_level(
|
||||
cls,
|
||||
@@ -1011,6 +1120,7 @@ class LLMHelper:
|
||||
"max_input_tokens": int(max_input_tokens),
|
||||
}
|
||||
|
||||
cls._attach_runtime_metadata(model, runtime)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1564,6 +1564,70 @@ class LLMProviderManager(metaclass=Singleton):
|
||||
return models[candidate]
|
||||
return None
|
||||
|
||||
def _cached_models_dev_model(
|
||||
self,
|
||||
provider_id: str,
|
||||
model_id: str,
|
||||
base_url: Optional[str] = None,
|
||||
base_url_preset_id: Optional[str] = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""从已缓存或内置的 models.dev 数据中同步读取模型元数据。"""
|
||||
try:
|
||||
spec = self.get_provider(provider_id)
|
||||
except LLMProviderError:
|
||||
return None
|
||||
|
||||
models_dev_provider_id = self._resolve_provider_models_dev_provider_id(
|
||||
spec,
|
||||
base_url,
|
||||
base_url_preset_id=base_url_preset_id,
|
||||
)
|
||||
if not models_dev_provider_id:
|
||||
return None
|
||||
|
||||
payload = self._cached_models_dev_payload().get(models_dev_provider_id, {}) or {}
|
||||
models = payload.get("models") if isinstance(payload, dict) else None
|
||||
if not isinstance(models, dict):
|
||||
return None
|
||||
|
||||
candidates = [model_id]
|
||||
if model_id.startswith("models/"):
|
||||
candidates.append(model_id.removeprefix("models/"))
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate in models:
|
||||
return models[candidate]
|
||||
return None
|
||||
|
||||
def resolve_cached_model_metadata(
|
||||
self,
|
||||
provider_id: str,
|
||||
model_id: Optional[str],
|
||||
base_url: Optional[str] = None,
|
||||
base_url_preset_id: Optional[str] = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""同步解析缓存中的模型元数据,不触发远端 models.dev 刷新。"""
|
||||
if not model_id:
|
||||
return None
|
||||
metadata = self._cached_models_dev_model(
|
||||
provider_id,
|
||||
model_id,
|
||||
base_url=base_url,
|
||||
base_url_preset_id=base_url_preset_id,
|
||||
)
|
||||
if metadata:
|
||||
return metadata
|
||||
if provider_id == "chatgpt":
|
||||
return self._cached_models_dev_model("openai", model_id)
|
||||
if provider_id == "openai":
|
||||
return (
|
||||
self._cached_models_dev_payload()
|
||||
.get("openai", {})
|
||||
.get("models", {})
|
||||
.get(model_id)
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_model_record(
|
||||
model_id: str,
|
||||
|
||||
@@ -20,7 +20,7 @@ from langchain.agents.middleware.tool_selection import (
|
||||
LLMToolSelectorMiddleware,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.runtime import Runtime
|
||||
@@ -70,17 +70,13 @@ class ToolSelectionStateUpdate(TypedDict):
|
||||
|
||||
class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
"""
|
||||
为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。
|
||||
使用 provider-neutral JSON 提示执行工具筛选。
|
||||
|
||||
LangChain 默认会通过 `with_structured_output()` 走 OpenAI 的
|
||||
`response_format=json_schema` 路径,但 DeepSeek 官方 OpenAI 兼容端点公开文档
|
||||
仅保证 `json_object` 模式可用。对于 `deepseek-reasoner`,这会在工具筛选阶段
|
||||
提前触发 400,导致 Agent 还没真正开始执行工具就失败。
|
||||
|
||||
因此这里仅在识别到 DeepSeek 模型/端点时,退回到显式 JSON 输出模式:
|
||||
1. 使用 `response_format={"type": "json_object"}`;
|
||||
2. 在提示词中明确约束返回 JSON 结构;
|
||||
3. 手动解析 `{"tools": [...]}`,其余模型继续沿用 LangChain 默认实现。
|
||||
LangChain 默认会通过 `with_structured_output()` 走 provider-specific 的
|
||||
结构化输出能力,不同 OpenAI/Anthropic 兼容端点对 `response_format`、
|
||||
JSON schema 和工具绑定的支持并不一致。工具筛选只是 Agent 执行前的
|
||||
辅助优化,失败时也会恢复使用全部工具,因此这里统一使用文本提示约束
|
||||
模型返回 `{"tools": [...]}` 并手动解析,避免在筛选阶段引入额外兼容分支。
|
||||
|
||||
另外,LangChain 原生工具筛选挂在 `wrap_model_call` 上,会在同一条用户请求
|
||||
的每次“模型回合”前都重新筛选一次工具。对于会多轮调用工具的复杂任务,
|
||||
@@ -354,40 +350,13 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
request,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_deepseek_compatible_model(model: BaseChatModel) -> bool:
|
||||
"""
|
||||
判断当前模型是否应当走 DeepSeek JSON 兼容分支。
|
||||
|
||||
除了官方 `langchain_deepseek`,用户也可能通过 OpenAI-compatible
|
||||
配置把 DeepSeek 端点接到 `ChatOpenAI`。因此这里同时检查模块名、模型名
|
||||
和 Base URL,避免只靠单一条件漏判。
|
||||
"""
|
||||
module_name = type(model).__module__.lower()
|
||||
model_name = (
|
||||
str(getattr(model, "model_name", "") or getattr(model, "model", ""))
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
base_url = (
|
||||
str(getattr(model, "openai_api_base", "") or getattr(model, "api_base", ""))
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
|
||||
return (
|
||||
"deepseek" in module_name
|
||||
or model_name.startswith("deepseek-")
|
||||
or "api.deepseek.com" in base_url
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_json_object(text: str) -> dict[str, Any]:
|
||||
"""
|
||||
解析模型返回的 JSON。
|
||||
|
||||
DeepSeek 在 JSON 模式下通常会返回纯 JSON,但这里仍做一层兜底,
|
||||
兼容模型偶发输出围栏或前后说明文本的情况。
|
||||
不同模型可能偶发输出 Markdown 围栏或前后说明文本,因此这里从
|
||||
响应中提取第一个 JSON 对象作为兜底。
|
||||
"""
|
||||
stripped_text = text.strip()
|
||||
if not stripped_text:
|
||||
@@ -440,12 +409,12 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
)
|
||||
return f"Capability groups from tool tags:\n{rendered_groups}\n\n"
|
||||
|
||||
def _build_deepseek_selection_prompt(self, selection_request: Any) -> str:
|
||||
def _build_json_selection_prompt(self, selection_request: Any) -> str:
|
||||
"""
|
||||
为 DeepSeek 生成显式 JSON 输出提示。
|
||||
生成显式 JSON 输出提示。
|
||||
|
||||
DeepSeek 官方文档要求在 JSON 输出模式下,提示词中必须明确包含 JSON
|
||||
约束,否则兼容端点可能返回空内容或无意义输出。
|
||||
使用纯提示约束可覆盖更多兼容端点,避免在工具筛选阶段依赖某个
|
||||
provider 专属的 `response_format` 或 schema 能力。
|
||||
"""
|
||||
limit_instruction = ""
|
||||
if self.max_tools:
|
||||
@@ -469,7 +438,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
|
||||
def _normalize_selection_response(self, response: Any) -> dict[str, list[str]]:
|
||||
"""
|
||||
解析并标准化 DeepSeek JSON 模式的工具筛选结果。
|
||||
解析并标准化显式 JSON 模式的工具筛选结果。
|
||||
"""
|
||||
content = getattr(response, "content", response)
|
||||
text = LLMHelper.extract_text_content(content)
|
||||
@@ -486,22 +455,21 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
logger.debug(f"工具筛选标准化结果: {normalized_tools}")
|
||||
return {"tools": normalized_tools}
|
||||
|
||||
async def _aselect_tools_with_deepseek(
|
||||
async def _aselect_tools_with_json_prompt(
|
||||
self, selection_request: Any
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
使用 DeepSeek 兼容的 JSON 输出模式执行异步工具筛选。
|
||||
使用 JSON 提示执行异步工具筛选。
|
||||
|
||||
:param selection_request: LangChain 工具筛选请求
|
||||
:return: 标准化后的工具名列表
|
||||
"""
|
||||
logger.debug("工具筛选走 DeepSeek JSON 兼容分支")
|
||||
structured_model = selection_request.model.bind(
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
response = await structured_model.ainvoke(
|
||||
logger.debug("工具筛选走 JSON 提示分支")
|
||||
response = await selection_request.model.ainvoke(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": self._build_deepseek_selection_prompt(selection_request),
|
||||
},
|
||||
SystemMessage(
|
||||
content=self._build_json_selection_prompt(selection_request)
|
||||
),
|
||||
selection_request.last_user_message,
|
||||
]
|
||||
)
|
||||
@@ -550,26 +518,17 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
if selection_request is None:
|
||||
return request
|
||||
|
||||
if not self._is_deepseek_compatible_model(selection_request.model):
|
||||
captured_request: ModelRequest[ContextT] = request
|
||||
|
||||
async def _capture_handler(
|
||||
updated_request: ModelRequest[ContextT],
|
||||
) -> ModelRequest[ContextT]:
|
||||
nonlocal captured_request
|
||||
captured_request = updated_request
|
||||
return updated_request
|
||||
|
||||
await super().awrap_model_call(request, _capture_handler)
|
||||
return captured_request
|
||||
|
||||
response = await self._aselect_tools_with_deepseek(selection_request)
|
||||
return self._process_selection_response(
|
||||
response,
|
||||
selection_request.available_tools,
|
||||
selection_request.valid_tool_names,
|
||||
request,
|
||||
)
|
||||
try:
|
||||
response = await self._aselect_tools_with_json_prompt(selection_request)
|
||||
return self._process_selection_response(
|
||||
response,
|
||||
selection_request.available_tools,
|
||||
selection_request.valid_tool_names,
|
||||
request,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.warning(f"工具筛选失败,将恢复使用所有工具: {str(err)}")
|
||||
return request
|
||||
|
||||
async def abefore_agent( # noqa
|
||||
self,
|
||||
|
||||
@@ -243,7 +243,7 @@ class MessageChain(ChainBase):
|
||||
processing_status=processing_status,
|
||||
)
|
||||
finally:
|
||||
if continues_async is not True:
|
||||
if continues_async:
|
||||
self._mark_message_processing_finished(
|
||||
channel=channel,
|
||||
source=source,
|
||||
@@ -1278,7 +1278,10 @@ class MessageChain(ChainBase):
|
||||
# 将可直接输入给 LLM 的附件统一转换为 data URL
|
||||
original_images = images
|
||||
all_files = list(files or [])
|
||||
if images and LLMHelper.supports_image_input():
|
||||
if images and LLMHelper.supports_image_input(
|
||||
provider=settings.LLM_PROVIDER,
|
||||
model=settings.LLM_MODEL,
|
||||
):
|
||||
images = self._download_attachments_to_data_urls(
|
||||
images, channel, source
|
||||
)
|
||||
|
||||
101
tests/test_agent_image_capability.py
Normal file
101
tests/test_agent_image_capability.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.agent import MoviePilotAgent
|
||||
from app.agent.llm import AgentCapabilityManager, LLMHelper
|
||||
from app.chain.message import MessageChain
|
||||
from app.core.config import settings
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
def test_llm_supports_image_input_uses_model_catalog_text_only(monkeypatch):
|
||||
"""内置目录明确为纯文本模型时,应自动关闭图片输入。"""
|
||||
monkeypatch.setattr(settings, "LLM_SUPPORT_IMAGE_INPUT", True)
|
||||
|
||||
assert not LLMHelper.supports_image_input(
|
||||
provider="minimax",
|
||||
model="MiniMax-M2.7",
|
||||
)
|
||||
|
||||
|
||||
def test_llm_supports_image_input_keeps_known_vision_model(monkeypatch):
|
||||
"""内置目录明确为视觉模型时,应允许图片输入。"""
|
||||
monkeypatch.setattr(settings, "LLM_SUPPORT_IMAGE_INPUT", True)
|
||||
|
||||
assert LLMHelper.supports_image_input(
|
||||
provider="zhipuai",
|
||||
model="glm-5v-turbo",
|
||||
)
|
||||
|
||||
|
||||
def test_llm_supports_image_input_keeps_unknown_model_override(monkeypatch):
|
||||
"""未知自定义模型保持用户开关语义,避免误伤私有视觉模型。"""
|
||||
monkeypatch.setattr(settings, "LLM_SUPPORT_IMAGE_INPUT", True)
|
||||
|
||||
assert LLMHelper.supports_image_input(
|
||||
provider="custom-provider",
|
||||
model="custom-vlm-model",
|
||||
)
|
||||
|
||||
|
||||
def test_agent_capability_manager_delegates_image_support():
|
||||
"""Agent 能力管理器应复用统一的模型图片能力判断。"""
|
||||
with patch.object(LLMHelper, "supports_image_input", return_value=False) as supports:
|
||||
assert not AgentCapabilityManager.supports_image_input()
|
||||
|
||||
supports.assert_called_once_with()
|
||||
|
||||
|
||||
def test_handle_ai_message_routes_text_only_model_images_to_files(monkeypatch):
|
||||
"""纯文本模型收到图片消息时,应降级为文件附件而非 image_url 内容块。"""
|
||||
chain = MessageChain()
|
||||
monkeypatch.setattr(settings, "AI_AGENT_ENABLE", True)
|
||||
monkeypatch.setattr(settings, "LLM_SUPPORT_IMAGE_INPUT", True)
|
||||
monkeypatch.setattr(settings, "LLM_PROVIDER", "minimax")
|
||||
monkeypatch.setattr(settings, "LLM_MODEL", "MiniMax-M2.7")
|
||||
|
||||
with patch.object(
|
||||
chain, "_get_or_create_session_id", return_value="session-1"
|
||||
), patch.object(
|
||||
chain, "_download_attachments_to_data_urls"
|
||||
) as download_images, patch.object(
|
||||
chain,
|
||||
"_prepare_agent_files",
|
||||
return_value=[
|
||||
{
|
||||
"name": "image_1.jpg",
|
||||
"mime_type": "image/jpeg",
|
||||
"local_path": "/tmp/image_1.jpg",
|
||||
"status": "ready",
|
||||
}
|
||||
],
|
||||
) as prepare_files, patch(
|
||||
"app.chain.message.agent_manager.process_message", new_callable=AsyncMock
|
||||
) as process_message, patch(
|
||||
"app.chain.message.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=lambda coro, _loop: coro.close(),
|
||||
):
|
||||
chain._handle_ai_message(
|
||||
text="/ai 帮我看看这张图",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
images=["tg://file_id/image-1"],
|
||||
)
|
||||
|
||||
download_images.assert_not_called()
|
||||
prepare_files.assert_called_once()
|
||||
assert prepare_files.call_args.kwargs["files"][0].ref == "tg://file_id/image-1"
|
||||
assert process_message.call_args.kwargs["images"] is None
|
||||
assert process_message.call_args.kwargs["files"][0]["local_path"] == "/tmp/image_1.jpg"
|
||||
|
||||
|
||||
def test_unsupported_image_error_recognizes_vlm_text_only_message():
|
||||
"""兼容端点返回 not a VLM 时,应识别为图片输入能力错误。"""
|
||||
error = Exception(
|
||||
"Error code: 400 - {'code': 20041, 'message': "
|
||||
"'The model is not a VLM (Vision Language Model). "
|
||||
"Please use text-only prompts.'}"
|
||||
)
|
||||
|
||||
assert MoviePilotAgent._is_unsupported_image_input_error(error)
|
||||
@@ -1,9 +1,7 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from app.agent.middleware import tool_selection as tool_selector_module
|
||||
from app.agent.tools.tags import ToolTag
|
||||
@@ -25,14 +23,22 @@ class _FakeBoundModel:
|
||||
|
||||
class _FakeModel:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
content='{"tools": ["calendar", "search"]}',
|
||||
model_name="deepseek-reasoner",
|
||||
base_url="https://api.deepseek.com",
|
||||
self,
|
||||
*,
|
||||
content='{"tools": ["calendar", "search"]}',
|
||||
model_name="gpt-4o-mini",
|
||||
base_url="https://api.openai.com/v1",
|
||||
runtime=None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.model = model_name
|
||||
self.openai_api_base = base_url
|
||||
self.api_base = base_url
|
||||
self.base_url = base_url
|
||||
self._moviepilot_llm_runtime = runtime
|
||||
self._moviepilot_llm_base_url = base_url
|
||||
self.messages = None
|
||||
self.ainvoke_calls = []
|
||||
self.bind_calls = []
|
||||
self.bound_model = _FakeBoundModel(content)
|
||||
|
||||
@@ -40,6 +46,11 @@ class _FakeModel:
|
||||
self.bind_calls.append(kwargs)
|
||||
return self.bound_model
|
||||
|
||||
async def ainvoke(self, messages):
|
||||
self.messages = messages
|
||||
self.ainvoke_calls.append(messages)
|
||||
return SimpleNamespace(content=self.bound_model.content)
|
||||
|
||||
|
||||
class _FakeRequest:
|
||||
def __init__(self, *, tools, messages, model, state=None, runtime=None):
|
||||
@@ -66,418 +77,479 @@ def _tool(name, description, tags=None):
|
||||
return SimpleNamespace(name=name, description=description, tags=tags or [])
|
||||
|
||||
|
||||
class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
def test_awrap_model_call_uses_json_mode_for_deepseek(self):
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel()
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
model=model,
|
||||
)
|
||||
handled_requests = []
|
||||
def test_awrap_model_call_uses_json_prompt_for_all_models():
|
||||
"""工具筛选应统一使用 JSON 提示,不绑定 provider 专属参数。"""
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel()
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
model=model,
|
||||
)
|
||||
handled_requests = []
|
||||
|
||||
async def handler(updated_request):
|
||||
handled_requests.append(updated_request)
|
||||
return updated_request
|
||||
async def handler(updated_request):
|
||||
handled_requests.append(updated_request)
|
||||
return updated_request
|
||||
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
if state_update:
|
||||
request.state.update(state_update)
|
||||
result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
if state_update:
|
||||
request.state.update(state_update)
|
||||
result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
|
||||
self.assertEqual(
|
||||
model.bind_calls,
|
||||
[{"response_format": {"type": "json_object"}}],
|
||||
)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in result.tools],
|
||||
["search", "calendar"],
|
||||
)
|
||||
prompt = model.bound_model.messages[0]["content"]
|
||||
self.assertIn("Return the answer in JSON only.", prompt)
|
||||
self.assertIn('- search: Search for information', prompt)
|
||||
self.assertIn('- calendar: Manage events', prompt)
|
||||
self.assertIn("MoviePilot tool-chain hints:", prompt)
|
||||
self.assertEqual(len(handled_requests), 1)
|
||||
assert model.bind_calls == []
|
||||
assert [tool.name for tool in result.tools] == ["search", "calendar"]
|
||||
system_message = model.messages[0]
|
||||
assert isinstance(system_message, SystemMessage)
|
||||
prompt = system_message.content
|
||||
assert "Return the answer in JSON only." in prompt
|
||||
assert "- search: Search for information" in prompt
|
||||
assert "- calendar: Manage events" in prompt
|
||||
assert "MoviePilot tool-chain hints:" in prompt
|
||||
assert len(handled_requests) == 1
|
||||
|
||||
def test_awrap_model_call_reuses_first_selection_for_later_model_rounds(self):
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel(content='{"tools": ["calendar", "search"]}')
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
model=model,
|
||||
)
|
||||
handled_requests = []
|
||||
|
||||
async def handler(updated_request):
|
||||
handled_requests.append(updated_request)
|
||||
return updated_request
|
||||
def test_awrap_model_call_uses_same_json_prompt_for_minimax():
|
||||
"""MiniMax 工具筛选也应复用同一套 JSON 提示路径。"""
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel(
|
||||
model_name="MiniMax-M2.7",
|
||||
base_url="https://api.minimaxi.com/anthropic/v1",
|
||||
runtime="anthropic_compatible",
|
||||
)
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
model=model,
|
||||
)
|
||||
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
if state_update:
|
||||
request.state.update(state_update)
|
||||
first_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
second_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
model.bind_calls,
|
||||
[{"response_format": {"type": "json_object"}}],
|
||||
)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in first_result.tools],
|
||||
["search", "calendar"],
|
||||
)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in second_result.tools],
|
||||
["search", "calendar"],
|
||||
)
|
||||
self.assertEqual(len(handled_requests), 2)
|
||||
assert state_update == {"selected_tool_names": ["search", "calendar"]}
|
||||
assert model.bind_calls == []
|
||||
system_message = model.messages[0]
|
||||
assert isinstance(system_message, SystemMessage)
|
||||
assert "Return the answer in JSON only." in system_message.content
|
||||
|
||||
def test_awrap_model_call_caches_non_deepseek_selection_too(self):
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel(
|
||||
model_name="gpt-4o-mini",
|
||||
base_url="https://api.openai.com/v1",
|
||||
)
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
model=model,
|
||||
)
|
||||
|
||||
async def handler(updated_request):
|
||||
return updated_request
|
||||
def test_awrap_model_call_uses_prompt_json_for_anthropic_runtime():
|
||||
"""Anthropic-compatible runtime 不应触发额外 provider 分支。"""
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel(
|
||||
model_name="kimi-k2",
|
||||
base_url="https://example.com/anthropic/v1",
|
||||
runtime="anthropic_compatible",
|
||||
)
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
model=model,
|
||||
)
|
||||
|
||||
parent_calls = 0
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
|
||||
async def _fake_parent_awrap(self, request_arg, handler_arg):
|
||||
nonlocal parent_calls
|
||||
parent_calls += 1
|
||||
selected_request = request_arg.override(
|
||||
tools=[request_arg.tools[1], request_arg.tools[0]]
|
||||
)
|
||||
return await handler_arg(selected_request)
|
||||
assert state_update == {"selected_tool_names": ["search", "calendar"]}
|
||||
assert model.bind_calls == []
|
||||
system_message = model.messages[0]
|
||||
assert isinstance(system_message, SystemMessage)
|
||||
assert "Return the answer in JSON only." in system_message.content
|
||||
|
||||
with patch.object(
|
||||
tool_selector_module.LLMToolSelectorMiddleware,
|
||||
"awrap_model_call",
|
||||
_fake_parent_awrap,
|
||||
):
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
if state_update:
|
||||
request.state.update(state_update)
|
||||
first_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
second_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
|
||||
self.assertEqual(parent_calls, 1)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in first_result.tools],
|
||||
["calendar", "search"],
|
||||
)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in second_result.tools],
|
||||
["calendar", "search"],
|
||||
)
|
||||
def test_awrap_model_call_reuses_first_selection_for_later_model_rounds():
|
||||
"""多轮模型回合应复用首轮筛选出的工具集合。"""
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel(content='{"tools": ["calendar", "search"]}')
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
model=model,
|
||||
)
|
||||
handled_requests = []
|
||||
|
||||
def test_normalize_selection_response_accepts_code_fence_json(self):
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware()
|
||||
response = SimpleNamespace(
|
||||
content=[
|
||||
{
|
||||
"type": "text",
|
||||
"text": '```json\n{"tools": ["search"]}\n```',
|
||||
}
|
||||
]
|
||||
)
|
||||
async def handler(updated_request):
|
||||
handled_requests.append(updated_request)
|
||||
return updated_request
|
||||
|
||||
normalized = middleware._normalize_selection_response(response)
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
if state_update:
|
||||
request.state.update(state_update)
|
||||
first_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
second_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
|
||||
self.assertEqual(normalized, {"tools": ["search"]})
|
||||
assert model.bind_calls == []
|
||||
assert [tool.name for tool in first_result.tools] == ["search", "calendar"]
|
||||
assert [tool.name for tool in second_result.tools] == ["search", "calendar"]
|
||||
assert len(handled_requests) == 2
|
||||
assert len(model.ainvoke_calls) == 1
|
||||
|
||||
def test_deepseek_selection_uses_recent_conversation_context(self):
|
||||
"""多轮追问时工具筛选应看到上一轮用户需求和助手回复。"""
|
||||
tools = [
|
||||
_tool(
|
||||
"query_plugin_config",
|
||||
"Query plugin config",
|
||||
[ToolTag.Read, ToolTag.Plugin, ToolTag.Settings],
|
||||
),
|
||||
_tool(
|
||||
"update_plugin_config",
|
||||
"Update plugin config",
|
||||
[ToolTag.Write, ToolTag.Plugin, ToolTag.Settings],
|
||||
),
|
||||
_tool(
|
||||
"reload_plugin",
|
||||
"Reload plugin",
|
||||
[ToolTag.Write, ToolTag.Plugin],
|
||||
),
|
||||
]
|
||||
model = _FakeModel(content='{"tools": ["query_plugin_config"]}')
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=3,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[
|
||||
HumanMessage(content="帮我检查插件 DemoPlugin 的配置"),
|
||||
AIMessage(content="我建议先查询插件配置,然后根据结果决定是否重载插件。"),
|
||||
HumanMessage(content="按你说的来"),
|
||||
],
|
||||
model=model,
|
||||
)
|
||||
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
def test_awrap_model_call_caches_plain_json_prompt_selection_too():
|
||||
"""普通模型也应只调用一次 JSON 提示筛选并缓存结果。"""
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel(
|
||||
model_name="gpt-4o-mini",
|
||||
base_url="https://api.openai.com/v1",
|
||||
)
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
model=model,
|
||||
)
|
||||
|
||||
user_message = model.bound_model.messages[1]
|
||||
self.assertEqual(
|
||||
state_update,
|
||||
{"selected_tool_names": ["query_plugin_config", "update_plugin_config", "reload_plugin"]},
|
||||
)
|
||||
self.assertIsInstance(user_message, HumanMessage)
|
||||
self.assertIn(
|
||||
"Recent conversation context for tool selection",
|
||||
user_message.content,
|
||||
)
|
||||
self.assertIn("帮我检查插件 DemoPlugin 的配置", user_message.content)
|
||||
self.assertIn("我建议先查询插件配置", user_message.content)
|
||||
self.assertIn("按你说的来", user_message.content)
|
||||
async def handler(updated_request):
|
||||
return updated_request
|
||||
|
||||
def test_single_turn_selection_keeps_original_user_message(self):
|
||||
"""单轮对话不应额外包裹上下文提示。"""
|
||||
tools = [
|
||||
_tool("search", "Search for information", [ToolTag.Read, ToolTag.Web]),
|
||||
_tool("calendar", "Manage events", [ToolTag.Write]),
|
||||
]
|
||||
model = _FakeModel(content='{"tools": ["search"]}')
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
original_message = HumanMessage(content="帮我查一下最近的更新")
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[original_message],
|
||||
model=model,
|
||||
)
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
if state_update:
|
||||
request.state.update(state_update)
|
||||
first_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
second_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
|
||||
asyncio.run(middleware.abefore_agent(request.state, runtime=None, config=None))
|
||||
assert model.bind_calls == []
|
||||
assert len(model.ainvoke_calls) == 1
|
||||
assert [tool.name for tool in first_result.tools] == ["search", "calendar"]
|
||||
assert [tool.name for tool in second_result.tools] == ["search", "calendar"]
|
||||
|
||||
user_message = model.bound_model.messages[1]
|
||||
self.assertIs(user_message, original_message)
|
||||
self.assertNotIn(
|
||||
"Recent conversation context for tool selection",
|
||||
user_message.content,
|
||||
)
|
||||
|
||||
def test_process_selection_response_completes_low_count_tool_group_by_tags(self):
|
||||
"""筛选结果过少时应按已命中的工具标签组补齐同组工具。"""
|
||||
tools = [
|
||||
_tool(
|
||||
"search_media",
|
||||
"Search media",
|
||||
[ToolTag.Read, ToolTag.Media],
|
||||
),
|
||||
_tool(
|
||||
"search_torrents",
|
||||
"Search torrents",
|
||||
[ToolTag.Read, ToolTag.Resource, ToolTag.Site, ToolTag.Media],
|
||||
),
|
||||
_tool(
|
||||
"get_search_results",
|
||||
"Get results",
|
||||
[ToolTag.Read, ToolTag.Resource],
|
||||
),
|
||||
_tool(
|
||||
"add_download_tasks",
|
||||
"Add downloads",
|
||||
[ToolTag.Write, ToolTag.Download, ToolTag.Resource],
|
||||
),
|
||||
_tool(
|
||||
"query_download_tasks",
|
||||
"Query downloads",
|
||||
[ToolTag.Read, ToolTag.Download],
|
||||
),
|
||||
]
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=4,
|
||||
selection_tools=tools,
|
||||
)
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我下载流浪地球")],
|
||||
model=_FakeModel(),
|
||||
)
|
||||
def test_tool_selection_failure_falls_back_to_all_tools():
|
||||
"""筛选模型返回空响应时不应中断 Agent 请求。"""
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
]
|
||||
model = _FakeModel(content=None)
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
model=model,
|
||||
)
|
||||
|
||||
result = middleware._process_selection_response(
|
||||
{"tools": ["search_torrents"]},
|
||||
available_tools=tools,
|
||||
valid_tool_names=[tool.name for tool in tools],
|
||||
request=request,
|
||||
)
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
|
||||
self.assertEqual(len(result.tools), 4)
|
||||
self.assertEqual(
|
||||
{tool.name for tool in result.tools},
|
||||
assert state_update == {"selected_tool_names": ["search", "calendar"]}
|
||||
|
||||
|
||||
def test_normalize_selection_response_accepts_code_fence_json():
|
||||
"""工具筛选响应应兼容 Markdown 代码围栏包裹的 JSON。"""
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware()
|
||||
response = SimpleNamespace(
|
||||
content=[
|
||||
{
|
||||
"search_media",
|
||||
"search_torrents",
|
||||
"get_search_results",
|
||||
"add_download_tasks",
|
||||
},
|
||||
)
|
||||
|
||||
def test_process_selection_response_keeps_high_count_selection(self):
|
||||
"""筛选结果数量足够时不应额外补齐工具。"""
|
||||
tools = [
|
||||
SimpleNamespace(name="search_media", description="Search media"),
|
||||
SimpleNamespace(name="search_torrents", description="Search torrents"),
|
||||
SimpleNamespace(name="get_search_results", description="Get results"),
|
||||
SimpleNamespace(name="query_sites", description="Query sites"),
|
||||
"type": "text",
|
||||
"text": '```json\n{"tools": ["search"]}\n```',
|
||||
}
|
||||
]
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=4,
|
||||
selection_tools=tools,
|
||||
)
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我下载流浪地球")],
|
||||
model=_FakeModel(),
|
||||
)
|
||||
)
|
||||
|
||||
result = middleware._process_selection_response(
|
||||
{
|
||||
"tools": [
|
||||
"search_media",
|
||||
"search_torrents",
|
||||
"get_search_results",
|
||||
"query_sites",
|
||||
]
|
||||
},
|
||||
available_tools=tools,
|
||||
valid_tool_names=[tool.name for tool in tools],
|
||||
request=request,
|
||||
)
|
||||
normalized = middleware._normalize_selection_response(response)
|
||||
|
||||
self.assertEqual(
|
||||
[tool.name for tool in result.tools],
|
||||
[
|
||||
assert normalized == {"tools": ["search"]}
|
||||
|
||||
|
||||
def test_json_prompt_selection_uses_recent_conversation_context():
|
||||
"""多轮追问时工具筛选应看到上一轮用户需求和助手回复。"""
|
||||
tools = [
|
||||
_tool(
|
||||
"query_plugin_config",
|
||||
"Query plugin config",
|
||||
[ToolTag.Read, ToolTag.Plugin, ToolTag.Settings],
|
||||
),
|
||||
_tool(
|
||||
"update_plugin_config",
|
||||
"Update plugin config",
|
||||
[ToolTag.Write, ToolTag.Plugin, ToolTag.Settings],
|
||||
),
|
||||
_tool(
|
||||
"reload_plugin",
|
||||
"Reload plugin",
|
||||
[ToolTag.Write, ToolTag.Plugin],
|
||||
),
|
||||
]
|
||||
model = _FakeModel(content='{"tools": ["query_plugin_config"]}')
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=3,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[
|
||||
HumanMessage(content="帮我检查插件 DemoPlugin 的配置"),
|
||||
AIMessage(content="我建议先查询插件配置,然后根据结果决定是否重载插件。"),
|
||||
HumanMessage(content="按你说的来"),
|
||||
],
|
||||
model=model,
|
||||
)
|
||||
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
|
||||
user_message = model.messages[1]
|
||||
assert state_update == {
|
||||
"selected_tool_names": [
|
||||
"query_plugin_config",
|
||||
"update_plugin_config",
|
||||
"reload_plugin",
|
||||
]
|
||||
}
|
||||
assert isinstance(user_message, HumanMessage)
|
||||
assert "Recent conversation context for tool selection" in user_message.content
|
||||
assert "帮我检查插件 DemoPlugin 的配置" in user_message.content
|
||||
assert "我建议先查询插件配置" in user_message.content
|
||||
assert "按你说的来" in user_message.content
|
||||
|
||||
|
||||
def test_single_turn_selection_keeps_original_user_message():
|
||||
"""单轮对话不应额外包裹上下文提示。"""
|
||||
tools = [
|
||||
_tool("search", "Search for information", [ToolTag.Read, ToolTag.Web]),
|
||||
_tool("calendar", "Manage events", [ToolTag.Write]),
|
||||
]
|
||||
model = _FakeModel(content='{"tools": ["search"]}')
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
original_message = HumanMessage(content="帮我查一下最近的更新")
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[original_message],
|
||||
model=model,
|
||||
)
|
||||
|
||||
asyncio.run(middleware.abefore_agent(request.state, runtime=None, config=None))
|
||||
|
||||
user_message = model.messages[1]
|
||||
assert user_message is original_message
|
||||
assert "Recent conversation context for tool selection" not in user_message.content
|
||||
|
||||
|
||||
def test_process_selection_response_completes_low_count_tool_group_by_tags():
|
||||
"""筛选结果过少时应按已命中的工具标签组补齐同组工具。"""
|
||||
tools = [
|
||||
_tool(
|
||||
"search_media",
|
||||
"Search media",
|
||||
[ToolTag.Read, ToolTag.Media],
|
||||
),
|
||||
_tool(
|
||||
"search_torrents",
|
||||
"Search torrents",
|
||||
[ToolTag.Read, ToolTag.Resource, ToolTag.Site, ToolTag.Media],
|
||||
),
|
||||
_tool(
|
||||
"get_search_results",
|
||||
"Get results",
|
||||
[ToolTag.Read, ToolTag.Resource],
|
||||
),
|
||||
_tool(
|
||||
"add_download_tasks",
|
||||
"Add downloads",
|
||||
[ToolTag.Write, ToolTag.Download, ToolTag.Resource],
|
||||
),
|
||||
_tool(
|
||||
"query_download_tasks",
|
||||
"Query downloads",
|
||||
[ToolTag.Read, ToolTag.Download],
|
||||
),
|
||||
]
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=4,
|
||||
selection_tools=tools,
|
||||
)
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我下载流浪地球")],
|
||||
model=_FakeModel(),
|
||||
)
|
||||
|
||||
result = middleware._process_selection_response(
|
||||
{"tools": ["search_torrents"]},
|
||||
available_tools=tools,
|
||||
valid_tool_names=[tool.name for tool in tools],
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert len(result.tools) == 4
|
||||
assert {tool.name for tool in result.tools} == {
|
||||
"search_media",
|
||||
"search_torrents",
|
||||
"get_search_results",
|
||||
"add_download_tasks",
|
||||
}
|
||||
|
||||
|
||||
def test_process_selection_response_keeps_high_count_selection():
|
||||
"""筛选结果数量足够时不应额外补齐工具。"""
|
||||
tools = [
|
||||
SimpleNamespace(name="search_media", description="Search media"),
|
||||
SimpleNamespace(name="search_torrents", description="Search torrents"),
|
||||
SimpleNamespace(name="get_search_results", description="Get results"),
|
||||
SimpleNamespace(name="query_sites", description="Query sites"),
|
||||
]
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=4,
|
||||
selection_tools=tools,
|
||||
)
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我下载流浪地球")],
|
||||
model=_FakeModel(),
|
||||
)
|
||||
|
||||
result = middleware._process_selection_response(
|
||||
{
|
||||
"tools": [
|
||||
"search_media",
|
||||
"search_torrents",
|
||||
"get_search_results",
|
||||
"query_sites",
|
||||
],
|
||||
)
|
||||
]
|
||||
},
|
||||
available_tools=tools,
|
||||
valid_tool_names=[tool.name for tool in tools],
|
||||
request=request,
|
||||
)
|
||||
|
||||
def test_process_selection_response_respects_max_tools_when_completing(self):
|
||||
"""标签组补齐不应突破 max_tools 上限。"""
|
||||
tools = [
|
||||
_tool(
|
||||
"list_directory",
|
||||
"List directory",
|
||||
[ToolTag.Read, ToolTag.Directory, ToolTag.File],
|
||||
),
|
||||
_tool(
|
||||
"query_directory_settings",
|
||||
"Query settings",
|
||||
[ToolTag.Read, ToolTag.Directory, ToolTag.Settings],
|
||||
),
|
||||
_tool(
|
||||
"recognize_media",
|
||||
"Recognize media",
|
||||
[ToolTag.Read, ToolTag.Media],
|
||||
),
|
||||
_tool(
|
||||
"transfer_file",
|
||||
"Transfer file",
|
||||
[ToolTag.Write, ToolTag.Transfer, ToolTag.Library, ToolTag.File],
|
||||
),
|
||||
]
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我整理这个目录")],
|
||||
model=_FakeModel(),
|
||||
)
|
||||
assert [tool.name for tool in result.tools] == [
|
||||
"search_media",
|
||||
"search_torrents",
|
||||
"get_search_results",
|
||||
"query_sites",
|
||||
]
|
||||
|
||||
result = middleware._process_selection_response(
|
||||
{"tools": ["transfer_file"]},
|
||||
available_tools=tools,
|
||||
valid_tool_names=[tool.name for tool in tools],
|
||||
request=request,
|
||||
)
|
||||
|
||||
self.assertEqual(len(result.tools), 2)
|
||||
self.assertEqual(
|
||||
{tool.name for tool in result.tools},
|
||||
{"transfer_file", "list_directory"},
|
||||
)
|
||||
def test_process_selection_response_respects_max_tools_when_completing():
|
||||
"""标签组补齐不应突破 max_tools 上限。"""
|
||||
tools = [
|
||||
_tool(
|
||||
"list_directory",
|
||||
"List directory",
|
||||
[ToolTag.Read, ToolTag.Directory, ToolTag.File],
|
||||
),
|
||||
_tool(
|
||||
"query_directory_settings",
|
||||
"Query settings",
|
||||
[ToolTag.Read, ToolTag.Directory, ToolTag.Settings],
|
||||
),
|
||||
_tool(
|
||||
"recognize_media",
|
||||
"Recognize media",
|
||||
[ToolTag.Read, ToolTag.Media],
|
||||
),
|
||||
_tool(
|
||||
"transfer_file",
|
||||
"Transfer file",
|
||||
[ToolTag.Write, ToolTag.Transfer, ToolTag.Library, ToolTag.File],
|
||||
),
|
||||
]
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我整理这个目录")],
|
||||
model=_FakeModel(),
|
||||
)
|
||||
|
||||
def test_process_selection_response_ignores_generic_tags_when_completing(self):
|
||||
"""通用权限标签不应被当作工具组使用。"""
|
||||
tools = [
|
||||
_tool("read_one", "Read one", [ToolTag.Read]),
|
||||
_tool("read_two", "Read two", [ToolTag.Read]),
|
||||
_tool("write_one", "Write one", [ToolTag.Write, ToolTag.Admin]),
|
||||
]
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=4,
|
||||
selection_tools=tools,
|
||||
)
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="查一下信息")],
|
||||
model=_FakeModel(),
|
||||
)
|
||||
result = middleware._process_selection_response(
|
||||
{"tools": ["transfer_file"]},
|
||||
available_tools=tools,
|
||||
valid_tool_names=[tool.name for tool in tools],
|
||||
request=request,
|
||||
)
|
||||
|
||||
result = middleware._process_selection_response(
|
||||
{"tools": ["read_one"]},
|
||||
available_tools=tools,
|
||||
valid_tool_names=[tool.name for tool in tools],
|
||||
request=request,
|
||||
)
|
||||
assert len(result.tools) == 2
|
||||
assert {tool.name for tool in result.tools} == {"transfer_file", "list_directory"}
|
||||
|
||||
self.assertEqual([tool.name for tool in result.tools], ["read_one"])
|
||||
|
||||
def test_process_selection_response_ignores_generic_tags_when_completing():
|
||||
"""通用权限标签不应被当作工具组使用。"""
|
||||
tools = [
|
||||
_tool("read_one", "Read one", [ToolTag.Read]),
|
||||
_tool("read_two", "Read two", [ToolTag.Read]),
|
||||
_tool("write_one", "Write one", [ToolTag.Write, ToolTag.Admin]),
|
||||
]
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=4,
|
||||
selection_tools=tools,
|
||||
)
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="查一下信息")],
|
||||
model=_FakeModel(),
|
||||
)
|
||||
|
||||
result = middleware._process_selection_response(
|
||||
{"tools": ["read_one"]},
|
||||
available_tools=tools,
|
||||
valid_tool_names=[tool.name for tool in tools],
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert [tool.name for tool in result.tools] == ["read_one"]
|
||||
|
||||
@@ -589,6 +589,59 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(llm_calls[0].get("default_headers"), {"X-Test": "1"})
|
||||
|
||||
def test_get_llm_attaches_runtime_metadata(self):
|
||||
"""LLM 实例应带上内部 runtime 元数据,供 Agent 中间件判断兼容分支。"""
|
||||
|
||||
class _FakeProviderManager:
|
||||
async def resolve_runtime(self, **kwargs):
|
||||
return {
|
||||
"provider_id": kwargs["provider_id"],
|
||||
"runtime": "anthropic_compatible",
|
||||
"model_id": kwargs["model"],
|
||||
"api_key": kwargs["api_key"],
|
||||
"base_url": kwargs["base_url"],
|
||||
"default_headers": None,
|
||||
"use_responses_api": None,
|
||||
"model_record": None,
|
||||
"model_metadata": None,
|
||||
}
|
||||
|
||||
class _FakeChatAnthropic:
|
||||
def __init__(self, **kwargs):
|
||||
self.model = kwargs["model"]
|
||||
self.profile = None
|
||||
|
||||
provider_module = ModuleType("app.agent.llm.provider")
|
||||
provider_module.LLMProviderManager = _FakeProviderManager
|
||||
anthropic_module = ModuleType("langchain_anthropic")
|
||||
anthropic_module.ChatAnthropic = _FakeChatAnthropic
|
||||
|
||||
with patch.dict(
|
||||
sys.modules,
|
||||
{
|
||||
"app.agent.llm.provider": provider_module,
|
||||
"langchain_anthropic": anthropic_module,
|
||||
},
|
||||
):
|
||||
model = asyncio.run(
|
||||
llm_module.LLMHelper.get_llm(
|
||||
provider="minimax",
|
||||
model="MiniMax-M2.7",
|
||||
api_key="sk-test",
|
||||
base_url="https://api.minimaxi.com/anthropic/v1",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
getattr(model, "_moviepilot_llm_runtime"),
|
||||
"anthropic_compatible",
|
||||
)
|
||||
self.assertEqual(getattr(model, "_moviepilot_llm_provider_id"), "minimax")
|
||||
self.assertEqual(
|
||||
getattr(model, "_moviepilot_llm_base_url"),
|
||||
"https://api.minimaxi.com/anthropic/v1",
|
||||
)
|
||||
|
||||
def test_get_llm_applies_proxy_only_when_enabled(self):
|
||||
"""LLM 构造时应按独立开关决定是否传入系统代理。"""
|
||||
calls = []
|
||||
|
||||
Reference in New Issue
Block a user