mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-08 09:10:32 +08:00
feat(agent): Telegram与Agent相互时支持流式输出
This commit is contained in:
@@ -5,7 +5,8 @@ from typing import Dict, List
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import (
|
||||
SummarizationMiddleware, LLMToolSelectorMiddleware,
|
||||
SummarizationMiddleware,
|
||||
LLMToolSelectorMiddleware,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
@@ -36,12 +37,12 @@ class MoviePilotAgent:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
):
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
@@ -80,9 +81,7 @@ class MoviePilotAgent:
|
||||
# 系统提示词
|
||||
system_prompt = prompt_manager.get_agent_prompt(
|
||||
channel=self.channel
|
||||
).format(
|
||||
current_date=strftime('%Y-%m-%d')
|
||||
)
|
||||
).format(current_date=strftime("%Y-%m-%d"))
|
||||
|
||||
# LLM 模型(用于 agent 执行)
|
||||
llm = self._initialize_llm()
|
||||
@@ -93,21 +92,15 @@ class MoviePilotAgent:
|
||||
# 中间件
|
||||
middlewares = [
|
||||
# 工具选择
|
||||
LLMToolSelectorMiddleware(
|
||||
model=llm,
|
||||
max_tools=20
|
||||
),
|
||||
LLMToolSelectorMiddleware(model=llm, max_tools=20),
|
||||
# 记忆管理
|
||||
MemoryMiddleware(
|
||||
sources=[str(settings.CONFIG_PATH / "agent" / "MEMORY.md")]
|
||||
),
|
||||
# 上下文压缩
|
||||
SummarizationMiddleware(
|
||||
model=llm,
|
||||
trigger=("fraction", 0.85)
|
||||
),
|
||||
SummarizationMiddleware(model=llm, trigger=("fraction", 0.85)),
|
||||
# 错误工具调用修复
|
||||
PatchToolCallsMiddleware()
|
||||
PatchToolCallsMiddleware(),
|
||||
]
|
||||
|
||||
return create_agent(
|
||||
@@ -130,8 +123,7 @@ class MoviePilotAgent:
|
||||
|
||||
# 获取历史消息
|
||||
messages = memory_manager.get_agent_messages(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id
|
||||
session_id=self.session_id, user_id=self.user_id
|
||||
)
|
||||
|
||||
# 增加用户消息
|
||||
@@ -150,6 +142,7 @@ class MoviePilotAgent:
|
||||
"""
|
||||
调用 LangGraph Agent,通过 astream_events 流式获取 token,
|
||||
同时用 UsageMetadataCallbackHandler 统计 token 用量。
|
||||
支持流式输出:在支持消息编辑的渠道上实时推送 token。
|
||||
"""
|
||||
try:
|
||||
# Agent运行配置
|
||||
@@ -162,37 +155,57 @@ class MoviePilotAgent:
|
||||
# 创建智能体
|
||||
agent = self._create_agent()
|
||||
|
||||
# 启动流式输出(内部会检查渠道是否支持消息编辑)
|
||||
await self.stream_handler.start_streaming(
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
user_id=self.user_id,
|
||||
username=self.username,
|
||||
)
|
||||
|
||||
# 流式运行智能体
|
||||
async for chunk in agent.astream(
|
||||
{"messages": messages},
|
||||
stream_mode="messages",
|
||||
config=agent_config,
|
||||
version="v2"
|
||||
{"messages": messages},
|
||||
stream_mode="messages",
|
||||
config=agent_config,
|
||||
version="v2",
|
||||
):
|
||||
# 处理流式token(过滤工具调用token,只保留模型生成的内容)
|
||||
if chunk["type"] == "messages":
|
||||
token, metadata = chunk["data"]
|
||||
if (token and hasattr(token, "tool_call_chunks")
|
||||
and not token.tool_call_chunks):
|
||||
if (
|
||||
token
|
||||
and hasattr(token, "tool_call_chunks")
|
||||
and not token.tool_call_chunks
|
||||
):
|
||||
if token.content:
|
||||
self.stream_handler.emit(token.content)
|
||||
|
||||
# 发送最终消息给用户
|
||||
await self.send_agent_message(
|
||||
self.stream_handler.take()
|
||||
)
|
||||
# 停止流式输出,返回是否已通过流式编辑发送了所有内容
|
||||
all_sent_via_stream = await self.stream_handler.stop_streaming()
|
||||
|
||||
if not all_sent_via_stream:
|
||||
# 流式输出未能发送全部内容(渠道不支持编辑,或发送失败)
|
||||
# 通过常规方式发送剩余内容
|
||||
remaining_text = await self.stream_handler.take()
|
||||
if remaining_text:
|
||||
await self.send_agent_message(remaining_text)
|
||||
|
||||
# 保存消息
|
||||
memory_manager.save_agent_messages(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
messages=agent.get_state(agent_config).values.get("messages", [])
|
||||
messages=agent.get_state(agent_config).values.get("messages", []),
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# 确保取消时也停止流式输出
|
||||
await self.stream_handler.stop_streaming()
|
||||
logger.info(f"Agent执行被取消: session_id={self.session_id}")
|
||||
return "任务已取消", {}
|
||||
except Exception as e:
|
||||
# 确保异常时也停止流式输出
|
||||
await self.stream_handler.stop_streaming()
|
||||
logger.error(f"Agent执行失败: {e} - {traceback.format_exc()}")
|
||||
return str(e), {}
|
||||
|
||||
@@ -243,13 +256,13 @@ class AgentManager:
|
||||
self.active_agents.clear()
|
||||
|
||||
async def process_message(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
message: str,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
message: str,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息
|
||||
|
||||
Reference in New Issue
Block a user