fix(agent): 修复添加订阅时的用户名映射

This commit is contained in:
jxxghp
2026-05-04 21:27:48 +08:00
parent 367ecafbbb
commit 39f9550f86
2 changed files with 71 additions and 3 deletions

View File

@@ -1,13 +1,14 @@
"""添加订阅工具"""
from typing import Optional, Type, List
from typing import List, Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.subscribe import SubscribeChain
from app.db.user_oper import UserOper
from app.log import logger
from app.schemas.types import MediaType
from app.schemas.types import MediaType, MessageChannel
class AddSubscribeInput(BaseModel):
@@ -101,6 +102,36 @@ class AddSubscribeTool(MoviePilotTool):
return message
async def _resolve_subscribe_username(self) -> Optional[str]:
"""优先映射为系统用户名,未绑定时回退当前渠道用户名。"""
resolved_username = self._username
if not self._channel or not self._user_id:
return resolved_username
try:
channel = MessageChannel(self._channel)
except ValueError:
return resolved_username
binding_keys = {
MessageChannel.Telegram: ("telegram_userid",),
MessageChannel.Discord: ("discord_userid",),
MessageChannel.Wechat: ("wechat_userid",),
MessageChannel.Slack: ("slack_userid",),
MessageChannel.VoceChat: ("vocechat_userid",),
MessageChannel.SynologyChat: ("synologychat_userid",),
MessageChannel.QQ: ("qq_userid", "qq_openid"),
}.get(channel)
if not binding_keys:
return resolved_username
mapped_username = await self.run_blocking(
"db",
UserOper().get_name,
**{key: self._user_id for key in binding_keys},
)
return mapped_username or resolved_username
async def run(
self,
title: str,
@@ -137,6 +168,7 @@ class AddSubscribeTool(MoviePilotTool):
if media_type_enum == MediaType.TV
else None
)
subscribe_username = await self._resolve_subscribe_username()
# 构建额外的订阅参数
subscribe_kwargs = {}
@@ -162,7 +194,7 @@ class AddSubscribeTool(MoviePilotTool):
tmdbid=tmdb_id,
doubanid=douban_id,
season=season,
username=self._user_id,
username=subscribe_username,
**subscribe_kwargs,
)
if sid:

View File

@@ -3,15 +3,24 @@ import unittest
from unittest.mock import AsyncMock, patch
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
from app.schemas.types import MessageChannel
class TestAgentAddSubscribeTool(unittest.TestCase):
def test_tv_subscription_without_season_reports_default_first_season(self):
tool = AddSubscribeTool(session_id="session-1", user_id="10001")
tool.set_message_attr(
channel=MessageChannel.Telegram.value,
source="telegram-main",
username="tg_display_name",
)
with patch(
"app.agent.tools.impl.add_subscribe.SubscribeChain.async_add",
new=AsyncMock(return_value=(1, "")),
) as async_add, patch(
"app.agent.tools.impl.add_subscribe.UserOper.get_name",
return_value="moviepilot-user",
):
result = asyncio.run(
tool.run(
@@ -21,9 +30,36 @@ class TestAgentAddSubscribeTool(unittest.TestCase):
)
)
self.assertEqual(async_add.await_args.kwargs["username"], "moviepilot-user")
self.assertIn("第1季", result)
self.assertIn("默认按第一季订阅", result)
def test_subscription_falls_back_to_channel_username_when_no_binding_exists(self):
tool = AddSubscribeTool(session_id="session-1", user_id="10001")
tool.set_message_attr(
channel=MessageChannel.Telegram.value,
source="telegram-main",
username="tg_display_name",
)
with patch(
"app.agent.tools.impl.add_subscribe.SubscribeChain.async_add",
new=AsyncMock(return_value=(1, "")),
) as async_add, patch(
"app.agent.tools.impl.add_subscribe.UserOper.get_name",
return_value=None,
):
result = asyncio.run(
tool.run(
title="The Matrix",
year="1999",
media_type="movie",
)
)
self.assertEqual(async_add.await_args.kwargs["username"], "tg_display_name")
self.assertIn("成功添加订阅The Matrix (1999)", result)
if __name__ == "__main__":
unittest.main()