mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-21 23:44:31 +08:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
871d1ec0d8 | ||
|
|
ca1dbdf843 | ||
|
|
e77bef7cf1 | ||
|
|
f4011d3ac2 | ||
|
|
d0b62523a0 | ||
|
|
a9b1f7e9c9 | ||
|
|
fc8933c648 | ||
|
|
51981d151e | ||
|
|
97cfcda03c | ||
|
|
a2984530f8 | ||
|
|
7474ecd02f | ||
|
|
9056caae40 | ||
|
|
fd280a49b7 | ||
|
|
df75f42753 | ||
|
|
0d2c324e28 | ||
|
|
dc0ee2b466 | ||
|
|
781b1ce2aa | ||
|
|
791f1fe4ac | ||
|
|
6405ff1191 |
7
.github/workflows/test.yml
vendored
7
.github/workflows/test.yml
vendored
@@ -11,6 +11,13 @@ on:
|
||||
# 允许手动触发
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: unit-tests-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
|
||||
from app.api.endpoints import auth, login, user, webhook, message, site, subscribe, \
|
||||
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa, openai, anthropic, llm, notification
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
api_router.include_router(login.router, prefix="/login", tags=["login"])
|
||||
api_router.include_router(user.router, prefix="/user", tags=["user"])
|
||||
api_router.include_router(mfa.router, prefix="/mfa", tags=["mfa"])
|
||||
|
||||
70
app/api/endpoints/auth.py
Normal file
70
app/api/endpoints/auth.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app import schemas
|
||||
from app.core.auth_bridge import build_token_response, consume_plugin_auth_ticket
|
||||
from app.core.plugin import PluginManager
|
||||
from app.db.models.passkey import PassKey
|
||||
from app.db.models.user import User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class AuthExchangeRequest(BaseModel):
|
||||
"""
|
||||
插件认证票据兑换请求。
|
||||
"""
|
||||
|
||||
ticket: str
|
||||
|
||||
|
||||
def _system_auth_providers() -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取系统内建的匿名登录方式摘要。
|
||||
|
||||
:return: 系统认证提供方列表
|
||||
"""
|
||||
has_passkey = bool(PassKey.list(db=None))
|
||||
return [
|
||||
{
|
||||
"id": "system:passkey",
|
||||
"type": "system",
|
||||
"method": "passkey",
|
||||
"name": "通行密钥",
|
||||
"icon": "material-symbols:passkey",
|
||||
"enabled": has_passkey,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@router.get("/providers", summary="查询登录认证提供方", response_model=list[dict])
|
||||
def auth_providers() -> list[dict[str, Any]]:
|
||||
"""
|
||||
查询系统和插件提供的登录认证入口。
|
||||
|
||||
:return: 认证提供方摘要列表
|
||||
"""
|
||||
providers = _system_auth_providers()
|
||||
providers.extend(PluginManager().get_plugin_auth_providers())
|
||||
return [provider for provider in providers if provider.get("enabled", True)]
|
||||
|
||||
|
||||
@router.post("/exchange", summary="兑换插件认证登录票据", response_model=schemas.Token)
|
||||
def auth_exchange(body: AuthExchangeRequest) -> schemas.Token:
|
||||
"""
|
||||
将插件认证成功后生成的一次性票据兑换为系统 Token。
|
||||
|
||||
:param body: 票据兑换请求
|
||||
:return: 标准登录 Token
|
||||
"""
|
||||
ticket_data = consume_plugin_auth_ticket(body.ticket)
|
||||
if not ticket_data:
|
||||
raise HTTPException(status_code=401, detail="认证票据无效或已过期")
|
||||
|
||||
user = User.get(db=None, rid=ticket_data.get("user_id"))
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||||
|
||||
return build_token_response(user)
|
||||
@@ -22,6 +22,10 @@ from app.schemas.types import EventType, EVENT_TYPE_NAMES
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
WORKFLOW_TRIGGER_TIMER = "timer"
|
||||
WORKFLOW_TRIGGER_EVENT = "event"
|
||||
WORKFLOW_TRIGGER_MANUAL = "manual"
|
||||
|
||||
|
||||
@router.get("/", summary="所有工作流", response_model=List[schemas.Workflow])
|
||||
async def list_workflows(
|
||||
@@ -148,16 +152,20 @@ async def workflow_fork(
|
||||
except json.JSONDecodeError:
|
||||
return schemas.Response(success=False, message="context字段JSON格式错误")
|
||||
|
||||
try:
|
||||
event_conditions = json.loads(workflow.event_conditions or "{}") if workflow.event_conditions else {}
|
||||
except json.JSONDecodeError:
|
||||
return schemas.Response(success=False, message="event_conditions字段JSON格式错误")
|
||||
|
||||
share_id = workflow.id
|
||||
# 创建工作流
|
||||
workflow_dict = {
|
||||
"name": workflow.name,
|
||||
"description": workflow.description,
|
||||
"timer": workflow.timer,
|
||||
"trigger_type": workflow.trigger_type or "timer",
|
||||
"trigger_type": workflow.trigger_type or WORKFLOW_TRIGGER_TIMER,
|
||||
"event_type": workflow.event_type,
|
||||
"event_conditions": json.loads(workflow.event_conditions or "{}")
|
||||
if workflow.event_conditions
|
||||
else {},
|
||||
"event_conditions": event_conditions,
|
||||
"actions": actions,
|
||||
"flows": flows,
|
||||
"context": context,
|
||||
@@ -170,11 +178,11 @@ async def workflow_fork(
|
||||
return schemas.Response(success=False, message="已存在相同名称的工作流")
|
||||
|
||||
# 创建新工作流
|
||||
workflow = await Workflow(**workflow_dict).async_create(db)
|
||||
workflow_obj = await Workflow(**workflow_dict).async_create(db)
|
||||
|
||||
# 更新复用次数
|
||||
if workflow:
|
||||
await MoviePilotServerHelper.async_workflow_fork_by_id(share_id=workflow.id)
|
||||
if workflow_obj and share_id:
|
||||
await MoviePilotServerHelper.async_workflow_fork_by_id(share_id=share_id)
|
||||
|
||||
return schemas.Response(success=True, message="复用成功")
|
||||
|
||||
@@ -225,14 +233,23 @@ def start_workflow(
|
||||
workflow = WorkflowOper(db).get(workflow_id)
|
||||
if not workflow:
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
if not workflow.trigger_type or workflow.trigger_type == "timer":
|
||||
trigger_type = workflow.trigger_type or WORKFLOW_TRIGGER_TIMER
|
||||
if trigger_type == WORKFLOW_TRIGGER_TIMER and not workflow.timer:
|
||||
return schemas.Response(success=False, message="定时工作流缺少定时器配置")
|
||||
if trigger_type not in {
|
||||
WORKFLOW_TRIGGER_TIMER,
|
||||
WORKFLOW_TRIGGER_EVENT,
|
||||
WORKFLOW_TRIGGER_MANUAL,
|
||||
}:
|
||||
return schemas.Response(success=False, message="工作流触发类型不支持")
|
||||
# 先更新状态,事件触发注册会重新读取工作流并跳过暂停状态。
|
||||
workflow.update_state(db, workflow_id, "W")
|
||||
if trigger_type == WORKFLOW_TRIGGER_TIMER:
|
||||
# 添加定时任务
|
||||
Scheduler().update_workflow_job(workflow)
|
||||
else:
|
||||
elif trigger_type == WORKFLOW_TRIGGER_EVENT:
|
||||
# 事件触发:添加到事件触发器
|
||||
WorkFlowManager().load_workflow_events(workflow_id)
|
||||
# 更新状态
|
||||
workflow.update_state(db, workflow_id, "W")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@@ -251,10 +268,10 @@ def pause_workflow(
|
||||
if not workflow:
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
# 根据触发类型进行不同处理
|
||||
if workflow.trigger_type == "timer":
|
||||
if workflow.trigger_type == WORKFLOW_TRIGGER_TIMER:
|
||||
# 定时触发:移除定时任务
|
||||
Scheduler().remove_workflow_job(workflow)
|
||||
elif workflow.trigger_type == "event":
|
||||
elif workflow.trigger_type == WORKFLOW_TRIGGER_EVENT:
|
||||
# 事件触发:从事件触发器中移除
|
||||
WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type)
|
||||
# 停止工作流
|
||||
@@ -319,8 +336,11 @@ def update_workflow(
|
||||
wf.update(db, workflow.model_dump())
|
||||
# 更新后的工作流对象
|
||||
updated_workflow = workflow_oper.get(workflow.id)
|
||||
# 更新定时任务
|
||||
Scheduler().update_workflow_job(updated_workflow)
|
||||
scheduler = Scheduler()
|
||||
scheduler.remove_workflow_job(updated_workflow)
|
||||
if not updated_workflow.trigger_type or updated_workflow.trigger_type == WORKFLOW_TRIGGER_TIMER:
|
||||
if updated_workflow.timer:
|
||||
scheduler.update_workflow_job(updated_workflow)
|
||||
# 更新事件注册
|
||||
WorkFlowManager().update_workflow_event(updated_workflow)
|
||||
return schemas.Response(success=True, message="更新成功")
|
||||
@@ -338,10 +358,10 @@ def delete_workflow(
|
||||
workflow = WorkflowOper(db).get(workflow_id)
|
||||
if not workflow:
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
if not workflow.trigger_type or workflow.trigger_type == "timer":
|
||||
if not workflow.trigger_type or workflow.trigger_type == WORKFLOW_TRIGGER_TIMER:
|
||||
# 定时触发:删除定时任务
|
||||
Scheduler().remove_workflow_job(workflow)
|
||||
else:
|
||||
elif workflow.trigger_type == WORKFLOW_TRIGGER_EVENT:
|
||||
# 事件触发:从事件触发器中移除
|
||||
WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type)
|
||||
# 删除工作流
|
||||
|
||||
@@ -2785,9 +2785,16 @@ class SubscribeChain(ChainBase):
|
||||
# 更新剧集列表、开始集数、总集数
|
||||
if not episode_list:
|
||||
# 整季缺失
|
||||
episodes = []
|
||||
start_episode = start_episode or start
|
||||
total_episode = total_episode or total
|
||||
original_start = start if start is not None else 1
|
||||
# 空集列表会被下载链解释为整季下载;当订阅开始集裁掉季初范围时,需要转成显式集数。
|
||||
if start_episode and total_episode and start_episode > original_start:
|
||||
episodes = list(range(start_episode, total_episode + 1))
|
||||
if not episodes:
|
||||
return True, {}
|
||||
else:
|
||||
episodes = []
|
||||
else:
|
||||
# 部分缺失
|
||||
if not start_episode \
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
146
app/core/auth_bridge.py
Normal file
146
app/core/auth_bridge.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional
|
||||
|
||||
from app import schemas
|
||||
from app.core import security
|
||||
from app.core.config import settings
|
||||
from app.db.models.user import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.sites import SitesHelper
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
|
||||
class AuthTicketStore(metaclass=Singleton):
|
||||
"""
|
||||
插件认证一次性票据存储。
|
||||
"""
|
||||
|
||||
_ttl_seconds = 120
|
||||
_max_items = 1024
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化内存票据缓存。
|
||||
"""
|
||||
self._tickets: dict[str, dict[str, Any]] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def create(self, user_id: int, provider_id: str, metadata: Optional[dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
创建短时一次性登录票据。
|
||||
|
||||
:param user_id: 已通过插件认证的本地用户 ID
|
||||
:param provider_id: 认证提供方 ID
|
||||
:param metadata: 插件侧附加信息
|
||||
:return: 一次性票据字符串
|
||||
"""
|
||||
ticket = secrets.token_urlsafe(32)
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
self._cleanup(now)
|
||||
self._tickets[ticket] = {
|
||||
"user_id": int(user_id),
|
||||
"provider_id": provider_id,
|
||||
"metadata": metadata or {},
|
||||
"created_at": now,
|
||||
}
|
||||
return ticket
|
||||
|
||||
def consume(self, ticket: str) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
消费并删除一次性登录票据。
|
||||
|
||||
:param ticket: 登录票据
|
||||
:return: 票据数据,票据不存在或过期时返回 None
|
||||
"""
|
||||
if not ticket:
|
||||
return None
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
data = self._tickets.pop(ticket, None)
|
||||
self._cleanup(now)
|
||||
if not data:
|
||||
return None
|
||||
if now - float(data.get("created_at") or 0) > self._ttl_seconds:
|
||||
return None
|
||||
return data
|
||||
|
||||
def _cleanup(self, now: Optional[float] = None) -> None:
|
||||
"""
|
||||
清理过期或过量的票据缓存。
|
||||
|
||||
:param now: 当前时间戳,未传入时自动读取
|
||||
"""
|
||||
current = now or time.time()
|
||||
expired = [
|
||||
key
|
||||
for key, value in self._tickets.items()
|
||||
if current - float(value.get("created_at") or 0) > self._ttl_seconds
|
||||
]
|
||||
for key in expired:
|
||||
self._tickets.pop(key, None)
|
||||
if len(self._tickets) <= self._max_items:
|
||||
return
|
||||
ordered = sorted(
|
||||
self._tickets.items(),
|
||||
key=lambda item: float(item[1].get("created_at") or 0),
|
||||
)
|
||||
for key, _ in ordered[: len(self._tickets) - self._max_items]:
|
||||
self._tickets.pop(key, None)
|
||||
|
||||
|
||||
def create_plugin_auth_ticket(user_id: int, provider_id: str, metadata: Optional[dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
为插件认证成功的用户创建一次性登录票据。
|
||||
|
||||
:param user_id: 本地用户 ID
|
||||
:param provider_id: 认证提供方 ID
|
||||
:param metadata: 插件侧附加信息
|
||||
:return: 一次性票据字符串
|
||||
"""
|
||||
return AuthTicketStore().create(user_id=user_id, provider_id=provider_id, metadata=metadata)
|
||||
|
||||
|
||||
def consume_plugin_auth_ticket(ticket: str) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
消费插件认证登录票据。
|
||||
|
||||
:param ticket: 登录票据
|
||||
:return: 票据数据,票据不存在或过期时返回 None
|
||||
"""
|
||||
return AuthTicketStore().consume(ticket)
|
||||
|
||||
|
||||
def build_token_response(user: User) -> schemas.Token:
|
||||
"""
|
||||
使用系统统一逻辑构造登录 Token 响应。
|
||||
|
||||
:param user: 已认证的本地用户
|
||||
:return: 标准 Token 响应
|
||||
"""
|
||||
level = SitesHelper().auth_level
|
||||
show_wizard = (
|
||||
not SystemConfigOper().get(SystemConfigKey.SetupWizardState)
|
||||
and not settings.ADVANCED_MODE
|
||||
)
|
||||
return schemas.Token(
|
||||
access_token=security.create_access_token(
|
||||
userid=user.id,
|
||||
username=user.name,
|
||||
super_user=user.is_superuser,
|
||||
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
level=level,
|
||||
),
|
||||
token_type="bearer",
|
||||
super_user=user.is_superuser,
|
||||
user_id=user.id,
|
||||
user_name=user.name,
|
||||
avatar=user.avatar,
|
||||
level=level,
|
||||
permissions=user.permissions or {},
|
||||
wizard=show_wizard,
|
||||
)
|
||||
@@ -9,10 +9,10 @@ import sys
|
||||
import threading
|
||||
from asyncio import AbstractEventLoop
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_origin, get_args
|
||||
from urllib.parse import quote, urlencode, urlparse
|
||||
|
||||
from dotenv import set_key
|
||||
from dotenv import set_key, unset_key
|
||||
from pydantic import BaseModel, Field, ConfigDict, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
@@ -690,6 +690,18 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
|
||||
# 处理 Optional 类型:当值为空字符串且类型允许 None 时,转为 None
|
||||
# 兼容 typing.Union (Python 3.9) 与 types.UnionType (Python 3.10+ PEP 604)
|
||||
origin = get_origin(expected_type)
|
||||
is_union = origin is Union or getattr(origin, "__name__", None) == "UnionType"
|
||||
if (
|
||||
is_union
|
||||
and type(None) in get_args(expected_type)
|
||||
and isinstance(value, str)
|
||||
and not value
|
||||
):
|
||||
return default, str(default) != str(original_value)
|
||||
|
||||
try:
|
||||
if expected_type is bool:
|
||||
if isinstance(value, bool):
|
||||
@@ -812,13 +824,19 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
logger.warning(message)
|
||||
return False, message
|
||||
else:
|
||||
# 当值为 None 时,从 env 文件中删除该键,恢复为默认值
|
||||
if converted_value is None:
|
||||
unset_key(
|
||||
dotenv_path=SystemUtils.get_env_path(),
|
||||
key_to_unset=field_name,
|
||||
)
|
||||
logger.info(f"配置项 '{field_name}' 已清空,从 'app.env' 中移除")
|
||||
return True, message
|
||||
# 如果是列表、字典或集合类型,将其转换为JSON字符串
|
||||
if isinstance(converted_value, (list, dict, set)):
|
||||
value_to_write = json.dumps(converted_value)
|
||||
else:
|
||||
value_to_write = (
|
||||
str(converted_value) if converted_value is not None else ""
|
||||
)
|
||||
value_to_write = str(converted_value)
|
||||
|
||||
set_key(
|
||||
dotenv_path=SystemUtils.get_env_path(),
|
||||
@@ -967,7 +985,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
|
||||
@property
|
||||
def PROXY(self):
|
||||
if self.PROXY_HOST:
|
||||
if self.PROXY_HOST and self.PROXY_HOST.strip():
|
||||
return {
|
||||
"http": self.PROXY_HOST,
|
||||
"https": self.PROXY_HOST,
|
||||
@@ -1009,7 +1027,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
|
||||
@property
|
||||
def PROXY_SERVER(self):
|
||||
if self.PROXY_HOST:
|
||||
if self.PROXY_HOST and self.PROXY_HOST.strip():
|
||||
try:
|
||||
parsed = urlparse(self.PROXY_HOST)
|
||||
if not parsed.scheme:
|
||||
|
||||
@@ -950,6 +950,47 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
})
|
||||
return remotes
|
||||
|
||||
def get_plugin_auth_providers(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
聚合插件声明的登录认证提供方。
|
||||
|
||||
:return: 插件认证入口列表
|
||||
"""
|
||||
providers: List[Dict[str, Any]] = []
|
||||
running_plugins_snapshot = dict(self._running_plugins)
|
||||
for plugin_id, plugin in running_plugins_snapshot.items():
|
||||
if not plugin.get_state():
|
||||
continue
|
||||
if not hasattr(plugin, "get_auth_providers") or not ObjectUtils.check_method(plugin.get_auth_providers):
|
||||
continue
|
||||
try:
|
||||
plugin_providers = plugin.get_auth_providers() or []
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件 {plugin_id} 登录认证提供方出错:{str(e)}")
|
||||
continue
|
||||
render_mode = None
|
||||
dist_path = None
|
||||
if hasattr(plugin, "get_render_mode"):
|
||||
render_mode, dist_path = plugin.get_render_mode()
|
||||
for raw_provider in plugin_providers:
|
||||
if not raw_provider or not isinstance(raw_provider, dict):
|
||||
continue
|
||||
provider = raw_provider.copy()
|
||||
provider["type"] = "plugin"
|
||||
provider["plugin_id"] = plugin_id
|
||||
provider.setdefault("id", f"plugin:{plugin_id}")
|
||||
provider.setdefault("name", plugin.plugin_name)
|
||||
provider.setdefault("enabled", True)
|
||||
if render_mode == "vue" and dist_path:
|
||||
provider.setdefault("component", "AuthPage")
|
||||
provider["remote"] = {
|
||||
"id": plugin_id,
|
||||
"url": self.get_plugin_remote_entry(plugin_id, dist_path),
|
||||
"name": plugin.plugin_name,
|
||||
}
|
||||
providers.append(provider)
|
||||
return providers
|
||||
|
||||
def get_plugin_sidebar_nav(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
聚合所有已启用 Vue 插件的侧栏导航项(get_sidebar_nav)。
|
||||
|
||||
@@ -40,8 +40,12 @@ class Workflow(Base):
|
||||
flows = Column(JSON, default=builtin_list)
|
||||
# 执行上下文
|
||||
context = Column(JSON, default=dict)
|
||||
# 执行配置
|
||||
execution_config = Column(JSON, default=dict)
|
||||
# 结构化执行状态
|
||||
execution_state = Column(JSON, default=dict)
|
||||
# 创建时间
|
||||
add_time = Column(String, default=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
add_time = Column(String, default=lambda: datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
# 最后执行时间
|
||||
last_time = Column(String)
|
||||
|
||||
@@ -79,7 +83,7 @@ class Workflow(Base):
|
||||
and_(
|
||||
or_(
|
||||
cls.trigger_type == 'timer',
|
||||
not cls.trigger_type
|
||||
cls.trigger_type.is_(None)
|
||||
),
|
||||
cls.state != 'P'
|
||||
)
|
||||
@@ -93,7 +97,7 @@ class Workflow(Base):
|
||||
and_(
|
||||
or_(
|
||||
cls.trigger_type == 'timer',
|
||||
not cls.trigger_type
|
||||
cls.trigger_type.is_(None)
|
||||
),
|
||||
cls.state != 'P'
|
||||
)
|
||||
@@ -217,6 +221,8 @@ class Workflow(Base):
|
||||
"state": 'W',
|
||||
"result": None,
|
||||
"current_action": None,
|
||||
"context": {},
|
||||
"execution_state": {},
|
||||
"run_count": 0 if reset_count else cls.run_count,
|
||||
})
|
||||
return True
|
||||
@@ -229,30 +235,49 @@ class Workflow(Base):
|
||||
state='W',
|
||||
result=None,
|
||||
current_action=None,
|
||||
context={},
|
||||
execution_state={},
|
||||
run_count=0 if reset_count else cls.run_count,
|
||||
))
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@db_update
|
||||
def update_current_action(cls, db, wid: int, action_id: str, context: dict):
|
||||
db.query(cls).filter(cls.id == wid).update({
|
||||
"current_action": cls.current_action + f",{action_id}" if cls.current_action else action_id,
|
||||
def update_current_action(cls, db, wid: int, action_id: str, context: dict,
|
||||
execution_state: Optional[dict] = None):
|
||||
workflow = db.query(cls).filter(cls.id == wid).first()
|
||||
current_actions = []
|
||||
if workflow and workflow.current_action:
|
||||
current_actions = [item for item in workflow.current_action.split(",") if item]
|
||||
if action_id and action_id not in current_actions:
|
||||
current_actions.append(action_id)
|
||||
update_values = {
|
||||
"current_action": ",".join(current_actions),
|
||||
"context": context
|
||||
})
|
||||
}
|
||||
if execution_state is not None:
|
||||
update_values["execution_state"] = execution_state
|
||||
db.query(cls).filter(cls.id == wid).update(update_values)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@async_db_update
|
||||
async def async_update_current_action(cls, db: AsyncSession, wid: int, action_id: str, context: dict):
|
||||
async def async_update_current_action(cls, db: AsyncSession, wid: int, action_id: str, context: dict,
|
||||
execution_state: Optional[dict] = None):
|
||||
from sqlalchemy import update
|
||||
# 先获取当前current_action
|
||||
result = await db.execute(select(cls.current_action).where(cls.id == wid))
|
||||
current_action = result.scalar()
|
||||
new_current_action = current_action + f",{action_id}" if current_action else action_id
|
||||
current_actions = [item for item in (current_action or "").split(",") if item]
|
||||
if action_id and action_id not in current_actions:
|
||||
current_actions.append(action_id)
|
||||
new_current_action = ",".join(current_actions)
|
||||
|
||||
await db.execute(update(cls).where(cls.id == wid).values(
|
||||
current_action=new_current_action,
|
||||
context=context
|
||||
))
|
||||
update_values = {
|
||||
"current_action": new_current_action,
|
||||
"context": context
|
||||
}
|
||||
if execution_state is not None:
|
||||
update_values["execution_state"] = execution_state
|
||||
await db.execute(update(cls).where(cls.id == wid).values(**update_values))
|
||||
return True
|
||||
|
||||
@@ -91,11 +91,17 @@ class WorkflowOper(DbOper):
|
||||
"""
|
||||
return Workflow.fail(self._db, wid, result)
|
||||
|
||||
def step(self, wid: int, action_id: str, context: dict) -> bool:
|
||||
def step(self, wid: int, action_id: str, context: dict, execution_state: Optional[dict] = None) -> bool:
|
||||
"""
|
||||
步进
|
||||
"""
|
||||
return Workflow.update_current_action(self._db, wid, action_id, context)
|
||||
return Workflow.update_current_action(
|
||||
self._db,
|
||||
wid,
|
||||
action_id,
|
||||
context,
|
||||
execution_state
|
||||
)
|
||||
|
||||
def reset(self, wid: int, reset_count: bool = False) -> bool:
|
||||
"""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
@@ -708,6 +709,7 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
# 获取文件大小
|
||||
target_name = new_name or path.name
|
||||
target_path = Path(fileitem.path) / target_name
|
||||
stat = path.stat()
|
||||
|
||||
# 初始化进度回调
|
||||
progress_callback = transfer_process(path.as_posix())
|
||||
@@ -718,6 +720,9 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
headers.setdefault("Content-Type", "application/octet-stream")
|
||||
headers.setdefault("As-Task", str(task).lower())
|
||||
headers.setdefault("File-Path", encoded_path)
|
||||
headers.setdefault("Content-Length", str(stat.st_size))
|
||||
headers.setdefault("Last-Modified", str(int(stat.st_mtime * 1000)))
|
||||
headers.update(self.__get_upload_hash_headers(path))
|
||||
|
||||
# 创建自定义的文件流,支持进度回调
|
||||
class ProgressFileReader:
|
||||
@@ -783,6 +788,28 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
logger.error(f"【OpenList】上传文件 {path} 失败:{e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def __get_upload_hash_headers(path: Path) -> dict:
|
||||
"""
|
||||
计算 OpenList 秒传所需的文件哈希请求头。
|
||||
"""
|
||||
md5_hash = hashlib.md5()
|
||||
sha1_hash = hashlib.sha1()
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(path, "rb") as file_handler:
|
||||
while True:
|
||||
chunk = file_handler.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
md5_hash.update(chunk)
|
||||
sha1_hash.update(chunk)
|
||||
sha256_hash.update(chunk)
|
||||
return {
|
||||
"X-File-Md5": md5_hash.hexdigest(),
|
||||
"X-File-Sha1": sha1_hash.hexdigest(),
|
||||
"X-File-Sha256": sha256_hash.hexdigest(),
|
||||
}
|
||||
|
||||
def detail(self, fileitem: schemas.FileItem) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
获取文件详情
|
||||
|
||||
@@ -116,7 +116,7 @@ class NexusPhpSiteUserInfo(SiteParserBase):
|
||||
has_ucoin, self.bonus = self._parse_ucoin(html)
|
||||
if has_ucoin:
|
||||
return
|
||||
tmps = html.xpath('//a[contains(@href,"mybonus")]/text()') if html else None
|
||||
tmps = html.xpath('//a[contains(@href,"mybonus")]/text()') if html is not None else None
|
||||
if tmps:
|
||||
bonus_text = str(tmps[0]).strip()
|
||||
bonus_match = re.search(r"([\d,.]+)", bonus_text)
|
||||
|
||||
@@ -174,6 +174,21 @@ class _PluginBase(metaclass=ABCMeta):
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_auth_providers(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
声明插件提供的登录认证入口。
|
||||
|
||||
返回示例:
|
||||
[{
|
||||
"id": "oidc",
|
||||
"name": "OIDC 登录",
|
||||
"icon": "mdi-openid",
|
||||
"component": "AuthPage",
|
||||
"enabled": True
|
||||
}]
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_module(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取插件模块声明,用于胁持系统模块实现(方法名:方法实现)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.schemas.context import Context, MediaInfo
|
||||
from app.schemas.download import DownloadTask
|
||||
@@ -19,13 +19,15 @@ class Workflow(BaseModel):
|
||||
timer: Optional[str] = Field(default=None, description="定时器")
|
||||
trigger_type: Optional[str] = Field(default='timer', description="触发类型:timer-定时触发 event-事件触发 manual-手动触发")
|
||||
event_type: Optional[str] = Field(default=None, description="事件类型(当trigger_type为event时使用)")
|
||||
event_conditions: Optional[dict] = Field(default={}, description="事件条件(JSON格式,用于过滤事件)")
|
||||
event_conditions: Optional[dict] = Field(default_factory=dict, description="事件条件(JSON格式,用于过滤事件)")
|
||||
state: Optional[str] = Field(default=None, description="状态")
|
||||
current_action: Optional[str] = Field(default=None, description="已执行动作")
|
||||
result: Optional[str] = Field(default=None, description="任务执行结果")
|
||||
run_count: Optional[int] = Field(default=0, description="已执行次数")
|
||||
actions: Optional[list] = Field(default=[], description="任务列表")
|
||||
flows: Optional[list] = Field(default=[], description="任务流")
|
||||
actions: Optional[list] = Field(default_factory=list, description="任务列表")
|
||||
flows: Optional[list] = Field(default_factory=list, description="任务流")
|
||||
execution_config: Optional[dict] = Field(default_factory=dict, description="工作流执行配置")
|
||||
execution_state: Optional[dict] = Field(default_factory=dict, description="工作流结构化执行状态")
|
||||
add_time: Optional[str] = Field(default=None, description="创建时间")
|
||||
last_time: Optional[str] = Field(default=None, description="最后执行时间")
|
||||
|
||||
@@ -48,8 +50,16 @@ class Action(BaseModel):
|
||||
type: Optional[str] = Field(default=None, description="动作类型 (类名)")
|
||||
name: Optional[str] = Field(default=None, description="动作名称")
|
||||
description: Optional[str] = Field(default=None, description="动作描述")
|
||||
position: Optional[dict] = Field(default={}, description="位置")
|
||||
data: Optional[dict] = Field(default={}, description="参数")
|
||||
position: Optional[dict] = Field(default_factory=dict, description="位置")
|
||||
data: Optional[dict] = Field(default_factory=dict, description="参数")
|
||||
inputs: Optional[List[str]] = Field(default_factory=list, description="动作输入声明")
|
||||
outputs: Optional[dict] = Field(default_factory=dict, description="动作输出声明")
|
||||
join_policy: Optional[str] = Field(default=None, description="多上游节点汇合策略")
|
||||
fail_policy: Optional[str] = Field(default=None, description="动作失败后的工作流处理策略")
|
||||
branch_policy: Optional[str] = Field(default=None, description="多出边分支策略")
|
||||
concurrency_key: Optional[str] = Field(default=None, description="并发互斥键")
|
||||
timeout: Optional[int] = Field(default=None, description="动作执行超时时间(秒)")
|
||||
retry: Optional[dict] = Field(default_factory=dict, description="动作重试策略")
|
||||
|
||||
|
||||
class ActionExecution(BaseModel):
|
||||
@@ -66,16 +76,32 @@ class ActionContext(BaseModel):
|
||||
动作基础上下文,各动作通用数据
|
||||
"""
|
||||
content: Optional[str] = Field(default=None, description="文本类内容")
|
||||
torrents: Optional[List[Context]] = Field(default=[], description="资源列表")
|
||||
medias: Optional[List[MediaInfo]] = Field(default=[], description="媒体列表")
|
||||
fileitems: Optional[List[FileItem]] = Field(default=[], description="文件列表")
|
||||
downloads: Optional[List[DownloadTask]] = Field(default=[], description="下载任务列表")
|
||||
sites: Optional[List[Site]] = Field(default=[], description="站点列表")
|
||||
subscribes: Optional[List[Subscribe]] = Field(default=[], description="订阅列表")
|
||||
execute_history: Optional[List[ActionExecution]] = Field(default=[], description="执行历史")
|
||||
torrents: Optional[List[Context]] = Field(default_factory=list, description="资源列表")
|
||||
medias: Optional[List[MediaInfo]] = Field(default_factory=list, description="媒体列表")
|
||||
fileitems: Optional[List[FileItem]] = Field(default_factory=list, description="文件列表")
|
||||
downloads: Optional[List[DownloadTask]] = Field(default_factory=list, description="下载任务列表")
|
||||
sites: Optional[List[Site]] = Field(default_factory=list, description="站点列表")
|
||||
subscribes: Optional[List[Subscribe]] = Field(default_factory=list, description="订阅列表")
|
||||
workflow_context: Optional[dict] = Field(default_factory=dict, description="工作流全局上下文")
|
||||
node_outputs: Optional[dict] = Field(default_factory=dict, description="节点输出数据")
|
||||
runtime_state: Optional[dict] = Field(default_factory=dict, description="运行期状态")
|
||||
artifacts: Optional[dict] = Field(default_factory=dict, description="大对象引用与产物数据")
|
||||
execute_history: Optional[List[ActionExecution]] = Field(default_factory=list, description="执行历史")
|
||||
progress: Optional[int] = Field(default=0, description="执行进度(%)")
|
||||
|
||||
|
||||
class ActionResult(BaseModel):
|
||||
"""
|
||||
动作执行结果。
|
||||
"""
|
||||
success: Optional[bool] = Field(default=True, description="动作是否执行成功")
|
||||
message: Optional[str] = Field(default=None, description="动作执行消息")
|
||||
context: Optional[ActionContext] = Field(default=None, description="动作执行后的上下文")
|
||||
outputs: Optional[dict[str, Any]] = Field(default_factory=dict, description="当前节点显式输出")
|
||||
next_policy: Optional[str] = Field(default=None, description="动作完成后的调度策略")
|
||||
attempts: Optional[int] = Field(default=1, description="动作实际尝试次数")
|
||||
|
||||
|
||||
class ActionFlow(BaseModel):
|
||||
"""
|
||||
工作流流程
|
||||
@@ -84,6 +110,10 @@ class ActionFlow(BaseModel):
|
||||
source: Optional[str] = Field(default=None, description="源动作")
|
||||
target: Optional[str] = Field(default=None, description="目标动作")
|
||||
animated: Optional[bool] = Field(default=True, description="是否动画流程")
|
||||
data: Optional[dict] = Field(default_factory=dict, description="流程扩展配置")
|
||||
condition: Optional[str] = Field(default=None, description="流转条件表达式")
|
||||
join_policy: Optional[str] = Field(default=None, description="目标节点汇合策略")
|
||||
branch_policy: Optional[str] = Field(default=None, description="源节点分支策略")
|
||||
|
||||
|
||||
class WorkflowShare(BaseModel):
|
||||
|
||||
@@ -3,6 +3,19 @@ from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
# urllib3-future 覆盖 urllib3 命名空间后删除了 format_header_param,导致 telebot 崩溃,需在加载模块前打补丁
|
||||
try:
|
||||
import urllib3.fields as _urllib3_fields
|
||||
|
||||
if not hasattr(_urllib3_fields, "format_header_param") and hasattr(
|
||||
_urllib3_fields, "format_header_param_rfc2231"
|
||||
):
|
||||
_urllib3_fields.format_header_param = (
|
||||
_urllib3_fields.format_header_param_rfc2231
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from app.chain.system import SystemChain
|
||||
from app.core.config import global_vars
|
||||
from app.helper.server import MoviePilotServerHelper
|
||||
@@ -12,7 +25,11 @@ from app.startup.modules_initializer import init_modules, stop_modules
|
||||
from app.startup.monitor_initializer import stop_monitor, init_monitor
|
||||
from app.startup.plugins_initializer import init_plugins, stop_plugins, sync_plugins
|
||||
from app.startup.routers_initializer import init_routers
|
||||
from app.startup.scheduler_initializer import stop_scheduler, init_scheduler, init_plugin_scheduler
|
||||
from app.startup.scheduler_initializer import (
|
||||
stop_scheduler,
|
||||
init_scheduler,
|
||||
init_plugin_scheduler,
|
||||
)
|
||||
from app.startup.workflow_initializer import init_workflow, stop_workflow
|
||||
from app.utils.http import aclose_shared_async_transports
|
||||
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
"""测试辅助工具(主程序与插件仓共享)。
|
||||
|
||||
提供测试期对 ``sys.modules`` 的临时打桩能力,保证打桩在使用后还原,避免测试间
|
||||
因残留假模块而相互污染。仅供测试使用,不参与运行时逻辑。
|
||||
汇集主程序与插件仓共用的测试 harness,仅供测试使用、不参与运行时逻辑:
|
||||
|
||||
- :mod:`app.testing.stub`:测试期对 ``sys.modules`` 的临时打桩并自动还原,避免残留假模块相互污染;
|
||||
- :mod:`app.testing.bootstrap`:隔离 CONFIG_DIR、建表、插件目录注入与 v1/v2 marker 等引导逻辑;
|
||||
- :mod:`app.testing.network_guard`:autouse 拦截测试期对非本地主机的真实出站。
|
||||
|
||||
子模块各自按需 import(如 ``network_guard`` 依赖 pytest),故此处只 re-export 无第三方依赖的
|
||||
:func:`stub_modules`,保持 ``import app.testing`` 不引入 pytest 等测试期依赖。
|
||||
"""
|
||||
from app.testing.stub import stub_modules
|
||||
|
||||
|
||||
169
app/testing/bootstrap.py
Normal file
169
app/testing/bootstrap.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""测试引导共享实现(主程序与插件仓同源)。
|
||||
|
||||
主程序 ``tests/conftest.py`` 与各插件仓的极薄 shim(``tests/_bootstrap.py``,仅负责把
|
||||
后端定位并加入 ``sys.path``)都委托到这里,使「隔离 CONFIG_DIR / 建表 / 注入插件目录 /
|
||||
按目录打 v1·v2 marker / 退出清理」等引导逻辑只在主程序维护一处,所有消费方行为与修复一致。
|
||||
其中 :func:`isolate_config_dir` 为主程序与插件仓共用,``prepare_v1/v2_backend`` 与
|
||||
:func:`mark_plugin_generation` 为插件仓专用。
|
||||
|
||||
本模块只依赖标准库,``import`` 期不连库、不触发 ``app.db``:调用方可安全地「先 import 本模块、
|
||||
再隔离 CONFIG_DIR」,不破坏「隔离必须早于首个 ``import app.db``」这一硬约束。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# 本进程隔离出的临时 CONFIG_DIR,兼作幂等标记
|
||||
_isolated_config_dir: Optional[str] = None
|
||||
|
||||
|
||||
def isolate_config_dir() -> str:
|
||||
"""把 ``CONFIG_DIR`` 指向进程私有临时目录,隔离主程序真实库与配置(幂等)。
|
||||
|
||||
``import app.db`` / ``import app.chain.*`` 在 import 期即按 ``settings.CONFIG_PATH`` 连接
|
||||
``user.db``,故本函数必须在首个 ``import app.db`` 之前调用。调用方已显式设置 ``CONFIG_DIR``
|
||||
(如 CI 指定隔离目录)时尊重之、不覆盖。
|
||||
|
||||
:return: 实际生效的 CONFIG_DIR 绝对路径
|
||||
"""
|
||||
global _isolated_config_dir
|
||||
if _isolated_config_dir is not None:
|
||||
return _isolated_config_dir
|
||||
existing = os.environ.get("CONFIG_DIR")
|
||||
if existing:
|
||||
_isolated_config_dir = existing
|
||||
return existing
|
||||
tmp = tempfile.mkdtemp(prefix="mp-test-config-")
|
||||
os.environ["CONFIG_DIR"] = tmp
|
||||
_isolated_config_dir = tmp
|
||||
|
||||
def _cleanup(path: str = tmp, rmtree=shutil.rmtree, sys_mod=sys) -> None:
|
||||
"""进程退出时释放 SQLite 连接池再删临时目录。
|
||||
|
||||
默认参数绑定 ``rmtree``/``path``/``sys_mod``:解释器关停期标准库模块可能已被回收为 ``None``,
|
||||
绑定后仍可安全调用。先 ``Engine.dispose`` 释放 ``user.db`` 连接,规避 Windows 下
|
||||
文件锁导致 ``rmtree`` 静默失败(``ignore_errors``)、残留临时目录。
|
||||
"""
|
||||
try:
|
||||
db_mod = sys_mod.modules.get("app.db")
|
||||
if db_mod is not None:
|
||||
db_mod.Engine.dispose()
|
||||
except Exception:
|
||||
pass
|
||||
rmtree(path, ignore_errors=True)
|
||||
|
||||
atexit.register(_cleanup)
|
||||
return tmp
|
||||
|
||||
|
||||
def _prepend_sys_path(path: Path) -> None:
|
||||
"""把目录前置到 ``sys.path``(去重),使其内顶层包可被导入。"""
|
||||
value = str(path)
|
||||
if value not in sys.path:
|
||||
sys.path.insert(0, value)
|
||||
|
||||
|
||||
def ensure_sites_stub() -> None:
|
||||
"""为 ``app.helper.sites`` 补最小垫片(仅在缺失时)。
|
||||
|
||||
``app.helper.sites`` 由独立仓库动态拉取,CI / 全新环境无该模块,而众多 ``app.chain.*`` /
|
||||
``app.modules.*`` 在 import 期依赖它。统一补一个最小垫片,省去各测试文件各自打桩;若真实模块
|
||||
已存在(本地已拉取)则用真实模块、不覆盖,不影响真实行为。须在隔离 CONFIG_DIR 之后调用,
|
||||
以免试探性 ``import app.helper.sites`` 触发的连库落到真实库。
|
||||
"""
|
||||
if "app.helper.sites" in sys.modules:
|
||||
return
|
||||
try:
|
||||
import app.helper.sites # noqa: F401 本地已拉取时用真实模块
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
from types import ModuleType
|
||||
stub = ModuleType("app.helper.sites")
|
||||
stub.SitesHelper = object
|
||||
sys.modules["app.helper.sites"] = stub
|
||||
|
||||
|
||||
def ensure_optional_stub(name: str, **attrs) -> None:
|
||||
"""为可选第三方依赖补占位模块(仅在缺失时),可带属性。
|
||||
|
||||
用例 import 的 app 代码会牵入可选三方库(如 psutil / dateparser / Pinyin2Hanzi /
|
||||
qbittorrentapi / transmission_rpc),CI / 全新环境可能未安装。本函数在该库缺失时补一个带
|
||||
指定属性的占位,使 import 不致失败;若已真实安装则保留真实模块、不覆盖。占位为进程级常驻
|
||||
(与 import 生命周期一致、不作用域还原),是「让可选 import 不失败」的垫片——与
|
||||
:func:`stub_modules`(作用域内打桩并还原)属不同用途,故不收进 stub_modules。
|
||||
|
||||
:param name: 可选依赖的顶层模块名
|
||||
:param attrs: 占位模块需暴露的属性(仅在真正创建占位时设置)
|
||||
"""
|
||||
if name in sys.modules:
|
||||
return
|
||||
try:
|
||||
__import__(name)
|
||||
return
|
||||
except ImportError:
|
||||
pass
|
||||
from types import ModuleType
|
||||
module = ModuleType(name)
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
sys.modules[name] = module
|
||||
|
||||
|
||||
def prepare_backend() -> None:
|
||||
"""隔离 CONFIG_DIR、补 sites 垫片并建表(后端须已在 ``sys.path`` 上)。
|
||||
|
||||
主程序中后端即当前包;插件仓由其 ``tests/_bootstrap.py`` shim 在 import 本模块前
|
||||
先把后端目录注入 ``sys.path``。顺序固定:先隔离 CONFIG_DIR,再补 ``app.helper.sites`` 垫片,
|
||||
最后建表——隔离出的临时库为空,运行期查 ``systemconfig`` 等表会报 no such table,故建表;
|
||||
``init_db`` 仅 import models + create_all,无 alembic/网络、幂等、毫秒级。
|
||||
"""
|
||||
isolate_config_dir()
|
||||
ensure_sites_stub()
|
||||
from app.db.init import init_db
|
||||
init_db()
|
||||
|
||||
|
||||
def prepare_v2_backend(plugins_repo: Path) -> None:
|
||||
"""v2 插件单测引导:``prepare_backend`` + 把 ``<repo>/plugins.v2`` 注入 ``sys.path``。
|
||||
|
||||
与 :func:`prepare_v1_backend` 互斥:v1/v2 存在同名插件包,同一进程同时加载会相互覆盖,
|
||||
须在各自独立的 pytest 会话中运行。
|
||||
|
||||
:param plugins_repo: 插件仓根目录(由调用方 shim 传入)
|
||||
"""
|
||||
prepare_backend()
|
||||
_prepend_sys_path(Path(plugins_repo) / "plugins.v2")
|
||||
|
||||
|
||||
def prepare_v1_backend(plugins_repo: Path) -> None:
|
||||
"""v1 插件单测引导:``prepare_backend`` + 把 ``<repo>/plugins`` 注入 ``sys.path``(与 v2 互斥)。
|
||||
|
||||
:param plugins_repo: 插件仓根目录(由调用方 shim 传入)
|
||||
"""
|
||||
prepare_backend()
|
||||
_prepend_sys_path(Path(plugins_repo) / "plugins")
|
||||
|
||||
|
||||
def mark_plugin_generation(items, pytest_module) -> None:
|
||||
"""按用例所在目录自动给其打 ``v1`` / ``v2`` marker,供按代筛选与分会话运行。
|
||||
|
||||
优先读取 pytest 7+ 的 ``item.path``,旧版 pytest 缺失该属性时回退到 ``item.fspath``。用
|
||||
「不带前导斜杠」的子串匹配(``tests/v2/`` / ``tests/v1/``),兼容相对路径与绝对路径两种
|
||||
运行方式:以 ``pytest tests/v2`` 等相对路径运行时收集路径可能不含前导斜杠。
|
||||
``pytest`` 模块由各仓 conftest 传入,避免本模块在非测试态强依赖 pytest。
|
||||
|
||||
:param items: pytest 收集到的用例集合
|
||||
:param pytest_module: 调用方传入的 ``pytest`` 模块对象
|
||||
"""
|
||||
for item in items:
|
||||
item_path = getattr(item, "path", None)
|
||||
path = str(item_path if item_path is not None else item.fspath).replace("\\", "/")
|
||||
if "tests/v2/" in path:
|
||||
item.add_marker(pytest_module.mark.v2)
|
||||
elif "tests/v1/" in path:
|
||||
item.add_marker(pytest_module.mark.v1)
|
||||
40
app/testing/network_guard.py
Normal file
40
app/testing/network_guard.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""测试网络守卫(主程序与插件仓共享)。
|
||||
|
||||
提供一个 autouse 的 pytest fixture,拦截测试期对非本地主机的真实出站网络。主程序
|
||||
``tests/conftest.py`` 与各插件仓 conftest 只需 ``from app.testing.network_guard import
|
||||
block_real_network`` 即复用同一道守卫——pytest 会把 conftest 命名空间内(含 import 进来的)
|
||||
fixture 一并识别,autouse 自动作用于每个用例,无需逐用例改动。
|
||||
|
||||
仅供测试使用,不参与运行时逻辑。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
# 本地回环/通配地址放行,其余主机一律视为真实出站;getaddrinfo 的 host 可能为 str 或 bytes
|
||||
_ALLOWED_NETWORK_HOSTS = {"127.0.0.1", "::1", "localhost", "0.0.0.0", "::", ""}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def block_real_network(monkeypatch):
|
||||
"""防御纵深:拦截对非本地主机的真实出站,强制测试零真实网络。
|
||||
|
||||
补在各用例自身 mock 之上:某用例万一漏 mock 外部依赖(TMDB / LLM 目录 / 下载器 /
|
||||
媒体服务器 / 任意外链),其真实 DNS 解析会在此被拦并报错,而非静默发请求。本地回环放行
|
||||
(sqlite 等)。asyncio 默认解析器经线程池调用 ``socket.getaddrinfo``,故拦此一处即覆盖
|
||||
同步与异步出站。``monkeypatch`` 在用例结束后自动还原,不影响其他用例与进程退出。
|
||||
"""
|
||||
import socket
|
||||
|
||||
_real_getaddrinfo = socket.getaddrinfo
|
||||
|
||||
def _guarded_getaddrinfo(host, *args, **kwargs):
|
||||
normalized = host.decode() if isinstance(host, (bytes, bytearray)) else host
|
||||
if normalized is not None and normalized not in _ALLOWED_NETWORK_HOSTS:
|
||||
raise RuntimeError(
|
||||
f"测试禁止真实出站网络:尝试解析 {normalized!r};请 mock 对应外部依赖"
|
||||
)
|
||||
return _real_getaddrinfo(host, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(socket, "getaddrinfo", _guarded_getaddrinfo)
|
||||
yield
|
||||
@@ -70,6 +70,8 @@ _DEFAULT_MAX_KEEPALIVE_CONNECTIONS = 20
|
||||
_DEFAULT_MAX_CONNECTIONS = 40
|
||||
# 默认的 keep-alive 连接过期时间(秒)
|
||||
_DEFAULT_KEEPALIVE_EXPIRY = 30
|
||||
# 同步 requests.Session 复用连接时,遇到对端或代理关闭 keep-alive 后允许重试的方法
|
||||
_REQUESTS_RETRY_IDEMPOTENT_METHODS = ("GET", "HEAD", "OPTIONS")
|
||||
# 持有 LRU 淘汰后正在异步关闭的 transport task,避免 fire-and-forget 被 GC 警告
|
||||
_pending_eviction_tasks: set[asyncio.Task] = set()
|
||||
|
||||
@@ -90,6 +92,9 @@ def _get_shared_async_transport(
|
||||
会话级状态由调用方在外层 AsyncClient(transport=...) 实例化时单独配置,
|
||||
每次调用用完即销毁,因此天然无 jar 累积串扰。
|
||||
"""
|
||||
# 规范化代理:拒绝空字符串等非法值,防止 httpx 抛出 Unknown scheme for proxy URL
|
||||
if proxy is not None and (not proxy or not proxy.strip()):
|
||||
proxy = None
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
@@ -344,14 +349,47 @@ class RequestUtils:
|
||||
kwargs.setdefault("timeout", self._timeout)
|
||||
kwargs.setdefault("verify", False)
|
||||
kwargs.setdefault("stream", False)
|
||||
method_upper = method.upper()
|
||||
try:
|
||||
return req_method(method, url, **kwargs)
|
||||
except (
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.ChunkedEncodingError,
|
||||
requests.exceptions.ReadTimeout,
|
||||
) as e:
|
||||
if (
|
||||
self._session is not None
|
||||
and method_upper in _REQUESTS_RETRY_IDEMPOTENT_METHODS
|
||||
):
|
||||
logger.debug(f"keep-alive 连接已失效,同步幂等请求重试一次: {e!r}")
|
||||
try:
|
||||
self._session.close()
|
||||
return req_method(method, url, **kwargs)
|
||||
except requests.exceptions.RequestException as retry_error:
|
||||
error_msg = (
|
||||
str(retry_error)
|
||||
if str(retry_error)
|
||||
else f"未知网络错误 (URL: {url}, Method: {method_upper})"
|
||||
)
|
||||
logger.debug(f"重试后同步请求仍失败: {error_msg}")
|
||||
if raise_exception:
|
||||
raise
|
||||
return None
|
||||
error_msg = (
|
||||
str(e)
|
||||
if str(e)
|
||||
else f"未知网络错误 (URL: {url}, Method: {method_upper})"
|
||||
)
|
||||
logger.debug(f"同步请求失败(不重试): {error_msg}")
|
||||
if raise_exception:
|
||||
raise
|
||||
return None
|
||||
except requests.exceptions.RequestException as e:
|
||||
# 获取更详细的错误信息
|
||||
error_msg = (
|
||||
str(e)
|
||||
if str(e)
|
||||
else f"未知网络错误 (URL: {url}, Method: {method.upper()})"
|
||||
else f"未知网络错误 (URL: {url}, Method: {method_upper})"
|
||||
)
|
||||
logger.debug(f"请求失败: {error_msg}")
|
||||
if raise_exception:
|
||||
@@ -864,12 +902,17 @@ class AsyncRequestUtils:
|
||||
|
||||
# 如果已经是字符串格式,直接返回
|
||||
if isinstance(proxies, str):
|
||||
return proxies
|
||||
return proxies.strip() or None
|
||||
|
||||
# 如果是字典格式,提取http或https代理
|
||||
if isinstance(proxies, dict):
|
||||
# 优先使用https代理,如果没有则使用http代理
|
||||
proxy_url = proxies.get("https") or proxies.get("http")
|
||||
# 先各自 strip,避免空白字符串阻断裂合取或回退到 http 代理
|
||||
https_proxy = proxies.get("https")
|
||||
http_proxy = proxies.get("http")
|
||||
https_proxy = https_proxy.strip() if isinstance(https_proxy, str) else None
|
||||
http_proxy = http_proxy.strip() if isinstance(http_proxy, str) else None
|
||||
proxy_url = https_proxy or http_proxy
|
||||
if proxy_url:
|
||||
return proxy_url
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import threading
|
||||
from time import sleep
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import List, Tuple
|
||||
from time import monotonic, sleep
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from app.core.config import global_vars
|
||||
from app.core.event import eventmanager, Event
|
||||
@@ -9,7 +8,7 @@ from app.db.models import Workflow
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
from app.helper.module import ModuleHelper
|
||||
from app.log import logger
|
||||
from app.schemas import ActionContext, Action
|
||||
from app.schemas import ActionContext, Action, ActionResult
|
||||
from app.schemas.types import EventType
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
@@ -63,48 +62,176 @@ class WorkFlowManager(metaclass=Singleton):
|
||||
"""
|
||||
停止
|
||||
"""
|
||||
for event_type_str in list(self._event_workflows.keys()):
|
||||
self.remove_workflow_event(event_type_str=event_type_str)
|
||||
self._actions = {}
|
||||
self._event_workflows = {}
|
||||
|
||||
def excute(self, workflow_id: int, action: Action,
|
||||
context: ActionContext = None) -> Tuple[bool, str, ActionContext]:
|
||||
def execute(self, workflow_id: int, action: Action, context: ActionContext = None,
|
||||
inputs: Optional[dict] = None, runtime: Optional[dict] = None,
|
||||
cancel_token: Optional[Any] = None) -> ActionResult:
|
||||
"""
|
||||
执行工作流动作
|
||||
"""
|
||||
if not context:
|
||||
context = ActionContext()
|
||||
if action.type in self._actions:
|
||||
# 实例化之前,清理掉类对象的数据
|
||||
|
||||
# 实例化
|
||||
action_obj = self._actions[action.type](action.id)
|
||||
# 执行
|
||||
logger.info(f"执行动作: {action.id} - {action.name}")
|
||||
try:
|
||||
result_context = action_obj.execute(workflow_id, action.data, context)
|
||||
except Exception as err:
|
||||
logger.error(f"{action.name} 执行失败: {err}")
|
||||
return False, f"{err}", context
|
||||
loop = action.data.get("loop")
|
||||
loop_interval = action.data.get("loop_interval")
|
||||
if loop and loop_interval:
|
||||
while not action_obj.done:
|
||||
if global_vars.is_workflow_stopped(workflow_id):
|
||||
break
|
||||
# 等待
|
||||
logger.info(f"{action.name} 等待 {loop_interval} 秒后继续执行 ...")
|
||||
sleep(loop_interval)
|
||||
# 执行
|
||||
logger.info(f"继续执行动作: {action.id} - {action.name}")
|
||||
result_context = action_obj.execute(workflow_id, action.data, result_context)
|
||||
if action_obj.success:
|
||||
logger.info(f"{action.name} 执行成功")
|
||||
else:
|
||||
logger.error(f"{action.name} 执行失败!")
|
||||
return action_obj.success, action_obj.message, result_context
|
||||
else:
|
||||
if action.type not in self._actions:
|
||||
logger.error(f"未找到动作: {action.type} - {action.name}")
|
||||
return False, " ", context
|
||||
return ActionResult(success=False, message=" ", context=context)
|
||||
|
||||
retry_config = self._get_retry_config(action)
|
||||
max_attempts = retry_config["max_attempts"]
|
||||
interval = retry_config["interval"]
|
||||
backoff = retry_config["backoff"]
|
||||
action_result = ActionResult(success=False, message="", context=context)
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
if self._is_cancelled(workflow_id, cancel_token):
|
||||
return ActionResult(success=False, message="工作流已取消", context=context)
|
||||
runtime_data = {
|
||||
**(runtime or {}),
|
||||
"attempt": attempt,
|
||||
"max_attempts": max_attempts,
|
||||
"cancel_token": cancel_token,
|
||||
}
|
||||
action_result = self._execute_action_once(
|
||||
workflow_id=workflow_id,
|
||||
action=action,
|
||||
context=context,
|
||||
inputs=inputs or {},
|
||||
runtime=runtime_data,
|
||||
cancel_token=cancel_token
|
||||
)
|
||||
action_result.attempts = attempt
|
||||
context = action_result.context or context
|
||||
if action_result.success:
|
||||
logger.info(f"{action.name} 执行成功")
|
||||
return action_result
|
||||
if attempt < max_attempts and not self._is_cancelled(workflow_id, cancel_token):
|
||||
wait_seconds = interval * (backoff ** (attempt - 1))
|
||||
logger.info(f"{action.name} 执行失败,{wait_seconds} 秒后重试({attempt}/{max_attempts})...")
|
||||
self._sleep_with_cancel(workflow_id, wait_seconds, cancel_token)
|
||||
|
||||
logger.error(f"{action.name} 执行失败!")
|
||||
return action_result
|
||||
|
||||
def excute(self, workflow_id: int, action: Action,
|
||||
context: ActionContext = None) -> Tuple[bool, str, ActionContext]:
|
||||
"""
|
||||
执行工作流动作,兼容历史拼写错误的方法名。
|
||||
"""
|
||||
action_result = self.execute(workflow_id=workflow_id, action=action, context=context)
|
||||
return bool(action_result.success), action_result.message or "", action_result.context or context or ActionContext()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_action_result(result: Any, action_obj: Any, fallback_context: ActionContext) -> ActionResult:
|
||||
"""
|
||||
将旧版动作上下文与新版结构化结果统一为动作执行结果。
|
||||
"""
|
||||
if isinstance(result, ActionResult):
|
||||
result.context = result.context or fallback_context
|
||||
if result.message is None:
|
||||
result.message = action_obj.message
|
||||
return result
|
||||
return ActionResult(
|
||||
success=action_obj.success,
|
||||
message=action_obj.message,
|
||||
context=result or fallback_context
|
||||
)
|
||||
|
||||
def _execute_action_once(self, workflow_id: int, action: Action, context: ActionContext,
|
||||
inputs: dict, runtime: dict, cancel_token: Optional[Any]) -> ActionResult:
|
||||
action_obj = self._actions[action.type](action.id)
|
||||
logger.info(f"执行动作: {action.id} - {action.name}")
|
||||
try:
|
||||
action_result = self._run_action_with_loop(
|
||||
workflow_id=workflow_id,
|
||||
action=action,
|
||||
action_obj=action_obj,
|
||||
context=context,
|
||||
inputs=inputs,
|
||||
runtime=runtime,
|
||||
cancel_token=cancel_token
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"{action.name} 执行失败: {err}")
|
||||
return ActionResult(success=False, message=f"{err}", context=context)
|
||||
return action_result
|
||||
|
||||
def _run_action_with_loop(self, workflow_id: int, action: Action, action_obj: Any,
|
||||
context: ActionContext, inputs: dict, runtime: dict,
|
||||
cancel_token: Optional[Any]) -> ActionResult:
|
||||
timeout = self._get_action_timeout(action)
|
||||
started_at = monotonic()
|
||||
action_result = self._call_action(
|
||||
workflow_id=workflow_id,
|
||||
action=action,
|
||||
action_obj=action_obj,
|
||||
context=context,
|
||||
inputs=inputs,
|
||||
runtime=runtime
|
||||
)
|
||||
loop = self._get_action_data_value(action, "loop")
|
||||
loop_interval = self._get_action_data_value(action, "loop_interval")
|
||||
while loop and loop_interval and not action_obj.done:
|
||||
if self._is_cancelled(workflow_id, cancel_token):
|
||||
return ActionResult(success=False, message="工作流已取消", context=action_result.context or context)
|
||||
if timeout and monotonic() - started_at >= timeout:
|
||||
return ActionResult(success=False, message=f"动作执行超时({timeout}秒)", context=action_result.context or context)
|
||||
logger.info(f"{action.name} 等待 {loop_interval} 秒后继续执行 ...")
|
||||
self._sleep_with_cancel(workflow_id, loop_interval, cancel_token)
|
||||
if self._is_cancelled(workflow_id, cancel_token):
|
||||
return ActionResult(success=False, message="工作流已取消", context=action_result.context or context)
|
||||
logger.info(f"继续执行动作: {action.id} - {action.name}")
|
||||
action_result = self._call_action(
|
||||
workflow_id=workflow_id,
|
||||
action=action,
|
||||
action_obj=action_obj,
|
||||
context=action_result.context or context,
|
||||
inputs=inputs,
|
||||
runtime=runtime
|
||||
)
|
||||
return action_result
|
||||
|
||||
def _call_action(self, workflow_id: int, action: Action, action_obj: Any,
|
||||
context: ActionContext, inputs: dict, runtime: dict) -> ActionResult:
|
||||
if hasattr(action_obj, "execute_with_inputs"):
|
||||
result = action_obj.execute_with_inputs(workflow_id, action.data, inputs, runtime, context)
|
||||
else:
|
||||
result = action_obj.execute(workflow_id, action.data, context)
|
||||
return self._normalize_action_result(result, action_obj, context)
|
||||
|
||||
@staticmethod
|
||||
def _get_action_data_value(action: Action, key: str) -> Any:
|
||||
data = action.data or {}
|
||||
return data.get(key) if isinstance(data, dict) else None
|
||||
|
||||
def _get_action_timeout(self, action: Action) -> Optional[int]:
|
||||
timeout = action.timeout or self._get_action_data_value(action, "timeout")
|
||||
return int(timeout) if timeout else None
|
||||
|
||||
def _get_retry_config(self, action: Action) -> dict:
|
||||
retry_config = action.retry or self._get_action_data_value(action, "retry") or {}
|
||||
if not isinstance(retry_config, dict):
|
||||
retry_config = {}
|
||||
return {
|
||||
"max_attempts": max(int(retry_config.get("max_attempts") or 1), 1),
|
||||
"interval": max(float(retry_config.get("interval") or 0), 0),
|
||||
"backoff": max(float(retry_config.get("backoff") or 1), 1),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _is_cancelled(workflow_id: int, cancel_token: Optional[Any]) -> bool:
|
||||
if cancel_token and cancel_token.is_cancelled():
|
||||
return True
|
||||
return global_vars.is_workflow_stopped(workflow_id)
|
||||
|
||||
def _sleep_with_cancel(self, workflow_id: int, seconds: float, cancel_token: Optional[Any]) -> None:
|
||||
deadline = monotonic() + seconds
|
||||
while monotonic() < deadline:
|
||||
if self._is_cancelled(workflow_id, cancel_token):
|
||||
return
|
||||
sleep(min(0.1, deadline - monotonic()))
|
||||
|
||||
def list_actions(self) -> List[dict]:
|
||||
"""
|
||||
@@ -115,6 +242,7 @@ class WorkFlowManager(metaclass=Singleton):
|
||||
"type": key,
|
||||
"name": action.name,
|
||||
"description": action.description,
|
||||
"contract": action.get_contract(),
|
||||
"data": {
|
||||
"label": action.name,
|
||||
**action.data
|
||||
@@ -122,12 +250,21 @@ class WorkFlowManager(metaclass=Singleton):
|
||||
} for key, action in self._actions.items()
|
||||
]
|
||||
|
||||
def get_action_contract(self, action_type: str) -> dict:
|
||||
"""
|
||||
获取动作输入输出契约。
|
||||
"""
|
||||
action = self._actions.get(action_type)
|
||||
if not action or not hasattr(action, "get_contract"):
|
||||
return {}
|
||||
return action.get_contract()
|
||||
|
||||
def update_workflow_event(self, workflow: Workflow):
|
||||
"""
|
||||
更新工作流事件触发器
|
||||
"""
|
||||
# 确保先移除旧的事件监听器
|
||||
self.remove_workflow_event(workflow_id=workflow.id, event_type_str=workflow.event_type)
|
||||
# 工作流可能切换触发事件,先按工作流ID从所有事件映射中移除。
|
||||
self.remove_workflow_event(workflow_id=workflow.id)
|
||||
# 如果工作流是事件触发类型且未被禁用
|
||||
if workflow.trigger_type == "event" and workflow.state != 'P':
|
||||
# 注册事件触发器
|
||||
@@ -154,41 +291,46 @@ class WorkFlowManager(metaclass=Singleton):
|
||||
"""
|
||||
注册工作流事件触发器
|
||||
"""
|
||||
if not event_type_str:
|
||||
return
|
||||
try:
|
||||
event_type = EventType(event_type_str)
|
||||
except ValueError:
|
||||
logger.error(f"无效的事件类型: {event_type_str}")
|
||||
return
|
||||
if event_type in EventType:
|
||||
# 确保先移除旧的事件监听器
|
||||
self.remove_workflow_event(workflow_id, event_type.value)
|
||||
with self._lock:
|
||||
# 添加新的事件监听器
|
||||
eventmanager.add_event_listener(event_type, self._handle_event)
|
||||
# 记录工作流事件触发器
|
||||
if event_type.value not in self._event_workflows:
|
||||
self._event_workflows[event_type.value] = []
|
||||
self._event_workflows[event_type.value].append(workflow_id)
|
||||
eventmanager.add_event_listener(event_type, self._handle_event)
|
||||
# 记录工作流事件触发器
|
||||
if workflow_id not in self._event_workflows[event_type.value]:
|
||||
self._event_workflows[event_type.value].append(workflow_id)
|
||||
logger.info(f"已注册工作流 {workflow_id} 事件触发器: {event_type.value}")
|
||||
|
||||
def remove_workflow_event(self, workflow_id: int, event_type_str: str):
|
||||
def remove_workflow_event(self, workflow_id: Optional[int] = None, event_type_str: Optional[str] = None):
|
||||
"""
|
||||
移除工作流事件触发器
|
||||
"""
|
||||
try:
|
||||
event_type = EventType(event_type_str)
|
||||
except ValueError:
|
||||
logger.error(f"无效的事件类型: {event_type_str}")
|
||||
return
|
||||
if event_type in EventType:
|
||||
event_type_values = [event_type_str] if event_type_str else list(self._event_workflows.keys())
|
||||
for event_type_value in event_type_values:
|
||||
try:
|
||||
event_type = EventType(event_type_value)
|
||||
except ValueError:
|
||||
logger.error(f"无效的事件类型: {event_type_value}")
|
||||
continue
|
||||
with self._lock:
|
||||
eventmanager.remove_event_listener(event_type, self._handle_event)
|
||||
if event_type.value in self._event_workflows:
|
||||
if workflow_id in self._event_workflows[event_type.value]:
|
||||
self._event_workflows[event_type.value].remove(workflow_id)
|
||||
if not self._event_workflows[event_type.value]:
|
||||
del self._event_workflows[event_type.value]
|
||||
logger.info(f"已移除工作流 {workflow_id} 事件触发器")
|
||||
workflow_ids = self._event_workflows.get(event_type.value)
|
||||
if not workflow_ids:
|
||||
continue
|
||||
if workflow_id is None:
|
||||
workflow_ids.clear()
|
||||
elif workflow_id in workflow_ids:
|
||||
workflow_ids.remove(workflow_id)
|
||||
if not workflow_ids:
|
||||
self._event_workflows.pop(event_type.value, None)
|
||||
eventmanager.remove_event_listener(event_type, self._handle_event)
|
||||
logger.info(f"已移除工作流 {workflow_id or ''} 事件触发器")
|
||||
|
||||
def _handle_event(self, event: Event):
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.schemas import ActionContext, ActionParams
|
||||
from app.schemas import ActionContext, ActionParams, ActionResult
|
||||
|
||||
|
||||
class ActionChain(ChainBase):
|
||||
@@ -23,9 +23,13 @@ class BaseAction(ABC):
|
||||
_message = ""
|
||||
# 缓存键值
|
||||
_cache_key = "WorkflowCache-%s"
|
||||
# 动作输入输出契约,由具体动作按需覆盖
|
||||
contract = {}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
self._action_id = action_id
|
||||
self._done_flag = False
|
||||
self._message = ""
|
||||
self.systemconfigoper = SystemConfigOper()
|
||||
|
||||
@classmethod
|
||||
@@ -46,6 +50,41 @@ class BaseAction(ABC):
|
||||
def data(cls) -> dict: # noqa
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_contract(cls) -> dict:
|
||||
"""
|
||||
获取动作输入输出契约。
|
||||
"""
|
||||
contract = getattr(cls, "contract", None) or {}
|
||||
input_fields = cls._build_contract_fields(contract.get("inputs") or [])
|
||||
output_fields = cls._build_contract_fields(contract.get("outputs") or [])
|
||||
return {
|
||||
"inputs": input_fields,
|
||||
"outputs": output_fields,
|
||||
"condition_fields": output_fields,
|
||||
"concurrency_key": contract.get("concurrency_key"),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _build_contract_fields(cls, fields: list) -> list:
|
||||
"""
|
||||
标准化动作契约字段。
|
||||
"""
|
||||
result = []
|
||||
for field in fields:
|
||||
if isinstance(field, str):
|
||||
field = {"name": field}
|
||||
if not isinstance(field, dict) or not field.get("name"):
|
||||
continue
|
||||
result.append({
|
||||
"name": field["name"],
|
||||
"label": field.get("label") or field["name"],
|
||||
"kind": field.get("kind") or "scalar",
|
||||
"merge": field.get("merge"),
|
||||
"identity": field.get("identity"),
|
||||
})
|
||||
return result
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""
|
||||
@@ -92,9 +131,12 @@ class BaseAction(ABC):
|
||||
workflow_cache = self.systemconfigoper.get(workflow_key) or {}
|
||||
action_cache = workflow_cache.get(self._action_id) or []
|
||||
if isinstance(data, list):
|
||||
action_cache.extend(data)
|
||||
for item in data:
|
||||
if item not in action_cache:
|
||||
action_cache.append(item)
|
||||
else:
|
||||
action_cache.append(data)
|
||||
if data not in action_cache:
|
||||
action_cache.append(data)
|
||||
workflow_cache[self._action_id] = action_cache
|
||||
self.systemconfigoper.set(workflow_key, workflow_cache)
|
||||
|
||||
@@ -104,3 +146,61 @@ class BaseAction(ABC):
|
||||
执行动作
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_with_inputs(self, workflow_id: int, params: ActionParams, inputs: dict,
|
||||
runtime: dict, context: ActionContext) -> ActionResult:
|
||||
"""
|
||||
使用显式输入与运行期信息执行动作。
|
||||
"""
|
||||
self._apply_inputs_to_context(inputs=inputs, context=context)
|
||||
self._apply_runtime_to_context(runtime=runtime, context=context)
|
||||
result_context = self.execute(workflow_id, params, context)
|
||||
outputs = self._extract_outputs_from_context(result_context)
|
||||
return ActionResult(
|
||||
success=self.success,
|
||||
message=self.message,
|
||||
context=result_context,
|
||||
outputs=outputs
|
||||
)
|
||||
|
||||
def _apply_inputs_to_context(self, inputs: dict, context: ActionContext) -> None:
|
||||
"""
|
||||
将显式输入回填到旧版上下文字段,兼容仍读取 context 的动作。
|
||||
"""
|
||||
inputs = inputs or {}
|
||||
for field in self.get_contract().get("inputs") or []:
|
||||
missing = object()
|
||||
field_name = field["name"]
|
||||
value = inputs.get(field_name, missing)
|
||||
if value is missing:
|
||||
# 兼容旧版节点输入路径,例如 outputs.A.torrents。
|
||||
for input_key, input_value in inputs.items():
|
||||
if isinstance(input_key, str) and input_key.split(".")[-1] == field_name:
|
||||
value = input_value
|
||||
break
|
||||
if value is not missing:
|
||||
setattr(context, field_name, value)
|
||||
|
||||
@staticmethod
|
||||
def _apply_runtime_to_context(runtime: dict, context: ActionContext) -> None:
|
||||
"""
|
||||
将运行期信息写入 runtime_state,供动作和执行状态读取。
|
||||
"""
|
||||
if not runtime:
|
||||
return
|
||||
context.runtime_state = context.runtime_state or {}
|
||||
context.runtime_state["current_action_runtime"] = {
|
||||
key: value for key, value in runtime.items()
|
||||
if key != "cancel_token"
|
||||
}
|
||||
|
||||
def _extract_outputs_from_context(self, context: ActionContext) -> dict[str, Any]:
|
||||
"""
|
||||
按动作契约从上下文提取输出。
|
||||
"""
|
||||
outputs = {}
|
||||
for field in self.get_contract().get("outputs") or []:
|
||||
value = getattr(context, field["name"], None)
|
||||
if value not in (None, "", [], {}):
|
||||
outputs[field["name"]] = value
|
||||
return outputs
|
||||
|
||||
@@ -26,6 +26,12 @@ class AddDownloadAction(BaseAction):
|
||||
添加下载资源
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"inputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
|
||||
"outputs": [{"name": "downloads", "label": "下载任务", "kind": "list"}],
|
||||
"concurrency_key": "download",
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._added_downloads = []
|
||||
|
||||
@@ -19,6 +19,11 @@ class AddSubscribeAction(BaseAction):
|
||||
添加订阅
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"inputs": [{"name": "medias", "label": "媒体", "kind": "list"}],
|
||||
"outputs": [{"name": "subscribes", "label": "订阅", "kind": "list"}],
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._added_subscribes = []
|
||||
|
||||
@@ -16,6 +16,12 @@ class FetchDownloadsAction(BaseAction):
|
||||
获取下载任务
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"inputs": [{"name": "downloads", "label": "下载任务", "kind": "list"}],
|
||||
"outputs": [{"name": "downloads", "label": "下载任务", "kind": "list", "merge": "replace"}],
|
||||
"concurrency_key": "download",
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._downloads = []
|
||||
@@ -43,12 +49,19 @@ class FetchDownloadsAction(BaseAction):
|
||||
"""
|
||||
更新downloads中的下载任务状态
|
||||
"""
|
||||
__all_complete = False
|
||||
self._downloads = context.downloads or []
|
||||
if not self._downloads:
|
||||
self.job_done("无下载任务")
|
||||
return context
|
||||
|
||||
for download in self._downloads:
|
||||
if global_vars.is_workflow_stopped(workflow_id):
|
||||
break
|
||||
logger.info(f"获取下载任务 {download.download_id} 状态 ...")
|
||||
torrents = ActionChain().list_torrents(hashs=[download.download_id])
|
||||
torrents = ActionChain().list_torrents(
|
||||
hashs=[download.download_id],
|
||||
downloader=download.downloader,
|
||||
)
|
||||
if not torrents:
|
||||
download.completed = True
|
||||
continue
|
||||
@@ -61,5 +74,5 @@ class FetchDownloadsAction(BaseAction):
|
||||
logger.info(f"下载任务 {download.download_id} 未完成")
|
||||
download.completed = False
|
||||
if all([d.completed for d in self._downloads]):
|
||||
self.job_done()
|
||||
self.job_done("下载任务已全部完成")
|
||||
return context
|
||||
|
||||
@@ -27,6 +27,10 @@ class FetchMediasAction(BaseAction):
|
||||
获取媒体数据
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"outputs": [{"name": "medias", "label": "媒体", "kind": "list"}],
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
|
||||
|
||||
@@ -30,6 +30,10 @@ class FetchRssAction(BaseAction):
|
||||
获取RSS资源列表
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"outputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._rss_torrents = []
|
||||
|
||||
@@ -30,6 +30,11 @@ class FetchTorrentsAction(BaseAction):
|
||||
搜索站点资源
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"inputs": [{"name": "medias", "label": "媒体", "kind": "list"}],
|
||||
"outputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._torrents = []
|
||||
|
||||
@@ -22,6 +22,11 @@ class FilterMediasAction(BaseAction):
|
||||
过滤媒体数据
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"inputs": [{"name": "medias", "label": "媒体", "kind": "list"}],
|
||||
"outputs": [{"name": "medias", "label": "媒体", "kind": "list", "merge": "replace"}],
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._medias = []
|
||||
|
||||
@@ -27,6 +27,11 @@ class FilterTorrentsAction(BaseAction):
|
||||
过滤资源数据
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"inputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
|
||||
"outputs": [{"name": "torrents", "label": "资源", "kind": "list", "merge": "replace"}],
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._torrents = []
|
||||
|
||||
@@ -20,6 +20,8 @@ class InvokePluginAction(BaseAction):
|
||||
调用插件
|
||||
"""
|
||||
|
||||
contract = {}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._success = False
|
||||
|
||||
@@ -7,6 +7,8 @@ class NoteAction(BaseAction):
|
||||
备注
|
||||
"""
|
||||
|
||||
contract = {}
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def name(cls) -> str: # noqa
|
||||
|
||||
@@ -24,6 +24,10 @@ class ScanFileAction(BaseAction):
|
||||
整理文件
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"outputs": [{"name": "fileitems", "label": "文件", "kind": "list"}],
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._fileitems = []
|
||||
|
||||
@@ -18,6 +18,11 @@ class ScrapeFileAction(BaseAction):
|
||||
刮削文件
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"inputs": [{"name": "fileitems", "label": "文件", "kind": "list"}],
|
||||
"outputs": [{"name": "fileitems", "label": "文件", "kind": "list"}],
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._scraped_files = []
|
||||
@@ -61,18 +66,18 @@ class ScrapeFileAction(BaseAction):
|
||||
logger.info(f"{fileitem.path} 已刮削过,跳过")
|
||||
continue
|
||||
mediachain = MediaChain()
|
||||
context = mediachain.recognize_by_path(
|
||||
media_context = mediachain.recognize_by_path(
|
||||
fileitem.path,
|
||||
obtain_images=True,
|
||||
)
|
||||
if not context or not context.media_info:
|
||||
if not media_context or not media_context.media_info:
|
||||
_failed_count += 1
|
||||
logger.info(f"{fileitem.path} 未识别到媒体信息,无法刮削")
|
||||
continue
|
||||
mediachain.scrape_metadata(
|
||||
fileitem=fileitem,
|
||||
meta=context.meta_info,
|
||||
mediainfo=context.media_info
|
||||
meta=media_context.meta_info,
|
||||
mediainfo=media_context.media_info
|
||||
)
|
||||
self._scraped_files.append(fileitem)
|
||||
# 保存缓存
|
||||
|
||||
@@ -16,6 +16,8 @@ class SendEventAction(BaseAction):
|
||||
发送事件
|
||||
"""
|
||||
|
||||
contract = {}
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def name(cls) -> str: # noqa
|
||||
|
||||
@@ -20,6 +20,8 @@ class SendMessageAction(BaseAction):
|
||||
发送消息
|
||||
"""
|
||||
|
||||
contract = {}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
|
||||
|
||||
@@ -26,6 +26,15 @@ class TransferFileAction(BaseAction):
|
||||
整理文件
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"inputs": [
|
||||
{"name": "downloads", "label": "下载任务", "kind": "list"},
|
||||
{"name": "fileitems", "label": "文件", "kind": "list"},
|
||||
],
|
||||
"outputs": [{"name": "fileitems", "label": "文件", "kind": "list"}],
|
||||
"concurrency_key": "transfer",
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._fileitems = []
|
||||
|
||||
45
database/versions/7c1a2b3d4e5f_2_2_9.py
Normal file
45
database/versions/7c1a2b3d4e5f_2_2_9.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""2.2.9
|
||||
为工作流增加执行配置和结构化执行状态
|
||||
|
||||
Revision ID: 7c1a2b3d4e5f
|
||||
Revises: d5e6f7a8b9c0
|
||||
Create Date: 2026-06-04
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "7c1a2b3d4e5f"
|
||||
down_revision = "d5e6f7a8b9c0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _has_column(inspector: sa.Inspector, table_name: str, column_name: str) -> bool:
|
||||
"""检查数据表是否已存在指定列。"""
|
||||
if table_name not in inspector.get_table_names():
|
||||
return False
|
||||
return any(column["name"] == column_name for column in inspector.get_columns(table_name))
|
||||
|
||||
|
||||
def _add_json_column_if_missing(table_name: str, column_name: str) -> None:
|
||||
"""缺失时为数据表新增 JSON 列。"""
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
if not _has_column(inspector, table_name, column_name):
|
||||
op.add_column(table_name, sa.Column(column_name, sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""升级数据库结构。"""
|
||||
_add_json_column_if_missing("workflow", "execution_config")
|
||||
_add_json_column_if_missing("workflow", "execution_state")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""回滚数据库结构。"""
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
if _has_column(inspector, "workflow", "execution_state"):
|
||||
op.drop_column("workflow", "execution_state")
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
if _has_column(inspector, "workflow", "execution_config"):
|
||||
op.drop_column("workflow", "execution_config")
|
||||
@@ -22,14 +22,31 @@ def _has_column(inspector: sa.Inspector, table_name: str, column_name: str) -> b
|
||||
return any(column["name"] == column_name for column in inspector.get_columns(table_name))
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
if _has_column(inspector, "subscribe", "episode_priority") is False:
|
||||
op.add_column("subscribe", sa.Column("episode_priority", sa.JSON(), nullable=True))
|
||||
def _ensure_json_column(bind, table_name: str, column_name: str) -> None:
|
||||
"""Add the column as JSON, or fix it if it already exists with the wrong type (PostgreSQL only)."""
|
||||
inspector = sa.inspect(bind)
|
||||
if not _has_column(inspector, table_name, column_name):
|
||||
op.add_column(table_name, sa.Column(column_name, sa.JSON(), nullable=True))
|
||||
return
|
||||
if bind.dialect.name != "postgresql":
|
||||
return
|
||||
for col in inspector.get_columns(table_name):
|
||||
if col["name"] != column_name:
|
||||
continue
|
||||
type_name = type(col["type"]).__name__.upper()
|
||||
if type_name in ("JSON", "JSONB"):
|
||||
return
|
||||
# Column exists with wrong type (e.g., INTEGER from an intermediate build); replace it.
|
||||
# Existing values are non-functional garbage so data loss is acceptable.
|
||||
op.drop_column(table_name, column_name)
|
||||
op.add_column(table_name, sa.Column(column_name, sa.JSON(), nullable=True))
|
||||
return
|
||||
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
if _has_column(inspector, "subscribehistory", "episode_priority") is False:
|
||||
op.add_column("subscribehistory", sa.Column("episode_priority", sa.JSON(), nullable=True))
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
_ensure_json_column(bind, "subscribe", "episode_priority")
|
||||
_ensure_json_column(bind, "subscribehistory", "episode_priority")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
50
database/versions/d5e6f7a8b9c0_2_2_8.py
Normal file
50
database/versions/d5e6f7a8b9c0_2_2_8.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""2.2.8
|
||||
修复 episode_priority 列类型:PostgreSQL 下若列为 INTEGER 则重建为 JSON
|
||||
|
||||
Revision ID: d5e6f7a8b9c0
|
||||
Revises: 1f0d2c3b4a5e
|
||||
Create Date: 2026-06-04
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "d5e6f7a8b9c0"
|
||||
down_revision = "1f0d2c3b4a5e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _has_column(inspector: sa.Inspector, table_name: str, column_name: str) -> bool:
|
||||
if table_name not in inspector.get_table_names():
|
||||
return False
|
||||
return any(column["name"] == column_name for column in inspector.get_columns(table_name))
|
||||
|
||||
|
||||
def _fix_episode_priority_type(bind, table_name: str) -> None:
|
||||
"""On PostgreSQL, if episode_priority exists but is not JSON/JSONB, drop and re-add it as JSON."""
|
||||
if bind.dialect.name != "postgresql":
|
||||
return
|
||||
inspector = sa.inspect(bind)
|
||||
if not _has_column(inspector, table_name, "episode_priority"):
|
||||
op.add_column(table_name, sa.Column("episode_priority", sa.JSON(), nullable=True))
|
||||
return
|
||||
for col in inspector.get_columns(table_name):
|
||||
if col["name"] != "episode_priority":
|
||||
continue
|
||||
type_name = type(col["type"]).__name__.upper()
|
||||
if type_name in ("JSON", "JSONB"):
|
||||
return
|
||||
op.drop_column(table_name, "episode_priority")
|
||||
op.add_column(table_name, sa.Column("episode_priority", sa.JSON(), nullable=True))
|
||||
return
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
_fix_episode_priority_type(bind, "subscribe")
|
||||
_fix_episode_priority_type(bind, "subscribehistory")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
11
pytest.ini
Normal file
11
pytest.ini
Normal file
@@ -0,0 +1,11 @@
|
||||
[pytest]
|
||||
# 仅对「无法在本仓修复根因」的已知上游/三方弃用告警做精确忽略,保持测试输出干净、
|
||||
# 让本仓自身的新告警更醒目。本仓代码引发的告警一律不在此忽略,应在源码/用例处修复。
|
||||
filterwarnings =
|
||||
ignore:datetime.datetime.utcfromtimestamp\(\) is deprecated:DeprecationWarning
|
||||
ignore:websockets.legacy is deprecated:DeprecationWarning
|
||||
ignore:websockets.InvalidStatusCode is deprecated:DeprecationWarning
|
||||
ignore:pkg_resources is deprecated as an API:DeprecationWarning
|
||||
ignore:Deprecated call to .pkg_resources.declare_namespace:DeprecationWarning
|
||||
ignore:'crypt' is deprecated:DeprecationWarning
|
||||
ignore:'audioop' is deprecated:DeprecationWarning
|
||||
@@ -1,43 +1,14 @@
|
||||
"""pytest 全局引导:在 import 任何测试模块前把 CONFIG_DIR 指向临时目录并建表,隔离真实库。"""
|
||||
import atexit
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from types import ModuleType
|
||||
"""pytest 全局引导:隔离 CONFIG_DIR、补 sites 垫片、建表、装载网络守卫。
|
||||
|
||||
# 必须早于首个 import app.*:app.db 在导入时即按 CONFIG_PATH 连接 user.db
|
||||
if not os.environ.get("CONFIG_DIR"):
|
||||
_isolated_config_dir = tempfile.mkdtemp(prefix="mp-test-config-")
|
||||
os.environ["CONFIG_DIR"] = _isolated_config_dir
|
||||
引导与网络守卫均复用 ``app/testing`` 的共享 harness(与插件仓 conftest 同源),
|
||||
引导逻辑只在 ``app/testing`` 维护一处。
|
||||
"""
|
||||
# 必须早于首个 import app.db(其在 import 期即按 CONFIG_PATH 连库):prepare_backend 内部
|
||||
# 先隔离 CONFIG_DIR、补 app.helper.sites 垫片,再建表。app/testing 仅依赖标准库、import 不连库,
|
||||
# 故此处先 import 再调用是安全的。
|
||||
from app.testing.bootstrap import prepare_backend
|
||||
|
||||
def _cleanup_isolated_config_dir():
|
||||
"""进程退出时先释放 SQLite 连接池再删临时目录。
|
||||
prepare_backend()
|
||||
|
||||
Windows 下 Engine 若仍持有 user.db 的文件锁,直接 rmtree 会因占用而静默失败
|
||||
(ignore_errors=True)、残留临时目录;先 dispose 释放连接再删可规避。
|
||||
"""
|
||||
try:
|
||||
from app.db import Engine
|
||||
Engine.dispose()
|
||||
except Exception:
|
||||
pass
|
||||
shutil.rmtree(_isolated_config_dir, ignore_errors=True)
|
||||
|
||||
atexit.register(_cleanup_isolated_config_dir)
|
||||
|
||||
# app.helper.sites 由独立仓库动态拉取(CI / 全新环境无该模块),而众多 app.chain.* /
|
||||
# app.modules.* 在 import 期依赖它。在此统一补一个最小垫片,省去各测试文件各自打桩;
|
||||
# 若真实模块已存在(本地已拉取)则 setdefault 不覆盖,不影响真实行为。
|
||||
if "app.helper.sites" not in sys.modules:
|
||||
try:
|
||||
import app.helper.sites # noqa: F401 本地已拉取时用真实模块
|
||||
except ModuleNotFoundError:
|
||||
_sites_stub = ModuleType("app.helper.sites")
|
||||
_sites_stub.SitesHelper = object
|
||||
sys.modules["app.helper.sites"] = _sites_stub
|
||||
|
||||
# 必须在 CONFIG_DIR 设好之后再 import;空库会让运行期查表报 no such table,故建表
|
||||
from app.db.init import init_db # noqa: E402
|
||||
|
||||
init_db()
|
||||
# 复用共享 autouse 网络守卫;同一实现亦供各插件仓 conftest import 复用,避免逐仓维护
|
||||
from app.testing.network_guard import block_real_network # noqa: E402,F401
|
||||
|
||||
@@ -453,7 +453,8 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
) as prepare_files, patch(
|
||||
"app.chain.message.agent_manager.process_message", new_callable=AsyncMock
|
||||
) as process_message, patch(
|
||||
"app.chain.message.asyncio.run_coroutine_threadsafe"
|
||||
"app.chain.message.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=lambda coro, _loop: coro.close(),
|
||||
) as run_coroutine_threadsafe:
|
||||
chain._handle_ai_message(
|
||||
text="/ai 帮我看看这张图",
|
||||
@@ -486,7 +487,7 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
"app.chain.message.agent_manager.process_message", new_callable=AsyncMock
|
||||
) as process_message, patch(
|
||||
"app.chain.message.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=lambda coro, _loop: (coro.close(), Mock())[1],
|
||||
side_effect=lambda coro, _loop: coro.close(),
|
||||
):
|
||||
chain._handle_ai_message(
|
||||
text="帮我推荐一部电影",
|
||||
|
||||
@@ -2,14 +2,10 @@ import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.agent import MoviePilotAgent
|
||||
from app.agent.memory import memory_manager
|
||||
# agenttokens 为动态安装插件(app/plugins/** 被 gitignore,CI / 全新环境无此插件),
|
||||
# 缺失时跳过本模块,避免 collection 阶段 ImportError。
|
||||
AgentTokens = pytest.importorskip("app.plugins.agenttokens").AgentTokens
|
||||
from app.schemas.types import ChainEventType, EventType
|
||||
|
||||
|
||||
@@ -44,31 +40,6 @@ class _FakeFailingAgent(_FakeAgent):
|
||||
|
||||
|
||||
class AgentTokensEventsTest(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_plugin_sidebar_nav_respects_config(self):
|
||||
"""插件侧边栏入口应受 show_sidebar_nav 配置控制。"""
|
||||
plugin = AgentTokens()
|
||||
|
||||
with patch.object(plugin, "update_config"):
|
||||
plugin.init_plugin(
|
||||
{
|
||||
"enabled": True,
|
||||
"show_sidebar_nav": False,
|
||||
"providers": [],
|
||||
}
|
||||
)
|
||||
self.assertEqual([], plugin.get_sidebar_nav())
|
||||
|
||||
plugin.init_plugin(
|
||||
{
|
||||
"enabled": True,
|
||||
"show_sidebar_nav": True,
|
||||
"providers": [],
|
||||
}
|
||||
)
|
||||
nav = plugin.get_sidebar_nav()
|
||||
|
||||
self.assertEqual("Agent Tokens 管理", nav[0]["title"])
|
||||
|
||||
async def test_initialize_llm_uses_chain_event_selection(self):
|
||||
"""Agent 初始化 LLM 时应优先使用链式事件返回的供应商配置。"""
|
||||
agent = MoviePilotAgent(session_id="agent-tokens-test", user_id="user-1")
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import hashlib
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.modules.filemanager.storages import alist as alist_module
|
||||
@@ -137,3 +139,47 @@ class AlistStorageTest(unittest.TestCase):
|
||||
self.assertEqual("alist", target.storage)
|
||||
self.assertEqual("file", target.type)
|
||||
self.assertEqual(1024, target.size)
|
||||
|
||||
def test_upload_sends_hash_headers_for_rapid_upload(self):
|
||||
"""
|
||||
OpenList 上传应附带文件哈希头,供服务端尝试秒传。
|
||||
"""
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
local_path = Path(temp_dir) / "rapid.bin"
|
||||
content = b"moviepilot-openlist-rapid-upload"
|
||||
local_path.write_bytes(content)
|
||||
upload_dir = FileItem(
|
||||
storage="alist",
|
||||
type="dir",
|
||||
path="/library/",
|
||||
name="library",
|
||||
basename="library",
|
||||
)
|
||||
uploaded_item = FileItem(
|
||||
storage="alist",
|
||||
type="file",
|
||||
path="/library/rapid.bin",
|
||||
name="rapid.bin",
|
||||
basename="rapid",
|
||||
extension="bin",
|
||||
size=len(content),
|
||||
)
|
||||
request_utils = MagicMock()
|
||||
request_utils.put_res.return_value = _FakeResponse(
|
||||
{"code": 200, "message": "success", "data": None}
|
||||
)
|
||||
|
||||
with patch.object(Alist, "get_conf", return_value={"url": "http://openlist.test", "token": "token"}):
|
||||
with patch.object(alist_module, "RequestUtils", return_value=request_utils) as request_utils_factory:
|
||||
with patch.object(self.storage, "_delay_get_item", return_value=uploaded_item):
|
||||
result = self.storage.upload(upload_dir, local_path)
|
||||
|
||||
self.assertEqual(uploaded_item, result)
|
||||
request_utils.put_res.assert_called_once()
|
||||
headers = request_utils_factory.call_args.kwargs["headers"]
|
||||
self.assertEqual(hashlib.md5(content).hexdigest(), headers["X-File-Md5"])
|
||||
self.assertEqual(hashlib.sha1(content).hexdigest(), headers["X-File-Sha1"])
|
||||
self.assertEqual(hashlib.sha256(content).hexdigest(), headers["X-File-Sha256"])
|
||||
self.assertEqual(str(len(content)), headers["Content-Length"])
|
||||
self.assertEqual("application/octet-stream", headers["Content-Type"])
|
||||
self.assertEqual("/library/rapid.bin", headers["File-Path"])
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
import sys
|
||||
import asyncio
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
from app.testing.bootstrap import ensure_optional_stub
|
||||
|
||||
sys.modules.setdefault("psutil", ModuleType("psutil"))
|
||||
sys.modules.setdefault("dateparser", ModuleType("dateparser"))
|
||||
|
||||
if "Pinyin2Hanzi" not in sys.modules:
|
||||
pinyin_module = ModuleType("Pinyin2Hanzi")
|
||||
setattr(pinyin_module, "is_pinyin", lambda value: False)
|
||||
sys.modules["Pinyin2Hanzi"] = pinyin_module
|
||||
# 可选三方依赖在 CI / 全新环境可能未安装,补占位避免 app.modules.feishu 导入失败
|
||||
ensure_optional_stub("psutil")
|
||||
ensure_optional_stub("dateparser")
|
||||
ensure_optional_stub("Pinyin2Hanzi", is_pinyin=lambda value: False)
|
||||
|
||||
from app.modules.feishu import FeishuModule
|
||||
from app.modules.feishu.feishu import Feishu
|
||||
|
||||
@@ -1,19 +1,6 @@
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
for _module_name in (
|
||||
"app.chain.mediaserver",
|
||||
"app.db.models",
|
||||
"app.db.user_oper",
|
||||
"app.helper.message",
|
||||
"app.utils.crypto",
|
||||
):
|
||||
if _module_name in sys.modules and not hasattr(
|
||||
sys.modules[_module_name], "__file__"
|
||||
):
|
||||
del sys.modules[_module_name]
|
||||
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.schemas import MediaServerLibrary, MediaServerPlayItem
|
||||
from app.utils.security import SecurityUtils
|
||||
|
||||
@@ -131,14 +131,15 @@ class PluginHelperTest(TestCase):
|
||||
self.skipTest(f"missing dependency: {exc}")
|
||||
|
||||
module_names = ["app.plugins.dynamicwechat.helper", "Crypto.Cipher._mode_cbc"]
|
||||
previous_modules = {name: sys.modules.get(name) for name in module_names}
|
||||
|
||||
def fake_execute(_cmd):
|
||||
for module_name in module_names:
|
||||
sys.modules[module_name] = ModuleType(module_name)
|
||||
return True, "ok"
|
||||
|
||||
try:
|
||||
# patch.dict 进入时快照 sys.modules、退出时整体还原,替代手写逐项 save/restore;
|
||||
# 保证 fake_execute 在安装窗口注入的运行态模块在用例结束后被清理、不污染其他用例
|
||||
with patch.dict(sys.modules):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
requirements_file = Path(temp_dir) / "requirements.txt"
|
||||
requirements_file.write_text("demo-package\n", encoding="utf-8")
|
||||
@@ -149,12 +150,6 @@ class PluginHelperTest(TestCase):
|
||||
self.assertEqual("ok", message)
|
||||
for module_name in module_names:
|
||||
self.assertIn(module_name, sys.modules)
|
||||
finally:
|
||||
for module_name, previous_module in previous_modules.items():
|
||||
if previous_module is None:
|
||||
sys.modules.pop(module_name, None)
|
||||
else:
|
||||
sys.modules[module_name] = previous_module
|
||||
|
||||
def test_pip_install_serializes_concurrent_calls(self):
|
||||
"""
|
||||
|
||||
100
tests/test_request_utils.py
Normal file
100
tests/test_request_utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import requests
|
||||
|
||||
from app.utils.http import RequestUtils
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""
|
||||
测试用 requests.Session 替身,记录请求次数与连接池关闭行为。
|
||||
"""
|
||||
|
||||
def __init__(self, side_effects):
|
||||
"""
|
||||
初始化请求结果序列。
|
||||
|
||||
:param side_effects: 每次 request 调用要返回或抛出的对象
|
||||
"""
|
||||
self.side_effects = list(side_effects)
|
||||
self.calls = []
|
||||
self.close_count = 0
|
||||
|
||||
def request(self, method, url, **kwargs):
|
||||
"""
|
||||
模拟 requests.Session.request。
|
||||
"""
|
||||
self.calls.append((method, url, kwargs))
|
||||
effect = self.side_effects.pop(0)
|
||||
if isinstance(effect, Exception):
|
||||
raise effect
|
||||
return effect
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
模拟清空 session 连接池。
|
||||
"""
|
||||
self.close_count += 1
|
||||
|
||||
|
||||
def _make_response(status_code: int = 200) -> requests.Response:
|
||||
response = requests.Response()
|
||||
response.status_code = status_code
|
||||
return response
|
||||
|
||||
|
||||
def test_request_utils_retries_idempotent_session_connection_error():
|
||||
"""
|
||||
同步幂等请求遇到失效 session 连接时应清理连接池并重试一次。
|
||||
"""
|
||||
response = _make_response()
|
||||
session = _FakeSession(
|
||||
[
|
||||
requests.exceptions.ConnectionError("stale keep-alive"),
|
||||
response,
|
||||
]
|
||||
)
|
||||
request_utils = RequestUtils(session=session)
|
||||
|
||||
result = request_utils.get_res("https://example.com/data")
|
||||
|
||||
assert result is response
|
||||
assert len(session.calls) == 2
|
||||
assert session.close_count == 1
|
||||
|
||||
|
||||
def test_request_utils_does_not_retry_non_idempotent_connection_error():
|
||||
"""
|
||||
非幂等请求连接异常时不应自动重试,避免重复提交副作用。
|
||||
"""
|
||||
session = _FakeSession(
|
||||
[
|
||||
requests.exceptions.ConnectionError("connection failed"),
|
||||
_make_response(),
|
||||
]
|
||||
)
|
||||
request_utils = RequestUtils(session=session)
|
||||
|
||||
result = request_utils.post_res("https://example.com/data", data={"name": "demo"})
|
||||
|
||||
assert result is None
|
||||
assert len(session.calls) == 1
|
||||
assert session.close_count == 0
|
||||
|
||||
|
||||
def test_request_utils_raises_retry_error_when_retry_still_fails():
|
||||
"""
|
||||
开启 raise_exception 后,重试仍失败时应抛出重试阶段的异常。
|
||||
"""
|
||||
first_error = requests.exceptions.ConnectionError("stale keep-alive")
|
||||
retry_error = requests.exceptions.ConnectionError("proxy still unavailable")
|
||||
session = _FakeSession([first_error, retry_error])
|
||||
request_utils = RequestUtils(session=session)
|
||||
|
||||
try:
|
||||
request_utils.get_res("https://example.com/data", raise_exception=True)
|
||||
except requests.exceptions.ConnectionError as err:
|
||||
assert err is retry_error
|
||||
else:
|
||||
raise AssertionError("请求重试失败时应抛出异常")
|
||||
|
||||
assert len(session.calls) == 2
|
||||
assert session.close_count == 1
|
||||
@@ -1,28 +1,15 @@
|
||||
import asyncio
|
||||
import importlib.machinery
|
||||
import sys
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from types import ModuleType
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.testing.bootstrap import ensure_optional_stub
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = sys.modules.get(name)
|
||||
if module is None:
|
||||
module = ModuleType(name)
|
||||
sys.modules[name] = module
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
return module
|
||||
|
||||
|
||||
_stub_module("qbittorrentapi", TorrentFilesList=list)
|
||||
_stub_module("transmission_rpc", File=object)
|
||||
_stub_module(
|
||||
"psutil",
|
||||
__spec__=importlib.machinery.ModuleSpec("psutil", loader=None),
|
||||
)
|
||||
# 可选三方依赖在 CI / 全新环境可能未安装,补占位(带用例所需属性)避免导入失败
|
||||
ensure_optional_stub("qbittorrentapi", TorrentFilesList=list)
|
||||
ensure_optional_stub("transmission_rpc", File=object)
|
||||
ensure_optional_stub("psutil", __spec__=importlib.machinery.ModuleSpec("psutil", loader=None))
|
||||
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.agent import ReplyMode
|
||||
|
||||
@@ -7,6 +7,7 @@ from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.schemas.types import MediaType
|
||||
from app.testing import stub_modules
|
||||
|
||||
|
||||
def _load_subscribe_chain_class():
|
||||
@@ -16,13 +17,11 @@ def _load_subscribe_chain_class():
|
||||
module = sys.modules[module_name]
|
||||
return module, module.SubscribeChain
|
||||
|
||||
original_modules = {}
|
||||
stub_deps = {}
|
||||
|
||||
def ensure_module(name: str, module: types.ModuleType):
|
||||
"""临时替换模块依赖,并记录原模块以便加载完成后恢复。"""
|
||||
if name not in original_modules:
|
||||
original_modules[name] = sys.modules.get(name)
|
||||
sys.modules[name] = module
|
||||
"""登记一个加载期临时替换模块;实际替换与精确还原由 stub_modules 在加载时统一处理。"""
|
||||
stub_deps[name] = module
|
||||
return module
|
||||
|
||||
chain_module = ensure_module("app.chain", types.ModuleType("app.chain"))
|
||||
@@ -298,18 +297,12 @@ def _load_subscribe_chain_class():
|
||||
subscribe_path = Path(__file__).resolve().parents[1] / "app" / "chain" / "subscribe.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, subscribe_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(module)
|
||||
module._injected_modules = {
|
||||
name: sys.modules.get(name)
|
||||
for name in original_modules
|
||||
}
|
||||
for injected_name, original_module in original_modules.items():
|
||||
if original_module is None:
|
||||
sys.modules.pop(injected_name, None)
|
||||
else:
|
||||
sys.modules[injected_name] = original_module
|
||||
# 加载期用 stub_modules 精确替换依赖、退出时统一还原;module_name 非桩,缓存入 sys.modules 供复用
|
||||
with stub_modules(stub_deps):
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
module._injected_modules = {name: sys.modules.get(name) for name in stub_deps}
|
||||
return module, module.SubscribeChain
|
||||
|
||||
|
||||
@@ -463,6 +456,69 @@ class SubscribeChainTest(TestCase):
|
||||
self.assertEqual(SubscribeChain.get_best_version_current_priority(subscribe), 100)
|
||||
self.assertTrue(SubscribeChain.is_best_version_complete(subscribe))
|
||||
|
||||
def test_get_subscribe_no_exists_expands_whole_missing_when_custom_start_skips_existing_range(self):
|
||||
"""自定义开始集跳过季初集数时,缺失整季需要转成显式目标集。"""
|
||||
no_exists = {
|
||||
"media-key": {
|
||||
1: SimpleNamespace(season=1, episodes=[], total_episode=48, start_episode=1)
|
||||
}
|
||||
}
|
||||
|
||||
exist_flag, result = SubscribeChain._SubscribeChain__get_subscribe_no_exits(
|
||||
subscribe_name="主角 S01",
|
||||
no_exists=no_exists,
|
||||
mediakey="media-key",
|
||||
begin_season=1,
|
||||
total_episode=48,
|
||||
start_episode=44,
|
||||
)
|
||||
|
||||
self.assertFalse(exist_flag)
|
||||
self.assertEqual(result["media-key"][1].episodes, [44, 45, 46, 47, 48])
|
||||
self.assertEqual(result["media-key"][1].start_episode, 44)
|
||||
self.assertEqual(result["media-key"][1].total_episode, 48)
|
||||
|
||||
def test_get_subscribe_no_exists_keeps_whole_missing_when_custom_start_matches_original_start(self):
|
||||
"""自定义开始集没有缩小范围时,仍保留空集列表表示整季缺失。"""
|
||||
no_exists = {
|
||||
"media-key": {
|
||||
1: SimpleNamespace(season=1, episodes=[], total_episode=48, start_episode=1)
|
||||
}
|
||||
}
|
||||
|
||||
exist_flag, result = SubscribeChain._SubscribeChain__get_subscribe_no_exits(
|
||||
subscribe_name="主角 S01",
|
||||
no_exists=no_exists,
|
||||
mediakey="media-key",
|
||||
begin_season=1,
|
||||
total_episode=48,
|
||||
start_episode=1,
|
||||
)
|
||||
|
||||
self.assertFalse(exist_flag)
|
||||
self.assertEqual(result["media-key"][1].episodes, [])
|
||||
self.assertEqual(result["media-key"][1].start_episode, 1)
|
||||
self.assertEqual(result["media-key"][1].total_episode, 48)
|
||||
|
||||
def test_best_version_full_pack_first_keeps_whole_missing_for_custom_start_episode(self):
|
||||
"""分集洗版优先全集时,空集列表仍表示下载链按整季资源处理。"""
|
||||
subscribe = self._build_subscribe(
|
||||
best_version=1,
|
||||
best_version_full=0,
|
||||
start_episode=44,
|
||||
total_episode=48,
|
||||
episode_priority={str(episode): 80 for episode in range(44, 49)},
|
||||
)
|
||||
|
||||
result = SubscribeChain._SubscribeChain__build_full_pack_first_no_exists(
|
||||
subscribe=subscribe,
|
||||
mediakey="media-key",
|
||||
)
|
||||
|
||||
self.assertEqual(result["media-key"][1].episodes, [])
|
||||
self.assertEqual(result["media-key"][1].start_episode, 44)
|
||||
self.assertEqual(result["media-key"][1].total_episode, 48)
|
||||
|
||||
def test_is_episode_range_covered_matches_pending_episodes(self):
|
||||
subscribe = self._build_subscribe(
|
||||
total_episode=12,
|
||||
|
||||
@@ -1,24 +1,17 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import unittest
|
||||
from types import ModuleType
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
|
||||
_ORIGINAL_STUBBED_MODULES = {}
|
||||
from app.testing import stub_modules
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
"""
|
||||
安装临时 stub 模块,并记录原模块用于导入后恢复。
|
||||
"""
|
||||
if name not in _ORIGINAL_STUBBED_MODULES:
|
||||
_ORIGINAL_STUBBED_MODULES[name] = sys.modules.get(name)
|
||||
def _stub(name: str, **attrs) -> tuple:
|
||||
"""构造带指定属性的占位模块,返回 ``(模块名, 模块)`` 供 :func:`stub_modules` 使用。"""
|
||||
module = ModuleType(name)
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
sys.modules[name] = module
|
||||
return module
|
||||
return name, module
|
||||
|
||||
|
||||
class _Dummy:
|
||||
@@ -35,69 +28,44 @@ class _DummyError(Exception):
|
||||
self.duration_ms = duration_ms
|
||||
|
||||
|
||||
for _module_name in ("pillow_avif", "aiofiles", "psutil"):
|
||||
_stub_module(_module_name)
|
||||
# 在 import 期用占位模块替换重依赖/外部模块,import 完由 stub_modules 精确还原,避免污染其它用例
|
||||
_STUB_MODULES = dict([
|
||||
_stub("pillow_avif"),
|
||||
_stub("aiofiles"),
|
||||
_stub("psutil"),
|
||||
_stub("app.helper.sites", SitesHelper=_Dummy),
|
||||
_stub("app.chain.mediaserver", MediaServerChain=_Dummy),
|
||||
_stub("app.chain.search", SearchChain=_Dummy),
|
||||
_stub("app.chain.system", SystemChain=_Dummy),
|
||||
_stub("app.agent.llm", LLMHelper=_Dummy, LLMProviderManager=_Dummy,
|
||||
LLMTestError=_DummyError, LLMTestTimeout=_DummyError,
|
||||
render_auth_result_html=lambda success, message: message),
|
||||
_stub("app.core.event", eventmanager=_Dummy(), Event=_Dummy, EventManager=_Dummy),
|
||||
_stub("app.core.metainfo", MetaInfo=_Dummy),
|
||||
_stub("app.core.module", ModuleManager=_Dummy),
|
||||
_stub("app.core.security", verify_apitoken=_Dummy, verify_resource_token=_Dummy, verify_token=_Dummy),
|
||||
_stub("app.db.models", User=_Dummy),
|
||||
_stub("app.db.systemconfig_oper", SystemConfigOper=_Dummy),
|
||||
_stub("app.db.user_oper", get_current_active_superuser=_Dummy,
|
||||
get_current_active_superuser_async=_Dummy, get_current_active_user_async=_Dummy),
|
||||
_stub("app.helper.llm", LLMHelper=_Dummy, LLMTestError=_DummyError, LLMTestTimeout=_DummyError),
|
||||
_stub("app.helper.mediaserver", MediaServerHelper=_Dummy),
|
||||
_stub("app.helper.message", MessageHelper=_Dummy),
|
||||
_stub("app.helper.progress", ProgressHelper=_Dummy),
|
||||
_stub("app.helper.rule", RuleHelper=_Dummy),
|
||||
_stub("app.helper.server", MoviePilotServerHelper=_Dummy),
|
||||
_stub("app.helper.system", SystemHelper=_Dummy),
|
||||
_stub("app.helper.image", ImageHelper=_Dummy),
|
||||
_stub("app.scheduler", Scheduler=_Dummy),
|
||||
_stub("app.log", logger=_Dummy(), log_settings=_Dummy(),
|
||||
LogConfigModel=type("LogConfigModel", (), {})),
|
||||
_stub("app.utils.crypto", HashUtils=_Dummy),
|
||||
_stub("app.utils.http", RequestUtils=_Dummy, AsyncRequestUtils=_Dummy),
|
||||
_stub("version", APP_VERSION="test"),
|
||||
])
|
||||
|
||||
_stub_module("app.helper.sites", SitesHelper=_Dummy)
|
||||
_stub_module("app.chain.mediaserver", MediaServerChain=_Dummy)
|
||||
_stub_module("app.chain.search", SearchChain=_Dummy)
|
||||
_stub_module("app.chain.system", SystemChain=_Dummy)
|
||||
_stub_module(
|
||||
"app.agent.llm",
|
||||
LLMHelper=_Dummy,
|
||||
LLMProviderManager=_Dummy,
|
||||
LLMTestError=_DummyError,
|
||||
LLMTestTimeout=_DummyError,
|
||||
render_auth_result_html=lambda success, message: message,
|
||||
)
|
||||
_stub_module("app.core.event", eventmanager=_Dummy(), Event=_Dummy, EventManager=_Dummy)
|
||||
_stub_module("app.core.metainfo", MetaInfo=_Dummy)
|
||||
_stub_module("app.core.module", ModuleManager=_Dummy)
|
||||
_stub_module(
|
||||
"app.core.security",
|
||||
verify_apitoken=_Dummy,
|
||||
verify_resource_token=_Dummy,
|
||||
verify_token=_Dummy,
|
||||
)
|
||||
_stub_module("app.db.models", User=_Dummy)
|
||||
_stub_module("app.db.systemconfig_oper", SystemConfigOper=_Dummy)
|
||||
_stub_module(
|
||||
"app.db.user_oper",
|
||||
get_current_active_superuser=_Dummy,
|
||||
get_current_active_superuser_async=_Dummy,
|
||||
get_current_active_user_async=_Dummy,
|
||||
)
|
||||
_stub_module(
|
||||
"app.helper.llm",
|
||||
LLMHelper=_Dummy,
|
||||
LLMTestError=_DummyError,
|
||||
LLMTestTimeout=_DummyError,
|
||||
)
|
||||
_stub_module("app.helper.mediaserver", MediaServerHelper=_Dummy)
|
||||
_stub_module("app.helper.message", MessageHelper=_Dummy)
|
||||
_stub_module("app.helper.progress", ProgressHelper=_Dummy)
|
||||
_stub_module("app.helper.rule", RuleHelper=_Dummy)
|
||||
_stub_module("app.helper.server", MoviePilotServerHelper=_Dummy)
|
||||
_stub_module("app.helper.system", SystemHelper=_Dummy)
|
||||
_stub_module("app.helper.image", ImageHelper=_Dummy)
|
||||
_stub_module("app.scheduler", Scheduler=_Dummy)
|
||||
_stub_module(
|
||||
"app.log",
|
||||
logger=_Dummy(),
|
||||
log_settings=_Dummy(),
|
||||
LogConfigModel=type("LogConfigModel", (), {}),
|
||||
)
|
||||
_stub_module("app.utils.crypto", HashUtils=_Dummy)
|
||||
_stub_module("app.utils.http", RequestUtils=_Dummy, AsyncRequestUtils=_Dummy)
|
||||
_stub_module("version", APP_VERSION="test")
|
||||
|
||||
from app.api.endpoints import llm as system_endpoint
|
||||
|
||||
for _module_name, _module in _ORIGINAL_STUBBED_MODULES.items():
|
||||
if _module is None:
|
||||
sys.modules.pop(_module_name, None)
|
||||
else:
|
||||
sys.modules[_module_name] = _module
|
||||
with stub_modules(_STUB_MODULES):
|
||||
from app.api.endpoints import llm as system_endpoint
|
||||
|
||||
|
||||
class LlmTestEndpointTest(unittest.TestCase):
|
||||
|
||||
@@ -1,25 +1,18 @@
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import sys
|
||||
import unittest
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
|
||||
_ORIGINAL_STUBBED_MODULES = {}
|
||||
from app.testing import stub_modules
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
"""
|
||||
安装临时 stub 模块,并记录原模块用于导入后恢复。
|
||||
"""
|
||||
if name not in _ORIGINAL_STUBBED_MODULES:
|
||||
_ORIGINAL_STUBBED_MODULES[name] = sys.modules.get(name)
|
||||
def _stub(name: str, **attrs) -> tuple:
|
||||
"""构造带指定属性的占位模块,返回 ``(模块名, 模块)`` 供 :func:`stub_modules` 使用。"""
|
||||
module = ModuleType(name)
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
sys.modules[name] = module
|
||||
return module
|
||||
return name, module
|
||||
|
||||
|
||||
class _Dummy:
|
||||
@@ -36,67 +29,42 @@ class _DummyError(Exception):
|
||||
self.duration_ms = duration_ms
|
||||
|
||||
|
||||
for _module_name in ("pillow_avif", "aiofiles", "psutil"):
|
||||
_stub_module(_module_name)
|
||||
# 在 import 期用占位模块替换重依赖/外部模块,import 完由 stub_modules 精确还原,避免污染其它用例
|
||||
_STUB_MODULES = dict([
|
||||
_stub("pillow_avif"),
|
||||
_stub("aiofiles"),
|
||||
_stub("psutil"),
|
||||
_stub("app.helper.sites", SitesHelper=_Dummy),
|
||||
_stub("app.chain.media", MediaChain=_Dummy),
|
||||
_stub("app.chain.mediaserver", MediaServerChain=_Dummy),
|
||||
_stub("app.chain.search", SearchChain=_Dummy),
|
||||
_stub("app.chain.system", SystemChain=_Dummy),
|
||||
_stub("app.core.event", eventmanager=_Dummy(), Event=_Dummy, EventManager=_Dummy),
|
||||
_stub("app.core.metainfo", MetaInfo=_Dummy),
|
||||
_stub("app.core.module", ModuleManager=_Dummy),
|
||||
_stub("app.core.security", verify_apitoken=_Dummy, verify_resource_token=_Dummy, verify_token=_Dummy),
|
||||
_stub("app.db.models", User=_Dummy),
|
||||
_stub("app.db.systemconfig_oper", SystemConfigOper=_Dummy),
|
||||
_stub("app.db.user_oper", get_current_active_superuser=_Dummy,
|
||||
get_current_active_superuser_async=_Dummy, get_current_active_user_async=_Dummy),
|
||||
_stub("app.helper.llm", LLMHelper=_Dummy, LLMTestError=_DummyError, LLMTestTimeout=_DummyError),
|
||||
_stub("app.helper.mediaserver", MediaServerHelper=_Dummy),
|
||||
_stub("app.helper.message", MessageHelper=_Dummy),
|
||||
_stub("app.helper.progress", ProgressHelper=_Dummy),
|
||||
_stub("app.helper.rule", RuleHelper=_Dummy),
|
||||
_stub("app.helper.server", MoviePilotServerHelper=_Dummy),
|
||||
_stub("app.helper.system", SystemHelper=_Dummy),
|
||||
_stub("app.helper.image", ImageHelper=_Dummy),
|
||||
_stub("app.scheduler", Scheduler=_Dummy),
|
||||
_stub("app.log", logger=_Dummy(), log_settings=_Dummy(),
|
||||
LogConfigModel=type("LogConfigModel", (), {})),
|
||||
_stub("app.utils.crypto", HashUtils=_Dummy),
|
||||
_stub("app.utils.http", RequestUtils=_Dummy, AsyncRequestUtils=_Dummy),
|
||||
_stub("version", APP_VERSION="test", FRONTEND_VERSION="frontend-test"),
|
||||
])
|
||||
|
||||
_stub_module("app.helper.sites", SitesHelper=_Dummy)
|
||||
_stub_module("app.chain.media", MediaChain=_Dummy)
|
||||
_stub_module("app.chain.mediaserver", MediaServerChain=_Dummy)
|
||||
_stub_module("app.chain.search", SearchChain=_Dummy)
|
||||
_stub_module("app.chain.system", SystemChain=_Dummy)
|
||||
_stub_module(
|
||||
"app.core.event",
|
||||
eventmanager=_Dummy(),
|
||||
Event=_Dummy,
|
||||
EventManager=_Dummy,
|
||||
)
|
||||
_stub_module("app.core.metainfo", MetaInfo=_Dummy)
|
||||
_stub_module("app.core.module", ModuleManager=_Dummy)
|
||||
_stub_module(
|
||||
"app.core.security",
|
||||
verify_apitoken=_Dummy,
|
||||
verify_resource_token=_Dummy,
|
||||
verify_token=_Dummy,
|
||||
)
|
||||
_stub_module("app.db.models", User=_Dummy)
|
||||
_stub_module("app.db.systemconfig_oper", SystemConfigOper=_Dummy)
|
||||
_stub_module(
|
||||
"app.db.user_oper",
|
||||
get_current_active_superuser=_Dummy,
|
||||
get_current_active_superuser_async=_Dummy,
|
||||
get_current_active_user_async=_Dummy,
|
||||
)
|
||||
_stub_module(
|
||||
"app.helper.llm",
|
||||
LLMHelper=_Dummy,
|
||||
LLMTestError=_DummyError,
|
||||
LLMTestTimeout=_DummyError,
|
||||
)
|
||||
_stub_module("app.helper.mediaserver", MediaServerHelper=_Dummy)
|
||||
_stub_module("app.helper.message", MessageHelper=_Dummy)
|
||||
_stub_module("app.helper.progress", ProgressHelper=_Dummy)
|
||||
_stub_module("app.helper.rule", RuleHelper=_Dummy)
|
||||
_stub_module("app.helper.server", MoviePilotServerHelper=_Dummy)
|
||||
_stub_module("app.helper.system", SystemHelper=_Dummy)
|
||||
_stub_module("app.helper.image", ImageHelper=_Dummy)
|
||||
_stub_module("app.scheduler", Scheduler=_Dummy)
|
||||
_stub_module(
|
||||
"app.log",
|
||||
logger=_Dummy(),
|
||||
log_settings=_Dummy(),
|
||||
LogConfigModel=type("LogConfigModel", (), {}),
|
||||
)
|
||||
_stub_module("app.utils.crypto", HashUtils=_Dummy)
|
||||
_stub_module("app.utils.http", RequestUtils=_Dummy, AsyncRequestUtils=_Dummy)
|
||||
_stub_module("version", APP_VERSION="test", FRONTEND_VERSION="frontend-test")
|
||||
|
||||
from app.api.endpoints import system as system_endpoint
|
||||
|
||||
for _module_name, _module in _ORIGINAL_STUBBED_MODULES.items():
|
||||
if _module is None:
|
||||
sys.modules.pop(_module_name, None)
|
||||
else:
|
||||
sys.modules[_module_name] = _module
|
||||
with stub_modules(_STUB_MODULES):
|
||||
from app.api.endpoints import system as system_endpoint
|
||||
|
||||
|
||||
class NettestSecurityTest(unittest.TestCase):
|
||||
|
||||
62
tests/test_testing_bootstrap.py
Normal file
62
tests/test_testing_bootstrap.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""共享测试引导工具的回归用例。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
from app.testing import bootstrap
|
||||
|
||||
|
||||
def test_isolate_config_cleanup_uses_loaded_db_module_without_late_import(monkeypatch):
|
||||
"""清理回调只读取已加载模块,避免解释器关停期触发二次导入。"""
|
||||
captured = {}
|
||||
import_calls = []
|
||||
|
||||
def fake_import(name, *args, **kwargs):
|
||||
"""记录清理回调是否试图重新导入数据库模块。"""
|
||||
if name == "app.db":
|
||||
import_calls.append(name)
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
def fake_register(func):
|
||||
"""截获 atexit 回调,便于直接验证清理行为。"""
|
||||
captured["cleanup"] = func
|
||||
|
||||
monkeypatch.setattr(bootstrap, "_isolated_config_dir", None)
|
||||
monkeypatch.delenv("CONFIG_DIR", raising=False)
|
||||
monkeypatch.setattr(bootstrap.tempfile, "mkdtemp", lambda prefix: "/tmp/mp-test-config-demo")
|
||||
monkeypatch.setattr(bootstrap.shutil, "rmtree", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(bootstrap.atexit, "register", fake_register)
|
||||
|
||||
original_import = builtins.__import__
|
||||
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||
|
||||
bootstrap.isolate_config_dir()
|
||||
captured["cleanup"]()
|
||||
|
||||
assert import_calls == []
|
||||
|
||||
|
||||
def test_mark_plugin_generation_prefers_pathlib_item_path():
|
||||
"""pytest 新版 item.path 可独立驱动 v1/v2 marker 标记。"""
|
||||
|
||||
class FakeItem:
|
||||
"""只暴露 pytest 7+ 的 path 属性,模拟新版收集对象。"""
|
||||
|
||||
def __init__(self, value: str):
|
||||
self.path = Path(value)
|
||||
self.markers = []
|
||||
|
||||
def add_marker(self, marker):
|
||||
"""记录被添加的 marker。"""
|
||||
self.markers.append(marker)
|
||||
|
||||
pytest_module = types.SimpleNamespace(mark=types.SimpleNamespace(v1="v1", v2="v2"))
|
||||
v2_item = FakeItem("/repo/tests/v2/test_demo.py")
|
||||
v1_item = FakeItem("/repo/tests/v1/test_demo.py")
|
||||
|
||||
bootstrap.mark_plugin_generation([v2_item, v1_item], pytest_module)
|
||||
|
||||
assert v2_item.markers == ["v2"]
|
||||
assert v1_item.markers == ["v1"]
|
||||
@@ -95,11 +95,16 @@ class TestTransferFailedRetryButtons(unittest.TestCase):
|
||||
errmsg="未识别到媒体信息",
|
||||
)
|
||||
|
||||
def _close_pending_coro(coro, *args, **kwargs):
|
||||
"""关闭被调度的协程:测试中事件循环未运行,不关闭会残留 never-awaited 警告。"""
|
||||
coro.close()
|
||||
|
||||
with patch.object(settings, "AI_AGENT_ENABLE", True):
|
||||
with patch(
|
||||
"app.chain.message.TransferHistoryOper"
|
||||
) as history_oper_cls, patch(
|
||||
"app.chain.message.asyncio.run_coroutine_threadsafe"
|
||||
"app.chain.message.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=_close_pending_coro,
|
||||
) as run_task:
|
||||
history_oper_cls.return_value.get.return_value = history
|
||||
with patch.object(chain, "post_message") as post_message:
|
||||
|
||||
161
tests/test_workflow_actions.py
Normal file
161
tests/test_workflow_actions.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.schemas import ActionContext, DownloadTask, FileItem
|
||||
from app.schemas.workflow import ActionResult
|
||||
from app.workflow.actions import BaseAction
|
||||
from app.workflow.actions import fetch_downloads as fetch_downloads_module
|
||||
from app.workflow.actions import scrape_file as scrape_file_module
|
||||
from app.workflow.actions.fetch_downloads import FetchDownloadsAction
|
||||
from app.workflow.actions.scrape_file import ScrapeFileAction
|
||||
from app.workflow.actions.fetch_rss import FetchRssAction
|
||||
from app.workflow import WorkFlowManager
|
||||
|
||||
|
||||
def test_fetch_downloads_updates_context_downloads(monkeypatch):
|
||||
"""获取下载任务动作应更新上游上下文中的下载任务。"""
|
||||
calls = []
|
||||
|
||||
class FakeActionChain:
|
||||
"""模拟下载器查询链。"""
|
||||
|
||||
def list_torrents(self, hashs=None, downloader=None, **kwargs):
|
||||
calls.append((hashs, downloader))
|
||||
return [SimpleNamespace(path="/downloads/movie.mkv", progress=100)]
|
||||
|
||||
monkeypatch.setattr(fetch_downloads_module, "ActionChain", FakeActionChain)
|
||||
monkeypatch.setattr(fetch_downloads_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
context = ActionContext(
|
||||
downloads=[
|
||||
DownloadTask(download_id="hash-1", downloader="qbittorrent"),
|
||||
]
|
||||
)
|
||||
|
||||
result = FetchDownloadsAction("fetch-downloads").execute(
|
||||
workflow_id=1,
|
||||
params={},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert calls == [(["hash-1"], "qbittorrent")]
|
||||
assert result.downloads[0].completed is True
|
||||
assert result.downloads[0].path == "/downloads/movie.mkv"
|
||||
|
||||
|
||||
def test_scrape_file_keeps_workflow_action_context(monkeypatch):
|
||||
"""刮削文件动作不应将工作流上下文替换为媒体识别上下文。"""
|
||||
scraped = []
|
||||
|
||||
class FakeStorageChain:
|
||||
"""模拟存储链。"""
|
||||
|
||||
def exists(self, fileitem):
|
||||
return True
|
||||
|
||||
class FakeMediaChain:
|
||||
"""模拟媒体识别和刮削链。"""
|
||||
|
||||
def recognize_by_path(self, path, obtain_images=False):
|
||||
return SimpleNamespace(meta_info="meta", media_info="media")
|
||||
|
||||
def scrape_metadata(self, fileitem, meta=None, mediainfo=None):
|
||||
scraped.append((fileitem.path, meta, mediainfo))
|
||||
|
||||
monkeypatch.setattr(scrape_file_module, "StorageChain", FakeStorageChain)
|
||||
monkeypatch.setattr(scrape_file_module, "MediaChain", FakeMediaChain)
|
||||
monkeypatch.setattr(scrape_file_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
monkeypatch.setattr(ScrapeFileAction, "check_cache", lambda self, workflow_id, key: False)
|
||||
monkeypatch.setattr(ScrapeFileAction, "save_cache", lambda self, workflow_id, data: None)
|
||||
|
||||
context = ActionContext(
|
||||
fileitems=[
|
||||
FileItem(path="/library/movie.mkv", storage="local", type="file"),
|
||||
]
|
||||
)
|
||||
|
||||
result = ScrapeFileAction("scrape-file").execute(
|
||||
workflow_id=1,
|
||||
params={},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result is context
|
||||
assert result.fileitems[0].path == "/library/movie.mkv"
|
||||
assert scraped == [("/library/movie.mkv", "meta", "media")]
|
||||
|
||||
|
||||
def test_execute_with_inputs_maps_contract_inputs_outputs_and_runtime(monkeypatch):
|
||||
"""新版动作桥接方法应按契约映射输入、输出和运行期信息。"""
|
||||
|
||||
class ContractAction(BaseAction):
|
||||
"""测试动作契约桥接。"""
|
||||
|
||||
contract = {
|
||||
"inputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
|
||||
"outputs": [{"name": "downloads", "label": "下载任务", "kind": "list"}],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def name(cls) -> str:
|
||||
return "契约动作"
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def description(cls) -> str:
|
||||
return "测试契约动作"
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def data(cls) -> dict:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return True
|
||||
|
||||
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
|
||||
"""执行测试动作。"""
|
||||
_ = workflow_id, params
|
||||
context.downloads = [
|
||||
DownloadTask(download_id=f"{item}-hash", downloader="qbittorrent")
|
||||
for item in context.torrents
|
||||
]
|
||||
self.job_done("完成")
|
||||
return context
|
||||
|
||||
result = ContractAction("contract").execute_with_inputs(
|
||||
workflow_id=1,
|
||||
params={},
|
||||
inputs={"torrents": ["movie"]},
|
||||
runtime={"attempt": 1, "max_attempts": 1, "cancel_token": object()},
|
||||
context=ActionContext(),
|
||||
)
|
||||
|
||||
assert isinstance(result, ActionResult)
|
||||
assert result.outputs["downloads"][0].download_id == "movie-hash"
|
||||
assert result.context.runtime_state["current_action_runtime"] == {
|
||||
"attempt": 1,
|
||||
"max_attempts": 1,
|
||||
}
|
||||
|
||||
path_result = ContractAction("contract").execute_with_inputs(
|
||||
workflow_id=1,
|
||||
params={},
|
||||
inputs={"outputs.FetchRssAction.torrents": ["legacy"]},
|
||||
runtime={},
|
||||
context=ActionContext(),
|
||||
)
|
||||
|
||||
assert path_result.outputs["downloads"][0].download_id == "legacy-hash"
|
||||
|
||||
|
||||
def test_workflow_manager_list_actions_exposes_contract():
|
||||
"""动作列表应返回固定输入输出契约。"""
|
||||
manager = object.__new__(WorkFlowManager)
|
||||
manager._actions = {"FetchRssAction": FetchRssAction}
|
||||
|
||||
actions = manager.list_actions()
|
||||
|
||||
assert actions[0]["contract"]["outputs"][0]["name"] == "torrents"
|
||||
assert actions[0]["contract"]["condition_fields"][0]["label"] == "资源"
|
||||
787
tests/test_workflow_execution.py
Normal file
787
tests/test_workflow_execution.py
Normal file
@@ -0,0 +1,787 @@
|
||||
import base64
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.chain import workflow as workflow_module
|
||||
from app.schemas import Action, ActionContext, ActionResult
|
||||
from app.schemas.types import EventType
|
||||
from app import workflow as workflow_package
|
||||
|
||||
|
||||
def _build_workflow(current_action=None, context=None, actions=None, flows=None,
|
||||
execution_config=None, execution_state=None):
|
||||
"""构造最小工作流对象。"""
|
||||
return SimpleNamespace(
|
||||
id=1,
|
||||
name="测试工作流",
|
||||
actions=actions if actions is not None else [
|
||||
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
|
||||
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
|
||||
],
|
||||
flows=flows if flows is not None else [
|
||||
{"id": "flow-1", "source": "A", "target": "B", "animated": True},
|
||||
],
|
||||
current_action=current_action,
|
||||
context=context,
|
||||
execution_config=execution_config or {},
|
||||
execution_state=execution_state or {},
|
||||
)
|
||||
|
||||
|
||||
def _encoded_context(context: ActionContext) -> dict:
|
||||
"""编码工作流恢复上下文。"""
|
||||
return {
|
||||
"content": base64.b64encode(pickle.dumps(context)).decode("utf-8"),
|
||||
}
|
||||
|
||||
|
||||
class _FakeWorkflowManager:
|
||||
"""记录执行动作的工作流管理器。"""
|
||||
|
||||
def __init__(self, calls, results=None, contracts=None):
|
||||
self.calls = calls
|
||||
self.results = results or {}
|
||||
self.contracts = contracts or {}
|
||||
self.received_inputs = []
|
||||
|
||||
def execute(self, workflow_id, action, context=None, inputs=None, runtime=None, cancel_token=None):
|
||||
"""执行伪动作并记录新版输入。"""
|
||||
self.calls.append(action.id)
|
||||
self.received_inputs.append((action.id, inputs or {}, runtime or {}, cancel_token))
|
||||
result = self.results.get(action.id)
|
||||
if callable(result):
|
||||
return result(action, context or ActionContext())
|
||||
if result:
|
||||
return result
|
||||
return ActionResult(success=True, message=f"{action.name}完成", context=context or ActionContext())
|
||||
|
||||
def excute(self, workflow_id, action, context=None):
|
||||
"""兼容历史执行方法。"""
|
||||
result = self.execute(workflow_id, action, context)
|
||||
return result.success, result.message, result.context
|
||||
|
||||
def get_action_contract(self, action_type):
|
||||
"""获取伪动作契约。"""
|
||||
return self.contracts.get(action_type) or {}
|
||||
|
||||
|
||||
class _FakeWorkflowOper:
|
||||
"""记录工作流持久化调用。"""
|
||||
|
||||
def __init__(self, workflow):
|
||||
self.workflow = workflow
|
||||
self.steps = []
|
||||
self.started = False
|
||||
self.failed_result = None
|
||||
self.succeeded = False
|
||||
|
||||
def reset(self, wid):
|
||||
"""模拟重置工作流。"""
|
||||
_ = wid
|
||||
return True
|
||||
|
||||
def get(self, wid):
|
||||
"""返回预置工作流。"""
|
||||
_ = wid
|
||||
return self.workflow
|
||||
|
||||
def start(self, wid):
|
||||
"""记录启动调用。"""
|
||||
_ = wid
|
||||
self.started = True
|
||||
return True
|
||||
|
||||
def step(self, wid, action_id, context, execution_state=None):
|
||||
"""记录步骤持久化数据。"""
|
||||
self.steps.append(
|
||||
{
|
||||
"wid": wid,
|
||||
"action_id": action_id,
|
||||
"context": context,
|
||||
"execution_state": execution_state,
|
||||
}
|
||||
)
|
||||
return True
|
||||
|
||||
def fail(self, wid, result):
|
||||
"""记录失败结果。"""
|
||||
_ = wid
|
||||
self.failed_result = result
|
||||
return True
|
||||
|
||||
def success(self, wid, result=None):
|
||||
"""记录成功结果。"""
|
||||
_ = wid, result
|
||||
self.succeeded = True
|
||||
return True
|
||||
|
||||
|
||||
class _OpaqueValue:
|
||||
"""模拟无法直接 JSON 序列化的值。"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __str__(self):
|
||||
return "opaque-value"
|
||||
|
||||
|
||||
def test_workflow_executor_resumes_downstream_nodes(monkeypatch):
|
||||
"""恢复执行时应释放已完成节点的后继节点。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(calls)
|
||||
workflow = _build_workflow(
|
||||
current_action="A",
|
||||
context=_encoded_context(ActionContext()),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
assert calls == ["B"]
|
||||
assert executor.success is True
|
||||
assert executor.context.progress == 100
|
||||
|
||||
|
||||
def test_workflow_executor_restores_structured_context(monkeypatch):
|
||||
"""恢复执行时应兼容新版结构化上下文存储格式。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(calls)
|
||||
workflow = _build_workflow(
|
||||
current_action="A",
|
||||
context={
|
||||
"workflow_context": {"trace_id": "wf-1"},
|
||||
"node_outputs": {"A": {"items": ["movie"]}},
|
||||
"progress": 50,
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
assert calls == ["B"]
|
||||
assert executor.context.workflow_context["trace_id"] == "wf-1"
|
||||
assert executor.context.node_outputs["A"]["items"] == ["movie"]
|
||||
|
||||
|
||||
def test_workflow_executor_reports_incremental_progress(monkeypatch):
|
||||
"""顺序工作流的中间进度应按已完成比例计算。"""
|
||||
calls = []
|
||||
progresses = []
|
||||
fake_manager = _FakeWorkflowManager(calls)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(
|
||||
_build_workflow(),
|
||||
step_callback=lambda action, context: progresses.append(context.progress),
|
||||
)
|
||||
executor.execute()
|
||||
|
||||
assert calls == ["A", "B"]
|
||||
assert progresses == [50, 100]
|
||||
|
||||
|
||||
def test_workflow_executor_skips_false_condition_branch(monkeypatch):
|
||||
"""条件边不满足时应跳过对应分支,并继续执行满足条件的分支。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(
|
||||
calls,
|
||||
results={
|
||||
"A": lambda action, context: ActionResult(
|
||||
success=True,
|
||||
message=f"{action.name}完成",
|
||||
context=context,
|
||||
outputs={"items": ["movie"]}
|
||||
)
|
||||
}
|
||||
)
|
||||
workflow = _build_workflow(
|
||||
actions=[
|
||||
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
|
||||
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
|
||||
{"id": "C", "type": "FakeAction", "name": "动作C", "data": {}},
|
||||
],
|
||||
flows=[
|
||||
{"id": "flow-ab", "source": "A", "target": "B", "condition": "outputs.A.items.count == 0"},
|
||||
{"id": "flow-ac", "source": "A", "target": "C", "data": {"condition": "outputs.A.items.count > 0"}},
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
assert calls == ["A", "C"]
|
||||
assert executor.success is True
|
||||
assert executor.context.progress == 100
|
||||
assert executor.context.node_outputs["A"]["items"] == ["movie"]
|
||||
|
||||
|
||||
def test_workflow_executor_all_success_join_waits_parallel_branches(monkeypatch):
|
||||
"""默认汇合策略应等待所有上游分支成功后再执行目标节点。"""
|
||||
calls = []
|
||||
joined_outputs = {}
|
||||
|
||||
def run_join(action, context):
|
||||
"""记录汇合节点读取到的上游输出。"""
|
||||
joined_outputs.update(context.node_outputs)
|
||||
return ActionResult(success=True, message=f"{action.name}完成", context=context)
|
||||
|
||||
fake_manager = _FakeWorkflowManager(
|
||||
calls,
|
||||
results={
|
||||
"A": lambda action, context: ActionResult(
|
||||
success=True,
|
||||
message=f"{action.name}完成",
|
||||
context=context,
|
||||
outputs={"value": "A"}
|
||||
),
|
||||
"B": lambda action, context: ActionResult(
|
||||
success=True,
|
||||
message=f"{action.name}完成",
|
||||
context=context,
|
||||
outputs={"value": "B"}
|
||||
),
|
||||
"C": run_join,
|
||||
}
|
||||
)
|
||||
workflow = _build_workflow(
|
||||
actions=[
|
||||
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
|
||||
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
|
||||
{"id": "C", "type": "FakeAction", "name": "动作C", "data": {}},
|
||||
],
|
||||
flows=[
|
||||
{"id": "flow-ac", "source": "A", "target": "C"},
|
||||
{"id": "flow-bc", "source": "B", "target": "C"},
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
assert set(calls) == {"A", "B", "C"}
|
||||
assert calls[-1] == "C"
|
||||
assert joined_outputs["A"] == {"value": "A"}
|
||||
assert joined_outputs["B"] == {"value": "B"}
|
||||
|
||||
|
||||
def test_workflow_executor_any_success_join_runs_after_available_branch(monkeypatch):
|
||||
"""any_success 汇合策略应允许任一满足条件的上游分支触发目标节点。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(
|
||||
calls,
|
||||
results={
|
||||
"A": lambda action, context: ActionResult(
|
||||
success=True,
|
||||
message=f"{action.name}完成",
|
||||
context=context,
|
||||
outputs={"items": ["movie"]}
|
||||
)
|
||||
}
|
||||
)
|
||||
workflow = _build_workflow(
|
||||
actions=[
|
||||
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
|
||||
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
|
||||
{"id": "C", "type": "FakeAction", "name": "动作C", "data": {}},
|
||||
{"id": "D", "type": "FakeAction", "name": "动作D", "data": {"join_policy": "any_success"}},
|
||||
],
|
||||
flows=[
|
||||
{"id": "flow-ab", "source": "A", "target": "B", "condition": "outputs.A.items.count == 0"},
|
||||
{"id": "flow-ac", "source": "A", "target": "C", "condition": "outputs.A.items.count > 0"},
|
||||
{"id": "flow-bd", "source": "B", "target": "D"},
|
||||
{"id": "flow-cd", "source": "C", "target": "D"},
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
assert calls == ["A", "C", "D"]
|
||||
assert executor.context.progress == 100
|
||||
|
||||
|
||||
def test_workflow_executor_all_done_join_can_continue_after_failure(monkeypatch):
|
||||
"""continue 失败策略配合 all_done 汇合时应继续执行收尾节点。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(
|
||||
calls,
|
||||
results={
|
||||
"A": lambda action, context: ActionResult(success=False, message=f"{action.name}失败", context=context)
|
||||
}
|
||||
)
|
||||
workflow = _build_workflow(
|
||||
actions=[
|
||||
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {"fail_policy": "continue"}},
|
||||
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
|
||||
{"id": "C", "type": "FakeAction", "name": "动作C", "data": {"join_policy": "all_done"}},
|
||||
],
|
||||
flows=[
|
||||
{"id": "flow-ac", "source": "A", "target": "C"},
|
||||
{"id": "flow-bc", "source": "B", "target": "C"},
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
assert set(calls) == {"A", "B", "C"}
|
||||
assert calls[-1] == "C"
|
||||
assert executor.has_failure is True
|
||||
assert executor.success is True
|
||||
|
||||
|
||||
def test_workflow_executor_exclusive_branch_uses_first_matching_flow(monkeypatch):
|
||||
"""互斥分支应只执行第一条满足条件的出边。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(
|
||||
calls,
|
||||
results={
|
||||
"A": lambda action, context: ActionResult(
|
||||
success=True,
|
||||
message=f"{action.name}完成",
|
||||
context=context,
|
||||
outputs={"count": 2}
|
||||
)
|
||||
}
|
||||
)
|
||||
workflow = _build_workflow(
|
||||
actions=[
|
||||
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {"branch_policy": "exclusive"}},
|
||||
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {}},
|
||||
{"id": "C", "type": "FakeAction", "name": "动作C", "data": {}},
|
||||
],
|
||||
flows=[
|
||||
{"id": "flow-ab", "source": "A", "target": "B", "condition": "outputs.A.count > 0"},
|
||||
{"id": "flow-ac", "source": "A", "target": "C", "condition": "outputs.A.count > 1"},
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
assert calls == ["A", "B"]
|
||||
assert executor.node_states["C"] == "skipped"
|
||||
|
||||
|
||||
def test_workflow_executor_passes_declared_inputs(monkeypatch):
|
||||
"""动作输入声明应从 node_outputs 中读取指定路径。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(
|
||||
calls,
|
||||
results={
|
||||
"A": lambda action, context: ActionResult(
|
||||
success=True,
|
||||
message=f"{action.name}完成",
|
||||
context=context,
|
||||
outputs={"torrents": ["a", "b"]}
|
||||
)
|
||||
}
|
||||
)
|
||||
workflow = _build_workflow(
|
||||
actions=[
|
||||
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
|
||||
{
|
||||
"id": "B",
|
||||
"type": "FakeAction",
|
||||
"name": "动作B",
|
||||
"data": {"inputs": ["A.torrents", "outputs.A.torrents.count"]},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
b_inputs = [item for action_id, item, _, _ in fake_manager.received_inputs if action_id == "B"][0]
|
||||
assert b_inputs == {
|
||||
"A.torrents": ["a", "b"],
|
||||
"outputs.A.torrents.count": 2,
|
||||
}
|
||||
|
||||
|
||||
def test_workflow_executor_uses_contract_inputs(monkeypatch):
|
||||
"""未手写输入声明时应按动作契约读取上下文字段。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(
|
||||
calls,
|
||||
contracts={
|
||||
"NeedsTorrentsAction": {
|
||||
"inputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
|
||||
"outputs": [],
|
||||
}
|
||||
},
|
||||
results={
|
||||
"A": lambda action, context: ActionResult(
|
||||
success=True,
|
||||
message=f"{action.name}完成",
|
||||
context=context,
|
||||
outputs={"torrents": ["a", "b"]}
|
||||
)
|
||||
}
|
||||
)
|
||||
workflow = _build_workflow(
|
||||
actions=[
|
||||
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
|
||||
{"id": "B", "type": "NeedsTorrentsAction", "name": "动作B", "data": {}},
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
b_inputs = [item for action_id, item, _, _ in fake_manager.received_inputs if action_id == "B"][0]
|
||||
assert b_inputs == {"torrents": ["a", "b"]}
|
||||
|
||||
|
||||
def test_workflow_executor_persists_structured_state(monkeypatch):
|
||||
"""步骤回调应收到可持久化的结构化执行状态。"""
|
||||
calls = []
|
||||
states = []
|
||||
fake_manager = _FakeWorkflowManager(
|
||||
calls,
|
||||
results={
|
||||
"A": lambda action, context: ActionResult(
|
||||
success=True,
|
||||
message=f"{action.name}完成",
|
||||
context=context,
|
||||
outputs={"items": ["movie"]}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(
|
||||
_build_workflow(actions=[{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}], flows=[]),
|
||||
step_callback=lambda action, context, execution_state, completed: states.append(execution_state),
|
||||
)
|
||||
executor.execute()
|
||||
|
||||
assert states[-1]["nodes"]["A"]["state"] == "success"
|
||||
assert states[-1]["outputs"]["A"]["items"] == ["movie"]
|
||||
assert states[-1]["runtime"]["progress"] == 100
|
||||
|
||||
|
||||
def test_workflow_executor_restores_outputs_from_execution_state(monkeypatch):
|
||||
"""恢复执行时应从结构化状态读取节点输出并继续判断条件边。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(calls)
|
||||
workflow = _build_workflow(
|
||||
execution_state={
|
||||
"nodes": {
|
||||
"A": {"state": "success", "attempt": 1},
|
||||
},
|
||||
"outputs": {
|
||||
"A": {"torrents": ["movie"]},
|
||||
},
|
||||
},
|
||||
flows=[
|
||||
{"id": "flow-ab", "source": "A", "target": "B", "condition": "A.torrents.count > 0"},
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
assert calls == ["B"]
|
||||
assert executor.context.node_outputs["A"]["torrents"] == ["movie"]
|
||||
|
||||
|
||||
def test_workflow_executor_keeps_execution_state_dict_for_non_json_leaf(monkeypatch):
|
||||
"""结构化状态遇到不可序列化叶子节点时仍应保持字典结构。"""
|
||||
calls = []
|
||||
states = []
|
||||
fake_manager = _FakeWorkflowManager(
|
||||
calls,
|
||||
results={
|
||||
"A": lambda action, context: ActionResult(
|
||||
success=True,
|
||||
message=f"{action.name}完成",
|
||||
context=context,
|
||||
outputs={"opaque": _OpaqueValue()}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(
|
||||
_build_workflow(actions=[{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}], flows=[]),
|
||||
step_callback=lambda action, context, execution_state, completed: states.append(execution_state),
|
||||
)
|
||||
executor.execute()
|
||||
|
||||
assert isinstance(states[-1], dict)
|
||||
assert states[-1]["outputs"]["A"]["opaque"] == "opaque-value"
|
||||
|
||||
|
||||
def test_workflow_chain_process_serializes_circular_context(monkeypatch):
|
||||
"""工作流步骤持久化应清洗循环引用和不可序列化上下文。"""
|
||||
calls = []
|
||||
|
||||
def run_action(action, context):
|
||||
"""构造包含循环引用的上下文。"""
|
||||
context.workflow_context["self"] = context.workflow_context
|
||||
context.workflow_context["opaque"] = _OpaqueValue()
|
||||
return ActionResult(success=True, message=f"{action.name}完成", context=context)
|
||||
|
||||
fake_manager = _FakeWorkflowManager(calls, results={"A": run_action})
|
||||
workflow = _build_workflow(
|
||||
actions=[{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}],
|
||||
flows=[{"id": "flow-end", "source": "A", "target": "END", "animated": True}],
|
||||
)
|
||||
fake_oper = _FakeWorkflowOper(workflow)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module, "WorkflowOper", lambda: fake_oper)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
success, message = workflow_module.WorkflowChain.process(workflow_id=1)
|
||||
|
||||
assert success is True
|
||||
assert message == ""
|
||||
assert fake_oper.succeeded is True
|
||||
saved_workflow_context = fake_oper.steps[-1]["context"]["workflow_context"]
|
||||
saved_self = saved_workflow_context["self"]
|
||||
|
||||
assert saved_workflow_context["opaque"] == "opaque-value"
|
||||
if isinstance(saved_self, dict):
|
||||
assert saved_self["self"] == workflow_module.CIRCULAR_REFERENCE_PLACEHOLDER
|
||||
assert saved_self["opaque"] == "opaque-value"
|
||||
else:
|
||||
assert saved_self == workflow_module.CIRCULAR_REFERENCE_PLACEHOLDER
|
||||
|
||||
|
||||
def test_workflow_executor_concurrency_key_serializes_parallel_nodes(monkeypatch):
|
||||
"""相同 concurrency_key 的并行节点不应同时运行。"""
|
||||
calls = []
|
||||
active_count = 0
|
||||
max_active_count = 0
|
||||
lock = threading.Lock()
|
||||
|
||||
def run_action(action, context):
|
||||
"""记录同一并发键下的同时运行数量。"""
|
||||
nonlocal active_count, max_active_count
|
||||
with lock:
|
||||
active_count += 1
|
||||
max_active_count = max(max_active_count, active_count)
|
||||
time.sleep(0.05)
|
||||
with lock:
|
||||
active_count -= 1
|
||||
return ActionResult(success=True, message=f"{action.name}完成", context=context)
|
||||
|
||||
fake_manager = _FakeWorkflowManager(calls, results={"A": run_action, "B": run_action})
|
||||
workflow = _build_workflow(
|
||||
actions=[
|
||||
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {"concurrency_key": "download"}},
|
||||
{"id": "B", "type": "FakeAction", "name": "动作B", "data": {"concurrency_key": "download"}},
|
||||
],
|
||||
flows=[],
|
||||
execution_config={"max_workers": 2},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
assert set(calls) == {"A", "B"}
|
||||
assert max_active_count == 1
|
||||
|
||||
|
||||
def test_workflow_executor_filter_action_replaces_artifact_outputs(monkeypatch):
|
||||
"""过滤类动作默认应替换列表输出,避免把过滤前数据重新合并回来。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(
|
||||
calls,
|
||||
results={
|
||||
"A": lambda action, context: ActionResult(
|
||||
success=True,
|
||||
message=f"{action.name}完成",
|
||||
context=context,
|
||||
outputs={"torrents": ["old", "keep"]}
|
||||
),
|
||||
"B": lambda action, context: ActionResult(
|
||||
success=True,
|
||||
message=f"{action.name}完成",
|
||||
context=context,
|
||||
outputs={"torrents": ["keep"]}
|
||||
),
|
||||
}
|
||||
)
|
||||
workflow = _build_workflow(
|
||||
actions=[
|
||||
{"id": "A", "type": "FakeAction", "name": "动作A", "data": {}},
|
||||
{"id": "B", "type": "FilterTorrentsAction", "name": "过滤资源", "data": {}},
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(workflow)
|
||||
executor.execute()
|
||||
|
||||
assert executor.context.torrents == ["keep"]
|
||||
assert executor.context.artifacts["torrents"] == ["keep"]
|
||||
|
||||
|
||||
def test_workflow_executor_stop_is_not_success(monkeypatch):
|
||||
"""停止信号不应被执行器汇报为成功完成。"""
|
||||
calls = []
|
||||
fake_manager = _FakeWorkflowManager(calls)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None)
|
||||
monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: True)
|
||||
|
||||
executor = workflow_module.WorkflowExecutor(_build_workflow())
|
||||
executor.execute()
|
||||
|
||||
assert calls == []
|
||||
assert executor.stopped is True
|
||||
assert executor.success is False
|
||||
assert executor.errmsg == "工作流已停止"
|
||||
|
||||
|
||||
def test_workflow_context_merge_preserves_runtime_objects():
|
||||
"""合并上下文时应保留运行时对象,而不是转成字典。"""
|
||||
executor = object.__new__(workflow_module.WorkflowExecutor)
|
||||
executor.context = ActionContext()
|
||||
runtime_torrent = SimpleNamespace(title="runtime torrent")
|
||||
result_context = ActionContext()
|
||||
result_context.torrents.append(runtime_torrent)
|
||||
|
||||
executor.merge_context(result_context)
|
||||
|
||||
assert executor.context.torrents[0] is runtime_torrent
|
||||
|
||||
|
||||
class _FakeEventManager:
|
||||
"""记录事件监听器注册和移除次数。"""
|
||||
|
||||
def __init__(self):
|
||||
self.added = []
|
||||
self.removed = []
|
||||
|
||||
def add_event_listener(self, event_type, handler):
|
||||
self.added.append(event_type)
|
||||
|
||||
def remove_event_listener(self, event_type, handler):
|
||||
self.removed.append(event_type)
|
||||
|
||||
|
||||
def test_workflow_event_listener_keeps_shared_handler_until_last_workflow(monkeypatch):
|
||||
"""同一事件下移除单个工作流时不应断开其他工作流监听。"""
|
||||
fake_eventmanager = _FakeEventManager()
|
||||
manager = object.__new__(workflow_package.WorkFlowManager)
|
||||
manager._lock = threading.Lock()
|
||||
manager._event_workflows = {}
|
||||
|
||||
monkeypatch.setattr(workflow_package, "eventmanager", fake_eventmanager)
|
||||
|
||||
manager.register_workflow_event(1, EventType.DownloadAdded.value)
|
||||
manager.register_workflow_event(2, EventType.DownloadAdded.value)
|
||||
manager.remove_workflow_event(1, EventType.DownloadAdded.value)
|
||||
|
||||
assert fake_eventmanager.added == [EventType.DownloadAdded]
|
||||
assert fake_eventmanager.removed == []
|
||||
assert manager.get_event_workflows() == {EventType.DownloadAdded.value: [2]}
|
||||
|
||||
manager.remove_workflow_event(2, EventType.DownloadAdded.value)
|
||||
|
||||
assert fake_eventmanager.removed == [EventType.DownloadAdded]
|
||||
assert manager.get_event_workflows() == {}
|
||||
|
||||
|
||||
def test_workflow_manager_retries_action_until_success(monkeypatch):
|
||||
"""动作管理器应按 retry 配置重试失败动作。"""
|
||||
|
||||
class RetryAction:
|
||||
"""模拟第二次才成功的动作。"""
|
||||
|
||||
call_count = 0
|
||||
|
||||
def __init__(self, action_id):
|
||||
self.action_id = action_id
|
||||
|
||||
def execute_with_inputs(self, workflow_id, params, inputs, runtime, context):
|
||||
"""执行动作并在第二次返回成功。"""
|
||||
_ = workflow_id, params, inputs, runtime
|
||||
RetryAction.call_count += 1
|
||||
if RetryAction.call_count == 1:
|
||||
return ActionResult(success=False, message="第一次失败", context=context)
|
||||
return ActionResult(success=True, message="第二次成功", context=context, outputs={"ok": True})
|
||||
|
||||
manager = object.__new__(workflow_package.WorkFlowManager)
|
||||
manager._actions = {"RetryAction": RetryAction}
|
||||
monkeypatch.setattr(workflow_package.global_vars, "is_workflow_stopped", lambda workflow_id: False)
|
||||
|
||||
result = manager.execute(
|
||||
workflow_id=1,
|
||||
action=Action(
|
||||
id="retry",
|
||||
type="RetryAction",
|
||||
name="重试动作",
|
||||
data={"retry": {"max_attempts": 2, "interval": 0}},
|
||||
),
|
||||
context=ActionContext(),
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.attempts == 2
|
||||
assert result.outputs == {"ok": True}
|
||||
assert RetryAction.call_count == 2
|
||||
@@ -1,2 +1,2 @@
|
||||
APP_VERSION = 'v2.13.4'
|
||||
FRONTEND_VERSION = 'v2.13.4'
|
||||
APP_VERSION = 'v2.13.5-1'
|
||||
FRONTEND_VERSION = 'v2.13.5'
|
||||
|
||||
Reference in New Issue
Block a user