refine internal middleware llm usage for streaming agents

Use a non-streaming model for middleware-only calls so internal outputs do not leak into user streams and model-based middleware stays consistent.
This commit is contained in:
jxxghp
2026-04-27 06:55:41 +08:00
parent 4208c79d72
commit 221eb21694
2 changed files with 130 additions and 3 deletions

View File

@@ -377,6 +377,11 @@ class MoviePilotAgent:
llm = self._initialize_llm(streaming=streaming)
self._sync_model_profile(llm)
# 为中间件内部模型调用准备非流式 LLM避免与用户流式回复复用同一实例。
non_streaming_llm = (
llm if not streaming else self._initialize_llm(streaming=False)
)
# 工具列表
tools = self._initialize_tools()
@@ -399,8 +404,9 @@ class MoviePilotAgent:
),
# 用量统计
UsageMiddleware(on_usage=self._record_usage),
# 上下文压缩
SummarizationMiddleware(model=llm, trigger=("fraction", 0.85)),
SummarizationMiddleware(
model=non_streaming_llm, trigger=("fraction", 0.85)
),
# 错误工具调用修复
PatchToolCallsMiddleware(),
]
@@ -409,7 +415,8 @@ class MoviePilotAgent:
if settings.LLM_MAX_TOOLS > 0:
middlewares.append(
LLMToolSelectorMiddleware(
model=llm, max_tools=settings.LLM_MAX_TOOLS
model=non_streaming_llm,
max_tools=settings.LLM_MAX_TOOLS,
)
)

View File

@@ -0,0 +1,120 @@
import unittest
from unittest.mock import patch
from langchain.agents.middleware import SummarizationMiddleware
import app.agent as agent_module
class _FakeLLM:
_llm_type = "openai-chat"
def __init__(self, model: str):
self.model = model
self.profile = {"max_input_tokens": 64000}
class TestAgentSummarizationStreaming(unittest.TestCase):
def test_streaming_agent_uses_non_streaming_llm_for_summary(self):
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
non_streaming_llm = _FakeLLM("non-streaming")
captured: dict = {}
def _fake_create_agent(**kwargs):
captured.update(kwargs)
return object()
with (
patch.object(
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
),
patch.object(agent, "_initialize_tools", return_value=[]),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
):
agent._create_agent(streaming=True)
summary_middleware = next(
middleware
for middleware in captured["middleware"]
if isinstance(middleware, SummarizationMiddleware)
)
self.assertIs(captured["model"], main_llm)
self.assertIs(summary_middleware.model, non_streaming_llm)
def test_streaming_agent_uses_non_streaming_llm_for_model_middlewares(self):
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
non_streaming_llm = _FakeLLM("non-streaming")
captured: dict = {}
class _FakeToolSelectorMiddleware:
def __init__(self, model, max_tools):
self.model = model
self.max_tools = max_tools
def _fake_create_agent(**kwargs):
captured.update(kwargs)
return object()
with (
patch.object(
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
),
patch.object(agent, "_initialize_tools", return_value=[]),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(
agent_module,
"LLMToolSelectorMiddleware",
_FakeToolSelectorMiddleware,
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 3),
):
agent._create_agent(streaming=True)
tool_selector_middleware = next(
middleware
for middleware in captured["middleware"]
if isinstance(middleware, _FakeToolSelectorMiddleware)
)
self.assertIs(tool_selector_middleware.model, non_streaming_llm)
def test_non_streaming_agent_reuses_main_llm_for_summary(self):
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
captured: dict = {}
def _fake_create_agent(**kwargs):
captured.update(kwargs)
return object()
with (
patch.object(agent, "_initialize_llm", return_value=main_llm),
patch.object(agent, "_initialize_tools", return_value=[]),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
):
agent._create_agent(streaming=False)
summary_middleware = next(
middleware
for middleware in captured["middleware"]
if isinstance(middleware, SummarizationMiddleware)
)
self.assertIs(captured["model"], main_llm)
self.assertIs(summary_middleware.model, main_llm)
if __name__ == "__main__":
unittest.main()