Skip to content

Commit 530a216

Browse files
committed
add is_hf_model and is_moe_model to model state
Signed-off-by: Hemil Desai <[email protected]>
1 parent 5047e18 commit 530a216

File tree

7 files changed

+185
-26
lines changed

7 files changed

+185
-26
lines changed

nemo_rl/models/policy/dtensor_init.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import torch
2222
from accelerate import init_empty_weights
23+
from nemo_automodel._transformers.registry import ModelRegistry
2324
from nemo_automodel._transformers.utils import sliding_window_overwrite
2425
from nemo_automodel.components.config.loader import _resolve_target
2526
from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager
@@ -83,6 +84,8 @@ class ModelAndOptimizerState:
8384
optimizer: Optional[torch.optim.Optimizer]
8485
scheduler: Optional[Any]
8586
reference_model_state_dict: Optional[dict[str, torch.Tensor]]
87+
is_hf_model: bool
88+
is_moe_model: bool
8689

8790

8891
def validate_and_set_config(
@@ -423,8 +426,11 @@ def setup_model_and_optimizer(
423426
)
424427

425428
# Parallelize model
429+
is_hf_model = (
430+
model_config.architectures[0] not in ModelRegistry.model_arch_name_to_cls
431+
)
426432
is_moe_model = any(["expert" in key for key in model_state_dict_keys])
427-
if not isinstance(model, PreTrainedModel) and is_moe_model:
433+
if not isinstance(model, PreTrainedModel) and is_moe_model and not is_hf_model:
428434
moe_parallelize_model(
429435
model=model,
430436
world_mesh=device_mesh,
@@ -539,4 +545,6 @@ def setup_model_and_optimizer(
539545
optimizer=optimizer,
540546
scheduler=scheduler,
541547
reference_model_state_dict=reference_model_state_dict,
548+
is_hf_model=is_hf_model,
549+
is_moe_model=is_moe_model,
542550
)

nemo_rl/models/policy/dtensor_policy_worker_v2.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,14 @@ def __init__(
224224
_copy_dataclass_fields(
225225
self,
226226
model_state,
227-
["model", "model_state_dict_keys", "optimizer", "scheduler"],
227+
[
228+
"model",
229+
"model_state_dict_keys",
230+
"optimizer",
231+
"scheduler",
232+
"is_hf_model",
233+
"is_moe_model",
234+
],
228235
)
229236
if init_reference_model:
230237
self.reference_model_state_dict = model_state.reference_model_state_dict
@@ -503,6 +510,8 @@ def get_logprobs(
503510
self.cp_mesh,
504511
self._is_reward_model,
505512
self.allow_flash_attn_args,
513+
self.is_hf_model,
514+
self.is_moe_model,
506515
)
507516

508517
# Process outputs for logprobs

nemo_rl/models/policy/dtensor_train.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ def forward_backward(
8888
enable_seq_packing: bool,
8989
is_reward_model: bool,
9090
allow_flash_attn_args: bool,
91+
is_hf_model: bool,
92+
is_moe_model: bool,
9193
eval_mode: bool,
9294
apply_temperature_fn,
9395
) -> tuple[torch.Tensor, dict[str, Any]]:
@@ -107,6 +109,8 @@ def forward_backward(
107109
enable_seq_packing: Whether sequence packing is enabled
108110
is_reward_model: Whether this is a reward model
109111
allow_flash_attn_args: Whether model supports flash_attn_kwargs
112+
is_hf_model: Whether the model is an HF model
113+
is_moe_model: Whether the model is a MoE model
110114
eval_mode: Whether in evaluation mode
111115
apply_temperature_fn: Function to apply temperature scaling to logits
112116
@@ -121,6 +125,8 @@ def forward_backward(
121125
cp_mesh,
122126
is_reward_model,
123127
allow_flash_attn_args,
128+
is_hf_model,
129+
is_moe_model,
124130
)
125131

126132
# Process outputs for training (loss + backward)
@@ -230,6 +236,8 @@ def model_forward(
230236
cp_mesh: Any,
231237
is_reward_model: bool,
232238
allow_flash_attn_args: bool,
239+
is_hf_model: bool,
240+
is_moe_model: bool,
233241
) -> Any:
234242
"""Perform model forward pass.
235243
@@ -240,7 +248,8 @@ def model_forward(
240248
cp_mesh: Context parallel mesh
241249
is_reward_model: Whether this is a reward model
242250
allow_flash_attn_args: Whether model supports flash_attn_kwargs
243-
251+
is_hf_model: Whether the model is an HF model
252+
is_moe_model: Whether the model is a MoE model
244253
Returns:
245254
Model outputs
246255
"""
@@ -268,12 +277,14 @@ def model_forward(
268277
model_args = dict(
269278
input_ids=input_ids,
270279
attention_mask=attention_mask,
271-
padding_mask=~attention_mask if attention_mask is not None else None,
272280
position_ids=position_ids,
273281
use_cache=False,
274282
flash_attn_kwargs=flash_attn_kwargs,
275283
**vlm_kwargs,
276284
)
285+
if is_moe_model and not is_hf_model:
286+
padding_mask = ~attention_mask if attention_mask is not None else None
287+
model_args["padding_mask"] = padding_mask
277288

278289
if is_reward_model:
279290
# `flash_attn_kwarg` is not supported for `LlamaForSequenceClassification`.
@@ -291,7 +302,7 @@ def model_forward(
291302
# Remove None attention_mask padding_mask if present
292303
if model_args.get("attention_mask") is None:
293304
del model_args["attention_mask"]
294-
if model_args.get("padding_mask") is None:
305+
if "padding_mask" in model_args and model_args.get("padding_mask") is None:
295306
del model_args["padding_mask"]
296307

297308
outputs = model(**model_args)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ automodel = [
6464
"flash-attn==2.8.1",
6565
"mamba-ssm",
6666
"causal-conv1d",
67+
"transformers>=4.57.1",
6768
]
6869
vllm = [
6970
"cuda-python",

tests/unit/models/policy/test_dtensor_init.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def test_sequence_packing_with_vlm_raises_error(
188188
@patch("nemo_rl.models.policy.dtensor_init.resolve_model_class")
189189
@patch("nemo_rl.models.policy.dtensor_init.configure_dynamo_cache")
190190
@patch("nemo_rl.models.policy.dtensor_init.sliding_window_overwrite")
191-
@patch("nemo_rl.models.policy.dtensor_init.NeMoAutoModelForSequenceClassification")
191+
@patch("nemo_automodel.NeMoAutoModelForSequenceClassification")
192192
def test_reward_model_bradley_terry(
193193
self,
194194
mock_rm_class,
@@ -407,7 +407,7 @@ def test_hf_config_overrides_none(
407407
@patch("nemo_rl.models.policy.dtensor_init.resolve_model_class")
408408
@patch("nemo_rl.models.policy.dtensor_init.configure_dynamo_cache")
409409
@patch("nemo_rl.models.policy.dtensor_init.sliding_window_overwrite")
410-
@patch("nemo_rl.models.policy.dtensor_init.NeMoAutoModelForSequenceClassification")
410+
@patch("nemo_automodel.NeMoAutoModelForSequenceClassification")
411411
def test_reward_model_with_num_labels_equals_one(
412412
self,
413413
mock_rm_class,
@@ -753,6 +753,8 @@ def test_basic_model_setup(
753753
assert result.scheduler == mock_scheduler
754754
assert result.reference_model_state_dict is not None
755755
assert len(result.model_state_dict_keys) > 0
756+
assert isinstance(result.is_hf_model, bool)
757+
assert isinstance(result.is_moe_model, bool)
756758

757759
@patch("nemo_rl.models.policy.dtensor_init.init_empty_weights")
758760
@patch("nemo_rl.models.policy.utils.import_class_from_path")
@@ -835,6 +837,8 @@ def test_model_setup_without_optimizer(
835837
assert result.optimizer is None
836838
assert result.scheduler is None
837839
assert result.reference_model_state_dict is None
840+
assert isinstance(result.is_hf_model, bool)
841+
assert isinstance(result.is_moe_model, bool)
838842

839843
@patch("nemo_rl.models.policy.dtensor_init.init_empty_weights")
840844
def test_context_parallel_with_gemma3_raises_error(
@@ -1014,6 +1018,8 @@ def import_side_effect(path):
10141018
)
10151019

10161020
assert result.scheduler == mock_final_scheduler
1021+
assert isinstance(result.is_hf_model, bool)
1022+
assert isinstance(result.is_moe_model, bool)
10171023

10181024
@patch("nemo_rl.models.policy.dtensor_init.init_empty_weights")
10191025
@patch("nemo_rl.models.policy.utils.import_class_from_path")
@@ -1211,10 +1217,14 @@ def test_model_and_optimizer_state_creation(self):
12111217
optimizer=MagicMock(),
12121218
scheduler=MagicMock(),
12131219
reference_model_state_dict={"layer.weight": torch.zeros(10, 10)},
1220+
is_hf_model=False,
1221+
is_moe_model=True,
12141222
)
12151223

12161224
assert state.model is not None
12171225
assert len(state.model_state_dict_keys) == 2
12181226
assert state.optimizer is not None
12191227
assert state.scheduler is not None
12201228
assert state.reference_model_state_dict is not None
1229+
assert state.is_hf_model is False
1230+
assert state.is_moe_model is True

0 commit comments

Comments
 (0)