Skip to content

Commit

Permalink
[PyTorch] Add Float8Tensor option to avoid updating transpose cache w…
Browse files Browse the repository at this point in the history
…hen possible (#662)

* Add option to avoid updating transpose cache when possible

Signed-off-by: Tim Moon <[email protected]>

* Fix typo

Signed-off-by: Tim Moon <[email protected]>

* Use string kwarg for FP8 transpose caching

Signed-off-by: Tim Moon <[email protected]>

* Remove unused attr

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 authored Feb 15, 2024
1 parent bdf1afe commit 1e78094
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 16 deletions.
28 changes: 25 additions & 3 deletions tests/pytorch/test_float8tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,22 +298,44 @@ def test_transpose(

# Check transpose caching
if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]:

# Check that cached transpose is returned when expected
# Note: Sneakily destroy data so that recalculating
# transpose would give wrong answer.
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_fp8.transpose(*transpose_dims, update_cache="lazy"),
x_ref.transpose(*transpose_dims),
**tols,
)
x_fp8_data = x_fp8._data.clone()
x_fp8._data.zero_()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims),
x_ref.transpose(*transpose_dims),
**tols,
)
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_fp8.transpose(*transpose_dims, update_cache="lazy"),
x_ref.transpose(*transpose_dims),
**tols,
)
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache="force"),
torch.zeros_like(x_ref.transpose(*transpose_dims)),
rtol=0,
atol=0,
)
x_fp8._data.copy_(x_fp8_data)
x_fp8._reset_caches()

# Make sure cache is reset after in-place operation
x_fp8.transpose(*transpose_dims, update_cache="force")
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_fp8.transpose(*transpose_dims),
x_ref.transpose(*transpose_dims),
**tols,
)
Expand Down
31 changes: 22 additions & 9 deletions transformer_engine/pytorch/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def transpose(
dim0: int = 0,
dim1: int = 1,
*,
update_cache: bool = False,
update_cache: str | bool = "reuse_only",
) -> torch.Tensor:
"""
Swap tensor dimensions
Expand All @@ -454,31 +454,44 @@ def transpose(
The first dimension to be transposed
dim1: int, default = 1
The second dimension to be transposed
update_cache: bool, default = False
If `True`, the transpose is computed and stored
in a cache. If `False`, a cached version is
returned if available and otherwise the
transpose is computed. Caching is only supported
update_cache: str or bool, default = "reuse_only"
Memoization behavior. Options are
"reuse_only"/`False` (reuse cached value if
available, otherwise calculate transpose without
caching), "force"/`True` (calculate transpose
and cache), "lazy" (reuse cached value if
available, otherwise calculate transpose and
cache if possible). Caching is only supported
for basic 2D transposes and the cache is reset
after any in-place operations.
"""

# Check caching mode
if not isinstance(update_cache, str):
update_cache = "force" if update_cache else "reuse_only"
if update_cache not in ("force", "reuse_only", "lazy"):
raise ValueError(
"Supported values for update_cache are "
'"force" (True), "reuse_only" (False), "lazy" '
f"(got {update_cache})"
)

# Handle non-2D transposes
if -self.dim() <= dim0 < 0:
dim0 += self.dim()
if -self.dim() <= dim1 < 0:
dim1 += self.dim()
if self.dim() != 2 or dim0 == dim1:
if update_cache:
if update_cache == "force":
raise ValueError(
"Transpose caching is only supported for basic 2D transposes "
f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})"
)
return super().transpose(dim0, dim1)

# Clear cache if needed
if update_cache:
if update_cache == "force":
self._transpose = None

# Compute transpose if needed
Expand All @@ -493,7 +506,7 @@ def transpose(
)

# Update cache if needed
if update_cache:
if update_cache in ("force", "lazy"):
self._transpose = out
return out

Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ def backward(

# Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch)
weight_t_fp8 = weight.transpose(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy",
)

if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
Expand Down
5 changes: 3 additions & 2 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,10 +560,11 @@ def backward(
fc2_weight.main_grad = fc2_weight_main_grad

# Primary weights are in FP8.
update_transpose_cache = "reuse_only" if ctx.is_first_microbatch is None else "lazy"
if ctx.fp8 and fc1_weight_t_fp8 is None:
fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=ctx.is_first_microbatch)
fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=update_transpose_cache)
if ctx.fp8 and fc2_weight_t_fp8 is None:
fc2_weight_t_fp8 = fc2_weight.transpose(update_cache=ctx.is_first_microbatch)
fc2_weight_t_fp8 = fc2_weight.transpose(update_cache=update_transpose_cache)

activation_func = _act_func(ctx.activation)[1]

Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,9 @@ def backward(

# Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch)
weight_t_fp8 = weight.transpose(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy",
)

if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group)
Expand Down

0 comments on commit 1e78094

Please sign in to comment.