Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Failed to pass pytorch's numerical test on A800 SXM #1360

Closed
junjzhang opened this issue Dec 6, 2024 · 1 comment
Closed

[Bug] Failed to pass pytorch's numerical test on A800 SXM #1360

junjzhang opened this issue Dec 6, 2024 · 1 comment

Comments

@junjzhang
Copy link

junjzhang commented Dec 6, 2024

Hi, I just install te using pip install transformer_engine[pytorch] in a clean environment. And the test_numerical test is failed. Could anyone help me on this case?

Error Message with case tests/pytorch/test_numerical.py

FAILED test_numerical.py::test_gpt_full_activation_recompute[True-True-False-126m-1-dtype2] - AssertionError: Mismatch in tensor 0
FAILED test_numerical.py::test_linear_accuracy[small-1-dtype0] - AssertionError: Outputs not close enough in tensor at idx=0. Maximum difference at location [78] with -10.82479190826416 vs -10.818109512329102 (diff 0.006682395935058594).
FAILED test_numerical.py::test_linear_accuracy[small-2-dtype0] - AssertionError: Outputs not close enough in tensor at idx=0. Maximum difference at location [66] with 16.163537979125977 vs 16.155488967895508 (diff 0.00804901123046875).
FAILED test_numerical.py::test_layernorm_linear_accuracy[True-LayerNorm-small-1-dtype0] - TypeError: unsupported operand type(s) for *: 'NoneType' and 'Tensor'
FAILED test_numerical.py::test_layernorm_linear_accuracy[True-LayerNorm-small-2-dtype0] - TypeError: unsupported operand type(s) for *: 'NoneType' and 'Tensor'
FAILED test_numerical.py::test_layernorm_linear_accuracy[True-RMSNorm-small-1-dtype0] - TypeError: unsupported operand type(s) for *: 'NoneType' and 'Tensor'
FAILED test_numerical.py::test_layernorm_linear_accuracy[True-RMSNorm-small-2-dtype0] - TypeError: unsupported operand type(s) for *: 'NoneType' and 'Tensor'
FAILED test_numerical.py::test_layernorm_linear_accuracy[False-LayerNorm-small-1-dtype0] - TypeError: unsupported operand type(s) for *: 'NoneType' and 'Tensor'
FAILED test_numerical.py::test_layernorm_linear_accuracy[False-LayerNorm-small-2-dtype0] - TypeError: unsupported operand type(s) for *: 'NoneType' and 'Tensor'
FAILED test_numerical.py::test_layernorm_linear_accuracy[False-RMSNorm-small-1-dtype0] - TypeError: unsupported operand type(s) for *: 'NoneType' and 'Tensor'
FAILED test_numerical.py::test_layernorm_linear_accuracy[False-RMSNorm-small-2-dtype0] - TypeError: unsupported operand type(s) for *: 'NoneType' and 'Tensor'
FAILED test_numerical.py::test_layernorm_mlp_accuracy[LayerNorm-relu-small-1-dtype0] - AssertionError: Outputs not close enough in tensor at idx=4. Maximum difference at location [0, 105] with 0.13338899612426758 vs 0.15530425310134888 (diff 0.0219152569770813).
FAILED test_numerical.py::test_layernorm_mlp_accuracy[LayerNorm-relu-small-2-dtype0] - AssertionError: Outputs not close enough in tensor at idx=2. Maximum difference at location [0, 105] with 0.13338899612426758 vs 0.15530425310134888 (diff 0.0219152569770813).
FAILED test_numerical.py::test_layernorm_mlp_accuracy[LayerNorm-reglu-small-1-dtype0] - AssertionError: Outputs not close enough in tensor at idx=103. Maximum difference at location [40] with 0.15618646144866943 vs 0.09805639088153839 (diff 0.05813007056713104).
FAILED test_numerical.py::test_layernorm_mlp_accuracy[LayerNorm-reglu-small-2-dtype0] - AssertionError: Outputs not close enough in tensor at idx=53. Maximum difference at location [16] with -1.7313060760498047 vs -1.428515911102295 (diff 0.30279016494750977).
FAILED test_numerical.py::test_layernorm_mlp_accuracy[RMSNorm-relu-small-1-dtype0] - AssertionError: Outputs not close enough in tensor at idx=79. Maximum difference at location [0, 61] with -0.027338851243257523 vs -0.006757093593478203 (diff 0.02058175764977932).
FAILED test_numerical.py::test_layernorm_mlp_accuracy[RMSNorm-relu-small-2-dtype0] - AssertionError: Outputs not close enough in tensor at idx=39. Maximum difference at location [1, 61] with -0.027338841930031776 vs -0.006757088005542755 (diff 0.02058175392448902).
FAILED test_numerical.py::test_layernorm_mlp_accuracy[RMSNorm-reglu-small-1-dtype0] - AssertionError: Outputs not close enough in tensor at idx=82. Maximum difference at location [127] with 0.8240067362785339 vs 0.7542246580123901 (diff 0.0697820782661438).
FAILED test_numerical.py::test_layernorm_mlp_accuracy[RMSNorm-reglu-small-2-dtype0] - AssertionError: Outputs not close enough in tensor at idx=30. Maximum difference at location [0] with -0.6747775077819824 vs -0.6976503729820251 (diff 0.022872865200042725).
FAILED test_numerical.py::test_transformer_layer_hidden_states_format[126m-1-dtype1] - ValueError: No dot product attention support for the provided inputs!
FAILED test_numerical.py::test_transformer_layer_hidden_states_format[126m-1-dtype2] - ValueError: No dot product attention support for the provided inputs!
FAILED test_numerical.py::test_transformer_layer_hidden_states_format[126m-2-dtype1] - ValueError: No dot product attention support for the provided inputs!
FAILED test_numerical.py::test_transformer_layer_hidden_states_format[126m-2-dtype2] - ValueError: No dot product attention support for the provided inputs!

