Improve non-verbose agent tool summaries

This commit is contained in:
jxxghp
2026-04-29 07:07:33 +08:00
parent d4dec90e2f
commit 8789f35228
5 changed files with 309 additions and 23 deletions

View File

@@ -349,6 +349,14 @@ class MoviePilotAgent:
except Exception as e:
logger.debug(f"智能体输出回调失败: {e}")
def _handle_stream_text(self, text: str):
"""
统一处理一段可见流式文本,确保工具统计注入后的内容会同时进入
消息缓冲区和外部流式回调。
"""
emitted_text = self.stream_handler.emit(text)
self._emit_output(emitted_text)
def _initialize_tools(self) -> List:
"""
初始化工具列表
@@ -572,12 +580,13 @@ class MoviePilotAgent:
agent=agent,
messages={"messages": messages},
config=agent_config,
on_token=lambda token: (
self.stream_handler.emit(token),
self._emit_output(token),
),
on_token=self._handle_stream_text,
)
trailing_tool_summary = self.stream_handler.flush_pending_tool_summary()
if trailing_tool_summary:
self._emit_output(trailing_tool_summary)
# 停止流式输出,返回是否已通过流式编辑发送了所有内容及最终文本
(
all_sent_via_stream,
@@ -588,8 +597,14 @@ class MoviePilotAgent:
# 流式输出未能发送全部内容(发送失败等)
# 通过常规方式发送剩余内容
remaining_text = await self.stream_handler.take()
if remaining_text and not self._streamed_output:
self._emit_output(remaining_text)
if remaining_text:
unsent_text = remaining_text
if self._streamed_output and remaining_text.startswith(
self._streamed_output
):
unsent_text = remaining_text[len(self._streamed_output) :]
if unsent_text:
self._emit_output(unsent_text)
if (
remaining_text
and not self.suppress_user_reply

View File

@@ -1,6 +1,6 @@
import asyncio
import threading
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
from fastapi.concurrency import run_in_threadpool
@@ -62,16 +62,30 @@ class StreamingHandler:
self._user_id: Optional[str] = None
self._username: Optional[str] = None
self._title: str = ""
# 非啰嗦模式下的待输出工具统计,等下一段文本到来时再统一补一句摘要
self._pending_tool_stats: dict[str, dict[str, Any]] = {}
def emit(self, token: str):
def emit(self, token: str) -> str:
"""
接收 LLM 流式 token积累到缓冲区。
如果存在待输出的工具统计,则会先补上一句摘要再追加 token。
"""
with self._lock:
emitted = token or ""
if self._pending_tool_stats:
summary = self._consume_pending_tool_summary_locked()
if summary:
if emitted:
emitted = f"{summary}{emitted.lstrip(chr(10))}"
else:
emitted = summary
# 如果存量消息结束是两个换行,则去掉新消息前面的换行,避免过多空行
if self._buffer.endswith("\n\n") and token.startswith("\n"):
token = token.lstrip("\n")
self._buffer += token
if self._buffer.endswith("\n\n") and emitted.startswith("\n"):
emitted = emitted.lstrip("\n")
self._buffer += emitted
return emitted
async def take(self) -> str:
"""
@@ -82,6 +96,8 @@ class StreamingHandler:
注意:流式渠道不调用此方法,工具消息直接 emit 到 buffer 中。
"""
self.flush_pending_tool_summary()
with self._lock:
if not self._buffer:
return ""
@@ -99,6 +115,7 @@ class StreamingHandler:
self._sent_text = ""
self._message_response = None
self._msg_start_offset = 0
self._pending_tool_stats = {}
def reset(self):
"""
@@ -112,6 +129,7 @@ class StreamingHandler:
self._buffer = ""
self._sent_text = ""
self._msg_start_offset = 0
self._pending_tool_stats = {}
async def start_streaming(
self,
@@ -141,6 +159,7 @@ class StreamingHandler:
self._sent_text = ""
self._message_response = None
self._msg_start_offset = 0
self._pending_tool_stats = {}
# 检查渠道是否支持消息编辑,不支持则仅收集 token 到 buffer不实时推送
if not self._can_stream():
@@ -176,6 +195,9 @@ class StreamingHandler:
# 取消定时任务
await self._cancel_flush_task()
# 将未落地的工具统计补入缓冲区,避免流式结束时丢失这段执行信息
self.flush_pending_tool_summary()
# 执行最后一次刷新
await self._flush()
@@ -194,11 +216,172 @@ class StreamingHandler:
self._sent_text = ""
self._message_response = None
self._msg_start_offset = 0
self._pending_tool_stats = {}
if all_sent:
# 所有内容已通过流式发送,清空缓冲区
self._buffer = ""
return all_sent, final_text
def record_tool_call(
self,
tool_name: str,
tool_message: Optional[str] = None,
tool_kwargs: Optional[dict[str, Any]] = None,
):
"""
记录一次工具调用,供非啰嗦模式下延迟汇总输出。
"""
category, target = self._classify_tool_call(
tool_name=tool_name,
tool_message=tool_message,
tool_kwargs=tool_kwargs or {},
)
with self._lock:
bucket = self._pending_tool_stats.setdefault(
category,
{
"count": 0,
"targets": set(),
},
)
bucket["count"] += 1
if target:
bucket["targets"].add(str(target))
def flush_pending_tool_summary(self) -> str:
"""
将待输出的工具统计摘要补入缓冲区,并返回本次新增的摘要文本。
"""
with self._lock:
summary = self._consume_pending_tool_summary_locked()
if summary:
self._buffer += summary
return summary
@staticmethod
def _classify_tool_call(
tool_name: str,
tool_message: Optional[str],
tool_kwargs: dict[str, Any],
) -> tuple[str, Optional[str]]:
tool_name = (tool_name or "").strip().lower()
tool_message = (tool_message or "").strip()
tool_message_lower = tool_message.lower()
if tool_name == "read_file":
return "file_read", tool_kwargs.get("file_path")
if tool_name in {"write_file", "edit_file"}:
return "file_write", tool_kwargs.get("file_path")
if tool_name in {"list_directory", "query_directory_settings"}:
return "directory", tool_kwargs.get("path")
if tool_name == "browse_webpage":
return (
"web_browse",
tool_kwargs.get("url")
or tool_kwargs.get("target_url")
or tool_kwargs.get("path"),
)
if tool_name == "execute_command":
return "command", tool_kwargs.get("command")
if tool_name == "ask_user_choice":
return "interaction", tool_kwargs.get("message")
if tool_name.startswith("search_") or tool_name in {"get_search_results"}:
return (
"search",
tool_kwargs.get("query")
or tool_kwargs.get("title")
or tool_kwargs.get("keyword"),
)
if tool_name.startswith("query_") or tool_name.startswith("list_") or tool_name.startswith("get_"):
return "data_query", None
if tool_name.startswith(("add_", "update_", "delete_", "modify_", "run_")):
return "action", None
if tool_name in {
"recognize_media",
"scrape_metadata",
"transfer_file",
"test_site",
"send_message",
"send_local_file",
"send_voice_message",
}:
return "action", None
if "读取文件" in tool_message or "read file" in tool_message_lower:
return "file_read", tool_kwargs.get("file_path")
if (
"写入文件" in tool_message
or "编辑文件" in tool_message
or "write file" in tool_message_lower
or "edit file" in tool_message_lower
):
return "file_write", tool_kwargs.get("file_path")
if "目录" in tool_message or "directory" in tool_message_lower:
return "directory", tool_kwargs.get("path")
if "搜索" in tool_message or "search" in tool_message_lower:
return (
"search",
tool_kwargs.get("query")
or tool_kwargs.get("title")
or tool_kwargs.get("keyword"),
)
if "网页" in tool_message or "browser" in tool_message_lower or "webpage" in tool_message_lower:
return "web_browse", tool_kwargs.get("url")
if "命令" in tool_message or "command" in tool_message_lower:
return "command", tool_kwargs.get("command")
return "tool", None
def _consume_pending_tool_summary_locked(self) -> str:
if not self._pending_tool_stats:
return ""
parts = []
for category, bucket in self._pending_tool_stats.items():
value = bucket["count"]
if category in {"file_read", "file_write", "directory", "web_browse"} and bucket["targets"]:
value = len(bucket["targets"])
part = self._format_tool_stat(category, value)
if part:
parts.append(part)
self._pending_tool_stats = {}
if not parts:
return ""
summary = f"{''.join(parts)}"
visible_buffer = self._buffer.rstrip(" \t")
last_char = visible_buffer[-1:] if visible_buffer.strip() else ""
prefix = ""
if self._buffer and last_char != "\n":
prefix = "\n"
return f"{prefix}{summary}\n"
@staticmethod
def _format_tool_stat(category: str, count: int) -> str:
if count <= 0:
return ""
if category == "search":
return f"执行了 {count} 次搜索"
if category == "file_read":
return f"读取了 {count} 个文件"
if category == "file_write":
return f"修改了 {count} 个文件"
if category == "directory":
return f"查看了 {count} 个目录"
if category == "web_browse":
return f"浏览了 {count} 个网页"
if category == "command":
return f"执行了 {count} 条命令"
if category == "data_query":
return f"查询了 {count} 次数据"
if category == "action":
return f"执行了 {count} 次操作"
if category == "interaction":
return f"发起了 {count} 次交互"
return f"调用了 {count} 次工具"
def _can_stream(self) -> bool:
"""
检查当前渠道是否支持流式输出(消息编辑)

View File

@@ -124,9 +124,12 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
merged_message = "\n\n".join(messages)
await self.send_tool_message(merged_message)
else:
# 非VERBOSE工具边界至少补一个换行,避免工具前后的文本直接连在一起
if self._stream_handler.last_buffer_char not in ("", "\n"):
self._stream_handler.emit("\n")
# 非VERBOSE不逐条回显工具调用,转为在下一段文本前补一句聚合摘要
self._stream_handler.record_tool_call(
tool_name=self.name,
tool_message=tool_message,
tool_kwargs=kwargs,
)
else:
# 未启用流式传输,不发送任何工具消息内容
pass

View File

@@ -69,9 +69,15 @@ class _OpenAIStreamingHandler(StreamingHandler):
self._event_queue = queue
def emit(self, token: str):
super().emit(token)
if token and self._event_queue is not None:
self._event_queue.put_nowait(token)
emitted = super().emit(token)
if emitted and self._event_queue is not None:
self._event_queue.put_nowait(emitted)
def flush_pending_tool_summary(self) -> str:
emitted = super().flush_pending_tool_summary()
if emitted and self._event_queue is not None:
self._event_queue.put_nowait(emitted)
return emitted
async def start_streaming(
self,

View File

@@ -9,6 +9,7 @@ if not hasattr(langchain_agents, "create_agent"):
from app.agent.callback import StreamingHandler
from app.agent.tools.base import MoviePilotTool
from app.api.endpoints.openai import _OpenAIStreamingHandler
from app.core.config import settings
from app.schemas.message import MessageResponse
from app.schemas.types import MessageChannel
@@ -37,23 +38,101 @@ class TestAgentToolStreaming(unittest.TestCase):
buffered_message = await handler.take()
return result, buffered_message
def test_non_verbose_tool_call_appends_newline_separator(self):
def test_non_verbose_tool_call_flushes_summary_on_take(self):
result, buffered_message = asyncio.run(self._run_tool("prefix"))
self.assertEqual(result, "ok")
self.assertEqual(buffered_message, "prefix\n")
self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n")
def test_non_verbose_tool_call_does_not_duplicate_newline(self):
def test_non_verbose_tool_call_reuses_existing_newline_before_summary(self):
result, buffered_message = asyncio.run(self._run_tool("prefix\n"))
self.assertEqual(result, "ok")
self.assertEqual(buffered_message, "prefix\n")
self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n")
def test_non_verbose_tool_call_keeps_empty_buffer_unchanged(self):
def test_non_verbose_tool_call_emits_summary_even_when_buffer_was_empty(self):
result, buffered_message = asyncio.run(self._run_tool(""))
self.assertEqual(result, "ok")
self.assertEqual(buffered_message, "")
self.assertEqual(buffered_message, "(调用了 1 次工具)\n")
def test_non_verbose_tool_summary_is_inserted_before_next_text(self):
async def _run():
tool = DummyTool(session_id="session-1", user_id="10001")
handler = StreamingHandler()
await handler.start_streaming()
handler.emit("让我来检查一下:")
tool.set_stream_handler(handler)
with patch.object(settings, "AI_AGENT_VERBOSE", False):
await tool._arun(explanation="run test tool")
handler.emit("已经拿到结果")
return await handler.take()
buffered_message = asyncio.run(_run())
self.assertEqual(
buffered_message,
"让我来检查一下:\n(调用了 1 次工具)\n已经拿到结果",
)
def test_non_verbose_tool_summary_aggregates_multiple_categories(self):
async def _run():
handler = StreamingHandler()
await handler.start_streaming()
handler.emit("处理中:")
handler.record_tool_call(
tool_name="search_web",
tool_message="搜索网络内容: MoviePilot",
tool_kwargs={"query": "MoviePilot"},
)
handler.record_tool_call(
tool_name="search_web",
tool_message="搜索网络内容: agent streaming",
tool_kwargs={"query": "agent streaming"},
)
handler.record_tool_call(
tool_name="read_file",
tool_message="读取文件: a.py",
tool_kwargs={"file_path": "/tmp/a.py"},
)
handler.record_tool_call(
tool_name="read_file",
tool_message="读取文件: b.py",
tool_kwargs={"file_path": "/tmp/b.py"},
)
handler.emit("继续分析")
return await handler.take()
buffered_message = asyncio.run(_run())
self.assertEqual(
buffered_message,
"处理中:\n(执行了 2 次搜索,读取了 2 个文件)\n继续分析",
)
def test_openai_streaming_handler_flushes_pending_summary_to_queue(self):
async def _run():
handler = _OpenAIStreamingHandler()
queue: asyncio.Queue = asyncio.Queue()
handler.bind_queue(queue)
await handler.start_streaming()
handler.record_tool_call(
tool_name="read_file",
tool_message="读取文件: app.py",
tool_kwargs={"file_path": "/tmp/app.py"},
)
emitted = handler.flush_pending_tool_summary()
queued = await queue.get()
buffered_message = await handler.take()
return emitted, queued, buffered_message
emitted, queued, buffered_message = asyncio.run(_run())
self.assertEqual(emitted, "(读取了 1 个文件)\n")
self.assertEqual(queued, emitted)
self.assertEqual(buffered_message, emitted)
def test_flush_sends_direct_message_via_threadpool(self):
handler = StreamingHandler()