Compare commits

..

19 Commits

Author SHA1 Message Date
jxxghp
871d1ec0d8 更新 version.py 2026-06-05 16:16:44 +08:00
InfinityPacer
ca1dbdf843 ci: harden pull request unit test workflow (#5902) 2026-06-05 15:31:31 +08:00
InfinityPacer
e77bef7cf1 fix(subscribe): respect custom start episode for missing seasons (#5901) 2026-06-05 15:20:50 +08:00
ui_beam
f4011d3ac2 fix: 修复前端代理服务器设置清空保存后,httpx 持续报 `Unknown scheme for proxy URL (#5899) 2026-06-05 15:20:31 +08:00
jxxghp
d0b62523a0 chore(version): bump application and frontend versions to v2.13.5 2026-06-05 08:27:09 +08:00
Album
a9b1f7e9c9 fix(alist): support openlist rapid upload headers (#5897) 2026-06-05 06:50:20 +08:00
jxxghp
fc8933c648 feat(workflow): enhance workflow context serialization and execution state management 2026-06-05 00:41:02 +08:00
jxxghp
51981d151e feat(workflow): enhance execution state handling for non-JSON serializable values 2026-06-05 00:01:28 +08:00
jxxghp
97cfcda03c feat(workflow): implement action contract management for inputs and outputs 2026-06-04 21:06:25 +08:00
jxxghp
a2984530f8 feat(workflow): add execution configuration and structured execution state to workflow 2026-06-04 15:57:34 +08:00
jxxghp
7474ecd02f feat(workflow): enhance action execution with structured results and context management 2026-06-04 14:28:46 +08:00
jxxghp
9056caae40 feat(workflow): enhance workflow execution and context management 2026-06-04 14:10:06 +08:00
jxxghp
fd280a49b7 feat(auth): implement authentication provider endpoints and ticket exchange 2026-06-04 08:23:54 +08:00
DDSRem
df75f42753 fix: retry stale keep-alive requests (#5893) 2026-06-04 06:55:03 +08:00
DDSRem
0d2c324e28 fix(db): repair episode_priority column type mismatch on PostgreSQL (#5892) 2026-06-04 06:53:11 +08:00
DDSRem
dc0ee2b466 fix: patch urllib3.fields for urllib3-future compatibility (#5890) 2026-06-04 06:40:16 +08:00
InfinityPacer
781b1ce2aa test: 修复单测 warnings 并精确忽略上游弃用告警 (#5889) 2026-06-03 18:34:45 +08:00
InfinityPacer
791f1fe4ac test: 共享测试 harness 入 app/testing(网络守卫 + 引导)并统一 sys.modules 打桩原语 (#5888) 2026-06-03 18:34:20 +08:00
InfinityPacer
6405ff1191 test: split agenttokens plugin test out, un-skip agent event tests (#5885) 2026-06-03 10:50:55 +08:00
58 changed files with 3653 additions and 534 deletions

View File

@@ -11,6 +11,13 @@ on:
# 允许手动触发
workflow_dispatch:
permissions:
contents: read
concurrency:
group: unit-tests-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
pytest:
runs-on: ubuntu-latest

View File

@@ -1,10 +1,11 @@
from fastapi import APIRouter
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
from app.api.endpoints import auth, login, user, webhook, message, site, subscribe, \
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa, openai, anthropic, llm, notification
api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
api_router.include_router(login.router, prefix="/login", tags=["login"])
api_router.include_router(user.router, prefix="/user", tags=["user"])
api_router.include_router(mfa.router, prefix="/mfa", tags=["mfa"])

70
app/api/endpoints/auth.py Normal file
View File

@@ -0,0 +1,70 @@
from typing import Any
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from app import schemas
from app.core.auth_bridge import build_token_response, consume_plugin_auth_ticket
from app.core.plugin import PluginManager
from app.db.models.passkey import PassKey
from app.db.models.user import User
router = APIRouter()
class AuthExchangeRequest(BaseModel):
"""
插件认证票据兑换请求。
"""
ticket: str
def _system_auth_providers() -> list[dict[str, Any]]:
"""
获取系统内建的匿名登录方式摘要。
:return: 系统认证提供方列表
"""
has_passkey = bool(PassKey.list(db=None))
return [
{
"id": "system:passkey",
"type": "system",
"method": "passkey",
"name": "通行密钥",
"icon": "material-symbols:passkey",
"enabled": has_passkey,
}
]
@router.get("/providers", summary="查询登录认证提供方", response_model=list[dict])
def auth_providers() -> list[dict[str, Any]]:
"""
查询系统和插件提供的登录认证入口。
:return: 认证提供方摘要列表
"""
providers = _system_auth_providers()
providers.extend(PluginManager().get_plugin_auth_providers())
return [provider for provider in providers if provider.get("enabled", True)]
@router.post("/exchange", summary="兑换插件认证登录票据", response_model=schemas.Token)
def auth_exchange(body: AuthExchangeRequest) -> schemas.Token:
"""
将插件认证成功后生成的一次性票据兑换为系统 Token。
:param body: 票据兑换请求
:return: 标准登录 Token
"""
ticket_data = consume_plugin_auth_ticket(body.ticket)
if not ticket_data:
raise HTTPException(status_code=401, detail="认证票据无效或已过期")
user = User.get(db=None, rid=ticket_data.get("user_id"))
if not user or not user.is_active:
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
return build_token_response(user)

View File

@@ -22,6 +22,10 @@ from app.schemas.types import EventType, EVENT_TYPE_NAMES
router = APIRouter()
WORKFLOW_TRIGGER_TIMER = "timer"
WORKFLOW_TRIGGER_EVENT = "event"
WORKFLOW_TRIGGER_MANUAL = "manual"
@router.get("/", summary="所有工作流", response_model=List[schemas.Workflow])
async def list_workflows(
@@ -148,16 +152,20 @@ async def workflow_fork(
except json.JSONDecodeError:
return schemas.Response(success=False, message="context字段JSON格式错误")
try:
event_conditions = json.loads(workflow.event_conditions or "{}") if workflow.event_conditions else {}
except json.JSONDecodeError:
return schemas.Response(success=False, message="event_conditions字段JSON格式错误")
share_id = workflow.id
# 创建工作流
workflow_dict = {
"name": workflow.name,
"description": workflow.description,
"timer": workflow.timer,
"trigger_type": workflow.trigger_type or "timer",
"trigger_type": workflow.trigger_type or WORKFLOW_TRIGGER_TIMER,
"event_type": workflow.event_type,
"event_conditions": json.loads(workflow.event_conditions or "{}")
if workflow.event_conditions
else {},
"event_conditions": event_conditions,
"actions": actions,
"flows": flows,
"context": context,
@@ -170,11 +178,11 @@ async def workflow_fork(
return schemas.Response(success=False, message="已存在相同名称的工作流")
# 创建新工作流
workflow = await Workflow(**workflow_dict).async_create(db)
workflow_obj = await Workflow(**workflow_dict).async_create(db)
# 更新复用次数
if workflow:
await MoviePilotServerHelper.async_workflow_fork_by_id(share_id=workflow.id)
if workflow_obj and share_id:
await MoviePilotServerHelper.async_workflow_fork_by_id(share_id=share_id)
return schemas.Response(success=True, message="复用成功")
@@ -225,14 +233,23 @@ def start_workflow(
workflow = WorkflowOper(db).get(workflow_id)
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
if not workflow.trigger_type or workflow.trigger_type == "timer":
trigger_type = workflow.trigger_type or WORKFLOW_TRIGGER_TIMER
if trigger_type == WORKFLOW_TRIGGER_TIMER and not workflow.timer:
return schemas.Response(success=False, message="定时工作流缺少定时器配置")
if trigger_type not in {
WORKFLOW_TRIGGER_TIMER,
WORKFLOW_TRIGGER_EVENT,
WORKFLOW_TRIGGER_MANUAL,
}:
return schemas.Response(success=False, message="工作流触发类型不支持")
# 先更新状态,事件触发注册会重新读取工作流并跳过暂停状态。
workflow.update_state(db, workflow_id, "W")
if trigger_type == WORKFLOW_TRIGGER_TIMER:
# 添加定时任务
Scheduler().update_workflow_job(workflow)
else:
elif trigger_type == WORKFLOW_TRIGGER_EVENT:
# 事件触发:添加到事件触发器
WorkFlowManager().load_workflow_events(workflow_id)
# 更新状态
workflow.update_state(db, workflow_id, "W")
return schemas.Response(success=True)
@@ -251,10 +268,10 @@ def pause_workflow(
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
# 根据触发类型进行不同处理
if workflow.trigger_type == "timer":
if workflow.trigger_type == WORKFLOW_TRIGGER_TIMER:
# 定时触发:移除定时任务
Scheduler().remove_workflow_job(workflow)
elif workflow.trigger_type == "event":
elif workflow.trigger_type == WORKFLOW_TRIGGER_EVENT:
# 事件触发:从事件触发器中移除
WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type)
# 停止工作流
@@ -319,8 +336,11 @@ def update_workflow(
wf.update(db, workflow.model_dump())
# 更新后的工作流对象
updated_workflow = workflow_oper.get(workflow.id)
# 更新定时任务
Scheduler().update_workflow_job(updated_workflow)
scheduler = Scheduler()
scheduler.remove_workflow_job(updated_workflow)
if not updated_workflow.trigger_type or updated_workflow.trigger_type == WORKFLOW_TRIGGER_TIMER:
if updated_workflow.timer:
scheduler.update_workflow_job(updated_workflow)
# 更新事件注册
WorkFlowManager().update_workflow_event(updated_workflow)
return schemas.Response(success=True, message="更新成功")
@@ -338,10 +358,10 @@ def delete_workflow(
workflow = WorkflowOper(db).get(workflow_id)
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
if not workflow.trigger_type or workflow.trigger_type == "timer":
if not workflow.trigger_type or workflow.trigger_type == WORKFLOW_TRIGGER_TIMER:
# 定时触发:删除定时任务
Scheduler().remove_workflow_job(workflow)
else:
elif workflow.trigger_type == WORKFLOW_TRIGGER_EVENT:
# 事件触发:从事件触发器中移除
WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type)
# 删除工作流

View File

@@ -2785,9 +2785,16 @@ class SubscribeChain(ChainBase):
# 更新剧集列表、开始集数、总集数
if not episode_list:
# 整季缺失
episodes = []
start_episode = start_episode or start
total_episode = total_episode or total
original_start = start if start is not None else 1
# 空集列表会被下载链解释为整季下载;当订阅开始集裁掉季初范围时,需要转成显式集数。
if start_episode and total_episode and start_episode > original_start:
episodes = list(range(start_episode, total_episode + 1))
if not episodes:
return True, {}
else:
episodes = []
else:
# 部分缺失
if not start_episode \

File diff suppressed because it is too large Load Diff

146
app/core/auth_bridge.py Normal file
View File

@@ -0,0 +1,146 @@
import secrets
import threading
import time
from datetime import timedelta
from typing import Any, Optional
from app import schemas
from app.core import security
from app.core.config import settings
from app.db.models.user import User
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.sites import SitesHelper
from app.schemas.types import SystemConfigKey
from app.utils.singleton import Singleton
class AuthTicketStore(metaclass=Singleton):
"""
插件认证一次性票据存储。
"""
_ttl_seconds = 120
_max_items = 1024
def __init__(self):
"""
初始化内存票据缓存。
"""
self._tickets: dict[str, dict[str, Any]] = {}
self._lock = threading.RLock()
def create(self, user_id: int, provider_id: str, metadata: Optional[dict[str, Any]] = None) -> str:
"""
创建短时一次性登录票据。
:param user_id: 已通过插件认证的本地用户 ID
:param provider_id: 认证提供方 ID
:param metadata: 插件侧附加信息
:return: 一次性票据字符串
"""
ticket = secrets.token_urlsafe(32)
now = time.time()
with self._lock:
self._cleanup(now)
self._tickets[ticket] = {
"user_id": int(user_id),
"provider_id": provider_id,
"metadata": metadata or {},
"created_at": now,
}
return ticket
def consume(self, ticket: str) -> Optional[dict[str, Any]]:
"""
消费并删除一次性登录票据。
:param ticket: 登录票据
:return: 票据数据,票据不存在或过期时返回 None
"""
if not ticket:
return None
now = time.time()
with self._lock:
data = self._tickets.pop(ticket, None)
self._cleanup(now)
if not data:
return None
if now - float(data.get("created_at") or 0) > self._ttl_seconds:
return None
return data
def _cleanup(self, now: Optional[float] = None) -> None:
"""
清理过期或过量的票据缓存。
:param now: 当前时间戳,未传入时自动读取
"""
current = now or time.time()
expired = [
key
for key, value in self._tickets.items()
if current - float(value.get("created_at") or 0) > self._ttl_seconds
]
for key in expired:
self._tickets.pop(key, None)
if len(self._tickets) <= self._max_items:
return
ordered = sorted(
self._tickets.items(),
key=lambda item: float(item[1].get("created_at") or 0),
)
for key, _ in ordered[: len(self._tickets) - self._max_items]:
self._tickets.pop(key, None)
def create_plugin_auth_ticket(user_id: int, provider_id: str, metadata: Optional[dict[str, Any]] = None) -> str:
"""
为插件认证成功的用户创建一次性登录票据。
:param user_id: 本地用户 ID
:param provider_id: 认证提供方 ID
:param metadata: 插件侧附加信息
:return: 一次性票据字符串
"""
return AuthTicketStore().create(user_id=user_id, provider_id=provider_id, metadata=metadata)
def consume_plugin_auth_ticket(ticket: str) -> Optional[dict[str, Any]]:
"""
消费插件认证登录票据。
:param ticket: 登录票据
:return: 票据数据,票据不存在或过期时返回 None
"""
return AuthTicketStore().consume(ticket)
def build_token_response(user: User) -> schemas.Token:
"""
使用系统统一逻辑构造登录 Token 响应。
:param user: 已认证的本地用户
:return: 标准 Token 响应
"""
level = SitesHelper().auth_level
show_wizard = (
not SystemConfigOper().get(SystemConfigKey.SetupWizardState)
and not settings.ADVANCED_MODE
)
return schemas.Token(
access_token=security.create_access_token(
userid=user.id,
username=user.name,
super_user=user.is_superuser,
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
level=level,
),
token_type="bearer",
super_user=user.is_superuser,
user_id=user.id,
user_name=user.name,
avatar=user.avatar,
level=level,
permissions=user.permissions or {},
wizard=show_wizard,
)

View File

@@ -9,10 +9,10 @@ import sys
import threading
from asyncio import AbstractEventLoop
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_origin, get_args
from urllib.parse import quote, urlencode, urlparse
from dotenv import set_key
from dotenv import set_key, unset_key
from pydantic import BaseModel, Field, ConfigDict, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -690,6 +690,18 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
if isinstance(value, str):
value = value.strip()
# 处理 Optional 类型:当值为空字符串且类型允许 None 时,转为 None
# 兼容 typing.Union (Python 3.9) 与 types.UnionType (Python 3.10+ PEP 604)
origin = get_origin(expected_type)
is_union = origin is Union or getattr(origin, "__name__", None) == "UnionType"
if (
is_union
and type(None) in get_args(expected_type)
and isinstance(value, str)
and not value
):
return default, str(default) != str(original_value)
try:
if expected_type is bool:
if isinstance(value, bool):
@@ -812,13 +824,19 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
logger.warning(message)
return False, message
else:
# 当值为 None 时,从 env 文件中删除该键,恢复为默认值
if converted_value is None:
unset_key(
dotenv_path=SystemUtils.get_env_path(),
key_to_unset=field_name,
)
logger.info(f"配置项 '{field_name}' 已清空,从 'app.env' 中移除")
return True, message
# 如果是列表、字典或集合类型将其转换为JSON字符串
if isinstance(converted_value, (list, dict, set)):
value_to_write = json.dumps(converted_value)
else:
value_to_write = (
str(converted_value) if converted_value is not None else ""
)
value_to_write = str(converted_value)
set_key(
dotenv_path=SystemUtils.get_env_path(),
@@ -967,7 +985,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
@property
def PROXY(self):
if self.PROXY_HOST:
if self.PROXY_HOST and self.PROXY_HOST.strip():
return {
"http": self.PROXY_HOST,
"https": self.PROXY_HOST,
@@ -1009,7 +1027,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
@property
def PROXY_SERVER(self):
if self.PROXY_HOST:
if self.PROXY_HOST and self.PROXY_HOST.strip():
try:
parsed = urlparse(self.PROXY_HOST)
if not parsed.scheme:

View File

@@ -950,6 +950,47 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
})
return remotes
def get_plugin_auth_providers(self) -> List[Dict[str, Any]]:
"""
聚合插件声明的登录认证提供方。
:return: 插件认证入口列表
"""
providers: List[Dict[str, Any]] = []
running_plugins_snapshot = dict(self._running_plugins)
for plugin_id, plugin in running_plugins_snapshot.items():
if not plugin.get_state():
continue
if not hasattr(plugin, "get_auth_providers") or not ObjectUtils.check_method(plugin.get_auth_providers):
continue
try:
plugin_providers = plugin.get_auth_providers() or []
except Exception as e:
logger.error(f"获取插件 {plugin_id} 登录认证提供方出错:{str(e)}")
continue
render_mode = None
dist_path = None
if hasattr(plugin, "get_render_mode"):
render_mode, dist_path = plugin.get_render_mode()
for raw_provider in plugin_providers:
if not raw_provider or not isinstance(raw_provider, dict):
continue
provider = raw_provider.copy()
provider["type"] = "plugin"
provider["plugin_id"] = plugin_id
provider.setdefault("id", f"plugin:{plugin_id}")
provider.setdefault("name", plugin.plugin_name)
provider.setdefault("enabled", True)
if render_mode == "vue" and dist_path:
provider.setdefault("component", "AuthPage")
provider["remote"] = {
"id": plugin_id,
"url": self.get_plugin_remote_entry(plugin_id, dist_path),
"name": plugin.plugin_name,
}
providers.append(provider)
return providers
def get_plugin_sidebar_nav(self) -> List[Dict[str, Any]]:
"""
聚合所有已启用 Vue 插件的侧栏导航项get_sidebar_nav

View File

@@ -40,8 +40,12 @@ class Workflow(Base):
flows = Column(JSON, default=builtin_list)
# 执行上下文
context = Column(JSON, default=dict)
# 执行配置
execution_config = Column(JSON, default=dict)
# 结构化执行状态
execution_state = Column(JSON, default=dict)
# 创建时间
add_time = Column(String, default=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
add_time = Column(String, default=lambda: datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
# 最后执行时间
last_time = Column(String)
@@ -79,7 +83,7 @@ class Workflow(Base):
and_(
or_(
cls.trigger_type == 'timer',
not cls.trigger_type
cls.trigger_type.is_(None)
),
cls.state != 'P'
)
@@ -93,7 +97,7 @@ class Workflow(Base):
and_(
or_(
cls.trigger_type == 'timer',
not cls.trigger_type
cls.trigger_type.is_(None)
),
cls.state != 'P'
)
@@ -217,6 +221,8 @@ class Workflow(Base):
"state": 'W',
"result": None,
"current_action": None,
"context": {},
"execution_state": {},
"run_count": 0 if reset_count else cls.run_count,
})
return True
@@ -229,30 +235,49 @@ class Workflow(Base):
state='W',
result=None,
current_action=None,
context={},
execution_state={},
run_count=0 if reset_count else cls.run_count,
))
return True
@classmethod
@db_update
def update_current_action(cls, db, wid: int, action_id: str, context: dict):
db.query(cls).filter(cls.id == wid).update({
"current_action": cls.current_action + f",{action_id}" if cls.current_action else action_id,
def update_current_action(cls, db, wid: int, action_id: str, context: dict,
execution_state: Optional[dict] = None):
workflow = db.query(cls).filter(cls.id == wid).first()
current_actions = []
if workflow and workflow.current_action:
current_actions = [item for item in workflow.current_action.split(",") if item]
if action_id and action_id not in current_actions:
current_actions.append(action_id)
update_values = {
"current_action": ",".join(current_actions),
"context": context
})
}
if execution_state is not None:
update_values["execution_state"] = execution_state
db.query(cls).filter(cls.id == wid).update(update_values)
return True
@classmethod
@async_db_update
async def async_update_current_action(cls, db: AsyncSession, wid: int, action_id: str, context: dict):
async def async_update_current_action(cls, db: AsyncSession, wid: int, action_id: str, context: dict,
execution_state: Optional[dict] = None):
from sqlalchemy import update
# 先获取当前current_action
result = await db.execute(select(cls.current_action).where(cls.id == wid))
current_action = result.scalar()
new_current_action = current_action + f",{action_id}" if current_action else action_id
current_actions = [item for item in (current_action or "").split(",") if item]
if action_id and action_id not in current_actions:
current_actions.append(action_id)
new_current_action = ",".join(current_actions)
await db.execute(update(cls).where(cls.id == wid).values(
current_action=new_current_action,
context=context
))
update_values = {
"current_action": new_current_action,
"context": context
}
if execution_state is not None:
update_values["execution_state"] = execution_state
await db.execute(update(cls).where(cls.id == wid).values(**update_values))
return True

View File

@@ -91,11 +91,17 @@ class WorkflowOper(DbOper):
"""
return Workflow.fail(self._db, wid, result)
def step(self, wid: int, action_id: str, context: dict) -> bool:
def step(self, wid: int, action_id: str, context: dict, execution_state: Optional[dict] = None) -> bool:
"""
步进
"""
return Workflow.update_current_action(self._db, wid, action_id, context)
return Workflow.update_current_action(
self._db,
wid,
action_id,
context,
execution_state
)
def reset(self, wid: int, reset_count: bool = False) -> bool:
"""

View File

@@ -1,3 +1,4 @@
import hashlib
import json
import time
from datetime import datetime
@@ -708,6 +709,7 @@ class Alist(StorageBase, metaclass=WeakSingleton):
# 获取文件大小
target_name = new_name or path.name
target_path = Path(fileitem.path) / target_name
stat = path.stat()
# 初始化进度回调
progress_callback = transfer_process(path.as_posix())
@@ -718,6 +720,9 @@ class Alist(StorageBase, metaclass=WeakSingleton):
headers.setdefault("Content-Type", "application/octet-stream")
headers.setdefault("As-Task", str(task).lower())
headers.setdefault("File-Path", encoded_path)
headers.setdefault("Content-Length", str(stat.st_size))
headers.setdefault("Last-Modified", str(int(stat.st_mtime * 1000)))
headers.update(self.__get_upload_hash_headers(path))
# 创建自定义的文件流,支持进度回调
class ProgressFileReader:
@@ -783,6 +788,28 @@ class Alist(StorageBase, metaclass=WeakSingleton):
logger.error(f"【OpenList】上传文件 {path} 失败:{e}")
return None
@staticmethod
def __get_upload_hash_headers(path: Path) -> dict:
"""
计算 OpenList 秒传所需的文件哈希请求头。
"""
md5_hash = hashlib.md5()
sha1_hash = hashlib.sha1()
sha256_hash = hashlib.sha256()
with open(path, "rb") as file_handler:
while True:
chunk = file_handler.read(1024 * 1024)
if not chunk:
break
md5_hash.update(chunk)
sha1_hash.update(chunk)
sha256_hash.update(chunk)
return {
"X-File-Md5": md5_hash.hexdigest(),
"X-File-Sha1": sha1_hash.hexdigest(),
"X-File-Sha256": sha256_hash.hexdigest(),
}
def detail(self, fileitem: schemas.FileItem) -> Optional[schemas.FileItem]:
"""
获取文件详情

View File

@@ -116,7 +116,7 @@ class NexusPhpSiteUserInfo(SiteParserBase):
has_ucoin, self.bonus = self._parse_ucoin(html)
if has_ucoin:
return
tmps = html.xpath('//a[contains(@href,"mybonus")]/text()') if html else None
tmps = html.xpath('//a[contains(@href,"mybonus")]/text()') if html is not None else None
if tmps:
bonus_text = str(tmps[0]).strip()
bonus_match = re.search(r"([\d,.]+)", bonus_text)

View File

@@ -174,6 +174,21 @@ class _PluginBase(metaclass=ABCMeta):
"""
pass
def get_auth_providers(self) -> List[Dict[str, Any]]:
"""
声明插件提供的登录认证入口。
返回示例:
[{
"id": "oidc",
"name": "OIDC 登录",
"icon": "mdi-openid",
"component": "AuthPage",
"enabled": True
}]
"""
pass
def get_module(self) -> Dict[str, Any]:
"""
获取插件模块声明,用于胁持系统模块实现(方法名:方法实现)

View File

@@ -1,6 +1,6 @@
from typing import Optional, List
from typing import Any, List, Optional
from pydantic import BaseModel, Field, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from app.schemas.context import Context, MediaInfo
from app.schemas.download import DownloadTask
@@ -19,13 +19,15 @@ class Workflow(BaseModel):
timer: Optional[str] = Field(default=None, description="定时器")
trigger_type: Optional[str] = Field(default='timer', description="触发类型timer-定时触发 event-事件触发 manual-手动触发")
event_type: Optional[str] = Field(default=None, description="事件类型当trigger_type为event时使用")
event_conditions: Optional[dict] = Field(default={}, description="事件条件JSON格式用于过滤事件")
event_conditions: Optional[dict] = Field(default_factory=dict, description="事件条件JSON格式用于过滤事件")
state: Optional[str] = Field(default=None, description="状态")
current_action: Optional[str] = Field(default=None, description="已执行动作")
result: Optional[str] = Field(default=None, description="任务执行结果")
run_count: Optional[int] = Field(default=0, description="已执行次数")
actions: Optional[list] = Field(default=[], description="任务列表")
flows: Optional[list] = Field(default=[], description="任务流")
actions: Optional[list] = Field(default_factory=list, description="任务列表")
flows: Optional[list] = Field(default_factory=list, description="任务流")
execution_config: Optional[dict] = Field(default_factory=dict, description="工作流执行配置")
execution_state: Optional[dict] = Field(default_factory=dict, description="工作流结构化执行状态")
add_time: Optional[str] = Field(default=None, description="创建时间")
last_time: Optional[str] = Field(default=None, description="最后执行时间")
@@ -48,8 +50,16 @@ class Action(BaseModel):
type: Optional[str] = Field(default=None, description="动作类型 (类名)")
name: Optional[str] = Field(default=None, description="动作名称")
description: Optional[str] = Field(default=None, description="动作描述")
position: Optional[dict] = Field(default={}, description="位置")
data: Optional[dict] = Field(default={}, description="参数")
position: Optional[dict] = Field(default_factory=dict, description="位置")
data: Optional[dict] = Field(default_factory=dict, description="参数")
inputs: Optional[List[str]] = Field(default_factory=list, description="动作输入声明")
outputs: Optional[dict] = Field(default_factory=dict, description="动作输出声明")
join_policy: Optional[str] = Field(default=None, description="多上游节点汇合策略")
fail_policy: Optional[str] = Field(default=None, description="动作失败后的工作流处理策略")
branch_policy: Optional[str] = Field(default=None, description="多出边分支策略")
concurrency_key: Optional[str] = Field(default=None, description="并发互斥键")
timeout: Optional[int] = Field(default=None, description="动作执行超时时间(秒)")
retry: Optional[dict] = Field(default_factory=dict, description="动作重试策略")
class ActionExecution(BaseModel):
@@ -66,16 +76,32 @@ class ActionContext(BaseModel):
动作基础上下文,各动作通用数据
"""
content: Optional[str] = Field(default=None, description="文本类内容")
torrents: Optional[List[Context]] = Field(default=[], description="资源列表")
medias: Optional[List[MediaInfo]] = Field(default=[], description="媒体列表")
fileitems: Optional[List[FileItem]] = Field(default=[], description="文件列表")
downloads: Optional[List[DownloadTask]] = Field(default=[], description="下载任务列表")
sites: Optional[List[Site]] = Field(default=[], description="站点列表")
subscribes: Optional[List[Subscribe]] = Field(default=[], description="订阅列表")
execute_history: Optional[List[ActionExecution]] = Field(default=[], description="执行历史")
torrents: Optional[List[Context]] = Field(default_factory=list, description="资源列表")
medias: Optional[List[MediaInfo]] = Field(default_factory=list, description="媒体列表")
fileitems: Optional[List[FileItem]] = Field(default_factory=list, description="文件列表")
downloads: Optional[List[DownloadTask]] = Field(default_factory=list, description="下载任务列表")
sites: Optional[List[Site]] = Field(default_factory=list, description="站点列表")
subscribes: Optional[List[Subscribe]] = Field(default_factory=list, description="订阅列表")
workflow_context: Optional[dict] = Field(default_factory=dict, description="工作流全局上下文")
node_outputs: Optional[dict] = Field(default_factory=dict, description="节点输出数据")
runtime_state: Optional[dict] = Field(default_factory=dict, description="运行期状态")
artifacts: Optional[dict] = Field(default_factory=dict, description="大对象引用与产物数据")
execute_history: Optional[List[ActionExecution]] = Field(default_factory=list, description="执行历史")
progress: Optional[int] = Field(default=0, description="执行进度(%")
class ActionResult(BaseModel):
"""
动作执行结果。
"""
success: Optional[bool] = Field(default=True, description="动作是否执行成功")
message: Optional[str] = Field(default=None, description="动作执行消息")
context: Optional[ActionContext] = Field(default=None, description="动作执行后的上下文")
outputs: Optional[dict[str, Any]] = Field(default_factory=dict, description="当前节点显式输出")
next_policy: Optional[str] = Field(default=None, description="动作完成后的调度策略")
attempts: Optional[int] = Field(default=1, description="动作实际尝试次数")
class ActionFlow(BaseModel):
"""
工作流流程
@@ -84,6 +110,10 @@ class ActionFlow(BaseModel):
source: Optional[str] = Field(default=None, description="源动作")
target: Optional[str] = Field(default=None, description="目标动作")
animated: Optional[bool] = Field(default=True, description="是否动画流程")
data: Optional[dict] = Field(default_factory=dict, description="流程扩展配置")
condition: Optional[str] = Field(default=None, description="流转条件表达式")
join_policy: Optional[str] = Field(default=None, description="目标节点汇合策略")
branch_policy: Optional[str] = Field(default=None, description="源节点分支策略")
class WorkflowShare(BaseModel):

View File

@@ -3,6 +3,19 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
# urllib3-future 覆盖 urllib3 命名空间后删除了 format_header_param导致 telebot 崩溃,需在加载模块前打补丁
try:
import urllib3.fields as _urllib3_fields
if not hasattr(_urllib3_fields, "format_header_param") and hasattr(
_urllib3_fields, "format_header_param_rfc2231"
):
_urllib3_fields.format_header_param = (
_urllib3_fields.format_header_param_rfc2231
)
except Exception:
pass
from app.chain.system import SystemChain
from app.core.config import global_vars
from app.helper.server import MoviePilotServerHelper
@@ -12,7 +25,11 @@ from app.startup.modules_initializer import init_modules, stop_modules
from app.startup.monitor_initializer import stop_monitor, init_monitor
from app.startup.plugins_initializer import init_plugins, stop_plugins, sync_plugins
from app.startup.routers_initializer import init_routers
from app.startup.scheduler_initializer import stop_scheduler, init_scheduler, init_plugin_scheduler
from app.startup.scheduler_initializer import (
stop_scheduler,
init_scheduler,
init_plugin_scheduler,
)
from app.startup.workflow_initializer import init_workflow, stop_workflow
from app.utils.http import aclose_shared_async_transports

View File

@@ -1,7 +1,13 @@
"""测试辅助工具(主程序与插件仓共享)。
提供测试期对 ``sys.modules`` 的临时打桩能力,保证打桩在使用后还原,避免测试间
因残留假模块而相互污染。仅供测试使用,不参与运行时逻辑。
汇集主程序与插件仓共用的测试 harness仅供测试使用、不参与运行时逻辑
- :mod:`app.testing.stub`:测试期对 ``sys.modules`` 的临时打桩并自动还原,避免残留假模块相互污染;
- :mod:`app.testing.bootstrap`:隔离 CONFIG_DIR、建表、插件目录注入与 v1/v2 marker 等引导逻辑;
- :mod:`app.testing.network_guard`autouse 拦截测试期对非本地主机的真实出站。
子模块各自按需 import如 ``network_guard`` 依赖 pytest故此处只 re-export 无第三方依赖的
:func:`stub_modules`,保持 ``import app.testing`` 不引入 pytest 等测试期依赖。
"""
from app.testing.stub import stub_modules

169
app/testing/bootstrap.py Normal file
View File

@@ -0,0 +1,169 @@
"""测试引导共享实现(主程序与插件仓同源)。
主程序 ``tests/conftest.py`` 与各插件仓的极薄 shim``tests/_bootstrap.py``,仅负责把
后端定位并加入 ``sys.path``)都委托到这里,使「隔离 CONFIG_DIR / 建表 / 注入插件目录 /
按目录打 v1·v2 marker / 退出清理」等引导逻辑只在主程序维护一处,所有消费方行为与修复一致。
其中 :func:`isolate_config_dir` 为主程序与插件仓共用,``prepare_v1/v2_backend`` 与
:func:`mark_plugin_generation` 为插件仓专用。
本模块只依赖标准库,``import`` 期不连库、不触发 ``app.db``:调用方可安全地「先 import 本模块、
再隔离 CONFIG_DIR」不破坏「隔离必须早于首个 ``import app.db``」这一硬约束。
"""
from __future__ import annotations
import atexit
import os
import shutil
import sys
import tempfile
from pathlib import Path
from typing import Optional
# 本进程隔离出的临时 CONFIG_DIR兼作幂等标记
_isolated_config_dir: Optional[str] = None
def isolate_config_dir() -> str:
"""把 ``CONFIG_DIR`` 指向进程私有临时目录,隔离主程序真实库与配置(幂等)。
``import app.db`` / ``import app.chain.*`` 在 import 期即按 ``settings.CONFIG_PATH`` 连接
``user.db``,故本函数必须在首个 ``import app.db`` 之前调用。调用方已显式设置 ``CONFIG_DIR``
(如 CI 指定隔离目录)时尊重之、不覆盖。
:return: 实际生效的 CONFIG_DIR 绝对路径
"""
global _isolated_config_dir
if _isolated_config_dir is not None:
return _isolated_config_dir
existing = os.environ.get("CONFIG_DIR")
if existing:
_isolated_config_dir = existing
return existing
tmp = tempfile.mkdtemp(prefix="mp-test-config-")
os.environ["CONFIG_DIR"] = tmp
_isolated_config_dir = tmp
def _cleanup(path: str = tmp, rmtree=shutil.rmtree, sys_mod=sys) -> None:
"""进程退出时释放 SQLite 连接池再删临时目录。
默认参数绑定 ``rmtree``/``path``/``sys_mod``:解释器关停期标准库模块可能已被回收为 ``None``
绑定后仍可安全调用。先 ``Engine.dispose`` 释放 ``user.db`` 连接,规避 Windows 下
文件锁导致 ``rmtree`` 静默失败(``ignore_errors``)、残留临时目录。
"""
try:
db_mod = sys_mod.modules.get("app.db")
if db_mod is not None:
db_mod.Engine.dispose()
except Exception:
pass
rmtree(path, ignore_errors=True)
atexit.register(_cleanup)
return tmp
def _prepend_sys_path(path: Path) -> None:
"""把目录前置到 ``sys.path``(去重),使其内顶层包可被导入。"""
value = str(path)
if value not in sys.path:
sys.path.insert(0, value)
def ensure_sites_stub() -> None:
"""为 ``app.helper.sites`` 补最小垫片(仅在缺失时)。
``app.helper.sites`` 由独立仓库动态拉取CI / 全新环境无该模块,而众多 ``app.chain.*`` /
``app.modules.*`` 在 import 期依赖它。统一补一个最小垫片,省去各测试文件各自打桩;若真实模块
已存在(本地已拉取)则用真实模块、不覆盖,不影响真实行为。须在隔离 CONFIG_DIR 之后调用,
以免试探性 ``import app.helper.sites`` 触发的连库落到真实库。
"""
if "app.helper.sites" in sys.modules:
return
try:
import app.helper.sites # noqa: F401 本地已拉取时用真实模块
except (ModuleNotFoundError, ImportError):
from types import ModuleType
stub = ModuleType("app.helper.sites")
stub.SitesHelper = object
sys.modules["app.helper.sites"] = stub
def ensure_optional_stub(name: str, **attrs) -> None:
"""为可选第三方依赖补占位模块(仅在缺失时),可带属性。
用例 import 的 app 代码会牵入可选三方库(如 psutil / dateparser / Pinyin2Hanzi /
qbittorrentapi / transmission_rpcCI / 全新环境可能未安装。本函数在该库缺失时补一个带
指定属性的占位,使 import 不致失败;若已真实安装则保留真实模块、不覆盖。占位为进程级常驻
(与 import 生命周期一致、不作用域还原),是「让可选 import 不失败」的垫片——与
:func:`stub_modules`(作用域内打桩并还原)属不同用途,故不收进 stub_modules。
:param name: 可选依赖的顶层模块名
:param attrs: 占位模块需暴露的属性(仅在真正创建占位时设置)
"""
if name in sys.modules:
return
try:
__import__(name)
return
except ImportError:
pass
from types import ModuleType
module = ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
sys.modules[name] = module
def prepare_backend() -> None:
"""隔离 CONFIG_DIR、补 sites 垫片并建表(后端须已在 ``sys.path`` 上)。
主程序中后端即当前包;插件仓由其 ``tests/_bootstrap.py`` shim 在 import 本模块前
先把后端目录注入 ``sys.path``。顺序固定:先隔离 CONFIG_DIR再补 ``app.helper.sites`` 垫片,
最后建表——隔离出的临时库为空,运行期查 ``systemconfig`` 等表会报 no such table故建表
``init_db`` 仅 import models + create_all无 alembic/网络、幂等、毫秒级。
"""
isolate_config_dir()
ensure_sites_stub()
from app.db.init import init_db
init_db()
def prepare_v2_backend(plugins_repo: Path) -> None:
"""v2 插件单测引导:``prepare_backend`` + 把 ``<repo>/plugins.v2`` 注入 ``sys.path``。
与 :func:`prepare_v1_backend` 互斥v1/v2 存在同名插件包,同一进程同时加载会相互覆盖,
须在各自独立的 pytest 会话中运行。
:param plugins_repo: 插件仓根目录(由调用方 shim 传入)
"""
prepare_backend()
_prepend_sys_path(Path(plugins_repo) / "plugins.v2")
def prepare_v1_backend(plugins_repo: Path) -> None:
"""v1 插件单测引导:``prepare_backend`` + 把 ``<repo>/plugins`` 注入 ``sys.path``(与 v2 互斥)。
:param plugins_repo: 插件仓根目录(由调用方 shim 传入)
"""
prepare_backend()
_prepend_sys_path(Path(plugins_repo) / "plugins")
def mark_plugin_generation(items, pytest_module) -> None:
"""按用例所在目录自动给其打 ``v1`` / ``v2`` marker供按代筛选与分会话运行。
优先读取 pytest 7+ 的 ``item.path``,旧版 pytest 缺失该属性时回退到 ``item.fspath``。用
「不带前导斜杠」的子串匹配(``tests/v2/`` / ``tests/v1/``),兼容相对路径与绝对路径两种
运行方式:以 ``pytest tests/v2`` 等相对路径运行时收集路径可能不含前导斜杠。
``pytest`` 模块由各仓 conftest 传入,避免本模块在非测试态强依赖 pytest。
:param items: pytest 收集到的用例集合
:param pytest_module: 调用方传入的 ``pytest`` 模块对象
"""
for item in items:
item_path = getattr(item, "path", None)
path = str(item_path if item_path is not None else item.fspath).replace("\\", "/")
if "tests/v2/" in path:
item.add_marker(pytest_module.mark.v2)
elif "tests/v1/" in path:
item.add_marker(pytest_module.mark.v1)

View File

@@ -0,0 +1,40 @@
"""测试网络守卫(主程序与插件仓共享)。
提供一个 autouse 的 pytest fixture拦截测试期对非本地主机的真实出站网络。主程序
``tests/conftest.py`` 与各插件仓 conftest 只需 ``from app.testing.network_guard import
block_real_network`` 即复用同一道守卫——pytest 会把 conftest 命名空间内(含 import 进来的)
fixture 一并识别autouse 自动作用于每个用例,无需逐用例改动。
仅供测试使用,不参与运行时逻辑。
"""
from __future__ import annotations
import pytest
# 本地回环/通配地址放行其余主机一律视为真实出站getaddrinfo 的 host 可能为 str 或 bytes
_ALLOWED_NETWORK_HOSTS = {"127.0.0.1", "::1", "localhost", "0.0.0.0", "::", ""}
@pytest.fixture(autouse=True)
def block_real_network(monkeypatch):
"""防御纵深:拦截对非本地主机的真实出站,强制测试零真实网络。
补在各用例自身 mock 之上:某用例万一漏 mock 外部依赖TMDB / LLM 目录 / 下载器 /
媒体服务器 / 任意外链),其真实 DNS 解析会在此被拦并报错,而非静默发请求。本地回环放行
sqlite 等。asyncio 默认解析器经线程池调用 ``socket.getaddrinfo``,故拦此一处即覆盖
同步与异步出站。``monkeypatch`` 在用例结束后自动还原,不影响其他用例与进程退出。
"""
import socket
_real_getaddrinfo = socket.getaddrinfo
def _guarded_getaddrinfo(host, *args, **kwargs):
normalized = host.decode() if isinstance(host, (bytes, bytearray)) else host
if normalized is not None and normalized not in _ALLOWED_NETWORK_HOSTS:
raise RuntimeError(
f"测试禁止真实出站网络:尝试解析 {normalized!r};请 mock 对应外部依赖"
)
return _real_getaddrinfo(host, *args, **kwargs)
monkeypatch.setattr(socket, "getaddrinfo", _guarded_getaddrinfo)
yield

View File

@@ -70,6 +70,8 @@ _DEFAULT_MAX_KEEPALIVE_CONNECTIONS = 20
_DEFAULT_MAX_CONNECTIONS = 40
# 默认的 keep-alive 连接过期时间(秒)
_DEFAULT_KEEPALIVE_EXPIRY = 30
# 同步 requests.Session 复用连接时,遇到对端或代理关闭 keep-alive 后允许重试的方法
_REQUESTS_RETRY_IDEMPOTENT_METHODS = ("GET", "HEAD", "OPTIONS")
# 持有 LRU 淘汰后正在异步关闭的 transport task避免 fire-and-forget 被 GC 警告
_pending_eviction_tasks: set[asyncio.Task] = set()
@@ -90,6 +92,9 @@ def _get_shared_async_transport(
会话级状态由调用方在外层 AsyncClient(transport=...) 实例化时单独配置,
每次调用用完即销毁,因此天然无 jar 累积串扰。
"""
# 规范化代理:拒绝空字符串等非法值,防止 httpx 抛出 Unknown scheme for proxy URL
if proxy is not None and (not proxy or not proxy.strip()):
proxy = None
try:
loop = asyncio.get_running_loop()
except RuntimeError:
@@ -344,14 +349,47 @@ class RequestUtils:
kwargs.setdefault("timeout", self._timeout)
kwargs.setdefault("verify", False)
kwargs.setdefault("stream", False)
method_upper = method.upper()
try:
return req_method(method, url, **kwargs)
except (
requests.exceptions.ConnectionError,
requests.exceptions.ChunkedEncodingError,
requests.exceptions.ReadTimeout,
) as e:
if (
self._session is not None
and method_upper in _REQUESTS_RETRY_IDEMPOTENT_METHODS
):
logger.debug(f"keep-alive 连接已失效,同步幂等请求重试一次: {e!r}")
try:
self._session.close()
return req_method(method, url, **kwargs)
except requests.exceptions.RequestException as retry_error:
error_msg = (
str(retry_error)
if str(retry_error)
else f"未知网络错误 (URL: {url}, Method: {method_upper})"
)
logger.debug(f"重试后同步请求仍失败: {error_msg}")
if raise_exception:
raise
return None
error_msg = (
str(e)
if str(e)
else f"未知网络错误 (URL: {url}, Method: {method_upper})"
)
logger.debug(f"同步请求失败(不重试): {error_msg}")
if raise_exception:
raise
return None
except requests.exceptions.RequestException as e:
# 获取更详细的错误信息
error_msg = (
str(e)
if str(e)
else f"未知网络错误 (URL: {url}, Method: {method.upper()})"
else f"未知网络错误 (URL: {url}, Method: {method_upper})"
)
logger.debug(f"请求失败: {error_msg}")
if raise_exception:
@@ -864,12 +902,17 @@ class AsyncRequestUtils:
# 如果已经是字符串格式,直接返回
if isinstance(proxies, str):
return proxies
return proxies.strip() or None
# 如果是字典格式提取http或https代理
if isinstance(proxies, dict):
# 优先使用https代理如果没有则使用http代理
proxy_url = proxies.get("https") or proxies.get("http")
# 先各自 strip避免空白字符串阻断裂合取或回退到 http 代理
https_proxy = proxies.get("https")
http_proxy = proxies.get("http")
https_proxy = https_proxy.strip() if isinstance(https_proxy, str) else None
http_proxy = http_proxy.strip() if isinstance(http_proxy, str) else None
proxy_url = https_proxy or http_proxy
if proxy_url:
return proxy_url

View File

@@ -1,7 +1,6 @@
import threading
from time import sleep
from typing import Dict, Any, Optional
from typing import List, Tuple
from time import monotonic, sleep
from typing import Any, Dict, List, Optional, Tuple
from app.core.config import global_vars
from app.core.event import eventmanager, Event
@@ -9,7 +8,7 @@ from app.db.models import Workflow
from app.db.workflow_oper import WorkflowOper
from app.helper.module import ModuleHelper
from app.log import logger
from app.schemas import ActionContext, Action
from app.schemas import ActionContext, Action, ActionResult
from app.schemas.types import EventType
from app.utils.singleton import Singleton
@@ -63,48 +62,176 @@ class WorkFlowManager(metaclass=Singleton):
"""
停止
"""
for event_type_str in list(self._event_workflows.keys()):
self.remove_workflow_event(event_type_str=event_type_str)
self._actions = {}
self._event_workflows = {}
def excute(self, workflow_id: int, action: Action,
context: ActionContext = None) -> Tuple[bool, str, ActionContext]:
def execute(self, workflow_id: int, action: Action, context: ActionContext = None,
inputs: Optional[dict] = None, runtime: Optional[dict] = None,
cancel_token: Optional[Any] = None) -> ActionResult:
"""
执行工作流动作
"""
if not context:
context = ActionContext()
if action.type in self._actions:
# 实例化之前,清理掉类对象的数据
# 实例化
action_obj = self._actions[action.type](action.id)
# 执行
logger.info(f"执行动作: {action.id} - {action.name}")
try:
result_context = action_obj.execute(workflow_id, action.data, context)
except Exception as err:
logger.error(f"{action.name} 执行失败: {err}")
return False, f"{err}", context
loop = action.data.get("loop")
loop_interval = action.data.get("loop_interval")
if loop and loop_interval:
while not action_obj.done:
if global_vars.is_workflow_stopped(workflow_id):
break
# 等待
logger.info(f"{action.name} 等待 {loop_interval} 秒后继续执行 ...")
sleep(loop_interval)
# 执行
logger.info(f"继续执行动作: {action.id} - {action.name}")
result_context = action_obj.execute(workflow_id, action.data, result_context)
if action_obj.success:
logger.info(f"{action.name} 执行成功")
else:
logger.error(f"{action.name} 执行失败!")
return action_obj.success, action_obj.message, result_context
else:
if action.type not in self._actions:
logger.error(f"未找到动作: {action.type} - {action.name}")
return False, " ", context
return ActionResult(success=False, message=" ", context=context)
retry_config = self._get_retry_config(action)
max_attempts = retry_config["max_attempts"]
interval = retry_config["interval"]
backoff = retry_config["backoff"]
action_result = ActionResult(success=False, message="", context=context)
for attempt in range(1, max_attempts + 1):
if self._is_cancelled(workflow_id, cancel_token):
return ActionResult(success=False, message="工作流已取消", context=context)
runtime_data = {
**(runtime or {}),
"attempt": attempt,
"max_attempts": max_attempts,
"cancel_token": cancel_token,
}
action_result = self._execute_action_once(
workflow_id=workflow_id,
action=action,
context=context,
inputs=inputs or {},
runtime=runtime_data,
cancel_token=cancel_token
)
action_result.attempts = attempt
context = action_result.context or context
if action_result.success:
logger.info(f"{action.name} 执行成功")
return action_result
if attempt < max_attempts and not self._is_cancelled(workflow_id, cancel_token):
wait_seconds = interval * (backoff ** (attempt - 1))
logger.info(f"{action.name} 执行失败,{wait_seconds} 秒后重试({attempt}/{max_attempts}...")
self._sleep_with_cancel(workflow_id, wait_seconds, cancel_token)
logger.error(f"{action.name} 执行失败!")
return action_result
def excute(self, workflow_id: int, action: Action,
context: ActionContext = None) -> Tuple[bool, str, ActionContext]:
"""
执行工作流动作,兼容历史拼写错误的方法名。
"""
action_result = self.execute(workflow_id=workflow_id, action=action, context=context)
return bool(action_result.success), action_result.message or "", action_result.context or context or ActionContext()
@staticmethod
def _normalize_action_result(result: Any, action_obj: Any, fallback_context: ActionContext) -> ActionResult:
"""
将旧版动作上下文与新版结构化结果统一为动作执行结果。
"""
if isinstance(result, ActionResult):
result.context = result.context or fallback_context
if result.message is None:
result.message = action_obj.message
return result
return ActionResult(
success=action_obj.success,
message=action_obj.message,
context=result or fallback_context
)
def _execute_action_once(self, workflow_id: int, action: Action, context: ActionContext,
inputs: dict, runtime: dict, cancel_token: Optional[Any]) -> ActionResult:
action_obj = self._actions[action.type](action.id)
logger.info(f"执行动作: {action.id} - {action.name}")
try:
action_result = self._run_action_with_loop(
workflow_id=workflow_id,
action=action,
action_obj=action_obj,
context=context,
inputs=inputs,
runtime=runtime,
cancel_token=cancel_token
)
except Exception as err:
logger.error(f"{action.name} 执行失败: {err}")
return ActionResult(success=False, message=f"{err}", context=context)
return action_result
def _run_action_with_loop(self, workflow_id: int, action: Action, action_obj: Any,
context: ActionContext, inputs: dict, runtime: dict,
cancel_token: Optional[Any]) -> ActionResult:
timeout = self._get_action_timeout(action)
started_at = monotonic()
action_result = self._call_action(
workflow_id=workflow_id,
action=action,
action_obj=action_obj,
context=context,
inputs=inputs,
runtime=runtime
)
loop = self._get_action_data_value(action, "loop")
loop_interval = self._get_action_data_value(action, "loop_interval")
while loop and loop_interval and not action_obj.done:
if self._is_cancelled(workflow_id, cancel_token):
return ActionResult(success=False, message="工作流已取消", context=action_result.context or context)
if timeout and monotonic() - started_at >= timeout:
return ActionResult(success=False, message=f"动作执行超时({timeout}秒)", context=action_result.context or context)
logger.info(f"{action.name} 等待 {loop_interval} 秒后继续执行 ...")
self._sleep_with_cancel(workflow_id, loop_interval, cancel_token)
if self._is_cancelled(workflow_id, cancel_token):
return ActionResult(success=False, message="工作流已取消", context=action_result.context or context)
logger.info(f"继续执行动作: {action.id} - {action.name}")
action_result = self._call_action(
workflow_id=workflow_id,
action=action,
action_obj=action_obj,
context=action_result.context or context,
inputs=inputs,
runtime=runtime
)
return action_result
def _call_action(self, workflow_id: int, action: Action, action_obj: Any,
context: ActionContext, inputs: dict, runtime: dict) -> ActionResult:
if hasattr(action_obj, "execute_with_inputs"):
result = action_obj.execute_with_inputs(workflow_id, action.data, inputs, runtime, context)
else:
result = action_obj.execute(workflow_id, action.data, context)
return self._normalize_action_result(result, action_obj, context)
@staticmethod
def _get_action_data_value(action: Action, key: str) -> Any:
data = action.data or {}
return data.get(key) if isinstance(data, dict) else None
def _get_action_timeout(self, action: Action) -> Optional[int]:
timeout = action.timeout or self._get_action_data_value(action, "timeout")
return int(timeout) if timeout else None
def _get_retry_config(self, action: Action) -> dict:
retry_config = action.retry or self._get_action_data_value(action, "retry") or {}
if not isinstance(retry_config, dict):
retry_config = {}
return {
"max_attempts": max(int(retry_config.get("max_attempts") or 1), 1),
"interval": max(float(retry_config.get("interval") or 0), 0),
"backoff": max(float(retry_config.get("backoff") or 1), 1),
}
@staticmethod
def _is_cancelled(workflow_id: int, cancel_token: Optional[Any]) -> bool:
if cancel_token and cancel_token.is_cancelled():
return True
return global_vars.is_workflow_stopped(workflow_id)
def _sleep_with_cancel(self, workflow_id: int, seconds: float, cancel_token: Optional[Any]) -> None:
deadline = monotonic() + seconds
while monotonic() < deadline:
if self._is_cancelled(workflow_id, cancel_token):
return
sleep(min(0.1, deadline - monotonic()))
def list_actions(self) -> List[dict]:
"""
@@ -115,6 +242,7 @@ class WorkFlowManager(metaclass=Singleton):
"type": key,
"name": action.name,
"description": action.description,
"contract": action.get_contract(),
"data": {
"label": action.name,
**action.data
@@ -122,12 +250,21 @@ class WorkFlowManager(metaclass=Singleton):
} for key, action in self._actions.items()
]
def get_action_contract(self, action_type: str) -> dict:
"""
获取动作输入输出契约。
"""
action = self._actions.get(action_type)
if not action or not hasattr(action, "get_contract"):
return {}
return action.get_contract()
def update_workflow_event(self, workflow: Workflow):
"""
更新工作流事件触发器
"""
# 确保先移除旧的事件监听器
self.remove_workflow_event(workflow_id=workflow.id, event_type_str=workflow.event_type)
# 工作流可能切换触发事件先按工作流ID从所有事件映射中移除。
self.remove_workflow_event(workflow_id=workflow.id)
# 如果工作流是事件触发类型且未被禁用
if workflow.trigger_type == "event" and workflow.state != 'P':
# 注册事件触发器
@@ -154,41 +291,46 @@ class WorkFlowManager(metaclass=Singleton):
"""
注册工作流事件触发器
"""
if not event_type_str:
return
try:
event_type = EventType(event_type_str)
except ValueError:
logger.error(f"无效的事件类型: {event_type_str}")
return
if event_type in EventType:
# 确保先移除旧的事件监听器
self.remove_workflow_event(workflow_id, event_type.value)
with self._lock:
# 添加新的事件监听器
eventmanager.add_event_listener(event_type, self._handle_event)
# 记录工作流事件触发器
if event_type.value not in self._event_workflows:
self._event_workflows[event_type.value] = []
self._event_workflows[event_type.value].append(workflow_id)
eventmanager.add_event_listener(event_type, self._handle_event)
# 记录工作流事件触发器
if workflow_id not in self._event_workflows[event_type.value]:
self._event_workflows[event_type.value].append(workflow_id)
logger.info(f"已注册工作流 {workflow_id} 事件触发器: {event_type.value}")
def remove_workflow_event(self, workflow_id: int, event_type_str: str):
def remove_workflow_event(self, workflow_id: Optional[int] = None, event_type_str: Optional[str] = None):
"""
移除工作流事件触发器
"""
try:
event_type = EventType(event_type_str)
except ValueError:
logger.error(f"无效的事件类型: {event_type_str}")
return
if event_type in EventType:
event_type_values = [event_type_str] if event_type_str else list(self._event_workflows.keys())
for event_type_value in event_type_values:
try:
event_type = EventType(event_type_value)
except ValueError:
logger.error(f"无效的事件类型: {event_type_value}")
continue
with self._lock:
eventmanager.remove_event_listener(event_type, self._handle_event)
if event_type.value in self._event_workflows:
if workflow_id in self._event_workflows[event_type.value]:
self._event_workflows[event_type.value].remove(workflow_id)
if not self._event_workflows[event_type.value]:
del self._event_workflows[event_type.value]
logger.info(f"已移除工作流 {workflow_id} 事件触发器")
workflow_ids = self._event_workflows.get(event_type.value)
if not workflow_ids:
continue
if workflow_id is None:
workflow_ids.clear()
elif workflow_id in workflow_ids:
workflow_ids.remove(workflow_id)
if not workflow_ids:
self._event_workflows.pop(event_type.value, None)
eventmanager.remove_event_listener(event_type, self._handle_event)
logger.info(f"已移除工作流 {workflow_id or ''} 事件触发器")
def _handle_event(self, event: Event):
"""

View File

@@ -1,9 +1,9 @@
from abc import ABC, abstractmethod
from typing import Union
from typing import Any, Union
from app.chain import ChainBase
from app.db.systemconfig_oper import SystemConfigOper
from app.schemas import ActionContext, ActionParams
from app.schemas import ActionContext, ActionParams, ActionResult
class ActionChain(ChainBase):
@@ -23,9 +23,13 @@ class BaseAction(ABC):
_message = ""
# 缓存键值
_cache_key = "WorkflowCache-%s"
# 动作输入输出契约,由具体动作按需覆盖
contract = {}
def __init__(self, action_id: str):
self._action_id = action_id
self._done_flag = False
self._message = ""
self.systemconfigoper = SystemConfigOper()
@classmethod
@@ -46,6 +50,41 @@ class BaseAction(ABC):
def data(cls) -> dict: # noqa
pass
@classmethod
def get_contract(cls) -> dict:
"""
获取动作输入输出契约。
"""
contract = getattr(cls, "contract", None) or {}
input_fields = cls._build_contract_fields(contract.get("inputs") or [])
output_fields = cls._build_contract_fields(contract.get("outputs") or [])
return {
"inputs": input_fields,
"outputs": output_fields,
"condition_fields": output_fields,
"concurrency_key": contract.get("concurrency_key"),
}
@classmethod
def _build_contract_fields(cls, fields: list) -> list:
"""
标准化动作契约字段。
"""
result = []
for field in fields:
if isinstance(field, str):
field = {"name": field}
if not isinstance(field, dict) or not field.get("name"):
continue
result.append({
"name": field["name"],
"label": field.get("label") or field["name"],
"kind": field.get("kind") or "scalar",
"merge": field.get("merge"),
"identity": field.get("identity"),
})
return result
@property
def done(self) -> bool:
"""
@@ -92,9 +131,12 @@ class BaseAction(ABC):
workflow_cache = self.systemconfigoper.get(workflow_key) or {}
action_cache = workflow_cache.get(self._action_id) or []
if isinstance(data, list):
action_cache.extend(data)
for item in data:
if item not in action_cache:
action_cache.append(item)
else:
action_cache.append(data)
if data not in action_cache:
action_cache.append(data)
workflow_cache[self._action_id] = action_cache
self.systemconfigoper.set(workflow_key, workflow_cache)
@@ -104,3 +146,61 @@ class BaseAction(ABC):
执行动作
"""
raise NotImplementedError
def execute_with_inputs(self, workflow_id: int, params: ActionParams, inputs: dict,
runtime: dict, context: ActionContext) -> ActionResult:
"""
使用显式输入与运行期信息执行动作。
"""
self._apply_inputs_to_context(inputs=inputs, context=context)
self._apply_runtime_to_context(runtime=runtime, context=context)
result_context = self.execute(workflow_id, params, context)
outputs = self._extract_outputs_from_context(result_context)
return ActionResult(
success=self.success,
message=self.message,
context=result_context,
outputs=outputs
)
def _apply_inputs_to_context(self, inputs: dict, context: ActionContext) -> None:
"""
将显式输入回填到旧版上下文字段,兼容仍读取 context 的动作。
"""
inputs = inputs or {}
for field in self.get_contract().get("inputs") or []:
missing = object()
field_name = field["name"]
value = inputs.get(field_name, missing)
if value is missing:
# 兼容旧版节点输入路径,例如 outputs.A.torrents。
for input_key, input_value in inputs.items():
if isinstance(input_key, str) and input_key.split(".")[-1] == field_name:
value = input_value
break
if value is not missing:
setattr(context, field_name, value)
@staticmethod
def _apply_runtime_to_context(runtime: dict, context: ActionContext) -> None:
"""
将运行期信息写入 runtime_state供动作和执行状态读取。
"""
if not runtime:
return
context.runtime_state = context.runtime_state or {}
context.runtime_state["current_action_runtime"] = {
key: value for key, value in runtime.items()
if key != "cancel_token"
}
def _extract_outputs_from_context(self, context: ActionContext) -> dict[str, Any]:
"""
按动作契约从上下文提取输出。
"""
outputs = {}
for field in self.get_contract().get("outputs") or []:
value = getattr(context, field["name"], None)
if value not in (None, "", [], {}):
outputs[field["name"]] = value
return outputs

View File

@@ -26,6 +26,12 @@ class AddDownloadAction(BaseAction):
添加下载资源
"""
contract = {
"inputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
"outputs": [{"name": "downloads", "label": "下载任务", "kind": "list"}],
"concurrency_key": "download",
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._added_downloads = []

View File

@@ -19,6 +19,11 @@ class AddSubscribeAction(BaseAction):
添加订阅
"""
contract = {
"inputs": [{"name": "medias", "label": "媒体", "kind": "list"}],
"outputs": [{"name": "subscribes", "label": "订阅", "kind": "list"}],
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._added_subscribes = []

View File

@@ -16,6 +16,12 @@ class FetchDownloadsAction(BaseAction):
获取下载任务
"""
contract = {
"inputs": [{"name": "downloads", "label": "下载任务", "kind": "list"}],
"outputs": [{"name": "downloads", "label": "下载任务", "kind": "list", "merge": "replace"}],
"concurrency_key": "download",
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._downloads = []
@@ -43,12 +49,19 @@ class FetchDownloadsAction(BaseAction):
"""
更新downloads中的下载任务状态
"""
__all_complete = False
self._downloads = context.downloads or []
if not self._downloads:
self.job_done("无下载任务")
return context
for download in self._downloads:
if global_vars.is_workflow_stopped(workflow_id):
break
logger.info(f"获取下载任务 {download.download_id} 状态 ...")
torrents = ActionChain().list_torrents(hashs=[download.download_id])
torrents = ActionChain().list_torrents(
hashs=[download.download_id],
downloader=download.downloader,
)
if not torrents:
download.completed = True
continue
@@ -61,5 +74,5 @@ class FetchDownloadsAction(BaseAction):
logger.info(f"下载任务 {download.download_id} 未完成")
download.completed = False
if all([d.completed for d in self._downloads]):
self.job_done()
self.job_done("下载任务已全部完成")
return context

View File

@@ -27,6 +27,10 @@ class FetchMediasAction(BaseAction):
获取媒体数据
"""
contract = {
"outputs": [{"name": "medias", "label": "媒体", "kind": "list"}],
}
def __init__(self, action_id: str):
super().__init__(action_id)

View File

@@ -30,6 +30,10 @@ class FetchRssAction(BaseAction):
获取RSS资源列表
"""
contract = {
"outputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._rss_torrents = []

View File

@@ -30,6 +30,11 @@ class FetchTorrentsAction(BaseAction):
搜索站点资源
"""
contract = {
"inputs": [{"name": "medias", "label": "媒体", "kind": "list"}],
"outputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._torrents = []

View File

@@ -22,6 +22,11 @@ class FilterMediasAction(BaseAction):
过滤媒体数据
"""
contract = {
"inputs": [{"name": "medias", "label": "媒体", "kind": "list"}],
"outputs": [{"name": "medias", "label": "媒体", "kind": "list", "merge": "replace"}],
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._medias = []

View File

@@ -27,6 +27,11 @@ class FilterTorrentsAction(BaseAction):
过滤资源数据
"""
contract = {
"inputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
"outputs": [{"name": "torrents", "label": "资源", "kind": "list", "merge": "replace"}],
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._torrents = []

View File

@@ -20,6 +20,8 @@ class InvokePluginAction(BaseAction):
调用插件
"""
contract = {}
def __init__(self, action_id: str):
super().__init__(action_id)
self._success = False

View File

@@ -7,6 +7,8 @@ class NoteAction(BaseAction):
备注
"""
contract = {}
@classmethod
@property
def name(cls) -> str: # noqa

View File

@@ -24,6 +24,10 @@ class ScanFileAction(BaseAction):
整理文件
"""
contract = {
"outputs": [{"name": "fileitems", "label": "文件", "kind": "list"}],
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._fileitems = []

View File

@@ -18,6 +18,11 @@ class ScrapeFileAction(BaseAction):
刮削文件
"""
contract = {
"inputs": [{"name": "fileitems", "label": "文件", "kind": "list"}],
"outputs": [{"name": "fileitems", "label": "文件", "kind": "list"}],
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._scraped_files = []
@@ -61,18 +66,18 @@ class ScrapeFileAction(BaseAction):
logger.info(f"{fileitem.path} 已刮削过,跳过")
continue
mediachain = MediaChain()
context = mediachain.recognize_by_path(
media_context = mediachain.recognize_by_path(
fileitem.path,
obtain_images=True,
)
if not context or not context.media_info:
if not media_context or not media_context.media_info:
_failed_count += 1
logger.info(f"{fileitem.path} 未识别到媒体信息,无法刮削")
continue
mediachain.scrape_metadata(
fileitem=fileitem,
meta=context.meta_info,
mediainfo=context.media_info
meta=media_context.meta_info,
mediainfo=media_context.media_info
)
self._scraped_files.append(fileitem)
# 保存缓存

View File

@@ -16,6 +16,8 @@ class SendEventAction(BaseAction):
发送事件
"""
contract = {}
@classmethod
@property
def name(cls) -> str: # noqa

View File

@@ -20,6 +20,8 @@ class SendMessageAction(BaseAction):
发送消息
"""
contract = {}
def __init__(self, action_id: str):
super().__init__(action_id)

View File

@@ -26,6 +26,15 @@ class TransferFileAction(BaseAction):
整理文件
"""
contract = {
"inputs": [
{"name": "downloads", "label": "下载任务", "kind": "list"},
{"name": "fileitems", "label": "文件", "kind": "list"},
],
"outputs": [{"name": "fileitems", "label": "文件", "kind": "list"}],
"concurrency_key": "transfer",
}
def __init__(self, action_id: str):
super().__init__(action_id)
self._fileitems = []

View File

@@ -0,0 +1,45 @@
"""2.2.9
为工作流增加执行配置和结构化执行状态
Revision ID: 7c1a2b3d4e5f
Revises: d5e6f7a8b9c0
Create Date: 2026-06-04
"""
from alembic import op
import sqlalchemy as sa
revision = "7c1a2b3d4e5f"
down_revision = "d5e6f7a8b9c0"
branch_labels = None
depends_on = None
def _has_column(inspector: sa.Inspector, table_name: str, column_name: str) -> bool:
"""检查数据表是否已存在指定列。"""
if table_name not in inspector.get_table_names():
return False
return any(column["name"] == column_name for column in inspector.get_columns(table_name))
def _add_json_column_if_missing(table_name: str, column_name: str) -> None:
"""缺失时为数据表新增 JSON 列。"""
inspector = sa.inspect(op.get_bind())
if not _has_column(inspector, table_name, column_name):
op.add_column(table_name, sa.Column(column_name, sa.JSON(), nullable=True))
def upgrade() -> None:
"""升级数据库结构。"""
_add_json_column_if_missing("workflow", "execution_config")
_add_json_column_if_missing("workflow", "execution_state")
def downgrade() -> None:
"""回滚数据库结构。"""
inspector = sa.inspect(op.get_bind())
if _has_column(inspector, "workflow", "execution_state"):
op.drop_column("workflow", "execution_state")
inspector = sa.inspect(op.get_bind())
if _has_column(inspector, "workflow", "execution_config"):
op.drop_column("workflow", "execution_config")

View File

@@ -22,14 +22,31 @@ def _has_column(inspector: sa.Inspector, table_name: str, column_name: str) -> b
return any(column["name"] == column_name for column in inspector.get_columns(table_name))
def upgrade() -> None:
inspector = sa.inspect(op.get_bind())
if _has_column(inspector, "subscribe", "episode_priority") is False:
op.add_column("subscribe", sa.Column("episode_priority", sa.JSON(), nullable=True))
def _ensure_json_column(bind, table_name: str, column_name: str) -> None:
"""Add the column as JSON, or fix it if it already exists with the wrong type (PostgreSQL only)."""
inspector = sa.inspect(bind)
if not _has_column(inspector, table_name, column_name):
op.add_column(table_name, sa.Column(column_name, sa.JSON(), nullable=True))
return
if bind.dialect.name != "postgresql":
return
for col in inspector.get_columns(table_name):
if col["name"] != column_name:
continue
type_name = type(col["type"]).__name__.upper()
if type_name in ("JSON", "JSONB"):
return
# Column exists with wrong type (e.g., INTEGER from an intermediate build); replace it.
# Existing values are non-functional garbage so data loss is acceptable.
op.drop_column(table_name, column_name)
op.add_column(table_name, sa.Column(column_name, sa.JSON(), nullable=True))
return
inspector = sa.inspect(op.get_bind())
if _has_column(inspector, "subscribehistory", "episode_priority") is False:
op.add_column("subscribehistory", sa.Column("episode_priority", sa.JSON(), nullable=True))
def upgrade() -> None:
bind = op.get_bind()
_ensure_json_column(bind, "subscribe", "episode_priority")
_ensure_json_column(bind, "subscribehistory", "episode_priority")
def downgrade() -> None:

View File

@@ -0,0 +1,50 @@
"""2.2.8
修复 episode_priority 列类型PostgreSQL 下若列为 INTEGER 则重建为 JSON
Revision ID: d5e6f7a8b9c0
Revises: 1f0d2c3b4a5e
Create Date: 2026-06-04
"""
from alembic import op
import sqlalchemy as sa
revision = "d5e6f7a8b9c0"
down_revision = "1f0d2c3b4a5e"
branch_labels = None
depends_on = None
def _has_column(inspector: sa.Inspector, table_name: str, column_name: str) -> bool:
if table_name not in inspector.get_table_names():
return False
return any(column["name"] == column_name for column in inspector.get_columns(table_name))
def _fix_episode_priority_type(bind, table_name: str) -> None:
"""On PostgreSQL, if episode_priority exists but is not JSON/JSONB, drop and re-add it as JSON."""
if bind.dialect.name != "postgresql":
return
inspector = sa.inspect(bind)
if not _has_column(inspector, table_name, "episode_priority"):
op.add_column(table_name, sa.Column("episode_priority", sa.JSON(), nullable=True))
return
for col in inspector.get_columns(table_name):
if col["name"] != "episode_priority":
continue
type_name = type(col["type"]).__name__.upper()
if type_name in ("JSON", "JSONB"):
return
op.drop_column(table_name, "episode_priority")
op.add_column(table_name, sa.Column("episode_priority", sa.JSON(), nullable=True))
return
def upgrade() -> None:
bind = op.get_bind()
_fix_episode_priority_type(bind, "subscribe")
_fix_episode_priority_type(bind, "subscribehistory")
def downgrade() -> None:
pass

11
pytest.ini Normal file
View File

@@ -0,0 +1,11 @@
[pytest]
# 仅对「无法在本仓修复根因」的已知上游/三方弃用告警做精确忽略,保持测试输出干净、
# 让本仓自身的新告警更醒目。本仓代码引发的告警一律不在此忽略,应在源码/用例处修复。
filterwarnings =
ignore:datetime.datetime.utcfromtimestamp\(\) is deprecated:DeprecationWarning
ignore:websockets.legacy is deprecated:DeprecationWarning
ignore:websockets.InvalidStatusCode is deprecated:DeprecationWarning
ignore:pkg_resources is deprecated as an API:DeprecationWarning
ignore:Deprecated call to .pkg_resources.declare_namespace:DeprecationWarning
ignore:'crypt' is deprecated:DeprecationWarning
ignore:'audioop' is deprecated:DeprecationWarning

View File

@@ -1,43 +1,14 @@
"""pytest 全局引导:在 import 任何测试模块前把 CONFIG_DIR 指向临时目录并建表,隔离真实库。"""
import atexit
import os
import shutil
import sys
import tempfile
from types import ModuleType
"""pytest 全局引导:隔离 CONFIG_DIR、补 sites 垫片、建表、装载网络守卫。
# 必须早于首个 import app.*app.db 在导入时即按 CONFIG_PATH 连接 user.db
if not os.environ.get("CONFIG_DIR"):
_isolated_config_dir = tempfile.mkdtemp(prefix="mp-test-config-")
os.environ["CONFIG_DIR"] = _isolated_config_dir
引导与网络守卫均复用 ``app/testing`` 的共享 harness与插件仓 conftest 同源),
引导逻辑只在 ``app/testing`` 维护一处。
"""
# 必须早于首个 import app.db其在 import 期即按 CONFIG_PATH 连库prepare_backend 内部
# 先隔离 CONFIG_DIR、补 app.helper.sites 垫片再建表。app/testing 仅依赖标准库、import 不连库,
# 故此处先 import 再调用是安全的。
from app.testing.bootstrap import prepare_backend
def _cleanup_isolated_config_dir():
"""进程退出时先释放 SQLite 连接池再删临时目录。
prepare_backend()
Windows 下 Engine 若仍持有 user.db 的文件锁,直接 rmtree 会因占用而静默失败
ignore_errors=True、残留临时目录先 dispose 释放连接再删可规避。
"""
try:
from app.db import Engine
Engine.dispose()
except Exception:
pass
shutil.rmtree(_isolated_config_dir, ignore_errors=True)
atexit.register(_cleanup_isolated_config_dir)
# app.helper.sites 由独立仓库动态拉取CI / 全新环境无该模块),而众多 app.chain.* /
# app.modules.* 在 import 期依赖它。在此统一补一个最小垫片,省去各测试文件各自打桩;
# 若真实模块已存在(本地已拉取)则 setdefault 不覆盖,不影响真实行为。
if "app.helper.sites" not in sys.modules:
try:
import app.helper.sites # noqa: F401 本地已拉取时用真实模块
except ModuleNotFoundError:
_sites_stub = ModuleType("app.helper.sites")
_sites_stub.SitesHelper = object
sys.modules["app.helper.sites"] = _sites_stub
# 必须在 CONFIG_DIR 设好之后再 import空库会让运行期查表报 no such table故建表
from app.db.init import init_db # noqa: E402
init_db()
# 复用共享 autouse 网络守卫;同一实现亦供各插件仓 conftest import 复用,避免逐仓维护
from app.testing.network_guard import block_real_network # noqa: E402,F401

View File

@@ -453,7 +453,8 @@ class AgentImageSupportTest(unittest.TestCase):
) as prepare_files, patch(
"app.chain.message.agent_manager.process_message", new_callable=AsyncMock
) as process_message, patch(
"app.chain.message.asyncio.run_coroutine_threadsafe"
"app.chain.message.asyncio.run_coroutine_threadsafe",
side_effect=lambda coro, _loop: coro.close(),
) as run_coroutine_threadsafe:
chain._handle_ai_message(
text="/ai 帮我看看这张图",
@@ -486,7 +487,7 @@ class AgentImageSupportTest(unittest.TestCase):
"app.chain.message.agent_manager.process_message", new_callable=AsyncMock
) as process_message, patch(
"app.chain.message.asyncio.run_coroutine_threadsafe",
side_effect=lambda coro, _loop: (coro.close(), Mock())[1],
side_effect=lambda coro, _loop: coro.close(),
):
chain._handle_ai_message(
text="帮我推荐一部电影",

View File

@@ -2,14 +2,10 @@ import unittest
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
from langchain_core.messages import AIMessage
from app.agent import MoviePilotAgent
from app.agent.memory import memory_manager
# agenttokens 为动态安装插件app/plugins/** 被 gitignoreCI / 全新环境无此插件),
# 缺失时跳过本模块,避免 collection 阶段 ImportError。
AgentTokens = pytest.importorskip("app.plugins.agenttokens").AgentTokens
from app.schemas.types import ChainEventType, EventType
@@ -44,31 +40,6 @@ class _FakeFailingAgent(_FakeAgent):
class AgentTokensEventsTest(unittest.IsolatedAsyncioTestCase):
async def test_plugin_sidebar_nav_respects_config(self):
"""插件侧边栏入口应受 show_sidebar_nav 配置控制。"""
plugin = AgentTokens()
with patch.object(plugin, "update_config"):
plugin.init_plugin(
{
"enabled": True,
"show_sidebar_nav": False,
"providers": [],
}
)
self.assertEqual([], plugin.get_sidebar_nav())
plugin.init_plugin(
{
"enabled": True,
"show_sidebar_nav": True,
"providers": [],
}
)
nav = plugin.get_sidebar_nav()
self.assertEqual("Agent Tokens 管理", nav[0]["title"])
async def test_initialize_llm_uses_chain_event_selection(self):
"""Agent 初始化 LLM 时应优先使用链式事件返回的供应商配置。"""
agent = MoviePilotAgent(session_id="agent-tokens-test", user_id="user-1")

View File

@@ -1,5 +1,7 @@
import hashlib
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import MagicMock, patch
from app.modules.filemanager.storages import alist as alist_module
@@ -137,3 +139,47 @@ class AlistStorageTest(unittest.TestCase):
self.assertEqual("alist", target.storage)
self.assertEqual("file", target.type)
self.assertEqual(1024, target.size)
def test_upload_sends_hash_headers_for_rapid_upload(self):
"""
OpenList 上传应附带文件哈希头,供服务端尝试秒传。
"""
with TemporaryDirectory() as temp_dir:
local_path = Path(temp_dir) / "rapid.bin"
content = b"moviepilot-openlist-rapid-upload"
local_path.write_bytes(content)
upload_dir = FileItem(
storage="alist",
type="dir",
path="/library/",
name="library",
basename="library",
)
uploaded_item = FileItem(
storage="alist",
type="file",
path="/library/rapid.bin",
name="rapid.bin",
basename="rapid",
extension="bin",
size=len(content),
)
request_utils = MagicMock()
request_utils.put_res.return_value = _FakeResponse(
{"code": 200, "message": "success", "data": None}
)
with patch.object(Alist, "get_conf", return_value={"url": "http://openlist.test", "token": "token"}):
with patch.object(alist_module, "RequestUtils", return_value=request_utils) as request_utils_factory:
with patch.object(self.storage, "_delay_get_item", return_value=uploaded_item):
result = self.storage.upload(upload_dir, local_path)
self.assertEqual(uploaded_item, result)
request_utils.put_res.assert_called_once()
headers = request_utils_factory.call_args.kwargs["headers"]
self.assertEqual(hashlib.md5(content).hexdigest(), headers["X-File-Md5"])
self.assertEqual(hashlib.sha1(content).hexdigest(), headers["X-File-Sha1"])
self.assertEqual(hashlib.sha256(content).hexdigest(), headers["X-File-Sha256"])
self.assertEqual(str(len(content)), headers["Content-Length"])
self.assertEqual("application/octet-stream", headers["Content-Type"])
self.assertEqual("/library/rapid.bin", headers["File-Path"])

View File

@@ -1,19 +1,16 @@
import sys
import asyncio
import json
import tempfile
import unittest
from types import ModuleType, SimpleNamespace
from types import SimpleNamespace
from unittest.mock import ANY, MagicMock, patch
from app.testing.bootstrap import ensure_optional_stub
sys.modules.setdefault("psutil", ModuleType("psutil"))
sys.modules.setdefault("dateparser", ModuleType("dateparser"))
if "Pinyin2Hanzi" not in sys.modules:
pinyin_module = ModuleType("Pinyin2Hanzi")
setattr(pinyin_module, "is_pinyin", lambda value: False)
sys.modules["Pinyin2Hanzi"] = pinyin_module
# 可选三方依赖在 CI / 全新环境可能未安装,补占位避免 app.modules.feishu 导入失败
ensure_optional_stub("psutil")
ensure_optional_stub("dateparser")
ensure_optional_stub("Pinyin2Hanzi", is_pinyin=lambda value: False)
from app.modules.feishu import FeishuModule
from app.modules.feishu.feishu import Feishu

View File

@@ -1,19 +1,6 @@
import sys
import unittest
from unittest.mock import Mock
for _module_name in (
"app.chain.mediaserver",
"app.db.models",
"app.db.user_oper",
"app.helper.message",
"app.utils.crypto",
):
if _module_name in sys.modules and not hasattr(
sys.modules[_module_name], "__file__"
):
del sys.modules[_module_name]
from app.chain.mediaserver import MediaServerChain
from app.schemas import MediaServerLibrary, MediaServerPlayItem
from app.utils.security import SecurityUtils

View File

@@ -131,14 +131,15 @@ class PluginHelperTest(TestCase):
self.skipTest(f"missing dependency: {exc}")
module_names = ["app.plugins.dynamicwechat.helper", "Crypto.Cipher._mode_cbc"]
previous_modules = {name: sys.modules.get(name) for name in module_names}
def fake_execute(_cmd):
for module_name in module_names:
sys.modules[module_name] = ModuleType(module_name)
return True, "ok"
try:
# patch.dict 进入时快照 sys.modules、退出时整体还原替代手写逐项 save/restore
# 保证 fake_execute 在安装窗口注入的运行态模块在用例结束后被清理、不污染其他用例
with patch.dict(sys.modules):
with tempfile.TemporaryDirectory() as temp_dir:
requirements_file = Path(temp_dir) / "requirements.txt"
requirements_file.write_text("demo-package\n", encoding="utf-8")
@@ -149,12 +150,6 @@ class PluginHelperTest(TestCase):
self.assertEqual("ok", message)
for module_name in module_names:
self.assertIn(module_name, sys.modules)
finally:
for module_name, previous_module in previous_modules.items():
if previous_module is None:
sys.modules.pop(module_name, None)
else:
sys.modules[module_name] = previous_module
def test_pip_install_serializes_concurrent_calls(self):
"""

100
tests/test_request_utils.py Normal file
View File

@@ -0,0 +1,100 @@
import requests
from app.utils.http import RequestUtils
class _FakeSession:
"""
测试用 requests.Session 替身,记录请求次数与连接池关闭行为。
"""
def __init__(self, side_effects):
"""
初始化请求结果序列。
:param side_effects: 每次 request 调用要返回或抛出的对象
"""
self.side_effects = list(side_effects)
self.calls = []
self.close_count = 0
def request(self, method, url, **kwargs):
"""
模拟 requests.Session.request。
"""
self.calls.append((method, url, kwargs))
effect = self.side_effects.pop(0)
if isinstance(effect, Exception):
raise effect
return effect
def close(self):
"""
模拟清空 session 连接池。
"""
self.close_count += 1
def _make_response(status_code: int = 200) -> requests.Response:
response = requests.Response()
response.status_code = status_code
return response
def test_request_utils_retries_idempotent_session_connection_error():
"""
同步幂等请求遇到失效 session 连接时应清理连接池并重试一次。
"""
response = _make_response()
session = _FakeSession(
[
requests.exceptions.ConnectionError("stale keep-alive"),
response,
]
)
request_utils = RequestUtils(session=session)
result = request_utils.get_res("https://example.com/data")
assert result is response
assert len(session.calls) == 2
assert session.close_count == 1
def test_request_utils_does_not_retry_non_idempotent_connection_error():
"""
非幂等请求连接异常时不应自动重试,避免重复提交副作用。
"""
session = _FakeSession(
[
requests.exceptions.ConnectionError("connection failed"),
_make_response(),
]
)
request_utils = RequestUtils(session=session)
result = request_utils.post_res("https://example.com/data", data={"name": "demo"})
assert result is None
assert len(session.calls) == 1
assert session.close_count == 0
def test_request_utils_raises_retry_error_when_retry_still_fails():
"""
开启 raise_exception 后,重试仍失败时应抛出重试阶段的异常。
"""
first_error = requests.exceptions.ConnectionError("stale keep-alive")
retry_error = requests.exceptions.ConnectionError("proxy still unavailable")
session = _FakeSession([first_error, retry_error])
request_utils = RequestUtils(session=session)
try:
request_utils.get_res("https://example.com/data", raise_exception=True)
except requests.exceptions.ConnectionError as err:
assert err is retry_error
else:
raise AssertionError("请求重试失败时应抛出异常")
assert len(session.calls) == 2
assert session.close_count == 1

View File

@@ -1,28 +1,15 @@
import asyncio
import importlib.machinery
import sys
import unittest
from types import SimpleNamespace
from types import ModuleType
from unittest.mock import AsyncMock, patch
from app.testing.bootstrap import ensure_optional_stub
def _stub_module(name: str, **attrs):
module = sys.modules.get(name)
if module is None:
module = ModuleType(name)
sys.modules[name] = module
for key, value in attrs.items():
setattr(module, key, value)
return module
_stub_module("qbittorrentapi", TorrentFilesList=list)
_stub_module("transmission_rpc", File=object)
_stub_module(
"psutil",
__spec__=importlib.machinery.ModuleSpec("psutil", loader=None),
)
# 可选三方依赖在 CI / 全新环境可能未安装,补占位(带用例所需属性)避免导入失败
ensure_optional_stub("qbittorrentapi", TorrentFilesList=list)
ensure_optional_stub("transmission_rpc", File=object)
ensure_optional_stub("psutil", __spec__=importlib.machinery.ModuleSpec("psutil", loader=None))
from app.agent.tools.factory import MoviePilotToolFactory
from app.agent import ReplyMode

View File

@@ -7,6 +7,7 @@ from unittest import TestCase
from unittest.mock import patch
from app.schemas.types import MediaType
from app.testing import stub_modules
def _load_subscribe_chain_class():
@@ -16,13 +17,11 @@ def _load_subscribe_chain_class():
module = sys.modules[module_name]
return module, module.SubscribeChain
original_modules = {}
stub_deps = {}
def ensure_module(name: str, module: types.ModuleType):
"""临时替换模块依赖,并记录原模块以便加载完成后恢复"""
if name not in original_modules:
original_modules[name] = sys.modules.get(name)
sys.modules[name] = module
"""登记一个加载期临时替换模块;实际替换与精确还原由 stub_modules 在加载时统一处理"""
stub_deps[name] = module
return module
chain_module = ensure_module("app.chain", types.ModuleType("app.chain"))
@@ -298,18 +297,12 @@ def _load_subscribe_chain_class():
subscribe_path = Path(__file__).resolve().parents[1] / "app" / "chain" / "subscribe.py"
spec = importlib.util.spec_from_file_location(module_name, subscribe_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
assert spec and spec.loader
spec.loader.exec_module(module)
module._injected_modules = {
name: sys.modules.get(name)
for name in original_modules
}
for injected_name, original_module in original_modules.items():
if original_module is None:
sys.modules.pop(injected_name, None)
else:
sys.modules[injected_name] = original_module
# 加载期用 stub_modules 精确替换依赖、退出时统一还原module_name 非桩,缓存入 sys.modules 供复用
with stub_modules(stub_deps):
sys.modules[module_name] = module
spec.loader.exec_module(module)
module._injected_modules = {name: sys.modules.get(name) for name in stub_deps}
return module, module.SubscribeChain
@@ -463,6 +456,69 @@ class SubscribeChainTest(TestCase):
self.assertEqual(SubscribeChain.get_best_version_current_priority(subscribe), 100)
self.assertTrue(SubscribeChain.is_best_version_complete(subscribe))
def test_get_subscribe_no_exists_expands_whole_missing_when_custom_start_skips_existing_range(self):
"""自定义开始集跳过季初集数时,缺失整季需要转成显式目标集。"""
no_exists = {
"media-key": {
1: SimpleNamespace(season=1, episodes=[], total_episode=48, start_episode=1)
}
}
exist_flag, result = SubscribeChain._SubscribeChain__get_subscribe_no_exits(
subscribe_name="主角 S01",
no_exists=no_exists,
mediakey="media-key",
begin_season=1,
total_episode=48,
start_episode=44,
)
self.assertFalse(exist_flag)
self.assertEqual(result["media-key"][1].episodes, [44, 45, 46, 47, 48])
self.assertEqual(result["media-key"][1].start_episode, 44)
self.assertEqual(result["media-key"][1].total_episode, 48)
def test_get_subscribe_no_exists_keeps_whole_missing_when_custom_start_matches_original_start(self):
"""自定义开始集没有缩小范围时,仍保留空集列表表示整季缺失。"""
no_exists = {
"media-key": {
1: SimpleNamespace(season=1, episodes=[], total_episode=48, start_episode=1)
}
}
exist_flag, result = SubscribeChain._SubscribeChain__get_subscribe_no_exits(
subscribe_name="主角 S01",
no_exists=no_exists,
mediakey="media-key",
begin_season=1,
total_episode=48,
start_episode=1,
)
self.assertFalse(exist_flag)
self.assertEqual(result["media-key"][1].episodes, [])
self.assertEqual(result["media-key"][1].start_episode, 1)
self.assertEqual(result["media-key"][1].total_episode, 48)
def test_best_version_full_pack_first_keeps_whole_missing_for_custom_start_episode(self):
"""分集洗版优先全集时,空集列表仍表示下载链按整季资源处理。"""
subscribe = self._build_subscribe(
best_version=1,
best_version_full=0,
start_episode=44,
total_episode=48,
episode_priority={str(episode): 80 for episode in range(44, 49)},
)
result = SubscribeChain._SubscribeChain__build_full_pack_first_no_exists(
subscribe=subscribe,
mediakey="media-key",
)
self.assertEqual(result["media-key"][1].episodes, [])
self.assertEqual(result["media-key"][1].start_episode, 44)
self.assertEqual(result["media-key"][1].total_episode, 48)
def test_is_episode_range_covered_matches_pending_episodes(self):
subscribe = self._build_subscribe(
total_episode=12,

View File

@@ -1,24 +1,17 @@
import asyncio
import sys
import unittest
from types import ModuleType
from unittest.mock import AsyncMock, patch
_ORIGINAL_STUBBED_MODULES = {}
from app.testing import stub_modules
def _stub_module(name: str, **attrs):
"""
安装临时 stub 模块,并记录原模块用于导入后恢复。
"""
if name not in _ORIGINAL_STUBBED_MODULES:
_ORIGINAL_STUBBED_MODULES[name] = sys.modules.get(name)
def _stub(name: str, **attrs) -> tuple:
"""构造带指定属性的占位模块,返回 ``(模块名, 模块)`` 供 :func:`stub_modules` 使用。"""
module = ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
sys.modules[name] = module
return module
return name, module
class _Dummy:
@@ -35,69 +28,44 @@ class _DummyError(Exception):
self.duration_ms = duration_ms
for _module_name in ("pillow_avif", "aiofiles", "psutil"):
_stub_module(_module_name)
# 在 import 期用占位模块替换重依赖/外部模块import 完由 stub_modules 精确还原,避免污染其它用例
_STUB_MODULES = dict([
_stub("pillow_avif"),
_stub("aiofiles"),
_stub("psutil"),
_stub("app.helper.sites", SitesHelper=_Dummy),
_stub("app.chain.mediaserver", MediaServerChain=_Dummy),
_stub("app.chain.search", SearchChain=_Dummy),
_stub("app.chain.system", SystemChain=_Dummy),
_stub("app.agent.llm", LLMHelper=_Dummy, LLMProviderManager=_Dummy,
LLMTestError=_DummyError, LLMTestTimeout=_DummyError,
render_auth_result_html=lambda success, message: message),
_stub("app.core.event", eventmanager=_Dummy(), Event=_Dummy, EventManager=_Dummy),
_stub("app.core.metainfo", MetaInfo=_Dummy),
_stub("app.core.module", ModuleManager=_Dummy),
_stub("app.core.security", verify_apitoken=_Dummy, verify_resource_token=_Dummy, verify_token=_Dummy),
_stub("app.db.models", User=_Dummy),
_stub("app.db.systemconfig_oper", SystemConfigOper=_Dummy),
_stub("app.db.user_oper", get_current_active_superuser=_Dummy,
get_current_active_superuser_async=_Dummy, get_current_active_user_async=_Dummy),
_stub("app.helper.llm", LLMHelper=_Dummy, LLMTestError=_DummyError, LLMTestTimeout=_DummyError),
_stub("app.helper.mediaserver", MediaServerHelper=_Dummy),
_stub("app.helper.message", MessageHelper=_Dummy),
_stub("app.helper.progress", ProgressHelper=_Dummy),
_stub("app.helper.rule", RuleHelper=_Dummy),
_stub("app.helper.server", MoviePilotServerHelper=_Dummy),
_stub("app.helper.system", SystemHelper=_Dummy),
_stub("app.helper.image", ImageHelper=_Dummy),
_stub("app.scheduler", Scheduler=_Dummy),
_stub("app.log", logger=_Dummy(), log_settings=_Dummy(),
LogConfigModel=type("LogConfigModel", (), {})),
_stub("app.utils.crypto", HashUtils=_Dummy),
_stub("app.utils.http", RequestUtils=_Dummy, AsyncRequestUtils=_Dummy),
_stub("version", APP_VERSION="test"),
])
_stub_module("app.helper.sites", SitesHelper=_Dummy)
_stub_module("app.chain.mediaserver", MediaServerChain=_Dummy)
_stub_module("app.chain.search", SearchChain=_Dummy)
_stub_module("app.chain.system", SystemChain=_Dummy)
_stub_module(
"app.agent.llm",
LLMHelper=_Dummy,
LLMProviderManager=_Dummy,
LLMTestError=_DummyError,
LLMTestTimeout=_DummyError,
render_auth_result_html=lambda success, message: message,
)
_stub_module("app.core.event", eventmanager=_Dummy(), Event=_Dummy, EventManager=_Dummy)
_stub_module("app.core.metainfo", MetaInfo=_Dummy)
_stub_module("app.core.module", ModuleManager=_Dummy)
_stub_module(
"app.core.security",
verify_apitoken=_Dummy,
verify_resource_token=_Dummy,
verify_token=_Dummy,
)
_stub_module("app.db.models", User=_Dummy)
_stub_module("app.db.systemconfig_oper", SystemConfigOper=_Dummy)
_stub_module(
"app.db.user_oper",
get_current_active_superuser=_Dummy,
get_current_active_superuser_async=_Dummy,
get_current_active_user_async=_Dummy,
)
_stub_module(
"app.helper.llm",
LLMHelper=_Dummy,
LLMTestError=_DummyError,
LLMTestTimeout=_DummyError,
)
_stub_module("app.helper.mediaserver", MediaServerHelper=_Dummy)
_stub_module("app.helper.message", MessageHelper=_Dummy)
_stub_module("app.helper.progress", ProgressHelper=_Dummy)
_stub_module("app.helper.rule", RuleHelper=_Dummy)
_stub_module("app.helper.server", MoviePilotServerHelper=_Dummy)
_stub_module("app.helper.system", SystemHelper=_Dummy)
_stub_module("app.helper.image", ImageHelper=_Dummy)
_stub_module("app.scheduler", Scheduler=_Dummy)
_stub_module(
"app.log",
logger=_Dummy(),
log_settings=_Dummy(),
LogConfigModel=type("LogConfigModel", (), {}),
)
_stub_module("app.utils.crypto", HashUtils=_Dummy)
_stub_module("app.utils.http", RequestUtils=_Dummy, AsyncRequestUtils=_Dummy)
_stub_module("version", APP_VERSION="test")
from app.api.endpoints import llm as system_endpoint
for _module_name, _module in _ORIGINAL_STUBBED_MODULES.items():
if _module is None:
sys.modules.pop(_module_name, None)
else:
sys.modules[_module_name] = _module
with stub_modules(_STUB_MODULES):
from app.api.endpoints import llm as system_endpoint
class LlmTestEndpointTest(unittest.TestCase):

View File

@@ -1,25 +1,18 @@
import asyncio
import ipaddress
import sys
import unittest
from types import ModuleType, SimpleNamespace
from unittest.mock import AsyncMock, Mock, patch
_ORIGINAL_STUBBED_MODULES = {}
from app.testing import stub_modules
def _stub_module(name: str, **attrs):
"""
安装临时 stub 模块,并记录原模块用于导入后恢复。
"""
if name not in _ORIGINAL_STUBBED_MODULES:
_ORIGINAL_STUBBED_MODULES[name] = sys.modules.get(name)
def _stub(name: str, **attrs) -> tuple:
"""构造带指定属性的占位模块,返回 ``(模块名, 模块)`` 供 :func:`stub_modules` 使用。"""
module = ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
sys.modules[name] = module
return module
return name, module
class _Dummy:
@@ -36,67 +29,42 @@ class _DummyError(Exception):
self.duration_ms = duration_ms
for _module_name in ("pillow_avif", "aiofiles", "psutil"):
_stub_module(_module_name)
# 在 import 期用占位模块替换重依赖/外部模块import 完由 stub_modules 精确还原,避免污染其它用例
_STUB_MODULES = dict([
_stub("pillow_avif"),
_stub("aiofiles"),
_stub("psutil"),
_stub("app.helper.sites", SitesHelper=_Dummy),
_stub("app.chain.media", MediaChain=_Dummy),
_stub("app.chain.mediaserver", MediaServerChain=_Dummy),
_stub("app.chain.search", SearchChain=_Dummy),
_stub("app.chain.system", SystemChain=_Dummy),
_stub("app.core.event", eventmanager=_Dummy(), Event=_Dummy, EventManager=_Dummy),
_stub("app.core.metainfo", MetaInfo=_Dummy),
_stub("app.core.module", ModuleManager=_Dummy),
_stub("app.core.security", verify_apitoken=_Dummy, verify_resource_token=_Dummy, verify_token=_Dummy),
_stub("app.db.models", User=_Dummy),
_stub("app.db.systemconfig_oper", SystemConfigOper=_Dummy),
_stub("app.db.user_oper", get_current_active_superuser=_Dummy,
get_current_active_superuser_async=_Dummy, get_current_active_user_async=_Dummy),
_stub("app.helper.llm", LLMHelper=_Dummy, LLMTestError=_DummyError, LLMTestTimeout=_DummyError),
_stub("app.helper.mediaserver", MediaServerHelper=_Dummy),
_stub("app.helper.message", MessageHelper=_Dummy),
_stub("app.helper.progress", ProgressHelper=_Dummy),
_stub("app.helper.rule", RuleHelper=_Dummy),
_stub("app.helper.server", MoviePilotServerHelper=_Dummy),
_stub("app.helper.system", SystemHelper=_Dummy),
_stub("app.helper.image", ImageHelper=_Dummy),
_stub("app.scheduler", Scheduler=_Dummy),
_stub("app.log", logger=_Dummy(), log_settings=_Dummy(),
LogConfigModel=type("LogConfigModel", (), {})),
_stub("app.utils.crypto", HashUtils=_Dummy),
_stub("app.utils.http", RequestUtils=_Dummy, AsyncRequestUtils=_Dummy),
_stub("version", APP_VERSION="test", FRONTEND_VERSION="frontend-test"),
])
_stub_module("app.helper.sites", SitesHelper=_Dummy)
_stub_module("app.chain.media", MediaChain=_Dummy)
_stub_module("app.chain.mediaserver", MediaServerChain=_Dummy)
_stub_module("app.chain.search", SearchChain=_Dummy)
_stub_module("app.chain.system", SystemChain=_Dummy)
_stub_module(
"app.core.event",
eventmanager=_Dummy(),
Event=_Dummy,
EventManager=_Dummy,
)
_stub_module("app.core.metainfo", MetaInfo=_Dummy)
_stub_module("app.core.module", ModuleManager=_Dummy)
_stub_module(
"app.core.security",
verify_apitoken=_Dummy,
verify_resource_token=_Dummy,
verify_token=_Dummy,
)
_stub_module("app.db.models", User=_Dummy)
_stub_module("app.db.systemconfig_oper", SystemConfigOper=_Dummy)
_stub_module(
"app.db.user_oper",
get_current_active_superuser=_Dummy,
get_current_active_superuser_async=_Dummy,
get_current_active_user_async=_Dummy,
)
_stub_module(
"app.helper.llm",
LLMHelper=_Dummy,
LLMTestError=_DummyError,
LLMTestTimeout=_DummyError,
)
_stub_module("app.helper.mediaserver", MediaServerHelper=_Dummy)
_stub_module("app.helper.message", MessageHelper=_Dummy)
_stub_module("app.helper.progress", ProgressHelper=_Dummy)
_stub_module("app.helper.rule", RuleHelper=_Dummy)
_stub_module("app.helper.server", MoviePilotServerHelper=_Dummy)
_stub_module("app.helper.system", SystemHelper=_Dummy)
_stub_module("app.helper.image", ImageHelper=_Dummy)
_stub_module("app.scheduler", Scheduler=_Dummy)
_stub_module(
"app.log",
logger=_Dummy(),
log_settings=_Dummy(),
LogConfigModel=type("LogConfigModel", (), {}),
)
_stub_module("app.utils.crypto", HashUtils=_Dummy)
_stub_module("app.utils.http", RequestUtils=_Dummy, AsyncRequestUtils=_Dummy)
_stub_module("version", APP_VERSION="test", FRONTEND_VERSION="frontend-test")
from app.api.endpoints import system as system_endpoint
for _module_name, _module in _ORIGINAL_STUBBED_MODULES.items():
if _module is None:
sys.modules.pop(_module_name, None)
else:
sys.modules[_module_name] = _module
with stub_modules(_STUB_MODULES):
from app.api.endpoints import system as system_endpoint
class NettestSecurityTest(unittest.TestCase):

View File

@@ -0,0 +1,62 @@
"""共享测试引导工具的回归用例。"""
from __future__ import annotations
import builtins
import types
from pathlib import Path
from app.testing import bootstrap
def test_isolate_config_cleanup_uses_loaded_db_module_without_late_import(monkeypatch):
"""清理回调只读取已加载模块,避免解释器关停期触发二次导入。"""
captured = {}
import_calls = []
def fake_import(name, *args, **kwargs):
"""记录清理回调是否试图重新导入数据库模块。"""
if name == "app.db":
import_calls.append(name)
return original_import(name, *args, **kwargs)
def fake_register(func):
"""截获 atexit 回调,便于直接验证清理行为。"""
captured["cleanup"] = func
monkeypatch.setattr(bootstrap, "_isolated_config_dir", None)
monkeypatch.delenv("CONFIG_DIR", raising=False)
monkeypatch.setattr(bootstrap.tempfile, "mkdtemp", lambda prefix: "/tmp/mp-test-config-demo")
monkeypatch.setattr(bootstrap.shutil, "rmtree", lambda *args, **kwargs: None)
monkeypatch.setattr(bootstrap.atexit, "register", fake_register)
original_import = builtins.__import__
monkeypatch.setattr(builtins, "__import__", fake_import)
bootstrap.isolate_config_dir()
captured["cleanup"]()
assert import_calls == []
def test_mark_plugin_generation_prefers_pathlib_item_path():
"""pytest 新版 item.path 可独立驱动 v1/v2 marker 标记。"""
class FakeItem:
"""只暴露 pytest 7+ 的 path 属性,模拟新版收集对象。"""
def __init__(self, value: str):
self.path = Path(value)
self.markers = []
def add_marker(self, marker):
"""记录被添加的 marker。"""
self.markers.append(marker)
pytest_module = types.SimpleNamespace(mark=types.SimpleNamespace(v1="v1", v2="v2"))
v2_item = FakeItem("/repo/tests/v2/test_demo.py")
v1_item = FakeItem("/repo/tests/v1/test_demo.py")
bootstrap.mark_plugin_generation([v2_item, v1_item], pytest_module)
assert v2_item.markers == ["v2"]
assert v1_item.markers == ["v1"]

View File

@@ -95,11 +95,16 @@ class TestTransferFailedRetryButtons(unittest.TestCase):
errmsg="未识别到媒体信息",
)
def _close_pending_coro(coro, *args, **kwargs):
"""关闭被调度的协程:测试中事件循环未运行,不关闭会残留 never-awaited 警告。"""
coro.close()
with patch.object(settings, "AI_AGENT_ENABLE", True):
with patch(
"app.chain.message.TransferHistoryOper"
) as history_oper_cls, patch(
"app.chain.message.asyncio.run_coroutine_threadsafe"
"app.chain.message.asyncio.run_coroutine_threadsafe",
side_effect=_close_pending_coro,
) as run_task:
history_oper_cls.return_value.get.return_value = history
with patch.object(chain, "post_message") as post_message:

View File

@@ -0,0 +1,161 @@
from types import SimpleNamespace
from app.schemas import ActionContext, DownloadTask, FileItem
from app.schemas.workflow import ActionResult
from app.workflow.actions import BaseAction
from app.workflow.actions import fetch_downloads as fetch_downloads_module
from app.workflow.actions import scrape_file as scrape_file_module
from app.workflow.actions.fetch_downloads import FetchDownloadsAction
from app.workflow.actions.scrape_file import ScrapeFileAction
from app.workflow.actions.fetch_rss import FetchRssAction
from app.workflow import WorkFlowManager
def test_fetch_downloads_updates_context_downloads(monkeypatch):
"""获取下载任务动作应更新上游上下文中的下载任务。"""
calls = []
class FakeActionChain:
"""模拟下载器查询链。"""
def list_torrents(self, hashs=None, downloader=None, **kwargs):
calls.append((hashs, downloader))
return [SimpleNamespace(path="/downloads/movie.mkv", progress=100)]
monkeypatch.setattr(fetch_downloads_module, "ActionChain", FakeActionChain)
monkeypatch.setattr(fetch_downloads_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
context = ActionContext(
downloads=[
DownloadTask(download_id="hash-1", downloader="qbittorrent"),
]
)
result = FetchDownloadsAction("fetch-downloads").execute(
workflow_id=1,
params={},
context=context,
)
assert calls == [(["hash-1"], "qbittorrent")]
assert result.downloads[0].completed is True
assert result.downloads[0].path == "/downloads/movie.mkv"
def test_scrape_file_keeps_workflow_action_context(monkeypatch):
"""刮削文件动作不应将工作流上下文替换为媒体识别上下文。"""
scraped = []
class FakeStorageChain:
"""模拟存储链。"""
def exists(self, fileitem):
return True
class FakeMediaChain:
"""模拟媒体识别和刮削链。"""
def recognize_by_path(self, path, obtain_images=False):
return SimpleNamespace(meta_info="meta", media_info="media")
def scrape_metadata(self, fileitem, meta=None, mediainfo=None):
scraped.append((fileitem.path, meta, mediainfo))
monkeypatch.setattr(scrape_file_module, "StorageChain", FakeStorageChain)
monkeypatch.setattr(scrape_file_module, "MediaChain", FakeMediaChain)
monkeypatch.setattr(scrape_file_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
monkeypatch.setattr(ScrapeFileAction, "check_cache", lambda self, workflow_id, key: False)
monkeypatch.setattr(ScrapeFileAction, "save_cache", lambda self, workflow_id, data: None)
context = ActionContext(
fileitems=[
FileItem(path="/library/movie.mkv", storage="local", type="file"),
]
)
result = ScrapeFileAction("scrape-file").execute(
workflow_id=1,
params={},
context=context,
)
assert result is context
assert result.fileitems[0].path == "/library/movie.mkv"
assert scraped == [("/library/movie.mkv", "meta", "media")]
def test_execute_with_inputs_maps_contract_inputs_outputs_and_runtime(monkeypatch):
"""新版动作桥接方法应按契约映射输入、输出和运行期信息。"""
class ContractAction(BaseAction):
"""测试动作契约桥接。"""
contract = {
"inputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
"outputs": [{"name": "downloads", "label": "下载任务", "kind": "list"}],
}
@classmethod
@property
def name(cls) -> str:
return "契约动作"
@classmethod
@property
def description(cls) -> str:
return "测试契约动作"
@classmethod
@property
def data(cls) -> dict:
return {}
@property
def success(self) -> bool:
return True
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
"""执行测试动作。"""
_ = workflow_id, params
context.downloads = [
DownloadTask(download_id=f"{item}-hash", downloader="qbittorrent")
for item in context.torrents
]
self.job_done("完成")
return context
result = ContractAction("contract").execute_with_inputs(
workflow_id=1,
params={},
inputs={"torrents": ["movie"]},
runtime={"attempt": 1, "max_attempts": 1, "cancel_token": object()},
context=ActionContext(),
)
assert isinstance(result, ActionResult)
assert result.outputs["downloads"][0].download_id == "movie-hash"
assert result.context.runtime_state["current_action_runtime"] == {
"attempt": 1,
"max_attempts": 1,
}
path_result = ContractAction("contract").execute_with_inputs(
workflow_id=1,
params={},
inputs={"outputs.FetchRssAction.torrents": ["legacy"]},
runtime={},
context=ActionContext(),
)
assert path_result.outputs["downloads"][0].download_id == "legacy-hash"
def test_workflow_manager_list_actions_exposes_contract():
"""动作列表应返回固定输入输出契约。"""
manager = object.__new__(WorkFlowManager)
manager._actions = {"FetchRssAction": FetchRssAction}
actions = manager.list_actions()
assert actions[0]["contract"]["outputs"][0]["name"] == "torrents"
assert actions[0]["contract"]["condition_fields"][0]["label"] == "资源"

View File

@@ -0,0 +1,787 @@
import base64
import pickle
import threading
import time
from types import SimpleNamespace
from app.chain import workflow as workflow_module
from app.schemas import Action, ActionContext, ActionResult
from app.schemas.types import EventType
from app import workflow as workflow_package
def _build_workflow(current_action=None, context=None, actions=None, flows=None,
execution_config=None, execution_state=None):
"""构造最小工作流对象。"""
return SimpleNamespace(
id=1,
name="测试工作流",
actions=actions if actions is not None else [
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
],
flows=flows if flows is not None else [
{"id": "flow-1", "source": "A", "target": "B", "animated": True},
],
current_action=current_action,
context=context,
execution_config=execution_config or {},
execution_state=execution_state or {},
)
def _encoded_context(context: ActionContext) -> dict:
"""编码工作流恢复上下文。"""
return {
"content": base64.b64encode(pickle.dumps(context)).decode("utf-8"),
}
class _FakeWorkflowManager:
"""记录执行动作的工作流管理器。"""
def __init__(self, calls, results=None, contracts=None):
self.calls = calls
self.results = results or {}
self.contracts = contracts or {}
self.received_inputs = []
def execute(self, workflow_id, action, context=None, inputs=None, runtime=None, cancel_token=None):
"""执行伪动作并记录新版输入。"""
self.calls.append(action.id)
self.received_inputs.append((action.id, inputs or {}, runtime or {}, cancel_token))
result = self.results.get(action.id)
if callable(result):
return result(action, context or ActionContext())
if result:
return result
return ActionResult(success=True, message=f"{action.name}完成", context=context or ActionContext())
def excute(self, workflow_id, action, context=None):
"""兼容历史执行方法。"""
result = self.execute(workflow_id, action, context)
return result.success, result.message, result.context
def get_action_contract(self, action_type):
"""获取伪动作契约。"""
return self.contracts.get(action_type) or {}
class _FakeWorkflowOper:
"""记录工作流持久化调用。"""
def __init__(self, workflow):
self.workflow = workflow
self.steps = []
self.started = False
self.failed_result = None
self.succeeded = False
def reset(self, wid):
"""模拟重置工作流。"""
_ = wid
return True
def get(self, wid):
"""返回预置工作流。"""
_ = wid
return self.workflow
def start(self, wid):
"""记录启动调用。"""
_ = wid
self.started = True
return True
def step(self, wid, action_id, context, execution_state=None):
"""记录步骤持久化数据。"""
self.steps.append(
{
"wid": wid,
"action_id": action_id,
"context": context,
"execution_state": execution_state,
}
)
return True
def fail(self, wid, result):
"""记录失败结果。"""
_ = wid
self.failed_result = result
return True
def success(self, wid, result=None):
"""记录成功结果。"""
_ = wid, result
self.succeeded = True
return True
class _OpaqueValue:
"""模拟无法直接 JSON 序列化的值。"""
__slots__ = ()
def __str__(self):
return "opaque-value"
def test_workflow_executor_resumes_downstream_nodes(monkeypatch):
"""恢复执行时应释放已完成节点的后继节点。"""
calls = []
fake_manager = _FakeWorkflowManager(calls)
workflow = _build_workflow(
current_action="A",
context=_encoded_context(ActionContext()),
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
assert calls == ["B"]
assert executor.success is True
assert executor.context.progress == 100
def test_workflow_executor_restores_structured_context(monkeypatch):
"""恢复执行时应兼容新版结构化上下文存储格式。"""
calls = []
fake_manager = _FakeWorkflowManager(calls)
workflow = _build_workflow(
current_action="A",
context={
"workflow_context": {"trace_id": "wf-1"},
"node_outputs": {"A": {"items": ["movie"]}},
"progress": 50,
},
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
assert calls == ["B"]
assert executor.context.workflow_context["trace_id"] == "wf-1"
assert executor.context.node_outputs["A"]["items"] == ["movie"]
def test_workflow_executor_reports_incremental_progress(monkeypatch):
"""顺序工作流的中间进度应按已完成比例计算。"""
calls = []
progresses = []
fake_manager = _FakeWorkflowManager(calls)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(
_build_workflow(),
step_callback=lambda action, context: progresses.append(context.progress),
)
executor.execute()
assert calls == ["A", "B"]
assert progresses == [50, 100]
def test_workflow_executor_skips_false_condition_branch(monkeypatch):
"""条件边不满足时应跳过对应分支,并继续执行满足条件的分支。"""
calls = []
fake_manager = _FakeWorkflowManager(
calls,
results={
"A": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"items": ["movie"]}
)
}
)
workflow = _build_workflow(
actions=[
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
{"id": "C", "type": "FakeAction", "name": "动作C", "data": {}},
],
flows=[
{"id": "flow-ab", "source": "A", "target": "B", "condition": "outputs.A.items.count == 0"},
{"id": "flow-ac", "source": "A", "target": "C", "data": {"condition": "outputs.A.items.count > 0"}},
],
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
assert calls == ["A", "C"]
assert executor.success is True
assert executor.context.progress == 100
assert executor.context.node_outputs["A"]["items"] == ["movie"]
def test_workflow_executor_all_success_join_waits_parallel_branches(monkeypatch):
"""默认汇合策略应等待所有上游分支成功后再执行目标节点。"""
calls = []
joined_outputs = {}
def run_join(action, context):
"""记录汇合节点读取到的上游输出。"""
joined_outputs.update(context.node_outputs)
return ActionResult(success=True, message=f"{action.name}完成", context=context)
fake_manager = _FakeWorkflowManager(
calls,
results={
"A": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"value": "A"}
),
"B": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"value": "B"}
),
"C": run_join,
}
)
workflow = _build_workflow(
actions=[
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
{"id": "C", "type": "FakeAction", "name": "动作C", "data": {}},
],
flows=[
{"id": "flow-ac", "source": "A", "target": "C"},
{"id": "flow-bc", "source": "B", "target": "C"},
],
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
assert set(calls) == {"A", "B", "C"}
assert calls[-1] == "C"
assert joined_outputs["A"] == {"value": "A"}
assert joined_outputs["B"] == {"value": "B"}
def test_workflow_executor_any_success_join_runs_after_available_branch(monkeypatch):
"""any_success 汇合策略应允许任一满足条件的上游分支触发目标节点。"""
calls = []
fake_manager = _FakeWorkflowManager(
calls,
results={
"A": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"items": ["movie"]}
)
}
)
workflow = _build_workflow(
actions=[
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
{"id": "C", "type": "FakeAction", "name": "动作C", "data": {}},
{"id": "D", "type": "FakeAction", "name": "动作D", "data": {"join_policy": "any_success"}},
],
flows=[
{"id": "flow-ab", "source": "A", "target": "B", "condition": "outputs.A.items.count == 0"},
{"id": "flow-ac", "source": "A", "target": "C", "condition": "outputs.A.items.count > 0"},
{"id": "flow-bd", "source": "B", "target": "D"},
{"id": "flow-cd", "source": "C", "target": "D"},
],
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
assert calls == ["A", "C", "D"]
assert executor.context.progress == 100
def test_workflow_executor_all_done_join_can_continue_after_failure(monkeypatch):
"""continue 失败策略配合 all_done 汇合时应继续执行收尾节点。"""
calls = []
fake_manager = _FakeWorkflowManager(
calls,
results={
"A": lambda action, context: ActionResult(success=False, message=f"{action.name}失败", context=context)
}
)
workflow = _build_workflow(
actions=[
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {"fail_policy": "continue"}},
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
{"id": "C", "type": "FakeAction", "name": "动作C", "data": {"join_policy": "all_done"}},
],
flows=[
{"id": "flow-ac", "source": "A", "target": "C"},
{"id": "flow-bc", "source": "B", "target": "C"},
],
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
assert set(calls) == {"A", "B", "C"}
assert calls[-1] == "C"
assert executor.has_failure is True
assert executor.success is True
def test_workflow_executor_exclusive_branch_uses_first_matching_flow(monkeypatch):
"""互斥分支应只执行第一条满足条件的出边。"""
calls = []
fake_manager = _FakeWorkflowManager(
calls,
results={
"A": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"count": 2}
)
}
)
workflow = _build_workflow(
actions=[
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {"branch_policy": "exclusive"}},
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
{"id": "C", "type": "FakeAction", "name": "动作C", "data": {}},
],
flows=[
{"id": "flow-ab", "source": "A", "target": "B", "condition": "outputs.A.count > 0"},
{"id": "flow-ac", "source": "A", "target": "C", "condition": "outputs.A.count > 1"},
],
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
assert calls == ["A", "B"]
assert executor.node_states["C"] == "skipped"
def test_workflow_executor_passes_declared_inputs(monkeypatch):
"""动作输入声明应从 node_outputs 中读取指定路径。"""
calls = []
fake_manager = _FakeWorkflowManager(
calls,
results={
"A": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"torrents": ["a", "b"]}
)
}
)
workflow = _build_workflow(
actions=[
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
{
"id": "B",
"type": "FakeAction",
"name": "动作B",
"data": {"inputs": ["A.torrents", "outputs.A.torrents.count"]},
},
],
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
b_inputs = [item for action_id, item, _, _ in fake_manager.received_inputs if action_id == "B"][0]
assert b_inputs == {
"A.torrents": ["a", "b"],
"outputs.A.torrents.count": 2,
}
def test_workflow_executor_uses_contract_inputs(monkeypatch):
"""未手写输入声明时应按动作契约读取上下文字段。"""
calls = []
fake_manager = _FakeWorkflowManager(
calls,
contracts={
"NeedsTorrentsAction": {
"inputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
"outputs": [],
}
},
results={
"A": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"torrents": ["a", "b"]}
)
}
)
workflow = _build_workflow(
actions=[
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
{"id": "B", "type": "NeedsTorrentsAction", "name": "动作B", "data": {}},
],
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
b_inputs = [item for action_id, item, _, _ in fake_manager.received_inputs if action_id == "B"][0]
assert b_inputs == {"torrents": ["a", "b"]}
def test_workflow_executor_persists_structured_state(monkeypatch):
"""步骤回调应收到可持久化的结构化执行状态。"""
calls = []
states = []
fake_manager = _FakeWorkflowManager(
calls,
results={
"A": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"items": ["movie"]}
)
}
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(
_build_workflow(actions=[{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}], flows=[]),
step_callback=lambda action, context, execution_state, completed: states.append(execution_state),
)
executor.execute()
assert states[-1]["nodes"]["A"]["state"] == "success"
assert states[-1]["outputs"]["A"]["items"] == ["movie"]
assert states[-1]["runtime"]["progress"] == 100
def test_workflow_executor_restores_outputs_from_execution_state(monkeypatch):
"""恢复执行时应从结构化状态读取节点输出并继续判断条件边。"""
calls = []
fake_manager = _FakeWorkflowManager(calls)
workflow = _build_workflow(
execution_state={
"nodes": {
"A": {"state": "success", "attempt": 1},
},
"outputs": {
"A": {"torrents": ["movie"]},
},
},
flows=[
{"id": "flow-ab", "source": "A", "target": "B", "condition": "A.torrents.count > 0"},
],
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
assert calls == ["B"]
assert executor.context.node_outputs["A"]["torrents"] == ["movie"]
def test_workflow_executor_keeps_execution_state_dict_for_non_json_leaf(monkeypatch):
"""结构化状态遇到不可序列化叶子节点时仍应保持字典结构。"""
calls = []
states = []
fake_manager = _FakeWorkflowManager(
calls,
results={
"A": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"opaque": _OpaqueValue()}
)
}
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(
_build_workflow(actions=[{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}], flows=[]),
step_callback=lambda action, context, execution_state, completed: states.append(execution_state),
)
executor.execute()
assert isinstance(states[-1], dict)
assert states[-1]["outputs"]["A"]["opaque"] == "opaque-value"
def test_workflow_chain_process_serializes_circular_context(monkeypatch):
"""工作流步骤持久化应清洗循环引用和不可序列化上下文。"""
calls = []
def run_action(action, context):
"""构造包含循环引用的上下文。"""
context.workflow_context["self"] = context.workflow_context
context.workflow_context["opaque"] = _OpaqueValue()
return ActionResult(success=True, message=f"{action.name}完成", context=context)
fake_manager = _FakeWorkflowManager(calls, results={"A": run_action})
workflow = _build_workflow(
actions=[{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}],
flows=[{"id": "flow-end", "source": "A", "target": "END", "animated": True}],
)
fake_oper = _FakeWorkflowOper(workflow)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module, "WorkflowOper", lambda: fake_oper)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
success, message = workflow_module.WorkflowChain.process(workflow_id=1)
assert success is True
assert message == ""
assert fake_oper.succeeded is True
saved_workflow_context = fake_oper.steps[-1]["context"]["workflow_context"]
saved_self = saved_workflow_context["self"]
assert saved_workflow_context["opaque"] == "opaque-value"
if isinstance(saved_self, dict):
assert saved_self["self"] == workflow_module.CIRCULAR_REFERENCE_PLACEHOLDER
assert saved_self["opaque"] == "opaque-value"
else:
assert saved_self == workflow_module.CIRCULAR_REFERENCE_PLACEHOLDER
def test_workflow_executor_concurrency_key_serializes_parallel_nodes(monkeypatch):
"""相同 concurrency_key 的并行节点不应同时运行。"""
calls = []
active_count = 0
max_active_count = 0
lock = threading.Lock()
def run_action(action, context):
"""记录同一并发键下的同时运行数量。"""
nonlocal active_count, max_active_count
with lock:
active_count += 1
max_active_count = max(max_active_count, active_count)
time.sleep(0.05)
with lock:
active_count -= 1
return ActionResult(success=True, message=f"{action.name}完成", context=context)
fake_manager = _FakeWorkflowManager(calls, results={"A": run_action, "B": run_action})
workflow = _build_workflow(
actions=[
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {"concurrency_key": "download"}},
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {"concurrency_key": "download"}},
],
flows=[],
execution_config={"max_workers": 2},
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
assert set(calls) == {"A", "B"}
assert max_active_count == 1
def test_workflow_executor_filter_action_replaces_artifact_outputs(monkeypatch):
"""过滤类动作默认应替换列表输出,避免把过滤前数据重新合并回来。"""
calls = []
fake_manager = _FakeWorkflowManager(
calls,
results={
"A": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"torrents": ["old", "keep"]}
),
"B": lambda action, context: ActionResult(
success=True,
message=f"{action.name}完成",
context=context,
outputs={"torrents": ["keep"]}
),
}
)
workflow = _build_workflow(
actions=[
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
{"id": "B", "type": "FilterTorrentsAction", "name": "过滤资源", "data": {}},
],
)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
executor = workflow_module.WorkflowExecutor(workflow)
executor.execute()
assert executor.context.torrents == ["keep"]
assert executor.context.artifacts["torrents"] == ["keep"]
def test_workflow_executor_stop_is_not_success(monkeypatch):
"""停止信号不应被执行器汇报为成功完成。"""
calls = []
fake_manager = _FakeWorkflowManager(calls)
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: True)
executor = workflow_module.WorkflowExecutor(_build_workflow())
executor.execute()
assert calls == []
assert executor.stopped is True
assert executor.success is False
assert executor.errmsg == "工作流已停止"
def test_workflow_context_merge_preserves_runtime_objects():
"""合并上下文时应保留运行时对象,而不是转成字典。"""
executor = object.__new__(workflow_module.WorkflowExecutor)
executor.context = ActionContext()
runtime_torrent = SimpleNamespace(title="runtime torrent")
result_context = ActionContext()
result_context.torrents.append(runtime_torrent)
executor.merge_context(result_context)
assert executor.context.torrents[0] is runtime_torrent
class _FakeEventManager:
"""记录事件监听器注册和移除次数。"""
def __init__(self):
self.added = []
self.removed = []
def add_event_listener(self, event_type, handler):
self.added.append(event_type)
def remove_event_listener(self, event_type, handler):
self.removed.append(event_type)
def test_workflow_event_listener_keeps_shared_handler_until_last_workflow(monkeypatch):
"""同一事件下移除单个工作流时不应断开其他工作流监听。"""
fake_eventmanager = _FakeEventManager()
manager = object.__new__(workflow_package.WorkFlowManager)
manager._lock = threading.Lock()
manager._event_workflows = {}
monkeypatch.setattr(workflow_package, "eventmanager", fake_eventmanager)
manager.register_workflow_event(1, EventType.DownloadAdded.value)
manager.register_workflow_event(2, EventType.DownloadAdded.value)
manager.remove_workflow_event(1, EventType.DownloadAdded.value)
assert fake_eventmanager.added == [EventType.DownloadAdded]
assert fake_eventmanager.removed == []
assert manager.get_event_workflows() == {EventType.DownloadAdded.value: [2]}
manager.remove_workflow_event(2, EventType.DownloadAdded.value)
assert fake_eventmanager.removed == [EventType.DownloadAdded]
assert manager.get_event_workflows() == {}
def test_workflow_manager_retries_action_until_success(monkeypatch):
"""动作管理器应按 retry 配置重试失败动作。"""
class RetryAction:
"""模拟第二次才成功的动作。"""
call_count = 0
def __init__(self, action_id):
self.action_id = action_id
def execute_with_inputs(self, workflow_id, params, inputs, runtime, context):
"""执行动作并在第二次返回成功。"""
_ = workflow_id, params, inputs, runtime
RetryAction.call_count += 1
if RetryAction.call_count == 1:
return ActionResult(success=False, message="第一次失败", context=context)
return ActionResult(success=True, message="第二次成功", context=context, outputs={"ok": True})
manager = object.__new__(workflow_package.WorkFlowManager)
manager._actions = {"RetryAction": RetryAction}
monkeypatch.setattr(workflow_package.global_vars, "is_workflow_stopped", lambda workflow_id: False)
result = manager.execute(
workflow_id=1,
action=Action(
id="retry",
type="RetryAction",
name="重试动作",
data={"retry": {"max_attempts": 2, "interval": 0}},
),
context=ActionContext(),
)
assert result.success is True
assert result.attempts == 2
assert result.outputs == {"ok": True}
assert RetryAction.call_count == 2

View File

@@ -1,2 +1,2 @@
APP_VERSION = 'v2.13.4'
FRONTEND_VERSION = 'v2.13.4'
APP_VERSION = 'v2.13.5-1'
FRONTEND_VERSION = 'v2.13.5'