From 39f9550f86140366ac3969c167b3d403ed31a51c Mon Sep 17 00:00:00 2001 From: jxxghp Date: Mon, 4 May 2026 21:27:48 +0800 Subject: [PATCH] =?UTF-8?q?fix(agent):=20=E4=BF=AE=E5=A4=8D=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E8=AE=A2=E9=98=85=E6=97=B6=E7=9A=84=E7=94=A8=E6=88=B7?= =?UTF-8?q?=E5=90=8D=E6=98=A0=E5=B0=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/agent/tools/impl/add_subscribe.py | 38 ++++++++++++++++++++++++-- tests/test_agent_add_subscribe_tool.py | 36 ++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 3 deletions(-) diff --git a/app/agent/tools/impl/add_subscribe.py b/app/agent/tools/impl/add_subscribe.py index 518ff000..882fbdd7 100644 --- a/app/agent/tools/impl/add_subscribe.py +++ b/app/agent/tools/impl/add_subscribe.py @@ -1,13 +1,14 @@ """添加订阅工具""" -from typing import Optional, Type, List +from typing import List, Optional, Type from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.chain.subscribe import SubscribeChain +from app.db.user_oper import UserOper from app.log import logger -from app.schemas.types import MediaType +from app.schemas.types import MediaType, MessageChannel class AddSubscribeInput(BaseModel): @@ -101,6 +102,36 @@ class AddSubscribeTool(MoviePilotTool): return message + async def _resolve_subscribe_username(self) -> Optional[str]: + """优先映射为系统用户名,未绑定时回退当前渠道用户名。""" + resolved_username = self._username + if not self._channel or not self._user_id: + return resolved_username + + try: + channel = MessageChannel(self._channel) + except ValueError: + return resolved_username + + binding_keys = { + MessageChannel.Telegram: ("telegram_userid",), + MessageChannel.Discord: ("discord_userid",), + MessageChannel.Wechat: ("wechat_userid",), + MessageChannel.Slack: ("slack_userid",), + MessageChannel.VoceChat: ("vocechat_userid",), + MessageChannel.SynologyChat: ("synologychat_userid",), + MessageChannel.QQ: ("qq_userid", "qq_openid"), + }.get(channel) + if not binding_keys: + return resolved_username + + mapped_username = await self.run_blocking( + "db", + UserOper().get_name, + **{key: self._user_id for key in binding_keys}, + ) + return mapped_username or resolved_username + async def run( self, title: str, @@ -137,6 +168,7 @@ class AddSubscribeTool(MoviePilotTool): if media_type_enum == MediaType.TV else None ) + subscribe_username = await self._resolve_subscribe_username() # 构建额外的订阅参数 subscribe_kwargs = {} @@ -162,7 +194,7 @@ class AddSubscribeTool(MoviePilotTool): tmdbid=tmdb_id, doubanid=douban_id, season=season, - username=self._user_id, + username=subscribe_username, **subscribe_kwargs, ) if sid: diff --git a/tests/test_agent_add_subscribe_tool.py b/tests/test_agent_add_subscribe_tool.py index f1ff82da..983c836d 100644 --- a/tests/test_agent_add_subscribe_tool.py +++ b/tests/test_agent_add_subscribe_tool.py @@ -3,15 +3,24 @@ import unittest from unittest.mock import AsyncMock, patch from app.agent.tools.impl.add_subscribe import AddSubscribeTool +from app.schemas.types import MessageChannel class TestAgentAddSubscribeTool(unittest.TestCase): def test_tv_subscription_without_season_reports_default_first_season(self): tool = AddSubscribeTool(session_id="session-1", user_id="10001") + tool.set_message_attr( + channel=MessageChannel.Telegram.value, + source="telegram-main", + username="tg_display_name", + ) with patch( "app.agent.tools.impl.add_subscribe.SubscribeChain.async_add", new=AsyncMock(return_value=(1, "")), + ) as async_add, patch( + "app.agent.tools.impl.add_subscribe.UserOper.get_name", + return_value="moviepilot-user", ): result = asyncio.run( tool.run( @@ -21,9 +30,36 @@ class TestAgentAddSubscribeTool(unittest.TestCase): ) ) + self.assertEqual(async_add.await_args.kwargs["username"], "moviepilot-user") self.assertIn("第1季", result) self.assertIn("默认按第一季订阅", result) + def test_subscription_falls_back_to_channel_username_when_no_binding_exists(self): + tool = AddSubscribeTool(session_id="session-1", user_id="10001") + tool.set_message_attr( + channel=MessageChannel.Telegram.value, + source="telegram-main", + username="tg_display_name", + ) + + with patch( + "app.agent.tools.impl.add_subscribe.SubscribeChain.async_add", + new=AsyncMock(return_value=(1, "")), + ) as async_add, patch( + "app.agent.tools.impl.add_subscribe.UserOper.get_name", + return_value=None, + ): + result = asyncio.run( + tool.run( + title="The Matrix", + year="1999", + media_type="movie", + ) + ) + + self.assertEqual(async_add.await_args.kwargs["username"], "tg_display_name") + self.assertIn("成功添加订阅:The Matrix (1999)", result) + if __name__ == "__main__": unittest.main()