Skip to content

Commit

Permalink
Deprecate unused APIs (#321)
Browse files Browse the repository at this point in the history
* Deprecate unused APIs

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* review comments

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Review

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman authored Jul 14, 2023
1 parent b172bad commit 58d2eba
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 130 deletions.
142 changes: 77 additions & 65 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""LayerNormLinear API"""
import os
import warnings
from typing import Union, Optional, Callable, Tuple, List, Dict, Any


Expand Down Expand Up @@ -538,6 +539,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
r"""
Applies layer normalization followed by linear transformation to the incoming data.
.. warning::
Argument :attr:`skip_weight_param_allocation` is deprecated and will
be fully removed in future releases.
Parameters
----------
in_features : int
Expand Down Expand Up @@ -585,9 +591,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
skip_weight_param_allocation: bool, default = `False`
if set to `True`, weight parameter is not allocated and must be
passed as a keyword argument `weight` during the forward pass.
Optimization parameters
-----------------------
Expand Down Expand Up @@ -633,6 +636,14 @@ def __init__(
) -> None:
super().__init__()

if skip_weight_param_allocation:
warnings.warn(
"Argument `skip_weight_param_allocation` is deprecated and"
"will be fully removed in future releases. It is ignored"
"starting from v0.11.",
category=DeprecationWarning,
)

params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features
self.out_features = out_features
Expand Down Expand Up @@ -695,72 +706,71 @@ def __init__(
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
self.reset_layer_norm_parameters()

if not skip_weight_param_allocation:
self.weight_tensor = torch.empty(
self.out_features, self.in_features,
self.weight_tensor = torch.empty(
self.out_features, self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype)

initialize_affine_weight_gpu(
self.weight_tensor,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)

if self.use_bias:
self.bias_tensor = torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype)
else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device())

initialize_affine_weight_gpu(
self.weight_tensor,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
with torch.no_grad():
self.bias_tensor.zero_()

if self.use_bias:
self.bias_tensor = torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype)
else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device())
if parameters_split is None:
parameters_split = ("",)

with torch.no_grad():
self.bias_tensor.zero_()
assert (
self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"

if parameters_split is None:
parameters_split = ("",)
split_size = self.out_features // len(parameters_split)

assert (
self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
self.weight_names = []
self.bias_names = []

split_size = self.out_features // len(parameters_split)
for i, pname in enumerate(parameters_split):
wname = pname + "weight"
bname = pname + "bias"

self.weight_names = []
self.bias_names = []
self.register_parameter(
wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
)

for i, pname in enumerate(parameters_split):
wname = pname + "weight"
bname = pname + "bias"
set_tensor_model_parallel_attributes(
tensor=getattr(self, wname),
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)

if self.use_bias:
self.register_parameter(
wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()))

set_tensor_model_parallel_attributes(
tensor=getattr(self, wname),
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)

if self.use_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()))

if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)

self.weight_names.append(wname)
self.bias_names.append(bname)
self.weight_names.append(wname)
self.bias_names.append(bname)

self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))

Expand Down Expand Up @@ -821,17 +831,15 @@ def forward(
"""
Apply layer normalization to the input followed by a linear transformation.
.. warning::
Arguments :attr:`weight` and :attr:`bias` are deprecated and will
be fully removed in future releases.
Parameters
----------
inp : torch.Tensor
Input tensor.
weight : torch.Tensor, default = None
An optional weight tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True`
bias : torch.Tensor, default = None
An optional bias tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True` and one of `use_bias`
or `return_bias`
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
Expand All @@ -847,16 +855,20 @@ def forward(
produced)
"""

if weight is not None or bias is not None:
raise RuntimeError(
"Arguments `weight` and `bias` are deprecated and "
"will be fully removed in future releases."
)

with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = (
bias if bias is not None
else self.bias if self.parameters_split is None
self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names)
)
weight_tensor = (
weight if weight is not None
else self.weight if self.parameters_split is None
self.weight if self.parameters_split is None
else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names)
)
Expand Down
Loading

0 comments on commit 58d2eba

Please sign in to comment.