From 70e4f31dea1fa0dcf7e425af46344fa1db45f44f Mon Sep 17 00:00:00 2001 From: Alexander Bukharin <59148829+abukharin3@users.noreply.github.com> Date: Mon, 2 Dec 2024 21:26:28 -0500 Subject: [PATCH] fix: correct REINFORCE to resume training (#427) Signed-off-by: abukharin Signed-off-by: NeMo-Aligner CI Co-authored-by: abukharin Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Terry Kong --- nemo_aligner/utils/train_script_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nemo_aligner/utils/train_script_utils.py b/nemo_aligner/utils/train_script_utils.py index c6f6f8089..eeed1a538 100644 --- a/nemo_aligner/utils/train_script_utils.py +++ b/nemo_aligner/utils/train_script_utils.py @@ -50,12 +50,18 @@ def retrieve_custom_trainer_state_dict(ptl_trainer): consumed_samples = extract_value_from_ckpt(key="consumed_samples", ckpt_path=trainer_restore_path) step = extract_value_from_ckpt(key="step", ckpt_path=trainer_restore_path) epoch = extract_value_from_ckpt(key="epoch", ckpt_path=trainer_restore_path) + + # TODO: unify alignment step key to avoid adding one for each algo ppo_optimization_step = extract_value_from_ckpt(key="ppo_optimization_step", ckpt_path=trainer_restore_path) + reinforce_optimization_step = extract_value_from_ckpt( + key="reinforce_optimization_step", ckpt_path=trainer_restore_path + ) trainer_state_dict = { "step": step, "consumed_samples": consumed_samples, "epoch": epoch, "ppo_optimization_step": ppo_optimization_step, + "reinforce_optimization_step": reinforce_optimization_step, } return trainer_state_dict