Skip to content

[PyTorch] Refactor parameter splitting in Linear and LayerNormLinear #590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# See LICENSE for license information.

"""Attention."""
import collections
import os
import warnings
import math
Expand Down Expand Up @@ -2717,9 +2718,13 @@ def __init__(
qkv_parallel_mode = "column" if set_parallel_mode else None

if self.attention_type == "self":
parameters_split = {"query_": hidden_size,
"key_": self.hidden_size_kv,
"value_": self.hidden_size_kv} if not fuse_qkv_params else None
parameters_split = None
if not fuse_qkv_params:
parameters_split = collections.OrderedDict([
("query", hidden_size),
("key", self.hidden_size_kv),
("value", self.hidden_size_kv),
])
if self.input_layernorm:
self.layernorm_qkv = LayerNormLinear(
hidden_size,
Expand Down Expand Up @@ -2761,7 +2766,7 @@ def __init__(
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
parameters_split=("query_",) if not fuse_qkv_params else None,
parameters_split=("query",) if not fuse_qkv_params else None,
return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
Expand Down Expand Up @@ -2789,7 +2794,7 @@ def __init__(
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
parameters_split=("key_", "value_") if not fuse_qkv_params else None,
parameters_split=("key", "value") if not fuse_qkv_params else None,
**common_gemm_kwargs,
)

Expand Down
96 changes: 95 additions & 1 deletion transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

"""Internal function used by multiple modules."""

from typing import Union, Dict, Any
from typing import Any, Dict, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -93,3 +93,97 @@ def _apply_normalization(inputmat:torch.Tensor,
elif normalization == "LayerNorm":
output = (ln_out, output[1], output[2])
return output


class _NoopCatFunc(torch.autograd.Function):
"""No-op concatenate tensors along dim 0

`full_tensor` is assumed to already be the concatenation of
`tensors`, i.e. they occupy the same memory with the correct
offsets.

"""

@staticmethod
def forward(
ctx,
split_ranges: List[Tuple[int, int]],
full_tensor: torch.Tensor,
*tensors: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
# pylint: disable=unused-argument
ctx.split_ranges = split_ranges
assert not full_tensor.requires_grad, "Concatenated tensor should not require gradient"
out = full_tensor.new()
out.set_(
full_tensor.untyped_storage(),
full_tensor.storage_offset(),
full_tensor.size(),
full_tensor.stride(),
)
out.requires_grad = True
return out

@staticmethod
def backward(
ctx,
grad_output: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
grads = [
grad_output[split_start:split_end]
for split_start, split_end in ctx.split_ranges
]
return None, None, *grads


def _noop_cat(
tensors: List[torch.Tensor],
full_tensor: torch.Tensor,
) -> torch.Tensor:
"""Concatenate tensors along dim 0, doing a no-op if possible

If `full_tensor` is already the concatenation of `tensors`, i.e.
they occupy the same memory region with the correct offsets, then
no copies are performed. Otherwise the buffers in all the tensors
are reallocated so that another call would result in a no-op.

In the backward pass, gradients to `partial_tensors` will just be
tensor views.

"""

# Determine split points
split_ranges = []
full_tensor_shape = full_tensor.size()
offset = 0
for tensor in tensors:
tensor_shape = tensor.size()
if tensor_shape[1:] != full_tensor_shape[1:]:
raise ValueError(
f"Attempting to concatenate tensor with shape={list(tensor_shape)} "
f"into a tensor with shape={list(full_tensor_shape)}"
)
split_start = offset
offset += tensor_shape[0]
split_end = offset
split_ranges.append((split_start, split_end))
if offset != full_tensor_shape[0]:
raise ValueError(
f"Attempting to concatenate tensors with total shape[0]={offset} "
f"into a tensor with shape[0]={full_tensor_shape[0]}"
)

# Reallocate buffers if no-op concat isn't possible
need_to_reallocate = False
for tensor, (split_start, _) in zip(tensors, split_ranges):
if tensor.data_ptr() != full_tensor[split_start].data_ptr():
need_to_reallocate = True
break
if need_to_reallocate:
with torch.no_grad():
full_tensor.data = torch.cat(tensors)
for tensor, (split_start, split_end) in zip(tensors, split_ranges):
tensor.data = full_tensor[split_start:split_end]

# Perform no-op concat
return _NoopCatFunc.apply(split_ranges, full_tensor, *tensors)
73 changes: 0 additions & 73 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

import transformer_engine_extensions as tex
from ..export import is_in_onnx_export_mode
Expand Down Expand Up @@ -213,44 +212,6 @@ def get_ub(name: str):
return _ub_communicators[name]


class _NoopCat(torch.autograd.Function):
"""This class is a no-op replacement for `torch.cat`."""

@staticmethod
def forward(ctx,
full_param_buffer: torch.Tensor,
*params_split: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
assert not full_param_buffer.requires_grad, "Buffers should not require gradient"
sum_params_shape = sum(p.shape[0] for p in params_split)
assert (
full_param_buffer.shape[0] == sum_params_shape
), "Dimensions not compatible for concatenation"

param_temp = full_param_buffer.new()
param_temp.set_(full_param_buffer.untyped_storage(),
full_param_buffer.storage_offset(),
full_param_buffer.size(),
full_param_buffer.stride())
param_temp.requires_grad = True

ctx.save_for_backward(*params_split)
return param_temp

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
params_split = ctx.saved_tensors
grads = []
slice_begin = 0
for i, _ in enumerate(params_split):
slice_size = params_split[i].shape[0]
slice_end = slice_begin + slice_size
grads.append(grad_output[slice_begin:slice_end])
slice_begin = slice_end

return None, *grads


class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""

Expand Down Expand Up @@ -742,40 +703,6 @@ def grad_output_preprocess(

return grad_output_mat, grad_output_c, grad_output_t, grad_bias

def noop_cat(self,
buffer_name: str,
pnames: List[str],
parameters_split: Dict[str, int]
) -> torch.Tensor:
"""No-op replacement of `torch.cat`. The buffer and split parameters must occupy
the same memory region. If this is not the case, then the split parameters
are concatenated and the buffer is overwritten. The parameters' memory is then
re-assigned to point to the buffer to avoid subsequent concatenations.
"""

assert hasattr(self, buffer_name), f"No buffer named {buffer_name}"
full_param_buffer = getattr(self, buffer_name)
params = [getattr(self, name) for name in pnames]
slice_begin = 0
for i, p in enumerate(params):
slice_size = parameters_split[pnames[i].split('_')[0]+'_']
slice_end = slice_begin + slice_size
if p.data.data_ptr() != full_param_buffer[slice_begin:slice_end].data_ptr():
with torch.no_grad():
setattr(self, buffer_name, torch.cat(params))
slice_begin_j = 0
for pname in pnames:
slice_size_j = parameters_split[pname.split('_')[0]+'_']
slice_end_j = slice_begin_j + slice_size_j
full_param_buffer = getattr(self, buffer_name)
setattr(self, pname,
Parameter(full_param_buffer[slice_begin_j:slice_end_j]))
slice_begin_j = slice_end_j
break
slice_begin = slice_end

return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames])

def get_fp8_weights_empty_tensors(
self,
is_first_microbatch: Union[bool, None],
Expand Down
Loading