From 45f5326fb4213d261dd0029988bd7eb0d179d061 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 30 Apr 2026 13:47:43 +0800 Subject: [PATCH] fix tool selection middleware --- app/agent/__init__.py | 148 +++++++------- app/agent/middleware/tool_selection.py | 196 +++++++++++++++++++ tests/test_agent_tool_selector_middleware.py | 141 +++++++++++++ 3 files changed, 411 insertions(+), 74 deletions(-) create mode 100644 app/agent/middleware/tool_selection.py create mode 100644 tests/test_agent_tool_selector_middleware.py diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 853ad9c9..6faa4d3b 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -11,7 +11,6 @@ from typing import Any, Callable, Dict, List, Optional from langchain.agents import create_agent from langchain.agents.middleware import ( SummarizationMiddleware, - LLMToolSelectorMiddleware, ) from langchain_core.messages import ( # noqa: F401 HumanMessage, @@ -20,6 +19,7 @@ from langchain_core.messages import ( # noqa: F401 from langgraph.checkpoint.memory import InMemorySaver from app.agent.callback import StreamingHandler +from app.agent.llm import LLMHelper from app.agent.memory import memory_manager from app.agent.middleware.activity_log import ActivityLogMiddleware from app.agent.middleware.jobs import JobsMiddleware @@ -27,13 +27,13 @@ from app.agent.middleware.memory import MemoryMiddleware from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware from app.agent.middleware.runtime_config import RuntimeConfigMiddleware from app.agent.middleware.skills import SkillsMiddleware +from app.agent.middleware.tool_selection import MoviePilotToolSelectorMiddleware from app.agent.middleware.usage import UsageMiddleware from app.agent.prompt import prompt_manager from app.agent.runtime import agent_runtime_manager from app.agent.tools.factory import MoviePilotToolFactory from app.chain import ChainBase from app.core.config import settings -from app.agent.llm import LLMHelper from app.log import logger from app.schemas import Notification, NotificationType from app.schemas.message import ChannelCapabilityManager, ChannelCapability @@ -110,7 +110,7 @@ class _ThinkTagStripper: on_output(self.buffer[:start_idx]) emitted = True self.in_think_tag = True - self.buffer = self.buffer[start_idx + 7 :] + self.buffer = self.buffer[start_idx + 7:] else: # 检查是否以 的不完整前缀结尾 partial_match = False @@ -130,7 +130,7 @@ class _ThinkTagStripper: end_idx = self.buffer.find("") if end_idx != -1: self.in_think_tag = False - self.buffer = self.buffer[end_idx + 8 :] + self.buffer = self.buffer[end_idx + 8:] else: # 检查是否以 的不完整前缀结尾 partial_match = False @@ -166,12 +166,12 @@ class MoviePilotAgent: """ def __init__( - self, - session_id: str, - user_id: str = None, - channel: str = None, - source: str = None, - username: str = None, + self, + session_id: str, + user_id: str = None, + channel: str = None, + source: str = None, + username: str = None, ): self.session_id = session_id self.user_id = user_id @@ -200,16 +200,16 @@ class MoviePilotAgent: return None @classmethod - def _get_model_name(cls, llm: Any) -> Optional[str]: + def _get_model_name(cls, model: Any) -> Optional[str]: return ( - getattr(llm, "model", None) - or getattr(llm, "model_name", None) - or getattr(llm, "model_id", None) + getattr(model, "model", None) + or getattr(model, "model_name", None) + or getattr(model, "model_id", None) ) @classmethod - def _get_context_window_tokens(cls, llm: Any) -> Optional[int]: - profile = getattr(llm, "profile", None) + def _get_context_window_tokens(cls, model: Any) -> Optional[int]: + profile = getattr(model, "profile", None) if not profile: return None if isinstance(profile, dict): @@ -221,9 +221,9 @@ class MoviePilotAgent: or getattr(profile, "input_token_limit", None) ) - def _sync_model_profile(self, llm: Any) -> None: - model_name = self._get_model_name(llm) - context_window_tokens = self._get_context_window_tokens(llm) + def _sync_model_profile(self, model: Any) -> None: + model_name = self._get_model_name(model) + context_window_tokens = self._get_context_window_tokens(model) if model_name: self._session_usage.model = model_name if context_window_tokens: @@ -337,10 +337,10 @@ class MoviePilotAgent: if block.get("thought"): continue if block.get("type") in ( - "thinking", - "reasoning_content", - "reasoning", - "thought", + "thinking", + "reasoning_content", + "reasoning", + "thought", ): continue if block.get("type") == "text": @@ -397,8 +397,8 @@ class MoviePilotAgent: system_prompt = prompt_manager.get_agent_prompt(channel=self.channel) # LLM 模型(用于 agent 执行) - llm = await self._initialize_llm(streaming=streaming) - self._sync_model_profile(llm) + model = await self._initialize_llm(streaming=streaming) + self._sync_model_profile(model) # 为中间件内部模型调用准备非流式 LLM,避免与用户流式回复复用同一实例。 non_streaming_llm = ( @@ -444,7 +444,7 @@ class MoviePilotAgent: # 工具选择 if max_tools > 0: middlewares.append( - LLMToolSelectorMiddleware( + MoviePilotToolSelectorMiddleware( model=non_streaming_llm, max_tools=max_tools, always_include=always_include_tools, @@ -463,10 +463,10 @@ class MoviePilotAgent: raise e async def process( - self, - message: str, - images: List[str] = None, - files: Optional[List[dict]] = None, + self, + message: str, + images: List[str] = None, + files: Optional[List[dict]] = None, ) -> str: """ 处理用户消息,流式推理并返回 Agent 回复 @@ -519,7 +519,7 @@ class MoviePilotAgent: return error_message async def _stream_agent_tokens( - self, agent, messages: dict, config: dict, on_token: Callable[[str], None] + self, agent, messages: dict, config: dict, on_token: Callable[[str], None] ): """ 流式运行智能体,过滤工具调用token和思考内容,将模型生成的内容通过回调输出。 @@ -531,11 +531,11 @@ class MoviePilotAgent: stripper = _ThinkTagStripper() async for chunk in agent.astream( - messages, - stream_mode="messages", - config=config, - subgraphs=False, - version="v2", + messages, + stream_mode="messages", + config=config, + subgraphs=False, + version="v2", ): if chunk["type"] == "messages": token, metadata = chunk["data"] @@ -621,21 +621,21 @@ class MoviePilotAgent: if remaining_text: unsent_text = remaining_text if self._streamed_output and remaining_text.startswith( - self._streamed_output + self._streamed_output ): - unsent_text = remaining_text[len(self._streamed_output) :] + unsent_text = remaining_text[len(self._streamed_output):] if unsent_text: self._emit_output(unsent_text) if ( - remaining_text - and self.should_dispatch_reply - and not self._tool_context.get("user_reply_sent") + remaining_text + and self.should_dispatch_reply + and not self._tool_context.get("user_reply_sent") ): await self.send_agent_message(remaining_text) elif ( - remaining_text - and self.persist_output_message - and not self._tool_context.get("user_reply_sent") + remaining_text + and self.persist_output_message + and not self._tool_context.get("user_reply_sent") ): title = "MoviePilot助手" if self.is_background else "" await self._save_agent_message_to_db( @@ -674,9 +674,9 @@ class MoviePilotAgent: self._emit_output(final_text) if ( - final_text - and self.should_dispatch_reply - and not self._tool_context.get("user_reply_sent") + final_text + and self.should_dispatch_reply + and not self._tool_context.get("user_reply_sent") ): if self.is_background: # 后台任务发送最终回复时统一带标题 @@ -687,9 +687,9 @@ class MoviePilotAgent: # 非流式渠道:发送最终回复 await self.send_agent_message(final_text) elif ( - final_text - and self.persist_output_message - and not self._tool_context.get("user_reply_sent") + final_text + and self.persist_output_message + and not self._tool_context.get("user_reply_sent") ): title = "MoviePilot助手" if self.is_background else "" await self._save_agent_message_to_db(final_text, title=title) @@ -810,8 +810,8 @@ class AgentManager: queue = self._session_queues.get(session_id) status["pending_messages"] = queue.qsize() if queue else 0 status["is_processing"] = ( - session_id in self._session_workers - and not self._session_workers[session_id].done() + session_id in self._session_workers + and not self._session_workers[session_id].done() ) return status @@ -843,16 +843,16 @@ class AgentManager: self.active_agents.clear() async def process_message( - self, - session_id: str, - user_id: str, - message: str, - images: List[str] = None, - files: Optional[List[dict]] = None, - channel: str = None, - source: str = None, - username: str = None, - reply_mode: ReplyMode = ReplyMode.DISPATCH, + self, + session_id: str, + user_id: str, + message: str, + images: List[str] = None, + files: Optional[List[dict]] = None, + channel: str = None, + source: str = None, + username: str = None, + reply_mode: ReplyMode = ReplyMode.DISPATCH, ) -> str: """ 处理用户消息:将消息放入会话队列,按顺序依次处理。 @@ -879,8 +879,8 @@ class AgentManager: # 如果队列中已有等待的消息,通知用户消息已排队 if queue_size > 0 or ( - session_id in self._session_workers - and not self._session_workers[session_id].done() + session_id in self._session_workers + and not self._session_workers[session_id].done() ): logger.info( f"会话 {session_id} 有任务正在处理,消息已排队等待 " @@ -892,8 +892,8 @@ class AgentManager: # 确保该会话有一个worker在运行 if ( - session_id not in self._session_workers - or self._session_workers[session_id].done() + session_id not in self._session_workers + or self._session_workers[session_id].done() ): self._session_workers[session_id] = asyncio.create_task( self._session_worker(session_id) @@ -934,8 +934,8 @@ class AgentManager: self._session_workers.pop(session_id, None) # noqa # 如果队列为空,清理队列 if ( - session_id in self._session_queues - and self._session_queues[session_id].empty() + session_id in self._session_queues + and self._session_queues[session_id].empty() ): self._session_queues.pop(session_id, None) @@ -1033,12 +1033,12 @@ class AgentManager: @staticmethod async def run_background_prompt( - message: str, - session_prefix: str = "__agent_background", - output_callback: Optional[Callable[[str], None]] = None, - reply_mode: ReplyMode = ReplyMode.CAPTURE_ONLY, - persist_output_message: bool = True, - allow_message_tools: Optional[bool] = None, + message: str, + session_prefix: str = "__agent_background", + output_callback: Optional[Callable[[str], None]] = None, + reply_mode: ReplyMode = ReplyMode.CAPTURE_ONLY, + persist_output_message: bool = True, + allow_message_tools: Optional[bool] = None, ) -> None: """ 以独立后台会话执行一段 prompt。 diff --git a/app/agent/middleware/tool_selection.py b/app/agent/middleware/tool_selection.py new file mode 100644 index 00000000..890e8c38 --- /dev/null +++ b/app/agent/middleware/tool_selection.py @@ -0,0 +1,196 @@ +"""MoviePilot 自定义工具筛选中间件。""" + +from __future__ import annotations + +import json +from typing import Any + +from langchain.agents.middleware import LLMToolSelectorMiddleware +from langchain_core.language_models.chat_models import BaseChatModel + +from app.log import logger + + +class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware): + """ + 为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。 + + LangChain 默认会通过 `with_structured_output()` 走 OpenAI 的 + `response_format=json_schema` 路径,但 DeepSeek 官方 OpenAI 兼容端点公开文档 + 仅保证 `json_object` 模式可用。对于 `deepseek-reasoner`,这会在工具筛选阶段 + 提前触发 400,导致 Agent 还没真正开始执行工具就失败。 + + 因此这里仅在识别到 DeepSeek 模型/端点时,退回到显式 JSON 输出模式: + 1. 使用 `response_format={"type": "json_object"}`; + 2. 在提示词中明确约束返回 JSON 结构; + 3. 手动解析 `{"tools": [...]}`,其余模型继续沿用 LangChain 默认实现。 + """ + + @staticmethod + def _is_deepseek_compatible_model(model: BaseChatModel) -> bool: + """ + 判断当前模型是否应当走 DeepSeek JSON 兼容分支。 + + 除了官方 `langchain_deepseek`,用户也可能通过 OpenAI-compatible + 配置把 DeepSeek 端点接到 `ChatOpenAI`。因此这里同时检查模块名、模型名 + 和 Base URL,避免只靠单一条件漏判。 + """ + module_name = type(model).__module__.lower() + model_name = str( + getattr(model, "model_name", "") or getattr(model, "model", "") + ).strip().lower() + base_url = str( + getattr(model, "openai_api_base", "") or getattr(model, "api_base", "") + ).strip().lower() + + return ( + "deepseek" in module_name + or model_name.startswith("deepseek-") + or "api.deepseek.com" in base_url + ) + + @staticmethod + def _extract_text_content(content: Any) -> str: + """ + 从模型响应中提取纯文本。 + + 这里不依赖上层 LLMHelper,避免中间件与 LLM 构造逻辑互相耦合。 + """ + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + text_parts: list[str] = [] + for block in content: + if isinstance(block, str): + text_parts.append(block) + continue + if isinstance(block, dict): + if block.get("type") == "text" and isinstance( + block.get("text"), str + ): + text_parts.append(block["text"]) + continue + if not block.get("type") and isinstance(block.get("text"), str): + text_parts.append(block["text"]) + return "".join(text_parts) + if isinstance(content, dict): + if content.get("type") == "text" and isinstance(content.get("text"), str): + return content["text"] + if not content.get("type") and isinstance(content.get("text"), str): + return content["text"] + return "" + + @staticmethod + def _parse_json_object(text: str) -> dict[str, Any]: + """ + 解析模型返回的 JSON。 + + DeepSeek 在 JSON 模式下通常会返回纯 JSON,但这里仍做一层兜底, + 兼容模型偶发输出围栏或前后说明文本的情况。 + """ + stripped_text = text.strip() + if not stripped_text: + raise ValueError("工具筛选返回了空响应") + + try: + payload = json.loads(stripped_text) + if isinstance(payload, dict): + return payload + except json.JSONDecodeError: + pass + + start = stripped_text.find("{") + end = stripped_text.rfind("}") + if start == -1 or end == -1 or end <= start: + raise ValueError(f"工具筛选返回的内容不是合法 JSON: {stripped_text}") + + payload = json.loads(stripped_text[start: end + 1]) + if not isinstance(payload, dict): + raise ValueError("工具筛选 JSON 顶层必须是对象") + return payload + + @staticmethod + def _render_tool_list(available_tools: list[Any]) -> str: + """把工具名和描述渲染成稳定的文本列表。""" + return "\n".join( + f"- {tool.name}: {tool.description}" for tool in available_tools + ) + + def _build_deepseek_selection_prompt(self, selection_request: Any) -> str: + """ + 为 DeepSeek 生成显式 JSON 输出提示。 + + DeepSeek 官方文档要求在 JSON 输出模式下,提示词中必须明确包含 JSON + 约束,否则兼容端点可能返回空内容或无意义输出。 + """ + return ( + f"{selection_request.system_message}\n\n" + "Return the answer in JSON only.\n" + 'Use exactly this shape: {"tools": ["tool_name_1", "tool_name_2"]}\n' + "Rules:\n" + "- The `tools` field must be a JSON array of strings.\n" + "- Only use tool names from the allowed list below.\n" + "- Order tools by relevance, with the most relevant first.\n" + "- Do not add explanations, markdown, or extra keys.\n\n" + f"Allowed tools:\n{self._render_tool_list(selection_request.available_tools)}" + ) + + def _normalize_selection_response(self, response: Any) -> dict[str, list[str]]: + """ + 解析并标准化 DeepSeek JSON 模式的工具筛选结果。 + """ + content = getattr(response, "content", response) + text = self._extract_text_content(content) + payload = self._parse_json_object(text) + + tools = payload.get("tools") + if not isinstance(tools, list): + raise ValueError(f"工具筛选 JSON 缺少 `tools` 数组: {payload}") + + normalized_tools = [tool_name for tool_name in tools if isinstance(tool_name, str)] + return {"tools": normalized_tools} + + async def _aselect_tools_with_deepseek( + self, selection_request: Any + ) -> dict[str, list[str]]: + """ + 使用 DeepSeek 兼容的 JSON 输出模式执行异步工具筛选。 + """ + logger.debug("工具筛选走 DeepSeek JSON 兼容分支") + structured_model = selection_request.model.bind( + response_format={"type": "json_object"} + ) + response = await structured_model.ainvoke( + [ + { + "role": "system", + "content": self._build_deepseek_selection_prompt( + selection_request + ), + }, + selection_request.last_user_message, + ] + ) + return self._normalize_selection_response(response) + + async def awrap_model_call(self, request: Any, handler: Any) -> Any: + """ + 异步版本的 DeepSeek 工具筛选兼容分支。 + """ + selection_request = self._prepare_selection_request(request) + if selection_request is None: + return await handler(request) + + if not self._is_deepseek_compatible_model(selection_request.model): + return await super().awrap_model_call(request, handler) + + response = await self._aselect_tools_with_deepseek(selection_request) + modified_request = self._process_selection_response( + response, + selection_request.available_tools, + selection_request.valid_tool_names, + request, + ) + return await handler(modified_request) diff --git a/tests/test_agent_tool_selector_middleware.py b/tests/test_agent_tool_selector_middleware.py new file mode 100644 index 00000000..bc38a7e6 --- /dev/null +++ b/tests/test_agent_tool_selector_middleware.py @@ -0,0 +1,141 @@ +import asyncio +import importlib.util +import sys +import unittest +from pathlib import Path +from types import ModuleType, SimpleNamespace + +from langchain_core.messages import HumanMessage + + +def _stub_module(name: str, **attrs): + module = sys.modules.get(name) + if module is None: + module = ModuleType(name) + sys.modules[name] = module + for key, value in attrs.items(): + setattr(module, key, value) + return module + + +sys.modules.pop("app.agent.middleware.tool_selection", None) +_stub_module( + "app.log", + logger=SimpleNamespace(debug=lambda *args, **kwargs: None), +) + +module_path = ( + Path(__file__).resolve().parents[1] + / "app" + / "agent" + / "middleware" + / "tool_selection.py" +) +spec = importlib.util.spec_from_file_location("test_tool_selector_module", module_path) +tool_selector_module = importlib.util.module_from_spec(spec) +assert spec and spec.loader +spec.loader.exec_module(tool_selector_module) + + +class _FakeBoundModel: + def __init__(self, content): + self.content = content + self.messages = None + + def invoke(self, messages): + self.messages = messages + return SimpleNamespace(content=self.content) + + async def ainvoke(self, messages): + self.messages = messages + return SimpleNamespace(content=self.content) + + +class _FakeModel: + def __init__( + self, + *, + content='{"tools": ["calendar", "search"]}', + model_name="deepseek-reasoner", + base_url="https://api.deepseek.com", + ): + self.model_name = model_name + self.openai_api_base = base_url + self.bind_calls = [] + self.bound_model = _FakeBoundModel(content) + + def bind(self, **kwargs): + self.bind_calls.append(kwargs) + return self.bound_model + + +class _FakeRequest: + def __init__(self, *, tools, messages, model): + self.tools = tools + self.messages = messages + self.model = model + + def override(self, **kwargs): + data = { + "tools": self.tools, + "messages": self.messages, + "model": self.model, + } + data.update(kwargs) + return _FakeRequest(**data) + + +class ToolSelectorMiddlewareTest(unittest.TestCase): + def test_awrap_model_call_uses_json_mode_for_deepseek(self): + middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(max_tools=2) + tools = [ + SimpleNamespace(name="search", description="Search for information"), + SimpleNamespace(name="calendar", description="Manage events"), + SimpleNamespace(name="translate", description="Translate text"), + ] + model = _FakeModel() + request = _FakeRequest( + tools=tools, + messages=[HumanMessage(content="帮我安排明天的行程并查天气")], + model=model, + ) + handled_requests = [] + + async def handler(updated_request): + handled_requests.append(updated_request) + return updated_request + + result = asyncio.run(middleware.awrap_model_call(request, handler)) + + self.assertEqual( + model.bind_calls, + [{"response_format": {"type": "json_object"}}], + ) + self.assertEqual( + [tool.name for tool in result.tools], + ["search", "calendar"], + ) + prompt = model.bound_model.messages[0]["content"] + self.assertIn("Return the answer in JSON only.", prompt) + self.assertIn('- search: Search for information', prompt) + self.assertIn('- calendar: Manage events', prompt) + self.assertEqual(len(handled_requests), 1) + + def test_normalize_selection_response_accepts_code_fence_json(self): + middleware = tool_selector_module.MoviePilotToolSelectorMiddleware() + response = SimpleNamespace( + content=[ + { + "type": "text", + "text": '```json\n{"tools": ["search"]}\n```', + } + ] + ) + + normalized = middleware._normalize_selection_response(response) + + self.assertEqual(normalized, {"tools": ["search"]}) + + +if __name__ == "__main__": + unittest.main()