mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-27 18:52:23 +08:00
fix tool selection middleware
This commit is contained in:
@@ -11,7 +11,6 @@ from typing import Any, Callable, Dict, List, Optional
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import (
|
||||
SummarizationMiddleware,
|
||||
LLMToolSelectorMiddleware,
|
||||
)
|
||||
from langchain_core.messages import ( # noqa: F401
|
||||
HumanMessage,
|
||||
@@ -20,6 +19,7 @@ from langchain_core.messages import ( # noqa: F401
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from app.agent.callback import StreamingHandler
|
||||
from app.agent.llm import LLMHelper
|
||||
from app.agent.memory import memory_manager
|
||||
from app.agent.middleware.activity_log import ActivityLogMiddleware
|
||||
from app.agent.middleware.jobs import JobsMiddleware
|
||||
@@ -27,13 +27,13 @@ from app.agent.middleware.memory import MemoryMiddleware
|
||||
from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from app.agent.middleware.runtime_config import RuntimeConfigMiddleware
|
||||
from app.agent.middleware.skills import SkillsMiddleware
|
||||
from app.agent.middleware.tool_selection import MoviePilotToolSelectorMiddleware
|
||||
from app.agent.middleware.usage import UsageMiddleware
|
||||
from app.agent.prompt import prompt_manager
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.agent.llm import LLMHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
|
||||
@@ -110,7 +110,7 @@ class _ThinkTagStripper:
|
||||
on_output(self.buffer[:start_idx])
|
||||
emitted = True
|
||||
self.in_think_tag = True
|
||||
self.buffer = self.buffer[start_idx + 7 :]
|
||||
self.buffer = self.buffer[start_idx + 7:]
|
||||
else:
|
||||
# 检查是否以 <think> 的不完整前缀结尾
|
||||
partial_match = False
|
||||
@@ -130,7 +130,7 @@ class _ThinkTagStripper:
|
||||
end_idx = self.buffer.find("</think>")
|
||||
if end_idx != -1:
|
||||
self.in_think_tag = False
|
||||
self.buffer = self.buffer[end_idx + 8 :]
|
||||
self.buffer = self.buffer[end_idx + 8:]
|
||||
else:
|
||||
# 检查是否以 </think> 的不完整前缀结尾
|
||||
partial_match = False
|
||||
@@ -166,12 +166,12 @@ class MoviePilotAgent:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
):
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
@@ -200,16 +200,16 @@ class MoviePilotAgent:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_model_name(cls, llm: Any) -> Optional[str]:
|
||||
def _get_model_name(cls, model: Any) -> Optional[str]:
|
||||
return (
|
||||
getattr(llm, "model", None)
|
||||
or getattr(llm, "model_name", None)
|
||||
or getattr(llm, "model_id", None)
|
||||
getattr(model, "model", None)
|
||||
or getattr(model, "model_name", None)
|
||||
or getattr(model, "model_id", None)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_context_window_tokens(cls, llm: Any) -> Optional[int]:
|
||||
profile = getattr(llm, "profile", None)
|
||||
def _get_context_window_tokens(cls, model: Any) -> Optional[int]:
|
||||
profile = getattr(model, "profile", None)
|
||||
if not profile:
|
||||
return None
|
||||
if isinstance(profile, dict):
|
||||
@@ -221,9 +221,9 @@ class MoviePilotAgent:
|
||||
or getattr(profile, "input_token_limit", None)
|
||||
)
|
||||
|
||||
def _sync_model_profile(self, llm: Any) -> None:
|
||||
model_name = self._get_model_name(llm)
|
||||
context_window_tokens = self._get_context_window_tokens(llm)
|
||||
def _sync_model_profile(self, model: Any) -> None:
|
||||
model_name = self._get_model_name(model)
|
||||
context_window_tokens = self._get_context_window_tokens(model)
|
||||
if model_name:
|
||||
self._session_usage.model = model_name
|
||||
if context_window_tokens:
|
||||
@@ -337,10 +337,10 @@ class MoviePilotAgent:
|
||||
if block.get("thought"):
|
||||
continue
|
||||
if block.get("type") in (
|
||||
"thinking",
|
||||
"reasoning_content",
|
||||
"reasoning",
|
||||
"thought",
|
||||
"thinking",
|
||||
"reasoning_content",
|
||||
"reasoning",
|
||||
"thought",
|
||||
):
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
@@ -397,8 +397,8 @@ class MoviePilotAgent:
|
||||
system_prompt = prompt_manager.get_agent_prompt(channel=self.channel)
|
||||
|
||||
# LLM 模型(用于 agent 执行)
|
||||
llm = await self._initialize_llm(streaming=streaming)
|
||||
self._sync_model_profile(llm)
|
||||
model = await self._initialize_llm(streaming=streaming)
|
||||
self._sync_model_profile(model)
|
||||
|
||||
# 为中间件内部模型调用准备非流式 LLM,避免与用户流式回复复用同一实例。
|
||||
non_streaming_llm = (
|
||||
@@ -444,7 +444,7 @@ class MoviePilotAgent:
|
||||
# 工具选择
|
||||
if max_tools > 0:
|
||||
middlewares.append(
|
||||
LLMToolSelectorMiddleware(
|
||||
MoviePilotToolSelectorMiddleware(
|
||||
model=non_streaming_llm,
|
||||
max_tools=max_tools,
|
||||
always_include=always_include_tools,
|
||||
@@ -463,10 +463,10 @@ class MoviePilotAgent:
|
||||
raise e
|
||||
|
||||
async def process(
|
||||
self,
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
self,
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息,流式推理并返回 Agent 回复
|
||||
@@ -519,7 +519,7 @@ class MoviePilotAgent:
|
||||
return error_message
|
||||
|
||||
async def _stream_agent_tokens(
|
||||
self, agent, messages: dict, config: dict, on_token: Callable[[str], None]
|
||||
self, agent, messages: dict, config: dict, on_token: Callable[[str], None]
|
||||
):
|
||||
"""
|
||||
流式运行智能体,过滤工具调用token和思考内容,将模型生成的内容通过回调输出。
|
||||
@@ -531,11 +531,11 @@ class MoviePilotAgent:
|
||||
stripper = _ThinkTagStripper()
|
||||
|
||||
async for chunk in agent.astream(
|
||||
messages,
|
||||
stream_mode="messages",
|
||||
config=config,
|
||||
subgraphs=False,
|
||||
version="v2",
|
||||
messages,
|
||||
stream_mode="messages",
|
||||
config=config,
|
||||
subgraphs=False,
|
||||
version="v2",
|
||||
):
|
||||
if chunk["type"] == "messages":
|
||||
token, metadata = chunk["data"]
|
||||
@@ -621,21 +621,21 @@ class MoviePilotAgent:
|
||||
if remaining_text:
|
||||
unsent_text = remaining_text
|
||||
if self._streamed_output and remaining_text.startswith(
|
||||
self._streamed_output
|
||||
self._streamed_output
|
||||
):
|
||||
unsent_text = remaining_text[len(self._streamed_output) :]
|
||||
unsent_text = remaining_text[len(self._streamed_output):]
|
||||
if unsent_text:
|
||||
self._emit_output(unsent_text)
|
||||
if (
|
||||
remaining_text
|
||||
and self.should_dispatch_reply
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
remaining_text
|
||||
and self.should_dispatch_reply
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
):
|
||||
await self.send_agent_message(remaining_text)
|
||||
elif (
|
||||
remaining_text
|
||||
and self.persist_output_message
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
remaining_text
|
||||
and self.persist_output_message
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
):
|
||||
title = "MoviePilot助手" if self.is_background else ""
|
||||
await self._save_agent_message_to_db(
|
||||
@@ -674,9 +674,9 @@ class MoviePilotAgent:
|
||||
self._emit_output(final_text)
|
||||
|
||||
if (
|
||||
final_text
|
||||
and self.should_dispatch_reply
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
final_text
|
||||
and self.should_dispatch_reply
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
):
|
||||
if self.is_background:
|
||||
# 后台任务发送最终回复时统一带标题
|
||||
@@ -687,9 +687,9 @@ class MoviePilotAgent:
|
||||
# 非流式渠道:发送最终回复
|
||||
await self.send_agent_message(final_text)
|
||||
elif (
|
||||
final_text
|
||||
and self.persist_output_message
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
final_text
|
||||
and self.persist_output_message
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
):
|
||||
title = "MoviePilot助手" if self.is_background else ""
|
||||
await self._save_agent_message_to_db(final_text, title=title)
|
||||
@@ -810,8 +810,8 @@ class AgentManager:
|
||||
queue = self._session_queues.get(session_id)
|
||||
status["pending_messages"] = queue.qsize() if queue else 0
|
||||
status["is_processing"] = (
|
||||
session_id in self._session_workers
|
||||
and not self._session_workers[session_id].done()
|
||||
session_id in self._session_workers
|
||||
and not self._session_workers[session_id].done()
|
||||
)
|
||||
return status
|
||||
|
||||
@@ -843,16 +843,16 @@ class AgentManager:
|
||||
self.active_agents.clear()
|
||||
|
||||
async def process_message(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
reply_mode: ReplyMode = ReplyMode.DISPATCH,
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
reply_mode: ReplyMode = ReplyMode.DISPATCH,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息:将消息放入会话队列,按顺序依次处理。
|
||||
@@ -879,8 +879,8 @@ class AgentManager:
|
||||
|
||||
# 如果队列中已有等待的消息,通知用户消息已排队
|
||||
if queue_size > 0 or (
|
||||
session_id in self._session_workers
|
||||
and not self._session_workers[session_id].done()
|
||||
session_id in self._session_workers
|
||||
and not self._session_workers[session_id].done()
|
||||
):
|
||||
logger.info(
|
||||
f"会话 {session_id} 有任务正在处理,消息已排队等待 "
|
||||
@@ -892,8 +892,8 @@ class AgentManager:
|
||||
|
||||
# 确保该会话有一个worker在运行
|
||||
if (
|
||||
session_id not in self._session_workers
|
||||
or self._session_workers[session_id].done()
|
||||
session_id not in self._session_workers
|
||||
or self._session_workers[session_id].done()
|
||||
):
|
||||
self._session_workers[session_id] = asyncio.create_task(
|
||||
self._session_worker(session_id)
|
||||
@@ -934,8 +934,8 @@ class AgentManager:
|
||||
self._session_workers.pop(session_id, None) # noqa
|
||||
# 如果队列为空,清理队列
|
||||
if (
|
||||
session_id in self._session_queues
|
||||
and self._session_queues[session_id].empty()
|
||||
session_id in self._session_queues
|
||||
and self._session_queues[session_id].empty()
|
||||
):
|
||||
self._session_queues.pop(session_id, None)
|
||||
|
||||
@@ -1033,12 +1033,12 @@ class AgentManager:
|
||||
|
||||
@staticmethod
|
||||
async def run_background_prompt(
|
||||
message: str,
|
||||
session_prefix: str = "__agent_background",
|
||||
output_callback: Optional[Callable[[str], None]] = None,
|
||||
reply_mode: ReplyMode = ReplyMode.CAPTURE_ONLY,
|
||||
persist_output_message: bool = True,
|
||||
allow_message_tools: Optional[bool] = None,
|
||||
message: str,
|
||||
session_prefix: str = "__agent_background",
|
||||
output_callback: Optional[Callable[[str], None]] = None,
|
||||
reply_mode: ReplyMode = ReplyMode.CAPTURE_ONLY,
|
||||
persist_output_message: bool = True,
|
||||
allow_message_tools: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""
|
||||
以独立后台会话执行一段 prompt。
|
||||
|
||||
196
app/agent/middleware/tool_selection.py
Normal file
196
app/agent/middleware/tool_selection.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""MoviePilot 自定义工具筛选中间件。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import LLMToolSelectorMiddleware
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
"""
|
||||
为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。
|
||||
|
||||
LangChain 默认会通过 `with_structured_output()` 走 OpenAI 的
|
||||
`response_format=json_schema` 路径,但 DeepSeek 官方 OpenAI 兼容端点公开文档
|
||||
仅保证 `json_object` 模式可用。对于 `deepseek-reasoner`,这会在工具筛选阶段
|
||||
提前触发 400,导致 Agent 还没真正开始执行工具就失败。
|
||||
|
||||
因此这里仅在识别到 DeepSeek 模型/端点时,退回到显式 JSON 输出模式:
|
||||
1. 使用 `response_format={"type": "json_object"}`;
|
||||
2. 在提示词中明确约束返回 JSON 结构;
|
||||
3. 手动解析 `{"tools": [...]}`,其余模型继续沿用 LangChain 默认实现。
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _is_deepseek_compatible_model(model: BaseChatModel) -> bool:
|
||||
"""
|
||||
判断当前模型是否应当走 DeepSeek JSON 兼容分支。
|
||||
|
||||
除了官方 `langchain_deepseek`,用户也可能通过 OpenAI-compatible
|
||||
配置把 DeepSeek 端点接到 `ChatOpenAI`。因此这里同时检查模块名、模型名
|
||||
和 Base URL,避免只靠单一条件漏判。
|
||||
"""
|
||||
module_name = type(model).__module__.lower()
|
||||
model_name = str(
|
||||
getattr(model, "model_name", "") or getattr(model, "model", "")
|
||||
).strip().lower()
|
||||
base_url = str(
|
||||
getattr(model, "openai_api_base", "") or getattr(model, "api_base", "")
|
||||
).strip().lower()
|
||||
|
||||
return (
|
||||
"deepseek" in module_name
|
||||
or model_name.startswith("deepseek-")
|
||||
or "api.deepseek.com" in base_url
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_content(content: Any) -> str:
|
||||
"""
|
||||
从模型响应中提取纯文本。
|
||||
|
||||
这里不依赖上层 LLMHelper,避免中间件与 LLM 构造逻辑互相耦合。
|
||||
"""
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
continue
|
||||
if isinstance(block, dict):
|
||||
if block.get("type") == "text" and isinstance(
|
||||
block.get("text"), str
|
||||
):
|
||||
text_parts.append(block["text"])
|
||||
continue
|
||||
if not block.get("type") and isinstance(block.get("text"), str):
|
||||
text_parts.append(block["text"])
|
||||
return "".join(text_parts)
|
||||
if isinstance(content, dict):
|
||||
if content.get("type") == "text" and isinstance(content.get("text"), str):
|
||||
return content["text"]
|
||||
if not content.get("type") and isinstance(content.get("text"), str):
|
||||
return content["text"]
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _parse_json_object(text: str) -> dict[str, Any]:
|
||||
"""
|
||||
解析模型返回的 JSON。
|
||||
|
||||
DeepSeek 在 JSON 模式下通常会返回纯 JSON,但这里仍做一层兜底,
|
||||
兼容模型偶发输出围栏或前后说明文本的情况。
|
||||
"""
|
||||
stripped_text = text.strip()
|
||||
if not stripped_text:
|
||||
raise ValueError("工具筛选返回了空响应")
|
||||
|
||||
try:
|
||||
payload = json.loads(stripped_text)
|
||||
if isinstance(payload, dict):
|
||||
return payload
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
start = stripped_text.find("{")
|
||||
end = stripped_text.rfind("}")
|
||||
if start == -1 or end == -1 or end <= start:
|
||||
raise ValueError(f"工具筛选返回的内容不是合法 JSON: {stripped_text}")
|
||||
|
||||
payload = json.loads(stripped_text[start: end + 1])
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("工具筛选 JSON 顶层必须是对象")
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _render_tool_list(available_tools: list[Any]) -> str:
|
||||
"""把工具名和描述渲染成稳定的文本列表。"""
|
||||
return "\n".join(
|
||||
f"- {tool.name}: {tool.description}" for tool in available_tools
|
||||
)
|
||||
|
||||
def _build_deepseek_selection_prompt(self, selection_request: Any) -> str:
|
||||
"""
|
||||
为 DeepSeek 生成显式 JSON 输出提示。
|
||||
|
||||
DeepSeek 官方文档要求在 JSON 输出模式下,提示词中必须明确包含 JSON
|
||||
约束,否则兼容端点可能返回空内容或无意义输出。
|
||||
"""
|
||||
return (
|
||||
f"{selection_request.system_message}\n\n"
|
||||
"Return the answer in JSON only.\n"
|
||||
'Use exactly this shape: {"tools": ["tool_name_1", "tool_name_2"]}\n'
|
||||
"Rules:\n"
|
||||
"- The `tools` field must be a JSON array of strings.\n"
|
||||
"- Only use tool names from the allowed list below.\n"
|
||||
"- Order tools by relevance, with the most relevant first.\n"
|
||||
"- Do not add explanations, markdown, or extra keys.\n\n"
|
||||
f"Allowed tools:\n{self._render_tool_list(selection_request.available_tools)}"
|
||||
)
|
||||
|
||||
def _normalize_selection_response(self, response: Any) -> dict[str, list[str]]:
|
||||
"""
|
||||
解析并标准化 DeepSeek JSON 模式的工具筛选结果。
|
||||
"""
|
||||
content = getattr(response, "content", response)
|
||||
text = self._extract_text_content(content)
|
||||
payload = self._parse_json_object(text)
|
||||
|
||||
tools = payload.get("tools")
|
||||
if not isinstance(tools, list):
|
||||
raise ValueError(f"工具筛选 JSON 缺少 `tools` 数组: {payload}")
|
||||
|
||||
normalized_tools = [tool_name for tool_name in tools if isinstance(tool_name, str)]
|
||||
return {"tools": normalized_tools}
|
||||
|
||||
async def _aselect_tools_with_deepseek(
|
||||
self, selection_request: Any
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
使用 DeepSeek 兼容的 JSON 输出模式执行异步工具筛选。
|
||||
"""
|
||||
logger.debug("工具筛选走 DeepSeek JSON 兼容分支")
|
||||
structured_model = selection_request.model.bind(
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
response = await structured_model.ainvoke(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": self._build_deepseek_selection_prompt(
|
||||
selection_request
|
||||
),
|
||||
},
|
||||
selection_request.last_user_message,
|
||||
]
|
||||
)
|
||||
return self._normalize_selection_response(response)
|
||||
|
||||
async def awrap_model_call(self, request: Any, handler: Any) -> Any:
|
||||
"""
|
||||
异步版本的 DeepSeek 工具筛选兼容分支。
|
||||
"""
|
||||
selection_request = self._prepare_selection_request(request)
|
||||
if selection_request is None:
|
||||
return await handler(request)
|
||||
|
||||
if not self._is_deepseek_compatible_model(selection_request.model):
|
||||
return await super().awrap_model_call(request, handler)
|
||||
|
||||
response = await self._aselect_tools_with_deepseek(selection_request)
|
||||
modified_request = self._process_selection_response(
|
||||
response,
|
||||
selection_request.available_tools,
|
||||
selection_request.valid_tool_names,
|
||||
request,
|
||||
)
|
||||
return await handler(modified_request)
|
||||
141
tests/test_agent_tool_selector_middleware.py
Normal file
141
tests/test_agent_tool_selector_middleware.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = sys.modules.get(name)
|
||||
if module is None:
|
||||
module = ModuleType(name)
|
||||
sys.modules[name] = module
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
return module
|
||||
|
||||
|
||||
sys.modules.pop("app.agent.middleware.tool_selection", None)
|
||||
_stub_module(
|
||||
"app.log",
|
||||
logger=SimpleNamespace(debug=lambda *args, **kwargs: None),
|
||||
)
|
||||
|
||||
module_path = (
|
||||
Path(__file__).resolve().parents[1]
|
||||
/ "app"
|
||||
/ "agent"
|
||||
/ "middleware"
|
||||
/ "tool_selection.py"
|
||||
)
|
||||
spec = importlib.util.spec_from_file_location("test_tool_selector_module", module_path)
|
||||
tool_selector_module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(tool_selector_module)
|
||||
|
||||
|
||||
class _FakeBoundModel:
|
||||
def __init__(self, content):
|
||||
self.content = content
|
||||
self.messages = None
|
||||
|
||||
def invoke(self, messages):
|
||||
self.messages = messages
|
||||
return SimpleNamespace(content=self.content)
|
||||
|
||||
async def ainvoke(self, messages):
|
||||
self.messages = messages
|
||||
return SimpleNamespace(content=self.content)
|
||||
|
||||
|
||||
class _FakeModel:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
content='{"tools": ["calendar", "search"]}',
|
||||
model_name="deepseek-reasoner",
|
||||
base_url="https://api.deepseek.com",
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.openai_api_base = base_url
|
||||
self.bind_calls = []
|
||||
self.bound_model = _FakeBoundModel(content)
|
||||
|
||||
def bind(self, **kwargs):
|
||||
self.bind_calls.append(kwargs)
|
||||
return self.bound_model
|
||||
|
||||
|
||||
class _FakeRequest:
|
||||
def __init__(self, *, tools, messages, model):
|
||||
self.tools = tools
|
||||
self.messages = messages
|
||||
self.model = model
|
||||
|
||||
def override(self, **kwargs):
|
||||
data = {
|
||||
"tools": self.tools,
|
||||
"messages": self.messages,
|
||||
"model": self.model,
|
||||
}
|
||||
data.update(kwargs)
|
||||
return _FakeRequest(**data)
|
||||
|
||||
|
||||
class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
def test_awrap_model_call_uses_json_mode_for_deepseek(self):
|
||||
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(max_tools=2)
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel()
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
model=model,
|
||||
)
|
||||
handled_requests = []
|
||||
|
||||
async def handler(updated_request):
|
||||
handled_requests.append(updated_request)
|
||||
return updated_request
|
||||
|
||||
result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
|
||||
self.assertEqual(
|
||||
model.bind_calls,
|
||||
[{"response_format": {"type": "json_object"}}],
|
||||
)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in result.tools],
|
||||
["search", "calendar"],
|
||||
)
|
||||
prompt = model.bound_model.messages[0]["content"]
|
||||
self.assertIn("Return the answer in JSON only.", prompt)
|
||||
self.assertIn('- search: Search for information', prompt)
|
||||
self.assertIn('- calendar: Manage events', prompt)
|
||||
self.assertEqual(len(handled_requests), 1)
|
||||
|
||||
def test_normalize_selection_response_accepts_code_fence_json(self):
|
||||
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware()
|
||||
response = SimpleNamespace(
|
||||
content=[
|
||||
{
|
||||
"type": "text",
|
||||
"text": '```json\n{"tools": ["search"]}\n```',
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
normalized = middleware._normalize_selection_response(response)
|
||||
|
||||
self.assertEqual(normalized, {"tools": ["search"]})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user