fix: 统一消息 typing 生命周期

This commit is contained in:
jxxghp
2026-05-22 22:59:20 +08:00
parent 7e6cd47712
commit f7b78721c3
5 changed files with 242 additions and 116 deletions

View File

@@ -137,40 +137,6 @@ class MessageChain(ChainBase):
"""
images = CommingMessage.MessageImage.normalize_list(images)
# 语音输入只用于转写为文本,不默认改变回复形式。
has_audio_input = bool(audio_refs)
if audio_refs:
transcript = self._transcribe_audio_refs(audio_refs, channel, source)
merged_parts = []
seen_parts = set()
for item in [text.strip() if text else "", transcript or ""]:
normalized = item.strip()
if not normalized or normalized in seen_parts:
continue
seen_parts.add(normalized)
merged_parts.append(normalized)
text = "\n".join(merged_parts).strip()
if not text:
self.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="语音识别失败,请稍后重试",
)
)
return
if not text.startswith("CALLBACK:"):
self._record_user_message(
channel=channel,
source=source,
userid=userid,
username=username,
text=text,
)
processing_status = self._mark_message_processing_started(
channel=channel,
source=source,
@@ -181,6 +147,40 @@ class MessageChain(ChainBase):
)
continues_async = False
try:
# 语音输入只用于转写为文本,不默认改变回复形式。
has_audio_input = bool(audio_refs)
if audio_refs:
transcript = self._transcribe_audio_refs(audio_refs, channel, source)
merged_parts = []
seen_parts = set()
for item in [text.strip() if text else "", transcript or ""]:
normalized = item.strip()
if not normalized or normalized in seen_parts:
continue
seen_parts.add(normalized)
merged_parts.append(normalized)
text = "\n".join(merged_parts).strip()
if not text:
self.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="语音识别失败,请稍后重试",
)
)
return
if not text.startswith("CALLBACK:"):
self._record_user_message(
channel=channel,
source=source,
userid=userid,
username=username,
text=text,
)
continues_async = self._handle_message_core(
channel=channel,
source=source,
@@ -225,7 +225,7 @@ class MessageChain(ChainBase):
if text.startswith("CALLBACK:"):
if ChannelCapabilityManager.supports_callbacks(channel):
self._handle_callback(
return self._handle_callback(
text=text,
channel=channel,
source=source,
@@ -233,6 +233,7 @@ class MessageChain(ChainBase):
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
processing_status=processing_status,
)
else:
logger.warning(
@@ -245,9 +246,17 @@ class MessageChain(ChainBase):
if text.startswith("/") and not text.lower().startswith("/ai"):
self.eventmanager.send_event(
EventType.CommandExcute,
{"cmd": text, "user": userid, "channel": channel, "source": source},
{
"cmd": text,
"user": userid,
"channel": channel,
"source": source,
"processing_status": processing_status.to_dict()
if processing_status
else None,
},
)
return False
return bool(processing_status)
latest_slash_interaction = self._get_latest_slash_interaction(userid)
if latest_slash_interaction == "sites":
@@ -355,8 +364,6 @@ class MessageChain(ChainBase):
channel, ChannelCapability.PROCESSING_STATUS
):
return None
if not text:
return None
try:
status = self.run_module(
@@ -423,7 +430,8 @@ class MessageChain(ChainBase):
username: str,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
) -> None:
processing_status: Optional[_ProcessingStatus] = None,
) -> bool:
"""
处理按钮回调
"""
@@ -439,7 +447,7 @@ class MessageChain(ChainBase):
userid=userid,
username=username,
):
return
return False
if SkillsChain().handle_callback_interaction(
callback_data=callback_data,
@@ -450,7 +458,7 @@ class MessageChain(ChainBase):
original_message_id=original_message_id,
original_chat_id=original_chat_id,
):
return
return False
if SiteChain().handle_callback_interaction(
callback_data=callback_data,
@@ -461,7 +469,7 @@ class MessageChain(ChainBase):
original_message_id=original_message_id,
original_chat_id=original_chat_id,
):
return
return False
if SubscribeChain().handle_callback_interaction(
callback_data=callback_data,
@@ -472,7 +480,7 @@ class MessageChain(ChainBase):
original_message_id=original_message_id,
original_chat_id=original_chat_id,
):
return
return False
if MediaInteractionChain().handle_callback_interaction(
callback_data=callback_data,
@@ -483,7 +491,7 @@ class MessageChain(ChainBase):
original_message_id=original_message_id,
original_chat_id=original_chat_id,
):
return
return False
if self._handle_agent_choice_callback(
callback_data=callback_data,
@@ -493,8 +501,9 @@ class MessageChain(ChainBase):
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
processing_status=processing_status,
):
return
return True
# 插件消息的事件回调 [PLUGIN]插件ID|内容
if callback_data.startswith("[PLUGIN]"):
@@ -513,7 +522,7 @@ class MessageChain(ChainBase):
"original_chat_id": original_chat_id,
},
)
return
return False
logger.error(f"回调数据格式错误:{callback_data}")
self.post_message(
@@ -525,6 +534,7 @@ class MessageChain(ChainBase):
title="回调数据格式错误,请检查!",
)
)
return False
@staticmethod
def _get_latest_slash_interaction(userid: Union[str, int]) -> Optional[str]:
@@ -628,6 +638,7 @@ class MessageChain(ChainBase):
username: str,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
processing_status: Optional[_ProcessingStatus] = None,
) -> bool:
"""
将 Agent 按钮选择回传为同一会话中的下一条用户消息。
@@ -652,7 +663,7 @@ class MessageChain(ChainBase):
title="该选择已失效,请重新发起选择",
)
)
return True
return False
request, option = resolved
selected_text = option.value
@@ -673,15 +684,15 @@ class MessageChain(ChainBase):
username=username,
text=selected_text,
)
self._handle_ai_message(
return self._handle_ai_message(
text=selected_text,
channel=channel,
source=source,
userid=userid,
username=username,
session_id=request.session_id,
processing_status=processing_status,
)
return True
def _update_interaction_message_feedback(
self,

View File

@@ -28,6 +28,30 @@ class CommandChain(ChainBase):
pass
def _finish_command_processing_status(status: Optional[dict], user_id: Optional[str] = None) -> None:
"""
命令执行完成后通过消息模块收口渠道处理状态。
"""
if not status:
return
try:
channel = MessageChannel(status.get("channel"))
except Exception:
return
try:
CommandChain().run_module(
"mark_message_processing_finished",
channel=channel,
source=status.get("source"),
userid=status.get("userid") or user_id,
message_id=status.get("message_id"),
chat_id=status.get("chat_id"),
status=status,
)
except Exception as err:
logger.debug(f"结束命令消息处理状态失败: {err}")
class Command(metaclass=Singleton):
"""
全局命令管理,消费事件
@@ -434,17 +458,23 @@ class Command(metaclass=Singleton):
event_source = event.event_data.get("source")
# 消息用户
event_user = event.event_data.get("user")
if event_str:
cmd = event_str.split()[0]
args = " ".join(event_str.split()[1:])
if self.get(cmd):
self.execute(
cmd=cmd,
data_str=args,
channel=event_channel,
source=event_source,
userid=event_user,
)
try:
if event_str:
cmd = event_str.split()[0]
args = " ".join(event_str.split()[1:])
if self.get(cmd):
self.execute(
cmd=cmd,
data_str=args,
channel=event_channel,
source=event_source,
userid=event_user,
)
finally:
_finish_command_processing_status(
event.event_data.get("processing_status"),
user_id=event_user,
)
@eventmanager.register(EventType.ModuleReload)
def module_reload_event(self, _: ManagerEvent) -> None:

View File

@@ -15,7 +15,6 @@ from app.schemas import (
CommandRegisterEventData,
NotificationConf,
MessageResponse,
NotificationType,
)
from app.schemas.types import ModuleType, ChainEventType
from app.utils.structures import DictUtils
@@ -452,7 +451,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
return
client: Telegram = self.get_instance(conf.name)
if client:
stop_typing = message.mtype != NotificationType.Agent
if message.file_path:
client.send_file(
file_path=message.file_path,
@@ -461,7 +459,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
text=message.text,
userid=userid,
original_chat_id=message.original_chat_id,
stop_typing=stop_typing,
)
elif message.voice_path:
client.send_voice(
@@ -469,7 +466,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
userid=userid,
caption=message.voice_caption,
original_chat_id=message.original_chat_id,
stop_typing=stop_typing,
)
else:
client.send_msg(
@@ -482,7 +478,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
disable_web_page_preview=message.disable_web_page_preview,
stop_typing=stop_typing,
)
def post_medias_message(
@@ -507,7 +502,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
stop_typing=message.mtype != NotificationType.Agent,
)
def post_torrents_message(
@@ -532,7 +526,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
stop_typing=message.mtype != NotificationType.Agent,
)
def delete_message(
@@ -616,11 +609,18 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
) -> Optional[dict]:
"""
标记 Telegram 消息正在处理。
入站侧已经启动 typing 任务,这里只返回可用于统一收口的上下文
Telegram typing 需要周期性续发,因此在模块接口中启动保活任务
"""
if channel != self._channel:
return None
if not text:
client_config = self.get_config(source)
if not client_config:
return None
client: Telegram = self.get_instance(client_config.name)
if not client:
return None
started = client.start_typing(chat_id=chat_id, userid=userid)
if not started:
return None
return {
"channel": channel.value,
@@ -674,14 +674,12 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
return None
client: Telegram = self.get_instance(conf.name)
if client:
agent_managed_typing = message.mtype == NotificationType.Agent
if message.voice_path:
result = client.send_voice(
voice_path=message.voice_path,
userid=userid,
caption=message.voice_caption,
original_chat_id=message.original_chat_id,
stop_typing=not agent_managed_typing,
)
else:
result = client.send_msg(
@@ -691,7 +689,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
userid=userid,
link=message.link,
disable_web_page_preview=message.disable_web_page_preview,
stop_typing=not agent_managed_typing,
)
if result and result.get("success"):
return MessageResponse(
@@ -699,9 +696,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
chat_id=result.get("chat_id"),
channel=MessageChannel.Telegram,
source=conf.name,
metadata={"agent_managed_typing": True}
if agent_managed_typing
else None,
success=True,
)
return None

View File

@@ -118,30 +118,15 @@ class Telegram:
# Check if we should process this message
if self._should_process_message(message):
# 启动持续发送正在输入状态
message_text = message.text or message.caption or ""
max_duration = (
self._typing_command_max_duration_seconds
if (
message_text.startswith("/")
and not message_text.lower().startswith("/ai")
)
else None
)
self._start_typing_task(
message.chat.id, max_duration_seconds=max_duration
)
payload = self._serialize_update_payload(message)
if not payload:
logger.warn("Telegram消息序列化失败跳过转发")
self._stop_typing_task(message.chat.id)
return
response = RequestUtils(timeout=15).post_res(
self._ds_url, json=payload
)
if not response or response.status_code >= 400:
logger.warn("Telegram消息转发失败停止typing状态")
self._stop_typing_task(message.chat.id)
logger.warn("Telegram消息转发失败")
@_bot.callback_query_handler(func=lambda call: True)
def callback_query(call):
@@ -149,7 +134,6 @@ class Telegram:
处理按钮点击回调
"""
chat_id = None
typing_started = False
try:
# Update user-chat mapping for callbacks too
chat_id = call.message.chat.id
@@ -181,25 +165,15 @@ class Telegram:
# 先确认回调避免用户看到loading状态
_bot.answer_callback_query(call.id)
# 启动持续发送正在输入状态
self._start_typing_task(
chat_id,
max_duration_seconds=self._typing_callback_max_duration_seconds,
)
typing_started = True
# 发送给主程序处理
response = RequestUtils(timeout=15).post_res(
self._ds_url, json=callback_json
)
if not response or response.status_code >= 400:
logger.warn("Telegram按钮回调转发失败停止typing状态")
self._stop_typing_task(chat_id)
logger.warn("Telegram按钮回调转发失败")
except Exception as err:
logger.error(f"处理按钮回调失败:{str(err)}")
if typing_started and chat_id is not None:
self._stop_typing_task(chat_id)
_bot.answer_callback_query(call.id, "处理失败,请重试")
def run_polling():
@@ -427,11 +401,31 @@ class Telegram:
) -> None:
"""
按调用方要求停止 typing。
Agent 回复和流式编辑由 worker 统一收口,避免中途发送消息时误关后续排队消息的状态
typing 由消息处理状态统一收口,兼容显式要求立即停止的调用
"""
if stop_typing:
self._stop_typing_task(chat_id)
def start_typing(
self,
chat_id: Optional[Union[str, int]] = None,
userid: Optional[Union[str, int]] = None,
) -> bool:
"""
外部链路主动启动 typing 状态。
"""
if chat_id:
target_chat_id = chat_id
elif userid:
target_chat_id = self._get_user_chat_id(str(userid)) or str(userid)
else:
target_chat_id = None
target_chat_id = target_chat_id or (str(userid) if userid else None)
if not target_chat_id:
return False
self._start_typing_task(target_chat_id)
return True
def stop_typing(
self,
chat_id: Optional[Union[str, int]] = None,
@@ -463,7 +457,7 @@ class Telegram:
original_message_id: Optional[int] = None,
original_chat_id: Optional[str] = None,
disable_web_page_preview: Optional[bool] = None,
stop_typing: bool = True,
stop_typing: bool = False,
) -> Optional[dict]:
"""
发送Telegram消息
@@ -559,7 +553,7 @@ class Telegram:
userid: Optional[str] = None,
caption: Optional[str] = None,
original_chat_id: Optional[str] = None,
stop_typing: bool = True,
stop_typing: bool = False,
) -> Optional[dict]:
"""
发送Telegram语音消息。
@@ -608,7 +602,7 @@ class Telegram:
text: Optional[str] = None,
file_name: Optional[str] = None,
original_chat_id: Optional[str] = None,
stop_typing: bool = True,
stop_typing: bool = False,
) -> Optional[dict]:
"""
发送本地图片或文件给 Telegram 用户。
@@ -699,7 +693,7 @@ class Telegram:
buttons: Optional[List[List[Dict]]] = None,
original_message_id: Optional[int] = None,
original_chat_id: Optional[str] = None,
stop_typing: bool = True,
stop_typing: bool = False,
) -> Optional[bool]:
"""
发送媒体列表消息
@@ -779,7 +773,7 @@ class Telegram:
buttons: Optional[List[List[Dict]]] = None,
original_message_id: Optional[int] = None,
original_chat_id: Optional[str] = None,
stop_typing: bool = True,
stop_typing: bool = False,
) -> Optional[bool]:
"""
发送种子列表消息
@@ -937,7 +931,7 @@ class Telegram:
text: str,
title: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
stop_typing: bool = True,
stop_typing: bool = False,
) -> Optional[bool]:
"""
编辑Telegram消息公开方法

