mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 20:42:43 +08:00
Improve non-verbose agent tool summaries
This commit is contained in:
@@ -349,6 +349,14 @@ class MoviePilotAgent:
|
||||
except Exception as e:
|
||||
logger.debug(f"智能体输出回调失败: {e}")
|
||||
|
||||
def _handle_stream_text(self, text: str):
|
||||
"""
|
||||
统一处理一段可见流式文本,确保工具统计注入后的内容会同时进入
|
||||
消息缓冲区和外部流式回调。
|
||||
"""
|
||||
emitted_text = self.stream_handler.emit(text)
|
||||
self._emit_output(emitted_text)
|
||||
|
||||
def _initialize_tools(self) -> List:
|
||||
"""
|
||||
初始化工具列表
|
||||
@@ -572,12 +580,13 @@ class MoviePilotAgent:
|
||||
agent=agent,
|
||||
messages={"messages": messages},
|
||||
config=agent_config,
|
||||
on_token=lambda token: (
|
||||
self.stream_handler.emit(token),
|
||||
self._emit_output(token),
|
||||
),
|
||||
on_token=self._handle_stream_text,
|
||||
)
|
||||
|
||||
trailing_tool_summary = self.stream_handler.flush_pending_tool_summary()
|
||||
if trailing_tool_summary:
|
||||
self._emit_output(trailing_tool_summary)
|
||||
|
||||
# 停止流式输出,返回是否已通过流式编辑发送了所有内容及最终文本
|
||||
(
|
||||
all_sent_via_stream,
|
||||
@@ -588,8 +597,14 @@ class MoviePilotAgent:
|
||||
# 流式输出未能发送全部内容(发送失败等)
|
||||
# 通过常规方式发送剩余内容
|
||||
remaining_text = await self.stream_handler.take()
|
||||
if remaining_text and not self._streamed_output:
|
||||
self._emit_output(remaining_text)
|
||||
if remaining_text:
|
||||
unsent_text = remaining_text
|
||||
if self._streamed_output and remaining_text.startswith(
|
||||
self._streamed_output
|
||||
):
|
||||
unsent_text = remaining_text[len(self._streamed_output) :]
|
||||
if unsent_text:
|
||||
self._emit_output(unsent_text)
|
||||
if (
|
||||
remaining_text
|
||||
and not self.suppress_user_reply
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
@@ -62,16 +62,30 @@ class StreamingHandler:
|
||||
self._user_id: Optional[str] = None
|
||||
self._username: Optional[str] = None
|
||||
self._title: str = ""
|
||||
# 非啰嗦模式下的待输出工具统计,等下一段文本到来时再统一补一句摘要
|
||||
self._pending_tool_stats: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def emit(self, token: str):
|
||||
def emit(self, token: str) -> str:
|
||||
"""
|
||||
接收 LLM 流式 token,积累到缓冲区。
|
||||
如果存在待输出的工具统计,则会先补上一句摘要再追加 token。
|
||||
"""
|
||||
with self._lock:
|
||||
emitted = token or ""
|
||||
|
||||
if self._pending_tool_stats:
|
||||
summary = self._consume_pending_tool_summary_locked()
|
||||
if summary:
|
||||
if emitted:
|
||||
emitted = f"{summary}{emitted.lstrip(chr(10))}"
|
||||
else:
|
||||
emitted = summary
|
||||
|
||||
# 如果存量消息结束是两个换行,则去掉新消息前面的换行,避免过多空行
|
||||
if self._buffer.endswith("\n\n") and token.startswith("\n"):
|
||||
token = token.lstrip("\n")
|
||||
self._buffer += token
|
||||
if self._buffer.endswith("\n\n") and emitted.startswith("\n"):
|
||||
emitted = emitted.lstrip("\n")
|
||||
self._buffer += emitted
|
||||
return emitted
|
||||
|
||||
async def take(self) -> str:
|
||||
"""
|
||||
@@ -82,6 +96,8 @@ class StreamingHandler:
|
||||
|
||||
注意:流式渠道不调用此方法,工具消息直接 emit 到 buffer 中。
|
||||
"""
|
||||
self.flush_pending_tool_summary()
|
||||
|
||||
with self._lock:
|
||||
if not self._buffer:
|
||||
return ""
|
||||
@@ -99,6 +115,7 @@ class StreamingHandler:
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
self._pending_tool_stats = {}
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
@@ -112,6 +129,7 @@ class StreamingHandler:
|
||||
self._buffer = ""
|
||||
self._sent_text = ""
|
||||
self._msg_start_offset = 0
|
||||
self._pending_tool_stats = {}
|
||||
|
||||
async def start_streaming(
|
||||
self,
|
||||
@@ -141,6 +159,7 @@ class StreamingHandler:
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
self._pending_tool_stats = {}
|
||||
|
||||
# 检查渠道是否支持消息编辑,不支持则仅收集 token 到 buffer,不实时推送
|
||||
if not self._can_stream():
|
||||
@@ -176,6 +195,9 @@ class StreamingHandler:
|
||||
# 取消定时任务
|
||||
await self._cancel_flush_task()
|
||||
|
||||
# 将未落地的工具统计补入缓冲区,避免流式结束时丢失这段执行信息
|
||||
self.flush_pending_tool_summary()
|
||||
|
||||
# 执行最后一次刷新
|
||||
await self._flush()
|
||||
|
||||
@@ -194,11 +216,172 @@ class StreamingHandler:
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
self._pending_tool_stats = {}
|
||||
if all_sent:
|
||||
# 所有内容已通过流式发送,清空缓冲区
|
||||
self._buffer = ""
|
||||
return all_sent, final_text
|
||||
|
||||
def record_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_message: Optional[str] = None,
|
||||
tool_kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
记录一次工具调用,供非啰嗦模式下延迟汇总输出。
|
||||
"""
|
||||
category, target = self._classify_tool_call(
|
||||
tool_name=tool_name,
|
||||
tool_message=tool_message,
|
||||
tool_kwargs=tool_kwargs or {},
|
||||
)
|
||||
with self._lock:
|
||||
bucket = self._pending_tool_stats.setdefault(
|
||||
category,
|
||||
{
|
||||
"count": 0,
|
||||
"targets": set(),
|
||||
},
|
||||
)
|
||||
bucket["count"] += 1
|
||||
if target:
|
||||
bucket["targets"].add(str(target))
|
||||
|
||||
def flush_pending_tool_summary(self) -> str:
|
||||
"""
|
||||
将待输出的工具统计摘要补入缓冲区,并返回本次新增的摘要文本。
|
||||
"""
|
||||
with self._lock:
|
||||
summary = self._consume_pending_tool_summary_locked()
|
||||
if summary:
|
||||
self._buffer += summary
|
||||
return summary
|
||||
|
||||
@staticmethod
|
||||
def _classify_tool_call(
|
||||
tool_name: str,
|
||||
tool_message: Optional[str],
|
||||
tool_kwargs: dict[str, Any],
|
||||
) -> tuple[str, Optional[str]]:
|
||||
tool_name = (tool_name or "").strip().lower()
|
||||
tool_message = (tool_message or "").strip()
|
||||
tool_message_lower = tool_message.lower()
|
||||
|
||||
if tool_name == "read_file":
|
||||
return "file_read", tool_kwargs.get("file_path")
|
||||
if tool_name in {"write_file", "edit_file"}:
|
||||
return "file_write", tool_kwargs.get("file_path")
|
||||
if tool_name in {"list_directory", "query_directory_settings"}:
|
||||
return "directory", tool_kwargs.get("path")
|
||||
if tool_name == "browse_webpage":
|
||||
return (
|
||||
"web_browse",
|
||||
tool_kwargs.get("url")
|
||||
or tool_kwargs.get("target_url")
|
||||
or tool_kwargs.get("path"),
|
||||
)
|
||||
if tool_name == "execute_command":
|
||||
return "command", tool_kwargs.get("command")
|
||||
if tool_name == "ask_user_choice":
|
||||
return "interaction", tool_kwargs.get("message")
|
||||
if tool_name.startswith("search_") or tool_name in {"get_search_results"}:
|
||||
return (
|
||||
"search",
|
||||
tool_kwargs.get("query")
|
||||
or tool_kwargs.get("title")
|
||||
or tool_kwargs.get("keyword"),
|
||||
)
|
||||
if tool_name.startswith("query_") or tool_name.startswith("list_") or tool_name.startswith("get_"):
|
||||
return "data_query", None
|
||||
if tool_name.startswith(("add_", "update_", "delete_", "modify_", "run_")):
|
||||
return "action", None
|
||||
if tool_name in {
|
||||
"recognize_media",
|
||||
"scrape_metadata",
|
||||
"transfer_file",
|
||||
"test_site",
|
||||
"send_message",
|
||||
"send_local_file",
|
||||
"send_voice_message",
|
||||
}:
|
||||
return "action", None
|
||||
|
||||
if "读取文件" in tool_message or "read file" in tool_message_lower:
|
||||
return "file_read", tool_kwargs.get("file_path")
|
||||
if (
|
||||
"写入文件" in tool_message
|
||||
or "编辑文件" in tool_message
|
||||
or "write file" in tool_message_lower
|
||||
or "edit file" in tool_message_lower
|
||||
):
|
||||
return "file_write", tool_kwargs.get("file_path")
|
||||
if "目录" in tool_message or "directory" in tool_message_lower:
|
||||
return "directory", tool_kwargs.get("path")
|
||||
if "搜索" in tool_message or "search" in tool_message_lower:
|
||||
return (
|
||||
"search",
|
||||
tool_kwargs.get("query")
|
||||
or tool_kwargs.get("title")
|
||||
or tool_kwargs.get("keyword"),
|
||||
)
|
||||
if "网页" in tool_message or "browser" in tool_message_lower or "webpage" in tool_message_lower:
|
||||
return "web_browse", tool_kwargs.get("url")
|
||||
if "命令" in tool_message or "command" in tool_message_lower:
|
||||
return "command", tool_kwargs.get("command")
|
||||
|
||||
return "tool", None
|
||||
|
||||
def _consume_pending_tool_summary_locked(self) -> str:
|
||||
if not self._pending_tool_stats:
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
for category, bucket in self._pending_tool_stats.items():
|
||||
value = bucket["count"]
|
||||
if category in {"file_read", "file_write", "directory", "web_browse"} and bucket["targets"]:
|
||||
value = len(bucket["targets"])
|
||||
part = self._format_tool_stat(category, value)
|
||||
if part:
|
||||
parts.append(part)
|
||||
|
||||
self._pending_tool_stats = {}
|
||||
if not parts:
|
||||
return ""
|
||||
|
||||
summary = f"({','.join(parts)})"
|
||||
visible_buffer = self._buffer.rstrip(" \t")
|
||||
last_char = visible_buffer[-1:] if visible_buffer.strip() else ""
|
||||
prefix = ""
|
||||
if self._buffer and last_char != "\n":
|
||||
prefix = "\n"
|
||||
return f"{prefix}{summary}\n"
|
||||
|
||||
@staticmethod
|
||||
def _format_tool_stat(category: str, count: int) -> str:
|
||||
if count <= 0:
|
||||
return ""
|
||||
|
||||
if category == "search":
|
||||
return f"执行了 {count} 次搜索"
|
||||
if category == "file_read":
|
||||
return f"读取了 {count} 个文件"
|
||||
if category == "file_write":
|
||||
return f"修改了 {count} 个文件"
|
||||
if category == "directory":
|
||||
return f"查看了 {count} 个目录"
|
||||
if category == "web_browse":
|
||||
return f"浏览了 {count} 个网页"
|
||||
if category == "command":
|
||||
return f"执行了 {count} 条命令"
|
||||
if category == "data_query":
|
||||
return f"查询了 {count} 次数据"
|
||||
if category == "action":
|
||||
return f"执行了 {count} 次操作"
|
||||
if category == "interaction":
|
||||
return f"发起了 {count} 次交互"
|
||||
return f"调用了 {count} 次工具"
|
||||
|
||||
def _can_stream(self) -> bool:
|
||||
"""
|
||||
检查当前渠道是否支持流式输出(消息编辑)
|
||||
|
||||
@@ -124,9 +124,12 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message)
|
||||
else:
|
||||
# 非VERBOSE:工具边界至少补一个换行,避免工具前后的文本直接连在一起
|
||||
if self._stream_handler.last_buffer_char not in ("", "\n"):
|
||||
self._stream_handler.emit("\n")
|
||||
# 非VERBOSE:不逐条回显工具调用,转为在下一段文本前补一句聚合摘要
|
||||
self._stream_handler.record_tool_call(
|
||||
tool_name=self.name,
|
||||
tool_message=tool_message,
|
||||
tool_kwargs=kwargs,
|
||||
)
|
||||
else:
|
||||
# 未启用流式传输,不发送任何工具消息内容
|
||||
pass
|
||||
|
||||
@@ -69,9 +69,15 @@ class _OpenAIStreamingHandler(StreamingHandler):
|
||||
self._event_queue = queue
|
||||
|
||||
def emit(self, token: str):
|
||||
super().emit(token)
|
||||
if token and self._event_queue is not None:
|
||||
self._event_queue.put_nowait(token)
|
||||
emitted = super().emit(token)
|
||||
if emitted and self._event_queue is not None:
|
||||
self._event_queue.put_nowait(emitted)
|
||||
|
||||
def flush_pending_tool_summary(self) -> str:
|
||||
emitted = super().flush_pending_tool_summary()
|
||||
if emitted and self._event_queue is not None:
|
||||
self._event_queue.put_nowait(emitted)
|
||||
return emitted
|
||||
|
||||
async def start_streaming(
|
||||
self,
|
||||
|
||||
@@ -9,6 +9,7 @@ if not hasattr(langchain_agents, "create_agent"):
|
||||
|
||||
from app.agent.callback import StreamingHandler
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.api.endpoints.openai import _OpenAIStreamingHandler
|
||||
from app.core.config import settings
|
||||
from app.schemas.message import MessageResponse
|
||||
from app.schemas.types import MessageChannel
|
||||
@@ -37,23 +38,101 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
buffered_message = await handler.take()
|
||||
return result, buffered_message
|
||||
|
||||
def test_non_verbose_tool_call_appends_newline_separator(self):
|
||||
def test_non_verbose_tool_call_flushes_summary_on_take(self):
|
||||
result, buffered_message = asyncio.run(self._run_tool("prefix"))
|
||||
|
||||
self.assertEqual(result, "ok")
|
||||
self.assertEqual(buffered_message, "prefix\n")
|
||||
self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n")
|
||||
|
||||
def test_non_verbose_tool_call_does_not_duplicate_newline(self):
|
||||
def test_non_verbose_tool_call_reuses_existing_newline_before_summary(self):
|
||||
result, buffered_message = asyncio.run(self._run_tool("prefix\n"))
|
||||
|
||||
self.assertEqual(result, "ok")
|
||||
self.assertEqual(buffered_message, "prefix\n")
|
||||
self.assertEqual(buffered_message, "prefix\n(调用了 1 次工具)\n")
|
||||
|
||||
def test_non_verbose_tool_call_keeps_empty_buffer_unchanged(self):
|
||||
def test_non_verbose_tool_call_emits_summary_even_when_buffer_was_empty(self):
|
||||
result, buffered_message = asyncio.run(self._run_tool(""))
|
||||
|
||||
self.assertEqual(result, "ok")
|
||||
self.assertEqual(buffered_message, "")
|
||||
self.assertEqual(buffered_message, "(调用了 1 次工具)\n")
|
||||
|
||||
def test_non_verbose_tool_summary_is_inserted_before_next_text(self):
|
||||
async def _run():
|
||||
tool = DummyTool(session_id="session-1", user_id="10001")
|
||||
handler = StreamingHandler()
|
||||
await handler.start_streaming()
|
||||
handler.emit("让我来检查一下:")
|
||||
tool.set_stream_handler(handler)
|
||||
|
||||
with patch.object(settings, "AI_AGENT_VERBOSE", False):
|
||||
await tool._arun(explanation="run test tool")
|
||||
|
||||
handler.emit("已经拿到结果")
|
||||
return await handler.take()
|
||||
|
||||
buffered_message = asyncio.run(_run())
|
||||
|
||||
self.assertEqual(
|
||||
buffered_message,
|
||||
"让我来检查一下:\n(调用了 1 次工具)\n已经拿到结果",
|
||||
)
|
||||
|
||||
def test_non_verbose_tool_summary_aggregates_multiple_categories(self):
|
||||
async def _run():
|
||||
handler = StreamingHandler()
|
||||
await handler.start_streaming()
|
||||
handler.emit("处理中:")
|
||||
handler.record_tool_call(
|
||||
tool_name="search_web",
|
||||
tool_message="搜索网络内容: MoviePilot",
|
||||
tool_kwargs={"query": "MoviePilot"},
|
||||
)
|
||||
handler.record_tool_call(
|
||||
tool_name="search_web",
|
||||
tool_message="搜索网络内容: agent streaming",
|
||||
tool_kwargs={"query": "agent streaming"},
|
||||
)
|
||||
handler.record_tool_call(
|
||||
tool_name="read_file",
|
||||
tool_message="读取文件: a.py",
|
||||
tool_kwargs={"file_path": "/tmp/a.py"},
|
||||
)
|
||||
handler.record_tool_call(
|
||||
tool_name="read_file",
|
||||
tool_message="读取文件: b.py",
|
||||
tool_kwargs={"file_path": "/tmp/b.py"},
|
||||
)
|
||||
handler.emit("继续分析")
|
||||
return await handler.take()
|
||||
|
||||
buffered_message = asyncio.run(_run())
|
||||
|
||||
self.assertEqual(
|
||||
buffered_message,
|
||||
"处理中:\n(执行了 2 次搜索,读取了 2 个文件)\n继续分析",
|
||||
)
|
||||
|
||||
def test_openai_streaming_handler_flushes_pending_summary_to_queue(self):
|
||||
async def _run():
|
||||
handler = _OpenAIStreamingHandler()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
handler.bind_queue(queue)
|
||||
await handler.start_streaming()
|
||||
handler.record_tool_call(
|
||||
tool_name="read_file",
|
||||
tool_message="读取文件: app.py",
|
||||
tool_kwargs={"file_path": "/tmp/app.py"},
|
||||
)
|
||||
emitted = handler.flush_pending_tool_summary()
|
||||
queued = await queue.get()
|
||||
buffered_message = await handler.take()
|
||||
return emitted, queued, buffered_message
|
||||
|
||||
emitted, queued, buffered_message = asyncio.run(_run())
|
||||
|
||||
self.assertEqual(emitted, "(读取了 1 个文件)\n")
|
||||
self.assertEqual(queued, emitted)
|
||||
self.assertEqual(buffered_message, emitted)
|
||||
|
||||
def test_flush_sends_direct_message_via_threadpool(self):
|
||||
handler = StreamingHandler()
|
||||
|
||||
Reference in New Issue
Block a user