Skip to content

Commit 5a091c0

Browse files
authored
Merge branch 'main' into activation-ops
2 parents c82aea7 + 095b27d commit 5a091c0

File tree

106 files changed

+9929
-3653
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

106 files changed

+9929
-3653
lines changed

.github/workflows/docs.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ jobs:
1717
uses: actions/checkout@v3
1818
- name: 'Install dependencies'
1919
run: |
20-
pip install sphinx==5.1.1 sphinx_rtd_theme==1.0.0 nbsphinx==0.8.10 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==2.15.7
21-
pip install breathe==4.34.0 sphinx-autoapi==2.0.1
20+
pip install sphinx==8.1.3 sphinx_rtd_theme==3.0.1 nbsphinx==0.9.5 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==3.3.2
21+
pip install breathe==4.35.0 sphinx-autoapi==3.3.2
2222
sudo apt-get install -y pandoc graphviz doxygen
2323
export GIT_SHA=$(git show-ref --hash HEAD)
2424
- name: 'Build docs'

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ __pycache__
2222
.hypothesis
2323
.devcontainer.json
2424
tests/cpp/build/
25-
docs/_build
2625
.ipynb_checkpoints
27-
docs/doxygen
2826
*.log
2927
CMakeFiles/CMakeSystem.cmake
3028
sdist/
@@ -40,3 +38,4 @@ dist/
4038
downloads/
4139
.pytest_cache/
4240
compile_commands.json
41+
.nfs

3rdparty/cudnn-frontend

Submodule cudnn-frontend updated 146 files

build_tools/pytorch.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from .utils import (
1212
all_files_in_dir,
1313
cuda_archs,
14-
cuda_path,
1514
cuda_version,
1615
)
1716

@@ -29,9 +28,6 @@ def setup_pytorch_extension(
2928
sources = [
3029
csrc_source_files / "common.cu",
3130
csrc_source_files / "ts_fp8_op.cpp",
32-
csrc_source_files / "userbuffers" / "ipcsocket.cc",
33-
csrc_source_files / "userbuffers" / "userbuffers.cu",
34-
csrc_source_files / "userbuffers" / "userbuffers-host.cpp",
3531
] + all_files_in_dir(extensions_dir)
3632

3733
# Header files
@@ -85,19 +81,14 @@ def setup_pytorch_extension(
8581
continue # Already handled
8682
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])
8783

88-
# Libraries
89-
library_dirs = []
90-
libraries = []
91-
if bool(int(os.getenv("NVTE_UB_WITH_MPI", 0))):
84+
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
9285
assert (
9386
os.getenv("MPI_HOME") is not None
94-
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
95-
mpi_home = Path(os.getenv("MPI_HOME"))
96-
include_dirs.append(mpi_home / "include")
87+
), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
88+
mpi_path = Path(os.getenv("MPI_HOME"))
89+
include_dirs.append(mpi_path / "include")
9790
cxx_flags.append("-DNVTE_UB_WITH_MPI")
9891
nvcc_flags.append("-DNVTE_UB_WITH_MPI")
99-
library_dirs.append(mpi_home / "lib")
100-
libraries.append("mpi")
10192

10293
# Construct PyTorch CUDA extension
10394
sources = [str(path) for path in sources]
@@ -112,6 +103,4 @@ def setup_pytorch_extension(
112103
"cxx": cxx_flags,
113104
"nvcc": nvcc_flags,
114105
},
115-
libraries=[str(lib) for lib in libraries],
116-
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
117106
)

docs/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
_build
2+
doxygen
3+
sphinx_rtd_theme

docs/Makefile

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,10 @@ help:
1616

