Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C] Normalization Refactor + Adding CUDNN backend #1315

Merged
merged 85 commits into from
Dec 6, 2024

Conversation

phu0ngng
Copy link
Collaborator

@phu0ngng phu0ngng commented Nov 5, 2024

Description

  1. cuDNN normalization was integrated into TE. Users can use this backend via env vars as:
export NVTE_NORM_FWD_USE_CUDNN=1
export NVTE_NORM_BWD_USE_CUDNN=1

CuDNN 9.6 is needed.

By default, TE Kernels are still used for normalization. Further performance evaluation on different supported GPUs is needed to make cuDNN backend as default.

  1. Major changes were introduced in the TE normalization C APIs:
  • Merging barrier, dgamma_part, and dbeta_part into the workspace. Thus those args are no longer needed in the APIs.
  • Unifying APIs with and without zero centering.

These are the updated C APIs:

void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta, 
                        const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma,
                        NVTETensor workspace, const int multiprocessorCount,
                        const bool zero_centered_gamma, cudaStream_t stream); 

void nvte_layernorm_bwd(const NVTETensor dz, const NVTETensor x,  const NVTETensor mu, 
                        const NVTETensor rsigma, const NVTETensor gamma,  NVTETensor dx, 
                        NVTETensor dgamma, NVTETensor dbeta, NVTETensor workspace,
                        const int multiprocessorCount, const bool zero_centered_gamma,
                        cudaStream_t stream);

void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, NVTETensor z,
                      NVTETensor rsigma, NVTETensor workspace, const int multiprocessorCount,
                      const bool zero_centered_gamma, cudaStream_t stream); 

void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma,
                      const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma,
                      NVTETensor workspace, const int multiprocessorCount,
                      const bool zero_centered_gamma, cudaStream_t stream);

The following APIs are deprecated:

nvte_layernorm1p_fwd
nvte_layernorm1p_bwd
nvte_rmsnorm1p_fwd
nvte_rmsnorm1p_bwd
  1. TE normalization implementations were refactored to be more friendly for adding cuDNN backend.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactor

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@phu0ngng phu0ngng requested a review from timmoon10 November 5, 2024 23:16
@timmoon10 timmoon10 self-requested a review November 8, 2024 00:47
@phu0ngng phu0ngng requested a review from ptrendx November 13, 2024 19:18
tests/cpp/operator/test_normalization.cu Outdated Show resolved Hide resolved
transformer_engine/common/normalization/common.h Outdated Show resolved Hide resolved
tests/cpp/operator/test_normalization.cu Outdated Show resolved Hide resolved
@timmoon10 timmoon10 self-requested a review November 14, 2024 03:04
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
phu0ngng and others added 15 commits December 3, 2024 12:26
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
* Update list of CI users

Signed-off-by: Tim Moon <[email protected]>

* Update list of CI users

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
…age (NVIDIA#1308)

* draft implementation

Signed-off-by: Youngeun Kwon <[email protected]>

* compile error fix

Signed-off-by: Youngeun Kwon <[email protected]>

* fix compile error

Signed-off-by: Youngeun Kwon <[email protected]>

* remove print

Signed-off-by: Youngeun Kwon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Edit comments

Signed-off-by: Youngeun Kwon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* edit the bulk-overlap test case

Signed-off-by: Youngeun Kwon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add version guard

Signed-off-by: Youngeun Kwon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add runtime version guard

Signed-off-by: Youngeun Kwon <[email protected]>

* fix the version guard

Signed-off-by: Youngeun Kwon <[email protected]>

---------

Signed-off-by: Youngeun Kwon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@phu0ngng phu0ngng added the 1.14.0 label Dec 3, 2024
Signed-off-by: Phuong Nguyen <[email protected]>
@phu0ngng
Copy link
Collaborator Author

phu0ngng commented Dec 3, 2024

/te-ci

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have minor comments and we should make sure this passes CI, but overall LGTM.

tests/cpp/operator/test_normalization.cu Outdated Show resolved Hide resolved
transformer_engine/common/1 Outdated Show resolved Hide resolved
transformer_engine/common/normalization/common.cpp Outdated Show resolved Hide resolved
phu0ngng and others added 3 commits December 5, 2024 16:17
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
@phu0ngng
Copy link
Collaborator Author

phu0ngng commented Dec 5, 2024

/te-ci

@phu0ngng phu0ngng merged commit 3102fdd into NVIDIA:main Dec 6, 2024
14 checks passed
@phu0ngng phu0ngng deleted the te_norm_refactor branch December 6, 2024 18:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants