From 1bd12a941166cc7140b32aa8f4b670b4fcf26096 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Fri, 28 Feb 2025 19:02:38 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E5=B7=A5=E4=BD=9C=E6=B5=81?= =?UTF-8?q?=E6=89=8B=E5=8A=A8=E4=B8=AD=E6=AD=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/actions/__init__.py | 2 +- app/actions/add_download.py | 5 ++- app/actions/add_subscribe.py | 6 ++-- app/actions/fetch_downloads.py | 5 ++- app/actions/fetch_medias.py | 6 ++-- app/actions/fetch_rss.py | 6 ++-- app/actions/fetch_torrents.py | 58 ++++++++++++++++++++++++---------- app/actions/filter_medias.py | 5 ++- app/actions/filter_torrents.py | 5 ++- app/actions/scrape_file.py | 5 ++- app/actions/send_event.py | 5 ++- app/actions/send_message.py | 5 ++- app/actions/transfer_file.py | 5 ++- app/api/endpoints/workflow.py | 3 ++ app/chain/workflow.py | 57 ++++++++++++++++----------------- app/core/config.py | 22 +++++++++++++ app/core/workflow.py | 9 ++++-- app/db/models/workflow.py | 6 ++-- 18 files changed, 148 insertions(+), 67 deletions(-) diff --git a/app/actions/__init__.py b/app/actions/__init__.py index eefebb0b..95116fa2 100644 --- a/app/actions/__init__.py +++ b/app/actions/__init__.py @@ -56,7 +56,7 @@ class BaseAction(ABC): self._done_flag = True @abstractmethod - def execute(self, params: ActionParams, context: ActionContext) -> ActionContext: + def execute(self, workflow_id: int, params: ActionParams, context: ActionContext) -> ActionContext: """ 执行动作 """ diff --git a/app/actions/add_download.py b/app/actions/add_download.py index 1593a563..c0b6dc1c 100644 --- a/app/actions/add_download.py +++ b/app/actions/add_download.py @@ -3,6 +3,7 @@ from pydantic import Field from app.actions import BaseAction from app.chain.download import DownloadChain from app.chain.media import MediaChain +from app.core.config import global_vars from app.core.metainfo import MetaInfo from app.log import logger from app.schemas import ActionParams, ActionContext, DownloadTask, MediaType @@ -50,12 +51,14 @@ class AddDownloadAction(BaseAction): def success(self) -> bool: return not self._has_error - def execute(self, params: dict, context: ActionContext) -> ActionContext: + def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext: """ 将上下文中的torrents添加到下载任务中 """ params = AddDownloadParams(**params) for t in context.torrents: + if global_vars.is_workflow_stopped(workflow_id): + break if not t.meta_info: t.meta_info = MetaInfo(title=t.title, subtitle=t.description) if not t.media_info: diff --git a/app/actions/add_subscribe.py b/app/actions/add_subscribe.py index 44da0f87..c542c2e9 100644 --- a/app/actions/add_subscribe.py +++ b/app/actions/add_subscribe.py @@ -1,6 +1,6 @@ from app.actions import BaseAction from app.chain.subscribe import SubscribeChain -from app.core.config import settings +from app.core.config import settings, global_vars from app.core.context import MediaInfo from app.db.subscribe_oper import SubscribeOper from app.log import logger @@ -46,11 +46,13 @@ class AddSubscribeAction(BaseAction): def success(self) -> bool: return not self._has_error - def execute(self, params: dict, context: ActionContext) -> ActionContext: + def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext: """ 将medias中的信息添加订阅,如果订阅不存在的话 """ for media in context.medias: + if global_vars.is_workflow_stopped(workflow_id): + break mediainfo = MediaInfo() mediainfo.from_dict(media.dict()) if self.subscribechain.exists(mediainfo): diff --git a/app/actions/fetch_downloads.py b/app/actions/fetch_downloads.py index fb7fbb15..0ebdb023 100644 --- a/app/actions/fetch_downloads.py +++ b/app/actions/fetch_downloads.py @@ -1,4 +1,5 @@ from app.actions import BaseAction, ActionChain +from app.core.config import global_vars from app.schemas import ActionParams, ActionContext from app.log import logger @@ -40,12 +41,14 @@ class FetchDownloadsAction(BaseAction): def success(self) -> bool: return self.done - def execute(self, params: dict, context: ActionContext) -> ActionContext: + def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext: """ 更新downloads中的下载任务状态 """ __all_complete = False for download in self._downloads: + if global_vars.is_workflow_stopped(workflow_id): + break logger.info(f"获取下载任务 {download.download_id} 状态 ...") torrents = self.chain.list_torrents(hashs=[download.download_id]) if not torrents: diff --git a/app/actions/fetch_medias.py b/app/actions/fetch_medias.py index 9b2d777f..7cb74367 100644 --- a/app/actions/fetch_medias.py +++ b/app/actions/fetch_medias.py @@ -5,7 +5,7 @@ from pydantic import Field from app.actions import BaseAction from app.chain.recommend import RecommendChain from app.schemas import ActionParams, ActionContext -from app.core.config import settings +from app.core.config import settings, global_vars from app.core.event import eventmanager from app.log import logger from app.schemas import RecommendSourceEventData, MediaInfo @@ -124,12 +124,14 @@ class FetchMediasAction(BaseAction): return s return None - def execute(self, params: dict, context: ActionContext) -> ActionContext: + def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext: """ 获取媒体数据,填充到medias """ params = FetchMediasParams(**params) for name in params.sources: + if global_vars.is_workflow_stopped(workflow_id): + break source = self.__get_source(name) if not source: continue diff --git a/app/actions/fetch_rss.py b/app/actions/fetch_rss.py index 63cbabdf..d11cad79 100644 --- a/app/actions/fetch_rss.py +++ b/app/actions/fetch_rss.py @@ -3,7 +3,7 @@ from typing import Optional from pydantic import Field from app.actions import BaseAction, ActionChain -from app.core.config import settings +from app.core.config import settings, global_vars from app.core.context import Context from app.core.metainfo import MetaInfo from app.helper.rss import RssHelper @@ -55,7 +55,7 @@ class FetchRssAction(BaseAction): def success(self) -> bool: return not self._has_error - def execute(self, params: dict, context: ActionContext) -> ActionContext: + def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext: """ 请求RSS地址获取数据,并解析为资源列表 """ @@ -86,6 +86,8 @@ class FetchRssAction(BaseAction): # 组装种子 for item in rss_items: + if global_vars.is_workflow_stopped(workflow_id): + break if not item.get("title"): continue torrentinfo = TorrentInfo( diff --git a/app/actions/fetch_torrents.py b/app/actions/fetch_torrents.py index 8aa4b008..cb76dff5 100644 --- a/app/actions/fetch_torrents.py +++ b/app/actions/fetch_torrents.py @@ -1,9 +1,12 @@ +import random +import time from typing import Optional, List from pydantic import Field from app.actions import BaseAction from app.chain.search import SearchChain +from app.core.config import global_vars from app.log import logger from app.schemas import ActionParams, ActionContext, MediaType @@ -12,7 +15,8 @@ class FetchTorrentsParams(ActionParams): """ 获取站点资源参数 """ - name: str = Field(None, description="资源名称") + search_type: Optional[str] = Field("keyword", description="搜索类型") + name: Optional[str] = Field(None, description="资源名称") year: Optional[str] = Field(None, description="年份") type: Optional[str] = Field(None, description="资源类型 (电影/电视剧)") season: Optional[int] = Field(None, description="季度") @@ -49,29 +53,49 @@ class FetchTorrentsAction(BaseAction): def success(self) -> bool: return self.done - def execute(self, params: dict, context: ActionContext) -> ActionContext: + def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext: """ 搜索站点,获取资源列表 """ params = FetchTorrentsParams(**params) - torrents = self.searchchain.search_by_title(title=params.name, sites=params.sites, cache_local=False) - for torrent in torrents: - if params.year and torrent.meta_info.year != params.year: - continue - if params.type and torrent.media_info and torrent.media_info.type != MediaType(params.type): - continue - if params.season and torrent.meta_info.begin_season != params.season: - continue - # 识别媒体信息 - torrent.media_info = self.searchchain.recognize_media(torrent.meta_info) - if not torrent.media_info: - logger.warning(f"{torrent.torrent_info.title} 未识别到媒体信息") - continue - self._torrents.append(torrent) + if params.search_type == "keyword": + # 按关键字搜索 + torrents = self.searchchain.search_by_title(title=params.name, sites=params.sites, cache_local=False) + for torrent in torrents: + if global_vars.is_workflow_stopped(workflow_id): + break + if params.year and torrent.meta_info.year != params.year: + continue + if params.type and torrent.media_info and torrent.media_info.type != MediaType(params.type): + continue + if params.season and torrent.meta_info.begin_season != params.season: + continue + # 识别媒体信息 + torrent.media_info = self.searchchain.recognize_media(torrent.meta_info) + if not torrent.media_info: + logger.warning(f"{torrent.torrent_info.title} 未识别到媒体信息") + continue + self._torrents.append(torrent) + else: + # 搜索媒体列表 + for media in context.medias: + if global_vars.is_workflow_stopped(workflow_id): + break + torrents = self.searchchain.search_by_id(tmdbid=media.tmdb_id, + doubanid=media.douban_id, + mtype=MediaType(media.type), + sites=params.sites) + for torrent in torrents: + self._torrents.append(torrent) + + # 随机休眠 10-60秒 + sleep_time = random.randint(10, 60) + logger.info(f"随机休眠 {sleep_time} 秒 ...") + time.sleep(sleep_time) if self._torrents: context.torrents.extend(self._torrents) - logger.info(f"搜索到 {len(self._torrents)} 条资源") + logger.info(f"共搜索到 {len(self._torrents)} 条资源") self.job_done() return context diff --git a/app/actions/filter_medias.py b/app/actions/filter_medias.py index b0740165..24c74ce0 100644 --- a/app/actions/filter_medias.py +++ b/app/actions/filter_medias.py @@ -3,6 +3,7 @@ from typing import Optional from pydantic import Field from app.actions import BaseAction +from app.core.config import global_vars from app.schemas import ActionParams, ActionContext @@ -42,12 +43,14 @@ class FilterMediasAction(BaseAction): def success(self) -> bool: return self.done - def execute(self, params: dict, context: ActionContext) -> ActionContext: + def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext: """ 过滤medias中媒体数据 """ params = FilterMediasParams(**params) for media in context.medias: + if global_vars.is_workflow_stopped(workflow_id): + break if params.type and media.type != params.type: continue if params.category and media.category != params.category: diff --git a/app/actions/filter_torrents.py b/app/actions/filter_torrents.py index ce96771a..c8f1c4f3 100644 --- a/app/actions/filter_torrents.py +++ b/app/actions/filter_torrents.py @@ -3,6 +3,7 @@ from typing import Optional, List from pydantic import Field from app.actions import BaseAction, ActionChain +from app.core.config import global_vars from app.helper.torrent import TorrentHelper from app.schemas import ActionParams, ActionContext @@ -51,12 +52,14 @@ class FilterTorrentsAction(BaseAction): def success(self) -> bool: return self.done - def execute(self, params: dict, context: ActionContext) -> ActionContext: + def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext: """ 过滤torrents中的资源 """ params = FilterTorrentsParams(**params) for torrent in context.torrents: + if global_vars.is_workflow_stopped(workflow_id): + break if self.torrenthelper.filter_torrent( torrent_info=torrent.torrent_info, filter_params={ diff --git a/app/actions/scrape_file.py b/app/actions/scrape_file.py index e379e187..3c55b4b2 100644 --- a/app/actions/scrape_file.py +++ b/app/actions/scrape_file.py @@ -1,6 +1,7 @@ from pathlib import Path from app.actions import BaseAction +from app.core.config import global_vars from app.schemas import ActionParams, ActionContext from app.chain.media import MediaChain from app.chain.storage import StorageChain @@ -47,11 +48,13 @@ class ScrapeFileAction(BaseAction): def success(self) -> bool: return not self._has_error - def execute(self, params: dict, context: ActionContext) -> ActionContext: + def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext: """ 刮削fileitems中的所有文件 """ for fileitem in context.fileitems: + if global_vars.is_workflow_stopped(workflow_id): + break if fileitem in self._scraped_files: continue if not self.storagechain.exists(fileitem): diff --git a/app/actions/send_event.py b/app/actions/send_event.py index 376d5405..07c6d6d0 100644 --- a/app/actions/send_event.py +++ b/app/actions/send_event.py @@ -1,6 +1,7 @@ import copy from app.actions import BaseAction +from app.core.config import global_vars from app.schemas import ActionParams, ActionContext from app.core.event import eventmanager @@ -36,7 +37,7 @@ class SendEventAction(BaseAction): def success(self) -> bool: return self.done - def execute(self, params: dict, context: ActionContext) -> ActionContext: + def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext: """ 发送events中的事件 """ @@ -44,6 +45,8 @@ class SendEventAction(BaseAction): # 按优先级排序,优先级高的先发送 context.events.sort(key=lambda x: x.priority, reverse=True) for event in copy.deepcopy(context.events): + if global_vars.is_workflow_stopped(workflow_id): + break eventmanager.send_event(etype=event.event_type, data=event.event_data) context.events.remove(event) diff --git a/app/actions/send_message.py b/app/actions/send_message.py index 1d706a51..d6a6b1de 100644 --- a/app/actions/send_message.py +++ b/app/actions/send_message.py @@ -4,6 +4,7 @@ from typing import List, Optional, Union from pydantic import Field from app.actions import BaseAction, ActionChain +from app.core.config import global_vars from app.schemas import ActionParams, ActionContext @@ -43,11 +44,13 @@ class SendMessageAction(BaseAction): def success(self) -> bool: return self.done - def execute(self, params: dict, context: ActionContext) -> ActionContext: + def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext: """ 发送messages中的消息 """ for message in copy.deepcopy(context.messages): + if global_vars.is_workflow_stopped(workflow_id): + break if params.client: message.source = params.client if params.userid: diff --git a/app/actions/transfer_file.py b/app/actions/transfer_file.py index 4cdc78b4..07512e75 100644 --- a/app/actions/transfer_file.py +++ b/app/actions/transfer_file.py @@ -1,6 +1,7 @@ from pathlib import Path from app.actions import BaseAction +from app.core.config import global_vars from app.schemas import ActionParams, ActionContext from app.chain.storage import StorageChain from app.chain.transfer import TransferChain @@ -46,11 +47,13 @@ class TransferFileAction(BaseAction): def success(self) -> bool: return not self._has_error - def execute(self, params: dict, context: ActionContext) -> ActionContext: + def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext: """ 从downloads中整理文件,记录到fileitems """ for download in context.downloads: + if global_vars.is_workflow_stopped(workflow_id): + break if not download.completed: logger.info(f"下载任务 {download.download_id} 未完成") continue diff --git a/app/api/endpoints/workflow.py b/app/api/endpoints/workflow.py index c853d241..46ef5bac 100644 --- a/app/api/endpoints/workflow.py +++ b/app/api/endpoints/workflow.py @@ -5,6 +5,7 @@ from fastapi import APIRouter, Depends from sqlalchemy.orm import Session from app import schemas +from app.core.config import global_vars from app.core.workflow import WorkFlowManager from app.db import get_db from app.db.models.workflow import Workflow @@ -112,6 +113,7 @@ def start_workflow(workflow_id: int, if not workflow: return schemas.Response(success=False, message="工作流不存在") Scheduler().update_workflow_job(workflow) + global_vars.workflow_resume(workflow_id) workflow.update_state(db, workflow_id, "W") return schemas.Response(success=True) @@ -127,5 +129,6 @@ def pause_workflow(workflow_id: int, if not workflow: return schemas.Response(success=False, message="工作流不存在") Scheduler().remove_workflow_job(workflow) + global_vars.stop_workflow(workflow_id) workflow.update_state(db, workflow_id, "P") return schemas.Response(success=True) diff --git a/app/chain/workflow.py b/app/chain/workflow.py index e20d6e61..c5961cf5 100644 --- a/app/chain/workflow.py +++ b/app/chain/workflow.py @@ -97,7 +97,8 @@ class WorkflowExecutor: self.running_tasks += 1 # 已停机 - if global_vars.is_system_stopped: + if global_vars.is_workflow_stopped(self.workflow.id): + global_vars.workflow_resume(self.workflow.id) break # 已执行的跳过 @@ -108,17 +109,18 @@ class WorkflowExecutor: # 提交任务到线程池 future = self.executor.submit( self.execute_node, + self.workflow.id, node_id, self.context ) future.add_done_callback(self.on_node_complete) - def execute_node(self, node_id: int, context: ActionContext) -> Tuple[Action, bool, ActionContext]: + def execute_node(self, workflow_id: int, node_id: int, context: ActionContext) -> Tuple[Action, bool, ActionContext]: """ 执行单个节点操作,返回修改后的上下文和节点ID """ action = self.actions[node_id] - state, result_ctx = self.workflowmanager.excute(action, context=context) + state, result_ctx = self.workflowmanager.excute(workflow_id, action, context=context) return action, state, result_ctx def on_node_complete(self, future): @@ -127,35 +129,32 @@ class WorkflowExecutor: """ action, state, result_ctx = future.result() - # 节点执行失败 - if not state: - self.success = False - self.errmsg = f"{action.name} 失败" + try: + # 节点执行失败 + if not state: + self.success = False + self.errmsg = f"{action.name} 失败" + return + + with self.lock: + # 更新主上下文 + self.merge_context(result_ctx) + # 回调 + if self.step_callback: + self.step_callback(action, self.context) + + # 处理后继节点 + successors = self.adjacency.get(action.id, []) + for succ_id in successors: + with self.lock: + self.indegree[succ_id] -= 1 + if self.indegree[succ_id] == 0: + self.queue.append(succ_id) + finally: # 标记任务完成 with self.lock: self.running_tasks -= 1 - return - - with self.lock: - # 更新主上下文 - self.merge_context(result_ctx) - # 回调 - if self.step_callback: - self.step_callback(action, self.context) - - # 处理后继节点 - successors = self.adjacency.get(action.id, []) - for succ_id in successors: - with self.lock: - self.indegree[succ_id] -= 1 - if self.indegree[succ_id] == 0: - self.queue.append(succ_id) - - # 标记任务完成 - with self.lock: - self.running_tasks -= 1 - def merge_context(self, context: ActionContext): """ 合并上下文 @@ -221,7 +220,7 @@ class WorkflowChain(ChainBase): self.workflowoper.fail(workflow_id, result=executor.errmsg) return False, executor.errmsg else: - logger.info(f"工作流 {workflow.name} 执行成功") + logger.info(f"工作流 {workflow.name} 执行完成") self.workflowoper.success(workflow_id) return True, "" diff --git a/app/core/config.py b/app/core/config.py index 836cb54e..29bf20da 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -607,6 +607,8 @@ class GlobalVar(object): STOP_EVENT: threading.Event = threading.Event() # webpush订阅 SUBSCRIPTIONS: List[dict] = [] + # 需应急停止的工作流 + EMERGENCY_STOP_WORKFLOWS: List[str] = [] def stop_system(self): """ @@ -633,6 +635,26 @@ class GlobalVar(object): """ self.SUBSCRIPTIONS.append(subscription) + def stop_workflow(self, workflow_id: str): + """ + 停止工作流 + """ + if workflow_id not in self.EMERGENCY_STOP_WORKFLOWS: + self.EMERGENCY_STOP_WORKFLOWS.append(workflow_id) + + def workflow_resume(self, workflow_id: str): + """ + 恢复工作流 + """ + if workflow_id in self.EMERGENCY_STOP_WORKFLOWS: + self.EMERGENCY_STOP_WORKFLOWS.remove(workflow_id) + + def is_workflow_stopped(self, workflow_id: str): + """ + 是否停止工作流 + """ + return self.is_system_stopped or workflow_id in self.EMERGENCY_STOP_WORKFLOWS + # 实例化配置 settings = Settings() diff --git a/app/core/workflow.py b/app/core/workflow.py index 8bdb86ce..01e640c3 100644 --- a/app/core/workflow.py +++ b/app/core/workflow.py @@ -1,6 +1,7 @@ from time import sleep from typing import Dict, Any, Tuple, List +from app.core.config import global_vars from app.helper.module import ModuleHelper from app.log import logger from app.schemas import Action, ActionContext @@ -54,7 +55,7 @@ class WorkFlowManager(metaclass=Singleton): """ pass - def excute(self, action: Action, context: ActionContext = None) -> Tuple[bool, ActionContext]: + def excute(self, workflow_id: int, action: Action, context: ActionContext = None) -> Tuple[bool, ActionContext]: """ 执行工作流动作 """ @@ -66,7 +67,7 @@ class WorkFlowManager(metaclass=Singleton): # 执行 logger.info(f"执行动作: {action.id} - {action.name}") try: - result_context = action_obj.execute(action.data, context) + result_context = action_obj.execute(workflow_id, action.data, context) except Exception as err: logger.error(f"{action.name} 执行失败: {err}") return False, context @@ -74,12 +75,14 @@ class WorkFlowManager(metaclass=Singleton): loop_interval = action.data.get("loop_interval") if loop and loop_interval: while not action_obj.done: + if global_vars.is_workflow_stopped(workflow_id): + break # 等待 logger.info(f"{action.name} 等待 {loop_interval} 秒后继续执行 ...") sleep(loop_interval) # 执行 logger.info(f"继续执行动作: {action.id} - {action.name}") - result_context = action_obj.execute(action.data, result_context) + result_context = action_obj.execute(workflow_id, action.data, result_context) if action_obj.success: logger.info(f"{action.name} 执行成功") else: diff --git a/app/db/models/workflow.py b/app/db/models/workflow.py index 4b7b14d9..31c00f1c 100644 --- a/app/db/models/workflow.py +++ b/app/db/models/workflow.py @@ -1,6 +1,6 @@ from datetime import datetime -from sqlalchemy import Column, Integer, JSON, Sequence, String +from sqlalchemy import Column, Integer, JSON, Sequence, String, and_ from app.db import Base, db_query, db_update @@ -63,7 +63,7 @@ class Workflow(Base): @staticmethod @db_update def fail(db, wid: int, result: str): - db.query(Workflow).filter(Workflow.id == wid).update({ + db.query(Workflow).filter(and_(Workflow.id == wid, Workflow.state != "P")).update({ "state": 'F', "result": result, "last_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S') @@ -73,7 +73,7 @@ class Workflow(Base): @staticmethod @db_update def success(db, wid: int, result: str = None): - db.query(Workflow).filter(Workflow.id == wid).update({ + db.query(Workflow).filter(and_(Workflow.id == wid, Workflow.state != "P")).update({ "state": 'S', "result": result, "run_count": Workflow.run_count + 1,