1717
# Catch-all target: route all unknown targets to Sphinx using the new
1818
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19-
%: Makefile
20-
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
19+
%: Makefile sphinx_rtd_theme
20+
PYTHONPATH=sphinx_rtd_theme:$(PYTHONPATH) $(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21+
22+
# Patch Sphinx RTD theme 3.0.1 to add version selector in sidebar
23+
sphinx_rtd_theme:
24+
git clone --depth=1 -b 3.0.1 --single-branch https://github.com/readthedocs/sphinx_rtd_theme.git
25+
bash -c "cd sphinx_rtd_theme; git apply ../version_select.patch"

docs/api/pytorch.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,7 @@ pyTorch
5151
.. autoapifunction:: transformer_engine.pytorch.moe_permute
5252

5353
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
54+
55+
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
56+
57+
.. autoapifunction:: transformer_engine.pytorch.destroy_ub

docs/conf.py

Lines changed: 27 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,30 @@
22
#
33
# See LICENSE for license information.
44

5+
import datetime
56
import os
6-
import sys
7-
import sphinx_rtd_theme
8-
from sphinx.ext.autodoc.mock import mock
9-
from sphinx.ext.autodoc import between, ClassDocumenter, AttributeDocumenter
10-
from sphinx.util import inspect
11-
from builtins import str
12-
from enum import Enum
13-
import re
7+
import pathlib
148
import subprocess
15-
from pathlib import Path
16-
from datetime import date
17-
18-
te_path = os.path.dirname(os.path.realpath(__file__))
9+
from builtins import str
1910

20-
with open(te_path + "/../build_tools/VERSION.txt", "r") as f:
21-
te_version = f.readline().strip()
11+
# Basic project info
12+
project = "Transformer Engine"
13+
author = "NVIDIA CORPORATION & AFFILIATES"
2214

15+
# Copyright statement
2316
release_year = 2022
24-
25-
current_year = date.today().year
17+
current_year = datetime.date.today().year
2618
if current_year == release_year:
2719
copyright_year = release_year
2820
else:
2921
copyright_year = str(release_year) + "-" + str(current_year)
22+
copyright = f"{copyright_year}, NVIDIA CORPORATION & AFFILIATES. All rights reserved."
3023

31-
project = "Transformer Engine"
32-
copyright = "{}, NVIDIA CORPORATION & AFFILIATES. All rights reserved.".format(copyright_year)
33-
author = "NVIDIA CORPORATION & AFFILIATES"
24+
# Transformer Engine root directory
25+
root_path = pathlib.Path(__file__).resolve().parent.parent
3426

27+
# Git hash
3528
git_sha = os.getenv("GIT_SHA")
36-
3729
if not git_sha:
3830
try:
3931
git_sha = (
@@ -44,31 +36,16 @@
4436
)
4537
except:
4638
git_sha = "0000000"
47-
4839
git_sha = git_sha[:7] if len(git_sha) > 7 else git_sha
4940

50-
if "dev" in te_version:
51-
version = str(te_version + "-" + git_sha)
41+
# Version
42+
with open(root_path / "build_tools" / "VERSION.txt", "r") as f:
43+
_raw_version = f.readline().strip()
44+
if "dev" in _raw_version:
45+
version = str(_raw_version + "-" + git_sha)
5246
else:
53-
version = str(te_version)
54-
release = te_version
55-
56-
# hack: version is used for html creation, so put the version picker
57-
# link here as well:
58-
option_on = " selected"
59-
option_off = ""
60-
release_opt = option_on
61-
option_nr = 0
62-
version = (
63-
version
64-
+ """<br/>
65-
Version select: <select onChange="window.location.href = this.value" onFocus="this.selectedIndex = {0}">
66-
<option value="https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html"{1}>Current release</option>
67-
<option value="https://docs.nvidia.com/deeplearning/transformer-engine/documentation-archive.html">Older releases</option>
68-
</select>""".format(
69-
option_nr, release_opt
70-
)
71-
)
47+
version = str(_raw_version)
48+
release = _raw_version
7249

7350
# -- General configuration ---------------------------------------------------
7451
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
@@ -92,12 +69,10 @@
9269

9370
pygments_style = "sphinx"
9471

95-
9672
# -- Options for HTML output -------------------------------------------------
9773
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
9874

9975
html_theme = "sphinx_rtd_theme"
100-
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
10176
html_static_path = ["_static"]
10277
html_show_sphinx = False
10378

@@ -106,7 +81,12 @@
10681
"css/nvidia_footer.css",
10782
]
10883

109-
html_theme_options = {"display_version": True, "collapse_navigation": False, "logo_only": False}
84+
html_theme_options = {
85+
"collapse_navigation": False,
86+
"logo_only": False,
87+
"version_selector": False,
88+
"language_selector": False,
89+
}
11090

11191
napoleon_custom_sections = [
11292
("Parallelism parameters", "params_style"),
@@ -116,8 +96,8 @@
11696
("FP8-related parameters", "params_style"),
11797
]
11898

119-
breathe_projects = {"TransformerEngine": os.path.abspath("doxygen/xml/")}
99+
breathe_projects = {"TransformerEngine": root_path / "docs" / "doxygen" / "xml"}
120100
breathe_default_project = "TransformerEngine"
121101

122102
autoapi_generate_api_docs = False
123-
autoapi_dirs = ["../transformer_engine"]
103+
autoapi_dirs = [root_path / "transformer_engine"]

docs/version_select.patch

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
diff --git a/sphinx_rtd_theme/layout.html b/sphinx_rtd_theme/layout.html
2+
index e6a38b1..579eaec 100644
3+
--- a/sphinx_rtd_theme/layout.html
4+
+++ b/sphinx_rtd_theme/layout.html
5+
@@ -124,6 +124,16 @@
6+
{%- endif %}
7+
</a>
8+
9+
+ {# Show TE version and version selector #}
10+
+ <div class="version">
11+
+ {{ version }}
12+
+ <br>
13+
+ Version select: <select onChange="window.location.href = this.value" onFocus="this.selectedIndex = {0}">
14+
+ <option value="https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html"{1}>Current release</option>
15+
+ <option value="https://docs.nvidia.com/deeplearning/transformer-engine/documentation-archive.html">Older releases</option>
16+
+ </select>
17+
+ </div>
18+
+
19+
{%- if READTHEDOCS or DEBUG %}
20+
{%- if theme_version_selector or theme_language_selector %}
21+
<div class="switch-menus">

examples/jax/encoder/common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
"""Shared functions for the encoder tests"""
5+
from functools import lru_cache
6+
7+
from transformer_engine.transformer_engine_jax import get_device_compute_capability
8+
9+
10+
@lru_cache
11+
def is_bf16_supported():
12+
"""Return if BF16 has hardware supported"""
13+
gpu_arch = get_device_compute_capability(0)
14+
return gpu_arch >= 80

0 commit comments

Comments
 (0)