diff --git a/nemo_aligner/utils/train_script_utils.py b/nemo_aligner/utils/train_script_utils.py index c6f6f8089..eeed1a538 100644 --- a/nemo_aligner/utils/train_script_utils.py +++ b/nemo_aligner/utils/train_script_utils.py @@ -50,12 +50,18 @@ def retrieve_custom_trainer_state_dict(ptl_trainer): consumed_samples = extract_value_from_ckpt(key="consumed_samples", ckpt_path=trainer_restore_path) step = extract_value_from_ckpt(key="step", ckpt_path=trainer_restore_path) epoch = extract_value_from_ckpt(key="epoch", ckpt_path=trainer_restore_path) + + # TODO: unify alignment step key to avoid adding one for each algo ppo_optimization_step = extract_value_from_ckpt(key="ppo_optimization_step", ckpt_path=trainer_restore_path) + reinforce_optimization_step = extract_value_from_ckpt( + key="reinforce_optimization_step", ckpt_path=trainer_restore_path + ) trainer_state_dict = { "step": step, "consumed_samples": consumed_samples, "epoch": epoch, "ppo_optimization_step": ppo_optimization_step, + "reinforce_optimization_step": reinforce_optimization_step, } return trainer_state_dict