View File

@@ -1,11 +1,17 @@
import asyncio
import sys
import time
import unittest
from types import SimpleNamespace
from types import ModuleType, SimpleNamespace
from unittest.mock import AsyncMock, Mock, patch
sys.modules.setdefault("app.helper.sites", ModuleType("app.helper.sites"))
setattr(sys.modules["app.helper.sites"], "SitesHelper", object)
from app.agent import AgentManager, _MessageTask
from app.chain.message import MessageChain
from app.command import Command, _finish_command_processing_status
from app.modules.telegram import TelegramModule
from app.modules.telegram.telegram import Telegram
from app.schemas.types import MessageChannel
@@ -82,8 +88,47 @@ class TestTelegramTypingLifecycle(unittest.TestCase):
self.assertTrue(result["success"])
stop_typing.assert_not_called()
def test_slash_command_stops_typing_when_message_handler_returns(self):
def test_send_msg_does_not_stop_typing_by_default(self):
"""
响应发送不再默认结束 typing由处理状态统一收口。
"""
telegram = self._telegram_client()
sent = SimpleNamespace(message_id=1, chat=SimpleNamespace(id="chat-1"))
with patch.object(
telegram, "_Telegram__send_request", return_value=sent
), patch.object(telegram, "_stop_typing_task") as stop_typing:
result = telegram.send_msg(title="处理中", userid="10001")
self.assertTrue(result["success"])
stop_typing.assert_not_called()
def test_telegram_module_processing_status_starts_typing(self):
"""
Telegram 通过模块处理状态接口启动 typing 保活。
"""
module = TelegramModule()
module._channel = MessageChannel.Telegram
client = Mock()
client.start_typing.return_value = True
with patch.object(
module, "get_config", return_value=SimpleNamespace(name="telegram-test")
), patch.object(module, "get_instance", return_value=client):
status = module.mark_message_processing_started(
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
chat_id="-100",
text="hello",
)
client.start_typing.assert_called_once_with(chat_id="-100", userid="10001")
self.assertEqual(status["metadata"]["kind"], "typing")
def test_slash_command_defers_processing_status_to_command_handler(self):
chain = MessageChain.__new__(MessageChain)
chain.eventmanager = Mock()
status = MessageChain._ProcessingStatus(
channel=MessageChannel.Telegram,
source="telegram-test",
@@ -94,7 +139,7 @@ class TestTelegramTypingLifecycle(unittest.TestCase):
with patch.object(chain, "_record_user_message"), patch.object(
chain, "_mark_message_processing_started", return_value=status
), patch.object(chain, "_handle_message_core"), patch.object(
), patch.object(
chain, "_mark_message_processing_finished"
) as finish_status:
chain.handle_message(
@@ -106,13 +151,65 @@ class TestTelegramTypingLifecycle(unittest.TestCase):
original_chat_id="-100",
)
finish_status.assert_not_called()
chain.eventmanager.send_event.assert_called_once()
self.assertEqual(
chain.eventmanager.send_event.call_args.args[1]["processing_status"],
status.to_dict(),
)
def test_command_handler_finishes_processing_status_after_execute(self):
"""
传统命令响应完成后由命令处理器统一结束 processing status。
"""
command = Command.__new__(Command)
command.get = Mock(return_value={"func": Mock()})
command.execute = Mock()
event = SimpleNamespace(
event_data={
"cmd": "/sites",
"user": "10001",
"channel": MessageChannel.Telegram,
"source": "telegram-test",
"processing_status": {
"channel": MessageChannel.Telegram.value,
"source": "telegram-test",
"userid": "10001",
"chat_id": "-100",
"metadata": {"kind": "typing"},
},
}
)
with patch("app.command._finish_command_processing_status") as finish_status:
command.command_event(event)
command.execute.assert_called_once()
finish_status.assert_called_once_with(
event.event_data["processing_status"],
user_id="10001",
)
def test_finish_command_processing_status_uses_module_interface(self):
status = {
"channel": MessageChannel.Telegram.value,
"source": "telegram-test",
"userid": "10001",
"chat_id": "-100",
"metadata": {"kind": "typing"},
}
with patch("app.command.CommandChain") as chain_cls:
_finish_command_processing_status(status, user_id="fallback")
chain_cls.return_value.run_module.assert_called_once_with(
"mark_message_processing_finished",
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
message_id=None,
chat_id="-100",
status=status,
original_message_id=None,
original_chat_id="-100",
)
def test_async_agent_keeps_processing_status_for_worker(self):