diff --git a/paxml/tasks/lm/model_params.py b/paxml/tasks/lm/model_params.py index a4267f1f3..58e00f089 100644 --- a/paxml/tasks/lm/model_params.py +++ b/paxml/tasks/lm/model_params.py @@ -663,6 +663,12 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]: self.CHECKPOINT_POLICY) else: model_p.lm_tpl.stacked_transformer_tpl = stacked_transformer_tpl + if (self.CHECKPOINT_POLICY == + layers.AutodiffCheckpointType.OFFLOAD_DOT_WITH_NO_BATCH_DIM): + model_p.lm_tpl.stacked_transformer_tpl.checkpoint_policy = ( + self.CHECKPOINT_POLICY) + model_p.lm_tpl.stacked_transformer_tpl.remat = True + # Enable bf16. model_p.fprop_dtype = self.FPROP_DTYPE