mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-20 15:10:22 +08:00
支持全局 AI 下绕过传统搜索
This commit is contained in:
@@ -42,6 +42,8 @@ class MessageChain(ChainBase):
|
||||
外来消息处理链
|
||||
"""
|
||||
|
||||
_ai_prefix = "/ai"
|
||||
_no_ai_prefix = "/noai"
|
||||
# 用户会话信息 {userid: (session_id, last_time)}
|
||||
_user_sessions: Dict[Union[str, int], tuple] = {}
|
||||
# 会话超时时间(分钟)
|
||||
@@ -283,7 +285,22 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
return False
|
||||
|
||||
if text.startswith("/") and not text.lower().startswith("/ai"):
|
||||
no_ai_requested, no_ai_text = self._strip_no_ai_prefix(text)
|
||||
if no_ai_requested:
|
||||
text = no_ai_text
|
||||
if not text:
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="请输入要使用传统交互处理的内容",
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
if text.startswith("/") and not self._has_ai_prefix(text):
|
||||
self.eventmanager.send_event(
|
||||
EventType.CommandExcute,
|
||||
{
|
||||
@@ -298,7 +315,7 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
return bool(processing_status)
|
||||
|
||||
if text.lower().startswith("/ai"):
|
||||
if not no_ai_requested and self._has_ai_prefix(text):
|
||||
return self._handle_ai_message(
|
||||
text=text,
|
||||
channel=channel,
|
||||
@@ -354,6 +371,8 @@ class MessageChain(ChainBase):
|
||||
return False
|
||||
|
||||
if (
|
||||
not no_ai_requested
|
||||
and
|
||||
settings.AI_AGENT_ENABLE
|
||||
and (settings.AI_AGENT_GLOBAL or images or files or has_audio_input)
|
||||
):
|
||||
@@ -390,6 +409,25 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _strip_no_ai_prefix(cls, text: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
解析 /noai 前缀,显式要求本条消息绕过全局智能体。
|
||||
"""
|
||||
normalized = (text or "").strip()
|
||||
pattern = rf"^{re.escape(cls._no_ai_prefix)}(?:\s+|[::]\s*|$)(.*)$"
|
||||
match = re.match(pattern, normalized, re.IGNORECASE | re.DOTALL)
|
||||
if not match:
|
||||
return False, text
|
||||
return True, match.group(1).strip()
|
||||
|
||||
@classmethod
|
||||
def _has_ai_prefix(cls, text: str) -> bool:
|
||||
"""
|
||||
判断消息是否使用显式 AI 前缀。
|
||||
"""
|
||||
return (text or "").lower().startswith(cls._ai_prefix)
|
||||
|
||||
def _is_agent_message(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
@@ -404,7 +442,7 @@ class MessageChain(ChainBase):
|
||||
"""
|
||||
if text.startswith("CALLBACK:"):
|
||||
return self._parse_agent_choice_callback(text[9:]) is not None
|
||||
if text.lower().startswith("/ai"):
|
||||
if self._has_ai_prefix(text):
|
||||
return True
|
||||
if text.startswith("/"):
|
||||
return False
|
||||
@@ -1229,8 +1267,9 @@ class MessageChain(ChainBase):
|
||||
images = CommingMessage.MessageImage.normalize_list(images)
|
||||
|
||||
# 提取用户消息
|
||||
if text.lower().startswith("/ai"):
|
||||
user_message = text[3:].strip() # 移除 "/ai" 前缀(大小写不敏感)
|
||||
if self._has_ai_prefix(text):
|
||||
# 前缀匹配不区分大小写,但保留原始正文避免改变用户输入内容。
|
||||
user_message = text[len(self._ai_prefix):].strip()
|
||||
else:
|
||||
user_message = text.strip() # 按原消息处理
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.chain.media import MediaChain
|
||||
import pytest
|
||||
|
||||
from app.chain.message import MediaInteractionChain, MessageChain
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.meta import MetaBase
|
||||
@@ -9,139 +9,224 @@ from app.helper.interaction import media_interaction_manager
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class TestMediaInteraction(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
media_interaction_manager.clear()
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_media_interactions():
|
||||
"""清理媒体交互状态,避免用例之间共享内存会话。"""
|
||||
yield
|
||||
media_interaction_manager.clear()
|
||||
|
||||
@staticmethod
|
||||
def _build_meta(name: str) -> MetaBase:
|
||||
meta = MetaBase(name)
|
||||
meta.name = name
|
||||
meta.begin_season = 1
|
||||
return meta
|
||||
|
||||
def test_message_routes_text_reply_to_media_interaction_before_ai(self):
|
||||
chain = MessageChain()
|
||||
request = media_interaction_manager.create_or_replace(
|
||||
user_id="10001",
|
||||
def _build_meta(name: str) -> MetaBase:
|
||||
"""构造媒体识别元数据。"""
|
||||
meta = MetaBase(name)
|
||||
meta.name = name
|
||||
meta.begin_season = 1
|
||||
return meta
|
||||
|
||||
|
||||
def test_message_routes_text_reply_to_media_interaction_before_ai():
|
||||
"""已有传统媒体交互时,用户回复应优先交给传统交互处理。"""
|
||||
chain = MessageChain()
|
||||
request = media_interaction_manager.create_or_replace(
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Wechat,
|
||||
source="wechat-test",
|
||||
username="tester",
|
||||
action="Search",
|
||||
keyword="星际穿越",
|
||||
title="星际穿越",
|
||||
meta=_build_meta("星际穿越"),
|
||||
items=[MediaInfo(title="星际穿越", year="2014")],
|
||||
)
|
||||
assert request is not None
|
||||
|
||||
with patch.object(chain, "_record_user_message"), patch(
|
||||
"app.chain.message.MediaInteractionChain.handle_text_interaction",
|
||||
return_value=True,
|
||||
) as handle_text, patch.object(chain, "_handle_ai_message") as handle_ai:
|
||||
chain.handle_message(
|
||||
channel=MessageChannel.Wechat,
|
||||
source="wechat-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
action="Search",
|
||||
keyword="星际穿越",
|
||||
title="星际穿越",
|
||||
meta=self._build_meta("星际穿越"),
|
||||
items=[MediaInfo(title="星际穿越", year="2014")],
|
||||
text="1",
|
||||
)
|
||||
self.assertIsNotNone(request)
|
||||
|
||||
with patch.object(chain, "_record_user_message"), patch(
|
||||
"app.chain.message.MediaInteractionChain.handle_text_interaction",
|
||||
return_value=True,
|
||||
) as handle_text, patch.object(chain, "_handle_ai_message") as handle_ai:
|
||||
chain.handle_message(
|
||||
channel=MessageChannel.Wechat,
|
||||
source="wechat-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
text="1",
|
||||
)
|
||||
handle_text.assert_called_once()
|
||||
handle_ai.assert_not_called()
|
||||
|
||||
handle_text.assert_called_once()
|
||||
handle_ai.assert_not_called()
|
||||
|
||||
def test_callback_routes_to_media_interaction_chain(self):
|
||||
chain = MessageChain()
|
||||
request = media_interaction_manager.create_or_replace(
|
||||
user_id="10001",
|
||||
def test_noai_prefix_starts_traditional_search_when_global_ai_enabled():
|
||||
"""全局 AI 开启时,/noai 前缀应让本条消息进入传统搜索交互。"""
|
||||
chain = MessageChain()
|
||||
meta = _build_meta("星际穿越")
|
||||
medias = [
|
||||
MediaInfo(title="星际穿越", year="2014"),
|
||||
MediaInfo(title="Interstellar", year="2014"),
|
||||
]
|
||||
|
||||
with patch.object(chain, "_record_user_message"), patch(
|
||||
"app.chain.message.settings.AI_AGENT_ENABLE", True
|
||||
), patch(
|
||||
"app.chain.message.settings.AI_AGENT_GLOBAL", True
|
||||
), patch(
|
||||
"app.chain.media.MediaChain.search",
|
||||
return_value=(meta, medias),
|
||||
) as search_media, patch(
|
||||
"app.chain.message.MediaInteractionChain.post_medias_message"
|
||||
) as post_medias_message, patch.object(
|
||||
chain, "_handle_ai_message"
|
||||
) as handle_ai:
|
||||
chain.handle_message(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
action="Search",
|
||||
keyword="星际穿越",
|
||||
title="星际穿越",
|
||||
meta=self._build_meta("星际穿越"),
|
||||
items=[MediaInfo(title="星际穿越", year="2014")],
|
||||
text="/noai 星际穿越",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.chain.message.MediaInteractionChain.handle_callback_interaction",
|
||||
return_value=True,
|
||||
) as handle_callback:
|
||||
chain._handle_callback(
|
||||
text=f"CALLBACK:media:{request.request_id}:page-next",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
)
|
||||
search_media.assert_called_once_with("星际穿越")
|
||||
post_medias_message.assert_called_once()
|
||||
handle_ai.assert_not_called()
|
||||
|
||||
handle_callback.assert_called_once()
|
||||
request = media_interaction_manager.get_by_user("10001")
|
||||
assert request is not None
|
||||
assert request.action == "Search"
|
||||
assert request.keyword == "星际穿越"
|
||||
assert len(request.items) == 2
|
||||
|
||||
def test_media_interaction_starts_search_and_posts_media_list(self):
|
||||
chain = MediaInteractionChain()
|
||||
meta = self._build_meta("星际穿越")
|
||||
medias = [
|
||||
MediaInfo(title="星际穿越", year="2014"),
|
||||
MediaInfo(title="Interstellar", year="2014"),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"app.chain.media.MediaChain.search",
|
||||
return_value=(meta, medias),
|
||||
), patch.object(chain, "post_medias_message") as post_medias_message:
|
||||
handled = chain.handle_text_interaction(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
text="星际穿越",
|
||||
)
|
||||
def test_noai_prefix_preserves_traditional_interaction_priority_after_search():
|
||||
"""通过 /noai 进入传统交互后,后续选择应继续优先走传统交互。"""
|
||||
chain = MessageChain()
|
||||
request = media_interaction_manager.create_or_replace(
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Wechat,
|
||||
source="wechat-test",
|
||||
username="tester",
|
||||
action="Search",
|
||||
keyword="星际穿越",
|
||||
title="星际穿越",
|
||||
meta=_build_meta("星际穿越"),
|
||||
items=[MediaInfo(title="星际穿越", year="2014")],
|
||||
)
|
||||
assert request is not None
|
||||
|
||||
self.assertTrue(handled)
|
||||
post_medias_message.assert_called_once()
|
||||
notification = post_medias_message.call_args.args[0]
|
||||
self.assertTrue(notification.buttons)
|
||||
self.assertTrue(
|
||||
notification.buttons[0][0]["callback_data"].startswith("media:")
|
||||
with patch.object(chain, "_record_user_message"), patch(
|
||||
"app.chain.message.settings.AI_AGENT_ENABLE", True
|
||||
), patch(
|
||||
"app.chain.message.settings.AI_AGENT_GLOBAL", True
|
||||
), patch(
|
||||
"app.chain.message.MediaInteractionChain.handle_text_interaction",
|
||||
return_value=True,
|
||||
) as handle_text, patch.object(chain, "_handle_ai_message") as handle_ai:
|
||||
chain.handle_message(
|
||||
channel=MessageChannel.Wechat,
|
||||
source="wechat-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
text="1",
|
||||
)
|
||||
|
||||
request = media_interaction_manager.get_by_user("10001")
|
||||
self.assertIsNotNone(request)
|
||||
self.assertEqual(request.action, "Search")
|
||||
self.assertEqual(len(request.items), 2)
|
||||
handle_text.assert_called_once()
|
||||
handle_ai.assert_not_called()
|
||||
|
||||
def test_media_interaction_legacy_page_callback_updates_existing_request(self):
|
||||
chain = MediaInteractionChain()
|
||||
request = media_interaction_manager.create_or_replace(
|
||||
user_id="10001",
|
||||
|
||||
def test_callback_routes_to_media_interaction_chain():
|
||||
"""媒体按钮回调应路由到媒体交互链。"""
|
||||
chain = MessageChain()
|
||||
request = media_interaction_manager.create_or_replace(
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
action="Search",
|
||||
keyword="星际穿越",
|
||||
title="星际穿越",
|
||||
meta=_build_meta("星际穿越"),
|
||||
items=[MediaInfo(title="星际穿越", year="2014")],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.chain.message.MediaInteractionChain.handle_callback_interaction",
|
||||
return_value=True,
|
||||
) as handle_callback:
|
||||
chain._handle_callback(
|
||||
text=f"CALLBACK:media:{request.request_id}:page-next",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
action="Search",
|
||||
keyword="星际穿越",
|
||||
title="星际穿越",
|
||||
meta=self._build_meta("星际穿越"),
|
||||
items=[
|
||||
MediaInfo(title=f"资源 {index}", year="2024")
|
||||
for index in range(1, 11)
|
||||
],
|
||||
)
|
||||
|
||||
with patch.object(chain, "post_medias_message") as post_medias_message:
|
||||
handled = chain.handle_callback_interaction(
|
||||
callback_data="page_n",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
original_message_id=123,
|
||||
original_chat_id="456",
|
||||
)
|
||||
handle_callback.assert_called_once()
|
||||
|
||||
self.assertTrue(handled)
|
||||
self.assertEqual(request.page, 1)
|
||||
post_medias_message.assert_called_once()
|
||||
notification = post_medias_message.call_args.args[0]
|
||||
self.assertEqual(notification.original_message_id, 123)
|
||||
self.assertEqual(notification.original_chat_id, "456")
|
||||
|
||||
def test_media_interaction_starts_search_and_posts_media_list():
|
||||
"""传统媒体交互应能搜索媒体并发送候选列表。"""
|
||||
chain = MediaInteractionChain()
|
||||
meta = _build_meta("星际穿越")
|
||||
medias = [
|
||||
MediaInfo(title="星际穿越", year="2014"),
|
||||
MediaInfo(title="Interstellar", year="2014"),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"app.chain.media.MediaChain.search",
|
||||
return_value=(meta, medias),
|
||||
), patch.object(chain, "post_medias_message") as post_medias_message:
|
||||
handled = chain.handle_text_interaction(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
text="星际穿越",
|
||||
)
|
||||
|
||||
assert handled
|
||||
post_medias_message.assert_called_once()
|
||||
notification = post_medias_message.call_args.args[0]
|
||||
assert notification.buttons
|
||||
assert notification.buttons[0][0]["callback_data"].startswith("media:")
|
||||
|
||||
request = media_interaction_manager.get_by_user("10001")
|
||||
assert request is not None
|
||||
assert request.action == "Search"
|
||||
assert len(request.items) == 2
|
||||
|
||||
|
||||
def test_media_interaction_legacy_page_callback_updates_existing_request():
|
||||
"""旧格式翻页回调仍应更新当前媒体交互请求。"""
|
||||
chain = MediaInteractionChain()
|
||||
request = media_interaction_manager.create_or_replace(
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
action="Search",
|
||||
keyword="星际穿越",
|
||||
title="星际穿越",
|
||||
meta=_build_meta("星际穿越"),
|
||||
items=[
|
||||
MediaInfo(title=f"资源 {index}", year="2024")
|
||||
for index in range(1, 11)
|
||||
],
|
||||
)
|
||||
|
||||
with patch.object(chain, "post_medias_message") as post_medias_message:
|
||||
handled = chain.handle_callback_interaction(
|
||||
callback_data="page_n",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
original_message_id=123,
|
||||
original_chat_id="456",
|
||||
)
|
||||
|
||||
assert handled
|
||||
assert request.page == 1
|
||||
post_medias_message.assert_called_once()
|
||||
notification = post_medias_message.call_args.args[0]
|
||||
assert notification.original_message_id == 123
|
||||
assert notification.original_chat_id == "456"
|
||||
|
||||
Reference in New Issue
Block a user