Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,11 @@ def with_overrides(

if task_config is not None:
logger.warning("This override is beta. We may want to revisit this in the future.")
print(f"[PYTORCH_ELASTIC] with_overrides: Overriding task_config from {self.run_entity._task_config} to {task_config}")
if not isinstance(task_config, type(self.run_entity._task_config)):
raise ValueError("can't change the type of the task config")
self.run_entity._task_config = task_config
print(f"[PYTORCH_ELASTIC] with_overrides: Task config override complete. New config: {self.run_entity._task_config}")

if container_image is not None:
assert_not_promise(container_image, "container_image")
Expand Down
58 changes: 58 additions & 0 deletions plugins/flytekit-kf-pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,61 @@ To migrate from v0 to v1, change the following:
```
task_config=PyTorch(worker=Worker(replicas=10)),
```

## Dynamic Execution Modes with Overrides

The PyTorch Elastic plugin now supports dynamic switching between single-node and multi-node execution modes using `with_overrides()`. This allows you to adapt your training based on runtime conditions without creating separate task definitions.

### Example: Dynamic Node Configuration

```python
from flytekit import task, workflow
from flytekitplugins.kfpytorch import Elastic

# Define a task with default multi-node configuration
@task(task_config=Elastic(nnodes=2, nproc_per_node=2))
def train_model(epochs: int, batch_size: int) -> float:
# Your training code here
return accuracy

@workflow
def adaptive_training(use_single_node: bool) -> float:
if use_single_node:
# Override to single-node execution
# This will run as a regular pod without PyTorchJob
result = train_model(epochs=10, batch_size=32).with_overrides(
task_config=Elastic(nnodes=1, nproc_per_node=1)
)
else:
# Use the original multi-node configuration
result = train_model(epochs=10, batch_size=32)

return result
```

### Key Benefits

1. **No Rendezvous Timeouts**: Single-node tasks bypass elastic launch entirely, avoiding unnecessary rendezvous attempts
2. **Resource Efficiency**: Single-node tasks run as regular pods, reducing overhead
3. **Flexibility**: Switch between execution modes based on runtime conditions
4. **Backward Compatible**: Existing tasks continue to work as before

### Execution Behavior

- `nnodes=1`: Task type becomes `"python-task"`, executes directly without elastic launch
- `nnodes>1`: Task type is `"pytorch"`, uses PyTorchJob with elastic launch
- String values like `"1"` or `"1:1"` are treated as single-node
- Elastic ranges like `"1:4"` are treated as multi-node

## Debug Output

The plugin now automatically prints debug messages to help diagnose issues. Look for messages with the `[PYTORCH_ELASTIC]` prefix:

```
[PYTORCH_ELASTIC] Plugin loaded with fix version: 1.0-nnodes-override-fix
[PYTORCH_ELASTIC] __init__: nnodes=1, type=<class 'int'>
[PYTORCH_ELASTIC] execute: task_config=Elastic(nnodes=1, nproc_per_node=1, ...)
[PYTORCH_ELASTIC] *** SINGLE-NODE DETECTED - BYPASSING ELASTIC LAUNCH ***
```

If you see these messages in your logs, the fix is working correctly.
Loading