Skip to content

Commit

Permalink
Remove unnecessary global step increment
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Nov 4, 2024
1 parent 649dc16 commit e498bbd
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,6 @@ def repeat_generator():
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean()
gradient_accumulation_idx += 1
minibatch_idx += 1
self.state.global_step += 1
# del everything and empty cache
# fmt: off
del (
Expand Down Expand Up @@ -474,11 +473,11 @@ def repeat_generator():
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
metrics["episode"] = self.state.episode
self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
self.state.global_step += 1
self.log(metrics)
del kl, mean_kl, mean_entropy, scores

self.lr_scheduler.step()
self.state.global_step += 1
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None)
Expand Down

0 comments on commit e498bbd

Please sign in to comment.