Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

te.TransformerLayer fails on H100 with cudnn errors. #1350

Closed
wujingyue opened this issue Nov 30, 2024 · 2 comments
Closed

te.TransformerLayer fails on H100 with cudnn errors. #1350

wujingyue opened this issue Nov 30, 2024 · 2 comments
Assignees

Comments

@wujingyue
Copy link
Contributor

wujingyue commented Nov 30, 2024

Repro

import pytest
import torch
from enum import auto, Enum
from functools import partial

import transformer_engine.pytorch as te


class ComputeType(Enum):
    FORWARD = auto()
    BACKWARD = auto()


@pytest.mark.parametrize(
    "compute_type",
    [ComputeType.FORWARD, ComputeType.BACKWARD],
    ids=["forward", "backward"],
)
def test_transformer_layer(benchmark, compute_type):
    # Hyperparameters for GPT-3
    hidden_size = 12288
    num_heads = 96
    ffn_hidden_size = hidden_size * 4
    batch_size = 1
    sequence_length = 2048
    dtype = torch.bfloat16

    transformer_layer = te.TransformerLayer(
        hidden_size,
        ffn_hidden_size,
        num_heads,
    )
    transformer_layer.to(dtype).to("cuda")

    x = torch.randn(
        batch_size, sequence_length, hidden_size, dtype=dtype, device="cuda"
    )

    match compute_type:
        case ComputeType.FORWARD:

            def benchmark_fn(profile):
                if profile:
                    torch.cuda.cudart().cudaProfilerStart()

                y = transformer_layer(x)
                torch.cuda.synchronize()

                if profile:
                    torch.cuda.cudart().cudaProfilerStop()
                return y

            # Warmup.
            y = benchmark_fn(False)
            assert y.size() == torch.Size([batch_size, sequence_length, hidden_size])

            benchmark.pedantic(benchmark_fn, args=(True,), rounds=5)
        case ComputeType.BACKWARD:
            # Due to
            # https://github.com/Lightning-AI/lightning-thunder/issues/701, a
            # limitation in TransformerEngine, we can't repeatedly call
            # torch.autograd.backward to benchmark just backprop. As a
            # workaround, the code below runs forward before each backprop but
            # only measure the backprop time.
            def setup_fn(profile):
                y = transformer_layer(x)
                dy = torch.rand_like(y)
                torch.cuda.synchronize()
                # Unlike for forward, I can't pass `profile` directly to
                # `benchmark_fn` because `benchmark.pedantic` is not allowed to
                # take both `setup` and `args`. Therefore, we pass `profile` to
                # `setup_fn`, which in turn passes iit through to
                # `benchmark_fn`.
                return (y, dy, profile), {}

            def benchmark_fn(y, dy, profile):
                if profile:
                    torch.cuda.cudart().cudaProfilerStart()

                torch.autograd.backward(y, dy)
                torch.cuda.synchronize()

                if profile:
                    torch.cuda.cudart().cudaProfilerStop()

            # Warmup.
            args, kwargs = setup_fn(False)
            benchmark_fn(*args, **kwargs)

            benchmark.pedantic(
                benchmark_fn,
                setup=partial(setup_fn, True),
                rounds=5,
            )

FWIW, this is a simplified version of
https://github.com/NVIDIA/Fuser/blob/c154e90919c40bfe2202b432c2a38e106d1a5444/tests/python/test_transformer_engine.py#L50.

pytest repro.py

Errors

RuntimeError: /opt/nvidia/TransformerEngine/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:771 in function operator(): cuDNN Error: [cudnn_frontend] Error: No valid execution plans built.. For more information, enable cuDNN error logging by setting CUDNN_LOGERR_DBG=1 and CUDNN_LOGDEST_DBG=stderr in the environment.

With CUDNN_LOGERR_DBG=1 and CUDNN_LOGDEST_DBG=stderr:

tests/python/test_transformer_engine.py .
E! CuDNN (v99903 18) function cudnnBackendFinalize() called:
e!         Error: CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED; Reason: Encountered runtime kernel compilation failure at: compilationResult != NVRTC_SUCCESS
e!         Error: CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED; Reason: rtk(kernelNumRunning)->compile(compilerFlags, this->useNvrtcSassPath, true )
e!         Error: CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED; Reason: compile_internal()
e!         Error: CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED; Reason: ptr.compile()
e!         Error: CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED; Reason: engine_post_checks(*engine_iface, req_size)
e!         Error: CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED; Reason: finalize_internal()
e!         Error: CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED; Reason: ptrDesc->finalize()
e! Time: 2024-11-29T22:10:18.488657 (0d+0h+0m+2s since start)
e! Process=19401; Thread=19657; GPU=NULL; Handle=NULL; StreamId=NULL.


E! CuDNN (v99903 18) function cudnnBackendFinalize() called:
e!         Error: CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED; Reason: Encountered runtime kernel compilation failure at: compilationResult != NVRTC_SUCCESS
e!         Error: CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED; Reason: rtk(kernelNumRunning)->compile(compilerFlags, this->useNvrtcSassPath, true )
e!         Error: CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED; Reason: compile_internal()
e!         Error: CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED; Reason: ptr.compile()
e!         Error: CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED; Reason: engine_post_checks(*engine_iface, req_size)
e!         Error: CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED; Reason: finalize_internal()
e!         Error: CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED; Reason: ptrDesc->finalize()
e! Time: 2024-11-29T22:10:18.504707 (0d+0h+0m+2s since start)
e! Process=19401; Thread=19657; GPU=NULL; Handle=NULL; StreamId=NULL.

Environments

$ nvidia-smi -L
GPU 0: NVIDIA H100 80GB HBM3...
>>> import transformer_engine
>>> transformer_engine.__version__
'1.14.0.dev0+a132ac4'
>>> import cudnn
>>> cudnn.__version__
'1.8.0'

Notes

Setting attn_input_format to "bshd" works around the problem.

@cyanguwa
Copy link
Collaborator

Hi @wujingyue , it seems that you have solved the issue by yourself? :) The input x follows a bshd format, so you should set attn_input_format=bshd accordingly. Let me know if I misunderstood the intention of the bug - thanks!

@wujingyue
Copy link
Contributor Author

Thanks! I thought attn input format is something internal to the layer and orthogonal to the layer's input. Great to know they have to be consistent.

wujingyue added a commit to NVIDIA/Fuser that referenced this issue Dec 14, 2024
wujingyue added a commit to NVIDIA/Fuser that referenced this issue Dec 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants