fix tool selection middleware

This commit is contained in:
jxxghp
2026-04-30 13:47:43 +08:00
parent 2ccea2da39
commit 45f5326fb4
3 changed files with 411 additions and 74 deletions

View File

@@ -11,7 +11,6 @@ from typing import Any, Callable, Dict, List, Optional
from langchain.agents import create_agent
from langchain.agents.middleware import (
SummarizationMiddleware,
LLMToolSelectorMiddleware,
)
from langchain_core.messages import ( # noqa: F401
HumanMessage,
@@ -20,6 +19,7 @@ from langchain_core.messages import ( # noqa: F401
from langgraph.checkpoint.memory import InMemorySaver
from app.agent.callback import StreamingHandler
from app.agent.llm import LLMHelper
from app.agent.memory import memory_manager
from app.agent.middleware.activity_log import ActivityLogMiddleware
from app.agent.middleware.jobs import JobsMiddleware
@@ -27,13 +27,13 @@ from app.agent.middleware.memory import MemoryMiddleware
from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware
from app.agent.middleware.runtime_config import RuntimeConfigMiddleware
from app.agent.middleware.skills import SkillsMiddleware
from app.agent.middleware.tool_selection import MoviePilotToolSelectorMiddleware
from app.agent.middleware.usage import UsageMiddleware
from app.agent.prompt import prompt_manager
from app.agent.runtime import agent_runtime_manager
from app.agent.tools.factory import MoviePilotToolFactory
from app.chain import ChainBase
from app.core.config import settings
from app.agent.llm import LLMHelper
from app.log import logger
from app.schemas import Notification, NotificationType
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
@@ -110,7 +110,7 @@ class _ThinkTagStripper:
on_output(self.buffer[:start_idx])
emitted = True
self.in_think_tag = True
self.buffer = self.buffer[start_idx + 7 :]
self.buffer = self.buffer[start_idx + 7:]
else:
# 检查是否以 <think> 的不完整前缀结尾
partial_match = False
@@ -130,7 +130,7 @@ class _ThinkTagStripper:
end_idx = self.buffer.find("</think>")
if end_idx != -1:
self.in_think_tag = False
self.buffer = self.buffer[end_idx + 8 :]
self.buffer = self.buffer[end_idx + 8:]
else:
# 检查是否以 </think> 的不完整前缀结尾
partial_match = False
@@ -166,12 +166,12 @@ class MoviePilotAgent:
"""
def __init__(
self,
session_id: str,
user_id: str = None,
channel: str = None,
source: str = None,
username: str = None,
self,
session_id: str,
user_id: str = None,
channel: str = None,
source: str = None,
username: str = None,
):
self.session_id = session_id
self.user_id = user_id
@@ -200,16 +200,16 @@ class MoviePilotAgent:
return None
@classmethod
def _get_model_name(cls, llm: Any) -> Optional[str]:
def _get_model_name(cls, model: Any) -> Optional[str]:
return (
getattr(llm, "model", None)
or getattr(llm, "model_name", None)
or getattr(llm, "model_id", None)
getattr(model, "model", None)
or getattr(model, "model_name", None)
or getattr(model, "model_id", None)
)
@classmethod
def _get_context_window_tokens(cls, llm: Any) -> Optional[int]:
profile = getattr(llm, "profile", None)
def _get_context_window_tokens(cls, model: Any) -> Optional[int]:
profile = getattr(model, "profile", None)
if not profile:
return None
if isinstance(profile, dict):
@@ -221,9 +221,9 @@ class MoviePilotAgent:
or getattr(profile, "input_token_limit", None)
)
def _sync_model_profile(self, llm: Any) -> None:
model_name = self._get_model_name(llm)
context_window_tokens = self._get_context_window_tokens(llm)
def _sync_model_profile(self, model: Any) -> None:
model_name = self._get_model_name(model)
context_window_tokens = self._get_context_window_tokens(model)
if model_name:
self._session_usage.model = model_name
if context_window_tokens:
@@ -337,10 +337,10 @@ class MoviePilotAgent:
if block.get("thought"):
continue
if block.get("type") in (
"thinking",
"reasoning_content",
"reasoning",
"thought",
"thinking",
"reasoning_content",
"reasoning",
"thought",
):
continue
if block.get("type") == "text":
@@ -397,8 +397,8 @@ class MoviePilotAgent:
system_prompt = prompt_manager.get_agent_prompt(channel=self.channel)
# LLM 模型(用于 agent 执行)
llm = await self._initialize_llm(streaming=streaming)
self._sync_model_profile(llm)
model = await self._initialize_llm(streaming=streaming)
self._sync_model_profile(model)
# 为中间件内部模型调用准备非流式 LLM避免与用户流式回复复用同一实例。
non_streaming_llm = (
@@ -444,7 +444,7 @@ class MoviePilotAgent:
# 工具选择
if max_tools > 0:
middlewares.append(
LLMToolSelectorMiddleware(
MoviePilotToolSelectorMiddleware(
model=non_streaming_llm,
max_tools=max_tools,
always_include=always_include_tools,
@@ -463,10 +463,10 @@ class MoviePilotAgent:
raise e
async def process(
self,
message: str,
images: List[str] = None,
files: Optional[List[dict]] = None,
self,
message: str,
images: List[str] = None,
files: Optional[List[dict]] = None,
) -> str:
"""
处理用户消息,流式推理并返回 Agent 回复
@@ -519,7 +519,7 @@ class MoviePilotAgent:
return error_message
async def _stream_agent_tokens(
self, agent, messages: dict, config: dict, on_token: Callable[[str], None]
self, agent, messages: dict, config: dict, on_token: Callable[[str], None]
):
"""
流式运行智能体过滤工具调用token和思考内容将模型生成的内容通过回调输出。
@@ -531,11 +531,11 @@ class MoviePilotAgent:
stripper = _ThinkTagStripper()
async for chunk in agent.astream(
messages,
stream_mode="messages",
config=config,
subgraphs=False,
version="v2",
messages,
stream_mode="messages",
config=config,
subgraphs=False,
version="v2",
):
if chunk["type"] == "messages":
token, metadata = chunk["data"]
@@ -621,21 +621,21 @@ class MoviePilotAgent:
if remaining_text:
unsent_text = remaining_text
if self._streamed_output and remaining_text.startswith(
self._streamed_output
self._streamed_output
):
unsent_text = remaining_text[len(self._streamed_output) :]
unsent_text = remaining_text[len(self._streamed_output):]
if unsent_text:
self._emit_output(unsent_text)
if (
remaining_text
and self.should_dispatch_reply
and not self._tool_context.get("user_reply_sent")
remaining_text
and self.should_dispatch_reply
and not self._tool_context.get("user_reply_sent")
):
await self.send_agent_message(remaining_text)
elif (
remaining_text
and self.persist_output_message
and not self._tool_context.get("user_reply_sent")
remaining_text
and self.persist_output_message
and not self._tool_context.get("user_reply_sent")
):
title = "MoviePilot助手" if self.is_background else ""
await self._save_agent_message_to_db(
@@ -674,9 +674,9 @@ class MoviePilotAgent:
self._emit_output(final_text)
if (
final_text
and self.should_dispatch_reply
and not self._tool_context.get("user_reply_sent")
final_text
and self.should_dispatch_reply
and not self._tool_context.get("user_reply_sent")
):
if self.is_background:
# 后台任务发送最终回复时统一带标题
@@ -687,9 +687,9 @@ class MoviePilotAgent:
# 非流式渠道:发送最终回复
await self.send_agent_message(final_text)
elif (
final_text
and self.persist_output_message
and not self._tool_context.get("user_reply_sent")
final_text
and self.persist_output_message
and not self._tool_context.get("user_reply_sent")
):
title = "MoviePilot助手" if self.is_background else ""
await self._save_agent_message_to_db(final_text, title=title)
@@ -810,8 +810,8 @@ class AgentManager:
queue = self._session_queues.get(session_id)
status["pending_messages"] = queue.qsize() if queue else 0
status["is_processing"] = (
session_id in self._session_workers
and not self._session_workers[session_id].done()
session_id in self._session_workers
and not self._session_workers[session_id].done()
)
return status
@@ -843,16 +843,16 @@ class AgentManager:
self.active_agents.clear()
async def process_message(
self,
session_id: str,
user_id: str,
message: str,
images: List[str] = None,
files: Optional[List[dict]] = None,
channel: str = None,
source: str = None,
username: str = None,
reply_mode: ReplyMode = ReplyMode.DISPATCH,
self,
session_id: str,
user_id: str,
message: str,
images: List[str] = None,
files: Optional[List[dict]] = None,
channel: str = None,
source: str = None,
username: str = None,
reply_mode: ReplyMode = ReplyMode.DISPATCH,
) -> str:
"""
处理用户消息:将消息放入会话队列,按顺序依次处理。
@@ -879,8 +879,8 @@ class AgentManager:
# 如果队列中已有等待的消息,通知用户消息已排队
if queue_size > 0 or (
session_id in self._session_workers
and not self._session_workers[session_id].done()
session_id in self._session_workers
and not self._session_workers[session_id].done()
):
logger.info(
f"会话 {session_id} 有任务正在处理,消息已排队等待 "
@@ -892,8 +892,8 @@ class AgentManager:
# 确保该会话有一个worker在运行
if (
session_id not in self._session_workers
or self._session_workers[session_id].done()
session_id not in self._session_workers
or self._session_workers[session_id].done()
):
self._session_workers[session_id] = asyncio.create_task(
self._session_worker(session_id)
@@ -934,8 +934,8 @@ class AgentManager:
self._session_workers.pop(session_id, None) # noqa
# 如果队列为空,清理队列
if (
session_id in self._session_queues
and self._session_queues[session_id].empty()
session_id in self._session_queues
and self._session_queues[session_id].empty()
):
self._session_queues.pop(session_id, None)
@@ -1033,12 +1033,12 @@ class AgentManager:
@staticmethod
async def run_background_prompt(
message: str,
session_prefix: str = "__agent_background",
output_callback: Optional[Callable[[str], None]] = None,
reply_mode: ReplyMode = ReplyMode.CAPTURE_ONLY,
persist_output_message: bool = True,
allow_message_tools: Optional[bool] = None,
message: str,
session_prefix: str = "__agent_background",
output_callback: Optional[Callable[[str], None]] = None,
reply_mode: ReplyMode = ReplyMode.CAPTURE_ONLY,
persist_output_message: bool = True,
allow_message_tools: Optional[bool] = None,
) -> None:
"""
以独立后台会话执行一段 prompt。

View File

@@ -0,0 +1,196 @@
"""MoviePilot 自定义工具筛选中间件。"""
from __future__ import annotations
import json
from typing import Any
from langchain.agents.middleware import LLMToolSelectorMiddleware
from langchain_core.language_models.chat_models import BaseChatModel
from app.log import logger
class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware):
"""
为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。
LangChain 默认会通过 `with_structured_output()` 走 OpenAI 的
`response_format=json_schema` 路径,但 DeepSeek 官方 OpenAI 兼容端点公开文档
仅保证 `json_object` 模式可用。对于 `deepseek-reasoner`,这会在工具筛选阶段
提前触发 400导致 Agent 还没真正开始执行工具就失败。
因此这里仅在识别到 DeepSeek 模型/端点时,退回到显式 JSON 输出模式:
1. 使用 `response_format={"type": "json_object"}`
2. 在提示词中明确约束返回 JSON 结构;
3. 手动解析 `{"tools": [...]}`,其余模型继续沿用 LangChain 默认实现。
"""
@staticmethod
def _is_deepseek_compatible_model(model: BaseChatModel) -> bool:
"""
判断当前模型是否应当走 DeepSeek JSON 兼容分支。
除了官方 `langchain_deepseek`,用户也可能通过 OpenAI-compatible
配置把 DeepSeek 端点接到 `ChatOpenAI`。因此这里同时检查模块名、模型名
和 Base URL避免只靠单一条件漏判。
"""
module_name = type(model).__module__.lower()
model_name = str(
getattr(model, "model_name", "") or getattr(model, "model", "")
).strip().lower()
base_url = str(
getattr(model, "openai_api_base", "") or getattr(model, "api_base", "")
).strip().lower()
return (
"deepseek" in module_name
or model_name.startswith("deepseek-")
or "api.deepseek.com" in base_url
)
@staticmethod
def _extract_text_content(content: Any) -> str:
"""
从模型响应中提取纯文本。
这里不依赖上层 LLMHelper避免中间件与 LLM 构造逻辑互相耦合。
"""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
text_parts: list[str] = []
for block in content:
if isinstance(block, str):
text_parts.append(block)
continue
if isinstance(block, dict):
if block.get("type") == "text" and isinstance(
block.get("text"), str
):
text_parts.append(block["text"])
continue
if not block.get("type") and isinstance(block.get("text"), str):
text_parts.append(block["text"])
return "".join(text_parts)
if isinstance(content, dict):
if content.get("type") == "text" and isinstance(content.get("text"), str):
return content["text"]
if not content.get("type") and isinstance(content.get("text"), str):
return content["text"]
return ""
@staticmethod
def _parse_json_object(text: str) -> dict[str, Any]:
"""
解析模型返回的 JSON。
DeepSeek 在 JSON 模式下通常会返回纯 JSON但这里仍做一层兜底
兼容模型偶发输出围栏或前后说明文本的情况。
"""
stripped_text = text.strip()
if not stripped_text:
raise ValueError("工具筛选返回了空响应")
try:
payload = json.loads(stripped_text)
if isinstance(payload, dict):
return payload
except json.JSONDecodeError:
pass
start = stripped_text.find("{")
end = stripped_text.rfind("}")
if start == -1 or end == -1 or end <= start:
raise ValueError(f"工具筛选返回的内容不是合法 JSON: {stripped_text}")
payload = json.loads(stripped_text[start: end + 1])
if not isinstance(payload, dict):
raise ValueError("工具筛选 JSON 顶层必须是对象")
return payload
@staticmethod
def _render_tool_list(available_tools: list[Any]) -> str:
"""把工具名和描述渲染成稳定的文本列表。"""
return "\n".join(
f"- {tool.name}: {tool.description}" for tool in available_tools
)
def _build_deepseek_selection_prompt(self, selection_request: Any) -> str:
"""
为 DeepSeek 生成显式 JSON 输出提示。
DeepSeek 官方文档要求在 JSON 输出模式下,提示词中必须明确包含 JSON
约束,否则兼容端点可能返回空内容或无意义输出。
"""
return (
f"{selection_request.system_message}\n\n"
"Return the answer in JSON only.\n"
'Use exactly this shape: {"tools": ["tool_name_1", "tool_name_2"]}\n'
"Rules:\n"
"- The `tools` field must be a JSON array of strings.\n"
"- Only use tool names from the allowed list below.\n"
"- Order tools by relevance, with the most relevant first.\n"
"- Do not add explanations, markdown, or extra keys.\n\n"
f"Allowed tools:\n{self._render_tool_list(selection_request.available_tools)}"
)
def _normalize_selection_response(self, response: Any) -> dict[str, list[str]]:
"""
解析并标准化 DeepSeek JSON 模式的工具筛选结果。
"""
content = getattr(response, "content", response)
text = self._extract_text_content(content)
payload = self._parse_json_object(text)
tools = payload.get("tools")
if not isinstance(tools, list):
raise ValueError(f"工具筛选 JSON 缺少 `tools` 数组: {payload}")
normalized_tools = [tool_name for tool_name in tools if isinstance(tool_name, str)]
return {"tools": normalized_tools}
async def _aselect_tools_with_deepseek(
self, selection_request: Any
) -> dict[str, list[str]]:
"""
使用 DeepSeek 兼容的 JSON 输出模式执行异步工具筛选。
"""
logger.debug("工具筛选走 DeepSeek JSON 兼容分支")
structured_model = selection_request.model.bind(
response_format={"type": "json_object"}
)
response = await structured_model.ainvoke(
[
{
"role": "system",
"content": self._build_deepseek_selection_prompt(
selection_request
),
},
selection_request.last_user_message,
]
)
return self._normalize_selection_response(response)
async def awrap_model_call(self, request: Any, handler: Any) -> Any:
"""
异步版本的 DeepSeek 工具筛选兼容分支。
"""
selection_request = self._prepare_selection_request(request)
if selection_request is None:
return await handler(request)
if not self._is_deepseek_compatible_model(selection_request.model):
return await super().awrap_model_call(request, handler)
response = await self._aselect_tools_with_deepseek(selection_request)
modified_request = self._process_selection_response(
response,
selection_request.available_tools,
selection_request.valid_tool_names,
request,
)
return await handler(modified_request)

View File

@@ -0,0 +1,141 @@
import asyncio
import importlib.util
import sys
import unittest
from pathlib import Path
from types import ModuleType, SimpleNamespace
from langchain_core.messages import HumanMessage
def _stub_module(name: str, **attrs):
module = sys.modules.get(name)
if module is None:
module = ModuleType(name)
sys.modules[name] = module
for key, value in attrs.items():
setattr(module, key, value)
return module
sys.modules.pop("app.agent.middleware.tool_selection", None)
_stub_module(
"app.log",
logger=SimpleNamespace(debug=lambda *args, **kwargs: None),
)
module_path = (
Path(__file__).resolve().parents[1]
/ "app"
/ "agent"
/ "middleware"
/ "tool_selection.py"
)
spec = importlib.util.spec_from_file_location("test_tool_selector_module", module_path)
tool_selector_module = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(tool_selector_module)
class _FakeBoundModel:
def __init__(self, content):
self.content = content
self.messages = None
def invoke(self, messages):
self.messages = messages
return SimpleNamespace(content=self.content)
async def ainvoke(self, messages):
self.messages = messages
return SimpleNamespace(content=self.content)
class _FakeModel:
def __init__(
self,
*,
content='{"tools": ["calendar", "search"]}',
model_name="deepseek-reasoner",
base_url="https://api.deepseek.com",
):
self.model_name = model_name
self.openai_api_base = base_url
self.bind_calls = []
self.bound_model = _FakeBoundModel(content)
def bind(self, **kwargs):
self.bind_calls.append(kwargs)
return self.bound_model
class _FakeRequest:
def __init__(self, *, tools, messages, model):
self.tools = tools
self.messages = messages
self.model = model
def override(self, **kwargs):
data = {
"tools": self.tools,
"messages": self.messages,
"model": self.model,
}
data.update(kwargs)
return _FakeRequest(**data)
class ToolSelectorMiddlewareTest(unittest.TestCase):
def test_awrap_model_call_uses_json_mode_for_deepseek(self):
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(max_tools=2)
tools = [
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="calendar", description="Manage events"),
SimpleNamespace(name="translate", description="Translate text"),
]
model = _FakeModel()
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
model=model,
)
handled_requests = []
async def handler(updated_request):
handled_requests.append(updated_request)
return updated_request
result = asyncio.run(middleware.awrap_model_call(request, handler))
self.assertEqual(
model.bind_calls,
[{"response_format": {"type": "json_object"}}],
)
self.assertEqual(
[tool.name for tool in result.tools],
["search", "calendar"],
)
prompt = model.bound_model.messages[0]["content"]
self.assertIn("Return the answer in JSON only.", prompt)
self.assertIn('- search: Search for information', prompt)
self.assertIn('- calendar: Manage events', prompt)
self.assertEqual(len(handled_requests), 1)
def test_normalize_selection_response_accepts_code_fence_json(self):
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware()
response = SimpleNamespace(
content=[
{
"type": "text",
"text": '```json\n{"tools": ["search"]}\n```',
}
]
)
normalized = middleware._normalize_selection_response(response)
self.assertEqual(normalized, {"tools": ["search"]})
if __name__ == "__main__":
unittest.main()