feat:工作流手动中止

This commit is contained in:
jxxghp
2025-02-28 19:02:38 +08:00
parent 4086ba4763
commit 1bd12a9411
18 changed files with 148 additions and 67 deletions

View File

@@ -56,7 +56,7 @@ class BaseAction(ABC):
self._done_flag = True
@abstractmethod
def execute(self, params: ActionParams, context: ActionContext) -> ActionContext:
def execute(self, workflow_id: int, params: ActionParams, context: ActionContext) -> ActionContext:
"""
执行动作
"""

View File

@@ -3,6 +3,7 @@ from pydantic import Field
from app.actions import BaseAction
from app.chain.download import DownloadChain
from app.chain.media import MediaChain
from app.core.config import global_vars
from app.core.metainfo import MetaInfo
from app.log import logger
from app.schemas import ActionParams, ActionContext, DownloadTask, MediaType
@@ -50,12 +51,14 @@ class AddDownloadAction(BaseAction):
def success(self) -> bool:
return not self._has_error
def execute(self, params: dict, context: ActionContext) -> ActionContext:
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
"""
将上下文中的torrents添加到下载任务中
"""
params = AddDownloadParams(**params)
for t in context.torrents:
if global_vars.is_workflow_stopped(workflow_id):
break
if not t.meta_info:
t.meta_info = MetaInfo(title=t.title, subtitle=t.description)
if not t.media_info:

View File

@@ -1,6 +1,6 @@
from app.actions import BaseAction
from app.chain.subscribe import SubscribeChain
from app.core.config import settings
from app.core.config import settings, global_vars
from app.core.context import MediaInfo
from app.db.subscribe_oper import SubscribeOper
from app.log import logger
@@ -46,11 +46,13 @@ class AddSubscribeAction(BaseAction):
def success(self) -> bool:
return not self._has_error
def execute(self, params: dict, context: ActionContext) -> ActionContext:
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
"""
将medias中的信息添加订阅如果订阅不存在的话
"""
for media in context.medias:
if global_vars.is_workflow_stopped(workflow_id):
break
mediainfo = MediaInfo()
mediainfo.from_dict(media.dict())
if self.subscribechain.exists(mediainfo):

View File

@@ -1,4 +1,5 @@
from app.actions import BaseAction, ActionChain
from app.core.config import global_vars
from app.schemas import ActionParams, ActionContext
from app.log import logger
@@ -40,12 +41,14 @@ class FetchDownloadsAction(BaseAction):
def success(self) -> bool:
return self.done
def execute(self, params: dict, context: ActionContext) -> ActionContext:
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
"""
更新downloads中的下载任务状态
"""
__all_complete = False
for download in self._downloads:
if global_vars.is_workflow_stopped(workflow_id):
break
logger.info(f"获取下载任务 {download.download_id} 状态 ...")
torrents = self.chain.list_torrents(hashs=[download.download_id])
if not torrents:

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from app.actions import BaseAction
from app.chain.recommend import RecommendChain
from app.schemas import ActionParams, ActionContext
from app.core.config import settings
from app.core.config import settings, global_vars
from app.core.event import eventmanager
from app.log import logger
from app.schemas import RecommendSourceEventData, MediaInfo
@@ -124,12 +124,14 @@ class FetchMediasAction(BaseAction):
return s
return None
def execute(self, params: dict, context: ActionContext) -> ActionContext:
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
"""
获取媒体数据填充到medias
"""
params = FetchMediasParams(**params)
for name in params.sources:
if global_vars.is_workflow_stopped(workflow_id):
break
source = self.__get_source(name)
if not source:
continue

View File

