-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Exception: The provided qkv memory layout is not supported! #544
Comments
Hi @carmocca , does your TE installation have these lines of code? They should have made q/k/v contiguous and in your case detected the layout as sbhd_sbhd_sbhd. TransformerEngine/transformer_engine/pytorch/attention.py Lines 1204 to 1207 in 753eed3
I didn't use exactly your code, but I made this example, and it seems to give me the right results.
Results:
|
I was using the @stable branch as suggested in the README: https://github.com/NVIDIA/TransformerEngine/blob/stable/transformer_engine/pytorch/attention.py#L1195 |
Unfortunately, I just lost access to H100 again so I won't be able to check if the main branch fixes it. I'll close the issue since you believe it's fixed there |
Hi!
This piece of code used to work with a previous TransformerEngine version: https://github.com/Lightning-AI/lit-gpt/blob/d26ae7adb39918d0e821d2daf5089f86b752c004/lit_gpt/model_transformer_engine.py#L215-L219
I am working on upgrading it to the latest changes, but I am falling into:
Since
DotProductAttention
usesqkv_format="sbhd"
by default, I also triedUnsuccessfully. Perhaps
_get_qkv_layout
could be more verbose about what is not correct.For reference, this is the shape of q, k, v after RoPE and before any attention permutations.
My code can be viewed at Lightning-AI/litgpt@main...carmocca/transformer-engine
The text was updated successfully, but these errors were encountered: