Compare commits

...

8 Commits

22 changed files with 676 additions and 137 deletions

View File

@@ -1604,6 +1604,7 @@ class MoviePilotAgent:
original_chat_id=self.original_chat_id,
title=title,
text=message,
save_history=False,
)
)

View File

@@ -536,6 +536,7 @@ class StreamingHandler:
original_chat_id=self._original_chat_id,
title=self._title,
text=current_text,
save_history=False,
),
)
if response and response.success and response.message_id:
@@ -581,6 +582,7 @@ class StreamingHandler:
original_chat_id=self._original_chat_id,
title=self._title,
text=current_text,
save_history=False,
),
)
if response and response.success and response.message_id:

View File

@@ -649,5 +649,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
title=title,
text=message,
image=image,
save_history=False,
)
)

View File

@@ -105,6 +105,7 @@ class SendLocalFileTool(MoviePilotTool):
text=message,
file_path=str(resolved_path),
file_name=file_name or resolved_path.name,
save_history=False,
)
)
return "本地附件已发送"

View File

@@ -100,6 +100,7 @@ class SendMessageTool(MoviePilotTool):
title=title,
text=text,
image=image_url,
save_history=False,
)
)
self._agent_context["user_reply_sent"] = True

View File

@@ -96,6 +96,7 @@ class SendVoiceMessageTool(MoviePilotTool):
if voice_path and settings.AUDIO_OUTPUT_INCLUDE_TEXT
else None
),
save_history=False,
)
)
self._agent_context["user_reply_sent"] = True

View File

@@ -10,6 +10,7 @@ from app.agent.tools.tags import ToolTag
from app.core.event import eventmanager
from app.db.subscribe_oper import SubscribeOper
from app.log import logger
from app.schemas.event import SubscribeModifiedEventData
from app.schemas.types import EventType
@@ -261,13 +262,14 @@ class UpdateSubscribeTool(MoviePilotTool):
# 发送订阅调整事件
await eventmanager.async_send_event(
EventType.SubscribeModified,
{
"subscribe_id": subscribe_id,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": updated_subscribe.to_dict()
SubscribeModifiedEventData(
subscribe_id=subscribe_id,
old_subscribe_info=old_subscribe_dict,
subscribe_info=updated_subscribe.to_dict()
if updated_subscribe
else {},
},
scene="agent_update",
).to_dict(),
)
# 构建返回结果

View File