My environment

  • dependencies
Package                   Version      Build                         Size       Kind   Source
_libgcc_mutex             0.1          conda_forge                   2.5 KiB    conda  _libgcc_mutex-0.1-conda_forge.tar.bz2
_openmp_mutex             4.5          2_kmp_llvm                    5.6 KiB    conda  _openmp_mutex-4.5-2_kmp_llvm.tar.bz2
binutils_impl_linux-64    2.43         h4bf12b8_2                    5.4 MiB    conda  binutils_impl_linux-64-2.43-h4bf12b8_2.conda
blas                      2.116        mkl                           12.9 KiB   conda  blas-2.116-mkl.tar.bz2
blas-devel                3.9.0        16_linux64_mkl                12.3 KiB   conda  blas-devel-3.9.0-16_linux64_mkl.tar.bz2
bzip2                     1.0.8        h4bc722e_7                    246.9 KiB  conda  bzip2-1.0.8-h4bc722e_7.conda
ca-certificates           2024.8.30    hbcca054_0                    155.3 KiB  conda  ca-certificates-2024.8.30-hbcca054_0.conda
colorama                  0.4.6        pyhd8ed1ab_1                  26.4 KiB   conda  colorama-0.4.6-pyhd8ed1ab_1.conda
cpython                   3.12.8       py312hd8ed1ab_1               43.7 KiB   conda  cpython-3.12.8-py312hd8ed1ab_1.conda
cuda                      12.4.0       0                             1.8 KiB    conda  cuda-12.4.0-0.tar.bz2
cuda-cccl                 12.4.99      0                             1.4 MiB    conda  cuda-cccl-12.4.99-0.tar.bz2
cuda-command-line-tools   12.4.0       0                             1.8 KiB    conda  cuda-command-line-tools-12.4.0-0.tar.bz2
cuda-compiler             12.4.0       0                             1.8 KiB    conda  cuda-compiler-12.4.0-0.tar.bz2
cuda-cudart               12.4.99      0                             198.3 KiB  conda  cuda-cudart-12.4.99-0.tar.bz2
cuda-cudart-dev           12.4.99      0                             413.9 KiB  conda  cuda-cudart-dev-12.4.99-0.tar.bz2
cuda-cudart-static        12.4.99      0                             1.1 MiB    conda  cuda-cudart-static-12.4.99-0.tar.bz2
cuda-cuobjdump            12.4.99      0                             301 KiB    conda  cuda-cuobjdump-12.4.99-0.tar.bz2
cuda-cupti                12.4.99      0                             16.4 MiB   conda  cuda-cupti-12.4.99-0.tar.bz2
cuda-cupti-static         12.4.99      0                             11.3 MiB   conda  cuda-cupti-static-12.4.99-0.tar.bz2
cuda-cuxxfilt             12.4.99      0                             284.9 KiB  conda  cuda-cuxxfilt-12.4.99-0.tar.bz2
cuda-demo-suite           12.4.99      0                             5 MiB      conda  cuda-demo-suite-12.4.99-0.tar.bz2
cuda-documentation        12.4.99      0                             89.4 KiB   conda  cuda-documentation-12.4.99-0.tar.bz2
cuda-driver-dev           12.4.99      0                             18.1 KiB   conda  cuda-driver-dev-12.4.99-0.tar.bz2
cuda-gdb                  12.4.99      0                             5.8 MiB    conda  cuda-gdb-12.4.99-0.tar.bz2
cuda-libraries            12.4.0       0                             1.9 KiB    conda  cuda-libraries-12.4.0-0.tar.bz2
cuda-libraries-dev        12.4.0       0                             1.9 KiB    conda  cuda-libraries-dev-12.4.0-0.tar.bz2
cuda-libraries-static     12.4.0       0                             1.9 KiB    conda  cuda-libraries-static-12.4.0-0.tar.bz2
cuda-nsight               12.4.99      0                             113.7 MiB  conda  cuda-nsight-12.4.99-0.tar.bz2
cuda-nsight-compute       12.4.0       0                             1.8 KiB    conda  cuda-nsight-compute-12.4.0-0.tar.bz2
cuda-nvcc                 12.4.99      0                             62.6 MiB   conda  cuda-nvcc-12.4.99-0.tar.bz2
cuda-nvdisasm             12.4.99      0                             47.9 MiB   conda  cuda-nvdisasm-12.4.99-0.tar.bz2
cuda-nvml-dev             12.4.99      0                             173.6 KiB  conda  cuda-nvml-dev-12.4.99-0.tar.bz2
cuda-nvprof               12.4.99      0                             4.7 MiB    conda  cuda-nvprof-12.4.99-0.tar.bz2
cuda-nvprune              12.4.99      0                             65.7 KiB   conda  cuda-nvprune-12.4.99-0.tar.bz2
cuda-nvrtc                12.4.99      0                             21.1 MiB   conda  cuda-nvrtc-12.4.99-0.tar.bz2
cuda-nvrtc-dev            12.4.99      0                             12.5 KiB   conda  cuda-nvrtc-dev-12.4.99-0.tar.bz2
cuda-nvrtc-static         12.4.99      0                             21.9 MiB   conda  cuda-nvrtc-static-12.4.99-0.tar.bz2
cuda-nvtx                 12.4.99      0                             57.8 KiB   conda  cuda-nvtx-12.4.99-0.tar.bz2
cuda-nvvp                 12.4.99      0                             114.5 MiB  conda  cuda-nvvp-12.4.99-0.tar.bz2
cuda-opencl               12.4.99      0                             11.5 KiB   conda  cuda-opencl-12.4.99-0.tar.bz2
cuda-opencl-dev           12.4.99      0                             71.6 KiB   conda  cuda-opencl-dev-12.4.99-0.tar.bz2
cuda-profiler-api         12.4.99      0                             18.7 KiB   conda  cuda-profiler-api-12.4.99-0.tar.bz2
cuda-runtime              12.4.0       0                             1.8 KiB    conda  cuda-runtime-12.4.0-0.tar.bz2
cuda-sanitizer-api        12.4.99      0                             17.1 MiB   conda  cuda-sanitizer-api-12.4.99-0.tar.bz2
cuda-toolkit              12.4.0       0                             1.9 KiB    conda  cuda-toolkit-12.4.0-0.tar.bz2
cuda-tools                12.4.0       0                             1.8 KiB    conda  cuda-tools-12.4.0-0.tar.bz2
cuda-version              12.6         3                             16.5 KiB   conda  cuda-version-12.6-3.conda
cuda-visual-tools         12.4.0       0                             1.8 KiB    conda  cuda-visual-tools-12.4.0-0.tar.bz2
cudnn                     9.3.0.75     cuda12.6                      383.6 MiB  conda  cudnn-9.3.0.75-cuda12.6.conda
exceptiongroup            1.2.2        pyhd8ed1ab_1                  20 KiB     conda  exceptiongroup-1.2.2-pyhd8ed1ab_1.conda
filelock                  3.16.1       pyhd8ed1ab_1                  17 KiB     conda  filelock-3.16.1-pyhd8ed1ab_1.conda
gcc                       12.4.0       h236703b_1                    52.5 KiB   conda  gcc-12.4.0-h236703b_1.conda
gcc_impl_linux-64         12.4.0       hb2e57f8_1                    59.2 MiB   conda  gcc_impl_linux-64-12.4.0-hb2e57f8_1.conda
gds-tools                 1.9.0.20     0                             40.7 MiB   conda  gds-tools-1.9.0.20-0.tar.bz2
gmp                       6.3.0        hac33072_2                    449.3 KiB  conda  gmp-6.3.0-hac33072_2.conda
gmpy2                     2.1.5        py312h7201bc8_3               204.7 KiB  conda  gmpy2-2.1.5-py312h7201bc8_3.conda
gxx                       12.4.0       h236703b_1                    52 KiB     conda  gxx-12.4.0-h236703b_1.conda
gxx_impl_linux-64         12.4.0       h613a52c_1                    12.1 MiB   conda  gxx_impl_linux-64-12.4.0-h613a52c_1.conda
iniconfig                 2.0.0        pyhd8ed1ab_1                  11.2 KiB   conda  iniconfig-2.0.0-pyhd8ed1ab_1.conda
jinja2                    3.1.4        pyhd8ed1ab_1                  108.4 KiB  conda  jinja2-3.1.4-pyhd8ed1ab_1.conda
kernel-headers_linux-64   3.10.0       he073ed8_18                   921.4 KiB  conda  kernel-headers_linux-64-3.10.0-he073ed8_18.conda
ld_impl_linux-64          2.43         h712a8e2_2                    653.5 KiB  conda  ld_impl_linux-64-2.43-h712a8e2_2.conda
libblas                   3.9.0        16_linux64_mkl                12.8 KiB   conda  libblas-3.9.0-16_linux64_mkl.tar.bz2
libcblas                  3.9.0        16_linux64_mkl                12.5 KiB   conda  libcblas-3.9.0-16_linux64_mkl.tar.bz2
libcublas                 12.4.2.65    0                             308.8 MiB  conda  libcublas-12.4.2.65-0.tar.bz2
libcublas-dev             12.4.2.65    0                             74.8 KiB   conda  libcublas-dev-12.4.2.65-0.tar.bz2
libcublas-static          12.4.2.65    0                             349.5 MiB  conda  libcublas-static-12.4.2.65-0.tar.bz2
libcufft                  11.2.0.44    0                             190.5 MiB  conda  libcufft-11.2.0.44-0.tar.bz2
libcufft-dev              11.2.0.44    0                             14.4 KiB   conda  libcufft-dev-11.2.0.44-0.tar.bz2
libcufft-static           11.2.0.44    0                             383.8 MiB  conda  libcufft-static-11.2.0.44-0.tar.bz2
libcufile                 1.9.0.20     0                             1 MiB      conda  libcufile-1.9.0.20-0.tar.bz2
libcufile-dev             1.9.0.20     0                             14.5 KiB   conda  libcufile-dev-1.9.0.20-0.tar.bz2
libcufile-static          1.9.0.20     0                             3.6 MiB    conda  libcufile-static-1.9.0.20-0.tar.bz2
libcurand                 10.3.5.119   0                             51.8 MiB   conda  libcurand-10.3.5.119-0.tar.bz2
libcurand-dev             10.3.5.119   0                             449.6 KiB  conda  libcurand-dev-10.3.5.119-0.tar.bz2
libcurand-static          10.3.5.119   0                             52 MiB     conda  libcurand-static-10.3.5.119-0.tar.bz2
libcusolver               11.6.0.99    0                             114.3 MiB  conda  libcusolver-11.6.0.99-0.tar.bz2
libcusolver-dev           11.6.0.99    0                             49.1 KiB   conda  libcusolver-dev-11.6.0.99-0.tar.bz2
libcusolver-static        11.6.0.99    0                             76.4 MiB   conda  libcusolver-static-11.6.0.99-0.tar.bz2
libcusparse               12.3.0.142   0                             179.6 MiB  conda  libcusparse-12.3.0.142-0.tar.bz2
libcusparse-dev           12.3.0.142   0                             179.7 MiB  conda  libcusparse-dev-12.3.0.142-0.tar.bz2
libcusparse-static        12.3.0.142   0                             184.8 MiB  conda  libcusparse-static-12.3.0.142-0.tar.bz2
libexpat                  2.6.4        h5888daf_0                    71.6 KiB   conda  libexpat-2.6.4-h5888daf_0.conda
libffi                    3.4.2        h7f98852_5                    56.9 KiB   conda  libffi-3.4.2-h7f98852_5.tar.bz2
libgcc                    14.2.0       h77fa898_1                    828.9 KiB  conda  libgcc-14.2.0-h77fa898_1.conda
libgcc-devel_linux-64     12.4.0       ha4f9413_101                  2.4 MiB    conda  libgcc-devel_linux-64-12.4.0-ha4f9413_101.conda
libgcc-ng                 14.2.0       h69a702a_1                    52.9 KiB   conda  libgcc-ng-14.2.0-h69a702a_1.conda
libgfortran               14.2.0       h69a702a_1                    52.7 KiB   conda  libgfortran-14.2.0-h69a702a_1.conda
libgfortran-ng            14.2.0       h69a702a_1                    52.8 KiB   conda  libgfortran-ng-14.2.0-h69a702a_1.conda
libgfortran5              14.2.0       hd5240d6_1                    1.4 MiB    conda  libgfortran5-14.2.0-hd5240d6_1.conda
libgomp                   14.2.0       h77fa898_1                    450.2 KiB  conda  libgomp-14.2.0-h77fa898_1.conda
libhwloc                  2.11.2       default_h0d58e46_1001         2.3 MiB    conda  libhwloc-2.11.2-default_h0d58e46_1001.conda
libiconv                  1.17         hd590300_2                    689.2 KiB  conda  libiconv-1.17-hd590300_2.conda
liblapack                 3.9.0        16_linux64_mkl                12.5 KiB   conda  liblapack-3.9.0-16_linux64_mkl.tar.bz2
liblapacke                3.9.0        16_linux64_mkl                12.5 KiB   conda  liblapacke-3.9.0-16_linux64_mkl.tar.bz2
liblzma                   5.6.3        hb9d3cd8_1                    108.5 KiB  conda  liblzma-5.6.3-hb9d3cd8_1.conda
libnpp                    12.2.5.2     0                             142.8 MiB  conda  libnpp-12.2.5.2-0.tar.bz2
libnpp-dev                12.2.5.2     0                             538.9 KiB  conda  libnpp-dev-12.2.5.2-0.tar.bz2
libnpp-static             12.2.5.2     0                             139.2 MiB  conda  libnpp-static-12.2.5.2-0.tar.bz2
libnsl                    2.0.1        hd590300_0                    32.6 KiB   conda  libnsl-2.0.1-hd590300_0.conda
libnvfatbin               12.4.99      0                             855.7 KiB  conda  libnvfatbin-12.4.99-0.tar.bz2
libnvfatbin-dev           12.4.99      0                             685.1 KiB  conda  libnvfatbin-dev-12.4.99-0.tar.bz2
libnvjitlink              12.4.99      0                             18.2 MiB   conda  libnvjitlink-12.4.99-0.tar.bz2
libnvjitlink-dev          12.4.99      0                             18.1 MiB   conda  libnvjitlink-dev-12.4.99-0.tar.bz2
libnvjpeg                 12.3.1.89    0                             3 MiB      conda  libnvjpeg-12.3.1.89-0.tar.bz2
libnvjpeg-dev             12.3.1.89    0                             13.1 KiB   conda  libnvjpeg-dev-12.3.1.89-0.tar.bz2
libnvjpeg-static          12.3.1.89    0                             2.7 MiB    conda  libnvjpeg-static-12.3.1.89-0.tar.bz2
libsanitizer              12.4.0       h46f95d5_1                    3.8 MiB    conda  libsanitizer-12.4.0-h46f95d5_1.conda
libsqlite                 3.47.0       hadc24fc_1                    854.8 KiB  conda  libsqlite-3.47.0-hadc24fc_1.conda
libstdcxx                 14.2.0       hc0a3c3a_1                    3.7 MiB    conda  libstdcxx-14.2.0-hc0a3c3a_1.conda
libstdcxx-devel_linux-64  12.4.0       ha4f9413_101                  11.3 MiB   conda  libstdcxx-devel_linux-64-12.4.0-ha4f9413_101.conda
libstdcxx-ng              14.2.0       h4852527_1                    52.8 KiB   conda  libstdcxx-ng-14.2.0-h4852527_1.conda
libuuid                   2.38.1       h0b41bf4_0                    32.8 KiB   conda  libuuid-2.38.1-h0b41bf4_0.conda
libxcrypt                 4.4.36       hd590300_1                    98 KiB     conda  libxcrypt-4.4.36-hd590300_1.conda
libxml2                   2.13.5       h0d44e9d_1                    673.8 KiB  conda  libxml2-2.13.5-h0d44e9d_1.conda
libzlib                   1.3.1        hb9d3cd8_2                    59.5 KiB   conda  libzlib-1.3.1-hb9d3cd8_2.conda
llvm-openmp               15.0.7       h0cdce71_0                    3.1 MiB    conda  llvm-openmp-15.0.7-h0cdce71_0.conda
markupsafe                3.0.2        py312h178313f_1               24 KiB     conda  markupsafe-3.0.2-py312h178313f_1.conda
mkl                       2022.1.0     h84fe81f_915                  199.6 MiB  conda  mkl-2022.1.0-h84fe81f_915.tar.bz2
mkl-devel                 2022.1.0     ha770c72_916                  25 KiB     conda  mkl-devel-2022.1.0-ha770c72_916.tar.bz2
mkl-include               2022.1.0     h84fe81f_915                  744.7 KiB  conda  mkl-include-2022.1.0-h84fe81f_915.tar.bz2
mpc                       1.3.1        h24ddda3_1                    114 KiB    conda  mpc-1.3.1-h24ddda3_1.conda
mpfr                      4.2.1        h90cbb55_3                    619.9 KiB  conda  mpfr-4.2.1-h90cbb55_3.conda
mpmath                    1.3.0        pyhd8ed1ab_1                  429.4 KiB  conda  mpmath-1.3.0-pyhd8ed1ab_1.conda
ncurses                   6.5          he02047a_1                    868.2 KiB  conda  ncurses-6.5-he02047a_1.conda
networkx                  3.4.2        pyh267e887_2                  1.2 MiB    conda  networkx-3.4.2-pyh267e887_2.conda
ninja                     1.12.1       h297d8ca_0                    2.1 MiB    conda  ninja-1.12.1-h297d8ca_0.conda
nsight-compute            2024.1.0.13  0                             667.1 MiB  conda  nsight-compute-2024.1.0.13-0.tar.bz2
numpy                     2.1.3        py312h58c1407_0               8 MiB      conda  numpy-2.1.3-py312h58c1407_0.conda
openssl                   3.4.0        hb9d3cd8_0                    2.8 MiB    conda  openssl-3.4.0-hb9d3cd8_0.conda
packaging                 24.2         pyhd8ed1ab_2                  58.8 KiB   conda  packaging-24.2-pyhd8ed1ab_2.conda
pip                       24.3.1       pyh8b19718_0                  1.2 MiB    conda  pip-24.3.1-pyh8b19718_0.conda
pluggy                    1.5.0        pyhd8ed1ab_1                  23 KiB     conda  pluggy-1.5.0-pyhd8ed1ab_1.conda
pytest                    8.3.4        pyhd8ed1ab_1                  253.1 KiB  conda  pytest-8.3.4-pyhd8ed1ab_1.conda
python                    3.12.8       h9e4cc4f_1_cpython            30.1 MiB   conda  python-3.12.8-h9e4cc4f_1_cpython.conda
python_abi                3.12         5_cp312                       6.1 KiB    conda  python_abi-3.12-5_cp312.conda
pytorch                   2.5.1        py3.12_cuda12.4_cudnn9.1.0_0  1.5 GiB    conda  pytorch-2.5.1-py3.12_cuda12.4_cudnn9.1.0_0.tar.bz2
pytorch-cuda              12.4         hc786d27_6                    7 KiB      conda  pytorch-cuda-12.4-hc786d27_6.tar.bz2
pytorch-mutex             1.0          cuda                          2.8 KiB    conda  pytorch-mutex-1.0-cuda.tar.bz2
pyyaml                    6.0.2        py312h66e93f0_1               201.7 KiB  conda  pyyaml-6.0.2-py312h66e93f0_1.conda
readline                  8.2          h8228510_1                    274.9 KiB  conda  readline-8.2-h8228510_1.conda
setuptools                75.6.0       pyhff2d567_1                  756.1 KiB  conda  setuptools-75.6.0-pyhff2d567_1.conda
sympy                     1.13.3       pyh2585a3b_104                4.4 MiB    conda  sympy-1.13.3-pyh2585a3b_104.conda
sysroot_linux-64          2.17         h4a8ded7_18                   14.8 MiB   conda  sysroot_linux-64-2.17-h4a8ded7_18.conda
tbb                       2021.13.0    hceb3a55_1                    171.8 KiB  conda  tbb-2021.13.0-hceb3a55_1.conda
tk                        8.6.13       noxft_h4845f30_101            3.2 MiB    conda  tk-8.6.13-noxft_h4845f30_101.conda
tomli                     2.2.1        pyhd8ed1ab_1                  18.7 KiB   conda  tomli-2.2.1-pyhd8ed1ab_1.conda
torchtriton               3.1.0        py312                         233.6 MiB  conda  torchtriton-3.1.0-py312.tar.bz2
typing_extensions         4.12.2       pyha770c72_1                  38.7 KiB   conda  typing_extensions-4.12.2-pyha770c72_1.conda
tzdata                    2024b        hc8b5060_0                    119.5 KiB  conda  tzdata-2024b-hc8b5060_0.conda
wheel                     0.45.1       pyhd8ed1ab_1                  61.5 KiB   conda  wheel-0.45.1-pyhd8ed1ab_1.conda
yaml                      0.2.5        h7f98852_2                    87.1 KiB   conda  yaml-0.2.5-h7f98852_2.tar.bz2
  • device spec
Driver Version: 550.90.07
Device: NVIDIA A800-SXM4-80GB
@junjzhang junjzhang changed the title [Bug] Failed to pass pytorch's numerical test [Bug] Failed to pass pytorch's numerical test on A800 SXM Dec 6, 2024
@junjzhang
Copy link
Author

Based on these two related issues #494 #1165, it turns out that te use tf32 by default and pytorch use fp32 by default.
Thus, we can either let pytorch use tf32 by

torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True

or let te use fp32 by

export NVIDIA_TF32_OVERRIDE=0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant