Skip to content

Commit

Permalink
add ref
Browse files Browse the repository at this point in the history
Signed-off-by: Youngeun Kwon <[email protected]>
  • Loading branch information
youngeunkwon0405 committed Dec 5, 2024
1 parent daba5a6 commit 689e30a
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def fsdp_pre_all_gather(self, mesh): # pylint: disable=unused-argument
"""
A hook function used in torch fsdp2, called before all-gather
return (all-gather input), (metadata)
Ref: https://github.com/pytorch/pytorch/pull/122908
"""

Expand All @@ -463,6 +464,7 @@ def fsdp_post_all_gather(
"""
A hook function used in torch fsdp2, called after all-gather
return (Float8Tensor class instance of all-gathered input), (Things to free after forward)
Ref: https://github.com/pytorch/pytorch/pull/122908
"""
(data,) = all_gather_outputs
Expand Down

0 comments on commit 689e30a

Please sign in to comment.