Skip to content

Commit

Permalink
Don't save fp8 weight tensors if is_first_microbatch is None (#244)
Browse files Browse the repository at this point in the history
* extend fp8 weight placeholders logic for Linear, LNLinear, LNMLP

Signed-off-by: Sudhakar Singh <[email protected]>

* Update transformer_engine/pytorch/module/base.py

Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>

* Update transformer_engine/pytorch/module/base.py

Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>

* Update transformer_engine/pytorch/module/base.py

Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>

* Update transformer_engine/pytorch/module/base.py

Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>

* Update transformer_engine/pytorch/module/base.py

Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>

* Update transformer_engine/pytorch/module/layernorm_linear.py

Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>

* Update transformer_engine/pytorch/module/layernorm_mlp.py

Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>

* Update transformer_engine/pytorch/module/linear.py

Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>

* Update linear.py

Signed-off-by: Sudhakar Singh <[email protected]>

* Update layernorm_linear.py

Signed-off-by: Sudhakar Singh <[email protected]>

* Update layernorm_mlp.py

Signed-off-by: Sudhakar Singh <[email protected]>

* lint

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
sudhakarsingh27 and ksivaman authored Jun 1, 2023
1 parent 5495883 commit 80825fd
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 12 deletions.
50 changes: 49 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,11 @@ def prepare_forward(

self.set_activation_dtype(inp)
self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights()

# Create persistent tensors for fp8 weights and their transposes
# only when fp8 weight caching is used.
if is_first_microbatch is not None:
self.set_fp8_weights()

update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
if self.fp8 and self.sequence_parallel:
Expand Down Expand Up @@ -765,6 +769,50 @@ def noop_cat(self, buffer_name: str, pnames: List[str]) -> torch.Tensor:

return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames])

def get_fp8_weights_empty_tensors(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""
Returns empty tensors to be later used to store fp8 version of weights
and their transposes (for the bwd pass) for this batch (or microbatch).
When `is_first_microbatch` is `None`, this is especially useful since
we then don't need to store the fp8 weights that are needed for one time
only in the forward pass. Note that we still need to store the tensor
for the fp8 weight transpose which is at least needed in the backward
pass but that's taken care of by storing the transpose tensor in
`ctx.save_for_backward`.
"""
assert is_first_microbatch is None, "Should only be here when "\
"`is_first_microbatch` is None!"
fp8_weight_tensors = []
for shape in self.fp8_weight_shapes:
fp8_weight_tensors.append(
torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
)
)

fp8_weight_tensors.append(
torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
)
)
return fp8_weight_tensors


@abstractmethod
def forward(self):
"""Needs override."""

@abstractmethod
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""Needs override."""
35 changes: 32 additions & 3 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

"""LayerNormLinear API"""
import os
from typing import Union, Optional, Callable, Tuple, Dict, Any
from typing import Union, Optional, Callable, Tuple, List, Dict, Any


import torch
Expand Down Expand Up @@ -791,6 +791,30 @@ def reset_layer_norm_parameters(self) -> None:
init.zeros_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias)

def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8:
return [None, None]

if is_first_microbatch is None:
# Return empty weight placeholders for each fwd/bwd pass
fp8_weight_tensors = self.get_fp8_weights_empty_tensors(
is_first_microbatch
)
else:
# These persistent weight placeholders should've been created in
# `set_fp8_weights` method
fp8_weight_tensors = [self.weight1_fp8, self.weight1_t_fp8]

return fp8_weight_tensors

def forward(
self,
inp: torch.Tensor,
Expand Down Expand Up @@ -841,6 +865,11 @@ def forward(
else self.noop_cat("weight_tensor", self.weight_names)
)

# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
is_first_microbatch
)

if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply
args = []
Expand All @@ -852,8 +881,8 @@ def forward(
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
weight1_fp8,
weight1_t_fp8,
bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps,
Expand Down
41 changes: 36 additions & 5 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

"""LayerNormMLP API"""
import os
from typing import Union, Optional, Callable, Tuple, Dict, Any
from typing import Union, Optional, Callable, Tuple, List, Dict, Any

import torch
from torch.nn.parameter import Parameter
Expand Down Expand Up @@ -1063,6 +1063,31 @@ def reset_layer_norm_parameters(self) -> None:
init.zeros_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias)

def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8:
return [None, None, None, None]

if is_first_microbatch is None:
# Return empty weight placeholders for each fwd/bwd pass
fp8_weight_tensors = self.get_fp8_weights_empty_tensors(
is_first_microbatch
)
else:
# These persistent weight placeholders should've been created in
# `set_fp8_weights` method
fp8_weight_tensors = [self.weight1_fp8, self.weight1_t_fp8,
self.weight2_fp8, self.weight2_t_fp8]

return fp8_weight_tensors

def forward(
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
Expand All @@ -1089,6 +1114,12 @@ def forward(
"""

with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8, weight2_fp8, weight2_t_fp8 = \
self.get_fp8_weights_scratchpad(
is_first_microbatch
)

if torch.is_grad_enabled():
fwd_fn = _LayerNormMLP.apply
args = []
Expand All @@ -1100,13 +1131,13 @@ def forward(
self.layer_norm_weight,
self.layer_norm_bias,
self.fc1_weight,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
weight1_fp8,
weight1_t_fp8,
self.fc1_bias,
self.use_bias,
self.fc2_weight,
self.weight2_fp8 if self.fp8 else None,
self.weight2_t_fp8 if self.fp8 else None,
weight2_fp8,
weight2_t_fp8,
self.fc2_bias,
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps,
Expand Down
35 changes: 32 additions & 3 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# See LICENSE for license information.

"""Linear API"""
from typing import Union, Optional, Callable, Tuple, Dict, Any
from typing import Union, Optional, Callable, Tuple, List, Dict, Any

import torch
from torch.nn.parameter import Parameter
Expand Down Expand Up @@ -641,6 +641,30 @@ def __init__(
else:
self.gemm_bias_unfused_add = False

def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8:
return [None, None]

if is_first_microbatch is None:
# Return empty weight placeholders for each fwd/bwd pass
fp8_weight_tensors = self.get_fp8_weights_empty_tensors(
is_first_microbatch
)
else:
# These persistent weight placeholders should've been created in
# `set_fp8_weights` method
fp8_weight_tensors = [self.weight1_fp8, self.weight1_t_fp8]

return fp8_weight_tensors

def forward(
self,
inp: torch.Tensor,
Expand Down Expand Up @@ -691,6 +715,11 @@ def forward(
else self.noop_cat("weight_tensor", self.weight_names)
)

# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
is_first_microbatch
)

if torch.is_grad_enabled():
linear_fn = _Linear.apply
args = []
Expand All @@ -699,8 +728,8 @@ def forward(
args = [None]
args += (
weight_tensor,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
weight1_fp8,
weight1_t_fp8,
inp,
bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add,
Expand Down

0 comments on commit 80825fd

Please sign in to comment.