File tree Expand file tree Collapse file tree 2 files changed +18
-16
lines changed Expand file tree Collapse file tree 2 files changed +18
-16
lines changed Original file line number Diff line number Diff line change @@ -134,7 +134,7 @@ def validate_and_set_config(
134134
135135 # Get other configuration values
136136 cpu_offload = config ["dtensor_cfg" ]["cpu_offload" ]
137- offload_optimizer_for_logprob = config [ "offload_optimizer_for_logprob" ]
137+ offload_optimizer_for_logprob = config . get ( "offload_optimizer_for_logprob" , False )
138138 max_grad_norm = config ["max_grad_norm" ]
139139 enable_seq_packing = config ["sequence_packing" ]["enabled" ]
140140 model_name = config ["model_name" ]
Original file line number Diff line number Diff line change @@ -349,21 +349,23 @@ def train(
349349
350350 # Forward and backward pass
351351 loss , loss_metrics = forward_backward (
352- self .model ,
353- mb ,
354- loss_fn ,
355- global_valid_seqs ,
356- global_valid_toks ,
357- processed_inputs ,
358- self .dtype ,
359- self .cp_size ,
360- self .cp_mesh ,
361- self .device_mesh ,
362- self .enable_seq_packing ,
363- self ._is_reward_model ,
364- self .allow_flash_attn_args ,
365- eval_mode ,
366- self ._apply_temperature_scaling ,
352+ model = self .model ,
353+ mb = mb ,
354+ loss_fn = loss_fn ,
355+ global_valid_seqs = global_valid_seqs ,
356+ global_valid_toks = global_valid_toks ,
357+ processed_inputs = processed_inputs ,
358+ dtype = self .dtype ,
359+ cp_size = self .cp_size ,
360+ cp_mesh = self .cp_mesh ,
361+ device_mesh = self .device_mesh ,
362+ enable_seq_packing = self .enable_seq_packing ,
363+ is_reward_model = self ._is_reward_model ,
364+ allow_flash_attn_args = self .allow_flash_attn_args ,
365+ is_hf_model = self .is_hf_model ,
366+ is_moe_model = self .is_moe_model ,
367+ eval_mode = eval_mode ,
368+ apply_temperature_fn = self ._apply_temperature_scaling ,
367369 )
368370
369371 # skip the update for dummy batches
You can’t perform that action at this time.
0 commit comments