Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Experimental FP8 tensor class #452

Merged
merged 29 commits into from
Oct 31, 2023

Conversation

timmoon10
Copy link
Collaborator

This FP8 tensor class is based on the implementation at https://github.com/facebookexperimental/protoquant/tree/fp8_poc and is primarily oriented toward enabling efficient FP8 support in Apex's DistributedFusedAdam. See NVIDIA/NeMo#7469 and NVIDIA/NeMo#7565.

CC @sudhakarsingh27 @ksivaman

@timmoon10 timmoon10 added the enhancement New feature or request label Sep 29, 2023
@ptrendx
Copy link
Member

ptrendx commented Sep 30, 2023

/te-ci

1 similar comment
@timmoon10
Copy link
Collaborator Author

/te-ci

@ptrendx ptrendx added the 1.0.0 label Oct 16, 2023
@sudhakarsingh27
Copy link
Collaborator

/te-ci pytorch

@ksivaman ksivaman force-pushed the float8tensor_experiments branch from de20156 to 4315115 Compare October 16, 2023 21:08
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Sudhakar Singh <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
@ksivaman ksivaman force-pushed the float8tensor_experiments branch from 67f7cd3 to b6bfddb Compare October 19, 2023 23:45
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
@ksivaman ksivaman marked this pull request as ready for review October 20, 2023 04:47
@ksivaman
Copy link
Member

/te-ci pytorch

@ksivaman
Copy link
Member

/te-ci pytorch

1 similar comment
@ksivaman
Copy link
Member

/te-ci pytorch

handled outside this class. If a tensor is initialized with an FP8
metadata object, it extracts the information it needs so it isn't
affected by later changes in the FP8 metadata (although its design
does cause us to leak some subtle side-effects into FP8 metadata).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doc is not really correct since we are holding a view to the meta, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ops using the tensor class's __torch_dispatch__ are insensitive to external changes in the meta since we cache scale_inv. However, all bets are off when we extract _data and pass it to external ops like tex.fp8_gemm.

transformer_engine/pytorch/fp8.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/module/base.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/module/base.py Show resolved Hide resolved
handled outside this class. If a tensor is initialized with an FP8
metadata object, it extracts the information it needs so it isn't
affected by later changes in the FP8 metadata (although its design
does cause us to leak some subtle side-effects into FP8 metadata).
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ops using the tensor class's __torch_dispatch__ are insensitive to external changes in the meta since we cache scale_inv. However, all bets are off when we extract _data and pass it to external ops like tex.fp8_gemm.

transformer_engine/pytorch/float8_tensor.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/float8_tensor.py Show resolved Hide resolved
transformer_engine/pytorch/float8_tensor.py Outdated Show resolved Hide resolved
@sudhakarsingh27 sudhakarsingh27 force-pushed the float8tensor_experiments branch from 8ff9e05 to dfcbcf1 Compare October 24, 2023 21:27
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
timmoon10 and others added 8 commits October 25, 2023 14:25
Handle case where transpose cache is updated externally.

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
timmoon10 and others added 2 commits October 26, 2023 17:34
ksivaman and others added 4 commits October 27, 2023 22:20
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Easier for multiple tensors to share, e.g. detached tensors.

Signed-off-by: Tim Moon <[email protected]>
@timmoon10 timmoon10 force-pushed the float8tensor_experiments branch from 718d284 to 94848da Compare October 31, 2023 00:34
Copy link
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving as experimental. We will iterate upon this in the next release.

@ksivaman ksivaman merged commit b1820c4 into NVIDIA:main Oct 31, 2023
9 checks passed
ksivaman added a commit that referenced this pull request Oct 31, 2023
* Experimental FP8 tensor

Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Sudhakar Singh <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Add fp8 tensor to ci test

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* review comments and tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Minor changes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Default to FP8 usage

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix docs

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Naming changes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* minor fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix transpose caching

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Debug transpose caching

Handle case where transpose cache is updated externally.

Signed-off-by: Tim Moon <[email protected]>

* Rename FP8GlobalStateManager.with_fp8_parameters

Signed-off-by: Tim Moon <[email protected]>

* remove Float8Tensor from import API

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Avoid caching FP8 transposes if not required

