-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Experimental FP8 tensor class #452
Conversation
/te-ci |
1 similar comment
/te-ci |
/te-ci pytorch |
de20156
to
4315115
Compare
Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Sudhakar Singh <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
67f7cd3
to
b6bfddb
Compare
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
/te-ci pytorch |
/te-ci pytorch |
1 similar comment
/te-ci pytorch |
handled outside this class. If a tensor is initialized with an FP8 | ||
metadata object, it extracts the information it needs so it isn't | ||
affected by later changes in the FP8 metadata (although its design | ||
does cause us to leak some subtle side-effects into FP8 metadata). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doc is not really correct since we are holding a view to the meta, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ops using the tensor class's __torch_dispatch__
are insensitive to external changes in the meta since we cache scale_inv
. However, all bets are off when we extract _data
and pass it to external ops like tex.fp8_gemm
.
handled outside this class. If a tensor is initialized with an FP8 | ||
metadata object, it extracts the information it needs so it isn't | ||
affected by later changes in the FP8 metadata (although its design | ||
does cause us to leak some subtle side-effects into FP8 metadata). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ops using the tensor class's __torch_dispatch__
are insensitive to external changes in the meta since we cache scale_inv
. However, all bets are off when we extract _data
and pass it to external ops like tex.fp8_gemm
.
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
8ff9e05
to
dfcbcf1
Compare
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Handle case where transpose cache is updated externally. Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Easier for multiple tensors to share, e.g. detached tensors. Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
718d284
to
94848da
Compare
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
…10/TransformerEngine into float8tensor_experiments
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving as experimental. We will iterate upon this in the next release.
* Experimental FP8 tensor Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Sudhakar Singh <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add fp8 tensor to ci test Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * review comments and tests Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Minor changes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Default to FP8 usage Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix docs Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Naming changes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * minor fix Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix transpose caching Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Debug transpose caching Handle case where transpose cache is updated externally. Signed-off-by: Tim Moon <[email protected]> * Rename FP8GlobalStateManager.with_fp8_parameters Signed-off-by: Tim Moon <[email protected]> * remove Float8Tensor from import API Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Avoid caching FP8 transposes if not required Signed-off-by: Tim Moon <[email protected]> * Fix import error in FP8 tensor tests Signed-off-by: Tim Moon <[email protected]> * Fix tranpose caching and checkpointing bug Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Improve caching and fix distopt case Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Update transformer_engine/pytorch/float8_tensor.py Signed-off-by: Tim Moon <[email protected]> * Remove recursive logic Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix cache reset bug Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Store FP8 attributes in dict Easier for multiple tensors to share, e.g. detached tensors. Signed-off-by: Tim Moon <[email protected]> * Make sure scale_inv is 1D tensor Signed-off-by: Tim Moon <[email protected]> * Make sure scale_inv is 1D tensor Signed-off-by: Tim Moon <[email protected]> * Fixes and detach recipe Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Set default fp8 data type Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Sudhakar Singh <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]>
* Experimental FP8 tensor Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Sudhakar Singh <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add fp8 tensor to ci test Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * review comments and tests Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Minor changes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Default to FP8 usage Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix docs Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Naming changes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * minor fix Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix transpose caching Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Debug transpose caching Handle case where transpose cache is updated externally. Signed-off-by: Tim Moon <[email protected]> * Rename FP8GlobalStateManager.with_fp8_parameters Signed-off-by: Tim Moon <[email protected]> * remove Float8Tensor from import API Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Avoid caching FP8 transposes if not required Signed-off-by: Tim Moon <[email protected]> * Fix import error in FP8 tensor tests Signed-off-by: Tim Moon <[email protected]> * Fix tranpose caching and checkpointing bug Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Improve caching and fix distopt case Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Update transformer_engine/pytorch/float8_tensor.py Signed-off-by: Tim Moon <[email protected]> * Remove recursive logic Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix cache reset bug Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Store FP8 attributes in dict Easier for multiple tensors to share, e.g. detached tensors. Signed-off-by: Tim Moon <[email protected]> * Make sure scale_inv is 1D tensor Signed-off-by: Tim Moon <[email protected]> * Make sure scale_inv is 1D tensor Signed-off-by: Tim Moon <[email protected]> * Fixes and detach recipe Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Set default fp8 data type Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Sudhakar Singh <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Ming Huang <[email protected]>
* Experimental FP8 tensor Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Sudhakar Singh <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add fp8 tensor to ci test Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * review comments and tests Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Minor changes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Default to FP8 usage Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix docs Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Naming changes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * minor fix Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix transpose caching Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Debug transpose caching Handle case where transpose cache is updated externally. Signed-off-by: Tim Moon <[email protected]> * Rename FP8GlobalStateManager.with_fp8_parameters Signed-off-by: Tim Moon <[email protected]> * remove Float8Tensor from import API Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Avoid caching FP8 transposes if not required Signed-off-by: Tim Moon <[email protected]> * Fix import error in FP8 tensor tests Signed-off-by: Tim Moon <[email protected]> * Fix tranpose caching and checkpointing bug Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Improve caching and fix distopt case Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Update transformer_engine/pytorch/float8_tensor.py Signed-off-by: Tim Moon <[email protected]> * Remove recursive logic Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix cache reset bug Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Store FP8 attributes in dict Easier for multiple tensors to share, e.g. detached tensors. Signed-off-by: Tim Moon <[email protected]> * Make sure scale_inv is 1D tensor Signed-off-by: Tim Moon <[email protected]> * Make sure scale_inv is 1D tensor Signed-off-by: Tim Moon <[email protected]> * Fixes and detach recipe Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Set default fp8 data type Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Sudhakar Singh <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Charlene Yang <[email protected]>
* full model training using optimizer with master weights, where the high | ||
precision copies of weights are already present in the optimizer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this look in practice? If the model will be initialized directly with fp8 weights, how does the optimizer get high-precision copies?
This FP8 tensor class is based on the implementation at https://github.com/facebookexperimental/protoquant/tree/fp8_poc and is primarily oriented toward enabling efficient FP8 support in Apex's
DistributedFusedAdam
. See NVIDIA/NeMo#7469 and NVIDIA/NeMo#7565.CC @sudhakarsingh27 @ksivaman