diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 00a7c2003b..568de4e17c 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Attention.""" +import collections import os import warnings import math @@ -2717,9 +2718,13 @@ def __init__( qkv_parallel_mode = "column" if set_parallel_mode else None if self.attention_type == "self": - parameters_split = {"query_": hidden_size, - "key_": self.hidden_size_kv, - "value_": self.hidden_size_kv} if not fuse_qkv_params else None + parameters_split = None + if not fuse_qkv_params: + parameters_split = collections.OrderedDict([ + ("query", hidden_size), + ("key", self.hidden_size_kv), + ("value", self.hidden_size_kv), + ]) if self.input_layernorm: self.layernorm_qkv = LayerNormLinear( hidden_size, @@ -2761,7 +2766,7 @@ def __init__( bias=bias, return_bias=False, parallel_mode=qkv_parallel_mode, - parameters_split=("query_",) if not fuse_qkv_params else None, + parameters_split=("query",) if not fuse_qkv_params else None, return_layernorm_output=return_layernorm_output, zero_centered_gamma=zero_centered_gamma, ub_bulk_wgrad=ub_bulk_wgrad, @@ -2789,7 +2794,7 @@ def __init__( bias=bias, return_bias=False, parallel_mode=qkv_parallel_mode, - parameters_split=("key_", "value_") if not fuse_qkv_params else None, + parameters_split=("key", "value") if not fuse_qkv_params else None, **common_gemm_kwargs, ) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index b8f4a75f7d..edc3da120d 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -4,7 +4,7 @@ """Internal function used by multiple modules.""" -from typing import Union, Dict, Any +from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -93,3 +93,97 @@ def _apply_normalization(inputmat:torch.Tensor, elif normalization == "LayerNorm": output = (ln_out, output[1], output[2]) return output + + +class _NoopCatFunc(torch.autograd.Function): + """No-op concatenate tensors along dim 0 + + `full_tensor` is assumed to already be the concatenation of + `tensors`, i.e. they occupy the same memory with the correct + offsets. + + """ + + @staticmethod + def forward( + ctx, + split_ranges: List[Tuple[int, int]], + full_tensor: torch.Tensor, + *tensors: Tuple[torch.Tensor, ...], + ) -> torch.Tensor: + # pylint: disable=unused-argument + ctx.split_ranges = split_ranges + assert not full_tensor.requires_grad, "Concatenated tensor should not require gradient" + out = full_tensor.new() + out.set_( + full_tensor.untyped_storage(), + full_tensor.storage_offset(), + full_tensor.size(), + full_tensor.stride(), + ) + out.requires_grad = True + return out + + @staticmethod + def backward( + ctx, + grad_output: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + grads = [ + grad_output[split_start:split_end] + for split_start, split_end in ctx.split_ranges + ] + return None, None, *grads + + +def _noop_cat( + tensors: List[torch.Tensor], + full_tensor: torch.Tensor, +) -> torch.Tensor: + """Concatenate tensors along dim 0, doing a no-op if possible + + If `full_tensor` is already the concatenation of `tensors`, i.e. + they occupy the same memory region with the correct offsets, then + no copies are performed. Otherwise the buffers in all the tensors + are reallocated so that another call would result in a no-op. + + In the backward pass, gradients to `partial_tensors` will just be + tensor views. + + """ + + # Determine split points + split_ranges = [] + full_tensor_shape = full_tensor.size() + offset = 0 + for tensor in tensors: + tensor_shape = tensor.size() + if tensor_shape[1:] != full_tensor_shape[1:]: + raise ValueError( + f"Attempting to concatenate tensor with shape={list(tensor_shape)} " + f"into a tensor with shape={list(full_tensor_shape)}" + ) + split_start = offset + offset += tensor_shape[0] + split_end = offset + split_ranges.append((split_start, split_end)) + if offset != full_tensor_shape[0]: + raise ValueError( + f"Attempting to concatenate tensors with total shape[0]={offset} " + f"into a tensor with shape[0]={full_tensor_shape[0]}" + ) + + # Reallocate buffers if no-op concat isn't possible + need_to_reallocate = False + for tensor, (split_start, _) in zip(tensors, split_ranges): + if tensor.data_ptr() != full_tensor[split_start].data_ptr(): + need_to_reallocate = True + break + if need_to_reallocate: + with torch.no_grad(): + full_tensor.data = torch.cat(tensors) + for tensor, (split_start, split_end) in zip(tensors, split_ranges): + tensor.data = full_tensor[split_start:split_end] + + # Perform no-op concat + return _NoopCatFunc.apply(split_ranges, full_tensor, *tensors) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 05bb35efec..cf9634b2cc 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -14,7 +14,6 @@ import torch import torch.nn.functional as F -from torch.nn.parameter import Parameter import transformer_engine_extensions as tex from ..export import is_in_onnx_export_mode @@ -213,44 +212,6 @@ def get_ub(name: str): return _ub_communicators[name] -class _NoopCat(torch.autograd.Function): - """This class is a no-op replacement for `torch.cat`.""" - - @staticmethod - def forward(ctx, - full_param_buffer: torch.Tensor, - *params_split: Tuple[torch.Tensor, ...], - ) -> torch.Tensor: - assert not full_param_buffer.requires_grad, "Buffers should not require gradient" - sum_params_shape = sum(p.shape[0] for p in params_split) - assert ( - full_param_buffer.shape[0] == sum_params_shape - ), "Dimensions not compatible for concatenation" - - param_temp = full_param_buffer.new() - param_temp.set_(full_param_buffer.untyped_storage(), - full_param_buffer.storage_offset(), - full_param_buffer.size(), - full_param_buffer.stride()) - param_temp.requires_grad = True - - ctx.save_for_backward(*params_split) - return param_temp - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - params_split = ctx.saved_tensors - grads = [] - slice_begin = 0 - for i, _ in enumerate(params_split): - slice_size = params_split[i].shape[0] - slice_end = slice_begin + slice_size - grads.append(grad_output[slice_begin:slice_end]) - slice_begin = slice_end - - return None, *grads - - class TransformerEngineBaseModule(torch.nn.Module, ABC): """Base TE module.""" @@ -742,40 +703,6 @@ def grad_output_preprocess( return grad_output_mat, grad_output_c, grad_output_t, grad_bias - def noop_cat(self, - buffer_name: str, - pnames: List[str], - parameters_split: Dict[str, int] - ) -> torch.Tensor: - """No-op replacement of `torch.cat`. The buffer and split parameters must occupy - the same memory region. If this is not the case, then the split parameters - are concatenated and the buffer is overwritten. The parameters' memory is then - re-assigned to point to the buffer to avoid subsequent concatenations. - """ - - assert hasattr(self, buffer_name), f"No buffer named {buffer_name}" - full_param_buffer = getattr(self, buffer_name) - params = [getattr(self, name) for name in pnames] - slice_begin = 0 - for i, p in enumerate(params): - slice_size = parameters_split[pnames[i].split('_')[0]+'_'] - slice_end = slice_begin + slice_size - if p.data.data_ptr() != full_param_buffer[slice_begin:slice_end].data_ptr(): - with torch.no_grad(): - setattr(self, buffer_name, torch.cat(params)) - slice_begin_j = 0 - for pname in pnames: - slice_size_j = parameters_split[pname.split('_')[0]+'_'] - slice_end_j = slice_begin_j + slice_size_j - full_param_buffer = getattr(self, buffer_name) - setattr(self, pname, - Parameter(full_param_buffer[slice_begin_j:slice_end_j])) - slice_begin_j = slice_end_j - break - slice_begin = slice_end - - return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames]) - def get_fp8_weights_empty_tensors( self, is_first_microbatch: Union[bool, None], diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index fb36d3427b..d36d5a9923 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -7,9 +7,7 @@ import warnings from typing import Union, Optional, Callable, Tuple, List, Dict, Any - import torch -from torch.nn.parameter import Parameter from torch.nn import init from .. import cpp_extensions as tex @@ -41,7 +39,7 @@ ) from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo -from ._common import _apply_normalization +from ._common import _apply_normalization, _noop_cat from ..float8_tensor import Float8Tensor @@ -612,13 +610,13 @@ class LayerNormLinear(TransformerEngineBaseModule): Example use case: residual connection for transformer module is taken post layernorm. parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None - if a tuple of strings or a dict of strings to integers is provided, - the weight and bias parameters of the module are exposed as `N` separate - `torch.nn.parameter.Parameter`s each, split along the first dimension, - where `N` is the length of the argument and the strings contained are the - names of the split parameters. In the case of a tuple, each parameter - has the same shape. In the case of a dict, the values give the - `out_features` for each projection. + Configuration for splitting the weight and bias tensors along dim 0 into + multiple PyTorch parameters. If a list or tuple of strings is provided, + they are used to make the names of equally-sized parameters. If a dict + (preferably an OrderedDict) is provided, the keys are used as names and + values as split sizes along dim 0. The resulting parameters will have + names that end in `_weight` or `_bias`, so trailing underscores are + stripped from any provided names. zero_centered_gamma : bool, default = 'False' if set to 'True', gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to @@ -705,7 +703,6 @@ def __init__( self.return_bias = return_bias self.apply_bias = self.use_bias and not return_bias self.return_layernorm_output = return_layernorm_output - self.parameters_split = parameters_split self.zero_centered_gamma = zero_centered_gamma self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad @@ -752,12 +749,12 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.eps = eps - self.layer_norm_weight = Parameter( + self.layer_norm_weight = torch.nn.Parameter( torch.empty(in_features, device=device, dtype=params_dtype) ) setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) if self.normalization != "RMSNorm": - self.layer_norm_bias = Parameter( + self.layer_norm_bias = torch.nn.Parameter( torch.empty(in_features, device=device, dtype=params_dtype) ) setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) @@ -800,68 +797,100 @@ def __init__( with torch.no_grad(): self.bias_tensor.zero_() + # Configure parameter splits + self.weight_names = [] + self.bias_names = [] + self.parameter_split_sizes = [] if parameters_split is None: - parameters_split = {"": self.out_features} - elif isinstance(parameters_split, tuple): - assert ( - self.out_features % len(parameters_split) == 0 - ), f"Weight and bias params cannot be split into {len(parameters_split)} parts" - split_size = self.out_features // len(parameters_split) - parameters_split = {key: split_size for key in parameters_split} + # Split into a single parameter by default + self.weight_names = ["weight"] + self.bias_names = ["bias"] + self.parameter_split_sizes = [out_features] + elif not parameters_split: + raise ValueError("Cannot split weight buffer into 0 parameters") elif isinstance(parameters_split, dict): - overall_split_size = sum(parameters_split.values()) - assert( - self.out_features == overall_split_size - ), f"Overall sum of parameters_split (={overall_split_size}) does not match "\ - f"to out features (={self.out_features})" + # Split parameters with provided sizes + for name, split_size in parameters_split.items(): + self.weight_names.append(f"{name.rstrip('_')}_weight") + self.bias_names.append(f"{name.rstrip('_')}_bias") + self.parameter_split_sizes.append(split_size) + elif all(isinstance(name, str) for name in parameters_split): + # Split parameters evenly + split_size = out_features // len(parameters_split) + for name in parameters_split: + self.weight_names.append(f"{name.rstrip('_')}_weight") + self.bias_names.append(f"{name.rstrip('_')}_bias") + self.parameter_split_sizes.append(split_size) else: - assert False, "Type of 'parameters_split' is not None, tuple or dict" - self.updated_parameters_split = parameters_split + raise TypeError("Invalid configuration for parameters split") - self.weight_names = [] - self.bias_names = [] + # Make sure parameter splits are valid + if sum(self.parameter_split_sizes) != out_features: + raise ValueError( + f"Trying to split weight buffer ({out_features=}) " + f"with split sizes {self.parameter_split_sizes}" + ) - slice_begin = 0 - for pname, slice_size in parameters_split.items(): - wname = pname + "weight" - bname = pname + "bias" - - slice_end = slice_begin + slice_size - # NOTE(future): Figure out a way to support slicing when weights - # are of `Float8Tensor` class - if self.primary_weights_in_fp8: - assert len(parameters_split) == 1, ("Slicing operation is not " - "supported in Float8Tensor " - "class!") - self.register_parameter(wname, Parameter(self.weight_tensor)) - else: - self.register_parameter( - wname, Parameter(self.weight_tensor[slice_begin:slice_end]) + # Adjust parameter splits for tensor-parallel distribution + if self.parallel_mode == "column": + for i, size in enumerate(self.parameter_split_sizes): + if size % self.tp_size != 0: + raise RuntimeError( + f"Attempting to distribute a parameter with out_features={size} " + f"between {self.tp_size} tensor-parallel processes" + ) + self.parameter_split_sizes[i] = size // self.tp_size + + # Construct parameters from weight and bias buffers + offset = 0 + for i, split_size in enumerate(self.parameter_split_sizes): + split_start = offset + offset += split_size + split_end = offset + + # Check if parameters are subviews of buffers + is_subview = (split_start, split_end) != (0, self.out_features) + if is_subview and self.primary_weights_in_fp8: + raise RuntimeError( + "Splitting Float8Tensor into multiple params " + "is not supported" ) - set_tensor_model_parallel_attributes( - tensor=getattr(self, wname), - is_parallel=True, - dim=1 if parallel_mode == "row" else 0, - stride=1, - ) + # Construct weight parameter + weight = self.weight_tensor + if is_subview: + weight = weight[split_start:split_end] + weight = torch.nn.Parameter(weight) + self.register_parameter(self.weight_names[i], weight) + # Construct bias parameter if needed if self.use_bias: - self.register_parameter( - bname, Parameter(self.bias_tensor[slice_begin:slice_end]) - ) + bias = self.bias_tensor + if is_subview: + bias = bias[split_start:split_end] + bias = torch.nn.Parameter(bias) + self.register_parameter(self.bias_names[i], bias) if parallel_mode == "row": - setattr(getattr(self, bname), "sequence_parallel", sequence_parallel) + bias.sequence_parallel = sequence_parallel else: - setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device)) + bias = torch.Tensor().to(dtype=params_dtype, device=device) + setattr(self, self.bias_names[i], bias) + # Configure tensor parallelism + set_tensor_model_parallel_attributes( + tensor=weight, + is_parallel=True, + dim=1 if parallel_mode == "row" else 0, + stride=1, + ) if parallel_mode == "column": - set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1) + set_tensor_model_parallel_attributes(bias, True, 0, 1) - self.weight_names.append(wname) - self.bias_names.append(bname) - - slice_begin = slice_end + # Concatenated tensors are not needed if not splitting + # into multiple parameters + if not is_subview: + del self.weight_tensor + del self.bias_tensor self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) @@ -880,12 +909,6 @@ def __init__( self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - # Clean up weight and bias buffers - if self.parameters_split is None: - del self.weight_tensor - if self.use_bias: - del self.bias_tensor - def reset_layer_norm_parameters(self) -> None: """Init LN params""" if not self.zero_centered_gamma: @@ -950,18 +973,26 @@ def forward( with self.prepare_forward(inp, is_first_microbatch) as inp: assert self.fp8 or not self.primary_weights_in_fp8, \ "Need to run inside fp8_autocast region when weights are stored in FP8." - bias_tensor = ( - self.bias if self.parameters_split is None - else self.bias_tensor if not torch.is_grad_enabled() - else self.noop_cat("bias_tensor", self.bias_names, - self.updated_parameters_split) - ) - weight_tensor = ( - self.weight if self.parameters_split is None - else self.weight_tensor if not torch.is_grad_enabled() - else self.noop_cat("weight_tensor", self.weight_names, - self.updated_parameters_split) - ) + + # Get concatenated weight and bias tensors + if len(self.parameter_split_sizes) == 1: + weight_tensor = getattr(self, self.weight_names[0]) + bias_tensor = getattr(self, self.bias_names[0]) + elif torch.is_grad_enabled(): + weight_tensor = _noop_cat( + [getattr(self, name) for name in self.weight_names], + self.weight_tensor, + ) + if self.use_bias: + bias_tensor = _noop_cat( + [getattr(self, name) for name in self.bias_names], + self.bias_tensor, + ) + else: + bias_tensor = getattr(self, self.bias_names[0]) # Unused + else: + weight_tensor = self.weight_tensor + bias_tensor = self.bias_tensor # Fetch the fp8 weights placeholders (for linear/gemm) weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad( diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3ad1470ce3..2a28d67292 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -7,7 +7,6 @@ from typing import Union, Optional, Callable, Tuple, List, Dict, Any import torch -from torch.nn.parameter import Parameter import transformer_engine_extensions as tex @@ -20,6 +19,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) +from ._common import _noop_cat from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..utils import ( divide, @@ -521,8 +521,7 @@ def backward( class Linear(TransformerEngineBaseModule): - """ - Applies a linear transformation to the incoming data :math:`y = xA^T + b` + """Applies a linear transformation to the incoming data :math:`y = xA^T + b` On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`. @@ -538,13 +537,13 @@ class Linear(TransformerEngineBaseModule): used for initializing weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None - if a tuple of strings or a dict of strings to integers is provided, - the weight and bias parameters of the module are exposed as `N` separate - `torch.nn.parameter.Parameter`s each, split along the first dimension, - where `N` is the length of the argument and the strings contained are the - names of the split parameters. In the case of a tuple, each parameter - has the same shape. In the case of a dict, the values give the - `out_features` for each projection. + Configuration for splitting the weight and bias tensors along dim 0 into + multiple PyTorch parameters. If a list or tuple of strings is provided, + they are used to make the names of equally-sized parameters. If a dict + (preferably an OrderedDict) is provided, the keys are used as names and + values as split sizes along dim 0. The resulting parameters will have + names that end in `_weight` or `_bias`, so trailing underscores are + stripped from any provided names. device : Union[torch.device, str], default = "cuda" The device on which the parameters of the model will allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the @@ -584,6 +583,7 @@ class Linear(TransformerEngineBaseModule): it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. + """ def __init__( @@ -617,7 +617,6 @@ def __init__( self.use_bias = bias self.return_bias = return_bias self.apply_bias = bias and not return_bias - self.parameters_split = parameters_split self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_split_rs = ub_split_rs self.ub_split_ag = ub_split_ag @@ -694,69 +693,100 @@ def __init__( with torch.no_grad(): self.bias_tensor.zero_() + # Configure parameter splits + self.weight_names = [] + self.bias_names = [] + self.parameter_split_sizes = [] if parameters_split is None: - parameters_split = {"": self.out_features} - elif isinstance(parameters_split, tuple): - assert ( - self.out_features % len(parameters_split) == 0 - ), f"Weight and bias params cannot be split into {len(parameters_split)} parts" - split_size = self.out_features // len(parameters_split) - parameters_split = {key: split_size for key in parameters_split} + # Split into a single parameter by default + self.weight_names = ["weight"] + self.bias_names = ["bias"] + self.parameter_split_sizes = [out_features] + elif not parameters_split: + raise ValueError("Cannot split weight buffer into 0 parameters") elif isinstance(parameters_split, dict): - overall_split_size = sum(parameters_split.values()) - assert( - self.out_features == overall_split_size - ), f"Overall sum of parameters_split (={overall_split_size}) does not match "\ - f"to out features (={self.out_features})" + # Split parameters with provided sizes + for name, split_size in parameters_split.items(): + self.weight_names.append(f"{name.rstrip('_')}_weight") + self.bias_names.append(f"{name.rstrip('_')}_bias") + self.parameter_split_sizes.append(split_size) + elif all(isinstance(name, str) for name in parameters_split): + # Split parameters evenly + split_size = out_features // len(parameters_split) + for name in parameters_split: + self.weight_names.append(f"{name.rstrip('_')}_weight") + self.bias_names.append(f"{name.rstrip('_')}_bias") + self.parameter_split_sizes.append(split_size) else: - assert False, "Type of 'parameters_split' is not None, tuple or dict" - self.updated_parameters_split = parameters_split + raise TypeError("Invalid configuration for parameters split") - self.weight_names = [] - self.bias_names = [] + # Make sure parameter splits are valid + if sum(self.parameter_split_sizes) != out_features: + raise ValueError( + f"Trying to split weight buffer ({out_features=}) " + f"with split sizes {self.parameter_split_sizes}" + ) - slice_begin = 0 - for pname, slice_size in parameters_split.items(): - wname = pname + "weight" - bname = pname + "bias" + # Adjust parameter splits for tensor-parallel distribution + if self.parallel_mode == "column": + for i, size in enumerate(self.parameter_split_sizes): + if size % self.tp_size != 0: + raise RuntimeError( + f"Attempting to distribute a parameter with out_features={size} " + f"between {self.tp_size} tensor-parallel processes" + ) + self.parameter_split_sizes[i] = size // self.tp_size + + # Construct parameters from weight and bias buffers + offset = 0 + for i, split_size in enumerate(self.parameter_split_sizes): + split_start = offset + offset += split_size + split_end = offset + + # Check if parameters are subviews of buffers + is_subview = (split_start, split_end) != (0, self.out_features) + if is_subview and self.primary_weights_in_fp8: + raise RuntimeError( + "Splitting Float8Tensor into multiple params " + "is not supported" + ) - slice_end = slice_begin + slice_size + # Construct weight parameter + weight = self.weight_tensor + if is_subview: + weight = weight[split_start:split_end] + weight = torch.nn.Parameter(weight) + self.register_parameter(self.weight_names[i], weight) - # TODO(ksivaman): Add indexing op to torch dispatcher for float8 - if self.primary_weights_in_fp8: - assert len(parameters_split) == 1, ("Slicing operation is not " - "supported in Float8Tensor " - "class!") - self.register_parameter(wname, Parameter(self.weight_tensor)) + # Construct bias parameter if needed + if self.use_bias: + bias = self.bias_tensor + if is_subview: + bias = bias[split_start:split_end] + bias = torch.nn.Parameter(bias) + self.register_parameter(self.bias_names[i], bias) + if parallel_mode == "row": + bias.sequence_parallel = sequence_parallel else: + bias = torch.Tensor().to(dtype=params_dtype, device=device) + setattr(self, self.bias_names[i], bias) - self.register_parameter( - wname, Parameter(self.weight_tensor[slice_begin:slice_end]) - ) - + # Configure tensor parallelism set_tensor_model_parallel_attributes( - tensor=getattr(self, wname), + tensor=weight, is_parallel=True, dim=1 if parallel_mode == "row" else 0, stride=1, ) - - if self.use_bias: - self.register_parameter( - bname, Parameter(self.bias_tensor[slice_begin:slice_end]) - ) - if parallel_mode == "row": - setattr(getattr(self, bname), "sequence_parallel", sequence_parallel) - else: - setattr(self, bname, torch.Tensor().to(dtype=params_dtype, device=device)) - if parallel_mode == "column": - set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1) + set_tensor_model_parallel_attributes(bias, True, 0, 1) - self.weight_names.append(wname) - self.bias_names.append(bname) - - slice_begin = slice_end + # Concatenated tensors are not needed if not splitting + # into multiple parameters + if not is_subview: + del self.weight_tensor + del self.bias_tensor self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) @@ -767,12 +797,6 @@ def __init__( else: self.gemm_bias_unfused_add = False - # Clean up weight and bias buffers - if self.parameters_split is None: - del self.weight_tensor - if self.use_bias: - del self.bias_tensor - def get_fp8_weights_scratchpad( self, is_first_microbatch: Union[bool, None], @@ -828,18 +852,26 @@ def forward( with self.prepare_forward(inp, is_first_microbatch) as inp: assert self.fp8 or not self.primary_weights_in_fp8, \ "Need to run inside fp8_autocast region when weights are stored in FP8." - bias_tensor = ( - self.bias if self.parameters_split is None - else self.bias_tensor if not torch.is_grad_enabled() - else self.noop_cat("bias_tensor", self.bias_names, - self.updated_parameters_split) - ) - weight_tensor = ( - self.weight if self.parameters_split is None - else self.weight_tensor if not torch.is_grad_enabled() - else self.noop_cat("weight_tensor", self.weight_names, - self.updated_parameters_split) - ) + + # Get concatenated weight and bias tensors + if len(self.parameter_split_sizes) == 1: + weight_tensor = getattr(self, self.weight_names[0]) + bias_tensor = getattr(self, self.bias_names[0]) + elif torch.is_grad_enabled(): + weight_tensor = _noop_cat( + [getattr(self, name) for name in self.weight_names], + self.weight_tensor, + ) + if self.use_bias: + bias_tensor = _noop_cat( + [getattr(self, name) for name in self.bias_names], + self.bias_tensor, + ) + else: + bias_tensor = getattr(self, self.bias_names[0]) # Unused + else: + weight_tensor = self.weight_tensor + bias_tensor = self.bias_tensor # Fetch the fp8 weights placeholders (for linear/gemm) weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(