Skip to content

Commit

Permalink
Merge pull request #1307 from AI-Hypercomputer:mattdavidow-pp-remat-a…
Browse files Browse the repository at this point in the history
…gain

PiperOrigin-RevId: 730645365
  • Loading branch information
maxtext authors committed Feb 25, 2025
2 parents b18d7c6 + 6a1fe19 commit c507c29
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions MaxText/layers/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,10 +496,10 @@ def gather_weights_for_stages_in(weights):
def get_pipeline_remat_policy(self):
# We ensure that the decoder layer inputs are saved, although we leave it to a custom
# policy if they should be saved to device or offloaded.
if self.remat_policy != "custom":
save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input")
else:
save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input")
if self.config.remat_policy == "custom":
return self.remat_policy

save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input")
if self.remat_policy is not None:
remat_policy = jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy)
else:
Expand Down

0 comments on commit c507c29

Please sign in to comment.