Skip to content

Commit

Permalink
fix: correct REINFORCE to resume training (#427)
Browse files Browse the repository at this point in the history
Signed-off-by: abukharin <[email protected]>
Signed-off-by: NeMo-Aligner CI <[email protected]>
Co-authored-by: abukharin <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Terry Kong <[email protected]>
  • Loading branch information
4 people authored Dec 3, 2024
1 parent 35d0c59 commit 70e4f31
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions nemo_aligner/utils/train_script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,18 @@ def retrieve_custom_trainer_state_dict(ptl_trainer):
consumed_samples = extract_value_from_ckpt(key="consumed_samples", ckpt_path=trainer_restore_path)
step = extract_value_from_ckpt(key="step", ckpt_path=trainer_restore_path)
epoch = extract_value_from_ckpt(key="epoch", ckpt_path=trainer_restore_path)

# TODO: unify alignment step key to avoid adding one for each algo
ppo_optimization_step = extract_value_from_ckpt(key="ppo_optimization_step", ckpt_path=trainer_restore_path)
reinforce_optimization_step = extract_value_from_ckpt(
key="reinforce_optimization_step", ckpt_path=trainer_restore_path
)
trainer_state_dict = {
"step": step,
"consumed_samples": consumed_samples,
"epoch": epoch,
"ppo_optimization_step": ppo_optimization_step,
"reinforce_optimization_step": reinforce_optimization_step,
}

return trainer_state_dict
Expand Down

0 comments on commit 70e4f31

Please sign in to comment.