diff --git a/app/agent/__init__.py b/app/agent/__init__.py index b4afc341..07cd2e70 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -377,6 +377,11 @@ class MoviePilotAgent: llm = self._initialize_llm(streaming=streaming) self._sync_model_profile(llm) + # 为中间件内部模型调用准备非流式 LLM,避免与用户流式回复复用同一实例。 + non_streaming_llm = ( + llm if not streaming else self._initialize_llm(streaming=False) + ) + # 工具列表 tools = self._initialize_tools() @@ -399,8 +404,9 @@ class MoviePilotAgent: ), # 用量统计 UsageMiddleware(on_usage=self._record_usage), - # 上下文压缩 - SummarizationMiddleware(model=llm, trigger=("fraction", 0.85)), + SummarizationMiddleware( + model=non_streaming_llm, trigger=("fraction", 0.85) + ), # 错误工具调用修复 PatchToolCallsMiddleware(), ] @@ -409,7 +415,8 @@ class MoviePilotAgent: if settings.LLM_MAX_TOOLS > 0: middlewares.append( LLMToolSelectorMiddleware( - model=llm, max_tools=settings.LLM_MAX_TOOLS + model=non_streaming_llm, + max_tools=settings.LLM_MAX_TOOLS, ) ) diff --git a/tests/test_agent_summarization_streaming.py b/tests/test_agent_summarization_streaming.py new file mode 100644 index 00000000..5803524e --- /dev/null +++ b/tests/test_agent_summarization_streaming.py @@ -0,0 +1,120 @@ +import unittest +from unittest.mock import patch + +from langchain.agents.middleware import SummarizationMiddleware + +import app.agent as agent_module + + +class _FakeLLM: + _llm_type = "openai-chat" + + def __init__(self, model: str): + self.model = model + self.profile = {"max_input_tokens": 64000} + + +class TestAgentSummarizationStreaming(unittest.TestCase): + def test_streaming_agent_uses_non_streaming_llm_for_summary(self): + agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001") + main_llm = _FakeLLM("main") + non_streaming_llm = _FakeLLM("non-streaming") + captured: dict = {} + + def _fake_create_agent(**kwargs): + captured.update(kwargs) + return object() + + with ( + patch.object( + agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm] + ), + patch.object(agent, "_initialize_tools", return_value=[]), + patch.object( + agent_module.prompt_manager, "get_agent_prompt", return_value="prompt" + ), + patch.object(agent_module, "create_agent", side_effect=_fake_create_agent), + ): + agent._create_agent(streaming=True) + + summary_middleware = next( + middleware + for middleware in captured["middleware"] + if isinstance(middleware, SummarizationMiddleware) + ) + + self.assertIs(captured["model"], main_llm) + self.assertIs(summary_middleware.model, non_streaming_llm) + + def test_streaming_agent_uses_non_streaming_llm_for_model_middlewares(self): + agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001") + main_llm = _FakeLLM("main") + non_streaming_llm = _FakeLLM("non-streaming") + captured: dict = {} + + class _FakeToolSelectorMiddleware: + def __init__(self, model, max_tools): + self.model = model + self.max_tools = max_tools + + def _fake_create_agent(**kwargs): + captured.update(kwargs) + return object() + + with ( + patch.object( + agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm] + ), + patch.object(agent, "_initialize_tools", return_value=[]), + patch.object( + agent_module.prompt_manager, "get_agent_prompt", return_value="prompt" + ), + patch.object( + agent_module, + "LLMToolSelectorMiddleware", + _FakeToolSelectorMiddleware, + ), + patch.object(agent_module, "create_agent", side_effect=_fake_create_agent), + patch.object(agent_module.settings, "LLM_MAX_TOOLS", 3), + ): + agent._create_agent(streaming=True) + + tool_selector_middleware = next( + middleware + for middleware in captured["middleware"] + if isinstance(middleware, _FakeToolSelectorMiddleware) + ) + + self.assertIs(tool_selector_middleware.model, non_streaming_llm) + + def test_non_streaming_agent_reuses_main_llm_for_summary(self): + agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001") + main_llm = _FakeLLM("main") + captured: dict = {} + + def _fake_create_agent(**kwargs): + captured.update(kwargs) + return object() + + with ( + patch.object(agent, "_initialize_llm", return_value=main_llm), + patch.object(agent, "_initialize_tools", return_value=[]), + patch.object( + agent_module.prompt_manager, "get_agent_prompt", return_value="prompt" + ), + patch.object(agent_module, "create_agent", side_effect=_fake_create_agent), + ): + agent._create_agent(streaming=False) + + summary_middleware = next( + middleware + for middleware in captured["middleware"] + if isinstance(middleware, SummarizationMiddleware) + ) + + self.assertIs(captured["model"], main_llm) + self.assertIs(summary_middleware.model, main_llm) + + +if __name__ == "__main__": + unittest.main()