fix: simplify message typing lifecycle

This commit is contained in:
jxxghp
2026-05-23 00:11:56 +08:00
parent cde267c55f
commit a74f04a149
7 changed files with 336 additions and 236 deletions

View File

@@ -58,22 +58,36 @@ def _finish_processing_status(status: Optional[dict], user_id: Optional[str] = N
"""结束入站消息的渠道处理状态。"""
if not status:
return
try:
channel = MessageChannel(status.get("channel"))
except Exception:
return
try:
AgentChain().run_module(
"mark_message_processing_finished",
channel=channel,
source=status.get("source"),
userid=status.get("userid") or user_id,
message_id=status.get("message_id"),
chat_id=status.get("chat_id"),
status=status,
)
except Exception as err:
logger.debug(f"结束Agent消息处理状态失败: {err}")
AgentChain().finish_message_processing_status(
status=status,
userid=user_id,
)
async def _async_start_processing_status(task: "_MessageTask") -> Optional[dict]:
"""
在 Agent worker 中启动渠道处理状态。
渠道启动可能触发外部 API同步实现需切到线程池避免阻塞事件循环。
"""
if not task.channel:
return None
def _start() -> Optional[dict]:
"""在线程池中通过统一 Chain 接口启动处理状态。"""
try:
return AgentChain().start_message_processing_status(
channel=MessageChannel(task.channel),
source=task.source,
userid=task.user_id,
message_id=task.original_message_id,
chat_id=task.original_chat_id,
text=task.message,
)
except Exception as err:
logger.debug(f"启动Agent消息处理状态失败: {err}")
return None
return await run_in_threadpool(_start)
async def _async_finish_processing_status(
@@ -952,8 +966,6 @@ class AgentManager:
self._session_queues: Dict[str, asyncio.Queue] = {}
# 每个会话的worker任务
self._session_workers: Dict[str, asyncio.Task] = {}
# typing 这类状态按会话/聊天共享,前一条任务结束时可能仍需延续到后续排队消息。
self._deferred_processing_statuses: Dict[str, dict] = {}
def get_session_status(self, session_id: str) -> dict[str, Any]:
"""获取会话当前模型与 token 使用状态。"""
@@ -1009,7 +1021,6 @@ class AgentManager:
pass
self._session_workers.clear()
self._session_queues.clear()
self._deferred_processing_statuses.clear()
for agent in self.active_agents.values():
await agent.cleanup()
self.active_agents.clear()
@@ -1026,7 +1037,6 @@ class AgentManager:
username: str = None,
original_message_id: Optional[str] = None,
original_chat_id: Optional[str] = None,
processing_status: Optional[dict] = None,
reply_mode: ReplyMode = ReplyMode.DISPATCH,
) -> str:
"""
@@ -1044,7 +1054,6 @@ class AgentManager:
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
processing_status=processing_status,
reply_mode=reply_mode,
)
@@ -1099,15 +1108,12 @@ class AgentManager:
break
try:
await self._start_task_processing_status(task)
await self._process_message_internal(task)
except Exception as e:
logger.error(f"处理会话 {session_id} 的消息失败: {e}")
finally:
await self._finish_task_processing_status(
session_id=session_id,
task=task,
queue=queue,
)
await self._finish_task_processing_status(task)
queue.task_done()
except asyncio.CancelledError:
@@ -1121,52 +1127,23 @@ class AgentManager:
and self._session_queues[session_id].empty()
):
self._session_queues.pop(session_id, None)
self._deferred_processing_statuses.pop(session_id, None)
@staticmethod
def _is_shared_processing_status(status: Optional[dict]) -> bool:
async def _start_task_processing_status(task: _MessageTask) -> None:
"""
判断状态是否属于同一聊天窗口共享的处理提示
reaction 绑定到具体消息应按消息收口typing 绑定到会话/聊天,需要等队列空闲再关闭。
在 Agent worker 真正开始处理消息时启动渠道处理状态
"""
metadata = (status or {}).get("metadata") or {}
return isinstance(metadata, dict) and metadata.get("kind") == "typing"
async def _finish_task_processing_status(
self,
session_id: str,
task: _MessageTask,
queue: asyncio.Queue,
) -> None:
"""
根据会话队列状态结束或延后处理提示。
当后面还有排队消息时typing 状态继续保留;队列真正空闲后再统一关闭。
"""
status = task.processing_status
if self._is_shared_processing_status(status) and not queue.empty():
self._deferred_processing_statuses[session_id] = status
if task.processing_status:
return
task.processing_status = await _async_start_processing_status(task)
if status:
await _async_finish_processing_status(status, task.user_id)
if self._is_shared_processing_status(status):
self._deferred_processing_statuses.pop(session_id, None)
elif queue.empty():
deferred_status = self._deferred_processing_statuses.pop(
session_id, None
)
if deferred_status:
await _async_finish_processing_status(
deferred_status, task.user_id
)
return
if not queue.empty():
return
deferred_status = self._deferred_processing_statuses.pop(session_id, None)
if deferred_status:
await _async_finish_processing_status(deferred_status, task.user_id)
@staticmethod
async def _finish_task_processing_status(task: _MessageTask) -> None:
"""
在 Agent worker 完成或异常后结束本条消息的渠道处理状态。
"""
await _async_finish_processing_status(task.processing_status, task.user_id)
task.processing_status = None
async def _process_message_internal(self, task: _MessageTask):
"""
@@ -1232,7 +1209,6 @@ class AgentManager:
break
self._session_queues.pop(session_id, None)
stopped = True
self._deferred_processing_statuses.pop(session_id, None)
if stopped:
logger.info(f"会话 {session_id} 的Agent推理已应急停止")
@@ -1256,7 +1232,6 @@ class AgentManager:
# 清理队列
self._session_queues.pop(session_id, None)
self._deferred_processing_statuses.pop(session_id, None)
# 清理agent
if session_id in self.active_agents:

View File

@@ -41,6 +41,7 @@ from app.schemas import (
MessageResponse,
)
from app.utils.identity import normalize_internal_user_id
from app.schemas.message import ChannelCapability, ChannelCapabilityManager
from app.schemas.category import CategoryConfig
from app.schemas.types import (
TorrentStatus,
@@ -122,6 +123,74 @@ class ChainBase(metaclass=ABCMeta):
"""
self.filecache.delete(filename)
def start_message_processing_status(
self,
channel: MessageChannel,
source: Optional[str],
userid: Optional[Union[str, int]] = None,
message_id: Optional[Union[str, int]] = None,
chat_id: Optional[Union[str, int]] = None,
text: Optional[str] = None,
) -> Optional[dict]:
"""
启动渠道侧消息输入/处理状态。
具体表现由消息模块实现,例如 typing 保活或消息 reaction。
"""
if not channel or not ChannelCapabilityManager.supports_capability(
channel, ChannelCapability.PROCESSING_STATUS
):
return None
try:
status = self.run_module(
"mark_message_processing_started",
channel=channel,
source=source,
userid=userid,
message_id=message_id,
chat_id=chat_id,
text=text,
)
except Exception as err:
logger.debug(f"启动消息处理状态失败: {err}")
return None
return status if isinstance(status, dict) else None
def finish_message_processing_status(
self,
status: Optional[dict] = None,
channel: Optional[MessageChannel] = None,
source: Optional[str] = None,
userid: Optional[Union[str, int]] = None,
message_id: Optional[Union[str, int]] = None,
chat_id: Optional[Union[str, int]] = None,
) -> None:
"""
结束渠道侧消息输入/处理状态。
优先使用 start 返回的 status缺失时使用显式渠道和消息定位参数。
"""
target_channel = channel
if status:
try:
target_channel = MessageChannel(status.get("channel"))
except Exception:
target_channel = channel
if not target_channel or not ChannelCapabilityManager.supports_capability(
target_channel, ChannelCapability.PROCESSING_STATUS
):
return
try:
self.run_module(
"mark_message_processing_finished",
channel=target_channel,
source=(status or {}).get("source") or source,
userid=(status or {}).get("userid") or userid,
message_id=(status or {}).get("message_id") or message_id,
chat_id=(status or {}).get("chat_id") or chat_id,
status=status,
)
except Exception as err:
logger.debug(f"结束消息处理状态失败: {err}")
@staticmethod
def _normalize_notification_for_dispatch(
message: Notification

View File

@@ -137,14 +137,7 @@ class MessageChain(ChainBase):
"""
images = CommingMessage.MessageImage.normalize_list(images)
processing_status = self._mark_message_processing_started(
channel=channel,
source=source,
userid=userid,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
text=text,
)
processing_status = None
continues_async = False
try:
# 语音输入只用于转写为文本,不默认改变回复形式。
@@ -181,6 +174,23 @@ class MessageChain(ChainBase):
text=text,
)
if not self._is_agent_message(
channel=channel,
userid=userid,
text=text,
images=images,
files=files,
has_audio_input=has_audio_input,
):
processing_status = self._mark_message_processing_started(
channel=channel,
source=source,
userid=userid,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
text=text,
)
continues_async = self._handle_message_core(
channel=channel,
source=source,
@@ -310,7 +320,6 @@ class MessageChain(ChainBase):
original_chat_id=original_chat_id,
images=images,
files=files,
processing_status=processing_status,
)
if (
@@ -327,7 +336,6 @@ class MessageChain(ChainBase):
original_chat_id=original_chat_id,
images=images,
files=files,
processing_status=processing_status,
)
if MediaInteractionChain().handle_text_interaction(
@@ -350,6 +358,35 @@ class MessageChain(ChainBase):
)
return False
def _is_agent_message(
self,
channel: MessageChannel,
userid: Union[str, int],
text: str,
images: Optional[List[CommingMessage.MessageImage]] = None,
files: Optional[List[CommingMessage.MessageAttachment]] = None,
has_audio_input: bool = False,
) -> bool:
"""
判断本条消息是否会进入 Agent worker由 Agent worker 管理 typing 生命周期。
"""
if text.startswith("CALLBACK:"):
return self._parse_agent_choice_callback(text[9:]) is not None
if text.lower().startswith("/ai"):
return True
if text.startswith("/"):
return False
if not (
settings.AI_AGENT_ENABLE
and (settings.AI_AGENT_GLOBAL or images or files or has_audio_input)
):
return False
if self._get_latest_slash_interaction(userid):
return False
if media_interaction_manager.get_by_user(userid):
return False
return True
def _mark_message_processing_started(
self,
channel: MessageChannel,
@@ -360,27 +397,17 @@ class MessageChain(ChainBase):
text: str,
) -> Optional[_ProcessingStatus]:
"""为支持的渠道标记“消息正在处理”。"""
if not ChannelCapabilityManager.supports_capability(
channel, ChannelCapability.PROCESSING_STATUS
):
status = self.start_message_processing_status(
channel=channel,
source=source,
userid=userid,
message_id=original_message_id,
chat_id=original_chat_id,
text=text,
)
if not status:
return None
try:
status = self.run_module(
"mark_message_processing_started",
channel=channel,
source=source,
userid=userid,
message_id=original_message_id,
chat_id=original_chat_id,
text=text,
)
except Exception as err:
logger.debug(f"标记消息处理状态失败: {err}")
return None
if not isinstance(status, dict):
return None
metadata = status.get("metadata")
return self._ProcessingStatus(
channel=channel,
@@ -404,22 +431,16 @@ class MessageChain(ChainBase):
结束渠道侧“消息正在处理”状态。
不同渠道的表现可能是 reaction、typing 等,消息链只负责调用通用模块接口。
"""
if not status and not ChannelCapabilityManager.supports_capability(
channel, ChannelCapability.PROCESSING_STATUS
):
if not status:
return
try:
self.run_module(
"mark_message_processing_finished",
channel=channel,
source=source,
userid=userid,
message_id=status.message_id if status else original_message_id,
chat_id=status.chat_id if status else original_chat_id,
status=status.to_dict() if status else None,
)
except Exception as err:
logger.debug(f"结束消息处理状态失败: {err}")
self.finish_message_processing_status(
status=status.to_dict(),
channel=channel,
source=source,
userid=userid,
message_id=status.message_id or original_message_id,
chat_id=status.chat_id or original_chat_id,
)
def _handle_callback(
self,
@@ -501,7 +522,6 @@ class MessageChain(ChainBase):
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
processing_status=processing_status,
):
return True
@@ -1148,7 +1168,6 @@ class MessageChain(ChainBase):
images: Optional[List[CommingMessage.MessageImage]] = None,
files: Optional[List[CommingMessage.MessageAttachment]] = None,
session_id: Optional[str] = None,
processing_status: Optional[_ProcessingStatus] = None,
) -> bool:
"""
处理AI智能体消息
@@ -1261,9 +1280,6 @@ class MessageChain(ChainBase):
username=username,
original_message_id=str(original_message_id) if original_message_id else None,
original_chat_id=original_chat_id,
processing_status=processing_status.to_dict()
if processing_status
else None,
),
global_vars.loop,
)

View File

@@ -34,22 +34,10 @@ def _finish_command_processing_status(status: Optional[dict], user_id: Optional[
"""
if not status:
return
try:
channel = MessageChannel(status.get("channel"))
except Exception:
return
try:
CommandChain().run_module(
"mark_message_processing_finished",
channel=channel,
source=status.get("source"),
userid=status.get("userid") or user_id,
message_id=status.get("message_id"),
chat_id=status.get("chat_id"),
status=status,
)
except Exception as err:
logger.debug(f"结束命令消息处理状态失败: {err}")
CommandChain().finish_message_processing_status(
status=status,
userid=user_id,
)
class Command(metaclass=Singleton):

View File

@@ -585,14 +585,12 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
continue
client: Telegram = self.get_instance(conf.name)
if client:
stop_typing = not (metadata or {}).get("agent_managed_typing")
result = client.edit_msg(
chat_id=chat_id,
message_id=message_id,
text=text,
title=title,
buttons=buttons,
stop_typing=stop_typing,
)
if result:
return True