Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shape mismatch in RoPE embeddings gpt_neox model when rotary_ndims is odd #35233

Open
2 of 4 tasks
mseeger opened this issue Dec 12, 2024 · 11 comments · May be fixed by #35376
Open
2 of 4 tasks

Shape mismatch in RoPE embeddings gpt_neox model when rotary_ndims is odd #35233

mseeger opened this issue Dec 12, 2024 · 11 comments · May be fixed by #35376
Labels

Comments

@mseeger
Copy link

mseeger commented Dec 12, 2024

System Info

  • transformers version: 4.48.0.dev0
  • Platform: macOS-15.1.1-arm64-arm-64bit
  • Python version: 3.12.7
  • Huggingface_hub version: 0.26.5
  • Safetensors version: 0.4.5
  • Accelerate version: 1.2.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I just appended the following to https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py:

def reproduce_bug():
    # Then:
    # head_size = hidden_size // num_attention_heads = 4
    # rotary_ndims = int(head_size * rotary_pct) = 3
    config = GPTNeoXConfig(
        vocab_size=96,
        max_position_embeddings=32,
        hidden_size=32,
        num_hidden_layers=2,
        num_attention_heads=8,
        intermediate_size=3 * 32,
        rotary_pct=0.75,
        use_parallel_residual=False,
    )
    model = GPTNeoXModel(config)
    input_ids = torch.randint(0, config.vocab_size, (1, config.max_position_embeddings))
    logits = model(input_ids)
    print(f"logits.shape = {logits.shape}")


if __name__ == "__main__":
    reproduce_bug()

Then, I ran

python -m src.transformers.models.gpt_neox.modeling_gpt_neox src/transformers/models/gpt_neox/modeling_gpt_neox.py

This gives me the following error output:

