Skip to content

Commit 095b27d

Browse files
[PyTorch] Userbuffers support in operation-based API (#1142)
* Add Userbuffers support for column TP linear layer Signed-off-by: Tim Moon <[email protected]> * Add Userbuffers support for row TP linear layer Signed-off-by: Tim Moon <[email protected]> * Interpret linear+RS as row TP linear Signed-off-by: Tim Moon <[email protected]> * Add Userbuffers support for FP8 row TP linear layer Assumes FP8 RS, which is not a good assumption. Signed-off-by: Tim Moon <[email protected]> * Debug bug with incorrect bias pointers in UB GEMM Bias pointers are not properly offset for different data chunks. Also removed logic for FP8 RS. Signed-off-by: Tim Moon <[email protected]> * Add Userbuffers support for linear dgrad Test passes with row TP, fails with col TP. Signed-off-by: Tim Moon <[email protected]> * Add Userbuffers support for linear wgrad Signed-off-by: Tim Moon <[email protected]> * Add support for grad bias Signed-off-by: Tim Moon <[email protected]> * Fused cast-transpose-dbias Signed-off-by: Tim Moon <[email protected]> * Support case where wgrad is optional Signed-off-by: Tim Moon <[email protected]> * Expand documentation Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> * Use recently added convenience functions in Float8Tensor Signed-off-by: Tim Moon <[email protected]> * Respect autograd dtype Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix missing imports Signed-off-by: Tim Moon <[email protected]> * Respect PyT autocast dtype in bprop Signed-off-by: Tim Moon <[email protected]> * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> * Debug merge conflicts Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 77c37d4 commit 095b27d

File tree

11 files changed

+2033
-10
lines changed

11 files changed

+2033
-10
lines changed

qa/L1_pytorch_distributed_unittest/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ pip install pytest==8.2.1
1010
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
1111
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
1212
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
13+
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
1314
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py

0 commit comments

Comments
 (0)