mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 20:42:43 +08:00
fix: preserve deepseek reasoning content in tool loops
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, List
|
||||
|
||||
from app.core.config import settings
|
||||
@@ -74,6 +75,132 @@ def _get_httpx_proxy_key() -> str:
|
||||
return "proxies"
|
||||
|
||||
|
||||
def _deepseek_thinking_toggle(extra_body: Any) -> bool | None:
|
||||
"""
|
||||
解析 DeepSeek extra_body 中显式传入的 thinking 开关。
|
||||
"""
|
||||
if not isinstance(extra_body, dict):
|
||||
return None
|
||||
|
||||
thinking = extra_body.get("thinking")
|
||||
if not isinstance(thinking, dict):
|
||||
return None
|
||||
|
||||
thinking_type = str(thinking.get("type") or "").strip().lower()
|
||||
if thinking_type == "enabled":
|
||||
return True
|
||||
if thinking_type == "disabled":
|
||||
return False
|
||||
return None
|
||||
|
||||
|
||||
def _is_deepseek_thinking_enabled(model_name: str | None, extra_body: Any) -> bool:
|
||||
"""
|
||||
判断本次 DeepSeek 调用是否处于 thinking mode。
|
||||
"""
|
||||
explicit_toggle = _deepseek_thinking_toggle(extra_body)
|
||||
if explicit_toggle is not None:
|
||||
return explicit_toggle
|
||||
|
||||
normalized_model_name = str(model_name or "").strip().lower()
|
||||
if normalized_model_name == "deepseek-reasoner":
|
||||
return True
|
||||
if normalized_model_name.startswith("deepseek-v4-"):
|
||||
# DeepSeek V4 默认启用 thinking mode,除非显式关闭。
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _extract_input_messages(input_: Any) -> list[Any]:
|
||||
"""
|
||||
将 chat model 输入还原为原始 BaseMessage 序列。
|
||||
"""
|
||||
try:
|
||||
from langchain_core.messages import convert_to_messages
|
||||
|
||||
return list(convert_to_messages(input_))
|
||||
except Exception:
|
||||
if isinstance(input_, list):
|
||||
return list(input_)
|
||||
return []
|
||||
|
||||
|
||||
def _patch_deepseek_reasoning_content_support():
|
||||
"""
|
||||
修补 langchain-deepseek 在 tool-call 场景下遗漏 reasoning_content 回传的问题。
|
||||
|
||||
DeepSeek thinking mode 要求:若 assistant 历史消息包含 tool_calls,
|
||||
后续请求中必须带回该条消息的顶层 reasoning_content。
|
||||
某些 langchain-deepseek 版本虽然能从响应中拿到 reasoning_content,
|
||||
但不会在重放消息历史时写回请求载荷,导致 400。
|
||||
"""
|
||||
try:
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
except Exception as err:
|
||||
logger.debug(f"跳过 langchain-deepseek reasoning_content 修补:{err}")
|
||||
return
|
||||
|
||||
if getattr(ChatDeepSeek, "_moviepilot_reasoning_content_patched", False):
|
||||
return
|
||||
|
||||
original_get_request_payload = getattr(ChatDeepSeek, "_get_request_payload", None)
|
||||
if not callable(original_get_request_payload):
|
||||
logger.warning("langchain-deepseek 缺少 _get_request_payload,无法修补 reasoning_content")
|
||||
return
|
||||
|
||||
@wraps(original_get_request_payload)
|
||||
def _patched_get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||
payload = original_get_request_payload(self, input_, stop=stop, **kwargs)
|
||||
|
||||
try:
|
||||
original_messages = _extract_input_messages(input_)
|
||||
payload_messages = payload.get("messages") or []
|
||||
model_name = getattr(self, "model_name", None) or getattr(
|
||||
self, "model", None
|
||||
)
|
||||
extra_body = kwargs.get("extra_body")
|
||||
if extra_body is None:
|
||||
extra_body = getattr(self, "extra_body", None)
|
||||
if extra_body is None:
|
||||
extra_body = getattr(self, "model_kwargs", {}).get("extra_body")
|
||||
|
||||
if not _is_deepseek_thinking_enabled(model_name, extra_body):
|
||||
return payload
|
||||
|
||||
for index, message in enumerate(payload_messages):
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
if message.get("role") != "assistant":
|
||||
continue
|
||||
if not message.get("tool_calls"):
|
||||
continue
|
||||
if message.get("reasoning_content") is not None:
|
||||
continue
|
||||
|
||||
reasoning_content = ""
|
||||
if index < len(original_messages):
|
||||
additional_kwargs = (
|
||||
getattr(original_messages[index], "additional_kwargs", None)
|
||||
or {}
|
||||
)
|
||||
if isinstance(additional_kwargs, dict):
|
||||
captured_reasoning = additional_kwargs.get("reasoning_content")
|
||||
if isinstance(captured_reasoning, str):
|
||||
reasoning_content = captured_reasoning
|
||||
|
||||
message["reasoning_content"] = reasoning_content
|
||||
except Exception as err:
|
||||
logger.warning(
|
||||
f"修补 langchain-deepseek reasoning_content 请求载荷时失败,将继续使用原始载荷: {err}"
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
ChatDeepSeek._get_request_payload = _patched_get_request_payload
|
||||
ChatDeepSeek._moviepilot_reasoning_content_patched = True
|
||||
logger.debug("已修补 langchain-deepseek thinking tool-call 的 reasoning_content 回传兼容性")
|
||||
|
||||
|
||||
class LLMHelper:
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
@@ -437,6 +564,7 @@ class LLMHelper:
|
||||
elif provider_name == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
_patch_deepseek_reasoning_content_support()
|
||||
model = ChatDeepSeek(
|
||||
model=model_name,
|
||||
api_key=api_key_value,
|
||||
|
||||
146
tests/test_langchain_deepseek_compat.py
Normal file
146
tests/test_langchain_deepseek_compat.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import importlib.util
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = sys.modules.get(name)
|
||||
if module is None:
|
||||
module = ModuleType(name)
|
||||
sys.modules[name] = module
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
return module
|
||||
|
||||
|
||||
class _DummyLogger:
|
||||
def __getattr__(self, _name):
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def _build_tool_call(name: str = "search", arguments: str = "{}"):
|
||||
return [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "tool_call",
|
||||
"name": name,
|
||||
"args": {},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class _FakeChatDeepSeek:
|
||||
def __init__(self, model_name: str, model_kwargs: dict | None = None):
|
||||
self.model_name = model_name
|
||||
self.model_kwargs = model_kwargs or {}
|
||||
|
||||
def _get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||
messages = []
|
||||
for message in input_:
|
||||
payload_message = {
|
||||
"role": message.type,
|
||||
"content": message.content,
|
||||
}
|
||||
if message.type == "human":
|
||||
payload_message["role"] = "user"
|
||||
elif message.type == "ai":
|
||||
payload_message["role"] = "assistant"
|
||||
tool_calls = getattr(message, "tool_calls", None)
|
||||
if tool_calls:
|
||||
payload_message["tool_calls"] = tool_calls
|
||||
elif message.type == "tool":
|
||||
payload_message["role"] = "tool"
|
||||
payload_message["tool_call_id"] = message.tool_call_id
|
||||
messages.append(payload_message)
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
_ORIGINAL_GET_REQUEST_PAYLOAD = _FakeChatDeepSeek._get_request_payload
|
||||
|
||||
|
||||
sys.modules.pop("app.helper.llm", None)
|
||||
_stub_module(
|
||||
"app.core.config",
|
||||
settings=ModuleType("settings"),
|
||||
)
|
||||
sys.modules["app.core.config"].settings.LLM_PROVIDER = "deepseek"
|
||||
sys.modules["app.core.config"].settings.LLM_MODEL = "deepseek-v4-pro"
|
||||
sys.modules["app.core.config"].settings.LLM_API_KEY = "sk-test"
|
||||
sys.modules["app.core.config"].settings.LLM_BASE_URL = "https://api.deepseek.com"
|
||||
sys.modules["app.core.config"].settings.LLM_THINKING_LEVEL = None
|
||||
sys.modules["app.core.config"].settings.LLM_DISABLE_THINKING = False
|
||||
sys.modules["app.core.config"].settings.LLM_REASONING_EFFORT = None
|
||||
sys.modules["app.core.config"].settings.LLM_TEMPERATURE = 0.1
|
||||
sys.modules["app.core.config"].settings.LLM_MAX_CONTEXT_TOKENS = 64
|
||||
sys.modules["app.core.config"].settings.PROXY_HOST = None
|
||||
_stub_module("app.log", logger=_DummyLogger())
|
||||
_stub_module("langchain_deepseek", ChatDeepSeek=_FakeChatDeepSeek)
|
||||
|
||||
module_path = Path(__file__).resolve().parents[1] / "app" / "helper" / "llm.py"
|
||||
spec = importlib.util.spec_from_file_location("test_llm_module_for_deepseek_compat", module_path)
|
||||
llm_module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(llm_module)
|
||||
|
||||
|
||||
class DeepSeekCompatPatchTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
_FakeChatDeepSeek._get_request_payload = _ORIGINAL_GET_REQUEST_PAYLOAD
|
||||
if hasattr(_FakeChatDeepSeek, "_moviepilot_reasoning_content_patched"):
|
||||
delattr(_FakeChatDeepSeek, "_moviepilot_reasoning_content_patched")
|
||||
llm_module._patch_deepseek_reasoning_content_support()
|
||||
|
||||
def test_injects_reasoning_content_for_assistant_tool_calls(self):
|
||||
llm = _FakeChatDeepSeek("deepseek-v4-pro")
|
||||
messages = [
|
||||
HumanMessage(content="天气如何?"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=_build_tool_call(),
|
||||
additional_kwargs={"reasoning_content": "先调用天气工具"},
|
||||
),
|
||||
ToolMessage(content="晴天", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
payload = llm._get_request_payload(messages)
|
||||
|
||||
self.assertEqual(
|
||||
payload["messages"][1]["reasoning_content"],
|
||||
"先调用天气工具",
|
||||
)
|
||||
|
||||
def test_falls_back_to_empty_reasoning_content_when_missing(self):
|
||||
llm = _FakeChatDeepSeek("deepseek-v4-flash")
|
||||
messages = [
|
||||
HumanMessage(content="天气如何?"),
|
||||
AIMessage(content="", tool_calls=_build_tool_call()),
|
||||
ToolMessage(content="晴天", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
payload = llm._get_request_payload(messages)
|
||||
|
||||
self.assertIn("reasoning_content", payload["messages"][1])
|
||||
self.assertEqual(payload["messages"][1]["reasoning_content"], "")
|
||||
|
||||
def test_skips_injection_when_thinking_is_disabled(self):
|
||||
llm = _FakeChatDeepSeek(
|
||||
"deepseek-v4-pro",
|
||||
model_kwargs={"extra_body": {"thinking": {"type": "disabled"}}},
|
||||
)
|
||||
messages = [
|
||||
HumanMessage(content="天气如何?"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=_build_tool_call(),
|
||||
additional_kwargs={"reasoning_content": "先调用天气工具"},
|
||||
),
|
||||
ToolMessage(content="晴天", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
payload = llm._get_request_payload(messages)
|
||||
|
||||
self.assertNotIn("reasoning_content", payload["messages"][1])
|
||||
@@ -144,6 +144,7 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
||||
|
||||
def test_get_llm_uses_deepseek_thinking_level_controls(self):
|
||||
calls = []
|
||||
patch_calls = []
|
||||
|
||||
class _FakeChatDeepSeek:
|
||||
def __init__(self, **kwargs):
|
||||
@@ -154,6 +155,10 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
||||
with patch.dict(
|
||||
sys.modules,
|
||||
{"langchain_deepseek": SimpleNamespace(ChatDeepSeek=_FakeChatDeepSeek)},
|
||||
), patch.object(
|
||||
llm_module,
|
||||
"_patch_deepseek_reasoning_content_support",
|
||||
side_effect=lambda: patch_calls.append(True),
|
||||
):
|
||||
llm_module.LLMHelper.get_llm(
|
||||
provider="deepseek",
|
||||
@@ -168,11 +173,13 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
||||
calls[0].get("extra_body"),
|
||||
{"thinking": {"type": "enabled"}},
|
||||
)
|
||||
self.assertEqual(patch_calls, [True])
|
||||
self.assertEqual(calls[0].get("reasoning_effort"), "max")
|
||||
self.assertEqual(calls[0].get("api_base"), "https://api.deepseek.com")
|
||||
|
||||
def test_get_llm_disables_deepseek_thinking_via_thinking_level(self):
|
||||
calls = []
|
||||
patch_calls = []
|
||||
|
||||
class _FakeChatDeepSeek:
|
||||
def __init__(self, **kwargs):
|
||||
@@ -183,6 +190,10 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
||||
with patch.dict(
|
||||
sys.modules,
|
||||
{"langchain_deepseek": SimpleNamespace(ChatDeepSeek=_FakeChatDeepSeek)},
|
||||
), patch.object(
|
||||
llm_module,
|
||||
"_patch_deepseek_reasoning_content_support",
|
||||
side_effect=lambda: patch_calls.append(True),
|
||||
):
|
||||
llm_module.LLMHelper.get_llm(
|
||||
provider="deepseek",
|
||||
@@ -197,6 +208,7 @@ class LlmHelperTestCallTest(unittest.TestCase):
|
||||
calls[0].get("extra_body"),
|
||||
{"thinking": {"type": "disabled"}},
|
||||
)
|
||||
self.assertEqual(patch_calls, [True])
|
||||
self.assertIsNone(calls[0].get("reasoning_effort"))
|
||||
self.assertEqual(calls[0].get("api_base"), "https://proxy.example.com")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user