Signed-off-by: Tim Moon <[email protected]>

* Fix import error in FP8 tensor tests

Signed-off-by: Tim Moon <[email protected]>

* Fix tranpose caching and checkpointing bug

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Improve caching and fix distopt case

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Update transformer_engine/pytorch/float8_tensor.py

Signed-off-by: Tim Moon <[email protected]>

* Remove recursive logic

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix cache reset bug

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Store FP8 attributes in dict

Easier for multiple tensors to share, e.g. detached tensors.

Signed-off-by: Tim Moon <[email protected]>

* Make sure scale_inv is 1D tensor

Signed-off-by: Tim Moon <[email protected]>

* Make sure scale_inv is 1D tensor

Signed-off-by: Tim Moon <[email protected]>

* Fixes and detach recipe

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Set default fp8 data type

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Sudhakar Singh <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
mingxu1067 pushed a commit to mingxu1067/TransformerEngine that referenced this pull request Nov 3, 2023
* Experimental FP8 tensor

Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Sudhakar Singh <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Add fp8 tensor to ci test

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* review comments and tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Minor changes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Default to FP8 usage

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix docs

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Naming changes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* minor fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix transpose caching

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Debug transpose caching

Handle case where transpose cache is updated externally.

Signed-off-by: Tim Moon <[email protected]>

* Rename FP8GlobalStateManager.with_fp8_parameters

Signed-off-by: Tim Moon <[email protected]>

* remove Float8Tensor from import API

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Avoid caching FP8 transposes if not required

Signed-off-by: Tim Moon <[email protected]>

* Fix import error in FP8 tensor tests

Signed-off-by: Tim Moon <[email protected]>

* Fix tranpose caching and checkpointing bug

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Improve caching and fix distopt case

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Update transformer_engine/pytorch/float8_tensor.py

Signed-off-by: Tim Moon <[email protected]>

* Remove recursive logic

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix cache reset bug

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Store FP8 attributes in dict

Easier for multiple tensors to share, e.g. detached tensors.

Signed-off-by: Tim Moon <[email protected]>

* Make sure scale_inv is 1D tensor

Signed-off-by: Tim Moon <[email protected]>

* Make sure scale_inv is 1D tensor

Signed-off-by: Tim Moon <[email protected]>

* Fixes and detach recipe

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Set default fp8 data type

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Sudhakar Singh <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
cyanguwa pushed a commit to cyanguwa/TransformerEngine that referenced this pull request Nov 13, 2023
* Experimental FP8 tensor

Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Sudhakar Singh <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Add fp8 tensor to ci test

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* review comments and tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Minor changes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Default to FP8 usage

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix docs

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Naming changes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* minor fix

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix transpose caching

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Debug transpose caching

Handle case where transpose cache is updated externally.

Signed-off-by: Tim Moon <[email protected]>

* Rename FP8GlobalStateManager.with_fp8_parameters

Signed-off-by: Tim Moon <[email protected]>

* remove Float8Tensor from import API

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Avoid caching FP8 transposes if not required

Signed-off-by: Tim Moon <[email protected]>

* Fix import error in FP8 tensor tests

Signed-off-by: Tim Moon <[email protected]>

* Fix tranpose caching and checkpointing bug

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Improve caching and fix distopt case

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Update transformer_engine/pytorch/float8_tensor.py

Signed-off-by: Tim Moon <[email protected]>

* Remove recursive logic

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix cache reset bug

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Store FP8 attributes in dict

Easier for multiple tensors to share, e.g. detached tensors.

Signed-off-by: Tim Moon <[email protected]>

* Make sure scale_inv is 1D tensor

Signed-off-by: Tim Moon <[email protected]>

* Make sure scale_inv is 1D tensor

Signed-off-by: Tim Moon <[email protected]>

* Fixes and detach recipe

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Set default fp8 data type

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Sudhakar Singh <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Comment on lines +507 to +508
* full model training using optimizer with master weights, where the high
precision copies of weights are already present in the optimizer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this look in practice? If the model will be initialized directly with fp8 weights, how does the optimizer get high-precision copies?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
1.0.0 enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants