mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-11 18:50:59 +08:00
feat(workflow): implement action contract management for inputs and outputs
This commit is contained in:
@@ -242,6 +242,7 @@ class WorkFlowManager(metaclass=Singleton):
|
||||
"type": key,
|
||||
"name": action.name,
|
||||
"description": action.description,
|
||||
"contract": action.get_contract(),
|
||||
"data": {
|
||||
"label": action.name,
|
||||
**action.data
|
||||
@@ -249,6 +250,15 @@ class WorkFlowManager(metaclass=Singleton):
|
||||
} for key, action in self._actions.items()
|
||||
]
|
||||
|
||||
def get_action_contract(self, action_type: str) -> dict:
|
||||
"""
|
||||
获取动作输入输出契约。
|
||||
"""
|
||||
action = self._actions.get(action_type)
|
||||
if not action or not hasattr(action, "get_contract"):
|
||||
return {}
|
||||
return action.get_contract()
|
||||
|
||||
def update_workflow_event(self, workflow: Workflow):
|
||||
"""
|
||||
更新工作流事件触发器
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
@@ -23,6 +23,8 @@ class BaseAction(ABC):
|
||||
_message = ""
|
||||
# 缓存键值
|
||||
_cache_key = "WorkflowCache-%s"
|
||||
# 动作输入输出契约,由具体动作按需覆盖
|
||||
contract = {}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
self._action_id = action_id
|
||||
@@ -48,6 +50,41 @@ class BaseAction(ABC):
|
||||
def data(cls) -> dict: # noqa
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_contract(cls) -> dict:
|
||||
"""
|
||||
获取动作输入输出契约。
|
||||
"""
|
||||
contract = getattr(cls, "contract", None) or {}
|
||||
input_fields = cls._build_contract_fields(contract.get("inputs") or [])
|
||||
output_fields = cls._build_contract_fields(contract.get("outputs") or [])
|
||||
return {
|
||||
"inputs": input_fields,
|
||||
"outputs": output_fields,
|
||||
"condition_fields": output_fields,
|
||||
"concurrency_key": contract.get("concurrency_key"),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _build_contract_fields(cls, fields: list) -> list:
|
||||
"""
|
||||
标准化动作契约字段。
|
||||
"""
|
||||
result = []
|
||||
for field in fields:
|
||||
if isinstance(field, str):
|
||||
field = {"name": field}
|
||||
if not isinstance(field, dict) or not field.get("name"):
|
||||
continue
|
||||
result.append({
|
||||
"name": field["name"],
|
||||
"label": field.get("label") or field["name"],
|
||||
"kind": field.get("kind") or "scalar",
|
||||
"merge": field.get("merge"),
|
||||
"identity": field.get("identity"),
|
||||
})
|
||||
return result
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""
|
||||
@@ -115,10 +152,55 @@ class BaseAction(ABC):
|
||||
"""
|
||||
使用显式输入与运行期信息执行动作。
|
||||
"""
|
||||
_ = inputs, runtime
|
||||
self._apply_inputs_to_context(inputs=inputs, context=context)
|
||||
self._apply_runtime_to_context(runtime=runtime, context=context)
|
||||
result_context = self.execute(workflow_id, params, context)
|
||||
outputs = self._extract_outputs_from_context(result_context)
|
||||
return ActionResult(
|
||||
success=self.success,
|
||||
message=self.message,
|
||||
context=result_context
|
||||
context=result_context,
|
||||
outputs=outputs
|
||||
)
|
||||
|
||||
def _apply_inputs_to_context(self, inputs: dict, context: ActionContext) -> None:
|
||||
"""
|
||||
将显式输入回填到旧版上下文字段,兼容仍读取 context 的动作。
|
||||
"""
|
||||
inputs = inputs or {}
|
||||
for field in self.get_contract().get("inputs") or []:
|
||||
missing = object()
|
||||
field_name = field["name"]
|
||||
value = inputs.get(field_name, missing)
|
||||
if value is missing:
|
||||
# 兼容旧版节点输入路径,例如 outputs.A.torrents。
|
||||
for input_key, input_value in inputs.items():
|
||||
if isinstance(input_key, str) and input_key.split(".")[-1] == field_name:
|
||||
value = input_value
|
||||
break
|
||||
if value is not missing:
|
||||
setattr(context, field_name, value)
|
||||
|
||||
@staticmethod
|
||||
def _apply_runtime_to_context(runtime: dict, context: ActionContext) -> None:
|
||||
"""
|
||||
将运行期信息写入 runtime_state,供动作和执行状态读取。
|
||||
"""
|
||||
if not runtime:
|
||||
return
|
||||
context.runtime_state = context.runtime_state or {}
|
||||
context.runtime_state["current_action_runtime"] = {
|
||||
key: value for key, value in runtime.items()
|
||||
if key != "cancel_token"
|
||||
}
|
||||
|
||||
def _extract_outputs_from_context(self, context: ActionContext) -> dict[str, Any]:
|
||||
"""
|
||||
按动作契约从上下文提取输出。
|
||||
"""
|
||||
outputs = {}
|
||||
for field in self.get_contract().get("outputs") or []:
|
||||
value = getattr(context, field["name"], None)
|
||||
if value not in (None, "", [], {}):
|
||||
outputs[field["name"]] = value
|
||||
return outputs
|
||||
|
||||
@@ -26,6 +26,12 @@ class AddDownloadAction(BaseAction):
|
||||
添加下载资源
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"inputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
|
||||
"outputs": [{"name": "downloads", "label": "下载任务", "kind": "list"}],
|
||||
"concurrency_key": "download",
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._added_downloads = []
|
||||
|
||||
@@ -19,6 +19,11 @@ class AddSubscribeAction(BaseAction):
|
||||
添加订阅
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"inputs": [{"name": "medias", "label": "媒体", "kind": "list"}],
|
||||
"outputs": [{"name": "subscribes", "label": "订阅", "kind": "list"}],
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._added_subscribes = []
|
||||
|
||||
@@ -16,6 +16,12 @@ class FetchDownloadsAction(BaseAction):
|
||||
获取下载任务
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"inputs": [{"name": "downloads", "label": "下载任务", "kind": "list"}],
|
||||
"outputs": [{"name": "downloads", "label": "下载任务", "kind": "list", "merge": "replace"}],
|
||||
"concurrency_key": "download",
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._downloads = []
|
||||
|
||||
@@ -27,6 +27,10 @@ class FetchMediasAction(BaseAction):
|
||||
获取媒体数据
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"outputs": [{"name": "medias", "label": "媒体", "kind": "list"}],
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
|
||||
|
||||
@@ -30,6 +30,10 @@ class FetchRssAction(BaseAction):
|
||||
获取RSS资源列表
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"outputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._rss_torrents = []
|
||||
|
||||
@@ -30,6 +30,11 @@ class FetchTorrentsAction(BaseAction):
|
||||
搜索站点资源
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"inputs": [{"name": "medias", "label": "媒体", "kind": "list"}],
|
||||
"outputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._torrents = []
|
||||
|
||||
@@ -22,6 +22,11 @@ class FilterMediasAction(BaseAction):
|
||||
过滤媒体数据
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"inputs": [{"name": "medias", "label": "媒体", "kind": "list"}],
|
||||
"outputs": [{"name": "medias", "label": "媒体", "kind": "list", "merge": "replace"}],
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._medias = []
|
||||
|
||||
@@ -27,6 +27,11 @@ class FilterTorrentsAction(BaseAction):
|
||||
过滤资源数据
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"inputs": [{"name": "torrents", "label": "资源", "kind": "list"}],
|
||||
"outputs": [{"name": "torrents", "label": "资源", "kind": "list", "merge": "replace"}],
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._torrents = []
|
||||
|
||||
@@ -20,6 +20,8 @@ class InvokePluginAction(BaseAction):
|
||||
调用插件
|
||||
"""
|
||||
|
||||
contract = {}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._success = False
|
||||
|
||||
@@ -7,6 +7,8 @@ class NoteAction(BaseAction):
|
||||
备注
|
||||
"""
|
||||
|
||||
contract = {}
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def name(cls) -> str: # noqa
|
||||
|
||||
@@ -24,6 +24,10 @@ class ScanFileAction(BaseAction):
|
||||
整理文件
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"outputs": [{"name": "fileitems", "label": "文件", "kind": "list"}],
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._fileitems = []
|
||||
|
||||
@@ -18,6 +18,11 @@ class ScrapeFileAction(BaseAction):
|
||||
刮削文件
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"inputs": [{"name": "fileitems", "label": "文件", "kind": "list"}],
|
||||
"outputs": [{"name": "fileitems", "label": "文件", "kind": "list"}],
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._scraped_files = []
|
||||
|
||||
@@ -16,6 +16,8 @@ class SendEventAction(BaseAction):
|
||||
发送事件
|
||||
"""
|
||||
|
||||
contract = {}
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def name(cls) -> str: # noqa
|
||||
|
||||
@@ -20,6 +20,8 @@ class SendMessageAction(BaseAction):
|
||||
发送消息
|
||||
"""
|
||||
|
||||
contract = {}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
|
||||
|
||||
@@ -26,6 +26,15 @@ class TransferFileAction(BaseAction):
|
||||
整理文件
|
||||
"""
|
||||
|
||||
contract = {
|
||||
"inputs": [
|
||||
{"name": "downloads", "label": "下载任务", "kind": "list"},
|
||||
{"name": "fileitems", "label": "文件", "kind": "list"},
|
||||
],
|
||||
"outputs": [{"name": "fileitems", "label": "文件", "kind": "list"}],
|
||||
"concurrency_key": "transfer",
|
||||
}
|
||||
|
||||
def __init__(self, action_id: str):
|
||||
super().__init__(action_id)
|
||||
self._fileitems = []
|
||||
|
||||
Reference in New Issue
Block a user