diff --git a/fast_llm/engine/config_utils/parameter.py b/fast_llm/engine/config_utils/parameter.py index 76416d36..b1bf4d5c 100644 --- a/fast_llm/engine/config_utils/parameter.py +++ b/fast_llm/engine/config_utils/parameter.py @@ -10,29 +10,10 @@ from fast_llm.tensor import ParameterMeta -def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): +def combine_lr_scales(*lr_scales: float | None) -> float | None: # Remove `None` entries. lr_scales = tuple(lr_scale for lr_scale in lr_scales if lr_scale is not None) - if not lr_scales: - # Everything is None - return None - tuple_length = None - # Check if we have tuples, and determine the length. - for lr_scale in lr_scales: - if isinstance(lr_scale, tuple): - if tuple_length is None: - tuple_length = len(lr_scale) - else: - assert len(lr_scale) == tuple_length - if tuple_length is None: - # No tuple: simple product. - return math.prod(lr_scales) - else: - # Tuple(s): use recursion. - return tuple( - combine_lr_scales(*[lr_scale[i] if isinstance(lr_scale, tuple) else lr_scale for lr_scale in lr_scales]) - for i in range(tuple_length) - ) + return math.prod(lr_scales) if lr_scales else None @config_class() diff --git a/fast_llm/engine/config_utils/tensor_dim.py b/fast_llm/engine/config_utils/tensor_dim.py index f67916a6..9b235d39 100644 --- a/fast_llm/engine/config_utils/tensor_dim.py +++ b/fast_llm/engine/config_utils/tensor_dim.py @@ -157,19 +157,18 @@ def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: ) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": - import torch - return ( - torch.concatenate( - [ + self.merge_tensors( + tuple( tensor_dim.local_to_global(tensor_, dim) for tensor_, tensor_dim in zip( - tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self.split_tensor(tensor, dim, False), self._tensor_dims, strict=True, ) - ], + ), dim, + True, ) if self.is_parallel else tensor @@ -178,19 +177,18 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def local_to_global_partial( self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 ) -> "torch.Tensor": - import torch - return ( - torch.concatenate( - [ + self.merge_tensors( + tuple( tensor_dim.local_to_global_partial(tensor_, dim) for tensor_, tensor_dim in zip( - tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self.split_tensor(tensor, dim, False), self._tensor_dims, strict=True, ) - ], + ), dim, + True, ) if self.is_parallel else tensor @@ -199,23 +197,48 @@ def local_to_global_partial( def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": if self.is_parallel and expand: raise NotImplementedError() - import torch - return ( - torch.concatenate( - [ + self.merge_tensors( + tuple( tensor_dim.global_to_local(tensor_, dim) for tensor_, tensor_dim in zip( - tensor.split([tensor_dim.global_size for tensor_dim in self._tensor_dims], dim), + self.split_tensor(tensor, dim, True), self._tensor_dims, strict=True, ) - ], + ), dim, + False, ) if self.is_parallel else tensor ) + def split_tensor(self, tensor: "torch.Tensor", dim: int = 0, global_: bool = False) -> tuple["torch.Tensor", ...]: + return tensor.split( + [(tensor_dim.global_size if global_ else tensor_dim.size) for tensor_dim in self._tensor_dims], dim + ) + + def merge_tensors( + self, tensors: tuple["torch.Tensor", ...], dim: int = 0, global_: bool = False + ) -> "torch.Tensor": + import torch + + assert all( + tensor.size(dim) == (tensor_dim.global_size if global_ else tensor_dim.size) + for tensor, tensor_dim in zip(tensors, self._tensor_dims, strict=True) + ) + return torch.concatenate(tensors, dim) + + def get_split_ranges(self, global_: bool = False) -> tuple[tuple[int, int], ...]: + split_ranges = [] + split_begin = 0 + for tensor_dim in self._tensor_dims: + split_end = split_begin + (tensor_dim.global_size if global_ else tensor_dim.size) + split_ranges.append((split_begin, split_end)) + split_begin = split_end + Assert.eq(split_begin, (self.global_size if global_ else self.size)) + return tuple(split_ranges) + scalar_dim = TensorDim("scalar", 1) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index ded24e53..5e5b3d39 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -15,7 +15,7 @@ from fast_llm.engine.optimizer.config import ParamGroup from fast_llm.logging import log_generator from fast_llm.tensor import ParameterMeta, SafeTensorSlice -from fast_llm.utils import Assert, div +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -234,28 +234,28 @@ def get_param_groups( for i, fsdp in enumerate(self._fsdps): if not fsdp.requires_grad: continue + buffer_begin = 0 for parameter_name in fsdp.parameter_names: # If needed, chunk the parameter on the first dimension. - parameter_meta = fsdp.get_parameter_meta(parameter_name) + parameter_meta: ParameterMeta = fsdp.get_parameter_meta(parameter_name) + Assert.eq(buffer_begin, fsdp.get_parameter_begin_in_buffer(parameter_meta.tensor_name)) if not parameter_meta.requires_grad: continue - chunk_size = div(parameter_meta.numel(), len(parameter_meta.lr_scale)) - buffer_begin = fsdp.get_parameter_begin_in_buffer(parameter_meta.tensor_name) - for i, lr_scale in enumerate(parameter_meta.lr_scale): - begin = fsdp.index_buffer_to_shard(buffer_begin + i * chunk_size) - end = fsdp.index_buffer_to_shard(buffer_begin + (i + 1) * chunk_size) - if lr_scale == 0 or begin == end: - continue - optimizer_params = (parameter_meta.param_weight_decay, lr_scale) - if optimizer_params in grouped_parameter_slices: - last_slice = grouped_parameter_slices[optimizer_params][-1] - if begin == last_slice.stop: - grouped_parameter_slices[optimizer_params][-1] = slice(last_slice.start, end) - continue - else: - grouped_parameter_slices[optimizer_params] = [] - grouped_parameter_slices[optimizer_params].append(slice(begin, end)) + for parameter_meta_ in parameter_meta.metas_for_grad: + # Metas for grad are contiguous. + buffer_end = buffer_begin + parameter_meta_.numel() + if buffer_begin == buffer_end: + pass + elif ( + optimizer_params := (parameter_meta_._param_weight_decay, parameter_meta_.lr_scale) + ) not in grouped_parameter_slices: + grouped_parameter_slices[optimizer_params] = [slice(buffer_begin, buffer_end)] + elif buffer_begin == (last_slice := grouped_parameter_slices[optimizer_params][-1]).stop: + grouped_parameter_slices[optimizer_params][-1] = slice(last_slice.start, buffer_end) + else: + grouped_parameter_slices[optimizer_params].append(slice(buffer_begin, buffer_end)) + buffer_begin = buffer_end param_groups += [ param_group_cls( name=f"wd_{weight_decay}_lr_scale_{lr_scale}", # noqa @@ -340,9 +340,9 @@ def _reorder_parameter_metas(cls, parameter_metas): reorder_index = sorted( range(len(parameter_metas)), key=lambda i: ( - parameter_metas[i].param_weight_decay, - parameter_metas[i].param_weight_decay == parameter_metas[i].is_tensor_parallel, - parameter_metas[i].param_weight_decay != parameter_metas[i].sequence_tensor_parallel, + parameter_metas[i]._param_weight_decay, + parameter_metas[i]._param_weight_decay == parameter_metas[i].is_tensor_parallel, + parameter_metas[i]._param_weight_decay != parameter_metas[i].sequence_tensor_parallel, ), ) reordered_metas = [parameter_metas[i] for i in reorder_index] diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 9a940f4c..43d0d0d6 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -6,11 +6,12 @@ from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ -from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.common.linear.linear import concatenate_linear_layers from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.tensor import TensorMeta @@ -92,13 +93,9 @@ def __init__( head_size_dim = TensorDim("head_size", self._config.head_size) query_dim = CompositeTensorDim("query", (head_group_dim, group_heads_dim, head_size_dim)) - key_value_dim = ConcatenatedTensorDim( - "key_value", - ( - CompositeTensorDim("key", (head_group_dim, head_size_dim)), - CompositeTensorDim("value", (head_group_dim, head_size_dim)), - ), - ) + key_dim = CompositeTensorDim("key", (head_group_dim, head_size_dim)) + value_dim = CompositeTensorDim("value", (head_group_dim, head_size_dim)) + dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, head_size_dim)) self._softmax_scale = self._config.head_size ** (-self._config.softmax_scale_power) @@ -114,22 +111,30 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) - # TODO: Use value config. - self.key_value = self._config.key_layer.get_layer( + key = self._config.key_layer.get_layer( hidden_dim, - key_value_dim, + key_dim, default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), default_add_bias=self._config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=self._lr_scale, - peft=None if self._config.key_layer.apply_peft is None else self._peft, + peft=None, + ) + value = self._config.key_layer.get_layer( + hidden_dim, + value_dim, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=self._config.add_linear_biases, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=None, + ) + self.key_value = concatenate_linear_layers( + (key, value), + (self._config.key_layer, self._config.value_layer), + default_apply_peft=(False, True), + peft=peft, ) - if self._peft is not None and self._config.key_layer.apply_peft is None: - # Default: Apply to value only. - # TODO: Avoid this hack. - self.key_value = self._peft.apply_linear( - self.key_value, True, out_channel_begin=div(key_value_dim.global_size, 2) - ) self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py index e7c6d9e9..8f90ccf5 100644 --- a/fast_llm/layers/common/linear/config.py +++ b/fast_llm/layers/common/linear/config.py @@ -200,7 +200,7 @@ def get_layer( if default_bias_initialization is None: default_bias_initialization = init_uniform_centered_((in_dim.global_size * kernel_dim.global_size) ** -0.5) - lr_scale = (combine_lr_scales(lr_scale, self.lr_scale),) + lr_scale = combine_lr_scales(lr_scale, self.lr_scale) weight = self.weight.get_parameter( (in_dim, scalar_dim, kernel_dim), default_initialization=default_weight_initialization, diff --git a/fast_llm/layers/common/linear/linear.py b/fast_llm/layers/common/linear/linear.py index 3028fd1e..013f34a5 100644 --- a/fast_llm/layers/common/linear/linear.py +++ b/fast_llm/layers/common/linear/linear.py @@ -15,7 +15,9 @@ output_parallel_linear_backward, output_parallel_linear_forward, ) -from fast_llm.tensor import ParameterMeta, TensorMeta +from fast_llm.layers.common.linear.config import LinearConfig +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.tensor import ConcatenatedParameterMeta, ParameterMeta, TensorMeta from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -164,3 +166,70 @@ def forward_only(self, input_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor def backward(self, grad_output: torch.Tensor, context: tuple[typing.Any, ...]) -> torch.Tensor: # noqa # TODO: Needs grad_bias as input too? return input_parallel_linear_backward(grad_output, context) + + +def concatenate_linear_layers[ + T: LinearBase +]( + layers: tuple[T, ...], + configs: tuple[LinearConfig, ...], + *, + concatenate_input_dim: bool = False, + dim_name: str | None = None, + default_apply_peft: bool | tuple[bool, ...] = False, + peft: PeftConfig | None, +) -> T: + # TODO: Simplify. + # All biases must be either enabled or disabled. TODO: Allow non-constant. + enable_bias = layers[0].bias is not None + # Concatenate on `in_dim` (instead of `out_dim`) + if concatenate_input_dim: + # TODO: Support this case? (needs one bias instead of a concatenation) + assert not enable_bias + + cls = type(layers[0]) + # Should not already be wrapped with Peft. + Assert.incl(cls, (Linear, InputParallelLinear, OutputParallelLinear)) + # The concatenated dimension must be at index zero. + transposed_weight = concatenate_input_dim + for layer in layers: + Assert.eq(layer._transposed_weight, transposed_weight) + Assert.is_(type(layer), cls) + Assert.eq(layer.bias is not None, enable_bias) + + if cls in (InputParallelLinear, OutputParallelLinear): + for layer in layers[1:]: + Assert.is_(layer._parallel_dim, layers[0]._parallel_dim) + Assert.eq(layer._sequence_parallel, layers[0]._sequence_parallel) + args = {"parallel_dim": layers[0]._parallel_dim, "sequence_parallel": layers[0]._sequence_parallel} + else: + args = {} + + # TODO: Original parameters won't get names. + weight = ConcatenatedParameterMeta.from_metas(tuple(layer.weight for layer in layers), dim_name=dim_name) + bias = ( + ConcatenatedParameterMeta.from_metas(tuple(layer.bias for layer in layers), dim_name=dim_name) + if enable_bias + else None + ) + + out = cls(weight, bias, transposed_weight=transposed_weight, **args) + if peft is not None: + if isinstance(default_apply_peft, bool): + default_apply_peft = (default_apply_peft,) * len(layers) + apply_peft = [ + default if config.apply_peft is None else config.apply_peft + for config, default in zip(configs, default_apply_peft, strict=True) + ] + if len(set(apply_peft)) == 1: + out_channel_ranges = None + enabled = apply_peft[0] + else: + enabled = True + out_channel_ranges = tuple( + split_range + for split_range, apply_peft_ in zip(weight.dims[0].get_split_ranges(True), apply_peft) + if apply_peft_ + ) + out = peft.apply_linear(out, enabled, out_channel_ranges=out_channel_ranges) + return out diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py index 6c765683..775f3cc5 100644 --- a/fast_llm/layers/common/peft/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -25,8 +25,7 @@ def apply_linear( self, module: "LinearBase", enabled: bool, - out_channel_begin: int | None = None, - out_channel_end: int | None = None, + out_channel_ranges: tuple[tuple[int | None, int | None], ...] | None = None, ) -> "LinearLike": return self.apply_other(module) @@ -75,8 +74,7 @@ def apply_linear( self, module: "LinearBase", enabled: bool, - out_channel_begin: int | None = None, - out_channel_end: int | None = None, + out_channel_ranges: tuple[tuple[int | None, int | None], ...] | None = None, ) -> "LinearLike": if not enabled: return self.apply_other(module) diff --git a/fast_llm/layers/common/peft/lora.py b/fast_llm/layers/common/peft/lora.py index fcff5d49..1fccdfde 100644 --- a/fast_llm/layers/common/peft/lora.py +++ b/fast_llm/layers/common/peft/lora.py @@ -6,6 +6,7 @@ from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.common.linear.linear import Linear, LinearBase from fast_llm.tensor import ParameterMeta +from fast_llm.utils import Assert def lora_linear( @@ -13,8 +14,7 @@ def lora_linear( rank: int, alpha: float, dropout: float = 0.0, - out_channel_begin: int | None = None, - out_channel_end: int | None = None, + out_channel_ranges: tuple[tuple[int | None, int | None], ...] | None = None, ): module.weight.requires_grad = False in_dim = module._in_dim @@ -25,20 +25,32 @@ def lora_linear( assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." if out_dim.parallel_dim is not None: out_dim = TensorDim(out_dim.name, out_dim.global_size) - if out_channel_begin is not None or out_channel_end is not None: - if out_channel_begin is None: - out_channel_begin = 0 - if out_channel_end is None: - out_channel_end = out_dim.global_size + + if out_channel_ranges is not None: + out_channel_range_map = [] + lora_channel_begin = 0 + # We construct a lora output dimension from the sum of the ranges and map it to the original output dimension. # TODO: This won't work with TP. Use Composite dim structure for proper split? - out_dim = TensorDim(out_dim.name, out_channel_end - out_channel_begin) + for out_channel_begin, out_channel_end in out_channel_ranges: + if out_channel_begin is None: + out_channel_begin = 0 + if out_channel_end is None: + out_channel_end = out_dim.global_size + Assert.gt(out_channel_end, out_channel_begin) + lora_channel_end = lora_channel_begin + out_channel_end - out_channel_begin + out_channel_range_map.append( + ((out_channel_begin, out_channel_end), (lora_channel_begin, lora_channel_end)) + ) + lora_channel_begin = lora_channel_end + out_dim = TensorDim(f"lora_{out_dim.name}", lora_channel_begin) middle_dim = TensorDim("lora_middle", rank) + # TODO: Doesn't make sense for concatenated layers. module.lora_0 = Linear( ParameterMeta.from_dims( (in_dim, middle_dim), - init_method=module.weight.param_init_method, + init_method=module.weight._param_init_method, lr_scale=module.weight.lr_scale, ), None, @@ -47,7 +59,7 @@ def lora_linear( module.lora_1 = Linear( ParameterMeta.from_dims( (middle_dim, out_dim), - init_method=module.weight.param_init_method, + init_method=module.weight._param_init_method, lr_scale=module.weight.lr_scale, ), None, @@ -67,10 +79,14 @@ def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor lora_out = (alpha / rank) * module.lora_1( module.lora_0(torch.dropout(input_, dropout, module.training) if dropout > 0.0 else input_) ) - if out_channel_begin is None: + if out_channel_ranges is None: output = output + lora_out else: - output.view(-1, layer_out.size(-1))[:, out_channel_begin:out_channel_end] += lora_out + # Add each piece individually. + for (out_channel_begin, out_channel_end), (lora_channel_begin, lora_channel_end) in out_channel_ranges: + output.view(-1, layer_out.size(-1))[:, out_channel_begin:out_channel_end] += lora_out[ + lora_channel_begin:lora_channel_end + ] return output.detach(), (input_, output) def backward( diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index f63bd76f..7d20e1c4 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -20,7 +20,7 @@ def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: Assert.eq(distributed.config.world_size, 1) if "bias" in meta.tensor_name: # Generator unused. - return meta.param_init_method(meta, tensor, distributed.tp_init_generator) + return _get_init_method(meta)(meta, tensor, distributed.tp_init_generator) if "query" in meta.tensor_name or "key_value" in meta.tensor_name or "dense" in meta.tensor_name: tensor_ = _init_attention_megatron(config, meta, tensor, distributed, hidden_size) elif "position_embeddings" in meta.tensor_name: @@ -33,7 +33,7 @@ def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: tensor_ = _init_transposed_mlp_weight_megatron(meta, tensor, distributed) else: # Word embedding (override generator), layer norm (generator unused), other mlp weights. - return meta.param_init_method(meta, tensor, distributed.tp_init_generator) + return _get_init_method(meta)(meta, tensor, distributed.tp_init_generator) tensor.copy_(tensor_.reshape_as(tensor)) return init_megatron @@ -59,11 +59,11 @@ def _init_attention_megatron( ) -> "torch.Tensor": # Megatron combines q and kv and inverts the initialization order of qkv and dense layers. # It also always treats the tensors as tensor-parallel and uses a different rotary embedding format. - assert meta.param_init_method is not None + init_method = _get_init_method(meta) generator = distributed.tp_init_generator state = generator.get_state() # Initialize a mock dense layer to advance the random state - meta.param_init_method( + init_method( meta, dense_tensor_ := tensor.new_empty( config.mixer.head_size * config.mixer.heads, @@ -73,7 +73,7 @@ def _init_attention_megatron( ) # QKV is split differently. (Assuming no tensor-parallel.) heads_per_group = div(config.mixer.heads, config.mixer.head_groups) - meta.param_init_method( + init_method( meta, qkv_tensor_ := tensor.new_empty( config.mixer.head_groups, @@ -107,16 +107,23 @@ def _init_attention_megatron( return tensor_ +def _get_init_method(meta: "ParameterMeta"): + from fast_llm.tensor import ConcatenatedParameterMeta + + # For concatenated parameters (ex. key_value), either initialization is good enough. + return (meta.metas[0] if isinstance(meta, ConcatenatedParameterMeta) else meta)._param_init_method + + def _init_position_embeddings_megatron( meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": import torch # Megatron initializes the position embeddings on cpu twice. - assert meta.param_init_method is not None generator = distributed.default_cpu_generator - meta.param_init_method(meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), generator) - meta.param_init_method(meta, tensor_, generator) + init_method = _get_init_method(meta) + init_method(meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), generator) + init_method(meta, tensor_, generator) return tensor_ @@ -126,8 +133,7 @@ def _init_transposed_mlp_weight_megatron( import torch # Megatron never transposes the mlp layer 2 weight. - assert meta.param_init_method is not None - meta.param_init_method(meta, tensor_ := torch.empty_like(tensor), distributed.tp_init_generator) + _get_init_method(meta)(meta, tensor_ := torch.empty_like(tensor), distributed.tp_init_generator) return tensor_.view(meta.size(1), meta.size(0)).t() @@ -137,8 +143,7 @@ def _init_moe_router_megatron( import torch # Megatron initializes the router on cpu. - assert meta.param_init_method is not None - meta.param_init_method( + _get_init_method(meta)( meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator ) return tensor_ @@ -151,18 +156,17 @@ def _init_moe_mlp_megatron( distributed: "Distributed", hidden_size: int, ) -> "torch.Tensor": - assert meta.param_init_method is not None generator = distributed.tp_init_generator if meta.is_tensor_parallel else distributed.pp_init_generator - # self.param_init_method(self, tensor, generator) + init_method = _get_init_method(meta) state = generator.get_state() weight_1 = tensor.new_empty( config.mlp.experts * (1 + config.mlp.gated) * config.mlp.intermediate_size, hidden_size ) weight_2 = tensor.new_empty(config.mlp.experts * config.mlp.intermediate_size, hidden_size) for chunk_1, chunk_2 in zip(weight_1.chunk(config.mlp.experts), weight_2.chunk(config.mlp.experts)): - meta.param_init_method(meta, chunk_1, generator) + init_method(meta, chunk_1, generator) chunk_2_ = chunk_2.new_empty(hidden_size, config.mlp.intermediate_size) - meta.param_init_method(meta, chunk_2_, generator) + init_method(meta, chunk_2_, generator) chunk_2.copy_(chunk_2_.t()) if "layer_1.weight" in meta.tensor_name: # Keep the original random state for weight_2. diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b709ea83..2c86cdc8 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -7,7 +7,7 @@ from fast_llm.core.distributed import ReduceOp from fast_llm.core.ops import reduce_op from fast_llm.engine.config_utils.initialization import Initialization, Initializer, LambdaInitializer -from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.config_utils.tensor_dim import ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.functional.triton.pointwise import triton_add, triton_copy @@ -84,6 +84,7 @@ def __new__( tensor_name: str, dims: tuple[TensorDim, ...], reductions: tuple[tuple[DistributedDim, ReduceOp], ...] = (), + **kwargs, ): return super().__new__( cls, @@ -241,7 +242,7 @@ def __init__( init_method: "Initialization | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None" = None, weight_decay: bool = True, # Pass a list to split the parameter in contiguous (dim=0) chunks of equal size for optimization. - lr_scale: float | None | tuple[float | None, ...] = None, + lr_scale: float | None = None, requires_grad: bool = True, allow_sequence_tensor_parallel: bool = True, allow_no_grad: bool = False, @@ -253,8 +254,8 @@ def __init__( # Support non-wrapped callables for convenience. assert callable(init_method) init_method = LambdaInitializer(init_method) - self.param_init_method: Initializer | None = init_method - self.param_weight_decay = weight_decay + self._param_init_method: Initializer | None = init_method + self._param_weight_decay = weight_decay self._is_param = True self.param_grad_is_zero = False # Almost all parameters are either tensor-parallel or process tensor-sequence-parallel inputs. @@ -264,68 +265,198 @@ def __init__( # to support cases where gradients may not always be computed (ex. MOE layers). self.allow_no_grad = allow_no_grad - self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) - self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) - # Ensure the parameter is split in chunks of equal size. - Assert.multiple(self.dims[0].size, len(self.lr_scale)) + self._lr_scale = lr_scale if requires_grad else 0 + self.requires_grad = self._lr_scale != 0 - def __new__( - cls, - data: torch.Tensor, - *, - tensor_name: str = "", - dims: tuple[TensorDim, ...], - init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None", - weight_decay: bool = True, - lr_scale: float | None | tuple[float | None, ...] = None, - allow_sequence_tensor_parallel: bool = True, - allow_no_grad: bool = False, - ): - return super().__new__( - cls, - data, - tensor_name=tensor_name, - dims=dims, - ) + @property + def lr_scale(self) -> float | None: + return self._lr_scale + + @property + def param_weight_decay(self) -> bool: + return self._param_weight_decay def __repr__(self, *, tensor_contents=()) -> str: return super().__repr__( - tensor_contents=(f"wd={self.param_weight_decay}", f"lr_scale={self.lr_scale}", *tensor_contents) + tensor_contents=(f"wd={self._param_weight_decay}", f"lr_scale={self._lr_scale}", *tensor_contents) ) def init_parameter(self, tensor: torch.Tensor, distributed: Distributed) -> None: - assert self.param_init_method is not None + assert self._param_init_method is not None if ( distributed.config.tensor_parallel == 1 or distributed.config.reproducible_init - or self.param_init_method.requires_global_initialization + or self._param_init_method.requires_global_initialization ): generator = distributed.pp_init_generator else: generator = distributed.tp_init_generator if self.is_tensor_parallel else distributed.pp_init_generator - self.param_init_method(self, tensor, generator) + self._param_init_method(self, tensor, generator) @property def requires_global_initialization(self) -> bool: - return self.param_init_method.requires_global_initialization + return self._param_init_method.requires_global_initialization def save(self) -> dict[str, typing.Any]: return { "name": self.tensor_name, "dim_names": self.dim_names, "shape": tuple(self.shape), - "weight_decay": self.param_weight_decay, + "weight_decay": self._param_weight_decay, "sequence_tensor_parallel": self.sequence_tensor_parallel, "requires_grad": self.requires_grad, "tensor_parallel": self.is_tensor_parallel, "allow_no_grad": self.allow_no_grad, - "lr_scale": self.lr_scale, + "lr_scale": self._lr_scale, } def load(self, state: dict[str, typing.Any]) -> None: current = self.save() Assert.eq(state, current) + @property + def metas_for_grad(self) -> tuple["ParameterMeta", ...]: + return (self,) + + +class ConcatenatedTensorMeta(TensorMeta): + def __init__( + self, + data: torch.Tensor, + *, + metas: tuple[ParameterMeta, ...], + dim_index: int, + _concatenate_check: bool = False, + **kwargs, + ): + super().__init__(data, **kwargs) + if not _concatenate_check: + raise RuntimeError( + f"Please instantiate {type(self).__name__} tensors through {type(self).__name__}.from_dict()" + ) + self.metas = metas + self.dim_index = dim_index + + @classmethod + def from_metas( + cls, + metas: tuple[ParameterMeta, ...], + *, + tensor_name: str = "", + dim_index: int = 0, + dim_name: str | None = None, + **kwargs, + ): + for meta in metas: + # TODO: Support recursion? + assert not isinstance(meta, ConcatenatedTensorMeta) + for meta in metas[1:]: + Assert.eq(meta.ndim, metas[0].ndim) + for index, dim in enumerate(meta.dims): + if index != dim_index: + Assert.is_(dim, metas[0].dims[index]) + Assert.eq(meta.dtype, metas[0].dtype) + Assert.eq(meta._reductions, metas[0]._reductions) + + dims = list(metas[0].dims) + if dim_name is None: + dim_name = f"concatenated_{'_'.join([meta.dims[dim_index].name for meta in metas])}" + dims[dim_index] = ConcatenatedTensorDim(dim_name, tuple(meta.dims[dim_index] for meta in metas)) + return cls.from_dims( + tuple(dims), + tensor_name=tensor_name, + dtype=metas[0].dtype, + _concatenate_check=True, + metas=metas, + dim_index=dim_index, + **kwargs, + ) + + def split_tensor(self, tensor: torch.Tensor) -> list[tuple[torch.Tensor, "TensorMeta"]]: + return [ + (tensor_, meta_) + for tensor_, meta_ in zip( + self.dims[self.dim_index].split_tensor(tensor, self.dim_index, global_=True), self.metas, strict=True + ) + ] + + +class ConcatenatedParameterMeta(ConcatenatedTensorMeta, ParameterMeta): + def __init__( + self, + data: torch.Tensor, + *, + split_optimization: bool = False, + **kwargs, + ): + super().__init__(data, **kwargs) + # Split the parameter from the point of view of the optimizer so they can go in different param groups. + # Only possible if `dim_index == 0` because the optimizer requires contiguous parameters. + self._split_optimization = split_optimization + if self._split_optimization: + Assert.eq(self.dim_index, 0) + + @classmethod + def from_metas( + cls, + metas: tuple[ParameterMeta, ...], + *, + tensor_name: str = "", + dim_index: int = 0, + dim_name: str | None = None, + **kwargs, + ): + for meta in metas: + assert isinstance(meta, ParameterMeta) + split_optimization = False + # TODO: Support more varying attributes. + for meta in metas[1:]: + if meta._lr_scale != metas[0]._lr_scale or meta._param_weight_decay != metas[0]._param_weight_decay: + Assert.eq(dim_index, 0) + split_optimization = True + Assert.eq(meta.sequence_tensor_parallel, metas[0].sequence_tensor_parallel) + Assert.eq(meta.allow_no_grad, metas[0].allow_no_grad) + + return super().from_metas( + metas, + tensor_name=tensor_name, + dim_index=dim_index, + dim_name=dim_name, + init_method=None, # Unused + weight_decay=True, # Unused + lr_scale=None, # Unused + # If partially frozen, this will place the whole parameter in the non-frozen shard, + # but skip the optimization step for + requires_grad=any(meta.requires_grad for meta in metas), + allow_sequence_tensor_parallel=metas[0].sequence_tensor_parallel, + split_optimization=split_optimization, + **kwargs, + ) + + @property + def lr_scale(self) -> float | None: + if self._split_optimization: + raise RuntimeError() + return super().lr_scale + + @property + def param_weight_decay(self) -> bool: + if self._split_optimization: + raise RuntimeError() + return super().param_weight_decay + + def init_parameter(self, tensor: torch.Tensor, distributed: Distributed) -> None: + for tensor_, meta_ in self.split_tensor(tensor): + meta_.init_parameter(tensor_, distributed) + + @property + def requires_global_initialization(self) -> bool: + return True + + @property + def metas_for_grad(self) -> tuple["ParameterMeta", ...]: + return self.metas if self._split_optimization else super().metas_for_grad + def param_get_and_unset_is_zero(param: torch.Tensor) -> bool: is_zero = param.param_grad_is_zero