The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`attribute of the `GPTNeoXAttention` class! It will be removed in v4.48
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/seeger/git/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 1510, in <module>
    reproduce_bug()
  File "/Users/seeger/git/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 1505, in reproduce_bug
    logits = model(input_ids)
             ^^^^^^^^^^^^^^^^
  File "/Users/seeger/venvs/transformers_venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/seeger/venvs/transformers_venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/seeger/git/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 910, in forward
    outputs = layer(
              ^^^^^^
  File "/Users/seeger/venvs/transformers_venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/seeger/venvs/transformers_venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/seeger/git/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 657, in forward
    attention_layer_outputs = self.attention(
                              ^^^^^^^^^^^^^^^
  File "/Users/seeger/venvs/transformers_venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/seeger/venvs/transformers_venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/seeger/git/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 319, in forward
    query, key, value, present = self._attn_projections_and_rope(
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/seeger/git/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 431, in _attn_projections_and_rope
    query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/seeger/git/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 607, in apply_rotary_pos_emb
    q_embed = (q * cos) + (rotate_half(q) * sin)
               ~~^~~~~
RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 3

This is what I expected. Your code does not work if rotary_ndims is odd. Here, it is 3. The way that cos, sin are computed gives them a final dim size 2 * ceil(rotary_ndims / 2) == rotary_ndims + 1, this is 1 too large.

Note that your code actually "works" if rotary_ndims = 1. Then, cos, sin have final dim size 2 and the code above works due to broadcasting (1 broadcast to 2), both q, k have final dim 1 too large, but that still works. But once rotary_ndims is odd and larger than 1, it fails.

Expected behavior

Without this bug, cos and sin should have size rotary_ndims in the final dimension, no matter whether this is even or odd. My suggestions:

  • Restrict rotary_ndim to be even, or
  • Subselect cos, sin so their final dim size is rotary_ndims

My feeling is this does not only affect this single model, but many others as well. But I did not check.

@mseeger mseeger added the bug label Dec 12, 2024
@Rocketknight1
Copy link
Member

Hi @mseeger, thanks for the bug report!

First question: Does this affect any of the major gpt-neox checkpoints on the Hub, or do they all have rotary_pct values that don't trigger the issue?

Secondly: Would you be willing to make a PR to fix this? I think either of the two solutions you suggested are viable, as long as they don't change the output for existing models that aren't broken.

@mseeger
Copy link
Author

mseeger commented Dec 12, 2024

Sure, I can do this. I could also try to find out whether any of the other models have the same issue.

I'd be surprised if any models on the hub surface this issue, since otherwise their creators would have noted, no? But I can do a quick check.

I'd prefer the second option. Unless I am missing something, since cos, sin are derived tensors, they'd never be part of a checkpoint. Also, even if q, k shapes are changing, this only happens in the dimension where their inner products are over, so this would not affect any other sizes.

@mseeger
Copy link
Author

mseeger commented Dec 12, 2024

@Rocketknight1 . I found a model where the issue would likely arise:

https://huggingface.co/Isotonic/gpt_neox_225M/blob/main/config.json

Pretty odd one, hidden_size=1024, num_attention_heads=12, rotary_pct=0.25: head_size=85, rotary_ndims=21.
Note that 12 * 85 = 1020, does this even work?

@mseeger
Copy link
Author

mseeger commented Dec 12, 2024

https://huggingface.co/mkshing/gpt-neox-285m-init/blob/main/config.json

Another one of the same size.
Somebody who knows better than me should write a script to run over all models on the hub.

@Rocketknight1
Copy link
Member

Hmm, yeah - their num_attention_heads doesn't seem to evenly divide their hidden_size, so I think there are other issues in those checkpoints.

@Rocketknight1
Copy link
Member

Anyway, for now, I think you can just make the PR, and we can test it with models with more 'normal' inits!

@mseeger
Copy link
Author

mseeger commented Dec 13, 2024

The HF approach to massively copy&paste code between different models is really painful if one has to change something. I know this seems a trade-mark, but the risk is quite high that in fixes like these, one missed some models. The copy and pasting is even done within models.

@mseeger
Copy link
Author

mseeger commented Dec 16, 2024

Takes a little longer than I thought. There are almost 40 models involved. I am trying to write new tests, this takes the most time, because frankly many of these codes just do different things. I know this lowers the barrier to entry I suppose, but making changes/fixes such as this one here is painful.

BTW: I also found at least 2-3 bugs affecting some of the models. Shall I send a separate PR for fixing them? @Rocketknight1

@Rocketknight1
Copy link
Member

Woah, okay - are all 40 models copying the same code with Copied from?

@mseeger
Copy link
Author

mseeger commented Dec 16, 2024

That would be nice. But there are different implementations. The best one is the one of gpt_neox, which is used quite a bit, where the position_embeddings are computed once and passed to the layers. Another version keeps them (copies) in the attention classes. And then there are odd ones.

I'll not refactor anything, because I don't want to change anything. I just fix bugs. But it may be a good exercise to unify RoPE across all models that use it, so it is really just copy&paste.

mseeger added a commit to mseeger/transformers that referenced this issue Dec 20, 2024
First part of resolution of huggingface#35233
- Changes related to `position_embeddings` being a mandatory argument
- Remove `position_ids` argument of `apply_rotary_pos_emb`
- Replace `torch.stack` by `torch.cat`, former requires equal shapes
- `esm`: RoPE depends on `position_ids`, which was ignored.
- `gpt_neox`: Selection of attention compute type via class removed
- `gptj`: RoPE must be applied per head, and some shape issues.
- `nemotron`: `config.partial_rotary_factor` was not implemented.
mseeger added a commit to mseeger/transformers that referenced this issue Dec 20, 2024
First part of resolution of huggingface#35233
- Changes related to `position_embeddings` being a mandatory argument
- Remove `position_ids` argument of `apply_rotary_pos_emb`
- Replace `torch.stack` by `torch.cat`, former requires equal shapes
- `esm`: RoPE depends on `position_ids`, which was ignored.
- `gpt_neox`: Selection of attention compute type via class removed
- `gptj`: RoPE must be applied per head, and some shape issues.
- `nemotron`: `config.partial_rotary_factor` was not implemented.
@mseeger mseeger linked a pull request Dec 20, 2024 that will close this issue
5 tasks
mseeger added a commit to mseeger/transformers that referenced this issue Dec 20, 2024
First part of resolution of huggingface#35233
- Changes related to `position_embeddings` being a mandatory argument
- Remove `position_ids` argument of `apply_rotary_pos_emb`
- Replace `torch.stack` by `torch.cat`, former requires equal shapes
- `esm`: RoPE depends on `position_ids`, which was ignored.
- `gpt_neox`: Selection of attention compute type via class removed
- `gptj`: RoPE must be applied per head, and some shape issues.
- `nemotron`: `config.partial_rotary_factor` was not implemented.
@mseeger
Copy link
Author

mseeger commented Dec 21, 2024

#35376 is the first of 2 PRs fixing this issue. I split it into two for easier reviewing.

mseeger added a commit to mseeger/transformers that referenced this issue Dec 21, 2024
First part of resolution of huggingface#35233
- Changes related to `position_embeddings` being a mandatory argument
- Remove `position_ids` argument of `apply_rotary_pos_emb`
- Replace `torch.stack` by `torch.cat`, former requires equal shapes
- `esm`: RoPE depends on `position_ids`, which was ignored.
- `gpt_neox`: Selection of attention compute type via class removed
- `gptj`: RoPE must be applied per head, and some shape issues.
- `nemotron`: `config.partial_rotary_factor` was not implemented.
mseeger added a commit to mseeger/transformers that referenced this issue Dec 21, 2024
First part of resolution of huggingface#35233
- Changes related to `position_embeddings` being a mandatory argument
- Remove `position_ids` argument of `apply_rotary_pos_emb`
- Replace `torch.stack` by `torch.cat`, former requires equal shapes
- `esm`: RoPE depends on `position_ids`, which was ignored.
- `gpt_neox`: Selection of attention compute type via class removed
- `gptj`: RoPE must be applied per head, and some shape issues.
- `nemotron`: `config.partial_rotary_factor` was not implemented.
mseeger added a commit to mseeger/transformers that referenced this issue Dec 21, 2024
First part of resolution of huggingface#35233
- Changes related to `position_embeddings` being a mandatory argument
- Remove `position_ids` argument of `apply_rotary_pos_emb`
- Replace `torch.stack` by `torch.cat`, former requires equal shapes
- `esm`: RoPE depends on `position_ids`, which was ignored.
- `gpt_neox`: Selection of attention compute type via class removed
- `gptj`, `codegen`: RoPE must be applied per head, and some shape issues.
- `nemotron`: `config.partial_rotary_factor` was not implemented.
mseeger added a commit to mseeger/transformers that referenced this issue Dec 22, 2024
First part of resolution of huggingface#35233
- Changes related to `position_embeddings` being a mandatory argument
- Remove `position_ids` argument of `apply_rotary_pos_emb`
- Replace `torch.stack` by `torch.cat`, former requires equal shapes
- `esm`: RoPE depends on `position_ids`, which was ignored.
- `gpt_neox`: Selection of attention compute type via class removed
- `gptj`, `codegen`: RoPE must be applied per head, and some shape issues.
- `nemotron`: `config.partial_rotary_factor` was not implemented.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants