-
Notifications
You must be signed in to change notification settings - Fork 10
Fsdp2 support for activation checkpointing #359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…w for testing with a distributed environment
…tion_checkpointing' into fsdp2_activation_checkpointing
if config_file_path is None: | ||
experiment_id = f"{date_of_run}" | ||
else: | ||
hash = hashlib.sha256(str(config_file_path).encode()).hexdigest()[:hash_length] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we hash the path of the config instead of the content of the file? The latter would yield the same experiment ID for identical configs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a valid point. We can address this later on. Before we also did not hash the file itself.
#388
raise NotImplementedError | ||
|
||
|
||
class RandomDatasetBatchGenerator(DatasetBatchGeneratorIF): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In which case do we need this? What is the advantage over using our non-random test data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also wonder why this module was added. The class does not seem to be used anywhere.
src/modalities/training/activation_checkpointing/activation_checkpointing.py
Outdated
Show resolved
Hide resolved
|
||
def _selective_checkpointing_context_fn(): | ||
meta = defaultdict(int) | ||
save_ops_set = {ActivationCheckpointing.SAVE_DICT[key] for key in save_ops_keys} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This throws an error for operations that are not listed in ActivationCheckpointing.SAVE_DICT. Why do we restrict to the ops in ActivationCheckpointing.SAVE_DICT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the config we can only define strings and we need to map the string to the respective function. In theory, we could do something with eval()
but I also found that rather ugly and error prone. I would suggest, we run the benchmarking w.r.t. SAC and check if we need to make it more generic?
mm_count_key = f"{mode}_mm_count" | ||
if func == torch.ops.aten.mm.default: | ||
meta[mm_count_key] += 1 | ||
# Saves output of all compute ops in save_ops_set, except every second mm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why only every second?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed the setup in torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/infra/parallelize.py#L301
From my understanding it's for balancing compute vs memory savings. If we wanted to make this completely configurable, we would have to store the checkpointing frequency of every OP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! I left a few comments.
Generally, it is a bit hard to understand the various changes that are seemingly unrelated to AC:
- new results subscriber variant
EvaluationResultToDiscSubscriber
- new class
RandomDatasetBatchGenerator
- changes in
get_compiled_model
- changes in
experiment_id
generation
I think it would be good to at least list them explicitly in the PR description (instead of "Minor restructurings of the code") and indicate their purpose.
| ActivationCheckpointedModelConfig.SelectiveOpACParams | ||
), | ||
) -> nn.Module: | ||
"""FSDP2 variant for applying activation checkpointing to the given model (in-place operation). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""FSDP2 variant for applying activation checkpointing to the given model (in-place operation). | |
"""General variant for applying activation checkpointing to the given model (in-place operation). |
.. since it can be used in the absence of FSDP2, directly on nn.Module
classes (as the name of the method also indicates).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say since we are always either using FSDP1 or FSDP2 and don't support training a model without either of those parallelizations, we should be a bit more restrictive here. Otherwise, the user might get a wrong idea even though in theory it could be possible but in practice it's not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I get your point. I wonder if we should put "FSDP2" in the function name (like we do with FSDP1) to emphasize this, despite the model being a nn.Module
. This might be slightly less confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done 👍
src/modalities/training/activation_checkpointing/activation_checkpointing.py
Outdated
Show resolved
Hide resolved
(22310, 2, "config_activation_checkpointing_fsdp1_legacy.yaml"), | ||
], | ||
) | ||
def test_full_activation_checkpointing_FSDP1_legacy(world_size: int, rdvz_port: int, relative_config_path: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. I think the point is that modalities with FSDP1 is stable and has successfully been used for model training in practice. FSDP2, in contrast, requires some additional work (like this PR, or #374). Once the work is done and modalities with FSDP2 has proven to be stable and reliable in practice, we will probably drop support for FSDP1 after a certain grace period.
raise NotImplementedError | ||
|
||
|
||
class RandomDatasetBatchGenerator(DatasetBatchGeneratorIF): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also wonder why this module was added. The class does not seem to be used anywhere.
src/modalities/logging_broker/subscriber_impl/results_subscriber.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first three tests in tests/training/test_activation_checkpointing.py
require 2 GPUs, but they are not skipped if only a single GPU is available.
This makes the tests with github actions fail:
https://github.com/Modalities/modalities/actions/runs/16443411219/job/46469270473
Locally, the tests also fail with CUDA_VISIBLE_DEVICES=0
, while they run through with CUDA_VISIBLE_DEVICES=0,1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
What does this PR do?
This PR adds activation checkpointing (AC) support for FSDP2.
There are now three AC variants:
Additionally
EvaluationResultToDiscSubscriber
(will be used in benchmark tooling)RandomDatasetBatchGenerator
(will be used in profiler)get_compiled_model
: We check now that a module that is to be compiled has exactly one parent module that references it and throw an exception otherwise. Before we were only replacing the compiled model for one of the parents and silently skipped the other parents.experiment_id
generation: Previouslyget_experiment_id_from_config(...)
contained local experiment id generation and syncing. I refactored it such that we can now also sync arbitrary strings.Checklist before submitting final PR
python tests/tests.py
)CHANGELOG_DEV.md
)