Skip to content

Commit 0ac6167

Browse files
committed
fix
Signed-off-by: Hemil Desai <[email protected]>
1 parent ae04fcc commit 0ac6167

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

nemo_rl/models/policy/dtensor_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def validate_and_set_config(
134134

135135
# Get other configuration values
136136
cpu_offload = config["dtensor_cfg"]["cpu_offload"]
137-
offload_optimizer_for_logprob = config["offload_optimizer_for_logprob"]
137+
offload_optimizer_for_logprob = config.get("offload_optimizer_for_logprob", False)
138138
max_grad_norm = config["max_grad_norm"]
139139
enable_seq_packing = config["sequence_packing"]["enabled"]
140140
model_name = config["model_name"]

nemo_rl/models/policy/dtensor_policy_worker_v2.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -349,21 +349,23 @@ def train(
349349

350350
# Forward and backward pass
351351
loss, loss_metrics = forward_backward(
352-
self.model,
353-
mb,
354-
loss_fn,
355-
global_valid_seqs,
356-
global_valid_toks,
357-
processed_inputs,
358-
self.dtype,
359-
self.cp_size,
360-
self.cp_mesh,
361-
self.device_mesh,
362-
self.enable_seq_packing,
363-
self._is_reward_model,
364-
self.allow_flash_attn_args,
365-
eval_mode,
366-
self._apply_temperature_scaling,
352+
model=self.model,
353+
mb=mb,
354+
loss_fn=loss_fn,
355+
global_valid_seqs=global_valid_seqs,
356+
global_valid_toks=global_valid_toks,
357+
processed_inputs=processed_inputs,
358+
dtype=self.dtype,
359+
cp_size=self.cp_size,
360+
cp_mesh=self.cp_mesh,
361+
device_mesh=self.device_mesh,
362+
enable_seq_packing=self.enable_seq_packing,
363+
is_reward_model=self._is_reward_model,
364+
allow_flash_attn_args=self.allow_flash_attn_args,
365+
is_hf_model=self.is_hf_model,
366+
is_moe_model=self.is_moe_model,
367+
eval_mode=eval_mode,
368+
apply_temperature_fn=self._apply_temperature_scaling,
367369
)
368370

369371
# skip the update for dummy batches

0 commit comments

Comments
 (0)