From 7474ecd02f72a5894b5630c0b3dd408133dd9479 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 4 Jun 2026 14:28:46 +0800 Subject: [PATCH] feat(workflow): enhance action execution with structured results and context management --- app/chain/workflow.py | 458 ++++++++++++++++++++++++++----- app/schemas/workflow.py | 18 +- app/workflow/__init__.py | 51 ++-- tests/test_workflow_execution.py | 191 ++++++++++++- 4 files changed, 624 insertions(+), 94 deletions(-) diff --git a/app/chain/workflow.py b/app/chain/workflow.py index 45a4e0dd..60c72acf 100644 --- a/app/chain/workflow.py +++ b/app/chain/workflow.py @@ -1,3 +1,4 @@ +import ast import base64 import copy import pickle @@ -5,7 +6,7 @@ import threading from collections import defaultdict, deque from concurrent.futures import ThreadPoolExecutor from time import sleep -from typing import Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple from app.chain import ChainBase from app.core.config import global_vars @@ -13,7 +14,7 @@ from app.core.event import Event, eventmanager from app.db.models import Workflow from app.db.workflow_oper import WorkflowOper from app.log import logger -from app.schemas import ActionContext, ActionFlow, Action, ActionExecution +from app.schemas import ActionContext, ActionFlow, Action, ActionExecution, ActionResult from app.schemas.types import EventType from app.workflow import WorkFlowManager @@ -42,13 +43,20 @@ class WorkflowExecutor: self.finished_actions = len(self.completed_actions) self.success = True + self.has_failure = False self.stopped = False self.errmsg = "" + self.node_states = {action_id: "pending" for action_id in self.actions} + for action_id in self.completed_actions: + self.node_states[action_id] = "completed" + self.flow_finished = set() + self.flow_satisfied = set() # 工作流管理器 self.workflowmanager = WorkFlowManager() # 线程安全队列 self.queue = deque() + self.queued_actions = set() # 锁用于保证线程安全 self.lock = threading.Lock() # 线程池 @@ -56,23 +64,14 @@ class WorkflowExecutor: # 跟踪运行中的任务数 self.running_tasks = 0 - # 构建邻接表、入度表 - self.adjacency = defaultdict(list) - self.indegree = defaultdict(int) + # 构建出边与入边表,用于条件流转和多上游汇合。 + self.outgoing_flows = defaultdict(list) + self.incoming_flows = defaultdict(list) for flow in self.flows: - source = flow.source - target = flow.target - self.adjacency[source].append(target) - self.indegree[target] += 1 - - # 初始化所有节点的入度(确保未被引用的节点入度为0) - for action_id in self.actions: - if action_id not in self.indegree: - self.indegree[action_id] = 0 - - for action_id in self.completed_actions: - for succ_id in self.adjacency.get(action_id, []): - self.indegree[succ_id] -= 1 + if not flow.source or not flow.target: + continue + self.outgoing_flows[flow.source].append(flow) + self.incoming_flows[flow.target].append(flow) # 初始上下文 if workflow.current_action and workflow.context: @@ -83,13 +82,17 @@ class WorkflowExecutor: self.context = pickle.loads(decoded_data) else: self.context = ActionContext() + self.context.node_outputs = self.context.node_outputs or {} # 恢复工作流 global_vars.workflow_resume(self.workflow.id) - # 初始化队列,添加入度为0的节点 + # 恢复时重新释放已完成节点的出边,使后继节点能继续执行。 + for action_id in self.completed_actions: + self.release_successors(action_id, source_success=True) + # 初始化队列,添加没有入边的起始节点。 for action_id in self.actions: - if action_id not in self.completed_actions and self.indegree[action_id] == 0: - self.queue.append(action_id) + if action_id not in self.completed_actions and not self.incoming_flows.get(action_id): + self.enqueue_node(action_id) def execute(self) -> None: """ @@ -98,6 +101,7 @@ class WorkflowExecutor: try: while True: should_sleep = False + node_id = None with self.lock: if global_vars.is_workflow_stopped(self.workflow.id): self.success = False @@ -119,6 +123,10 @@ class WorkflowExecutor: else: # 取出队首节点 node_id = self.queue.popleft() + self.queued_actions.discard(node_id) + if self.node_states.get(node_id) != "queued": + continue + self.node_states[node_id] = "running" # 标记任务开始 self.running_tasks += 1 @@ -126,18 +134,7 @@ class WorkflowExecutor: sleep(0.1) continue - # 已停机 - if global_vars.is_workflow_stopped(self.workflow.id): - with self.lock: - self.success = False - self.stopped = True - self.errmsg = "工作流已停止" - self.running_tasks -= 1 - break - - # 已执行的跳过,并继续释放后继节点。 - if node_id in self.completed_actions: - self.on_node_skipped(node_id) + if not node_id: continue # 提交任务到线程池,每个节点使用上下文快照,避免并行节点互相修改同一个对象。 @@ -151,32 +148,34 @@ class WorkflowExecutor: finally: self.executor.shutdown(wait=True, cancel_futures=True) - def execute_node(self, workflow_id: int, node_id: int, - context: ActionContext) -> Tuple[Action, bool, str, ActionContext]: + def execute_node(self, workflow_id: int, node_id: str, + context: ActionContext) -> Tuple[Action, ActionResult]: """ 执行单个节点操作,返回修改后的上下文和节点ID """ action = self.actions[node_id] - state, message, result_ctx = self.workflowmanager.excute(workflow_id, action, context=context) - return action, state, message, result_ctx + action_result = self.workflowmanager.execute(workflow_id, action, context=context) + return action, action_result def on_node_complete(self, future): """ 节点完成回调:更新上下文、处理后继节点 """ try: - action, state, message, result_ctx = future.result() + action, action_result = future.result() with self.lock: if global_vars.is_workflow_stopped(self.workflow.id): self.success = False self.stopped = True self.errmsg = "工作流已停止" return - self.finished_actions += 1 - # 更新当前进度 - self.context.progress = round(self.finished_actions / self.total_actions * 100) if self.total_actions else 100 + state = bool(action_result.success) + message = action_result.message or "" + result_ctx = action_result.context or ActionContext() - # 补充执行历史 + self.finished_actions += 1 + self.update_progress() + # 更新当前进度 self.context.execute_history.append( ActionExecution( action=action.name, @@ -185,28 +184,33 @@ class WorkflowExecutor: ) ) - # 节点执行失败 - if not state: - with self.lock: - self.success = False - self.errmsg = f"{action.name} 失败" - return + # 节点执行失败时默认停止;显式配置 continue/ignore 时继续释放后续 all_done 汇合。 + if not state: + self.node_states[action.id] = "failed" + fail_policy = self.get_action_fail_policy(action) + if fail_policy != "ignore": + self.has_failure = True + self.errmsg = f"{action.name} 失败" + if fail_policy == "stop": + self.success = False + return + if fail_policy not in ("continue", "ignore"): + self.success = False + self.errmsg = f"{action.name} 失败:无效失败策略 {fail_policy}" + return + self.release_successors(action.id, source_success=False) + return - with self.lock: # 更新主上下文 self.merge_context(result_ctx) + self.record_node_outputs(action.id, action_result, result_ctx) self.completed_actions.add(action.id) + self.node_states[action.id] = "completed" + # 处理后继节点 + self.release_successors(action.id, source_success=True) # 回调 if self.step_callback: self.step_callback(action, self.context) - - # 处理后继节点 - successors = self.adjacency.get(action.id, []) - for succ_id in successors: - with self.lock: - self.indegree[succ_id] -= 1 - if self.indegree[succ_id] == 0: - self.queue.append(succ_id) except Exception as err: logger.error(f"工作流节点执行回调失败: {str(err)}") with self.lock: @@ -217,16 +221,331 @@ class WorkflowExecutor: with self.lock: self.running_tasks -= 1 - def on_node_skipped(self, node_id: str) -> None: + def enqueue_node(self, node_id: str) -> None: """ - 跳过已完成节点,并释放其后继节点。 + 将满足条件的节点加入待执行队列。 """ - with self.lock: - for succ_id in self.adjacency.get(node_id, []): - self.indegree[succ_id] -= 1 - if succ_id not in self.completed_actions and self.indegree[succ_id] == 0: - self.queue.append(succ_id) - self.running_tasks -= 1 + if node_id not in self.actions: + return + if self.node_states.get(node_id) != "pending" or node_id in self.queued_actions: + return + self.queue.append(node_id) + self.queued_actions.add(node_id) + self.node_states[node_id] = "queued" + + def skip_node(self, node_id: str, message: str) -> None: + """ + 将不可达节点标记为跳过,并把跳过状态继续传递给后继节点。 + """ + if node_id not in self.actions: + return + if self.node_states.get(node_id) not in ("pending", "queued"): + return + self.queued_actions.discard(node_id) + self.node_states[node_id] = "skipped" + self.finished_actions += 1 + self.update_progress() + self.context.execute_history.append( + ActionExecution( + action=self.actions[node_id].name, + result=True, + message=message + ) + ) + self.release_successors(node_id, source_success=False) + + def release_successors(self, source_id: str, source_success: bool) -> None: + """ + 根据源节点状态释放出边,并重新判断目标节点是否可运行。 + """ + for flow in self.outgoing_flows.get(source_id, []): + flow_key = self.get_flow_key(flow) + if flow_key in self.flow_finished: + continue + condition_matched = False + if source_success: + try: + condition_matched = self.evaluate_condition(self.get_flow_condition(flow)) + except ValueError as err: + self.success = False + self.errmsg = f"流程条件判断失败:{err}" + return + self.flow_finished.add(flow_key) + if source_success and condition_matched: + self.flow_satisfied.add(flow_key) + self.evaluate_target_state(flow.target) + + def evaluate_target_state(self, target_id: str) -> None: + """ + 按目标节点汇合策略判断节点是否入队或跳过。 + """ + if not target_id or target_id not in self.actions: + return + if self.node_states.get(target_id) != "pending": + return + incoming_flows = self.incoming_flows.get(target_id, []) + if not incoming_flows: + self.enqueue_node(target_id) + return + + total_count = len(incoming_flows) + finished_count = sum(1 for flow in incoming_flows if self.get_flow_key(flow) in self.flow_finished) + satisfied_count = sum(1 for flow in incoming_flows if self.get_flow_key(flow) in self.flow_satisfied) + join_policy = self.get_action_join_policy(self.actions[target_id], incoming_flows) + + if join_policy == "any_success": + if satisfied_count > 0: + self.enqueue_node(target_id) + elif finished_count == total_count: + self.skip_node(target_id, "所有上游条件均未满足,已跳过") + return + + if join_policy == "all_done": + if finished_count == total_count: + self.enqueue_node(target_id) + return + + if join_policy != "all_success": + self.success = False + self.errmsg = f"{self.actions[target_id].name} 汇合策略无效:{join_policy}" + return + + if finished_count != total_count: + return + if satisfied_count == total_count: + self.enqueue_node(target_id) + else: + self.skip_node(target_id, "上游条件未全部满足,已跳过") + + def update_progress(self) -> None: + """ + 根据已完成和已跳过节点数量更新整体进度。 + """ + self.context.progress = round(self.finished_actions / self.total_actions * 100) if self.total_actions else 100 + + def record_node_outputs(self, action_id: str, action_result: ActionResult, result_context: ActionContext) -> None: + """ + 记录当前节点输出,供后续条件表达式读取。 + """ + outputs = action_result.outputs or self.extract_context_outputs(result_context) + if outputs: + self.context.node_outputs[action_id] = outputs + + @staticmethod + def extract_context_outputs(context: ActionContext) -> dict: + """ + 从动作上下文中提取非空业务字段作为节点默认输出。 + """ + if not context: + return {} + outputs = {} + for key in context.__class__.model_fields: + if key in ("execute_history", "progress", "node_outputs"): + continue + value = getattr(context, key, None) + if value in (None, "", [], {}): + continue + outputs[key] = value + return outputs + + @staticmethod + def get_flow_key(flow: ActionFlow) -> str: + """ + 生成流程边的运行期唯一标识。 + """ + return flow.id or f"{flow.source}->{flow.target}:{id(flow)}" + + def get_action_join_policy(self, action: Action, incoming_flows: List[ActionFlow]) -> str: + """ + 获取动作汇合策略,优先使用动作配置,其次兼容流程边配置。 + """ + join_policy = action.join_policy or self.get_action_data_value(action, "join_policy") + if join_policy: + return join_policy + for flow in incoming_flows: + join_policy = flow.join_policy or self.get_flow_data_value(flow, "join_policy") + if join_policy: + return join_policy + return "all_success" + + def get_action_fail_policy(self, action: Action) -> str: + """ + 获取动作失败策略。 + """ + return action.fail_policy or self.get_action_data_value(action, "fail_policy") or "stop" + + def get_flow_condition(self, flow: ActionFlow) -> Optional[str]: + """ + 获取流程边条件表达式。 + """ + return flow.condition or self.get_flow_data_value(flow, "condition") + + @staticmethod + def get_action_data_value(action: Action, key: str) -> Any: + """ + 从动作 data 中读取扩展配置。 + """ + data = action.data or {} + return data.get(key) if isinstance(data, dict) else None + + @staticmethod + def get_flow_data_value(flow: ActionFlow, key: str) -> Any: + """ + 从流程边 data 中读取扩展配置。 + """ + data = flow.data or {} + return data.get(key) if isinstance(data, dict) else None + + def evaluate_condition(self, condition: Optional[str]) -> bool: + """ + 安全计算流程边条件表达式。 + """ + if not condition: + return True + expression = condition.strip() + if not expression: + return True + expression = expression.replace("&&", " and ").replace("||", " or ") + try: + tree = ast.parse(expression, mode="eval") + except SyntaxError as err: + raise ValueError(f"{condition} 语法错误") from err + return bool(self.evaluate_condition_node(tree.body)) + + def evaluate_condition_node(self, node: ast.AST) -> Any: + """ + 递归计算受限 AST 节点,避免执行任意代码。 + """ + if isinstance(node, ast.BoolOp): + values = [bool(self.evaluate_condition_node(value)) for value in node.values] + if isinstance(node.op, ast.And): + return all(values) + if isinstance(node.op, ast.Or): + return any(values) + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): + return not bool(self.evaluate_condition_node(node.operand)) + if isinstance(node, ast.Compare): + return self.evaluate_compare_node(node) + if isinstance(node, ast.Name): + return self.resolve_condition_name(node.id) + if isinstance(node, ast.Attribute): + return self.read_value(self.evaluate_condition_node(node.value), node.attr) + if isinstance(node, ast.Subscript): + return self.read_subscript_node(node) + if isinstance(node, ast.Constant): + return node.value + if isinstance(node, ast.List): + return [self.evaluate_condition_node(item) for item in node.elts] + if isinstance(node, ast.Tuple): + return tuple(self.evaluate_condition_node(item) for item in node.elts) + if isinstance(node, ast.Set): + return {self.evaluate_condition_node(item) for item in node.elts} + if isinstance(node, ast.Dict): + return { + self.evaluate_condition_node(key): self.evaluate_condition_node(value) + for key, value in zip(node.keys, node.values) + } + raise ValueError(f"不支持的条件表达式:{ast.dump(node)}") + + def evaluate_compare_node(self, node: ast.Compare) -> bool: + """ + 计算比较表达式,支持链式比较和成员判断。 + """ + left = self.evaluate_condition_node(node.left) + for operator, comparator in zip(node.ops, node.comparators): + right = self.evaluate_condition_node(comparator) + if not self.compare_values(left, operator, right): + return False + left = right + return True + + def read_subscript_node(self, node: ast.Subscript) -> Any: + """ + 读取下标访问表达式。 + """ + if isinstance(node.slice, ast.Slice): + raise ValueError("条件表达式不支持切片访问") + container = self.evaluate_condition_node(node.value) + key = self.evaluate_condition_node(node.slice) + return self.read_value(container, key) + + def resolve_condition_name(self, name: str) -> Any: + """ + 将条件表达式中的根名称映射到当前工作流上下文。 + """ + if name in ("true", "True"): + return True + if name in ("false", "False"): + return False + if name in ("none", "None", "null"): + return None + if name == "context": + return self.context + if name in ("outputs", "node_outputs"): + return self.context.node_outputs or {} + if name in ActionContext.model_fields: + return getattr(self.context, name, None) + raise ValueError(f"未知上下文变量 {name}") + + def resolve_context_path(self, path: str) -> Any: + """ + 按点分路径读取工作流上下文数据。 + """ + if not path: + return None + value = None + for index, part in enumerate(path.split(".")): + if index == 0: + value = self.resolve_condition_name(part) + continue + key = int(part) if part.isdigit() else part + value = self.read_value(value, key) + return value + + @staticmethod + def read_value(value: Any, key: Any) -> Any: + """ + 从 dict、对象或序列中读取属性值。 + """ + if value is None: + return None + if isinstance(key, str) and key in ("count", "length") and hasattr(value, "__len__"): + return len(value) + if isinstance(value, dict): + return value.get(key) + if isinstance(value, (list, tuple)): + if isinstance(key, int) and 0 <= key < len(value): + return value[key] + return None + if isinstance(key, str) and hasattr(value, key): + return getattr(value, key) + return None + + @staticmethod + def compare_values(left: Any, operator: ast.cmpop, right: Any) -> bool: + """ + 比较两个条件表达式值。 + """ + try: + if isinstance(operator, ast.Eq): + return left == right + if isinstance(operator, ast.NotEq): + return left != right + if isinstance(operator, ast.Gt): + return left > right + if isinstance(operator, ast.GtE): + return left >= right + if isinstance(operator, ast.Lt): + return left < right + if isinstance(operator, ast.LtE): + return left <= right + if isinstance(operator, ast.In): + return left in right + if isinstance(operator, ast.NotIn): + return left not in right + except TypeError: + return False + raise ValueError(f"不支持的比较操作符:{operator.__class__.__name__}") def merge_context(self, context: ActionContext) -> None: """ @@ -318,14 +637,13 @@ class WorkflowChain(ChainBase): logger.info(f"工作流 {workflow.name} 已停止") return False, executor.errmsg - if not executor.success: + if not executor.success or executor.has_failure: logger.info(f"工作流 {workflow.name} 执行失败:{executor.errmsg}") workflowoper.fail(workflow_id, result=executor.errmsg) return False, executor.errmsg - else: - logger.info(f"工作流 {workflow.name} 执行完成") - workflowoper.success(workflow_id) - return True, "" + logger.info(f"工作流 {workflow.name} 执行完成") + workflowoper.success(workflow_id) + return True, "" @staticmethod def get_workflows() -> List[Workflow]: diff --git a/app/schemas/workflow.py b/app/schemas/workflow.py index 358a47d7..30f52dc0 100644 --- a/app/schemas/workflow.py +++ b/app/schemas/workflow.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Any, Optional, List from pydantic import BaseModel, Field, ConfigDict @@ -50,6 +50,8 @@ class Action(BaseModel): description: Optional[str] = Field(default=None, description="动作描述") position: Optional[dict] = Field(default_factory=dict, description="位置") data: Optional[dict] = Field(default_factory=dict, description="参数") + join_policy: Optional[str] = Field(default=None, description="多上游节点汇合策略") + fail_policy: Optional[str] = Field(default=None, description="动作失败后的工作流处理策略") class ActionExecution(BaseModel): @@ -72,10 +74,21 @@ class ActionContext(BaseModel): downloads: Optional[List[DownloadTask]] = Field(default_factory=list, description="下载任务列表") sites: Optional[List[Site]] = Field(default_factory=list, description="站点列表") subscribes: Optional[List[Subscribe]] = Field(default_factory=list, description="订阅列表") + node_outputs: Optional[dict] = Field(default_factory=dict, description="节点输出数据") execute_history: Optional[List[ActionExecution]] = Field(default_factory=list, description="执行历史") progress: Optional[int] = Field(default=0, description="执行进度(%)") +class ActionResult(BaseModel): + """ + 动作执行结果。 + """ + success: Optional[bool] = Field(default=True, description="动作是否执行成功") + message: Optional[str] = Field(default=None, description="动作执行消息") + context: Optional[ActionContext] = Field(default=None, description="动作执行后的上下文") + outputs: Optional[dict[str, Any]] = Field(default_factory=dict, description="当前节点显式输出") + + class ActionFlow(BaseModel): """ 工作流流程 @@ -84,6 +97,9 @@ class ActionFlow(BaseModel): source: Optional[str] = Field(default=None, description="源动作") target: Optional[str] = Field(default=None, description="目标动作") animated: Optional[bool] = Field(default=True, description="是否动画流程") + data: Optional[dict] = Field(default_factory=dict, description="流程扩展配置") + condition: Optional[str] = Field(default=None, description="流转条件表达式") + join_policy: Optional[str] = Field(default=None, description="目标节点汇合策略") class WorkflowShare(BaseModel): diff --git a/app/workflow/__init__.py b/app/workflow/__init__.py index 5f55fc76..4f88b569 100644 --- a/app/workflow/__init__.py +++ b/app/workflow/__init__.py @@ -9,7 +9,7 @@ from app.db.models import Workflow from app.db.workflow_oper import WorkflowOper from app.helper.module import ModuleHelper from app.log import logger -from app.schemas import ActionContext, Action +from app.schemas import ActionContext, Action, ActionResult from app.schemas.types import EventType from app.utils.singleton import Singleton @@ -69,14 +69,7 @@ class WorkFlowManager(metaclass=Singleton): self._event_workflows = {} def execute(self, workflow_id: int, action: Action, - context: ActionContext = None) -> Tuple[bool, str, ActionContext]: - """ - 执行工作流动作 - """ - return self.excute(workflow_id=workflow_id, action=action, context=context) - - def excute(self, workflow_id: int, action: Action, - context: ActionContext = None) -> Tuple[bool, str, ActionContext]: + context: ActionContext = None) -> ActionResult: """ 执行工作流动作 """ @@ -91,11 +84,12 @@ class WorkFlowManager(metaclass=Singleton): logger.info(f"执行动作: {action.id} - {action.name}") try: result_context = action_obj.execute(workflow_id, action.data, context) + action_result = self._normalize_action_result(result_context, action_obj, context) except Exception as err: logger.error(f"{action.name} 执行失败: {err}") - return False, f"{err}", context - loop = action.data.get("loop") - loop_interval = action.data.get("loop_interval") + return ActionResult(success=False, message=f"{err}", context=context) + loop = (action.data or {}).get("loop") + loop_interval = (action.data or {}).get("loop_interval") if loop and loop_interval: while not action_obj.done: if global_vars.is_workflow_stopped(workflow_id): @@ -105,15 +99,40 @@ class WorkFlowManager(metaclass=Singleton): sleep(loop_interval) # 执行 logger.info(f"继续执行动作: {action.id} - {action.name}") - result_context = action_obj.execute(workflow_id, action.data, result_context) - if action_obj.success: + result_context = action_obj.execute(workflow_id, action.data, action_result.context) + action_result = self._normalize_action_result(result_context, action_obj, action_result.context) + if action_result.success: logger.info(f"{action.name} 执行成功") else: logger.error(f"{action.name} 执行失败!") - return action_obj.success, action_obj.message, result_context + return action_result else: logger.error(f"未找到动作: {action.type} - {action.name}") - return False, " ", context + return ActionResult(success=False, message=" ", context=context) + + def excute(self, workflow_id: int, action: Action, + context: ActionContext = None) -> Tuple[bool, str, ActionContext]: + """ + 执行工作流动作,兼容历史拼写错误的方法名。 + """ + action_result = self.execute(workflow_id=workflow_id, action=action, context=context) + return bool(action_result.success), action_result.message or "", action_result.context or context or ActionContext() + + @staticmethod + def _normalize_action_result(result: Any, action_obj: Any, fallback_context: ActionContext) -> ActionResult: + """ + 将旧版动作上下文与新版结构化结果统一为动作执行结果。 + """ + if isinstance(result, ActionResult): + result.context = result.context or fallback_context + if result.message is None: + result.message = action_obj.message + return result + return ActionResult( + success=action_obj.success, + message=action_obj.message, + context=result or fallback_context + ) def list_actions(self) -> List[dict]: """ diff --git a/tests/test_workflow_execution.py b/tests/test_workflow_execution.py index a23a004d..e4a409db 100644 --- a/tests/test_workflow_execution.py +++ b/tests/test_workflow_execution.py @@ -4,21 +4,21 @@ import threading from types import SimpleNamespace from app.chain import workflow as workflow_module -from app.schemas import ActionContext +from app.schemas import ActionContext, ActionResult from app.schemas.types import EventType from app import workflow as workflow_package -def _build_workflow(current_action=None, context=None): +def _build_workflow(current_action=None, context=None, actions=None, flows=None): """构造最小工作流对象。""" return SimpleNamespace( id=1, name="测试工作流", - actions=[ + actions=actions or [ {"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}, {"id": "B", "type": "FakeAction", "name": "动作B", "data": {}}, ], - flows=[ + flows=flows or [ {"id": "flow-1", "source": "A", "target": "B", "animated": True}, ], current_action=current_action, @@ -36,12 +36,23 @@ def _encoded_context(context: ActionContext) -> dict: class _FakeWorkflowManager: """记录执行动作的工作流管理器。""" - def __init__(self, calls): + def __init__(self, calls, results=None): self.calls = calls + self.results = results or {} + + def execute(self, workflow_id, action, context=None): + self.calls.append(action.id) + result = self.results.get(action.id) + if callable(result): + return result(action, context or ActionContext()) + if result: + return result + return ActionResult(success=True, message=f"{action.name}完成", context=context or ActionContext()) def excute(self, workflow_id, action, context=None): - self.calls.append(action.id) - return True, f"{action.name}完成", context or ActionContext() + """兼容历史执行方法。""" + result = self.execute(workflow_id, action, context) + return result.success, result.message, result.context def test_workflow_executor_resumes_downstream_nodes(monkeypatch): @@ -85,6 +96,172 @@ def test_workflow_executor_reports_incremental_progress(monkeypatch): assert progresses == [50, 100] +def test_workflow_executor_skips_false_condition_branch(monkeypatch): + """条件边不满足时应跳过对应分支,并继续执行满足条件的分支。""" + calls = [] + fake_manager = _FakeWorkflowManager( + calls, + results={ + "A": lambda action, context: ActionResult( + success=True, + message=f"{action.name}完成", + context=context, + outputs={"items": ["movie"]} + ) + } + ) + workflow = _build_workflow( + actions=[ + {"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}, + {"id": "B", "type": "FakeAction", "name": "动作B", "data": {}}, + {"id": "C", "type": "FakeAction", "name": "动作C", "data": {}}, + ], + flows=[ + {"id": "flow-ab", "source": "A", "target": "B", "condition": "outputs.A.items.count == 0"}, + {"id": "flow-ac", "source": "A", "target": "C", "data": {"condition": "outputs.A.items.count > 0"}}, + ], + ) + + monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager) + monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None) + monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + executor = workflow_module.WorkflowExecutor(workflow) + executor.execute() + + assert calls == ["A", "C"] + assert executor.success is True + assert executor.context.progress == 100 + assert executor.context.node_outputs["A"]["items"] == ["movie"] + + +def test_workflow_executor_all_success_join_waits_parallel_branches(monkeypatch): + """默认汇合策略应等待所有上游分支成功后再执行目标节点。""" + calls = [] + joined_outputs = {} + + def run_join(action, context): + """记录汇合节点读取到的上游输出。""" + joined_outputs.update(context.node_outputs) + return ActionResult(success=True, message=f"{action.name}完成", context=context) + + fake_manager = _FakeWorkflowManager( + calls, + results={ + "A": lambda action, context: ActionResult( + success=True, + message=f"{action.name}完成", + context=context, + outputs={"value": "A"} + ), + "B": lambda action, context: ActionResult( + success=True, + message=f"{action.name}完成", + context=context, + outputs={"value": "B"} + ), + "C": run_join, + } + ) + workflow = _build_workflow( + actions=[ + {"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}, + {"id": "B", "type": "FakeAction", "name": "动作B", "data": {}}, + {"id": "C", "type": "FakeAction", "name": "动作C", "data": {}}, + ], + flows=[ + {"id": "flow-ac", "source": "A", "target": "C"}, + {"id": "flow-bc", "source": "B", "target": "C"}, + ], + ) + + monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager) + monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None) + monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + executor = workflow_module.WorkflowExecutor(workflow) + executor.execute() + + assert set(calls) == {"A", "B", "C"} + assert calls[-1] == "C" + assert joined_outputs["A"] == {"value": "A"} + assert joined_outputs["B"] == {"value": "B"} + + +def test_workflow_executor_any_success_join_runs_after_available_branch(monkeypatch): + """any_success 汇合策略应允许任一满足条件的上游分支触发目标节点。""" + calls = [] + fake_manager = _FakeWorkflowManager( + calls, + results={ + "A": lambda action, context: ActionResult( + success=True, + message=f"{action.name}完成", + context=context, + outputs={"items": ["movie"]} + ) + } + ) + workflow = _build_workflow( + actions=[ + {"id": "A", "type": "FakeAction", "name": "动作A", "data": {}}, + {"id": "B", "type": "FakeAction", "name": "动作B", "data": {}}, + {"id": "C", "type": "FakeAction", "name": "动作C", "data": {}}, + {"id": "D", "type": "FakeAction", "name": "动作D", "data": {"join_policy": "any_success"}}, + ], + flows=[ + {"id": "flow-ab", "source": "A", "target": "B", "condition": "outputs.A.items.count == 0"}, + {"id": "flow-ac", "source": "A", "target": "C", "condition": "outputs.A.items.count > 0"}, + {"id": "flow-bd", "source": "B", "target": "D"}, + {"id": "flow-cd", "source": "C", "target": "D"}, + ], + ) + + monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager) + monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None) + monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + executor = workflow_module.WorkflowExecutor(workflow) + executor.execute() + + assert calls == ["A", "C", "D"] + assert executor.context.progress == 100 + + +def test_workflow_executor_all_done_join_can_continue_after_failure(monkeypatch): + """continue 失败策略配合 all_done 汇合时应继续执行收尾节点。""" + calls = [] + fake_manager = _FakeWorkflowManager( + calls, + results={ + "A": lambda action, context: ActionResult(success=False, message=f"{action.name}失败", context=context) + } + ) + workflow = _build_workflow( + actions=[ + {"id": "A", "type": "FakeAction", "name": "动作A", "data": {"fail_policy": "continue"}}, + {"id": "B", "type": "FakeAction", "name": "动作B", "data": {}}, + {"id": "C", "type": "FakeAction", "name": "动作C", "data": {"join_policy": "all_done"}}, + ], + flows=[ + {"id": "flow-ac", "source": "A", "target": "C"}, + {"id": "flow-bc", "source": "B", "target": "C"}, + ], + ) + + monkeypatch.setattr(workflow_module, "WorkFlowManager", lambda: fake_manager) + monkeypatch.setattr(workflow_module.global_vars, "workflow_resume", lambda workflow_id: None) + monkeypatch.setattr(workflow_module.global_vars, "is_workflow_stopped", lambda workflow_id: False) + + executor = workflow_module.WorkflowExecutor(workflow) + executor.execute() + + assert set(calls) == {"A", "B", "C"} + assert calls[-1] == "C" + assert executor.has_failure is True + assert executor.success is True + + def test_workflow_executor_stop_is_not_success(monkeypatch): """停止信号不应被执行器汇报为成功完成。""" calls = []