File tree Expand file tree Collapse file tree 7 files changed +10
-10
lines changed
applications/ColossalChat/coati/trainer Expand file tree Collapse file tree 7 files changed +10
-10
lines changed Original file line number Diff line number Diff line change @@ -380,9 +380,9 @@ def _criterion(outputs, inputs):
380380 self .accumulative_meter .get ("accuracy" ),
381381 global_step ,
382382 )
383- self .num_train_step += 1
384383 self .accumulative_meter .reset ()
385-
384+ self .num_train_step += 1
385+
386386 if self .save_dir is not None and self .num_train_step > 0 and self .num_train_step % self .save_interval == 0 :
387387 # save checkpoint
388388 self .coordinator .print_on_master ("\n Start saving model checkpoint with running states" )
Original file line number Diff line number Diff line change @@ -231,7 +231,6 @@ def _training_step(self, experience: Experience):
231231 experience:
232232 sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
233233 """
234- self .num_train_step += 1
235234 self .actor .train ()
236235 num_actions = experience .action_log_probs .size (1 )
237236 # policy loss
@@ -294,7 +293,7 @@ def _training_step(self, experience: Experience):
294293 self .temperature_annealing_scheduler .step_forward ()
295294
296295 # preparing logging model output and corresponding rewards.
297- if self .num_train_step % 10 == 1 :
296+ if self .num_train_step % 10 == 0 :
298297 response_text = self .experience_maker .tokenizer .batch_decode (
299298 experience .sequences , skip_special_tokens = True
300299 )
@@ -327,6 +326,7 @@ def _training_step(self, experience: Experience):
327326 self .writer .add_scalar ("approx_kl" , self .accumulative_meter .get ("kl" ), global_step )
328327 self .writer .add_scalar ("advantages" , self .accumulative_meter .get ("advantages" ), global_step )
329328 self .accumulative_meter .reset ()
329+ self .num_train_step += 1
330330
331331 def _learn (self , update_step : int ):
332332 """
Original file line number Diff line number Diff line change @@ -256,7 +256,7 @@ def _train(self, epoch: int):
256256 self .coordinator .print_on_master (
257257 f"Saved checkpoint at epoch { epoch } step { self .save_interval } at folder { self .save_dir } "
258258 )
259- self .num_train_step += 1
259+ self .num_train_step += 1
260260
261261 step_bar .close ()
262262
Original file line number Diff line number Diff line change @@ -233,7 +233,7 @@ def _train(self, epoch: int):
233233 self .coordinator .print_on_master (
234234 f"Saved checkpoint at epoch { epoch } step { self .save_interval } at folder { self .save_dir } "
235235 )
236- self .num_train_step += 1
236+ self .num_train_step += 1
237237
238238 step_bar .close ()
239239
Original file line number Diff line number Diff line change @@ -220,7 +220,6 @@ def _training_step(self, experience: Experience):
220220 experience:
221221 sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
222222 """
223- self .num_train_step += 1
224223 self .actor .train ()
225224 self .critic .train ()
226225 num_actions = experience .action_log_probs .size (1 )
@@ -294,7 +293,7 @@ def _training_step(self, experience: Experience):
294293 self .critic_scheduler .step ()
295294
296295 # preparing logging model output and corresponding rewards.
297- if self .num_train_step % 10 == 1 :
296+ if self .num_train_step % 10 == 0 :
298297 response_text = self .experience_maker .tokenizer .batch_decode (
299298 experience .sequences , skip_special_tokens = True
300299 )
@@ -336,6 +335,7 @@ def _training_step(self, experience: Experience):
336335 self .writer .add_scalar ("value" , self .accumulative_meter .get ("value" ), self .num_train_step )
337336 self .writer .add_scalar ("advantages" , self .accumulative_meter .get ("advantages" ), self .num_train_step )
338337 self .accumulative_meter .reset ()
338+ self .num_train_step += 1
339339
340340 def _learn (self , update_step : int ):
341341 """
Original file line number Diff line number Diff line change @@ -193,7 +193,7 @@ def _train(self, epoch):
193193 self .coordinator .print_on_master (
194194 f"Saved checkpoint at epoch { epoch } step { (i + 1 )/ self .accumulation_steps } at folder { self .save_dir } "
195195 )
196- self .num_train_step += 1
196+ self .num_train_step += 1
197197 step_bar .close ()
198198
199199 def _eval (self , epoch ):
Original file line number Diff line number Diff line change @@ -152,9 +152,9 @@ def _train(self, epoch: int):
152152 if self .writer :
153153 self .writer .add_scalar ("train/loss" , self .accumulative_meter .get ("loss" ), global_step )
154154 self .writer .add_scalar ("train/lr" , self .scheduler .get_last_lr ()[0 ], global_step )
155- self .num_train_step += 1
156155 self .accumulative_meter .reset ()
157156 step_bar .update ()
157+ self .num_train_step += 1
158158
159159 # Save checkpoint
160160 if (
You can’t perform that action at this time.
0 commit comments