PyRecover is a robust distributed checkpointing and job management system for multi-GPU SLURM workloads. The project offers efficient training with time-aware checkpointing to maximize cluster utilization.
Shows environment creation with conda, but principally also other tools such as venv can be used.
- Miniconda or Anaconda
- Clone the repository
- Set up the environment with miniconda:
conda env create -f env.yml
This will create an environment called "pyrecover" based on python 3.10. Use this environment for development. Activate it by calling:
conda activate pyrecover
To install with Flash Attention support, ensure you have the following prerequisites:
- CUDA toolkit (compatible with your PyTorch installation)
- C++ compiler (gcc/g++)
- Python development headers
Then install with:
./setup_flashattention.sh
or
pip install ".[flash-attention]"
After this you can activate flash attention as argument. The isntalletion can take un to 2h. In the slurm script flash attention is attempted to be installed if its activated. This is to make slurm runs as stateless as possible without needing this installation in a container or environment before.
The codebase contains example code for training a Transformer model on a parquet dataset. It's designed to work with SLURM, automatically detecting when multiple GPUs are available and enabling distributed training via DDP (DistributedDataParallel).
The training script (train.py
) accepts various arguments to customize the training process. Here are the key parameters:
Argument | Description | Default |
---|---|---|
--dataset |
Path to parquet file with text data | /capstor/store/cscs/ethz/large-sc/datasets/train_data.parquet |
--sequence-length |
Maximum sequence length | 2048 |
--batch-size |
Batch size per GPU | 1 |
--learning-rate |
Learning rate | 1e-5 |
--training-steps |
Number of training steps | 1000 |
--distributed |
Enable distributed training | False |
--model-dtype |
Model precision (fp16/bf16/fp32/fp64) | "bf16" |
--checkpoint-dir |
Directory for checkpoints | "checkpoints/" |
--checkpoint-frequency |
Save checkpoint every N steps | 10 |
--resume-from-checkpoint |
Path to checkpoint or "latest" | None |
--profile |
activates profiling support for nsys | False |
--experiment_name |
Name of experiment (for checkpoint subfolder) | "default-exp" |
--use-torch-distributed-ckpt |
Use distributed checkpointing | False |
--compile |
Compile model with torch.compile | False |
--fused-optimizer |
Use fused optimizer | False |
--use_flash_attention |
Use flash-attention in the model | False |
--log-loss-to-csv |
Log loss to a csv for plots/comparison | False |
--timeaware-checkpointing |
Activates time aware checkpointing | False |
For a complete list of arguments, run:
python train.py --help
The training script accepts various arguments to customize the training process. Here are the key parameters: train.py
Make sure to set #SBATCH --ntasks-per-node=1
this way only one process is spawned on a node. The code uses DDP and one process will only make use of one gpu.
python train.py --help
The script is provided for launching training jobs on SLURM clusters. submit-training-simple.sh
These key parameters can be adapted by editing the script.
SLURM Parameter | Description |
---|---|
--nodes |
Number of nodes to allocate |
--ntasks-per-node |
Tasks per node (typically 1 per GPU) |
--gpus-per-node |
GPUs to use per node |
--time |
Time limit for the job |
--partition |
SLURM partition to use |
The submission script supports the following arguments:
Argument | Description |
---|---|
--distributed |
Enable distributed training |
--exp_name=NAME |
Set experiment name (affects checkpoint subfolder) |
--continue |
Resume from latest checkpoint |
--use_torch_distributed_ckpt |
Use torch distributed checkpointing |
--use_flash_attention |
Use and install flash-attention in the model |
--log-loss-to-csv |
Log the loss of the training to a csv file |
--timeaware-checkpointing |
Activate time-aware checkpointing in train script |
--fused-optimizer |
Activate using the fused optimizer for training |
--profile-nsys |
Run the nsys profiling. Only support run with one GPU, so adapt batch script accordingly |
Note: NSYS profiling (
--profile-nsys
) is only supported for single-GPU runs. The script will automatically handle the required setup.
The script automatically computes the job end time based on the SLURM time limit and makes it available to the training script. This enables graceful stopping and checkpointing as the job approaches its time limit.
# Non-distributed training
sbatch submit-training-simple.sh --exp_name=my_experiment
# Distributed training on multiple GPUs
sbatch submit-training-simple.sh --distributed --exp_name=distributed_exp
# Resume from checkpoint with distributed checkpointing
sbatch submit-training-simple.sh --distributed --continue --use_torch_distributed_ckpt
PyRecover offers two checkpointing methods:
-
Vanilla Checkpointing: Standard PyTorch checkpointing (default)
- Use with standard submission script without flags
-
Distributed Checkpointing: Faster loading/saving for large models (45+ GB)
- Enable with flag
--use_torch_distributed_ckpt
- Enable with flag
Checkpoints are automatically organized by experiment name, allowing you to run multiple experiments without overwriting previous results.
Time-aware checkpointing enables the training script to monitor the remaining SLURM job time and automatically trigger a final checkpoint and graceful exit before the job ends. This prevents loss of training progress due to walltime limits.
- Activate by adding the
--timeaware-checkpointing
flag to your training command or SLURM script. - The script dynamically tracks iteration and checkpoint durations to calculate a safe stopping threshold.
- As the job nears its end, a final checkpoint is saved and the process exits cleanly, allowing seamless resumption.
Example usage:
sbatch submit-training-simple.sh --distributed --timeaware-checkpointing
For distributed training across multiple GPUs and nodes:
-
Set SLURM parameters in the submission script:
- (for 4 GPUs per node)
--ntasks-per-node=4
--nodes=X
(where X is the number of nodes)
- (for 4 GPUs per node)
-
Launch with the distributed flag:
sbatch submit-training-simple.sh --distributed
This will automatically:
- Initialize process groups
- Set up data parallelism with DistributedDataParallel
- Configure distributed samplers for the dataset
To test the checkpointing we employ different benchmark possibilities. This is either enabled by separate scripts or by setting cmd args. For some it is even enough to look at the output.
With and without checkpointing or continue from checkpoint we can reach two final checkpoints.
Make sure training is done with same hyperparams and training-args and use the same fixed seed.
Then use the script tests/check_weights_equality.py
and give the path to two checkpoints as arguments.
python check_weights_equality.py <checkpoint1> <checkpoint2> [--distributed] [--tolerance 1e-7] [--verbose]
checkpoint1
: Path to the first checkpointcheckpoint2
: Path to the second checkpoint- : Use this flag if the checkpoints were saved using distributed checkpointing
--distributed
--tolerance
: Floating point tolerance for comparison (default: 1e-7)--verbose
: Enable detailed output of differences
To compare loss convergence with and without checkpointing, we add the possibility to log loss values for each step to a csv file that will be stored in the experiment folder.
Just add the parameter: --log-loss-to-csv
.
To compare loss convergence across runs, use the
--log-loss-to-csv
flag. This will save step-wise loss values in a CSV file in the experiment directory for easy plotting and analysis.