mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 20:42:43 +08:00
fix(agent): 修复添加订阅时的用户名映射
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
"""添加订阅工具"""
|
||||
|
||||
from typing import Optional, Type, List
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.db.user_oper import UserOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.schemas.types import MediaType, MessageChannel
|
||||
|
||||
|
||||
class AddSubscribeInput(BaseModel):
|
||||
@@ -101,6 +102,36 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
|
||||
return message
|
||||
|
||||
async def _resolve_subscribe_username(self) -> Optional[str]:
|
||||
"""优先映射为系统用户名,未绑定时回退当前渠道用户名。"""
|
||||
resolved_username = self._username
|
||||
if not self._channel or not self._user_id:
|
||||
return resolved_username
|
||||
|
||||
try:
|
||||
channel = MessageChannel(self._channel)
|
||||
except ValueError:
|
||||
return resolved_username
|
||||
|
||||
binding_keys = {
|
||||
MessageChannel.Telegram: ("telegram_userid",),
|
||||
MessageChannel.Discord: ("discord_userid",),
|
||||
MessageChannel.Wechat: ("wechat_userid",),
|
||||
MessageChannel.Slack: ("slack_userid",),
|
||||
MessageChannel.VoceChat: ("vocechat_userid",),
|
||||
MessageChannel.SynologyChat: ("synologychat_userid",),
|
||||
MessageChannel.QQ: ("qq_userid", "qq_openid"),
|
||||
}.get(channel)
|
||||
if not binding_keys:
|
||||
return resolved_username
|
||||
|
||||
mapped_username = await self.run_blocking(
|
||||
"db",
|
||||
UserOper().get_name,
|
||||
**{key: self._user_id for key in binding_keys},
|
||||
)
|
||||
return mapped_username or resolved_username
|
||||
|
||||
async def run(
|
||||
self,
|
||||
title: str,
|
||||
@@ -137,6 +168,7 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
if media_type_enum == MediaType.TV
|
||||
else None
|
||||
)
|
||||
subscribe_username = await self._resolve_subscribe_username()
|
||||
|
||||
# 构建额外的订阅参数
|
||||
subscribe_kwargs = {}
|
||||
@@ -162,7 +194,7 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
tmdbid=tmdb_id,
|
||||
doubanid=douban_id,
|
||||
season=season,
|
||||
username=self._user_id,
|
||||
username=subscribe_username,
|
||||
**subscribe_kwargs,
|
||||
)
|
||||
if sid:
|
||||
|
||||
@@ -3,15 +3,24 @@ import unittest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class TestAgentAddSubscribeTool(unittest.TestCase):
|
||||
def test_tv_subscription_without_season_reports_default_first_season(self):
|
||||
tool = AddSubscribeTool(session_id="session-1", user_id="10001")
|
||||
tool.set_message_attr(
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-main",
|
||||
username="tg_display_name",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.add_subscribe.SubscribeChain.async_add",
|
||||
new=AsyncMock(return_value=(1, "")),
|
||||
) as async_add, patch(
|
||||
"app.agent.tools.impl.add_subscribe.UserOper.get_name",
|
||||
return_value="moviepilot-user",
|
||||
):
|
||||
result = asyncio.run(
|
||||
tool.run(
|
||||
@@ -21,9 +30,36 @@ class TestAgentAddSubscribeTool(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(async_add.await_args.kwargs["username"], "moviepilot-user")
|
||||
self.assertIn("第1季", result)
|
||||
self.assertIn("默认按第一季订阅", result)
|
||||
|
||||
def test_subscription_falls_back_to_channel_username_when_no_binding_exists(self):
|
||||
tool = AddSubscribeTool(session_id="session-1", user_id="10001")
|
||||
tool.set_message_attr(
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-main",
|
||||
username="tg_display_name",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.add_subscribe.SubscribeChain.async_add",
|
||||
new=AsyncMock(return_value=(1, "")),
|
||||
) as async_add, patch(
|
||||
"app.agent.tools.impl.add_subscribe.UserOper.get_name",
|
||||
return_value=None,
|
||||
):
|
||||
result = asyncio.run(
|
||||
tool.run(
|
||||
title="The Matrix",
|
||||
year="1999",
|
||||
media_type="movie",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(async_add.await_args.kwargs["username"], "tg_display_name")
|
||||
self.assertIn("成功添加订阅:The Matrix (1999)", result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user