Skip to content

Commit d097883

Browse files
authored
Better way of checking cuDNN version (#485)
* Ability to check cuDNN version from Python Signed-off-by: Przemek Tredak <[email protected]> * Modify the fused attention test to not use the CUDNN_VERSION env variable which is specific to NGC containers Signed-off-by: Przemek Tredak <[email protected]> --------- Signed-off-by: Przemek Tredak <[email protected]>
1 parent 136acac commit d097883

File tree

5 files changed

+17
-1
lines changed

5 files changed

+17
-1
lines changed

tests/pytorch/test_fused_attn.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,15 @@
4444
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
4545
_flash_attn_version = packaging.version.Version(version("flash-attn"))
4646
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
47-
_cudnn_version = [int(i) for i in os.environ['CUDNN_VERSION'].split('.')]
47+
48+
def _get_cudnn_version():
49+
cudnn_version_encoded = ext.get_cudnn_version()
50+
cudnn_major = cudnn_version_encoded // 1000
51+
cudnn_minor = (cudnn_version_encoded - cudnn_major * 1000) // 100
52+
cudnn_patch = cudnn_version_encoded - 1000 * cudnn_major - 100 * cudnn_minor
53+
return [cudnn_major, cudnn_minor, cudnn_patch]
54+
55+
_cudnn_version = _get_cudnn_version()
4856

4957

5058
class ModelConfig:

transformer_engine/pytorch/csrc/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <cuda_runtime.h>
3232
#include <cuda_bf16.h>
3333
#include <cublasLt.h>
34+
#include <cudnn.h>
3435
#include <stdexcept>
3536
#include <memory>
3637
#include <iomanip>

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,8 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
524524

525525
size_t get_cublasLt_version();
526526

527+
size_t get_cudnn_version();
528+
527529
bool userbuf_comm_available();
528530

529531
void placeholder();

transformer_engine/pytorch/csrc/extensions/misc.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ size_t get_cublasLt_version() {
1313
return cublasLtGetVersion();
1414
}
1515

16+
size_t get_cudnn_version() {
17+
return cudnnGetVersion();
18+
}
19+
1620

1721
bool userbuf_comm_available() { // TODO(ksivamani) check on python side
1822
#ifdef NVTE_WITH_USERBUFFERS

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7777

7878
// Misc
7979
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
80+
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version");
8081
m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available");
8182

8283
// Data structures

0 commit comments

Comments
 (0)