@@ -3,7 +3,7 @@ from typing import Optional
from pydantic import Field
from app.actions import BaseAction, ActionChain
from app.core.config import settings
from app.core.config import settings, global_vars
from app.core.context import Context
from app.core.metainfo import MetaInfo
from app.helper.rss import RssHelper
@@ -55,7 +55,7 @@ class FetchRssAction(BaseAction):
def success(self) -> bool:
return not self._has_error
def execute(self, params: dict, context: ActionContext) -> ActionContext:
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
"""
请求RSS地址获取数据并解析为资源列表
"""
@@ -86,6 +86,8 @@ class FetchRssAction(BaseAction):
# 组装种子
for item in rss_items:
if global_vars.is_workflow_stopped(workflow_id):
break
if not item.get("title"):
continue
torrentinfo = TorrentInfo(

View File

@@ -1,9 +1,12 @@
import random
import time
from typing import Optional, List
from pydantic import Field
from app.actions import BaseAction
from app.chain.search import SearchChain
from app.core.config import global_vars
from app.log import logger
from app.schemas import ActionParams, ActionContext, MediaType
@@ -12,7 +15,8 @@ class FetchTorrentsParams(ActionParams):
"""
获取站点资源参数
"""
name: str = Field(None, description="资源名称")
search_type: Optional[str] = Field("keyword", description="搜索类型")
name: Optional[str] = Field(None, description="资源名称")
year: Optional[str] = Field(None, description="年份")
type: Optional[str] = Field(None, description="资源类型 (电影/电视剧)")
season: Optional[int] = Field(None, description="季度")
@@ -49,29 +53,49 @@ class FetchTorrentsAction(BaseAction):
def success(self) -> bool:
return self.done
def execute(self, params: dict, context: ActionContext) -> ActionContext:
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
"""
搜索站点,获取资源列表
"""
params = FetchTorrentsParams(**params)
torrents = self.searchchain.search_by_title(title=params.name, sites=params.sites, cache_local=False)
for torrent in torrents:
if params.year and torrent.meta_info.year != params.year:
continue
if params.type and torrent.media_info and torrent.media_info.type != MediaType(params.type):
continue
if params.season and torrent.meta_info.begin_season != params.season:
continue
# 识别媒体信息
torrent.media_info = self.searchchain.recognize_media(torrent.meta_info)
if not torrent.media_info:
logger.warning(f"{torrent.torrent_info.title} 未识别到媒体信息")
continue
self._torrents.append(torrent)
if params.search_type == "keyword":
# 按关键字搜索
torrents = self.searchchain.search_by_title(title=params.name, sites=params.sites, cache_local=False)
for torrent in torrents:
if global_vars.is_workflow_stopped(workflow_id):
break
if params.year and torrent.meta_info.year != params.year:
continue
if params.type and torrent.media_info and torrent.media_info.type != MediaType(params.type):
continue
if params.season and torrent.meta_info.begin_season != params.season:
continue
# 识别媒体信息
torrent.media_info = self.searchchain.recognize_media(torrent.meta_info)
if not torrent.media_info:
logger.warning(f"{torrent.torrent_info.title} 未识别到媒体信息")
continue
self._torrents.append(torrent)
else:
# 搜索媒体列表
for media in context.medias:
if global_vars.is_workflow_stopped(workflow_id):
break
torrents = self.searchchain.search_by_id(tmdbid=media.tmdb_id,
doubanid=media.douban_id,
mtype=MediaType(media.type),
sites=params.sites)
for torrent in torrents:
self._torrents.append(torrent)
# 随机休眠 10-60秒
sleep_time = random.randint(10, 60)
logger.info(f"随机休眠 {sleep_time} 秒 ...")
time.sleep(sleep_time)
if self._torrents:
context.torrents.extend(self._torrents)
logger.info(f"搜索到 {len(self._torrents)} 条资源")
logger.info(f"搜索到 {len(self._torrents)} 条资源")
self.job_done()
return context

View File

@@ -3,6 +3,7 @@ from typing import Optional
from pydantic import Field
from app.actions import BaseAction
from app.core.config import global_vars
from app.schemas import ActionParams, ActionContext
@@ -42,12 +43,14 @@ class FilterMediasAction(BaseAction):
def success(self) -> bool:
return self.done
def execute(self, params: dict, context: ActionContext) -> ActionContext:
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
"""
过滤medias中媒体数据
"""
params = FilterMediasParams(**params)
for media in context.medias:
if global_vars.is_workflow_stopped(workflow_id):
break
if params.type and media.type != params.type:
continue
if params.category and media.category != params.category:

View File

@@ -3,6 +3,7 @@ from typing import Optional, List
from pydantic import Field
from app.actions import BaseAction, ActionChain
from app.core.config import global_vars
from app.helper.torrent import TorrentHelper
from app.schemas import ActionParams, ActionContext
@@ -51,12 +52,14 @@ class FilterTorrentsAction(BaseAction):
def success(self) -> bool:
return self.done
def execute(self, params: dict, context: ActionContext) -> ActionContext:
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
"""
过滤torrents中的资源
"""
params = FilterTorrentsParams(**params)
for torrent in context.torrents:
if global_vars.is_workflow_stopped(workflow_id):
break
if self.torrenthelper.filter_torrent(
torrent_info=torrent.torrent_info,
filter_params={

View File

@@ -1,6 +1,7 @@
from pathlib import Path
from app.actions import BaseAction
from app.core.config import global_vars
from app.schemas import ActionParams, ActionContext
from app.chain.media import MediaChain
from app.chain.storage import StorageChain
@@ -47,11 +48,13 @@ class ScrapeFileAction(BaseAction):
def success(self) -> bool:
return not self._has_error
def execute(self, params: dict, context: ActionContext) -> ActionContext:
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
"""
刮削fileitems中的所有文件
"""
for fileitem in context.fileitems:
if global_vars.is_workflow_stopped(workflow_id):
break
if fileitem in self._scraped_files:
continue
if not self.storagechain.exists(fileitem):

View File

@@ -1,6 +1,7 @@
import copy
from app.actions import BaseAction
from app.core.config import global_vars
from app.schemas import ActionParams, ActionContext
from app.core.event import eventmanager
@@ -36,7 +37,7 @@ class SendEventAction(BaseAction):
def success(self) -> bool:
return self.done
def execute(self, params: dict, context: ActionContext) -> ActionContext:
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
"""
发送events中的事件
"""
@@ -44,6 +45,8 @@ class SendEventAction(BaseAction):
# 按优先级排序,优先级高的先发送
context.events.sort(key=lambda x: x.priority, reverse=True)
for event in copy.deepcopy(context.events):
if global_vars.is_workflow_stopped(workflow_id):
break
eventmanager.send_event(etype=event.event_type, data=event.event_data)
context.events.remove(event)

View File

@@ -4,6 +4,7 @@ from typing import List, Optional, Union
from pydantic import Field
from app.actions import BaseAction, ActionChain
from app.core.config import global_vars
from app.schemas import ActionParams, ActionContext
@@ -43,11 +44,13 @@ class SendMessageAction(BaseAction):
def success(self) -> bool:
return self.done
def execute(self, params: dict, context: ActionContext) -> ActionContext:
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
"""
发送messages中的消息
"""
for message in copy.deepcopy(context.messages):
if global_vars.is_workflow_stopped(workflow_id):
break
if params.client:
message.source = params.client
if params.userid:

View File

@@ -1,6 +1,7 @@
from pathlib import Path
from app.actions import BaseAction
from app.core.config import global_vars
from app.schemas import ActionParams, ActionContext
from app.chain.storage import StorageChain
from app.chain.transfer import TransferChain
@@ -46,11 +47,13 @@ class TransferFileAction(BaseAction):
def success(self) -> bool:
return not self._has_error
def execute(self, params: dict, context: ActionContext) -> ActionContext:
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
"""
从downloads中整理文件记录到fileitems
"""
for download in context.downloads:
if global_vars.is_workflow_stopped(workflow_id):
break
if not download.completed:
logger.info(f"下载任务 {download.download_id} 未完成")
continue

View File

@@ -5,6 +5,7 @@ from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app import schemas
from app.core.config import global_vars
from app.core.workflow import WorkFlowManager
from app.db import get_db
from app.db.models.workflow import Workflow
@@ -112,6 +113,7 @@ def start_workflow(workflow_id: int,
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
Scheduler().update_workflow_job(workflow)
global_vars.workflow_resume(workflow_id)
workflow.update_state(db, workflow_id, "W")
return schemas.Response(success=True)
@@ -127,5 +129,6 @@ def pause_workflow(workflow_id: int,
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
Scheduler().remove_workflow_job(workflow)
global_vars.stop_workflow(workflow_id)
workflow.update_state(db, workflow_id, "P")
return schemas.Response(success=True)

View File

@@ -97,7 +97,8 @@ class WorkflowExecutor:
self.running_tasks += 1
# 已停机
if global_vars.is_system_stopped:
if global_vars.is_workflow_stopped(self.workflow.id):
global_vars.workflow_resume(self.workflow.id)
break
# 已执行的跳过
@@ -108,17 +109,18 @@ class WorkflowExecutor:
# 提交任务到线程池
future = self.executor.submit(
self.execute_node,
self.workflow.id,
node_id,
self.context
)
future.add_done_callback(self.on_node_complete)
def execute_node(self, node_id: int, context: ActionContext) -> Tuple[Action, bool, ActionContext]:
def execute_node(self, workflow_id: int, node_id: int, context: ActionContext) -> Tuple[Action, bool, ActionContext]:
"""
执行单个节点操作返回修改后的上下文和节点ID
"""
action = self.actions[node_id]
state, result_ctx = self.workflowmanager.excute(action, context=context)
state, result_ctx = self.workflowmanager.excute(workflow_id, action, context=context)
return action, state, result_ctx
def on_node_complete(self, future):
@@ -127,35 +129,32 @@ class WorkflowExecutor:
"""
action, state, result_ctx = future.result()
# 节点执行失败
if not state:
self.success = False
self.errmsg = f"{action.name} 失败"
try:
# 节点执行失败
if not state:
self.success = False
self.errmsg = f"{action.name} 失败"
return
with self.lock:
# 更新主上下文
self.merge_context(result_ctx)
# 回调
if self.step_callback:
self.step_callback(action, self.context)
# 处理后继节点
successors = self.adjacency.get(action.id, [])
for succ_id in successors:
with self.lock:
self.indegree[succ_id] -= 1
if self.indegree[succ_id] == 0:
self.queue.append(succ_id)
finally:
# 标记任务完成
with self.lock:
self.running_tasks -= 1
return
with self.lock:
# 更新主上下文
self.merge_context(result_ctx)
# 回调
if self.step_callback:
self.step_callback(action, self.context)
# 处理后继节点
successors = self.adjacency.get(action.id, [])
for succ_id in successors:
with self.lock:
self.indegree[succ_id] -= 1
if self.indegree[succ_id] == 0:
self.queue.append(succ_id)
# 标记任务完成
with self.lock:
self.running_tasks -= 1
def merge_context(self, context: ActionContext):
"""
合并上下文
@@ -221,7 +220,7 @@ class WorkflowChain(ChainBase):
self.workflowoper.fail(workflow_id, result=executor.errmsg)
return False, executor.errmsg
else:
logger.info(f"工作流 {workflow.name} 执行成")
logger.info(f"工作流 {workflow.name} 执行")
self.workflowoper.success(workflow_id)
return True, ""

View File

@@ -607,6 +607,8 @@ class GlobalVar(object):
STOP_EVENT: threading.Event = threading.Event()
# webpush订阅
SUBSCRIPTIONS: List[dict] = []
# 需应急停止的工作流
EMERGENCY_STOP_WORKFLOWS: List[str] = []
def stop_system(self):
"""
@@ -633,6 +635,26 @@ class GlobalVar(object):
"""
self.SUBSCRIPTIONS.append(subscription)
def stop_workflow(self, workflow_id: str):
"""
停止工作流
"""
if workflow_id not in self.EMERGENCY_STOP_WORKFLOWS:
self.EMERGENCY_STOP_WORKFLOWS.append(workflow_id)
def workflow_resume(self, workflow_id: str):
"""
恢复工作流
"""
if workflow_id in self.EMERGENCY_STOP_WORKFLOWS:
self.EMERGENCY_STOP_WORKFLOWS.remove(workflow_id)
def is_workflow_stopped(self, workflow_id: str):
"""
是否停止工作流
"""
return self.is_system_stopped or workflow_id in self.EMERGENCY_STOP_WORKFLOWS
# 实例化配置
settings = Settings()

View File

@@ -1,6 +1,7 @@
from time import sleep
from typing import Dict, Any, Tuple, List
from app.core.config import global_vars
from app.helper.module import ModuleHelper
from app.log import logger
from app.schemas import Action, ActionContext
@@ -54,7 +55,7 @@ class WorkFlowManager(metaclass=Singleton):
"""
pass
def excute(self, action: Action, context: ActionContext = None) -> Tuple[bool, ActionContext]:
def excute(self, workflow_id: int, action: Action, context: ActionContext = None) -> Tuple[bool, ActionContext]:
"""
执行工作流动作
"""
@@ -66,7 +67,7 @@ class WorkFlowManager(metaclass=Singleton):
# 执行
logger.info(f"执行动作: {action.id} - {action.name}")
try:
result_context = action_obj.execute(action.data, context)
result_context = action_obj.execute(workflow_id, action.data, context)
except Exception as err:
logger.error(f"{action.name} 执行失败: {err}")
return False, context
@@ -74,12 +75,14 @@ class WorkFlowManager(metaclass=Singleton):
loop_interval = action.data.get("loop_interval")
if loop and loop_interval:
while not action_obj.done:
if global_vars.is_workflow_stopped(workflow_id):
break
# 等待
logger.info(f"{action.name} 等待 {loop_interval} 秒后继续执行 ...")
sleep(loop_interval)
# 执行
logger.info(f"继续执行动作: {action.id} - {action.name}")
result_context = action_obj.execute(action.data, result_context)
result_context = action_obj.execute(workflow_id, action.data, result_context)
if action_obj.success:
logger.info(f"{action.name} 执行成功")
else:

View File

@@ -1,6 +1,6 @@
from datetime import datetime
from sqlalchemy import Column, Integer, JSON, Sequence, String
from sqlalchemy import Column, Integer, JSON, Sequence, String, and_
from app.db import Base, db_query, db_update
@@ -63,7 +63,7 @@ class Workflow(Base):
@staticmethod
@db_update
def fail(db, wid: int, result: str):
db.query(Workflow).filter(Workflow.id == wid).update({
db.query(Workflow).filter(and_(Workflow.id == wid, Workflow.state != "P")).update({
"state": 'F',
"result": result,
"last_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
@@ -73,7 +73,7 @@ class Workflow(Base):
@staticmethod
@db_update
def success(db, wid: int, result: str = None):
db.query(Workflow).filter(Workflow.id == wid).update({
db.query(Workflow).filter(and_(Workflow.id == wid, Workflow.state != "P")).update({
"state": 'S',
"result": result,
"run_count": Workflow.run_count + 1,