From ad18fb073b87dc994a66420e38e16705bfb75957 Mon Sep 17 00:00:00 2001 From: Jaroslav Sevcik Date: Tue, 12 Mar 2024 11:38:44 +0000 Subject: [PATCH] Set offload checkpoint policy --- paxml/tasks/lm/model_params.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/paxml/tasks/lm/model_params.py b/paxml/tasks/lm/model_params.py index a4267f1f3..58e00f089 100644 --- a/paxml/tasks/lm/model_params.py +++ b/paxml/tasks/lm/model_params.py @@ -663,6 +663,12 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]: self.CHECKPOINT_POLICY) else: model_p.lm_tpl.stacked_transformer_tpl = stacked_transformer_tpl + if (self.CHECKPOINT_POLICY == + layers.AutodiffCheckpointType.OFFLOAD_DOT_WITH_NO_BATCH_DIM): + model_p.lm_tpl.stacked_transformer_tpl.checkpoint_policy = ( + self.CHECKPOINT_POLICY) + model_p.lm_tpl.stacked_transformer_tpl.remat = True + # Enable bf16. model_p.fprop_dtype = self.FPROP_DTYPE