From 689e30a017b1f22fe3ab0b6984e3ebecedf56f28 Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Wed, 4 Dec 2024 17:49:03 -0800 Subject: [PATCH] add ref Signed-off-by: Youngeun Kwon --- transformer_engine/pytorch/tensor/float8_tensor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 32cb8e7c28..414e819f53 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -447,6 +447,7 @@ def fsdp_pre_all_gather(self, mesh): # pylint: disable=unused-argument """ A hook function used in torch fsdp2, called before all-gather return (all-gather input), (metadata) + Ref: https://github.com/pytorch/pytorch/pull/122908 """ @@ -463,6 +464,7 @@ def fsdp_post_all_gather( """ A hook function used in torch fsdp2, called after all-gather return (Float8Tensor class instance of all-gathered input), (Things to free after forward) + Ref: https://github.com/pytorch/pytorch/pull/122908 """ (data,) = all_gather_outputs