diff --git a/app/agent/__init__.py b/app/agent/__init__.py index ee5f7c73..6c87706f 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -349,6 +349,14 @@ class MoviePilotAgent: except Exception as e: logger.debug(f"智能体输出回调失败: {e}") + def _handle_stream_text(self, text: str): + """ + 统一处理一段可见流式文本,确保工具统计注入后的内容会同时进入 + 消息缓冲区和外部流式回调。 + """ + emitted_text = self.stream_handler.emit(text) + self._emit_output(emitted_text) + def _initialize_tools(self) -> List: """ 初始化工具列表 @@ -572,12 +580,13 @@ class MoviePilotAgent: agent=agent, messages={"messages": messages}, config=agent_config, - on_token=lambda token: ( - self.stream_handler.emit(token), - self._emit_output(token), - ), + on_token=self._handle_stream_text, ) + trailing_tool_summary = self.stream_handler.flush_pending_tool_summary() + if trailing_tool_summary: + self._emit_output(trailing_tool_summary) + # 停止流式输出,返回是否已通过流式编辑发送了所有内容及最终文本 ( all_sent_via_stream, @@ -588,8 +597,14 @@ class MoviePilotAgent: # 流式输出未能发送全部内容(发送失败等) # 通过常规方式发送剩余内容 remaining_text = await self.stream_handler.take() - if remaining_text and not self._streamed_output: - self._emit_output(remaining_text) + if remaining_text: + unsent_text = remaining_text + if self._streamed_output and remaining_text.startswith( + self._streamed_output + ): + unsent_text = remaining_text[len(self._streamed_output) :] + if unsent_text: + self._emit_output(unsent_text) if ( remaining_text and not self.suppress_user_reply diff --git a/app/agent/callback/__init__.py b/app/agent/callback/__init__.py index 0b081080..316ec835 100644 --- a/app/agent/callback/__init__.py +++ b/app/agent/callback/__init__.py @@ -1,6 +1,6 @@ import asyncio import threading -from typing import Optional, Tuple +from typing import Any, Optional, Tuple from fastapi.concurrency import run_in_threadpool @@ -62,16 +62,30 @@ class StreamingHandler: self._user_id: Optional[str] = None self._username: Optional[str] = None self._title: str = "" + # 非啰嗦模式下的待输出工具统计,等下一段文本到来时再统一补一句摘要 + self._pending_tool_stats: dict[str, dict[str, Any]] = {} - def emit(self, token: str): + def emit(self, token: str) -> str: """ 接收 LLM 流式 token,积累到缓冲区。 + 如果存在待输出的工具统计,则会先补上一句摘要再追加 token。 """ with self._lock: + emitted = token or "" + + if self._pending_tool_stats: + summary = self._consume_pending_tool_summary_locked() + if summary: + if emitted: + emitted = f"{summary}{emitted.lstrip(chr(10))}" + else: + emitted = summary + # 如果存量消息结束是两个换行,则去掉新消息前面的换行,避免过多空行 - if self._buffer.endswith("\n\n") and token.startswith("\n"): - token = token.lstrip("\n") - self._buffer += token + if self._buffer.endswith("\n\n") and emitted.startswith("\n"): + emitted = emitted.lstrip("\n") + self._buffer += emitted + return emitted async def take(self) -> str: """ @@ -82,6 +96,8 @@ class StreamingHandler: 注意:流式渠道不调用此方法,工具消息直接 emit 到 buffer 中。 """ + self.flush_pending_tool_summary() + with self._lock: if not self._buffer: return "" @@ -99,6 +115,7 @@ class StreamingHandler: self._sent_text = "" self._message_response = None self._msg_start_offset = 0 + self._pending_tool_stats = {} def reset(self): """ @@ -112,6 +129,7 @@ class StreamingHandler: self._buffer = "" self._sent_text = "" self._msg_start_offset = 0 + self._pending_tool_stats = {} async def start_streaming( self, @@ -141,6 +159,7 @@ class StreamingHandler: self._sent_text = "" self._message_response = None self._msg_start_offset = 0 + self._pending_tool_stats = {} # 检查渠道是否支持消息编辑,不支持则仅收集 token 到 buffer,不实时推送 if not self._can_stream(): @@ -176,6 +195,9 @@ class StreamingHandler: # 取消定时任务 await self._cancel_flush_task() + # 将未落地的工具统计补入缓冲区,避免流式结束时丢失这段执行信息 + self.flush_pending_tool_summary() + # 执行最后一次刷新 await self._flush() @@ -194,11 +216,172 @@ class StreamingHandler: self._sent_text = "" self._message_response = None self._msg_start_offset = 0 + self._pending_tool_stats = {} if all_sent: # 所有内容已通过流式发送,清空缓冲区 self._buffer = "" return all_sent, final_text + def record_tool_call( + self, + tool_name: str, + tool_message: Optional[str] = None, + tool_kwargs: Optional[dict[str, Any]] = None, + ): + """ + 记录一次工具调用,供非啰嗦模式下延迟汇总输出。 + """ + category, target = self._classify_tool_call( + tool_name=tool_name, + tool_message=tool_message, + tool_kwargs=tool_kwargs or {}, + ) + with self._lock: + bucket = self._pending_tool_stats.setdefault( + category, + { + "count": 0, + "targets": set(), + }, + ) + bucket["count"] += 1 + if target: + bucket["targets"].add(str(target)) + + def flush_pending_tool_summary(self) -> str: + """ + 将待输出的工具统计摘要补入缓冲区,并返回本次新增的摘要文本。 + """ + with self._lock: + summary = self._consume_pending_tool_summary_locked() + if summary: + self._buffer += summary + return summary + + @staticmethod + def _classify_tool_call( + tool_name: str, + tool_message: Optional[str], + tool_kwargs: dict[str, Any], + ) -> tuple[str, Optional[str]]: + tool_name = (tool_name or "").strip().lower() + tool_message = (tool_message or "").strip() + tool_message_lower = tool_message.lower() + + if tool_name == "read_file": + return "file_read", tool_kwargs.get("file_path") + if tool_name in {"write_file", "edit_file"}: + return "file_write", tool_kwargs.get("file_path") + if tool_name in {"list_directory", "query_directory_settings"}: + return "directory", tool_kwargs.get("path") + if tool_name == "browse_webpage": + return ( + "web_browse", + tool_kwargs.get("url") + or tool_kwargs.get("target_url") + or tool_kwargs.get("path"), + ) + if tool_name == "execute_command": + return "command", tool_kwargs.get("command") + if tool_name == "ask_user_choice": + return "interaction", tool_kwargs.get("message") + if tool_name.startswith("search_") or tool_name in {"get_search_results"}: + return ( + "search", + tool_kwargs.get("query") + or tool_kwargs.get("title") + or tool_kwargs.get("keyword"), + ) + if tool_name.startswith("query_") or tool_name.startswith("list_") or tool_name.startswith("get_"): + return "data_query", None + if tool_name.startswith(("add_", "update_", "delete_", "modify_", "run_")): + return "action", None + if tool_name in { + "recognize_media", + "scrape_metadata", + "transfer_file", + "test_site", + "send_message", + "send_local_file", + "send_voice_message", + }: + return "action", None + + if "读取文件" in tool_message or "read file" in tool_message_lower: + return "file_read", tool_kwargs.get("file_path") + if ( + "写入文件" in tool_message + or "编辑文件" in tool_message + or "write file" in tool_message_lower + or "edit file" in tool_message_lower + ): + return "file_write", tool_kwargs.get("file_path") + if "目录" in tool_message or "directory" in tool_message_lower: + return "directory", tool_kwargs.get("path") + if "搜索" in tool_message or "search" in tool_message_lower: + return ( + "search", + tool_kwargs.get("query") + or tool_kwargs.get("title") + or tool_kwargs.get("keyword"), + ) + if "网页" in tool_message or "browser" in tool_message_lower or "webpage" in tool_message_lower: + return "web_browse", tool_kwargs.get("url") + if "命令" in tool_message or "command" in tool_message_lower: + return "command", tool_kwargs.get("command") + + return "tool", None + + def _consume_pending_tool_summary_locked(self) -> str: + if not self._pending_tool_stats: + return "" + + parts = [] + for category, bucket in self._pending_tool_stats.items(): + value = bucket["count"] + if category in {"file_read", "file_write", "directory", "web_browse"} and bucket["targets"]: + value = len(bucket["targets"]) + part = self._format_tool_stat(category, value) + if part: + parts.append(part) + + self._pending_tool_stats = {} + if not parts: + return "" + + summary = f"({','.join(parts)})" + visible_buffer = self._buffer.rstrip(" \t") + last_char = visible_buffer[-1:] if visible_buffer.strip() else "" + prefix = "" + if self._buffer and last_char != "\n": + prefix = "\n" + return f"{prefix}{summary}\n" + + @staticmethod + def _format_tool_stat(category: str, count: int) -> str: + if count <= 0: + return "" + + if category == "search": + return f"执行了 {count} 次搜索" + if category == "file_read": + return f"读取了 {count} 个文件" + if category == "file_write": + return f"修改了 {count} 个文件" + if category == "directory": + return f"查看了 {count} 个目录" + if category == "web_browse": + return f"浏览了 {count} 个网页" + if category == "command": + return f"执行了 {count} 条命令" + if category == "data_query": + return f"查询了 {count} 次数据" + if category == "action": + return f"执行了 {count} 次操作" + if category == "interaction": + return f"发起了 {count} 次交互" + return f"调用了 {count} 次工具" + def _can_stream(self) -> bool: """ 检查当前渠道是否支持流式输出(消息编辑) diff --git a/app/agent/tools/base.py b/app/agent/tools/base.py index e0972d2d..17757903 100644 --- a/app/agent/tools/base.py +++ b/app/agent/tools/base.py @@ -124,9 +124,12 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): merged_message = "\n\n".join(messages) await self.send_tool_message(merged_message) else: - # 非VERBOSE:工具边界至少补一个换行,避免工具前后的文本直接连在一起 - if self._stream_handler.last_buffer_char not in ("", "\n"): - self._stream_handler.emit("\n") + # 非VERBOSE:不逐条回显工具调用,转为在下一段文本前补一句聚合摘要 + self._stream_handler.record_tool_call( + tool_name=self.name, + tool_message=tool_message, + tool_kwargs=kwargs, + ) else: # 未启用流式传输,不发送任何工具消息内容 pass diff --git a/app/api/endpoints/openai.py b/app/api/endpoints/openai.py index 01a30185..6eb12e46 100644 --- a/app/api/endpoints/openai.py +++ b/app/api/endpoints/openai.py @@ -69,9 +69,15 @@ class _OpenAIStreamingHandler(StreamingHandler): self._event_queue = queue def emit(self, token: str): - super().emit(token) - if token and self._event_queue is not None: - self._event_queue.put_nowait(token) + emitted = super().emit(token) + if emitted and self._event_queue is not None: + self._event_queue.put_nowait(emitted) + + def flush_pending_tool_summary(self) -> str: + emitted = super().flush_pending_tool_summary() + if emitted and self._event_queue is not None: + self._event_queue.put_nowait(emitted) + return emitted async def start_streaming( self, diff --git a/tests/test_agent_tool_streaming.py b/tests/test_agent_tool_streaming.py index bfeb556c..755fba04 100644 --- a/tests/test_agent_tool_streaming.py +++ b/tests/test_agent_tool_streaming.py @@ -9,6 +9,7 @@ if not hasattr(langchain_agents, "create_agent"): from app.agent.callback import StreamingHandler from app.agent.tools.base import MoviePilotTool +from app.api.endpoints.openai import _OpenAIStreamingHandler from app.core.config import settings from app.schemas.message import MessageResponse from app.schemas.types import MessageChannel @@ -37,23 +38,101 @@ class TestAgentToolStreaming(unittest.TestCase): buffered_message = await handler.take() return result, buffered_message - def test_non_verbose_tool_call_appends_newline_separator(self): + def test_non_verbose_tool_call_flushes_summary_on_take(self): result, buffered_message = asyncio.run(self._run_tool("prefix")) self.assertEqual(result, "ok") - self.assertEqual(buffered_message, "prefix\n") + self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n") - def test_non_verbose_tool_call_does_not_duplicate_newline(self): + def test_non_verbose_tool_call_reuses_existing_newline_before_summary(self): result, buffered_message = asyncio.run(self._run_tool("prefix\n")) self.assertEqual(result, "ok") - self.assertEqual(buffered_message, "prefix\n") + self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n") - def test_non_verbose_tool_call_keeps_empty_buffer_unchanged(self): + def test_non_verbose_tool_call_emits_summary_even_when_buffer_was_empty(self): result, buffered_message = asyncio.run(self._run_tool("")) self.assertEqual(result, "ok") - self.assertEqual(buffered_message, "") + self.assertEqual(buffered_message, "(调用了 1 次工具)\n") + + def test_non_verbose_tool_summary_is_inserted_before_next_text(self): + async def _run(): + tool = DummyTool(session_id="session-1", user_id="10001") + handler = StreamingHandler() + await handler.start_streaming() + handler.emit("让我来检查一下:") + tool.set_stream_handler(handler) + + with patch.object(settings, "AI_AGENT_VERBOSE", False): + await tool._arun(explanation="run test tool") + + handler.emit("已经拿到结果") + return await handler.take() + + buffered_message = asyncio.run(_run()) + + self.assertEqual( + buffered_message, + "让我来检查一下:\n(调用了 1 次工具)\n已经拿到结果", + ) + + def test_non_verbose_tool_summary_aggregates_multiple_categories(self): + async def _run(): + handler = StreamingHandler() + await handler.start_streaming() + handler.emit("处理中:") + handler.record_tool_call( + tool_name="search_web", + tool_message="搜索网络内容: MoviePilot", + tool_kwargs={"query": "MoviePilot"}, + ) + handler.record_tool_call( + tool_name="search_web", + tool_message="搜索网络内容: agent streaming", + tool_kwargs={"query": "agent streaming"}, + ) + handler.record_tool_call( + tool_name="read_file", + tool_message="读取文件: a.py", + tool_kwargs={"file_path": "/tmp/a.py"}, + ) + handler.record_tool_call( + tool_name="read_file", + tool_message="读取文件: b.py", + tool_kwargs={"file_path": "/tmp/b.py"}, + ) + handler.emit("继续分析") + return await handler.take() + + buffered_message = asyncio.run(_run()) + + self.assertEqual( + buffered_message, + "处理中:\n(执行了 2 次搜索,读取了 2 个文件)\n继续分析", + ) + + def test_openai_streaming_handler_flushes_pending_summary_to_queue(self): + async def _run(): + handler = _OpenAIStreamingHandler() + queue: asyncio.Queue = asyncio.Queue() + handler.bind_queue(queue) + await handler.start_streaming() + handler.record_tool_call( + tool_name="read_file", + tool_message="读取文件: app.py", + tool_kwargs={"file_path": "/tmp/app.py"}, + ) + emitted = handler.flush_pending_tool_summary() + queued = await queue.get() + buffered_message = await handler.take() + return emitted, queued, buffered_message + + emitted, queued, buffered_message = asyncio.run(_run()) + + self.assertEqual(emitted, "(读取了 1 个文件)\n") + self.assertEqual(queued, emitted) + self.assertEqual(buffered_message, emitted) def test_flush_sends_direct_message_via_threadpool(self): handler = StreamingHandler()