Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 2 additions & 21 deletions fast_llm/engine/config_utils/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,10 @@
from fast_llm.tensor import ParameterMeta


def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]):
def combine_lr_scales(*lr_scales: float | None) -> float | None:
# Remove `None` entries.
lr_scales = tuple(lr_scale for lr_scale in lr_scales if lr_scale is not None)
if not lr_scales:
# Everything is None
return None
tuple_length = None
# Check if we have tuples, and determine the length.
for lr_scale in lr_scales:
if isinstance(lr_scale, tuple):
if tuple_length is None:
tuple_length = len(lr_scale)
else:
assert len(lr_scale) == tuple_length
if tuple_length is None:
# No tuple: simple product.
return math.prod(lr_scales)
else:
# Tuple(s): use recursion.
return tuple(
combine_lr_scales(*[lr_scale[i] if isinstance(lr_scale, tuple) else lr_scale for lr_scale in lr_scales])
for i in range(tuple_length)
)
return math.prod(lr_scales) if lr_scales else None


@config_class()
Expand Down
59 changes: 41 additions & 18 deletions fast_llm/engine/config_utils/tensor_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,18 @@ def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self:
)

def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor":
import torch

return (
torch.concatenate(
[
self.merge_tensors(
tuple(
tensor_dim.local_to_global(tensor_, dim)
for tensor_, tensor_dim in zip(
tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim),
self.split_tensor(tensor, dim, False),
self._tensor_dims,
strict=True,
)
],
),
dim,
True,
)
if self.is_parallel
else tensor
Expand All @@ -178,19 +177,18 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor
def local_to_global_partial(
self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1
) -> "torch.Tensor":
import torch

return (
torch.concatenate(
[
self.merge_tensors(
tuple(
tensor_dim.local_to_global_partial(tensor_, dim)
for tensor_, tensor_dim in zip(
tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim),
self.split_tensor(tensor, dim, False),
self._tensor_dims,
strict=True,
)
],
),
dim,
True,
)
if self.is_parallel
else tensor
Expand All @@ -199,23 +197,48 @@ def local_to_global_partial(
def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor":
if self.is_parallel and expand:
raise NotImplementedError()
import torch

return (
torch.concatenate(
[
self.merge_tensors(
tuple(
tensor_dim.global_to_local(tensor_, dim)
for tensor_, tensor_dim in zip(
tensor.split([tensor_dim.global_size for tensor_dim in self._tensor_dims], dim),
self.split_tensor(tensor, dim, True),
self._tensor_dims,
strict=True,
)
],
),
dim,
False,
)
if self.is_parallel
else tensor
)

def split_tensor(self, tensor: "torch.Tensor", dim: int = 0, global_: bool = False) -> tuple["torch.Tensor", ...]:
return tensor.split(
[(tensor_dim.global_size if global_ else tensor_dim.size) for tensor_dim in self._tensor_dims], dim
)

def merge_tensors(
self, tensors: tuple["torch.Tensor", ...], dim: int = 0, global_: bool = False
) -> "torch.Tensor":
import torch

assert all(
tensor.size(dim) == (tensor_dim.global_size if global_ else tensor_dim.size)
for tensor, tensor_dim in zip(tensors, self._tensor_dims, strict=True)
)
return torch.concatenate(tensors, dim)

def get_split_ranges(self, global_: bool = False) -> tuple[tuple[int, int], ...]:
split_ranges = []
split_begin = 0
for tensor_dim in self._tensor_dims:
split_end = split_begin + (tensor_dim.global_size if global_ else tensor_dim.size)
split_ranges.append((split_begin, split_end))
split_begin = split_end
Assert.eq(split_begin, (self.global_size if global_ else self.size))
return tuple(split_ranges)


scalar_dim = TensorDim("scalar", 1)
42 changes: 21 additions & 21 deletions fast_llm/engine/multi_stage/stage_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from fast_llm.engine.optimizer.config import ParamGroup
from fast_llm.logging import log_generator
from fast_llm.tensor import ParameterMeta, SafeTensorSlice
from fast_llm.utils import Assert, div
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -234,28 +234,28 @@ def get_param_groups(
for i, fsdp in enumerate(self._fsdps):
if not fsdp.requires_grad:
continue
buffer_begin = 0
for parameter_name in fsdp.parameter_names:
# If needed, chunk the parameter on the first dimension.
parameter_meta = fsdp.get_parameter_meta(parameter_name)
parameter_meta: ParameterMeta = fsdp.get_parameter_meta(parameter_name)
Assert.eq(buffer_begin, fsdp.get_parameter_begin_in_buffer(parameter_meta.tensor_name))
if not parameter_meta.requires_grad:
continue
chunk_size = div(parameter_meta.numel(), len(parameter_meta.lr_scale))
buffer_begin = fsdp.get_parameter_begin_in_buffer(parameter_meta.tensor_name)
for i, lr_scale in enumerate(parameter_meta.lr_scale):
begin = fsdp.index_buffer_to_shard(buffer_begin + i * chunk_size)
end = fsdp.index_buffer_to_shard(buffer_begin + (i + 1) * chunk_size)
if lr_scale == 0 or begin == end:
continue
optimizer_params = (parameter_meta.param_weight_decay, lr_scale)
if optimizer_params in grouped_parameter_slices:
last_slice = grouped_parameter_slices[optimizer_params][-1]
if begin == last_slice.stop:
grouped_parameter_slices[optimizer_params][-1] = slice(last_slice.start, end)
continue
else:
grouped_parameter_slices[optimizer_params] = []
grouped_parameter_slices[optimizer_params].append(slice(begin, end))

for parameter_meta_ in parameter_meta.metas_for_grad:
# Metas for grad are contiguous.
buffer_end = buffer_begin + parameter_meta_.numel()
if buffer_begin == buffer_end:
pass
elif (
optimizer_params := (parameter_meta_._param_weight_decay, parameter_meta_.lr_scale)
) not in grouped_parameter_slices:
grouped_parameter_slices[optimizer_params] = [slice(buffer_begin, buffer_end)]
elif buffer_begin == (last_slice := grouped_parameter_slices[optimizer_params][-1]).stop:
grouped_parameter_slices[optimizer_params][-1] = slice(last_slice.start, buffer_end)
else:
grouped_parameter_slices[optimizer_params].append(slice(buffer_begin, buffer_end))
buffer_begin = buffer_end
param_groups += [
param_group_cls(
name=f"wd_{weight_decay}_lr_scale_{lr_scale}", # noqa
Expand Down Expand Up @@ -340,9 +340,9 @@ def _reorder_parameter_metas(cls, parameter_metas):
reorder_index = sorted(
range(len(parameter_metas)),
key=lambda i: (
parameter_metas[i].param_weight_decay,
parameter_metas[i].param_weight_decay == parameter_metas[i].is_tensor_parallel,
parameter_metas[i].param_weight_decay != parameter_metas[i].sequence_tensor_parallel,
parameter_metas[i]._param_weight_decay,
parameter_metas[i]._param_weight_decay == parameter_metas[i].is_tensor_parallel,
parameter_metas[i]._param_weight_decay != parameter_metas[i].sequence_tensor_parallel,
),
)
reordered_metas = [parameter_metas[i] for i in reorder_index]
Expand Down
41 changes: 23 additions & 18 deletions fast_llm/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim
from fast_llm.engine.base_model.config import ResourceUsageConfig
from fast_llm.engine.config_utils.initialization import init_normal_
from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim
from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
from fast_llm.functional.autograd import wrap_forward_backward
from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs
from fast_llm.layers.block.config import BlockDimNames
from fast_llm.layers.common.linear.linear import concatenate_linear_layers
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.decoder.block import BlockWithBias
from fast_llm.tensor import TensorMeta
Expand Down Expand Up @@ -92,13 +93,9 @@ def __init__(

head_size_dim = TensorDim("head_size", self._config.head_size)
query_dim = CompositeTensorDim("query", (head_group_dim, group_heads_dim, head_size_dim))
key_value_dim = ConcatenatedTensorDim(
"key_value",
(
CompositeTensorDim("key", (head_group_dim, head_size_dim)),
CompositeTensorDim("value", (head_group_dim, head_size_dim)),
),
)
key_dim = CompositeTensorDim("key", (head_group_dim, head_size_dim))
value_dim = CompositeTensorDim("value", (head_group_dim, head_size_dim))

dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, head_size_dim))

self._softmax_scale = self._config.head_size ** (-self._config.softmax_scale_power)
Expand All @@ -114,22 +111,30 @@ def __init__(
lr_scale=self._lr_scale,
peft=self._peft,
)
# TODO: Use value config.
self.key_value = self._config.key_layer.get_layer(
key = self._config.key_layer.get_layer(
hidden_dim,
key_value_dim,
key_dim,
default_weight_initialization=init_normal_(std=self._hidden_size**-0.5),
default_add_bias=self._config.add_linear_biases,
sequence_parallel=self._sequence_parallel,
lr_scale=self._lr_scale,
peft=None if self._config.key_layer.apply_peft is None else self._peft,
peft=None,
)
value = self._config.key_layer.get_layer(
hidden_dim,
value_dim,
default_weight_initialization=init_normal_(std=self._hidden_size**-0.5),
default_add_bias=self._config.add_linear_biases,
sequence_parallel=self._sequence_parallel,
lr_scale=self._lr_scale,
peft=None,
)
self.key_value = concatenate_linear_layers(
(key, value),
(self._config.key_layer, self._config.value_layer),
default_apply_peft=(False, True),
peft=peft,
)
if self._peft is not None and self._config.key_layer.apply_peft is None:
# Default: Apply to value only.
# TODO: Avoid this hack.
self.key_value = self._peft.apply_linear(
self.key_value, True, out_channel_begin=div(key_value_dim.global_size, 2)
)

self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward)

Expand Down
2 changes: 1 addition & 1 deletion fast_llm/layers/common/linear/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def get_layer(
if default_bias_initialization is None:
default_bias_initialization = init_uniform_centered_((in_dim.global_size * kernel_dim.global_size) ** -0.5)

lr_scale = (combine_lr_scales(lr_scale, self.lr_scale),)
lr_scale = combine_lr_scales(lr_scale, self.lr_scale)
weight = self.weight.get_parameter(
(in_dim, scalar_dim, kernel_dim),
default_initialization=default_weight_initialization,
Expand Down
71 changes: 70 additions & 1 deletion fast_llm/layers/common/linear/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
output_parallel_linear_backward,
output_parallel_linear_forward,
)
from fast_llm.tensor import ParameterMeta, TensorMeta
from fast_llm.layers.common.linear.config import LinearConfig
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.tensor import ConcatenatedParameterMeta, ParameterMeta, TensorMeta
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -164,3 +166,70 @@ def forward_only(self, input_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor
def backward(self, grad_output: torch.Tensor, context: tuple[typing.Any, ...]) -> torch.Tensor: # noqa
# TODO: Needs grad_bias as input too?
return input_parallel_linear_backward(grad_output, context)


def concatenate_linear_layers[
T: LinearBase
](
layers: tuple[T, ...],
configs: tuple[LinearConfig, ...],
*,
concatenate_input_dim: bool = False,
dim_name: str | None = None,
default_apply_peft: bool | tuple[bool, ...] = False,
peft: PeftConfig | None,
) -> T:
# TODO: Simplify.
# All biases must be either enabled or disabled. TODO: Allow non-constant.
enable_bias = layers[0].bias is not None
# Concatenate on `in_dim` (instead of `out_dim`)
if concatenate_input_dim:
# TODO: Support this case? (needs one bias instead of a concatenation)
assert not enable_bias

cls = type(layers[0])
# Should not already be wrapped with Peft.
Assert.incl(cls, (Linear, InputParallelLinear, OutputParallelLinear))
# The concatenated dimension must be at index zero.
transposed_weight = concatenate_input_dim
for layer in layers:
Assert.eq(layer._transposed_weight, transposed_weight)
Assert.is_(type(layer), cls)
Assert.eq(layer.bias is not None, enable_bias)

if cls in (InputParallelLinear, OutputParallelLinear):
for layer in layers[1:]:
Assert.is_(layer._parallel_dim, layers[0]._parallel_dim)
Assert.eq(layer._sequence_parallel, layers[0]._sequence_parallel)
args = {"parallel_dim": layers[0]._parallel_dim, "sequence_parallel": layers[0]._sequence_parallel}
else:
args = {}

# TODO: Original parameters won't get names.
weight = ConcatenatedParameterMeta.from_metas(tuple(layer.weight for layer in layers), dim_name=dim_name)
bias = (
ConcatenatedParameterMeta.from_metas(tuple(layer.bias for layer in layers), dim_name=dim_name)
if enable_bias
else None
)

out = cls(weight, bias, transposed_weight=transposed_weight, **args)
if peft is not None:
if isinstance(default_apply_peft, bool):
default_apply_peft = (default_apply_peft,) * len(layers)
apply_peft = [
default if config.apply_peft is None else config.apply_peft
for config, default in zip(configs, default_apply_peft, strict=True)
]
if len(set(apply_peft)) == 1:
out_channel_ranges = None
enabled = apply_peft[0]
else:
enabled = True
out_channel_ranges = tuple(
split_range
for split_range, apply_peft_ in zip(weight.dims[0].get_split_ranges(True), apply_peft)
if apply_peft_
)
out = peft.apply_linear(out, enabled, out_channel_ranges=out_channel_ranges)
return out
6 changes: 2 additions & 4 deletions fast_llm/layers/common/peft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def apply_linear(
self,
module: "LinearBase",
enabled: bool,
out_channel_begin: int | None = None,
out_channel_end: int | None = None,
out_channel_ranges: tuple[tuple[int | None, int | None], ...] | None = None,
) -> "LinearLike":
return self.apply_other(module)

Expand Down Expand Up @@ -75,8 +74,7 @@ def apply_linear(
self,
module: "LinearBase",
enabled: bool,
out_channel_begin: int | None = None,
out_channel_end: int | None = None,
out_channel_ranges: tuple[tuple[int | None, int | None], ...] | None = None,
) -> "LinearLike":
if not enabled:
return self.apply_other(module)
Expand Down
Loading