-
Notifications
You must be signed in to change notification settings - Fork 234
Open
Labels
bugSomething isn't workingSomething isn't working
Description
When both policy.sequence_packing.enabled=true and policy.dtensor_cfg.sequence_parallel=true, GRPO training fails during logprob computation with a RoPE size mismatch in Qwen2. Disabling sequence parallel makes it run normally.
Error Log
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1881, in _call_impl
return inner()
^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1829, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/nemo-rl/3rdparty/Automodel-workspace/Automodel/nemo_automodel/_transformers/auto_model.py", line 91, in wrapper
return func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/transformers/utils/generic.py", line 918, in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 449, in forward
outputs: BaseModelOutputWithPast = self.model(
^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/transformers/utils/generic.py", line 1064, in wrapper
outputs = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 384, in forward
hidden_states = decoder_layer(
^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1881, in _call_impl
return inner()
^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1829, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 234, in forward
hidden_states, _ = self.self_attn(
^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
return self.checkpoint_fn( # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 503, in checkpoint
ret = function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 158, in forward
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 79, in apply_rotary_pos_emb
q_embed = (q * cos) + (rotate_half(q) * sin)
~~^~~~~
RuntimeError: The size of tensor a (17168) must match the size of tensor b (17167) at non-singleton dimension 2Steps to reproduce
- Use this config:
test_assets/debug_sequence_packing/grpo_deepseek_r1_7B_16n8g.yaml - Set
policy.dtensor_cfg.sequence_parallel=true(keeppolicy.sequence_packing.enabled=true) - Run:
NRL_FORCE_REBUILD_VENVS=true uv run examples/run_grpo_math.py \
--config test_assets/debug_sequence_packing/grpo_deepseek_r1_7B_16n8g.yaml \
cluster.num_nodes=16 \
cluster.gpus_per_node=8 \
policy.dtensor_cfg.tensor_parallel_size=2 \
checkpointing.enabled=false \
logger.wandb_enabled=falseMetadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working