@@ -21,6 +21,7 @@ from app.db.user_oper import get_current_active_user_async
from app.helper.server import MoviePilotServerHelper
from app.log import logger
from app.scheduler import Scheduler
from app.schemas.event import SubscribeModifiedEventData
from app.schemas.types import MediaType, EventType, SystemConfigKey
router = APIRouter()
@@ -149,11 +150,12 @@ async def update_subscribe(
# 发送订阅调整事件
await eventmanager.async_send_event(
EventType.SubscribeModified,
{
"subscribe_id": subscribe_in.id,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
},
SubscribeModifiedEventData(
subscribe_id=subscribe_in.id,
old_subscribe_info=old_subscribe_dict,
subscribe_info=updated_subscribe.to_dict() if updated_subscribe else {},
scene="update",
).to_dict(),
)
return schemas.Response(success=True)
@@ -181,11 +183,12 @@ async def update_subscribe_status(
# 发送订阅调整事件
await eventmanager.async_send_event(
EventType.SubscribeModified,
{
"subscribe_id": subid,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
},
SubscribeModifiedEventData(
subscribe_id=subid,
old_subscribe_info=old_subscribe_dict,
subscribe_info=updated_subscribe.to_dict() if updated_subscribe else {},
scene="status",
).to_dict(),
)
return schemas.Response(success=True)
@@ -275,13 +278,14 @@ async def reset_subscribes(
# 发送订阅调整事件
await eventmanager.async_send_event(
EventType.SubscribeModified,
{
"subscribe_id": subid,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": updated_subscribe.to_dict()
SubscribeModifiedEventData(
subscribe_id=subid,
old_subscribe_info=old_subscribe_dict,
subscribe_info=updated_subscribe.to_dict()
if updated_subscribe
else {},
},
scene="reset",
).to_dict(),
)
return schemas.Response(success=True)
return schemas.Response(success=False, message="订阅不存在")

View File

@@ -2635,6 +2635,8 @@ class MediaInteractionChain(ChainBase):
download_dirs = self._get_download_dirs(media_info)
if not download_dirs:
return False
if len(download_dirs) == 1 and not self._is_auto_download_dir(download_dirs[0]):
return False
request.pending_torrent_page = request.page
request.phase = "download-dir"
@@ -3252,6 +3254,11 @@ class MediaInteractionChain(ChainBase):
"""
获取可供消息交互选择的下载目录。
"""
dir_infos = [
dir_info
for dir_info in DirectoryHelper().get_download_dirs()
if dir_info.download_path
]
download_dirs = [
DownloadDirectory(
name=dir_info.name,
@@ -3265,11 +3272,13 @@ class MediaInteractionChain(ChainBase):
media_type=dir_info.media_type,
media_category=dir_info.media_category,
)
for dir_info in DirectoryHelper().get_download_dirs()
if dir_info.download_path and cls._match_download_dir_media(dir_info, media_info)
for dir_info in dir_infos
if cls._match_download_dir_media(dir_info, media_info)
]
if not download_dirs:
return []
if len(download_dirs) == 1:
return download_dirs
return [cls._build_auto_download_dir(), *download_dirs]
@classmethod

View File

@@ -40,6 +40,7 @@ from app.schemas import (
TransferQueue,
TransferJob,
TransferJobTask,
TmdbEpisode,
)
from app.schemas.exception import OperationInterrupted
from app.schemas.types import (
@@ -947,6 +948,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
mediainfo=task.mediainfo,
transferinfo=transferinfo,
season_episode=se_str,
episodes_info=task.episodes_info,
username=task.username,
)
@@ -3395,10 +3397,17 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
mediainfo: MediaInfo,
transferinfo: TransferInfo,
season_episode: Optional[str] = None,
episodes_info: Optional[List[TmdbEpisode]] = None,
username: Optional[str] = None,
):
"""
发送入库成功的消息
:param meta: 文件元数据
:param mediainfo: 识别的媒体信息
:param transferinfo: 文件整理信息
:param season_episode: 已入库季集文本
:param episodes_info: 当前季的全部集信息
:param username: 用户名
"""
self.post_message(
Notification(
@@ -3412,6 +3421,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
mediainfo=mediainfo,
transferinfo=transferinfo,
season_episode=season_episode,
episodes_info=episodes_info,
username=username,
)

View File

@@ -180,6 +180,8 @@ class TemplateContextBuilder:
"season_fmt": meta.season,
# 集号
"episode": meta.episode_seqs,
# 当前季总集数
"total_episodes": len(episodes) if episodes else 0,
# 季集 SxxExx
"season_episode": "%s%s" % (meta.season, meta.episode),
# 段/节

View File

@@ -603,6 +603,48 @@ class SubscribeEpisodesRefreshEventData(ChainEventData):
reason: str = Field(default="", description="覆盖原因")
class SubscribeModifiedEventData(BaseEventData):
"""
SubscribeModified 广播事件数据。
主程序在订阅字段被普通更新、状态入口、重置或 Agent 更新后发出。payload
继续保持 dict 形态scene 用于表达操作场景fields 表达最终快照里的真实字段差异。
"""
subscribe_id: int = Field(description="订阅 ID")
old_subscribe_info: Dict[str, Any] = Field(default_factory=dict, description="更新前订阅快照")
subscribe_info: Dict[str, Any] = Field(default_factory=dict, description="更新后订阅快照")
scene: str = Field(default="update", description="触发场景update/status/reset/agent_update")
fields: List[str] = Field(default_factory=list, description="真实变更字段")
@model_validator(mode="after")
def compute_fields(self):
self.fields = self._diff_fields(self.old_subscribe_info, self.subscribe_info)
return self
@staticmethod
def _diff_fields(old_info: Dict[str, Any], new_info: Dict[str, Any]) -> List[str]:
"""
按 old/new 快照并集计算真实字段差异;缺失 key 按 None 参与比较。
"""
old_info = old_info or {}
new_info = new_info or {}
keys = set(old_info) | set(new_info)
return sorted(key for key in keys if old_info.get(key) != new_info.get(key))
def to_dict(self) -> Dict[str, Any]:
"""
输出公开事件 payload避免内部属性被未来扩展意外暴露。
"""
return {
"subscribe_id": self.subscribe_id,
"old_subscribe_info": self.old_subscribe_info,
"subscribe_info": self.subscribe_info,
"scene": self.scene,
"fields": list(self.fields),
}
class SubscribeCompletionCheckEventData(ChainEventData):
"""
SubscribeCompletionCheck 事件的数据模型

View File

@@ -1,4 +1,4 @@
moviepilot-rust~=0.1.13
moviepilot-rust~=0.1.14
pydantic>=2.13.4,<3.0.0
pydantic-settings>=2.14.1,<3.0.0
SQLAlchemy~=2.0.50

View File

@@ -1,17 +1,19 @@
import asyncio
from unittest.mock import AsyncMock, Mock, patch
from app.agent import MoviePilotAgent
from app.agent.tools.impl.ask_user_choice import (
AskUserChoiceTool,
UserChoiceOptionInput,
)
from app.agent.tools.impl.send_message import SendMessageTool
from app.chain.message import MessageChain
from app.core.config import settings
from app.db import SessionFactory
from app.db.message_oper import MessageOper
from app.db.models.message import Message
from app.helper.interaction import AgentInteractionOption, agent_interaction_manager, media_interaction_manager
from app.schemas.types import MessageChannel
from app.schemas.types import MessageChannel, NotificationType
def _clear_messages() -> None:
@@ -120,6 +122,49 @@ def test_ask_user_choice_message_is_not_recorded_to_message_history():
async_send_message.assert_awaited_once()
def test_agent_final_reply_disables_notification_history():
"""Agent 最终回复发往渠道时不保存通知历史。"""
agent = MoviePilotAgent(
session_id="session-agent-reply",
user_id="10001",
channel=MessageChannel.Telegram.value,
source="telegram-test",
username="tester",
)
with patch(
"app.agent.AgentChain.async_post_message",
new_callable=AsyncMock,
) as async_post_message:
asyncio.run(agent.send_agent_message("已完成处理"))
notification = async_post_message.await_args.args[0]
assert notification.mtype == NotificationType.Agent
assert notification.save_history is False
def test_send_message_tool_disables_notification_history():
"""Agent 主动发消息工具发送的通知不保存通知历史。"""
tool = SendMessageTool(session_id="session-send-message", user_id="10001")
tool.set_message_attr(
channel=MessageChannel.Telegram.value,
source="telegram-test",
username="tester",
)
tool.set_agent_context(agent_context={})
with patch(
"app.agent.tools.base.ToolChain.async_post_message",
new_callable=AsyncMock,
) as async_post_message:
result = asyncio.run(tool.run(message="处理结果", title="MoviePilot助手"))
notification = async_post_message.await_args.args[0]
assert result == "消息已发送"
assert notification.text == "处理结果"
assert notification.save_history is False
def test_agent_choice_callback_is_not_recorded_to_message_history():
"""Agent 按钮选择回传不登记到数据库或实时消息队列。"""
chain = MessageChain()

View File

@@ -1,6 +1,4 @@
import importlib.util
from types import SimpleNamespace
from pathlib import Path
from typing import Iterator, Optional, Type
import pytest
@@ -69,30 +67,8 @@ def _schema_properties(args_schema: Type[BaseModel]) -> dict:
return args_schema.model_json_schema().get("properties", {})
def _load_lexiannot_tool_schemas() -> list[Type[BaseModel]]:
"""只加载 LexiAnnot schema 文件,避免触发插件包可选依赖。"""
schema_path = (
Path(__file__).resolve().parents[1]
/ "app"
/ "plugins"
/ "lexiannot"
/ "schemas.py"
)
spec = importlib.util.spec_from_file_location(
"_test_lexiannot_schemas",
schema_path,
)
assert spec and spec.loader
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return [
module.VocabularyAnnotatingToolInput,
module.QueryAnnotationTasksToolInput,
]
def test_agent_tool_schemas_do_not_expose_explanation_parameter() -> None:
"""所有 Agent 工具输入模型不应暴露 explanation 参数。"""
"""仓库内置 Agent 工具和中间件输入模型不应暴露 explanation 参数。"""
tool_classes = [
*MoviePilotToolFactory.BUILTIN_TOOL_CLASSES,
AskUserChoiceTool,
@@ -103,7 +79,6 @@ def test_agent_tool_schemas_do_not_expose_explanation_parameter() -> None:
SkillToolInput,
QueryActivityLogInput,
]
plugin_schemas = _load_lexiannot_tool_schemas()
for tool_class in tool_classes:
args_schema = getattr(tool_class, "args_schema", None)
@@ -111,7 +86,7 @@ def test_agent_tool_schemas_do_not_expose_explanation_parameter() -> None:
continue
assert "explanation" not in _schema_properties(args_schema), tool_class.name
for args_schema in middleware_schemas + plugin_schemas:
for args_schema in middleware_schemas:
assert "explanation" not in _schema_properties(args_schema), args_schema.__name__

View File

@@ -0,0 +1,74 @@
import asyncio
import json
from unittest.mock import AsyncMock, patch
from app.agent.tools.impl.update_subscribe import UpdateSubscribeTool
from app.schemas.types import EventType
def test_agent_update_subscribe_sends_modified_event_payload_with_agent_scene():
"""
Agent 更新订阅后只发送 modify 事件,并标记 agent_update 场景。
"""
subscribe = _AgentSubscribe(id=9, name="旧标题", state="R", total_episode=8)
oper = _SubscribeOperStub(subscribe)
with patch(
"app.agent.tools.impl.update_subscribe.SubscribeOper",
return_value=oper,
), patch(
"app.agent.tools.impl.update_subscribe.eventmanager.async_send_event",
new=AsyncMock(),
) as send_event:
result = asyncio.run(
UpdateSubscribeTool(session_id="session-1", user_id="10001").run(
subscribe_id=9,
name="新标题",
state="S",
)
)
payload = json.loads(result)
assert payload["success"] is True
assert oper.updates == [(9, {"name": "新标题", "state": "S"})]
send_event.assert_awaited_once()
event_type, event_payload = send_event.await_args.args
assert event_type == EventType.SubscribeModified
assert event_payload["subscribe_id"] == 9
assert event_payload["scene"] == "agent_update"
assert event_payload["fields"] == ["name", "state"]
assert event_payload["old_subscribe_info"]["name"] == "旧标题"
assert event_payload["subscribe_info"]["name"] == "新标题"
class _AgentSubscribe:
"""
最小订阅替身,模拟 Agent 工具依赖的订阅对象接口。
"""
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __getattr__(self, item):
return None
def to_dict(self):
return dict(self.__dict__)
class _SubscribeOperStub:
"""
内存订阅操作替身,记录工具最终提交的更新字段。
"""
def __init__(self, subscribe):
self.subscribe = subscribe
self.updates = []
async def async_get(self, subscribe_id):
return self.subscribe if subscribe_id == self.subscribe.id else None
async def async_update(self, subscribe_id, payload):
self.updates.append((subscribe_id, dict(payload)))
self.subscribe.__dict__.update(payload)
return self.subscribe

View File

@@ -65,7 +65,7 @@ def _build_tv_context(title: str = "葬送的芙莉莲") -> Context:
def _build_download_dirs() -> list[TransferDirectoryConf]:
"""构造消息交互可选择的下载目录配置。"""
"""构造不同媒体类型各一个下载目录配置。"""
return [
TransferDirectoryConf(
name="电影下载",
@@ -85,6 +85,46 @@ def _build_download_dirs() -> list[TransferDirectoryConf]:
]
def _build_multiple_movie_download_dirs() -> list[TransferDirectoryConf]:
"""构造多个匹配电影类型的下载目录配置。"""
return [
TransferDirectoryConf(
name="电影下载",
storage="local",
download_path="/downloads/movies",
priority=1,
media_type=MediaType.MOVIE.value,
),
TransferDirectoryConf(
name="4K电影下载",
storage="local",
download_path="/downloads/uhd-movies",
priority=2,
media_type=MediaType.MOVIE.value,
),
TransferDirectoryConf(
name="动画下载",
storage="rclone",
download_path="/media/anime",
priority=3,
media_type=MediaType.TV.value,
media_category="动漫",
),
]
def _build_single_download_dir() -> list[TransferDirectoryConf]:
"""构造只有一个下载目录的配置。"""
return [
TransferDirectoryConf(
name="默认下载",
storage="local",
download_path="/downloads",
priority=1,
),
]
def test_message_routes_text_reply_to_media_interaction_before_ai():
"""已有传统媒体交互时,用户回复应优先交给传统交互处理。"""
chain = MessageChain()
@@ -295,7 +335,7 @@ def test_media_interaction_legacy_page_callback_updates_existing_request():
def test_torrent_selection_prompts_download_dir_buttons_before_download():
"""支持按钮的渠道选择资源后,应先发送下载目录按钮而不是立即下载。"""
"""匹配当前媒体的目录有多个时,应先发送下载目录按钮而不是立即下载。"""
chain = MediaInteractionChain()
context = _build_context()
request = media_interaction_manager.create_or_replace(
@@ -313,7 +353,7 @@ def test_torrent_selection_prompts_download_dir_buttons_before_download():
with patch(
"app.chain.message.DirectoryHelper.get_download_dirs",
return_value=_build_download_dirs(),
return_value=_build_multiple_movie_download_dirs(),
), patch.object(chain, "post_message") as post_message, patch(
"app.chain.message.DownloadChain.download_single"
) as download_single:
@@ -334,12 +374,93 @@ def test_torrent_selection_prompts_download_dir_buttons_before_download():
assert "请选择下载目录" in notification.title
assert "1. 自动匹配目录" in notification.text
assert "2. 电影下载 (/downloads/movies)" in notification.text
assert "3. 4K电影下载 (/downloads/uhd-movies)" in notification.text
assert "动画下载" not in notification.text
assert notification.buttons[0][0]["callback_data"] == f"media:{request.request_id}:download-dir:1"
def test_torrent_selection_skips_download_dir_when_only_one_dir_matches_media():
"""匹配当前媒体的目录只有一个时,应跳过目录选择并交给下载链自动匹配。"""
chain = MediaInteractionChain()
context = _build_context()
request = media_interaction_manager.create_or_replace(
user_id="10001",
channel=MessageChannel.Telegram,
source="telegram-test",
username="tester",
action="Search",
keyword="星际穿越",
title="星际穿越",
meta=_build_meta("星际穿越"),
items=[context],
)
request.phase = "torrent"
with patch(
"app.chain.message.DirectoryHelper.get_download_dirs",
return_value=_build_download_dirs(),
), patch.object(chain, "post_message") as post_message, patch(
"app.chain.message.DownloadChain.download_single",
return_value="hash",
) as download_single:
handled = chain.handle_text_interaction(
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
text="1",
)
assert handled
assert request.phase == "torrent"
post_message.assert_not_called()
download_single.assert_called_once()
assert download_single.call_args.args[0] is context
assert "save_path" not in download_single.call_args.kwargs
def test_torrent_selection_skips_download_dir_when_user_has_single_dir():
"""用户只有一个下载目录时,也应跳过目录选择并交给下载链自动匹配。"""
chain = MediaInteractionChain()
context = _build_context()
request = media_interaction_manager.create_or_replace(
user_id="10001",
channel=MessageChannel.Telegram,
source="telegram-test",
username="tester",
action="Search",
keyword="星际穿越",
title="星际穿越",
meta=_build_meta("星际穿越"),
items=[context],
)
request.phase = "torrent"
with patch(
"app.chain.message.DirectoryHelper.get_download_dirs",
return_value=_build_single_download_dir(),
), patch.object(chain, "post_message") as post_message, patch(
"app.chain.message.DownloadChain.download_single",
return_value="hash",
) as download_single:
handled = chain.handle_text_interaction(
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
text="1",
)
assert handled
assert request.phase == "torrent"
post_message.assert_not_called()
download_single.assert_called_once()
assert download_single.call_args.args[0] is context
assert "save_path" not in download_single.call_args.kwargs
def test_torrent_selection_prompts_text_download_dir_for_plain_channel():
"""不支持按钮的渠道选择资源后,应提示用户回复数字选择下载目录。"""
"""不支持按钮的渠道在多个匹配目录时,应提示用户回复数字选择下载目录。"""
chain = MediaInteractionChain()
context = _build_context()
request = media_interaction_manager.create_or_replace(
@@ -357,7 +478,7 @@ def test_torrent_selection_prompts_text_download_dir_for_plain_channel():
with patch(
"app.chain.message.DirectoryHelper.get_download_dirs",
return_value=_build_download_dirs(),
return_value=_build_multiple_movie_download_dirs(),
), patch.object(chain, "post_message") as post_message:
handled = chain.handle_text_interaction(
channel=MessageChannel.Wechat,
@@ -374,6 +495,7 @@ def test_torrent_selection_prompts_text_download_dir_for_plain_channel():
assert notification.buttons is None
assert "1. 自动匹配目录" in notification.text
assert "2. 电影下载 (/downloads/movies)" in notification.text
assert "3. 4K电影下载 (/downloads/uhd-movies)" in notification.text
assert "动画下载" not in notification.text
@@ -398,7 +520,7 @@ def test_download_dir_callback_runs_pending_single_download_without_save_path_fo
with patch(
"app.chain.message.DirectoryHelper.get_download_dirs",
return_value=_build_download_dirs(),
return_value=_build_multiple_movie_download_dirs(),
), patch(
"app.chain.message.DownloadChain.download_single",
return_value="hash",
@@ -440,7 +562,7 @@ def test_download_dir_callback_runs_pending_single_download_with_save_path():
with patch(
"app.chain.message.DirectoryHelper.get_download_dirs",
return_value=_build_download_dirs(),
return_value=_build_multiple_movie_download_dirs(),
), patch(
"app.chain.message.DownloadChain.download_single",
return_value="hash",
@@ -482,7 +604,7 @@ def test_download_dir_text_reply_runs_pending_single_download_without_save_path(
with patch(
"app.chain.message.DirectoryHelper.get_download_dirs",
return_value=_build_download_dirs(),
return_value=_build_multiple_movie_download_dirs(),
), patch(
"app.chain.message.DownloadChain.download_single",
return_value="hash",
@@ -515,7 +637,6 @@ def test_get_download_dirs_keeps_matching_tv_category_dir():
download_dirs = chain._get_download_dirs(context.media_info)
assert [download_dir.name for download_dir in download_dirs] == [
"自动匹配目录",
"动画下载",
]
assert download_dirs[1].save_path == "rclone:/media/anime"
assert download_dirs[0].save_path == "rclone:/media/anime"

View File

@@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, patch
from app.api.endpoints.subscribe import create_subscribe
from app.schemas.subscribe import Subscribe
from app.schemas.types import MediaType
from app.schemas.types import EventType, MediaType
class SubscribeEndpointTest(TestCase):
@@ -73,3 +73,135 @@ class SubscribeEndpointTest(TestCase):
self.assertTrue(response.success)
self.assertEqual(async_add.await_args.kwargs["season"], 0)
def test_update_status_sends_modified_event_payload_with_scene_and_fields(self):
"""
状态更新只负责发出订阅修改事件,并携带场景和真实变更字段。
"""
from app.api.endpoints.subscribe import update_subscribe_status
subscribe = _EndpointSubscribe(id=5, state="R", name="测试订阅")
with patch(
"app.api.endpoints.subscribe.Subscribe.async_get",
new=AsyncMock(side_effect=[subscribe, subscribe]),
), patch(
"app.api.endpoints.subscribe.eventmanager.async_send_event",
new=AsyncMock(),
) as send_event:
response = asyncio.run(update_subscribe_status(subid=5, state="S", db=object()))
self.assertTrue(response.success)
send_event.assert_awaited_once()
event_type, payload = send_event.await_args.args
self.assertEqual(event_type, EventType.SubscribeModified)
self.assertEqual(payload["subscribe_id"], 5)
self.assertEqual(payload["scene"], "status")
self.assertEqual(payload["fields"], ["state"])
self.assertEqual(payload["old_subscribe_info"]["state"], "R")
self.assertEqual(payload["subscribe_info"]["state"], "S")
def test_reset_sends_modified_event_payload_with_reset_scene(self):
"""
reset 事件需要明确 scene消费者不需要再从字段差异猜测用户意图。
"""
from app.api.endpoints.subscribe import reset_subscribes
subscribe = _EndpointSubscribe(
id=6,
state="S",
name="测试订阅",
total_episode=10,
lack_episode=3,
note=[1, 2],
current_priority=80,
episode_priority={"1": 80},
)
with patch(
"app.api.endpoints.subscribe.Subscribe.async_get",
new=AsyncMock(side_effect=[subscribe, subscribe]),
), patch(
"app.api.endpoints.subscribe.eventmanager.async_send_event",
new=AsyncMock(),
) as send_event:
response = asyncio.run(reset_subscribes(subid=6, db=object()))
self.assertTrue(response.success)
send_event.assert_awaited_once()
event_type, payload = send_event.await_args.args
self.assertEqual(event_type, EventType.SubscribeModified)
self.assertEqual(payload["subscribe_id"], 6)
self.assertEqual(payload["scene"], "reset")
self.assertEqual(
payload["fields"],
["current_priority", "episode_priority", "lack_episode", "note", "state"],
)
self.assertEqual(payload["subscribe_info"]["note"], [])
self.assertEqual(payload["subscribe_info"]["lack_episode"], 10)
def test_update_subscribe_sends_modified_event_payload_without_progress_refresh(self):
"""
普通更新只发送 modify 事件;进度刷新由事件消费者或后续流程处理。
"""
from app.api.endpoints.subscribe import update_subscribe
subscribe = _EndpointSubscribe(
id=7,
name="旧标题",
total_episode=8,
lack_episode=2,
vote=0.0,
sites=[],
search_imdbid=0,
filter_groups=[],
start_episode=0,
)
subscribe_in = Subscribe(id=7, name="新标题", total_episode=8, lack_episode=2)
with patch(
"app.api.endpoints.subscribe.Subscribe.async_get",
new=AsyncMock(side_effect=[subscribe, subscribe]),
), patch(
"app.api.endpoints.subscribe.eventmanager.async_send_event",
new=AsyncMock(),
) as send_event:
response = asyncio.run(update_subscribe(subscribe_in=subscribe_in, db=object()))
self.assertTrue(response.success)
send_event.assert_awaited_once()
event_type, payload = send_event.await_args.args
self.assertEqual(event_type, EventType.SubscribeModified)
self.assertEqual(payload["subscribe_id"], 7)
self.assertEqual(payload["scene"], "update")
self.assertEqual(payload["fields"], ["name"])
self.assertEqual(payload["old_subscribe_info"]["name"], "旧标题")
self.assertEqual(payload["subscribe_info"]["name"], "新标题")
class _EndpointSubscribe:
"""
最小订阅替身,模拟 endpoint 依赖的 ORM 对象接口。
"""
def __init__(self, **kwargs):
self.id = kwargs.pop("id", None)
self.name = kwargs.pop("name", None)
self.total_episode = kwargs.pop("total_episode", None)
self.lack_episode = kwargs.pop("lack_episode", None)
self.state = kwargs.pop("state", None)
self.note = kwargs.pop("note", None)
self.current_priority = kwargs.pop("current_priority", None)
self.episode_priority = kwargs.pop("episode_priority", None)
self.manual_total_episode = kwargs.pop("manual_total_episode", None)
self.__dict__.update(kwargs)
def to_dict(self):
return {
key: value
for key, value in self.__dict__.items()
if value is not None
}
async def async_update(self, _db, payload):
self.__dict__.update(payload)

View File

@@ -0,0 +1,50 @@
from app.schemas.event import SubscribeModifiedEventData
def test_subscribe_modified_event_data_computes_sorted_fields():
data = SubscribeModifiedEventData(
subscribe_id=7,
old_subscribe_info={"state": "R", "lack_episode": 3, "name": "A"},
subscribe_info={"state": "S", "lack_episode": 3, "name": "B"},
scene="status",
)
assert data.fields == ["name", "state"]
assert data.to_dict() == {
"subscribe_id": 7,
"old_subscribe_info": {"state": "R", "lack_episode": 3, "name": "A"},
"subscribe_info": {"state": "S", "lack_episode": 3, "name": "B"},
"scene": "status",
"fields": ["name", "state"],
}
def test_subscribe_modified_event_data_diffs_missing_keys_as_none():
data = SubscribeModifiedEventData(
subscribe_id=8,
old_subscribe_info={"state": "R", "episode_priority": {"1": 80}},
subscribe_info={"state": "R"},
scene="reset",
)
assert data.fields == ["episode_priority"]
assert set(data.to_dict()) == {
"subscribe_id",
"old_subscribe_info",
"subscribe_info",
"scene",
"fields",
}
def test_subscribe_modified_event_data_ignores_caller_supplied_fields():
data = SubscribeModifiedEventData(
subscribe_id=9,
old_subscribe_info={"state": "R"},
subscribe_info={"state": "S"},
scene="update",
fields=["fake"],
)
assert data.fields == ["state"]
assert data.to_dict()["fields"] == ["state"]

View File

@@ -9,93 +9,112 @@ TemplateContextBuilder 的并发安全单元测试。
线程下连续调用 ``build()``,校验每个线程拿到的字典只反映自己的入参。
"""
import threading
import unittest
from app.helper.message import TemplateContextBuilder
from app.schemas.tmdb import TmdbEpisode
class TemplateContextBuilderConcurrencyTest(unittest.TestCase):
THREAD_COUNT = 8
ITERATIONS_PER_THREAD = 200
def _build_fake_meta():
"""
构造模板上下文测试所需的最小元数据对象。
"""
meta = type("FakeMeta", (), {})()
meta.begin_episode = None
meta.title = "Movie.2024.1080p.x265.10bit.mkv"
meta.name = "Movie"
meta.en_name = "Movie"
meta.year = "2024"
meta.season_seq = ""
meta.season = ""
meta.episode_seqs = ""
meta.episode = ""
meta.part = None
meta.customization = None
meta.fps = None
meta.resource_type = None
meta.resource_effect = None
meta.edition = ""
meta.resource_pix = "1080p"
meta.resource_term = "1080p"
meta.resource_team = None
meta.video_encode = "x265 10bit"
meta.video_bit = "10bit"
meta.audio_encode = "AAC"
meta.web_source = None
return meta
def test_concurrent_build_no_cross_contamination() -> None:
"""
使用 8 个线程并发调用同一 TemplateContextBuilder 实例的 build()
确保各自的 file_extension / 自定义 kwargs 不会被其它线程覆盖。
"""
builder = TemplateContextBuilder()
errors = []
THREAD_COUNT = 8
ITERATIONS_PER_THREAD = 200
def worker(tag: int) -> None:
try:
for _ in range(ITERATIONS_PER_THREAD):
ctx = builder.build(
file_extension=f".{tag}",
marker=tag,
)
assert ctx.get("fileExt") == f".{tag}"
assert ctx.get("marker") == tag
except AssertionError as exc:
errors.append(exc)
def test_concurrent_build_no_cross_contamination(self):
builder = TemplateContextBuilder()
errors = []
threads = [
threading.Thread(target=worker, args=(i,), name=f"builder-{i}")
for i in range(THREAD_COUNT)
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
def worker(tag: int) -> None:
try:
for _ in range(self.ITERATIONS_PER_THREAD):
ctx = builder.build(
file_extension=f".{tag}",
marker=tag,
)
self.assertEqual(ctx.get("fileExt"), f".{tag}")
self.assertEqual(ctx.get("marker"), tag)
except AssertionError as exc:
errors.append(exc)
assert not errors, f"检测到并发串味,共 {len(errors)} 条;首个错误:{errors[0] if errors else ''}"
threads = [
threading.Thread(target=worker, args=(i,), name=f"builder-{i}")
for i in range(self.THREAD_COUNT)
]
for t in threads:
t.start()
for t in threads:
t.join()
self.assertFalse(
errors,
msg=f"检测到并发串味,共 {len(errors)} 条;首个错误:{errors[0] if errors else ''}",
)
def test_build_returns_independent_dicts() -> None:
"""
连续两次 build() 应返回相互独立的 dict 实例,避免调用方误用共享结果。
"""
builder = TemplateContextBuilder()
first = builder.build(file_extension=".a", marker=1)
second = builder.build(file_extension=".b", marker=2)
def test_build_returns_independent_dicts(self):
"""
即便不开线程,连续两次 build() 也应当返回相互独立的 dict 实例,
避免无状态化后调用方误以为返回的还是 builder 内部共享对象。
"""
builder = TemplateContextBuilder()
first = builder.build(file_extension=".a", marker=1)
second = builder.build(file_extension=".b", marker=2)
self.assertIsNot(first, second)
self.assertEqual(first.get("fileExt"), ".a")
self.assertEqual(second.get("fileExt"), ".b")
# 第二次调用不应反向污染第一次的结果
self.assertEqual(first.get("marker"), 1)
assert first is not second
assert first.get("fileExt") == ".a"
assert second.get("fileExt") == ".b"
assert first.get("marker") == 1
def test_build_exposes_video_bit_from_meta(self):
"""
模板上下文应提供独立 videoBit 字段,避免用户只能从 videoCodec 中手工拆位深。
"""
meta = type("FakeMeta", (), {})()
meta.begin_episode = None
meta.title = "Movie.2024.1080p.x265.10bit.mkv"
meta.name = "Movie"
meta.en_name = "Movie"
meta.year = "2024"
meta.season_seq = ""
meta.season = ""
meta.episode_seqs = ""
meta.episode = ""
meta.part = None
meta.customization = None
meta.fps = None
meta.resource_type = None
meta.resource_effect = None
meta.edition = ""
meta.resource_pix = "1080p"
meta.resource_term = "1080p"
meta.resource_team = None
meta.video_encode = "x265 10bit"
meta.video_bit = "10bit"
meta.audio_encode = "AAC"
meta.web_source = None
context = TemplateContextBuilder().build(meta=meta)
def test_build_exposes_video_bit_from_meta() -> None:
"""
模板上下文应提供独立 videoBit 字段,避免用户只能从 videoCodec 中手工拆位深。
"""
context = TemplateContextBuilder().build(meta=_build_fake_meta())
self.assertEqual(context.get("videoCodec"), "x265 10bit")
self.assertEqual(context.get("videoBit"), "10bit")
assert context.get("videoCodec") == "x265 10bit"
assert context.get("videoBit") == "10bit"
def test_build_exposes_total_episodes_from_current_season() -> None:
"""
模板上下文应提供当前季总集数,供入库通知模板直接引用。
"""
context = TemplateContextBuilder().build(
meta=_build_fake_meta(),
episodes_info=[
TmdbEpisode(episode_number=1, name="第一集"),
TmdbEpisode(episode_number=2, name="第二集"),
TmdbEpisode(episode_number=3, name="第三集"),
],
)
assert context.get("total_episodes") == 3

View File

@@ -0,0 +1,47 @@
from unittest.mock import patch
from app.chain.transfer import TransferChain
from app.core.context import MediaInfo
from app.core.meta.metabase import MetaBase
from app.schemas import TransferInfo
from app.schemas.tmdb import TmdbEpisode
from app.schemas.types import ContentType, MediaType, NotificationType
def test_send_transfer_message_passes_episode_info_to_template_context() -> None:
"""
入库成功通知应把当前季集信息传给消息模板,确保 total_episodes 可渲染。
"""
chain = TransferChain()
meta = MetaBase("Test.Show.S01E01.mkv")
meta.type = MediaType.TV
meta.name = "Test Show"
meta.begin_season = 1
meta.begin_episode = 1
episodes_info = [
TmdbEpisode(episode_number=1, name="第一集"),
TmdbEpisode(episode_number=2, name="第二集"),
]
mediainfo = MediaInfo(
type=MediaType.TV,
title="Test Show",
season=1,
tmdb_id=12345,
)
transferinfo = TransferInfo(success=True)
with patch.object(chain, "post_message") as post_message:
chain.send_transfer_message(
meta=meta,
mediainfo=mediainfo,
transferinfo=transferinfo,
season_episode="S01 E01",
episodes_info=episodes_info,
username="tester",
)
message = post_message.call_args.args[0]
assert message.mtype == NotificationType.Organize
assert message.ctype == ContentType.OrganizeSuccess
assert post_message.call_args.kwargs["episodes_info"] is episodes_info
assert post_message.call_args.kwargs["season_episode"] == "S01 E01"

View File

@@ -1,2 +1,2 @@
APP_VERSION = 'v2.13.15'
FRONTEND_VERSION = 'v2.13.15'
APP_VERSION = 'v2.13.16'
FRONTEND_VERSION = 'v2.13.16'