diff --git a/app/chain/message.py b/app/chain/message.py index bd0c55ae..0d86989c 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -137,40 +137,6 @@ class MessageChain(ChainBase): """ images = CommingMessage.MessageImage.normalize_list(images) - # 语音输入只用于转写为文本,不默认改变回复形式。 - has_audio_input = bool(audio_refs) - if audio_refs: - transcript = self._transcribe_audio_refs(audio_refs, channel, source) - merged_parts = [] - seen_parts = set() - for item in [text.strip() if text else "", transcript or ""]: - normalized = item.strip() - if not normalized or normalized in seen_parts: - continue - seen_parts.add(normalized) - merged_parts.append(normalized) - text = "\n".join(merged_parts).strip() - if not text: - self.post_message( - Notification( - channel=channel, - source=source, - userid=userid, - username=username, - title="语音识别失败,请稍后重试", - ) - ) - return - - if not text.startswith("CALLBACK:"): - self._record_user_message( - channel=channel, - source=source, - userid=userid, - username=username, - text=text, - ) - processing_status = self._mark_message_processing_started( channel=channel, source=source, @@ -181,6 +147,40 @@ class MessageChain(ChainBase): ) continues_async = False try: + # 语音输入只用于转写为文本,不默认改变回复形式。 + has_audio_input = bool(audio_refs) + if audio_refs: + transcript = self._transcribe_audio_refs(audio_refs, channel, source) + merged_parts = [] + seen_parts = set() + for item in [text.strip() if text else "", transcript or ""]: + normalized = item.strip() + if not normalized or normalized in seen_parts: + continue + seen_parts.add(normalized) + merged_parts.append(normalized) + text = "\n".join(merged_parts).strip() + if not text: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="语音识别失败,请稍后重试", + ) + ) + return + + if not text.startswith("CALLBACK:"): + self._record_user_message( + channel=channel, + source=source, + userid=userid, + username=username, + text=text, + ) + continues_async = self._handle_message_core( channel=channel, source=source, @@ -225,7 +225,7 @@ class MessageChain(ChainBase): if text.startswith("CALLBACK:"): if ChannelCapabilityManager.supports_callbacks(channel): - self._handle_callback( + return self._handle_callback( text=text, channel=channel, source=source, @@ -233,6 +233,7 @@ class MessageChain(ChainBase): username=username, original_message_id=original_message_id, original_chat_id=original_chat_id, + processing_status=processing_status, ) else: logger.warning( @@ -245,9 +246,17 @@ class MessageChain(ChainBase): if text.startswith("/") and not text.lower().startswith("/ai"): self.eventmanager.send_event( EventType.CommandExcute, - {"cmd": text, "user": userid, "channel": channel, "source": source}, + { + "cmd": text, + "user": userid, + "channel": channel, + "source": source, + "processing_status": processing_status.to_dict() + if processing_status + else None, + }, ) - return False + return bool(processing_status) latest_slash_interaction = self._get_latest_slash_interaction(userid) if latest_slash_interaction == "sites": @@ -355,8 +364,6 @@ class MessageChain(ChainBase): channel, ChannelCapability.PROCESSING_STATUS ): return None - if not text: - return None try: status = self.run_module( @@ -423,7 +430,8 @@ class MessageChain(ChainBase): username: str, original_message_id: Optional[Union[str, int]] = None, original_chat_id: Optional[str] = None, - ) -> None: + processing_status: Optional[_ProcessingStatus] = None, + ) -> bool: """ 处理按钮回调 """ @@ -439,7 +447,7 @@ class MessageChain(ChainBase): userid=userid, username=username, ): - return + return False if SkillsChain().handle_callback_interaction( callback_data=callback_data, @@ -450,7 +458,7 @@ class MessageChain(ChainBase): original_message_id=original_message_id, original_chat_id=original_chat_id, ): - return + return False if SiteChain().handle_callback_interaction( callback_data=callback_data, @@ -461,7 +469,7 @@ class MessageChain(ChainBase): original_message_id=original_message_id, original_chat_id=original_chat_id, ): - return + return False if SubscribeChain().handle_callback_interaction( callback_data=callback_data, @@ -472,7 +480,7 @@ class MessageChain(ChainBase): original_message_id=original_message_id, original_chat_id=original_chat_id, ): - return + return False if MediaInteractionChain().handle_callback_interaction( callback_data=callback_data, @@ -483,7 +491,7 @@ class MessageChain(ChainBase): original_message_id=original_message_id, original_chat_id=original_chat_id, ): - return + return False if self._handle_agent_choice_callback( callback_data=callback_data, @@ -493,8 +501,9 @@ class MessageChain(ChainBase): username=username, original_message_id=original_message_id, original_chat_id=original_chat_id, + processing_status=processing_status, ): - return + return True # 插件消息的事件回调 [PLUGIN]插件ID|内容 if callback_data.startswith("[PLUGIN]"): @@ -513,7 +522,7 @@ class MessageChain(ChainBase): "original_chat_id": original_chat_id, }, ) - return + return False logger.error(f"回调数据格式错误:{callback_data}") self.post_message( @@ -525,6 +534,7 @@ class MessageChain(ChainBase): title="回调数据格式错误,请检查!", ) ) + return False @staticmethod def _get_latest_slash_interaction(userid: Union[str, int]) -> Optional[str]: @@ -628,6 +638,7 @@ class MessageChain(ChainBase): username: str, original_message_id: Optional[Union[str, int]] = None, original_chat_id: Optional[str] = None, + processing_status: Optional[_ProcessingStatus] = None, ) -> bool: """ 将 Agent 按钮选择回传为同一会话中的下一条用户消息。 @@ -652,7 +663,7 @@ class MessageChain(ChainBase): title="该选择已失效,请重新发起选择", ) ) - return True + return False request, option = resolved selected_text = option.value @@ -673,15 +684,15 @@ class MessageChain(ChainBase): username=username, text=selected_text, ) - self._handle_ai_message( + return self._handle_ai_message( text=selected_text, channel=channel, source=source, userid=userid, username=username, session_id=request.session_id, + processing_status=processing_status, ) - return True def _update_interaction_message_feedback( self, diff --git a/app/command.py b/app/command.py index 57e547e7..c9d0c5f1 100644 --- a/app/command.py +++ b/app/command.py @@ -28,6 +28,30 @@ class CommandChain(ChainBase): pass +def _finish_command_processing_status(status: Optional[dict], user_id: Optional[str] = None) -> None: + """ + 命令执行完成后通过消息模块收口渠道处理状态。 + """ + if not status: + return + try: + channel = MessageChannel(status.get("channel")) + except Exception: + return + try: + CommandChain().run_module( + "mark_message_processing_finished", + channel=channel, + source=status.get("source"), + userid=status.get("userid") or user_id, + message_id=status.get("message_id"), + chat_id=status.get("chat_id"), + status=status, + ) + except Exception as err: + logger.debug(f"结束命令消息处理状态失败: {err}") + + class Command(metaclass=Singleton): """ 全局命令管理,消费事件 @@ -434,17 +458,23 @@ class Command(metaclass=Singleton): event_source = event.event_data.get("source") # 消息用户 event_user = event.event_data.get("user") - if event_str: - cmd = event_str.split()[0] - args = " ".join(event_str.split()[1:]) - if self.get(cmd): - self.execute( - cmd=cmd, - data_str=args, - channel=event_channel, - source=event_source, - userid=event_user, - ) + try: + if event_str: + cmd = event_str.split()[0] + args = " ".join(event_str.split()[1:]) + if self.get(cmd): + self.execute( + cmd=cmd, + data_str=args, + channel=event_channel, + source=event_source, + userid=event_user, + ) + finally: + _finish_command_processing_status( + event.event_data.get("processing_status"), + user_id=event_user, + ) @eventmanager.register(EventType.ModuleReload) def module_reload_event(self, _: ManagerEvent) -> None: diff --git a/app/modules/telegram/__init__.py b/app/modules/telegram/__init__.py index 8fe53776..cc66033d 100644 --- a/app/modules/telegram/__init__.py +++ b/app/modules/telegram/__init__.py @@ -15,7 +15,6 @@ from app.schemas import ( CommandRegisterEventData, NotificationConf, MessageResponse, - NotificationType, ) from app.schemas.types import ModuleType, ChainEventType from app.utils.structures import DictUtils @@ -452,7 +451,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): return client: Telegram = self.get_instance(conf.name) if client: - stop_typing = message.mtype != NotificationType.Agent if message.file_path: client.send_file( file_path=message.file_path, @@ -461,7 +459,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): text=message.text, userid=userid, original_chat_id=message.original_chat_id, - stop_typing=stop_typing, ) elif message.voice_path: client.send_voice( @@ -469,7 +466,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): userid=userid, caption=message.voice_caption, original_chat_id=message.original_chat_id, - stop_typing=stop_typing, ) else: client.send_msg( @@ -482,7 +478,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): original_message_id=message.original_message_id, original_chat_id=message.original_chat_id, disable_web_page_preview=message.disable_web_page_preview, - stop_typing=stop_typing, ) def post_medias_message( @@ -507,7 +502,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): buttons=message.buttons, original_message_id=message.original_message_id, original_chat_id=message.original_chat_id, - stop_typing=message.mtype != NotificationType.Agent, ) def post_torrents_message( @@ -532,7 +526,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): buttons=message.buttons, original_message_id=message.original_message_id, original_chat_id=message.original_chat_id, - stop_typing=message.mtype != NotificationType.Agent, ) def delete_message( @@ -616,11 +609,18 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): ) -> Optional[dict]: """ 标记 Telegram 消息正在处理。 - 入站侧已经启动 typing 任务,这里只返回可用于统一收口的上下文。 + Telegram typing 需要周期性续发,因此在模块接口中启动保活任务。 """ if channel != self._channel: return None - if not text: + client_config = self.get_config(source) + if not client_config: + return None + client: Telegram = self.get_instance(client_config.name) + if not client: + return None + started = client.start_typing(chat_id=chat_id, userid=userid) + if not started: return None return { "channel": channel.value, @@ -674,14 +674,12 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): return None client: Telegram = self.get_instance(conf.name) if client: - agent_managed_typing = message.mtype == NotificationType.Agent if message.voice_path: result = client.send_voice( voice_path=message.voice_path, userid=userid, caption=message.voice_caption, original_chat_id=message.original_chat_id, - stop_typing=not agent_managed_typing, ) else: result = client.send_msg( @@ -691,7 +689,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): userid=userid, link=message.link, disable_web_page_preview=message.disable_web_page_preview, - stop_typing=not agent_managed_typing, ) if result and result.get("success"): return MessageResponse( @@ -699,9 +696,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): chat_id=result.get("chat_id"), channel=MessageChannel.Telegram, source=conf.name, - metadata={"agent_managed_typing": True} - if agent_managed_typing - else None, success=True, ) return None diff --git a/app/modules/telegram/telegram.py b/app/modules/telegram/telegram.py index 9b22d09c..4cf39299 100644 --- a/app/modules/telegram/telegram.py +++ b/app/modules/telegram/telegram.py @@ -118,30 +118,15 @@ class Telegram: # Check if we should process this message if self._should_process_message(message): - # 启动持续发送正在输入状态 - message_text = message.text or message.caption or "" - max_duration = ( - self._typing_command_max_duration_seconds - if ( - message_text.startswith("/") - and not message_text.lower().startswith("/ai") - ) - else None - ) - self._start_typing_task( - message.chat.id, max_duration_seconds=max_duration - ) payload = self._serialize_update_payload(message) if not payload: logger.warn("Telegram消息序列化失败,跳过转发") - self._stop_typing_task(message.chat.id) return response = RequestUtils(timeout=15).post_res( self._ds_url, json=payload ) if not response or response.status_code >= 400: - logger.warn("Telegram消息转发失败,停止typing状态") - self._stop_typing_task(message.chat.id) + logger.warn("Telegram消息转发失败") @_bot.callback_query_handler(func=lambda call: True) def callback_query(call): @@ -149,7 +134,6 @@ class Telegram: 处理按钮点击回调 """ chat_id = None - typing_started = False try: # Update user-chat mapping for callbacks too chat_id = call.message.chat.id @@ -181,25 +165,15 @@ class Telegram: # 先确认回调,避免用户看到loading状态 _bot.answer_callback_query(call.id) - # 启动持续发送正在输入状态 - self._start_typing_task( - chat_id, - max_duration_seconds=self._typing_callback_max_duration_seconds, - ) - typing_started = True - # 发送给主程序处理 response = RequestUtils(timeout=15).post_res( self._ds_url, json=callback_json ) if not response or response.status_code >= 400: - logger.warn("Telegram按钮回调转发失败,停止typing状态") - self._stop_typing_task(chat_id) + logger.warn("Telegram按钮回调转发失败") except Exception as err: logger.error(f"处理按钮回调失败:{str(err)}") - if typing_started and chat_id is not None: - self._stop_typing_task(chat_id) _bot.answer_callback_query(call.id, "处理失败,请重试") def run_polling(): @@ -427,11 +401,31 @@ class Telegram: ) -> None: """ 按调用方要求停止 typing。 - Agent 回复和流式编辑由 worker 统一收口,避免中途发送消息时误关后续排队消息的状态。 + typing 由消息处理状态统一收口,兼容显式要求立即停止的调用。 """ if stop_typing: self._stop_typing_task(chat_id) + def start_typing( + self, + chat_id: Optional[Union[str, int]] = None, + userid: Optional[Union[str, int]] = None, + ) -> bool: + """ + 外部链路主动启动 typing 状态。 + """ + if chat_id: + target_chat_id = chat_id + elif userid: + target_chat_id = self._get_user_chat_id(str(userid)) or str(userid) + else: + target_chat_id = None + target_chat_id = target_chat_id or (str(userid) if userid else None) + if not target_chat_id: + return False + self._start_typing_task(target_chat_id) + return True + def stop_typing( self, chat_id: Optional[Union[str, int]] = None, @@ -463,7 +457,7 @@ class Telegram: original_message_id: Optional[int] = None, original_chat_id: Optional[str] = None, disable_web_page_preview: Optional[bool] = None, - stop_typing: bool = True, + stop_typing: bool = False, ) -> Optional[dict]: """ 发送Telegram消息 @@ -559,7 +553,7 @@ class Telegram: userid: Optional[str] = None, caption: Optional[str] = None, original_chat_id: Optional[str] = None, - stop_typing: bool = True, + stop_typing: bool = False, ) -> Optional[dict]: """ 发送Telegram语音消息。 @@ -608,7 +602,7 @@ class Telegram: text: Optional[str] = None, file_name: Optional[str] = None, original_chat_id: Optional[str] = None, - stop_typing: bool = True, + stop_typing: bool = False, ) -> Optional[dict]: """ 发送本地图片或文件给 Telegram 用户。 @@ -699,7 +693,7 @@ class Telegram: buttons: Optional[List[List[Dict]]] = None, original_message_id: Optional[int] = None, original_chat_id: Optional[str] = None, - stop_typing: bool = True, + stop_typing: bool = False, ) -> Optional[bool]: """ 发送媒体列表消息 @@ -779,7 +773,7 @@ class Telegram: buttons: Optional[List[List[Dict]]] = None, original_message_id: Optional[int] = None, original_chat_id: Optional[str] = None, - stop_typing: bool = True, + stop_typing: bool = False, ) -> Optional[bool]: """ 发送种子列表消息 @@ -937,7 +931,7 @@ class Telegram: text: str, title: Optional[str] = None, buttons: Optional[List[List[dict]]] = None, - stop_typing: bool = True, + stop_typing: bool = False, ) -> Optional[bool]: """ 编辑Telegram消息(公开方法) diff --git a/tests/test_telegram_typing_lifecycle.py b/tests/test_telegram_typing_lifecycle.py index d4f2a79f..7a839974 100644 --- a/tests/test_telegram_typing_lifecycle.py +++ b/tests/test_telegram_typing_lifecycle.py @@ -1,11 +1,17 @@ import asyncio +import sys import time import unittest -from types import SimpleNamespace +from types import ModuleType, SimpleNamespace from unittest.mock import AsyncMock, Mock, patch +sys.modules.setdefault("app.helper.sites", ModuleType("app.helper.sites")) +setattr(sys.modules["app.helper.sites"], "SitesHelper", object) + from app.agent import AgentManager, _MessageTask from app.chain.message import MessageChain +from app.command import Command, _finish_command_processing_status +from app.modules.telegram import TelegramModule from app.modules.telegram.telegram import Telegram from app.schemas.types import MessageChannel @@ -82,8 +88,47 @@ class TestTelegramTypingLifecycle(unittest.TestCase): self.assertTrue(result["success"]) stop_typing.assert_not_called() - def test_slash_command_stops_typing_when_message_handler_returns(self): + def test_send_msg_does_not_stop_typing_by_default(self): + """ + 响应发送不再默认结束 typing,由处理状态统一收口。 + """ + telegram = self._telegram_client() + sent = SimpleNamespace(message_id=1, chat=SimpleNamespace(id="chat-1")) + + with patch.object( + telegram, "_Telegram__send_request", return_value=sent + ), patch.object(telegram, "_stop_typing_task") as stop_typing: + result = telegram.send_msg(title="处理中", userid="10001") + + self.assertTrue(result["success"]) + stop_typing.assert_not_called() + + def test_telegram_module_processing_status_starts_typing(self): + """ + Telegram 通过模块处理状态接口启动 typing 保活。 + """ + module = TelegramModule() + module._channel = MessageChannel.Telegram + client = Mock() + client.start_typing.return_value = True + + with patch.object( + module, "get_config", return_value=SimpleNamespace(name="telegram-test") + ), patch.object(module, "get_instance", return_value=client): + status = module.mark_message_processing_started( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + chat_id="-100", + text="hello", + ) + + client.start_typing.assert_called_once_with(chat_id="-100", userid="10001") + self.assertEqual(status["metadata"]["kind"], "typing") + + def test_slash_command_defers_processing_status_to_command_handler(self): chain = MessageChain.__new__(MessageChain) + chain.eventmanager = Mock() status = MessageChain._ProcessingStatus( channel=MessageChannel.Telegram, source="telegram-test", @@ -94,7 +139,7 @@ class TestTelegramTypingLifecycle(unittest.TestCase): with patch.object(chain, "_record_user_message"), patch.object( chain, "_mark_message_processing_started", return_value=status - ), patch.object(chain, "_handle_message_core"), patch.object( + ), patch.object( chain, "_mark_message_processing_finished" ) as finish_status: chain.handle_message( @@ -106,13 +151,65 @@ class TestTelegramTypingLifecycle(unittest.TestCase): original_chat_id="-100", ) + finish_status.assert_not_called() + chain.eventmanager.send_event.assert_called_once() + self.assertEqual( + chain.eventmanager.send_event.call_args.args[1]["processing_status"], + status.to_dict(), + ) + + def test_command_handler_finishes_processing_status_after_execute(self): + """ + 传统命令响应完成后由命令处理器统一结束 processing status。 + """ + command = Command.__new__(Command) + command.get = Mock(return_value={"func": Mock()}) + command.execute = Mock() + event = SimpleNamespace( + event_data={ + "cmd": "/sites", + "user": "10001", + "channel": MessageChannel.Telegram, + "source": "telegram-test", + "processing_status": { + "channel": MessageChannel.Telegram.value, + "source": "telegram-test", + "userid": "10001", + "chat_id": "-100", + "metadata": {"kind": "typing"}, + }, + } + ) + + with patch("app.command._finish_command_processing_status") as finish_status: + command.command_event(event) + + command.execute.assert_called_once() finish_status.assert_called_once_with( + event.event_data["processing_status"], + user_id="10001", + ) + + def test_finish_command_processing_status_uses_module_interface(self): + status = { + "channel": MessageChannel.Telegram.value, + "source": "telegram-test", + "userid": "10001", + "chat_id": "-100", + "metadata": {"kind": "typing"}, + } + + with patch("app.command.CommandChain") as chain_cls: + _finish_command_processing_status(status, user_id="fallback") + + chain_cls.return_value.run_module.assert_called_once_with( + "mark_message_processing_finished", channel=MessageChannel.Telegram, source="telegram-test", userid="10001", + message_id=None, + chat_id="-100", status=status, - original_message_id=None, - original_chat_id="-100", ) def test_async_agent_keeps_processing_status_for_worker(self):