Skip to content

Fsdp2 support for activation checkpointing #359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 73 commits into from
Jul 22, 2025
Merged

Conversation

le1nux
Copy link
Member

@le1nux le1nux commented Apr 18, 2025

What does this PR do?

This PR adds activation checkpointing (AC) support for FSDP2.
There are now three AC variants:

  • Full AC (same as before, where entire complete modules get ACed, leading to the largest memory footprint reduction)
  • Selective Layer AC (only very nth layer or module is ACed)
  • Selective OP Ac (only certain OPs, typically low memory but compute intense, are checkpointed)

Additionally

  • Minor restructurings of the code
  • new results subscriber variant EvaluationResultToDiscSubscriber (will be used in benchmark tooling)
  • new class RandomDatasetBatchGenerator (will be used in profiler)
  • changes in get_compiled_model: We check now that a module that is to be compiled has exactly one parent module that references it and throw an exception otherwise. Before we were only replacing the compiled model for one of the parents and silently skipped the other parents.
  • changes in experiment_id generation: Previously get_experiment_id_from_config(...) contained local experiment id generation and syncing. I refactored it such that we can now also sync arbitrary strings.
  • Originally this PR also had profling and benchmarking tooling. Since this is not production ready yet, I moved it to https://github.com/Modalities/modalities/tree/legacy_profiling_env

Checklist before submitting final PR

  • [] My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

le1nux added 30 commits April 18, 2025 19:01
…w for testing with a distributed environment
@le1nux le1nux requested review from flxst and rrutmann July 17, 2025 11:15
@le1nux le1nux marked this pull request as ready for review July 17, 2025 11:20
@flxst flxst mentioned this pull request Jul 18, 2025
6 tasks
if config_file_path is None:
experiment_id = f"{date_of_run}"
else:
hash = hashlib.sha256(str(config_file_path).encode()).hexdigest()[:hash_length]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we hash the path of the config instead of the content of the file? The latter would yield the same experiment ID for identical configs

Copy link
Member Author

@le1nux le1nux Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a valid point. We can address this later on. Before we also did not hash the file itself.
#388

raise NotImplementedError


class RandomDatasetBatchGenerator(DatasetBatchGeneratorIF):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which case do we need this? What is the advantage over using our non-random test data?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wonder why this module was added. The class does not seem to be used anywhere.


def _selective_checkpointing_context_fn():
meta = defaultdict(int)
save_ops_set = {ActivationCheckpointing.SAVE_DICT[key] for key in save_ops_keys}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This throws an error for operations that are not listed in ActivationCheckpointing.SAVE_DICT. Why do we restrict to the ops in ActivationCheckpointing.SAVE_DICT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the config we can only define strings and we need to map the string to the respective function. In theory, we could do something with eval() but I also found that rather ugly and error prone. I would suggest, we run the benchmarking w.r.t. SAC and check if we need to make it more generic?

mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default:
meta[mm_count_key] += 1
# Saves output of all compute ops in save_ops_set, except every second mm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only every second?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the setup in torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/infra/parallelize.py#L301

From my understanding it's for balancing compute vs memory savings. If we wanted to make this completely configurable, we would have to store the checkpointing frequency of every OP.

Copy link
Member

@flxst flxst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! I left a few comments.

Generally, it is a bit hard to understand the various changes that are seemingly unrelated to AC:

  • new results subscriber variant EvaluationResultToDiscSubscriber
  • new class RandomDatasetBatchGenerator
  • changes in get_compiled_model
  • changes in experiment_id generation

I think it would be good to at least list them explicitly in the PR description (instead of "Minor restructurings of the code") and indicate their purpose.

| ActivationCheckpointedModelConfig.SelectiveOpACParams
),
) -> nn.Module:
"""FSDP2 variant for applying activation checkpointing to the given model (in-place operation).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""FSDP2 variant for applying activation checkpointing to the given model (in-place operation).
"""General variant for applying activation checkpointing to the given model (in-place operation).

.. since it can be used in the absence of FSDP2, directly on nn.Module classes (as the name of the method also indicates).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say since we are always either using FSDP1 or FSDP2 and don't support training a model without either of those parallelizations, we should be a bit more restrictive here. Otherwise, the user might get a wrong idea even though in theory it could be possible but in practice it's not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I get your point. I wonder if we should put "FSDP2" in the function name (like we do with FSDP1) to emphasize this, despite the model being a nn.Module. This might be slightly less confusing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done 👍

(22310, 2, "config_activation_checkpointing_fsdp1_legacy.yaml"),
],
)
def test_full_activation_checkpointing_FSDP1_legacy(world_size: int, rdvz_port: int, relative_config_path: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I think the point is that modalities with FSDP1 is stable and has successfully been used for model training in practice. FSDP2, in contrast, requires some additional work (like this PR, or #374). Once the work is done and modalities with FSDP2 has proven to be stable and reliable in practice, we will probably drop support for FSDP1 after a certain grace period.

raise NotImplementedError


class RandomDatasetBatchGenerator(DatasetBatchGeneratorIF):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wonder why this module was added. The class does not seem to be used anywhere.

@le1nux le1nux requested review from rrutmann and flxst July 22, 2025 10:28
Copy link
Member

@flxst flxst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first three tests in tests/training/test_activation_checkpointing.py require 2 GPUs, but they are not skipped if only a single GPU is available.

This makes the tests with github actions fail:
https://github.com/Modalities/modalities/actions/runs/16443411219/job/46469270473

Locally, the tests also fail with CUDA_VISIBLE_DEVICES=0, while they run through with CUDA_VISIBLE_DEVICES=0,1.

@flxst flxst self-requested a review July 22, 2025 12:52
Copy link
Member

@flxst flxst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@rrutmann rrutmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@le1nux le1nux merged commit f6f663b into main Jul 22, 2025
7 checks passed
@le1nux le1nux deleted the fsdp2_activation_checkpointing branch July 22, 2025 14:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants