fix: keep agent typing status while queued

This commit is contained in:
jxxghp
2026-05-22 16:46:25 +08:00
parent cb15b711b9
commit 2b5528c0ac
4 changed files with 231 additions and 18 deletions

View File

@@ -952,6 +952,8 @@ class AgentManager:
self._session_queues: Dict[str, asyncio.Queue] = {}
# 每个会话的worker任务
self._session_workers: Dict[str, asyncio.Task] = {}
# typing 这类状态按会话/聊天共享,前一条任务结束时可能仍需延续到后续排队消息。
self._deferred_processing_statuses: Dict[str, dict] = {}
def get_session_status(self, session_id: str) -> dict[str, Any]:
"""获取会话当前模型与 token 使用状态。"""
@@ -1007,6 +1009,7 @@ class AgentManager:
pass
self._session_workers.clear()
self._session_queues.clear()
self._deferred_processing_statuses.clear()
for agent in self.active_agents.values():
await agent.cleanup()
self.active_agents.clear()
@@ -1100,8 +1103,10 @@ class AgentManager:
except Exception as e:
logger.error(f"处理会话 {session_id} 的消息失败: {e}")
finally:
await _async_finish_processing_status(
task.processing_status, task.user_id
await self._finish_task_processing_status(
session_id=session_id,
task=task,
queue=queue,
)
queue.task_done()
@@ -1116,6 +1121,52 @@ class AgentManager:
and self._session_queues[session_id].empty()
):
self._session_queues.pop(session_id, None)
self._deferred_processing_statuses.pop(session_id, None)
@staticmethod
def _is_shared_processing_status(status: Optional[dict]) -> bool:
"""
判断状态是否属于同一聊天窗口共享的处理提示。
reaction 绑定到具体消息应按消息收口typing 绑定到会话/聊天,需要等队列空闲再关闭。
"""
metadata = (status or {}).get("metadata") or {}
return isinstance(metadata, dict) and metadata.get("kind") == "typing"
async def _finish_task_processing_status(
self,
session_id: str,
task: _MessageTask,
queue: asyncio.Queue,
) -> None:
"""
根据会话队列状态结束或延后处理提示。
当后面还有排队消息时typing 状态继续保留;队列真正空闲后再统一关闭。
"""
status = task.processing_status
if self._is_shared_processing_status(status) and not queue.empty():
self._deferred_processing_statuses[session_id] = status
return
if status:
await _async_finish_processing_status(status, task.user_id)
if self._is_shared_processing_status(status):
self._deferred_processing_statuses.pop(session_id, None)
elif queue.empty():
deferred_status = self._deferred_processing_statuses.pop(
session_id, None
)
if deferred_status:
await _async_finish_processing_status(
deferred_status, task.user_id
)
return
if not queue.empty():
return
deferred_status = self._deferred_processing_statuses.pop(session_id, None)
if deferred_status:
await _async_finish_processing_status(deferred_status, task.user_id)
async def _process_message_internal(self, task: _MessageTask):
"""
@@ -1181,6 +1232,7 @@ class AgentManager:
break
self._session_queues.pop(session_id, None)
stopped = True
self._deferred_processing_statuses.pop(session_id, None)
if stopped:
logger.info(f"会话 {session_id} 的Agent推理已应急停止")
@@ -1204,6 +1256,7 @@ class AgentManager:
# 清理队列
self._session_queues.pop(session_id, None)
self._deferred_processing_statuses.pop(session_id, None)
# 清理agent
if session_id in self.active_agents:

View File

@@ -15,6 +15,7 @@ from app.schemas import (
CommandRegisterEventData,
NotificationConf,
MessageResponse,
NotificationType,
)
from app.schemas.types import ModuleType, ChainEventType
from app.utils.structures import DictUtils
@@ -451,6 +452,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
return
client: Telegram = self.get_instance(conf.name)
if client:
stop_typing = message.mtype != NotificationType.Agent
if message.file_path:
client.send_file(
file_path=message.file_path,
@@ -459,6 +461,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
text=message.text,
userid=userid,
original_chat_id=message.original_chat_id,
stop_typing=stop_typing,
)
elif message.voice_path:
client.send_voice(
@@ -466,6 +469,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
userid=userid,
caption=message.voice_caption,
original_chat_id=message.original_chat_id,
stop_typing=stop_typing,
)
else:
client.send_msg(
@@ -478,6 +482,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
disable_web_page_preview=message.disable_web_page_preview,
stop_typing=stop_typing,
)
def post_medias_message(
@@ -502,6 +507,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
stop_typing=message.mtype != NotificationType.Agent,
)
def post_torrents_message(
@@ -526,6 +532,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
stop_typing=message.mtype != NotificationType.Agent,
)
def delete_message(
@@ -585,12 +592,14 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
continue
client: Telegram = self.get_instance(conf.name)
if client:
stop_typing = not (metadata or {}).get("agent_managed_typing")
result = client.edit_msg(
chat_id=chat_id,
message_id=message_id,
text=text,
title=title,
buttons=buttons,
stop_typing=stop_typing,
)
if result:
return True
@@ -665,12 +674,14 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
return None
client: Telegram = self.get_instance(conf.name)
if client:
agent_managed_typing = message.mtype == NotificationType.Agent
if message.voice_path:
result = client.send_voice(
voice_path=message.voice_path,
userid=userid,
caption=message.voice_caption,
original_chat_id=message.original_chat_id,
stop_typing=not agent_managed_typing,
)
else:
result = client.send_msg(
@@ -680,6 +691,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
userid=userid,
link=message.link,
disable_web_page_preview=message.disable_web_page_preview,
stop_typing=not agent_managed_typing,
)
if result and result.get("success"):
return MessageResponse(
@@ -687,6 +699,9 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
chat_id=result.get("chat_id"),
channel=MessageChannel.Telegram,
source=conf.name,
metadata={"agent_managed_typing": True}
if agent_managed_typing
else None,
success=True,
)
return None

View File

@@ -422,6 +422,16 @@ class Telegram:
if task and task.is_alive() and task is not threading.current_thread():
task.join(timeout=1)
def _stop_typing_if_needed(
self, chat_id: Union[str, int], stop_typing: bool
) -> None:
"""
按调用方要求停止 typing。
Agent 回复和流式编辑由 worker 统一收口,避免中途发送消息时误关后续排队消息的状态。
"""
if stop_typing:
self._stop_typing_task(chat_id)
def stop_typing(
self,
chat_id: Optional[Union[str, int]] = None,
@@ -453,6 +463,7 @@ class Telegram:
original_message_id: Optional[int] = None,
original_chat_id: Optional[str] = None,
disable_web_page_preview: Optional[bool] = None,
stop_typing: bool = True,
) -> Optional[dict]:
"""
发送Telegram消息
@@ -465,6 +476,7 @@ class Telegram:
:param original_message_id: 原消息ID如果提供则编辑原消息
:param original_chat_id: 原消息的聊天ID编辑消息时需要
:param disable_web_page_preview: 是否禁用链接预览
:param stop_typing: 发送完成后是否立即停止 typing
:return: 包含 message_id, chat_id, success 的字典
"""
if not self._telegram_token or not self._telegram_chat_id:
@@ -474,7 +486,7 @@ class Telegram:
chat_id = self._determine_target_chat_id(userid, original_chat_id)
if not title and not text:
logger.warn("标题和内容不能同时为空")
self._stop_typing_task(chat_id)
self._stop_typing_if_needed(chat_id, stop_typing)
return {"success": False}
try:
@@ -510,7 +522,7 @@ class Telegram:
image,
disable_web_page_preview=disable_web_page_preview,
)
self._stop_typing_task(chat_id)
self._stop_typing_if_needed(chat_id, stop_typing)
return {
"success": bool(result),
"message_id": original_message_id,
@@ -525,7 +537,7 @@ class Telegram:
reply_markup=reply_markup,
disable_web_page_preview=disable_web_page_preview,
)
self._stop_typing_task(chat_id)
self._stop_typing_if_needed(chat_id, stop_typing)
if sent and hasattr(sent, "message_id"):
return {
"success": True,
@@ -538,7 +550,7 @@ class Telegram:
except Exception as msg_e:
logger.error(f"发送消息失败:{msg_e}")
self._stop_typing_task(chat_id)
self._stop_typing_if_needed(chat_id, stop_typing)
return {"success": False}
def send_voice(
@@ -547,6 +559,7 @@ class Telegram:
userid: Optional[str] = None,
caption: Optional[str] = None,
original_chat_id: Optional[str] = None,
stop_typing: bool = True,
) -> Optional[dict]:
"""
发送Telegram语音消息。
@@ -558,7 +571,7 @@ class Telegram:
voice_file = Path(voice_path)
if not voice_file.exists():
logger.error(f"语音文件不存在: {voice_file}")
self._stop_typing_task(chat_id)
self._stop_typing_if_needed(chat_id, stop_typing)
return {"success": False}
try:
@@ -569,7 +582,7 @@ class Telegram:
caption=standardize(caption) if caption else None,
parse_mode="MarkdownV2" if caption else None,
)
self._stop_typing_task(chat_id)
self._stop_typing_if_needed(chat_id, stop_typing)
if sent and hasattr(sent, "message_id"):
return {
"success": True,
@@ -579,7 +592,7 @@ class Telegram:
return {"success": bool(sent)}
except Exception as err:
logger.error(f"发送语音消息失败:{err}")
self._stop_typing_task(chat_id)
self._stop_typing_if_needed(chat_id, stop_typing)
return {"success": False}
finally:
try:
@@ -595,6 +608,7 @@ class Telegram:
text: Optional[str] = None,
file_name: Optional[str] = None,
original_chat_id: Optional[str] = None,
stop_typing: bool = True,
) -> Optional[dict]:
"""
发送本地图片或文件给 Telegram 用户。
@@ -606,7 +620,7 @@ class Telegram:
local_file = Path(file_path)
if not local_file.exists() or not local_file.is_file():
logger.error(f"附件文件不存在: {local_file}")
self._stop_typing_task(chat_id)
self._stop_typing_if_needed(chat_id, stop_typing)
return {"success": False}
send_name = file_name or local_file.name
@@ -639,7 +653,7 @@ class Telegram:
caption=standardize(caption) if caption else None,
parse_mode="MarkdownV2" if caption else None,
)
self._stop_typing_task(chat_id)
self._stop_typing_if_needed(chat_id, stop_typing)
if sent and hasattr(sent, "message_id"):
return {
"success": True,
@@ -649,7 +663,7 @@ class Telegram:
return {"success": bool(sent)}
except Exception as err:
logger.error(f"发送本地附件失败: {err}")
self._stop_typing_task(chat_id)
self._stop_typing_if_needed(chat_id, stop_typing)
return {"success": False}
def _determine_target_chat_id(
@@ -685,6 +699,7 @@ class Telegram:
buttons: Optional[List[List[Dict]]] = None,
original_message_id: Optional[int] = None,
original_chat_id: Optional[str] = None,
stop_typing: bool = True,
) -> Optional[bool]:
"""
发送媒体列表消息
@@ -695,11 +710,12 @@ class Telegram:
:param buttons: 按钮列表,格式:[[{"text": "按钮文本", "callback_data": "回调数据"}]]
:param original_message_id: 原消息ID如果提供则编辑原消息
:param original_chat_id: 原消息的聊天ID编辑消息时需要
:param stop_typing: 发送完成后是否立即停止 typing
"""
if not self._telegram_token or not self._telegram_chat_id:
return None
# 列表消息也可能是一次交互的最终响应,需要确保 typing 状态在发送后结束
# 列表消息也可能是一次交互的最终响应,默认在发送后结束 typing
chat_id = self._determine_target_chat_id(userid, original_chat_id)
try:
index, image, caption = 1, "", "*%s*" % title
@@ -752,7 +768,7 @@ class Telegram:
logger.error(f"发送消息失败:{msg_e}")
return False
finally:
self._stop_typing_task(chat_id)
self._stop_typing_if_needed(chat_id, stop_typing)
def send_torrents_msg(
self,
@@ -763,6 +779,7 @@ class Telegram:
buttons: Optional[List[List[Dict]]] = None,
original_message_id: Optional[int] = None,
original_chat_id: Optional[str] = None,
stop_typing: bool = True,
) -> Optional[bool]:
"""
发送种子列表消息
@@ -773,11 +790,12 @@ class Telegram:
:param buttons: 按钮列表,格式:[[{"text": "按钮文本", "callback_data": "回调数据"}]]
:param original_message_id: 原消息ID如果提供则编辑原消息
:param original_chat_id: 原消息的聊天ID编辑消息时需要
:param stop_typing: 发送完成后是否立即停止 typing
"""
if not self._telegram_token or not self._telegram_chat_id:
return None
# 资源列表是搜索交互的常见出口,也必须统一释放 typing 状态
# 资源列表是搜索交互的常见出口,默认在发送后结束 typing。
chat_id = self._determine_target_chat_id(userid, original_chat_id)
try:
index, caption = 1, "*%s*" % title
@@ -829,7 +847,7 @@ class Telegram:
logger.error(f"发送消息失败:{msg_e}")
return False
finally:
self._stop_typing_task(chat_id)
self._stop_typing_if_needed(chat_id, stop_typing)
@staticmethod
def _create_inline_keyboard(buttons: List[List[Dict]]) -> InlineKeyboardMarkup:
@@ -919,6 +937,7 @@ class Telegram:
text: str,
title: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
stop_typing: bool = True,
) -> Optional[bool]:
"""
编辑Telegram消息公开方法
@@ -927,6 +946,7 @@ class Telegram:
:param text: 新的消息内容
:param title: 消息标题
:param buttons: 新的按钮列表
:param stop_typing: 编辑完成后是否立即停止 typing
:return: 编辑是否成功
"""
if not self._bot:
@@ -952,7 +972,7 @@ class Telegram:
logger.error(f"编辑Telegram消息异常: {str(e)}")
return False
finally:
self._stop_typing_task(chat_id)
self._stop_typing_if_needed(chat_id, stop_typing)
def __edit_message(
self,

View File

@@ -1,7 +1,10 @@
import asyncio
import time
import unittest
from unittest.mock import Mock, patch
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock, patch
from app.agent import AgentManager, _MessageTask
from app.chain.message import MessageChain
from app.modules.telegram.telegram import Telegram
from app.schemas.types import MessageChannel
@@ -63,6 +66,22 @@ class TestTelegramTypingLifecycle(unittest.TestCase):
self.assertNotIn("chat-3", Telegram._typing_tasks)
def test_agent_managed_send_msg_keeps_typing_for_worker_cleanup(self):
telegram = self._telegram_client()
sent = SimpleNamespace(message_id=1, chat=SimpleNamespace(id="chat-1"))
with patch.object(
telegram, "_Telegram__send_request", return_value=sent
), patch.object(telegram, "_stop_typing_task") as stop_typing:
result = telegram.send_msg(
title="处理中",
userid="10001",
stop_typing=False,
)
self.assertTrue(result["success"])
stop_typing.assert_not_called()
def test_slash_command_stops_typing_when_message_handler_returns(self):
chain = MessageChain.__new__(MessageChain)
status = MessageChain._ProcessingStatus(
@@ -184,6 +203,112 @@ class TestTelegramTypingLifecycle(unittest.TestCase):
status=status.to_dict(),
)
def test_agent_manager_defers_shared_typing_until_queued_task_finishes(self):
async def _run():
manager = AgentManager()
queue = asyncio.Queue()
first = _MessageTask(
session_id="session-1",
user_id="10001",
message="第一条",
processing_status={
"channel": MessageChannel.Telegram.value,
"source": "telegram-test",
"userid": "10001",
"chat_id": "-100",
"metadata": {"kind": "typing"},
},
)
second = _MessageTask(
session_id="session-1",
user_id="10001",
message="第二条",
processing_status={
"channel": MessageChannel.Telegram.value,
"source": "telegram-test",
"userid": "10001",
"chat_id": "-100",
"metadata": {"kind": "typing"},
},
)
await queue.put(second)
with patch(
"app.agent._async_finish_processing_status",
new_callable=AsyncMock,
) as finish_status:
await manager._finish_task_processing_status(
session_id="session-1",
task=first,
queue=queue,
)
finish_status.assert_not_awaited()
self.assertEqual(
manager._deferred_processing_statuses["session-1"],
first.processing_status,
)
queue.get_nowait()
await manager._finish_task_processing_status(
session_id="session-1",
task=second,
queue=queue,
)
finish_status.assert_awaited_once_with(
second.processing_status, "10001"
)
self.assertNotIn("session-1", manager._deferred_processing_statuses)
asyncio.run(_run())
def test_agent_manager_closes_deferred_typing_when_next_task_has_no_status(self):
async def _run():
manager = AgentManager()
queue = asyncio.Queue()
first = _MessageTask(
session_id="session-1",
user_id="10001",
message="第一条",
processing_status={
"channel": MessageChannel.Telegram.value,
"source": "telegram-test",
"userid": "10001",
"chat_id": "-100",
"metadata": {"kind": "typing"},
},
)
second = _MessageTask(
session_id="session-1",
user_id="10001",
message="第二条",
processing_status=None,
)
await queue.put(second)
with patch(
"app.agent._async_finish_processing_status",
new_callable=AsyncMock,
) as finish_status:
await manager._finish_task_processing_status(
session_id="session-1",
task=first,
queue=queue,
)
queue.get_nowait()
await manager._finish_task_processing_status(
session_id="session-1",
task=second,
queue=queue,
)
finish_status.assert_awaited_once_with(
first.processing_status, "10001"
)
self.assertNotIn("session-1", manager._deferred_processing_statuses)
asyncio.run(_run())
if __name__ == "__main__":
unittest.main()