feat(workflow): implement action contract management for inputs and outputs

This commit is contained in:
jxxghp
2026-06-04 21:06:25 +08:00
parent a2984530f8
commit 97cfcda03c
20 changed files with 341 additions and 10 deletions

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Union
from typing import Any, Union
from app.chain import ChainBase
from app.db.systemconfig_oper import SystemConfigOper
@@ -23,6 +23,8 @@ class BaseAction(ABC):
_message = ""
# 缓存键值
_cache_key = "WorkflowCache-%s"
# 动作输入输出契约,由具体动作按需覆盖
contract = {}
def __init__(self, action_id: str):
self._action_id = action_id
@@ -48,6 +50,41 @@ class BaseAction(ABC):
def data(cls) -> dict: # noqa
pass
@classmethod
def get_contract(cls) -> dict:
"""
获取动作输入输出契约。
"""
contract = getattr(cls, "contract", None) or {}
input_fields = cls._build_contract_fields(contract.get("inputs") or [])
output_fields = cls._build_contract_fields(contract.get("outputs") or [])
return {
"inputs": input_fields,
"outputs": output_fields,
"condition_fields": output_fields,
"concurrency_key": contract.get("concurrency_key"),
}
@classmethod
def _build_contract_fields(cls, fields: list) -> list:
"""
标准化动作契约字段。
"""
result = []
for field in fields:
if isinstance(field, str):
field = {"name": field}
if not isinstance(field, dict) or not field.get("name"):
continue
result.append({
"name": field["name"],
"label": field.get("label") or field["name"],
"kind": field.get("kind") or "scalar",
"merge": field.get("merge"),
"identity": field.get("identity"),
})
return result
@property
def done(self) -> bool:
"""
@@ -115,10 +152,55 @@ class BaseAction(ABC):
"""
使用显式输入与运行期信息执行动作。
"""
_ = inputs, runtime
self._apply_inputs_to_context(inputs=inputs, context=context)
self._apply_runtime_to_context(runtime=runtime, context=context)
result_context = self.execute(workflow_id, params, context)
outputs = self._extract_outputs_from_context(result_context)
return ActionResult(
success=self.success,
message=self.message,
context=result_context
context=result_context,
outputs=outputs
)
def _apply_inputs_to_context(self, inputs: dict, context: ActionContext) -> None:
"""
将显式输入回填到旧版上下文字段,兼容仍读取 context 的动作。
"""
inputs = inputs or {}
for field in self.get_contract().get("inputs") or []:
missing = object()
field_name = field["name"]
value = inputs.get(field_name, missing)
if value is missing:
# 兼容旧版节点输入路径,例如 outputs.A.torrents。
for input_key, input_value in inputs.items():
if isinstance(input_key, str) and input_key.split(".")[-1] == field_name:
value = input_value
break
if value is not missing:
setattr(context, field_name, value)
@staticmethod
def _apply_runtime_to_context(runtime: dict, context: ActionContext) -> None:
"""
将运行期信息写入 runtime_state供动作和执行状态读取。
"""
if not runtime:
return
context.runtime_state = context.runtime_state or {}
context.runtime_state["current_action_runtime"] = {
key: value for key, value in runtime.items()
if key != "cancel_token"
}
def _extract_outputs_from_context(self, context: ActionContext) -> dict[str, Any]:
"""
按动作契约从上下文提取输出。
"""
outputs = {}
for field in self.get_contract().get("outputs") or []:
value = getattr(context, field["name"], None)
if value not in (None, "", [], {}):
outputs[field["name"]] = value
return outputs