Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exception: The provided qkv memory layout is not supported! #544

Closed
carmocca opened this issue Nov 29, 2023 · 3 comments
Closed

Exception: The provided qkv memory layout is not supported! #544

carmocca opened this issue Nov 29, 2023 · 3 comments
Assignees

Comments

@carmocca
Copy link
Contributor

carmocca commented Nov 29, 2023

Hi!

This piece of code used to work with a previous TransformerEngine version: https://github.com/Lightning-AI/lit-gpt/blob/d26ae7adb39918d0e821d2daf5089f86b752c004/lit_gpt/model_transformer_engine.py#L215-L219

I am working on upgrading it to the latest changes, but I am falling into:

Traceback (most recent call last):
  File "/workspace/lit-gpt/compare.py", line 22, in <module>
    actual = te_model(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/lit-gpt/model_te.py", line 62, in forward
    x = block(x, cos, sin)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/lit-gpt/model_te.py", line 132, in forward
    y = self.attn(q, k, v)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py", line 2085, in forward
    qkv_layout = _get_qkv_layout(query_layer, key_layer, value_layer,
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py", line 1195, in _get_qkv_layout
    raise Exception("The provided qkv memory layout is not supported!")
Exception: The provided qkv memory layout is not supported!

Since DotProductAttention uses qkv_format="sbhd" by default, I also tried

        # TE attn requires (T, B, nh, hs)
        q = q.permute(2, 0, 1, 3)
        k = k.permute(2, 0, 1, 3)
        v = v.permute(2, 0, 1, 3)
        y = self.attn(q, k, v)
        y = y.transpose(0, 1)

Unsuccessfully. Perhaps _get_qkv_layout could be more verbose about what is not correct.

For reference, this is the shape of q, k, v after RoPE and before any attention permutations.

# B, n_head, T, head_size
q.shape=torch.Size([2, 8, 2048, 64]) 
k.shape=torch.Size([2, 8, 2048, 64])
v.shape=torch.Size([2, 8, 2048, 64])

My code can be viewed at Lightning-AI/litgpt@main...carmocca/transformer-engine

@cyanguwa
Copy link
Collaborator

Hi @carmocca , does your TE installation have these lines of code? They should have made q/k/v contiguous and in your case detected the layout as sbhd_sbhd_sbhd.

if qkv_layout == 'not_supported':
# force q,k,v to be contiguous and run get_layout again
q, k, v = [x.contiguous() for x in [q, k, v]]
qkv_layout = run_iteratively(q, k, v)

I didn't use exactly your code, but I made this example, and it seems to give me the right results.

import torch
# b, s, d
B, T, C = (2, 2048, 64)
h = 8
q_per_kv = 1

# b, h, 3, s, d 
qkv = torch.randn(B, T, h, q_per_kv + 2, C).permute(0, 2, 3, 1, 4)

q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
q = q.reshape(B, -1, T, C)  # (B, nh_q, T, hs)
k = k.view(B, -1, T, C)  # (B, nh_k, T, hs)
v = v.view(B, -1, T, C)  # (B, nh_v, T, hs)
print("before rope: ", q.shape, k.shape, v.shape)

q = q.permute(2, 0, 1, 3)
k = k.permute(2, 0, 1, 3)
v = v.permute(2, 0, 1, 3)
print("after permutes: ",q.shape, k.shape, v.shape)

from transformer_engine.pytorch.attention import (
    RotaryPositionEmbedding,
    apply_rotary_pos_emb,
    _get_qkv_layout,
)
PE = RotaryPositionEmbedding(dim=C)
rotary_pos_emb = PE(T)
q_pos_emb = rotary_pos_emb
k_pos_emb = rotary_pos_emb
q = apply_rotary_pos_emb(q, q_pos_emb)
k = apply_rotary_pos_emb(k, k_pos_emb)
print("after rope: ", q.shape, k.shape, v.shape)

layout, *_ = _get_qkv_layout(q, k, v)
print("layout: ",layout)

Results:

before rope:  torch.Size([2, 8, 2048, 64]) torch.Size([2, 8, 2048, 64]) torch.Size([2, 8, 2048, 64])
after permutes:  torch.Size([2048, 2, 8, 64]) torch.Size([2048, 2, 8, 64]) torch.Size([2048, 2, 8, 64])
after rope:  torch.Size([2048, 2, 8, 64]) torch.Size([2048, 2, 8, 64]) torch.Size([2048, 2, 8, 64])
layout:  sbhd_sbhd_sbhd

@carmocca
Copy link
Contributor Author

@carmocca
Copy link
Contributor Author

carmocca commented Nov 30, 2023

Unfortunately, I just lost access to H100 again so I won't be able to check if the main branch fixes it. I'll close the issue since you believe it's fixed there

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants