From 2b5528c0aca873b1627474d9560649fe3df8802d Mon Sep 17 00:00:00 2001 From: jxxghp Date: Fri, 22 May 2026 16:46:25 +0800 Subject: [PATCH] fix: keep agent typing status while queued --- app/agent/__init__.py | 57 ++++++++++- app/modules/telegram/__init__.py | 15 +++ app/modules/telegram/telegram.py | 50 +++++++--- tests/test_telegram_typing_lifecycle.py | 127 +++++++++++++++++++++++- 4 files changed, 231 insertions(+), 18 deletions(-) diff --git a/app/agent/__init__.py b/app/agent/__init__.py index e8f22ee1..6f058a3c 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -952,6 +952,8 @@ class AgentManager: self._session_queues: Dict[str, asyncio.Queue] = {} # 每个会话的worker任务 self._session_workers: Dict[str, asyncio.Task] = {} + # typing 这类状态按会话/聊天共享,前一条任务结束时可能仍需延续到后续排队消息。 + self._deferred_processing_statuses: Dict[str, dict] = {} def get_session_status(self, session_id: str) -> dict[str, Any]: """获取会话当前模型与 token 使用状态。""" @@ -1007,6 +1009,7 @@ class AgentManager: pass self._session_workers.clear() self._session_queues.clear() + self._deferred_processing_statuses.clear() for agent in self.active_agents.values(): await agent.cleanup() self.active_agents.clear() @@ -1100,8 +1103,10 @@ class AgentManager: except Exception as e: logger.error(f"处理会话 {session_id} 的消息失败: {e}") finally: - await _async_finish_processing_status( - task.processing_status, task.user_id + await self._finish_task_processing_status( + session_id=session_id, + task=task, + queue=queue, ) queue.task_done() @@ -1116,6 +1121,52 @@ class AgentManager: and self._session_queues[session_id].empty() ): self._session_queues.pop(session_id, None) + self._deferred_processing_statuses.pop(session_id, None) + + @staticmethod + def _is_shared_processing_status(status: Optional[dict]) -> bool: + """ + 判断状态是否属于同一聊天窗口共享的处理提示。 + reaction 绑定到具体消息,应按消息收口;typing 绑定到会话/聊天,需要等队列空闲再关闭。 + """ + metadata = (status or {}).get("metadata") or {} + return isinstance(metadata, dict) and metadata.get("kind") == "typing" + + async def _finish_task_processing_status( + self, + session_id: str, + task: _MessageTask, + queue: asyncio.Queue, + ) -> None: + """ + 根据会话队列状态结束或延后处理提示。 + 当后面还有排队消息时,typing 状态继续保留;队列真正空闲后再统一关闭。 + """ + status = task.processing_status + if self._is_shared_processing_status(status) and not queue.empty(): + self._deferred_processing_statuses[session_id] = status + return + + if status: + await _async_finish_processing_status(status, task.user_id) + if self._is_shared_processing_status(status): + self._deferred_processing_statuses.pop(session_id, None) + elif queue.empty(): + deferred_status = self._deferred_processing_statuses.pop( + session_id, None + ) + if deferred_status: + await _async_finish_processing_status( + deferred_status, task.user_id + ) + return + + if not queue.empty(): + return + + deferred_status = self._deferred_processing_statuses.pop(session_id, None) + if deferred_status: + await _async_finish_processing_status(deferred_status, task.user_id) async def _process_message_internal(self, task: _MessageTask): """ @@ -1181,6 +1232,7 @@ class AgentManager: break self._session_queues.pop(session_id, None) stopped = True + self._deferred_processing_statuses.pop(session_id, None) if stopped: logger.info(f"会话 {session_id} 的Agent推理已应急停止") @@ -1204,6 +1256,7 @@ class AgentManager: # 清理队列 self._session_queues.pop(session_id, None) + self._deferred_processing_statuses.pop(session_id, None) # 清理agent if session_id in self.active_agents: diff --git a/app/modules/telegram/__init__.py b/app/modules/telegram/__init__.py index f7392de4..8fe53776 100644 --- a/app/modules/telegram/__init__.py +++ b/app/modules/telegram/__init__.py @@ -15,6 +15,7 @@ from app.schemas import ( CommandRegisterEventData, NotificationConf, MessageResponse, + NotificationType, ) from app.schemas.types import ModuleType, ChainEventType from app.utils.structures import DictUtils @@ -451,6 +452,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): return client: Telegram = self.get_instance(conf.name) if client: + stop_typing = message.mtype != NotificationType.Agent if message.file_path: client.send_file( file_path=message.file_path, @@ -459,6 +461,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): text=message.text, userid=userid, original_chat_id=message.original_chat_id, + stop_typing=stop_typing, ) elif message.voice_path: client.send_voice( @@ -466,6 +469,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): userid=userid, caption=message.voice_caption, original_chat_id=message.original_chat_id, + stop_typing=stop_typing, ) else: client.send_msg( @@ -478,6 +482,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): original_message_id=message.original_message_id, original_chat_id=message.original_chat_id, disable_web_page_preview=message.disable_web_page_preview, + stop_typing=stop_typing, ) def post_medias_message( @@ -502,6 +507,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): buttons=message.buttons, original_message_id=message.original_message_id, original_chat_id=message.original_chat_id, + stop_typing=message.mtype != NotificationType.Agent, ) def post_torrents_message( @@ -526,6 +532,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): buttons=message.buttons, original_message_id=message.original_message_id, original_chat_id=message.original_chat_id, + stop_typing=message.mtype != NotificationType.Agent, ) def delete_message( @@ -585,12 +592,14 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): continue client: Telegram = self.get_instance(conf.name) if client: + stop_typing = not (metadata or {}).get("agent_managed_typing") result = client.edit_msg( chat_id=chat_id, message_id=message_id, text=text, title=title, buttons=buttons, + stop_typing=stop_typing, ) if result: return True @@ -665,12 +674,14 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): return None client: Telegram = self.get_instance(conf.name) if client: + agent_managed_typing = message.mtype == NotificationType.Agent if message.voice_path: result = client.send_voice( voice_path=message.voice_path, userid=userid, caption=message.voice_caption, original_chat_id=message.original_chat_id, + stop_typing=not agent_managed_typing, ) else: result = client.send_msg( @@ -680,6 +691,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): userid=userid, link=message.link, disable_web_page_preview=message.disable_web_page_preview, + stop_typing=not agent_managed_typing, ) if result and result.get("success"): return MessageResponse( @@ -687,6 +699,9 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): chat_id=result.get("chat_id"), channel=MessageChannel.Telegram, source=conf.name, + metadata={"agent_managed_typing": True} + if agent_managed_typing + else None, success=True, ) return None diff --git a/app/modules/telegram/telegram.py b/app/modules/telegram/telegram.py index e3be15fd..9b22d09c 100644 --- a/app/modules/telegram/telegram.py +++ b/app/modules/telegram/telegram.py @@ -422,6 +422,16 @@ class Telegram: if task and task.is_alive() and task is not threading.current_thread(): task.join(timeout=1) + def _stop_typing_if_needed( + self, chat_id: Union[str, int], stop_typing: bool + ) -> None: + """ + 按调用方要求停止 typing。 + Agent 回复和流式编辑由 worker 统一收口,避免中途发送消息时误关后续排队消息的状态。 + """ + if stop_typing: + self._stop_typing_task(chat_id) + def stop_typing( self, chat_id: Optional[Union[str, int]] = None, @@ -453,6 +463,7 @@ class Telegram: original_message_id: Optional[int] = None, original_chat_id: Optional[str] = None, disable_web_page_preview: Optional[bool] = None, + stop_typing: bool = True, ) -> Optional[dict]: """ 发送Telegram消息 @@ -465,6 +476,7 @@ class Telegram: :param original_message_id: 原消息ID,如果提供则编辑原消息 :param original_chat_id: 原消息的聊天ID,编辑消息时需要 :param disable_web_page_preview: 是否禁用链接预览 + :param stop_typing: 发送完成后是否立即停止 typing :return: 包含 message_id, chat_id, success 的字典 """ if not self._telegram_token or not self._telegram_chat_id: @@ -474,7 +486,7 @@ class Telegram: chat_id = self._determine_target_chat_id(userid, original_chat_id) if not title and not text: logger.warn("标题和内容不能同时为空") - self._stop_typing_task(chat_id) + self._stop_typing_if_needed(chat_id, stop_typing) return {"success": False} try: @@ -510,7 +522,7 @@ class Telegram: image, disable_web_page_preview=disable_web_page_preview, ) - self._stop_typing_task(chat_id) + self._stop_typing_if_needed(chat_id, stop_typing) return { "success": bool(result), "message_id": original_message_id, @@ -525,7 +537,7 @@ class Telegram: reply_markup=reply_markup, disable_web_page_preview=disable_web_page_preview, ) - self._stop_typing_task(chat_id) + self._stop_typing_if_needed(chat_id, stop_typing) if sent and hasattr(sent, "message_id"): return { "success": True, @@ -538,7 +550,7 @@ class Telegram: except Exception as msg_e: logger.error(f"发送消息失败:{msg_e}") - self._stop_typing_task(chat_id) + self._stop_typing_if_needed(chat_id, stop_typing) return {"success": False} def send_voice( @@ -547,6 +559,7 @@ class Telegram: userid: Optional[str] = None, caption: Optional[str] = None, original_chat_id: Optional[str] = None, + stop_typing: bool = True, ) -> Optional[dict]: """ 发送Telegram语音消息。 @@ -558,7 +571,7 @@ class Telegram: voice_file = Path(voice_path) if not voice_file.exists(): logger.error(f"语音文件不存在: {voice_file}") - self._stop_typing_task(chat_id) + self._stop_typing_if_needed(chat_id, stop_typing) return {"success": False} try: @@ -569,7 +582,7 @@ class Telegram: caption=standardize(caption) if caption else None, parse_mode="MarkdownV2" if caption else None, ) - self._stop_typing_task(chat_id) + self._stop_typing_if_needed(chat_id, stop_typing) if sent and hasattr(sent, "message_id"): return { "success": True, @@ -579,7 +592,7 @@ class Telegram: return {"success": bool(sent)} except Exception as err: logger.error(f"发送语音消息失败:{err}") - self._stop_typing_task(chat_id) + self._stop_typing_if_needed(chat_id, stop_typing) return {"success": False} finally: try: @@ -595,6 +608,7 @@ class Telegram: text: Optional[str] = None, file_name: Optional[str] = None, original_chat_id: Optional[str] = None, + stop_typing: bool = True, ) -> Optional[dict]: """ 发送本地图片或文件给 Telegram 用户。 @@ -606,7 +620,7 @@ class Telegram: local_file = Path(file_path) if not local_file.exists() or not local_file.is_file(): logger.error(f"附件文件不存在: {local_file}") - self._stop_typing_task(chat_id) + self._stop_typing_if_needed(chat_id, stop_typing) return {"success": False} send_name = file_name or local_file.name @@ -639,7 +653,7 @@ class Telegram: caption=standardize(caption) if caption else None, parse_mode="MarkdownV2" if caption else None, ) - self._stop_typing_task(chat_id) + self._stop_typing_if_needed(chat_id, stop_typing) if sent and hasattr(sent, "message_id"): return { "success": True, @@ -649,7 +663,7 @@ class Telegram: return {"success": bool(sent)} except Exception as err: logger.error(f"发送本地附件失败: {err}") - self._stop_typing_task(chat_id) + self._stop_typing_if_needed(chat_id, stop_typing) return {"success": False} def _determine_target_chat_id( @@ -685,6 +699,7 @@ class Telegram: buttons: Optional[List[List[Dict]]] = None, original_message_id: Optional[int] = None, original_chat_id: Optional[str] = None, + stop_typing: bool = True, ) -> Optional[bool]: """ 发送媒体列表消息 @@ -695,11 +710,12 @@ class Telegram: :param buttons: 按钮列表,格式:[[{"text": "按钮文本", "callback_data": "回调数据"}]] :param original_message_id: 原消息ID,如果提供则编辑原消息 :param original_chat_id: 原消息的聊天ID,编辑消息时需要 + :param stop_typing: 发送完成后是否立即停止 typing """ if not self._telegram_token or not self._telegram_chat_id: return None - # 列表消息也可能是一次交互的最终响应,需要确保 typing 状态在发送后结束。 + # 列表消息也可能是一次交互的最终响应,默认在发送后结束 typing。 chat_id = self._determine_target_chat_id(userid, original_chat_id) try: index, image, caption = 1, "", "*%s*" % title @@ -752,7 +768,7 @@ class Telegram: logger.error(f"发送消息失败:{msg_e}") return False finally: - self._stop_typing_task(chat_id) + self._stop_typing_if_needed(chat_id, stop_typing) def send_torrents_msg( self, @@ -763,6 +779,7 @@ class Telegram: buttons: Optional[List[List[Dict]]] = None, original_message_id: Optional[int] = None, original_chat_id: Optional[str] = None, + stop_typing: bool = True, ) -> Optional[bool]: """ 发送种子列表消息 @@ -773,11 +790,12 @@ class Telegram: :param buttons: 按钮列表,格式:[[{"text": "按钮文本", "callback_data": "回调数据"}]] :param original_message_id: 原消息ID,如果提供则编辑原消息 :param original_chat_id: 原消息的聊天ID,编辑消息时需要 + :param stop_typing: 发送完成后是否立即停止 typing """ if not self._telegram_token or not self._telegram_chat_id: return None - # 资源列表是搜索交互的常见出口,也必须统一释放 typing 状态。 + # 资源列表是搜索交互的常见出口,默认在发送后结束 typing。 chat_id = self._determine_target_chat_id(userid, original_chat_id) try: index, caption = 1, "*%s*" % title @@ -829,7 +847,7 @@ class Telegram: logger.error(f"发送消息失败:{msg_e}") return False finally: - self._stop_typing_task(chat_id) + self._stop_typing_if_needed(chat_id, stop_typing) @staticmethod def _create_inline_keyboard(buttons: List[List[Dict]]) -> InlineKeyboardMarkup: @@ -919,6 +937,7 @@ class Telegram: text: str, title: Optional[str] = None, buttons: Optional[List[List[dict]]] = None, + stop_typing: bool = True, ) -> Optional[bool]: """ 编辑Telegram消息(公开方法) @@ -927,6 +946,7 @@ class Telegram: :param text: 新的消息内容 :param title: 消息标题 :param buttons: 新的按钮列表 + :param stop_typing: 编辑完成后是否立即停止 typing :return: 编辑是否成功 """ if not self._bot: @@ -952,7 +972,7 @@ class Telegram: logger.error(f"编辑Telegram消息异常: {str(e)}") return False finally: - self._stop_typing_task(chat_id) + self._stop_typing_if_needed(chat_id, stop_typing) def __edit_message( self, diff --git a/tests/test_telegram_typing_lifecycle.py b/tests/test_telegram_typing_lifecycle.py index 4e96fb87..d4f2a79f 100644 --- a/tests/test_telegram_typing_lifecycle.py +++ b/tests/test_telegram_typing_lifecycle.py @@ -1,7 +1,10 @@ +import asyncio import time import unittest -from unittest.mock import Mock, patch +from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock, patch +from app.agent import AgentManager, _MessageTask from app.chain.message import MessageChain from app.modules.telegram.telegram import Telegram from app.schemas.types import MessageChannel @@ -63,6 +66,22 @@ class TestTelegramTypingLifecycle(unittest.TestCase): self.assertNotIn("chat-3", Telegram._typing_tasks) + def test_agent_managed_send_msg_keeps_typing_for_worker_cleanup(self): + telegram = self._telegram_client() + sent = SimpleNamespace(message_id=1, chat=SimpleNamespace(id="chat-1")) + + with patch.object( + telegram, "_Telegram__send_request", return_value=sent + ), patch.object(telegram, "_stop_typing_task") as stop_typing: + result = telegram.send_msg( + title="处理中", + userid="10001", + stop_typing=False, + ) + + self.assertTrue(result["success"]) + stop_typing.assert_not_called() + def test_slash_command_stops_typing_when_message_handler_returns(self): chain = MessageChain.__new__(MessageChain) status = MessageChain._ProcessingStatus( @@ -184,6 +203,112 @@ class TestTelegramTypingLifecycle(unittest.TestCase): status=status.to_dict(), ) + def test_agent_manager_defers_shared_typing_until_queued_task_finishes(self): + async def _run(): + manager = AgentManager() + queue = asyncio.Queue() + first = _MessageTask( + session_id="session-1", + user_id="10001", + message="第一条", + processing_status={ + "channel": MessageChannel.Telegram.value, + "source": "telegram-test", + "userid": "10001", + "chat_id": "-100", + "metadata": {"kind": "typing"}, + }, + ) + second = _MessageTask( + session_id="session-1", + user_id="10001", + message="第二条", + processing_status={ + "channel": MessageChannel.Telegram.value, + "source": "telegram-test", + "userid": "10001", + "chat_id": "-100", + "metadata": {"kind": "typing"}, + }, + ) + await queue.put(second) + + with patch( + "app.agent._async_finish_processing_status", + new_callable=AsyncMock, + ) as finish_status: + await manager._finish_task_processing_status( + session_id="session-1", + task=first, + queue=queue, + ) + finish_status.assert_not_awaited() + self.assertEqual( + manager._deferred_processing_statuses["session-1"], + first.processing_status, + ) + + queue.get_nowait() + await manager._finish_task_processing_status( + session_id="session-1", + task=second, + queue=queue, + ) + + finish_status.assert_awaited_once_with( + second.processing_status, "10001" + ) + self.assertNotIn("session-1", manager._deferred_processing_statuses) + + asyncio.run(_run()) + + def test_agent_manager_closes_deferred_typing_when_next_task_has_no_status(self): + async def _run(): + manager = AgentManager() + queue = asyncio.Queue() + first = _MessageTask( + session_id="session-1", + user_id="10001", + message="第一条", + processing_status={ + "channel": MessageChannel.Telegram.value, + "source": "telegram-test", + "userid": "10001", + "chat_id": "-100", + "metadata": {"kind": "typing"}, + }, + ) + second = _MessageTask( + session_id="session-1", + user_id="10001", + message="第二条", + processing_status=None, + ) + await queue.put(second) + + with patch( + "app.agent._async_finish_processing_status", + new_callable=AsyncMock, + ) as finish_status: + await manager._finish_task_processing_status( + session_id="session-1", + task=first, + queue=queue, + ) + queue.get_nowait() + await manager._finish_task_processing_status( + session_id="session-1", + task=second, + queue=queue, + ) + + finish_status.assert_awaited_once_with( + first.processing_status, "10001" + ) + self.assertNotIn("session-1", manager._deferred_processing_statuses) + + asyncio.run(_run()) + if __name__ == "__main__": unittest.main()