diff --git a/flytekit/core/node.py b/flytekit/core/node.py index f579d391ad..dadbc1c146 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -230,9 +230,11 @@ def with_overrides( if task_config is not None: logger.warning("This override is beta. We may want to revisit this in the future.") + print(f"[PYTORCH_ELASTIC] with_overrides: Overriding task_config from {self.run_entity._task_config} to {task_config}") if not isinstance(task_config, type(self.run_entity._task_config)): raise ValueError("can't change the type of the task config") self.run_entity._task_config = task_config + print(f"[PYTORCH_ELASTIC] with_overrides: Task config override complete. New config: {self.run_entity._task_config}") if container_image is not None: assert_not_promise(container_image, "container_image") diff --git a/plugins/flytekit-kf-pytorch/README.md b/plugins/flytekit-kf-pytorch/README.md index 436636224f..cbc3f6a948 100644 --- a/plugins/flytekit-kf-pytorch/README.md +++ b/plugins/flytekit-kf-pytorch/README.md @@ -62,3 +62,61 @@ To migrate from v0 to v1, change the following: ``` task_config=PyTorch(worker=Worker(replicas=10)), ``` + +## Dynamic Execution Modes with Overrides + +The PyTorch Elastic plugin now supports dynamic switching between single-node and multi-node execution modes using `with_overrides()`. This allows you to adapt your training based on runtime conditions without creating separate task definitions. + +### Example: Dynamic Node Configuration + +```python +from flytekit import task, workflow +from flytekitplugins.kfpytorch import Elastic + +# Define a task with default multi-node configuration +@task(task_config=Elastic(nnodes=2, nproc_per_node=2)) +def train_model(epochs: int, batch_size: int) -> float: + # Your training code here + return accuracy + +@workflow +def adaptive_training(use_single_node: bool) -> float: + if use_single_node: + # Override to single-node execution + # This will run as a regular pod without PyTorchJob + result = train_model(epochs=10, batch_size=32).with_overrides( + task_config=Elastic(nnodes=1, nproc_per_node=1) + ) + else: + # Use the original multi-node configuration + result = train_model(epochs=10, batch_size=32) + + return result +``` + +### Key Benefits + +1. **No Rendezvous Timeouts**: Single-node tasks bypass elastic launch entirely, avoiding unnecessary rendezvous attempts +2. **Resource Efficiency**: Single-node tasks run as regular pods, reducing overhead +3. **Flexibility**: Switch between execution modes based on runtime conditions +4. **Backward Compatible**: Existing tasks continue to work as before + +### Execution Behavior + +- `nnodes=1`: Task type becomes `"python-task"`, executes directly without elastic launch +- `nnodes>1`: Task type is `"pytorch"`, uses PyTorchJob with elastic launch +- String values like `"1"` or `"1:1"` are treated as single-node +- Elastic ranges like `"1:4"` are treated as multi-node + +## Debug Output + +The plugin now automatically prints debug messages to help diagnose issues. Look for messages with the `[PYTORCH_ELASTIC]` prefix: + +``` +[PYTORCH_ELASTIC] Plugin loaded with fix version: 1.0-nnodes-override-fix +[PYTORCH_ELASTIC] __init__: nnodes=1, type= +[PYTORCH_ELASTIC] execute: task_config=Elastic(nnodes=1, nproc_per_node=1, ...) +[PYTORCH_ELASTIC] *** SINGLE-NODE DETECTED - BYPASSING ELASTIC LAUNCH *** +``` + +If you see these messages in your logs, the fix is working correctly. diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 1972f10bd9..57c7c9dc87 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -3,35 +3,61 @@ Kubernetes. It leverages `Pytorch Job `_ Plugin from kubeflow. """ +import logging import os +import sys + +# Force unbuffered output for immediate visibility +sys.stdout.flush() +os.environ['PYTHONUNBUFFERED'] = '1' + from dataclasses import dataclass, field from enum import Enum from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union +import cloudpickle +import flytekit from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common from flyteidl.plugins.kubeflow import pytorch_pb2 as pytorch_task from google.protobuf.json_format import MessageToDict -import flytekit -from flytekit import PythonFunctionTask, Resources, lazy_module +from flytekit import FlyteContextManager, PythonFunctionTask, Resources, lazy_module, task from flytekit.configuration import SerializationSettings -from flytekit.core.context_manager import FlyteContextManager, OutputMetadata +from flytekit.core.base_task import PythonTask +from flytekit.core.context_manager import FlyteContext, OutputMetadata, OutputMetadataTracker from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import convert_resources_to_resource_model -from flytekit.exceptions.user import ( - FlyteRecoverableException, - FlyteUserRuntimeException, -) +from flytekit.exceptions.base import FlyteRecoverableException +from flytekit.exceptions.user import FlyteUserRuntimeException from flytekit.extend import IgnoreOutputs, TaskPlugins -from flytekit.loggers import logger +from flytekit.models import task as _task_models -from .error_handling import create_recoverable_error_file, is_recoverable_worker_error +from .error_handling import is_recoverable_worker_error from .pod_template import add_shared_mem_volume_to_pod_template -cloudpickle = lazy_module("cloudpickle") +pd = lazy_module("pandas") TORCH_IMPORT_ERROR_MESSAGE = "PyTorch is not installed. Please install `flytekitplugins-kfpytorch['elastic']`." +# Configure logger to show INFO level messages +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# Add console handler if not already present +if not logger.handlers: + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + formatter = logging.Formatter('[%(levelname)s] %(asctime)s - flytekitplugins.kfpytorch - %(message)s') + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + +# Force immediate output +logger.info(f"PyTorch Elastic plugin logger initialized") + +# Version marker for debugging +PYTORCH_ELASTIC_FIX_VERSION = "1.0-nnodes-override-fix" +print(f"[PYTORCH_ELASTIC] Plugin loaded with fix version: {PYTORCH_ELASTIC_FIX_VERSION}") + @dataclass class RestartPolicy(Enum): @@ -138,6 +164,17 @@ class Elastic(object): Please see https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html for potential performance improvements. To change `OMP_NUM_THREADS`, specify it in the environment dict of the flytekit task decorator or via `pyflyte run --env`. + .. note:: + + The task type (and execution backend) is dynamically determined based on the `nnodes` value: + + - When `nnodes=1`: Task runs as a standalone pod (task_type="python-task") + - When `nnodes>1`: Task runs as a PyTorchJob via Kubeflow operator (task_type="pytorch") + + This behavior is preserved even when using `with_overrides()` to change the task configuration. + For example, a task created with `nnodes=2` can be overridden to `nnodes=1` and will correctly + execute as a standalone pod instead of a PyTorchJob. + Args: nnodes (Union[int, str]): Number of nodes, or the range of nodes in form :. nproc_per_node (str): Number of workers per node. @@ -166,6 +203,11 @@ class Elastic(object): increase_shared_mem: bool = True run_policy: Optional[RunPolicy] = None + def __repr__(self) -> str: + """String representation for better logging.""" + return (f"Elastic(nnodes={self.nnodes}, nproc_per_node={self.nproc_per_node}, " + f"start_method={self.start_method}, max_restarts={self.max_restarts})") + class PyTorchFunctionTask(PythonFunctionTask[PyTorch]): """ @@ -309,17 +351,38 @@ class PytorchElasticFunctionTask(PythonFunctionTask[Elastic]): """ Plugin for distributed training with torch elastic/torchrun (see https://pytorch.org/docs/stable/elastic/run.html). + + This task type dynamically adjusts its execution behavior based on the `nnodes` configuration: + + - When `nnodes=1`: Executes as a regular Python task without elastic launch, avoiding + unnecessary overhead and rendezvous timeouts. + - When `nnodes>1`: Uses torch elastic launch for distributed execution across multiple nodes. + + This behavior is preserved even when using `with_overrides()` to change the configuration, + allowing seamless switching between single-node and multi-node execution modes. """ _ELASTIC_TASK_TYPE = "pytorch" _ELASTIC_TASK_TYPE_STANDALONE = "python-task" def __init__(self, task_config: Elastic, task_function: Callable, **kwargs): - task_type = self._ELASTIC_TASK_TYPE_STANDALONE if task_config.nnodes == 1 else self._ELASTIC_TASK_TYPE + # Store initial task type based on initial config + # Handle both int and string nnodes values + nnodes = task_config.nnodes + print(f"[PYTORCH_ELASTIC] __init__: nnodes={nnodes}, type={type(nnodes)}") + + if isinstance(nnodes, int): + initial_task_type = self._ELASTIC_TASK_TYPE_STANDALONE if nnodes == 1 else self._ELASTIC_TASK_TYPE + else: + # For string values like "1:4", check if it's "1" or "1:1" + nnodes_str = str(nnodes) + initial_task_type = self._ELASTIC_TASK_TYPE_STANDALONE if nnodes_str in ["1", "1:1"] else self._ELASTIC_TASK_TYPE + + print(f"[PYTORCH_ELASTIC] __init__: initial_task_type={initial_task_type}") super(PytorchElasticFunctionTask, self).__init__( task_config=task_config, - task_type=task_type, + task_type=initial_task_type, task_function=task_function, # task_type_version controls the version of the task template, do not change task_type_version=1, @@ -340,7 +403,31 @@ def __init__(self, task_config: Elastic, task_function: Callable, **kwargs): self.pod_template = PodTemplate() add_shared_mem_volume_to_pod_template(self.pod_template) - self._task_config = task_config + @property + def task_type(self) -> str: + """ + Dynamically determine task type based on current nnodes configuration. + This ensures that task type updates when task_config is overridden. + """ + print(f"[PYTORCH_ELASTIC] task_type property accessed") + if self._task_config: + # Handle both int and string nnodes values + nnodes = self._task_config.nnodes + print(f"[PYTORCH_ELASTIC] task_type property: checking nnodes={nnodes}, type={type(nnodes)}") + + if isinstance(nnodes, int): + if nnodes == 1: + print(f"[PYTORCH_ELASTIC] task_type property: returning STANDALONE (nnodes=1)") + return self._ELASTIC_TASK_TYPE_STANDALONE + else: + # For string values like "1:4", check if it's "1" or "1:1" + nnodes_str = str(nnodes) + if nnodes_str == "1" or nnodes_str == "1:1": + print(f"[PYTORCH_ELASTIC] task_type property: returning STANDALONE (nnodes_str={nnodes_str})") + return self._ELASTIC_TASK_TYPE_STANDALONE + + print(f"[PYTORCH_ELASTIC] task_type property: returning ELASTIC (multi-node)") + return self._ELASTIC_TASK_TYPE def _execute(self, **kwargs) -> Any: """ @@ -352,10 +439,13 @@ def _execute(self, **kwargs) -> Any: Raises: FlyteRecoverableException: If the first exception raised in the local worker group is or inherits from `FlyteRecoverableException`. - RuntimeError: If the first exception raised in the local worker group is not and does not + RuntimeError: The first exception raised in the local worker group is not and does not inherit from `FlyteRecoverableException`. IgnoreOutputs: Raised when the task is successful in any worker group with index > 0. """ + print(f"[PYTORCH_ELASTIC] _execute: ENTERED ELASTIC LAUNCH METHOD") + print(f"[PYTORCH_ELASTIC] _execute: task_config.nnodes={self._task_config.nnodes}") + try: from torch.distributed import run from torch.distributed.launcher.api import LaunchConfig, elastic_launch @@ -363,19 +453,24 @@ def _execute(self, **kwargs) -> Any: raise ImportError(TORCH_IMPORT_ERROR_MESSAGE) nnodes_str = os.environ.get("PET_NNODES", str(self._task_config.nnodes)) + print(f"[PYTORCH_ELASTIC] _execute: PET_NNODES env var={os.environ.get('PET_NNODES')}, using nnodes_str={nnodes_str}") min_nodes, max_nodes = run.parse_min_max_nnodes(nnodes_str) + print(f"[PYTORCH_ELASTIC] _execute: parsed min_nodes={min_nodes}, max_nodes={max_nodes}") nproc_per_node = int(os.environ.get("PET_NPROC_PER_NODE", self._task_config.nproc_per_node)) max_restarts = int(os.environ.get("PET_MAX_RESTARTS", self._task_config.max_restarts)) monitor_interval = int(os.environ.get("PET_MONITOR_INTERVAL", self._task_config.monitor_interval)) rdzv_endpoint = os.environ.get("PET_RDZV_ENDPOINT", "localhost:0") + + print(f"[PYTORCH_ELASTIC] _execute: nproc_per_node={nproc_per_node}, max_restarts={max_restarts}") + print(f"[PYTORCH_ELASTIC] _execute: monitor_interval={monitor_interval}, rdzv_endpoint={rdzv_endpoint}") # If OMP_NUM_THREADS is not set, set it to 1 to avoid overloading the system. # Doing so to copy the default behavior of torchrun. # See https://github.com/pytorch/pytorch/blob/eea4ece256d74c6f25c1f4eab37b3f2f4aeefd4d/torch/distributed/run.py#L791 if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1: omp_num_threads = 1 - logger.warning( + print( "\n*****************************************\n" "Setting OMP_NUM_THREADS environment variable for each process to be " "%s in default, to avoid your system being overloaded, " @@ -398,6 +493,13 @@ def _execute(self, **kwargs) -> Any: monitor_interval=monitor_interval, start_method=self._task_config.start_method, ) + + print(f"[PYTORCH_ELASTIC] _execute: LaunchConfig created with:") + print(f" - min_nodes={min_nodes}, max_nodes={max_nodes}") + print(f" - nproc_per_node={nproc_per_node}") + print(f" - rdzv_backend={self.rdzv_backend}") + print(f" - rdzv_endpoint={rdzv_endpoint}") + print(f" - start_method={self._task_config.start_method}") if self._task_config.start_method == "spawn": """ @@ -503,44 +605,100 @@ def execute(self, **kwargs) -> Any: Handles the exception scope for the `_execute` method. """ - + print(f"[PYTORCH_ELASTIC] ========== EXECUTE METHOD CALLED ==========") + print(f"PytorchElasticFunctionTask.execute called") + print(f"Current task_config: {self._task_config}") + print(f"Current task_type: {self.task_type}") + + print(f"[PYTORCH_ELASTIC] execute: task_config={self._task_config}") + print(f"[PYTORCH_ELASTIC] execute: task_type={self.task_type}") + + # Log relevant environment variables + print(f"[PYTORCH_ELASTIC] Environment: PET_NNODES={os.environ.get('PET_NNODES', 'NOT SET')}") + + # Check if this is a single-node configuration + nnodes = self._task_config.nnodes + is_single_node = False + if isinstance(nnodes, int): + is_single_node = (nnodes == 1) + print(f"[PYTORCH_ELASTIC] execute: nnodes is int={nnodes}, is_single_node={is_single_node}") + else: + # For string values like "1:4", check if it's "1" or "1:1" + nnodes_str = str(nnodes) + is_single_node = nnodes_str in ["1", "1:1"] + print(f"[PYTORCH_ELASTIC] execute: nnodes is str={nnodes_str}, is_single_node={is_single_node}") + + # For single-node execution, bypass elastic launch and run directly + if is_single_node: + # Run as a regular Python task without elastic launch + print(f"[PYTORCH_ELASTIC] *** SINGLE-NODE DETECTED - BYPASSING ELASTIC LAUNCH ***") + try: + # Get parent class info + parent_class = super().__class__ + print(f"[PYTORCH_ELASTIC] execute: Parent class is {parent_class}") + print(f"[PYTORCH_ELASTIC] execute: Calling {parent_class.__name__}.execute()") + result = super().execute(**kwargs) + print(f"[PYTORCH_ELASTIC] execute: Parent execute returned successfully") + return result + except Exception as e: + print(f"[PYTORCH_ELASTIC] execute: ERROR in parent execute: {type(e).__name__}: {e}") + raise + + # For multi-node execution, use elastic launch + print(f"[PYTORCH_ELASTIC] *** MULTI-NODE DETECTED - USING ELASTIC LAUNCH ***") return self._execute(**kwargs) def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]: - if self._task_config.nnodes == 1: - """ - Torch elastic distributed training is executed in a normal k8s pod so that this - works without the kubeflow train operator. - """ - return super().get_custom(settings) + print(f"[PYTORCH_ELASTIC] get_custom: Called for serialization") + print(f"[PYTORCH_ELASTIC] get_custom: Current task_type property returns: {self.task_type}") + print(f"[PYTORCH_ELASTIC] get_custom: settings={settings}") + + # Always return ElasticConfig, even for single-node + # This ensures the task specification is valid + from flyteidl.plugins.kubeflow.pytorch_pb2 import ElasticConfig + + try: + from torch.distributed import run + except ImportError: + raise ImportError(TORCH_IMPORT_ERROR_MESSAGE) + + # Check if this is a single-node configuration + nnodes = self._task_config.nnodes + is_single_node = False + if isinstance(nnodes, int): + is_single_node = (nnodes == 1) else: - from flyteidl.plugins.kubeflow.pytorch_pb2 import ElasticConfig - - try: - from torch.distributed import run - except ImportError: - raise ImportError(TORCH_IMPORT_ERROR_MESSAGE) - - min_nodes, max_nodes = run.parse_min_max_nnodes(str(self._task_config.nnodes)) - - elastic_config = ElasticConfig( - rdzv_backend=self.rdzv_backend, - min_replicas=min_nodes, - max_replicas=max_nodes, - nproc_per_node=self._task_config.nproc_per_node, - max_restarts=self._task_config.max_restarts, - ) - run_policy = ( - _convert_run_policy_to_flyte_idl(self._task_config.run_policy) if self._task_config.run_policy else None - ) - job = pytorch_task.DistributedPyTorchTrainingTask( - worker_replicas=pytorch_task.DistributedPyTorchTrainingReplicaSpec( - replicas=max_nodes, - ), - elastic_config=elastic_config, - run_policy=run_policy, - ) - return MessageToDict(job) + # For string values like "1:4", check if it's "1" or "1:1" + nnodes_str = str(nnodes) + is_single_node = nnodes_str in ["1", "1:1"] + + print(f"[PYTORCH_ELASTIC] get_custom: nnodes={nnodes}, is_single_node={is_single_node}") + + min_nodes, max_nodes = run.parse_min_max_nnodes(str(self._task_config.nnodes)) + + elastic_config = ElasticConfig( + rdzv_backend=self.rdzv_backend, + min_replicas=min_nodes, + max_replicas=max_nodes, + nproc_per_node=self._task_config.nproc_per_node, + max_restarts=self._task_config.max_restarts, + ) + run_policy = ( + _convert_run_policy_to_flyte_idl(self._task_config.run_policy) if self._task_config.run_policy else None + ) + + # For single-node, we still return a valid PyTorch job spec + # but it will execute differently based on task_type + job = pytorch_task.DistributedPyTorchTrainingTask( + worker_replicas=pytorch_task.DistributedPyTorchTrainingReplicaSpec( + replicas=max_nodes, + ), + elastic_config=elastic_config, + run_policy=run_policy, + ) + + print(f"[PYTORCH_ELASTIC] get_custom: Returning PyTorch job spec for {'single' if is_single_node else 'multi'}-node") + return MessageToDict(job) # Register the PytorchElastic Plugin into the flytekit core plugin system diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py index f8742d1fe9..1260b65623 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -310,3 +310,138 @@ def test_task(): test_task() assert e.value.timestamp is not None + + +def test_elastic_task_type_override(): + """Test that task_type changes correctly when overriding nnodes.""" + # Create a task with nnodes=2 (multi-node) + @task(task_config=Elastic(nnodes=2, nproc_per_node=1)) + def multi_node_task(x: int) -> int: + return x * 2 + + # Verify initial task type is "pytorch" for multi-node + assert multi_node_task.task_type == "pytorch" + + @workflow + def test_override_workflow() -> int: + # Override with nnodes=1 (single-node) + return multi_node_task(x=5).with_overrides( + task_config=Elastic(nnodes=1, nproc_per_node=1) + ) + + # Get the workflow node + node = test_override_workflow.nodes[0] + + # Verify that the task config was updated + assert node.flyte_entity._task_config.nnodes == 1 + + # Verify that task_type now reflects single-node execution + assert node.flyte_entity.task_type == "python-task" + + # Test the opposite direction: single-node to multi-node + @task(task_config=Elastic(nnodes=1, nproc_per_node=1)) + def single_node_task(x: int) -> int: + return x * 3 + + # Verify initial task type is "python-task" for single-node + assert single_node_task.task_type == "python-task" + + @workflow + def test_override_workflow2() -> int: + # Override with nnodes=2 (multi-node) + return single_node_task(x=5).with_overrides( + task_config=Elastic(nnodes=2, nproc_per_node=1) + ) + + # Get the workflow node + node2 = test_override_workflow2.nodes[0] + + # Verify that the task config was updated + assert node2.flyte_entity._task_config.nnodes == 2 + + # Verify that task_type now reflects multi-node execution + assert node2.flyte_entity.task_type == "pytorch" + + +def test_elastic_task_type_with_string_nnodes(): + """Test that task_type works correctly with string nnodes values.""" + # Test with "1" string value + @task(task_config=Elastic(nnodes="1", nproc_per_node=1)) + def single_node_str_task(x: int) -> int: + return x * 2 + + assert single_node_str_task.task_type == "python-task" + + # Test with "1:1" string value (min and max both 1) + @task(task_config=Elastic(nnodes="1:1", nproc_per_node=1)) + def single_node_range_task(x: int) -> int: + return x * 2 + + assert single_node_range_task.task_type == "python-task" + + # Test with "1:4" string value (elastic range) + @task(task_config=Elastic(nnodes="1:4", nproc_per_node=1)) + def elastic_range_task(x: int) -> int: + return x * 2 + + assert elastic_range_task.task_type == "pytorch" + + # Test override from "2:4" to "1" + @task(task_config=Elastic(nnodes="2:4", nproc_per_node=1)) + def multi_range_task(x: int) -> int: + return x * 2 + + assert multi_range_task.task_type == "pytorch" + + @workflow + def test_string_override_workflow() -> int: + # Override with string "1" + return multi_range_task(x=5).with_overrides( + task_config=Elastic(nnodes="1", nproc_per_node=1) + ) + + node = test_string_override_workflow.nodes[0] + assert node.flyte_entity._task_config.nnodes == "1" + assert node.flyte_entity.task_type == "python-task" + + +def test_elastic_single_node_execution(): + """Test that single-node tasks execute without elastic launch.""" + import os + from unittest.mock import patch, MagicMock + + # Create a simple task function + def simple_task(x: int) -> int: + return x * 2 + + # Test with nnodes=1 (should not use elastic_launch) + @task(task_config=Elastic(nnodes=1, nproc_per_node=1)) + def single_node_task(x: int) -> int: + return simple_task(x) + + # Mock elastic_launch to ensure it's not called + with patch('torch.distributed.launcher.api.elastic_launch') as mock_elastic_launch: + # Execute the task + result = single_node_task.execute(x=5) + + # Verify the result is correct + assert result == 10 + + # Verify elastic_launch was NOT called + mock_elastic_launch.assert_not_called() + + # Test with nnodes=2 (should use elastic_launch) + @task(task_config=Elastic(nnodes=2, nproc_per_node=1)) + def multi_node_task(x: int) -> int: + return simple_task(x) + + # Mock elastic_launch for multi-node case + mock_result = MagicMock() + mock_result.return_value = {0: MagicMock(return_value=20, decks=[], om=None)} + + with patch('torch.distributed.launcher.api.elastic_launch', return_value=mock_result) as mock_elastic_launch: + # Execute the task (this will use mocked elastic_launch) + result = multi_node_task.execute(x=10) + + # Verify elastic_launch WAS called for multi-node + mock_elastic_launch.assert_called_once()