Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] TransformerLayer: add support for Falcon architecture #513

Merged
merged 7 commits into from
Dec 4, 2023

Conversation

Marks101
Copy link
Contributor

@Marks101 Marks101 commented Nov 10, 2023

Falcon-40b and Falcon-180b are two exciting publicly available models. Currently, transformer-engine does not support their architecture because in their implementation self attention and mlp are not computed in sequence. Instead, the blocks (layer norm -> self attention) and (layernorm -> mlp) are fed with the input into the layer. In a computational graph these operations are thus in parallel. In the Falcon configs this is denoted as new_decoder_architecture. This PR introduces this feature and thus makes it possible to finetune Falcon models with transformer-engine. We would be really happy if this feature finds it's way into transformer-engine.

Two notes on the implementation:

  • I denoted the new option parallel_attention_mlp. I am not sure if this is a perfect naming, this is up for discussion.
  • I created a new method _bias_dropout_add() in order to keep the new code clear
  • Falcon models do not have a bias, accordingly I think that a pretty clean solution is to set return_bias=False for attention and mlp and then use bias dropout add. For models using parallel_attention_mpl and bias, this might not be optimal.

@timmoon10 timmoon10 self-requested a review November 14, 2023 05:49
@timmoon10
Copy link
Collaborator

/te-ci

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Marks101
Copy link
Contributor Author

Marks101 commented Nov 14, 2023

Thanks for proceeding so quickly with this PR 🥳
Sadly unittests for jax and paddle failed. Not sure how this could be influence based on my changes.

@ptrendx
Copy link
Member

ptrendx commented Nov 20, 2023

@Marks101 Could you add a test to test_numerics, similar to https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/test_numerics.py#L622? Otherwise LGTM :-). The other frameworks tests seem to have failed because of machine issues, so not related to this PR.

@Marks101 Marks101 force-pushed the add-parallel-attention-mlp branch from af7d81a to b7c908b Compare November 22, 2023 12:32
Now uses nn.functional.dropout because depending on the path there are one or two dropouts.

Signed-off-by: Markus Schnoes <[email protected]>
tests/pytorch/test_numerics.py Outdated Show resolved Hide resolved
tests/pytorch/test_numerics.py Outdated Show resolved Hide resolved
tests/pytorch/test_numerics.py Outdated Show resolved Hide resolved
tests/pytorch/test_numerics.py Outdated Show resolved Hide resolved
tests/pytorch/test_numerics.py Outdated Show resolved Hide resolved
@timmoon10
Copy link
Collaborator

/te-ci pytorch

@Marks101
Copy link
Contributor Author

Thanks for fixing my spelling mistakes ... sorry for that. The tests failed because there was one last occurance of parallel_attention_ml. I fixed that and ran the tests locally. Now it should be fine.

@timmoon10
Copy link
Collaborator

/te-ci pytorch

1 similar comment
@ptrendx
Copy link
Member

ptrendx commented Dec 4, 2023

/te-ci pytorch

@ptrendx
Copy link
Member

ptrendx commented Dec 4, 2023

Tim's attempt to run the CI failed due to network issue apparently, I just retried it.

@ptrendx ptrendx added the 1.2.0 label Dec 4, 2023
@ptrendx ptrendx merged commit 4e33a69 into NVIDIA:main Dec 4, 2023
20 checks passed
@ptrendx
Copy link
Member

ptrendx commented Dec 4, 2023

Merged. Thank you @Marks101 for the contribution!

@Marks101 Marks101 deleted the add-parallel-attention-mlp branch December 5, 2023 07:01
@Marks101
Copy link
Contributor Author

Marks101 commented Dec 5, 2023

Great, thank you for the support!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants