From 0f42a0fb8c9388e3226c36052d1598f4aeb137ad Mon Sep 17 00:00:00 2001 From: jxxghp Date: Mon, 15 Jun 2026 07:50:45 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=85=A8=E5=B1=80=20AI=20?= =?UTF-8?q?=E4=B8=8B=E7=BB=95=E8=BF=87=E4=BC=A0=E7=BB=9F=E6=90=9C=E7=B4=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/chain/message.py | 49 ++++- tests/test_media_interaction.py | 307 ++++++++++++++++++++------------ 2 files changed, 240 insertions(+), 116 deletions(-) diff --git a/app/chain/message.py b/app/chain/message.py index 4a19e8e7..506423b5 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -42,6 +42,8 @@ class MessageChain(ChainBase): 外来消息处理链 """ + _ai_prefix = "/ai" + _no_ai_prefix = "/noai" # 用户会话信息 {userid: (session_id, last_time)} _user_sessions: Dict[Union[str, int], tuple] = {} # 会话超时时间(分钟) @@ -283,7 +285,22 @@ class MessageChain(ChainBase): ) return False - if text.startswith("/") and not text.lower().startswith("/ai"): + no_ai_requested, no_ai_text = self._strip_no_ai_prefix(text) + if no_ai_requested: + text = no_ai_text + if not text: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="请输入要使用传统交互处理的内容", + ) + ) + return False + + if text.startswith("/") and not self._has_ai_prefix(text): self.eventmanager.send_event( EventType.CommandExcute, { @@ -298,7 +315,7 @@ class MessageChain(ChainBase): ) return bool(processing_status) - if text.lower().startswith("/ai"): + if not no_ai_requested and self._has_ai_prefix(text): return self._handle_ai_message( text=text, channel=channel, @@ -354,6 +371,8 @@ class MessageChain(ChainBase): return False if ( + not no_ai_requested + and settings.AI_AGENT_ENABLE and (settings.AI_AGENT_GLOBAL or images or files or has_audio_input) ): @@ -390,6 +409,25 @@ class MessageChain(ChainBase): ) return False + @classmethod + def _strip_no_ai_prefix(cls, text: str) -> Tuple[bool, str]: + """ + 解析 /noai 前缀,显式要求本条消息绕过全局智能体。 + """ + normalized = (text or "").strip() + pattern = rf"^{re.escape(cls._no_ai_prefix)}(?:\s+|[::]\s*|$)(.*)$" + match = re.match(pattern, normalized, re.IGNORECASE | re.DOTALL) + if not match: + return False, text + return True, match.group(1).strip() + + @classmethod + def _has_ai_prefix(cls, text: str) -> bool: + """ + 判断消息是否使用显式 AI 前缀。 + """ + return (text or "").lower().startswith(cls._ai_prefix) + def _is_agent_message( self, channel: MessageChannel, @@ -404,7 +442,7 @@ class MessageChain(ChainBase): """ if text.startswith("CALLBACK:"): return self._parse_agent_choice_callback(text[9:]) is not None - if text.lower().startswith("/ai"): + if self._has_ai_prefix(text): return True if text.startswith("/"): return False @@ -1229,8 +1267,9 @@ class MessageChain(ChainBase): images = CommingMessage.MessageImage.normalize_list(images) # 提取用户消息 - if text.lower().startswith("/ai"): - user_message = text[3:].strip() # 移除 "/ai" 前缀(大小写不敏感) + if self._has_ai_prefix(text): + # 前缀匹配不区分大小写,但保留原始正文避免改变用户输入内容。 + user_message = text[len(self._ai_prefix):].strip() else: user_message = text.strip() # 按原消息处理 diff --git a/tests/test_media_interaction.py b/tests/test_media_interaction.py index 07be40e9..a5fc0630 100644 --- a/tests/test_media_interaction.py +++ b/tests/test_media_interaction.py @@ -1,7 +1,7 @@ -import unittest from unittest.mock import patch -from app.chain.media import MediaChain +import pytest + from app.chain.message import MediaInteractionChain, MessageChain from app.core.context import MediaInfo from app.core.meta import MetaBase @@ -9,139 +9,224 @@ from app.helper.interaction import media_interaction_manager from app.schemas.types import MessageChannel -class TestMediaInteraction(unittest.TestCase): - def tearDown(self): - media_interaction_manager.clear() +@pytest.fixture(autouse=True) +def clear_media_interactions(): + """清理媒体交互状态,避免用例之间共享内存会话。""" + yield + media_interaction_manager.clear() - @staticmethod - def _build_meta(name: str) -> MetaBase: - meta = MetaBase(name) - meta.name = name - meta.begin_season = 1 - return meta - def test_message_routes_text_reply_to_media_interaction_before_ai(self): - chain = MessageChain() - request = media_interaction_manager.create_or_replace( - user_id="10001", +def _build_meta(name: str) -> MetaBase: + """构造媒体识别元数据。""" + meta = MetaBase(name) + meta.name = name + meta.begin_season = 1 + return meta + + +def test_message_routes_text_reply_to_media_interaction_before_ai(): + """已有传统媒体交互时,用户回复应优先交给传统交互处理。""" + chain = MessageChain() + request = media_interaction_manager.create_or_replace( + user_id="10001", + channel=MessageChannel.Wechat, + source="wechat-test", + username="tester", + action="Search", + keyword="星际穿越", + title="星际穿越", + meta=_build_meta("星际穿越"), + items=[MediaInfo(title="星际穿越", year="2014")], + ) + assert request is not None + + with patch.object(chain, "_record_user_message"), patch( + "app.chain.message.MediaInteractionChain.handle_text_interaction", + return_value=True, + ) as handle_text, patch.object(chain, "_handle_ai_message") as handle_ai: + chain.handle_message( channel=MessageChannel.Wechat, source="wechat-test", + userid="10001", username="tester", - action="Search", - keyword="星际穿越", - title="星际穿越", - meta=self._build_meta("星际穿越"), - items=[MediaInfo(title="星际穿越", year="2014")], + text="1", ) - self.assertIsNotNone(request) - with patch.object(chain, "_record_user_message"), patch( - "app.chain.message.MediaInteractionChain.handle_text_interaction", - return_value=True, - ) as handle_text, patch.object(chain, "_handle_ai_message") as handle_ai: - chain.handle_message( - channel=MessageChannel.Wechat, - source="wechat-test", - userid="10001", - username="tester", - text="1", - ) + handle_text.assert_called_once() + handle_ai.assert_not_called() - handle_text.assert_called_once() - handle_ai.assert_not_called() - def test_callback_routes_to_media_interaction_chain(self): - chain = MessageChain() - request = media_interaction_manager.create_or_replace( - user_id="10001", +def test_noai_prefix_starts_traditional_search_when_global_ai_enabled(): + """全局 AI 开启时,/noai 前缀应让本条消息进入传统搜索交互。""" + chain = MessageChain() + meta = _build_meta("星际穿越") + medias = [ + MediaInfo(title="星际穿越", year="2014"), + MediaInfo(title="Interstellar", year="2014"), + ] + + with patch.object(chain, "_record_user_message"), patch( + "app.chain.message.settings.AI_AGENT_ENABLE", True + ), patch( + "app.chain.message.settings.AI_AGENT_GLOBAL", True + ), patch( + "app.chain.media.MediaChain.search", + return_value=(meta, medias), + ) as search_media, patch( + "app.chain.message.MediaInteractionChain.post_medias_message" + ) as post_medias_message, patch.object( + chain, "_handle_ai_message" + ) as handle_ai: + chain.handle_message( channel=MessageChannel.Telegram, source="telegram-test", + userid="10001", username="tester", - action="Search", - keyword="星际穿越", - title="星际穿越", - meta=self._build_meta("星际穿越"), - items=[MediaInfo(title="星际穿越", year="2014")], + text="/noai 星际穿越", ) - with patch( - "app.chain.message.MediaInteractionChain.handle_callback_interaction", - return_value=True, - ) as handle_callback: - chain._handle_callback( - text=f"CALLBACK:media:{request.request_id}:page-next", - channel=MessageChannel.Telegram, - source="telegram-test", - userid="10001", - username="tester", - ) + search_media.assert_called_once_with("星际穿越") + post_medias_message.assert_called_once() + handle_ai.assert_not_called() - handle_callback.assert_called_once() + request = media_interaction_manager.get_by_user("10001") + assert request is not None + assert request.action == "Search" + assert request.keyword == "星际穿越" + assert len(request.items) == 2 - def test_media_interaction_starts_search_and_posts_media_list(self): - chain = MediaInteractionChain() - meta = self._build_meta("星际穿越") - medias = [ - MediaInfo(title="星际穿越", year="2014"), - MediaInfo(title="Interstellar", year="2014"), - ] - with patch( - "app.chain.media.MediaChain.search", - return_value=(meta, medias), - ), patch.object(chain, "post_medias_message") as post_medias_message: - handled = chain.handle_text_interaction( - channel=MessageChannel.Telegram, - source="telegram-test", - userid="10001", - username="tester", - text="星际穿越", - ) +def test_noai_prefix_preserves_traditional_interaction_priority_after_search(): + """通过 /noai 进入传统交互后,后续选择应继续优先走传统交互。""" + chain = MessageChain() + request = media_interaction_manager.create_or_replace( + user_id="10001", + channel=MessageChannel.Wechat, + source="wechat-test", + username="tester", + action="Search", + keyword="星际穿越", + title="星际穿越", + meta=_build_meta("星际穿越"), + items=[MediaInfo(title="星际穿越", year="2014")], + ) + assert request is not None - self.assertTrue(handled) - post_medias_message.assert_called_once() - notification = post_medias_message.call_args.args[0] - self.assertTrue(notification.buttons) - self.assertTrue( - notification.buttons[0][0]["callback_data"].startswith("media:") + with patch.object(chain, "_record_user_message"), patch( + "app.chain.message.settings.AI_AGENT_ENABLE", True + ), patch( + "app.chain.message.settings.AI_AGENT_GLOBAL", True + ), patch( + "app.chain.message.MediaInteractionChain.handle_text_interaction", + return_value=True, + ) as handle_text, patch.object(chain, "_handle_ai_message") as handle_ai: + chain.handle_message( + channel=MessageChannel.Wechat, + source="wechat-test", + userid="10001", + username="tester", + text="1", ) - request = media_interaction_manager.get_by_user("10001") - self.assertIsNotNone(request) - self.assertEqual(request.action, "Search") - self.assertEqual(len(request.items), 2) + handle_text.assert_called_once() + handle_ai.assert_not_called() - def test_media_interaction_legacy_page_callback_updates_existing_request(self): - chain = MediaInteractionChain() - request = media_interaction_manager.create_or_replace( - user_id="10001", + +def test_callback_routes_to_media_interaction_chain(): + """媒体按钮回调应路由到媒体交互链。""" + chain = MessageChain() + request = media_interaction_manager.create_or_replace( + user_id="10001", + channel=MessageChannel.Telegram, + source="telegram-test", + username="tester", + action="Search", + keyword="星际穿越", + title="星际穿越", + meta=_build_meta("星际穿越"), + items=[MediaInfo(title="星际穿越", year="2014")], + ) + + with patch( + "app.chain.message.MediaInteractionChain.handle_callback_interaction", + return_value=True, + ) as handle_callback: + chain._handle_callback( + text=f"CALLBACK:media:{request.request_id}:page-next", channel=MessageChannel.Telegram, source="telegram-test", + userid="10001", username="tester", - action="Search", - keyword="星际穿越", - title="星际穿越", - meta=self._build_meta("星际穿越"), - items=[ - MediaInfo(title=f"资源 {index}", year="2024") - for index in range(1, 11) - ], ) - with patch.object(chain, "post_medias_message") as post_medias_message: - handled = chain.handle_callback_interaction( - callback_data="page_n", - channel=MessageChannel.Telegram, - source="telegram-test", - userid="10001", - username="tester", - original_message_id=123, - original_chat_id="456", - ) + handle_callback.assert_called_once() - self.assertTrue(handled) - self.assertEqual(request.page, 1) - post_medias_message.assert_called_once() - notification = post_medias_message.call_args.args[0] - self.assertEqual(notification.original_message_id, 123) - self.assertEqual(notification.original_chat_id, "456") + +def test_media_interaction_starts_search_and_posts_media_list(): + """传统媒体交互应能搜索媒体并发送候选列表。""" + chain = MediaInteractionChain() + meta = _build_meta("星际穿越") + medias = [ + MediaInfo(title="星际穿越", year="2014"), + MediaInfo(title="Interstellar", year="2014"), + ] + + with patch( + "app.chain.media.MediaChain.search", + return_value=(meta, medias), + ), patch.object(chain, "post_medias_message") as post_medias_message: + handled = chain.handle_text_interaction( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + text="星际穿越", + ) + + assert handled + post_medias_message.assert_called_once() + notification = post_medias_message.call_args.args[0] + assert notification.buttons + assert notification.buttons[0][0]["callback_data"].startswith("media:") + + request = media_interaction_manager.get_by_user("10001") + assert request is not None + assert request.action == "Search" + assert len(request.items) == 2 + + +def test_media_interaction_legacy_page_callback_updates_existing_request(): + """旧格式翻页回调仍应更新当前媒体交互请求。""" + chain = MediaInteractionChain() + request = media_interaction_manager.create_or_replace( + user_id="10001", + channel=MessageChannel.Telegram, + source="telegram-test", + username="tester", + action="Search", + keyword="星际穿越", + title="星际穿越", + meta=_build_meta("星际穿越"), + items=[ + MediaInfo(title=f"资源 {index}", year="2024") + for index in range(1, 11) + ], + ) + + with patch.object(chain, "post_medias_message") as post_medias_message: + handled = chain.handle_callback_interaction( + callback_data="page_n", + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + original_message_id=123, + original_chat_id="456", + ) + + assert handled + assert request.page == 1 + post_medias_message.assert_called_once() + notification = post_medias_message.call_args.args[0] + assert notification.original_message_id == 123 + assert notification.original_chat_id == "456"