From 7709e65d05652906936ada2cdb53c31ab4e68663 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Apr 2025 06:51:19 +0000 Subject: [PATCH 001/161] WIP: multimodal support --- fast_llm/data/config.py | 39 ++++++++++++++++++ fast_llm/data/image_processor.py | 40 +++++++++++++++++++ .../data/preparator/gpt_memmap/prepare.py | 3 ++ fast_llm/data/processor.py | 11 +++++ setup.cfg | 2 + 5 files changed, 95 insertions(+) create mode 100644 fast_llm/data/image_processor.py create mode 100644 fast_llm/data/processor.py diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 1586d370d..351dcaaef 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -34,3 +34,42 @@ class TokenizerConfig(Config): desc="Path to the tokenizer file.", hint=FieldHint.core, ) + + +@config_class() +class ImageProcessorConfig(Config): + """ + Configuration for the image processor + """ + + # Defaults taken from [pixtral](https://github.com/huggingface/transformers/blob/794fde7b1c3d041519fc28ea3e1461b0cfcad4e7/src/transformers/models/pixtral/image_processing_pixtral.py#L201) + patch_size: list[int] = Field( + default_factory=lambda: [16, 16], + desc="Size for each path extracted from the image. Each patch corresponds to a token for the transformer", + hint=FieldHint.optional, + ) + max_height: int = Field( + default=1024, + desc="Maximum height of the image. Image will be resized if larger", + hint=FieldHint.optional, + ) + max_width: int = Field( + default=1024, + desc="Maximum width of the image. Image will be resized if larger", + hint=FieldHint.optional, + ) + mean: list[float] = Field( + default_factory=lambda: [0.48145466, 0.4578275, 0.40821073], + desc="Mean RGB values for pixel normalization", + hint=FieldHint.optional, + ) + std: list[float] = Field( + default_factory=lambda: [0.26862954, 0.26130258, 0.27577711], + desc="Standard deviation RGB values for pixel normalization", + hint=FieldHint.optional, + ) + rescale_factor: float = Field( + default=255.0, + desc="Diminisher factor for pixel normalization", + hint=FieldHint.optional, + ) diff --git a/fast_llm/data/image_processor.py b/fast_llm/data/image_processor.py new file mode 100644 index 000000000..cf4c6e938 --- /dev/null +++ b/fast_llm/data/image_processor.py @@ -0,0 +1,40 @@ +import math + +import torch +from torchvision.transforms.v2 import functional as F + +from fast_llm.data.config import ImageProcessorConfig + + +class ImageProcessor: + def __init__(self, config: ImageProcessorConfig): + self.patch_size = config.patch_size + self.mean = config.mean / config.rescale_factor + self.std = config.std / config.rescale_factor + self.max_height = config.max_height + self.max_width = config.max_width + assert ( + self.max_height % self.patch_size[0] == 0 + ), "max_height must be divisible by patch_size[0]. Found {max_height} and {self.patch_size[0]}" + assert ( + self.max_width % self.patch_size[1] == 0 + ), "max_width must be divisible by patch_size[1]. Found {max_width} and {self.patch_size[1]}" + + def resize(self, image: torch.Tensor) -> torch.Tensor: + # Resize the image to the specified size + height = image.shape[0] + width = image.shape[1] + ratio = max(height / self.max_height, width / self.max_width) + if ratio > 1: + height = math.ceil(height / ratio) + width = math.ceil(width / ratio) + else: + height = self.patch_size[0] * math.ceil(height / self.self.patch_size[0]) + width = self.patch_size[1] * math.ceil(width / self.patch_size[1]) + + # TODO: options for interpolation mode + return F.resize(image, size=(height, width), interpolation=F.InterpolationMode.BICUBIC) + + def normalize(self, image: torch.Tensor) -> torch.Tensor: + # Normalize the image using the mean and std + return F.normalize(image, mean=self.mean, std=self.std) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index b3dae1df1..5cfad9ec5 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -38,6 +38,9 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D _tokenizer: Tokenizer _data_type: DataType + def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: + pass + def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: input_ids = [ np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) diff --git a/fast_llm/data/processor.py b/fast_llm/data/processor.py new file mode 100644 index 000000000..43b1cda83 --- /dev/null +++ b/fast_llm/data/processor.py @@ -0,0 +1,11 @@ +from fast_llm.data.tokenizer import Tokenizer + + +class MultiModalProcessor: + """ + Combines multiple modalities (text and image) and converts to tokens/patches for text and images. + """ + + def __init__(self, tokenizer: Tokenizer, image_processor=None): + self._tokenizer = tokenizer + self._image_processor = image_processor diff --git a/setup.cfg b/setup.cfg index c21f02a7b..3c1dad9da 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,8 @@ OPTIONAL = # Miscellanous requests>=2.32.3 tqdm>=4.66.3 + # Vision Tools + torchvision>=0.20.0 DEV = # Pre-commit git hook From 0db2bd21218fa133d4a1e41223552ece8f3044a7 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 9 Apr 2025 06:19:10 +0000 Subject: [PATCH 002/161] rough idea for memmap --- fast_llm/data/config.py | 18 ++++ fast_llm/data/dataset/gpt/memmap.py | 59 ++++++++++-- fast_llm/data/dataset/gpt/sampled.py | 2 + fast_llm/data/image_processor.py | 3 + fast_llm/data/preparator/gpt_memmap/config.py | 8 +- .../data/preparator/gpt_memmap/prepare.py | 92 ++++++++++++++----- fast_llm/data/processor.py | 11 --- 7 files changed, 145 insertions(+), 48 deletions(-) delete mode 100644 fast_llm/data/processor.py diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 351dcaaef..8c2c3c28e 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -73,3 +73,21 @@ class ImageProcessorConfig(Config): desc="Diminisher factor for pixel normalization", hint=FieldHint.optional, ) + + +@config_class() +class MultiModalProcessorConfig(Config): + """ + Wrapper config that stores the `ImageProcessorConfig` and `TokenizerConfig` + """ + + tokenizer: TokenizerConfig = Field( + default_factory=TokenizerConfig, + desc="Configuration for the tokenizer.", + hint=FieldHint.core, + ) + image_processor: ImageProcessorConfig = Field( + default_factory=ImageProcessorConfig, + desc="Configuration for the image processor.", + hint=FieldHint.core, + ) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index ef060b008..c8b2592f1 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -38,10 +38,14 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 3: + self._has_images = struct.unpack("= 2: self._spans = [] self._num_spans = np.frombuffer( self._index_bin_buffer, @@ -82,6 +86,8 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None offset=span_offset + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, ).reshape(-1, 2) ) + if self._has_images and self._version >= 3: + self._image_sizes = np.frombuffer() self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) @@ -151,7 +157,11 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Initialize metadata dtype = None num_documents = 0 - lengths = [] + doc_lengths = [] + n_images = [] + im_lengths = [] + im_positions = [] + total_images = 0 pointers = [] offset = 0 # number of spans for each document @@ -160,8 +170,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP prefix = pathlib.Path(prefix) prefix.parent.mkdir(parents=True, exist_ok=True) + pathlib.Path(prefix + "_images") # Write the binary data file (.bin) lazily + # TODO Soham: append image tokens along with text tokens with prefix.with_suffix(".bin").open("wb") as bin_stream: for document in documents: # Infer dtype from the first document @@ -174,10 +186,18 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Write document to binary file bin_stream.write(document.token_ids.tobytes(order="C")) + if document.images: + n_images.append(len(document.images)) + total_images += len(document.images) + for image, image_position in zip(document.images, document.image_positions): + im_lengths.append(image.size) + im_positions.append(document.image_positions) + bin_stream.write(image.tobytes(order="C")) # Update metadata doc_length = len(document.token_ids) - lengths.append(doc_length) + doc_lengths.append(doc_length) + im_lengths.append() pointers.append(offset) if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) @@ -186,7 +206,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP num_documents += 1 # Finalize metadata arrays - lengths = np.array(lengths, dtype=np.int32) + doc_lengths = np.array(doc_lengths, dtype=np.int32) pointers = np.array(pointers, dtype=np.int64) num_spans = np.array(num_spans, dtype=np.int32) if len(spans) > 0: @@ -194,27 +214,46 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP else: spans = np.array(spans, dtype=np.int32) + # TODO Soham: else condition might not be necessary + if total_images: + n_images = np.array(n_images, dtype=np.int32) + im_lengths = np.array(im_lengths, dtype=np.int32) + im_positions = np.array(im_positions, dtype=np.int32) + else: + n_images = np.array([]) + im_lengths = np.array([]) + im_positions = np.array([]) + # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: idx_stream.write(MEMMAP_INDEX_HEADER) # Indicates the version - # Version 2 optionally adds loss-masking spans - idx_stream.write(struct.pack(" 0 else 0)) + # Flag to indicate whether images are present + idx_stream.write(struct.pack(" 0 else 0)) # Data type idx_stream.write(struct.pack(" torch.Tensor: def normalize(self, image: torch.Tensor) -> torch.Tensor: # Normalize the image using the mean and std return F.normalize(image, mean=self.mean, std=self.std) + + def get_num_patches(self, image: torch.Tensor) -> torch.Tensor: + return (image.size(0) // self.patch_size[0]) * (image.size(1) // self.patch_size[1]) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2c4311c37..60262743e 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -3,7 +3,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.data.config import TokenizerConfig +from fast_llm.data.config import MultiModalProcessorConfig from fast_llm.data.preparator.config import DatasetPreparatorConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -153,9 +153,9 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for the dataset.", hint=FieldHint.feature, ) - tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, - desc="Configuration for the tokenizer.", + data_processor: MultiModalProcessorConfig = Field( + default_factory=MultiModalProcessorConfig, + desc="Configuration for data processing. Describes the tokenizer and image processor", hint=FieldHint.feature, ) splits: dict[str, float] | None = Field( diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 5cfad9ec5..d4180986e 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -23,9 +23,9 @@ ) from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.multi_modal_processor import MultiModalProcessor from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig -from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -35,45 +35,79 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig - _tokenizer: Tokenizer + # _tokenizer: Tokenizer + _data_processor: MultiModalProcessor _data_type: DataType def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: pass def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids = [ - np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) - for text in batch[self._config.dataset.field] + # input_ids = [ + # np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) + # for text in batch[self._config.dataset.field] + # ] + input_ids, images, image_token_positions = map( + list, + zip( + *[ + ( + np.array(input_ids, dtype=self._data_type.numpy), + np.array(images, dtype=np.uint8), + np.array(image_token_positions, dtype=np.int32), + ) + for input_ids, images, image_token_positions in [ + self._data_processor.tokenize(text, ims, im_char_positions) + for text, ims, im_char_positions in zip( + batch[self._config.dataset.field], + batch[self._config.dataset.images], + batch[self._config.dataset.image_positions], + ) + ] + ] + ), + ) + num_tokens = [ + len(x) + self._data_processor._image_processor.get_num_patches(im) for x, im in zip(input_ids, images) ] - num_tokens = [len(x) for x in input_ids] return { "input_ids": input_ids, + "images": images, + "image_positions": image_token_positions, "num_tokens": num_tokens, } def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids, token_spans = map( + input_ids, token_spans, images, image_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans, dtype=np.int32).reshape(-1, 2), + np.array(images, dtype=np.uint8), + np.array(image_token_positions, dtype=np.int32), ) - for input_ids, token_spans in [ - self._tokenizer.tokenize_with_spans(text, char_spans) + for input_ids, token_spans, images, image_token_positions in [ + self._data_processor.tokenize_with_spans(text, char_spans) for text, char_spans in zip( - batch[self._config.dataset.field], batch[self._config.dataset.loss_masking_spans] + batch[self._config.dataset.field], + batch[self._config.dataset.loss_masking_spans], + batch[self._config.dataset.images], + batch[self._config.dataset.image_positions], ) ] ] ), ) - num_tokens = [len(x) for x in input_ids] + num_tokens = [ + len(x) + self._data_processor._image_processor.get_num_patches(im) for x, im in zip(input_ids, images) + ] return { "input_ids": input_ids, "token_spans": token_spans, + "images": images, + "image_positions": image_token_positions, "num_tokens": num_tokens, } @@ -83,15 +117,27 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon shard_output_path = self._config.output_path / prefix def _document_generator(): - if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - np.array(item["input_ids"], dtype=self._data_type.numpy), - np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), - ) - else: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) + # TODO Soham: simplify this + for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + yield GPTSample( + np.array(item["input_ids"], dtype=self._data_type.numpy), + ( + np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2) + if self._config.dataset.loss_masking_spans + else None + ), + images if self._config.dataset.images else None, + image_positions if self._config.dataset.image_positions else None, + ) + # if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: + # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + # yield GPTSample( + # np.array(item["input_ids"], dtype=self._data_type.numpy), + # np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), + # ) + # else: + # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + # yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -169,12 +215,12 @@ def run(self) -> None: if self._config.dataset.disable_disk_space_check: datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True - # Load tokenizer - self._tokenizer = Tokenizer(config=self._config.tokenizer) + # Load Processor + self._processor = MultiModalProcessor(config=self._config.data_processor) # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( - get_unsigned_integer_type(self._tokenizer.vocab_size) + get_unsigned_integer_type(self.processor._tokenizer.vocab_size) if self._config.dataset.data_type is None else self._config.dataset.data_type ) diff --git a/fast_llm/data/processor.py b/fast_llm/data/processor.py deleted file mode 100644 index 43b1cda83..000000000 --- a/fast_llm/data/processor.py +++ /dev/null @@ -1,11 +0,0 @@ -from fast_llm.data.tokenizer import Tokenizer - - -class MultiModalProcessor: - """ - Combines multiple modalities (text and image) and converts to tokens/patches for text and images. - """ - - def __init__(self, tokenizer: Tokenizer, image_processor=None): - self._tokenizer = tokenizer - self._image_processor = image_processor From 0d89f68d7c4d5a40f5fa7e2651ac61b75da31aa5 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 15 Apr 2025 06:10:33 +0000 Subject: [PATCH 003/161] faster image size reading --- fast_llm/data/dataset/gpt/memmap.py | 54 ++++++++++++------ fast_llm/data/image_processor.py | 17 +++--- fast_llm/data/preparator/gpt_memmap/config.py | 18 +++++- .../data/preparator/gpt_memmap/prepare.py | 55 ++++++++++++------- setup.cfg | 3 + 5 files changed, 101 insertions(+), 46 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index c8b2592f1..069240540 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -34,12 +34,12 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._name = name self._prefix = pathlib.Path(prefix) self._has_spans = 0 + self._has_images = 0 with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 2: self._spans = [] @@ -73,9 +74,8 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._index_bin_buffer, dtype=np.int32, count=self._num_documents, - offset=offset + self._document_sizes.nbytes + self._pointers.nbytes, + offset=offset, ) - span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + self._num_spans.nbytes self._num_spans_cumsum = np.r_[0, np.cumsum(self._num_spans[:-1], dtype=np.int64)] for idx in range(self._num_documents): self._spans.append( @@ -83,18 +83,40 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._index_bin_buffer, dtype=np.int32, count=self._num_spans[idx] * 2, - offset=span_offset + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, + offset=offset + + self._num_spans.nbytes + + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, ).reshape(-1, 2) ) + offset += ( + self._num_spans.nbytes + + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize + + sum([x.nbytes for x in self._spans]) + ) if self._has_images and self._version >= 3: - self._image_sizes = np.frombuffer() + self._n_images = np.frombuffer( + self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset + ) + self._im_lengths = np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=self._n_images.sum() * 3, + offset=offset + self._n_images.nbytes, + ) + self._im_positions = np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=self._n_images.sum(), + offset=offset + self._n_images.nbytes + self._im_lengths.nbytes, + ) self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) + # TODO Soham: fix num_tokens to include images self._num_tokens = div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) - if num_tokens is not None: - assert self._num_tokens == num_tokens + # if num_tokens is not None: + # assert self._num_tokens == num_tokens def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: return (self._name, self._prefix, self._num_documents, self._num_tokens) @@ -110,6 +132,7 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap + # TODO Soham: get images def get( self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False ) -> GPTSample: @@ -170,10 +193,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP prefix = pathlib.Path(prefix) prefix.parent.mkdir(parents=True, exist_ok=True) - pathlib.Path(prefix + "_images") # Write the binary data file (.bin) lazily - # TODO Soham: append image tokens along with text tokens with prefix.with_suffix(".bin").open("wb") as bin_stream: for document in documents: # Infer dtype from the first document @@ -186,23 +207,25 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Write document to binary file bin_stream.write(document.token_ids.tobytes(order="C")) + total_im_size = 0 if document.images: n_images.append(len(document.images)) total_images += len(document.images) for image, image_position in zip(document.images, document.image_positions): - im_lengths.append(image.size) + # assume 3 channels (RGB) for all images + im_lengths.append(np.array(image.shape[1:])) im_positions.append(document.image_positions) bin_stream.write(image.tobytes(order="C")) + total_im_size += image.size # Update metadata doc_length = len(document.token_ids) doc_lengths.append(doc_length) - im_lengths.append() pointers.append(offset) if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) spans.append(document.loss_masking_spans) - offset += doc_length * np.dtype(dtype).itemsize + offset += doc_length * np.dtype(dtype).itemsize + total_im_size * np.dtype(np.uint8).itemsize num_documents += 1 # Finalize metadata arrays @@ -214,15 +237,12 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP else: spans = np.array(spans, dtype=np.int32) - # TODO Soham: else condition might not be necessary if total_images: n_images = np.array(n_images, dtype=np.int32) - im_lengths = np.array(im_lengths, dtype=np.int32) - im_positions = np.array(im_positions, dtype=np.int32) else: n_images = np.array([]) - im_lengths = np.array([]) - im_positions = np.array([]) + im_lengths = np.stack(im_lengths, dtype=np.int32) + im_positions = np.array(im_positions, dtype=np.int32) # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: diff --git a/fast_llm/data/image_processor.py b/fast_llm/data/image_processor.py index 473db11a2..c5cbe9095 100644 --- a/fast_llm/data/image_processor.py +++ b/fast_llm/data/image_processor.py @@ -9,8 +9,8 @@ class ImageProcessor: def __init__(self, config: ImageProcessorConfig): self.patch_size = config.patch_size - self.mean = config.mean / config.rescale_factor - self.std = config.std / config.rescale_factor + self.mean = [x / config.rescale_factor for x in config.mean] + self.std = [x / config.rescale_factor for x in config.std] self.max_height = config.max_height self.max_width = config.max_width assert ( @@ -20,16 +20,19 @@ def __init__(self, config: ImageProcessorConfig): self.max_width % self.patch_size[1] == 0 ), "max_width must be divisible by patch_size[1]. Found {max_width} and {self.patch_size[1]}" - def resize(self, image: torch.Tensor) -> torch.Tensor: + def resize(self, image): # Resize the image to the specified size - height = image.shape[0] - width = image.shape[1] + # TODO Soham: resize for patches only during train? + # TODO Soham: convert all images to tensor? + # height = image.shape[0] + # width = image.shape[1] + height, width = image.size ratio = max(height / self.max_height, width / self.max_width) if ratio > 1: height = math.ceil(height / ratio) width = math.ceil(width / ratio) else: - height = self.patch_size[0] * math.ceil(height / self.self.patch_size[0]) + height = self.patch_size[0] * math.ceil(height / self.patch_size[0]) width = self.patch_size[1] * math.ceil(width / self.patch_size[1]) # TODO: options for interpolation mode @@ -40,4 +43,4 @@ def normalize(self, image: torch.Tensor) -> torch.Tensor: return F.normalize(image, mean=self.mean, std=self.std) def get_num_patches(self, image: torch.Tensor) -> torch.Tensor: - return (image.size(0) // self.patch_size[0]) * (image.size(1) // self.patch_size[1]) + return (image.size[0] // self.patch_size[0]) * (image.size[1] // self.patch_size[1]) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 60262743e..8a15d96c8 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -59,6 +59,15 @@ class GPTHuggingfaceDatasetConfig(Config): loss_masking_spans: None | str = Field( default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional ) + image_paths: None | str = Field( + default=None, desc="Field containing images within the document", hint=FieldHint.optional + ) + image_positions: None | str = Field( + default=None, desc="Field containing image positions within a document", hint=FieldHint.optional + ) + images: None | str = Field( + default=None, desc="Field containing images relevant to a document", hint=FieldHint.optional + ) data_type: DataType | None = Field( default=None, desc="Data type of the dataset field." @@ -142,6 +151,12 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.optional, valid=check_field(Assert.geq, 1), ) + tokenize_batch_size: int = Field( + default=1000, + desc="Batch size for tokenization.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 1), + ) saving_workers: int = Field( default=1, desc="Number of processes for saving the data.", @@ -165,8 +180,9 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.optional, ) + # TODO Soham: move tokenizer validation to MultiModalDataProcessor def _validate(self) -> None: - assert self.tokenizer.path is not None + assert self.data_processor.tokenizer.path is not None if self.dataset.data_type is not None: Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) super()._validate() diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index d4180986e..0199cb400 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,3 +1,5 @@ +import io +import itertools import json import logging import multiprocessing @@ -13,6 +15,7 @@ import tqdm import transformers import yaml +from PIL import Image from fast_llm.data.dataset.gpt.config import ( GPTBlendedDatasetConfig, @@ -42,37 +45,43 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: pass + # TODO Soham: can we merged tokenize_batch and tokenize_batch_with_spans? def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: # input_ids = [ # np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) # for text in batch[self._config.dataset.field] # ] - input_ids, images, image_token_positions = map( + input_ids, image_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), - np.array(images, dtype=np.uint8), np.array(image_token_positions, dtype=np.int32), ) - for input_ids, images, image_token_positions in [ - self._data_processor.tokenize(text, ims, im_char_positions) - for text, ims, im_char_positions in zip( + for input_ids, image_token_positions in [ + self._data_processor.tokenize( + text, + im_char_positions, + ) + for text, im_char_positions in zip( batch[self._config.dataset.field], - batch[self._config.dataset.images], - batch[self._config.dataset.image_positions], + batch.get(self._config.dataset.image_positions, itertools.repeat(None)), ) ] ] ), ) - num_tokens = [ - len(x) + self._data_processor._image_processor.get_num_patches(im) for x, im in zip(input_ids, images) - ] + num_tokens = [len(x) for x in input_ids] + # TODO Soham: is this ok? Should we get num_image_tokens separately? + for idx, images in enumerate(batch.get("images", [])): + for bytes_im in images: + with Image.open(io.BytesIO(bytes_im["bytes"])) as im: + width, height = im.size + num_tokens[idx] += (width * height * 3) // np.dtype(self._dtype).itemsize + return { "input_ids": input_ids, - "images": images, "image_positions": image_token_positions, "num_tokens": num_tokens, } @@ -92,16 +101,17 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict self._data_processor.tokenize_with_spans(text, char_spans) for text, char_spans in zip( batch[self._config.dataset.field], - batch[self._config.dataset.loss_masking_spans], - batch[self._config.dataset.images], - batch[self._config.dataset.image_positions], + batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), + batch.get(self._config.dataset.images, itertools.repeat(None)), + batch.get(self._config.dataset.image_positions, itertools.repeat(None)), ) ] ] ), ) num_tokens = [ - len(x) + self._data_processor._image_processor.get_num_patches(im) for x, im in zip(input_ids, images) + len(x) + sum([self._data_processor._image_processor.get_num_patches(im) for im in doc_images]) + for x, doc_images in zip(input_ids, images) ] return { "input_ids": input_ids, @@ -117,7 +127,6 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon shard_output_path = self._config.output_path / prefix def _document_generator(): - # TODO Soham: simplify this for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( np.array(item["input_ids"], dtype=self._data_type.numpy), @@ -126,8 +135,9 @@ def _document_generator(): if self._config.dataset.loss_masking_spans else None ), - images if self._config.dataset.images else None, - image_positions if self._config.dataset.image_positions else None, + # [np.array(Image.open(pathlib.Path(self._config.dataset.path) / path)) for path in item["image_paths"]] if self._config.dataset.image_paths else None, + [np.array(im) for im in item["images"]] if self._config.dataset.images else None, + item["image_positions"] if self._config.dataset.image_positions else None, ) # if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): @@ -215,12 +225,12 @@ def run(self) -> None: if self._config.dataset.disable_disk_space_check: datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True - # Load Processor - self._processor = MultiModalProcessor(config=self._config.data_processor) + # Load the data processor + self._data_processor = MultiModalProcessor(config=self._config.data_processor) # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( - get_unsigned_integer_type(self.processor._tokenizer.vocab_size) + get_unsigned_integer_type(self._data_processor._tokenizer.vocab_size) if self._config.dataset.data_type is None else self._config.dataset.data_type ) @@ -269,6 +279,9 @@ def run(self) -> None: tokenize_fn = self._tokenize_batch_with_spans else: tokenize_fn = self._tokenize_batch + # Avoid decoding bytes to images unless asked + if self._config.dataset.images is not None: + dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) # Tokenize the dataset in parallel tokenized_dataset = dataset.map( diff --git a/setup.cfg b/setup.cfg index 3c1dad9da..57913f83d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,9 @@ OPTIONAL = requests>=2.32.3 tqdm>=4.66.3 # Vision Tools + # TODO Soham: use pillow-simd instead of pillow? + webp>=0.4.0 + pillow-simd>=9.5.0 torchvision>=0.20.0 DEV = From 3866a5330fcf299ba8347b8e3aed057b598b5185 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 21 Apr 2025 07:04:41 +0000 Subject: [PATCH 004/161] solidify prepare --- fast_llm/data/config.py | 60 ++++----- fast_llm/data/data/gpt/config.py | 1 + fast_llm/data/data/gpt/data.py | 9 +- fast_llm/data/dataset/gpt/config.py | 15 ++- fast_llm/data/dataset/gpt/memmap.py | 127 ++++++++++++++---- fast_llm/data/dataset/gpt/sampled.py | 52 +++++-- fast_llm/data/image_processor.py | 25 ++-- fast_llm/data/preparator/gpt_memmap/config.py | 10 +- .../data/preparator/gpt_memmap/prepare.py | 49 ++++--- fast_llm/data/tokenizer.py | 30 ++++- fast_llm/layers/language_model/config.py | 14 ++ 11 files changed, 291 insertions(+), 101 deletions(-) diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 8c2c3c28e..f1a0fd58a 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -43,36 +43,36 @@ class ImageProcessorConfig(Config): """ # Defaults taken from [pixtral](https://github.com/huggingface/transformers/blob/794fde7b1c3d041519fc28ea3e1461b0cfcad4e7/src/transformers/models/pixtral/image_processing_pixtral.py#L201) - patch_size: list[int] = Field( - default_factory=lambda: [16, 16], - desc="Size for each path extracted from the image. Each patch corresponds to a token for the transformer", - hint=FieldHint.optional, - ) - max_height: int = Field( - default=1024, - desc="Maximum height of the image. Image will be resized if larger", - hint=FieldHint.optional, - ) - max_width: int = Field( - default=1024, - desc="Maximum width of the image. Image will be resized if larger", - hint=FieldHint.optional, - ) - mean: list[float] = Field( - default_factory=lambda: [0.48145466, 0.4578275, 0.40821073], - desc="Mean RGB values for pixel normalization", - hint=FieldHint.optional, - ) - std: list[float] = Field( - default_factory=lambda: [0.26862954, 0.26130258, 0.27577711], - desc="Standard deviation RGB values for pixel normalization", - hint=FieldHint.optional, - ) - rescale_factor: float = Field( - default=255.0, - desc="Diminisher factor for pixel normalization", - hint=FieldHint.optional, - ) + # patch_size: list[int] = Field( + # default_factory=lambda: [16, 16], + # desc="Size for each path extracted from the image. Each patch corresponds to a token for the transformer", + # hint=FieldHint.optional, + # ) + # max_height: int = Field( + # default=1024, + # desc="Maximum height of the image. Image will be resized if larger", + # hint=FieldHint.optional, + # ) + # max_width: int = Field( + # default=1024, + # desc="Maximum width of the image. Image will be resized if larger", + # hint=FieldHint.optional, + # ) + # mean: list[float] = Field( + # default_factory=lambda: [0.48145466, 0.4578275, 0.40821073], + # desc="Mean RGB values for pixel normalization", + # hint=FieldHint.optional, + # ) + # std: list[float] = Field( + # default_factory=lambda: [0.26862954, 0.26130258, 0.27577711], + # desc="Standard deviation RGB values for pixel normalization", + # hint=FieldHint.optional, + # ) + # rescale_factor: float = Field( + # default=255.0, + # desc="Diminisher factor for pixel normalization", + # hint=FieldHint.optional, + # ) @config_class() diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index c98a781e6..652342b58 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -21,6 +21,7 @@ class GPTSamplingDefaultConfig(SamplingDefaultConfig, GPTSamplingConfig): gpu: bool = FieldUpdate(default=True) use_loss_masking_spans: bool = FieldUpdate(default=False) + use_images: bool = FieldUpdate(default=False) shuffle: ShufflingType = FieldUpdate(default=ShufflingType.epoch) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index a0940e7c6..5bd9d09e2 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -32,10 +32,16 @@ class GPTBatch: token_ids: torch.Tensor loss_masking_spans: list[torch.Tensor] | None = None sequence_lengths: list[torch.Tensor] | None = None + images: list[torch.Tensor] | None = None + image_positions: list[torch.Tensor] | None = None +# TODO: do we need a separate use_images? def gpt_data_collate_fn( - batch: list[GPTSample], use_loss_masking_spans: bool, cross_document_attention: bool + batch: list[GPTSample], + use_loss_masking_spans: bool, + cross_document_attention: bool, + use_images: bool, ) -> GPTBatch: stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None @@ -170,6 +176,7 @@ def get_iterator( gpt_data_collate_fn, use_loss_masking_spans=self._config.sampling.use_loss_masking_spans, cross_document_attention=self._cross_document_attention, + use_images=self._config.sampling.use_images, ), multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 0f04884b6..45d27e7d0 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -57,6 +57,11 @@ class GPTSamplingConfig(SamplingConfig): desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) + use_images: bool | None = Field( + default=None, + desc="Use images in the dataset.", + hint=FieldHint.feature, + ) shuffle: ShufflingType | None = Field( default=None, desc="Shuffling strategy.", @@ -73,6 +78,7 @@ class GPTSamplingData(SamplingData): tokenizer: "Tokenizer" truncate_documents: bool = True cross_document_attention: bool = True + patch_size: list[int] | None = None @config_class() @@ -178,11 +184,18 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): desc="Expected number of tokens in the dataset.", hint=FieldHint.optional, ) + num_pixels: int | None = Field( + default=None, + desc="Expected number of pixels in the dataset.", + hint=FieldHint.optional, + ) def build(self) -> "GPTMemmapDataset": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) + return GPTMemmapDataset( + str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens, self.num_pixels + ) @config_class() diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 069240540..87bd3a8eb 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -1,8 +1,10 @@ +import io import pathlib import struct import typing import numpy as np +import PIL.Image from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample @@ -26,10 +28,18 @@ def __init__( prefix: pathlib.Path | str, num_documents: int | None = None, num_tokens: int | None = None, + num_pixels: int | None = None, ): - self._init(name, prefix, num_documents, num_tokens) + self._init(name, prefix, num_documents, num_tokens, num_pixels) - def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None, num_tokens: int | None) -> None: + def _init( + self, + name: str, + prefix: pathlib.Path | str, + num_documents: int | None, + num_tokens: int | None, + num_pixels: int | None, + ) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) @@ -93,30 +103,48 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize + sum([x.nbytes for x in self._spans]) ) + self._n_pixels = 0 if self._has_images and self._version >= 3: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) - self._im_lengths = np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._n_images.sum() * 3, - offset=offset + self._n_images.nbytes, - ) - self._im_positions = np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._n_images.sum(), - offset=offset + self._n_images.nbytes + self._im_lengths.nbytes, - ) + self._im_lengths = [] + self._im_positions = [] + images_seen = 0 + # TODO Soham: verify correctness, reshaping into width, height? + for n_images in self._n_images: + self._im_lengths.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_images * 2, + offset=offset + self._n_images.nbytes + 2 * images_seen * np.dtype(np.int32).itemsize, + ).reshape(-1, 2) + ) + self._n_pixels += self._im_lengths[-1].prod(axis=1, initial=3).sum() + self._im_positions.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_images, + offset=offset + + self._n_images.nbytes + + 2 * self._n_images.sum() * np.dtype(np.int32).itemsize + + images_seen * np.dtype(np.int32).itemsize, + ) + ) + images_seen += n_images self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) - # TODO Soham: fix num_tokens to include images - self._num_tokens = div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) - # if num_tokens is not None: - # assert self._num_tokens == num_tokens + # TODO Soham: fix num_tokens to include images. Get total number of image pixels from index file and assign + # self._num_tokens = div(self._bin_buffer_mmap.size - n_pixels, np.dtype(self._dtype).itemsize) + self._num_tokens = div(self._bin_buffer_mmap.size - self._n_pixels, np.dtype(self._dtype).itemsize) + if num_pixels is not None: + assert self._n_pixels == num_pixels + if num_tokens is not None: + assert self._num_tokens == num_tokens def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: return (self._name, self._prefix, self._num_documents, self._num_tokens) @@ -133,6 +161,42 @@ def __del__(self): del self._index_bin_buffer_mmap # TODO Soham: get images + def get( + self, + idx: int, + offset: int = 0, + length: int | None = None, + use_loss_masking_spans: bool = False, + # , patch_size: tuple(int), max_height: int, max_width: int + ): + # TODO Soham: Handle truncations? + # if self._has_images: + # doc_size = self._document_sizes[idx] + # n_images = self._n_images[idx] + # image_positions = self._im_positions[idx] + # image_lengths = self._im_lengths[idx] + # image_tokens_seen = 0 + # for idx in range(n_images): + # height, width = ImageProcessor.get_resize_dims(image_lengths[0], image_lengths[1], max_height, max_width) + # n_image_tokens = (height // patch_size[0]) * (width // patch_size[1]) + # if (image_positions[idx] > offset + length) or (image_positions[idx] + n_tokens < offset): + # continue + token_ids = np.frombuffer( + self._bin_buffer, + dtype=self._dtype, + count=self._document_sizes[idx] - offset if length is None else length, + offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, + ) + if self._has_images: + image_positions = self._im_positions[idx] + images = np.frombuffer( + self._bin_buffer, + dtype=np.dtype(np.uint8).itemsize, + count=self._image_lengths[idx][0] * self._image_lengths[idx][1] * 3, + offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, + ) + return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) + def get( self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False ) -> GPTSample: @@ -164,16 +228,25 @@ def __len__(self) -> int: def num_tokens(self) -> int: return self._num_tokens + @property + def has_images(self) -> bool: + return self._has_images + + # TODO: image sizes def get_document_sizes(self) -> np.ndarray: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes + return self._document_sizes, self._im_lengths - def get_document_size(self, index: int) -> int: - return self._document_sizes[index].item() + def get_document_size(self, index: int, patch_size: list[int]) -> int: + return self._document_sizes[index].item() + ( + sum((h // patch_size[0]) * (w // patch_size[1]) for h, w in self._image_lengths[index]) + if self._has_images + else 0 + ) @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): @@ -211,12 +284,14 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if document.images: n_images.append(len(document.images)) total_images += len(document.images) - for image, image_position in zip(document.images, document.image_positions): + for image in document.images: # assume 3 channels (RGB) for all images - im_lengths.append(np.array(image.shape[1:])) - im_positions.append(document.image_positions) - bin_stream.write(image.tobytes(order="C")) - total_im_size += image.size + with PIL.Image.open(io.BytesIO(image["bytes"])) as img: + pixels = np.array(img) + im_lengths.append(np.array(pixels.shape[:2])) + bin_stream.write(pixels.tobytes(order="C")) + total_im_size += pixels.size + im_positions.append(document.image_positions) # Update metadata doc_length = len(document.token_ids) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 22e3396b4..288018b12 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -12,6 +12,7 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset +from fast_llm.data.image_processor import ImageProcessor from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert @@ -89,11 +90,17 @@ def __init__( self._indexed_dataset = indexed_dataset self._num_samples = sampling.num_samples self._sequence_length = sampling.sequence_length + self._patch_size = sampling.patch_size self._cross_document_attention = sampling.cross_document_attention self._config = sampling.config self._truncate_documents = sampling.truncate_documents self._device = torch.device("cuda" if self._config.gpu else "cpu") + if self._indexed_dataset.has_images and self._truncate_documents: + raise RuntimeError( + "Truncating documents with images is not supported. Please turn off truncation to use images." + ) + if sampling.cache_directory is None: self._document_shuffling = MemmapArray() self._token_cumsum_shuffled = MemmapArray() @@ -126,9 +133,15 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) + # TODO Soham: verify numpy correctness + document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() + document_sizes = torch.from_numpy(document_sizes).to(self._device) + image_token_sizes = torch.zeros_like(document_sizes).to(self._device) + for i, sizes in enumerate(image_sizes): + image_token_sizes[i] = sum(sizes[0, :] // self._patch_size[0]) * (sizes[:, 1] // self._patch_size[1]) + documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() + tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() # Calculate basic stats. if not self._truncate_documents: @@ -136,14 +149,14 @@ def _sample(self) -> None: "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes > self._sequence_length + 1 + long_docs_filter = document_sizes + image_token_sizes > self._sequence_length + 1 ignored_documents = sum(long_docs_filter) if ignored_documents: log_main_rank( f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._sequence_length+1} tokens and will be ignored.", log_fn=logger.warning, ) - tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() + tokens_per_epoch = (document_sizes[~long_docs_filter] + image_token_sizes[~long_docs_filter]).sum().item() if tokens_per_epoch == 0: raise RuntimeError( f" > No documents shorter than {self._sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." @@ -177,6 +190,7 @@ def _sample(self) -> None: "num_samples": self._num_samples, "unshuffled_epochs": unshuffled_epochs, "sequence_length": self._sequence_length, + "patch_size": self._patch_size, "truncate_documents": self._truncate_documents, "config": self._config.to_serialized(), } @@ -258,7 +272,7 @@ def _sample(self) -> None: # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` if unshuffled_epochs > 0: token_cumsum_unshuffled, num_tokens_unshuffled = self._get_token_cumsum( - document_sizes, + document_sizes + image_token_sizes, offset=0, # TODO: Allowing for max 100% extra tokens for padding, is that enough? dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), @@ -282,6 +296,9 @@ def _sample(self) -> None: document_shuffling.to( dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 ) + ] + + image_token_sizes[ + document_shuffling.to(torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32) ], offset=num_tokens_unshuffled, # TODO: Allowing for max 100% extra tokens for padding, is that enough? @@ -360,6 +377,9 @@ def __getitem__(self, index: int) -> typing.Any: token_ids = [] loss_masking_spans = [] + images = [] + image_positions = [] + image_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -367,7 +387,7 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size = self._indexed_dataset.get_document_size(document_index) + document_size = self._indexed_dataset.get_document_size(document_index, self._patch_size) if not self._truncate_documents: if document_size > self._sequence_length + 1: @@ -398,6 +418,12 @@ def __getitem__(self, index: int) -> typing.Any: length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._config.use_loss_masking_spans, ) + # TODO Soham: handle images with loss masking spans + for idx, im_position in enumerate(sample.image_positions): + # image_positions.append(im_positions + len(token_ids) + image_tokens_added) + image_positions.append(im_position + len(token_ids) + image_tokens_added) + image_tokens_added += ImageProcessor.get_num_patches(sample.images[idx]) + images.append(sample.images) token_ids.append(sample.token_ids) if self._config.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: @@ -411,6 +437,7 @@ def __getitem__(self, index: int) -> typing.Any: sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) + + np.array([ImageProcessor.get_num_patches(image) for image in images[idx] for idx in range(len(images))]) if not self._cross_document_attention else None ) @@ -420,9 +447,16 @@ def __getitem__(self, index: int) -> typing.Any: if self._config.use_loss_masking_spans else None ) - Assert.eq(len(token_ids), self._sequence_length + 1) - - return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) + images = [im for img_list in images for im in img_list] + Assert.eq(len(token_ids) + image_tokens_added, self._sequence_length + 1) + + return GPTSample( + token_ids=token_ids, + loss_masking_spans=loss_masking_spans, + sequence_lengths=sequence_lengths, + images=images, + image_positions=image_positions, + ) @property def name(self) -> str: diff --git a/fast_llm/data/image_processor.py b/fast_llm/data/image_processor.py index c5cbe9095..567c81469 100644 --- a/fast_llm/data/image_processor.py +++ b/fast_llm/data/image_processor.py @@ -26,21 +26,26 @@ def resize(self, image): # TODO Soham: convert all images to tensor? # height = image.shape[0] # width = image.shape[1] - height, width = image.size - ratio = max(height / self.max_height, width / self.max_width) - if ratio > 1: - height = math.ceil(height / ratio) - width = math.ceil(width / ratio) - else: - height = self.patch_size[0] * math.ceil(height / self.patch_size[0]) - width = self.patch_size[1] * math.ceil(width / self.patch_size[1]) + height, width = self.get_resize_dims(image.shape[0], image.shape[1], self.max_height, self.max_width) # TODO: options for interpolation mode return F.resize(image, size=(height, width), interpolation=F.InterpolationMode.BICUBIC) + # TODO Soham: move to utils + @classmethod + def get_resize_dims(height, width, max_height, max_width, patch_size: list[int]): + ratio = max(height / max_height, width / max_width) + return ( + (math.ceil(height / ratio), math.ceil(width / ratio)) + if ratio > 1 + else (patch_size[0] * math.ceil(height / patch_size[0]), patch_size[1] * math.ceil(width / patch_size[1])) + ) + def normalize(self, image: torch.Tensor) -> torch.Tensor: # Normalize the image using the mean and std return F.normalize(image, mean=self.mean, std=self.std) - def get_num_patches(self, image: torch.Tensor) -> torch.Tensor: - return (image.size[0] // self.patch_size[0]) * (image.size[1] // self.patch_size[1]) + @classmethod + # TODO Soham: move to utils + def get_num_patches(image: torch.Tensor, patch_size: list[int]) -> torch.Tensor: + return (image.size[0] // patch_size[0]) * (image.size[1] // patch_size[1]) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 8a15d96c8..89fe904cd 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -3,7 +3,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.data.config import MultiModalProcessorConfig +from fast_llm.data.config import TokenizerConfig from fast_llm.data.preparator.config import DatasetPreparatorConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -168,9 +168,9 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for the dataset.", hint=FieldHint.feature, ) - data_processor: MultiModalProcessorConfig = Field( - default_factory=MultiModalProcessorConfig, - desc="Configuration for data processing. Describes the tokenizer and image processor", + tokenizer: TokenizerConfig = Field( + default_factory=TokenizerConfig, + desc="Tokenizer configuration.", hint=FieldHint.feature, ) splits: dict[str, float] | None = Field( @@ -182,7 +182,7 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): # TODO Soham: move tokenizer validation to MultiModalDataProcessor def _validate(self) -> None: - assert self.data_processor.tokenizer.path is not None + assert self.tokenizer.path is not None if self.dataset.data_type is not None: Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) super()._validate() diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 0199cb400..4965dfdfc 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -10,12 +10,12 @@ import datasets import huggingface_hub import numpy as np +import PIL.Image import requests import torch.distributed import tqdm import transformers import yaml -from PIL import Image from fast_llm.data.dataset.gpt.config import ( GPTBlendedDatasetConfig, @@ -26,9 +26,9 @@ ) from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample -from fast_llm.data.multi_modal_processor import MultiModalProcessor from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig +from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -38,8 +38,7 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig - # _tokenizer: Tokenizer - _data_processor: MultiModalProcessor + _tokenizer: Tokenizer _data_type: DataType def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: @@ -60,7 +59,7 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ np.array(image_token_positions, dtype=np.int32), ) for input_ids, image_token_positions in [ - self._data_processor.tokenize( + self._tokenizer.tokenize( text, im_char_positions, ) @@ -73,17 +72,18 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ ), ) num_tokens = [len(x) for x in input_ids] - # TODO Soham: is this ok? Should we get num_image_tokens separately? + num_pixels = [0] * len(input_ids) for idx, images in enumerate(batch.get("images", [])): for bytes_im in images: - with Image.open(io.BytesIO(bytes_im["bytes"])) as im: + with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: width, height = im.size - num_tokens[idx] += (width * height * 3) // np.dtype(self._dtype).itemsize + num_pixels[idx] += width * height * 3 return { "input_ids": input_ids, "image_positions": image_token_positions, "num_tokens": num_tokens, + "num_pixels": num_pixels, } def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: @@ -98,7 +98,7 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict np.array(image_token_positions, dtype=np.int32), ) for input_ids, token_spans, images, image_token_positions in [ - self._data_processor.tokenize_with_spans(text, char_spans) + self._tokenizer.tokenize_with_spans(text, char_spans) for text, char_spans in zip( batch[self._config.dataset.field], batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), @@ -109,16 +109,20 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict ] ), ) - num_tokens = [ - len(x) + sum([self._data_processor._image_processor.get_num_patches(im) for im in doc_images]) - for x, doc_images in zip(input_ids, images) - ] + num_tokens = [len(x) for x in input_ids] + num_pixels = [0] * len(input_ids) + for idx, images in enumerate(images): + for bytes_im in images: + with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: + width, height = im.size + num_pixels[idx] += width * height * 3 return { "input_ids": input_ids, "token_spans": token_spans, "images": images, "image_positions": image_token_positions, "num_tokens": num_tokens, + "num_pixels": num_pixels, } def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: @@ -136,7 +140,8 @@ def _document_generator(): else None ), # [np.array(Image.open(pathlib.Path(self._config.dataset.path) / path)) for path in item["image_paths"]] if self._config.dataset.image_paths else None, - [np.array(im) for im in item["images"]] if self._config.dataset.images else None, + # [np.array(im) for im in item["images"]] if self._config.dataset.images else None, + item["images"] if self._config.dataset.images else None, item["image_positions"] if self._config.dataset.image_positions else None, ) # if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: @@ -157,6 +162,7 @@ def _document_generator(): "path": prefix, "num_documents": len(shard_dataset), # Use the length of the shard dataset directly "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), + "num_pixels": sum(doc["num_pixels"] for doc in shard_dataset), } ) @@ -225,12 +231,12 @@ def run(self) -> None: if self._config.dataset.disable_disk_space_check: datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True - # Load the data processor - self._data_processor = MultiModalProcessor(config=self._config.data_processor) + # Load tokenizer + self._tokenizer = Tokenizer(config=self._config.tokenizer) # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( - get_unsigned_integer_type(self._data_processor._tokenizer.vocab_size) + get_unsigned_integer_type(self._tokenizer.vocab_size) if self._config.dataset.data_type is None else self._config.dataset.data_type ) @@ -293,6 +299,12 @@ def run(self) -> None: # Calculate total number of tokens total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens")) + total_pixels = ( + sum(tqdm.tqdm(tokenized_dataset["num_pixels"], desc="Counting pixels", unit="pixels")) + if self._config.dataset.images + else 0 + ) + total_tokens += total_pixels // np.dtype(self._data_type.numpy).itemsize # Split dataset into shards based on number of tokens num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) @@ -391,7 +403,8 @@ def _split_and_blend_dataset_configs( # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() - sizes_cumsum = dataset.get_document_sizes().cumsum() + # TODO Soham: handle pixels (could still work with number of tokens?) + sizes_cumsum = dataset.get_document_sizes()[0].cumsum() Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 28e105ee8..0e7d54709 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -35,13 +35,41 @@ def vocab(self) -> dict[str, int]: def inv_vocab(self) -> dict[int, str]: return self._inv_vocab - def tokenize(self, text: str, begin=True, end=True) -> list[int]: + def _tokenize(self, text: str, begin=True, end=True) -> list[int]: return ( ([self.bod_id] if begin else []) + self.tokenizer.encode(text, add_special_tokens=False) + ([self.eod_id] if end else []) ) + def tokenize(self, text, image_positions=None): + if not image_positions: + return self._tokenize(text), [], [] + image_idx = 0 + char_pos = 0 + token_ids = [] + image_token_positions = [] + beginning_of_text = True + while image_idx < len(image_positions): + if image_positions[image_idx] > len(text): + raise ValueError( + f"Image position {image_positions[image_idx]} is greater than text length {len(text)}" + ) + curr_text = text[char_pos : image_positions[image_idx]] + tokenized_text = self._tokenize( + curr_text, begin=beginning_of_text, end=image_positions[image_idx] >= len(text) + ) + beginning_of_text = False + token_ids.extend(tokenized_text) + image_token_positions = len(token_ids) + char_pos = image_positions[image_idx] + image_idx += 1 + if char_pos < len(text): + curr_text = text[char_pos:] + tokenized_text = self._tokenize(curr_text, begin=beginning_of_text, end=True) + token_ids.extend(tokenized_text) + return token_ids, image_token_positions + def tokenize_with_spans( self, text: str, char_spans: list[tuple[int, int]] ) -> tuple[list[int], list[tuple[int, int]]]: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 3bd796033..75c5418bb 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -6,6 +6,7 @@ from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig +from fast_llm.layers.vision_encoder.config import VisionArchitectureConfig from fast_llm.utils import Assert @@ -198,3 +199,16 @@ def _validate(self) -> None: if self.init_method_max_embed is not None and self.init_method_min_embed is not None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) super()._validate() + + +class MultiModalBaseConfig: + language_model: LanguageModelBaseConfig = Field( + default_factory=LanguageModelBaseConfig, + desc="Configuration for the language model.", + hint=FieldHint.core, + ) + vision_model: VisionArchitectureConfig = Field( + default_factory=VisionArchitectureConfig, + desc="Configuration for the vision inputs.", + hint=FieldHint.core, + ) From 841398396714e5c3b346d6d2c46dcb37f532c167 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 24 Apr 2025 07:55:31 +0000 Subject: [PATCH 005/161] wip --- fast_llm/data/data/gpt/config.py | 1 - fast_llm/data/data/gpt/data.py | 31 ++- fast_llm/data/dataset/gpt/config.py | 7 +- fast_llm/data/dataset/gpt/indexed.py | 12 +- fast_llm/data/dataset/gpt/memmap.py | 97 +++++---- fast_llm/data/dataset/gpt/sampled.py | 32 ++- fast_llm/data/image_processor.py | 10 +- fast_llm/engine/schedule/config.py | 15 ++ fast_llm/layers/language_model/config.py | 13 +- fast_llm/models/gpt/config.py | 4 + fast_llm/models/gpt/conversion.py | 258 +++++++++++++++++++++-- fast_llm/models/gpt/model.py | 12 +- fast_llm/models/gpt/trainer.py | 3 + 13 files changed, 400 insertions(+), 95 deletions(-) diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 652342b58..c98a781e6 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -21,7 +21,6 @@ class GPTSamplingDefaultConfig(SamplingDefaultConfig, GPTSamplingConfig): gpu: bool = FieldUpdate(default=True) use_loss_masking_spans: bool = FieldUpdate(default=False) - use_images: bool = FieldUpdate(default=False) shuffle: ShufflingType = FieldUpdate(default=ShufflingType.epoch) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 5bd9d09e2..22e4730c9 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -36,12 +36,11 @@ class GPTBatch: image_positions: list[torch.Tensor] | None = None -# TODO: do we need a separate use_images? +# TODO: collate images def gpt_data_collate_fn( batch: list[GPTSample], use_loss_masking_spans: bool, cross_document_attention: bool, - use_images: bool, ) -> GPTBatch: stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None @@ -50,8 +49,24 @@ def gpt_data_collate_fn( stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] if not cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] + batch_images = [] + for sample in batch: + if sample.images is not None: + batch_images.append([torch.from_numpy(image) for image in sample.images]) + else: + batch_images.append(None) + batch_image_positions = [] + for sample in batch: + if sample.image_positions is not None: + batch_image_positions.append(torch.from_numpy(sample.image_positions)) + else: + batch_image_positions.append(None) return GPTBatch( - token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths + token_ids=torch.from_numpy(stacked_ids), + loss_masking_spans=stacked_spans, + sequence_lengths=sequence_lengths, + images=batch_images if any(batch_images) else None, + image_positions=batch_image_positions if any(batch_image_positions) else None, ) @@ -73,6 +88,9 @@ def __init__( vocab_size: int, max_sequence_length: int, cross_document_attention: bool = True, + patch_size: list[int] | None = None, + max_image_height: int | None = None, + max_image_width: int | None = None, ): """ Create the data and gather some basic information on the dataset(s). @@ -82,6 +100,9 @@ def __init__( self._vocab_size = vocab_size self._max_sequence_length = max_sequence_length self._cross_document_attention = cross_document_attention + self._patch_size = patch_size + self._max_image_height = max_image_height + self._max_image_width = max_image_width def setup( self, @@ -129,6 +150,9 @@ def setup( tokenizer=self._tokenizer, truncate_documents=self._config.truncate_documents, cross_document_attention=self._cross_document_attention, + patch_size=self._patch_size, + image_height=self._max_image_height, + image_width=self._max_image_width, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) @@ -176,7 +200,6 @@ def get_iterator( gpt_data_collate_fn, use_loss_masking_spans=self._config.sampling.use_loss_masking_spans, cross_document_attention=self._cross_document_attention, - use_images=self._config.sampling.use_images, ), multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 45d27e7d0..8022a05f7 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -57,11 +57,6 @@ class GPTSamplingConfig(SamplingConfig): desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) - use_images: bool | None = Field( - default=None, - desc="Use images in the dataset.", - hint=FieldHint.feature, - ) shuffle: ShufflingType | None = Field( default=None, desc="Shuffling strategy.", @@ -79,6 +74,8 @@ class GPTSamplingData(SamplingData): truncate_documents: bool = True cross_document_attention: bool = True patch_size: list[int] | None = None + image_height: int | None = None + image_width: int | None = None @config_class() diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 688ea6a70..209c6e317 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -11,6 +11,7 @@ class GPTIndexedDataset(IndexedDataset): + # TODO Soham: should we change this to include images? @abc.abstractmethod def get_document_sizes(self) -> np.ndarray: """ @@ -44,10 +45,15 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. - return self._dataset.get_document_sizes()[self._begin : self._end] + doc_sizes, im_sizes = self._dataset.get_document_sizes() + return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] - def get_document_size(self, index: int) -> int: - return self._dataset.get_document_size(self._begin + index) + def get_document_size(self, index: int, patch_size: list[int]) -> int: + return self._dataset.get_document_size(self._begin + index, patch_size) + + @property + def has_images(self) -> bool: + return self._dataset.has_images class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 87bd3a8eb..43fba843c 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -103,17 +103,17 @@ def _init( + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize + sum([x.nbytes for x in self._spans]) ) - self._n_pixels = 0 + self._num_pixels = 0 if self._has_images and self._version >= 3: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) - self._im_lengths = [] - self._im_positions = [] + self._image_lengths = [] + self._image_positions = [] images_seen = 0 # TODO Soham: verify correctness, reshaping into width, height? for n_images in self._n_images: - self._im_lengths.append( + self._image_lengths.append( np.frombuffer( self._index_bin_buffer, dtype=np.int32, @@ -121,8 +121,8 @@ def _init( offset=offset + self._n_images.nbytes + 2 * images_seen * np.dtype(np.int32).itemsize, ).reshape(-1, 2) ) - self._n_pixels += self._im_lengths[-1].prod(axis=1, initial=3).sum() - self._im_positions.append( + self._num_pixels += self._image_lengths[-1].prod(axis=1, initial=3).sum() + self._image_positions.append( np.frombuffer( self._index_bin_buffer, dtype=np.int32, @@ -140,14 +140,14 @@ def _init( # TODO Soham: fix num_tokens to include images. Get total number of image pixels from index file and assign # self._num_tokens = div(self._bin_buffer_mmap.size - n_pixels, np.dtype(self._dtype).itemsize) - self._num_tokens = div(self._bin_buffer_mmap.size - self._n_pixels, np.dtype(self._dtype).itemsize) + self._num_tokens = div(self._bin_buffer_mmap.size - self._num_pixels, np.dtype(self._dtype).itemsize) if num_pixels is not None: - assert self._n_pixels == num_pixels + assert self._num_pixels == num_pixels if num_tokens is not None: assert self._num_tokens == num_tokens def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: - return (self._name, self._prefix, self._num_documents, self._num_tokens) + return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) def __setstate__(self, state: tuple[str, pathlib.Path, int | None, int | None]): self._init(*state) @@ -169,7 +169,7 @@ def get( use_loss_masking_spans: bool = False, # , patch_size: tuple(int), max_height: int, max_width: int ): - # TODO Soham: Handle truncations? + # TODO Soham: handle spans # if self._has_images: # doc_size = self._document_sizes[idx] # n_images = self._n_images[idx] @@ -188,34 +188,42 @@ def get( offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) if self._has_images: - image_positions = self._im_positions[idx] - images = np.frombuffer( + image_positions = self._image_positions[idx] + pixels = np.frombuffer( self._bin_buffer, - dtype=np.dtype(np.uint8).itemsize, - count=self._image_lengths[idx][0] * self._image_lengths[idx][1] * 3, + dtype=np.dtype(np.uint8), + count=self._image_lengths[idx].prod(initial=3), offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, ) + images = [] + start = 0 + for image_length in self._image_lengths[idx]: + # TODO Soham: verify reshape dimension order + n_pixels = image_length.prod(initial=3) + images.append(pixels[start : start + n_pixels].reshape(image_length[0], image_length[1], 3)) + start += n_pixels + # TODO Soham: return loss_masking_spans return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) - def get( - self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False - ) -> GPTSample: - token_ids = np.frombuffer( - self._bin_buffer, - dtype=self._dtype, - count=self._document_sizes[idx] - offset if length is None else length, - offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, - ) - sample_spans = None - if use_loss_masking_spans and self._spans is not None: - sample_spans = self._spans[idx] - # adjust the spans for the offset and length - sample_spans = sample_spans[ - (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) - ] - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset - sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans) + # def get( + # self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False + # ) -> GPTSample: + # token_ids = np.frombuffer( + # self._bin_buffer, + # dtype=self._dtype, + # count=self._document_sizes[idx] - offset if length is None else length, + # offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, + # ) + # sample_spans = None + # if use_loss_masking_spans and self._spans is not None: + # sample_spans = self._spans[idx] + # # adjust the spans for the offset and length + # sample_spans = sample_spans[ + # (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) + # ] + # sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset + # sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset + # return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans) @property def name(self) -> str: @@ -233,20 +241,21 @@ def has_images(self) -> bool: return self._has_images # TODO: image sizes - def get_document_sizes(self) -> np.ndarray: + def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes, self._im_lengths + return self._document_sizes, self._image_lengths def get_document_size(self, index: int, patch_size: list[int]) -> int: - return self._document_sizes[index].item() + ( - sum((h // patch_size[0]) * (w // patch_size[1]) for h, w in self._image_lengths[index]) - if self._has_images - else 0 - ) + # return self._document_sizes[index].item() + ( + # sum((h // patch_size[0]) * (w // patch_size[1]) for h, w in self._image_lengths[index]) + # if self._has_images + # else 0 + # ) + return self._document_sizes[index].item(), self._image_lengths[index] if self._has_images else [] @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): @@ -255,7 +264,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP num_documents = 0 doc_lengths = [] n_images = [] - im_lengths = [] + image_lengths = [] im_positions = [] total_images = 0 pointers = [] @@ -288,7 +297,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # assume 3 channels (RGB) for all images with PIL.Image.open(io.BytesIO(image["bytes"])) as img: pixels = np.array(img) - im_lengths.append(np.array(pixels.shape[:2])) + image_lengths.append(np.array(pixels.shape[:2])) bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size im_positions.append(document.image_positions) @@ -316,7 +325,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP n_images = np.array(n_images, dtype=np.int32) else: n_images = np.array([]) - im_lengths = np.stack(im_lengths, dtype=np.int32) + image_lengths = np.stack(image_lengths, dtype=np.int32) im_positions = np.array(im_positions, dtype=np.int32) # Write the index file (.idx) @@ -347,7 +356,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Number of images per document idx_stream.write(n_images.tobytes(order="C")) # n_pixels * 3 per image - idx_stream.write(im_lengths.tobytes(order="C")) + idx_stream.write(image_lengths.tobytes(order="C")) # Position of each image in the document idx_stream.write(im_positions.tobytes(order="C")) # Document indices, unused but needed for compatibility with Megatron-LM diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 288018b12..8acbf9ee6 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -91,11 +91,14 @@ def __init__( self._num_samples = sampling.num_samples self._sequence_length = sampling.sequence_length self._patch_size = sampling.patch_size + self._image_height = sampling.image_height + self._image_width = sampling.image_width self._cross_document_attention = sampling.cross_document_attention self._config = sampling.config self._truncate_documents = sampling.truncate_documents self._device = torch.device("cuda" if self._config.gpu else "cpu") + # TODO Soham: use something else for this check, introducing has_images for just this check might be unnecessary. if self._indexed_dataset.has_images and self._truncate_documents: raise RuntimeError( "Truncating documents with images is not supported. Please turn off truncation to use images." @@ -137,8 +140,9 @@ def _sample(self) -> None: document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) image_token_sizes = torch.zeros_like(document_sizes).to(self._device) + # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): - image_token_sizes[i] = sum(sizes[0, :] // self._patch_size[0]) * (sizes[:, 1] // self._patch_size[1]) + image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size[0]) * (sizes[:, 1] // self._patch_size[1])) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() @@ -387,15 +391,26 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size = self._indexed_dataset.get_document_size(document_index, self._patch_size) + document_size, image_lengths = self._indexed_dataset.get_document_size(document_index, self._patch_size) + + image_sizes = [ + ImageProcessor.get_num_patches_from_dims( + *ImageProcessor.get_resize_dims( + *image_length, self._image_height, self._image_width, self._patch_size + ), + self._patch_size, + ) + for image_length in image_lengths + ] + image_tokens = sum(image_sizes) if not self._truncate_documents: - if document_size > self._sequence_length + 1: + if document_size + image_tokens > self._sequence_length + 1: # Document too long, ignore document_sampling_index += 1 continue tokens_in_sample = token_count % (self._sequence_length + 1) - if document_size + tokens_in_sample > self._sequence_length + 1: + if document_size + image_tokens + tokens_in_sample > self._sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._sequence_length + 1 - tokens_in_sample if token_count > token_start: @@ -408,7 +423,7 @@ def __getitem__(self, index: int) -> typing.Any: token_count += padding_size # Determine if the document belongs to the requested sample. - if token_count + document_size >= token_start: + if token_count + document_size + image_tokens >= token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) @@ -422,7 +437,7 @@ def __getitem__(self, index: int) -> typing.Any: for idx, im_position in enumerate(sample.image_positions): # image_positions.append(im_positions + len(token_ids) + image_tokens_added) image_positions.append(im_position + len(token_ids) + image_tokens_added) - image_tokens_added += ImageProcessor.get_num_patches(sample.images[idx]) + image_tokens_added += image_tokens images.append(sample.images) token_ids.append(sample.token_ids) if self._config.use_loss_masking_spans: @@ -433,7 +448,7 @@ def __getitem__(self, index: int) -> typing.Any: # Go to the next document. document_sampling_index += 1 - token_count += document_size + token_count += document_size + image_tokens sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) @@ -447,7 +462,8 @@ def __getitem__(self, index: int) -> typing.Any: if self._config.use_loss_masking_spans else None ) - images = [im for img_list in images for im in img_list] + images = [im for img_list in images for im in img_list] if images else None + image_positions = np.array(image_positions) if image_positions else None Assert.eq(len(token_ids) + image_tokens_added, self._sequence_length + 1) return GPTSample( diff --git a/fast_llm/data/image_processor.py b/fast_llm/data/image_processor.py index 567c81469..edfeceb95 100644 --- a/fast_llm/data/image_processor.py +++ b/fast_llm/data/image_processor.py @@ -33,7 +33,7 @@ def resize(self, image): # TODO Soham: move to utils @classmethod - def get_resize_dims(height, width, max_height, max_width, patch_size: list[int]): + def get_resize_dims(self, height, width, max_height, max_width, patch_size: list[int]): ratio = max(height / max_height, width / max_width) return ( (math.ceil(height / ratio), math.ceil(width / ratio)) @@ -47,5 +47,9 @@ def normalize(self, image: torch.Tensor) -> torch.Tensor: @classmethod # TODO Soham: move to utils - def get_num_patches(image: torch.Tensor, patch_size: list[int]) -> torch.Tensor: - return (image.size[0] // patch_size[0]) * (image.size[1] // patch_size[1]) + def get_num_patches(self, image: torch.Tensor, patch_size: list[int]) -> torch.Tensor: + return (image.shape[0] // patch_size[0]) * (image.shape[1] // patch_size[1]) + + @classmethod + def get_num_patches_from_dims(self, height: int, width: int, patch_size: list[int]) -> torch.Tensor: + return (height // patch_size[0]) * (width // patch_size[1]) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 83d3d51a3..16cfaf713 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -55,6 +55,21 @@ class BatchConfig(Config): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) + patch_size: list[int] | None = Field( + default=None, + desc="Patch size for each image token", + hint=FieldHint.optional, + ) + max_image_height: int | None = Field( + default=None, + desc="Maximum image height for each image token", + hint=FieldHint.optional, + ) + max_image_width: int | None = Field( + default=None, + desc="Maximum image width for each image token", + hint=FieldHint.optional, + ) num_micro_sequences: int = Field( init=False, desc="Number of micro-sequences to split each sample (= seqence length / micro-sequence length).", diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 75c5418bb..0175296c5 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -125,6 +125,11 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) + vision_encoder: VisionArchitectureConfig | None = Field( + default=None, + desc="Configuration for the vision encoder that transforms images into embeddings.", + hint=FieldHint.optional, + ) init_method_std_embed: float = Field( default=None, desc="Initialization scale for the vocabulary embedding and output weights (logits).", @@ -200,8 +205,14 @@ def _validate(self) -> None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) super()._validate() + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + super().setup_tensor_space(tensor_space) + + if self.vision_encoder is not None: + self.vision_encoder.setup_tensor_space(tensor_space) + -class MultiModalBaseConfig: +class MultiModalBaseConfig(BaseModelConfig): language_model: LanguageModelBaseConfig = Field( default_factory=LanguageModelBaseConfig, desc="Configuration for the language model.", diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 5a21368fa..c90da81b3 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -48,6 +48,10 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mixtral" +class PixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "pixtral" + + @config_class() class GPTArchitectureConfig(LanguageModelArchitectureConfig): _abstract = False diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 30ae80416..30f54f06d 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -32,6 +32,7 @@ LlamaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, + PixtralGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) @@ -163,54 +164,65 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig def _create_weight_converters( self, + hf_base_prefix: str = "", + fast_llm_offset: int = 0, ) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers # Embeddings - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + converters.append( + WeightConverter( + f"layers.{fast_llm_offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight" + ) + ) - converters += self._create_lm_head_converters() + converters += self._create_lm_head_converters(hf_base_prefix, fast_llm_offset) for i in range(num_layers): - converters += self._create_transformer_layer_converters(i) + converters += self._create_transformer_layer_converters(i, hf_base_prefix, fast_llm_offset) return converters - def _create_transformer_layer_converters(self, i: int, ignore_export: bool = False) -> list[WeightConverter]: + def _create_transformer_layer_converters( + self, i: int, ignore_export: bool = False, hf_base_prefix: str = "", fast_llm_offset: int = 1 + ) -> list[WeightConverter]: transformer_config: TransformerConfig = self._model.config.base_model.transformer norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm converters = [] names_bias_cls = [ # Self-attn ( - f"layers.{i+1}.self_attn.query", - f"model.layers.{i}.self_attn.q_proj", + f"layers.{i+fast_llm_offset}.self_attn.query", + f"{hf_base_prefix}model.layers.{i}.self_attn.q_proj", transformer_config.add_attn_qkv_bias, QueryWeightConverter, ), ( - f"layers.{i+1}.self_attn.key_value", - (f"model.layers.{i}.self_attn.k_proj", f"model.layers.{i}.self_attn.v_proj"), + f"layers.{i+fast_llm_offset}.self_attn.key_value", + ( + f"{hf_base_prefix}model.layers.{i}.self_attn.k_proj", + f"{hf_base_prefix}model.layers.{i}.self_attn.v_proj", + ), transformer_config.add_attn_qkv_bias, KeyValueWeightConverter, ), ( - f"layers.{i+1}.self_attn.dense", - f"model.layers.{i}.self_attn.o_proj", + f"layers.{i+fast_llm_offset}.self_attn.dense", + f"{hf_base_prefix}model.layers.{i}.self_attn.o_proj", transformer_config.add_attn_dense_bias, WeightConverter, ), # Norm ( - f"layers.{i+1}.norm_1", - f"model.layers.{i}.input_layernorm", + f"layers.{i+fast_llm_offset}.norm_1", + f"{hf_base_prefix}model.layers.{i}.input_layernorm", norm_bias, WeightConverter, ), ( - f"layers.{i+1}.norm_2", - f"model.layers.{i}.post_attention_layernorm", + f"layers.{i+fast_llm_offset}.norm_2", + f"{hf_base_prefix}model.layers.{i}.post_attention_layernorm", norm_bias, WeightConverter, ), @@ -226,17 +238,23 @@ def _create_transformer_layer_converters(self, i: int, ignore_export: bool = Fal # MLP if ignore_export: converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mlp.layer_1", (), transformer_config.add_mlp_bias, cls=IgnoreExportWeightConverter + f"layers.{i+fast_llm_offset}.mlp.layer_1", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mlp.layer_2", (), transformer_config.add_mlp_bias, cls=IgnoreExportWeightConverter + f"layers.{i+fast_llm_offset}.mlp.layer_2", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, ) - converters += [IgnoreExportWeightConverter(f"layers.{i+1}.mlp.router.weight", ())] + converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] else: - converters += self._get_mlp_converters(f"layers.{i+1}", f"model.layers.{i}") + converters += self._get_mlp_converters(f"layers.{i+fast_llm_offset}", f"{hf_base_prefix}model.layers.{i}") return converters - def _create_lm_head_converters(self) -> list[WeightConverter]: + def _create_lm_head_converters(self, hf_base_prefix: str, fast_llm_offset: int = 1) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm @@ -245,15 +263,20 @@ def _create_lm_head_converters(self) -> list[WeightConverter]: # Next-token prediction head # Final norm converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias + f"layers.{num_layers + fast_llm_offset}.final_norm", f"{hf_base_prefix}model.norm", norm_bias ) # Output weights if self._model.config.base_model.tie_word_embeddings: - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + converters.append(IgnoreImportWeightConverter((), f"{hf_base_prefix}lm_head.weight")) else: - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + converters.append( + WeightConverter( + f"layers.{num_layers + fast_llm_offset}.output_weights", f"{hf_base_prefix}lm_head.weight" + ) + ) # MTP-heads > 0 are thrown away + # TODO Soham: handle offset with MTP for i in range(1, prediction_heads): logger.warning( f"The model weights for the multi-token prediction head {i} are discarded during conversion." @@ -531,6 +554,196 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] +class PixtralHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + lm_converters = super()._create_config_converters() + for converter in lm_converters: + if converter.fast_llm_names[0][0] == "transformer": + converter.export_names[0] = ("text_config", *converter.export_names[0]) + return lm_converters + [ + # Multimodal adapter + RenameParamConverter( + fast_llm_names=(("vision_encoder", "adapter_size"),), + export_names=( + ( + "text_config", + "hidden_size", + ) + ), + ), + # Image processing and conv layer + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "image_size"),), + export_names=( + ( + "vision_config", + "image_size", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "patch_size"),), + export_names=( + ( + "vision_config", + "patch_size", + ) + ), + ), + # Vision Transformer + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "num_hidden_layers"),), + export_names=( + ( + "vision_config", + "num_hidden_layers", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "hidden_size"),), + export_names=( + ( + "vision_config", + "hidden_size", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "num_attention_heads"),), + export_names=( + ( + "vision_config", + "num_attention_heads", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "intermediate_size"),), + export_names=( + ( + "vision_config", + "intermediate_size", + ) + ), + ), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "activation_type"),), + export_names=( + ( + "vision_config", + "hidden_act", + ) + ), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "num_channels"),), + export_names=( + ( + "vision_config", + "num_channels", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "attention_dropout"),), + export_names=( + ( + "vision_config", + "attention_dropout", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "rope_theta"),), + export_names=(("vision_config", "rope_theta"),), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "initializer_range"),), + export_names=(("vision_config", "initializer_range"),), + ), + ] + + def _create_vision_transformer_converters(self) -> list[WeightConverter]: + num_layers = self._model.config.base_model.vision_encoder.encoder.num_hidden_layers + vision_transformer_converters = [] + for i in range(num_layers): + vision_transformer_converters += [ + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.k_proj.weight", + f"vision_tower.transformer.layers.{i}.attention.k_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.v_proj.weight", + f"vision_tower.transformer.layers.{i}.attention.v_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.q_proj.weight", + f"vision_tower.transformer.layers.{i}.attention.q_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.o_proj.weight", + f"vision_tower.transformer.layers.{i}.attention.o_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention_norm.weight", + f"vision_tower.transformer.layers.{i}.attention_norm.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.down_proj.weight", + f"vision_tower.transformer.layers.{i}.feed_forward.down_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", + f"vision_tower.transformer.layers.{i}.feed_forward.gate_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.up_proj.weight", + f"vision_tower.transformer.layers.{i}.feed_forward.up_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.ffn_norm.weight", + f"vision_tower.transformer.layers.{i}.ffn_norm.weight", + ), + ] + + return vision_transformer_converters + + def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: + patch_conv_converter = WeightConverter( + "layers.0._vision_encoder.patch_conv.weight", + "vision_tower.patch_conv.weight", + ) + vision_transformer_converters = self._create_vision_transformer_converters() + adapter_converters = [ + WeightConverter( + "layers.0._vision_encoder._adapter.layer_1.weight", + "multi_modal_projector.linear_1.weight", + ), + WeightConverter( + "layers.0._vision_encoder._adapter.layer_1.bias", + "multi_modal_projector.linear_1.bias", + ), + WeightConverter( + "layers.0._vision_encoder._adapter.layer_2.weight", + "multi_modal_projector.linear_2.weight", + ), + WeightConverter( + "layers.0._vision_encoder._adapter.layer_2.bias", + "multi_modal_projector.linear_2.bias", + ), + ] + return [patch_conv_converter] + vision_transformer_converters + adapter_converters + + def _create_weight_converters(self) -> list[WeightConverter]: + vision_encoder_converter = self._create_vision_encoder_weight_converters() + lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=1) + return vision_encoder_converter + lm_converters + + class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat @@ -580,4 +793,5 @@ class AutoGPTHuggingfaceCheckpointHandler( Qwen2GPTHuggingfaceCheckpointFormat.name: Qwen2HuggingfaceCheckpointHandler, MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler, MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, + PixtralGPTHuggingfaceCheckpointFormat.name: PixtralHuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index e878530cf..674116413 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -26,6 +26,7 @@ RotaryEmbeddingPreprocessor, ) from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.vision_encoder.encoder import VisionEncoder from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -100,7 +101,10 @@ def get_layers(self) -> list[Layer]: LanguageModelEmbedding(self._config, self._tensor_space), LanguageModelHead(self._config, self._tensor_space, 0), ] - return [ + return ( + [VisionEncoder(self._config, self._tensor_space)] if self._config.vision_encoder is not None else [] + ) + [ + # return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ TransformerLayer( @@ -312,11 +316,11 @@ def preprocess( @property def embedding(self) -> LanguageModelEmbedding: - return self.layers[0] + return self.layers[self._config.vision_encoder is not None] @property def transformer_layers(self) -> list[TransformerLayer]: - return self.layers[1:-1] + return self.layers[(self._config.vision_encoder is not None) + 1 : -1] @property def model_head(self) -> LanguageModelHead: @@ -331,7 +335,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: return { WORD_EMBEDDINGS_WEIGHT: ( self.embedding.word_embeddings_weight, - (0, *self.model_head_indices), + (self._config.vision_encoder is not None, *self.model_head_indices), ) } elif self._config.prediction_heads > 1: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 376d8b840..b801fbd3d 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -21,6 +21,9 @@ def _get_data(self) -> GPTData: vocab_size=self._config.model.base_model.vocab_size, max_sequence_length=self._config.batch.sequence_length, cross_document_attention=self._config.batch.cross_document_attention, + patch_size=self._config.batch.patch_size, + max_image_height=self._config.batch.max_image_height, + max_image_width=self._config.batch.max_image_width, ) def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: From 6521e41920fe8b17f207b32f58c43978bfcc8a46 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 24 Apr 2025 19:00:23 +0000 Subject: [PATCH 006/161] vision model --- fast_llm/layers/vision_encoder/adapter.py | 44 ++++++++ fast_llm/layers/vision_encoder/config.py | 128 ++++++++++++++++++++++ fast_llm/layers/vision_encoder/encoder.py | 89 +++++++++++++++ 3 files changed, 261 insertions(+) create mode 100644 fast_llm/layers/vision_encoder/adapter.py create mode 100644 fast_llm/layers/vision_encoder/config.py create mode 100644 fast_llm/layers/vision_encoder/encoder.py diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py new file mode 100644 index 000000000..234c451a9 --- /dev/null +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -0,0 +1,44 @@ +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.common.linear import LinearBase +from fast_llm.layers.transformer.config import TransformerDimNames +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames +from fast_llm.tensor import init_normal_ + + +class VisionAdapter(Layer): + """ + Vision adapter layer for the LLM. + """ + + def __init__(self, intermediate_size: int, tensor_space: TensorSpace, name: str = "vision_adapter"): + super().__init__() + self._name = name + input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + self.layer_1 = LinearBase( + input_dim, + tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), + bias=True, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + ) + self.layer_2 = LinearBase( + tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), + tensor_space.get_tensor_dim(TransformerDimNames.hidden), + bias=True, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + ) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ): + return self.layer_2(self.layer_1(input_)) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py new file mode 100644 index 000000000..d410f92dc --- /dev/null +++ b/fast_llm/layers/vision_encoder/config.py @@ -0,0 +1,128 @@ +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.engine.base_model.config import BaseModelArchitectureConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.functional.config import ActivationType + + +class VisionEncoderDimNames: + out_channels = "vision_out_channels" + intermediate_size = "vision_intermediate_size" + patch_height = "vision_patch_height" + patch_width = "vision_patch_width" + + +@config_class() +class PatchConvConfig(BaseModelArchitectureConfig): + _abstract = False + """ + Configuration class for the convolution layers to apply on the image patches + """ + in_channels: int = Field( + default=3, + desc="Number of input channels for the convolution layers. Typically 3 for RGB images.", + hint=FieldHint.optional, + ) + bias: bool = Field( + default=False, desc="Whether to use a bias term in the convolution layers.", hint=FieldHint.optional + ) + height: int = Field( + default=16, + desc="Height of the image patches considered as tokens", + ) + width: int | None = Field( + default=16, + desc="Width of the image patches considered as tokens", + ) + + +@config_class() +class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): + _abstract = False + """ + Configuration class for the vision encoder, which transforms images into embeddings + """ + path: str | None = Field(default=None, desc="Path to a pretrained vision encoder model.", hint=FieldHint.optional) + hidden_size: int = Field( + default=1024, desc="The size of the hidden layers in the transformer model.", hint=FieldHint.optional + ) + intermediate_size: int = Field( + default=4096, + desc="The size of the intermediate (feed-forward) layers in the transformer model.", + hint=FieldHint.optional, + ) + num_hidden_layers: int = Field( + default=24, desc="The number of hidden layers in the transformer model.", hint=FieldHint.optional + ) + num_attention_heads: int = Field( + default=16, desc="The number of attention heads for the multi-head attention layers.", hint=FieldHint.optional + ) + num_channels: int = Field( + default=3, desc="Number of channels in the input image, typically 3 for RGB.", hint=FieldHint.optional + ) + image_size: int = Field( + default=1024, desc="The size of the input images (assumed square).", hint=FieldHint.optional + ) + patch_size: int = Field(default=16, desc="The size of the image patches to be encoded.", hint=FieldHint.optional) + hidden_act: str = Field( + default="gelu", desc="The activation function used in the hidden layers.", hint=FieldHint.optional + ) + attention_dropout: float = Field( + default=0.0, desc="The dropout probability for attention layers.", hint=FieldHint.optional + ) + rope_theta: float = Field( + default=10000.0, desc="The base value for rotary position embeddings.", hint=FieldHint.optional + ) + initializer_range: float = Field( + default=0.02, desc="The standard deviation of the normal initializer.", hint=FieldHint.optional + ) + + +@config_class() +class VisionArchitectureConfig(BaseModelArchitectureConfig): + _abstract = False + + encoder: VisionEncoderArchitectureConfig = Field( + default_factory=VisionEncoderArchitectureConfig, + desc="Configuration for the vision encoder that transforms images into embeddings.", + hint=FieldHint.optional, + ) + adapter_size: int = Field( + default=5120, + desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", + hint=FieldHint.optional, + ) + adapter_activation_type: ActivationType = Field( + default=ActivationType.gelu, + desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + hint=FieldHint.core, + ) + + def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.encoder.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.intermediate_size, self.adapter_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_height, self.encoder.patch_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_width, self.encoder.patch_size)) + # tensor_space.add_tensor_dim( + # CompositeTensorDim(VisionEncoderDimNames.) + # ) + + # patch_convolution: PatchConvConfig = Field( + # default_factory=PatchConvConfig, + # desc="Configuration for the convolution layers applied to the image patches.", + # hint=FieldHint.optional + # ) + # normalization: NormalizationArchitectureConfig = Field( + # default_factory=NormalizationArchitectureConfig, + # desc="Configuration for the normalization layers applied to the image patches.", + # hint=FieldHint.optional + # ) + # transformer: TransformerArchitectureConfig = Field( + # default_factory=TransformerArchitectureConfig, + # desc="Configuration for the transformer layers applied to the image patches.", + # hint=FieldHint.optional + # ) + # patch_rotary: RotaryArchitectureConfig = Field( + # default_factory=RotaryArchitectureConfig, + # desc="Configuration for the rotary positional embeddings applied to the image patches.", + # hint=FieldHint.optional + # ) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py new file mode 100644 index 000000000..2ea5c1e4f --- /dev/null +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -0,0 +1,89 @@ +import functools +import typing + +import torch +from transformers import PixtralVisionConfig, PixtralVisionModel + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.language_model.config import LanguageModelBaseConfig +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.vision_encoder.adapter import VisionAdapter +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames +from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ + + +class VisionEncoder(Layer): + """ + A vision encoder layer for creating token embeddings from vision model + """ + + def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): + super().__init__() + + self._config = config.vision_encoder + self._distributed_config = tensor_space.distributed_config + with torch.device("meta"): + if self._config.encoder.path: + self._vision_encoder = PixtralVisionModel.from_pretrained( + self._config.encoder.path, torch_dtype=self._distributed_config.training_dtype.torch + ) + else: + self._vision_encoder = PixtralVisionModel( + PixtralVisionConfig( + hidden_size=self._config.hidden_size, + intermediate_size=self._config.intermediate_size, + num_hidden_layers=self._config.num_hidden_layers, + num_attention_heads=self._config.num_attention_heads, + num_channels=self._config.num_channels, + image_size=self._config.image_size, + patch_size=self._config.patch_size, + hidden_act=self._config.hidden_act, + attention_dropout=self._config.attention_dropout, + rope_theta=self._config.rope_theta, + initializer_range=self._config.initializer_range, + ) + ) + param_names = [] + # gather all names first. PyTorch complains if we do it in the loop + for name, param in self._vision_encoder.named_parameters(): + param_names.append(name) + for name in param_names: + # exclude .weight/.bias + *module_path, stem = name.split(".")[:-1] + module = functools.reduce(getattr, module_path, self._vision_encoder) + param = self._vision_encoder.get_parameter(name) + setattr( + module, + stem, + ParameterMeta.from_dims( + tuple(TensorDim(f"{name}_{idx}", size) for idx, size in enumerate(param.shape)), + init_method=init_normal_(), + ), + # ParameterMeta( + # param, + # tensor_name=name, + # dims=(TensorDim(f"{name}_{idx}", size) for idx, size in enumerate(param.shape)), + # init_method=init_normal_(), + # allow_no_grad=True, + # ), + ) + self._adapter = VisionAdapter( + intermediate_size=tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), + tensor_space=tensor_space, + ) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Vision Output", + dtype=self._distributed_config.training_dtype.torch, + ) + return self._vision_encoder(input_) From daf586f8d6a428398674771bce71a61a7e32cdbf Mon Sep 17 00:00:00 2001 From: root Date: Thu, 24 Apr 2025 21:49:28 +0000 Subject: [PATCH 007/161] wip --- fast_llm/models/gpt/config.py | 5 ++-- fast_llm/models/gpt/conversion.py | 43 ++++++++++++++++--------------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index c90da81b3..ca73b879e 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -48,8 +48,8 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mixtral" -class PixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "pixtral" +class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "llava" @config_class() @@ -109,6 +109,7 @@ class GPTModelConfig(FastLLMModelConfig): Qwen2GPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, + LlavaGPTHuggingfaceCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 30f54f06d..ad74ad53e 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -30,9 +30,9 @@ GPTArchitectureConfig, GPTModelConfig, LlamaGPTHuggingfaceCheckpointFormat, + LlavaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, - PixtralGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) @@ -367,7 +367,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), RenameParamConverter( fast_llm_names=(("transformer", "kv_channels"),), - export_names=(("head_dim"),), + export_names=(("head_dim",),), ), ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), @@ -554,23 +554,24 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] -class PixtralHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): +class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat + @classmethod def _create_config_converters(cls) -> list[ParamConverter]: lm_converters = super()._create_config_converters() for converter in lm_converters: - if converter.fast_llm_names[0][0] == "transformer": - converter.export_names[0] = ("text_config", *converter.export_names[0]) + if isinstance(converter, (RenameParamConverter, MappedConfigParamConverter, RopeScalingParamConverter)): + # Llava uses a different name for the text config + # if converter.fast_llm_names[0][0] == "transformer": + converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) + # if converter.fast_llm_names[0][0] == "transformer": + # converter.export_names[0] = ("text_config", *converter.export_names[0]) return lm_converters + [ # Multimodal adapter RenameParamConverter( fast_llm_names=(("vision_encoder", "adapter_size"),), - export_names=( - ( - "text_config", - "hidden_size", - ) - ), + export_names=(("text_config", "hidden_size"),), ), # Image processing and conv layer RenameParamConverter( @@ -579,7 +580,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "image_size", - ) + ), ), ), RenameParamConverter( @@ -588,7 +589,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "patch_size", - ) + ), ), ), # Vision Transformer @@ -598,7 +599,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "num_hidden_layers", - ) + ), ), ), RenameParamConverter( @@ -607,7 +608,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "hidden_size", - ) + ), ), ), RenameParamConverter( @@ -616,7 +617,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "num_attention_heads", - ) + ), ), ), RenameParamConverter( @@ -625,7 +626,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "intermediate_size", - ) + ), ), ), MappedConfigParamConverter( @@ -634,7 +635,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "hidden_act", - ) + ), ), fast_llm_value=ActivationType.from_hf_name, export_value=lambda activation_type: activation_type.hf_name, @@ -645,7 +646,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "num_channels", - ) + ), ), ), RenameParamConverter( @@ -654,7 +655,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "attention_dropout", - ) + ), ), ), RenameParamConverter( @@ -793,5 +794,5 @@ class AutoGPTHuggingfaceCheckpointHandler( Qwen2GPTHuggingfaceCheckpointFormat.name: Qwen2HuggingfaceCheckpointHandler, MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler, MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, - PixtralGPTHuggingfaceCheckpointFormat.name: PixtralHuggingfaceCheckpointHandler, + LlavaGPTHuggingfaceCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, } From ef4488d4f94b9c19b04f409917b3091b8e8601e8 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 25 Apr 2025 22:47:08 +0000 Subject: [PATCH 008/161] wip --- fast_llm/layers/language_model/config.py | 7 ++- fast_llm/layers/vision_encoder/config.py | 11 +++++ fast_llm/layers/vision_encoder/encoder.py | 35 +++++++-------- fast_llm/models/gpt/conversion.py | 54 ++++++++++++++++------- 4 files changed, 69 insertions(+), 38 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 0175296c5..ec80a9334 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -44,6 +44,11 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.core, ) + vision_encoder: None | VisionArchitectureConfig = Field( + default=None, + desc="Configuration for the vision encoder that transforms images into embeddings.", + hint=FieldHint.optional, + ) max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", @@ -125,7 +130,7 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) - vision_encoder: VisionArchitectureConfig | None = Field( + vision_encoder: None | VisionArchitectureConfig = FieldUpdate( default=None, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index d410f92dc..76af3d371 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -2,6 +2,7 @@ from fast_llm.engine.base_model.config import BaseModelArchitectureConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.config import NormalizationType class VisionEncoderDimNames: @@ -42,6 +43,11 @@ class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): Configuration class for the vision encoder, which transforms images into embeddings """ path: str | None = Field(default=None, desc="Path to a pretrained vision encoder model.", hint=FieldHint.optional) + pre_norm: NormalizationType = Field( + default=NormalizationType.rms_norm, + desc="The type of normalization to use before the transformer layers.", + hint=FieldHint.optional, + ) hidden_size: int = Field( default=1024, desc="The size of the hidden layers in the transformer model.", hint=FieldHint.optional ) @@ -75,6 +81,11 @@ class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): initializer_range: float = Field( default=0.02, desc="The standard deviation of the normal initializer.", hint=FieldHint.optional ) + activation_type: ActivationType = Field( + default=ActivationType.silu, + desc="The activation function used in the hidden layers. Default: SiLU.", + hint=FieldHint.optional, + ) @config_class() diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 2ea5c1e4f..88064b51a 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -31,17 +31,17 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): else: self._vision_encoder = PixtralVisionModel( PixtralVisionConfig( - hidden_size=self._config.hidden_size, - intermediate_size=self._config.intermediate_size, - num_hidden_layers=self._config.num_hidden_layers, - num_attention_heads=self._config.num_attention_heads, - num_channels=self._config.num_channels, - image_size=self._config.image_size, - patch_size=self._config.patch_size, - hidden_act=self._config.hidden_act, - attention_dropout=self._config.attention_dropout, - rope_theta=self._config.rope_theta, - initializer_range=self._config.initializer_range, + hidden_size=self._config.encoder.hidden_size, + intermediate_size=self._config.encoder.intermediate_size, + num_hidden_layers=self._config.encoder.num_hidden_layers, + num_attention_heads=self._config.encoder.num_attention_heads, + num_channels=self._config.encoder.num_channels, + image_size=self._config.encoder.image_size, + patch_size=self._config.encoder.patch_size, + hidden_act=self._config.encoder.hidden_act, + attention_dropout=self._config.encoder.attention_dropout, + rope_theta=self._config.encoder.rope_theta, + initializer_range=self._config.encoder.initializer_range, ) ) param_names = [] @@ -49,8 +49,7 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): for name, param in self._vision_encoder.named_parameters(): param_names.append(name) for name in param_names: - # exclude .weight/.bias - *module_path, stem = name.split(".")[:-1] + *module_path, stem = name.split(".") module = functools.reduce(getattr, module_path, self._vision_encoder) param = self._vision_encoder.get_parameter(name) setattr( @@ -60,14 +59,10 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): tuple(TensorDim(f"{name}_{idx}", size) for idx, size in enumerate(param.shape)), init_method=init_normal_(), ), - # ParameterMeta( - # param, - # tensor_name=name, - # dims=(TensorDim(f"{name}_{idx}", size) for idx, size in enumerate(param.shape)), - # init_method=init_normal_(), - # allow_no_grad=True, - # ), ) + none_params = [key for key, value in module._parameters.items() if value is None] + for key in none_params: + module._parameters.pop(key) self._adapter = VisionAdapter( intermediate_size=tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), tensor_space=tensor_space, diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index ad74ad53e..f730d79c6 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -173,14 +173,16 @@ def _create_weight_converters( # Embeddings converters.append( WeightConverter( - f"layers.{fast_llm_offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight" + f"layers.{fast_llm_offset - 1}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight" ) ) converters += self._create_lm_head_converters(hf_base_prefix, fast_llm_offset) for i in range(num_layers): - converters += self._create_transformer_layer_converters(i, hf_base_prefix, fast_llm_offset) + converters += self._create_transformer_layer_converters( + i, hf_base_prefix=hf_base_prefix, fast_llm_offset=fast_llm_offset + ) return converters @@ -560,6 +562,9 @@ class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: lm_converters = super()._create_config_converters() + lm_converters[-2] = ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] + ) for converter in lm_converters: if isinstance(converter, (RenameParamConverter, MappedConfigParamConverter, RopeScalingParamConverter)): # Llava uses a different name for the text config @@ -674,39 +679,39 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: for i in range(num_layers): vision_transformer_converters += [ WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.k_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.attention.k_proj.weight", f"vision_tower.transformer.layers.{i}.attention.k_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.v_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.attention.v_proj.weight", f"vision_tower.transformer.layers.{i}.attention.v_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.q_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.attention.q_proj.weight", f"vision_tower.transformer.layers.{i}.attention.q_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.o_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.attention.o_proj.weight", f"vision_tower.transformer.layers.{i}.attention.o_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention_norm.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.attention_norm.weight", f"vision_tower.transformer.layers.{i}.attention_norm.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.down_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.down_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.down_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.gate_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.up_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.up_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.up_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.ffn_norm.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.ffn_norm.weight", f"vision_tower.transformer.layers.{i}.ffn_norm.weight", ), ] @@ -718,30 +723,45 @@ def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: "layers.0._vision_encoder.patch_conv.weight", "vision_tower.patch_conv.weight", ) + # TODO Soham: use _get_weight_and_bias_converters? + layer_norm_converter = WeightConverter( + "layers.0._vision_encoder.ln_pre.weight", + "vision_tower.ln_pre.weight", + ) + if self._model.config.base_model.vision_encoder.encoder.pre_norm == NormalizationType.layer_norm: + layer_norm_bias_converter = WeightConverter( + "layers.0._vision_encoder.ln_pre.bias", + "vision_tower.ln_pre.bias", + ) vision_transformer_converters = self._create_vision_transformer_converters() adapter_converters = [ WeightConverter( - "layers.0._vision_encoder._adapter.layer_1.weight", + "layers.0._adapter.layer_1.weight", "multi_modal_projector.linear_1.weight", ), WeightConverter( - "layers.0._vision_encoder._adapter.layer_1.bias", + "layers.0._adapter.layer_1.bias", "multi_modal_projector.linear_1.bias", ), + # TODO Soham: conditionally add bias WeightConverter( - "layers.0._vision_encoder._adapter.layer_2.weight", + "layers.0._adapter.layer_2.weight", "multi_modal_projector.linear_2.weight", ), WeightConverter( - "layers.0._vision_encoder._adapter.layer_2.bias", + "layers.0._adapter.layer_2.bias", "multi_modal_projector.linear_2.bias", ), ] - return [patch_conv_converter] + vision_transformer_converters + adapter_converters + return ( + [patch_conv_converter, layer_norm_converter, layer_norm_bias_converter] + + vision_transformer_converters + + adapter_converters + ) def _create_weight_converters(self) -> list[WeightConverter]: vision_encoder_converter = self._create_vision_encoder_weight_converters() - lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=1) + lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=2) return vision_encoder_converter + lm_converters From 6d9d595b921bc5139a910fe843d6ece3403445fb Mon Sep 17 00:00:00 2001 From: root Date: Mon, 28 Apr 2025 15:23:50 +0000 Subject: [PATCH 009/161] missing files --- fast_llm/data/dataset/gpt/memmap.py | 6 +- fast_llm/engine/multi_stage/stage_base.py | 3 + fast_llm/layers/multi_modal/embedding.py | 83 +++++++++++++++++++ fast_llm/layers/vision_encoder/config.py | 57 ++++++++++++- fast_llm/layers/vision_encoder/encoder.py | 26 +++--- .../layers/vision_encoder/preprocessing.py | 74 +++++++++++++++++ fast_llm/models/gpt/conversion.py | 44 +++++----- fast_llm/models/gpt/model.py | 59 +++++++++++-- setup.cfg | 2 +- 9 files changed, 309 insertions(+), 45 deletions(-) create mode 100644 fast_llm/layers/multi_modal/embedding.py create mode 100644 fast_llm/layers/vision_encoder/preprocessing.py diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 43fba843c..99bfbfa42 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -200,7 +200,7 @@ def get( for image_length in self._image_lengths[idx]: # TODO Soham: verify reshape dimension order n_pixels = image_length.prod(initial=3) - images.append(pixels[start : start + n_pixels].reshape(image_length[0], image_length[1], 3)) + images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) start += n_pixels # TODO Soham: return loss_masking_spans return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) @@ -296,8 +296,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP for image in document.images: # assume 3 channels (RGB) for all images with PIL.Image.open(io.BytesIO(image["bytes"])) as img: - pixels = np.array(img) - image_lengths.append(np.array(pixels.shape[:2])) + pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW + image_lengths.append(np.array(pixels.shape[1:])) bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size im_positions.append(document.image_positions) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 0f83c862d..e97ef0410 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -161,6 +161,9 @@ def _replace(module: torch.nn.Module): nonlocal i for key in module._parameters: meta = typing.cast(ParameterMeta, module._parameters[key]) + # TODO Soham: clean way to get around check? + if meta is None: + continue module._parameters[key] = self.get_parameter_buffer(meta.tensor_name) i += 1 diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py new file mode 100644 index 000000000..a92fdc4e5 --- /dev/null +++ b/fast_llm/layers/multi_modal/embedding.py @@ -0,0 +1,83 @@ +import typing + +import torch + +from fast_llm.core.distributed import set_generator +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs +from fast_llm.layers.language_model.embedding import LanguageModelEmbedding +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionModelKwargs +from fast_llm.layers.vision_encoder.encoder import VisionEncoder +from fast_llm.tensor import TensorMeta + + +class MultiModalEmbedding(LanguageModelEmbedding): + """ + Multi-modal embedding layer to combine embeddings from text, image and more modalities. + """ + + def __init__( + self, + config: LanguageModelBaseConfig, + tensor_space: TensorSpace, + ): + super().__init__(config, tensor_space) + self.vision_encoder = VisionEncoder(config, tensor_space) + + def _forward( + self, + input_: torch.Tensor, + position_ids: torch.Tensor | None, + images: torch.Tensor | None, + image_sizes: torch.Tensor | None, + image_positions: list[torch.Tensor] | None, + ) -> torch.Tensor: + image_embeddings = self.vision_encoder(images, kwargs={VisionModelKwargs.image_sizes: image_sizes}) + # TODO Soham: offset position ids + img_tokens_seen = 0 + image_idx = 0 + text_embeddings = super()._forward(input_, position_ids) + embeddings = [] + for sample_idx, positions in enumerate(image_positions): + embedding_parts = [] + for position in positions: + image = images[image_idx] + image_tokens = (image.size[1] // self._config.vision_encoder.encoder.patch_size) * ( + image.size[2] // self._config.vision_encoder.encoder.patch_size + ) + image_idx += 1 + img_tokens_seen += image_tokens + embedding_parts.append(text_embeddings[sample_idx, :position]) + embedding_parts.append(image_embeddings[img_tokens_seen : img_tokens_seen + image_tokens]) + embedding_parts.append(text_embeddings[sample_idx, position + image_tokens :]) + embeddings.append(torch.cat(embedding_parts, dim=0)) + embeddings = torch.stack(embeddings, dim=0) + with set_generator( + self._tensor_space.distributed.tp_generator + if self._sequence_parallel + else self._tensor_space.distributed.pp_generator + ): + embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + return embeddings.to(self._residual_dtype) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Embedding output", + dtype=self._residual_dtype, + ) + return self._forward( + input_, + kwargs.get(LanguageModelKwargs.position_ids), + kwargs.get(VisionModelKwargs.images), + kwargs.get(VisionModelKwargs.image_sizes), + kwargs.get(VisionModelKwargs.image_positions), + ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 76af3d371..5e4722513 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,4 +1,4 @@ -from fast_llm.config import Field, FieldHint, config_class +from fast_llm.config import Config, Field, FieldHint, config_class from fast_llm.engine.base_model.config import BaseModelArchitectureConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType @@ -12,6 +12,17 @@ class VisionEncoderDimNames: patch_width = "vision_patch_width" +class VisionModelKwargs: + images = "images" + image_positions = "image_positions" + image_height = "image_height" + image_width = "image_width" + image_sizes = "image_sizes" + image_mean = "image_normalization_mean" + image_std = "image_normalization_std" + image_rescale_factor = "image_rescale_factor" + + @config_class() class PatchConvConfig(BaseModelArchitectureConfig): _abstract = False @@ -88,6 +99,45 @@ class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): ) +@config_class() +class ImageNormalizationConfig(Config): + mean_r: float = Field( + default=0.48145466, + desc="Mean value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_g: float = Field( + default=0.4578275, + desc="Mean value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_b: float = Field( + default=0.40821073, + desc="Mean value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_r: float = Field( + default=0.26862954, + desc="Standard deviation value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_g: float = Field( + default=0.26130258, + desc="Standard deviation value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_b: float = Field( + default=0.27577711, + desc="Standard deviation value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + rescale_factor: float = Field( + default=255.0, + desc="Rescale factor for the image normalization process.", + hint=FieldHint.optional, + ) + + @config_class() class VisionArchitectureConfig(BaseModelArchitectureConfig): _abstract = False @@ -107,6 +157,11 @@ class VisionArchitectureConfig(BaseModelArchitectureConfig): desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", hint=FieldHint.core, ) + normalization: ImageNormalizationConfig = Field( + default_factory=ImageNormalizationConfig, + desc="Configuration for the normalization layers applied to the image patches.", + hint=FieldHint.optional, + ) def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.encoder.hidden_size)) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 88064b51a..b028fa1fa 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -9,10 +9,11 @@ from fast_llm.layers.language_model.config import LanguageModelBaseConfig from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.adapter import VisionAdapter -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionModelKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +# TODO Soham: should this just be nn.Module? class VisionEncoder(Layer): """ A vision encoder layer for creating token embeddings from vision model @@ -25,11 +26,14 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._distributed_config = tensor_space.distributed_config with torch.device("meta"): if self._config.encoder.path: - self._vision_encoder = PixtralVisionModel.from_pretrained( + self.vision_encoder = PixtralVisionModel.from_pretrained( self._config.encoder.path, torch_dtype=self._distributed_config.training_dtype.torch ) else: - self._vision_encoder = PixtralVisionModel( + # TODO Soham options to fix rotary: + # 1. load PixtralTransformer instead of PixtralVisionModel. Required to implement conv2d, ln_pre separately and store positional embeddings in kwargs_meta + # 2. set self.vision_encoder.position_embeddings = PixtralRotaryEmbedding(config) outside of meta scope + self.vision_encoder = PixtralVisionModel( PixtralVisionConfig( hidden_size=self._config.encoder.hidden_size, intermediate_size=self._config.encoder.intermediate_size, @@ -46,12 +50,12 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): ) param_names = [] # gather all names first. PyTorch complains if we do it in the loop - for name, param in self._vision_encoder.named_parameters(): + for name, param in self.vision_encoder.named_parameters(): param_names.append(name) for name in param_names: *module_path, stem = name.split(".") - module = functools.reduce(getattr, module_path, self._vision_encoder) - param = self._vision_encoder.get_parameter(name) + module = functools.reduce(getattr, module_path, self.vision_encoder) + param = self.vision_encoder.get_parameter(name) setattr( module, stem, @@ -60,10 +64,10 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): init_method=init_normal_(), ), ) - none_params = [key for key, value in module._parameters.items() if value is None] - for key in none_params: - module._parameters.pop(key) - self._adapter = VisionAdapter( + # none_params = [key for key, value in module._parameters.items() if value is None] + # for key in none_params: + # module._parameters.pop(key) + self.adapter = VisionAdapter( intermediate_size=tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), tensor_space=tensor_space, ) @@ -81,4 +85,4 @@ def forward( tensor_name="Vision Output", dtype=self._distributed_config.training_dtype.torch, ) - return self._vision_encoder(input_) + return self.adapter(self.vision_encoder(input_, kwargs[VisionModelKwargs.image_sizes])) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py new file mode 100644 index 000000000..7ebfd3d7d --- /dev/null +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -0,0 +1,74 @@ +import typing + +import torch +import torchvision.transforms.v2.functional as F + +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.vision_encoder.config import VisionArchitectureConfig, VisionModelKwargs + + +def get_resize_dims(height: int, width: int, max_height: int, max_width: int) -> tuple[int, int]: + """ + Calculate the new dimensions for resizing an image while maintaining the aspect ratio. + If the image is larger than the max dimensions, it will be resized to fit within them. + If the image is smaller, it will be resized to the nearest multiple of the patch size. + """ + ratio = max(height / max_height, width / max_width) + return ( + (int(height / ratio), int(width / ratio)) + if ratio > 1 + else (max_height * (height // max_height), max_width * (width // max_width)) + ) + + +def resize(image: torch.Tensor, max_height: int, max_width: int) -> tuple[int, int]: + resize_dims = get_resize_dims(image.size(1), image.size(2), max_height, max_width) + # TODO: options for interpolation mode? + return F.resize(image, size=resize_dims, interpolation=F.InterpolationMode.BICUBIC) + + +def normalize(image: torch.Tensor, mean: list[float], std: list[float]) -> torch.Tensor: + """ + Normalize the image using the specified mean and standard deviation. + """ + return F.normalize(image, mean=mean, std=std) + + +def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: + """ + Pad images on the right and bottom with 0s untitl max_height and max_width + """ + width_padding = max(0, max_height - image.size(1)) + depth_padding = max(0, max_width - image.size(2)) + return F.pad(image, (0, 0, width_padding, depth_padding), 0) + + +class VisionPreprocessor: + def __init__(self, config: VisionArchitectureConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + images = kwargs.get("images") + im_height = kwargs.get(VisionModelKwargs.image_height) + im_width = kwargs.get(VisionModelKwargs.image_width) + kwargs[VisionModelKwargs.image_sizes] = [(im.size(1), im.size(2)) for im in images] + images = [ + pad( + normalize( + resize(image, im_height, im_width) / kwargs[VisionModelKwargs.image_rescale_factor], + mean=kwargs[VisionModelKwargs.image_mean], + std=kwargs[VisionModelKwargs.image_std], + ), + max_height=im_height, + max_width=im_width, + ) + for image in images + ] + images = torch.stack(images, dim=0).to( + # TODO Soham: is this needed? + device=self._tensor_space.distributed.device, + dtype=self._distributed_config.training_dtype.torch, + ) + kwargs[VisionModelKwargs.images] = images diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index f730d79c6..3caaee5ad 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -679,39 +679,39 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: for i in range(num_layers): vision_transformer_converters += [ WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.attention.k_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.k_proj.weight", f"vision_tower.transformer.layers.{i}.attention.k_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.attention.v_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.v_proj.weight", f"vision_tower.transformer.layers.{i}.attention.v_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.attention.q_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.q_proj.weight", f"vision_tower.transformer.layers.{i}.attention.q_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.attention.o_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.o_proj.weight", f"vision_tower.transformer.layers.{i}.attention.o_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.attention_norm.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention_norm.weight", f"vision_tower.transformer.layers.{i}.attention_norm.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.down_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.down_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.down_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.gate_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.up_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.up_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.up_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.ffn_norm.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.ffn_norm.weight", f"vision_tower.transformer.layers.{i}.ffn_norm.weight", ), ] @@ -720,48 +720,48 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: patch_conv_converter = WeightConverter( - "layers.0._vision_encoder.patch_conv.weight", + "layers.0.vision_encoder.vision_encoder.patch_conv.weight", "vision_tower.patch_conv.weight", ) # TODO Soham: use _get_weight_and_bias_converters? + layernorm_converters = [] layer_norm_converter = WeightConverter( - "layers.0._vision_encoder.ln_pre.weight", + "layers.0.vision_encoder.vision_encoder.ln_pre.weight", "vision_tower.ln_pre.weight", ) + layernorm_converters.append(layer_norm_converter) + layer_norm_converter if self._model.config.base_model.vision_encoder.encoder.pre_norm == NormalizationType.layer_norm: layer_norm_bias_converter = WeightConverter( - "layers.0._vision_encoder.ln_pre.bias", + "layers.0.vision_encoder.vision_encoder.ln_pre.bias", "vision_tower.ln_pre.bias", ) + layernorm_converters.append(layer_norm_bias_converter) vision_transformer_converters = self._create_vision_transformer_converters() adapter_converters = [ WeightConverter( - "layers.0._adapter.layer_1.weight", + "layers.0.vision_encoder.adapter.layer_1.weight", "multi_modal_projector.linear_1.weight", ), WeightConverter( - "layers.0._adapter.layer_1.bias", + "layers.0.vision_encoder.adapter.layer_1.bias", "multi_modal_projector.linear_1.bias", ), # TODO Soham: conditionally add bias WeightConverter( - "layers.0._adapter.layer_2.weight", + "layers.0.vision_encoder.adapter.layer_2.weight", "multi_modal_projector.linear_2.weight", ), WeightConverter( - "layers.0._adapter.layer_2.bias", + "layers.0.vision_encoder.adapter.layer_2.bias", "multi_modal_projector.linear_2.bias", ), ] - return ( - [patch_conv_converter, layer_norm_converter, layer_norm_bias_converter] - + vision_transformer_converters - + adapter_converters - ) + return [patch_conv_converter] + layernorm_converters + vision_transformer_converters + adapter_converters def _create_weight_converters(self) -> list[WeightConverter]: vision_encoder_converter = self._create_vision_encoder_weight_converters() - lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=2) + lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=1) return vision_encoder_converter + lm_converters diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 674116413..0890051ea 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -14,6 +14,7 @@ from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor +from fast_llm.layers.multi_modal.embedding import MultiModalEmbedding from fast_llm.layers.transformer.config import ( RoutingType, TransformerDimNames, @@ -26,7 +27,8 @@ RotaryEmbeddingPreprocessor, ) from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.encoder import VisionEncoder +from fast_llm.layers.vision_encoder.config import VisionModelKwargs +from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -72,6 +74,9 @@ def __init__( else: self._flash_varlen_preprocessor = FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space) + if self._config.vision_encoder: + self._vision_preprocessor = VisionPreprocessor(self._config.vision_encoder, self._tensor_space) + def get_output_layers(self) -> list[Layer]: return [ layer @@ -98,14 +103,19 @@ def get_layers(self) -> list[Layer]: if self._config.transformer.num_layers == 0: Assert.eq(self._config.prediction_heads, 1) return [ - LanguageModelEmbedding(self._config, self._tensor_space), + ( + LanguageModelEmbedding(self._config, self._tensor_space) + if self._config.vision_encoder is None + else MultiModalEmbedding(self._config, self._tensor_space) + ), LanguageModelHead(self._config, self._tensor_space, 0), ] - return ( - [VisionEncoder(self._config, self._tensor_space)] if self._config.vision_encoder is not None else [] - ) + [ - # return [ - LanguageModelEmbedding(self._config, self._tensor_space), + return [ + ( + LanguageModelEmbedding(self._config, self._tensor_space) + if self._config.vision_encoder is None + else MultiModalEmbedding(self._config, self._tensor_space) + ), *[ TransformerLayer( self._config.transformer, @@ -139,6 +149,30 @@ def preprocess_meta( sequence_length -= 1 micro_sequence_length = sequence_length + if self._config.vision_encoder: + image_height = batch_meta.max_image_height + image_width = batch_meta.max_image_width + image_mean = [ + self._config.vision_encoder.normalization.mean_r, + self._config.vision_encoder.normalization.mean_g, + self._config.vision_encoder.normalization.mean_b, + ] + image_std = [ + self._config.vision_encoder.normalization.std_r, + self._config.vision_encoder.normalization.std_g, + self._config.vision_encoder.normalization.std_b, + ] + image_rescale_factor = self._config.vision_encoder.normalization.rescale_factor + vision_kwargs = { + VisionModelKwargs.image_height: image_height, + VisionModelKwargs.image_width: image_width, + VisionModelKwargs.image_mean: image_mean, + VisionModelKwargs.image_std: image_std, + VisionModelKwargs.image_rescale_factor: image_rescale_factor, + } + else: + vision_kwargs = {} + batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) @@ -189,6 +223,7 @@ def preprocess_meta( TransformerKwargs.sequence_length: sequence_length, TransformerKwargs.sequence_q_dim: sequence_q_dim, } + common_kwargs.update(vision_kwargs) preprocessed_meta = [] for sequence_k_past in range( @@ -271,6 +306,16 @@ def preprocess( if self._use_flash_attention: self._flash_varlen_preprocessor.preprocess(kwargs_meta) + if batch.images is not None: + kwargs_meta[VisionModelKwargs.images] = [ + img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + for images in batch.images + for img in images + ] + kwargs_meta[VisionModelKwargs.image_positions] = batch.image_positions + if self._config.vision_encoder: + self._vision_preprocessor.preprocess(kwargs_meta) + # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. pasts = presents diff --git a/setup.cfg b/setup.cfg index 57913f83d..52676c799 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ CORE = # Required for some optional features and tools. OPTIONAL = # Huggingface tools - transformers>=4.44.2 + transformers>=4.48.3 hf-transfer>=0.1.8 datasets>=3.1.0 huggingface-hub>=0.28.1 From 6cb8f5d0e85e8b1bd24470e387b4c0d259124201 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 30 Apr 2025 16:31:46 +0000 Subject: [PATCH 010/161] make it work, barely --- Dockerfile | 1 + fast_llm/data/data/gpt/data.py | 6 +- fast_llm/data/dataset/gpt/sampled.py | 17 ++- fast_llm/layers/multi_modal/embedding.py | 76 +++++----- fast_llm/layers/vision_encoder/adapter.py | 19 +-- fast_llm/layers/vision_encoder/config.py | 17 ++- fast_llm/layers/vision_encoder/encoder.py | 134 ++++++++++++++---- .../layers/vision_encoder/preprocessing.py | 31 +++- fast_llm/models/gpt/conversion.py | 30 ++-- fast_llm/models/gpt/model.py | 26 ++-- 10 files changed, 240 insertions(+), 117 deletions(-) diff --git a/Dockerfile b/Dockerfile index 8c2efa85e..b8e1f8887 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,6 +4,7 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ && apt-get install --no-install-recommends -y acl git-lfs \ + && apt-get install --no-install-recommends -y libtiff5-dev \ && rm -rf /var/lib/apt/lists/* \ && git lfs install diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 22e4730c9..cffaa734f 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -49,10 +49,12 @@ def gpt_data_collate_fn( stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] if not cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] + has_images = False batch_images = [] for sample in batch: if sample.images is not None: batch_images.append([torch.from_numpy(image) for image in sample.images]) + has_images = True else: batch_images.append(None) batch_image_positions = [] @@ -65,8 +67,8 @@ def gpt_data_collate_fn( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, - images=batch_images if any(batch_images) else None, - image_positions=batch_image_positions if any(batch_image_positions) else None, + images=batch_images if has_images else None, + image_positions=batch_image_positions if has_images else None, ) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 8acbf9ee6..973c1db53 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -433,13 +433,22 @@ def __getitem__(self, index: int) -> typing.Any: length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._config.use_loss_masking_spans, ) - # TODO Soham: handle images with loss masking spans + start_pos = 0 for idx, im_position in enumerate(sample.image_positions): # image_positions.append(im_positions + len(token_ids) + image_tokens_added) + # Add placeholders for image tokens + token_ids.append(sample.token_ids[start_pos:im_position]) + token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) image_positions.append(im_position + len(token_ids) + image_tokens_added) image_tokens_added += image_tokens + start_pos = im_position + token_ids.append(sample.token_ids[start_pos:]) + # TODO Soham: remove this + # if len(sample.images) == 1: + # sample.images.append(sample.images[0]) + # sample.image_positions = np.concatenate([sample.image_positions, sample.image_positions]) images.append(sample.images) - token_ids.append(sample.token_ids) + # TODO Soham: add offsets for loss masking spans if self._config.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: span = np.clip(loss_masking_span + token_count - token_start, 0, self._sequence_length + 1) @@ -452,7 +461,7 @@ def __getitem__(self, index: int) -> typing.Any: sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) - + np.array([ImageProcessor.get_num_patches(image) for image in images[idx] for idx in range(len(images))]) + # + np.array([ImageProcessor.get_num_patches(image) for image in images[idx] for idx in range(len(images))]) if not self._cross_document_attention else None ) @@ -464,7 +473,7 @@ def __getitem__(self, index: int) -> typing.Any: ) images = [im for img_list in images for im in img_list] if images else None image_positions = np.array(image_positions) if image_positions else None - Assert.eq(len(token_ids) + image_tokens_added, self._sequence_length + 1) + Assert.eq(len(token_ids), self._sequence_length + 1) return GPTSample( token_ids=token_ids, diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index a92fdc4e5..3b62c60b7 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -25,59 +25,59 @@ def __init__( super().__init__(config, tensor_space) self.vision_encoder = VisionEncoder(config, tensor_space) - def _forward( + def forward( self, input_: torch.Tensor, - position_ids: torch.Tensor | None, - images: torch.Tensor | None, - image_sizes: torch.Tensor | None, - image_positions: list[torch.Tensor] | None, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, ) -> torch.Tensor: - image_embeddings = self.vision_encoder(images, kwargs={VisionModelKwargs.image_sizes: image_sizes}) + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Embedding output", + dtype=self._residual_dtype, + ) + # return self._forward( + # input_, + # kwargs.get(LanguageModelKwargs.position_ids), + # kwargs.get(VisionModelKwargs.images), + # kwargs.get(VisionModelKwargs.image_sizes), + # kwargs.get(VisionModelKwargs.image_positions), + # ) # TODO Soham: offset position ids + images = kwargs.pop(VisionModelKwargs.images)[:1] + position_ids = kwargs.get(LanguageModelKwargs.position_ids) + image_positions = kwargs.get(VisionModelKwargs.image_positions)[:1] + image_embeddings = self.vision_encoder(images, kwargs) + embeddings = super()._forward(input_, position_ids) img_tokens_seen = 0 image_idx = 0 - text_embeddings = super()._forward(input_, position_ids) - embeddings = [] for sample_idx, positions in enumerate(image_positions): - embedding_parts = [] - for position in positions: + # embedding_parts = [] + for position in positions[:1]: image = images[image_idx] - image_tokens = (image.size[1] // self._config.vision_encoder.encoder.patch_size) * ( - image.size[2] // self._config.vision_encoder.encoder.patch_size + image_tokens = (image.size(1) // self._config.vision_encoder.encoder.patch_size) * ( + image.size(2) // self._config.vision_encoder.encoder.patch_size ) + embeddings[sample_idx, position : position + image_tokens] = image_embeddings[ + sample_idx, img_tokens_seen : img_tokens_seen + image_tokens + ] + # embedding_parts.append(text_embeddings[sample_idx, :position]) + # embedding_parts.append(image_embeddings[sample_idx, img_tokens_seen : img_tokens_seen + image_tokens]) image_idx += 1 img_tokens_seen += image_tokens - embedding_parts.append(text_embeddings[sample_idx, :position]) - embedding_parts.append(image_embeddings[img_tokens_seen : img_tokens_seen + image_tokens]) - embedding_parts.append(text_embeddings[sample_idx, position + image_tokens :]) - embeddings.append(torch.cat(embedding_parts, dim=0)) - embeddings = torch.stack(embeddings, dim=0) + # embedding_parts.append(text_embeddings[sample_idx, position:]) + # TODO Soham: debug from here + # embeddings.append(torch.cat(embedding_parts, dim=0)) + # embeddings = torch.stack(embeddings, dim=0) with set_generator( self._tensor_space.distributed.tp_generator if self._sequence_parallel else self._tensor_space.distributed.pp_generator ): embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + # assert embeddings.size(1) == 8192 + del image_embeddings + del images return embeddings.to(self._residual_dtype) - - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict | None = None, - ) -> torch.Tensor: - if isinstance(input_, TensorMeta): - return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims], - tensor_name="Embedding output", - dtype=self._residual_dtype, - ) - return self._forward( - input_, - kwargs.get(LanguageModelKwargs.position_ids), - kwargs.get(VisionModelKwargs.images), - kwargs.get(VisionModelKwargs.image_sizes), - kwargs.get(VisionModelKwargs.image_positions), - ) diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py index 234c451a9..b8436f72e 100644 --- a/fast_llm/layers/vision_encoder/adapter.py +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -1,16 +1,13 @@ -import typing - import torch -from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.common.linear import LinearBase +from fast_llm.layers.common.linear import Linear from fast_llm.layers.transformer.config import TransformerDimNames from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames from fast_llm.tensor import init_normal_ -class VisionAdapter(Layer): +class VisionAdapter(torch.nn.Module): """ Vision adapter layer for the LLM. """ @@ -19,14 +16,14 @@ def __init__(self, intermediate_size: int, tensor_space: TensorSpace, name: str super().__init__() self._name = name input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) - self.layer_1 = LinearBase( + self.layer_1 = Linear( input_dim, tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), bias=True, weight_init_method=init_normal_(), bias_init_method=init_normal_(), ) - self.layer_2 = LinearBase( + self.layer_2 = Linear( tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), tensor_space.get_tensor_dim(TransformerDimNames.hidden), bias=True, @@ -34,11 +31,5 @@ def __init__(self, intermediate_size: int, tensor_space: TensorSpace, name: str bias_init_method=init_normal_(), ) - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict[str, typing.Any] | None = None, - ): + def forward(self, input_: torch.Tensor): return self.layer_2(self.layer_1(input_)) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 5e4722513..65ae8e502 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -2,7 +2,7 @@ from fast_llm.engine.base_model.config import BaseModelArchitectureConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import NormalizationType +from fast_llm.layers.common.config import NormalizationConfig class VisionEncoderDimNames: @@ -10,9 +10,11 @@ class VisionEncoderDimNames: intermediate_size = "vision_intermediate_size" patch_height = "vision_patch_height" patch_width = "vision_patch_width" + kv_channels = "vision_kv_channels" class VisionModelKwargs: + patch_size = "patch_size" images = "images" image_positions = "image_positions" image_height = "image_height" @@ -21,6 +23,9 @@ class VisionModelKwargs: image_mean = "image_normalization_mean" image_std = "image_normalization_std" image_rescale_factor = "image_rescale_factor" + rope_theta = "vit_rope_theta" + rotary_inv_freq = "vit_rotary_inv_freq" + kv_channels = "vit_kv_channels" @config_class() @@ -54,10 +59,8 @@ class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): Configuration class for the vision encoder, which transforms images into embeddings """ path: str | None = Field(default=None, desc="Path to a pretrained vision encoder model.", hint=FieldHint.optional) - pre_norm: NormalizationType = Field( - default=NormalizationType.rms_norm, - desc="The type of normalization to use before the transformer layers.", - hint=FieldHint.optional, + pre_norm: NormalizationConfig = Field( + default_factory=NormalizationConfig, ) hidden_size: int = Field( default=1024, desc="The size of the hidden layers in the transformer model.", hint=FieldHint.optional @@ -168,6 +171,10 @@ def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.intermediate_size, self.adapter_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_height, self.encoder.patch_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_width, self.encoder.patch_size)) + # TODO Soham: add a check for kv channels + tensor_space.add_tensor_dim( + TensorDim(VisionEncoderDimNames.kv_channels, self.encoder.hidden_size // self.encoder.num_attention_heads) + ) # tensor_space.add_tensor_dim( # CompositeTensorDim(VisionEncoderDimNames.) # ) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index b028fa1fa..bbcebf251 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -2,7 +2,8 @@ import typing import torch -from transformers import PixtralVisionConfig, PixtralVisionModel +from transformers import PixtralVisionConfig +from transformers.models.pixtral.modeling_pixtral import PixtralTransformer from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -13,6 +14,33 @@ from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +def position_ids_in_meshgrid(patch_embeddings_list, max_width): + positions = [] + for patch in patch_embeddings_list: + height, width = patch.shape[-2:] + mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_width + v_grid + positions.append(ids[:, 0]) + return torch.cat(positions) + + +def generate_block_attention_mask(patch_embeds_list, tensor): + dtype = tensor.dtype + device = tensor.device + seq_len = tensor.shape[1] + d_min = torch.finfo(dtype).min + causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device) + + block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1) + block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1) + for start, end in zip(block_start_idx, block_end_idx): + causal_mask[start:end, start:end] = 0 + + causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1) + return causal_mask + + # TODO Soham: should this just be nn.Module? class VisionEncoder(Layer): """ @@ -25,37 +53,49 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._config = config.vision_encoder self._distributed_config = tensor_space.distributed_config with torch.device("meta"): - if self._config.encoder.path: - self.vision_encoder = PixtralVisionModel.from_pretrained( - self._config.encoder.path, torch_dtype=self._distributed_config.training_dtype.torch - ) - else: - # TODO Soham options to fix rotary: - # 1. load PixtralTransformer instead of PixtralVisionModel. Required to implement conv2d, ln_pre separately and store positional embeddings in kwargs_meta - # 2. set self.vision_encoder.position_embeddings = PixtralRotaryEmbedding(config) outside of meta scope - self.vision_encoder = PixtralVisionModel( - PixtralVisionConfig( - hidden_size=self._config.encoder.hidden_size, - intermediate_size=self._config.encoder.intermediate_size, - num_hidden_layers=self._config.encoder.num_hidden_layers, - num_attention_heads=self._config.encoder.num_attention_heads, - num_channels=self._config.encoder.num_channels, - image_size=self._config.encoder.image_size, - patch_size=self._config.encoder.patch_size, - hidden_act=self._config.encoder.hidden_act, - attention_dropout=self._config.encoder.attention_dropout, - rope_theta=self._config.encoder.rope_theta, - initializer_range=self._config.encoder.initializer_range, - ) - ) + # TODO Soham options to fix rotary: + # 1. load PixtralTransformer instead of PixtralVisionModel. Required to implement conv2d, ln_pre separately and store positional embeddings in kwargs_meta + # 2. set self.vision_encoder.position_embeddings = PixtralRotaryEmbedding(config) outside of meta scope + config = PixtralVisionConfig( + hidden_size=self._config.encoder.hidden_size, + intermediate_size=self._config.encoder.intermediate_size, + num_hidden_layers=self._config.encoder.num_hidden_layers, + num_attention_heads=self._config.encoder.num_attention_heads, + num_channels=self._config.encoder.num_channels, + image_size=self._config.encoder.image_size, + patch_size=self._config.encoder.patch_size, + hidden_act=self._config.encoder.hidden_act, + attention_dropout=self._config.encoder.attention_dropout, + rope_theta=self._config.encoder.rope_theta, + initializer_range=self._config.encoder.initializer_range, + ) + self.patch_conv = torch.nn.Conv2d( + in_channels=3, + out_channels=self._config.encoder.hidden_size, + kernel_size=self._config.encoder.patch_size, + stride=self._config.encoder.patch_size, + bias=False, + ) + self.patch_conv.weight = ParameterMeta.from_dims( + tuple( + TensorDim(f"patch_conv_weight_{idx}", size) + for idx, size in enumerate(self.patch_conv.weight.shape) + ), + init_method=init_normal_(), + ) + self.norm = self._config.encoder.pre_norm.get_layer( + tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + ) + self.vision_transformer = PixtralTransformer(config) + # self.vision_encoder = PixtralVisionModel(config) param_names = [] # gather all names first. PyTorch complains if we do it in the loop - for name, param in self.vision_encoder.named_parameters(): + for name, param in self.vision_transformer.named_parameters(): param_names.append(name) for name in param_names: *module_path, stem = name.split(".") - module = functools.reduce(getattr, module_path, self.vision_encoder) - param = self.vision_encoder.get_parameter(name) + module = functools.reduce(getattr, module_path, self.vision_transformer) + param = self.vision_transformer.get_parameter(name) setattr( module, stem, @@ -72,6 +112,38 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): tensor_space=tensor_space, ) + def _forward( + self, input_: torch.Tensor, image_sizes: torch.Tensor, inv_freq: torch.Tensor, image_width: int + ) -> torch.Tensor: + patch_embeddings = self.patch_conv(input_) + patch_embeddings_list = [ + embedding[..., : image_size[0], : image_size[1]] + for embedding, image_size in zip(patch_embeddings, image_sizes) + ] + patch_embeddings = torch.cat([p.flatten(1).T for p in patch_embeddings_list], dim=0).unsqueeze(0) + patch_embeddings = self.norm(patch_embeddings) + position_ids = position_ids_in_meshgrid(patch_embeddings_list, image_width // self._config.encoder.patch_size) + freqs = inv_freq[position_ids] + with torch.autocast(device_type=input_.device.type): + cos = freqs.cos() + sin = freqs.sin() + cos = cos.to(dtype=input_.dtype) + sin = sin.to(dtype=input_.dtype) + + attention_mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeddings_list], patch_embeddings + ) + + (out,) = self.vision_transformer( + patch_embeddings, + attention_mask=attention_mask, + position_embeddings=(cos, sin), + output_attentions=False, + return_dict=False, + ) + + return self.adapter(out) + def forward( self, input_: torch.Tensor, @@ -85,4 +157,10 @@ def forward( tensor_name="Vision Output", dtype=self._distributed_config.training_dtype.torch, ) - return self.adapter(self.vision_encoder(input_, kwargs[VisionModelKwargs.image_sizes])) + return self._forward( + input_, + kwargs[VisionModelKwargs.image_sizes][:1], + kwargs[VisionModelKwargs.rotary_inv_freq], + image_width=kwargs[VisionModelKwargs.image_width], + ) + # return self.adapter(self.vision_encoder(input_, kwargs[VisionModelKwargs.image_sizes])) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 7ebfd3d7d..57ee3a0b2 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -40,7 +40,27 @@ def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: """ width_padding = max(0, max_height - image.size(1)) depth_padding = max(0, max_width - image.size(2)) - return F.pad(image, (0, 0, width_padding, depth_padding), 0) + return F.pad(image, (0, 0, depth_padding, width_padding), 0) + + +def create_inv_freqs(rope_theta: int, kv_channels: int, image_size: int, patch_size: int) -> torch.Tensor: + freqs = 1.0 / (rope_theta ** (torch.arange(0, kv_channels, 2).float() / kv_channels)) + max_patches_per_side = image_size // patch_size + + h = torch.arange(max_patches_per_side) + w = torch.arange(max_patches_per_side) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape(-1, kv_channels // 2) + + return torch.cat((inv_freq, inv_freq), dim=-1) class VisionPreprocessor: @@ -53,7 +73,8 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: images = kwargs.get("images") im_height = kwargs.get(VisionModelKwargs.image_height) im_width = kwargs.get(VisionModelKwargs.image_width) - kwargs[VisionModelKwargs.image_sizes] = [(im.size(1), im.size(2)) for im in images] + image_sizes = [get_resize_dims(im.size(1), im.size(2), im_height, im_width) for im in images] + kwargs[VisionModelKwargs.image_sizes] = image_sizes images = [ pad( normalize( @@ -72,3 +93,9 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: dtype=self._distributed_config.training_dtype.torch, ) kwargs[VisionModelKwargs.images] = images + kwargs[VisionModelKwargs.rotary_inv_freq] = create_inv_freqs( + kwargs[VisionModelKwargs.rope_theta], + kwargs[VisionModelKwargs.kv_channels], + im_height, + kwargs[VisionModelKwargs.patch_size], + ).to(device=self._tensor_space.distributed.device) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 3caaee5ad..bd7da7979 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -597,6 +597,10 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ), ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "pre_norm", "type"),), + fast_llm_value=NormalizationType.rms_norm, + ), # Vision Transformer RenameParamConverter( fast_llm_names=(("vision_encoder", "encoder", "num_hidden_layers"),), @@ -679,39 +683,39 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: for i in range(num_layers): vision_transformer_converters += [ WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.k_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.k_proj.weight", f"vision_tower.transformer.layers.{i}.attention.k_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.v_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.v_proj.weight", f"vision_tower.transformer.layers.{i}.attention.v_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.q_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.q_proj.weight", f"vision_tower.transformer.layers.{i}.attention.q_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.o_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.o_proj.weight", f"vision_tower.transformer.layers.{i}.attention.o_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention_norm.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention_norm.weight", f"vision_tower.transformer.layers.{i}.attention_norm.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.down_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.down_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.down_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.gate_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.gate_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.up_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.up_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.up_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.ffn_norm.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.ffn_norm.weight", f"vision_tower.transformer.layers.{i}.ffn_norm.weight", ), ] @@ -720,20 +724,20 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: patch_conv_converter = WeightConverter( - "layers.0.vision_encoder.vision_encoder.patch_conv.weight", + "layers.0.vision_encoder.patch_conv.weight", "vision_tower.patch_conv.weight", ) # TODO Soham: use _get_weight_and_bias_converters? layernorm_converters = [] layer_norm_converter = WeightConverter( - "layers.0.vision_encoder.vision_encoder.ln_pre.weight", + "layers.0.vision_encoder.norm.weight", "vision_tower.ln_pre.weight", ) layernorm_converters.append(layer_norm_converter) layer_norm_converter - if self._model.config.base_model.vision_encoder.encoder.pre_norm == NormalizationType.layer_norm: + if self._model.config.base_model.vision_encoder.encoder.pre_norm.type == NormalizationType.layer_norm: layer_norm_bias_converter = WeightConverter( - "layers.0.vision_encoder.vision_encoder.ln_pre.bias", + "layers.0.vision_encoder.norm.bias", "vision_tower.ln_pre.bias", ) layernorm_converters.append(layer_norm_bias_converter) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 0890051ea..ffbd22816 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -27,7 +27,7 @@ RotaryEmbeddingPreprocessor, ) from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.config import VisionModelKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionModelKwargs from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -164,11 +164,16 @@ def preprocess_meta( ] image_rescale_factor = self._config.vision_encoder.normalization.rescale_factor vision_kwargs = { + VisionModelKwargs.patch_size: self._config.vision_encoder.encoder.patch_size, VisionModelKwargs.image_height: image_height, VisionModelKwargs.image_width: image_width, VisionModelKwargs.image_mean: image_mean, VisionModelKwargs.image_std: image_std, VisionModelKwargs.image_rescale_factor: image_rescale_factor, + VisionModelKwargs.rope_theta: self._config.vision_encoder.encoder.rope_theta, + VisionModelKwargs.kv_channels: self._tensor_space.get_tensor_dim( + VisionEncoderDimNames.kv_channels + ).size, } else: vision_kwargs = {} @@ -306,16 +311,6 @@ def preprocess( if self._use_flash_attention: self._flash_varlen_preprocessor.preprocess(kwargs_meta) - if batch.images is not None: - kwargs_meta[VisionModelKwargs.images] = [ - img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) - for images in batch.images - for img in images - ] - kwargs_meta[VisionModelKwargs.image_positions] = batch.image_positions - if self._config.vision_encoder: - self._vision_preprocessor.preprocess(kwargs_meta) - # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. pasts = presents @@ -349,6 +344,15 @@ def preprocess( else: labels[i, start : end + 1] = -100 kwargs[LanguageModelKwargs.labels] = labels + if batch.images is not None: + kwargs[VisionModelKwargs.images] = [ + img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + for images in batch.images + for img in images + ] + kwargs[VisionModelKwargs.image_positions] = batch.image_positions + if self._config.vision_encoder: + self._vision_preprocessor.preprocess(kwargs) if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.preprocess(kwargs) if self._config.transformer.rotary.enabled: From 5761a2d52cf4e7e5fcfd38ec19750be48cb06f8e Mon Sep 17 00:00:00 2001 From: root Date: Wed, 30 Apr 2025 18:24:54 +0000 Subject: [PATCH 011/161] fix --- fast_llm/data/dataset/gpt/memmap.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 99bfbfa42..54bf6826a 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -49,11 +49,14 @@ def _init( with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 3: + self._has_preference_spans = struct.unpack("= 4: self._has_images = struct.unpack("= 3: + if self._has_images and self._version >= 4: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) @@ -333,10 +336,12 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP idx_stream.write(MEMMAP_INDEX_HEADER) # Indicates the version # Version 2 onwards optionally add loss-masking spans - # Version 3 onwards optionally add images - idx_stream.write(struct.pack(" 0 else 0)) + # Placeholder flag for preference spans + idx_stream.write(struct.pack(" 0 else 0)) # Data type From d45d60061068b316c3e49d633ea0e8adbc2d52ef Mon Sep 17 00:00:00 2001 From: root Date: Thu, 1 May 2025 05:43:50 +0000 Subject: [PATCH 012/161] fixes --- fast_llm/data/config.py | 57 ------------------- fast_llm/data/data/gpt/data.py | 9 +-- fast_llm/data/dataset/gpt/config.py | 5 +- fast_llm/data/dataset/gpt/memmap.py | 15 +---- fast_llm/data/dataset/gpt/sampled.py | 18 ++---- fast_llm/data/image_processor.py | 55 ------------------ fast_llm/engine/schedule/config.py | 11 +--- fast_llm/layers/vision_encoder/config.py | 57 +------------------ fast_llm/layers/vision_encoder/encoder.py | 2 +- .../layers/vision_encoder/preprocessing.py | 30 +++++++--- fast_llm/models/gpt/model.py | 6 +- fast_llm/models/gpt/trainer.py | 3 +- 12 files changed, 44 insertions(+), 224 deletions(-) delete mode 100644 fast_llm/data/image_processor.py diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index f1a0fd58a..1586d370d 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -34,60 +34,3 @@ class TokenizerConfig(Config): desc="Path to the tokenizer file.", hint=FieldHint.core, ) - - -@config_class() -class ImageProcessorConfig(Config): - """ - Configuration for the image processor - """ - - # Defaults taken from [pixtral](https://github.com/huggingface/transformers/blob/794fde7b1c3d041519fc28ea3e1461b0cfcad4e7/src/transformers/models/pixtral/image_processing_pixtral.py#L201) - # patch_size: list[int] = Field( - # default_factory=lambda: [16, 16], - # desc="Size for each path extracted from the image. Each patch corresponds to a token for the transformer", - # hint=FieldHint.optional, - # ) - # max_height: int = Field( - # default=1024, - # desc="Maximum height of the image. Image will be resized if larger", - # hint=FieldHint.optional, - # ) - # max_width: int = Field( - # default=1024, - # desc="Maximum width of the image. Image will be resized if larger", - # hint=FieldHint.optional, - # ) - # mean: list[float] = Field( - # default_factory=lambda: [0.48145466, 0.4578275, 0.40821073], - # desc="Mean RGB values for pixel normalization", - # hint=FieldHint.optional, - # ) - # std: list[float] = Field( - # default_factory=lambda: [0.26862954, 0.26130258, 0.27577711], - # desc="Standard deviation RGB values for pixel normalization", - # hint=FieldHint.optional, - # ) - # rescale_factor: float = Field( - # default=255.0, - # desc="Diminisher factor for pixel normalization", - # hint=FieldHint.optional, - # ) - - -@config_class() -class MultiModalProcessorConfig(Config): - """ - Wrapper config that stores the `ImageProcessorConfig` and `TokenizerConfig` - """ - - tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, - desc="Configuration for the tokenizer.", - hint=FieldHint.core, - ) - image_processor: ImageProcessorConfig = Field( - default_factory=ImageProcessorConfig, - desc="Configuration for the image processor.", - hint=FieldHint.core, - ) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index cffaa734f..34b86f213 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -91,8 +91,7 @@ def __init__( max_sequence_length: int, cross_document_attention: bool = True, patch_size: list[int] | None = None, - max_image_height: int | None = None, - max_image_width: int | None = None, + max_image_size: int | None = None, ): """ Create the data and gather some basic information on the dataset(s). @@ -103,8 +102,7 @@ def __init__( self._max_sequence_length = max_sequence_length self._cross_document_attention = cross_document_attention self._patch_size = patch_size - self._max_image_height = max_image_height - self._max_image_width = max_image_width + self._max_image_size = max_image_size def setup( self, @@ -153,8 +151,7 @@ def setup( truncate_documents=self._config.truncate_documents, cross_document_attention=self._cross_document_attention, patch_size=self._patch_size, - image_height=self._max_image_height, - image_width=self._max_image_width, + image_size=self._max_image_size, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 8022a05f7..65adf0bda 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -73,9 +73,8 @@ class GPTSamplingData(SamplingData): tokenizer: "Tokenizer" truncate_documents: bool = True cross_document_attention: bool = True - patch_size: list[int] | None = None - image_height: int | None = None - image_width: int | None = None + patch_size: int | None = None + image_size: int | None = None @config_class() diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 54bf6826a..8651b8fcd 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -170,20 +170,8 @@ def get( offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False, - # , patch_size: tuple(int), max_height: int, max_width: int ): # TODO Soham: handle spans - # if self._has_images: - # doc_size = self._document_sizes[idx] - # n_images = self._n_images[idx] - # image_positions = self._im_positions[idx] - # image_lengths = self._im_lengths[idx] - # image_tokens_seen = 0 - # for idx in range(n_images): - # height, width = ImageProcessor.get_resize_dims(image_lengths[0], image_lengths[1], max_height, max_width) - # n_image_tokens = (height // patch_size[0]) * (width // patch_size[1]) - # if (image_positions[idx] > offset + length) or (image_positions[idx] + n_tokens < offset): - # continue token_ids = np.frombuffer( self._bin_buffer, dtype=self._dtype, @@ -299,6 +287,9 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP for image in document.images: # assume 3 channels (RGB) for all images with PIL.Image.open(io.BytesIO(image["bytes"])) as img: + if img.mode == "L": + # Convert grayscale to RGB + img = img.convert("RGB") pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW image_lengths.append(np.array(pixels.shape[1:])) bin_stream.write(pixels.tobytes(order="C")) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 973c1db53..0ba3f0e13 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -12,9 +12,9 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset -from fast_llm.data.image_processor import ImageProcessor from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.layers.vision_encoder.preprocessing import get_num_patches, get_resize_dims from fast_llm.utils import Assert try: @@ -91,8 +91,7 @@ def __init__( self._num_samples = sampling.num_samples self._sequence_length = sampling.sequence_length self._patch_size = sampling.patch_size - self._image_height = sampling.image_height - self._image_width = sampling.image_width + self._image_size = sampling.image_size self._cross_document_attention = sampling.cross_document_attention self._config = sampling.config self._truncate_documents = sampling.truncate_documents @@ -142,7 +141,7 @@ def _sample(self) -> None: image_token_sizes = torch.zeros_like(document_sizes).to(self._device) # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): - image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size[0]) * (sizes[:, 1] // self._patch_size[1])) + image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size) * (sizes[:, 1] // self._patch_size)) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() @@ -394,10 +393,8 @@ def __getitem__(self, index: int) -> typing.Any: document_size, image_lengths = self._indexed_dataset.get_document_size(document_index, self._patch_size) image_sizes = [ - ImageProcessor.get_num_patches_from_dims( - *ImageProcessor.get_resize_dims( - *image_length, self._image_height, self._image_width, self._patch_size - ), + get_num_patches( + *get_resize_dims(*image_length, self._image_size, self._image_size, self._patch_size), self._patch_size, ) for image_length in image_lengths @@ -443,10 +440,6 @@ def __getitem__(self, index: int) -> typing.Any: image_tokens_added += image_tokens start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) - # TODO Soham: remove this - # if len(sample.images) == 1: - # sample.images.append(sample.images[0]) - # sample.image_positions = np.concatenate([sample.image_positions, sample.image_positions]) images.append(sample.images) # TODO Soham: add offsets for loss masking spans if self._config.use_loss_masking_spans: @@ -461,7 +454,6 @@ def __getitem__(self, index: int) -> typing.Any: sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) - # + np.array([ImageProcessor.get_num_patches(image) for image in images[idx] for idx in range(len(images))]) if not self._cross_document_attention else None ) diff --git a/fast_llm/data/image_processor.py b/fast_llm/data/image_processor.py deleted file mode 100644 index edfeceb95..000000000 --- a/fast_llm/data/image_processor.py +++ /dev/null @@ -1,55 +0,0 @@ -import math - -import torch -from torchvision.transforms.v2 import functional as F - -from fast_llm.data.config import ImageProcessorConfig - - -class ImageProcessor: - def __init__(self, config: ImageProcessorConfig): - self.patch_size = config.patch_size - self.mean = [x / config.rescale_factor for x in config.mean] - self.std = [x / config.rescale_factor for x in config.std] - self.max_height = config.max_height - self.max_width = config.max_width - assert ( - self.max_height % self.patch_size[0] == 0 - ), "max_height must be divisible by patch_size[0]. Found {max_height} and {self.patch_size[0]}" - assert ( - self.max_width % self.patch_size[1] == 0 - ), "max_width must be divisible by patch_size[1]. Found {max_width} and {self.patch_size[1]}" - - def resize(self, image): - # Resize the image to the specified size - # TODO Soham: resize for patches only during train? - # TODO Soham: convert all images to tensor? - # height = image.shape[0] - # width = image.shape[1] - height, width = self.get_resize_dims(image.shape[0], image.shape[1], self.max_height, self.max_width) - - # TODO: options for interpolation mode - return F.resize(image, size=(height, width), interpolation=F.InterpolationMode.BICUBIC) - - # TODO Soham: move to utils - @classmethod - def get_resize_dims(self, height, width, max_height, max_width, patch_size: list[int]): - ratio = max(height / max_height, width / max_width) - return ( - (math.ceil(height / ratio), math.ceil(width / ratio)) - if ratio > 1 - else (patch_size[0] * math.ceil(height / patch_size[0]), patch_size[1] * math.ceil(width / patch_size[1])) - ) - - def normalize(self, image: torch.Tensor) -> torch.Tensor: - # Normalize the image using the mean and std - return F.normalize(image, mean=self.mean, std=self.std) - - @classmethod - # TODO Soham: move to utils - def get_num_patches(self, image: torch.Tensor, patch_size: list[int]) -> torch.Tensor: - return (image.shape[0] // patch_size[0]) * (image.shape[1] // patch_size[1]) - - @classmethod - def get_num_patches_from_dims(self, height: int, width: int, patch_size: list[int]) -> torch.Tensor: - return (height // patch_size[0]) * (width // patch_size[1]) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 16cfaf713..9cf8f8b57 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -55,19 +55,14 @@ class BatchConfig(Config): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) - patch_size: list[int] | None = Field( + patch_size: int | None = Field( default=None, desc="Patch size for each image token", hint=FieldHint.optional, ) - max_image_height: int | None = Field( + max_image_size: int | None = Field( default=None, - desc="Maximum image height for each image token", - hint=FieldHint.optional, - ) - max_image_width: int | None = Field( - default=None, - desc="Maximum image width for each image token", + desc="Maximum image height and width", hint=FieldHint.optional, ) num_micro_sequences: int = Field( diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 65ae8e502..b83a118b5 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -8,8 +8,7 @@ class VisionEncoderDimNames: out_channels = "vision_out_channels" intermediate_size = "vision_intermediate_size" - patch_height = "vision_patch_height" - patch_width = "vision_patch_width" + patch_size = "vision_patch_size" kv_channels = "vision_kv_channels" @@ -17,8 +16,7 @@ class VisionModelKwargs: patch_size = "patch_size" images = "images" image_positions = "image_positions" - image_height = "image_height" - image_width = "image_width" + image_size = "image_size" image_sizes = "image_sizes" image_mean = "image_normalization_mean" image_std = "image_normalization_std" @@ -28,30 +26,6 @@ class VisionModelKwargs: kv_channels = "vit_kv_channels" -@config_class() -class PatchConvConfig(BaseModelArchitectureConfig): - _abstract = False - """ - Configuration class for the convolution layers to apply on the image patches - """ - in_channels: int = Field( - default=3, - desc="Number of input channels for the convolution layers. Typically 3 for RGB images.", - hint=FieldHint.optional, - ) - bias: bool = Field( - default=False, desc="Whether to use a bias term in the convolution layers.", hint=FieldHint.optional - ) - height: int = Field( - default=16, - desc="Height of the image patches considered as tokens", - ) - width: int | None = Field( - default=16, - desc="Width of the image patches considered as tokens", - ) - - @config_class() class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): _abstract = False @@ -169,33 +143,8 @@ class VisionArchitectureConfig(BaseModelArchitectureConfig): def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.encoder.hidden_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.intermediate_size, self.adapter_size)) - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_height, self.encoder.patch_size)) - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_width, self.encoder.patch_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.encoder.patch_size)) # TODO Soham: add a check for kv channels tensor_space.add_tensor_dim( TensorDim(VisionEncoderDimNames.kv_channels, self.encoder.hidden_size // self.encoder.num_attention_heads) ) - # tensor_space.add_tensor_dim( - # CompositeTensorDim(VisionEncoderDimNames.) - # ) - - # patch_convolution: PatchConvConfig = Field( - # default_factory=PatchConvConfig, - # desc="Configuration for the convolution layers applied to the image patches.", - # hint=FieldHint.optional - # ) - # normalization: NormalizationArchitectureConfig = Field( - # default_factory=NormalizationArchitectureConfig, - # desc="Configuration for the normalization layers applied to the image patches.", - # hint=FieldHint.optional - # ) - # transformer: TransformerArchitectureConfig = Field( - # default_factory=TransformerArchitectureConfig, - # desc="Configuration for the transformer layers applied to the image patches.", - # hint=FieldHint.optional - # ) - # patch_rotary: RotaryArchitectureConfig = Field( - # default_factory=RotaryArchitectureConfig, - # desc="Configuration for the rotary positional embeddings applied to the image patches.", - # hint=FieldHint.optional - # ) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index bbcebf251..8c694d28a 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -161,6 +161,6 @@ def forward( input_, kwargs[VisionModelKwargs.image_sizes][:1], kwargs[VisionModelKwargs.rotary_inv_freq], - image_width=kwargs[VisionModelKwargs.image_width], + image_width=kwargs[VisionModelKwargs.image_size], ) # return self.adapter(self.vision_encoder(input_, kwargs[VisionModelKwargs.image_sizes])) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 57ee3a0b2..154c1a16d 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -1,3 +1,4 @@ +import math import typing import torch @@ -5,9 +6,17 @@ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.vision_encoder.config import VisionArchitectureConfig, VisionModelKwargs +from fast_llm.utils import div -def get_resize_dims(height: int, width: int, max_height: int, max_width: int) -> tuple[int, int]: +def get_num_patches(height: int, width: int, patch_size: int) -> tuple[int, int]: + """ + Calculate the number of patches in height and width dimensions. + """ + return div(height, patch_size) * div(width, patch_size) + + +def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: """ Calculate the new dimensions for resizing an image while maintaining the aspect ratio. If the image is larger than the max dimensions, it will be resized to fit within them. @@ -17,12 +26,12 @@ def get_resize_dims(height: int, width: int, max_height: int, max_width: int) -> return ( (int(height / ratio), int(width / ratio)) if ratio > 1 - else (max_height * (height // max_height), max_width * (width // max_width)) + else (patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size)) ) -def resize(image: torch.Tensor, max_height: int, max_width: int) -> tuple[int, int]: - resize_dims = get_resize_dims(image.size(1), image.size(2), max_height, max_width) +def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: + resize_dims = get_resize_dims(image.size(1), image.size(2), max_height, max_width, patch_size=patch_size) # TODO: options for interpolation mode? return F.resize(image, size=resize_dims, interpolation=F.InterpolationMode.BICUBIC) @@ -71,14 +80,17 @@ def __init__(self, config: VisionArchitectureConfig, tensor_space: TensorSpace): def preprocess(self, kwargs: dict[str, typing.Any]) -> None: images = kwargs.get("images") - im_height = kwargs.get(VisionModelKwargs.image_height) - im_width = kwargs.get(VisionModelKwargs.image_width) - image_sizes = [get_resize_dims(im.size(1), im.size(2), im_height, im_width) for im in images] + im_height = kwargs.get(VisionModelKwargs.image_size) + im_width = kwargs.get(VisionModelKwargs.image_size) + patch_size = kwargs[VisionModelKwargs.patch_size] + image_sizes = [ + get_resize_dims(im.size(1), im.size(2), im_height, im_width, patch_size=patch_size) for im in images + ] kwargs[VisionModelKwargs.image_sizes] = image_sizes images = [ pad( normalize( - resize(image, im_height, im_width) / kwargs[VisionModelKwargs.image_rescale_factor], + resize(image, im_height, im_width, patch_size) / kwargs[VisionModelKwargs.image_rescale_factor], mean=kwargs[VisionModelKwargs.image_mean], std=kwargs[VisionModelKwargs.image_std], ), @@ -97,5 +109,5 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: kwargs[VisionModelKwargs.rope_theta], kwargs[VisionModelKwargs.kv_channels], im_height, - kwargs[VisionModelKwargs.patch_size], + patch_size, ).to(device=self._tensor_space.distributed.device) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index ffbd22816..c273f09b1 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -150,8 +150,7 @@ def preprocess_meta( micro_sequence_length = sequence_length if self._config.vision_encoder: - image_height = batch_meta.max_image_height - image_width = batch_meta.max_image_width + image_size = batch_meta.max_image_size image_mean = [ self._config.vision_encoder.normalization.mean_r, self._config.vision_encoder.normalization.mean_g, @@ -165,8 +164,7 @@ def preprocess_meta( image_rescale_factor = self._config.vision_encoder.normalization.rescale_factor vision_kwargs = { VisionModelKwargs.patch_size: self._config.vision_encoder.encoder.patch_size, - VisionModelKwargs.image_height: image_height, - VisionModelKwargs.image_width: image_width, + VisionModelKwargs.image_size: image_size, VisionModelKwargs.image_mean: image_mean, VisionModelKwargs.image_std: image_std, VisionModelKwargs.image_rescale_factor: image_rescale_factor, diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index b801fbd3d..bc16829b3 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -22,8 +22,7 @@ def _get_data(self) -> GPTData: max_sequence_length=self._config.batch.sequence_length, cross_document_attention=self._config.batch.cross_document_attention, patch_size=self._config.batch.patch_size, - max_image_height=self._config.batch.max_image_height, - max_image_width=self._config.batch.max_image_width, + max_image_size=self._config.batch.max_image_size, ) def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: From 74a99b8ec047e31acd514a48237196ed9da761be Mon Sep 17 00:00:00 2001 From: root Date: Tue, 6 May 2025 17:44:50 +0000 Subject: [PATCH 013/161] changes --- fast_llm/engine/schedule/config.py | 21 +- fast_llm/functional/config.py | 2 + fast_llm/layers/language_model/config.py | 20 +- fast_llm/layers/multi_modal/embedding.py | 52 +-- fast_llm/layers/transformer/attention.py | 109 +++--- fast_llm/layers/transformer/config.py | 96 +++-- fast_llm/layers/transformer/mlp.py | 22 +- fast_llm/layers/transformer/preprocessing.py | 139 ++++++-- fast_llm/layers/transformer/transformer.py | 18 +- fast_llm/layers/vision_encoder/adapter.py | 39 ++- fast_llm/layers/vision_encoder/config.py | 178 ++++++---- fast_llm/layers/vision_encoder/encoder.py | 141 ++------ .../layers/vision_encoder/preprocessing.py | 153 ++++++-- fast_llm/models/gpt/conversion.py | 330 +++++++++++------- fast_llm/models/gpt/model.py | 113 ++++-- fast_llm/tools/cli.py | 1 - 16 files changed, 886 insertions(+), 548 deletions(-) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 9cf8f8b57..517a9cff5 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -55,16 +55,6 @@ class BatchConfig(Config): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) - patch_size: int | None = Field( - default=None, - desc="Patch size for each image token", - hint=FieldHint.optional, - ) - max_image_size: int | None = Field( - default=None, - desc="Maximum image height and width", - hint=FieldHint.optional, - ) num_micro_sequences: int = Field( init=False, desc="Number of micro-sequences to split each sample (= seqence length / micro-sequence length).", @@ -81,6 +71,17 @@ class BatchConfig(Config): desc="Pointer to a distributed configuration, required to know the data-parallel split of the batch.", hint=FieldHint.setup, ) + # Image inputs + patch_size: int | None = Field( + default=None, + desc="Patch size for each image token", + hint=FieldHint.optional, + ) + max_image_size: int | None = Field( + default=None, + desc="Maximum image height and width", + hint=FieldHint.optional, + ) def setup(self, distributed_config: DistributedConfig) -> None: self._distributed = distributed_config diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 9f1fe005e..c5da0f9b1 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -82,6 +82,8 @@ def _set_activation_fn_map() -> None: ActivationType.squared_relu: "relu2", } _ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()} +_ACTIVATION_HF_NAMES_INV["gelu"] = ActivationType.gelu + MAX_DROPLESS_BLOCK_SIZE_ROW = 128 diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index ec80a9334..887952d7a 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -6,7 +6,7 @@ from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig -from fast_llm.layers.vision_encoder.config import VisionArchitectureConfig +from fast_llm.layers.vision_encoder.config import VisionEncoderArchitectureConfig, VisionEncoderConfig from fast_llm.utils import Assert @@ -34,6 +34,7 @@ class LanguageModelKwargs: position_ids = "position_ids" # TODO: These are generic labels = "labels" + tokens = "tokens" phase = "phase" @@ -44,7 +45,7 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.core, ) - vision_encoder: None | VisionArchitectureConfig = Field( + vision_encoder: None | VisionEncoderArchitectureConfig = Field( default=None, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, @@ -130,7 +131,7 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) - vision_encoder: None | VisionArchitectureConfig = FieldUpdate( + vision_encoder: None | VisionEncoderConfig = FieldUpdate( default=None, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, @@ -215,16 +216,3 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: if self.vision_encoder is not None: self.vision_encoder.setup_tensor_space(tensor_space) - - -class MultiModalBaseConfig(BaseModelConfig): - language_model: LanguageModelBaseConfig = Field( - default_factory=LanguageModelBaseConfig, - desc="Configuration for the language model.", - hint=FieldHint.core, - ) - vision_model: VisionArchitectureConfig = Field( - default_factory=VisionArchitectureConfig, - desc="Configuration for the vision inputs.", - hint=FieldHint.core, - ) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 3b62c60b7..a3abe7813 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -7,8 +7,8 @@ from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.layers.vision_encoder.config import VisionModelKwargs -from fast_llm.layers.vision_encoder.encoder import VisionEncoder +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs +from fast_llm.layers.vision_encoder.preprocessing import get_num_patches from fast_llm.tensor import TensorMeta @@ -23,7 +23,6 @@ def __init__( tensor_space: TensorSpace, ): super().__init__(config, tensor_space) - self.vision_encoder = VisionEncoder(config, tensor_space) def forward( self, @@ -38,46 +37,29 @@ def forward( tensor_name="Embedding output", dtype=self._residual_dtype, ) - # return self._forward( - # input_, - # kwargs.get(LanguageModelKwargs.position_ids), - # kwargs.get(VisionModelKwargs.images), - # kwargs.get(VisionModelKwargs.image_sizes), - # kwargs.get(VisionModelKwargs.image_positions), - # ) - # TODO Soham: offset position ids - images = kwargs.pop(VisionModelKwargs.images)[:1] + # image_embeddings = kwargs.pop(VisionEncoderKwargs.patch_embeddings) position_ids = kwargs.get(LanguageModelKwargs.position_ids) - image_positions = kwargs.get(VisionModelKwargs.image_positions)[:1] - image_embeddings = self.vision_encoder(images, kwargs) - embeddings = super()._forward(input_, position_ids) - img_tokens_seen = 0 + image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes) + image_positions = kwargs.get(VisionEncoderKwargs.image_positions) + tokens = kwargs.get(LanguageModelKwargs.tokens) + # get text embeddings + embeddings = super()._forward(tokens, position_ids) image_idx = 0 - for sample_idx, positions in enumerate(image_positions): - # embedding_parts = [] - for position in positions[:1]: - image = images[image_idx] - image_tokens = (image.size(1) // self._config.vision_encoder.encoder.patch_size) * ( - image.size(2) // self._config.vision_encoder.encoder.patch_size - ) - embeddings[sample_idx, position : position + image_tokens] = image_embeddings[ - sample_idx, img_tokens_seen : img_tokens_seen + image_tokens + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) + embeddings[sample_idx, position : position + num_image_tokens] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens ] - # embedding_parts.append(text_embeddings[sample_idx, :position]) - # embedding_parts.append(image_embeddings[sample_idx, img_tokens_seen : img_tokens_seen + image_tokens]) + image_embedding_offset += num_image_tokens image_idx += 1 - img_tokens_seen += image_tokens - # embedding_parts.append(text_embeddings[sample_idx, position:]) - # TODO Soham: debug from here - # embeddings.append(torch.cat(embedding_parts, dim=0)) - # embeddings = torch.stack(embeddings, dim=0) + with set_generator( self._tensor_space.distributed.tp_generator if self._sequence_parallel else self._tensor_space.distributed.pp_generator ): embeddings = torch.dropout(embeddings, self._dropout_p, self.training) - # assert embeddings.size(1) == 8192 - del image_embeddings - del images + return embeddings.to(self._residual_dtype) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index c7ae55c5c..3a3f40239 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -14,7 +14,9 @@ TransformerDimNames, TransformerKwargs, TransformerSubLayerName, + VisionTransformerConfig, ) +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -57,24 +59,6 @@ class Attention(torch.nn.Module): A self-attention layer. """ - _QUERY_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_heads, - TransformerDimNames.kv_channels, - ) - _KV_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.group_heads, - TransformerDimNames.kv_channels, - ) - _CONTEXT_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_dense, - ) - def __init__( self, config: TransformerConfig, @@ -82,12 +66,19 @@ def __init__( layer_index, ): super().__init__() + if isinstance(config, VisionTransformerConfig): + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + elif isinstance(config, TransformerConfig): + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs self._config = config self._tensor_space = tensor_space - Assert.in_range_incl(layer_index, 1, self._config.num_layers) + # Assert.in_range_incl(layer_index, 1, self._config.num_layers) self._layer_index = layer_index self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel self._debug_transformer = self._config.debug_transformer + self._causal = self._config.causal self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -101,19 +92,19 @@ def __init__( max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels).size - self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size - self._local_head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).size - self._local_heads_per_group = self._tensor_space.get_tensor_dim(TransformerDimNames.group_heads).size + self._kv_channels = self._tensor_space.get_tensor_dim(self._transformer_dim_names.kv_channels).size + self._head_groups = self._tensor_space.get_tensor_dim(self._transformer_dim_names.head_groups).global_size + self._local_head_groups = self._tensor_space.get_tensor_dim(self._transformer_dim_names.head_groups).size + self._local_heads_per_group = self._tensor_space.get_tensor_dim(self._transformer_dim_names.group_heads).size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), + self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_query), bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -122,7 +113,7 @@ def __init__( ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), + self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_key_value), bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -133,7 +124,7 @@ def __init__( # Output. self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), + self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_dense), hidden_dim, bias=self._config.add_attn_dense_bias, weight_init_method=init_method_std_attn_proj, @@ -199,7 +190,7 @@ def _attn_fused( def _get_meta( self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[TransformerKwargs.hidden_dims]} + hidden_dims = {dim.name: dim for dim in kwargs[self._transformer_kwargs.hidden_dims]} return TensorMeta.from_dims( tuple( hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) @@ -209,6 +200,32 @@ def _get_meta( dtype=input_.dtype, ) + @property + def query_dims(self): + return ( + self._transformer_dim_names.batch, + self._transformer_dim_names.sequence_q, + self._transformer_dim_names.composite_heads, + self._transformer_dim_names.kv_channels, + ) + + @property + def kv_dims(self): + return ( + self._transformer_dim_names.batch, + self._transformer_dim_names.sequence_q, + self._transformer_dim_names.group_heads, + self._transformer_dim_names.kv_channels, + ) + + @property + def context_dims(self): + return ( + self._transformer_dim_names.batch, + self._transformer_dim_names.sequence_q, + self._transformer_dim_names.composite_dense, + ) + def _debug_log( self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] ) -> None: @@ -307,12 +324,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # TODO: Move the rest to function. - if (past_key_values := kwargs.get(TransformerKwargs.past_key_values)) is not None: + if (past_key_values := kwargs.get(self._transformer_kwargs.past_key_values)) is not None: assert sequence_first # Clear the lists so tensors can be de-allocated key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) - if (presents := kwargs.get(TransformerKwargs.presents)) is not None: + if (presents := kwargs.get(self._transformer_kwargs.presents)) is not None: # Return the presents as a leaf tensors so the gradients from later micro-sequences # don't propagate to this one. presents.append(present := key_value.detach().requires_grad_()) @@ -339,23 +356,23 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._config.rotary.enabled: if self._debug_transformer: - self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) + self._debug_log(query, "query_rotary_input", self.query_dims, kwargs) self._debug_log( key, "key_rotary_input", - self._KV_DIMS, + self.kv_dims, kwargs, ) rotary_fn = triton_rotary_autograd_ if self._config.rotary.triton else apply_rotary_embeddings - query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q]) - key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k]) + query = rotary_fn(query, kwargs[self._transformer_kwargs.rotary_freq_q]) + key = rotary_fn(key, kwargs[self._transformer_kwargs.rotary_freq_k]) window_size = self._decide_window_size() if self._use_flash_attention: assert _flash_available with set_generator(self._tensor_space.distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(TransformerKwargs.cu_seqlens_q, None)) is not None: + if (cu_seqlens_q := kwargs.get(self._transformer_kwargs.cu_seqlens_q, None)) is not None: out_dims = query.size() query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) @@ -365,12 +382,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key, value, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(TransformerKwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(TransformerKwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), + cu_seqlens_k=kwargs.get(self._transformer_kwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(self._transformer_kwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(self._transformer_kwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), - causal=True, + causal=self._causal, softmax_scale=self._softmax_scale, ).view(*out_dims) else: @@ -380,7 +397,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ value, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), dropout_p=self._config.attention_dropout if self.training else 0.0, - causal=True, + causal=self._causal, softmax_scale=self._softmax_scale, ) input_ = input_.flatten(-2) @@ -390,25 +407,25 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ query.flatten(-2), key.flatten(-2), value.flatten(-2), - kwargs[TransformerKwargs.attention_mask], - kwargs[TransformerKwargs.attention_mask_value], + kwargs[self._transformer_kwargs.attention_mask], + kwargs[self._transformer_kwargs.attention_mask_value], ) if self._debug_transformer: - self._debug_log(query, "query", self._QUERY_DIMS, kwargs) + self._debug_log(query, "query", self.query_dims, kwargs) self._debug_log( key, "key", - self._KV_DIMS, + self.kv_dims, kwargs, ) self._debug_log( value, "value", - self._KV_DIMS, + self.kv_dims, kwargs, ) - self._debug_log(input_, "context", self._CONTEXT_DIMS, kwargs) + self._debug_log(input_, "context", self.context_dims, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 4806e37ec..6b0d7ad68 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -84,6 +84,7 @@ class TransformerKwargs: sequence_q_dim = "sequence_q_dim" sequence_k_dim = "sequence_k_dim" sequence_length = "sequence_length" + micro_batch_size = "micro_batch_size" # TODO: Move grad_output = "grad_output" @@ -98,6 +99,8 @@ class RotaryEmbeddingType(str, enum.Enum): default = "default" llama3 = "llama3" yarn = "yarn" + # TODO Soham: generic name? + pixtral = "pixtral" @config_class() @@ -166,6 +169,15 @@ class RotaryConfig(RotaryArchitectureConfig, BaseModelConfig): pass +@config_class() +class VisionRotaryConfig(RotaryConfig): + type: RotaryEmbeddingType = Field( + default=RotaryEmbeddingType.pixtral, + desc="The type of rotary embedding to use. Choices: none, default, llama3, yarn, pixtral.", + hint=FieldHint.feature, + ) + + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" everywhere = "everywhere" @@ -398,63 +410,73 @@ def _from_dict( cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton")) return super()._from_dict(default, strict, flat) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + def setup_tensor_space(self, tensor_space: TensorSpace, type: str | None = None) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + if type == "vision": + # TODO Soham: better way to get around circular imports? + from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames + + transformer_dim_names = VisionTransformerDimNames + else: + transformer_dim_names = TransformerDimNames + # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.hidden, self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.hidden, self.hidden_size)) # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - TransformerDimNames.group_heads, + transformer_dim_names.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(transformer_dim_names.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(transformer_dim_names.kv_channels, self.kv_channels)) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) + CompositeTensorDim(transformer_dim_names.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) ) # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(TransformerDimNames.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim(gate_and_up := TensorDim(TransformerDimNames.gate_and_up, 2 if self.gated else 1)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(TransformerDimNames.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_expert_mlp, (experts, mlp))) + tensor_space.add_tensor_dim(mlp := TensorDim(transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + gate_and_up := TensorDim(transformer_dim_names.gate_and_up, 2 if self.gated else 1) ) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.unshared_experts, self.num_unshared_experts)) + tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp))) + tensor_space.add_tensor_dim(experts := TensorDim(transformer_dim_names.experts, self.num_experts)) + tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_expert_mlp, (experts, mlp))) + tensor_space.add_tensor_dim( + CompositeTensorDim(transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + ) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.unshared_experts, self.num_unshared_experts)) # shared_experts if self.num_shared_experts: tensor_space.add_tensor_dim( - shared_experts := TensorDim(TransformerDimNames.shared_experts, self.num_shared_experts) + shared_experts := TensorDim(transformer_dim_names.shared_experts, self.num_shared_experts) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) + CompositeTensorDim(transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - TransformerDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) + transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) ) ) @@ -656,6 +678,11 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) + causal: bool = Field( + default=True, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + hint=FieldHint.feature, + ) def _validate(self) -> None: if self.init_method_std is None: @@ -718,3 +745,30 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: Assert.is_(self.window_size, None) return use_flash_attention + + +@config_class() +class VisionRotaryConfig(RotaryConfig): + type: RotaryEmbeddingType = Field( + default=RotaryEmbeddingType.pixtral, + desc="The type of rotary embedding to use. Choices: none, default, llama3, yarn, pixtral.", + hint=FieldHint.feature, + ) + + +@config_class() +class VisionTransformerConfig(TransformerConfig): + """ + Configuration for the Vision Transformer (ViT) model. + """ + + causal: bool = FieldUpdate( + default=False, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + hint=FieldHint.feature, + ) + rotary: VisionRotaryConfig = FieldUpdate( + default_factory=VisionRotaryConfig, + desc="Configuration for the rotary positional embeddings.", + hint=FieldHint.feature, + ) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 9b90beffb..1b494fc0b 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -8,7 +8,14 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerName +from fast_llm.layers.transformer.config import ( + TransformerConfig, + TransformerDimNames, + TransformerKwargs, + TransformerSubLayerName, + VisionTransformerConfig, +) +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -18,6 +25,13 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s super().__init__() self._name = name + if isinstance(config, VisionTransformerConfig): + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + elif isinstance(config, TransformerConfig): + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs + init_method_1 = init_normal_( std=config.init_method_std_mlp_1, min_val=config.init_method_min_mlp_1, @@ -29,8 +43,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self._intermediate_dim = tensor_space.get_tensor_dim(TransformerDimNames.composite_expert_mlp) + hidden_dim = tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) + self._intermediate_dim = tensor_space.get_tensor_dim(self._transformer_dim_names.composite_expert_mlp) self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._recompute_level = config.mlp_recompute_level @@ -41,7 +55,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), + tensor_space.get_tensor_dim(self._transformer_dim_names.composite_gated_expert_mlp), bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index cbafe6c97..542b4d42e 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -12,7 +12,9 @@ TransformerConfig, TransformerDimNames, TransformerKwargs, + VisionTransformerConfig, ) +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -129,63 +131,122 @@ def get_rotary_frequencies( return frequencies +def get_2d_rotary_frequencies( + config: RotaryConfig, + height, + width, + kv_channels, + *, + device="cuda", +) -> torch.Tensor: + # Calculate complex frequencies by using alternating channels for width and height + height_positions = torch.arange(height, device=device, dtype=torch.float64) + width_positions = torch.arange(width, device=device, dtype=torch.float64) + frequencies = config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) + # TODO Soham: apply scaling + angles_h = torch.outer(height_positions, frequencies[::2]) + angles_w = torch.outer(width_positions, frequencies[1::2]) + angles = torch.cat( + [ + angles_h[:, None, :].repeat(1, width, 1), + angles_w[None, :, :].repeat(height, 1, 1), + ], + dim=-1, + ).reshape(-1, kv_channels // 2) + + frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) + if not config.complex_format: + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 + ).contiguous() + + return frequencies + + class RotaryEmbeddingPreprocessor: _scalar_dim: TensorDim - _kv_channels_dim: TensorDim - _rotary_embedding_frequencies: torch.Tensor _mask: torch.Tensor _mask_value: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 def __init__( self, config: RotaryConfig, tensor_space: TensorSpace, ): + # if isinstance(config, TransformerConfig): + # self._transformer_dim_names = TransformerDimNames + # self._transformer_kwargs = TransformerKwargs + # elif isinstance(config, VisionTransformerConfig): + # self._transformer_dim_names = VisionTransformerDimNames + # self._transformer_kwargs = VisionTransformerKwargs + # TODO Soham: better way to do this? + if config.type == RotaryEmbeddingType.pixtral: + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + else: + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs self._config = config assert self._config.enabled self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._kv_channels_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.kv_channels) + self._tensor_cache_max_sequence_length: int = -1 - def create_tensors(self, sequence_length: int) -> None: + def create_tensors(self, sequence_length: int, num_patches: None | int = None) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length - self._rotary_embedding_frequencies = get_rotary_frequencies( - self._config, - sequence_length, - self._kv_channels_dim.global_size, - device=self._tensor_space.distributed.device, - ) + if self._config.type == RotaryEmbeddingType.pixtral: + self._rotary_embedding_frequencies = get_2d_rotary_frequencies( + self._config, + num_patches, + num_patches, + self._kv_channels_dim.global_size, + device=self._tensor_space.distributed.device, + ) + else: + self._rotary_embedding_frequencies = get_rotary_frequencies( + self._config, + sequence_length, + self._kv_channels_dim.global_size, + device=self._tensor_space.distributed.device, + ) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k - ] - kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + if self._config.type == RotaryEmbeddingType.pixtral: + position_ids = kwargs[self._transformer_kwargs.patch_position_ids] + # TODO Soham: use position_ids_q and position_ids_k for sequence_data_parallelism + kwargs[self._transformer_kwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] + kwargs[self._transformer_kwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] + else: + kwargs[self._transformer_kwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - sequence_q : sequence_k + ] + kwargs[self._transformer_kwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_q, + tensor_name=self._transformer_kwargs.rotary_freq_q, ) - kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_k, + tensor_name=self._transformer_kwargs.rotary_freq_k, ) @@ -202,6 +263,12 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, ): + if isinstance(config, VisionTransformerConfig): + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + elif isinstance(config, TransformerConfig): + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config @@ -231,22 +298,22 @@ def create_tensors(self, sequence_length: int) -> None: def preprocess(self, kwargs: dict[str, typing.Any]) -> None: sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - kwargs[TransformerKwargs.attention_mask] = self._mask[ + kwargs[self._transformer_kwargs.attention_mask] = self._mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths, None)) is not None: + if (sequence_lengths := kwargs.get(self._transformer_kwargs.sequence_lengths, None)) is not None: seq_ids = torch.stack( [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) - kwargs[TransformerKwargs.attention_mask] = ( - kwargs[TransformerKwargs.attention_mask] + kwargs[self._transformer_kwargs.attention_mask] = ( + kwargs[self._transformer_kwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] ) - kwargs[TransformerKwargs.attention_mask_value] = self._mask_value + kwargs[self._transformer_kwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.attention_mask] = TensorMeta.from_dims( ( self._scalar_dim, self._scalar_dim, @@ -254,12 +321,12 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: self._scalar_dim, kwargs[TransformerKwargs.sequence_k_dim], ), - tensor_name=TransformerKwargs.attention_mask, + tensor_name=self._transformer_kwargs.attention_mask, dtype=torch.bool, ) - kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.attention_mask_value] = TensorMeta.from_dims( (self._scalar_dim,), - tensor_name=TransformerKwargs.attention_mask_value, + tensor_name=self._transformer_kwargs.attention_mask_value, dtype=self._tensor_space.distributed_config.training_dtype.torch, ) @@ -270,6 +337,12 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert self._config.do_use_flash_attention(self._distributed_config) + if isinstance(config, VisionTransformerConfig): + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + elif isinstance(config, TransformerConfig): + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs def preprocess(self, kwargs: dict[str, typing.Any]) -> None: """ @@ -281,7 +354,7 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: also contain previous tokens from the first document in micro-sequence. We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. """ - sequence_lengths = kwargs.get(TransformerKwargs.sequence_lengths) + sequence_lengths = kwargs.get(self._transformer_kwargs.sequence_lengths) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size if sequence_q < kwargs[TransformerKwargs.sequence_length]: @@ -318,17 +391,17 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: else: seqlens_q = torch.cat(sequence_lengths) seqlens_k = torch.cat(sequence_lengths) - kwargs[TransformerKwargs.cu_seqlens_q] = torch.cat( + kwargs[self._transformer_kwargs.cu_seqlens_q] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.cu_seqlens_k] = torch.cat( + kwargs[self._transformer_kwargs.cu_seqlens_k] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() - kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() + kwargs[self._transformer_kwargs.max_seqlen_q] = seqlens_q.max() + kwargs[self._transformer_kwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 311403fc9..ba4e5139f 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,9 +8,15 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import ( + TransformerConfig, + TransformerDimNames, + TransformerKwargs, + VisionTransformerConfig, +) from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -30,6 +36,12 @@ def __init__( return_input: bool = False, ): super().__init__() + if isinstance(config, VisionTransformerConfig): + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + elif isinstance(config, TransformerConfig): + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs self._config = config self._tensor_space = tensor_space self._dropout_p = self._config.hidden_dropout @@ -39,7 +51,7 @@ def __init__( self._layer_index = layer_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) @@ -66,7 +78,7 @@ def name(self) -> str: return f"Transformer layer {self._layer_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): - dims = kwargs[TransformerKwargs.hidden_dims] + dims = kwargs[self._transformer_kwargs.hidden_dims] if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py index b8436f72e..bf5f3f1aa 100644 --- a/fast_llm/layers/vision_encoder/adapter.py +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -1,35 +1,54 @@ +import typing + import torch +from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.functional.triton.mlp import torch_mlp_activation from fast_llm.layers.common.linear import Linear -from fast_llm.layers.transformer.config import TransformerDimNames -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames -from fast_llm.tensor import init_normal_ +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames +from fast_llm.tensor import TensorMeta, init_normal_ -class VisionAdapter(torch.nn.Module): +class VisionAdapter(Layer): """ Vision adapter layer for the LLM. """ - def __init__(self, intermediate_size: int, tensor_space: TensorSpace, name: str = "vision_adapter"): + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() - self._name = name input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + self._activation_type = config.adapter_activation_type + # TODO Soham: Make them OutputParallelLinear instead? How would this work with parallelism? self.layer_1 = Linear( input_dim, - tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), + tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), bias=True, weight_init_method=init_normal_(), bias_init_method=init_normal_(), ) self.layer_2 = Linear( - tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), + tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), tensor_space.get_tensor_dim(TransformerDimNames.hidden), bias=True, weight_init_method=init_normal_(), bias_init_method=init_normal_(), ) - def forward(self, input_: torch.Tensor): - return self.layer_2(self.layer_1(input_)) + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Vision adapter output", + dtype=input_.dtype, + ) + return self.layer_2( + torch_mlp_activation(input_=self.layer_1(input_), gated=False, activation_type=self._activation_type) + ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index b83a118b5..7c650bf93 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,20 +1,55 @@ -from fast_llm.config import Config, Field, FieldHint, config_class -from fast_llm.engine.base_model.config import BaseModelArchitectureConfig +from fast_llm.config import Config, Field, FieldHint, FieldUpdate, config_class +from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.layers.transformer.config import TransformerArchitectureConfig, VisionTransformerConfig class VisionEncoderDimNames: out_channels = "vision_out_channels" - intermediate_size = "vision_intermediate_size" + adapter_size = "vision_adapter_size" patch_size = "vision_patch_size" kv_channels = "vision_kv_channels" -class VisionModelKwargs: +class VisionTransformerDimNames: + # A set of common tensor dim names packed into a namespace. + # Input dimensions (variable) + # TODO: Does batch belong here? + batch = "vit_batch" + # TODO: Distinguish micro-sequence? + sequence_q = "vit_sequence_q" + sequence_q_tp = "vit_sequence_q_tp" + sequence_k = "vit_sequence_k" + hidden = "vit_hidden" + # Self-attention dimensions + head_groups = "vit_head_groups" + group_heads = "vit_group_heads" + key_and_value = "vit_key_value" + kv_channels = "vit_kv_channels" + composite_heads = "vit_composite_heads" + composite_query = "vit_composite_query" + composite_key_value = "vit_composite_key_value" + composite_dense = "vit_composite_dense" + # MLP dimensions + mlp = "vit_mlp" + gate_and_up = "vit_gate_and_up" + composite_gated_mlp = "vit_composite_gated_mlp" + experts = "vit_experts" + top_experts = "vit_top_experts" + shared_experts = "vit_shared_experts" + unshared_experts = "vit_unshared_experts" + composite_expert_mlp = "vit_composite_expert_mlp" + composite_gated_expert_mlp = "vit_composite_gated_expert_mlp" + composite_shared_expert_mlp = "vit_composite_shared_expert_mlp" + composite_gated_shared_expert_mlp = "vit_composite_gated_shared_expert_mlp" + + +class VisionEncoderKwargs: patch_size = "patch_size" images = "images" + image_patches = "image_patches" image_positions = "image_positions" image_size = "image_size" image_sizes = "image_sizes" @@ -24,56 +59,34 @@ class VisionModelKwargs: rope_theta = "vit_rope_theta" rotary_inv_freq = "vit_rotary_inv_freq" kv_channels = "vit_kv_channels" + max_image_tokens = "max_image_tokens" + patch_embeddings = "patch_embeddings" + hidden_dims = "vit_hidden_dims" -@config_class() -class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): - _abstract = False - """ - Configuration class for the vision encoder, which transforms images into embeddings - """ - path: str | None = Field(default=None, desc="Path to a pretrained vision encoder model.", hint=FieldHint.optional) - pre_norm: NormalizationConfig = Field( - default_factory=NormalizationConfig, - ) - hidden_size: int = Field( - default=1024, desc="The size of the hidden layers in the transformer model.", hint=FieldHint.optional - ) - intermediate_size: int = Field( - default=4096, - desc="The size of the intermediate (feed-forward) layers in the transformer model.", - hint=FieldHint.optional, - ) - num_hidden_layers: int = Field( - default=24, desc="The number of hidden layers in the transformer model.", hint=FieldHint.optional - ) - num_attention_heads: int = Field( - default=16, desc="The number of attention heads for the multi-head attention layers.", hint=FieldHint.optional - ) - num_channels: int = Field( - default=3, desc="Number of channels in the input image, typically 3 for RGB.", hint=FieldHint.optional - ) - image_size: int = Field( - default=1024, desc="The size of the input images (assumed square).", hint=FieldHint.optional - ) - patch_size: int = Field(default=16, desc="The size of the image patches to be encoded.", hint=FieldHint.optional) - hidden_act: str = Field( - default="gelu", desc="The activation function used in the hidden layers.", hint=FieldHint.optional - ) - attention_dropout: float = Field( - default=0.0, desc="The dropout probability for attention layers.", hint=FieldHint.optional - ) - rope_theta: float = Field( - default=10000.0, desc="The base value for rotary position embeddings.", hint=FieldHint.optional - ) - initializer_range: float = Field( - default=0.02, desc="The standard deviation of the normal initializer.", hint=FieldHint.optional - ) - activation_type: ActivationType = Field( - default=ActivationType.silu, - desc="The activation function used in the hidden layers. Default: SiLU.", - hint=FieldHint.optional, - ) +# TODO Soham: do we need all of them? +class VisionTransformerKwargs: + rotary_freq_q = "vit_rotary_freq_q" + rotary_freq_k = "vit_rotary_freq_k" + attention_mask = "vit_attention_mask" + attention_mask_value = "vit_attention_mask_value" + sequence_lengths = "vit_sequence_lengths" + cu_seqlens_q = "vit_cu_seqlens_q" + cu_seqlens_k = "vit_cu_seqlens_k" + max_seqlen_q = "vit_max_seqlen_q" + max_seqlen_k = "vit_max_seqlen_k" + # TODO: Review these + presents = "vit_presents" + past_key_values = "vit_past_key_values" + sequence_first = "vit_sequence_first" + hidden_dims = "vit_hidden_dims" + sequence_q_dim = "vit_sequence_q_dim" + sequence_k_dim = "vit_sequence_k_dim" + sequence_length = "vit_sequence_length" + micro_batch_size = "vit_micro_batch_size" + # TODO: Move + grad_output = "vit_grad_output" + patch_position_ids = "patch_position_ids" @config_class() @@ -116,35 +129,70 @@ class ImageNormalizationConfig(Config): @config_class() -class VisionArchitectureConfig(BaseModelArchitectureConfig): +class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): _abstract = False - encoder: VisionEncoderArchitectureConfig = Field( - default_factory=VisionEncoderArchitectureConfig, - desc="Configuration for the vision encoder that transforms images into embeddings.", + transformer: TransformerArchitectureConfig = Field( + default_factory=TransformerArchitectureConfig, + desc="Configuration for the vision transformer architecture.", + hint=FieldHint.core, + ) + patch_size: int = Field( + default=16, + desc="Patch size for the image encoder.", + hint=FieldHint.core, + ) + patch_norm: NormalizationConfig = Field( + default_factory=NormalizationConfig, + desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) adapter_size: int = Field( default=5120, desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", - hint=FieldHint.optional, + hint=FieldHint.core, ) adapter_activation_type: ActivationType = Field( default=ActivationType.gelu, desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", hint=FieldHint.core, ) - normalization: ImageNormalizationConfig = Field( + + def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.adapter_size, self.adapter_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.patch_size)) + # TODO Soham: add a check for presence of kv channels parameter (head_dim) + tensor_space.add_tensor_dim( + TensorDim( + VisionEncoderDimNames.kv_channels, self.transformer.hidden_size // self.transformer.num_attention_heads + ) + ) + self.transformer.setup_tensor_space(tensor_space, type="vision") + + +@config_class() +class VisionEncoderConfig(VisionEncoderArchitectureConfig, BaseModelConfig): + transformer: VisionTransformerConfig = FieldUpdate( + default_factory=VisionTransformerConfig, + desc="Configuration for the transformer architecture.", + hint=FieldHint.core, + ) + patch_norm: NormalizationConfig = FieldUpdate( + default_factory=NormalizationConfig, + desc="Configuration for the normalization layers applied to the image patches.", + hint=FieldHint.optional, + ) + image_normalization: ImageNormalizationConfig = Field( default_factory=ImageNormalizationConfig, desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) + adapter_activation_type: ActivationType = FieldUpdate( + default=ActivationType.gelu, + desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + hint=FieldHint.core, + ) def setup_tensor_space(self, tensor_space: TensorSpace): - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.encoder.hidden_size)) - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.intermediate_size, self.adapter_size)) - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.encoder.patch_size)) - # TODO Soham: add a check for kv channels - tensor_space.add_tensor_dim( - TensorDim(VisionEncoderDimNames.kv_channels, self.encoder.hidden_size // self.encoder.num_attention_heads) - ) + super().setup_tensor_space(tensor_space) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 8c694d28a..9369037d4 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -1,26 +1,20 @@ -import functools import typing import torch -from transformers import PixtralVisionConfig -from transformers.models.pixtral.modeling_pixtral import PixtralTransformer from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.language_model.config import LanguageModelBaseConfig -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.layers.vision_encoder.adapter import VisionAdapter -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionModelKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ -def position_ids_in_meshgrid(patch_embeddings_list, max_width): +def position_ids_in_meshgrid(patch_embeddings_list, max_size): positions = [] for patch in patch_embeddings_list: height, width = patch.shape[-2:] mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) - ids = h_grid * max_width + v_grid + ids = h_grid * max_size + v_grid positions.append(ids[:, 0]) return torch.cat(positions) @@ -41,108 +35,24 @@ def generate_block_attention_mask(patch_embeds_list, tensor): return causal_mask -# TODO Soham: should this just be nn.Module? -class VisionEncoder(Layer): - """ - A vision encoder layer for creating token embeddings from vision model - """ - - def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): +class PatchConv(Layer): + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() - - self._config = config.vision_encoder - self._distributed_config = tensor_space.distributed_config + # TODO Soham: device=meta with torch.device("meta"): - # TODO Soham options to fix rotary: - # 1. load PixtralTransformer instead of PixtralVisionModel. Required to implement conv2d, ln_pre separately and store positional embeddings in kwargs_meta - # 2. set self.vision_encoder.position_embeddings = PixtralRotaryEmbedding(config) outside of meta scope - config = PixtralVisionConfig( - hidden_size=self._config.encoder.hidden_size, - intermediate_size=self._config.encoder.intermediate_size, - num_hidden_layers=self._config.encoder.num_hidden_layers, - num_attention_heads=self._config.encoder.num_attention_heads, - num_channels=self._config.encoder.num_channels, - image_size=self._config.encoder.image_size, - patch_size=self._config.encoder.patch_size, - hidden_act=self._config.encoder.hidden_act, - attention_dropout=self._config.encoder.attention_dropout, - rope_theta=self._config.encoder.rope_theta, - initializer_range=self._config.encoder.initializer_range, - ) - self.patch_conv = torch.nn.Conv2d( + self.conv = torch.nn.Conv2d( in_channels=3, - out_channels=self._config.encoder.hidden_size, - kernel_size=self._config.encoder.patch_size, - stride=self._config.encoder.patch_size, + out_channels=config.transformer.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, bias=False, + dtype=tensor_space.distributed_config.training_dtype.torch, ) - self.patch_conv.weight = ParameterMeta.from_dims( - tuple( - TensorDim(f"patch_conv_weight_{idx}", size) - for idx, size in enumerate(self.patch_conv.weight.shape) - ), + self.conv.weight = ParameterMeta.from_dims( + tuple(TensorDim(f"patch_conv_weight_{idx}", size) for idx, size in enumerate(self.conv.weight.shape)), init_method=init_normal_(), ) - self.norm = self._config.encoder.pre_norm.get_layer( - tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) - ) - self.vision_transformer = PixtralTransformer(config) - # self.vision_encoder = PixtralVisionModel(config) - param_names = [] - # gather all names first. PyTorch complains if we do it in the loop - for name, param in self.vision_transformer.named_parameters(): - param_names.append(name) - for name in param_names: - *module_path, stem = name.split(".") - module = functools.reduce(getattr, module_path, self.vision_transformer) - param = self.vision_transformer.get_parameter(name) - setattr( - module, - stem, - ParameterMeta.from_dims( - tuple(TensorDim(f"{name}_{idx}", size) for idx, size in enumerate(param.shape)), - init_method=init_normal_(), - ), - ) - # none_params = [key for key, value in module._parameters.items() if value is None] - # for key in none_params: - # module._parameters.pop(key) - self.adapter = VisionAdapter( - intermediate_size=tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), - tensor_space=tensor_space, - ) - - def _forward( - self, input_: torch.Tensor, image_sizes: torch.Tensor, inv_freq: torch.Tensor, image_width: int - ) -> torch.Tensor: - patch_embeddings = self.patch_conv(input_) - patch_embeddings_list = [ - embedding[..., : image_size[0], : image_size[1]] - for embedding, image_size in zip(patch_embeddings, image_sizes) - ] - patch_embeddings = torch.cat([p.flatten(1).T for p in patch_embeddings_list], dim=0).unsqueeze(0) - patch_embeddings = self.norm(patch_embeddings) - position_ids = position_ids_in_meshgrid(patch_embeddings_list, image_width // self._config.encoder.patch_size) - freqs = inv_freq[position_ids] - with torch.autocast(device_type=input_.device.type): - cos = freqs.cos() - sin = freqs.sin() - cos = cos.to(dtype=input_.dtype) - sin = sin.to(dtype=input_.dtype) - - attention_mask = generate_block_attention_mask( - [p.shape[-2] * p.shape[-1] for p in patch_embeddings_list], patch_embeddings - ) - - (out,) = self.vision_transformer( - patch_embeddings, - attention_mask=attention_mask, - position_embeddings=(cos, sin), - output_attentions=False, - return_dict=False, - ) - - return self.adapter(out) + self.norm = config.patch_norm.get_layer(tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels)) def forward( self, @@ -150,17 +60,14 @@ def forward( kwargs: dict[str, typing.Any], losses: dict[str, typing.Any] | None = None, metrics: dict | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> torch.Tensor: + hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): - return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims], - tensor_name="Vision Output", - dtype=self._distributed_config.training_dtype.torch, - ) - return self._forward( - input_, - kwargs[VisionModelKwargs.image_sizes][:1], - kwargs[VisionModelKwargs.rotary_inv_freq], - image_width=kwargs[VisionModelKwargs.image_size], - ) - # return self.adapter(self.vision_encoder(input_, kwargs[VisionModelKwargs.image_sizes])) + return TensorMeta.from_dims(hidden_dims) + # we don't need images after this point + # image_patches = kwargs.pop(VisionEncoderKwargs.image_patches) + patch_embeddings = self.norm(self.conv(input_)) + patch_embeddings = patch_embeddings.reshape(*(x.size for x in hidden_dims)) + # Hack to pass patch embeddings to the next layer + # kwargs[VisionEncoderKwargs.patch_embeddings] = patch_embeddings + return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 154c1a16d..abae6f11a 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -5,7 +5,12 @@ import torchvision.transforms.v2.functional as F from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.vision_encoder.config import VisionArchitectureConfig, VisionModelKwargs +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.vision_encoder.config import ( + VisionEncoderArchitectureConfig, + VisionEncoderKwargs, + VisionTransformerKwargs, +) from fast_llm.utils import div @@ -23,11 +28,11 @@ def get_resize_dims(height: int, width: int, max_height: int, max_width: int, pa If the image is smaller, it will be resized to the nearest multiple of the patch size. """ ratio = max(height / max_height, width / max_width) - return ( - (int(height / ratio), int(width / ratio)) - if ratio > 1 - else (patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size)) - ) + if ratio > 1: + # Resize to fit within max dimensions + height = int(height / ratio) + width = int(width / ratio) + return patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size) def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: @@ -72,42 +77,128 @@ def create_inv_freqs(rope_theta: int, kv_channels: int, image_size: int, patch_s return torch.cat((inv_freq, inv_freq), dim=-1) +def position_ids_in_meshgrid(image_sizes: list[torch.Tensor], max_size: int, patch_size: int) -> torch.Tensor: + positions = [] + for h, w in image_sizes: + patch_height = h // patch_size + patch_width = w // patch_size + mesh = torch.meshgrid(torch.arange(patch_height), torch.arange(patch_width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_size + v_grid + positions.append(ids[:, 0]) + return positions + + +def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tensor: + patch_height = height // patch_size + patch_width = width // patch_size + mesh = torch.meshgrid(torch.arange(patch_height), torch.arange(patch_width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_size + v_grid + return ids[:, 0] + + class VisionPreprocessor: - def __init__(self, config: VisionArchitectureConfig, tensor_space: TensorSpace): + def __init__(self, config: VisionEncoderArchitectureConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config def preprocess(self, kwargs: dict[str, typing.Any]) -> None: - images = kwargs.get("images") - im_height = kwargs.get(VisionModelKwargs.image_size) - im_width = kwargs.get(VisionModelKwargs.image_size) - patch_size = kwargs[VisionModelKwargs.patch_size] + images = kwargs.get(VisionEncoderKwargs.images) + im_height = kwargs.get(VisionEncoderKwargs.image_size) + im_width = kwargs.get(VisionEncoderKwargs.image_size) + patch_size = kwargs[VisionEncoderKwargs.patch_size] image_sizes = [ - get_resize_dims(im.size(1), im.size(2), im_height, im_width, patch_size=patch_size) for im in images + [get_resize_dims(im.size(1), im.size(2), im_height, im_width, patch_size=patch_size) for im in ims] + for ims in images ] - kwargs[VisionModelKwargs.image_sizes] = image_sizes + kwargs[VisionEncoderKwargs.image_sizes] = image_sizes images = [ - pad( + [ normalize( - resize(image, im_height, im_width, patch_size) / kwargs[VisionModelKwargs.image_rescale_factor], - mean=kwargs[VisionModelKwargs.image_mean], - std=kwargs[VisionModelKwargs.image_std], - ), - max_height=im_height, - max_width=im_width, - ) - for image in images + resize(image, im_height, im_width, patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch + ) + / kwargs[VisionEncoderKwargs.image_rescale_factor], + mean=kwargs[VisionEncoderKwargs.image_mean], + std=kwargs[VisionEncoderKwargs.image_std], + ) + for image in imgs + ] + for imgs in images ] - images = torch.stack(images, dim=0).to( - # TODO Soham: is this needed? - device=self._tensor_space.distributed.device, - dtype=self._distributed_config.training_dtype.torch, - ) - kwargs[VisionModelKwargs.images] = images - kwargs[VisionModelKwargs.rotary_inv_freq] = create_inv_freqs( - kwargs[VisionModelKwargs.rope_theta], - kwargs[VisionModelKwargs.kv_channels], + # position_ids = position_ids_in_meshgrid(image_sizes, im_height, patch_size) + patches = [] + patch_position_ids = [] + cu_seqlens = [0] + max_seqlen = -1 + for imgs, sizes in zip(images, image_sizes): + # TODO Soham: should this be micro_sequence_length? + # sum( + # get_num_patches(*size, patch_size) for size in sizes + # ) + seq_patches = [] + for image, size in zip(imgs, sizes): + seqlen = get_num_patches(*size, patch_size) + if seqlen > max_seqlen: + max_seqlen = seqlen + cu_seqlens.append(cu_seqlens[-1] + seqlen) + seq_patches.append( + torch.cat( + [ + torch.nn.functional.unfold(image, kernel_size=patch_size, stride=patch_size).T.reshape( + -1, 3, patch_size, patch_size + ), + ] + ) + ) + padding_size = kwargs[TransformerKwargs.sequence_length] - cu_seqlens[-1] + if padding_size > max_seqlen: + max_seqlen = padding_size + cu_seqlens.append(kwargs[TransformerKwargs.sequence_length]) + patches.append( + torch.cat( + [ + *seq_patches, + torch.zeros(padding_size, 3, patch_size, patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ), + ] + ) + ) + position_ids = torch.cat( + [position_ids_in_meshgrid(*size, im_height // patch_size, patch_size) for size in sizes] + ).to(device=self._tensor_space.distributed.device) + # We pad at the end instead of padding at the position in meshgrid because flash attention does not support custom attention masks + patch_position_ids.append( + torch.cat( + [ + position_ids, + torch.full((padding_size,), 0).to(device=self._tensor_space.distributed.device), + ] + ) + ) + # TODO Soham: remove + assert patches[-1].size(0) == kwargs[TransformerKwargs.sequence_length] + patches = torch.cat(patches) + patch_position_ids = torch.cat(patch_position_ids) + kwargs[VisionEncoderKwargs.image_patches] = patches + kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( + kwargs[VisionEncoderKwargs.rope_theta], + kwargs[VisionEncoderKwargs.kv_channels], im_height, patch_size, ).to(device=self._tensor_space.distributed.device) + kwargs[VisionEncoderKwargs.max_image_tokens] = div(im_height * im_width, patch_size**2) + kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids + # TODO Soham: handle sequence data parallel + kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( + cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 + ) + kwargs[VisionTransformerKwargs.cu_seqlens_k] = torch.tensor( + cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 + ) + kwargs[VisionTransformerKwargs.max_seqlen_q] = max_seqlen + kwargs[VisionTransformerKwargs.max_seqlen_k] = max_seqlen diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index bd7da7979..d599a1148 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -165,7 +165,7 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig def _create_weight_converters( self, hf_base_prefix: str = "", - fast_llm_offset: int = 0, + fast_llm_offset: int = 1, ) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers @@ -187,9 +187,18 @@ def _create_weight_converters( return converters def _create_transformer_layer_converters( - self, i: int, ignore_export: bool = False, hf_base_prefix: str = "", fast_llm_offset: int = 1 + self, + i: int, + ignore_export: bool = False, + hf_base_prefix: str = "", + fast_llm_offset: int = 1, + type: str | None = None, ) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + if type is not None: + if type == "vision": + transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer + else: + transformer_config: TransformerConfig = self._model.config.base_model.transformer norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm converters = [] names_bias_cls = [ @@ -565,6 +574,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: lm_converters[-2] = ConstantExportParamConverter( export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] ) + # TODO Soham: cleaner way to get language model config converters for converter in lm_converters: if isinstance(converter, (RenameParamConverter, MappedConfigParamConverter, RopeScalingParamConverter)): # Llava uses a different name for the text config @@ -579,31 +589,36 @@ def _create_config_converters(cls) -> list[ParamConverter]: export_names=(("text_config", "hidden_size"),), ), # Image processing and conv layer - RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "image_size"),), - export_names=( - ( - "vision_config", - "image_size", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "patch_size"),), - export_names=( - ( - "vision_config", - "patch_size", - ), - ), + # TODO Soham: these options are not in the fast-llm model config. They're read from BatchConfig currently + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "encoder", "image_size"),), + # export_names=( + # ( + # "vision_config", + # "image_size", + # ), + # ), + # ), + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "encoder", "patch_size"),), + # export_names=( + # ( + # "vision_config", + # "patch_size", + # ), + # ), + # ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "patch_norm", "type"),), + fast_llm_value=NormalizationType.rms_norm, ), ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "pre_norm", "type"),), + fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm, ), # Vision Transformer RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "num_hidden_layers"),), + fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), export_names=( ( "vision_config", @@ -612,7 +627,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ), RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "hidden_size"),), + fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), export_names=( ( "vision_config", @@ -621,7 +636,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ), RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "num_attention_heads"),), + fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), export_names=( ( "vision_config", @@ -630,144 +645,213 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ), RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "intermediate_size"),), + fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), export_names=( ( "vision_config", - "intermediate_size", + "num_key_value_heads", ), ), ), - MappedConfigParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "activation_type"),), - export_names=( - ( - "vision_config", - "hidden_act", - ), - ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "num_channels"),), + fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), export_names=( ( "vision_config", - "num_channels", + "intermediate_size", ), ), ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "attention_dropout"),), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), export_names=( ( "vision_config", - "attention_dropout", + "hidden_act", ), ), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "rope_theta"),), - export_names=(("vision_config", "rope_theta"),), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True ), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "adapter_activation_type"),), + export_names=(("projector_hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False + ), + # TODO Soham: add this config param for completeness? + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "encoder", "num_channels"),), + # export_names=( + # ( + # "vision_config", + # "num_channels", + # ), + # ), + # ), + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "transformer", "attention_dropout"),), + # export_names=( + # ( + # "vision_config", + # "attention_dropout", + # ), + # ), + # ), RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "initializer_range"),), - export_names=(("vision_config", "initializer_range"),), + fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), + export_names=(("vision_config", "rope_theta"),), ), + # TODO Soham: add this config param in vision encoder for completeness? + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "transformer", "initializer_range"),), + # export_names=(("vision_config", "initializer_range"),), + # ), ] def _create_vision_transformer_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.vision_encoder.encoder.num_hidden_layers + num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers vision_transformer_converters = [] - for i in range(num_layers): - vision_transformer_converters += [ - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.k_proj.weight", - f"vision_tower.transformer.layers.{i}.attention.k_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.v_proj.weight", - f"vision_tower.transformer.layers.{i}.attention.v_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.q_proj.weight", - f"vision_tower.transformer.layers.{i}.attention.q_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.o_proj.weight", - f"vision_tower.transformer.layers.{i}.attention.o_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention_norm.weight", - f"vision_tower.transformer.layers.{i}.attention_norm.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.down_proj.weight", - f"vision_tower.transformer.layers.{i}.feed_forward.down_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.gate_proj.weight", - f"vision_tower.transformer.layers.{i}.feed_forward.gate_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.up_proj.weight", - f"vision_tower.transformer.layers.{i}.feed_forward.up_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.ffn_norm.weight", - f"vision_tower.transformer.layers.{i}.ffn_norm.weight", - ), - ] + for layer in range(num_layers): + # TODO Soham: check if args are correct + vision_transformer_converters.extend( + self._create_vision_transformer_layer_converters( + layer, + ignore_export=False, + hf_base_prefix="vision_tower.transformer.layers.", + fast_llm_offset=1, + type="vision", + ) + ) return vision_transformer_converters def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: - patch_conv_converter = WeightConverter( - "layers.0.vision_encoder.patch_conv.weight", - "vision_tower.patch_conv.weight", - ) - # TODO Soham: use _get_weight_and_bias_converters? - layernorm_converters = [] - layer_norm_converter = WeightConverter( - "layers.0.vision_encoder.norm.weight", - "vision_tower.ln_pre.weight", - ) - layernorm_converters.append(layer_norm_converter) - layer_norm_converter - if self._model.config.base_model.vision_encoder.encoder.pre_norm.type == NormalizationType.layer_norm: - layer_norm_bias_converter = WeightConverter( - "layers.0.vision_encoder.norm.bias", - "vision_tower.ln_pre.bias", - ) - layernorm_converters.append(layer_norm_bias_converter) + patch_conv_converter = WeightConverter("layers.0.conv.weight", "vision_tower.patch_conv.weight") + layernorm_converters = [ + WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), + ] + if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: + layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) + vision_transformer_converters = self._create_vision_transformer_converters() + offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 adapter_converters = [ - WeightConverter( - "layers.0.vision_encoder.adapter.layer_1.weight", - "multi_modal_projector.linear_1.weight", - ), - WeightConverter( - "layers.0.vision_encoder.adapter.layer_1.bias", - "multi_modal_projector.linear_1.bias", - ), - # TODO Soham: conditionally add bias - WeightConverter( - "layers.0.vision_encoder.adapter.layer_2.weight", - "multi_modal_projector.linear_2.weight", - ), - WeightConverter( - "layers.0.vision_encoder.adapter.layer_2.bias", - "multi_modal_projector.linear_2.bias", - ), + WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), + WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), + # TODO Soham: add bias based on config + WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), + WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), ] + return [patch_conv_converter] + layernorm_converters + vision_transformer_converters + adapter_converters def _create_weight_converters(self) -> list[WeightConverter]: vision_encoder_converter = self._create_vision_encoder_weight_converters() - lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=1) + offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 + lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) return vision_encoder_converter + lm_converters + def _create_vision_transformer_layer_converters( + self, + i: int, + ignore_export: bool = False, + hf_base_prefix: str = "", + fast_llm_offset: int = 1, + type: str | None = None, + ) -> list[WeightConverter]: + if type is not None: + if type == "vision": + transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer + else: + transformer_config: TransformerConfig = self._model.config.base_model.transformer + norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + converters = [] + names_bias_cls = [ + # Self-attn + ( + f"layers.{i+fast_llm_offset}.self_attn.query", + f"vision_tower.transformer.layers.{i}.attention.q_proj", + transformer_config.add_attn_qkv_bias, + QueryWeightConverter, + ), + ( + f"layers.{i+fast_llm_offset}.self_attn.key_value", + ( + f"vision_tower.transformer.layers.{i}.attention.k_proj", + f"vision_tower.transformer.layers.{i}.attention.v_proj", + ), + transformer_config.add_attn_qkv_bias, + KeyValueWeightConverter, + ), + ( + f"layers.{i+fast_llm_offset}.self_attn.dense", + f"vision_tower.transformer.layers.{i}.attention.o_proj", + transformer_config.add_attn_dense_bias, + WeightConverter, + ), + # Norm + ( + f"layers.{i+fast_llm_offset}.norm_1", + f"vision_tower.transformer.layers.{i}.attention_norm", + norm_bias, + WeightConverter, + ), + ( + f"layers.{i+fast_llm_offset}.norm_2", + f"vision_tower.transformer.layers.{i}.ffn_norm", + norm_bias, + WeightConverter, + ), + ] + for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: + converters += self._get_weight_and_bias_converters( + fast_llm_prefix, + () if ignore_export else hf_prefix, + use_bias, + cls=IgnoreExportWeightConverter if ignore_export else cls, + ) + + # MLP + if ignore_export: + converters += self._get_weight_and_bias_converters( + f"layers.{i+fast_llm_offset}.mlp.layer_1", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+fast_llm_offset}.mlp.layer_2", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, + ) + converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] + else: + converters += self._get_vision_transformer_mlp_converters( + f"layers.{i+fast_llm_offset}", f"vision_tower.transformer.layers.{i}" + ) + return converters + + def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + return [ + SplitWeightConverter( + f"{fast_llm_prefix}.mlp.layer_1.weight", + (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), + ), + MLPLayer2Converter( + f"{fast_llm_prefix}.mlp.layer_2.weight", + f"{hf_prefix}.feed_forward.down_proj.weight", + self._model.config.base_model, + ), + ] + class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c273f09b1..6aef273f6 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -27,7 +27,10 @@ RotaryEmbeddingPreprocessor, ) from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionModelKwargs +from fast_llm.layers.transformer.vision_transformer import VisionTransformerLayer +from fast_llm.layers.vision_encoder.adapter import VisionAdapter +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs, VisionTransformerDimNames +from fast_llm.layers.vision_encoder.encoder import PatchConv from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -76,6 +79,10 @@ def __init__( if self._config.vision_encoder: self._vision_preprocessor = VisionPreprocessor(self._config.vision_encoder, self._tensor_space) + if self._config.vision_encoder.transformer.rotary.enabled: + self._vision_rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor( + self._config.vision_encoder.transformer.rotary, self._tensor_space + ) def get_output_layers(self) -> list[Layer]: return [ @@ -99,22 +106,35 @@ def get_output_layers(self) -> list[Layer]: ] ] + def get_vision_layers(self) -> list[Layer]: + patch_conv = PatchConv(self._config.vision_encoder, self._tensor_space) + vit_layers = [ + VisionTransformerLayer(self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1) + for idx in range(self._config.vision_encoder.transformer.num_layers) + ] + return [ + patch_conv, + *vit_layers, + VisionAdapter(self._config.vision_encoder, self._tensor_space), + MultiModalEmbedding(self._config, self._tensor_space), + ] + def get_layers(self) -> list[Layer]: if self._config.transformer.num_layers == 0: Assert.eq(self._config.prediction_heads, 1) return [ - ( - LanguageModelEmbedding(self._config, self._tensor_space) + *( + [LanguageModelEmbedding(self._config, self._tensor_space)] if self._config.vision_encoder is None - else MultiModalEmbedding(self._config, self._tensor_space) + else self.get_vision_layers(self._config, self._tensor_space) ), LanguageModelHead(self._config, self._tensor_space, 0), ] return [ - ( - LanguageModelEmbedding(self._config, self._tensor_space) + *( + [LanguageModelEmbedding(self._config, self._tensor_space)] if self._config.vision_encoder is None - else MultiModalEmbedding(self._config, self._tensor_space) + else self.get_vision_layers() ), *[ TransformerLayer( @@ -152,24 +172,24 @@ def preprocess_meta( if self._config.vision_encoder: image_size = batch_meta.max_image_size image_mean = [ - self._config.vision_encoder.normalization.mean_r, - self._config.vision_encoder.normalization.mean_g, - self._config.vision_encoder.normalization.mean_b, + self._config.vision_encoder.image_normalization.mean_r, + self._config.vision_encoder.image_normalization.mean_g, + self._config.vision_encoder.image_normalization.mean_b, ] image_std = [ - self._config.vision_encoder.normalization.std_r, - self._config.vision_encoder.normalization.std_g, - self._config.vision_encoder.normalization.std_b, + self._config.vision_encoder.image_normalization.std_r, + self._config.vision_encoder.image_normalization.std_g, + self._config.vision_encoder.image_normalization.std_b, ] - image_rescale_factor = self._config.vision_encoder.normalization.rescale_factor + image_rescale_factor = self._config.vision_encoder.image_normalization.rescale_factor vision_kwargs = { - VisionModelKwargs.patch_size: self._config.vision_encoder.encoder.patch_size, - VisionModelKwargs.image_size: image_size, - VisionModelKwargs.image_mean: image_mean, - VisionModelKwargs.image_std: image_std, - VisionModelKwargs.image_rescale_factor: image_rescale_factor, - VisionModelKwargs.rope_theta: self._config.vision_encoder.encoder.rope_theta, - VisionModelKwargs.kv_channels: self._tensor_space.get_tensor_dim( + VisionEncoderKwargs.patch_size: batch_meta.patch_size, + VisionEncoderKwargs.image_size: image_size, + VisionEncoderKwargs.image_mean: image_mean, + VisionEncoderKwargs.image_std: image_std, + VisionEncoderKwargs.image_rescale_factor: image_rescale_factor, + VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, + VisionEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( VisionEncoderDimNames.kv_channels ).size, } @@ -218,6 +238,18 @@ def preprocess_meta( if sequence_first else (batch_dim, hidden_sequence_q_dim, hidden_dim) ) + if self._config.vision_encoder: + vision_hidden_dim = self._tensor_space.get_tensor_dim(VisionTransformerDimNames.hidden) + vision_hidden_dims = ( + (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) + if sequence_first + else (batch_dim, hidden_sequence_q_dim, vision_hidden_dim) + ) + vision_kwargs.update( + { + VisionEncoderKwargs.hidden_dims: vision_hidden_dims, + } + ) common_kwargs = { LanguageModelKwargs.phase: phase, @@ -225,6 +257,7 @@ def preprocess_meta( TransformerKwargs.hidden_dims: hidden_dims, TransformerKwargs.sequence_length: sequence_length, TransformerKwargs.sequence_q_dim: sequence_q_dim, + TransformerKwargs.micro_batch_size: micro_batch_size, } common_kwargs.update(vision_kwargs) @@ -253,6 +286,9 @@ def preprocess_meta( self._position_embedding_preprocessor.preprocess_meta(kwargs) if self._config.transformer.rotary.enabled: self._rotary_embedding_preprocessor.preprocess_meta(kwargs) + if self._config.vision_encoder: + if self._config.vision_encoder.transformer.rotary.enabled: + self._vision_rotary_embedding_preprocessor.preprocess_meta(kwargs) if not self._use_flash_attention: self._backup_attention_preprocessor.preprocess_meta(kwargs) preprocessed_meta.append((tokens, kwargs)) @@ -294,6 +330,11 @@ def preprocess( self._rotary_embedding_preprocessor.create_tensors(sequence_length) if not self._use_flash_attention: self._backup_attention_preprocessor.create_tensors(sequence_length) + if self._config.vision_encoder and self._config.vision_encoder.transformer.rotary.enabled: + max_num_patches = ( + common_kwargs[VisionEncoderKwargs.image_size] // common_kwargs[VisionEncoderKwargs.patch_size] + ) + self._vision_rotary_embedding_preprocessor.create_tensors(sequence_length, max_num_patches) preprocessed = [] presents = None @@ -342,32 +383,38 @@ def preprocess( else: labels[i, start : end + 1] = -100 kwargs[LanguageModelKwargs.labels] = labels - if batch.images is not None: - kwargs[VisionModelKwargs.images] = [ - img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) - for images in batch.images - for img in images - ] - kwargs[VisionModelKwargs.image_positions] = batch.image_positions - if self._config.vision_encoder: - self._vision_preprocessor.preprocess(kwargs) if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.preprocess(kwargs) if self._config.transformer.rotary.enabled: self._rotary_embedding_preprocessor.preprocess(kwargs) if not self._use_flash_attention: self._backup_attention_preprocessor.preprocess(kwargs) - preprocessed.append((tokens, kwargs)) + if batch.images is not None: + kwargs[VisionEncoderKwargs.images] = [ + [ + img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + for img in images + ] + for images in batch.images + ] + kwargs[VisionEncoderKwargs.image_positions] = batch.image_positions + if self._config.vision_encoder: + self._vision_preprocessor.preprocess(kwargs) + self._vision_rotary_embedding_preprocessor.preprocess(kwargs) + kwargs[LanguageModelKwargs.tokens] = tokens + preprocessed.append((kwargs[VisionEncoderKwargs.image_patches], kwargs)) + else: + preprocessed.append((tokens, kwargs)) return preprocessed @property def embedding(self) -> LanguageModelEmbedding: - return self.layers[self._config.vision_encoder is not None] + return self.layers[self._config.vision_encoder.transformer.num_layers + 2] @property def transformer_layers(self) -> list[TransformerLayer]: - return self.layers[(self._config.vision_encoder is not None) + 1 : -1] + return self.layers[self._config.vision_encoder.transformer.num_layers + 3 : -1] @property def model_head(self) -> LanguageModelHead: diff --git a/fast_llm/tools/cli.py b/fast_llm/tools/cli.py index e9df18ed2..b1f14ccc5 100644 --- a/fast_llm/tools/cli.py +++ b/fast_llm/tools/cli.py @@ -32,7 +32,6 @@ def fast_llm(args=None): sys.exit(1) except Exception: # noqa logger.critical(traceback.format_exc()) - sys.exit(1) if __name__ == "__main__": From 99ad5d9bda84eea74e377c8cc75f7184bb0dcc76 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 May 2025 18:34:24 +0000 Subject: [PATCH 014/161] patches and fixes --- fast_llm/layers/language_model/config.py | 10 ++++++---- fast_llm/layers/vision_encoder/config.py | 2 ++ fast_llm/layers/vision_encoder/encoder.py | 2 +- .../layers/vision_encoder/preprocessing.py | 20 ++++++++++++++++++- fast_llm/models/gpt/model.py | 10 +++++++--- 5 files changed, 35 insertions(+), 9 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 887952d7a..ef0e7a5cc 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -45,8 +45,9 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.core, ) - vision_encoder: None | VisionEncoderArchitectureConfig = Field( - default=None, + # TODO Soham: make this None by default. Need to figure out how to handle this in the config (see ) + vision_encoder: VisionEncoderArchitectureConfig = Field( + default_factory=VisionEncoderArchitectureConfig, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, ) @@ -131,8 +132,9 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) - vision_encoder: None | VisionEncoderConfig = FieldUpdate( - default=None, + # TODO Soham: make this None by default. Need to figure out how to handle this in the config + vision_encoder: VisionEncoderConfig = FieldUpdate( + default_factory=VisionEncoderConfig, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 7c650bf93..283513727 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -7,6 +7,7 @@ class VisionEncoderDimNames: + in_channels = "vision_in_channels" out_channels = "vision_out_channels" adapter_size = "vision_adapter_size" patch_size = "vision_patch_size" @@ -62,6 +63,7 @@ class VisionEncoderKwargs: max_image_tokens = "max_image_tokens" patch_embeddings = "patch_embeddings" hidden_dims = "vit_hidden_dims" + image_patches_meta = "vit_image_patches_meta" # TODO Soham: do we need all of them? diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 9369037d4..ed6fbc92a 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -63,7 +63,7 @@ def forward( ) -> torch.Tensor: hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): - return TensorMeta.from_dims(hidden_dims) + return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) # we don't need images after this point # image_patches = kwargs.pop(VisionEncoderKwargs.image_patches) patch_embeddings = self.norm(self.conv(input_)) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index abae6f11a..c087cf6d0 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -4,13 +4,16 @@ import torch import torchvision.transforms.v2.functional as F -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.config import ( VisionEncoderArchitectureConfig, + VisionEncoderDimNames, VisionEncoderKwargs, + VisionTransformerDimNames, VisionTransformerKwargs, ) +from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -104,6 +107,21 @@ def __init__(self, config: VisionEncoderArchitectureConfig, tensor_space: Tensor self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + # kwargs[VisionEncoderDimNames] + kwargs[VisionEncoderKwargs.image_patches_meta] = TensorMeta.from_dims( + ( + TensorDim( + VisionTransformerDimNames.batch, + kwargs[TransformerKwargs.micro_batch_size] * kwargs[TransformerKwargs.sequence_q_dim].size, + ), + TensorDim(VisionEncoderDimNames.in_channels, 3), + TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + ), + dtype=self._distributed_config.training_dtype.torch, + ) + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: images = kwargs.get(VisionEncoderKwargs.images) im_height = kwargs.get(VisionEncoderKwargs.image_size) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 6aef273f6..5425a1e13 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -286,12 +286,16 @@ def preprocess_meta( self._position_embedding_preprocessor.preprocess_meta(kwargs) if self._config.transformer.rotary.enabled: self._rotary_embedding_preprocessor.preprocess_meta(kwargs) + if not self._use_flash_attention: + self._backup_attention_preprocessor.preprocess_meta(kwargs) if self._config.vision_encoder: + self._vision_preprocessor.preprocess_meta(kwargs) if self._config.vision_encoder.transformer.rotary.enabled: self._vision_rotary_embedding_preprocessor.preprocess_meta(kwargs) - if not self._use_flash_attention: - self._backup_attention_preprocessor.preprocess_meta(kwargs) - preprocessed_meta.append((tokens, kwargs)) + # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size + preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) + else: + preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta From bcb557aca291afcbb2e19969d2e7e1da16a93612 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 May 2025 20:44:40 +0000 Subject: [PATCH 015/161] fix dependency --- Dockerfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index b8e1f8887..149a498e0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,8 +3,7 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ - && apt-get install --no-install-recommends -y acl git-lfs \ - && apt-get install --no-install-recommends -y libtiff5-dev \ + && apt-get install --no-install-recommends -y acl git-lfs libtiff5-dev \ && rm -rf /var/lib/apt/lists/* \ && git lfs install From a6f5364d33c8d80ff46ea592612362fd03f85f30 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 May 2025 20:49:53 +0000 Subject: [PATCH 016/161] remove for testing --- Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 149a498e0..b7e42d4dc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,8 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ - && apt-get install --no-install-recommends -y acl git-lfs libtiff5-dev \ + && apt-get install --no-install-recommends -y acl git-lfs \ + # && apt-get install --no-install-recommends -y acl git-lfs libtiff5-dev \ && rm -rf /var/lib/apt/lists/* \ && git lfs install From 73b431b22d0c4b54d41d25a4dcf0738c5a1b1711 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 May 2025 21:57:17 +0000 Subject: [PATCH 017/161] mising --- .../layers/transformer/vision_transformer.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 fast_llm/layers/transformer/vision_transformer.py diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py new file mode 100644 index 000000000..94a9c70af --- /dev/null +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -0,0 +1,55 @@ +import torch + +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.tensor import TensorMeta + + +class VisionTransformerLayer(TransformerLayer): + """ + A vision transformer layer to encode image patches + """ + + def __init__( + self, + config: TransformerConfig, + tensor_space: TensorSpace, + layer_index: int, + return_input: bool = False, + ): + super().__init__(config, tensor_space, layer_index, return_input) + + hidden_dim = self._tensor_space.get_tensor_dim(VisionTransformerDimNames.hidden) + self.norm_1 = self._config.normalization.get_layer(hidden_dim) + self.norm_2 = self._config.normalization.get_layer(hidden_dim) + + self.norm_1 = self._config.peft.apply_other(self.norm_1) + self.norm_2 = self._config.peft.apply_other(self.norm_2) + + @property + def name(self) -> str: + return f"Vision transformer layer {self._layer_index}" + + def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): + dims = kwargs[VisionTransformerKwargs.hidden_dims] + if self._return_input: + dims = (TensorDim("stacked_input_output", 2),) + dims + return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) + + # TODO Soham: remove this since we only need to call the parent method + # def forward( + # self, + # input_: torch.Tensor, + # kwargs: dict[str, typing.Any], + # losses: dict[str, typing.Any] | None = None, + # metrics: dict[str, typing.Any] | None = None, + # ) -> torch.Tensor: + # if isinstance(input_, TensorMeta): + # return self._get_meta(input_, "output", kwargs) + # # Hack for now to compute the patch embeddings + # kwargs[VisionTransformerKwargs.patch_embeddings] = super().forward( + # kwargs.pop(VisionTransformerKwargs.patch_embeddings), kwargs, losses, metrics + # ) + # return input_ From 6d6567673450e3e97ae07879957a55875ec80caf Mon Sep 17 00:00:00 2001 From: root Date: Thu, 8 May 2025 06:11:54 +0000 Subject: [PATCH 018/161] fix --- fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/layers/multi_modal/embedding.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 0ba3f0e13..2f80ee77d 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -485,7 +485,7 @@ def _lazy_load(self): def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - if unshuffled_tokens := data.get("unshuffled_tokens") is not None: + if (unshuffled_tokens := data.get("unshuffled_tokens")) is not None: self._unshuffled_tokens = unshuffled_tokens else: self._unshuffled_tokens = data["unshuffled_epochs"] * data["dataset"]["tokens_per_epoch"] diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index a3abe7813..b7d79dd37 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -43,7 +43,8 @@ def forward( image_positions = kwargs.get(VisionEncoderKwargs.image_positions) tokens = kwargs.get(LanguageModelKwargs.tokens) # get text embeddings - embeddings = super()._forward(tokens, position_ids) + # TODO Soham: cloning to avoid pytorch complaint about in-place operation. Can we do better? + embeddings = super()._forward(tokens, position_ids).clone() image_idx = 0 for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 From 66e708170d98bd476e679fcbaf6fbf761b284388 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 9 May 2025 18:39:55 +0000 Subject: [PATCH 019/161] fixes --- fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/layers/language_model/config.py | 6 +----- fast_llm/layers/vision_encoder/config.py | 6 +++--- fast_llm/layers/vision_encoder/preprocessing.py | 7 ++++--- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index cb6d6c8d4..54564a212 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -278,7 +278,7 @@ def _sample(self) -> None: # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` if unshuffled_epochs > 0: token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( - document_sizes + image_token_sizes + document_sizes + image_token_sizes, offset=0, # TODO: Allowing for max 100% extra tokens for padding, is that enough? dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 451044207..ab5707804 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,12 +5,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl -<<<<<<< HEAD -from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig -from fast_llm.layers.vision_encoder.config import VisionEncoderArchitectureConfig, VisionEncoderConfig -======= from fast_llm.layers.transformer.config import TransformerConfig ->>>>>>> main +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig from fast_llm.utils import Assert diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index b15f90bdb..345b118ed 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,9 +1,9 @@ -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, config_class -from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig +from fast_llm.config import Config, Field, FieldHint, config_class +from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationConfig -from fast_llm.layers.transformer.config import TransformerArchitectureConfig, VisionTransformerConfig +from fast_llm.layers.transformer.config import VisionTransformerConfig class VisionEncoderDimNames: diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index c087cf6d0..7bd8a2aa1 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -4,10 +4,11 @@ import torch import torchvision.transforms.v2.functional as F +from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.config import ( - VisionEncoderArchitectureConfig, + VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs, VisionTransformerDimNames, @@ -101,8 +102,8 @@ def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tenso return ids[:, 0] -class VisionPreprocessor: - def __init__(self, config: VisionEncoderArchitectureConfig, tensor_space: TensorSpace): +class VisionPreprocessor(Preprocessor): + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config From 7f86a7f1889065ca06dade517d0cc69ef8b83215 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 12 May 2025 06:18:03 +0000 Subject: [PATCH 020/161] fix --- fast_llm/data/dataset/gpt/sampled.py | 19 ++- fast_llm/data/tokenizer.py | 153 ++++++++++++------ fast_llm/engine/schedule/config.py | 2 +- fast_llm/layers/language_model/config.py | 2 +- fast_llm/layers/transformer/config.py | 43 ++--- fast_llm/layers/vision_encoder/config.py | 1 - .../layers/vision_encoder/preprocessing.py | 2 +- fast_llm/models/gpt/model.py | 19 ++- 8 files changed, 153 insertions(+), 88 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 54564a212..f99a9d3ef 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -138,7 +138,9 @@ def _sample(self) -> None: image_token_sizes = torch.zeros_like(document_sizes).to(self._device) # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): - image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size) * (sizes[:, 1] // self._patch_size)) + image_token_sizes[i] = sum( + (sizes[:, 0] // self._parameters.patch_size) * (sizes[:, 1] // self._parameters.patch_size) + ) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() @@ -195,7 +197,7 @@ def _sample(self) -> None: "num_samples": self._parameters.num_samples, "unshuffled_epochs": unshuffled_epochs, "sequence_length": self._parameters.sequence_length, - "patch_size": self._patch_size, + "patch_size": self._parameters.patch_size, "truncate_documents": self._truncate_documents, "config": self._config.to_dict(), } @@ -405,12 +407,19 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size, image_lengths = self._indexed_dataset.get_document_size(document_index, self._patch_size) + document_size, image_lengths = self._indexed_dataset.get_document_size( + document_index, self._parameters.patch_size + ) image_sizes = [ get_num_patches( - *get_resize_dims(*image_length, self._image_size, self._image_size, self._patch_size), - self._patch_size, + *get_resize_dims( + *image_length, + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ), + self._parameters.patch_size, ) for image_length in image_lengths ] diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 0e7d54709..10b8b2c64 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -42,64 +42,119 @@ def _tokenize(self, text: str, begin=True, end=True) -> list[int]: + ([self.eod_id] if end else []) ) - def tokenize(self, text, image_positions=None): + def tokenize(self, text: str, image_positions=None, char_spans=None) -> tuple[list[int], list[tuple[int, int]]]: + """ + Tokenize the input text and return the tokenized input_ids along with token spans. + """ + # if not image_positions and not char_spans: + # return self._tokenize(text), [], [] if not image_positions: - return self._tokenize(text), [], [] + image_positions = [] + if not char_spans: + char_spans = [] + image_idx = 0 char_pos = 0 token_ids = [] image_token_positions = [] beginning_of_text = True - while image_idx < len(image_positions): - if image_positions[image_idx] > len(text): - raise ValueError( - f"Image position {image_positions[image_idx]} is greater than text length {len(text)}" - ) - curr_text = text[char_pos : image_positions[image_idx]] - tokenized_text = self._tokenize( - curr_text, begin=beginning_of_text, end=image_positions[image_idx] >= len(text) - ) - beginning_of_text = False - token_ids.extend(tokenized_text) - image_token_positions = len(token_ids) - char_pos = image_positions[image_idx] - image_idx += 1 - if char_pos < len(text): - curr_text = text[char_pos:] - tokenized_text = self._tokenize(curr_text, begin=beginning_of_text, end=True) - token_ids.extend(tokenized_text) - return token_ids, image_token_positions - - def tokenize_with_spans( - self, text: str, char_spans: list[tuple[int, int]] - ) -> tuple[list[int], list[tuple[int, int]]]: - """ - Perform span-aware tokenization and return the tokenized input_ids along with token spans. - """ - input_ids = [] - token_spans = [] - char_pos = 0 - beginning_of_text = True for start, end in char_spans: + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + while image_position <= start: + tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) + beginning_of_text = False + image_token_positions.append(len(token_ids)) + token_ids.extend(tokenized_text) + image_idx += 1 + char_pos = image_position + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") if char_pos < start: - curr_text = text[char_pos:start] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) + self._tokenize(text[char_pos:start], begin=beginning_of_text, end=False) beginning_of_text = False - input_ids.extend(tokenized_text) - curr_text = text[start : end + 1] - if end >= len(text) - 1: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - else: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) - input_ids.extend(tokenized_text) - char_pos = end + 1 - if char_pos < len(text): - curr_text = text[char_pos:] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - input_ids.extend(tokenized_text) - return input_ids, token_spans + token_ids.extend(tokenized_text) + char_pos = start + len(token_ids) + span_length = 0 + while image_position <= end: + tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) + beginning_of_text = False + image_token_positions.append(len(token_ids)) + token_ids.extend(tokenized_text) + span_length += len(tokenized_text) + char_pos = image_position + image_idx += 1 + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + if char_pos < end: + if end >= len(text) - 1: + tokenized_text = self._tokenize(text[char_pos : end + 1], begin=beginning_of_text, end=True) + beginning_of_text = False + token_ids.extend(tokenized_text) + span_length += len(tokenized_text) + char_pos = end + 1 + else: + tokenized_text = self._tokenize(text[char_pos : end + 1], begin=beginning_of_text, end=False) + beginning_of_text = False + token_ids.extend(tokenized_text) + span_length += len(tokenized_text) + + # def tokenize(self, text, image_positions=None): + # if not image_positions: + # return self._tokenize(text), [], [] + # image_idx = 0 + # char_pos = 0 + # token_ids = [] + # image_token_positions = [] + # beginning_of_text = True + # while image_idx < len(image_positions): + # if image_positions[image_idx] > len(text): + # raise ValueError( + # f"Image position {image_positions[image_idx]} is greater than text length {len(text)}" + # ) + # curr_text = text[char_pos : image_positions[image_idx]] + # tokenized_text = self._tokenize( + # curr_text, begin=beginning_of_text, end=image_positions[image_idx] >= len(text) + # ) + # beginning_of_text = False + # token_ids.extend(tokenized_text) + # image_token_positions = len(token_ids) + # char_pos = image_positions[image_idx] + # image_idx += 1 + # if char_pos < len(text): + # curr_text = text[char_pos:] + # tokenized_text = self._tokenize(curr_text, begin=beginning_of_text, end=True) + # token_ids.extend(tokenized_text) + # return token_ids, image_token_positions + + # def tokenize_with_spans( + # self, text: str, char_spans: list[tuple[int, int]] + # ) -> tuple[list[int], list[tuple[int, int]]]: + # """ + # Perform span-aware tokenization and return the tokenized input_ids along with token spans. + # """ + # input_ids = [] + # token_spans = [] + # char_pos = 0 + # beginning_of_text = True + # for start, end in char_spans: + # if char_pos < start: + # curr_text = text[char_pos:start] + # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) + # beginning_of_text = False + # input_ids.extend(tokenized_text) + # curr_text = text[start : end + 1] + # if end >= len(text) - 1: + # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) + # else: + # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) + # beginning_of_text = False + # token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) + # input_ids.extend(tokenized_text) + # char_pos = end + 1 + # if char_pos < len(text): + # curr_text = text[char_pos:] + # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) + # input_ids.extend(tokenized_text) + # return input_ids, token_spans def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 10f87835b..48daf0e69 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -55,7 +55,7 @@ class BatchConfig(Config): desc="Patch size for each image token", hint=FieldHint.optional, ) - max_image_size: int | None = Field( + image_size: int | None = Field( default=None, desc="Maximum image height and width", hint=FieldHint.optional, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index ab5707804..78de218f1 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -167,7 +167,7 @@ def _validate(self) -> None: raise NotImplementedError("Multi-token prediction not supported with distillation.") def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - self.transformer.setup_tensor_space(tensor_space, type="vision" if self.vision_encoder is not None else None) + self.transformer.setup_tensor_space(tensor_space) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Embedding dimensions diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 55320a1b5..38dc9ec48 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -169,6 +169,7 @@ class VisionRotaryConfig(RotaryConfig): hint=FieldHint.feature, ) + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" everywhere = "everywhere" @@ -668,59 +669,61 @@ def setup_tensor_space(self, tensor_space: TensorSpace, type: str | None = None) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.hidden, self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.hidden, self.hidden_size)) # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - TransformerDimNames.group_heads, + transformer_dim_names.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(transformer_dim_names.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(transformer_dim_names.kv_channels, self.kv_channels)) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) + CompositeTensorDim(transformer_dim_names.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) ) # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(TransformerDimNames.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim(gate_and_up := TensorDim(TransformerDimNames.gate_and_up, 2 if self.gated else 1)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(TransformerDimNames.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_expert_mlp, (experts, mlp))) + tensor_space.add_tensor_dim(mlp := TensorDim(transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) + tensor_space.add_tensor_dim( + gate_and_up := TensorDim(transformer_dim_names.gate_and_up, 2 if self.gated else 1) + ) + tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp))) + tensor_space.add_tensor_dim(experts := TensorDim(transformer_dim_names.experts, self.num_experts)) + tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_expert_mlp, (experts, mlp))) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + CompositeTensorDim(transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) ) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.unshared_experts, self.num_unshared_experts)) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.unshared_experts, self.num_unshared_experts)) # shared_experts if self.num_shared_experts: tensor_space.add_tensor_dim( - shared_experts := TensorDim(TransformerDimNames.shared_experts, self.num_shared_experts) + shared_experts := TensorDim(transformer_dim_names.shared_experts, self.num_shared_experts) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) + CompositeTensorDim(transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - TransformerDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) + transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) ) ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 345b118ed..4dde28bee 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -176,4 +176,3 @@ def setup_tensor_space(self, tensor_space: TensorSpace): ) ) self.transformer.setup_tensor_space(tensor_space, type="vision") - super().setup_tensor_space(tensor_space) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 7bd8a2aa1..46bf0ab3f 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -123,7 +123,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: dtype=self._distributed_config.training_dtype.torch, ) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: images = kwargs.get(VisionEncoderKwargs.images) im_height = kwargs.get(VisionEncoderKwargs.image_size) im_width = kwargs.get(VisionEncoderKwargs.image_size) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c80c05f94..b832f1b04 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -77,14 +77,10 @@ def __init__( self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) if self._config.vision_encoder: - self._preprocessors.append( - VisionPreprocessor(self._config.vision_encoder, self._tensor_space) - ) + self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) if self._config.vision_encoder.transformer.rotary.enabled: self._preprocessors.append( - RotaryEmbeddingPreprocessor( - self._config.vision_encoder.transformer.rotary, self._tensor_space - ) + RotaryEmbeddingPreprocessor(self._config.vision_encoder.transformer.rotary, self._tensor_space) ) # self._vision_preprocessor = VisionPreprocessor(self._config.vision_encoder, self._tensor_space) # if self._config.vision_encoder.transformer.rotary.enabled: @@ -167,7 +163,7 @@ def preprocess_meta( micro_sequence_length = sequence_length if self._config.vision_encoder: - image_size = batch_meta.max_image_size + image_size = batch_meta.image_size image_mean = [ self._config.vision_encoder.image_normalization.mean_r, self._config.vision_encoder.image_normalization.mean_g, @@ -411,8 +407,6 @@ def preprocess( kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) - for preprocessor in self._preprocessors: - preprocessor.preprocess(tokens, kwargs) if batch.images is not None: kwargs[VisionEncoderKwargs.images] = [ [ @@ -423,7 +417,12 @@ def preprocess( ] kwargs[VisionEncoderKwargs.image_positions] = batch.image_positions kwargs[LanguageModelKwargs.tokens] = tokens - preprocessed.append((kwargs[VisionEncoderKwargs.image_patches], kwargs)) + + for preprocessor in self._preprocessors: + preprocessor.preprocess(tokens, kwargs) + image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) + if image_patches is not None: + preprocessed.append((image_patches, kwargs)) else: preprocessed.append((tokens, kwargs)) From 3a8a99d62c559f97f35d37dc4c2133d5e0a77a73 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 12 May 2025 15:23:33 +0000 Subject: [PATCH 021/161] more fixes after merge --- fast_llm/layers/transformer/preprocessing.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 01b953976..870463df2 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -15,7 +15,11 @@ TransformerKwargs, VisionTransformerConfig, ) -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import ( + VisionEncoderKwargs, + VisionTransformerDimNames, + VisionTransformerKwargs, +) from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -163,6 +167,7 @@ def get_2d_rotary_frequencies( return frequencies + class RotaryEmbeddingPreprocessor(Preprocessor): _scalar_dim: TensorDim _mask: torch.Tensor @@ -216,7 +221,11 @@ def _create_tensors(self, sequence_length: int, num_patches: None | int = None) ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) + if self._config.type == RotaryEmbeddingType.pixtral: + max_num_patches = kwargs[VisionEncoderKwargs.image_size] // kwargs[VisionEncoderKwargs.patch_size] + self._create_tensors(kwargs[TransformerKwargs.sequence_length], max_num_patches) + else: + self._create_tensors(kwargs[TransformerKwargs.sequence_length]) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size if self._config.type == RotaryEmbeddingType.pixtral: From d16284ee0b96598e63e74c27b6b09e7e70d9d367 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 12 May 2025 19:32:32 +0000 Subject: [PATCH 022/161] conv cleanup --- fast_llm/data/dataset/gpt/memmap.py | 1 - fast_llm/data/preparator/gpt_memmap/config.py | 1 - fast_llm/layers/vision_encoder/config.py | 6 +++ fast_llm/layers/vision_encoder/encoder.py | 39 ++++++++++--------- fast_llm/models/gpt/conversion.py | 6 ++- setup.cfg | 1 - 6 files changed, 30 insertions(+), 24 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 8651b8fcd..5d3df5983 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -163,7 +163,6 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap - # TODO Soham: get images def get( self, idx: int, diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 89fe904cd..38d90ed42 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -180,7 +180,6 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.optional, ) - # TODO Soham: move tokenizer validation to MultiModalDataProcessor def _validate(self) -> None: assert self.tokenizer.path is not None if self.dataset.data_type is not None: diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 4dde28bee..be3fb38cb 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -144,6 +144,11 @@ class VisionEncoderConfig(BaseModelConfig): desc="Patch size for the image encoder.", hint=FieldHint.core, ) + conv_bias: bool = Field( + default=False, + desc="Whether to use bias in the convolutional layer.", + hint=FieldHint.optional, + ) patch_norm: NormalizationConfig = Field( default_factory=NormalizationConfig, desc="Configuration for the normalization layers applied to the image patches.", @@ -169,6 +174,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.adapter_size, self.adapter_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.patch_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.in_channels, 3)) # TODO Soham: add a check for presence of kv channels parameter (head_dim) tensor_space.add_tensor_dim( TensorDim( diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index ed6fbc92a..59212c58f 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -3,7 +3,7 @@ import torch from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ @@ -38,21 +38,25 @@ def generate_block_attention_mask(patch_embeds_list, tensor): class PatchConv(Layer): def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() - # TODO Soham: device=meta - with torch.device("meta"): - self.conv = torch.nn.Conv2d( - in_channels=3, - out_channels=config.transformer.hidden_size, - kernel_size=config.patch_size, - stride=config.patch_size, - bias=False, - dtype=tensor_space.distributed_config.training_dtype.torch, - ) - self.conv.weight = ParameterMeta.from_dims( - tuple(TensorDim(f"patch_conv_weight_{idx}", size) for idx, size in enumerate(self.conv.weight.shape)), - init_method=init_normal_(), + self._tensor_space = tensor_space + # TODO Soham: lr_scale + self.weight = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels), + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.in_channels), + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size), + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size), + ), + init_method=init_normal_(), + ) + if config.conv_bias: + self.bias = ParameterMeta.from_dims( + (self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels),) ) + else: + self.bias = None self.norm = config.patch_norm.get_layer(tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels)) + self.stride = config.patch_size def forward( self, @@ -64,10 +68,7 @@ def forward( hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) - # we don't need images after this point - # image_patches = kwargs.pop(VisionEncoderKwargs.image_patches) - patch_embeddings = self.norm(self.conv(input_)) + input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) + patch_embeddings = self.norm(input_.flatten(1)) patch_embeddings = patch_embeddings.reshape(*(x.size for x in hidden_dims)) - # Hack to pass patch embeddings to the next layer - # kwargs[VisionEncoderKwargs.patch_embeddings] = patch_embeddings return patch_embeddings diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 4b08d564a..6aa3aaf1f 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -728,7 +728,9 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: return vision_transformer_converters def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: - patch_conv_converter = WeightConverter("layers.0.conv.weight", "vision_tower.patch_conv.weight") + patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] + if self._model.config.base_model.vision_encoder.conv_bias: + patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) layernorm_converters = [ WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), ] @@ -745,7 +747,7 @@ def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), ] - return [patch_conv_converter] + layernorm_converters + vision_transformer_converters + adapter_converters + return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters def _create_weight_converters(self) -> list[WeightConverter]: vision_encoder_converter = self._create_vision_encoder_weight_converters() diff --git a/setup.cfg b/setup.cfg index 3b5eea402..25f8af8bd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,7 +45,6 @@ OPTIONAL = requests>=2.32.3 tqdm>=4.66.3 # Vision Tools - # TODO Soham: use pillow-simd instead of pillow? webp>=0.4.0 pillow-simd>=9.5.0 torchvision>=0.20.0 From b3134aade1428641c47ea831d50701a43ee222ca Mon Sep 17 00:00:00 2001 From: root Date: Mon, 12 May 2025 19:35:17 +0000 Subject: [PATCH 023/161] more conv cleanup --- fast_llm/engine/multi_stage/stage_base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 4d9cd8488..fd50f55c5 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -162,9 +162,6 @@ def _replace(module: torch.nn.Module): nonlocal i for key in module._parameters: meta = typing.cast(ParameterMeta, module._parameters[key]) - # TODO Soham: clean way to get around check? - if meta is None: - continue module._parameters[key] = self.get_parameter_buffer(meta.tensor_name) i += 1 From c8aa66ec3793e222e0412afa0b142869f513e431 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 06:08:16 +0000 Subject: [PATCH 024/161] images + loss-masks --- fast_llm/data/dataset/gpt/memmap.py | 94 +++++++++++++------ .../data/preparator/gpt_memmap/prepare.py | 9 +- fast_llm/data/tokenizer.py | 81 ++++------------ 3 files changed, 92 insertions(+), 92 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 5d3df5983..73fb3903a 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -10,6 +10,7 @@ from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.layers.vision_encoder.preprocessing import get_num_patches, get_resize_dims from fast_llm.utils import Assert, div @@ -114,7 +115,6 @@ def _init( self._image_lengths = [] self._image_positions = [] images_seen = 0 - # TODO Soham: verify correctness, reshaping into width, height? for n_images in self._n_images: self._image_lengths.append( np.frombuffer( @@ -141,8 +141,6 @@ def _init( self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) - # TODO Soham: fix num_tokens to include images. Get total number of image pixels from index file and assign - # self._num_tokens = div(self._bin_buffer_mmap.size - n_pixels, np.dtype(self._dtype).itemsize) self._num_tokens = div(self._bin_buffer_mmap.size - self._num_pixels, np.dtype(self._dtype).itemsize) if num_pixels is not None: assert self._num_pixels == num_pixels @@ -163,21 +161,54 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap + # def get( + # self, + # idx: int, + # offset: int = 0, + # image_offset: int = 0, + # length: int | None = None, + # use_loss_masking_spans: bool = False, + # ): + # token_ids = np.frombuffer( + # self._bin_buffer, + # dtype=self._dtype, + # count=self._document_sizes[idx] - offset if length is None else length, + # offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, + # ) + # if self._has_images: + # image_positions = self._image_positions[idx] + # pixels = np.frombuffer( + # self._bin_buffer, + # dtype=np.dtype(np.uint8), + # count=self._image_lengths[idx].prod(initial=3), + # offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, + # ) + # images = [] + # start = 0 + # for image_length in self._image_lengths[idx]: + # n_pixels = image_length.prod(initial=3) + # images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) + # start += n_pixels + # return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) + def get( self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False, - ): - # TODO Soham: handle spans + patch_size: int | None = None, + image_size: int | None = None, + ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, dtype=self._dtype, count=self._document_sizes[idx] - offset if length is None else length, offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) + images = None if self._has_images: + # Truncations with images are not yet supported image_positions = self._image_positions[idx] pixels = np.frombuffer( self._bin_buffer, @@ -188,32 +219,39 @@ def get( images = [] start = 0 for image_length in self._image_lengths[idx]: - # TODO Soham: verify reshape dimension order n_pixels = image_length.prod(initial=3) images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) start += n_pixels - # TODO Soham: return loss_masking_spans - return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) - - # def get( - # self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False - # ) -> GPTSample: - # token_ids = np.frombuffer( - # self._bin_buffer, - # dtype=self._dtype, - # count=self._document_sizes[idx] - offset if length is None else length, - # offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, - # ) - # sample_spans = None - # if use_loss_masking_spans and self._spans is not None: - # sample_spans = self._spans[idx] - # # adjust the spans for the offset and length - # sample_spans = sample_spans[ - # (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) - # ] - # sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset - # sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - # return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans) + sample_spans = None + if use_loss_masking_spans and self._spans is not None: + sample_spans = self._spans[idx] + sample_spans = sample_spans[ + (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) + ] + sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset + sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset + if images: + image_idx = 0 + for span in sample_spans: + additional_tokens = 0 + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + while image_position >= span[0] and image_position <= span[1]: + image_tokens = get_num_patches( + get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), + patch_size, + ) + additional_tokens += image_tokens + image_idx += 1 + image_position = ( + image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + ) + span[1] += additional_tokens + return GPTSample( + token_ids=token_ids, + images=images, + image_positions=image_positions, + loss_masking_spans=sample_spans, + ) @property def name(self) -> str: diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 2a3778df6..b6d817730 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -50,21 +50,24 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ # np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) # for text in batch[self._config.dataset.field] # ] - input_ids, image_token_positions = map( + input_ids, token_spans, image_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), + np.array(token_spans, dtype=np.int32).reshape(-1, 2), np.array(image_token_positions, dtype=np.int32), ) - for input_ids, image_token_positions in [ + for input_ids, token_spans, image_token_positions in [ self._tokenizer.tokenize( text, + loss_mask_spans, im_char_positions, ) - for text, im_char_positions in zip( + for text, loss_mask_spans, im_char_positions in zip( batch[self._config.dataset.field], + batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), batch.get(self._config.dataset.image_positions, itertools.repeat(None)), ) ] diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 10b8b2c64..c44715d80 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -42,7 +42,7 @@ def _tokenize(self, text: str, begin=True, end=True) -> list[int]: + ([self.eod_id] if end else []) ) - def tokenize(self, text: str, image_positions=None, char_spans=None) -> tuple[list[int], list[tuple[int, int]]]: + def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[list[int], list[tuple[int, int]]]: """ Tokenize the input text and return the tokenized input_ids along with token spans. """ @@ -57,14 +57,15 @@ def tokenize(self, text: str, image_positions=None, char_spans=None) -> tuple[li char_pos = 0 token_ids = [] image_token_positions = [] + token_spans = [] beginning_of_text = True + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") for start, end in char_spans: - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") while image_position <= start: tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) beginning_of_text = False - image_token_positions.append(len(token_ids)) token_ids.extend(tokenized_text) + image_token_positions.append(len(token_ids)) image_idx += 1 char_pos = image_position image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") @@ -75,11 +76,12 @@ def tokenize(self, text: str, image_positions=None, char_spans=None) -> tuple[li char_pos = start len(token_ids) span_length = 0 + token_start = len(token_ids) while image_position <= end: tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) beginning_of_text = False - image_token_positions.append(len(token_ids)) token_ids.extend(tokenized_text) + image_token_positions.append(len(token_ids)) span_length += len(tokenized_text) char_pos = image_position image_idx += 1 @@ -96,65 +98,22 @@ def tokenize(self, text: str, image_positions=None, char_spans=None) -> tuple[li beginning_of_text = False token_ids.extend(tokenized_text) span_length += len(tokenized_text) + char_pos = end + 1 + token_spans.append((token_start, token_start + span_length - 1)) - # def tokenize(self, text, image_positions=None): - # if not image_positions: - # return self._tokenize(text), [], [] - # image_idx = 0 - # char_pos = 0 - # token_ids = [] - # image_token_positions = [] - # beginning_of_text = True - # while image_idx < len(image_positions): - # if image_positions[image_idx] > len(text): - # raise ValueError( - # f"Image position {image_positions[image_idx]} is greater than text length {len(text)}" - # ) - # curr_text = text[char_pos : image_positions[image_idx]] - # tokenized_text = self._tokenize( - # curr_text, begin=beginning_of_text, end=image_positions[image_idx] >= len(text) - # ) - # beginning_of_text = False - # token_ids.extend(tokenized_text) - # image_token_positions = len(token_ids) - # char_pos = image_positions[image_idx] - # image_idx += 1 - # if char_pos < len(text): - # curr_text = text[char_pos:] - # tokenized_text = self._tokenize(curr_text, begin=beginning_of_text, end=True) - # token_ids.extend(tokenized_text) - # return token_ids, image_token_positions + while image_position <= len(text): + image_position = image_positions[image_idx] + tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) + beginning_of_text = False + token_ids.extend(tokenized_text) + image_token_positions.append(len(token_ids)) + char_pos = image_position + image_idx += 1 + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + tokenized_text = self._tokenize(text[char_pos:], begin=beginning_of_text, end=True) + token_ids.extend(tokenized_text) - # def tokenize_with_spans( - # self, text: str, char_spans: list[tuple[int, int]] - # ) -> tuple[list[int], list[tuple[int, int]]]: - # """ - # Perform span-aware tokenization and return the tokenized input_ids along with token spans. - # """ - # input_ids = [] - # token_spans = [] - # char_pos = 0 - # beginning_of_text = True - # for start, end in char_spans: - # if char_pos < start: - # curr_text = text[char_pos:start] - # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - # beginning_of_text = False - # input_ids.extend(tokenized_text) - # curr_text = text[start : end + 1] - # if end >= len(text) - 1: - # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - # else: - # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - # beginning_of_text = False - # token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) - # input_ids.extend(tokenized_text) - # char_pos = end + 1 - # if char_pos < len(text): - # curr_text = text[char_pos:] - # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - # input_ids.extend(tokenized_text) - # return input_ids, token_spans + return token_ids, token_spans, image_token_positions def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) From 0baae59dc9c4d7401a98b253b03fb41323219910 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 06:21:39 +0000 Subject: [PATCH 025/161] minor fixes --- fast_llm/data/dataset/gpt/indexed.py | 4 ++-- fast_llm/data/dataset/gpt/memmap.py | 8 +------- fast_llm/data/dataset/gpt/sampled.py | 6 ++---- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 209c6e317..f8260413d 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -48,8 +48,8 @@ def get_document_sizes(self) -> np.ndarray: doc_sizes, im_sizes = self._dataset.get_document_sizes() return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] - def get_document_size(self, index: int, patch_size: list[int]) -> int: - return self._dataset.get_document_size(self._begin + index, patch_size) + def get_document_size(self, index: int) -> int: + return self._dataset.get_document_size(self._begin + index) @property def has_images(self) -> bool: diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 73fb3903a..af632d5b4 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -268,7 +268,6 @@ def num_tokens(self) -> int: def has_images(self) -> bool: return self._has_images - # TODO: image sizes def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ The size of each document in the dataset. @@ -277,12 +276,7 @@ def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ return self._document_sizes, self._image_lengths - def get_document_size(self, index: int, patch_size: list[int]) -> int: - # return self._document_sizes[index].item() + ( - # sum((h // patch_size[0]) * (w // patch_size[1]) for h, w in self._image_lengths[index]) - # if self._has_images - # else 0 - # ) + def get_document_size(self, index: int) -> int: return self._document_sizes[index].item(), self._image_lengths[index] if self._has_images else [] @classmethod diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f99a9d3ef..2a1df4430 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -407,9 +407,7 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size, image_lengths = self._indexed_dataset.get_document_size( - document_index, self._parameters.patch_size - ) + document_size, image_lengths = self._indexed_dataset.get_document_size(document_index) image_sizes = [ get_num_patches( @@ -582,7 +580,7 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ logger.info(f" > Sampling dataset {self._indexed_dataset.name} ...") - document_sizes = self._indexed_dataset.get_document_sizes() + document_sizes, _ = self._indexed_dataset.get_document_sizes() num_documents = len(document_sizes) num_tokens = document_sizes.sum() np_rng = np.random.RandomState(seed=self._config.seed) From 48855be3c9413298a38af9a94ee25eb56167815f Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 06:30:55 +0000 Subject: [PATCH 026/161] cleanup --- fast_llm/data/dataset/gpt/sampled.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 2a1df4430..a8ad574c1 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -135,12 +135,24 @@ def _sample(self) -> None: # TODO Soham: verify numpy correctness document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) - image_token_sizes = torch.zeros_like(document_sizes).to(self._device) + image_token_sizes = [] # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): - image_token_sizes[i] = sum( - (sizes[:, 0] // self._parameters.patch_size) * (sizes[:, 1] // self._parameters.patch_size) + image_token_sizes.append( + sum( + get_num_patches( + *get_resize_dims( + *size, + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ), + self._parameters.patch_size, + ) + for size in sizes + ) ) + image_token_sizes = image_token_sizes.to(self._device) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() From f35e003d82b05e4787bc791928e1955262d4ba6a Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 06:34:37 +0000 Subject: [PATCH 027/161] cleanup --- fast_llm/data/dataset/gpt/sampled.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index a8ad574c1..ce92d1c1f 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -434,14 +434,15 @@ def __getitem__(self, index: int) -> typing.Any: for image_length in image_lengths ] image_tokens = sum(image_sizes) + document_size += image_tokens if not self._truncate_documents: - if document_size + image_tokens > self._parameters.sequence_length + 1: + if document_size > self._parameters.sequence_length + 1: # Document too long, ignore document_sampling_index += 1 continue tokens_in_sample = token_count % (self._parameters.sequence_length + 1) - if document_size + image_tokens + tokens_in_sample > self._parameters.sequence_length + 1: + if document_size + tokens_in_sample > self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample if token_count > token_start: @@ -454,7 +455,7 @@ def __getitem__(self, index: int) -> typing.Any: token_count += padding_size # Determine if the document belongs to the requested sample. - if token_count + document_size + image_tokens >= token_start: + if token_count + document_size >= token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) @@ -488,7 +489,7 @@ def __getitem__(self, index: int) -> typing.Any: # Go to the next document. document_sampling_index += 1 - token_count += document_size + image_tokens + token_count += document_size sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) From 4eb34cb0c4a4be901d079aaf0997e048035dbce6 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 06:41:02 +0000 Subject: [PATCH 028/161] cleanup --- fast_llm/data/dataset/gpt/sampled.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index ce92d1c1f..01459fa0a 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -96,7 +96,7 @@ def __init__( # TODO Soham: use something else for this check, introducing has_images for just this check might be unnecessary. if self._indexed_dataset.has_images and self._truncate_documents: raise RuntimeError( - "Truncating documents with images is not supported. Please turn off truncation to use images." + "Truncating documents with images is not yet supported. Please turn off truncation to use images." ) if sampling.cache_directory is None: @@ -132,11 +132,9 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - # TODO Soham: verify numpy correctness document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) image_token_sizes = [] - # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): image_token_sizes.append( sum( @@ -476,7 +474,6 @@ def __getitem__(self, index: int) -> typing.Any: start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) images.append(sample.images) - # TODO Soham: add offsets for loss masking spans if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: span = np.clip( From ebb9e276a3b97b3571e26c346a986be67d8e87cc Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 15:13:09 +0000 Subject: [PATCH 029/161] cleanup --- fast_llm/data/dataset/gpt/indexed.py | 1 - .../layers/transformer/vision_transformer.py | 16 ---------------- 2 files changed, 17 deletions(-) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index f8260413d..6e9bef96d 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -11,7 +11,6 @@ class GPTIndexedDataset(IndexedDataset): - # TODO Soham: should we change this to include images? @abc.abstractmethod def get_document_sizes(self) -> np.ndarray: """ diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py index 94a9c70af..3588956c7 100644 --- a/fast_llm/layers/transformer/vision_transformer.py +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -37,19 +37,3 @@ def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) - - # TODO Soham: remove this since we only need to call the parent method - # def forward( - # self, - # input_: torch.Tensor, - # kwargs: dict[str, typing.Any], - # losses: dict[str, typing.Any] | None = None, - # metrics: dict[str, typing.Any] | None = None, - # ) -> torch.Tensor: - # if isinstance(input_, TensorMeta): - # return self._get_meta(input_, "output", kwargs) - # # Hack for now to compute the patch embeddings - # kwargs[VisionTransformerKwargs.patch_embeddings] = super().forward( - # kwargs.pop(VisionTransformerKwargs.patch_embeddings), kwargs, losses, metrics - # ) - # return input_ From 51098ef106b72a0528c71558aba1405993d96aa0 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 15:45:07 +0000 Subject: [PATCH 030/161] fix --- fast_llm/data/dataset/gpt/sampled.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 01459fa0a..fc2ddb6a0 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -150,7 +150,7 @@ def _sample(self) -> None: for size in sizes ) ) - image_token_sizes = image_token_sizes.to(self._device) + image_token_sizes = torch.tensor(image_token_sizes).to(self._device) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() @@ -417,7 +417,7 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size, image_lengths = self._indexed_dataset.get_document_size(document_index) + text_size, image_lengths = self._indexed_dataset.get_document_size(document_index) image_sizes = [ get_num_patches( @@ -432,7 +432,7 @@ def __getitem__(self, index: int) -> typing.Any: for image_length in image_lengths ] image_tokens = sum(image_sizes) - document_size += image_tokens + document_size = text_size + image_tokens if not self._truncate_documents: if document_size > self._parameters.sequence_length + 1: @@ -456,7 +456,7 @@ def __getitem__(self, index: int) -> typing.Any: if token_count + document_size >= token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) - token_end_index_in_document = min(token_end - token_count, document_size) + token_end_index_in_document = min(token_end - token_count, text_size) sample = self._indexed_dataset.get( document_index, offset=token_start_index_in_document, From 60b87fa766a77a183a4aa998ae914a2d22b1e195 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 16:39:46 +0000 Subject: [PATCH 031/161] prepare cleanup --- fast_llm/data/dataset/gpt/memmap.py | 2 + fast_llm/data/dataset/gpt/sampled.py | 1 - fast_llm/data/preparator/gpt_memmap/config.py | 5 +++ .../data/preparator/gpt_memmap/prepare.py | 44 ++++++++++--------- fast_llm/data/tokenizer.py | 2 - 5 files changed, 31 insertions(+), 23 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index af632d5b4..e1297b14a 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -108,6 +108,8 @@ def _init( + sum([x.nbytes for x in self._spans]) ) self._num_pixels = 0 + self._image_lengths = None + self._image_positions = None if self._has_images and self._version >= 4: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index fc2ddb6a0..91f8ca8fa 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -93,7 +93,6 @@ def __init__( self._truncate_documents = sampling.truncate_documents self._device = torch.device("cuda" if self._config.gpu else "cpu") - # TODO Soham: use something else for this check, introducing has_images for just this check might be unnecessary. if self._indexed_dataset.has_images and self._truncate_documents: raise RuntimeError( "Truncating documents with images is not yet supported. Please turn off truncation to use images." diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 38d90ed42..53f8e4688 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -173,6 +173,11 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Tokenizer configuration.", hint=FieldHint.feature, ) + image_patch_size: int = Field( + default=16, + desc="Patch size for images. This is used solely for computing the number of tokens in an image to get an even split.", + hint=FieldHint.optional, + ) splits: dict[str, float] | None = Field( default=None, desc="Split the output dataset into multiple ones (ex, train/valid/test) with the specified ratios." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index b6d817730..c5a1b339c 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -44,12 +44,7 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: pass - # TODO Soham: can we merged tokenize_batch and tokenize_batch_with_spans? def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - # input_ids = [ - # np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) - # for text in batch[self._config.dataset.field] - # ] input_ids, token_spans, image_token_positions = map( list, zip( @@ -85,6 +80,7 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ return { "input_ids": input_ids, "image_positions": image_token_positions, + "token_spans": token_spans, "num_tokens": num_tokens, "num_pixels": num_pixels, } @@ -282,12 +278,7 @@ def run(self) -> None: ) if self._config.dataset.field not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.") - if self._config.dataset.loss_masking_spans is not None: - if self._config.dataset.loss_masking_spans not in dataset.column_names: - raise ValueError(f"Dataset does not have spans field '{self._config.dataset.loss_masking_spans}'.") - tokenize_fn = self._tokenize_batch_with_spans - else: - tokenize_fn = self._tokenize_batch + tokenize_fn = self._tokenize_batch # Avoid decoding bytes to images unless asked if self._config.dataset.images is not None: dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) @@ -336,7 +327,7 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa # Create the config file(s) on rank 0 if self._config.splits: for split_name, split_config in self._split_and_blend_dataset_configs( - dataset_configs, self._config.splits, self._config.output_path + dataset_configs, self._config.splits, self._config.output_path, self._config.image_patch_size ).items(): self._save_dataset_config( split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml" @@ -376,7 +367,11 @@ def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) - @classmethod def _split_and_blend_dataset_configs( - cls, dataset_configs: list[GPTMemmapDatasetConfig], splits: dict[str, int | float], output_path: pathlib.Path + cls, + dataset_configs: list[GPTMemmapDatasetConfig], + splits: dict[str, int | float], + output_path: pathlib.Path, + image_patch_size: int, ) -> dict[str, GPTSampledDatasetConfig]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] @@ -406,11 +401,20 @@ def _split_and_blend_dataset_configs( # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() - # TODO Soham: handle pixels (could still work with number of tokens?) - sizes_cumsum = dataset.get_document_sizes()[0].cumsum() - Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) - begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) - end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) + text_sizes, image_sizes = dataset.get_document_sizes() + tokens_cumsum = text_sizes.cumsum() + Assert.eq(tokens_cumsum[-1], dataset_config.num_tokens) + if image_sizes: + num_pixels_cumsum = np.cumsum([x.prod(axis=1).sum() for x in image_sizes]) + # We use the patch sizes only for the purposes of even splitting and blending weights. + # We can always use a different patch size for training without any significant impact + # Unless the patch size used at training time is significantly different from the one used here + image_tokens_cumsum = num_pixels_cumsum // (image_patch_size**2) + tokens_cumsum += image_tokens_cumsum + num_pixels_cumsum = num_pixels_cumsum * 3 + Assert.eq(num_pixels_cumsum[-1], dataset_config.num_pixels) + begin_index = _get_nearest_split(tokens_cumsum, split_begin_in_dataset * tokens_cumsum[-1]) + end_index = _get_nearest_split(tokens_cumsum, split_end_in_dataset * tokens_cumsum[-1]) if end_index > begin_index: datasets_in_split.append( GPTDatasetSliceConfig.from_dict( @@ -423,8 +427,8 @@ def _split_and_blend_dataset_configs( ) ) dataset_tokens_in_split.append( - sizes_cumsum[end_index - 1].item() - - (sizes_cumsum[begin_index - 1].item() if begin_index > 0 else 0) + tokens_cumsum[end_index - 1].item() + - (tokens_cumsum[begin_index - 1].item() if begin_index > 0 else 0) ) # [else] None of the dataset belongs to the split. diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index c44715d80..0acb65e47 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -46,8 +46,6 @@ def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[li """ Tokenize the input text and return the tokenized input_ids along with token spans. """ - # if not image_positions and not char_spans: - # return self._tokenize(text), [], [] if not image_positions: image_positions = [] if not char_spans: From f8a5532f16df73794bbed793721a2b507bb8b280 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 22:21:27 +0000 Subject: [PATCH 032/161] slightly better conversion --- fast_llm/models/gpt/conversion.py | 328 +++++++++++++----------------- 1 file changed, 146 insertions(+), 182 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 6aa3aaf1f..4363c96c6 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -167,20 +167,16 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig def _create_weight_converters( self, - hf_base_prefix: str = "", - fast_llm_offset: int = 1, ) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers # Embeddings converters.append( - WeightConverter( - f"layers.{fast_llm_offset - 1}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight" - ) + WeightConverter(f"layers.{num_layers - 1}.word_embeddings_weight", f"model.embed_tokens.weight") ) - converters += self._create_lm_head_converters(hf_base_prefix, fast_llm_offset) + converters += self._create_lm_head_converters() for i in range(num_layers): converters += self._create_transformer_layer_converters(f"layers.{i+1}", f"model.layers.{i}") @@ -565,196 +561,111 @@ class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + # lm_converters = super()._create_config_converters() lm_converters = super()._create_config_converters() - lm_converters[-2] = ConstantExportParamConverter( - export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] - ) - # TODO Soham: cleaner way to get language model config converters - for converter in lm_converters: - if isinstance(converter, (RenameParamConverter, MappedConfigParamConverter, RopeScalingParamConverter)): - # Llava uses a different name for the text config - # if converter.fast_llm_names[0][0] == "transformer": + for idx, converter in enumerate(lm_converters): + if converter.export_names == (("model_type",),): + continue + elif converter.export_names == (("architectures",),): + ignore_index = idx + if converter.export_names: converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) - # if converter.fast_llm_names[0][0] == "transformer": - # converter.export_names[0] = ("text_config", *converter.export_names[0]) - return lm_converters + [ - # Multimodal adapter - RenameParamConverter( - fast_llm_names=(("vision_encoder", "adapter_size"),), - export_names=(("text_config", "hidden_size"),), - ), - # Image processing and conv layer - # TODO Soham: these options are not in the fast-llm model config. They're read from BatchConfig currently - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "encoder", "image_size"),), - # export_names=( - # ( - # "vision_config", - # "image_size", - # ), - # ), - # ), - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "encoder", "patch_size"),), - # export_names=( - # ( - # "vision_config", - # "patch_size", - # ), - # ), - # ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "patch_norm", "type"),), - fast_llm_value=NormalizationType.rms_norm, - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), - fast_llm_value=NormalizationType.rms_norm, - ), - # Vision Transformer - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), - export_names=( - ( - "vision_config", - "num_hidden_layers", + + return ( + lm_converters[:ignore_index] + + lm_converters[ignore_index + 1 :] + + [ + ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] + ), + # Vision Adapter + RenameParamConverter( + fast_llm_names=(("vision_encoder", "adapter_size"),), + export_names=(("text_config", "hidden_size"),), + ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "patch_norm", "type"),), + fast_llm_value=NormalizationType.rms_norm, + ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), + fast_llm_value=NormalizationType.rms_norm, + ), + # Vision Transformer + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), + export_names=( + ( + "vision_config", + "num_hidden_layers", + ), ), ), - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), - export_names=( - ( - "vision_config", - "hidden_size", + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), + export_names=( + ( + "vision_config", + "hidden_size", + ), ), ), - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), - export_names=( - ( - "vision_config", - "num_attention_heads", + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), + export_names=( + ( + "vision_config", + "num_attention_heads", + ), ), ), - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), - export_names=( - ( - "vision_config", - "num_key_value_heads", + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), + export_names=( + ( + "vision_config", + "num_key_value_heads", + ), ), ), - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), - export_names=( - ( - "vision_config", - "intermediate_size", + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), + export_names=( + ( + "vision_config", + "intermediate_size", + ), ), ), - ), - MappedConfigParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), - export_names=( - ( - "vision_config", - "hidden_act", + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), + export_names=( + ( + "vision_config", + "hidden_act", + ), ), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True - ), - MappedConfigParamConverter( - fast_llm_names=(("vision_encoder", "adapter_activation_type"),), - export_names=(("projector_hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False - ), - # TODO Soham: add this config param for completeness? - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "encoder", "num_channels"),), - # export_names=( - # ( - # "vision_config", - # "num_channels", - # ), - # ), - # ), - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "transformer", "attention_dropout"),), - # export_names=( - # ( - # "vision_config", - # "attention_dropout", - # ), - # ), - # ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), - export_names=(("vision_config", "rope_theta"),), - ), - # TODO Soham: add this config param in vision encoder for completeness? - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "transformer", "initializer_range"),), - # export_names=(("vision_config", "initializer_range"),), - # ), - ] - - def _create_vision_transformer_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers - vision_transformer_converters = [] - for layer in range(num_layers): - # TODO Soham: check if args are correct - vision_transformer_converters.extend( - self._create_vision_transformer_layer_converters( - layer, - ignore_export=False, - hf_base_prefix="vision_tower.transformer.layers.", - fast_llm_offset=1, - type="vision", - ) - ) - - return vision_transformer_converters - - def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: - patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] - if self._model.config.base_model.vision_encoder.conv_bias: - patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) - layernorm_converters = [ - WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), - ] - if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: - layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) - - vision_transformer_converters = self._create_vision_transformer_converters() - offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 - adapter_converters = [ - WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), - WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), - # TODO Soham: add bias based on config - WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), - WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), - ] - - return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters - - def _create_weight_converters(self) -> list[WeightConverter]: - vision_encoder_converter = self._create_vision_encoder_weight_converters() - offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 - # TODO Soham: call _create_transformer_layer_converters with llava's custom offset - lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) - return vision_encoder_converter + lm_converters + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True + ), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "adapter_activation_type"),), + export_names=(("projector_hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), + export_names=(("vision_config", "rope_theta"),), + ), + ] + ) def _create_vision_transformer_layer_converters( self, @@ -850,6 +761,59 @@ def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix ), ] + def _create_vision_transformer_converters(self) -> list[WeightConverter]: + num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers + vision_transformer_converters = [] + for layer in range(num_layers): + # TODO Soham: check if args are correct + vision_transformer_converters.extend( + self._create_vision_transformer_layer_converters( + layer, + ignore_export=False, + hf_base_prefix="vision_tower.transformer.layers.", + fast_llm_offset=1, + type="vision", + ) + ) + + return vision_transformer_converters + + def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: + patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] + if self._model.config.base_model.vision_encoder.conv_bias: + patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) + layernorm_converters = [ + WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), + ] + if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: + layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) + + vision_transformer_converters = self._create_vision_transformer_converters() + offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 + adapter_converters = [ + WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), + WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), + # TODO Soham: add bias based on config + WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), + WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), + ] + + return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters + + def _create_weight_converters(self) -> list[WeightConverter]: + vision_encoder_converter = self._create_vision_encoder_weight_converters() + offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 + # Embeddings + lm_converters = [ + WeightConverter(f"layers.{offset - 1}.word_embeddings_weight", f"language_model.model.embed_tokens.weight") + ] + for i in range(self._model.config.base_model.transformer.num_layers): + lm_converters += self._create_transformer_layer_converters( + fast_llm_layer_name=f"layers.{i + offset}", hf_layer_name=f"language_model.model.layers.{i}" + ) + lm_converters += self._create_lm_head_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) + return vision_encoder_converter + lm_converters + class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat From 490651e4b074073e60e36910c6d6d0ed1fa46c21 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 May 2025 06:51:31 +0000 Subject: [PATCH 033/161] cleanup, sequence parallelism --- fast_llm/data/dataset/gpt/indexed.py | 2 +- fast_llm/data/dataset/gpt/memmap.py | 1 + fast_llm/layers/language_model/config.py | 2 +- fast_llm/layers/multi_modal/embedding.py | 96 +++++++++++++++---- fast_llm/layers/vision_encoder/config.py | 16 ++++ fast_llm/layers/vision_encoder/encoder.py | 8 +- .../layers/vision_encoder/preprocessing.py | 2 +- fast_llm/models/gpt/conversion.py | 10 +- fast_llm/models/gpt/model.py | 26 +++-- 9 files changed, 126 insertions(+), 37 deletions(-) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 6e9bef96d..cbe77ff0a 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -45,7 +45,7 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. doc_sizes, im_sizes = self._dataset.get_document_sizes() - return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] + return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else None def get_document_size(self, index: int) -> int: return self._dataset.get_document_size(self._begin + index) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index e1297b14a..1efc312e8 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -209,6 +209,7 @@ def get( offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) images = None + image_positions = None if self._has_images: # Truncations with images are not yet supported image_positions = self._image_positions[idx] diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 78de218f1..e46e104c2 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -175,7 +175,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: # TODO: Need both? tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab, self.vocab_size)) tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) - if self.vision_encoder is not None: + if self.vision_encoder.enabled: self.vision_encoder.setup_tensor_space(tensor_space) @property diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index b7d79dd37..52eaaac34 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -3,6 +3,7 @@ import torch from fast_llm.core.distributed import set_generator +from fast_llm.core.ops import gather, reduce_forward, split from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.layers.language_model.embedding import LanguageModelEmbedding @@ -10,6 +11,7 @@ from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs from fast_llm.layers.vision_encoder.preprocessing import get_num_patches from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert class MultiModalEmbedding(LanguageModelEmbedding): @@ -24,6 +26,78 @@ def __init__( ): super().__init__(config, tensor_space) + @torch.compile + def _forward( + self, + input_: torch.Tensor, + tokens: torch.Tensor, + position_ids: torch.Tensor | None, + image_positions: list[torch.Tensor] | None, + image_sizes: list[list[tuple[int, int]]] | None, + ) -> torch.Tensor: + """ + Forward pass for the multi-modal embedding layer. + Args: + input_: The input tensor (image embeddings). + tokens: The tokenized text input. + position_ids: The position ids for the text input. + image_positions: The positions of the image tokens in the input. + image_sizes: The sizes of the images in the input. + Returns: + The combined embeddings for text and images. + """ + Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) + group = self._tensor_space.distributed.tensor_group + if self._parallel_embeddings: + token_mask = (tokens >= self._vocab_start_index) * (tokens < self._vocab_end_index) + masked_tokens = (tokens - self._vocab_start_index) * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa + embeddings = reduce_forward(embeddings, group) + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + # TODO Soham: avoid cloning? + embeddings = embeddings.clone() + input_ = gather(input_, group, dim=0) + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) + embeddings[sample_idx, position : position + num_image_tokens] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens + ] + image_embedding_offset += num_image_tokens + if self._sequence_parallel: + embeddings = split(embeddings, group=group, dim=0) + else: + if self._sequence_parallel: + tokens = split(tokens, group=group, dim=0) + if self._use_absolute_position_embeddings: + position_ids = split(position_ids, group=group, dim=0) + # TODO Soham: get image positions for current split. Maybe in preprocessing? + # for positions in image_positions: + # if positions > self._distributed_config.tensor_rank + embeddings = torch.embedding(self.word_embeddings_weight, tokens) + # TODO Soham: avoid cloning? + embeddings = embeddings.clone() + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) + embeddings[sample_idx, position : position + num_image_tokens] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens + ] + image_embedding_offset += num_image_tokens + + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + with set_generator( + self._tensor_space.distributed.tp_generator + if self._sequence_parallel + else self._tensor_space.distributed.pp_generator + ): + embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + return embeddings.to(dtype=self._residual_dtype) + def forward( self, input_: torch.Tensor, @@ -42,25 +116,5 @@ def forward( image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes) image_positions = kwargs.get(VisionEncoderKwargs.image_positions) tokens = kwargs.get(LanguageModelKwargs.tokens) - # get text embeddings - # TODO Soham: cloning to avoid pytorch complaint about in-place operation. Can we do better? - embeddings = super()._forward(tokens, position_ids).clone() - image_idx = 0 - for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): - image_embedding_offset = 0 - for position, size in zip(positions, sizes): - num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) - embeddings[sample_idx, position : position + num_image_tokens] = input_[ - sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens - ] - image_embedding_offset += num_image_tokens - image_idx += 1 - - with set_generator( - self._tensor_space.distributed.tp_generator - if self._sequence_parallel - else self._tensor_space.distributed.pp_generator - ): - embeddings = torch.dropout(embeddings, self._dropout_p, self.training) - return embeddings.to(self._residual_dtype) + return self._forward(input_, tokens, position_ids, image_positions, image_sizes) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index be3fb38cb..e9bfd7d1c 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,3 +1,5 @@ +import enum + from fast_llm.config import Config, Field, FieldHint, config_class from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -130,10 +132,20 @@ class ImageNormalizationConfig(Config): ) +class VisionEncoderType(str, enum.Enum): + none = "none" + pixtral = "pixtral" + + @config_class() class VisionEncoderConfig(BaseModelConfig): _abstract = False + type: VisionEncoderType = Field( + default=VisionEncoderType.none, + desc="Type of the vision encoder. Choices: none, pixtral.", + hint=FieldHint.architecture, + ) transformer: VisionTransformerConfig = Field( default_factory=VisionTransformerConfig, desc="Configuration for the vision transformer architecture.", @@ -182,3 +194,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace): ) ) self.transformer.setup_tensor_space(tensor_space, type="vision") + + @property + def enabled(self) -> bool: + return self.type != VisionEncoderType.none diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 59212c58f..a67053d56 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -2,6 +2,7 @@ import torch +from fast_llm.core.ops import split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs @@ -39,6 +40,8 @@ class PatchConv(Layer): def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() self._tensor_space = tensor_space + self._distributed_config = tensor_space.distributed_config + self._sequence_parallel = self._distributed_config.sequence_tensor_parallel # TODO Soham: lr_scale self.weight = ParameterMeta.from_dims( ( @@ -68,7 +71,10 @@ def forward( hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) + group = self._tensor_space.distributed.tensor_group input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) patch_embeddings = self.norm(input_.flatten(1)) - patch_embeddings = patch_embeddings.reshape(*(x.size for x in hidden_dims)) + patch_embeddings = patch_embeddings.reshape(*(x.global_size for x in hidden_dims)) + if self._sequence_parallel: + patch_embeddings = split(patch_embeddings, group=group, dim=0) return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 46bf0ab3f..db726e24f 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -153,7 +153,6 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: cu_seqlens = [0] max_seqlen = -1 for imgs, sizes in zip(images, image_sizes): - # TODO Soham: should this be micro_sequence_length? # sum( # get_num_patches(*size, patch_size) for size in sizes # ) @@ -172,6 +171,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] ) ) + # TODO Soham: should this be micro_sequence_length? padding_size = kwargs[TransformerKwargs.sequence_length] - cu_seqlens[-1] if padding_size > max_seqlen: max_seqlen = padding_size diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 4363c96c6..ad4df7378 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -27,6 +27,7 @@ from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.layers.common.config import NormalizationType from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig +from fast_llm.layers.vision_encoder.config import VisionEncoderType from fast_llm.models.gpt.config import ( GPTBaseModelConfig, GPTModelConfig, @@ -172,9 +173,7 @@ def _create_weight_converters( num_layers = self._model.config.base_model.transformer.num_layers # Embeddings - converters.append( - WeightConverter(f"layers.{num_layers - 1}.word_embeddings_weight", f"model.embed_tokens.weight") - ) + converters.append(WeightConverter(f"layers.0.word_embeddings_weight", f"model.embed_tokens.weight")) converters += self._create_lm_head_converters() @@ -250,7 +249,7 @@ def _create_transformer_layer_converters( converters += self._get_mlp_converters(f"{fast_llm_layer_name}", f"{hf_layer_name}") return converters - def _create_lm_head_converters(self, hf_base_prefix: str, fast_llm_offset: int = 1) -> list[WeightConverter]: + def _create_lm_head_converters(self, hf_base_prefix: str = "", fast_llm_offset: int = 1) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm @@ -575,6 +574,9 @@ def _create_config_converters(cls) -> list[ParamConverter]: lm_converters[:ignore_index] + lm_converters[ignore_index + 1 :] + [ + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "type"),), fast_llm_value=VisionEncoderType.pixtral + ), ConstantExportParamConverter( export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] ), diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index b832f1b04..4219ac324 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -76,7 +76,7 @@ def __init__( else: self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) - if self._config.vision_encoder: + if self._config.vision_encoder.enabled: self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) if self._config.vision_encoder.transformer.rotary.enabled: self._preprocessors.append( @@ -129,7 +129,7 @@ def get_layers(self) -> list[Layer]: return [ *( [LanguageModelEmbedding(self._config, self._tensor_space)] - if self._config.vision_encoder is None + if not self._config.vision_encoder.enabled else self.get_vision_layers() ), *[ @@ -162,7 +162,7 @@ def preprocess_meta( sequence_length -= self._config.prediction_heads micro_sequence_length = sequence_length - if self._config.vision_encoder: + if self._config.vision_encoder.enabled: image_size = batch_meta.image_size image_mean = [ self._config.vision_encoder.image_normalization.mean_r, @@ -231,7 +231,7 @@ def preprocess_meta( if sequence_first else (batch_dim, hidden_sequence_q_dim, hidden_dim) ) - if self._config.vision_encoder: + if self._config.vision_encoder.enabled: vision_hidden_dim = self._tensor_space.get_tensor_dim(VisionTransformerDimNames.hidden) vision_hidden_dims = ( (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) @@ -298,7 +298,7 @@ def preprocess_meta( reference_kwargs[name] = reference_kwargs_ kwargs["reference_models"] = reference_kwargs - if self._config.vision_encoder: + if self._config.vision_encoder.enabled: # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) else: @@ -430,11 +430,17 @@ def preprocess( @property def embedding(self) -> LanguageModelEmbedding: - return self.layers[self._config.vision_encoder.transformer.num_layers + 2] + if self._config.vision_encoder.enabled: + return self.layers[self._config.vision_encoder.transformer.num_layers + 2] + else: + return self.layers[0] @property def transformer_layers(self) -> list[TransformerLayer]: - return self.layers[self._config.vision_encoder.transformer.num_layers + 3 : -1] + if self._config.vision_encoder.enabled: + return self.layers[self._config.vision_encoder.transformer.num_layers + 3 : -1] + else: + return self.layers[1:-1] @property def model_head(self) -> LanguageModelHead: @@ -449,7 +455,11 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: return { WORD_EMBEDDINGS_WEIGHT: ( self.embedding.word_embeddings_weight, - (self._config.vision_encoder is not None, *self.model_head_indices), + # TODO Soham: make embedding layer index a property + ( + self._config.vision_encoder.enabled * (self._config.vision_encoder.transformer.num_layers + 2), + *self.model_head_indices, + ), ) } elif self._config.prediction_heads > 1: From 24e1b83f15c0ec89cb866b5438283533218bc005 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 May 2025 07:19:49 +0000 Subject: [PATCH 034/161] fix conv --- fast_llm/layers/vision_encoder/encoder.py | 28 +++++++++++++++++------ 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index a67053d56..cff874793 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -61,6 +61,25 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self.norm = config.patch_norm.get_layer(tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels)) self.stride = config.patch_size + @torch.compile + def _forward( + self, + input_: torch.Tensor, + hidden_dims: tuple[TensorMeta, ...], + ): + group = self._tensor_space.distributed.tensor_group + input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) + patch_embeddings = self.norm(input_.flatten(1)) + batch_dim, sequence_q_dim, hidden_dim = hidden_dims + if self._sequence_parallel: + patch_embeddings = patch_embeddings.reshape( + sequence_q_dim.global_size, batch_dim.size, hidden_dim.global_size + ) + patch_embeddings = split(patch_embeddings, group=group, dim=0) + else: + patch_embeddings = patch_embeddings.reshape(batch_dim.size, sequence_q_dim.size, hidden_dim.size) + return patch_embeddings + def forward( self, input_: torch.Tensor, @@ -71,10 +90,5 @@ def forward( hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) - group = self._tensor_space.distributed.tensor_group - input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) - patch_embeddings = self.norm(input_.flatten(1)) - patch_embeddings = patch_embeddings.reshape(*(x.global_size for x in hidden_dims)) - if self._sequence_parallel: - patch_embeddings = split(patch_embeddings, group=group, dim=0) - return patch_embeddings + + return self._forward(input_, hidden_dims) From 0f1612a63c84b355c45b282cb10f174c6a9a7da3 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 May 2025 14:57:47 +0000 Subject: [PATCH 035/161] wip fixes --- fast_llm/data/dataset/gpt/indexed.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 53 ++++++++++++++++------------ 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index cbe77ff0a..56c4c8927 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -45,7 +45,7 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. doc_sizes, im_sizes = self._dataset.get_document_sizes() - return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else None + return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else [] def get_document_size(self, index: int) -> int: return self._dataset.get_document_size(self._begin + index) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 91f8ca8fa..9fbb218ee 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -133,23 +133,26 @@ def _sample(self) -> None: # Get the document sizes, the main information needed for sampling. document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) - image_token_sizes = [] - for i, sizes in enumerate(image_sizes): - image_token_sizes.append( - sum( - get_num_patches( - *get_resize_dims( - *size, - self._parameters.image_size, - self._parameters.image_size, + if image_sizes: + image_token_sizes = [] + for i, sizes in enumerate(image_sizes): + image_token_sizes.append( + sum( + get_num_patches( + *get_resize_dims( + *size, + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ), self._parameters.patch_size, - ), - self._parameters.patch_size, + ) + for size in sizes ) - for size in sizes ) - ) - image_token_sizes = torch.tensor(image_token_sizes).to(self._device) + image_token_sizes = torch.tensor(image_token_sizes).to(self._device) + else: + image_token_sizes = torch.zeros_like(document_sizes) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() @@ -463,16 +466,20 @@ def __getitem__(self, index: int) -> typing.Any: use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) start_pos = 0 - for idx, im_position in enumerate(sample.image_positions): - # image_positions.append(im_positions + len(token_ids) + image_tokens_added) - # Add placeholders for image tokens - token_ids.append(sample.token_ids[start_pos:im_position]) - token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) - image_positions.append(im_position + len(token_ids) + image_tokens_added) - image_tokens_added += image_tokens - start_pos = im_position + if sample.image_positions: + for idx, im_position in enumerate(sample.image_positions): + # image_positions.append(im_positions + len(token_ids) + image_tokens_added) + # Add placeholders for image tokens + token_ids.append(sample.token_ids[start_pos:im_position]) + token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) + image_positions.append(im_position + len(token_ids) + image_tokens_added) + image_tokens_added += image_tokens + start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) - images.append(sample.images) + if sample.images: + images.append(sample.images) + else: + images.append([]) if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: span = np.clip( From 2e48c5f282e4e5b1e460e96efdc9e42b2c0743db Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 May 2025 22:26:10 +0000 Subject: [PATCH 036/161] fix --- fast_llm/layers/multi_modal/embedding.py | 11 ++++-- fast_llm/layers/vision_encoder/config.py | 1 + fast_llm/layers/vision_encoder/encoder.py | 34 +++++++------------ .../layers/vision_encoder/preprocessing.py | 5 ++- fast_llm/models/gpt/model.py | 3 ++ 5 files changed, 29 insertions(+), 25 deletions(-) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 52eaaac34..9a035d8fd 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -62,9 +62,14 @@ def _forward( image_embedding_offset = 0 for position, size in zip(positions, sizes): num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) - embeddings[sample_idx, position : position + num_image_tokens] = input_[ - sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens - ] + if self._sequence_parallel: + embeddings[position : position + num_image_tokens, sample_idx] = input_[ + image_embedding_offset : image_embedding_offset + num_image_tokens, sample_idx + ] + else: + embeddings[sample_idx, position : position + num_image_tokens] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens + ] image_embedding_offset += num_image_tokens if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index e9bfd7d1c..fdbe2726f 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -66,6 +66,7 @@ class VisionEncoderKwargs: patch_embeddings = "patch_embeddings" hidden_dims = "vit_hidden_dims" image_patches_meta = "vit_image_patches_meta" + out_channels = "vit_out_channels" # TODO Soham: do we need all of them? diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index cff874793..1df7f889c 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -5,6 +5,7 @@ from fast_llm.core.ops import split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ @@ -61,25 +62,6 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self.norm = config.patch_norm.get_layer(tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels)) self.stride = config.patch_size - @torch.compile - def _forward( - self, - input_: torch.Tensor, - hidden_dims: tuple[TensorMeta, ...], - ): - group = self._tensor_space.distributed.tensor_group - input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) - patch_embeddings = self.norm(input_.flatten(1)) - batch_dim, sequence_q_dim, hidden_dim = hidden_dims - if self._sequence_parallel: - patch_embeddings = patch_embeddings.reshape( - sequence_q_dim.global_size, batch_dim.size, hidden_dim.global_size - ) - patch_embeddings = split(patch_embeddings, group=group, dim=0) - else: - patch_embeddings = patch_embeddings.reshape(batch_dim.size, sequence_q_dim.size, hidden_dim.size) - return patch_embeddings - def forward( self, input_: torch.Tensor, @@ -90,5 +72,15 @@ def forward( hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) - - return self._forward(input_, hidden_dims) + micro_batch_size = kwargs[TransformerKwargs.micro_batch_size] + sequence_length = kwargs[TransformerKwargs.sequence_length] + out_channels = kwargs[VisionEncoderKwargs.out_channels] + reshape_dims = (micro_batch_size, sequence_length, out_channels) + group = self._tensor_space.distributed.tensor_group + input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) + patch_embeddings = self.norm(input_.flatten(1)) + patch_embeddings = patch_embeddings.view(reshape_dims) + if self._sequence_parallel: + patch_embeddings = patch_embeddings.permute(1, 0, 2).contiguous() + patch_embeddings = split(patch_embeddings, group=group, dim=0) + return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index db726e24f..7ebfb5228 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -152,16 +152,19 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: patch_position_ids = [] cu_seqlens = [0] max_seqlen = -1 + sequence_first = kwargs.get(TransformerKwargs.sequence_first) for imgs, sizes in zip(images, image_sizes): # sum( # get_num_patches(*size, patch_size) for size in sizes # ) seq_patches = [] + sample_cu_seqlen = 0 for image, size in zip(imgs, sizes): seqlen = get_num_patches(*size, patch_size) if seqlen > max_seqlen: max_seqlen = seqlen cu_seqlens.append(cu_seqlens[-1] + seqlen) + sample_cu_seqlen += seqlen seq_patches.append( torch.cat( [ @@ -172,7 +175,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ) ) # TODO Soham: should this be micro_sequence_length? - padding_size = kwargs[TransformerKwargs.sequence_length] - cu_seqlens[-1] + padding_size = kwargs[TransformerKwargs.sequence_length] - sample_cu_seqlen if padding_size > max_seqlen: max_seqlen = padding_size cu_seqlens.append(kwargs[TransformerKwargs.sequence_length]) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 4219ac324..9fff50bc7 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -185,6 +185,9 @@ def preprocess_meta( VisionEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( VisionEncoderDimNames.kv_channels ).size, + VisionEncoderKwargs.out_channels: self._tensor_space.get_tensor_dim( + VisionEncoderDimNames.out_channels + ).size, } else: vision_kwargs = {} From d529d37d881849afff40e57609ef4d10a916b742 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 17 May 2025 17:42:24 +0000 Subject: [PATCH 037/161] fix image position --- fast_llm/data/dataset/gpt/sampled.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 9fbb218ee..780b18878 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -412,6 +412,7 @@ def __getitem__(self, index: int) -> typing.Any: images = [] image_positions = [] image_tokens_added = 0 + text_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -471,11 +472,13 @@ def __getitem__(self, index: int) -> typing.Any: # image_positions.append(im_positions + len(token_ids) + image_tokens_added) # Add placeholders for image tokens token_ids.append(sample.token_ids[start_pos:im_position]) + text_tokens_added += len(token_ids[-1]) token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) image_positions.append(im_position + len(token_ids) + image_tokens_added) image_tokens_added += image_tokens start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) + text_tokens_added += len(token_ids[-1]) if sample.images: images.append(sample.images) else: From 3c22ddafc27e02a6f5af31ad7022a6d315cb3f03 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 17 May 2025 17:45:04 +0000 Subject: [PATCH 038/161] cleanup --- .../layers/transformer/vision_transformer.py | 25 ++----------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py index 3588956c7..72bd95ddd 100644 --- a/fast_llm/layers/transformer/vision_transformer.py +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -1,33 +1,12 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionTransformerKwargs from fast_llm.tensor import TensorMeta class VisionTransformerLayer(TransformerLayer): - """ - A vision transformer layer to encode image patches - """ - - def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index: int, - return_input: bool = False, - ): - super().__init__(config, tensor_space, layer_index, return_input) - - hidden_dim = self._tensor_space.get_tensor_dim(VisionTransformerDimNames.hidden) - self.norm_1 = self._config.normalization.get_layer(hidden_dim) - self.norm_2 = self._config.normalization.get_layer(hidden_dim) - - self.norm_1 = self._config.peft.apply_other(self.norm_1) - self.norm_2 = self._config.peft.apply_other(self.norm_2) - @property def name(self) -> str: return f"Vision transformer layer {self._layer_index}" From f0c8d830da9c4ea43df478a6cafbbb48bf910111 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 20 May 2025 07:05:01 +0000 Subject: [PATCH 039/161] cleanup --- fast_llm/layers/language_model/config.py | 1 - fast_llm/layers/transformer/attention.py | 17 +- fast_llm/layers/transformer/config.py | 259 ++++++++++++------ fast_llm/layers/transformer/mlp.py | 17 +- fast_llm/layers/transformer/preprocessing.py | 58 ++-- fast_llm/layers/transformer/transformer.py | 24 +- .../layers/transformer/vision_transformer.py | 12 +- fast_llm/layers/vision_encoder/config.py | 60 +--- fast_llm/models/gpt/model.py | 19 +- fast_llm/utils.py | 7 + 10 files changed, 239 insertions(+), 235 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index e46e104c2..cdb27d9ef 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -46,7 +46,6 @@ class LanguageModelBaseConfig(BaseModelConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - # TODO Soham: make this None by default. Need to figure out how to handle this in the config (see ) vision_encoder: VisionEncoderConfig = Field( default_factory=VisionEncoderConfig, desc="Configuration for the vision encoder that transforms images into embeddings.", diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index b16f17405..3180b6cb8 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -9,14 +9,7 @@ from fast_llm.functional.rotary import apply_rotary_embeddings from fast_llm.functional.triton.rotary import triton_rotary_autograd_ from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - TransformerSubLayerName, - VisionTransformerConfig, -) -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs, TransformerSubLayerName from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -66,12 +59,8 @@ def __init__( layer_index, ): super().__init__() - if isinstance(config, VisionTransformerConfig): - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - elif isinstance(config, TransformerConfig): - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config = config self._tensor_space = tensor_space # TODO Soham: fix assert diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 38dc9ec48..9a6bec07d 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -28,60 +28,109 @@ class RoutingType(str, enum.Enum): sinkhorn = "sinkhorn" -class TransformerDimNames: - # A set of common tensor dim names packed into a namespace. - # Input dimensions (variable) - # TODO: Does batch belong here? - batch = "batch" - # TODO: Distinguish micro-sequence? - sequence_q = "sequence_q" - sequence_q_tp = "sequence_q_tp" - sequence_k = "sequence_k" - hidden = "hidden" - # Self-attention dimensions - head_groups = "head_groups" - group_heads = "group_heads" - key_and_value = "key_value" - kv_channels = "kv_channels" - composite_heads = "composite_heads" - composite_query = "composite_query" - composite_key_value = "composite_key_value" - composite_dense = "composite_dense" - # MLP dimensions - mlp = "mlp" - gate_and_up = "gate_and_up" - composite_gated_mlp = "composite_gated_mlp" - experts = "experts" - top_experts = "top_experts" - shared_experts = "shared_experts" - unshared_experts = "unshared_experts" - composite_expert_mlp = "composite_expert_mlp" - composite_gated_expert_mlp = "composite_gated_expert_mlp" - composite_shared_expert_mlp = "composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" - - -class TransformerKwargs: - rotary_freq_q = "rotary_freq_q" - rotary_freq_k = "rotary_freq_k" - attention_mask = "attention_mask" - attention_mask_value = "attention_mask_value" - sequence_lengths = "sequence_lengths" - cu_seqlens_q = "cu_seqlens_q" - cu_seqlens_k = "cu_seqlens_k" - max_seqlen_q = "max_seqlen_q" - max_seqlen_k = "max_seqlen_k" - # TODO: Review these - presents = "presents" - past_key_values = "past_key_values" - sequence_first = "sequence_first" - hidden_dims = "hidden_dims" - sequence_q_dim = "sequence_q_dim" - sequence_k_dim = "sequence_k_dim" - sequence_length = "sequence_length" - micro_batch_size = "micro_batch_size" - # TODO: Move - grad_output = "grad_output" +class BaseTransformerDimNames: + _kwargs_attributes = { + "batch": "batch", + "sequence_q": "sequence_q", + "sequence_q_tp": "sequence_q_tp", + "sequence_k": "sequence_k", + "hidden": "hidden", + "head_groups": "head_groups", + "group_heads": "group_heads", + "key_and_value": "key_value", + "kv_channels": "kv_channels", + "composite_heads": "composite_heads", + "composite_query": "composite_query", + "composite_key_value": "composite_key_value", + "composite_dense": "composite_dense", + "mlp": "mlp", + "gate_and_up": "gate_and_up", + "composite_gated_mlp": "composite_gated_mlp", + "experts": "experts", + "top_experts": "top_experts", + "shared_experts": "shared_experts", + "unshared_experts": "unshared_experts", + "composite_expert_mlp": "composite_expert_mlp", + "composite_gated_expert_mlp": "composite_gated_expert_mlp", + "composite_shared_expert_mlp": "composite_shared_expert_mlp", + "composite_gated_shared_expert_mlp": "composite_gated_shared_expert_mlp", + } + + def __init_subclass__(cls, prefix="", **kwargs): + super().__init_subclass__(**kwargs) + cls._prefix = prefix + for attr, value in BaseTransformerDimNames._kwargs_attributes.items(): + setattr(cls, value, f"{cls._prefix}_{value}") + + +class TransformerDimNames(BaseTransformerDimNames, prefix=""): + pass + + +class VisionTransformerDimNames(BaseTransformerDimNames, prefix="image_encoder"): + pass + + +class BaseTransformerKwargs: + _kwargs_attributes = { + "rotary_freq_q": "rotary_freq_q", + "rotary_freq_k": "rotary_freq_k", + "attention_mask": "attention_mask", + "attention_mask_value": "attention_mask_value", + "sequence_lengths": "sequence_lengths", + "cu_seqlens_q": "cu_seqlens_q", + "cu_seqlens_k": "cu_seqlens_k", + "max_seqlen_q": "max_seqlen_q", + "max_seqlen_k": "max_seqlen_k", + "presents": "presents", + "past_key_values": "past_key_values", + "sequence_first": "sequence_first", + "hidden_dims": "hidden_dims", + "sequence_q_dim": "sequence_q_dim", + "sequence_k_dim": "sequence_k_dim", + "sequence_length": "sequence_length", + "micro_batch_size": "micro_batch_size", + "grad_output": "grad_output", + } + + _prefix = "" + + def __init_subclass__(cls, prefix="", **kwargs): + super().__init_subclass__(**kwargs) + cls._prefix = prefix + for attr, value in BaseTransformerKwargs._kwargs_attributes.items(): + setattr(cls, value, f"{cls._prefix}_{value}") + + +class TransformerKwargs(BaseTransformerKwargs, prefix=""): + pass + + +class VisionTransformerKwargs(BaseTransformerKwargs, prefix="image_encoder"): + patch_position_ids = "patch_position_ids" + + +# class TransformerKwargs: +# rotary_freq_q = "rotary_freq_q" +# rotary_freq_k = "rotary_freq_k" +# attention_mask = "attention_mask" +# attention_mask_value = "attention_mask_value" +# sequence_lengths = "sequence_lengths" +# cu_seqlens_q = "cu_seqlens_q" +# cu_seqlens_k = "cu_seqlens_k" +# max_seqlen_q = "max_seqlen_q" +# max_seqlen_k = "max_seqlen_k" +# # TODO: Review these +# presents = "presents" +# past_key_values = "past_key_values" +# sequence_first = "sequence_first" +# hidden_dims = "hidden_dims" +# sequence_q_dim = "sequence_q_dim" +# sequence_k_dim = "sequence_k_dim" +# sequence_length = "sequence_length" +# micro_batch_size = "micro_batch_size" +# # TODO: Move +# grad_output = "grad_output" class TransformerLossNames: @@ -98,6 +147,11 @@ class RotaryEmbeddingType(str, enum.Enum): pixtral = "pixtral" +class TransformerType(str, enum.Enum): + lm_decoder = "lm_decoder" + image_encoder = "image_encoder" + + @config_class() class RotaryConfig(BaseModelConfig): _abstract = False @@ -160,6 +214,14 @@ def _validate(self) -> None: if self.triton and not TritonConfig.TRITON_ENABLED: warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") + @property + def _transformer_dim_names(self) -> TransformerDimNames: + return TransformerDimNames + + @property + def _transformer_kwargs(self) -> TransformerKwargs: + return TransformerKwargs + @config_class() class VisionRotaryConfig(RotaryConfig): @@ -169,6 +231,14 @@ class VisionRotaryConfig(RotaryConfig): hint=FieldHint.feature, ) + @property + def _transformer_dim_names(self) -> VisionTransformerDimNames: + return VisionTransformerDimNames + + @property + def _transformer_kwargs(self) -> VisionTransformerKwargs: + return VisionTransformerKwargs + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" @@ -259,6 +329,11 @@ def _validate(self) -> None: @config_class() class TransformerConfig(BaseModelConfig): _abstract = False + transformer_type: TransformerType = Field( + default=TransformerType.lm_decoder, + desc="Type of the transformer. Choices: lm_decoder, image_encoder.", + hint=FieldHint.architecture, + ) normalization: NormalizationConfig = Field( default_factory=NormalizationConfig, desc="Configuration for the normalization layers architecture.", @@ -658,72 +733,71 @@ def _from_dict( cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton")) return super()._from_dict(default, strict, flat) - def setup_tensor_space(self, tensor_space: TensorSpace, type: str | None = None) -> None: - if type == "vision": - # TODO Soham: better way to get around circular imports? Maybe add a type class variable to TransformerConfig? - from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames - - transformer_dim_names = VisionTransformerDimNames - else: - transformer_dim_names = TransformerDimNames + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.hidden, self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.hidden, self.hidden_size)) # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + self.transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - transformer_dim_names.group_heads, + self.transformer_dim_names.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(transformer_dim_names.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(transformer_dim_names.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(self.transformer_dim_names.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(self.transformer_dim_names.kv_channels, self.kv_channels)) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_heads, (head_groups, group_heads)) + CompositeTensorDim(self.transformer_dim_names.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(self.transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim( + self.transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels) + ) ) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(self.transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) ) # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) + tensor_space.add_tensor_dim(mlp := TensorDim(self.transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) tensor_space.add_tensor_dim( - gate_and_up := TensorDim(transformer_dim_names.gate_and_up, 2 if self.gated else 1) + gate_and_up := TensorDim(self.transformer_dim_names.gate_and_up, 2 if self.gated else 1) ) - tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(transformer_dim_names.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_expert_mlp, (experts, mlp))) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + CompositeTensorDim(self.transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp)) ) - tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.unshared_experts, self.num_unshared_experts)) + tensor_space.add_tensor_dim(experts := TensorDim(self.transformer_dim_names.experts, self.num_experts)) + tensor_space.add_tensor_dim( + CompositeTensorDim(self.transformer_dim_names.composite_expert_mlp, (experts, mlp)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(self.transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + ) + tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.unshared_experts, self.num_unshared_experts)) # shared_experts if self.num_shared_experts: tensor_space.add_tensor_dim( - shared_experts := TensorDim(transformer_dim_names.shared_experts, self.num_shared_experts) + shared_experts := TensorDim(self.transformer_dim_names.shared_experts, self.num_shared_experts) ) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) + CompositeTensorDim(self.transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) + self.transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) ) ) @@ -739,6 +813,14 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return use_flash_attention + @property + def _transformer_kwargs(self) -> TransformerKwargs: + return TransformerKwargs + + @property + def _transformer_dim_names(self) -> TransformerDimNames: + return TransformerDimNames + @config_class() class VisionRotaryConfig(RotaryConfig): @@ -755,6 +837,11 @@ class VisionTransformerConfig(TransformerConfig): Configuration for the Vision Transformer (ViT) model. """ + transformer_type: TransformerType = FieldUpdate( + default=TransformerType.image_encoder, + desc="Type of the transformer. Choices: lm_decoder, image_encoder.", + hint=FieldHint.architecture, + ) causal: bool = FieldUpdate( default=False, desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", @@ -765,3 +852,11 @@ class VisionTransformerConfig(TransformerConfig): desc="Configuration for the rotary positional embeddings.", hint=FieldHint.feature, ) + + @property + def _transformer_kwargs(self) -> VisionTransformerKwargs: + return VisionTransformerKwargs + + @property + def _transformer_dim_names(self) -> VisionTransformerDimNames: + return VisionTransformerDimNames diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index dcea463a8..42393a413 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -8,14 +8,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - TransformerSubLayerName, - VisionTransformerConfig, -) -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerSubLayerName from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -25,12 +18,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s super().__init__() self._name = name - if isinstance(config, VisionTransformerConfig): - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - elif isinstance(config, TransformerConfig): - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs init_method_1 = init_normal_( std=config.init_method_std_mlp_1, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 870463df2..97c6c0f3f 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -7,19 +7,8 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.rotary import convert_rotary_complex_to_real -from fast_llm.layers.transformer.config import ( - RotaryConfig, - RotaryEmbeddingType, - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - VisionTransformerConfig, -) -from fast_llm.layers.vision_encoder.config import ( - VisionEncoderKwargs, - VisionTransformerDimNames, - VisionTransformerKwargs, -) +from fast_llm.layers.transformer.config import RotaryConfig, RotaryEmbeddingType, TransformerConfig, TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -178,19 +167,8 @@ def __init__( config: RotaryConfig, tensor_space: TensorSpace, ): - # if isinstance(config, TransformerConfig): - # self._transformer_dim_names = TransformerDimNames - # self._transformer_kwargs = TransformerKwargs - # elif isinstance(config, VisionTransformerConfig): - # self._transformer_dim_names = VisionTransformerDimNames - # self._transformer_kwargs = VisionTransformerKwargs - # TODO Soham: better way to do this? - if config.type == RotaryEmbeddingType.pixtral: - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - else: - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config = config assert self._config.enabled self._tensor_space = tensor_space @@ -273,12 +251,14 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, ): - if isinstance(config, VisionTransformerConfig): - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - elif isinstance(config, TransformerConfig): - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + # if isinstance(config, VisionTransformerConfig): + # self._transformer_dim_names = VisionTransformerDimNames + # self._transformer_kwargs = VisionTransformerKwargs + # elif isinstance(config, TransformerConfig): + # self._transformer_dim_names = TransformerDimNames + # self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config @@ -348,12 +328,14 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert self._config.do_use_flash_attention(self._distributed_config) - if isinstance(config, VisionTransformerConfig): - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - elif isinstance(config, TransformerConfig): - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + # if isinstance(config, VisionTransformerConfig): + # self._transformer_dim_names = VisionTransformerDimNames + # self._transformer_kwargs = VisionTransformerKwargs + # elif isinstance(config, TransformerConfig): + # self._transformer_dim_names = TransformerDimNames + # self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: """ diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 5590be322..8bd1394e1 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -9,15 +9,9 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - VisionTransformerConfig, -) +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -35,12 +29,8 @@ def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): super().__init__() - if isinstance(config, VisionTransformerConfig): - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - elif isinstance(config, TransformerConfig): - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config: TransformerConfig = config self._tensor_space: TensorSpace = tensor_space self._dropout_p: float = self._config.hidden_dropout @@ -80,6 +70,14 @@ def _bias_dropout_add( def name(self) -> str: return f"{self._name} {self._layer_index}" + @property + def _transformer_kwargs(self) -> TransformerKwargs: + return TransformerKwargs + + @property + def _transformer_dim_names(self) -> TransformerDimNames: + return TransformerDimNames + def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[self._transformer_kwargs.hidden_dims] if self._return_input: diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py index 72bd95ddd..7f39f9cff 100644 --- a/fast_llm/layers/transformer/vision_transformer.py +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -2,14 +2,20 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.config import VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.tensor import TensorMeta class VisionTransformerLayer(TransformerLayer): + _name: str = "Vision transformer layer" + + @property + def _transformer_kwargs(self) -> VisionTransformerKwargs: + return VisionTransformerKwargs + @property - def name(self) -> str: - return f"Vision transformer layer {self._layer_index}" + def _transformer_dim_names(self) -> VisionTransformerDimNames: + return VisionTransformerDimNames def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[VisionTransformerKwargs.hidden_dims] diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index fdbe2726f..70504901b 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -16,39 +16,6 @@ class VisionEncoderDimNames: kv_channels = "vision_kv_channels" -class VisionTransformerDimNames: - # A set of common tensor dim names packed into a namespace. - # Input dimensions (variable) - # TODO: Does batch belong here? - batch = "vit_batch" - # TODO: Distinguish micro-sequence? - sequence_q = "vit_sequence_q" - sequence_q_tp = "vit_sequence_q_tp" - sequence_k = "vit_sequence_k" - hidden = "vit_hidden" - # Self-attention dimensions - head_groups = "vit_head_groups" - group_heads = "vit_group_heads" - key_and_value = "vit_key_value" - kv_channels = "vit_kv_channels" - composite_heads = "vit_composite_heads" - composite_query = "vit_composite_query" - composite_key_value = "vit_composite_key_value" - composite_dense = "vit_composite_dense" - # MLP dimensions - mlp = "vit_mlp" - gate_and_up = "vit_gate_and_up" - composite_gated_mlp = "vit_composite_gated_mlp" - experts = "vit_experts" - top_experts = "vit_top_experts" - shared_experts = "vit_shared_experts" - unshared_experts = "vit_unshared_experts" - composite_expert_mlp = "vit_composite_expert_mlp" - composite_gated_expert_mlp = "vit_composite_gated_expert_mlp" - composite_shared_expert_mlp = "vit_composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "vit_composite_gated_shared_expert_mlp" - - class VisionEncoderKwargs: patch_size = "patch_size" images = "images" @@ -69,31 +36,6 @@ class VisionEncoderKwargs: out_channels = "vit_out_channels" -# TODO Soham: do we need all of them? -class VisionTransformerKwargs: - rotary_freq_q = "vit_rotary_freq_q" - rotary_freq_k = "vit_rotary_freq_k" - attention_mask = "vit_attention_mask" - attention_mask_value = "vit_attention_mask_value" - sequence_lengths = "vit_sequence_lengths" - cu_seqlens_q = "vit_cu_seqlens_q" - cu_seqlens_k = "vit_cu_seqlens_k" - max_seqlen_q = "vit_max_seqlen_q" - max_seqlen_k = "vit_max_seqlen_k" - # TODO: Review these - presents = "vit_presents" - past_key_values = "vit_past_key_values" - sequence_first = "vit_sequence_first" - hidden_dims = "vit_hidden_dims" - sequence_q_dim = "vit_sequence_q_dim" - sequence_k_dim = "vit_sequence_k_dim" - sequence_length = "vit_sequence_length" - micro_batch_size = "vit_micro_batch_size" - # TODO: Move - grad_output = "vit_grad_output" - patch_position_ids = "patch_position_ids" - - @config_class() class ImageNormalizationConfig(Config): mean_r: float = Field( @@ -194,7 +136,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace): VisionEncoderDimNames.kv_channels, self.transformer.hidden_size // self.transformer.num_attention_heads ) ) - self.transformer.setup_tensor_space(tensor_space, type="vision") + self.transformer.setup_tensor_space(tensor_space) @property def enabled(self) -> bool: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 9fff50bc7..c1d9df90f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -433,17 +433,18 @@ def preprocess( @property def embedding(self) -> LanguageModelEmbedding: - if self._config.vision_encoder.enabled: - return self.layers[self._config.vision_encoder.transformer.num_layers + 2] - else: - return self.layers[0] + return self.layers[self.embedding_layer_index] @property def transformer_layers(self) -> list[TransformerLayer]: + return self.layers[self.embedding_layer_index + 1 : -1] + + @property + def embedding_layer_index(self) -> int: if self._config.vision_encoder.enabled: - return self.layers[self._config.vision_encoder.transformer.num_layers + 3 : -1] + return self._config.vision_encoder.transformer.num_layers + 2 else: - return self.layers[1:-1] + return 0 @property def model_head(self) -> LanguageModelHead: @@ -458,11 +459,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: return { WORD_EMBEDDINGS_WEIGHT: ( self.embedding.word_embeddings_weight, - # TODO Soham: make embedding layer index a property - ( - self._config.vision_encoder.enabled * (self._config.vision_encoder.transformer.num_layers + 2), - *self.model_head_indices, - ), + (self.embedding_layer_index, *self.model_head_indices), ) } elif self._config.prediction_heads > 1: diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 51e0eee59..c5b7f07ae 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -336,3 +336,10 @@ def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple def check_equal_nested(config_a, config_b): if errors := compare_nested(config_a, config_b): raise ValueError("\n".join(errors)) + + +def prefix_class_vars(cls, prefix: str, base_cls: type): + for attr, value in vars(base_cls).items(): + if not attr.startswith("__") and isinstance(value, str) and not hasattr(cls, attr): + setattr(cls, attr, prefix + value) + return cls From ca33ee83b22bea5c45a946a13209572b6aa73680 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 21 May 2025 20:59:14 +0000 Subject: [PATCH 040/161] cleaner, extensible multimodal config --- fast_llm/layers/transformer/config.py | 44 +- fast_llm/layers/transformer/preprocessing.py | 12 - fast_llm/layers/transformer/transformer.py | 18 +- .../layers/transformer/vision_transformer.py | 14 +- fast_llm/layers/vision_encoder/config.py | 5 + .../layers/vision_encoder/preprocessing.py | 12 +- fast_llm/models/gpt/config.py | 30 + fast_llm/models/gpt/conversion.py | 984 ++++++++++++------ fast_llm/models/gpt/model.py | 3 +- 9 files changed, 740 insertions(+), 382 deletions(-) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 9a6bec07d..a634bc3c8 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -60,7 +60,7 @@ def __init_subclass__(cls, prefix="", **kwargs): super().__init_subclass__(**kwargs) cls._prefix = prefix for attr, value in BaseTransformerDimNames._kwargs_attributes.items(): - setattr(cls, value, f"{cls._prefix}_{value}") + setattr(cls, attr, f"{cls._prefix}_{value}") class TransformerDimNames(BaseTransformerDimNames, prefix=""): @@ -737,67 +737,69 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.hidden, self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.hidden, self.hidden_size)) # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - self.transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + self._transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - self.transformer_dim_names.group_heads, + self._transformer_dim_names.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(self.transformer_dim_names.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(self.transformer_dim_names.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(self._transformer_dim_names.key_and_value, 2)) + tensor_space.add_tensor_dim( + kv_channels := TensorDim(self._transformer_dim_names.kv_channels, self.kv_channels) + ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_heads, (head_groups, group_heads)) + CompositeTensorDim(self._transformer_dim_names.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(self._transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - self.transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels) + self._transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels) ) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(self._transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) ) # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(self.transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) + tensor_space.add_tensor_dim(mlp := TensorDim(self._transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) tensor_space.add_tensor_dim( - gate_and_up := TensorDim(self.transformer_dim_names.gate_and_up, 2 if self.gated else 1) + gate_and_up := TensorDim(self._transformer_dim_names.gate_and_up, 2 if self.gated else 1) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp)) + CompositeTensorDim(self._transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp)) ) - tensor_space.add_tensor_dim(experts := TensorDim(self.transformer_dim_names.experts, self.num_experts)) + tensor_space.add_tensor_dim(experts := TensorDim(self._transformer_dim_names.experts, self.num_experts)) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_expert_mlp, (experts, mlp)) + CompositeTensorDim(self._transformer_dim_names.composite_expert_mlp, (experts, mlp)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + CompositeTensorDim(self._transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) ) - tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.unshared_experts, self.num_unshared_experts)) + tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.unshared_experts, self.num_unshared_experts)) # shared_experts if self.num_shared_experts: tensor_space.add_tensor_dim( - shared_experts := TensorDim(self.transformer_dim_names.shared_experts, self.num_shared_experts) + shared_experts := TensorDim(self._transformer_dim_names.shared_experts, self.num_shared_experts) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) + CompositeTensorDim(self._transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - self.transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) + self._transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) ) ) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 97c6c0f3f..af1a53f68 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -251,12 +251,6 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, ): - # if isinstance(config, VisionTransformerConfig): - # self._transformer_dim_names = VisionTransformerDimNames - # self._transformer_kwargs = VisionTransformerKwargs - # elif isinstance(config, TransformerConfig): - # self._transformer_dim_names = TransformerDimNames - # self._transformer_kwargs = TransformerKwargs self._transformer_dim_names = config._transformer_dim_names self._transformer_kwargs = config._transformer_kwargs self._config = config @@ -328,12 +322,6 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert self._config.do_use_flash_attention(self._distributed_config) - # if isinstance(config, VisionTransformerConfig): - # self._transformer_dim_names = VisionTransformerDimNames - # self._transformer_kwargs = VisionTransformerKwargs - # elif isinstance(config, TransformerConfig): - # self._transformer_dim_names = TransformerDimNames - # self._transformer_kwargs = TransformerKwargs self._transformer_dim_names = config._transformer_dim_names self._transformer_kwargs = config._transformer_kwargs diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 8bd1394e1..2c79883b3 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -9,7 +9,7 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage @@ -70,14 +70,6 @@ def _bias_dropout_add( def name(self) -> str: return f"{self._name} {self._layer_index}" - @property - def _transformer_kwargs(self) -> TransformerKwargs: - return TransformerKwargs - - @property - def _transformer_dim_names(self) -> TransformerDimNames: - return TransformerDimNames - def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[self._transformer_kwargs.hidden_dims] if self._return_input: @@ -157,3 +149,11 @@ def __init__( def _create_mixer(self): self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + + # @property + # def _transformer_kwargs(self) -> TransformerKwargs: + # return TransformerKwargs + + # @property + # def _transformer_dim_names(self) -> TransformerDimNames: + # return TransformerDimNames diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py index 7f39f9cff..c2cfe9f23 100644 --- a/fast_llm/layers/transformer/vision_transformer.py +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -1,21 +1,21 @@ import torch from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.layers.transformer.config import VisionTransformerKwargs from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.tensor import TensorMeta class VisionTransformerLayer(TransformerLayer): _name: str = "Vision transformer layer" - @property - def _transformer_kwargs(self) -> VisionTransformerKwargs: - return VisionTransformerKwargs + # @property + # def _transformer_kwargs(self) -> VisionTransformerKwargs: + # return VisionTransformerKwargs - @property - def _transformer_dim_names(self) -> VisionTransformerDimNames: - return VisionTransformerDimNames + # @property + # def _transformer_dim_names(self) -> VisionTransformerDimNames: + # return VisionTransformerDimNames def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[VisionTransformerKwargs.hidden_dims] diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 70504901b..6932c8fc0 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -119,6 +119,11 @@ class VisionEncoderConfig(BaseModelConfig): desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", hint=FieldHint.core, ) + adapter_bias: bool = Field( + default=True, + desc="Whether to use bias in the adapter linear layer.", + hint=FieldHint.optional, + ) image_normalization: ImageNormalizationConfig = Field( default_factory=ImageNormalizationConfig, desc="Configuration for the normalization layers applied to the image patches.", diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 7ebfb5228..5009123f0 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -6,14 +6,8 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.layers.vision_encoder.config import ( - VisionEncoderConfig, - VisionEncoderDimNames, - VisionEncoderKwargs, - VisionTransformerDimNames, - VisionTransformerKwargs, -) +from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -152,7 +146,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: patch_position_ids = [] cu_seqlens = [0] max_seqlen = -1 - sequence_first = kwargs.get(TransformerKwargs.sequence_first) + kwargs.get(TransformerKwargs.sequence_first) for imgs, sizes in zip(images, image_sizes): # sum( # get_num_patches(*size, patch_size) for size in sizes diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 162015768..d7d32221d 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -51,12 +51,22 @@ class MistralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mixtral" + class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mtp_llama" trust_remote_code: typing.ClassVar[bool] = True + class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "llava" + # Using default values for vision and text models. Can be overridden in the config + vision_name: typing.ClassVar[str] = "pixtral" + text_name: typing.ClassVar[str] = "mistral" + + +class PixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "pixtral" + @config_class() class GPTBatchConfig(BatchConfig): @@ -140,6 +150,7 @@ class GPTModelConfig(FastLLMModelConfig): MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, LlavaGPTHuggingfaceCheckpointFormat, + PixtralGPTHuggingfaceCheckpointFormat, ) @classmethod @@ -154,6 +165,25 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: return HuggingfaceGPTModelForCausalLM + @classmethod + def get_checkpoint_format(cls, format: type[CheckpointFormat]) -> type[CheckpointFormat]: + if isinstance(format, type) and issubclass(format, CheckpointFormat): + format_ = cls.get_checkpoint_format(format.name) + Assert.is_(format, format_) + return format_ + elif isinstance(format, dict): + for format_ in cls.checkpoint_formats: + if format_.name == format["name"]: + if (vision_name := format.get("vision_name")) is not None: + format_.vision_name = vision_name + if (text_name := format.get("text_name")) is not None: + format_.text_name = text_name + return format_ + for format_ in cls.checkpoint_formats: + if format_.name == format: + return format_ + raise ValueError(f"Checkpoint format {format} not supported for model {cls.model_name}") + @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index ad4df7378..0b0796ed2 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -6,8 +6,10 @@ import torch from transformers.configuration_utils import PretrainedConfig -from fast_llm.config import DEFAULT, MISSING -from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm import __version__ +from fast_llm.config import DEFAULT, MISSING, get_nested_dict_value, set_nested_dict_value +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointLoadMetadataConfig from fast_llm.engine.checkpoint.external import ( AutoStateDictCheckpointHandler, ConstantExportParamConverter, @@ -22,7 +24,7 @@ WeightConverter, ) from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler -from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.layers.common.config import NormalizationType @@ -36,6 +38,7 @@ MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, + PixtralGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) @@ -112,73 +115,70 @@ def import_weight( return (merged_weight.t().contiguous(),) -class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): - _model: GPTModel - _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig - """ - Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) - """ - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), - RenameParamConverter( - fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) - ), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=(("hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), - export_names=(("num_hidden_layers",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("hidden_size",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "num_attention_heads"),), - export_names=(("num_attention_heads",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "head_groups"),), - export_names=(("num_key_value_heads",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), - export_names=(("intermediate_size",),), - ), - RenameParamConverter( - fast_llm_names=(("vocab_size",),), - export_names=(("vocab_size",),), - ), - RenameParamConverter( - fast_llm_names=(("tie_word_embeddings",),), - export_names=(("tie_word_embeddings",),), - ), - ] - - @abc.abstractmethod - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - pass +class TransformerWeightConverterMixin: - def _create_weight_converters( + def _get_weight_and_bias_converters( self, + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, ) -> list[WeightConverter]: - converters = [] - num_layers = self._model.config.base_model.transformer.num_layers + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + cls( + tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + tuple(f"{prefix}.weight" for prefix in hf_prefix), + self._model.config.base_model, + ) + ] + if use_bias: + converters.append( + cls( + tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + tuple(f"{prefix}.bias" for prefix in hf_prefix), + self._model.config.base_model, + ) + ) + return converters - # Embeddings - converters.append(WeightConverter(f"layers.0.word_embeddings_weight", f"model.embed_tokens.weight")) + def _create_lm_head_converters(self, hf_base_prefix: str = "", offset: int = 0) -> list[WeightConverter]: + num_layers = self._model.config.base_model.transformer.num_layers + prediction_heads = self._model.config.base_model.prediction_heads + norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + converters = [] - converters += self._create_lm_head_converters() + # Next-token prediction head + # Final norm + converters += self._get_weight_and_bias_converters( + f"layers.{num_layers + offset + 1}.final_norm", f"{hf_base_prefix}model.norm", norm_bias + ) + # Output weights + if self._model.config.base_model.tie_word_embeddings: + converters.append(IgnoreImportWeightConverter((), f"{hf_base_prefix}lm_head.weight")) + else: + converters.append( + WeightConverter(f"layers.{num_layers + offset + 1}.output_weights", f"{hf_base_prefix}lm_head.weight") + ) - for i in range(num_layers): - converters += self._create_transformer_layer_converters(f"layers.{i+1}", f"model.layers.{i}") + # MTP-heads > 0 are thrown away + # TODO Soham: handle offset with MTP + for i in range(1, prediction_heads): + logger.warning( + f"The model weights for the multi-token prediction head {i} are discarded during conversion." + ) + mtp_transformer_layer_index = num_layers - 1 + 2 * i + # MTP transformer layer + converters += self._create_transformer_layer_converters( + f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True + ) + # MTP output norm + converters += self._get_weight_and_bias_converters( + f"layers.{mtp_transformer_layer_index + 2}.final_norm", (), norm_bias, IgnoreExportWeightConverter + ) return converters @@ -249,71 +249,81 @@ def _create_transformer_layer_converters( converters += self._get_mlp_converters(f"{fast_llm_layer_name}", f"{hf_layer_name}") return converters - def _create_lm_head_converters(self, hf_base_prefix: str = "", fast_llm_offset: int = 1) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers - prediction_heads = self._model.config.base_model.prediction_heads - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm - converters = [] - # Next-token prediction head - # Final norm - converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + fast_llm_offset}.final_norm", f"{hf_base_prefix}model.norm", norm_bias - ) - # Output weights - if self._model.config.base_model.tie_word_embeddings: - converters.append(IgnoreImportWeightConverter((), f"{hf_base_prefix}lm_head.weight")) - else: - converters.append( - WeightConverter( - f"layers.{num_layers + fast_llm_offset}.output_weights", f"{hf_base_prefix}lm_head.weight" - ) - ) +class CommonHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): + _model: GPTModel + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + """ + Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) + """ - # MTP-heads > 0 are thrown away - # TODO Soham: handle offset with MTP - for i in range(1, prediction_heads): - logger.warning( - f"The model weights for the multi-token prediction head {i} are discarded during conversion." - ) - mtp_transformer_layer_index = num_layers - 1 + 2 * i - # MTP transformer layer - converters += self._create_transformer_layer_converters( - f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True - ) - # MTP output norm - converters += self._get_weight_and_bias_converters( - f"layers.{mtp_transformer_layer_index + 2}.final_norm", (), norm_bias, IgnoreExportWeightConverter - ) + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_layers"),), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("hidden_size",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_attention_heads"),), + export_names=(("num_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "head_groups"),), + export_names=(("num_key_value_heads",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "ffn_hidden_size"),), + export_names=(("intermediate_size",),), + ), + RenameParamConverter( + fast_llm_names=(("vocab_size",),), + export_names=(("vocab_size",),), + ), + RenameParamConverter( + fast_llm_names=(("tie_word_embeddings",),), + export_names=(("tie_word_embeddings",),), + ), + ] - return converters + @abc.abstractmethod + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + pass - def _get_weight_and_bias_converters( + def _create_weight_converters( self, - fast_llm_prefix: str | tuple[str, ...], - hf_prefix: str | tuple[str, ...], - use_bias: bool, - cls=WeightConverter, + hf_base_prefix: str = "", + offset: int = 0, ) -> list[WeightConverter]: - if isinstance(fast_llm_prefix, str): - fast_llm_prefix = (fast_llm_prefix,) - if isinstance(hf_prefix, str): - hf_prefix = (hf_prefix,) - converters = [ - cls( - tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), - tuple(f"{prefix}.weight" for prefix in hf_prefix), - self._model.config.base_model, - ) - ] - if use_bias: - converters.append( - cls( - tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), - tuple(f"{prefix}.bias" for prefix in hf_prefix), - self._model.config.base_model, - ) + converters = [] + num_layers = self._model.config.base_model.transformer.num_layers + + # Embeddings + converters.append( + WeightConverter(f"layers.{offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight") + ) + + converters += self._create_lm_head_converters(hf_base_prefix, offset=offset) + + for i in range(num_layers): + converters += self._create_transformer_layer_converters( + f"layers.{i+offset+1}", f"{hf_base_prefix}model.layers.{i}" ) + return converters @@ -555,266 +565,592 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] -class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): - format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat +class PixtralHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - # lm_converters = super()._create_config_converters() - lm_converters = super()._create_config_converters() - for idx, converter in enumerate(lm_converters): - if converter.export_names == (("model_type",),): - continue - elif converter.export_names == (("architectures",),): - ignore_index = idx - if converter.export_names: - converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) - - return ( - lm_converters[:ignore_index] - + lm_converters[ignore_index + 1 :] - + [ - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "type"),), fast_llm_value=VisionEncoderType.pixtral - ), - ConstantExportParamConverter( - export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] - ), - # Vision Adapter - RenameParamConverter( - fast_llm_names=(("vision_encoder", "adapter_size"),), - export_names=(("text_config", "hidden_size"),), - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "patch_norm", "type"),), - fast_llm_value=NormalizationType.rms_norm, - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), - fast_llm_value=NormalizationType.rms_norm, - ), - # Vision Transformer - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), - export_names=( - ( - "vision_config", - "num_hidden_layers", - ), + return super()._create_config_converters() + [ + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["PixtralVisionModel"]), + ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value=VisionEncoderType.pixtral), + ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_layers", ), ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), - export_names=( - ( - "vision_config", - "hidden_size", - ), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "hidden_size", ), ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), - export_names=( - ( - "vision_config", - "num_attention_heads", - ), + export_names=(("hidden_size",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_attention_heads", ), ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), - export_names=( - ( - "vision_config", - "num_key_value_heads", - ), + export_names=(("num_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "head_groups", ), ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), - export_names=( - ( - "vision_config", - "intermediate_size", - ), + export_names=(("num_key_value_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "ffn_hidden_size", ), ), - MappedConfigParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), - export_names=( - ( - "vision_config", - "hidden_act", - ), + export_names=(("intermediate_size",),), + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "kv_channels", ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True - ), - MappedConfigParamConverter( - fast_llm_names=(("vision_encoder", "adapter_activation_type"),), - export_names=(("projector_hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), - export_names=(("vision_config", "rope_theta"),), + export_names=(("head_dim",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "rotary", + "theta", + ), ), - ] - ) + export_names=(("rope_theta",),), + ), + RenameParamConverter(fast_llm_names=(("patch_size",),), export_names=(("patch_size",),)), + ] + + def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + return [ + SplitWeightConverter( + f"{fast_llm_prefix}.mlp.layer_1.weight", + (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), + ), + MLPLayer2Converter( + f"{fast_llm_prefix}.mlp.layer_2.weight", + f"{hf_prefix}.feed_forward.down_proj.weight", + self._model.config.base_model, + ), + ] def _create_vision_transformer_layer_converters( - self, - i: int, - ignore_export: bool = False, - hf_base_prefix: str = "", - fast_llm_offset: int = 1, - type: str | None = None, + self, transformer_layer_index: int, fast_llm_offset: int = 1, hf_base_prefix: str = "" ) -> list[WeightConverter]: - if type is not None: - if type == "vision": - transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer - else: - transformer_config: TransformerConfig = self._model.config.base_model.transformer - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm - converters = [] - names_bias_cls = [ + # Vision transformer layer + transformer_config = self._model.config.base_model.vision_encoder.transformer + norm_bias: bool = transformer_config.normalization.type == NormalizationType.layer_norm + name_bias_cls = [ # Self-attn ( - f"layers.{i+fast_llm_offset}.self_attn.query", - f"vision_tower.transformer.layers.{i}.attention.q_proj", + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.query", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.q_proj", transformer_config.add_attn_qkv_bias, QueryWeightConverter, ), ( - f"layers.{i+fast_llm_offset}.self_attn.key_value", + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.key_value", ( - f"vision_tower.transformer.layers.{i}.attention.k_proj", - f"vision_tower.transformer.layers.{i}.attention.v_proj", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.k_proj", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.v_proj", ), transformer_config.add_attn_qkv_bias, KeyValueWeightConverter, ), ( - f"layers.{i+fast_llm_offset}.self_attn.dense", - f"vision_tower.transformer.layers.{i}.attention.o_proj", + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.dense", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.o_proj", transformer_config.add_attn_dense_bias, WeightConverter, ), # Norm ( - f"layers.{i+fast_llm_offset}.norm_1", - f"vision_tower.transformer.layers.{i}.attention_norm", + f"layers.{fast_llm_offset + transformer_layer_index}.norm_1", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention_norm", norm_bias, WeightConverter, ), ( - f"layers.{i+fast_llm_offset}.norm_2", - f"vision_tower.transformer.layers.{i}.ffn_norm", + f"layers.{fast_llm_offset + transformer_layer_index}.norm_2", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.ffn_norm", norm_bias, WeightConverter, ), ] - for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: + converters = [] + for fast_llm_prefix, hf_prefix, use_bias, cls in name_bias_cls: converters += self._get_weight_and_bias_converters( fast_llm_prefix, - () if ignore_export else hf_prefix, + hf_prefix, use_bias, - cls=IgnoreExportWeightConverter if ignore_export else cls, + cls, ) - # MLP - if ignore_export: - converters += self._get_weight_and_bias_converters( - f"layers.{i+fast_llm_offset}.mlp.layer_1", - (), - transformer_config.add_mlp_bias, - cls=IgnoreExportWeightConverter, - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+fast_llm_offset}.mlp.layer_2", - (), - transformer_config.add_mlp_bias, - cls=IgnoreExportWeightConverter, - ) - converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] - else: - converters += self._get_vision_transformer_mlp_converters( - f"layers.{i+fast_llm_offset}", f"vision_tower.transformer.layers.{i}" - ) + converters += self._get_transformer_mlp_converters( + f"layers.{fast_llm_offset + transformer_layer_index}", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}", + ) return converters - def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - return [ - SplitWeightConverter( - f"{fast_llm_prefix}.mlp.layer_1.weight", - (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), - ), - MLPLayer2Converter( - f"{fast_llm_prefix}.mlp.layer_2.weight", - f"{hf_prefix}.feed_forward.down_proj.weight", - self._model.config.base_model, - ), - ] + def _create_weight_converters(self, offset: int = 0, hf_base_prefix: str = "") -> list[WeightConverter]: + converters = [] + converters.append(WeightConverter(f"layers.{offset}.weight", f"{hf_base_prefix}patch_conv.weight")) + if self._model.config.base_model.vision_encoder.conv_bias: + converters.append(WeightConverter(f"layers.{offset}.bias", f"{hf_base_prefix}patch_conv.bias")) + converters.append(WeightConverter(f"layers.{offset}.norm.weight", f"{hf_base_prefix}ln_pre.weight")) + if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: + converters.append(WeightConverter(f"layers.{offset}.norm.bias", f"{hf_base_prefix}ln_pre.bias")) - def _create_vision_transformer_converters(self) -> list[WeightConverter]: num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers - vision_transformer_converters = [] - for layer in range(num_layers): - # TODO Soham: check if args are correct - vision_transformer_converters.extend( - self._create_vision_transformer_layer_converters( - layer, - ignore_export=False, - hf_base_prefix="vision_tower.transformer.layers.", - fast_llm_offset=1, - type="vision", - ) + for i in range(num_layers): + converters += self._create_vision_transformer_layer_converters(i, offset + 1, hf_base_prefix) + + converters.extend( + [ + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_1.weight", "multi_modal_projector.linear_1.weight" + ), + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_2.weight", "multi_modal_projector.linear_2.weight" + ), + ] + ) + if self._model.config.base_model.vision_encoder.adapter_bias: + converters.extend( + [ + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_1.bias", "multi_modal_projector.linear_1.bias" + ), + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_2.bias", "multi_modal_projector.linear_2.bias" + ), + ] ) - return vision_transformer_converters + return converters - def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: - patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] - if self._model.config.base_model.vision_encoder.conv_bias: - patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) - layernorm_converters = [ - WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), - ] - if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: - layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) - - vision_transformer_converters = self._create_vision_transformer_converters() - offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 - adapter_converters = [ - WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), - WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), - # TODO Soham: add bias based on config - WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), - WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), - ] + @property + def num_layers(self) -> int: + # +2 for projector and conv layers + return self._model.config.base_model.vision_encoder.transformer.num_layers + 2 - return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters - def _create_weight_converters(self) -> list[WeightConverter]: - vision_encoder_converter = self._create_vision_encoder_weight_converters() - offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 - # Embeddings - lm_converters = [ - WeightConverter(f"layers.{offset - 1}.word_embeddings_weight", f"language_model.model.embed_tokens.weight") - ] - for i in range(self._model.config.base_model.transformer.num_layers): - lm_converters += self._create_transformer_layer_converters( - fast_llm_layer_name=f"layers.{i + offset}", hf_layer_name=f"language_model.model.layers.{i}" +class LlavaHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + + @classmethod + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + cfg_dict = cls._load_config(config.path) + kwargs = {} + if "text_config" in cfg_dict: + text_kwargs = cls._import_config(cfg_dict["text_config"]) + kwargs.update(text_kwargs) + if "vision_config" in cfg_dict: + vision_kwargs = cls._import_config(cfg_dict["vision_config"]) + vision_kwargs = {tuple(["vision_encoder"] + list(key)): value for key, value in vision_kwargs.items()} + kwargs.update(vision_kwargs) + kwargs.update( + cls._import_config( + {key: value for key, value in cfg_dict.items() if key not in ("text_config", "vision_config")} ) - lm_converters += self._create_lm_head_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) - return vision_encoder_converter + lm_converters + ) + imported_model_config = cls._model_class.get_base_model_config_class().from_dict({}, kwargs) + return CheckpointMetadata( + fast_llm_version=__version__, + model=cls._model_class, + format=config.format, + config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), + shards=["weights"], + ) + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] + ), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "adapter_activation_type"),), + export_names=(("projector_hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + ] + + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: + handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) + kwargs = {} + for converter in handler_cls._create_config_converters(): + try: + values = () + for export_name in converter.export_names: + try: + value = get_nested_dict_value(config, export_name) + except KeyError: + value = MISSING + values = values + (value,) + values = converter.import_params(values) + for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True): + if value is MISSING: + raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}") + if fast_llm_name in kwargs: + raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}") + kwargs[fast_llm_name] = value + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return kwargs + + @classmethod + def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: + exported_config = {} + vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + for converter in vision_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, ("vision_encoder",) + fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("vision_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in text_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("text_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return exported_config + + def _create_weight_converters(self): + vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.vision_name) + vision_handler = vision_handler_cls(self._model) + converters = vision_handler._create_weight_converters(hf_base_prefix="vision_tower.", offset=0) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.text_name) + text_handler = text_handler_cls(self._model) + converters.extend( + text_handler._create_weight_converters(hf_base_prefix="language_model.", offset=vision_handler.num_layers) + ) + return converters + + +# class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): +# format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat + +# @classmethod +# def _create_config_converters(cls) -> list[ParamConverter]: +# # lm_converters = super()._create_config_converters() +# lm_converters = super()._create_config_converters() +# for idx, converter in enumerate(lm_converters): +# if converter.export_names == (("model_type",),): +# continue +# elif converter.export_names == (("architectures",),): +# ignore_index = idx +# if converter.export_names: +# converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) + +# return ( +# lm_converters[:ignore_index] +# + lm_converters[ignore_index + 1 :] +# + [ +# ConstantImportParamConverter( +# fast_llm_names=(("vision_encoder", "type"),), fast_llm_value=VisionEncoderType.pixtral +# ), +# ConstantExportParamConverter( +# export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] +# ), +# # Vision Adapter +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "adapter_size"),), +# export_names=(("text_config", "hidden_size"),), +# ), +# ConstantImportParamConverter( +# fast_llm_names=(("vision_encoder", "patch_norm", "type"),), +# fast_llm_value=NormalizationType.rms_norm, +# ), +# ConstantImportParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), +# fast_llm_value=NormalizationType.rms_norm, +# ), +# # Vision Transformer +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), +# export_names=( +# ( +# "vision_config", +# "num_hidden_layers", +# ), +# ), +# ), +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), +# export_names=( +# ( +# "vision_config", +# "hidden_size", +# ), +# ), +# ), +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), +# export_names=( +# ( +# "vision_config", +# "num_attention_heads", +# ), +# ), +# ), +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), +# export_names=( +# ( +# "vision_config", +# "num_key_value_heads", +# ), +# ), +# ), +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), +# export_names=( +# ( +# "vision_config", +# "intermediate_size", +# ), +# ), +# ), +# MappedConfigParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), +# export_names=( +# ( +# "vision_config", +# "hidden_act", +# ), +# ), +# fast_llm_value=ActivationType.from_hf_name, +# export_value=lambda activation_type: activation_type.hf_name, +# ), +# ConstantImportParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True +# ), +# MappedConfigParamConverter( +# fast_llm_names=(("vision_encoder", "adapter_activation_type"),), +# export_names=(("projector_hidden_act",),), +# fast_llm_value=ActivationType.from_hf_name, +# export_value=lambda activation_type: activation_type.hf_name, +# ), +# ConstantImportParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False +# ), +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), +# export_names=(("vision_config", "rope_theta"),), +# ), +# ] +# ) + +# def _create_vision_transformer_layer_converters( +# self, +# i: int, +# ignore_export: bool = False, +# hf_base_prefix: str = "", +# fast_llm_offset: int = 1, +# type: str | None = None, +# ) -> list[WeightConverter]: +# if type is not None: +# if type == "vision": +# transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer +# else: +# transformer_config: TransformerConfig = self._model.config.base_model.transformer +# norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm +# converters = [] +# names_bias_cls = [ +# # Self-attn +# ( +# f"layers.{i+fast_llm_offset}.self_attn.query", +# f"vision_tower.transformer.layers.{i}.attention.q_proj", +# transformer_config.add_attn_qkv_bias, +# QueryWeightConverter, +# ), +# ( +# f"layers.{i+fast_llm_offset}.self_attn.key_value", +# ( +# f"vision_tower.transformer.layers.{i}.attention.k_proj", +# f"vision_tower.transformer.layers.{i}.attention.v_proj", +# ), +# transformer_config.add_attn_qkv_bias, +# KeyValueWeightConverter, +# ), +# ( +# f"layers.{i+fast_llm_offset}.self_attn.dense", +# f"vision_tower.transformer.layers.{i}.attention.o_proj", +# transformer_config.add_attn_dense_bias, +# WeightConverter, +# ), +# # Norm +# ( +# f"layers.{i+fast_llm_offset}.norm_1", +# f"vision_tower.transformer.layers.{i}.attention_norm", +# norm_bias, +# WeightConverter, +# ), +# ( +# f"layers.{i+fast_llm_offset}.norm_2", +# f"vision_tower.transformer.layers.{i}.ffn_norm", +# norm_bias, +# WeightConverter, +# ), +# ] +# for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: +# converters += self._get_weight_and_bias_converters( +# fast_llm_prefix, +# () if ignore_export else hf_prefix, +# use_bias, +# cls=IgnoreExportWeightConverter if ignore_export else cls, +# ) + +# # MLP +# if ignore_export: +# converters += self._get_weight_and_bias_converters( +# f"layers.{i+fast_llm_offset}.mlp.layer_1", +# (), +# transformer_config.add_mlp_bias, +# cls=IgnoreExportWeightConverter, +# ) +# converters += self._get_weight_and_bias_converters( +# f"layers.{i+fast_llm_offset}.mlp.layer_2", +# (), +# transformer_config.add_mlp_bias, +# cls=IgnoreExportWeightConverter, +# ) +# converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] +# else: +# converters += self._get_vision_transformer_mlp_converters( +# f"layers.{i+fast_llm_offset}", f"vision_tower.transformer.layers.{i}" +# ) +# return converters + +# def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: +# return [ +# SplitWeightConverter( +# f"{fast_llm_prefix}.mlp.layer_1.weight", +# (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), +# ), +# MLPLayer2Converter( +# f"{fast_llm_prefix}.mlp.layer_2.weight", +# f"{hf_prefix}.feed_forward.down_proj.weight", +# self._model.config.base_model, +# ), +# ] + +# def _create_vision_transformer_converters(self) -> list[WeightConverter]: +# num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers +# vision_transformer_converters = [] +# for layer in range(num_layers): +# # TODO Soham: check if args are correct +# vision_transformer_converters.extend( +# self._create_vision_transformer_layer_converters( +# layer, +# ignore_export=False, +# hf_base_prefix="vision_tower.transformer.layers.", +# fast_llm_offset=1, +# type="vision", +# ) +# ) + +# return vision_transformer_converters + +# def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: +# patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] +# if self._model.config.base_model.vision_encoder.conv_bias: +# patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) +# layernorm_converters = [ +# WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), +# ] +# if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: +# layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) + +# vision_transformer_converters = self._create_vision_transformer_converters() +# offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 +# adapter_converters = [ +# WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), +# WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), +# # TODO Soham: add bias based on config +# WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), +# WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), +# ] + +# return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters + +# def _create_weight_converters(self) -> list[WeightConverter]: +# vision_encoder_converter = self._create_vision_encoder_weight_converters() +# offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 +# # Embeddings +# lm_converters = [ +# WeightConverter(f"layers.{offset - 1}.word_embeddings_weight", f"language_model.model.embed_tokens.weight") +# ] +# for i in range(self._model.config.base_model.transformer.num_layers): +# lm_converters += self._create_transformer_layer_converters( +# fast_llm_layer_name=f"layers.{i + offset}", hf_layer_name=f"language_model.model.layers.{i}" +# ) +# lm_converters += self._create_lm_head_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) +# return vision_encoder_converter + lm_converters class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): @@ -950,4 +1286,6 @@ class AutoGPTHuggingfaceCheckpointHandler( MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, MTPLlamaGPTHuggingfaceCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler, LlavaGPTHuggingfaceCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, + PixtralGPTHuggingfaceCheckpointFormat.name: PixtralHuggingfaceCheckpointHandler, + # MultiModalGPTHuggingfaceCheckpointFormat.name: MultiModalHuggingfaceCheckpointHandler } diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c1d9df90f..72ff1b887 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -20,6 +20,7 @@ TransformerDimNames, TransformerKwargs, TransformerLossNames, + VisionTransformerDimNames, ) from fast_llm.layers.transformer.preprocessing import ( BackupAttentionPreprocessor, @@ -29,7 +30,7 @@ from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.layers.transformer.vision_transformer import VisionTransformerLayer from fast_llm.layers.vision_encoder.adapter import VisionAdapter -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs, VisionTransformerDimNames +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs from fast_llm.layers.vision_encoder.encoder import PatchConv from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig From f3a4a74a086f5cb81da86195a00d6549cf66844b Mon Sep 17 00:00:00 2001 From: root Date: Wed, 21 May 2025 21:00:11 +0000 Subject: [PATCH 041/161] cleanup --- fast_llm/models/gpt/conversion.py | 262 ------------------------------ 1 file changed, 262 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 0b0796ed2..356525471 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -891,268 +891,6 @@ def _create_weight_converters(self): return converters -# class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): -# format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat - -# @classmethod -# def _create_config_converters(cls) -> list[ParamConverter]: -# # lm_converters = super()._create_config_converters() -# lm_converters = super()._create_config_converters() -# for idx, converter in enumerate(lm_converters): -# if converter.export_names == (("model_type",),): -# continue -# elif converter.export_names == (("architectures",),): -# ignore_index = idx -# if converter.export_names: -# converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) - -# return ( -# lm_converters[:ignore_index] -# + lm_converters[ignore_index + 1 :] -# + [ -# ConstantImportParamConverter( -# fast_llm_names=(("vision_encoder", "type"),), fast_llm_value=VisionEncoderType.pixtral -# ), -# ConstantExportParamConverter( -# export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] -# ), -# # Vision Adapter -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "adapter_size"),), -# export_names=(("text_config", "hidden_size"),), -# ), -# ConstantImportParamConverter( -# fast_llm_names=(("vision_encoder", "patch_norm", "type"),), -# fast_llm_value=NormalizationType.rms_norm, -# ), -# ConstantImportParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), -# fast_llm_value=NormalizationType.rms_norm, -# ), -# # Vision Transformer -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), -# export_names=( -# ( -# "vision_config", -# "num_hidden_layers", -# ), -# ), -# ), -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), -# export_names=( -# ( -# "vision_config", -# "hidden_size", -# ), -# ), -# ), -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), -# export_names=( -# ( -# "vision_config", -# "num_attention_heads", -# ), -# ), -# ), -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), -# export_names=( -# ( -# "vision_config", -# "num_key_value_heads", -# ), -# ), -# ), -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), -# export_names=( -# ( -# "vision_config", -# "intermediate_size", -# ), -# ), -# ), -# MappedConfigParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), -# export_names=( -# ( -# "vision_config", -# "hidden_act", -# ), -# ), -# fast_llm_value=ActivationType.from_hf_name, -# export_value=lambda activation_type: activation_type.hf_name, -# ), -# ConstantImportParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True -# ), -# MappedConfigParamConverter( -# fast_llm_names=(("vision_encoder", "adapter_activation_type"),), -# export_names=(("projector_hidden_act",),), -# fast_llm_value=ActivationType.from_hf_name, -# export_value=lambda activation_type: activation_type.hf_name, -# ), -# ConstantImportParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False -# ), -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), -# export_names=(("vision_config", "rope_theta"),), -# ), -# ] -# ) - -# def _create_vision_transformer_layer_converters( -# self, -# i: int, -# ignore_export: bool = False, -# hf_base_prefix: str = "", -# fast_llm_offset: int = 1, -# type: str | None = None, -# ) -> list[WeightConverter]: -# if type is not None: -# if type == "vision": -# transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer -# else: -# transformer_config: TransformerConfig = self._model.config.base_model.transformer -# norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm -# converters = [] -# names_bias_cls = [ -# # Self-attn -# ( -# f"layers.{i+fast_llm_offset}.self_attn.query", -# f"vision_tower.transformer.layers.{i}.attention.q_proj", -# transformer_config.add_attn_qkv_bias, -# QueryWeightConverter, -# ), -# ( -# f"layers.{i+fast_llm_offset}.self_attn.key_value", -# ( -# f"vision_tower.transformer.layers.{i}.attention.k_proj", -# f"vision_tower.transformer.layers.{i}.attention.v_proj", -# ), -# transformer_config.add_attn_qkv_bias, -# KeyValueWeightConverter, -# ), -# ( -# f"layers.{i+fast_llm_offset}.self_attn.dense", -# f"vision_tower.transformer.layers.{i}.attention.o_proj", -# transformer_config.add_attn_dense_bias, -# WeightConverter, -# ), -# # Norm -# ( -# f"layers.{i+fast_llm_offset}.norm_1", -# f"vision_tower.transformer.layers.{i}.attention_norm", -# norm_bias, -# WeightConverter, -# ), -# ( -# f"layers.{i+fast_llm_offset}.norm_2", -# f"vision_tower.transformer.layers.{i}.ffn_norm", -# norm_bias, -# WeightConverter, -# ), -# ] -# for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: -# converters += self._get_weight_and_bias_converters( -# fast_llm_prefix, -# () if ignore_export else hf_prefix, -# use_bias, -# cls=IgnoreExportWeightConverter if ignore_export else cls, -# ) - -# # MLP -# if ignore_export: -# converters += self._get_weight_and_bias_converters( -# f"layers.{i+fast_llm_offset}.mlp.layer_1", -# (), -# transformer_config.add_mlp_bias, -# cls=IgnoreExportWeightConverter, -# ) -# converters += self._get_weight_and_bias_converters( -# f"layers.{i+fast_llm_offset}.mlp.layer_2", -# (), -# transformer_config.add_mlp_bias, -# cls=IgnoreExportWeightConverter, -# ) -# converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] -# else: -# converters += self._get_vision_transformer_mlp_converters( -# f"layers.{i+fast_llm_offset}", f"vision_tower.transformer.layers.{i}" -# ) -# return converters - -# def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: -# return [ -# SplitWeightConverter( -# f"{fast_llm_prefix}.mlp.layer_1.weight", -# (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), -# ), -# MLPLayer2Converter( -# f"{fast_llm_prefix}.mlp.layer_2.weight", -# f"{hf_prefix}.feed_forward.down_proj.weight", -# self._model.config.base_model, -# ), -# ] - -# def _create_vision_transformer_converters(self) -> list[WeightConverter]: -# num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers -# vision_transformer_converters = [] -# for layer in range(num_layers): -# # TODO Soham: check if args are correct -# vision_transformer_converters.extend( -# self._create_vision_transformer_layer_converters( -# layer, -# ignore_export=False, -# hf_base_prefix="vision_tower.transformer.layers.", -# fast_llm_offset=1, -# type="vision", -# ) -# ) - -# return vision_transformer_converters - -# def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: -# patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] -# if self._model.config.base_model.vision_encoder.conv_bias: -# patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) -# layernorm_converters = [ -# WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), -# ] -# if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: -# layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) - -# vision_transformer_converters = self._create_vision_transformer_converters() -# offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 -# adapter_converters = [ -# WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), -# WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), -# # TODO Soham: add bias based on config -# WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), -# WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), -# ] - -# return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters - -# def _create_weight_converters(self) -> list[WeightConverter]: -# vision_encoder_converter = self._create_vision_encoder_weight_converters() -# offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 -# # Embeddings -# lm_converters = [ -# WeightConverter(f"layers.{offset - 1}.word_embeddings_weight", f"language_model.model.embed_tokens.weight") -# ] -# for i in range(self._model.config.base_model.transformer.num_layers): -# lm_converters += self._create_transformer_layer_converters( -# fast_llm_layer_name=f"layers.{i + offset}", hf_layer_name=f"language_model.model.layers.{i}" -# ) -# lm_converters += self._create_lm_head_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) -# return vision_encoder_converter + lm_converters - - class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat From 3b955b1600ba09c5b7844113b6fc55ee3916f261 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 21 May 2025 22:23:07 +0000 Subject: [PATCH 042/161] fixes for pixtral --- fast_llm/models/gpt/conversion.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 356525471..b7f9f7733 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -572,6 +572,12 @@ class PixtralHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, Huggi @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantImportParamConverter( + fast_llm_names=(("patch_norm", "type"),), fast_llm_value=NormalizationType.rms_norm + ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + ), ConstantExportParamConverter(export_names=(("architectures",),), export_value=["PixtralVisionModel"]), ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value=VisionEncoderType.pixtral), ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), @@ -646,6 +652,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: export_names=(("rope_theta",),), ), RenameParamConverter(fast_llm_names=(("patch_size",),), export_names=(("patch_size",),)), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), ] def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: @@ -803,6 +811,10 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_value=ActivationType.from_hf_name, export_value=lambda activation_type: activation_type.hf_name, ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "adapter_size"),), + export_names=(("projector_intermediate_size",),), + ), ] @classmethod From 49daf581600175c884265c00df4aaf04a9dc0f74 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 21 May 2025 23:27:52 +0000 Subject: [PATCH 043/161] model fixes --- fast_llm/layers/multi_modal/embedding.py | 2 -- fast_llm/layers/transformer/config.py | 11 +---------- fast_llm/layers/vision_encoder/encoder.py | 4 ++-- fast_llm/models/gpt/model.py | 5 +++-- 4 files changed, 6 insertions(+), 16 deletions(-) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 9a035d8fd..c67f82b41 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -55,7 +55,6 @@ def _forward( embeddings = reduce_forward(embeddings, group) if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) - # TODO Soham: avoid cloning? embeddings = embeddings.clone() input_ = gather(input_, group, dim=0) for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): @@ -82,7 +81,6 @@ def _forward( # for positions in image_positions: # if positions > self._distributed_config.tensor_rank embeddings = torch.embedding(self.word_embeddings_weight, tokens) - # TODO Soham: avoid cloning? embeddings = embeddings.clone() for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index a634bc3c8..49babb06b 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -99,7 +99,7 @@ def __init_subclass__(cls, prefix="", **kwargs): super().__init_subclass__(**kwargs) cls._prefix = prefix for attr, value in BaseTransformerKwargs._kwargs_attributes.items(): - setattr(cls, value, f"{cls._prefix}_{value}") + setattr(cls, value, f"{cls._prefix}_{value}" if cls._prefix else value) class TransformerKwargs(BaseTransformerKwargs, prefix=""): @@ -824,15 +824,6 @@ def _transformer_dim_names(self) -> TransformerDimNames: return TransformerDimNames -@config_class() -class VisionRotaryConfig(RotaryConfig): - type: RotaryEmbeddingType = Field( - default=RotaryEmbeddingType.pixtral, - desc="The type of rotary embedding to use. Choices: none, default, llama3, yarn, pixtral.", - hint=FieldHint.feature, - ) - - @config_class() class VisionTransformerConfig(TransformerConfig): """ diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 1df7f889c..20749af48 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -5,7 +5,7 @@ from fast_llm.core.ops import split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerKwargs from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ @@ -69,7 +69,7 @@ def forward( losses: dict[str, typing.Any] | None = None, metrics: dict | None = None, ) -> torch.Tensor: - hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] + hidden_dims = kwargs[VisionTransformerKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) micro_batch_size = kwargs[TransformerKwargs.micro_batch_size] diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 72ff1b887..cbce66f2e 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -21,6 +21,7 @@ TransformerKwargs, TransformerLossNames, VisionTransformerDimNames, + VisionTransformerKwargs, ) from fast_llm.layers.transformer.preprocessing import ( BackupAttentionPreprocessor, @@ -30,7 +31,7 @@ from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.layers.transformer.vision_transformer import VisionTransformerLayer from fast_llm.layers.vision_encoder.adapter import VisionAdapter -from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.layers.vision_encoder.encoder import PatchConv from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig @@ -244,7 +245,7 @@ def preprocess_meta( ) vision_kwargs.update( { - VisionEncoderKwargs.hidden_dims: vision_hidden_dims, + VisionTransformerKwargs.hidden_dims: vision_hidden_dims, } ) From b5ed9f4f6fdd6205225f730a136edb2f211c9f95 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 22 May 2025 19:07:00 +0000 Subject: [PATCH 044/161] more cleanup --- fast_llm/data/data/gpt/data.py | 4 +- fast_llm/data/preparator/gpt_memmap/config.py | 8 +- .../data/preparator/gpt_memmap/prepare.py | 13 +- fast_llm/data/tokenizer.py | 2 +- fast_llm/engine/schedule/config.py | 5 - fast_llm/functional/config.py | 8 +- fast_llm/layers/multi_modal/embedding.py | 1 - fast_llm/layers/transformer/attention.py | 18 +- fast_llm/layers/transformer/config.py | 23 --- fast_llm/layers/transformer/transformer.py | 8 - .../layers/transformer/vision_transformer.py | 8 - fast_llm/layers/vision_encoder/adapter.py | 1 - fast_llm/layers/vision_encoder/config.py | 21 ++- .../{encoder.py => patch_conv.py} | 6 +- .../layers/vision_encoder/preprocessing.py | 5 - fast_llm/models/gpt/conversion.py | 161 +++++++++--------- fast_llm/models/gpt/model.py | 11 +- fast_llm/models/gpt/trainer.py | 9 +- fast_llm/tools/cli.py | 1 + fast_llm/utils.py | 7 - 20 files changed, 129 insertions(+), 191 deletions(-) rename fast_llm/layers/vision_encoder/{encoder.py => patch_conv.py} (95%) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 4fcd42ae1..31a19e148 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -51,13 +51,13 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling batch_images.append([torch.from_numpy(image) for image in sample.images]) has_images = True else: - batch_images.append(None) + batch_images.append([]) batch_image_positions = [] for sample in batch: if sample.image_positions is not None: batch_image_positions.append(torch.from_numpy(sample.image_positions)) else: - batch_image_positions.append(None) + batch_image_positions.append([]) return GPTBatch( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 53f8e4688..2e9243807 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -151,12 +151,6 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.optional, valid=check_field(Assert.geq, 1), ) - tokenize_batch_size: int = Field( - default=1000, - desc="Batch size for tokenization.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 1), - ) saving_workers: int = Field( default=1, desc="Number of processes for saving the data.", @@ -170,7 +164,7 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): ) tokenizer: TokenizerConfig = Field( default_factory=TokenizerConfig, - desc="Tokenizer configuration.", + desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) image_patch_size: int = Field( diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index c5a1b339c..fa46ee92e 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -138,20 +138,9 @@ def _document_generator(): if self._config.dataset.loss_masking_spans else None ), - # [np.array(Image.open(pathlib.Path(self._config.dataset.path) / path)) for path in item["image_paths"]] if self._config.dataset.image_paths else None, - # [np.array(im) for im in item["images"]] if self._config.dataset.images else None, item["images"] if self._config.dataset.images else None, item["image_positions"] if self._config.dataset.image_positions else None, ) - # if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: - # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - # yield GPTSample( - # np.array(item["input_ids"], dtype=self._data_type.numpy), - # np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), - # ) - # else: - # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - # yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -279,7 +268,7 @@ def run(self) -> None: if self._config.dataset.field not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.") tokenize_fn = self._tokenize_batch - # Avoid decoding bytes to images unless asked + # decoding bytes to images is slow and should be done only when needed if self._config.dataset.images is not None: dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 0acb65e47..1cbc1ec56 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -44,7 +44,7 @@ def _tokenize(self, text: str, begin=True, end=True) -> list[int]: def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[list[int], list[tuple[int, int]]]: """ - Tokenize the input text and return the tokenized input_ids along with token spans. + Tokenize the input text and return the tokenized input_ids and if provided, token spans and image positions. """ if not image_positions: image_positions = [] diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 48daf0e69..204abdf1c 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -50,11 +50,6 @@ class BatchConfig(Config): hint=FieldHint.setup, ) # Image inputs - patch_size: int | None = Field( - default=None, - desc="Patch size for each image token", - hint=FieldHint.optional, - ) image_size: int | None = Field( default=None, desc="Maximum image height and width", diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 233ea339d..480fa067e 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -40,6 +40,7 @@ class ActivationType(enum.StrEnum): """ gelu = "gelu" + gelu_pytorch_tanh = "gelu_pytorch_tanh" silu = "silu" relu = "relu" squared_relu = "squared_relu" @@ -67,7 +68,8 @@ def _set_activation_fn_map() -> None: global _ACTIVATION_FN_MAP _ACTIVATION_FN_MAP = { - ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.gelu: torch.nn.functional.gelu, + ActivationType.gelu_pytorch_tanh: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), ActivationType.silu: torch.nn.functional.silu, ActivationType.relu: torch.nn.functional.relu, ActivationType.squared_relu: lambda x: torch.pow(torch.nn.functional.relu(x), 2), @@ -78,14 +80,14 @@ def _set_activation_fn_map() -> None: _ACTIVATION_FN_MAP: dict[ActivationType, typing.Callable[["torch.Tensor"], "torch.Tensor"]] = {} _ACTIVATION_HF_NAMES = { - ActivationType.gelu: "gelu_pytorch_tanh", + ActivationType.gelu: "gelu", + ActivationType.gelu_pytorch_tanh: "gelu_pytorch_tanh", ActivationType.silu: "silu", ActivationType.relu: "relu", ActivationType.squared_relu: "relu2", ActivationType.identity: "identity", } _ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()} -_ACTIVATION_HF_NAMES_INV["gelu"] = ActivationType.gelu MAX_DROPLESS_BLOCK_SIZE_ROW = 128 diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index c67f82b41..8c541e983 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -114,7 +114,6 @@ def forward( tensor_name="Embedding output", dtype=self._residual_dtype, ) - # image_embeddings = kwargs.pop(VisionEncoderKwargs.patch_embeddings) position_ids = kwargs.get(LanguageModelKwargs.position_ids) image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes) image_positions = kwargs.get(VisionEncoderKwargs.image_positions) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3180b6cb8..e88f64a30 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -191,7 +191,7 @@ def _get_meta( ) @property - def query_dims(self): + def _query_dims(self): return ( self._transformer_dim_names.batch, self._transformer_dim_names.sequence_q, @@ -200,7 +200,7 @@ def query_dims(self): ) @property - def kv_dims(self): + def _kv_dims(self): return ( self._transformer_dim_names.batch, self._transformer_dim_names.sequence_q, @@ -209,7 +209,7 @@ def kv_dims(self): ) @property - def context_dims(self): + def _context_dims(self): return ( self._transformer_dim_names.batch, self._transformer_dim_names.sequence_q, @@ -346,11 +346,11 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._config.rotary.enabled: if self._debug_transformer: - self._debug_log(query, "query_rotary_input", self.query_dims, kwargs) + self._debug_log(query, "query_rotary_input", self._query_dims, kwargs) self._debug_log( key, "key_rotary_input", - self.kv_dims, + self._kv_dims, kwargs, ) rotary_fn = triton_rotary_autograd_ if self._config.rotary.triton else apply_rotary_embeddings @@ -402,20 +402,20 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ ) if self._debug_transformer: - self._debug_log(query, "query", self.query_dims, kwargs) + self._debug_log(query, "query", self._query_dims, kwargs) self._debug_log( key, "key", - self.kv_dims, + self._kv_dims, kwargs, ) self._debug_log( value, "value", - self.kv_dims, + self._kv_dims, kwargs, ) - self._debug_log(input_, "context", self.context_dims, kwargs) + self._debug_log(input_, "context", self._context_dims, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 49babb06b..b8d153672 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -110,29 +110,6 @@ class VisionTransformerKwargs(BaseTransformerKwargs, prefix="image_encoder"): patch_position_ids = "patch_position_ids" -# class TransformerKwargs: -# rotary_freq_q = "rotary_freq_q" -# rotary_freq_k = "rotary_freq_k" -# attention_mask = "attention_mask" -# attention_mask_value = "attention_mask_value" -# sequence_lengths = "sequence_lengths" -# cu_seqlens_q = "cu_seqlens_q" -# cu_seqlens_k = "cu_seqlens_k" -# max_seqlen_q = "max_seqlen_q" -# max_seqlen_k = "max_seqlen_k" -# # TODO: Review these -# presents = "presents" -# past_key_values = "past_key_values" -# sequence_first = "sequence_first" -# hidden_dims = "hidden_dims" -# sequence_q_dim = "sequence_q_dim" -# sequence_k_dim = "sequence_k_dim" -# sequence_length = "sequence_length" -# micro_batch_size = "micro_batch_size" -# # TODO: Move -# grad_output = "grad_output" - - class TransformerLossNames: load_balancing_loss = "load_balancing_loss" router_z_loss = "router_z_loss" diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 2c79883b3..392ebb889 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -149,11 +149,3 @@ def __init__( def _create_mixer(self): self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) - - # @property - # def _transformer_kwargs(self) -> TransformerKwargs: - # return TransformerKwargs - - # @property - # def _transformer_dim_names(self) -> TransformerDimNames: - # return TransformerDimNames diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py index c2cfe9f23..7c1be0d16 100644 --- a/fast_llm/layers/transformer/vision_transformer.py +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -9,14 +9,6 @@ class VisionTransformerLayer(TransformerLayer): _name: str = "Vision transformer layer" - # @property - # def _transformer_kwargs(self) -> VisionTransformerKwargs: - # return VisionTransformerKwargs - - # @property - # def _transformer_dim_names(self) -> VisionTransformerDimNames: - # return VisionTransformerDimNames - def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[VisionTransformerKwargs.hidden_dims] if self._return_input: diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py index bf5f3f1aa..41ea065d0 100644 --- a/fast_llm/layers/vision_encoder/adapter.py +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -20,7 +20,6 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) self._activation_type = config.adapter_activation_type - # TODO Soham: Make them OutputParallelLinear instead? How would this work with parallelism? self.layer_1 = Linear( input_dim, tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 6932c8fc0..f788b5149 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,11 +1,12 @@ import enum -from fast_llm.config import Config, Field, FieldHint, config_class +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationConfig from fast_llm.layers.transformer.config import VisionTransformerConfig +from fast_llm.utils import Assert class VisionEncoderDimNames: @@ -129,18 +130,24 @@ class VisionEncoderConfig(BaseModelConfig): desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) + adapter_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the adapter weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + conv_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the convolutional layer weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.adapter_size, self.adapter_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.patch_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.in_channels, 3)) - # TODO Soham: add a check for presence of kv channels parameter (head_dim) - tensor_space.add_tensor_dim( - TensorDim( - VisionEncoderDimNames.kv_channels, self.transformer.hidden_size // self.transformer.num_attention_heads - ) - ) self.transformer.setup_tensor_space(tensor_space) @property diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/patch_conv.py similarity index 95% rename from fast_llm/layers/vision_encoder/encoder.py rename to fast_llm/layers/vision_encoder/patch_conv.py index 20749af48..68f22200a 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/patch_conv.py @@ -43,6 +43,7 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = tensor_space.distributed_config self._sequence_parallel = self._distributed_config.sequence_tensor_parallel + self._lr_scale = config.adapter_lr_scale # TODO Soham: lr_scale self.weight = ParameterMeta.from_dims( ( @@ -52,10 +53,13 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size), ), init_method=init_normal_(), + lr_scale=self._lr_scale, ) if config.conv_bias: self.bias = ParameterMeta.from_dims( - (self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels),) + (self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels),), + init_method=init_normal_(), + lr_sclae=self._lr_scale, ) else: self.bias = None diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 5009123f0..d85442a3e 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -103,7 +103,6 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self._distributed_config = self._tensor_space.distributed_config def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - # kwargs[VisionEncoderDimNames] kwargs[VisionEncoderKwargs.image_patches_meta] = TensorMeta.from_dims( ( TensorDim( @@ -141,16 +140,12 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] for imgs in images ] - # position_ids = position_ids_in_meshgrid(image_sizes, im_height, patch_size) patches = [] patch_position_ids = [] cu_seqlens = [0] max_seqlen = -1 kwargs.get(TransformerKwargs.sequence_first) for imgs, sizes in zip(images, image_sizes): - # sum( - # get_num_patches(*size, patch_size) for size in sizes - # ) seq_patches = [] sample_cu_seqlen = 0 for image, size in zip(imgs, sizes): diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index b7f9f7733..95bbebde2 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -115,8 +115,7 @@ def import_weight( return (merged_weight.t().contiguous(),) -class TransformerWeightConverterMixin: - +class WeightAndBiasConverterMixin: def _get_weight_and_bias_converters( self, fast_llm_prefix: str | tuple[str, ...], @@ -145,6 +144,83 @@ def _get_weight_and_bias_converters( ) return converters + +class CommonHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): + _model: GPTModel + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + """ + Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) + """ + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_layers"),), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("hidden_size",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_attention_heads"),), + export_names=(("num_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "head_groups"),), + export_names=(("num_key_value_heads",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "ffn_hidden_size"),), + export_names=(("intermediate_size",),), + ), + RenameParamConverter( + fast_llm_names=(("vocab_size",),), + export_names=(("vocab_size",),), + ), + RenameParamConverter( + fast_llm_names=(("tie_word_embeddings",),), + export_names=(("tie_word_embeddings",),), + ), + ] + + @abc.abstractmethod + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + pass + + def _create_weight_converters( + self, + hf_base_prefix: str = "", + offset: int = 0, + ) -> list[WeightConverter]: + converters = [] + num_layers = self._model.config.base_model.transformer.num_layers + + # Embeddings + converters.append( + WeightConverter(f"layers.{offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight") + ) + + converters += self._create_lm_head_converters(hf_base_prefix, offset=offset) + + for i in range(num_layers): + converters += self._create_transformer_layer_converters( + f"layers.{i+offset+1}", f"{hf_base_prefix}model.layers.{i}" + ) + + return converters + def _create_lm_head_converters(self, hf_base_prefix: str = "", offset: int = 0) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads @@ -250,83 +326,6 @@ def _create_transformer_layer_converters( return converters -class CommonHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): - _model: GPTModel - _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig - """ - Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) - """ - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), - RenameParamConverter( - fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) - ), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=(("hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), - export_names=(("num_hidden_layers",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("hidden_size",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "num_attention_heads"),), - export_names=(("num_attention_heads",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "head_groups"),), - export_names=(("num_key_value_heads",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), - export_names=(("intermediate_size",),), - ), - RenameParamConverter( - fast_llm_names=(("vocab_size",),), - export_names=(("vocab_size",),), - ), - RenameParamConverter( - fast_llm_names=(("tie_word_embeddings",),), - export_names=(("tie_word_embeddings",),), - ), - ] - - @abc.abstractmethod - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - pass - - def _create_weight_converters( - self, - hf_base_prefix: str = "", - offset: int = 0, - ) -> list[WeightConverter]: - converters = [] - num_layers = self._model.config.base_model.transformer.num_layers - - # Embeddings - converters.append( - WeightConverter(f"layers.{offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight") - ) - - converters += self._create_lm_head_converters(hf_base_prefix, offset=offset) - - for i in range(num_layers): - converters += self._create_transformer_layer_converters( - f"layers.{i+offset+1}", f"{hf_base_prefix}model.layers.{i}" - ) - - return converters - - class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = Starcoder2GPTHuggingfaceCheckpointFormat @@ -565,7 +564,7 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] -class PixtralHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): +class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @@ -770,7 +769,7 @@ def num_layers(self) -> int: return self._model.config.base_model.vision_encoder.transformer.num_layers + 2 -class LlavaHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): +class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index cbce66f2e..586b511ba 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -32,7 +32,7 @@ from fast_llm.layers.transformer.vision_transformer import VisionTransformerLayer from fast_llm.layers.vision_encoder.adapter import VisionAdapter from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs -from fast_llm.layers.vision_encoder.encoder import PatchConv +from fast_llm.layers.vision_encoder.patch_conv import PatchConv from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -84,11 +84,6 @@ def __init__( self._preprocessors.append( RotaryEmbeddingPreprocessor(self._config.vision_encoder.transformer.rotary, self._tensor_space) ) - # self._vision_preprocessor = VisionPreprocessor(self._config.vision_encoder, self._tensor_space) - # if self._config.vision_encoder.transformer.rotary.enabled: - # self._vision_rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor( - # self._config.vision_encoder.transformer.rotary, self._tensor_space - # ) def get_output_layers(self) -> list[Layer]: layers = [] @@ -178,14 +173,14 @@ def preprocess_meta( ] image_rescale_factor = self._config.vision_encoder.image_normalization.rescale_factor vision_kwargs = { - VisionEncoderKwargs.patch_size: batch_meta.patch_size, + VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, VisionEncoderKwargs.image_size: image_size, VisionEncoderKwargs.image_mean: image_mean, VisionEncoderKwargs.image_std: image_std, VisionEncoderKwargs.image_rescale_factor: image_rescale_factor, VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, VisionEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( - VisionEncoderDimNames.kv_channels + VisionTransformerDimNames.kv_channels ).size, VisionEncoderKwargs.out_channels: self._tensor_space.get_tensor_dim( VisionEncoderDimNames.out_channels diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 482fea02f..840b80926 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -30,10 +30,15 @@ def _get_sampling_parameters( "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, "cross_document_attention": self._config.batch.cross_document_attention, "extra_tokens": self._config.model.base_model.prediction_heads, - "patch_size": self._config.batch.patch_size, - "image_size": self._config.batch.image_size, } ) + if self._config.model.base_model.vision_encoder.enabled: + parameters.update( + { + "patch_size": self._config.model.base_model.vision_encoder.patch_size, + "image_size": self._config.batch.image_size, + } + ) return parameters if _return_dict else GPTSamplingParameters(**parameters) def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: diff --git a/fast_llm/tools/cli.py b/fast_llm/tools/cli.py index 4d218c3ff..0cc02f426 100644 --- a/fast_llm/tools/cli.py +++ b/fast_llm/tools/cli.py @@ -36,6 +36,7 @@ def fast_llm(args=None): if sys.gettrace(): raise logger.critical(traceback.format_exc()) + sys.exit(1) if __name__ == "__main__": diff --git a/fast_llm/utils.py b/fast_llm/utils.py index c5b7f07ae..51e0eee59 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -336,10 +336,3 @@ def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple def check_equal_nested(config_a, config_b): if errors := compare_nested(config_a, config_b): raise ValueError("\n".join(errors)) - - -def prefix_class_vars(cls, prefix: str, base_cls: type): - for attr, value in vars(base_cls).items(): - if not attr.startswith("__") and isinstance(value, str) and not hasattr(cls, attr): - setattr(cls, attr, prefix + value) - return cls From dc888c8fc6596b0ba7483b4eaf184ba7015e2063 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 22 May 2025 23:05:57 +0000 Subject: [PATCH 045/161] image break token in sampling --- fast_llm/data/dataset/gpt/config.py | 1 + fast_llm/data/dataset/gpt/sampled.py | 45 +++++++++++++++++-- fast_llm/layers/vision_encoder/config.py | 5 +++ .../layers/vision_encoder/preprocessing.py | 10 +++++ fast_llm/models/gpt/trainer.py | 1 + 5 files changed, 58 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 44d1f4cc9..004a062c2 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -76,6 +76,7 @@ class GPTSamplingParameters(SamplingParameters): cross_document_attention: bool = True patch_size: int | None = None image_size: int | None = None + image_break_token: int | None = None # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 780b18878..de8e1d75c 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -14,7 +14,7 @@ from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.layers.vision_encoder.preprocessing import get_num_patches, get_resize_dims +from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims from fast_llm.utils import Assert try: @@ -138,7 +138,7 @@ def _sample(self) -> None: for i, sizes in enumerate(image_sizes): image_token_sizes.append( sum( - get_num_patches( + get_num_image_tokens( *get_resize_dims( *size, self._parameters.image_size, @@ -146,6 +146,7 @@ def _sample(self) -> None: self._parameters.patch_size, ), self._parameters.patch_size, + break_token=self._parameters.image_break_token is not None, ) for size in sizes ) @@ -211,6 +212,7 @@ def _sample(self) -> None: "sequence_length": self._parameters.sequence_length, "patch_size": self._parameters.patch_size, "truncate_documents": self._truncate_documents, + "image_break_token": self._parameters.image_break_token, "config": self._config.to_dict(), } if self._truncate_documents: @@ -423,7 +425,7 @@ def __getitem__(self, index: int) -> typing.Any: text_size, image_lengths = self._indexed_dataset.get_document_size(document_index) image_sizes = [ - get_num_patches( + get_num_image_tokens( *get_resize_dims( *image_length, self._parameters.image_size, @@ -431,6 +433,7 @@ def __getitem__(self, index: int) -> typing.Any: self._parameters.patch_size, ), self._parameters.patch_size, + break_token=self._parameters.image_break_token is not None, ) for image_length in image_lengths ] @@ -473,7 +476,41 @@ def __getitem__(self, index: int) -> typing.Any: # Add placeholders for image tokens token_ids.append(sample.token_ids[start_pos:im_position]) text_tokens_added += len(token_ids[-1]) - token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) + # token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) + if self._parameters.image_break_token is not None: + # Calculate patch dimensions for the image + width, height = get_resize_dims( + image_lengths[idx][0], + image_lengths[idx][1], + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ) + num_patches_w = math.ceil(width / self._parameters.patch_size) + num_patches_h = math.ceil(height / self._parameters.patch_size) + + # Calculate the token count considering break tokens + tokens_per_row = num_patches_w + total_tokens = num_patches_h * tokens_per_row + ( + num_patches_h - 1 + ) # Add break tokens after each row except last + + # Create image token placeholder array + image_token_array = np.full((total_tokens,), -100, dtype=np.int64) + + # Add break tokens after each row except the last row + for row in range(num_patches_h - 1): + position = (row + 1) * tokens_per_row + row + image_token_array[position] = self._parameters.image_break_token + + token_ids.append(image_token_array) + + # Update image_tokens_added to reflect actual number of tokens added + image_tokens_added += total_tokens + else: + # Just add placeholders for all image tokens without break tokens + token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) + image_tokens_added += image_sizes[idx] image_positions.append(im_position + len(token_ids) + image_tokens_added) image_tokens_added += image_tokens start_pos = im_position diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index f788b5149..5b972f128 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -130,6 +130,11 @@ class VisionEncoderConfig(BaseModelConfig): desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) + image_break_token: int | None = Field( + default=None, + desc="Token id to separate image rows. If None, no token id is applied is applied.", + hint=FieldHint.optional, + ) adapter_lr_scale: float | None = Field( default=None, desc="Custom learning rate scale for the adapter weights.", diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index d85442a3e..5cffbff58 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -19,6 +19,16 @@ def get_num_patches(height: int, width: int, patch_size: int) -> tuple[int, int] return div(height, patch_size) * div(width, patch_size) +def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool) -> int: + """ + Calculate the number of image tokens. + If image_break is True, we consider 1 additional token after every row of patches. + """ + height_patches = div(height, patch_size) + width_patches = div(width, patch_size) + return height_patches * (width_patches + image_break) + + def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: """ Calculate the new dimensions for resizing an image while maintaining the aspect ratio. diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 840b80926..d1b6d19e2 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -37,6 +37,7 @@ def _get_sampling_parameters( { "patch_size": self._config.model.base_model.vision_encoder.patch_size, "image_size": self._config.batch.image_size, + "image_break_token": self._config.model.base_model.vision_encoder.image_break_token, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) From af3e2dbcb19bec618d88dbf1bfb913fe8940caf7 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 23 May 2025 22:47:04 +0000 Subject: [PATCH 046/161] minor fixes --- fast_llm/data/dataset/gpt/memmap.py | 6 +- fast_llm/data/dataset/gpt/sampled.py | 13 ++-- fast_llm/layers/multi_modal/embedding.py | 64 ++++++++++++++----- .../layers/vision_encoder/preprocessing.py | 2 +- 4 files changed, 60 insertions(+), 25 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 1efc312e8..a202d2e1f 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -10,7 +10,7 @@ from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.layers.vision_encoder.preprocessing import get_num_patches, get_resize_dims +from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims from fast_llm.utils import Assert, div @@ -201,6 +201,7 @@ def get( use_loss_masking_spans: bool = False, patch_size: int | None = None, image_size: int | None = None, + image_break: bool = False, ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, @@ -239,9 +240,10 @@ def get( additional_tokens = 0 image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") while image_position >= span[0] and image_position <= span[1]: - image_tokens = get_num_patches( + image_tokens = get_num_image_tokens( get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), patch_size, + image_break=image_break, ) additional_tokens += image_tokens image_idx += 1 diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index de8e1d75c..f441d9b9e 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -146,7 +146,7 @@ def _sample(self) -> None: self._parameters.patch_size, ), self._parameters.patch_size, - break_token=self._parameters.image_break_token is not None, + image_break=self._parameters.image_break_token is not None, ) for size in sizes ) @@ -433,7 +433,7 @@ def __getitem__(self, index: int) -> typing.Any: self._parameters.patch_size, ), self._parameters.patch_size, - break_token=self._parameters.image_break_token is not None, + image_break=self._parameters.image_break_token is not None, ) for image_length in image_lengths ] @@ -476,6 +476,7 @@ def __getitem__(self, index: int) -> typing.Any: # Add placeholders for image tokens token_ids.append(sample.token_ids[start_pos:im_position]) text_tokens_added += len(token_ids[-1]) + image_positions.append(text_tokens_added + im_position + image_tokens_added) # token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) if self._parameters.image_break_token is not None: # Calculate patch dimensions for the image @@ -491,12 +492,12 @@ def __getitem__(self, index: int) -> typing.Any: # Calculate the token count considering break tokens tokens_per_row = num_patches_w - total_tokens = num_patches_h * tokens_per_row + ( + resized_image_tokens = num_patches_h * tokens_per_row + ( num_patches_h - 1 ) # Add break tokens after each row except last # Create image token placeholder array - image_token_array = np.full((total_tokens,), -100, dtype=np.int64) + image_token_array = np.full((resized_image_tokens,), -100, dtype=np.int64) # Add break tokens after each row except the last row for row in range(num_patches_h - 1): @@ -506,13 +507,11 @@ def __getitem__(self, index: int) -> typing.Any: token_ids.append(image_token_array) # Update image_tokens_added to reflect actual number of tokens added - image_tokens_added += total_tokens + image_tokens_added += resized_image_tokens else: # Just add placeholders for all image tokens without break tokens token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) image_tokens_added += image_sizes[idx] - image_positions.append(im_position + len(token_ids) + image_tokens_added) - image_tokens_added += image_tokens start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) text_tokens_added += len(token_ids[-1]) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 8c541e983..12b58a764 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -9,9 +9,9 @@ from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs -from fast_llm.layers.vision_encoder.preprocessing import get_num_patches +from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div class MultiModalEmbedding(LanguageModelEmbedding): @@ -60,15 +60,32 @@ def _forward( for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): - num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) - if self._sequence_parallel: - embeddings[position : position + num_image_tokens, sample_idx] = input_[ - image_embedding_offset : image_embedding_offset + num_image_tokens, sample_idx - ] - else: - embeddings[sample_idx, position : position + num_image_tokens] = input_[ - sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens - ] + num_image_tokens = get_num_image_tokens(*size, self._config.vision_encoder.patch_size) + # Calculate the patch dimensions + patch_width = div(size[0], self._config.vision_encoder.patch_size) + patch_height = div(size[1], self._config.vision_encoder.patch_size) + + # Process row by row for both sequence parallel and non-parallel cases + for row in range(patch_height): + # Calculate source and destination starting positions + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + + # Always use full patch_width + tokens_in_row = patch_width + + if self._sequence_parallel: + # Copy with dimensions swapped for sequence parallel case + embeddings[row_start_dst : row_start_dst + tokens_in_row, sample_idx] = input_[ + row_start_src : row_start_src + tokens_in_row, sample_idx + ] + else: + # Copy with normal dimension ordering + embeddings[sample_idx, row_start_dst : row_start_dst + tokens_in_row] = input_[ + sample_idx, row_start_src : row_start_src + tokens_in_row + ] + + # Move to the next image in the input tensor image_embedding_offset += num_image_tokens if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) @@ -85,10 +102,27 @@ def _forward( for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): - num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) - embeddings[sample_idx, position : position + num_image_tokens] = input_[ - sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens - ] + num_image_tokens = get_num_image_tokens( + *size, + self._config.vision_encoder.patch_size, + image_break=self._config.vision_encoder.image_break_token is not None, + ) + # Calculate the patch dimensions + patch_width = div(size[0], self._config.vision_encoder.patch_size) + patch_height = div(size[1], self._config.vision_encoder.patch_size) + + # Process row by row + for row in range(patch_height): + # Calculate source and destination starting positions + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + + # Copy row by row + embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ + sample_idx, row_start_src : row_start_src + patch_width + ] + + # Move to the next image in the input tensor image_embedding_offset += num_image_tokens if self._use_absolute_position_embeddings: diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 5cffbff58..c5c14a262 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -26,7 +26,7 @@ def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: """ height_patches = div(height, patch_size) width_patches = div(width, patch_size) - return height_patches * (width_patches + image_break) + return height_patches * (width_patches + image_break) - 1 def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: From 6d56be085309a4e0f74c24c5bad4aa8aea442708 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 24 May 2025 19:43:34 +0000 Subject: [PATCH 047/161] fix img break --- fast_llm/data/dataset/gpt/sampled.py | 6 +++--- fast_llm/layers/vision_encoder/preprocessing.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f441d9b9e..2c068742c 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -476,19 +476,19 @@ def __getitem__(self, index: int) -> typing.Any: # Add placeholders for image tokens token_ids.append(sample.token_ids[start_pos:im_position]) text_tokens_added += len(token_ids[-1]) - image_positions.append(text_tokens_added + im_position + image_tokens_added) + image_positions.append(text_tokens_added + image_tokens_added) # token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) if self._parameters.image_break_token is not None: # Calculate patch dimensions for the image - width, height = get_resize_dims( + height, width = get_resize_dims( image_lengths[idx][0], image_lengths[idx][1], self._parameters.image_size, self._parameters.image_size, self._parameters.patch_size, ) - num_patches_w = math.ceil(width / self._parameters.patch_size) num_patches_h = math.ceil(height / self._parameters.patch_size) + num_patches_w = math.ceil(width / self._parameters.patch_size) # Calculate the token count considering break tokens tokens_per_row = num_patches_w diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index c5c14a262..8404adae9 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -26,7 +26,7 @@ def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: """ height_patches = div(height, patch_size) width_patches = div(width, patch_size) - return height_patches * (width_patches + image_break) - 1 + return height_patches * width_patches + (height_patches - 1 if image_break else 0) def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: From ce9164647d3a582b8a13fd3646a66f3a019c8966 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 27 May 2025 23:34:57 +0000 Subject: [PATCH 048/161] fixes --- fast_llm/layers/language_model/embedding.py | 5 ++++- fast_llm/layers/multi_modal/embedding.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1d9406ed1..f51f40df7 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -99,7 +99,10 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> t input_ = split(input_, group=group, dim=0) if self._use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) - embeddings = torch.embedding(self.word_embeddings_weight, input_) + # mask padded tokens + input_mask = input_ >= 0 + masked_input = input_ * input_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) with set_generator( diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 12b58a764..f40df3f09 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -60,7 +60,11 @@ def _forward( for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): - num_image_tokens = get_num_image_tokens(*size, self._config.vision_encoder.patch_size) + num_image_tokens = get_num_image_tokens( + *size, + self._config.vision_encoder.patch_size, + image_break=self._config.vision_encoder.image_break_token is not None, + ) # Calculate the patch dimensions patch_width = div(size[0], self._config.vision_encoder.patch_size) patch_height = div(size[1], self._config.vision_encoder.patch_size) @@ -97,7 +101,10 @@ def _forward( # TODO Soham: get image positions for current split. Maybe in preprocessing? # for positions in image_positions: # if positions > self._distributed_config.tensor_rank - embeddings = torch.embedding(self.word_embeddings_weight, tokens) + # mask padded tokens + token_mask = tokens >= 0 + masked_tokens = tokens * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) embeddings = embeddings.clone() for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 From 204b3e9f27e6d12168f72a4ae045fc7ab9dbe475 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 28 May 2025 06:04:47 +0000 Subject: [PATCH 049/161] fix image embeddings offset --- fast_llm/data/dataset/gpt/config.py | 1 + fast_llm/data/dataset/gpt/memmap.py | 2 + fast_llm/data/dataset/gpt/sampled.py | 68 +++++++------- fast_llm/layers/multi_modal/embedding.py | 89 ++++++++----------- fast_llm/layers/vision_encoder/config.py | 7 +- .../layers/vision_encoder/preprocessing.py | 31 ++++++- fast_llm/models/gpt/trainer.py | 1 + 7 files changed, 109 insertions(+), 90 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 004a062c2..bb3ff717a 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -77,6 +77,7 @@ class GPTSamplingParameters(SamplingParameters): patch_size: int | None = None image_size: int | None = None image_break_token: int | None = None + image_end_token: int | None = None # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index a202d2e1f..d83064b1e 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -202,6 +202,7 @@ def get( patch_size: int | None = None, image_size: int | None = None, image_break: bool = False, + image_end: bool = False, ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, @@ -244,6 +245,7 @@ def get( get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), patch_size, image_break=image_break, + image_end=image_end, ) additional_tokens += image_tokens image_idx += 1 diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 2c068742c..6c8e9fe71 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -15,7 +15,7 @@ from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div try: from fast_llm.csrc.data import build_padded_token_cumsum, build_sample_idx # noqa @@ -147,6 +147,7 @@ def _sample(self) -> None: ), self._parameters.patch_size, image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, ) for size in sizes ) @@ -213,6 +214,7 @@ def _sample(self) -> None: "patch_size": self._parameters.patch_size, "truncate_documents": self._truncate_documents, "image_break_token": self._parameters.image_break_token, + "image_end_token": self._parameters.image_end_token, "config": self._config.to_dict(), } if self._truncate_documents: @@ -424,18 +426,23 @@ def __getitem__(self, index: int) -> typing.Any: text_size, image_lengths = self._indexed_dataset.get_document_size(document_index) + resized_image_lengths = [ + get_resize_dims( + *image_length, + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ) + for image_length in image_lengths + ] image_sizes = [ get_num_image_tokens( - *get_resize_dims( - *image_length, - self._parameters.image_size, - self._parameters.image_size, - self._parameters.patch_size, - ), + *image_length, self._parameters.patch_size, image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, ) - for image_length in image_lengths + for image_length in resized_image_lengths ] image_tokens = sum(image_sizes) document_size = text_size + image_tokens @@ -468,6 +475,8 @@ def __getitem__(self, index: int) -> typing.Any: offset=token_start_index_in_document, length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._parameters.use_loss_masking_spans, + # image_break=self._parameters.image_break_token is not None, + # image_end=self._parameters.image_end_token is not None, ) start_pos = 0 if sample.image_positions: @@ -477,41 +486,30 @@ def __getitem__(self, index: int) -> typing.Any: token_ids.append(sample.token_ids[start_pos:im_position]) text_tokens_added += len(token_ids[-1]) image_positions.append(text_tokens_added + image_tokens_added) - # token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) if self._parameters.image_break_token is not None: - # Calculate patch dimensions for the image - height, width = get_resize_dims( - image_lengths[idx][0], - image_lengths[idx][1], - self._parameters.image_size, - self._parameters.image_size, - self._parameters.patch_size, - ) - num_patches_h = math.ceil(height / self._parameters.patch_size) - num_patches_w = math.ceil(width / self._parameters.patch_size) - - # Calculate the token count considering break tokens - tokens_per_row = num_patches_w - resized_image_tokens = num_patches_h * tokens_per_row + ( - num_patches_h - 1 - ) # Add break tokens after each row except last + height, width = resized_image_lengths[idx] + num_patches_h = div(height, self._parameters.patch_size) + num_patches_w = div(width, self._parameters.patch_size) # Create image token placeholder array - image_token_array = np.full((resized_image_tokens,), -100, dtype=np.int64) + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) # Add break tokens after each row except the last row for row in range(num_patches_h - 1): - position = (row + 1) * tokens_per_row + row + position = (row + 1) * num_patches_w + row image_token_array[position] = self._parameters.image_break_token - - token_ids.append(image_token_array) - - # Update image_tokens_added to reflect actual number of tokens added - image_tokens_added += resized_image_tokens + # add end token if specified, else break token + last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 + if self._parameters.image_end_token is not None: + image_token_array[last_row_position] = self._parameters.image_end_token + else: + image_token_array[last_row_position] = self._parameters.image_break_token else: - # Just add placeholders for all image tokens without break tokens - token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) - image_tokens_added += image_sizes[idx] + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + if self._parameters.image_end_token is not None: + image_token_array[-1] = self._parameters.image_end_token + token_ids.append(image_token_array) + image_tokens_added += image_sizes[idx] start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) text_tokens_added += len(token_ids[-1]) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index f40df3f09..4dd4a46eb 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -9,7 +9,7 @@ from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs -from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens +from fast_llm.layers.vision_encoder.preprocessing import get_num_patches from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div @@ -60,37 +60,30 @@ def _forward( for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): - num_image_tokens = get_num_image_tokens( - *size, - self._config.vision_encoder.patch_size, - image_break=self._config.vision_encoder.image_break_token is not None, - ) - # Calculate the patch dimensions - patch_width = div(size[0], self._config.vision_encoder.patch_size) - patch_height = div(size[1], self._config.vision_encoder.patch_size) - - # Process row by row for both sequence parallel and non-parallel cases - for row in range(patch_height): - # Calculate source and destination starting positions - row_start_src = image_embedding_offset + row * patch_width - row_start_dst = position + row * (patch_width + 1) - - # Always use full patch_width - tokens_in_row = patch_width - - if self._sequence_parallel: - # Copy with dimensions swapped for sequence parallel case - embeddings[row_start_dst : row_start_dst + tokens_in_row, sample_idx] = input_[ - row_start_src : row_start_src + tokens_in_row, sample_idx - ] - else: - # Copy with normal dimension ordering - embeddings[sample_idx, row_start_dst : row_start_dst + tokens_in_row] = input_[ - sample_idx, row_start_src : row_start_src + tokens_in_row - ] - - # Move to the next image in the input tensor - image_embedding_offset += num_image_tokens + num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if self._config.vision_encoder.image_break_token is not None: + patch_width = div(size[0], self._config.vision_encoder.patch_size) + patch_height = div(size[1], self._config.vision_encoder.patch_size) + + for row in range(patch_height): + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + + if self._sequence_parallel: + # Copy with dimensions swapped for sequence parallel case + embeddings[row_start_dst : row_start_dst + patch_width, sample_idx] = input_[ + row_start_src : row_start_src + patch_width, sample_idx + ] + else: + # Copy with normal dimension ordering + embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ + sample_idx, row_start_src : row_start_src + patch_width + ] + else: + embeddings[sample_idx, position : position + num_patches] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_patches + ] + image_embedding_offset += num_patches if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: @@ -109,28 +102,24 @@ def _forward( for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): - num_image_tokens = get_num_image_tokens( - *size, - self._config.vision_encoder.patch_size, - image_break=self._config.vision_encoder.image_break_token is not None, - ) - # Calculate the patch dimensions - patch_width = div(size[0], self._config.vision_encoder.patch_size) - patch_height = div(size[1], self._config.vision_encoder.patch_size) + num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if self._config.vision_encoder.image_break_token is not None: + patch_width = div(size[0], self._config.vision_encoder.patch_size) + patch_height = div(size[1], self._config.vision_encoder.patch_size) - # Process row by row - for row in range(patch_height): - # Calculate source and destination starting positions - row_start_src = image_embedding_offset + row * patch_width - row_start_dst = position + row * (patch_width + 1) + for row in range(patch_height): + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) - # Copy row by row - embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ - sample_idx, row_start_src : row_start_src + patch_width + embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ + sample_idx, row_start_src : row_start_src + patch_width + ] + else: + embeddings[sample_idx, position : position + num_patches] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_patches ] - # Move to the next image in the input tensor - image_embedding_offset += num_image_tokens + image_embedding_offset += num_patches if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 5b972f128..267941741 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -132,7 +132,12 @@ class VisionEncoderConfig(BaseModelConfig): ) image_break_token: int | None = Field( default=None, - desc="Token id to separate image rows. If None, no token id is applied is applied.", + desc="Token id to separate image rows. If None, no token id is applied.", + hint=FieldHint.optional, + ) + image_end_token: int | None = Field( + default=None, + desc="Token id to indicate the end of an image. If None, no token id is applied.", hint=FieldHint.optional, ) adapter_lr_scale: float | None = Field( diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 8404adae9..41da4fb6f 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -6,6 +6,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import TensorMeta @@ -19,14 +20,19 @@ def get_num_patches(height: int, width: int, patch_size: int) -> tuple[int, int] return div(height, patch_size) * div(width, patch_size) -def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool) -> int: +def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool, image_end: bool) -> int: """ Calculate the number of image tokens. If image_break is True, we consider 1 additional token after every row of patches. """ height_patches = div(height, patch_size) width_patches = div(width, patch_size) - return height_patches * width_patches + (height_patches - 1 if image_break else 0) + num_tokens = height_patches * width_patches + if image_break: + num_tokens += height_patches + elif image_end: + num_tokens += 1 + return num_tokens def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: @@ -150,16 +156,32 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] for imgs in images ] + image_positions = kwargs.get(VisionEncoderKwargs.image_positions) + + labels = kwargs[LanguageModelKwargs.labels] + if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): + # If image break or end token is present, we need to replace image token ids to -100 in labels + # TODO: avoid double cloning labels in case of loss masking spans? + labels = labels.clone() + patches = [] patch_position_ids = [] cu_seqlens = [0] max_seqlen = -1 kwargs.get(TransformerKwargs.sequence_first) - for imgs, sizes in zip(images, image_sizes): + for idx, (imgs, sizes, positions) in enumerate(zip(images, image_sizes, image_positions)): seq_patches = [] sample_cu_seqlen = 0 - for image, size in zip(imgs, sizes): + for image, size, position in zip(imgs, sizes, positions): seqlen = get_num_patches(*size, patch_size) + num_tokens = get_num_image_tokens( + *size, + patch_size=patch_size, + image_break=self._config.image_break_token is not None, + image_end=self._config.image_end_token is not None, + ) + # set labels for image patches to -100 + labels[idx, max(position - 1, 0) : position + num_tokens - 1] = -100 if seqlen > max_seqlen: max_seqlen = seqlen cu_seqlens.append(cu_seqlens[-1] + seqlen) @@ -204,6 +226,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: # TODO Soham: remove assert patches[-1].size(0) == kwargs[TransformerKwargs.sequence_length] patches = torch.cat(patches) + kwargs[LanguageModelKwargs.labels] = labels patch_position_ids = torch.cat(patch_position_ids) kwargs[VisionEncoderKwargs.image_patches] = patches kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index d1b6d19e2..a4f0b0b42 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -38,6 +38,7 @@ def _get_sampling_parameters( "patch_size": self._config.model.base_model.vision_encoder.patch_size, "image_size": self._config.batch.image_size, "image_break_token": self._config.model.base_model.vision_encoder.image_break_token, + "image_end_token": self._config.model.base_model.vision_encoder.image_end_token, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) From fd08eac092f508b50219d4314f22a54af8efe768 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 29 May 2025 00:10:38 +0000 Subject: [PATCH 050/161] heterogeneous data fixes --- fast_llm/engine/multi_stage/stage.py | 2 +- fast_llm/functional/cross_entropy.py | 2 +- fast_llm/functional/triton/mlp.py | 4 ++-- .../layers/vision_encoder/preprocessing.py | 21 +++++++++++++++---- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 675e878b3..b1c7df819 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -121,7 +121,7 @@ def backward( assert self._mode.support_backward input_, output = grad_context output.backward(output_grad) - return input_.grad + return input_.grad if input_.grad is not None else torch.zeros_like(input_) def restore_parameters(self) -> None: assert self._is_setup diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 513510ec7..53b5979ed 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -145,7 +145,7 @@ def _fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask + per_sample_loss = per_sample_loss[loss_mask] loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index ee3ba304c..0fb71bd56 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -50,7 +50,7 @@ def triton_mlp_activation_forward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) - if activation_type == _TritonActivationType.gelu: + if activation_type == _TritonActivationType.gelu_pytorch_tanh: tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) out = input_ * 0.5 * (1.0 + tanh) @@ -100,7 +100,7 @@ def triton_mlp_activation_backward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) output_grad = tl.load(grad_output_ptr + output_offsets, mask=mask).to(tl.float32) - if activation_type == _TritonActivationType.gelu: + if activation_type == _TritonActivationType.gelu_pytorch_tanh: tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) grad = 0.5 * input_ * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * input_ * input_)) + 0.5 * (1 + tanh) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 41da4fb6f..8fad35722 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -170,7 +170,13 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: max_seqlen = -1 kwargs.get(TransformerKwargs.sequence_first) for idx, (imgs, sizes, positions) in enumerate(zip(images, image_sizes, image_positions)): - seq_patches = [] + # add an empty tensor for clean concatenation in case of no images + seq_patches = [ + torch.tensor([]).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) + ] sample_cu_seqlen = 0 for image, size, position in zip(imgs, sizes, positions): seqlen = get_num_patches(*size, patch_size) @@ -211,9 +217,16 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] ) ) - position_ids = torch.cat( - [position_ids_in_meshgrid(*size, im_height // patch_size, patch_size) for size in sizes] - ).to(device=self._tensor_space.distributed.device) + if sizes: + position_ids = torch.cat( + [position_ids_in_meshgrid(*size, im_height // patch_size, patch_size) for size in sizes] + ).to(device=self._tensor_space.distributed.device) + else: + position_ids = torch.tensor( + [], + dtype=torch.int64, + device=self._tensor_space.distributed.device, + ) # We pad at the end instead of padding at the position in meshgrid because flash attention does not support custom attention masks patch_position_ids.append( torch.cat( From 1e3652aeae78f930fdd1c58d09b45681adec2047 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 29 May 2025 15:25:49 +0000 Subject: [PATCH 051/161] convert to rgb --- fast_llm/data/dataset/gpt/memmap.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index d83064b1e..703809417 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -325,10 +325,11 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP for image in document.images: # assume 3 channels (RGB) for all images with PIL.Image.open(io.BytesIO(image["bytes"])) as img: - if img.mode == "L": - # Convert grayscale to RGB + if img.mode != "RGB": + # Convert all images to RGB img = img.convert("RGB") pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW + assert pixels.dtype == np.uint8, f"Expected uint8 pixels, got {pixels.dtype}." image_lengths.append(np.array(pixels.shape[1:])) bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size From 2aabf353752eeb9290f470cd76e44da8482c0456 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 30 May 2025 20:48:27 +0000 Subject: [PATCH 052/161] fix sequence parallel image patches --- fast_llm/layers/multi_modal/embedding.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 4dd4a46eb..9e11df3f3 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -48,11 +48,17 @@ def _forward( """ Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) group = self._tensor_space.distributed.tensor_group + if self._sequence_parallel: + micro_seqlen = input_.size(0) + patch_start_offset = self._distributed_config.tensor_rank * micro_seqlen + patch_end_offset = (self._distributed_config.tensor_rank + 1) * micro_seqlen + else: + patch_start_offset = 0 + patch_end_offset = input_.size(0) if self._parallel_embeddings: token_mask = (tokens >= self._vocab_start_index) * (tokens < self._vocab_end_index) masked_tokens = (tokens - self._vocab_start_index) * token_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa - embeddings = reduce_forward(embeddings, group) if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) embeddings = embeddings.clone() @@ -61,13 +67,18 @@ def _forward( image_embedding_offset = 0 for position, size in zip(positions, sizes): num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if image_embedding_offset + num_patches < patch_start_offset: + continue if self._config.vision_encoder.image_break_token is not None: patch_width = div(size[0], self._config.vision_encoder.patch_size) patch_height = div(size[1], self._config.vision_encoder.patch_size) - for row in range(patch_height): row_start_src = image_embedding_offset + row * patch_width row_start_dst = position + row * (patch_width + 1) + if row_start_src > patch_end_offset: + break + if row_start_dst < patch_start_offset: + continue if self._sequence_parallel: # Copy with dimensions swapped for sequence parallel case @@ -84,6 +95,9 @@ def _forward( sample_idx, image_embedding_offset : image_embedding_offset + num_patches ] image_embedding_offset += num_patches + if image_embedding_offset > patch_end_offset: + break + embeddings = reduce_forward(embeddings, group) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: From b6d48589ad500034efdecb3727a5d163702f60e2 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 31 May 2025 01:50:12 +0000 Subject: [PATCH 053/161] fixes --- fast_llm/layers/multi_modal/embedding.py | 46 +++++++++++++------ .../layers/vision_encoder/preprocessing.py | 2 +- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 9e11df3f3..76060a004 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -59,8 +59,6 @@ def _forward( token_mask = (tokens >= self._vocab_start_index) * (tokens < self._vocab_end_index) masked_tokens = (tokens - self._vocab_start_index) * token_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa - if self._use_absolute_position_embeddings: - embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) embeddings = embeddings.clone() input_ = gather(input_, group, dim=0) for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): @@ -70,34 +68,56 @@ def _forward( if image_embedding_offset + num_patches < patch_start_offset: continue if self._config.vision_encoder.image_break_token is not None: - patch_width = div(size[0], self._config.vision_encoder.patch_size) - patch_height = div(size[1], self._config.vision_encoder.patch_size) + patch_height = div(size[0], self._config.vision_encoder.patch_size) + patch_width = div(size[1], self._config.vision_encoder.patch_size) for row in range(patch_height): row_start_src = image_embedding_offset + row * patch_width row_start_dst = position + row * (patch_width + 1) if row_start_src > patch_end_offset: break - if row_start_dst < patch_start_offset: + if row_start_src + patch_width <= patch_start_offset: continue + input_start_index = max(row_start_src, patch_start_offset) - patch_start_offset + input_end_index = min(row_start_src + patch_width, patch_end_offset) - patch_start_offset + embeddings_start_index = row_start_dst - max(patch_start_offset - row_start_src, 0) + embeddings_end_index = ( + row_start_dst + patch_width - max(row_start_src + patch_width - patch_end_offset, 0) + ) + # row_end_src = min(row_start_src + patch_width, patch_end_offset) if self._sequence_parallel: # Copy with dimensions swapped for sequence parallel case - embeddings[row_start_dst : row_start_dst + patch_width, sample_idx] = input_[ - row_start_src : row_start_src + patch_width, sample_idx + embeddings[embeddings_start_index:embeddings_end_index, sample_idx] = input_[ + input_start_index:input_end_index, sample_idx ] + tokens[embeddings_start_index:embeddings_end_index, sample_idx] = 10 else: # Copy with normal dimension ordering - embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ - sample_idx, row_start_src : row_start_src + patch_width + embeddings[sample_idx, embeddings_start_index:embeddings_end_index] = input_[ + sample_idx, input_start_index:input_end_index ] + tokens[embeddings_start_index:embeddings_end_index, sample_idx] = 10 else: - embeddings[sample_idx, position : position + num_patches] = input_[ - sample_idx, image_embedding_offset : image_embedding_offset + num_patches + input_start_index = max(image_embedding_offset, patch_start_offset) - patch_start_offset + input_end_index = ( + min(image_embedding_offset + num_patches, patch_end_offset) - patch_start_offset + ) + embedding_start_index = position - max(patch_start_offset - image_embedding_offset, 0) + embedding_end_index = ( + position + num_patches - max(image_embedding_offset + num_patches - patch_end_offset, 0) + ) + embeddings[sample_idx, embedding_start_index:embedding_end_index] = input_[ + input_start_index:input_end_index, sample_idx ] + # embeddings[sample_idx, position : position + num_patches] = input_[ + # sample_idx, image_embedding_offset : image_embedding_offset + num_patches + # ] image_embedding_offset += num_patches if image_embedding_offset > patch_end_offset: break embeddings = reduce_forward(embeddings, group) + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: @@ -118,8 +138,8 @@ def _forward( for position, size in zip(positions, sizes): num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) if self._config.vision_encoder.image_break_token is not None: - patch_width = div(size[0], self._config.vision_encoder.patch_size) - patch_height = div(size[1], self._config.vision_encoder.patch_size) + patch_height = div(size[0], self._config.vision_encoder.patch_size) + patch_width = div(size[1], self._config.vision_encoder.patch_size) for row in range(patch_height): row_start_src = image_embedding_offset + row * patch_width diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 8fad35722..ab0d23787 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -205,7 +205,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: padding_size = kwargs[TransformerKwargs.sequence_length] - sample_cu_seqlen if padding_size > max_seqlen: max_seqlen = padding_size - cu_seqlens.append(kwargs[TransformerKwargs.sequence_length]) + cu_seqlens.append(kwargs[TransformerKwargs.sequence_length] * (idx + 1)) patches.append( torch.cat( [ From 25a650bf588e8a20b02a4b6f6b991aa42993808b Mon Sep 17 00:00:00 2001 From: root Date: Sat, 31 May 2025 17:10:16 +0000 Subject: [PATCH 054/161] no compile for embeddings --- fast_llm/layers/multi_modal/embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 76060a004..7f09347bf 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -26,7 +26,7 @@ def __init__( ): super().__init__(config, tensor_space) - @torch.compile + # @torch.compile def _forward( self, input_: torch.Tensor, From c904da5def23c6db1abb775971f6790a4bec8272 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 1 Jun 2025 17:48:59 +0000 Subject: [PATCH 055/161] fix sampling --- fast_llm/data/dataset/gpt/sampled.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 6c8e9fe71..8d216b3d4 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -453,7 +453,7 @@ def __getitem__(self, index: int) -> typing.Any: document_sampling_index += 1 continue tokens_in_sample = token_count % (self._parameters.sequence_length + 1) - if document_size + tokens_in_sample > self._parameters.sequence_length + 1: + if document_size + tokens_in_sample >= self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample if token_count > token_start: @@ -464,6 +464,7 @@ def __getitem__(self, index: int) -> typing.Any: else: # Move on to the next sample. token_count += padding_size + continue # Determine if the document belongs to the requested sample. if token_count + document_size >= token_start: From 7a4701c522431eb94a873f59a220e13691c007b9 Mon Sep 17 00:00:00 2001 From: sohamparikh Date: Mon, 2 Jun 2025 00:15:54 -0700 Subject: [PATCH 056/161] sampling and preprocessing bugs --- fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/layers/vision_encoder/preprocessing.py | 6 +++--- fast_llm/models/gpt/model.py | 13 ++++++++++--- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 8d216b3d4..f58b009a1 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -456,7 +456,7 @@ def __getitem__(self, index: int) -> typing.Any: if document_size + tokens_in_sample >= self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample - if token_count > token_start: + if token_count >= token_start: # Add padding tokens to current sample token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) Assert.eq(token_count + padding_size, token_end) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index ab0d23787..76b0aa284 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -137,6 +137,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: im_height = kwargs.get(VisionEncoderKwargs.image_size) im_width = kwargs.get(VisionEncoderKwargs.image_size) patch_size = kwargs[VisionEncoderKwargs.patch_size] + image_positions = kwargs.get(VisionEncoderKwargs.image_positions) image_sizes = [ [get_resize_dims(im.size(1), im.size(2), im_height, im_width, patch_size=patch_size) for im in ims] for ims in images @@ -156,7 +157,6 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] for imgs in images ] - image_positions = kwargs.get(VisionEncoderKwargs.image_positions) labels = kwargs[LanguageModelKwargs.labels] if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): @@ -239,9 +239,9 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: # TODO Soham: remove assert patches[-1].size(0) == kwargs[TransformerKwargs.sequence_length] patches = torch.cat(patches) - kwargs[LanguageModelKwargs.labels] = labels patch_position_ids = torch.cat(patch_position_ids) kwargs[VisionEncoderKwargs.image_patches] = patches + kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( kwargs[VisionEncoderKwargs.rope_theta], kwargs[VisionEncoderKwargs.kv_channels], @@ -249,7 +249,6 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: patch_size, ).to(device=self._tensor_space.distributed.device) kwargs[VisionEncoderKwargs.max_image_tokens] = div(im_height * im_width, patch_size**2) - kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids # TODO Soham: handle sequence data parallel kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 @@ -259,3 +258,4 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ) kwargs[VisionTransformerKwargs.max_seqlen_q] = max_seqlen kwargs[VisionTransformerKwargs.max_seqlen_k] = max_seqlen + kwargs[LanguageModelKwargs.labels] = labels diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 586b511ba..45cf4a4fe 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -407,15 +407,22 @@ def preprocess( kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) - if batch.images is not None: + if self._config.vision_encoder.enabled: + batch_images = ( + batch.images if batch.images is not None else [[]] * kwargs[TransformerKwargs.micro_batch_size] + ) kwargs[VisionEncoderKwargs.images] = [ [ img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) for img in images ] - for images in batch.images + for images in batch_images ] - kwargs[VisionEncoderKwargs.image_positions] = batch.image_positions + kwargs[VisionEncoderKwargs.image_positions] = ( + batch.image_positions + if batch.image_positions is not None + else [[]] * kwargs[TransformerKwargs.micro_batch_size] + ) kwargs[LanguageModelKwargs.tokens] = tokens for preprocessor in self._preprocessors: From 067f901bc8bc0b51148f2531d1f929f74b90081a Mon Sep 17 00:00:00 2001 From: root Date: Mon, 2 Jun 2025 18:35:24 +0000 Subject: [PATCH 057/161] speed up sampling --- fast_llm/data/dataset/gpt/sampled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f58b009a1..2972632cb 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -166,7 +166,7 @@ def _sample(self) -> None: " Please make sure Fast-LLM is installed correctly." ) long_docs_filter = document_sizes + image_token_sizes > self._parameters.sequence_length + 1 - ignored_documents = sum(long_docs_filter) + ignored_documents = long_docs_filter.sum() if ignored_documents: log_main_rank( f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", From f24325eaf768de0dac5a3e4c7f879a3bf0d5f3cc Mon Sep 17 00:00:00 2001 From: root Date: Mon, 2 Jun 2025 22:15:57 +0000 Subject: [PATCH 058/161] cap image size reduction --- fast_llm/layers/vision_encoder/preprocessing.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 76b0aa284..a9115c97c 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -50,9 +50,22 @@ def get_resize_dims(height: int, width: int, max_height: int, max_width: int, pa def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: - resize_dims = get_resize_dims(image.size(1), image.size(2), max_height, max_width, patch_size=patch_size) + target_height, target_width = get_resize_dims( + image.size(1), image.size(2), max_height, max_width, patch_size=patch_size + ) + height, width = image.size(1), image.size(2) + while height > 2 * target_height or width > 2 * target_width: + # cap the resizing to half of the current size as a workaround for large images + # See pytorch issue: https://github.com/pytorch/pytorch/issues/103589 + intermediate_max_width = max(target_width, width // 2) + intermediate_max_height = max(target_height, height // 2) + height, width = get_resize_dims( + height, width, intermediate_max_height, intermediate_max_width, patch_size=patch_size + ) + image = F.resize(image, size=(height, width), interpolation=F.InterpolationMode.BICUBIC) + # TODO: options for interpolation mode? - return F.resize(image, size=resize_dims, interpolation=F.InterpolationMode.BICUBIC) + return F.resize(image, size=(target_height, target_width), interpolation=F.InterpolationMode.BICUBIC) def normalize(image: torch.Tensor, mean: list[float], std: list[float]) -> torch.Tensor: From 0f376643df53c2831c2a164436fd7aba92cb4f80 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 2 Jun 2025 23:15:50 +0000 Subject: [PATCH 059/161] fix span offset with images --- fast_llm/data/dataset/gpt/memmap.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 703809417..6f5a963f4 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -240,6 +240,18 @@ def get( for span in sample_spans: additional_tokens = 0 image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + while image_position < span[0]: + image_tokens = get_num_image_tokens( + get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), + patch_size, + image_break=image_break, + image_end=image_end, + ) + additional_tokens += image_tokens + image_idx += 1 + image_position = ( + image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + ) while image_position >= span[0] and image_position <= span[1]: image_tokens = get_num_image_tokens( get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), From ff8fecc46e37c76a84ba8036b3781e3c1c9c447e Mon Sep 17 00:00:00 2001 From: root Date: Mon, 2 Jun 2025 23:35:47 +0000 Subject: [PATCH 060/161] fix span offset with images --- fast_llm/data/dataset/gpt/memmap.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 6f5a963f4..a3f2f9019 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -237,8 +237,9 @@ def get( sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset if images: image_idx = 0 + prev_image_tokens = 0 for span in sample_spans: - additional_tokens = 0 + span_image_tokens = 0 image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") while image_position < span[0]: image_tokens = get_num_image_tokens( @@ -247,11 +248,12 @@ def get( image_break=image_break, image_end=image_end, ) - additional_tokens += image_tokens + span_image_tokens += image_tokens image_idx += 1 image_position = ( image_positions[image_idx] if image_idx < len(image_positions) else float("inf") ) + prev_image_tokens += image_tokens while image_position >= span[0] and image_position <= span[1]: image_tokens = get_num_image_tokens( get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), @@ -259,12 +261,14 @@ def get( image_break=image_break, image_end=image_end, ) - additional_tokens += image_tokens + span_image_tokens += image_tokens image_idx += 1 image_position = ( image_positions[image_idx] if image_idx < len(image_positions) else float("inf") ) - span[1] += additional_tokens + span[0] += prev_image_tokens + span[1] += prev_image_tokens + span_image_tokens + prev_image_tokens += span_image_tokens return GPTSample( token_ids=token_ids, images=images, From c663cbb69b6334e2802e04113c9aa33ca7e4f8c3 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 3 Jun 2025 00:06:02 +0000 Subject: [PATCH 061/161] move image logic to sampled --- fast_llm/data/dataset/gpt/memmap.py | 35 ---------------------------- fast_llm/data/dataset/gpt/sampled.py | 31 +++++++++++++++++++++--- 2 files changed, 28 insertions(+), 38 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index a3f2f9019..ce24f3b97 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -10,7 +10,6 @@ from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims from fast_llm.utils import Assert, div @@ -235,40 +234,6 @@ def get( ] sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - if images: - image_idx = 0 - prev_image_tokens = 0 - for span in sample_spans: - span_image_tokens = 0 - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - while image_position < span[0]: - image_tokens = get_num_image_tokens( - get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), - patch_size, - image_break=image_break, - image_end=image_end, - ) - span_image_tokens += image_tokens - image_idx += 1 - image_position = ( - image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - ) - prev_image_tokens += image_tokens - while image_position >= span[0] and image_position <= span[1]: - image_tokens = get_num_image_tokens( - get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), - patch_size, - image_break=image_break, - image_end=image_end, - ) - span_image_tokens += image_tokens - image_idx += 1 - image_position = ( - image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - ) - span[0] += prev_image_tokens - span[1] += prev_image_tokens + span_image_tokens - prev_image_tokens += span_image_tokens return GPTSample( token_ids=token_ids, images=images, diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 2972632cb..d0a867510 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -476,8 +476,6 @@ def __getitem__(self, index: int) -> typing.Any: offset=token_start_index_in_document, length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._parameters.use_loss_masking_spans, - # image_break=self._parameters.image_break_token is not None, - # image_end=self._parameters.image_end_token is not None, ) start_pos = 0 if sample.image_positions: @@ -520,12 +518,39 @@ def __getitem__(self, index: int) -> typing.Any: images.append([]) if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: + prev_image_tokens = 0 + image_idx = 0 + image_position = ( + sample.image_positions[image_idx] + if image_idx < len(sample.image_positions) + else float("inf") + ) + while image_position < loss_masking_span[0]: + prev_image_tokens += image_sizes[image_idx] + image_idx += 1 + image_position = ( + sample.image_positions[image_idx] + if image_idx < len(sample.image_positions) + else float("inf") + ) + span_image_tokens = 0 + while image_position <= loss_masking_span[1]: + span_image_tokens += image_sizes[image_idx] + image_idx += 1 + image_position = ( + sample.image_positions[image_idx] + if image_idx < len(sample.image_positions) + else float("inf") + ) + loss_masking_span[0] += prev_image_tokens + loss_masking_span[1] += prev_image_tokens + span_image_tokens + prev_image_tokens += span_image_tokens span = np.clip( loss_masking_span + token_count - token_start, 0, self._parameters.sequence_length + self._parameters.extra_tokens, ) - if span[1] > span[0]: + if span[1] >= span[0]: loss_masking_spans.append(span) # Go to the next document. From f52f02bf71ac17abd64cca4f7ecae4de9eea4cb2 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 3 Jun 2025 00:06:32 +0000 Subject: [PATCH 062/161] cleanup --- fast_llm/data/dataset/gpt/memmap.py | 34 ----------------------------- 1 file changed, 34 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index ce24f3b97..21c096b38 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -162,46 +162,12 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap - # def get( - # self, - # idx: int, - # offset: int = 0, - # image_offset: int = 0, - # length: int | None = None, - # use_loss_masking_spans: bool = False, - # ): - # token_ids = np.frombuffer( - # self._bin_buffer, - # dtype=self._dtype, - # count=self._document_sizes[idx] - offset if length is None else length, - # offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, - # ) - # if self._has_images: - # image_positions = self._image_positions[idx] - # pixels = np.frombuffer( - # self._bin_buffer, - # dtype=np.dtype(np.uint8), - # count=self._image_lengths[idx].prod(initial=3), - # offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, - # ) - # images = [] - # start = 0 - # for image_length in self._image_lengths[idx]: - # n_pixels = image_length.prod(initial=3) - # images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) - # start += n_pixels - # return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) - def get( self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False, - patch_size: int | None = None, - image_size: int | None = None, - image_break: bool = False, - image_end: bool = False, ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, From 02f6d8fa114dfd25c0ffbf52868131ba52011f20 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Jun 2025 01:50:33 +0000 Subject: [PATCH 063/161] cleanup --- fast_llm/data/dataset/gpt/memmap.py | 68 +++++++------- fast_llm/data/dataset/gpt/sampled.py | 9 +- fast_llm/layers/language_model/config.py | 1 - fast_llm/layers/multi_modal/embedding.py | 9 +- fast_llm/layers/transformer/config.py | 89 ++++++------------- fast_llm/layers/transformer/preprocessing.py | 9 +- fast_llm/layers/vision_encoder/config.py | 16 ++-- fast_llm/layers/vision_encoder/patch_conv.py | 1 - .../layers/vision_encoder/preprocessing.py | 4 +- fast_llm/models/gpt/conversion.py | 3 +- 10 files changed, 92 insertions(+), 117 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 76637565b..372415249 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -114,6 +114,34 @@ def _init( + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize + sum([x.nbytes for x in self._spans]) ) + # read preference spans + self._chosen_spans = None + self._rejected_spans = None + if self._has_preference_spans and self._version >= 3: + self._chosen_spans = [] + self._rejected_spans = [] + for idx in range(self._num_documents): + self._chosen_spans.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=2, + offset=offset + idx * 2 * np.dtype(np.int32).itemsize, + ) + ) + + rejected_span_offset = offset + np.array(self._chosen_spans).nbytes + for idx in range(self._num_documents): + self._rejected_spans.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=2, + offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, + ) + ) + offset += np.array(self._chosen_spans).nbytes + np.array(self._rejected_spans).nbytes + self._num_pixels = 0 self._image_lengths = None self._image_positions = None @@ -147,36 +175,6 @@ def _init( ) images_seen += n_images - # read preference spans - self._chosen_spans = None - self._rejected_spans = None - if self._has_preference_spans and self._version >= 3: - self._chosen_spans = [] - self._rejected_spans = [] - chosen_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes - for idx in range(self._num_documents): - self._chosen_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=chosen_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - - rejected_span_offset = ( - offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes - ) - for idx in range(self._num_documents): - self._rejected_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) @@ -215,7 +213,9 @@ def get( offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) images = None + image_positions = None if self._has_images: + image_positions = self._image_positions[idx] # Truncations with images are not yet supported, so we get all images from the document pixels = np.frombuffer( self._bin_buffer, @@ -275,6 +275,8 @@ def get( return GPTSample( token_ids=token_ids, + images=images, + image_positions=image_positions, loss_masking_spans=sample_spans, chosen_span=chosen_span, rejected_span=rejected_span, @@ -384,10 +386,12 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if total_images: n_images = np.array(n_images, dtype=np.int32) + image_lengths = np.stack(image_lengths, dtype=np.int32) + im_positions = np.array(im_positions, dtype=np.int32) else: n_images = np.array([]) - image_lengths = np.stack(image_lengths, dtype=np.int32) - im_positions = np.array(im_positions, dtype=np.int32) + image_lengths = np.array([]) + im_positions = np.array([]) # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index ddd45539c..092a1c1c9 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -519,10 +519,10 @@ def __getitem__(self, index: int) -> typing.Any: document_sampling_index += 1 continue tokens_in_sample = token_count % (self._parameters.sequence_length + 1) - if document_size + tokens_in_sample >= self._parameters.sequence_length + 1: + if document_size + tokens_in_sample > self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample - if token_count >= token_start: + if token_count > token_start: # Add padding tokens to current sample token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) Assert.eq(token_count + padding_size, token_end) @@ -531,6 +531,11 @@ def __getitem__(self, index: int) -> typing.Any: # Move on to the next sample. token_count += padding_size continue + elif document_size + tokens_in_sample == self._parameters.sequence_length + 1: + if token_count + document_size == token_start: + token_count += document_size + document_sampling_index += 1 + continue # Determine if the document belongs to the requested sample. if token_count + document_size >= token_start: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index c50a26ab9..ff4d5ec97 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -48,7 +48,6 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.architecture, ) vision_encoder: VisionEncoderConfig = Field( - default_factory=VisionEncoderConfig, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, ) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 7f09347bf..fa5c0356b 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -59,8 +59,12 @@ def _forward( token_mask = (tokens >= self._vocab_start_index) * (tokens < self._vocab_end_index) masked_tokens = (tokens - self._vocab_start_index) * token_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa + # Cloning since we will modify the embeddings in-place embeddings = embeddings.clone() input_ = gather(input_, group, dim=0) + # the embeddings tensor are full-sized, but we might get a split of the patch embeddings + # We need to determine the offset in the embeddings tensor for each sample + # and also account for the special image tokens if applicable for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): @@ -86,13 +90,11 @@ def _forward( ) # row_end_src = min(row_start_src + patch_width, patch_end_offset) if self._sequence_parallel: - # Copy with dimensions swapped for sequence parallel case embeddings[embeddings_start_index:embeddings_end_index, sample_idx] = input_[ input_start_index:input_end_index, sample_idx ] tokens[embeddings_start_index:embeddings_end_index, sample_idx] = 10 else: - # Copy with normal dimension ordering embeddings[sample_idx, embeddings_start_index:embeddings_end_index] = input_[ sample_idx, input_start_index:input_end_index ] @@ -125,9 +127,6 @@ def _forward( tokens = split(tokens, group=group, dim=0) if self._use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) - # TODO Soham: get image positions for current split. Maybe in preprocessing? - # for positions in image_positions: - # if positions > self._distributed_config.tensor_rank # mask padded tokens token_mask = tokens >= 0 masked_tokens = tokens * token_mask diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 1052e01ea..3bb302dd6 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -5,7 +5,7 @@ import typing import warnings -from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace @@ -120,13 +120,7 @@ class RotaryEmbeddingType(str, enum.Enum): default = "default" llama3 = "llama3" yarn = "yarn" - # TODO Soham: generic name? - pixtral = "pixtral" - - -class TransformerType(str, enum.Enum): - lm_decoder = "lm_decoder" - image_encoder = "image_encoder" + rope_2d = "rope_2d" @config_class(registry=True) @@ -193,28 +187,17 @@ def _validate(self) -> None: @property def _transformer_dim_names(self) -> TransformerDimNames: - return TransformerDimNames + if self.type == RotaryEmbeddingType.rope_2d: + return VisionTransformerDimNames + else: + return TransformerDimNames @property def _transformer_kwargs(self) -> TransformerKwargs: - return TransformerKwargs - - -@config_class() -class VisionRotaryConfig(RotaryConfig): - type: RotaryEmbeddingType = Field( - default=RotaryEmbeddingType.pixtral, - desc="The type of rotary embedding to use. Choices: none, default, llama3, yarn, pixtral.", - hint=FieldHint.feature, - ) - - @property - def _transformer_dim_names(self) -> VisionTransformerDimNames: - return VisionTransformerDimNames - - @property - def _transformer_kwargs(self) -> VisionTransformerKwargs: - return VisionTransformerKwargs + if self.type == RotaryEmbeddingType.rope_2d: + return VisionTransformerKwargs + else: + return TransformerKwargs for name in RotaryEmbeddingType: @@ -315,10 +298,15 @@ def _validate(self) -> None: TransformerPeftConfig.register_subclass(name.value, TransformerPeftConfig) -@config_class() +class TransformerType(str, enum.Enum): + lm_decoder = "lm_decoder" + image_encoder = "image_encoder" + + +@config_class(registry=True) class TransformerConfig(BaseModelConfig): _abstract = False - transformer_type: TransformerType = Field( + type: TransformerType = Field( default=TransformerType.lm_decoder, desc="Type of the transformer. Choices: lm_decoder, image_encoder.", hint=FieldHint.architecture, @@ -803,39 +791,20 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: @property def _transformer_kwargs(self) -> TransformerKwargs: - return TransformerKwargs + if self.type == TransformerType.image_encoder: + return VisionTransformerKwargs + else: + return TransformerKwargs @property def _transformer_dim_names(self) -> TransformerDimNames: - return TransformerDimNames - - -@config_class() -class VisionTransformerConfig(TransformerConfig): - """ - Configuration for the Vision Transformer (ViT) model. - """ - - transformer_type: TransformerType = FieldUpdate( - default=TransformerType.image_encoder, - desc="Type of the transformer. Choices: lm_decoder, image_encoder.", - hint=FieldHint.architecture, - ) - causal: bool = FieldUpdate( - default=False, - desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", - hint=FieldHint.feature, - ) - rotary: VisionRotaryConfig = FieldUpdate( - default_factory=VisionRotaryConfig, - desc="Configuration for the rotary positional embeddings.", - hint=FieldHint.feature, - ) + if self.type == TransformerType.image_encoder: + return VisionTransformerDimNames + else: + return TransformerDimNames - @property - def _transformer_kwargs(self) -> VisionTransformerKwargs: - return VisionTransformerKwargs - @property - def _transformer_dim_names(self) -> VisionTransformerDimNames: - return VisionTransformerDimNames +for name in TransformerType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + TransformerConfig.register_subclass(name.value, TransformerConfig) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index e5cb5fb89..ae74724c4 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -137,7 +137,6 @@ def get_2d_rotary_frequencies( height_positions = torch.arange(height, device=device, dtype=torch.float64) width_positions = torch.arange(width, device=device, dtype=torch.float64) frequencies = config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) - # TODO Soham: apply scaling angles_h = torch.outer(height_positions, frequencies[::2]) angles_w = torch.outer(width_positions, frequencies[1::2]) angles = torch.cat( @@ -182,7 +181,7 @@ def _create_tensors(self, sequence_length: int, num_patches: None | int = None) return self._tensor_cache_max_sequence_length = sequence_length - if self._config.type == RotaryEmbeddingType.pixtral: + if self._config.type == RotaryEmbeddingType.rope_2d: self._rotary_embedding_frequencies = get_2d_rotary_frequencies( self._config, num_patches, @@ -199,16 +198,16 @@ def _create_tensors(self, sequence_length: int, num_patches: None | int = None) ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - if self._config.type == RotaryEmbeddingType.pixtral: + if self._config.type == RotaryEmbeddingType.rope_2d: max_num_patches = kwargs[VisionEncoderKwargs.image_size] // kwargs[VisionEncoderKwargs.patch_size] self._create_tensors(kwargs[TransformerKwargs.sequence_length], max_num_patches) else: self._create_tensors(kwargs[TransformerKwargs.sequence_length]) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - if self._config.type == RotaryEmbeddingType.pixtral: + if self._config.type == RotaryEmbeddingType.rope_2d: position_ids = kwargs[self._transformer_kwargs.patch_position_ids] - # TODO Soham: use position_ids_q and position_ids_k for sequence_data_parallelism + # sequence data parallelism is not yet supported with images, so we can safely assume that sequence_q == sequence_k kwargs[self._transformer_kwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] kwargs[self._transformer_kwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] else: diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 267941741..c5b790fe4 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -5,7 +5,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationConfig -from fast_llm.layers.transformer.config import VisionTransformerConfig +from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.utils import Assert @@ -78,10 +78,11 @@ class ImageNormalizationConfig(Config): class VisionEncoderType(str, enum.Enum): none = "none" + # TODO: better name? normalization, patch size, adapter can change based on implementation, no standard way currently. pixtral = "pixtral" -@config_class() +@config_class(registry=True) class VisionEncoderConfig(BaseModelConfig): _abstract = False @@ -90,8 +91,7 @@ class VisionEncoderConfig(BaseModelConfig): desc="Type of the vision encoder. Choices: none, pixtral.", hint=FieldHint.architecture, ) - transformer: VisionTransformerConfig = Field( - default_factory=VisionTransformerConfig, + transformer: TransformerConfig = Field( desc="Configuration for the vision transformer architecture.", hint=FieldHint.core, ) @@ -106,7 +106,6 @@ class VisionEncoderConfig(BaseModelConfig): hint=FieldHint.optional, ) patch_norm: NormalizationConfig = Field( - default_factory=NormalizationConfig, desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) @@ -126,7 +125,6 @@ class VisionEncoderConfig(BaseModelConfig): hint=FieldHint.optional, ) image_normalization: ImageNormalizationConfig = Field( - default_factory=ImageNormalizationConfig, desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) @@ -163,3 +161,9 @@ def setup_tensor_space(self, tensor_space: TensorSpace): @property def enabled(self) -> bool: return self.type != VisionEncoderType.none + + +for name in VisionEncoderType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + VisionEncoderConfig.register_subclass(name.value, VisionEncoderConfig) diff --git a/fast_llm/layers/vision_encoder/patch_conv.py b/fast_llm/layers/vision_encoder/patch_conv.py index 68f22200a..559ecc22d 100644 --- a/fast_llm/layers/vision_encoder/patch_conv.py +++ b/fast_llm/layers/vision_encoder/patch_conv.py @@ -44,7 +44,6 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self._distributed_config = tensor_space.distributed_config self._sequence_parallel = self._distributed_config.sequence_tensor_parallel self._lr_scale = config.adapter_lr_scale - # TODO Soham: lr_scale self.weight = ParameterMeta.from_dims( ( self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels), diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index a9115c97c..12dc68db6 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -214,7 +214,6 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] ) ) - # TODO Soham: should this be micro_sequence_length? padding_size = kwargs[TransformerKwargs.sequence_length] - sample_cu_seqlen if padding_size > max_seqlen: max_seqlen = padding_size @@ -249,7 +248,6 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] ) ) - # TODO Soham: remove assert patches[-1].size(0) == kwargs[TransformerKwargs.sequence_length] patches = torch.cat(patches) patch_position_ids = torch.cat(patch_position_ids) @@ -262,7 +260,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: patch_size, ).to(device=self._tensor_space.distributed.device) kwargs[VisionEncoderKwargs.max_image_tokens] = div(im_height * im_width, patch_size**2) - # TODO Soham: handle sequence data parallel + # sequence data parallel is not yet supported for images, so we use the same cu_seqlens for q and k kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 ) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 95bbebde2..661f5e516 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -241,12 +241,11 @@ def _create_lm_head_converters(self, hf_base_prefix: str = "", offset: int = 0) ) # MTP-heads > 0 are thrown away - # TODO Soham: handle offset with MTP for i in range(1, prediction_heads): logger.warning( f"The model weights for the multi-token prediction head {i} are discarded during conversion." ) - mtp_transformer_layer_index = num_layers - 1 + 2 * i + mtp_transformer_layer_index = num_layers + offset - 1 + 2 * i # MTP transformer layer converters += self._create_transformer_layer_converters( f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True From 68431293beed488b50ed963474df1a29d05222ec Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Jun 2025 02:10:54 +0000 Subject: [PATCH 064/161] jpeg dependency --- Dockerfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index b7e42d4dc..be579bccb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,8 +3,7 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ - && apt-get install --no-install-recommends -y acl git-lfs \ - # && apt-get install --no-install-recommends -y acl git-lfs libtiff5-dev \ + && apt-get install --no-install-recommends -y acl git-lfs libjpeg-dev \ && rm -rf /var/lib/apt/lists/* \ && git lfs install From b94b1eefda1d4447678a231e32ccdee9745c2184 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Jun 2025 02:14:56 +0000 Subject: [PATCH 065/161] install libjpeg-dev in gh actions --- .github/workflows/ci.yaml | 1 + Dockerfile | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 912ddaf5e..05ce16216 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -27,6 +27,7 @@ jobs: - name: Install dependencies run: | + sudo apt install libjpeg-dev pip install "torch>=2.2.2" pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]" diff --git a/Dockerfile b/Dockerfile index be579bccb..dda7b6535 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,8 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ - && apt-get install --no-install-recommends -y acl git-lfs libjpeg-dev \ + # && apt-get install --no-install-recommends -y acl git-lfs libjpeg-dev \ + && apt-get install --no-install-recommends -y acl git-lfs \ && rm -rf /var/lib/apt/lists/* \ && git lfs install From 9e4f14fe19d3b389a257eefb6245bc15869639b5 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Jun 2025 02:35:16 +0000 Subject: [PATCH 066/161] fix sampling test --- .github/workflows/docs.yaml | 2 ++ fast_llm/data/dataset/gpt/indexed.py | 8 ++++++++ tests/data/test_sampling.py | 3 +++ 3 files changed, 13 insertions(+) diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 93191972e..e8cb56d85 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -29,6 +29,7 @@ jobs: restore-keys: | mkdocs-material- - run: | + sudo apt install libjpeg-dev pip install "torch>=2.2.2" pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ @@ -56,6 +57,7 @@ jobs: restore-keys: | mkdocs-material- - run: | + sudo apt install libjpeg-dev pip install "torch>=2.2.2" pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 56c4c8927..2c7aefc80 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -34,6 +34,14 @@ def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": else GPTSampledIndexedDataset(self, sampling) ) + @property + @abc.abstractmethod + def has_images(self) -> bool: + """ + Whether the dataset contains images. + This is used to determine whether to use image-related fields in the sampled data. + """ + class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset): """ diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 386795826..a0aff3a72 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -106,6 +106,9 @@ def get_document_size(self, index: int) -> int: def name(self) -> str: return "dataset" + def has_images(self) -> bool: + return False + TEST_DATASET = SimpleGPTIndexedDataset( [ From d1c804ff558e0b34f7dc47822a281cbf9c2c796c Mon Sep 17 00:00:00 2001 From: root Date: Fri, 6 Jun 2025 05:46:22 +0000 Subject: [PATCH 067/161] fix --- fast_llm/data/dataset/gpt/memmap.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 372415249..ba2aa5800 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -56,9 +56,6 @@ def _init( if self._version >= 3: self._has_preference_spans = struct.unpack("= 3: - self._has_preference_spans = struct.unpack("= 4: self._has_images = struct.unpack(" Date: Mon, 9 Jun 2025 17:40:10 +0000 Subject: [PATCH 068/161] fix data cache reloading --- fast_llm/data/dataset/gpt/sampled.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 092a1c1c9..b4648af40 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -236,9 +236,10 @@ def _sample(self) -> None: if self._yaml_path is not None and self._yaml_path.is_file(): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) + yaml_data["unshuffled_tokens"] = loaded_yaml_data.get("unshuffled_tokens", 0) self._load_yaml_data(yaml_data) - if not self._truncate_documents and not self._parameters.use_preference_loss_spans: - del loaded_yaml_data["unshuffled_tokens"] + # if not self._truncate_documents and not self._parameters.use_preference_loss_spans: + # del loaded_yaml_data["unshuffled_tokens"] if loaded_yaml_data != yaml_data: raise RuntimeError( From cba6986a5d9665f7dc26ff50ebc6875667af43e5 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 9 Jun 2025 17:43:20 +0000 Subject: [PATCH 069/161] fix tokenization --- fast_llm/data/tokenizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 5988769f2..24eb77bd3 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -69,7 +69,7 @@ def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[li char_pos = image_position image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") if char_pos < start: - self._tokenize(text[char_pos:start], begin=beginning_of_text, end=False) + tokenized_text = self._tokenize(text[char_pos:start], begin=beginning_of_text, end=False) beginning_of_text = False token_ids.extend(tokenized_text) char_pos = start From 275fefa1dbb0dcd4a85417a0488e115a6efe647c Mon Sep 17 00:00:00 2001 From: shruthan Date: Wed, 11 Jun 2025 12:05:26 -0700 Subject: [PATCH 070/161] pixtral SFT (#296) Co-authored-by: sohamparikh --- fast_llm/data/dataset/gpt/memmap.py | 8 +++++--- fast_llm/data/dataset/gpt/sampled.py | 9 +++++---- fast_llm/data/preparator/gpt_memmap/prepare.py | 4 ++-- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index ba2aa5800..acc7914f1 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -58,6 +58,8 @@ def _init( if self._version >= 4: self._has_images = struct.unpack(" typing.Any: use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) start_pos = 0 - if sample.image_positions: + has_images = sample.image_positions is not None + if has_image_positions: for idx, im_position in enumerate(sample.image_positions): # image_positions.append(im_positions + len(token_ids) + image_tokens_added) # Add placeholders for image tokens @@ -594,7 +595,7 @@ def __getitem__(self, index: int) -> typing.Any: image_idx = 0 image_position = ( sample.image_positions[image_idx] - if image_idx < len(sample.image_positions) + if has_images and image_idx < len(sample.image_positions) else float("inf") ) while image_position < loss_masking_span[0]: @@ -602,7 +603,7 @@ def __getitem__(self, index: int) -> typing.Any: image_idx += 1 image_position = ( sample.image_positions[image_idx] - if image_idx < len(sample.image_positions) + if has_images and image_idx < len(sample.image_positions) else float("inf") ) span_image_tokens = 0 @@ -611,7 +612,7 @@ def __getitem__(self, index: int) -> typing.Any: image_idx += 1 image_position = ( sample.image_positions[image_idx] - if image_idx < len(sample.image_positions) + if has_images and image_idx < len(sample.image_positions) else float("inf") ) loss_masking_span[0] += prev_image_tokens diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index ad3dd4496..0b6803100 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -158,13 +158,13 @@ def _document_generator(): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( np.array(item["input_ids"], dtype=self._data_type.numpy), + item["images"] if self._config.dataset.images else None, + item["image_positions"] if self._config.dataset.image_positions else None, ( np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2) if self._config.dataset.loss_masking_spans else None ), - item["images"] if self._config.dataset.images else None, - item["image_positions"] if self._config.dataset.image_positions else None, item.get("chosen_token_spans", None), item.get("rejected_token_spans", None), ) From 605cc7ffac53047bb82871233413ac13cef35cac Mon Sep 17 00:00:00 2001 From: root Date: Wed, 11 Jun 2025 21:12:51 +0000 Subject: [PATCH 071/161] review comments --- Dockerfile | 1 - fast_llm/data/dataset/gpt/memmap.py | 45 ++++++++----------- fast_llm/data/tokenizer.py | 28 ++++++------ fast_llm/layers/transformer/transformer.py | 4 ++ .../layers/transformer/vision_transformer.py | 16 ------- fast_llm/layers/vision_encoder/patch_conv.py | 27 ----------- .../layers/vision_encoder/preprocessing.py | 12 ----- fast_llm/models/gpt/model.py | 3 +- 8 files changed, 39 insertions(+), 97 deletions(-) delete mode 100644 fast_llm/layers/transformer/vision_transformer.py diff --git a/Dockerfile b/Dockerfile index dda7b6535..8c2efa85e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,6 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ - # && apt-get install --no-install-recommends -y acl git-lfs libjpeg-dev \ && apt-get install --no-install-recommends -y acl git-lfs \ && rm -rf /var/lib/apt/lists/* \ && git lfs install diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index acc7914f1..642cd9800 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -58,8 +58,6 @@ def _init( if self._version >= 4: self._has_images = struct.unpack("= 4: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) - self._image_lengths = [] + self._image_sizes = [] self._image_positions = [] images_seen = 0 for n_images in self._n_images: - self._image_lengths.append( + self._image_sizes.append( np.frombuffer( self._index_bin_buffer, dtype=np.int32, @@ -159,7 +154,7 @@ def _init( offset=offset + self._n_images.nbytes + 2 * images_seen * np.dtype(np.int32).itemsize, ).reshape(-1, 2) ) - self._num_pixels += self._image_lengths[-1].prod(axis=1, initial=3).sum() + self._num_pixels += self._image_sizes[-1].prod(axis=1, initial=3).sum() self._image_positions.append( np.frombuffer( self._index_bin_buffer, @@ -214,19 +209,19 @@ def get( image_positions = None if self._has_images: image_positions = self._image_positions[idx] - + # Truncations with images are not yet supported, so we get all images from the document pixels = np.frombuffer( self._bin_buffer, dtype=np.dtype(np.uint8), - count=self._image_lengths[idx].prod(initial=3, axis=1).sum(), + count=self._image_sizes[idx].prod(initial=3, axis=1).sum(), offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, ) images = [] start = 0 - for image_length in self._image_lengths[idx]: - n_pixels = image_length.prod(initial=3) - images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) + for image_size in self._image_sizes[idx]: + n_pixels = image_size.prod(initial=3) + images.append(pixels[start : start + n_pixels].reshape(3, image_size[0], image_size[1])) start += n_pixels sample_spans = None if use_loss_masking_spans and self._spans is not None: @@ -302,10 +297,10 @@ def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes, self._image_lengths + return self._document_sizes, self._image_sizes def get_document_size(self, index: int) -> int: - return self._document_sizes[index].item(), self._image_lengths[index] if self._has_images else [] + return self._document_sizes[index].item(), self._image_sizes[index] if self._has_images else [] @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): @@ -314,7 +309,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP num_documents = 0 doc_lengths = [] n_images = [] - image_lengths = [] + image_sizes = [] im_positions = [] total_images = 0 pointers = [] @@ -353,7 +348,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP img = img.convert("RGB") pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW assert pixels.dtype == np.uint8, f"Expected uint8 pixels, got {pixels.dtype}." - image_lengths.append(np.array(pixels.shape[1:])) + image_sizes.append(np.array(pixels.shape[1:])) bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size im_positions.extend(document.image_positions) @@ -385,11 +380,11 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if total_images: n_images = np.array(n_images, dtype=np.int32) - image_lengths = np.stack(image_lengths, dtype=np.int32) + image_sizes = np.stack(image_sizes, dtype=np.int32) im_positions = np.array(im_positions, dtype=np.int32) else: n_images = np.array([]) - image_lengths = np.array([]) + image_sizes = np.array([]) im_positions = np.array([]) # Write the index file (.idx) @@ -402,12 +397,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP idx_stream.write(struct.pack(" 0 else 0)) - # Placeholder flag for preference spans - idx_stream.write(struct.pack(" 0 else 0)) # Flag to indicate whether preference loss-masking spans are present idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) + # Flag to indicate whether images are present + idx_stream.write(struct.pack(" 0 else 0)) # Data type idx_stream.write(struct.pack(" tuple[li image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") for start, end in char_spans: + # Tokenize all text before the span, with image positions in mind (i.e., break text at image positions). while image_position <= start: tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) beginning_of_text = False @@ -76,6 +77,7 @@ def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[li len(token_ids) span_length = 0 token_start = len(token_ids) + # Tokenize all text before the end of the span while image_position <= end: tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) beginning_of_text = False @@ -85,21 +87,21 @@ def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[li char_pos = image_position image_idx += 1 image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - if char_pos < end: - if end >= len(text) - 1: - tokenized_text = self._tokenize(text[char_pos : end + 1], begin=beginning_of_text, end=True) - beginning_of_text = False - token_ids.extend(tokenized_text) - span_length += len(tokenized_text) - char_pos = end + 1 - else: - tokenized_text = self._tokenize(text[char_pos : end + 1], begin=beginning_of_text, end=False) - beginning_of_text = False - token_ids.extend(tokenized_text) - span_length += len(tokenized_text) - char_pos = end + 1 + # Tokenize the last part of the span, since there are no more images + if char_pos < end + 1: + # end of span is end of text + tokenized_text = self._tokenize( + text[char_pos : end + 1], + begin=beginning_of_text, + end=(end >= len(text) - 1), + ) + beginning_of_text = False + token_ids.extend(tokenized_text) + span_length += len(tokenized_text) + char_pos = end + 1 token_spans.append((token_start, token_start + span_length - 1)) + # Tokenize text remaining after the last span while image_position <= len(text): image_position = image_positions[image_idx] tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 392ebb889..784a0f051 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -149,3 +149,7 @@ def __init__( def _create_mixer(self): self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + + +class VisionTransformerLayer(TransformerLayer): + _name: str = "Vision transformer layer" diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py deleted file mode 100644 index 7c1be0d16..000000000 --- a/fast_llm/layers/transformer/vision_transformer.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch - -from fast_llm.engine.config_utils.tensor_space import TensorDim -from fast_llm.layers.transformer.config import VisionTransformerKwargs -from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.tensor import TensorMeta - - -class VisionTransformerLayer(TransformerLayer): - _name: str = "Vision transformer layer" - - def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): - dims = kwargs[VisionTransformerKwargs.hidden_dims] - if self._return_input: - dims = (TensorDim("stacked_input_output", 2),) + dims - return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) diff --git a/fast_llm/layers/vision_encoder/patch_conv.py b/fast_llm/layers/vision_encoder/patch_conv.py index 559ecc22d..3d1845dd8 100644 --- a/fast_llm/layers/vision_encoder/patch_conv.py +++ b/fast_llm/layers/vision_encoder/patch_conv.py @@ -10,33 +10,6 @@ from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ -def position_ids_in_meshgrid(patch_embeddings_list, max_size): - positions = [] - for patch in patch_embeddings_list: - height, width = patch.shape[-2:] - mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") - h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) - ids = h_grid * max_size + v_grid - positions.append(ids[:, 0]) - return torch.cat(positions) - - -def generate_block_attention_mask(patch_embeds_list, tensor): - dtype = tensor.dtype - device = tensor.device - seq_len = tensor.shape[1] - d_min = torch.finfo(dtype).min - causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device) - - block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1) - block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1) - for start, end in zip(block_start_idx, block_end_idx): - causal_mask[start:end, start:end] = 0 - - causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1) - return causal_mask - - class PatchConv(Layer): def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 12dc68db6..77220a063 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -104,18 +104,6 @@ def create_inv_freqs(rope_theta: int, kv_channels: int, image_size: int, patch_s return torch.cat((inv_freq, inv_freq), dim=-1) -def position_ids_in_meshgrid(image_sizes: list[torch.Tensor], max_size: int, patch_size: int) -> torch.Tensor: - positions = [] - for h, w in image_sizes: - patch_height = h // patch_size - patch_width = w // patch_size - mesh = torch.meshgrid(torch.arange(patch_height), torch.arange(patch_width), indexing="ij") - h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) - ids = h_grid * max_size + v_grid - positions.append(ids[:, 0]) - return positions - - def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tensor: patch_height = height // patch_size patch_width = width // patch_size diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 6b16938fb..bf3778cc7 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -113,13 +113,12 @@ def get_output_layers(self) -> list[Layer]: return layers def get_vision_layers(self) -> list[Layer]: - patch_conv = PatchConv(self._config.vision_encoder, self._tensor_space) vit_layers = [ VisionTransformerLayer(self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1) for idx in range(self._config.vision_encoder.transformer.num_layers) ] return [ - patch_conv, + PatchConv(self._config.vision_encoder, self._tensor_space), *vit_layers, VisionAdapter(self._config.vision_encoder, self._tensor_space), MultiModalEmbedding(self._config, self._tensor_space), From 06aa7401119302d1ced30c54012e7b5d19e88ea9 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 12 Jun 2025 17:40:56 +0000 Subject: [PATCH 072/161] simplified tokenization with spans --- fast_llm/data/tokenizer.py | 97 ++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 57 deletions(-) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index d8b0ff87b..284ae21f7 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -42,77 +42,60 @@ def _tokenize(self, text: str, begin=True, end=True) -> list[int]: + ([self.eod_id] if end else []) ) - def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[list[int], list[tuple[int, int]]]: + def tokenize( + self, text: str, char_spans=None, image_positions=None + ) -> tuple[list[int], list[tuple[int, int]], list[int]]: """ - Tokenize the input text and return the tokenized input_ids and if provided, token spans and image positions. + Tokenize the input text and return the tokenized input_ids, token spans, and image token positions. + This version simplifies logic by merging all relevant positions, sorting, and tokenizing between them. """ if not image_positions: image_positions = [] if not char_spans: char_spans = [] - image_idx = 0 - char_pos = 0 + # Collect all positions with their type + positions = [] + for idx, pos in enumerate(image_positions): + positions.append((pos, "image")) + for idx, (start, end) in enumerate(char_spans): + positions.append((start, "span_start")) + positions.append((end + 1, "span_end")) + # Sort positions by character index. We assume that image and span positions are individually sorted and spans do not overlap + positions = sorted(positions, key=lambda x: x[0]) + token_ids = [] - image_token_positions = [] token_spans = [] - beginning_of_text = True - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + image_token_positions = [] + char_pos = 0 + current_span_start = None - for start, end in char_spans: - # Tokenize all text before the span, with image positions in mind (i.e., break text at image positions). - while image_position <= start: - tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) - beginning_of_text = False - token_ids.extend(tokenized_text) - image_token_positions.append(len(token_ids)) - image_idx += 1 - char_pos = image_position - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - if char_pos < start: - tokenized_text = self._tokenize(text[char_pos:start], begin=beginning_of_text, end=False) - beginning_of_text = False - token_ids.extend(tokenized_text) - char_pos = start - len(token_ids) - span_length = 0 - token_start = len(token_ids) - # Tokenize all text before the end of the span - while image_position <= end: - tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) - beginning_of_text = False - token_ids.extend(tokenized_text) - image_token_positions.append(len(token_ids)) - span_length += len(tokenized_text) - char_pos = image_position - image_idx += 1 - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - # Tokenize the last part of the span, since there are no more images - if char_pos < end + 1: - # end of span is end of text + for position in positions: + if char_pos < position[0]: tokenized_text = self._tokenize( - text[char_pos : end + 1], - begin=beginning_of_text, - end=(end >= len(text) - 1), + text[char_pos : position[0]], begin=(char_pos == 0), end=position[0] > len(text) - 1 ) - beginning_of_text = False token_ids.extend(tokenized_text) - span_length += len(tokenized_text) - char_pos = end + 1 - token_spans.append((token_start, token_start + span_length - 1)) - - # Tokenize text remaining after the last span - while image_position <= len(text): - image_position = image_positions[image_idx] - tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) - beginning_of_text = False + char_pos = position[0] + # beginning_of_text = False + if position[1] == "image": + image_token_positions.append(len(token_ids)) + elif position[1] == "span_start": + assert ( + current_span_start is None + ), "Starting a new span before current has ended, please check for overlapping spans" + current_span_start = len(token_ids) + elif position[1] == "span_end": + assert ( + current_span_start is not None + ), "Closing a span that has not started, please check for overlapping spans" + # spans are inclusive, so we take the index of the last token in the span + token_spans.append((current_span_start, len(token_ids) - 1)) + current_span_start = None + # Handle any remaining text after the last position and add EOS token + if char_pos < len(text): + tokenized_text = self._tokenize(text[char_pos:], begin=(char_pos == 0), end=True) token_ids.extend(tokenized_text) - image_token_positions.append(len(token_ids)) - char_pos = image_position - image_idx += 1 - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - tokenized_text = self._tokenize(text[char_pos:], begin=beginning_of_text, end=True) - token_ids.extend(tokenized_text) return token_ids, token_spans, image_token_positions From 30e3d34acca8a1cb89149ad69ea0720fa0d327ca Mon Sep 17 00:00:00 2001 From: sohamparikh Date: Thu, 12 Jun 2025 10:42:12 -0700 Subject: [PATCH 073/161] Update fast_llm/data/preparator/gpt_memmap/prepare.py Co-authored-by: RaymondLi0 --- fast_llm/data/preparator/gpt_memmap/prepare.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 0b6803100..43849857b 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -329,6 +329,7 @@ def run(self) -> None: if self._config.dataset.images else 0 ) + # Add the token-equivalent bytes of pixels to determine shard size total_tokens += total_pixels // np.dtype(self._data_type.numpy).itemsize # Split dataset into shards based on number of tokens From c1aa7094924cd6931d27db5e02384fb79aaa1b36 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 12 Jun 2025 17:47:17 +0000 Subject: [PATCH 074/161] rename --- fast_llm/data/dataset/gpt/config.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 8 ++++---- fast_llm/engine/schedule/config.py | 2 +- fast_llm/layers/transformer/preprocessing.py | 2 +- fast_llm/layers/vision_encoder/config.py | 2 +- .../layers/vision_encoder/preprocessing.py | 18 +++++++++--------- fast_llm/models/gpt/model.py | 4 ++-- fast_llm/models/gpt/trainer.py | 2 +- 8 files changed, 20 insertions(+), 20 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 9a5aa2007..250bfcb09 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -76,7 +76,7 @@ class GPTSamplingParameters(SamplingParameters): use_preference_loss_spans: bool = False cross_document_attention: bool = True patch_size: int | None = None - image_size: int | None = None + max_image_size: int | None = None image_break_token: int | None = None image_end_token: int | None = None # How many extra tokens to add to the sequence length. diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 255a30963..d4bcacddd 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -151,8 +151,8 @@ def _sample(self) -> None: get_num_image_tokens( *get_resize_dims( *size, - self._parameters.image_size, - self._parameters.image_size, + self._parameters.max_image_size, + self._parameters.max_image_size, self._parameters.patch_size, ), self._parameters.patch_size, @@ -496,8 +496,8 @@ def __getitem__(self, index: int) -> typing.Any: resized_image_lengths = [ get_resize_dims( *image_length, - self._parameters.image_size, - self._parameters.image_size, + self._parameters.max_image_size, + self._parameters.max_image_size, self._parameters.patch_size, ) for image_length in image_lengths diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 204abdf1c..f5c1bc133 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -50,7 +50,7 @@ class BatchConfig(Config): hint=FieldHint.setup, ) # Image inputs - image_size: int | None = Field( + max_image_size: int | None = Field( default=None, desc="Maximum image height and width", hint=FieldHint.optional, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index ae74724c4..9b79aa1b3 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -199,7 +199,7 @@ def _create_tensors(self, sequence_length: int, num_patches: None | int = None) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: if self._config.type == RotaryEmbeddingType.rope_2d: - max_num_patches = kwargs[VisionEncoderKwargs.image_size] // kwargs[VisionEncoderKwargs.patch_size] + max_num_patches = kwargs[VisionEncoderKwargs.max_image_size] // kwargs[VisionEncoderKwargs.patch_size] self._create_tensors(kwargs[TransformerKwargs.sequence_length], max_num_patches) else: self._create_tensors(kwargs[TransformerKwargs.sequence_length]) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index c5b790fe4..2ea7f6114 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -22,7 +22,7 @@ class VisionEncoderKwargs: images = "images" image_patches = "image_patches" image_positions = "image_positions" - image_size = "image_size" + max_image_size = "max_image_size" image_sizes = "image_sizes" image_mean = "image_normalization_mean" image_std = "image_normalization_std" diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 77220a063..ebd41b3d7 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -84,9 +84,9 @@ def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: return F.pad(image, (0, 0, depth_padding, width_padding), 0) -def create_inv_freqs(rope_theta: int, kv_channels: int, image_size: int, patch_size: int) -> torch.Tensor: +def create_inv_freqs(rope_theta: int, kv_channels: int, max_image_size: int, patch_size: int) -> torch.Tensor: freqs = 1.0 / (rope_theta ** (torch.arange(0, kv_channels, 2).float() / kv_channels)) - max_patches_per_side = image_size // patch_size + max_patches_per_side = max_image_size // patch_size h = torch.arange(max_patches_per_side) w = torch.arange(max_patches_per_side) @@ -135,19 +135,19 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: images = kwargs.get(VisionEncoderKwargs.images) - im_height = kwargs.get(VisionEncoderKwargs.image_size) - im_width = kwargs.get(VisionEncoderKwargs.image_size) + max_image_size = kwargs.get(VisionEncoderKwargs.max_image_size) + im_width = kwargs.get(VisionEncoderKwargs.max_image_size) patch_size = kwargs[VisionEncoderKwargs.patch_size] image_positions = kwargs.get(VisionEncoderKwargs.image_positions) image_sizes = [ - [get_resize_dims(im.size(1), im.size(2), im_height, im_width, patch_size=patch_size) for im in ims] + [get_resize_dims(im.size(1), im.size(2), max_image_size, im_width, patch_size=patch_size) for im in ims] for ims in images ] kwargs[VisionEncoderKwargs.image_sizes] = image_sizes images = [ [ normalize( - resize(image, im_height, im_width, patch_size).to( + resize(image, max_image_size, im_width, patch_size).to( dtype=self._tensor_space.distributed_config.training_dtype.torch ) / kwargs[VisionEncoderKwargs.image_rescale_factor], @@ -219,7 +219,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ) if sizes: position_ids = torch.cat( - [position_ids_in_meshgrid(*size, im_height // patch_size, patch_size) for size in sizes] + [position_ids_in_meshgrid(*size, max_image_size // patch_size, patch_size) for size in sizes] ).to(device=self._tensor_space.distributed.device) else: position_ids = torch.tensor( @@ -244,10 +244,10 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( kwargs[VisionEncoderKwargs.rope_theta], kwargs[VisionEncoderKwargs.kv_channels], - im_height, + max_image_size, patch_size, ).to(device=self._tensor_space.distributed.device) - kwargs[VisionEncoderKwargs.max_image_tokens] = div(im_height * im_width, patch_size**2) + kwargs[VisionEncoderKwargs.max_image_tokens] = div(max_image_size * im_width, patch_size**2) # sequence data parallel is not yet supported for images, so we use the same cu_seqlens for q and k kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index bf3778cc7..a1479a34d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -162,7 +162,7 @@ def preprocess_meta( micro_sequence_length = sequence_length if self._config.vision_encoder.enabled: - image_size = batch_meta.image_size + max_image_size = batch_meta.max_image_size image_mean = [ self._config.vision_encoder.image_normalization.mean_r, self._config.vision_encoder.image_normalization.mean_g, @@ -176,7 +176,7 @@ def preprocess_meta( image_rescale_factor = self._config.vision_encoder.image_normalization.rescale_factor vision_kwargs = { VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, - VisionEncoderKwargs.image_size: image_size, + VisionEncoderKwargs.max_image_size: max_image_size, VisionEncoderKwargs.image_mean: image_mean, VisionEncoderKwargs.image_std: image_std, VisionEncoderKwargs.image_rescale_factor: image_rescale_factor, diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index b2736b447..92cb20554 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -37,7 +37,7 @@ def _get_sampling_parameters( parameters.update( { "patch_size": self._config.model.base_model.vision_encoder.patch_size, - "image_size": self._config.batch.image_size, + "max_image_size": self._config.batch.max_image_size, "image_break_token": self._config.model.base_model.vision_encoder.image_break_token, "image_end_token": self._config.model.base_model.vision_encoder.image_end_token, } From 8e106f74041a4c4181ecf414968f91d975000c5f Mon Sep 17 00:00:00 2001 From: root Date: Thu, 12 Jun 2025 23:18:13 +0000 Subject: [PATCH 075/161] fix conversion --- fast_llm/models/gpt/conversion.py | 8 +++++++- fast_llm/models/gpt/model.py | 3 +-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 661f5e516..080b8b3ae 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -28,7 +28,7 @@ from fast_llm.functional.config import ActivationType from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.layers.common.config import NormalizationType -from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig +from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig, TransformerType from fast_llm.layers.vision_encoder.config import VisionEncoderType from fast_llm.models.gpt.config import ( GPTBaseModelConfig, @@ -576,6 +576,9 @@ def _create_config_converters(cls) -> list[ParamConverter]: ConstantImportParamConverter( fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "type"),), fast_llm_value=TransformerType.image_encoder + ), ConstantExportParamConverter(export_names=(("architectures",),), export_value=["PixtralVisionModel"]), ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value=VisionEncoderType.pixtral), ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), @@ -639,6 +642,9 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), export_names=(("head_dim",),), ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.rope_2d + ), RenameParamConverter( fast_llm_names=( ( diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a1479a34d..8fc8c830b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -28,8 +28,7 @@ FlashAttnVarlenPreprocessor, RotaryEmbeddingPreprocessor, ) -from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.transformer.vision_transformer import VisionTransformerLayer +from fast_llm.layers.transformer.transformer import TransformerLayer, VisionTransformerLayer from fast_llm.layers.vision_encoder.adapter import VisionAdapter from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.layers.vision_encoder.patch_conv import PatchConv From 080dcb58e8a8c364513eb7f915950c7c7ffb5c28 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 16 Jun 2025 21:15:27 +0000 Subject: [PATCH 076/161] fix sequence lengths, parallel conv --- fast_llm/data/dataset/gpt/sampled.py | 33 ++++++++++++++---------- fast_llm/layers/multi_modal/embedding.py | 5 +--- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index d4bcacddd..2f1575f7d 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -551,26 +551,23 @@ def __getitem__(self, index: int) -> typing.Any: ) start_pos = 0 has_images = sample.image_positions is not None - if has_image_positions: + if has_images: + sample_token_ids = [] for idx, im_position in enumerate(sample.image_positions): - # image_positions.append(im_positions + len(token_ids) + image_tokens_added) - # Add placeholders for image tokens - token_ids.append(sample.token_ids[start_pos:im_position]) - text_tokens_added += len(token_ids[-1]) - image_positions.append(text_tokens_added + image_tokens_added) + # add placeholder masked tokens for images + # if image_break_token is set, it is appended after every row + # if image_end_token is set, it is appended at the end of the image instead of image_break_token + text_part = sample.token_ids[start_pos:im_position] if self._parameters.image_break_token is not None: height, width = resized_image_lengths[idx] num_patches_h = div(height, self._parameters.patch_size) num_patches_w = div(width, self._parameters.patch_size) - - # Create image token placeholder array image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) - - # Add break tokens after each row except the last row + # account for break tokens after each row for row in range(num_patches_h - 1): position = (row + 1) * num_patches_w + row image_token_array[position] = self._parameters.image_break_token - # add end token if specified, else break token + # handle the last row separately last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 if self._parameters.image_end_token is not None: image_token_array[last_row_position] = self._parameters.image_end_token @@ -580,11 +577,19 @@ def __getitem__(self, index: int) -> typing.Any: image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) if self._parameters.image_end_token is not None: image_token_array[-1] = self._parameters.image_end_token - token_ids.append(image_token_array) + segment = np.concatenate([text_part, image_token_array], dtype=np.int64) + sample_token_ids.append(segment) + text_tokens_added += len(text_part) + image_positions.append(text_tokens_added + image_tokens_added) image_tokens_added += image_sizes[idx] start_pos = im_position - token_ids.append(sample.token_ids[start_pos:]) - text_tokens_added += len(token_ids[-1]) + # Add the last text segment after the last image + sample_token_ids.append(sample.token_ids[start_pos:]) + text_tokens_added += len(sample_token_ids[-1]) + token_ids.append(np.concatenate(sample_token_ids)) + else: + token_ids.append(sample.token_ids[start_pos:]) + text_tokens_added += len(token_ids[-1]) if sample.images: images.append(sample.images) else: diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index fa5c0356b..948b2acf9 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -3,7 +3,7 @@ import torch from fast_llm.core.distributed import set_generator -from fast_llm.core.ops import gather, reduce_forward, split +from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.layers.language_model.embedding import LanguageModelEmbedding @@ -61,7 +61,6 @@ def _forward( embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa # Cloning since we will modify the embeddings in-place embeddings = embeddings.clone() - input_ = gather(input_, group, dim=0) # the embeddings tensor are full-sized, but we might get a split of the patch embeddings # We need to determine the offset in the embeddings tensor for each sample # and also account for the special image tokens if applicable @@ -93,12 +92,10 @@ def _forward( embeddings[embeddings_start_index:embeddings_end_index, sample_idx] = input_[ input_start_index:input_end_index, sample_idx ] - tokens[embeddings_start_index:embeddings_end_index, sample_idx] = 10 else: embeddings[sample_idx, embeddings_start_index:embeddings_end_index] = input_[ sample_idx, input_start_index:input_end_index ] - tokens[embeddings_start_index:embeddings_end_index, sample_idx] = 10 else: input_start_index = max(image_embedding_offset, patch_start_offset) - patch_start_offset input_end_index = ( From f1868687f2a230a98613ab936de87d6a145b17d1 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 16 Jun 2025 21:19:50 +0000 Subject: [PATCH 077/161] minor --- fast_llm/data/dataset/gpt/sampled.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 2f1575f7d..8641ee707 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -577,8 +577,7 @@ def __getitem__(self, index: int) -> typing.Any: image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) if self._parameters.image_end_token is not None: image_token_array[-1] = self._parameters.image_end_token - segment = np.concatenate([text_part, image_token_array], dtype=np.int64) - sample_token_ids.append(segment) + sample_token_ids.append(np.concatenate([text_part, image_token_array], dtype=np.int64)) text_tokens_added += len(text_part) image_positions.append(text_tokens_added + image_tokens_added) image_tokens_added += image_sizes[idx] From 6b9ea2e1b22e83fa936aae66c95d66218adfa0b3 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 16 Jun 2025 21:37:51 +0000 Subject: [PATCH 078/161] fix image at beginning --- fast_llm/data/tokenizer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 284ae21f7..7268ba3ce 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -56,9 +56,9 @@ def tokenize( # Collect all positions with their type positions = [] - for idx, pos in enumerate(image_positions): + for pos in image_positions: positions.append((pos, "image")) - for idx, (start, end) in enumerate(char_spans): + for start, end in char_spans: positions.append((start, "span_start")) positions.append((end + 1, "span_end")) # Sort positions by character index. We assume that image and span positions are individually sorted and spans do not overlap @@ -71,6 +71,7 @@ def tokenize( current_span_start = None for position in positions: + # We only tokenize if there is at least one character, else we might potentially add begin/end multiple times if char_pos < position[0]: tokenized_text = self._tokenize( text[char_pos : position[0]], begin=(char_pos == 0), end=position[0] > len(text) - 1 @@ -79,7 +80,11 @@ def tokenize( char_pos = position[0] # beginning_of_text = False if position[1] == "image": - image_token_positions.append(len(token_ids)) + if position[0] == 0: + # image should be after the bos token + image_token_positions.append(1) + else: + image_token_positions.append(len(token_ids)) elif position[1] == "span_start": assert ( current_span_start is None From ad18ea1903e8ceacd58ab2817759c2f963ccb511 Mon Sep 17 00:00:00 2001 From: RaymondLi0 Date: Fri, 20 Jun 2025 15:14:22 -0400 Subject: [PATCH 079/161] pixtral fix conversion (#315) --- fast_llm/models/gpt/conversion.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 080b8b3ae..a7e624ffe 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -563,6 +563,26 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] +class PixtralNumHeadsConverter(ParamConverter): + """ + Pixtral encoder uses Multi-Head Attention. + Map `num_attention_heads` and `head_groups` to a single `num_heads` parameter. + """ + + def __post_init__(self): + Assert.eq(len(self.fast_llm_names), 2) + Assert.eq(len(self.export_names), 1) + + def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (num_heads, head_groups) = fast_llm_values + assert head_groups == num_heads, "Pixtral encoder expects num_heads == head_groups (MHA)" + return (num_heads,) + + def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (num_heads,) = export_values + return (num_heads, num_heads) + + class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @@ -600,23 +620,18 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), export_names=(("hidden_size",),), ), - RenameParamConverter( + PixtralNumHeadsConverter( fast_llm_names=( ( "transformer", "num_attention_heads", ), - ), - export_names=(("num_attention_heads",),), - ), - RenameParamConverter( - fast_llm_names=( ( "transformer", "head_groups", ), ), - export_names=(("num_key_value_heads",),), + export_names=(("num_attention_heads",),), ), RenameParamConverter( fast_llm_names=( From 29e66d944034f58ae454b4cfd028be21efb8f848 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 25 Jun 2025 21:11:56 +0000 Subject: [PATCH 080/161] handle no image samples --- fast_llm/data/dataset/gpt/memmap.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 642cd9800..c7a99f10f 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -145,6 +145,7 @@ def _init( self._image_sizes = [] self._image_positions = [] images_seen = 0 + num_total_images = self._n_images.sum() for n_images in self._n_images: self._image_sizes.append( np.frombuffer( @@ -162,8 +163,8 @@ def _init( count=n_images, offset=offset + self._n_images.nbytes - + 2 * self._n_images.sum() * np.dtype(np.int32).itemsize - + images_seen * np.dtype(np.int32).itemsize, + + 2 * num_total_images * np.dtype(np.int32).itemsize + + +images_seen * np.dtype(np.int32).itemsize, ) ) images_seen += n_images @@ -352,6 +353,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size im_positions.extend(document.image_positions) + else: + n_images.append(0) # Update metadata doc_length = len(document.token_ids) From 06a0910bd06de25f84115b89dd8ddb566780a24a Mon Sep 17 00:00:00 2001 From: root Date: Thu, 26 Jun 2025 18:23:38 +0000 Subject: [PATCH 081/161] mask special image tokens --- fast_llm/models/gpt/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 8fc8c830b..9bef7ae5c 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -410,6 +410,12 @@ def preprocess( if self._config.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) + if self._config.vision_encoder.enabled: + labels = labels.clone() + if self._config.vision_encoder.image_break_token is not None: + labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) + if self._config.vision_encoder.image_end_token is not None: + labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) From bbd71dfb2706d19022eccb5e55019e526ded2007 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 27 Jun 2025 16:34:27 +0000 Subject: [PATCH 082/161] avoid multiple labels cloning --- fast_llm/models/gpt/model.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 9bef7ae5c..23bb3d067 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -387,9 +387,11 @@ def preprocess( labels = batch.token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config + labels_cloned = False if batch.loss_masking_spans is not None: # avoid changing input tokens labels = labels.clone() + labels_cloned = True for i, spans in enumerate(batch.loss_masking_spans): if not spans.numel(): continue @@ -411,10 +413,15 @@ def preprocess( kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) if self._config.vision_encoder.enabled: - labels = labels.clone() if self._config.vision_encoder.image_break_token is not None: + if not labels_cloned: + labels = labels.clone() + labels_cloned = True labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) if self._config.vision_encoder.image_end_token is not None: + if not labels_cloned: + labels = labels.clone() + labels_cloned = True labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) From 96a5fd82f5200712274ba217323793b91038615b Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Jul 2025 19:11:33 +0000 Subject: [PATCH 083/161] fix training --- Dockerfile | 2 +- fast_llm/functional/triton/mlp.py | 4 +- fast_llm/layers/transformer/rotary/config.py | 8 ++ fast_llm/layers/transformer/rotary/rotary.py | 72 +++++++++- fast_llm/models/gpt/conversion.py | 134 +++++++++---------- fast_llm/models/gpt/model.py | 5 +- 6 files changed, 150 insertions(+), 75 deletions(-) diff --git a/Dockerfile b/Dockerfile index e98223de8..6c013c14d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,7 +37,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV]" triton==3.1.0 +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,DEV]" triton==3.1.0 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index a34af4f5e..f3d9d7d0c 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -47,7 +47,7 @@ def triton_mlp_activation_forward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) - if activation_type in ["gelu_pytorch_tanh", "gelu"]: + if activation_type == "gelu" or activation_type == "gelu_pytorch_tanh": tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) out = input_ * 0.5 * (1.0 + tanh) @@ -97,7 +97,7 @@ def triton_mlp_activation_backward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) output_grad = tl.load(grad_output_ptr + output_offsets, mask=mask).to(tl.float32) - if activation_type in ["gelu_pytorch_tanh", "gelu"]: + if activation_type == "gelu" or activation_type == "gelu_pytorch_tanh": tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) grad = 0.5 * input_ * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * input_ * input_)) + 0.5 * (1 + tanh) diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/transformer/rotary/config.py index ce7af88d5..d7285714f 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/transformer/rotary/config.py @@ -136,3 +136,11 @@ def _get_configurable_class(self) -> "type[YarnRotary]": from fast_llm.layers.transformer.rotary.rotary import YarnRotary return YarnRotary + + +@config_class(dynamic_type={RotaryConfig: "rope_2d"}) +class Rotary2DConfig(DefaultRotaryConfig): + def _get_configurable_class(self) -> "type[Rotary2D]": + from fast_llm.layers.transformer.rotary.rotary import Rotary2D + + return Rotary2D diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 056b9aa4c..b2c69dd8d 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -8,14 +8,16 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs, VisionTransformerKwargs from fast_llm.layers.transformer.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, NoRotaryConfig, + Rotary2DConfig, RotaryConfig, YarnRotaryConfig, ) +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -212,3 +214,71 @@ def _get_correction(self, beta: float, dim: int) -> float: * math.log(self._config.original_context_length / (beta * 2 * math.pi)) / (2 * math.log(self._config.theta)) ) + + +class Rotary2D[ConfigType: DefaultRotaryConfig](DefaultRotary[Rotary2DConfig]): + _rotary_embedding_frequencies: torch.Tensor + _tensor_cache_max_num_patches: int = -1 + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + assert self._tensor_space is not None + max_num_patches = kwargs[VisionEncoderKwargs.max_image_size] // kwargs[VisionEncoderKwargs.patch_size] + self._create_tensors(max_num_patches) + position_ids = kwargs[VisionTransformerKwargs.patch_position_ids] + kwargs[VisionTransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] + kwargs[VisionTransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + assert self._tensor_space is not None + kwargs[VisionTransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + ( + self._scalar_dim, + kwargs[TransformerKwargs.sequence_q_dim], + self._scalar_dim, + self._kv_channels_dim, + ), + tensor_name=VisionTransformerKwargs.rotary_freq_q, + ) + kwargs[VisionTransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + ( + self._scalar_dim, + kwargs[TransformerKwargs.sequence_k_dim], + self._scalar_dim, + self._kv_channels_dim, + ), + tensor_name=VisionTransformerKwargs.rotary_freq_k, + ) + + def _create_tensors(self, max_num_patches: int) -> None: + if max_num_patches <= self._tensor_cache_max_num_patches: + return + self._tensor_cache_max_num_patches = max_num_patches + + self._rotary_embedding_frequencies = self._get_frequencies( + max_num_patches, + self._kv_channels_dim.global_size, + device=self._tensor_space.distributed.device, + ) + + def _get_frequencies(self, max_num_patches: int, kv_channels: int, device="cuda") -> torch.Tensor: + # Calculate complex frequencies by using alternating channels for width and height + height_positions = torch.arange(max_num_patches, device=device, dtype=torch.float64) + width_positions = torch.arange(max_num_patches, device=device, dtype=torch.float64) + frequencies = self._config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) + angles_h = torch.outer(height_positions, frequencies[::2]) + angles_w = torch.outer(width_positions, frequencies[1::2]) + angles = torch.cat( + [ + angles_h[:, None, :].repeat(1, max_num_patches, 1), + angles_w[None, :, :].repeat(max_num_patches, 1, 1), + ], + dim=-1, + ).reshape(-1, kv_channels // 2) + + frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) + if not self._config.complex_format: + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 + ).contiguous() + + return frequencies diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 319a495d7..e01aaf702 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -26,11 +26,15 @@ from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import LayerNormalizationConfig, NormalizationType -from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig, TransformerType -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig +from fast_llm.layers.common.config import LayerNormalizationConfig +from fast_llm.layers.transformer.config import RoutingType, TransformerConfig +from fast_llm.layers.transformer.rotary.config import ( + DefaultRotaryConfig, + Llama3RotaryConfig, + Rotary2DConfig, + YarnRotaryConfig, +) from fast_llm.layers.transformer.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex -from fast_llm.layers.vision_encoder.config import VisionEncoderType from fast_llm.models.gpt.config import ( DiffusionDreamGPTHuggingfaceCheckpointFormat, DiffusionLlamaGPTHuggingfaceCheckpointFormat, @@ -161,6 +165,7 @@ class CommonHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, Huggingfac @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantImportParamConverter(fast_llm_names=(("transformer", "type"),), fast_llm_value="lm_decoder"), ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]), ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), RenameParamConverter( @@ -228,42 +233,6 @@ def _create_weight_converters( return converters - def _create_lm_head_converters(self, hf_base_prefix: str = "", offset: int = 0) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers - prediction_heads = self._model.config.base_model.prediction_heads - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm - converters = [] - - # Next-token prediction head - # Final norm - converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + offset + 1}.final_norm", f"{hf_base_prefix}model.norm", norm_bias - ) - # Output weights - if self._model.config.base_model.tie_word_embeddings: - converters.append(IgnoreImportWeightConverter((), f"{hf_base_prefix}lm_head.weight")) - else: - converters.append( - WeightConverter(f"layers.{num_layers + offset + 1}.output_weights", f"{hf_base_prefix}lm_head.weight") - ) - - # MTP-heads > 0 are thrown away - for i in range(1, prediction_heads): - logger.warning( - f"The model weights for the multi-token prediction head {i} are discarded during conversion." - ) - mtp_transformer_layer_index = num_layers + offset - 1 + 2 * i - # MTP transformer layer - converters += self._create_transformer_layer_converters( - f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True - ) - # MTP output norm - converters += self._get_weight_and_bias_converters( - f"layers.{mtp_transformer_layer_index + 2}.final_norm", (), norm_bias, IgnoreExportWeightConverter - ) - - return converters - def _create_transformer_layer_converters( self, fast_llm_layer_name: str, hf_layer_name: str, ignore_export: bool = False ) -> list[WeightConverter]: @@ -331,7 +300,7 @@ def _create_transformer_layer_converters( converters += self._get_mlp_converters(f"{fast_llm_layer_name}", f"{hf_layer_name}") return converters - def _create_lm_head_converters(self) -> list[WeightConverter]: + def _create_lm_head_converters(self, hf_base_prefix: str = "", offset: int = 0) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) @@ -340,20 +309,22 @@ def _create_lm_head_converters(self) -> list[WeightConverter]: # Next-token prediction head # Final norm converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias + f"layers.{num_layers + offset + 1}.final_norm", f"{hf_base_prefix}model.norm", norm_bias ) # Output weights if self._model.config.base_model.tie_word_embeddings: - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + converters.append(IgnoreImportWeightConverter((), f"{hf_base_prefix}lm_head.weight")) else: - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + converters.append( + WeightConverter(f"layers.{num_layers + offset + 1}.output_weights", f"{hf_base_prefix}lm_head.weight") + ) # MTP-heads > 0 are thrown away for i in range(1, prediction_heads): logger.warning( f"The model weights for the multi-token prediction head {i} are discarded during conversion." ) - mtp_transformer_layer_index = num_layers - 1 + 2 * i + mtp_transformer_layer_index = num_layers + offset - 1 + 2 * i # MTP transformer layer converters += self._create_transformer_layer_converters( f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True @@ -466,7 +437,7 @@ def __post_init__(self): def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: (rotary_config,) = fast_llm_values - if type(rotary_config) is DefaultRotaryConfig: + if type(rotary_config) is DefaultRotaryConfig or rotary_config is Rotary2DConfig: rotary_scaling = { "rope_type": "default", } @@ -663,6 +634,34 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A return (num_heads, num_heads) +class PixtralRotaryParamConverter(ParamConverter): + """ + Pixtral encoder uses 2D Rotary Embeddings. + Map `rope_theta` to a single `rotary` parameter. `rotary_scaling` is not needed. + """ + + def __init__(self, fast_llm_names, export_names): + Assert.eq(len(fast_llm_names), 1) + Assert.eq(len(export_names), 1) + self.fast_llm_names = fast_llm_names + self.export_names = export_names + + def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (rotary_config,) = fast_llm_values + if type(rotary_config) is Rotary2DConfig: + return (rotary_config.theta,) + else: + raise ValueError(f"Unsupported rotary type: {type(rotary_config).__name__}") + + def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (rotary_theta,) = export_values + rotary_config = { + "type": "rope_2d", + "theta": rotary_theta, + } + return (rotary_config,) + + class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @@ -670,17 +669,13 @@ class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, Huggingfa @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value="pixtral"), + ConstantImportParamConverter(fast_llm_names=(("patch_norm", "type"),), fast_llm_value="rms_norm"), ConstantImportParamConverter( - fast_llm_names=(("patch_norm", "type"),), fast_llm_value=NormalizationType.rms_norm - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "type"),), fast_llm_value=TransformerType.image_encoder + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value="rms_norm" ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "type"),), fast_llm_value="image_encoder"), ConstantExportParamConverter(export_names=(("architectures",),), export_value=["PixtralVisionModel"]), - ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value=VisionEncoderType.pixtral), ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), RenameParamConverter( fast_llm_names=( @@ -737,17 +732,21 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), export_names=(("head_dim",),), ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.rope_2d - ), - RenameParamConverter( - fast_llm_names=( - ( - "transformer", - "rotary", - "theta", - ), - ), + # ConstantImportParamConverter( + # fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.rope_2d + # ), + # RenameParamConverter( + # fast_llm_names=( + # ( + # "transformer", + # "rotary", + # "theta", + # ), + # ), + # export_names=(("rope_theta",),), + # ), + PixtralRotaryParamConverter( + fast_llm_names=(("transformer", "rotary"),), export_names=(("rope_theta",),), ), RenameParamConverter(fast_llm_names=(("patch_size",),), export_names=(("patch_size",),)), @@ -773,7 +772,7 @@ def _create_vision_transformer_layer_converters( ) -> list[WeightConverter]: # Vision transformer layer transformer_config = self._model.config.base_model.vision_encoder.transformer - norm_bias: bool = transformer_config.normalization.type == NormalizationType.layer_norm + norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) name_bias_cls = [ # Self-attn ( @@ -828,11 +827,12 @@ def _create_vision_transformer_layer_converters( def _create_weight_converters(self, offset: int = 0, hf_base_prefix: str = "") -> list[WeightConverter]: converters = [] + norm_bias = isinstance(self._model.config.base_model.vision_encoder.patch_norm, LayerNormalizationConfig) converters.append(WeightConverter(f"layers.{offset}.weight", f"{hf_base_prefix}patch_conv.weight")) if self._model.config.base_model.vision_encoder.conv_bias: converters.append(WeightConverter(f"layers.{offset}.bias", f"{hf_base_prefix}patch_conv.bias")) converters.append(WeightConverter(f"layers.{offset}.norm.weight", f"{hf_base_prefix}ln_pre.weight")) - if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: + if norm_bias: converters.append(WeightConverter(f"layers.{offset}.norm.bias", f"{hf_base_prefix}ln_pre.bias")) num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 38639fc5f..436b4a60f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -72,10 +72,7 @@ def __init__( if self._config.vision_encoder.enabled: self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) - if self._config.vision_encoder.transformer.rotary.enabled: - self._preprocessors.append( - RotaryEmbeddingPreprocessor(self._config.vision_encoder.transformer.rotary, self._tensor_space) - ) + self._preprocessors.append(self._config.vision_encoder.transformer.rotary.build(self._tensor_space)) def get_output_layers(self) -> list[Layer]: layers = [] From 8f93a276e7406d6195edaacda4fbdd774d193356 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Jul 2025 19:34:41 +0000 Subject: [PATCH 084/161] fix prepare config --- fast_llm/data/preparator/gpt_memmap/config.py | 18 ++++++++++++------ fast_llm/data/preparator/gpt_memmap/prepare.py | 18 +++++++++++------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 9f25cba4c..da353793d 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -42,6 +42,18 @@ class TextColumnConfig(SourceSchemaConfig): ) +@config_class(dynamic_type={SourceSchemaConfig: "text_image_column"}) +class TextImageColumnConfig(TextColumnConfig): + images_column: str = Field( + default="images", + desc="Field containing images relevant to a document.", + ) + image_positions_column: None | str = Field( + default="image_positions", + desc="Field containing image positions within a document.", + ) + + @config_class() class GPTHuggingfaceDatasetConfig(Config): path: str = Field( @@ -79,12 +91,6 @@ class GPTHuggingfaceDatasetConfig(Config): rejected_text: None | str = Field( default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional ) - image_positions: None | str = Field( - default=None, desc="Field containing image positions within a document", hint=FieldHint.optional - ) - images: None | str = Field( - default=None, desc="Field containing images relevant to a document", hint=FieldHint.optional - ) data_type: DataType | None = Field( default=None, desc="Data type of the dataset field." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index dee1b37bf..b100ce400 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -27,7 +27,11 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.config import DatasetPreparator -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig +from fast_llm.data.preparator.gpt_memmap.config import ( + GPTMemmapDatasetPreparatorConfig, + TextColumnConfig, + TextImageColumnConfig, +) from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -60,9 +64,9 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ im_char_positions, ) for text, loss_mask_spans, im_char_positions in zip( - batch[self._config.dataset.field], - batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), - batch.get(self._config.dataset.image_positions, itertools.repeat(None)), + batch[self._text_column], + batch.get(self._loss_masking_spans_column, itertools.repeat(None)), + batch.get(self._image_positions_column, itertools.repeat(None)), ) ] ] @@ -160,8 +164,8 @@ def _document_generator(): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( np.array(item["input_ids"], dtype=self._data_type.numpy), - item["images"] if self._config.dataset.images else None, - item["image_positions"] if self._config.dataset.image_positions else None, + item["images"] if self._images_column else None, + item["image_positions"] if self._image_positions_column else None, ( np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2) if self._loss_masking_spans_column @@ -344,7 +348,7 @@ def run(self) -> None: total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens")) total_pixels = ( sum(tqdm.tqdm(tokenized_dataset["num_pixels"], desc="Counting pixels", unit="pixels")) - if self._config.dataset.images + if self._images_column else 0 ) # Add the token-equivalent bytes of pixels to determine shard size From c3eda1c1c792a144d353e99f6d4532ec644a837a Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Jul 2025 19:44:18 +0000 Subject: [PATCH 085/161] fix imports --- .github/workflows/ci.yaml | 2 +- .github/workflows/docs.yaml | 2 +- fast_llm/layers/vision_encoder/preprocessing.py | 14 +++++++++----- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 03353a79b..cb5260dca 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -32,7 +32,7 @@ jobs: pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV,DOCS]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,DEV,DOCS]" - name: Run tests run: pytest . diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index b509b2702..75ba3bb31 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -34,7 +34,7 @@ jobs: pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV,DOCS]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,DEV,DOCS]" - name: Build the documentation run: mkdocs build diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index ebd41b3d7..c81e7c646 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -2,7 +2,7 @@ import typing import torch -import torchvision.transforms.v2.functional as F +import torchvision from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -62,17 +62,21 @@ def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int height, width = get_resize_dims( height, width, intermediate_max_height, intermediate_max_width, patch_size=patch_size ) - image = F.resize(image, size=(height, width), interpolation=F.InterpolationMode.BICUBIC) + image = torchvision.transforms.v2.functional.resize( + image, size=(height, width), interpolation=torchvision.transforms.InterpolationMode.BICUBIC + ) # TODO: options for interpolation mode? - return F.resize(image, size=(target_height, target_width), interpolation=F.InterpolationMode.BICUBIC) + return torchvision.transforms.v2.functional.resize( + image, size=(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.BICUBIC + ) def normalize(image: torch.Tensor, mean: list[float], std: list[float]) -> torch.Tensor: """ Normalize the image using the specified mean and standard deviation. """ - return F.normalize(image, mean=mean, std=std) + return torchvision.transforms.v2.functional.normalize(image, mean=mean, std=std) def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: @@ -81,7 +85,7 @@ def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: """ width_padding = max(0, max_height - image.size(1)) depth_padding = max(0, max_width - image.size(2)) - return F.pad(image, (0, 0, depth_padding, width_padding), 0) + return torchvision.transforms.v2.functional.pad(image, (0, 0, depth_padding, width_padding), 0) def create_inv_freqs(rope_theta: int, kv_channels: int, max_image_size: int, patch_size: int) -> torch.Tensor: From 1cf0ea0285bdfb8754ce6030ff03be6582d3c7c5 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Jul 2025 21:46:26 +0000 Subject: [PATCH 086/161] fix tests --- fast_llm/data/dataset/gpt/fim.py | 6 +++--- fast_llm/data/dataset/gpt/indexed.py | 13 +++++++++++-- fast_llm/data/preparator/gpt_memmap/prepare.py | 2 +- tests/data/common.py | 4 ++-- tests/data/test_sampling.py | 1 + 5 files changed, 18 insertions(+), 8 deletions(-) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 2b2c8b3be..b05b79b24 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -158,9 +158,9 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=np.int64) - middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=np.int64) - suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=np.int64) + prefix = np.array([*self._tokenizer._tokenize(prefix, end=False)], dtype=np.int64) + middle = np.array([*self._tokenizer._tokenize(middle, begin=False, end=False)], dtype=np.int64) + suffix = np.array([*self._tokenizer._tokenize(suffix, begin=False)], dtype=np.int64) # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 2c7aefc80..8a4440ae4 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -53,7 +53,7 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. doc_sizes, im_sizes = self._dataset.get_document_sizes() - return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else [] + return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else np.array([]) def get_document_size(self, index: int) -> int: return self._dataset.get_document_size(self._begin + index) @@ -70,8 +70,17 @@ class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. - return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) + # return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) + sizes = [dataset.get_document_sizes() for dataset in self._datasets] + return ( + np.concatenate([size[0] for size in sizes]), + np.concatenate([size[1] for size in sizes]) if sizes[0][1] is not None else np.array([]), + ) def get_document_size(self, index: int) -> int: dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) + + @property + def has_images(self) -> bool: + return any(dataset.has_images for dataset in self._datasets) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index b100ce400..c6a0528f1 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -425,7 +425,7 @@ def _split_and_blend_dataset_configs( dataset_configs: list[GPTMemmapDatasetConfig], splits: dict[str, int | float], output_path: pathlib.Path, - image_patch_size: int, + image_patch_size: None | int = None, ) -> dict[str, GPTSampledDatasetConfig]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] diff --git a/tests/data/common.py b/tests/data/common.py index 2bb90a6b4..858380816 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -127,10 +127,10 @@ def compare_indexed_dataset( loss_masking_spans: dict[int, list[int]] | None = None, ) -> None: Assert.eq(len(dataset), length) - sizes = dataset.get_document_sizes() + text_sizes, image_sizes = dataset.get_document_sizes() # Assert.eq(sizes.sum(), num_tokens) Assert.all_equal( - [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)] + [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], text_sizes[: min(len(dataset), 100)] ) for i, expected_sample in expected_samples.items(): Assert.all_equal(dataset.get(i).token_ids, np.array(expected_sample, dtype=np.uint16)) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 123e5e955..b8e7a92ff 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -106,6 +106,7 @@ def get_document_size(self, index: int) -> int: def name(self) -> str: return "dataset" + @property def has_images(self) -> bool: return False From 77d294c76157210caf9856b23dc9a52ca1d4f44c Mon Sep 17 00:00:00 2001 From: root Date: Wed, 9 Jul 2025 17:57:48 +0000 Subject: [PATCH 087/161] fix tests --- fast_llm/data/dataset/gpt/memmap.py | 3 ++- fast_llm/data/dataset/gpt/sampled.py | 2 +- .../data/preparator/gpt_memmap/prepare.py | 4 ++-- tests/data/common.py | 7 +++++- tests/data/test_sampling.py | 10 ++++++-- tests/test_config.py | 24 +++++++++++++++++++ 6 files changed, 43 insertions(+), 7 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index c7a99f10f..2a1986b63 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -136,7 +136,7 @@ def _init( offset += np.array(self._chosen_spans).nbytes + np.array(self._rejected_spans).nbytes self._num_pixels = 0 - self._image_sizes = None + self._image_sizes = [] self._image_positions = None if self._has_images and self._version >= 4: self._n_images = np.frombuffer( @@ -177,6 +177,7 @@ def _init( assert self._num_pixels == num_pixels if num_tokens is not None: assert self._num_tokens == num_tokens + self._image_sizes = np.array(self._image_sizes, dtype=np.int32) def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 29a784b77..42062a58c 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -143,7 +143,7 @@ def _sample(self) -> None: # Get the document sizes, the main information needed for sampling. document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) - if image_sizes: + if image_sizes.any(): image_token_sizes = [] for i, sizes in enumerate(image_sizes): image_token_sizes.append( diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index c6a0528f1..fce0f022c 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -458,7 +458,7 @@ def _split_and_blend_dataset_configs( text_sizes, image_sizes = dataset.get_document_sizes() tokens_cumsum = text_sizes.cumsum() Assert.eq(tokens_cumsum[-1], dataset_config.num_tokens) - if image_sizes: + if image_sizes.any(): num_pixels_cumsum = np.cumsum([x.prod(axis=1).sum() for x in image_sizes]) # We use the patch sizes only for the purposes of even splitting and blending weights. # We can always use a different patch size for training without any significant impact @@ -466,7 +466,7 @@ def _split_and_blend_dataset_configs( image_tokens_cumsum = num_pixels_cumsum // (image_patch_size**2) tokens_cumsum += image_tokens_cumsum num_pixels_cumsum = num_pixels_cumsum * 3 - Assert.eq(num_pixels_cumsum[-1], dataset_config.num_pixels) + Assert.eq(num_pixels_cumsum[-1], dataset_config.num_pixels) begin_index = _get_nearest_split(tokens_cumsum, split_begin_in_dataset * tokens_cumsum[-1]) end_index = _get_nearest_split(tokens_cumsum, split_end_in_dataset * tokens_cumsum[-1]) if end_index > begin_index: diff --git a/tests/data/common.py b/tests/data/common.py index 858380816..23ed9d76b 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -224,10 +224,15 @@ def __len__(self) -> int: return self._config.num_documents def get_document_sizes(self) -> np.ndarray: - return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64) + return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64), np.array( + [], dtype=np.int64 + ) def get_document_size(self, index: int) -> int: return self._config.num_tokens_per_document def get(self, index: int, *args, **kwargs) -> typing.Any: raise NotImplementedError() + + def has_images(self) -> bool: + return False diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index b8e7a92ff..296102f7d 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -98,10 +98,16 @@ def __len__(self) -> int: return len(self._samples) def get_document_sizes(self) -> np.ndarray: - return np.array([self.get_document_size(index) for index in range(len(self))], dtype=np.int64) + doc_sizes = [] + im_sizes = [] + for index in range(len(self)): + doc_size, im_size = self.get_document_size(index) + doc_sizes.append(doc_size) + im_sizes.append(im_size) + return np.array(doc_sizes, dtype=np.int64), np.array(im_sizes, dtype=np.int64) def get_document_size(self, index: int) -> int: - return len(self._samples[index]) + return len(self._samples[index]), [] def name(self) -> str: return "dataset" diff --git a/tests/test_config.py b/tests/test_config.py index b6a9a9854..c12ef9f03 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -88,6 +88,14 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "multi_stage": {"zero_stage": 3}, "distributed": {"training_dtype": "bfloat16"}, + # "vision_encoder": { + # "type": "none", + # "transformer": { + # "normalization": { + # "type": "rms_norm", + # } + # } + # } } ) with NoAutoValidate(): @@ -137,6 +145,14 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "tie_word_embeddings": False, "vocab_size": 1000, + "vision_encoder": { + "transformer": { + "normalization": {"type": "layer_norm"}, + "rotary": {"type": "none"}, + "peft": {"type": "none"}, + }, + "patch_norm": {"type": "layer_norm"}, + }, } else: base_model_update["transformer"]["peft"] = { @@ -146,6 +162,14 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): } base_model_update["transformer"]["normalization"]["type"] = "layer_norm" base_model_update["transformer"]["rotary"] = {"type": "none"} + base_model_update["vision_encoder"] = { + "transformer": { + "normalization": {"type": "layer_norm"}, + "rotary": {"type": "none"}, + "peft": {"type": "none"}, + }, + "patch_norm": {"type": "layer_norm"}, + } expected_config["base_model"] = base_model_update check_equal_nested(serialized_config, expected_config) From 8434b20e7f2dc3e4885ab87772bea3694fddaff2 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 9 Jul 2025 18:00:58 +0000 Subject: [PATCH 088/161] cleanup --- tests/test_config.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index c12ef9f03..52c00f0a1 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -88,14 +88,6 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "multi_stage": {"zero_stage": 3}, "distributed": {"training_dtype": "bfloat16"}, - # "vision_encoder": { - # "type": "none", - # "transformer": { - # "normalization": { - # "type": "rms_norm", - # } - # } - # } } ) with NoAutoValidate(): From a0b6e45b3e6af541e48703b0b358b19149c3a471 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 17 Jul 2025 14:05:06 +0000 Subject: [PATCH 089/161] add assert --- fast_llm/data/dataset/gpt/sampled.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 42062a58c..969cafa74 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -144,6 +144,10 @@ def _sample(self) -> None: document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) if image_sizes.any(): + assert self._parameters.max_image_size is not None, ( + f"Dataset {self._indexed_dataset.name} contains images, but no max_image_size is set." + f"image_sizes: {image_sizes}, max_image_size: {self._parameters.max_image_size}" + ) image_token_sizes = [] for i, sizes in enumerate(image_sizes): image_token_sizes.append( From d35c6853ef0156ba551cf20fa3f07db0da86bfa9 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 17 Jul 2025 15:00:54 +0000 Subject: [PATCH 090/161] move check to config validation --- fast_llm/data/dataset/gpt/sampled.py | 4 ---- fast_llm/models/gpt/config.py | 3 +++ 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 969cafa74..42062a58c 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -144,10 +144,6 @@ def _sample(self) -> None: document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) if image_sizes.any(): - assert self._parameters.max_image_size is not None, ( - f"Dataset {self._indexed_dataset.name} contains images, but no max_image_size is set." - f"image_sizes: {image_sizes}, max_image_size: {self._parameters.max_image_size}" - ) image_token_sizes = [] for i, sizes in enumerate(image_sizes): image_token_sizes.append( diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 039b97f8c..bc64821f2 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -257,6 +257,9 @@ def _validate(self) -> None: Assert.none(reference_model.model.base_model.cross_entropy_splits) Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) + if self.model.base_model.vision_encoder.enabled: + assert self.batch.max_image_size is not None, "max_image_size must be set when using vision encoder" + Assert.gt(self.batch.max_image_size, 0) @classmethod def _from_dict( From ef982c9cd1d73ed4ddf4d117603e07d9a7eae697 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 17 Jul 2025 15:59:39 +0000 Subject: [PATCH 091/161] fix torchvision import --- fast_llm/layers/vision_encoder/preprocessing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index c81e7c646..3b857ba26 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -2,7 +2,7 @@ import typing import torch -import torchvision +import torchvision.transforms.v2 as torchvision_transforms from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -62,13 +62,13 @@ def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int height, width = get_resize_dims( height, width, intermediate_max_height, intermediate_max_width, patch_size=patch_size ) - image = torchvision.transforms.v2.functional.resize( - image, size=(height, width), interpolation=torchvision.transforms.InterpolationMode.BICUBIC + image = torchvision_transforms.functional.resize( + image, size=(height, width), interpolation=torchvision_transforms.InterpolationMode.BICUBIC ) # TODO: options for interpolation mode? - return torchvision.transforms.v2.functional.resize( - image, size=(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.BICUBIC + return torchvision_transforms.functional.resize( + image, size=(target_height, target_width), interpolation=torchvision_transforms.InterpolationMode.BICUBIC ) @@ -76,7 +76,7 @@ def normalize(image: torch.Tensor, mean: list[float], std: list[float]) -> torch """ Normalize the image using the specified mean and standard deviation. """ - return torchvision.transforms.v2.functional.normalize(image, mean=mean, std=std) + return torchvision_transforms.functional.normalize(image, mean=mean, std=std) def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: @@ -85,7 +85,7 @@ def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: """ width_padding = max(0, max_height - image.size(1)) depth_padding = max(0, max_width - image.size(2)) - return torchvision.transforms.v2.functional.pad(image, (0, 0, depth_padding, width_padding), 0) + return torchvision_transforms.functional.pad(image, (0, 0, depth_padding, width_padding), 0) def create_inv_freqs(rope_theta: int, kv_channels: int, max_image_size: int, patch_size: int) -> torch.Tensor: From 3345ab122a9d5aa50f17c11cfa95e4c7a1279cef Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 17 Jul 2025 18:25:15 +0000 Subject: [PATCH 092/161] debug log --- fast_llm/engine/distributed/distributed.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 200074ee9..e9247871a 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -46,6 +46,8 @@ def __init__( Assert.in_range_incl(self._local_world_size, 1, torch.cuda.device_count()) torch.cuda.init() self._device = torch.device(self._rank) + logger.info(f"Using device {self._device} for rank {self._rank}.") + logger.info(f"Number of local devices: {torch.cuda.device_count()}.") torch.cuda.set_device(self._device) if self._world_size > 1: From b0b52fa98b366ffa52c3844b177b88d3b558c36e Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 17 Jul 2025 18:35:27 +0000 Subject: [PATCH 093/161] fix device --- fast_llm/engine/distributed/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index e9247871a..1f8798305 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -45,7 +45,7 @@ def __init__( else: Assert.in_range_incl(self._local_world_size, 1, torch.cuda.device_count()) torch.cuda.init() - self._device = torch.device(self._rank) + self._device = torch.device(self._rank % self._local_world_size) logger.info(f"Using device {self._device} for rank {self._rank}.") logger.info(f"Number of local devices: {torch.cuda.device_count()}.") torch.cuda.set_device(self._device) From 2be288ca52a459a24d086d520e1e9ad338e442c7 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 17 Jul 2025 19:41:57 +0000 Subject: [PATCH 094/161] remove log --- fast_llm/engine/distributed/distributed.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 1f8798305..f17a8f452 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -46,8 +46,6 @@ def __init__( Assert.in_range_incl(self._local_world_size, 1, torch.cuda.device_count()) torch.cuda.init() self._device = torch.device(self._rank % self._local_world_size) - logger.info(f"Using device {self._device} for rank {self._rank}.") - logger.info(f"Number of local devices: {torch.cuda.device_count()}.") torch.cuda.set_device(self._device) if self._world_size > 1: From d1caa987e648719ca8418c8bd93a2f31a46e7d2c Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 17 Jul 2025 20:47:42 +0000 Subject: [PATCH 095/161] fix name --- fast_llm/layers/ssm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c69ada389..46d629aa8 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -22,7 +22,7 @@ class SSMDimNames: v_heads = "v_heads" # Number of V heads # Mamba 2 - x_proj_dim_2 = "x_proj_dim" # d_xb + x_proj_dim_2 = "x_proj_dim_2" # d_xb class SSMBlockType(enum.StrEnum): From 0fbe88156f2fd8c9d218f554b63bb699eba49cc0 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 18 Jul 2025 13:48:46 +0000 Subject: [PATCH 096/161] fix hybrid get_layers --- fast_llm/models/gpt/model.py | 12 +++++++----- fast_llm/models/ssm/model.py | 3 +-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 6356cf23d..8d70a8944 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -110,13 +110,15 @@ def get_vision_layers(self) -> list[Layer]: MultiModalEmbedding(self._config, self._tensor_space), ] + def get_embedding_layers(self) -> list[Layer]: + if self._config.vision_encoder.enabled: + return self.get_vision_layers() + else: + return [LanguageModelEmbedding(self._config, self._tensor_space)] + def get_layers(self) -> list[Layer]: return [ - *( - [LanguageModelEmbedding(self._config, self._tensor_space)] - if not self._config.vision_encoder.enabled - else self.get_vision_layers() - ), + *(self.get_embedding_layers()), *[ TransformerLayer( self._config.transformer, diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 02a5ac239..80f9ca8ba 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -3,7 +3,6 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.llamba_block import LlambaBlock @@ -94,7 +93,7 @@ def get_layers(self) -> list[Layer]: Create a list of layers for the model, interleaving Transformer and Mamba blocks according to the block pattern. """ - layers = [LanguageModelEmbedding(self._config, self._tensor_space)] + layers = self.get_embedding_layers() # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): From 854f305507f5d8852a8f8e5cc56cb3ca1b6248d5 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 18 Jul 2025 14:21:02 +0000 Subject: [PATCH 097/161] debug --- fast_llm/layers/ssm/mamba2.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a03509abb..96116abe5 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,3 +1,4 @@ +import logging import math import typing @@ -10,6 +11,8 @@ from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ from fast_llm.utils import get_lr_scale +logger = logging.getLogger(__name__) + try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -144,6 +147,9 @@ def init_from_tensor_( value: torch.Tensor, ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + logger.info( + f"Initializing {meta.tensor_name} with shape {meta.shape} from tensor with shape {value.shape}" + ) return tensor.copy_(value) return init_ @@ -156,6 +162,7 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) lr_scale=mamba_layer_lr_scale, ) # define bias outside the linear layer since its also used in the selective_scan_fn + logger.info(f"td_inner: {td_inner}, inv_dt: {inv_dt.shape}") self.dt_proj_bias = ParameterMeta.from_dims( (td_inner,), init_method=init_from_tensor_(inv_dt), lr_scale=mamba_layer_lr_scale ) @@ -166,6 +173,7 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) d=self.d_inner, ).contiguous() A_log = torch.log(A).flatten() # Keep A_log in fp32 + logger.info(f"A_log: {A_log.shape}, td_inner: {td_inner}, td_state: {td_state}") self.A_log = ParameterMeta.from_dims( (td_inner, td_state), init_method=init_from_tensor_(A_log), From b7b81931ad634646476663222f04585957633bee Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 18 Jul 2025 20:30:41 +0000 Subject: [PATCH 098/161] add llava hybrid format --- fast_llm/models/ssm/config.py | 20 ++++++++++++++++++++ fast_llm/models/ssm/conversion.py | 11 ++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index cc83f11be..95c8ca84c 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -166,6 +166,26 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler +# class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): +# name: typing.ClassVar[str] = "llava" +# # Using default values for vision and text models. Can be overridden in the config +# vision_name: typing.ClassVar[str] = "pixtral" +# text_name: typing.ClassVar[str] = "mistral" + + +class LlavaHybridHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "llava_hybrid" + vision_name: typing.ClassVar[str] = "pixtral" + text_name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.ssm.conversion import LlavaHybridHuggingfaceCheckpointHandler + + return LlavaHybridHuggingfaceCheckpointHandler + + @config_class(dynamic_type={FastLLMModelConfig: "hybrid_ssm"}) class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index d57300252..ddf45c6c5 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -20,13 +20,18 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig from fast_llm.layers.ssm.config import SSMBlockType -from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter +from fast_llm.models.gpt.conversion import ( + CommonLlamaHuggingfaceCheckpointHandler, + LlavaHuggingfaceCheckpointHandler, + MLPLayer2Converter, +) from fast_llm.models.ssm.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, AprielSSMHuggingfaceCheckpointFormat, AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat, + LlavaHybridHuggingfaceCheckpointFormat, ) from fast_llm.models.ssm.model import HybridSSMModel from fast_llm.utils import Assert @@ -762,3 +767,7 @@ def _load_config(cls, directory: pathlib.Path | str) -> dict: def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: with open(directory / "config.json", "w") as f: json.dump(config, f) + + +class LlavaHybridHuggingfaceCheckpointHandler(LlavaHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = LlavaHybridHuggingfaceCheckpointFormat From a1589da4f320c05e2515ec8559109e0fb6f6b3cc Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 18 Jul 2025 20:47:04 +0000 Subject: [PATCH 099/161] workaround init --- fast_llm/layers/ssm/mamba2.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 96116abe5..8a61a8969 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -148,9 +148,14 @@ def init_from_tensor_( ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa logger.info( - f"Initializing {meta.tensor_name} with shape {meta.shape} from tensor with shape {value.shape}" + f"Initializing {meta.tensor_name} with shape {meta.shape}, tensor shape {tensor.shape} from value shape {value.shape}" ) - return tensor.copy_(value) + # TODO: fix and remove try-except + try: + return tensor.copy_(value) + except RuntimeError as e: + logger.error(f"Failed to copy value to tensor: {e}") + return tensor.fill_(0.0) return init_ From a202b4ce4d8ba34edf867895dbe4e40b06039918 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 18 Jul 2025 21:00:26 +0000 Subject: [PATCH 100/161] update --- fast_llm/models/ssm/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 95c8ca84c..55a7ef548 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -196,6 +196,7 @@ class HybridSSMModelConfig(FastLLMModelConfig): AprielSSMHuggingfaceCheckpointFormat, AprielSSMHHybridHuggingfaceCheckpointFormat, AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + LlavaHybridHuggingfaceCheckpointFormat, ) @classmethod From b675ec2d1e073a0430962832dd37fd86e08fe768 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 18 Jul 2025 21:16:16 +0000 Subject: [PATCH 101/161] update --- fast_llm/models/ssm/conversion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index ddf45c6c5..0b2a36c63 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -771,3 +771,4 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class LlavaHybridHuggingfaceCheckpointHandler(LlavaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaHybridHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig From a055d2a8aff442173291cb890a31a3664fd15262 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 18 Jul 2025 22:15:47 +0000 Subject: [PATCH 102/161] update --- fast_llm/models/gpt/conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index e01aaf702..3a53754db 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -664,7 +664,7 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat - _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + _model_class: typing.ClassVar[FastLLMModelConfig] = FastLLMModelConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: From 8e4ef5d52f94960820b8dbe062c12256e3249373 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 18 Jul 2025 22:43:02 +0000 Subject: [PATCH 103/161] refactoring attempt --- fast_llm/models/gpt/conversion.py | 17 +++++++++++++---- fast_llm/models/ssm/conversion.py | 7 +++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 3a53754db..79d7099cd 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -14,6 +14,7 @@ AutoStateDictCheckpointHandler, ConstantExportParamConverter, ConstantImportParamConverter, + ExternalStateDictCheckpointHandler, IgnoreExportWeightConverter, IgnoreImportParamConverter, IgnoreImportWeightConverter, @@ -873,6 +874,14 @@ class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + @classmethod + def get_vision_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: + return AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) + + @classmethod + def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: + return AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + @classmethod def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: cfg_dict = cls._load_config(config.path) @@ -944,8 +953,8 @@ def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: @classmethod def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: exported_config = {} - vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) - text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + vision_handler_cls = cls.get_vision_handler_class() + text_handler_cls = cls.get_text_handler_class() for converter in vision_handler_cls._create_config_converters(): try: values = converter.export_params( @@ -991,10 +1000,10 @@ def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: return exported_config def _create_weight_converters(self): - vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.vision_name) + vision_handler_cls = self.get_vision_handler_class() vision_handler = vision_handler_cls(self._model) converters = vision_handler._create_weight_converters(hf_base_prefix="vision_tower.", offset=0) - text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.text_name) + text_handler_cls = self.get_text_handler_class() text_handler = text_handler_cls(self._model) converters.extend( text_handler._create_weight_converters(hf_base_prefix="language_model.", offset=vision_handler.num_layers) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 0b2a36c63..d62a0549a 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -7,6 +7,7 @@ from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, ConstantImportParamConverter, + ExternalStateDictCheckpointHandler, IgnoreImportParamConverter, IgnoreImportWeightConverter, MappedConfigParamConverter, @@ -772,3 +773,9 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class LlavaHybridHuggingfaceCheckpointHandler(LlavaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaHybridHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + + @classmethod + def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: + from fast_llm.models.ssm.conversion import AprielThinkerSSMHHybridHuggingfaceCheckpointHandler + + return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler From 4bfce67042390d0eef54a5697457469b2443a812 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 21 Jul 2025 14:11:11 +0000 Subject: [PATCH 104/161] update ssm conversion: use hf_prefix/offset --- fast_llm/models/ssm/conversion.py | 59 +++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 18 deletions(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index d62a0549a..55ed3058a 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -221,7 +221,11 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ] - def _create_weight_converters(self) -> list[WeightConverter]: + def _create_weight_converters( + self, + hf_base_prefix: str = "", + offset: int = 0, + ) -> list[WeightConverter]: converters = super()._create_weight_converters() or [] num_layers = self._model.config.base_model.transformer.num_layers @@ -230,55 +234,65 @@ def _create_weight_converters(self) -> list[WeightConverter]: for i in range(num_layers): # SSM converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.in_proj", f"model.layers.{i}.mixer.in_proj", ssm_bias + f"layers.{offset+i+1}.mixer.in_proj", f"{hf_base_prefix}model.layers.{i}.mixer.in_proj", ssm_bias ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.out_proj", f"model.layers.{i}.mixer.out_proj", ssm_bias + f"layers.{offset+i+1}.mixer.out_proj", f"{hf_base_prefix}model.layers.{i}.mixer.out_proj", ssm_bias ) converters.append( - WeightConverter(f"layers.{i+1}.mixer.D", f"model.layers.{i}.mixer.D", self._model.config.base_model) + WeightConverter( + f"layers.{offset+i+1}.mixer.D", + f"{hf_base_prefix}model.layers.{i}.mixer.D", + self._model.config.base_model, + ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model + f"layers.{offset+i+1}.mixer.z_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.z_bias", + self._model.config.base_model, ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model + f"layers.{offset+i+1}.mixer.z_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.z_bias", + self._model.config.base_model, ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.conv1d_weight", - f"model.layers.{i}.mixer.conv1d.weight", + f"layers.{offset+i+1}.mixer.conv1d_weight", + f"{hf_base_prefix}model.layers.{i}.mixer.conv1d.weight", self._model.config.base_model, ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.conv1d_bias", - f"model.layers.{i}.mixer.conv1d.bias", + f"layers.{offset+i+1}.mixer.conv1d_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.conv1d.bias", self._model.config.base_model, ) ) # ================================================ # Mamba2 specific parameters converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.dt_proj", f"model.layers.{i}.mixer.dt_proj", False + f"layers.{offset+i+1}.mixer.dt_proj", f"{hf_base_prefix}model.layers.{i}.mixer.dt_proj", False ) # bias is treated separately in Mamba2 and must always exist (https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py) converters.append( WeightConverter( - f"layers.{i+1}.mixer.dt_proj_bias", - f"model.layers.{i}.mixer.dt_proj.bias", + f"layers.{offset+i+1}.mixer.dt_proj_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.dt_proj.bias", self._model.config.base_model, ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.A_log", f"model.layers.{i}.mixer.A_log", self._model.config.base_model + f"layers.{offset+i+1}.mixer.A_log", + f"{hf_base_prefix}model.layers.{i}.mixer.A_log", + self._model.config.base_model, ) ) @@ -572,11 +586,16 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ] - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() + def _create_weight_converters( + self, + hf_base_prefix: str = "", + offset: int = 0, + ) -> list[WeightConverter]: + converters = super()._create_weight_converters(hf_base_prefix, offset) num_layers = self._model.config.base_model.transformer.num_layers norm_bias: bool = False + # TODO: use hf_base_prefix and offset # Embedding and output if self._model.config.base_model.tie_word_embeddings: converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) @@ -710,8 +729,12 @@ class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( _hf_prefix: str = "model" architecture: typing.ClassVar[str] = "AprielThinkerSSMHybridForCausalLM" - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() + def _create_weight_converters( + self, + hf_base_prefix: str = "", + offset: int = 0, + ) -> list[WeightConverter]: + converters = super()._create_weight_converters(hf_base_prefix, offset) # num_layers = self._model.config.base_model.transformer.num_layers # # Embedding and output # if self._model.config.base_model.tie_word_embeddings: From 82eed2b44c30c891ef2e07c2c80c4f5fcfa1e7f1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 21 Jul 2025 17:17:26 -0400 Subject: [PATCH 105/161] TP mamba --- fast_llm/layers/common/config.py | 6 +- fast_llm/layers/ssm/config.py | 214 +++++++++---- fast_llm/layers/ssm/discrete_mamba2.py | 39 ++- fast_llm/layers/ssm/llamba_block.py | 18 +- fast_llm/layers/ssm/mamba2.py | 302 +++++++----------- fast_llm/layers/ssm/mamba_layer.py | 159 ++++----- fast_llm/layers/transformer/attention.py | 3 +- fast_llm/layers/transformer/transformer.py | 27 +- fast_llm/models/custom/model.py | 4 +- fast_llm/models/gpt/model.py | 8 +- fast_llm/models/ssm/config.py | 42 +-- .../external/llamba/modeling_mtp_llamba.py | 10 +- fast_llm/models/ssm/model.py | 34 +- fast_llm/tensor.py | 8 +- setup.cfg | 2 +- tests/test_multi_stage.py | 4 +- 16 files changed, 407 insertions(+), 473 deletions(-) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 9f32ac689..07dadbc22 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -99,7 +99,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig): ) def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.tensor import init_uniform_ + from fast_llm.tensor import init_uniform_centered_ kwargs = { "hidden_dim": hidden_dim, @@ -110,9 +110,7 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> " } if self.initialization_range: mean = 0 if self.zero_centered else 1 - kwargs["weight_init_method"] = init_uniform_( - mean - self.initialization_range, mean + self.initialization_range - ) + kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) return self.module_class(**kwargs) @property diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c69ada389..f4c8067dd 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,28 +1,35 @@ import enum from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace +from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div class SSMDimNames: - model_dim = "model_dim" # Model dimension (D) - state_dim = "state_dim" # State dimension (N) - conv_dim = "conv_dim" # Dimension of the conv1d input in mamba layers - inner_dim = "inner_dim" # Inner dimension after expansion - dt_rank = "dt_rank" # Rank of Δ - inner_proj_mamba = "inner_proj_mamba" # Inner projection dimension for mamba - inner_proj_discrete_mamba2 = "inner_proj_discrete_mamba2" # Inner projection dimension for discrete mamba2 - inner_proj_mamba2 = "inner_proj_mamba2" # Inner projection dimension for mamba2 - x_proj_dim = "x_proj_dim" # X projection dimension - head_dim = "head_dim" # Dimension of the mamba2 head (P) - conv_kernel_size = "conv_kernel_size" # Kernel size of the conv1d in mamba layers - qk_heads = "qk_heads" # Number of QK heads - v_heads = "v_heads" # Number of V heads + # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. + state = "ssm_state" # State dimension (N), aka head size / num channels + + head_groups = "ssm_head_groups" + group_heads = "ssm_group_heads" + + composite_heads = "ssm_composite_heads" + composite_heads_and_state = "ssm_composite_heads_and_state" + composite_head_groups_and_state = "ssm_composite_head_groups_and_state" + + # Inner projection total dimension. + inner_projection = "ssm_inner_projection" + composite_inner_projection = "ssm_composite_inner_projection" + + # Convolution shape in discrete mamba 2. TODO: Remove (dim too complex) + conv_dim = "ssm_conv_dim" + + dt_rank = "ssm_dt_rank" - # Mamba 2 - x_proj_dim_2 = "x_proj_dim" # d_xb + x_proj_dim = "x_proj_dim" # X projection dimension + conv_kernel = "conv_kernel" # Kernel size of the conv1d in mamba layers class SSMBlockType(enum.StrEnum): @@ -36,6 +43,16 @@ class SSMBlockType(enum.StrEnum): transformer = "t" +class DTInitType(enum.StrEnum): + constant = "constant" + random = "random" + + def get_init_method(self, scale: float): + from fast_llm.tensor import init_fill_, init_uniform_centered_ + + return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) + + @config_class() class SSMConfig(LLMBlockConfig): _abstract = False @@ -45,79 +62,87 @@ class SSMConfig(LLMBlockConfig): desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) + + # Model dimensions + # TODO: Remove (redundant default) expansion_factor: int = Field( default=2, desc="Expansion factor for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # head_size [MambaLayer, Mamba2, DiscreteMamba2] state_size: int = Field( default=16, desc="State size for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # [MambaLayer, Mamba2, DiscreteMamba2] conv_kernel_dimension: int = Field( default=4, desc="Conv kernel dimension for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - # Layer parameters - add_bias_linear: bool = Field( - default=False, - desc="Whether to use bias in SSM layers", - hint=FieldHint.architecture, - ) - + # [MambaLayer, Mamba2] dt_rank: None | int = Field( default=None, desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)", hint=FieldHint.architecture, ) - chunk_size: int = Field( - default=256, - desc="Chunk size for Mamba2 blocks.", - hint=FieldHint.architecture, - ) + # head_groups [DiscreteMamba2] n_qk_heads: int = Field( default=32, desc="Number of QK heads for Mamba2 blocks.", hint=FieldHint.architecture, ) + # heads [DiscreteMamba2]# TODO: Remove? (redundant) n_v_heads: int = Field( default=32, desc="Number of V heads for Mamba2 blocks.", hint=FieldHint.architecture, ) - activation_type: ActivationType = Field( + # c_size [MambaLayer, Mamba2, DiscreteMamba2]? + d_inner: None | int = Field( + default=None, + desc="Inner dimension for Mamba2 blocks.", + hint=FieldHint.core, + ) + # xb_size [Mamba2] + d_xb: int = Field( default=None, - desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", + desc="Dimension of the xB in Mamba2 blocks.", hint=FieldHint.architecture, ) - debug_ssm: bool = Field( + + # Model options + # add_bias_linear [Mamba2, DiscreteMamba2] [hard-coded to False in MambaLayer] + add_bias_linear: bool = Field( default=False, - desc="debug_ssm", - hint=FieldHint.optional, + desc="Whether to use bias in SSM layers", + hint=FieldHint.architecture, ) - dt_min: float = Field( - default=0.001, - desc="Minimum step size for discretization", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), + # activation_type [DiscreteMamba2] [hard-coded to silu in MambaLayer, Mamba2] + activation_type: ActivationType = Field( + default=None, + hint=FieldHint.architecture, ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), + # repeat_xb_before_conv [Mamba2] + repeat_kv_before_conv: bool = Field( + default=True, + desc="Whether to repeat x and B before (True) or after (False) the conv1d in Mamba2 blocks.", + hint=FieldHint.architecture, ) - - d_inner: None | int = Field( - default=None, - desc="Inner dimension for Mamba2 blocks.", - hint=FieldHint.core, + # chunk_size [DiscreteMamba2] + chunk_size: int = Field( + default=256, + desc="Chunk size for Mamba2 blocks.", + hint=FieldHint.architecture, ) + + # Learning rate + # lr_scale [MambaLayer, Mamba2, DiscreteMamba2] mamba_lr_scale: float | None = Field( default=None, desc="Learning rate scale for Mamba blocks.", @@ -125,43 +150,38 @@ class SSMConfig(LLMBlockConfig): valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - # Mamba 2 - repeat_kv_before_conv: bool = Field( - default=True, - desc="Whether to repeat the KV before the conv1d in Mamba2 blocks.", - hint=FieldHint.architecture, - ) - d_xb: int = Field( - default=None, - desc="Dimension of the xB in Mamba2 blocks.", - hint=FieldHint.architecture, - ) - dt_init: str = Field( + # Initialization + # dt_weight_initialization_method [Mamba2] + dt_init: DTInitType = Field( default="random", desc="Initialization method for dt", hint=FieldHint.core, ) - dt_max: float = Field( - default=0.1, - desc="Maximum step size for discretization", + # dt_weight_initialization_scale [Mamba2] + dt_scale: float = Field( + default=1.0, + desc="Scale for dt", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) + # dt_bias_initialization_min [MambaLayer, Mamba2] dt_min: float = Field( default=0.001, desc="Minimum step size for discretization", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", + # dt_bias_initialization_max [MambaLayer, Mamba2] + dt_max: float = Field( + default=0.1, + desc="Maximum step size for discretization", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_scale: float = Field( - default=1.0, - desc="Scale for dt", + # dt_bias_initialization_floor [MambaLayer, Mamba2] + dt_init_floor: float = Field( + default=1e-4, + desc="Minimum value for initializing dt", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) @@ -172,3 +192,59 @@ def _validate(self) -> None: self.activation_type = ActivationType.silu super()._validate() Assert.geq(self.dt_max, self.dt_min) + + def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: + tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + num_heads = div(self.d_inner, self.state_size) + # Head groups are configured differently depending on the block type. + if block_type == SSMBlockType.mamba: + num_head_groups = num_heads + # (head_groups, 2 * group_heads * state_dim) + inner_projection_size = self.d_inner * 2 + elif block_type == SSMBlockType.mamba2: + num_head_groups = div(self.d_xb, self.state_size) + # (head_groups, 2 * group_heads + 2, state_dim) + (dt,) + inner_projection_size: int = 2 * self.d_inner + 2 * num_head_groups * self.state_size + self.dt_rank + elif block_type == SSMBlockType.mamba2_discrete: + Assert.eq(num_heads, self.n_v_heads) + num_head_groups = self.n_qk_heads + # (head_groups, (2 * group_heads + 2) * state_dim + group_heads) + inner_projection_size = 2 * self.d_inner + 2 * num_head_groups * self.state_size + num_heads + else: + raise NotImplementedError(block_type) + + tensor_space.add_tensor_dim(state_dim := TensorDim(SSMDimNames.state, self.state_size)) + tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) + tensor_space.add_tensor_dim( + group_heads := TensorDim(SSMDimNames.group_heads, num_group_heads := div(num_heads, num_head_groups)) + ) + tensor_space.add_tensor_dim(CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads))) + tensor_space.add_tensor_dim( + CompositeTensorDim(SSMDimNames.composite_heads_and_state, (head_groups, group_heads, state_dim)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(SSMDimNames.composite_head_groups_and_state, (head_groups, state_dim)) + ) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel, self.conv_kernel_dimension)) + + # DT projection + if block_type in (SSMBlockType.mamba, SSMBlockType.mamba2): + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, self.dt_rank)) + + if block_type == SSMBlockType.mamba: + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.dt_rank + self.state_size * 2)) + inner_projection_size = 2 * num_group_heads * self.state_size + elif block_type == SSMBlockType.mamba2: + inner_projection_size = 2 * (num_group_heads + 1) * self.state_size + elif block_type == SSMBlockType.mamba2_discrete: + inner_projection_size = 2 * (num_group_heads + 1) * self.state_size + num_group_heads + # TODO: (head_groups, group_heads + 2, state_size) + tensor_space.add_tensor_dim( + TensorDim(SSMDimNames.conv_dim, self.d_inner + 2 * self.n_qk_heads * self.state_size) + ) + + tensor_space.add_tensor_dim(inner_projection := TensorDim(SSMDimNames.inner_projection, inner_projection_size)) + tensor_space.add_tensor_dim( + CompositeTensorDim(SSMDimNames.composite_inner_projection, (head_groups, inner_projection)) + ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 934cd2b5d..d06b47965 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,5 +1,6 @@ import logging import math +import typing import einops import torch @@ -7,8 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ def bias_init_method(conv_weight): fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_(-bound, bound) + return init_uniform_centered_(bound) class DiscreteMamba2(torch.nn.Module): @@ -53,21 +54,20 @@ def __init__( # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} super().__init__() self.config: SSMConfig = config - bias = config.add_bias_linear self.layer_idx = layer_idx self._return_input = return_input layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) + td_inner = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + td_state = tensor_space.get_tensor_dim(SSMDimNames.state) + td_model = tensor_space.get_tensor_dim(TransformerDimNames.hidden) td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.qk_heads) - td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.v_heads) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.inner_proj_discrete_mamba2) + td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.head_groups) + td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) + td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel) + td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.composite_inner_projection) self.d_model = td_model.size self.d_inner = td_inner.size @@ -85,8 +85,8 @@ def __init__( self.in_proj = Linear( td_model, td_inner_proj, - bias=bias, - weight_init_method=kaiming_init_(td_model.size), + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(td_model.size), lr_scale=mamba_layer_lr_scale, ) self.z_bias = ( @@ -96,15 +96,13 @@ def __init__( init_method=init_zeros_, lr_scale=mamba_layer_lr_scale, ) - if not bias + if not config.add_bias_linear else 0.0 ) self.conv1d_weight = ParameterMeta.from_dims( (td_conv, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) - ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 + init_method=init_uniform_centered_((td_conv.size * td_conv_kernel.size) ** -0.5), lr_scale=mamba_layer_lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( @@ -123,12 +121,12 @@ def __init__( self.out_proj = Linear( td_inner, td_model, - bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) - def forward(self, hidden_states, kwargs): + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: """ ON variable names and pep8: keeping some variable names as in the original code for clarity. @@ -144,7 +142,6 @@ def forward(self, hidden_states, kwargs): raise NotImplementedError(f"Sequence-first not supported for SSMs.") assert _mamba_available - input_ = hidden_states outputs = {} # assert state is None batch, seqlen, dim = input_.shape diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index ee222d6d2..e877ff9c2 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -1,6 +1,6 @@ import typing -from fast_llm.layers.transformer.transformer import BaseBlock +from fast_llm.layers.transformer.transformer import BaseBlock, Mixer if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.tensor_space import TensorSpace @@ -14,21 +14,19 @@ class LlambaBlock(BaseBlock): """ _name = "Llamba block" - _mixer_module_name = "mixer" def __init__( self, - config_transformer: "TransformerConfig", - config_ssm: "SSMConfig", + transformer_config: "TransformerConfig", + ssm_config: "SSMConfig", tensor_space: "TensorSpace", - mixer_cls, + mixer_cls: type[Mixer], layer_index: int, return_input: bool = False, ): - self.mixer_cls = mixer_cls - self._config_ssm = config_ssm self._debug_mode = self._config_ssm.debug_ssm - super().__init__(config_transformer, tensor_space, layer_index, return_input) + super().__init__(transformer_config, tensor_space, layer_index, return_input) + self.mixer = mixer_cls(ssm_config, layer_idx=self._layer_index, tensor_space=self._tensor_space) - def _create_mixer(self): - self.mixer = self.mixer_cls(self._config_ssm, layer_idx=self._layer_index, tensor_space=self._tensor_space) + def get_mixer(self) -> Mixer: + return self.mixer diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a03509abb..011889d04 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,14 +1,15 @@ -import math -import typing - -import einops import torch from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.common.linear import Linear +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ -from fast_llm.utils import get_lr_scale +from fast_llm.layers.ssm.discrete_mamba2 import bias_init_method +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ +from fast_llm.utils import Assert, div, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -25,25 +26,7 @@ _causal_conv1d_available = False -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def bias_init_method(conv_weight): - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_(-bound, bound) - - -class Mamba2(torch.nn.Module): +class Mamba2(Mixer): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ @@ -53,207 +36,138 @@ def __init__( config: SSMConfig, layer_idx: int, tensor_space: TensorSpace, - return_input: bool = False, ): super().__init__() - self.config: SSMConfig = config - bias: bool = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input + self._config: SSMConfig = config + Assert.eq(self._config.activation_type, ActivationType.silu) layer_lr_scale: float | None = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None - mamba_layer_lr_scale: float | tuple[float | None, ...] | None = get_lr_scale( - self.config.mamba_lr_scale, layer_lr_scale - ) - - td_inner: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_dim) - td_state: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.state_dim) - td_model: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.model_dim) - tdt_rank: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) - td_xb: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.x_proj_dim_2) - td_inner_proj: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_proj_mamba2) - td_conv_kernel: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel_size) + lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - self.repeat_kv_before_conv = config.repeat_kv_before_conv + inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) + hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) - self.d_state = td_state.size - self.d_model = td_model.size - self.d_xb = td_xb.size - self.d_inner = td_inner.size - self.dt_rank = tdt_rank.size - - if self.repeat_kv_before_conv: - self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), - ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 - lr_scale=mamba_layer_lr_scale, - ) + self._head_groups = div(self._config.d_xb, self._config.state_size) + self._heads = div(self._config.d_inner, self._config.state_size) + self._group_heads = div(self._heads, self._head_groups) - self.conv1d_bias = ParameterMeta.from_dims( - (td_inner,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale - ) - else: - self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), - ), - ) - self.conv1d_bias = ParameterMeta.from_dims( - (td_xb,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale - ) - - self.activation = "silu" - - self.num_xb_head = td_xb.size // td_state.size - self.num_C_head = td_inner.size // td_state.size - self.repeat_group = self.num_C_head // self.num_xb_head - - self.in_proj = Linear( - td_model, - td_inner_proj, - bias=bias, - weight_init_method=kaiming_init_(td_model.size), - lr_scale=mamba_layer_lr_scale, + conv1d_dim = ( + inner_dim + if self._config.repeat_kv_before_conv + else tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) ) - - # Initialize special dt projection to preserve variance at initialization - dt_scale = config.dt_scale # 1.0 - dt_init_std = self.dt_rank**-0.5 * dt_scale - if config.dt_init == "constant": - dt_init = init_fill_(dt_init_std) - elif config.dt_init == "random": - dt_init = init_uniform_(-dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt_max = config.dt_max # or 0.1 - dt_min = config.dt_min # or 0.001 - dt_init_floor = config.dt_init_floor # or 1e-4 - dt = torch.exp(torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( - min=dt_init_floor + self.conv1d_weight = ParameterMeta.from_dims( + (conv1d_dim, tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel)), + init_method=init_uniform_centered_((conv1d_dim.size * self._config.conv_kernel_dimension) ** -0.5), + lr_scale=lr_scale, + ) + self.conv1d_bias = ParameterMeta.from_dims( + (conv1d_dim,), init_method=bias_init_method(self._config.conv_kernel_dimension**-0.5), lr_scale=lr_scale + ) + self.in_proj = OutputParallelLinear( + hidden_dim, + tensor_space.get_tensor_dim(name=SSMDimNames.composite_inner_projection), + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, ) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - - def init_from_tensor_( - value: torch.Tensor, - ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.copy_(value) - - return init_ - self.dt_proj = Linear( - tdt_rank, - td_inner, + tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank), + inner_dim, bias=False, - weight_init_method=dt_init, - lr_scale=mamba_layer_lr_scale, + # Initialize special dt projection to preserve variance at initialization + weight_init_method=self._config.dt_init.get_init_method( + self._config.dt_rank**-0.5 * self._config.dt_scale + ), + lr_scale=lr_scale, ) # define bias outside the linear layer since its also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), init_method=init_from_tensor_(inv_dt), lr_scale=mamba_layer_lr_scale + (inner_dim,), + init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + lr_scale=lr_scale, ) - - A = einops.repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A).flatten() # Keep A_log in fp32 self.A_log = ParameterMeta.from_dims( - (td_inner, td_state), - init_method=init_from_tensor_(A_log), - lr_scale=mamba_layer_lr_scale, + (inner_dim, tensor_space.get_tensor_dim(name=SSMDimNames.state)), + init_method=init_A(self._config.state_size, self._config.d_inner), + lr_scale=lr_scale, weight_decay=False, ) - self.D = ParameterMeta.from_dims( - (td_inner,), + (inner_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - - self.out_proj = Linear( - td_inner, - td_model, - bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + self.out_proj = InputParallelLinear( + inner_dim, + hidden_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(self._config.d_inner), ) def forward(self, hidden_states, kwargs): - """ - hidden_states: (B, L, D) - Returns: same shape as hidden_states - """ assert _mamba_available - batch, seqlen, dim = hidden_states.shape - outputs = {} - - conv_state, ssm_state = None, None - - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) - - x = einops.rearrange(x, "b l d -> b d l") - z = einops.rearrange(z, "b l d -> b d l") - - B = einops.rearrange(B, "b l (n_group dstate) -> b n_group l dstate", dstate=self.d_state) - B = repeat_kv(B, self.repeat_group) # B, n_group, L, H - B = einops.rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() - C = einops.rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - - dt = self.dt_proj(dt) + self.dt_proj_bias # B, L, d_inner - dt = einops.rearrange(dt, "b l d -> b d l") # B, d_inner, L + assert _causal_conv1d_available + + inner_projection = self.in_proj(hidden_states) + # Standardize to (batch, sequence, inner_projection) + if kwargs[TransformerKwargs.sequence_first]: + inner_projection = inner_projection.transpose(0, 1) + sequence_length = hidden_states.size(1) + + z, x, b, c, dt = torch.split( + inner_projection, + [self._config.d_inner, self._config.d_xb, self._config.d_xb, self._config.d_inner, self._config.dt_rank], + dim=2, + ) + # z: (batch, sequence, heads * state) -> (batch, heads * state, sequence) + z = z.transpose(1, 2) + + # x: (batch, sequence, head_groups * state) -> (batch, heads * state, sequence) + x = x.transpose(1, 2) + if self._config.repeat_kv_before_conv: + x = ( + x.unflatten(1, (self._head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._heads) + .flatten(1, 2) + ) + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") + else: + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") + x = ( + x.unflatten(1, (self._head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._heads) + .flatten(1, 2) + ) - if self.repeat_kv_before_conv: - assert self.repeat_group > 0 - x = einops.rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + # b: (batch, sequence, head_groups * state) -> (batch, heads, state, sequence) + b = ( + b.transpose(1, 2) + .unflatten(1, (self._head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._heads) + ) - assert self.activation in ["silu", "swish"] - if _causal_conv1d_available: - x = _causal_conv1d_fn( - x=x, - weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), - bias=self.conv1d_bias, - activation=self.activation, - ) # B, L, D - else: - raise RuntimeError("Causal conv1d is not available. Please install causal_conv1d.") + # c: (batch, sequence, heads * state) -> (batch, heads, state, sequence) + c = c.transpose(1, 2).unflatten(1, (self._heads, self._config.state_size)) - if not self.repeat_kv_before_conv: - x = einops.rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + # dt: (batch, sequence, dt_rank) -> (batch, heads * state, sequence) + dt = (self.dt_proj(dt) + self.dt_proj_bias).transpose(1, 2) y = selective_scan_fn( x, dt, - A, - B, - C, + -torch.exp(self.A_log.float()), + b, + c, self.D.float(), - z=z, - delta_bias=self.dt_proj_bias.float(), # self.dt_proj.bias.float(), + z, + delta_bias=self.dt_proj_bias.float(), delta_softplus=True, - return_last_state=False, ) - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(einops.rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) - - y = einops.rearrange(y, "b d l -> b l d") - out = self.out_proj(y) - outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - return outputs["hidden_states"], None + # y: (batch, heads * state, sequence) -> out: (batch, sequence, hidden) + out = self.out_proj(y.transpose(1, 2))[:, :sequence_length] + if kwargs[TransformerKwargs.sequence_first]: + out = out.transpose(0, 1) + # TODO: Is contiguous needed? + return out.contiguous(), None diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 7c824d235..fa2789b1e 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,14 +1,18 @@ +import logging import math +import typing from typing import Callable -import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ -from fast_llm.utils import get_lr_scale +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ +from fast_llm.utils import Assert, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -17,6 +21,8 @@ except (ImportError, RuntimeError): _mamba_available = False +logger = logging.getLogger(__name__) + """ Note: this is mostly adapted from https://github.com/Zyphra/Zamba2, similar code is also in https://github.com/state-spaces/mamba. For now it only supports training and not inference. @@ -26,169 +32,126 @@ def init_A(d_state, d_inner) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - # S4D real initialization # TODO: adopt this initialization to work for tensor parallel setting! - A = einops.repeat(torch.arange(1, d_state + 1, dtype=torch.float32), "n -> d n", d=d_inner).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - if tensor.shape != A_log.shape: - if tensor.numel() == A_log.numel(): - tensor_view = tensor.view(d_inner, d_state) - tensor_view.copy_(A_log) - else: - raise ValueError(f"Tensor size {tensor.numel()} doesn't match expected size {A_log.numel()}") - else: - tensor.copy_(A_log) - return tensor + if tensor.numel() != d_state * d_inner: + raise ValueError(f"_init_A requires not supported for tensor slices.") + return torch.log(torch.arange(1, d_state + 1, dtype=torch.float32).repeat(d_inner), out=tensor) return init_ def init_dtprojbias( - d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float, factory_kwargs: dict + dt_max: float, dt_min: float, dt_init_floor: float ) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - dt = torch.exp( - torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) - ).clamp(min=dt_init_floor) + tensor = tensor.uniform_(math.log(dt_min), math.log(dt_max)).exp_().clamp_min(dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - tensor.copy_(inv_dt) - return tensor + return tensor.add_(torch.log(-torch.expm1(-tensor))) return init_ -class MambaLayer(torch.nn.Module): +class MambaLayer(Mixer): def __init__( self, config: SSMConfig, layer_idx: int, tensor_space: TensorSpace, - return_input: bool = False, ): - factory_kwargs = {} super().__init__() - self.config: SSMConfig = config - self.layer_idx = layer_idx - - self._debug_mode = config.debug_ssm + assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" + self._config = config + # TODO: It's not silu? + Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_inner_proj = tensor_space.get_tensor_dim( - SSMDimNames.inner_proj_mamba - ) # TensorDim("D_inner_2", self.d_inner * 2) - tdt_rank = tensor_space.get_tensor_dim(SSMDimNames.dt_rank) - td_x_proj = tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - self.d_conv = td_conv_kernel.size - self.d_inner = td_inner.size - self.d_state = td_state.size - self.d_model = td_model.size - self.dt_rank = tdt_rank.size + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None - mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - self.in_proj_weight = ParameterMeta.from_dims( - (td_inner_proj, td_model), - init_method=kaiming_init_(td_model.size), + # TODO: Backward compatibility? + # TODO: lr_scale? + self.in_proj = Linear( + hidden_dim, + tensor_space.get_tensor_dim(SSMDimNames.composite_inner_projection), + bias=False, + weight_init_method=init_kaiming_(hidden_dim.size), ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), - init_method=kaiming_init_(td_inner.size), - lr_scale=mamba_layer_lr_scale, + (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.conv_kernel)), + init_method=init_kaiming_(inner_dim.size), + lr_scale=lr_scale, ) - self.conv1d_bias = None - - self.activation = "silu" - self.act = torch.nn.SiLU() - self.x_proj = Linear( - td_inner, - td_x_proj, - weight_init_method=kaiming_init_(td_inner.size), + inner_dim, + tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim), + weight_init_method=init_kaiming_(inner_dim.size), bias=False, - lr_scale=mamba_layer_lr_scale, - **factory_kwargs, + lr_scale=lr_scale, ) self.x_proj.weight.auto_grad_accumulation = True # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( - (td_inner, tdt_rank), - init_method=kaiming_init_(tdt_rank.size), - lr_scale=mamba_layer_lr_scale, + (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.dt_rank)), + init_method=init_kaiming_(self._config.dt_rank), + lr_scale=lr_scale, ) self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), - init_method=init_dtprojbias( - self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs - ), - lr_scale=mamba_layer_lr_scale, + (inner_dim,), + init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (td_inner, td_state), + (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.state)), weight_decay=False, - init_method=init_A(self.d_state, self.d_inner), - lr_scale=mamba_layer_lr_scale, + init_method=init_A(self._config.state_size, inner_dim.size), + lr_scale=lr_scale, ) # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_inner,), + (inner_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) self.out_proj = Linear( - td_inner, - td_model, + inner_dim, + hidden_dim, bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. - weight_init_method=kaiming_init_(td_model.size), - lr_scale=mamba_layer_lr_scale, - **factory_kwargs, + weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, ) self.out_proj.weight.auto_grad_accumulation = True - self._return_input = return_input - def forward(self, hidden_states, kwargs): + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - batch, seqlen, dim = hidden_states.shape - - # We do matmul and transpose BLH -> HBL at the same time - xz = einops.rearrange( - self.in_proj_weight @ einops.rearrange(hidden_states, "b l d -> d (b l)"), - "d (b l) -> b d l", - l=seqlen, - ) - if self._debug_mode: - print("XZ: ", xz.shape) + in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[TransformerKwargs.sequence_first] else (0, 2, 1)) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s out = _mamba_inner_fn( - xz, - self.conv1d_weight, - self.conv1d_bias, + in_proj, + self.conv1d_weight.unsqueeze(1), + None, self.x_proj.weight, self.dt_proj_weight, self.out_proj.weight, self.out_proj.bias, # is None here - A, + -torch.exp(self.A_log.float()), None, # input-dependent B None, # input-dependent C self.D.float(), delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) - if self._return_input: - out = torch.stack((hidden_states, out), dim=0) + if kwargs[TransformerKwargs.sequence_first]: + out = out.transpose(0, 1) return out, None diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3351c9906..76b8ed1ca 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -13,6 +13,7 @@ TransformerKwargs, TransformerSubLayerName, ) +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert, get_lr_scale @@ -50,7 +51,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(torch.nn.Module): +class Attention(Mixer): """ A self-attention layer. """ diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 147452073..f80e903f0 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -18,13 +18,24 @@ logger = logging.getLogger(__name__) +class Mixer(torch.nn.Module, abc.ABC): + """ + Base class for mixer modules. + """ + + @abc.abstractmethod + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Mixer module forward. Returns the output hidden states and an optional bias, + in case its addition can be made more efficient in `_bias_dropout_add`. + """ + + class BaseBlock(Layer, abc.ABC): """ A transformer-like decoder base block with abstract mixer. """ - _mixer_module_name = "self_attn" - def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): @@ -54,7 +65,7 @@ def __init__( self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def _create_mixer(self): + def get_mixer(self) -> Mixer: pass @torch.compile @@ -115,7 +126,7 @@ def forward( hidden_states = self.norm_1(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 1", kwargs) - hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) + hidden_states, bias = self.get_mixer()(hidden_states, kwargs) if self._debug_mode: self._debug_log(hidden_states, f"{self._mixer_module_name} output", kwargs, bias=bias) with set_generator(generator): @@ -137,14 +148,14 @@ def forward( return hidden_states -class TransformerLayer(BaseBlock): +class TransformerBlock(BaseBlock): _name = "Transformer layer" - _mixer_module_name = "self_attn" def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): super().__init__(config, tensor_space, layer_index, return_input) - - def _create_mixer(self): self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + + def get_mixer(self) -> Mixer: + return self.self_attn diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index c206ef406..a9cf3bb8c 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -7,7 +7,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig from fast_llm.models.custom.head import CustomHead from fast_llm.models.gpt.config import GPTBaseModelConfig @@ -31,7 +31,7 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, layer_index=i + 1, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 444ad72b2..a3a68e0a6 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -21,7 +21,7 @@ TransformerLossNames, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -68,7 +68,7 @@ def get_output_layers(self) -> list[Layer]: for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, # TODO MTP: which index? @@ -91,7 +91,7 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, layer_index=i + 1, @@ -336,7 +336,7 @@ def embedding(self) -> LanguageModelEmbedding: return self.layers[0] @property - def transformer_layers(self) -> list[TransformerLayer]: + def transformer_layers(self) -> list[TransformerBlock]: return self.layers[1:-1] @property diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index cc83f11be..c294fe528 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -6,12 +6,11 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig -from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -24,7 +23,7 @@ @config_class() -class HybridSSMBaseModelConfig(LanguageModelBaseConfig): +class HybridSSMBaseModelConfig(GPTBaseModelConfig): _abstract = False ssm: SSMConfig = Field( @@ -51,38 +50,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: Some of these can be setup directly in the layer config, but keeping them here for clarity. """ super().setup_tensor_space(tensor_space) - d_inner: int = self.ssm.d_inner - - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.model_dim, self.transformer.hidden_size)) - # Mamba-specific dimensions - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_dim, d_inner)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.state_dim, self.ssm.state_size)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, self.ssm.dt_rank)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.ssm.dt_rank + self.ssm.state_size * 2)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel_size, self.ssm.conv_kernel_dimension)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba, d_inner * 2)) - - if SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout: - # Mamba2 specific dimensions - # as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4 - headdim = d_inner // self.ssm.n_v_heads - Assert.eq(self.ssm.n_v_heads, d_inner // headdim) - Assert.eq(d_inner % headdim, 0) - Assert.eq(self.ssm.n_v_heads % self.ssm.n_qk_heads, 0) - - conv_dim = d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size - inner_proj_dim = 2 * d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size + self.ssm.n_v_heads - - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.head_dim, headdim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.qk_heads, self.ssm.n_qk_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.v_heads, self.ssm.n_v_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_discrete_mamba2, inner_proj_dim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) - elif SSMBlockType.mamba2.value in self.hybrid_block_layout: - inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) + self.ssm.setup_tensor_space(tensor_space) def _validate(self): with self._set_implicit_default(None): diff --git a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py index 6d9746db1..8f49ded40 100644 --- a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py +++ b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py @@ -322,19 +322,21 @@ def __init__(self, config, factory_kwargs, layer_idx, **kwargs): # Mixer self.mixer = DiscreteMamba2( - d_model=self.config.d_model, + d_model=self.config._hidden_size, layer_idx=layer_idx, **config.ssm_cfg, **factory_kwargs, ) # Other components - self.input_layernorm = LlamaRMSNorm(hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs) + self.input_layernorm = LlamaRMSNorm( + hidden_size=self.config._hidden_size, eps=1e-5, factory_kwargs=factory_kwargs + ) self.post_attention_layernorm = LlamaRMSNorm( - hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs + hidden_size=self.config._hidden_size, eps=1e-5, factory_kwargs=factory_kwargs ) self.mlp = LlamaMLP( - hidden_size=self.config.d_model, + hidden_size=self.config._hidden_size, **config.mlp_cfg, factory_kwargs=factory_kwargs, ) diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 02a5ac239..3e57689b6 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -9,7 +9,7 @@ from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.ssm.mamba2 import Mamba2 from fast_llm.layers.ssm.mamba_layer import MambaLayer -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -39,14 +39,14 @@ def get_output_layers(self) -> list[Layer]: Get the output layers of the model. This includes the language model head and any additional heads specified in the configuration. """ - layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] + layers: list[Layer] = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] if self._config.prediction_heads > 1: block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] for i in range(1, self._config.prediction_heads): if block_type == SSMBlockType.transformer: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, layer_index=len(self._config.hybrid_block_layout), @@ -55,8 +55,8 @@ def get_output_layers(self) -> list[Layer]: ) elif block_type == SSMBlockType.mamba2_discrete: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=DiscreteMamba2, layer_index=len(self._config.hybrid_block_layout), tensor_space=self._tensor_space, @@ -65,8 +65,8 @@ def get_output_layers(self) -> list[Layer]: layers.append(mamba_block) elif block_type == SSMBlockType.mamba: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=MambaLayer, layer_index=len(self._config.hybrid_block_layout), tensor_space=self._tensor_space, @@ -75,8 +75,8 @@ def get_output_layers(self) -> list[Layer]: layers.append(mamba_block) elif block_type == SSMBlockType.mamba2: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=Mamba2, layer_index=len(self._config.hybrid_block_layout), tensor_space=self._tensor_space, @@ -94,14 +94,14 @@ def get_layers(self) -> list[Layer]: Create a list of layers for the model, interleaving Transformer and Mamba blocks according to the block pattern. """ - layers = [LanguageModelEmbedding(self._config, self._tensor_space)] + layers: list[Layer] = [LanguageModelEmbedding(self._config, self._tensor_space)] # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): if block_type == SSMBlockType.transformer: # Transformer block layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, layer_index=i + 1, @@ -112,8 +112,8 @@ def get_layers(self) -> list[Layer]: ) elif block_type == SSMBlockType.mamba2_discrete: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=DiscreteMamba2, layer_index=i + 1, tensor_space=self._tensor_space, @@ -126,8 +126,8 @@ def get_layers(self) -> list[Layer]: elif block_type == SSMBlockType.mamba: # Create Mamba block mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=MambaLayer, layer_index=i + 1, tensor_space=self._tensor_space, @@ -139,8 +139,8 @@ def get_layers(self) -> list[Layer]: elif block_type == SSMBlockType.mamba2: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=Mamba2, layer_index=i + 1, tensor_space=self._tensor_space, diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d780e4d6d..b474fe87f 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -354,7 +354,7 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return init_ -def kaiming_init_(d_in): +def init_kaiming_(d_in): return init_normal_(0.0, math.sqrt(2.0 / d_in)) @@ -369,3 +369,9 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return tensor return init_ + + +def init_uniform_centered_( + high, max_val=None, mean=0.0 +) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: + return init_uniform_(mean - high, mean + high, min_val=mean - max_val, max_val=mean + max_val) diff --git a/setup.cfg b/setup.cfg index 2f69b8e06..c086af7d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,7 +48,7 @@ HUGGINGFACE = # Required to run SSMs # To install on cpu environment (ex. for IDE support): -# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation +# MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index c530a170c..e5fbc7d69 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -4,7 +4,7 @@ from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.utils import Assert from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -39,7 +39,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerLayer, LlambaBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, LlambaBlock)) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip( From f93c51f465998370e4a9ac2b1f2eb78903b57ba3 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 21 Jul 2025 21:50:55 +0000 Subject: [PATCH 106/161] draft llava hybrid --- fast_llm/engine/checkpoint/huggingface.py | 3 + fast_llm/models/ssm/conversion.py | 18 ++- .../configuration_llava_hybrid.py | 110 ++++++++++++++++++ .../llava_hybrid/modeling_llava_hybrid.py | 12 ++ 4 files changed, 141 insertions(+), 2 deletions(-) create mode 100644 fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py create mode 100644 fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 16b3e005f..4cfff4afa 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -134,6 +134,7 @@ class CustomModelingExportMixin: configuration_file: typing.ClassVar[str] configuration_cls: typing.ClassVar[type[PretrainedConfig]] generation_utils_file: str | None = None + additional_files: typing.ClassVar[list[str]] = [] # Use custom config instead of relying on the transformers library @classmethod @@ -159,3 +160,5 @@ def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None: gen_config = pathlib.Path(self.generation_utils_file).parent / "generation_config.json" if gen_config.exists(): shutil.copy(gen_config, config.path) + for file in self.additional_files: + shutil.copy(file, config.path) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 55ed3058a..7d908a135 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,6 +3,8 @@ import pathlib import typing +from transformers.configuration_utils import PretrainedConfig + from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, @@ -16,7 +18,7 @@ SplitWeightConverter, WeightConverter, ) -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig @@ -34,6 +36,11 @@ LLambaHuggingfaceCheckpointFormat, LlavaHybridHuggingfaceCheckpointFormat, ) +from fast_llm.models.ssm.external.apriel_15b_hybrid import ( + configuration_ssm_hybrid_apriel15b, + modeling_ssm_hybrid_apriel15b, +) +from fast_llm.models.ssm.external.llava_hybrid import configuration_llava_hybrid, modeling_llava_hybrid from fast_llm.models.ssm.model import HybridSSMModel from fast_llm.utils import Assert @@ -793,9 +800,16 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An json.dump(config, f) -class LlavaHybridHuggingfaceCheckpointHandler(LlavaHuggingfaceCheckpointHandler): +class LlavaHybridHuggingfaceCheckpointHandler(CustomModelingExportMixin, LlavaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaHybridHuggingfaceCheckpointFormat + modeling_file = modeling_llava_hybrid.__file__ + configuration_file = configuration_llava_hybrid.__file__ + configuration_cls: typing.ClassVar[type[PretrainedConfig]] = configuration_llava_hybrid.LlavaHybridConfig _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + additional_files = [ + modeling_ssm_hybrid_apriel15b.__file__, + configuration_ssm_hybrid_apriel15b.__file__, + ] @classmethod def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: diff --git a/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py new file mode 100644 index 000000000..09e17a92b --- /dev/null +++ b/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py @@ -0,0 +1,110 @@ +from transformers import MistralConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.models.auto import CONFIG_MAPPING +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +# Copied from configuration_ssm_hybrid_apriel15b.py +# TODO: split into mamba 2 and discrete mamba 2 configs with a base dict +ssm_config_default = { + # discrete mamba2 + "d_state": 64, + "n_v_heads": 32, + "n_qk_heads": 32, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_conv": 4, + "d_inner": 32 * 128, + # mamba2 + "d_xb": None, # will be set to model dim + "dt_rank": "auto", + "dt_min": 0.001, + "dt_max": 0.1, + "dt_init": "random", + "dt_scale": 1.0, + "dt_init_floor": 1e-4, + "conv_bias": True, +} + + +class AprielSSMHybridConfig(MistralConfig): + model_type = "apriel_ssm_thinker_hybrid" + + def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): + super().__init__(**kwargs) + self.hybrid_block_layout = hybrid_block_layout + self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3 + self.ssm_cfg = ssm_cfg or ssm_config_default + + for k, v in ssm_config_default.items(): + if k not in self.ssm_cfg: + self.ssm_cfg[k] = v # to make sure all elements are present in the config + + +class LlavaHybridConfig(PretrainedConfig): + """ + Configuration class for Llava SSM-Hybrid-decoder model. + """ + + model_type = "llava_hybrid" + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=32000, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-2, + image_seq_length=576, + multimodal_projector_bias=True, + **kwargs, + ): + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.image_seq_length = image_seq_length + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError( + "vision_feature_select_strategy should be one of 'default', 'full'." + f"Got: {vision_feature_select_strategy}" + ) + + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + + if isinstance(vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" + ) + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=336, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + self.vision_config = vision_config + + if isinstance(text_config, dict): + if text_config.get("model_type") == "apriel_ssm_thinker_hybrid": + text_config = AprielSSMHybridConfig(**text_config) + else: + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + self.multimodal_projector_bias = multimodal_projector_bias + + super().__init__(**kwargs) diff --git a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py new file mode 100644 index 000000000..d58b3535d --- /dev/null +++ b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py @@ -0,0 +1,12 @@ +from transformers import LlavaModel + +from .configuration_llava_hybrid import LlavaHybridConfig + + +class LlavaHybridModel(LlavaModel): + """ + Llava SSM-Hybrid-decoder model. + """ + + def __init__(self, config: LlavaHybridConfig): + super().__init__(config) From 4e310c74634a70c4d8117cc025f18a040ffbd098 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 22 Jul 2025 13:04:54 -0400 Subject: [PATCH 107/161] TP mamba --- fast_llm/engine/config_utils/tensor_space.py | 174 ++++++++++++------- fast_llm/layers/common/linear.py | 8 +- fast_llm/layers/common/normalization.py | 4 +- fast_llm/layers/common/peft.py | 4 +- fast_llm/layers/ssm/config.py | 45 +++-- fast_llm/layers/ssm/discrete_mamba2.py | 2 +- fast_llm/layers/ssm/mamba2.py | 22 ++- fast_llm/layers/ssm/mamba_layer.py | 2 +- fast_llm/tensor.py | 31 ++-- 9 files changed, 184 insertions(+), 108 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 99c1bcf70..dceeb7da4 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -5,6 +5,8 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: + import torch + from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.distributed.distributed import Distributed @@ -23,7 +25,7 @@ def __repr__(self) -> str: f"name={self._name}," f" size={self._size}," f" global_size={self._global_size}," - f" parallel_dim={None if self.parallel_dim is None else self._parallel_dim}" + f" parallel_dim={self._parallel_dim}" f")" ) @@ -38,83 +40,134 @@ def name(self) -> str: def size(self) -> int: return self._size - @property - def expanded_shape(self) -> tuple[int, ...]: - return (self._size,) - - @property - def ndim(self) -> int: - return 1 - @property def global_size(self) -> int: return self._global_size @property - def global_expanded_shape(self) -> tuple[int, ...]: - return (self._size if self._parallel_dim is None else self._size * self._parallel_dim.size,) + def is_parallel(self) -> bool: + return self._parallel_dim is not None and self._parallel_dim.size > 1 @property def parallel_dim(self) -> DistributedDim | None: + # TODO: Make more flexible for derived classes? return self._parallel_dim - @property - def parallel_dim_index(self) -> int | None: - return None if self._parallel_dim is None else 0 - @property def parallel_group(self) -> "ProcessGroup|None": + # TODO: Make more flexible for derived classes? return None if self._parallel_dim is None else self._parallel_dim.group def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim is not None + assert self.is_parallel return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + if self.parallel_group is not None: + from fast_llm.core.ops import gather_op + + return gather_op(tensor, self.parallel_group, dim) + else: + return tensor + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + return ( + tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] + if self.parallel_dim is not None and self.parallel_dim.size > 1 + else tensor + ) + class CompositeTensorDim(TensorDim): - def __init__(self, name: str, dims: tuple[TensorDim, ...]): - # TODO: Recursive composition?? - parallel_dims = [(i, dim.parallel_dim) for i, dim in enumerate(dims) if dim.parallel_dim] - Assert.leq(len(parallel_dims), 1) + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = None + for dim, tensor_dim in enumerate(tensor_dims): + if tensor_dim.is_parallel: + # TODO: Allow more than one parallel subdim? + assert parallel_dim is None + parallel_dim = tensor_dim.parallel_dim + self._parallel_dim_index = dim super().__init__( name=name, - global_size=math.prod(dim.global_size for dim in dims), - parallel_dim=parallel_dims[0][1] if parallel_dims else None, - ) - self._dims = dims - self._parallel_dim_index = ( - sum(dim.ndim for dim in self._dims[: parallel_dims[0][0]]) - + self._dims[parallel_dims[0][0]].parallel_dim_index - if parallel_dims - else None + global_size=math.prod(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, ) + self._tensor_dims = tensor_dims - @property - def dims(self) -> tuple[TensorDim, ...]: - return self._dims + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + assert self._parallel_dim_index is not None + dims = list(self._tensor_dims) + dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim) + return CompositeTensorDim(self.name, tuple(dims)) - @property - def ndim(self) -> int: - return sum(dim.ndim for dim in self._dims) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global(tensor, dim + i) - @property - def expanded_shape(self) -> tuple[int, ...]: - return sum((dim.expanded_shape for dim in self._dims), ()) + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def global_expanded_shape(self) -> tuple[int, ...]: - return sum((dim.global_expanded_shape for dim in self._dims), ()) + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): + tensor = tensor_dim.global_to_local(tensor, dim + i) + return tensor if expand else tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def parallel_dim_index(self) -> int | None: - return self._parallel_dim_index + +class ConcatenatedTensorDim(TensorDim): + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = tensor_dims[0].parallel_dim + for dim, tensor_dim in enumerate(tensor_dims[1:]): + # TODO: Allow more flexibility? + Assert.is_(tensor_dim.parallel_dim, parallel_dim) + + super().__init__( + name=name, + global_size=sum(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, + ) + self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim_index is not None - dims = list(self.dims) - dims[self.parallel_dim_index] = dims[self.parallel_dim_index].replace_parallel_dim(distributed_dim) - return CompositeTensorDim(self.name, tuple(dims)) + # TODO: Implement + raise NotImplementedError() + + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + return ( + torch.concatenate( + [ + tensor_dim.local_to_global(tensor_, dim)[0] + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + if self.is_parallel and expand: + raise NotImplementedError() + return ( + torch.concatenate( + [ + tensor_dim.global_to_local(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.global_size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) class DefaultDimNames: @@ -147,21 +200,22 @@ def distributed(self) -> "Distributed": assert self._is_setup return self._distributed - def add_tensor_dim(self, dim: TensorDim) -> None: - if isinstance(dim, CompositeTensorDim): - for dim_ in dim.dims: - Assert.incl(dim_.name, self._tensor_dims) - Assert.eq(dim_, self._tensor_dims[dim_.name]) - if dim.name in self._tensor_dims: - Assert.eq(dim, self._tensor_dims[dim.name]) + def add_tensor_dim(self, tensor_dim: TensorDim) -> None: + if tensor_dim.name in self._tensor_dims: + Assert.eq(tensor_dim, self._tensor_dims[tensor_dim.name]) else: - if dim.parallel_dim is not None: - assert dim.parallel_dim.name in self._distributed_config.distributed_dims, dim.parallel_dim.name + if tensor_dim.parallel_dim is not None: + assert ( + tensor_dim.parallel_dim.name in self._distributed_config.distributed_dims + ), tensor_dim.parallel_dim.name Assert.eq( - dim.parallel_dim.__dict__, - self._distributed_config.distributed_dims[dim.parallel_dim.name].__dict__, + tensor_dim.parallel_dim.__dict__, + self._distributed_config.distributed_dims[tensor_dim.parallel_dim.name].__dict__, ) - self._tensor_dims[dim.name] = dim + self._tensor_dims[tensor_dim.name] = tensor_dim def get_tensor_dim(self, name: str) -> TensorDim: return self._tensor_dims[name] + + # TODO: Replace uses + __getitem__ = get_tensor_dim diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index cd19a47a5..7249ef569 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -94,8 +94,8 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None - assert out_dim.parallel_dim is None + assert not in_dim.is_parallel + assert not out_dim.is_parallel super().__init__( in_dim, out_dim, @@ -132,7 +132,7 @@ def __init__( sequence_parallel: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None + assert not in_dim.is_parallel self._group_size = 1 if out_dim.parallel_dim is None else out_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( @@ -176,7 +176,7 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert out_dim.parallel_dim is None + assert not out_dim.is_parallel self._group_size = 1 if in_dim.parallel_dim is None else in_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 5f30beaef..bccc1d627 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -158,7 +158,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: @@ -242,7 +242,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft.py index 3a1966e51..08f3e535b 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft.py @@ -19,12 +19,12 @@ def lora_linear( ): layer.weight.requires_grad = False in_dim = layer._in_dim + assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." if in_dim.parallel_dim is not None: - assert in_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." in_dim = TensorDim(in_dim.name, in_dim.global_size) out_dim = layer._out_dim + assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." if out_dim.parallel_dim is not None: - assert out_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." out_dim = TensorDim(out_dim.name, out_dim.global_size) if out_channel_begin is not None or out_channel_end is not None: if out_channel_begin is None: diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index f4c8067dd..ce37a9804 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,7 +1,7 @@ import enum from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig @@ -20,8 +20,7 @@ class SSMDimNames: composite_head_groups_and_state = "ssm_composite_head_groups_and_state" # Inner projection total dimension. - inner_projection = "ssm_inner_projection" - composite_inner_projection = "ssm_composite_inner_projection" + concatenated_inner_projection = "ssm_concatenated_inner_projection" # Convolution shape in discrete mamba 2. TODO: Remove (dim too complex) conv_dim = "ssm_conv_dim" @@ -210,7 +209,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType Assert.eq(num_heads, self.n_v_heads) num_head_groups = self.n_qk_heads # (head_groups, (2 * group_heads + 2) * state_dim + group_heads) - inner_projection_size = 2 * self.d_inner + 2 * num_head_groups * self.state_size + num_heads + 2 * self.d_inner + 2 * num_head_groups * self.state_size + num_heads else: raise NotImplementedError(block_type) @@ -219,12 +218,18 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType tensor_space.add_tensor_dim( group_heads := TensorDim(SSMDimNames.group_heads, num_group_heads := div(num_heads, num_head_groups)) ) - tensor_space.add_tensor_dim(CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads))) tensor_space.add_tensor_dim( - CompositeTensorDim(SSMDimNames.composite_heads_and_state, (head_groups, group_heads, state_dim)) + heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(SSMDimNames.composite_head_groups_and_state, (head_groups, state_dim)) + heads_and_state := CompositeTensorDim( + SSMDimNames.composite_heads_and_state, (head_groups, group_heads, state_dim) + ) + ) + tensor_space.add_tensor_dim( + head_groups_and_state := CompositeTensorDim( + SSMDimNames.composite_head_groups_and_state, (head_groups, state_dim) + ) ) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel, self.conv_kernel_dimension)) @@ -234,17 +239,27 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType if block_type == SSMBlockType.mamba: tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.dt_rank + self.state_size * 2)) - inner_projection_size = 2 * num_group_heads * self.state_size + # TODO: Use composition instead + tensor_space.add_tensor_dim( + ConcatenatedTensorDim(SSMDimNames.concatenated_inner_projection, (heads_and_state, heads_and_state)) + ) elif block_type == SSMBlockType.mamba2: - inner_projection_size = 2 * (num_group_heads + 1) * self.state_size + # TODO: Factor out state? + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, + (heads_and_state, head_groups_and_state, head_groups_and_state, heads_and_state), + ) + ) elif block_type == SSMBlockType.mamba2_discrete: - inner_projection_size = 2 * (num_group_heads + 1) * self.state_size + num_group_heads + # TODO: Factor as (head_groups, (group_heads + 2) * state_size + group_heads)? + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, + (heads_and_state, head_groups_and_state, head_groups_and_state, heads_and_state, heads), + ) + ) # TODO: (head_groups, group_heads + 2, state_size) tensor_space.add_tensor_dim( TensorDim(SSMDimNames.conv_dim, self.d_inner + 2 * self.n_qk_heads * self.state_size) ) - - tensor_space.add_tensor_dim(inner_projection := TensorDim(SSMDimNames.inner_projection, inner_projection_size)) - tensor_space.add_tensor_dim( - CompositeTensorDim(SSMDimNames.composite_inner_projection, (head_groups, inner_projection)) - ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index d06b47965..988a09504 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -67,7 +67,7 @@ def __init__( td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.head_groups) td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.composite_inner_projection) + td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection) self.d_model = td_model.size self.d_inner = td_inner.size diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 011889d04..dff1356e6 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -45,6 +45,7 @@ def __init__( inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) + dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) self._head_groups = div(self._config.d_xb, self._config.state_size) self._heads = div(self._config.d_inner, self._config.state_size) @@ -65,13 +66,21 @@ def __init__( ) self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space.get_tensor_dim(name=SSMDimNames.composite_inner_projection), + tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), bias=config.add_bias_linear, weight_init_method=init_kaiming_(hidden_dim.size), lr_scale=lr_scale, ) - self.dt_proj = Linear( - tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank), + + self.dt_in_proj = Linear( + hidden_dim, + dt_rank_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, + ) + self.dt_proj = OutputParallelLinear( + dt_rank_dim, inner_dim, bias=False, # Initialize special dt projection to preserve variance at initialization @@ -110,16 +119,19 @@ def forward(self, hidden_states, kwargs): assert _causal_conv1d_available inner_projection = self.in_proj(hidden_states) + dt = self.dt_in_proj(hidden_states) # Standardize to (batch, sequence, inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) + dt = dt.transpose(0, 1) sequence_length = hidden_states.size(1) - z, x, b, c, dt = torch.split( + z, x, b, c = torch.split( inner_projection, - [self._config.d_inner, self._config.d_xb, self._config.d_xb, self._config.d_inner, self._config.dt_rank], + [self._config.d_inner, self._config.d_xb, self._config.d_xb, self._config.d_inner], dim=2, ) + # z: (batch, sequence, heads * state) -> (batch, heads * state, sequence) z = z.transpose(1, 2) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index fa2789b1e..0cdcb5242 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -74,7 +74,7 @@ def __init__( # TODO: lr_scale? self.in_proj = Linear( hidden_dim, - tensor_space.get_tensor_dim(SSMDimNames.composite_inner_projection), + tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection), bias=False, weight_init_method=init_kaiming_(hidden_dim.size), ) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b474fe87f..f312f1962 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -5,7 +5,7 @@ import torch from fast_llm.core.distributed import ReduceOp -from fast_llm.core.ops import gather_op, reduce_op +from fast_llm.core.ops import reduce_op from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed @@ -166,14 +166,13 @@ def local_to_global( ) -> tuple[torch.Tensor, ...]: # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication - is_first_rank = distributed.config.tensor_rank == 0 - modified = False - for i, dim in enumerate(self.dims): - if dim.parallel_group is not None: - tensor = gather_op( - tensor.unflatten(i, dim.expanded_shape), dim.parallel_group, i + dim.parallel_dim_index - ).flatten(i, i + len(dim.expanded_shape) - 1) - is_first_rank, modified = is_first_rank and dim.parallel_group.rank() == 0, True + is_first_rank, modified = distributed.config.tensor_rank == 0, False + + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global(tensor, dim) + is_first_rank &= tensor_dim.parallel_dim.rank == 0 + modified = True for distributed_dim, op in self._reductions: if distributed_dim.group is not None: @@ -187,23 +186,19 @@ def local_to_global( def global_to_local( self, tensor: torch.Tensor | SafeTensorSlice, - # Return an expanded tensor, avoiding `flatten` which copies the data. + # Return an expanded tensor, avoiding `flatten` which copies the data. TODO: Rework. expand: bool = False, ) -> torch.Tensor: """ Recover the tensor-parallel slice of a tensor. Support lazy-loaded safetensor slices. """ # Take a trivial slice to convert safetensor slices. - tensor_ = tensor[:] + tensor = tensor[:] assert not self._reductions - for i, dim in reversed(list(enumerate(self.dims))): - if dim.parallel_dim is not None and dim.parallel_dim.size > 1: - tensor_ = tensor_.unflatten(i, dim.global_expanded_shape).chunk( - dim.parallel_dim.size, i + dim.parallel_dim_index - )[dim.parallel_dim.rank] - - return tensor_ if expand else tensor_.reshape(self.shape) + for dim, tensor_dim in reversed(list(enumerate(self.dims))): + tensor = tensor_dim.global_to_local(tensor, dim, expand) + return tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): From 3cc41182a71d28e02918d76cd882978ca8384f73 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 22 Jul 2025 16:57:38 -0400 Subject: [PATCH 108/161] fix --- fast_llm/engine/config_utils/tensor_space.py | 6 +- fast_llm/layers/ssm/config.py | 24 +++-- fast_llm/layers/ssm/discrete_mamba2.py | 2 + fast_llm/layers/ssm/llamba_block.py | 10 +- fast_llm/layers/ssm/mamba_layer.py | 13 ++- fast_llm/layers/transformer/transformer.py | 20 ++-- fast_llm/models/ssm/config.py | 41 +++----- fast_llm/models/ssm/model.py | 99 +++++--------------- fast_llm/tensor.py | 7 +- tests/data/test_blending.py | 1 + tests/data/test_concatenate.py | 1 + tests/data/test_fim.py | 2 + tests/test_multi_stage.py | 6 +- tests/utils/model_configs.py | 43 +++++---- 14 files changed, 127 insertions(+), 148 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index dceeb7da4..d927f2e71 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -70,7 +70,7 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor else: return tensor - def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": return ( tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] if self.parallel_dim is not None and self.parallel_dim.size > 1 @@ -108,7 +108,7 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): tensor = tensor_dim.global_to_local(tensor, dim + i) @@ -150,7 +150,7 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor else tensor ) - def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": if self.is_parallel and expand: raise NotImplementedError() return ( diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index ce37a9804..aa011f75f 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -41,6 +41,22 @@ class SSMBlockType(enum.StrEnum): mamba2 = "m2" transformer = "t" + def get_mixer_class(self): + if self == SSMBlockType.mamba: + from fast_llm.layers.ssm.mamba_layer import MambaLayer + + return MambaLayer + elif self == SSMBlockType.mamba2: + from fast_llm.layers.ssm.mamba2 import Mamba2 + + return Mamba2 + elif self == SSMBlockType.mamba2_discrete: + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + + return DiscreteMamba2 + else: + raise NotImplementedError(self) + class DTInitType(enum.StrEnum): constant = "constant" @@ -199,17 +215,13 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType # Head groups are configured differently depending on the block type. if block_type == SSMBlockType.mamba: num_head_groups = num_heads - # (head_groups, 2 * group_heads * state_dim) - inner_projection_size = self.d_inner * 2 elif block_type == SSMBlockType.mamba2: num_head_groups = div(self.d_xb, self.state_size) - # (head_groups, 2 * group_heads + 2, state_dim) + (dt,) - inner_projection_size: int = 2 * self.d_inner + 2 * num_head_groups * self.state_size + self.dt_rank elif block_type == SSMBlockType.mamba2_discrete: Assert.eq(num_heads, self.n_v_heads) + # TODO: Fix (Du einsum crashes) + Assert.eq(self.n_qk_heads, self.n_v_heads) num_head_groups = self.n_qk_heads - # (head_groups, (2 * group_heads + 2) * state_dim + group_heads) - 2 * self.d_inner + 2 * num_head_groups * self.state_size + num_heads else: raise NotImplementedError(block_type) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 988a09504..14fb8aaed 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -216,6 +216,8 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ else: y = result + print("AHNFIUWEGIUWEI", self.D.shape, x.shape) + # TODO: h different for D and x (qk_heads, v_heads) Du = torch.einsum("h,blhp->blhp", self.D, x) y = einops.rearrange(y + Du, "b l h p -> b l (h p)") diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index e877ff9c2..774ee7303 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -8,7 +8,7 @@ from fast_llm.layers.transformer.config import TransformerConfig -class LlambaBlock(BaseBlock): +class SSMBlock(BaseBlock): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ @@ -24,9 +24,9 @@ def __init__( layer_index: int, return_input: bool = False, ): - self._debug_mode = self._config_ssm.debug_ssm + self._ssm_config = ssm_config + self._mixer_cls = mixer_cls super().__init__(transformer_config, tensor_space, layer_index, return_input) - self.mixer = mixer_cls(ssm_config, layer_idx=self._layer_index, tensor_space=self._tensor_space) - def get_mixer(self) -> Mixer: - return self.mixer + def _create_mixer(self) -> Mixer: + return self._mixer_cls(self._ssm_config, layer_idx=self._layer_index, tensor_space=self._tensor_space) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 0cdcb5242..8235f4f1a 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,7 +1,6 @@ import logging import math import typing -from typing import Callable import torch @@ -30,21 +29,25 @@ """ -def init_A(d_state, d_inner) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +def init_A(d_state, d_inner) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa # TODO: adopt this initialization to work for tensor parallel setting! if tensor.numel() != d_state * d_inner: raise ValueError(f"_init_A requires not supported for tensor slices.") - return torch.log(torch.arange(1, d_state + 1, dtype=torch.float32).repeat(d_inner), out=tensor) + return torch.log( + torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device).repeat(d_inner), out=tensor + ) return init_ def init_dtprojbias( dt_max: float, dt_min: float, dt_init_floor: float -) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - tensor = tensor.uniform_(math.log(dt_min), math.log(dt_max)).exp_().clamp_min(dt_init_floor) + tensor = ( + tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min(dt_init_floor) + ) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 return tensor.add_(torch.log(-torch.expm1(-tensor))) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index f80e903f0..a0611cd29 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,7 +8,6 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP @@ -36,6 +35,9 @@ class BaseBlock(Layer, abc.ABC): A transformer-like decoder base block with abstract mixer. """ + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "mixer" + def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): @@ -54,7 +56,8 @@ def __init__( self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - self._create_mixer() + # The mixer needs to be created here for backward-compatible weight ordering. + setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index @@ -65,7 +68,7 @@ def __init__( self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def get_mixer(self) -> Mixer: + def _create_mixer(self) -> Mixer: pass @torch.compile @@ -126,7 +129,7 @@ def forward( hidden_states = self.norm_1(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 1", kwargs) - hidden_states, bias = self.get_mixer()(hidden_states, kwargs) + hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) if self._debug_mode: self._debug_log(hidden_states, f"{self._mixer_module_name} output", kwargs, bias=bias) with set_generator(generator): @@ -150,12 +153,15 @@ def forward( class TransformerBlock(BaseBlock): _name = "Transformer layer" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "self_attn" def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): super().__init__(config, tensor_space, layer_index, return_input) - self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) - def get_mixer(self) -> Mixer: - return self.self_attn + def _create_mixer(self) -> Mixer: + from fast_llm.layers.transformer.attention import Attention + + return Attention(self._config, self._tensor_space, self._layer_index) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index c294fe528..6b9e28584 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -30,7 +30,7 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - hybrid_block_layout: list[str] | None = Field( + hybrid_block_layout: list[SSMBlockType] | None = Field( default=None, desc=f"Pattern of blocks to use in the model. Available types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, @@ -43,14 +43,16 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): use_megatron_initialization: bool = Field( default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing ) # TODO: is this needed? + # TODO: Support combination of different SSM block types. + ssm_block_type: SSMBlockType | None = Field(init=False) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: """ Setup the tensor space for the model. - Some of these can be setup directly in the layer config, but keeping them here for clarity. """ super().setup_tensor_space(tensor_space) - self.ssm.setup_tensor_space(tensor_space) + if self.ssm_block_type is not None: + self.ssm.setup_tensor_space(tensor_space, self.ssm_block_type) def _validate(self): with self._set_implicit_default(None): @@ -64,30 +66,21 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers if len(self.hybrid_block_layout) != self.transformer.num_layers: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: - raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - ) - num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) - logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" - ) + raise ValueError(message) + num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) + logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) - Assert.custom( - lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", - ) - Assert.custom( - lambda _: self.default_mtp_type in SSMBlockType.__members__.values() or self.default_mtp_type is None, - f"Invalid MTP type: {self.default_mtp_type}. Must be one of {SSMBlockType.__members__.values()} or None", - ) - super()._validate() + ssm_block_types = set(self.hybrid_block_layout) - {SSMBlockType.transformer} + # TODO: Support combination of different SSM block types. + Assert.leq(len(ssm_block_types), 1) + self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): @@ -162,12 +155,6 @@ def _validate(self): logger.warning( "HybridSSMModelConfig is being instantiated. This model is experimental and may not work as expected." ) - if ( - self.base_model.sequence_first - or self.distributed.sequence_data_parallel > 1 - or self.distributed.sequence_tensor_parallel - ): - raise NotImplementedError(f"Sequence-first not supported for SSMs.") super()._validate() diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 3e57689b6..4a95891a7 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -5,10 +5,7 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.ssm.mamba2 import Mamba2 -from fast_llm.layers.ssm.mamba_layer import MambaLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -31,7 +28,6 @@ def __init__( config: HybridSSMBaseModelConfig, distributed_config: DistributedConfig, ): - self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed super().__init__(config, distributed_config) def get_output_layers(self) -> list[Layer]: @@ -53,38 +49,17 @@ def get_output_layers(self) -> list[Layer]: return_input=i != self._config.prediction_heads - 1, ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=Mamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + layer_index=len(self._config.hybrid_block_layout), + tensor_space=self._tensor_space, + return_input=i != self._config.prediction_heads - 1, + ) + ) layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) return layers @@ -110,47 +85,19 @@ def get_layers(self) -> list[Layer]: ), ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba: - # Create Mamba block - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=Mamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + layer_index=i + 1, + tensor_space=self._tensor_space, + return_input=( + i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 + ), + ) + ) # Add the output layers layers += self.get_output_layers() diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index f312f1962..1111fd044 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -369,4 +369,9 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_uniform_centered_( high, max_val=None, mean=0.0 ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - return init_uniform_(mean - high, mean + high, min_val=mean - max_val, max_val=mean + max_val) + return init_uniform_( + mean - high, + mean + high, + min_val=None if max_val is None else mean - max_val, + max_val=None if max_val is None else mean + max_val, + ) diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 438782dfe..3e6c37632 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -193,6 +193,7 @@ def test_gpt_blended_mixed(): def test_gpt_blended_mixed_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index e951cc2b1..4f36cdf89 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -39,6 +39,7 @@ def test_gpt_concatenate(): def test_gpt_concatenate_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 7472f1958..004b96289 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -58,6 +58,7 @@ def test_gpt_fim(): def test_gpt_fim_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { @@ -81,6 +82,7 @@ def test_gpt_fim_data(): def test_gpt_fim_data_legacy(): + get_test_dataset() get_test_data_and_compare_samples( { "format": "list", diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index e5fbc7d69..2f125717e 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,9 +3,10 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.ssm.llamba_block import LlambaBlock +from fast_llm.layers.ssm.llamba_block import SSMBlock from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.utils import Assert +from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -23,6 +24,7 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): + get_model_test_dataset() args = model_testing_config.config_args + ["run.tensor_logs.save=False"] model_ref = _get_trainer_from_args(args, model_testing_config.model_type)._multi_stage model_frozen = _get_trainer_from_args( @@ -39,7 +41,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, LlambaBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, SSMBlock)) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip( diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index b96a8963b..b834ed4d1 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -451,16 +451,14 @@ def _update_and_add_testing_config( ) _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. + # Tests hybrid Mamba, llamba converter. "llama", "llamba", model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m']", - "model.base_model.ssm.state_size=8", - "model.base_model.ssm.chunk_size=32", - "model.base_model.ssm.n_qk_heads=8", - "model.base_model.ssm.n_v_heads=8", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=16", ], megatron_args=None, checkpoint_format=LLambaHuggingfaceCheckpointFormat, @@ -468,26 +466,31 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.broken, # TODO: Fix and bring back to `testing_groups` + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, + ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, compare_factor=2.0, - # SSMs don't support sequence-first configurations. - skip_tests=("sf", "sdp", "stp", "ms"), + # Micro-sequence split not supported. + skip_tests=("sdp", "ms"), ) _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. - "llamba", + # Tests hybrid discrete Mamba 2. + "llama", "hybrid_discrete_mamba2", model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2d']", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=8", + # TODO: Set to 16 once fixed. + "model.base_model.ssm.n_qk_heads=32", + "model.base_model.ssm.n_v_heads=32", + "model.base_model.ssm.chunk_size=32", ], megatron_args=None, checkpoint_format=None, @@ -497,17 +500,23 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + # TODO: Implement + ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, + # Micro-sequence split and sequence-first not supported. + skip_tests=("sf", "stp", "sdp", "ms"), ) _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. - "llamba", + # Tests hybrid Mamba 2. + "llama", "hybrid_mamba2", model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=16", + "model.base_model.ssm.d_xb=256", ], megatron_args=None, checkpoint_format=None, @@ -517,8 +526,10 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + # Micro-sequence split not supported. + skip_tests=("sdp", "ms"), ) From 9f7f75c72f1fff36a781773c8c772441d7fa9067 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 22 Jul 2025 19:56:35 -0400 Subject: [PATCH 109/161] fix --- fast_llm/engine/config_utils/tensor_space.py | 6 +++++- fast_llm/layers/ssm/config.py | 2 -- fast_llm/layers/ssm/discrete_mamba2.py | 4 +--- fast_llm/layers/ssm/mamba2.py | 19 +++++++++++-------- fast_llm/layers/ssm/mamba_layer.py | 5 ++++- fast_llm/tensor.py | 6 ++++++ tests/utils/model_configs.py | 9 +++++---- 7 files changed, 32 insertions(+), 19 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index d927f2e71..2ca7e3e9a 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -21,7 +21,7 @@ def __init__(self, name: str, global_size: int | None, parallel_dim: Distributed def __repr__(self) -> str: return ( - f"TensorDim(" + f"{type(self).__name__}(" f"name={self._name}," f" size={self._size}," f" global_size={self._global_size}," @@ -134,6 +134,8 @@ def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: raise NotImplementedError() def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + import torch + return ( torch.concatenate( [ @@ -153,6 +155,8 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": if self.is_parallel and expand: raise NotImplementedError() + import torch + return ( torch.concatenate( [ diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index aa011f75f..7da4283ba 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -219,8 +219,6 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType num_head_groups = div(self.d_xb, self.state_size) elif block_type == SSMBlockType.mamba2_discrete: Assert.eq(num_heads, self.n_v_heads) - # TODO: Fix (Du einsum crashes) - Assert.eq(self.n_qk_heads, self.n_v_heads) num_head_groups = self.n_qk_heads else: raise NotImplementedError(block_type) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 14fb8aaed..102accb85 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -111,7 +111,7 @@ def __init__( # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_n_qk_heads,), + (td_n_v_heads,), weight_decay=False, init_method=init_ones_, lr_scale=mamba_layer_lr_scale, @@ -216,8 +216,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ else: y = result - print("AHNFIUWEGIUWEI", self.D.shape, x.shape) - # TODO: h different for D and x (qk_heads, v_heads) Du = torch.einsum("h,blhp->blhp", self.D, x) y = einops.rearrange(y + Du, "b l h p -> b l (h p)") diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index dff1356e6..11ab91e40 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -4,7 +4,6 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.discrete_mamba2 import bias_init_method from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer @@ -62,7 +61,9 @@ def __init__( lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (conv1d_dim,), init_method=bias_init_method(self._config.conv_kernel_dimension**-0.5), lr_scale=lr_scale + (conv1d_dim,), + init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + lr_scale=lr_scale, ) self.in_proj = OutputParallelLinear( hidden_dim, @@ -124,7 +125,7 @@ def forward(self, hidden_states, kwargs): if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) dt = dt.transpose(0, 1) - sequence_length = hidden_states.size(1) + sequence_length = inner_projection.size(1) z, x, b, c = torch.split( inner_projection, @@ -177,9 +178,11 @@ def forward(self, hidden_states, kwargs): delta_softplus=True, ) - # y: (batch, heads * state, sequence) -> out: (batch, sequence, hidden) - out = self.out_proj(y.transpose(1, 2))[:, :sequence_length] + # y: (batch, heads * state, sequence) -> (batch, sequence, heads * state) + y = y.transpose(1, 2)[:, :sequence_length] if kwargs[TransformerKwargs.sequence_first]: - out = out.transpose(0, 1) - # TODO: Is contiguous needed? - return out.contiguous(), None + # TODO: Is contiguous needed? + y = y.transpose(0, 1).contiguous() + a, b = self.out_proj(y) + Assert.eq(a.shape, hidden_states.shape) + return a, b diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 8235f4f1a..49b9e45b7 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -35,7 +35,10 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) if tensor.numel() != d_state * d_inner: raise ValueError(f"_init_A requires not supported for tensor slices.") return torch.log( - torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device).repeat(d_inner), out=tensor + torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device) + .unsqueeze(0) + .expand(d_inner, d_state), + out=tensor, ) return init_ diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 1111fd044..25ae49a31 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -164,6 +164,9 @@ def local_to_global( *, distributed: Distributed, ) -> tuple[torch.Tensor, ...]: + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication is_first_rank, modified = distributed.config.tensor_rank == 0, False @@ -195,6 +198,9 @@ def global_to_local( # Take a trivial slice to convert safetensor slices. tensor = tensor[:] assert not self._reductions + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.global_shape) for dim, tensor_dim in reversed(list(enumerate(self.dims))): tensor = tensor_dim.global_to_local(tensor, dim, expand) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index b834ed4d1..47314263b 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -487,9 +487,8 @@ def _update_and_add_testing_config( "model.base_model.hybrid_block_layout=['t','m2d']", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", - # TODO: Set to 16 once fixed. - "model.base_model.ssm.n_qk_heads=32", - "model.base_model.ssm.n_v_heads=32", + "model.base_model.ssm.n_qk_heads=8", + "model.base_model.ssm.n_v_heads=16", "model.base_model.ssm.chunk_size=32", ], megatron_args=None, @@ -503,6 +502,7 @@ def _update_and_add_testing_config( # TODO: Implement ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, + compare_factor=2.0, # Micro-sequence split and sequence-first not supported. skip_tests=("sf", "stp", "sdp", "ms"), ) @@ -515,7 +515,7 @@ def _update_and_add_testing_config( extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", "model.base_model.ssm.d_inner=512", - "model.base_model.ssm.state_size=16", + "model.base_model.ssm.state_size=8", "model.base_model.ssm.d_xb=256", ], megatron_args=None, @@ -528,6 +528,7 @@ def _update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + compare_factor=2.0, # Micro-sequence split not supported. skip_tests=("sdp", "ms"), ) From d10eaad67d80efd00f70c574d5743c94a42f90eb Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 23 Jul 2025 15:10:01 +0000 Subject: [PATCH 110/161] fix and add test config --- fast_llm/layers/vision_encoder/adapter.py | 8 +++--- fast_llm/layers/vision_encoder/config.py | 12 ++++++++ fast_llm/models/ssm/conversion.py | 8 +++++- tests/utils/model_configs.py | 34 ++++++++++++++++++++++- 4 files changed, 56 insertions(+), 6 deletions(-) diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py index 41ea065d0..d324d5221 100644 --- a/fast_llm/layers/vision_encoder/adapter.py +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -24,15 +24,15 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): input_dim, tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), bias=True, - weight_init_method=init_normal_(), - bias_init_method=init_normal_(), + weight_init_method=init_normal_(std=config.adapter_init_method_std), + bias_init_method=init_normal_(std=config.adapter_init_method_std), ) self.layer_2 = Linear( tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), tensor_space.get_tensor_dim(TransformerDimNames.hidden), bias=True, - weight_init_method=init_normal_(), - bias_init_method=init_normal_(), + weight_init_method=init_normal_(std=config.adapter_init_method_std), + bias_init_method=init_normal_(std=config.adapter_init_method_std), ) def forward( diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 2ea7f6114..a705d948a 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -150,6 +150,18 @@ class VisionEncoderConfig(BaseModelConfig): hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) + adapter_init_method_std: float = Field( + default=None, + desc="Standard deviation for the normal initialization of the adapter weights. Default: adapter_size ** -0.5.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.adapter_init_method_std is None: + self.adapter_init_method_std = self.adapter_size**-0.5 + super()._validate() def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 7d908a135..97794b4fb 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -233,7 +233,13 @@ def _create_weight_converters( hf_base_prefix: str = "", offset: int = 0, ) -> list[WeightConverter]: - converters = super()._create_weight_converters() or [] + converters = ( + super()._create_weight_converters( + hf_base_prefix=hf_base_prefix, + offset=offset, + ) + or [] + ) num_layers = self._model.config.base_model.transformer.num_layers ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index b96a8963b..6af35eeb1 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -19,7 +19,7 @@ Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat, LlavaHybridHuggingfaceCheckpointFormat from tests.utils.dataset import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from tests.utils.distributed_configs import DistributedTestingConfig @@ -521,6 +521,38 @@ def _update_and_add_testing_config( }, ) +_update_and_add_testing_config( + # Tests hybrid ssm, llamba converter. + "hybrid_mamba2", + "vision_hybrid_mamba2", + model_type="hybrid_ssm", + extra_args=[ + "batch.max_image_size=128", + "model.base_model.vision_encoder.type=pixtral", + "model.base_model.vision_encoder.transformer.type=image_encoder", + "model.base_model.vision_encoder.transformer.gated=True", + "model.base_model.vision_encoder.transformer.num_layers=2", + "model.base_model.vision_encoder.transformer.hidden_size=256", + "model.base_model.vision_encoder.transformer.num_attention_heads=8", + "model.base_model.vision_encoder.transformer.head_groups=4", + "model.base_model.vision_encoder.transformer.init_method_std=0.022", + "model.base_model.vision_encoder.transformer.rotary.type=rope_2d", + "model.base_model.vision_encoder.adapter_size=512", + "model.distributed.training_dtype=torch.bfloat16", + ], + megatron_args=None, + checkpoint_format=LlavaHybridHuggingfaceCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, + compare_factor=16.0, +) + @pytest.fixture(scope="session", params=MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: From 4054e047d7318c2dfd6e37712f3b6b94d3beca5b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 15:22:24 -0400 Subject: [PATCH 111/161] fixes --- fast_llm/engine/config_utils/tensor_space.py | 11 ++++-- fast_llm/engine/multi_stage/stage_base.py | 2 + fast_llm/layers/ssm/mamba2.py | 41 +++++++++++--------- fast_llm/tensor.py | 2 + 4 files changed, 34 insertions(+), 22 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 2ca7e3e9a..0d971a88a 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -1,3 +1,4 @@ +import logging import math import typing @@ -10,6 +11,8 @@ from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.distributed.distributed import Distributed +logger = logging.getLogger(__name__) + class TensorDim: def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None): @@ -130,8 +133,10 @@ def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - # TODO: Implement - raise NotImplementedError() + assert self.is_parallel + return ConcatenatedTensorDim( + self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims) + ) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": import torch @@ -139,7 +144,7 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor return ( torch.concatenate( [ - tensor_dim.local_to_global(tensor_, dim)[0] + tensor_dim.local_to_global(tensor_, dim) for tensor_, tensor_dim in zip( tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), self._tensor_dims, diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 2f18f1360..9a8ce2092 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -191,6 +191,8 @@ def initialize_weights(self) -> None: # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) meta.init_parameter(global_param, distributed=self._distributed) + # It happens. + Assert.eq(global_param.shape, global_shape) if self._mode.on_device: parameter.copy_(fsdp.parameter_global_to_shard(global_param, meta.tensor_name)) elif self._mode.on_device: diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 11ab91e40..a285711c6 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,3 +1,5 @@ +import logging + import torch from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -24,6 +26,8 @@ except (ImportError, RuntimeError): _causal_conv1d_available = False +logger = logging.getLogger(__name__) + class Mamba2(Mixer): """ @@ -43,21 +47,20 @@ def __init__( lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) + xb_dim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) - self._head_groups = div(self._config.d_xb, self._config.state_size) - self._heads = div(self._config.d_inner, self._config.state_size) - self._group_heads = div(self._heads, self._head_groups) + self._local_heads = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads).size + self._local_head_groups = tensor_space.get_tensor_dim(name=SSMDimNames.head_groups).size + self._group_heads = div(self._local_heads, self._local_head_groups) + self._local_inner_size = inner_dim.size + self._local_xb_size = xb_dim.size - conv1d_dim = ( - inner_dim - if self._config.repeat_kv_before_conv - else tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) - ) + conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim self.conv1d_weight = ParameterMeta.from_dims( (conv1d_dim, tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel)), - init_method=init_uniform_centered_((conv1d_dim.size * self._config.conv_kernel_dimension) ** -0.5), + init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( @@ -69,7 +72,7 @@ def __init__( hidden_dim, tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), bias=config.add_bias_linear, - weight_init_method=init_kaiming_(hidden_dim.size), + weight_init_method=init_kaiming_(hidden_dim.global_size), lr_scale=lr_scale, ) @@ -77,7 +80,7 @@ def __init__( hidden_dim, dt_rank_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(hidden_dim.size), + weight_init_method=init_kaiming_(hidden_dim.global_size), lr_scale=lr_scale, ) self.dt_proj = OutputParallelLinear( @@ -129,7 +132,7 @@ def forward(self, hidden_states, kwargs): z, x, b, c = torch.split( inner_projection, - [self._config.d_inner, self._config.d_xb, self._config.d_xb, self._config.d_inner], + [self._local_inner_size, self._local_xb_size, self._local_xb_size, self._local_inner_size], dim=2, ) @@ -140,28 +143,28 @@ def forward(self, hidden_states, kwargs): x = x.transpose(1, 2) if self._config.repeat_kv_before_conv: x = ( - x.unflatten(1, (self._head_groups, self._config.state_size)) - .repeat_interleave(self._group_heads, 1, output_size=self._heads) + x.unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") else: x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") x = ( - x.unflatten(1, (self._head_groups, self._config.state_size)) - .repeat_interleave(self._group_heads, 1, output_size=self._heads) + x.unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) # b: (batch, sequence, head_groups * state) -> (batch, heads, state, sequence) b = ( b.transpose(1, 2) - .unflatten(1, (self._head_groups, self._config.state_size)) - .repeat_interleave(self._group_heads, 1, output_size=self._heads) + .unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) ) # c: (batch, sequence, heads * state) -> (batch, heads, state, sequence) - c = c.transpose(1, 2).unflatten(1, (self._heads, self._config.state_size)) + c = c.transpose(1, 2).unflatten(1, (self._local_heads, self._config.state_size)) # dt: (batch, sequence, dt_rank) -> (batch, heads * state, sequence) dt = (self.dt_proj(dt) + self.dt_proj_bias).transpose(1, 2) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 25ae49a31..6995e9e94 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -184,6 +184,7 @@ def local_to_global( tensor = tensor.clone() tensor = reduce_op(tensor, distributed_dim.group, op=op) is_first_rank, modified = is_first_rank and distributed_dim.group.rank() == 0, True + Assert.eq(tensor.shape, self.global_shape) return tensor, is_first_rank def global_to_local( @@ -204,6 +205,7 @@ def global_to_local( for dim, tensor_dim in reversed(list(enumerate(self.dims))): tensor = tensor_dim.global_to_local(tensor, dim, expand) + Assert.eq(tensor.shape, self.shape) return tensor @classmethod From 7c8de47cbc424b39a5abeb0c09881bb159ee2893 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 23 Jul 2025 19:57:24 +0000 Subject: [PATCH 112/161] conversion fixes and tests --- fast_llm/engine/checkpoint/external.py | 6 +++- .../layers/vision_encoder/preprocessing.py | 33 ++++++++++++++----- fast_llm/models/gpt/conversion.py | 10 +++--- fast_llm/models/gpt/model.py | 6 +++- fast_llm/models/ssm/config.py | 3 +- fast_llm/models/ssm/conversion.py | 14 ++++++++ .../llava_hybrid/modeling_llava_hybrid.py | 11 ++++++- fast_llm/models/ssm/huggingface.py | 6 ++-- fast_llm/models/ssm/model.py | 7 ++++ tests/utils/model_configs.py | 8 +++-- 10 files changed, 84 insertions(+), 20 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 72db80f6a..f8a42b31a 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -283,7 +283,7 @@ def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: return exported_config # Noqa @classmethod - def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # noqa + def _import_config_dict(cls, config: dict[str, typing.Any]) -> dict[str | tuple[str, ...], typing.Any]: kwargs = {} for converter in cls._get_config_converters(): try: @@ -306,7 +306,11 @@ def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # no kwargs[fast_llm_name] = value except Exception as e: raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + return kwargs + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # noqa + kwargs = cls._import_config_dict(config) return cls._model_class.get_base_model_config_class().from_dict({}, kwargs) def _convert_state_dict( diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 3b857ba26..adacd380c 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -163,11 +163,12 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: for imgs in images ] - labels = kwargs[LanguageModelKwargs.labels] - if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): - # If image break or end token is present, we need to replace image token ids to -100 in labels - # TODO: avoid double cloning labels in case of loss masking spans? - labels = labels.clone() + if LanguageModelKwargs.labels in kwargs: + labels = kwargs[LanguageModelKwargs.labels] + if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): + # If image break or end token is present, we need to replace image token ids to -100 in labels + # TODO: avoid double cloning labels in case of loss masking spans? + labels = labels.clone() patches = [] patch_position_ids = [] @@ -191,8 +192,9 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: image_break=self._config.image_break_token is not None, image_end=self._config.image_end_token is not None, ) - # set labels for image patches to -100 - labels[idx, max(position - 1, 0) : position + num_tokens - 1] = -100 + if LanguageModelKwargs.labels in kwargs: + # set labels for image patches to -100 + labels[idx, max(position - 1, 0) : position + num_tokens - 1] = -100 if seqlen > max_seqlen: max_seqlen = seqlen cu_seqlens.append(cu_seqlens[-1] + seqlen) @@ -261,4 +263,19 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ) kwargs[VisionTransformerKwargs.max_seqlen_q] = max_seqlen kwargs[VisionTransformerKwargs.max_seqlen_k] = max_seqlen - kwargs[LanguageModelKwargs.labels] = labels + if LanguageModelKwargs.labels in kwargs: + kwargs[LanguageModelKwargs.labels] = labels + + # TODO: add proper preprocessing for attention-mask when not using flash attention + # Following is just a dummy code to run the tests. + kwargs[self._config.transformer._transformer_kwargs.attention_mask] = torch.ones( + (1, 1, kwargs[TransformerKwargs.sequence_length], 1, kwargs[TransformerKwargs.sequence_length]), + dtype=torch.bool, + device=self._tensor_space.distributed.device, + ) + kwargs[self._config.transformer._transformer_kwargs.attention_mask_value] = torch.full( + [], + torch.finfo(self._distributed_config.training_dtype.torch).min, + dtype=self._distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 79d7099cd..2f4d9b61b 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -884,13 +884,15 @@ def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: @classmethod def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + vision_handler_cls = cls.get_vision_handler_class() + text_handler_cls = cls.get_text_handler_class() cfg_dict = cls._load_config(config.path) kwargs = {} if "text_config" in cfg_dict: - text_kwargs = cls._import_config(cfg_dict["text_config"]) + text_kwargs = text_handler_cls._import_config_dict(cfg_dict["text_config"]) kwargs.update(text_kwargs) if "vision_config" in cfg_dict: - vision_kwargs = cls._import_config(cfg_dict["vision_config"]) + vision_kwargs = vision_handler_cls._import_config_dict(cfg_dict["vision_config"]) vision_kwargs = {tuple(["vision_encoder"] + list(key)): value for key, value in vision_kwargs.items()} kwargs.update(vision_kwargs) kwargs.update( @@ -927,9 +929,9 @@ def _create_config_converters(cls) -> list[ParamConverter]: @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: - handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) + # handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) kwargs = {} - for converter in handler_cls._create_config_converters(): + for converter in cls._create_config_converters(): try: values = () for export_name in converter.export_names: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 8d70a8944..1e439e72e 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -150,7 +150,11 @@ def preprocess_meta( micro_sequence_length = sequence_length if self._config.vision_encoder.enabled: - max_image_size = batch_meta.max_image_size + try: + max_image_size = batch_meta.max_image_size + except AttributeError: + max_image_size = 256 + logger.warning("Inference mode: max_image_size not provided, defaulting to 256") image_mean = [ self._config.vision_encoder.image_normalization.mean_r, self._config.vision_encoder.image_normalization.mean_g, diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 55a7ef548..41f4eadbe 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -178,6 +178,7 @@ class LlavaHybridHuggingfaceCheckpointFormat(CheckpointFormat): name: typing.ClassVar[str] = "llava_hybrid" vision_name: typing.ClassVar[str] = "pixtral" text_name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" + trust_remote_code: typing.ClassVar[bool] = True @classmethod def get_handler_class(cls) -> type[CheckpointHandler]: @@ -206,7 +207,7 @@ def get_model_class(cls) -> type["HybridSSMModel"]: return HybridSSMModel @classmethod - def get_huggingface_model_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: + def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM return HuggingfaceHybridSSMModelForCausalLM diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 97794b4fb..db960f0fc 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -822,3 +822,17 @@ def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: from fast_llm.models.ssm.conversion import AprielThinkerSSMHHybridHuggingfaceCheckpointHandler return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "LlavaHybridForConditionalGeneration" + return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig", + "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", + "AutoModelForConditionalGeneration": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + }, + ), + ] diff --git a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py index d58b3535d..78e390a17 100644 --- a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py +++ b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py @@ -1,4 +1,5 @@ -from transformers import LlavaModel +from torch import nn +from transformers import LlavaForConditionalGeneration, LlavaModel from .configuration_llava_hybrid import LlavaHybridConfig @@ -10,3 +11,11 @@ class LlavaHybridModel(LlavaModel): def __init__(self, config: LlavaHybridConfig): super().__init__(config) + + +class LlavaHybridForConditionalGeneration(LlavaForConditionalGeneration): + def __init__(self, config: LlavaHybridConfig): + super(LlavaForConditionalGeneration, self).__init__(config) + self.model = LlavaHybridModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() diff --git a/fast_llm/models/ssm/huggingface.py b/fast_llm/models/ssm/huggingface.py index 77cd346f7..02f472076 100644 --- a/fast_llm/models/ssm/huggingface.py +++ b/fast_llm/models/ssm/huggingface.py @@ -1,9 +1,10 @@ import logging +import typing -from fast_llm.engine.huggingface.config import HuggingfaceModelConfig +from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM from fast_llm.models.ssm.config import HybridSSMModelConfig -from fast_llm.models.ssm.model import HybridSSMModel +from fast_llm.models.ssm.model import HybridSSMInferenceRunner, HybridSSMModel logger = logging.getLogger(__name__) @@ -17,5 +18,6 @@ class HuggingfaceSSMModelConfig(HuggingfaceModelConfig): class HuggingfaceHybridSSMModelForCausalLM(HuggingfaceGPTModelForCausalLM): config_class = HuggingfaceSSMModelConfig config: HuggingfaceSSMModelConfig + runner_class: typing.ClassVar[type[HybridSSMInferenceRunner]] = HybridSSMInferenceRunner model_class = HybridSSMModel _fast_llm_model: HybridSSMModel diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 80f9ca8ba..df15907d2 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -3,12 +3,14 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.ssm.mamba2 import Mamba2 from fast_llm.layers.ssm.mamba_layer import MambaLayer from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -164,3 +166,8 @@ class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]): config_class: typing.ClassVar[type[HybridSSMModelConfig]] = HybridSSMModelConfig base_model_class: typing.ClassVar[type[HybridSSMBaseModel]] = HybridSSMBaseModel + + +class HybridSSMInferenceRunner(InferenceRunner): + model_class: typing.ClassVar[type[HybridSSMModel]] = HybridSSMModel + batch_config_class: typing.ClassVar[type[GPTBatchConfig]] = GPTBatchConfig diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 6af35eeb1..c2bb68003 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -529,12 +529,16 @@ def _update_and_add_testing_config( extra_args=[ "batch.max_image_size=128", "model.base_model.vision_encoder.type=pixtral", + "model.base_model.vision_encoder.patch_norm.type=rms_norm", + "model.base_model.vision_encoder.transformer.add_linear_biases=False", + "model.base_model.vision_encoder.transformer.causal=False", + "model.base_model.vision_encoder.transformer.normalization.type=rms_norm", "model.base_model.vision_encoder.transformer.type=image_encoder", "model.base_model.vision_encoder.transformer.gated=True", "model.base_model.vision_encoder.transformer.num_layers=2", "model.base_model.vision_encoder.transformer.hidden_size=256", "model.base_model.vision_encoder.transformer.num_attention_heads=8", - "model.base_model.vision_encoder.transformer.head_groups=4", + "model.base_model.vision_encoder.transformer.head_groups=8", "model.base_model.vision_encoder.transformer.init_method_std=0.022", "model.base_model.vision_encoder.transformer.rotary.type=rope_2d", "model.base_model.vision_encoder.adapter_size=512", @@ -545,7 +549,7 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, From 0014cc6b3f79138e53610dc86cb654a5eaba90a0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 18:02:43 -0400 Subject: [PATCH 113/161] fix --- fast_llm/layers/ssm/discrete_mamba2.py | 27 +++----- fast_llm/layers/ssm/llamba_block.py | 11 ++- fast_llm/layers/ssm/mamba2.py | 53 ++++++++++---- fast_llm/layers/ssm/mamba_layer.py | 11 +-- fast_llm/layers/transformer/attention.py | 69 ++++--------------- .../layers/transformer/mixture_of_experts.py | 6 +- fast_llm/layers/transformer/mlp.py | 10 +-- fast_llm/layers/transformer/transformer.py | 63 ++++++++++++++--- fast_llm/models/custom/model.py | 2 +- fast_llm/models/gpt/model.py | 4 +- fast_llm/models/ssm/model.py | 8 +-- tests/utils/model_configs.py | 6 +- 12 files changed, 154 insertions(+), 116 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 102accb85..b95ff76da 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -8,7 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.utils import get_lr_scale @@ -37,28 +38,23 @@ def bias_init_method(conv_weight): return init_uniform_centered_(bound) -class DiscreteMamba2(torch.nn.Module): +class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" + _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): - """ - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. - Other options are all experimental and should not need to be configured. - """ - # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") + logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") td_inner = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) td_state = tensor_space.get_tensor_dim(SSMDimNames.state) @@ -223,9 +219,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - if self._return_input: - return torch.stack([input_, outputs["hidden_states"]], dim=0) - # TODO: since we do not support inference for now, we only return the hidden states for now. return outputs["hidden_states"], None diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index 774ee7303..986606634 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -21,12 +21,17 @@ def __init__( ssm_config: "SSMConfig", tensor_space: "TensorSpace", mixer_cls: type[Mixer], - layer_index: int, + block_index: int, return_input: bool = False, ): self._ssm_config = ssm_config self._mixer_cls = mixer_cls - super().__init__(transformer_config, tensor_space, layer_index, return_input) + super().__init__(transformer_config, tensor_space, block_index, return_input) def _create_mixer(self) -> Mixer: - return self._mixer_cls(self._ssm_config, layer_idx=self._layer_index, tensor_space=self._tensor_space) + return self._mixer_cls( + self._ssm_config, + tensor_space=self._tensor_space, + block_index=self._block_index, + transformer_config=self._config, + ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a285711c6..88fe4abc0 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,4 +1,5 @@ import logging +import typing import torch @@ -7,7 +8,7 @@ from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ from fast_llm.utils import Assert, div, get_lr_scale @@ -34,16 +35,31 @@ class Mamba2(Mixer): This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ + _mixer_name: typing.ClassVar[str] = "mamba_2" + + _XZ_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.composite_heads_and_state, + TransformerDimNames.sequence_q, + ) + _BC_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.composite_heads, + SSMDimNames.state, + TransformerDimNames.sequence_q, + ) + def __init__( self, config: SSMConfig, - layer_idx: int, tensor_space: TensorSpace, + block_index: int, + transformer_config: TransformerConfig, ): - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self._config: SSMConfig = config Assert.eq(self._config.activation_type, ActivationType.silu) - layer_lr_scale: float | None = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) @@ -72,7 +88,8 @@ def __init__( hidden_dim, tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), bias=config.add_bias_linear, - weight_init_method=init_kaiming_(hidden_dim.global_size), + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -80,7 +97,7 @@ def __init__( hidden_dim, dt_rank_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(hidden_dim.global_size), + weight_init_method=init_kaiming_(transformer_config.hidden_size), lr_scale=lr_scale, ) self.dt_proj = OutputParallelLinear( @@ -91,6 +108,7 @@ def __init__( weight_init_method=self._config.dt_init.get_init_method( self._config.dt_rank**-0.5 * self._config.dt_scale ), + sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) # define bias outside the linear layer since its also used in the selective_scan_fn @@ -116,6 +134,8 @@ def __init__( hidden_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(self._config.d_inner), + sequence_parallel=self._sequence_parallel, + # TODO: lr_scale? ) def forward(self, hidden_states, kwargs): @@ -123,11 +143,12 @@ def forward(self, hidden_states, kwargs): assert _causal_conv1d_available inner_projection = self.in_proj(hidden_states) - dt = self.dt_in_proj(hidden_states) + dt = self.dt_proj(self.dt_in_proj(hidden_states)) + self.dt_proj_bias # Standardize to (batch, sequence, inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) dt = dt.transpose(0, 1) + sequence_length = inner_projection.size(1) z, x, b, c = torch.split( @@ -166,8 +187,15 @@ def forward(self, hidden_states, kwargs): # c: (batch, sequence, heads * state) -> (batch, heads, state, sequence) c = c.transpose(1, 2).unflatten(1, (self._local_heads, self._config.state_size)) - # dt: (batch, sequence, dt_rank) -> (batch, heads * state, sequence) - dt = (self.dt_proj(dt) + self.dt_proj_bias).transpose(1, 2) + # dt: (batch, sequence, heads * state) -> (batch, heads * state, sequence) + dt = dt.transpose(1, 2) + + if self._debug_level: + self._debug_log(z, "z", self._XZ_DIMS, kwargs) + self._debug_log(x, "x", self._XZ_DIMS, kwargs) + self._debug_log(b, "b", self._BC_DIMS, kwargs) + self._debug_log(c, "c", self._BC_DIMS, kwargs) + self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) y = selective_scan_fn( x, @@ -181,11 +209,12 @@ def forward(self, hidden_states, kwargs): delta_softplus=True, ) + if self._debug_level: + self._debug_log(y, "y", self._XZ_DIMS, kwargs) + # y: (batch, heads * state, sequence) -> (batch, sequence, heads * state) y = y.transpose(1, 2)[:, :sequence_length] if kwargs[TransformerKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() - a, b = self.out_proj(y) - Assert.eq(a.shape, hidden_states.shape) - return a, b + return self.out_proj(y) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 49b9e45b7..49afa910e 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -8,7 +8,7 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ from fast_llm.utils import Assert, get_lr_scale @@ -58,13 +58,16 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) class MambaLayer(Mixer): + _mixer_name: typing.ClassVar[str] = "mamba" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, + transformer_config: TransformerConfig, ): - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" self._config = config # TODO: It's not silu? @@ -73,7 +76,7 @@ def __init__( # Tensor dims: inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) # TODO: Backward compatibility? diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 76b8ed1ca..174e19588 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -14,9 +14,8 @@ TransformerSubLayerName, ) from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.logging import log_distributed_grad, log_distributed_tensor -from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.tensor import init_normal_, init_zeros_ +from fast_llm.utils import get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -56,6 +55,8 @@ class Attention(Mixer): A self-attention layer. """ + _mixer_name: typing.ClassVar[str] = "attn" + _QUERY_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, @@ -65,7 +66,7 @@ class Attention(Mixer): _KV_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, - TransformerDimNames.group_heads, + TransformerDimNames.head_groups, TransformerDimNames.kv_channels, ) _CONTEXT_DIMS = ( @@ -74,19 +75,9 @@ class Attention(Mixer): TransformerDimNames.composite_dense, ) - def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index, - ): - super().__init__() + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): + super().__init__(tensor_space, block_index, config.debug_transformer) self._config = config - self._tensor_space = tensor_space - # Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) - self._layer_index = layer_index - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._debug_transformer = self._config.debug_transformer self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -109,7 +100,7 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) @@ -179,10 +170,10 @@ def _attn_fused( query, key, beta=0, - alpha=self._softmax_scale / self._layer_index, + alpha=self._softmax_scale / self._block_index, ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) - attn_weights = attn_weights.to(torch.float32) * self._layer_index + attn_weights = attn_weights.to(torch.float32) * self._block_index attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) @@ -201,40 +192,6 @@ def _attn_fused( .flatten(2) ) - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[TransformerKwargs.hidden_dims]} - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) - for dim_name in dim_names - ), - tensor_name=f"transformer layer {self._layer_index} attn {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_transformer, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - def _query_key_value_forward( self, input_: torch.Tensor, sequence_first: bool ) -> tuple[torch.Tensor, torch.Tensor, dict[str, typing.Any]]: @@ -301,7 +258,7 @@ def _decide_window_size(self) -> int | None: # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 # TODO: make universal per layer config window_size = self._config.window_size - if self._config.max_window_layers is not None and self._layer_index < self._config.max_window_layers: + if self._config.max_window_layers is not None and self._block_index < self._config.max_window_layers: window_size = None return window_size @@ -342,7 +299,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) self._debug_log( key, @@ -396,7 +353,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ kwargs[TransformerKwargs.attention_mask_value], ) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query", self._QUERY_DIMS, kwargs) self._debug_log( key, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index a46af1387..73f83ccf5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -40,11 +40,11 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory @@ -59,7 +59,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index b01eb2aa5..efe0c5cc5 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -14,10 +14,10 @@ class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): super().__init__() self._name = name - self._layer_index = layer_index + self._block_index = block_index init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -39,7 +39,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale lr_scale = get_lr_scale(lr_scale, layer_lr_scale) @@ -69,9 +69,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) def forward( self, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index a0611cd29..d08db9a94 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -13,6 +13,7 @@ from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -22,6 +23,15 @@ class Mixer(torch.nn.Module, abc.ABC): Base class for mixer modules. """ + _mixer_name: typing.ClassVar[str] + + def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): + super().__init__() + self._tensor_space = tensor_space + self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel + self._block_index = block_index + self._debug_level = debug_level + @abc.abstractmethod def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: """ @@ -29,6 +39,43 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ in case its addition can be made more efficient in `_bias_dropout_add`. """ + def _get_meta( + self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim + for dim in kwargs[TransformerKwargs.hidden_dims] + (kwargs[TransformerKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + for dim_name in dim_names + ), + tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", + dtype=input_.dtype, + ) + + def _debug_log( + self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> None: + # TODO: Local vs global + Assert.gt(self._debug_level, 0) + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + class BaseBlock(Layer, abc.ABC): """ @@ -39,7 +86,7 @@ class BaseBlock(Layer, abc.ABC): _mixer_module_name: typing.ClassVar[str] = "mixer" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): super().__init__() self._config: TransformerConfig = config @@ -48,11 +95,11 @@ def __init__( # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._layer_index = layer_index + self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) # Note, layer_lr_scale does not impact the norms - # TODO: add a seperate norm_lr_scale + # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) @@ -60,7 +107,7 @@ def __init__( setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index + self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index ) # PEFT. @@ -81,7 +128,7 @@ def _bias_dropout_add( @property def name(self) -> str: - return f"{self._name} {self._layer_index}" + return f"{self._name} {self._block_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[TransformerKwargs.hidden_dims] @@ -157,11 +204,11 @@ class TransformerBlock(BaseBlock): _mixer_module_name: typing.ClassVar[str] = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): - super().__init__(config, tensor_space, layer_index, return_input) + super().__init__(config, tensor_space, block_index, return_input) def _create_mixer(self) -> Mixer: from fast_llm.layers.transformer.attention import Attention - return Attention(self._config, self._tensor_space, self._layer_index) + return Attention(self._config, self._tensor_space, self._block_index) diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index a9cf3bb8c..534d813ff 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -34,7 +34,7 @@ def get_layers(self) -> list[Layer]: TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, ) for i in range(self._config.transformer.num_layers) ], diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a3a68e0a6..4c1eab46f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -72,7 +72,7 @@ def get_output_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, # TODO MTP: which index? - layer_index=max(self._config.transformer.num_layers + i, 1), + block_index=max(self._config.transformer.num_layers + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -94,7 +94,7 @@ def get_layers(self) -> list[Layer]: TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 4a95891a7..89f0cd4aa 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -45,7 +45,7 @@ def get_output_layers(self) -> list[Layer]: TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=len(self._config.hybrid_block_layout), + block_index=len(self._config.hybrid_block_layout), return_input=i != self._config.prediction_heads - 1, ) ) @@ -55,7 +55,7 @@ def get_output_layers(self) -> list[Layer]: transformer_config=self._config.transformer, ssm_config=self._config.ssm, mixer_cls=self._config.ssm_block_type.get_mixer_class(), - layer_index=len(self._config.hybrid_block_layout), + block_index=len(self._config.hybrid_block_layout), tensor_space=self._tensor_space, return_input=i != self._config.prediction_heads - 1, ) @@ -79,7 +79,7 @@ def get_layers(self) -> list[Layer]: TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 ), @@ -91,7 +91,7 @@ def get_layers(self) -> list[Layer]: transformer_config=self._config.transformer, ssm_config=self._config.ssm, mixer_cls=self._config.ssm_block_type.get_mixer_class(), - layer_index=i + 1, + block_index=i + 1, tensor_space=self._tensor_space, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 47314263b..4090e5a38 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -517,6 +517,7 @@ def _update_and_add_testing_config( "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", "model.base_model.ssm.d_xb=256", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" ], megatron_args=None, checkpoint_format=None, @@ -530,7 +531,10 @@ def _update_and_add_testing_config( }, compare_factor=2.0, # Micro-sequence split not supported. - skip_tests=("sdp", "ms"), + skip_tests=( + "sdp", + "ms", + ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) From 47ad5485454236d557570a32771c5888bbb3658e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 19:03:01 -0400 Subject: [PATCH 114/161] fixes --- Megatron-LM | 2 +- fast_llm/layers/language_model/head.py | 16 ++++++++++------ fast_llm/logging.py | 2 ++ fast_llm/tensor.py | 3 ++- tests/test_attention.py | 4 ++-- tests/utils/model_configs.py | 2 +- 6 files changed, 18 insertions(+), 11 deletions(-) diff --git a/Megatron-LM b/Megatron-LM index 511e8f5cb..75b0d9787 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 511e8f5cbe3ab8291953ac64e5beceb727a1b814 +Subproject commit 75b0d97876006c4b6b23fce302100d18dbf7db37 diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 25fc2b28d..21bf3bbd0 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -125,12 +125,16 @@ def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: if isinstance(input_, TensorMeta): - return TensorMeta.from_tensor_space( - (DefaultDimNames.scalar,), - self._tensor_space, - tensor_name="Loss", - reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa - ) + if self._is_last_head: + return TensorMeta.from_tensor_space( + (DefaultDimNames.scalar,), + self._tensor_space, + tensor_name="Loss", + reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + ) + else: + return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") + if not self._is_last_head: # MTP: split the stacked input shared_hidden, input_ = torch.unbind(input_, dim=0) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index e8334de6e..6d555a0bb 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -138,6 +138,8 @@ def log_tensor[ if level < 1: return tensor = tensor.detach() + if tensor.ndim == 0: + tensor = tensor[None] save_stats = TensorLogs.config.save shape = tuple(tensor.shape) _, dtype = str(tensor.dtype).split("torch.") diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 6995e9e94..899e70005 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -205,7 +205,8 @@ def global_to_local( for dim, tensor_dim in reversed(list(enumerate(self.dims))): tensor = tensor_dim.global_to_local(tensor, dim, expand) - Assert.eq(tensor.shape, self.shape) + if not expand: + Assert.eq(tensor.shape, self.shape) return tensor @classmethod diff --git a/tests/test_attention.py b/tests/test_attention.py index 87b0d3e59..dd36b840a 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -17,12 +17,12 @@ def test_decide_window_size(): # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 2 + attention._block_index = 2 assert attention._decide_window_size() == 512 # Arrange - Case 2: window_size is None (layer_index < max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 1 + attention._block_index = 1 assert attention._decide_window_size() is None # Arrange - Case 3: max_window_layers is None (always return window_size) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 4090e5a38..18db0d401 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -467,7 +467,7 @@ def _update_and_add_testing_config( ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.convert: ModelTestingGroupAction.broken, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, From 6a074fa3c72bbe16c617a11cff690c543e4c5e86 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 19:50:05 -0400 Subject: [PATCH 115/161] fixes --- fast_llm/layers/ssm/config.py | 2 +- fast_llm/models/ssm/conversion.py | 18 ++++++---- tests/utils/model_configs.py | 55 ++++++++++++++++--------------- 3 files changed, 41 insertions(+), 34 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 7da4283ba..15a6a8210 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -168,7 +168,7 @@ class SSMConfig(LLMBlockConfig): # Initialization # dt_weight_initialization_method [Mamba2] dt_init: DTInitType = Field( - default="random", + default=DTInitType.random, desc="Initialization method for dt", hint=FieldHint.core, ) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index d57300252..43e3c67e5 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,6 +3,7 @@ import pathlib import typing +from fast_llm.config import MISSING from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, @@ -19,7 +20,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig -from fast_llm.layers.ssm.config import SSMBlockType +from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, @@ -42,11 +43,11 @@ class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - block_converter = RenameParamConverter( + block_converter = MappedConfigParamConverter( fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),), - ignore_missing=True, - default_value=[cls._default_block_type], + fast_llm_value=lambda x: [cls._default_block_type] if x == MISSING else x, + export_value=lambda x: [x_.value for x_ in x], ) return super()._create_config_converters() + [block_converter] @@ -202,7 +203,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ignore_missing=True, default_value=4, ), - RenameParamConverter( + MappedConfigParamConverter( fast_llm_names=(("ssm", "dt_init"),), export_names=( ( @@ -210,8 +211,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: "dt_init", ), ), - ignore_missing=True, - default_value="random", + fast_llm_value=lambda x: DTInitType.random if x == MISSING else DTInitType(x), + export_value=lambda x: x.value, ), ] @@ -258,6 +259,9 @@ def _create_weight_converters(self) -> list[WeightConverter]: ) # ================================================ # Mamba2 specific parameters + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.dt_in_proj", f"model.layers.{i}.mixer.dt_in_proj", ssm_bias + ) converters += self._get_weight_and_bias_converters( f"layers.{i+1}.mixer.dt_proj", f"model.layers.{i}.mixer.dt_proj", False ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 18db0d401..3ffc3281b 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -19,7 +19,10 @@ Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import ( + AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + LLambaHuggingfaceCheckpointFormat, +) from tests.utils.dataset import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from tests.utils.distributed_configs import DistributedTestingConfig @@ -467,7 +470,7 @@ def _update_and_add_testing_config( ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, @@ -477,47 +480,49 @@ def _update_and_add_testing_config( skip_tests=("sdp", "ms"), ) - _update_and_add_testing_config( - # Tests hybrid discrete Mamba 2. + # Tests hybrid Mamba 2. "llama", - "hybrid_discrete_mamba2", + "hybrid_mamba2", model_type="hybrid_ssm", extra_args=[ - "model.base_model.hybrid_block_layout=['t','m2d']", + "model.base_model.hybrid_block_layout=['t','m2']", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", - "model.base_model.ssm.n_qk_heads=8", - "model.base_model.ssm.n_v_heads=16", - "model.base_model.ssm.chunk_size=32", + "model.base_model.ssm.d_xb=256", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" ], megatron_args=None, - checkpoint_format=None, + checkpoint_format=AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - # TODO: Implement - ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, compare_factor=2.0, - # Micro-sequence split and sequence-first not supported. - skip_tests=("sf", "stp", "sdp", "ms"), + # Micro-sequence split not supported. + skip_tests=( + "sdp", + "ms", + ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) + _update_and_add_testing_config( - # Tests hybrid Mamba 2. + # Tests hybrid discrete Mamba 2. "llama", - "hybrid_mamba2", + "hybrid_discrete_mamba2", model_type="hybrid_ssm", extra_args=[ - "model.base_model.hybrid_block_layout=['t','m2']", + "model.base_model.hybrid_block_layout=['t','m2d']", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", - "model.base_model.ssm.d_xb=256", - # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" + "model.base_model.ssm.n_qk_heads=8", + "model.base_model.ssm.n_v_heads=16", + "model.base_model.ssm.chunk_size=32", ], megatron_args=None, checkpoint_format=None, @@ -527,14 +532,12 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + # TODO: Implement + ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, compare_factor=2.0, - # Micro-sequence split not supported. - skip_tests=( - "sdp", - "ms", - ), # "pp","dp", "ce","16", "bf", "df", "stp"), + # Micro-sequence split and sequence-first not supported. + skip_tests=("sf", "stp", "sdp", "ms"), ) From d66651f5433392794d1b45560282d9237824881d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 19:56:19 -0400 Subject: [PATCH 116/161] Update external --- .../modeling_ssm_hybrid_apriel15b.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index f8f6a0520..4fde72458 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -843,9 +843,8 @@ def __init__( self.num_C_head = self.d_inner // self.d_state self.repeat_group = self.num_C_head // self.num_xb_head - self.in_proj = nn.Linear( - self.d_model, 2 * self.d_xb + 2 * self.d_inner + self.dt_rank, bias=bias, **factory_kwargs - ) + self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs) + self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization @@ -933,8 +932,17 @@ def forward( outputs = {} A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split( + zxbc, + [ + self.d_inner, + self.d_xb, + self.d_xb, + self.d_inner, + ], + dim=-1, + ) x = rearrange(x, "b l d -> b d l") z = rearrange(z, "b l d -> b d l") @@ -944,7 +952,7 @@ def forward( B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) # B, L, d_inner dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: From 752274be515a8ee417d454f9e2f6032427327489 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 24 Jul 2025 15:46:41 +0000 Subject: [PATCH 117/161] fix conversion --- fast_llm/models/gpt/conversion.py | 5 +- fast_llm/models/ssm/conversion.py | 4 +- .../configuration_llava_hybrid.py | 7 ++ .../llava_hybrid/modeling_llava_hybrid.py | 105 +++++++++++++++++- tests/models/test_checkpoint.py | 2 +- 5 files changed, 115 insertions(+), 8 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 2f4d9b61b..f376db3ff 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -872,6 +872,7 @@ def num_layers(self) -> int: class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "LlavaForConditionalGeneration" _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @classmethod @@ -912,9 +913,7 @@ def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetad @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ - ConstantExportParamConverter( - export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] - ), + ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]), MappedConfigParamConverter( fast_llm_names=(("vision_encoder", "adapter_activation_type"),), export_names=(("projector_hidden_act",),), diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index db960f0fc..c765c956b 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -808,6 +808,7 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class LlavaHybridHuggingfaceCheckpointHandler(CustomModelingExportMixin, LlavaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaHybridHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "LlavaHybridForConditionalGeneration" modeling_file = modeling_llava_hybrid.__file__ configuration_file = configuration_llava_hybrid.__file__ configuration_cls: typing.ClassVar[type[PretrainedConfig]] = configuration_llava_hybrid.LlavaHybridConfig @@ -825,14 +826,13 @@ def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "LlavaHybridForConditionalGeneration" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), export_value={ "AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig", "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", - "AutoModelForConditionalGeneration": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + "AutoModelForVision2Seq": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", }, ), ] diff --git a/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py index 09e17a92b..b8e822d9f 100644 --- a/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py +++ b/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py @@ -57,6 +57,7 @@ def __init__( text_config=None, image_token_index=32000, projector_hidden_act="gelu", + projector_intermediate_size=4096, vision_feature_select_strategy="default", vision_feature_layer=-2, image_seq_length=576, @@ -65,6 +66,8 @@ def __init__( ): self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act + # projector_intermediate_size is an addition to the original Llava config + self.projector_intermediate_size = projector_intermediate_size self.image_seq_length = image_seq_length if vision_feature_select_strategy not in ["default", "full"]: @@ -96,6 +99,7 @@ def __init__( self.vision_config = vision_config if isinstance(text_config, dict): + # Load the custom SSM hybrid config if specified if text_config.get("model_type") == "apriel_ssm_thinker_hybrid": text_config = AprielSSMHybridConfig(**text_config) else: @@ -108,3 +112,6 @@ def __init__( self.multimodal_projector_bias = multimodal_projector_bias super().__init__(**kwargs) + + +__all__ = ["LlavaHybridConfig"] diff --git a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py index 78e390a17..6917fea93 100644 --- a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py +++ b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py @@ -1,7 +1,31 @@ from torch import nn -from transformers import LlavaForConditionalGeneration, LlavaModel +from transformers import AutoModel, LlavaForConditionalGeneration, LlavaModel +from transformers.activations import ACT2FN from .configuration_llava_hybrid import LlavaHybridConfig +from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel, HybridMambaAttentionDynamicCache + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: LlavaHybridConfig): + super().__init__() + # We have hidden_size * the number of vision feature layers + num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * num_feature_layers, + config.projector_intermediate_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.projector_intermediate_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states class LlavaHybridModel(LlavaModel): @@ -9,13 +33,90 @@ class LlavaHybridModel(LlavaModel): Llava SSM-Hybrid-decoder model. """ + config_class = LlavaHybridConfig + def __init__(self, config: LlavaHybridConfig): - super().__init__(config) + super(LlavaModel, self).__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = LlavaMultiModalProjector(config) + assert ( + config.text_config.model_type == "apriel_ssm_thinker_hybrid" + ), "Only Apriel SSM Hybrid model is supported in LlavaHybridModel" + # from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel + self.language_model = AprielThinkerSSMHybridModel(config.text_config) + self.post_init() class LlavaHybridForConditionalGeneration(LlavaForConditionalGeneration): + config_class = LlavaHybridConfig + def __init__(self, config: LlavaHybridConfig): super(LlavaForConditionalGeneration, self).__init__(config) self.model = LlavaHybridModel(config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Copy of the method from `AprielThinkerSSMHybridForCausalLM` + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + # Copy from `LlavaForConditionalGeneration.prepare_inputs_for_generation` + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + # "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 05acf23dc..b1e9e74f0 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -309,7 +309,7 @@ def test_huggingface_model(model_testing_config, get_convert_path): errors = [] auto_model = ( transformers.AutoModel - if model_testing_config.name in ("diffusion_llama", "dream") + if model_testing_config.name in ("diffusion_llama", "dream", "vision_hybrid_mamba2") else transformers.AutoModelForCausalLM ) model_as_hf = auto_model.from_pretrained( From 50083ba88a0bfa58747d2bc8307814b62af1a79a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 15:14:13 -0400 Subject: [PATCH 118/161] SSM debugging --- Megatron-LM | 2 +- fast_llm/engine/multi_stage/stage_base.py | 2 + fast_llm/layers/language_model/head.py | 16 ++- fast_llm/layers/ssm/config.py | 34 +++--- fast_llm/layers/ssm/discrete_mamba2.py | 23 ++-- fast_llm/layers/ssm/llamba_block.py | 29 +++-- fast_llm/layers/ssm/mamba2.py | 38 ++++-- fast_llm/layers/ssm/mamba_layer.py | 36 +++--- fast_llm/layers/transformer/attention.py | 72 +++-------- .../layers/transformer/mixture_of_experts.py | 6 +- fast_llm/layers/transformer/mlp.py | 10 +- fast_llm/layers/transformer/transformer.py | 94 ++++++++++++--- fast_llm/logging.py | 2 + fast_llm/models/gpt/model.py | 12 +- fast_llm/models/ssm/config.py | 40 +++---- fast_llm/models/ssm/model.py | 113 +++++------------- setup.cfg | 7 +- tests/data/test_blending.py | 1 + tests/data/test_concatenate.py | 1 + tests/data/test_fim.py | 2 + tests/test_attention.py | 4 +- tests/test_multi_stage.py | 8 +- tests/utils/model_configs.py | 1 + 23 files changed, 271 insertions(+), 282 deletions(-) diff --git a/Megatron-LM b/Megatron-LM index 511e8f5cb..75b0d9787 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 511e8f5cbe3ab8291953ac64e5beceb727a1b814 +Subproject commit 75b0d97876006c4b6b23fce302100d18dbf7db37 diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 2f18f1360..9a8ce2092 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -191,6 +191,8 @@ def initialize_weights(self) -> None: # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) meta.init_parameter(global_param, distributed=self._distributed) + # It happens. + Assert.eq(global_param.shape, global_shape) if self._mode.on_device: parameter.copy_(fsdp.parameter_global_to_shard(global_param, meta.tensor_name)) elif self._mode.on_device: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 25fc2b28d..21bf3bbd0 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -125,12 +125,16 @@ def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: if isinstance(input_, TensorMeta): - return TensorMeta.from_tensor_space( - (DefaultDimNames.scalar,), - self._tensor_space, - tensor_name="Loss", - reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa - ) + if self._is_last_head: + return TensorMeta.from_tensor_space( + (DefaultDimNames.scalar,), + self._tensor_space, + tensor_name="Loss", + reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + ) + else: + return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") + if not self._is_last_head: # MTP: split the stacked input shared_hidden, input_ = torch.unbind(input_, dim=0) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 46d629aa8..a1f357de9 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -23,6 +23,7 @@ class SSMDimNames: # Mamba 2 x_proj_dim_2 = "x_proj_dim_2" # d_xb + c_heads = "c_heads" class SSMBlockType(enum.StrEnum): @@ -35,6 +36,22 @@ class SSMBlockType(enum.StrEnum): mamba2 = "m2" transformer = "t" + def get_mixer_class(self): + if self == SSMBlockType.mamba: + from fast_llm.layers.ssm.mamba_layer import MambaLayer + + return MambaLayer + elif self == SSMBlockType.mamba2: + from fast_llm.layers.ssm.mamba2 import Mamba2 + + return Mamba2 + elif self == SSMBlockType.mamba2_discrete: + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + + return DiscreteMamba2 + else: + raise NotImplementedError(self) + @config_class() class SSMConfig(LLMBlockConfig): @@ -95,11 +112,6 @@ class SSMConfig(LLMBlockConfig): desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", hint=FieldHint.architecture, ) - debug_ssm: bool = Field( - default=False, - desc="debug_ssm", - hint=FieldHint.optional, - ) dt_min: float = Field( default=0.001, desc="Minimum step size for discretization", @@ -147,18 +159,6 @@ class SSMConfig(LLMBlockConfig): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_min: float = Field( - default=0.001, - desc="Minimum step size for discretization", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) dt_scale: float = Field( default=1.0, desc="Scale for dt", diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 934cd2b5d..734e35b21 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,5 +1,6 @@ import logging import math +import typing import einops import torch @@ -7,7 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -36,29 +38,29 @@ def bias_init_method(conv_weight): return init_uniform_(-bound, bound) -class DiscreteMamba2(torch.nn.Module): +class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" + _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): """ See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Other options are all experimental and should not need to be configured. """ # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config bias = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") + logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) @@ -226,9 +228,6 @@ def forward(self, hidden_states, kwargs): out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - if self._return_input: - return torch.stack([input_, outputs["hidden_states"]], dim=0) - # TODO: since we do not support inference for now, we only return the hidden states for now. return outputs["hidden_states"], None diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index ee222d6d2..986606634 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -1,6 +1,6 @@ import typing -from fast_llm.layers.transformer.transformer import BaseBlock +from fast_llm.layers.transformer.transformer import BaseBlock, Mixer if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.tensor_space import TensorSpace @@ -8,27 +8,30 @@ from fast_llm.layers.transformer.config import TransformerConfig -class LlambaBlock(BaseBlock): +class SSMBlock(BaseBlock): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ _name = "Llamba block" - _mixer_module_name = "mixer" def __init__( self, - config_transformer: "TransformerConfig", - config_ssm: "SSMConfig", + transformer_config: "TransformerConfig", + ssm_config: "SSMConfig", tensor_space: "TensorSpace", - mixer_cls, - layer_index: int, + mixer_cls: type[Mixer], + block_index: int, return_input: bool = False, ): - self.mixer_cls = mixer_cls - self._config_ssm = config_ssm - self._debug_mode = self._config_ssm.debug_ssm - super().__init__(config_transformer, tensor_space, layer_index, return_input) + self._ssm_config = ssm_config + self._mixer_cls = mixer_cls + super().__init__(transformer_config, tensor_space, block_index, return_input) - def _create_mixer(self): - self.mixer = self.mixer_cls(self._config_ssm, layer_idx=self._layer_index, tensor_space=self._tensor_space) + def _create_mixer(self) -> Mixer: + return self._mixer_cls( + self._ssm_config, + tensor_space=self._tensor_space, + block_index=self._block_index, + transformer_config=self._config, + ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a03509abb..ead32fa2a 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -7,6 +7,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -43,24 +45,36 @@ def bias_init_method(conv_weight): return init_uniform_(-bound, bound) -class Mamba2(torch.nn.Module): +class Mamba2(Mixer): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ + _mixer_name: typing.ClassVar[str] = "mamba_2" + + _XZ_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.inner_dim, + TransformerDimNames.sequence_q, + ) + _BC_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.c_heads, + SSMDimNames.state_dim, + TransformerDimNames.sequence_q, + ) + def __init__( self, config: SSMConfig, - layer_idx: int, tensor_space: TensorSpace, - return_input: bool = False, + block_index: int, + transformer_config: TransformerConfig, ): - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config bias: bool = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale: float | None = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale: float | tuple[float | None, ...] | None = get_lr_scale( self.config.mamba_lr_scale, layer_lr_scale ) @@ -236,6 +250,13 @@ def forward(self, hidden_states, kwargs): x = repeat_kv(x, self.repeat_group) x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + if self._debug_level: + self._debug_log(z, "z", self._XZ_DIMS, kwargs) + self._debug_log(x, "x", self._XZ_DIMS, kwargs) + self._debug_log(B, "b", self._BC_DIMS, kwargs) + self._debug_log(C, "c", self._BC_DIMS, kwargs) + self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) + y = selective_scan_fn( x, dt, @@ -249,6 +270,9 @@ def forward(self, hidden_states, kwargs): return_last_state=False, ) + if self._debug_level: + self._debug_log(y, "y", self._XZ_DIMS, kwargs) + if ssm_state is not None: y, last_state = y ssm_state.copy_(einops.rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 7c824d235..a95e94c03 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,4 +1,5 @@ import math +import typing from typing import Callable import einops @@ -7,6 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -44,12 +47,12 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_dtprojbias( - d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float, factory_kwargs: dict + d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float ) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - dt = torch.exp( - torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) - ).clamp(min=dt_init_floor) + dt = torch.exp(torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( + min=dt_init_floor + ) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) tensor.copy_(inv_dt) @@ -58,20 +61,18 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return init_ -class MambaLayer(torch.nn.Module): +class MambaLayer(Mixer): + _mixer_name: typing.ClassVar[str] = "mamba" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): - factory_kwargs = {} - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config - self.layer_idx = layer_idx - - self._debug_mode = config.debug_ssm # Tensor dims: td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) @@ -88,7 +89,7 @@ def __init__( self.d_state = td_state.size self.d_model = td_model.size self.dt_rank = tdt_rank.size - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) self.in_proj_weight = ParameterMeta.from_dims( @@ -113,7 +114,6 @@ def __init__( weight_init_method=kaiming_init_(td_inner.size), bias=False, lr_scale=mamba_layer_lr_scale, - **factory_kwargs, ) self.x_proj.weight.auto_grad_accumulation = True @@ -127,7 +127,7 @@ def __init__( self.dt_proj_bias = ParameterMeta.from_dims( (td_inner,), init_method=init_dtprojbias( - self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs + self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor ), lr_scale=mamba_layer_lr_scale, ) @@ -153,10 +153,8 @@ def __init__( bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. weight_init_method=kaiming_init_(td_model.size), lr_scale=mamba_layer_lr_scale, - **factory_kwargs, ) self.out_proj.weight.auto_grad_accumulation = True - self._return_input = return_input def forward(self, hidden_states, kwargs): assert _mamba_available @@ -168,8 +166,6 @@ def forward(self, hidden_states, kwargs): "d (b l) -> b d l", l=seqlen, ) - if self._debug_mode: - print("XZ: ", xz.shape) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat @@ -189,6 +185,4 @@ def forward(self, hidden_states, kwargs): delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) - if self._return_input: - out = torch.stack((hidden_states, out), dim=0) return out, None diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3351c9906..174e19588 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -13,9 +13,9 @@ TransformerKwargs, TransformerSubLayerName, ) -from fast_llm.logging import log_distributed_grad, log_distributed_tensor -from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import init_normal_, init_zeros_ +from fast_llm.utils import get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -50,11 +50,13 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(torch.nn.Module): +class Attention(Mixer): """ A self-attention layer. """ + _mixer_name: typing.ClassVar[str] = "attn" + _QUERY_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, @@ -64,7 +66,7 @@ class Attention(torch.nn.Module): _KV_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, - TransformerDimNames.group_heads, + TransformerDimNames.head_groups, TransformerDimNames.kv_channels, ) _CONTEXT_DIMS = ( @@ -73,19 +75,9 @@ class Attention(torch.nn.Module): TransformerDimNames.composite_dense, ) - def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index, - ): - super().__init__() + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): + super().__init__(tensor_space, block_index, config.debug_transformer) self._config = config - self._tensor_space = tensor_space - # Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) - self._layer_index = layer_index - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._debug_transformer = self._config.debug_transformer self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -108,7 +100,7 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) @@ -178,10 +170,10 @@ def _attn_fused( query, key, beta=0, - alpha=self._softmax_scale / self._layer_index, + alpha=self._softmax_scale / self._block_index, ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) - attn_weights = attn_weights.to(torch.float32) * self._layer_index + attn_weights = attn_weights.to(torch.float32) * self._block_index attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) @@ -200,40 +192,6 @@ def _attn_fused( .flatten(2) ) - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[TransformerKwargs.hidden_dims]} - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) - for dim_name in dim_names - ), - tensor_name=f"transformer layer {self._layer_index} attn {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_transformer, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - def _query_key_value_forward( self, input_: torch.Tensor, sequence_first: bool ) -> tuple[torch.Tensor, torch.Tensor, dict[str, typing.Any]]: @@ -300,7 +258,7 @@ def _decide_window_size(self) -> int | None: # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 # TODO: make universal per layer config window_size = self._config.window_size - if self._config.max_window_layers is not None and self._layer_index < self._config.max_window_layers: + if self._config.max_window_layers is not None and self._block_index < self._config.max_window_layers: window_size = None return window_size @@ -341,7 +299,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) self._debug_log( key, @@ -395,7 +353,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ kwargs[TransformerKwargs.attention_mask_value], ) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query", self._QUERY_DIMS, kwargs) self._debug_log( key, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index a46af1387..73f83ccf5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -40,11 +40,11 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory @@ -59,7 +59,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index b01eb2aa5..efe0c5cc5 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -14,10 +14,10 @@ class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): super().__init__() self._name = name - self._layer_index = layer_index + self._block_index = block_index init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -39,7 +39,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale lr_scale = get_lr_scale(lr_scale, layer_lr_scale) @@ -69,9 +69,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) def forward( self, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 147452073..d08db9a94 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,25 +8,85 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) +class Mixer(torch.nn.Module, abc.ABC): + """ + Base class for mixer modules. + """ + + _mixer_name: typing.ClassVar[str] + + def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): + super().__init__() + self._tensor_space = tensor_space + self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel + self._block_index = block_index + self._debug_level = debug_level + + @abc.abstractmethod + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Mixer module forward. Returns the output hidden states and an optional bias, + in case its addition can be made more efficient in `_bias_dropout_add`. + """ + + def _get_meta( + self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim + for dim in kwargs[TransformerKwargs.hidden_dims] + (kwargs[TransformerKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + for dim_name in dim_names + ), + tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", + dtype=input_.dtype, + ) + + def _debug_log( + self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> None: + # TODO: Local vs global + Assert.gt(self._debug_level, 0) + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + + class BaseBlock(Layer, abc.ABC): """ A transformer-like decoder base block with abstract mixer. """ - _mixer_module_name = "self_attn" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "mixer" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): super().__init__() self._config: TransformerConfig = config @@ -35,18 +95,19 @@ def __init__( # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._layer_index = layer_index + self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) # Note, layer_lr_scale does not impact the norms - # TODO: add a seperate norm_lr_scale + # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - self._create_mixer() + # The mixer needs to be created here for backward-compatible weight ordering. + setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index + self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index ) # PEFT. @@ -54,7 +115,7 @@ def __init__( self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def _create_mixer(self): + def _create_mixer(self) -> Mixer: pass @torch.compile @@ -67,7 +128,7 @@ def _bias_dropout_add( @property def name(self) -> str: - return f"{self._name} {self._layer_index}" + return f"{self._name} {self._block_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[TransformerKwargs.hidden_dims] @@ -137,14 +198,17 @@ def forward( return hidden_states -class TransformerLayer(BaseBlock): +class TransformerBlock(BaseBlock): _name = "Transformer layer" - _mixer_module_name = "self_attn" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): - super().__init__(config, tensor_space, layer_index, return_input) + super().__init__(config, tensor_space, block_index, return_input) + + def _create_mixer(self) -> Mixer: + from fast_llm.layers.transformer.attention import Attention - def _create_mixer(self): - self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + return Attention(self._config, self._tensor_space, self._block_index) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index e8334de6e..6d555a0bb 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -138,6 +138,8 @@ def log_tensor[ if level < 1: return tensor = tensor.detach() + if tensor.ndim == 0: + tensor = tensor[None] save_stats = TensorLogs.config.save shape = tuple(tensor.shape) _, dtype = str(tensor.dtype).split("torch.") diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 444ad72b2..4c1eab46f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -21,7 +21,7 @@ TransformerLossNames, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -68,11 +68,11 @@ def get_output_layers(self) -> list[Layer]: for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, # TODO MTP: which index? - layer_index=max(self._config.transformer.num_layers + i, 1), + block_index=max(self._config.transformer.num_layers + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -91,10 +91,10 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, @@ -336,7 +336,7 @@ def embedding(self) -> LanguageModelEmbedding: return self.layers[0] @property - def transformer_layers(self) -> list[TransformerLayer]: + def transformer_layers(self) -> list[TransformerBlock]: return self.layers[1:-1] @property diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index cc83f11be..9ca0123b2 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -9,9 +9,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -24,14 +23,14 @@ @config_class() -class HybridSSMBaseModelConfig(LanguageModelBaseConfig): +class HybridSSMBaseModelConfig(GPTBaseModelConfig): _abstract = False ssm: SSMConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - hybrid_block_layout: list[str] | None = Field( + hybrid_block_layout: list[SSMBlockType] | None = Field( default=None, desc=f"Pattern of blocks to use in the model. Available types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, @@ -41,9 +40,8 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): desc="Multi-token prediction mixer to use in the model. If None, will use the last block type in `hybrid_block_layout`.", hint=FieldHint.optional, ) - use_megatron_initialization: bool = Field( - default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing - ) # TODO: is this needed? + # TODO: Support combination of different SSM block types. + ssm_block_type: SSMBlockType | None = Field(init=False) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: """ @@ -83,6 +81,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.c_heads, d_inner // self.ssm.state_size)) def _validate(self): with self._set_implicit_default(None): @@ -96,30 +95,21 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers if len(self.hybrid_block_layout) != self.transformer.num_layers: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: - raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - ) - num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) - logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" - ) + raise ValueError(message) + num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) + logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) - Assert.custom( - lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", - ) - Assert.custom( - lambda _: self.default_mtp_type in SSMBlockType.__members__.values() or self.default_mtp_type is None, - f"Invalid MTP type: {self.default_mtp_type}. Must be one of {SSMBlockType.__members__.values()} or None", - ) - super()._validate() + ssm_block_types = set(self.hybrid_block_layout) - {SSMBlockType.transformer} + # TODO: Support combination of different SSM block types. + Assert.leq(len(ssm_block_types), 1) + self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 02a5ac239..89f0cd4aa 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -5,11 +5,8 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.ssm.mamba2 import Mamba2 -from fast_llm.layers.ssm.mamba_layer import MambaLayer -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -31,7 +28,6 @@ def __init__( config: HybridSSMBaseModelConfig, distributed_config: DistributedConfig, ): - self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed super().__init__(config, distributed_config) def get_output_layers(self) -> list[Layer]: @@ -39,52 +35,31 @@ def get_output_layers(self) -> list[Layer]: Get the output layers of the model. This includes the language model head and any additional heads specified in the configuration. """ - layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] + layers: list[Layer] = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] if self._config.prediction_heads > 1: block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] for i in range(1, self._config.prediction_heads): if block_type == SSMBlockType.transformer: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=len(self._config.hybrid_block_layout), + block_index=len(self._config.hybrid_block_layout), return_input=i != self._config.prediction_heads - 1, ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=Mamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + block_index=len(self._config.hybrid_block_layout), + tensor_space=self._tensor_space, + return_input=i != self._config.prediction_heads - 1, + ) + ) layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) return layers @@ -94,63 +69,35 @@ def get_layers(self) -> list[Layer]: Create a list of layers for the model, interleaving Transformer and Mamba blocks according to the block pattern. """ - layers = [LanguageModelEmbedding(self._config, self._tensor_space)] + layers: list[Layer] = [LanguageModelEmbedding(self._config, self._tensor_space)] # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): if block_type == SSMBlockType.transformer: # Transformer block layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 ), ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba: - # Create Mamba block - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=Mamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + block_index=i + 1, + tensor_space=self._tensor_space, + return_input=( + i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 + ), + ) + ) # Add the output layers layers += self.get_output_layers() diff --git a/setup.cfg b/setup.cfg index 843aa15ca..c086af7d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,14 +48,9 @@ HUGGINGFACE = # Required to run SSMs # To install on cpu environment (ex. for IDE support): -# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation +# MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 - cartesia_pytorch>=0.0.2 - -GENERATION = - lm_eval>=0.4.9 - DEV = # Pre-commit git hook diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 438782dfe..3e6c37632 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -193,6 +193,7 @@ def test_gpt_blended_mixed(): def test_gpt_blended_mixed_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index e951cc2b1..4f36cdf89 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -39,6 +39,7 @@ def test_gpt_concatenate(): def test_gpt_concatenate_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 7472f1958..004b96289 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -58,6 +58,7 @@ def test_gpt_fim(): def test_gpt_fim_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { @@ -81,6 +82,7 @@ def test_gpt_fim_data(): def test_gpt_fim_data_legacy(): + get_test_dataset() get_test_data_and_compare_samples( { "format": "list", diff --git a/tests/test_attention.py b/tests/test_attention.py index 87b0d3e59..dd36b840a 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -17,12 +17,12 @@ def test_decide_window_size(): # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 2 + attention._block_index = 2 assert attention._decide_window_size() == 512 # Arrange - Case 2: window_size is None (layer_index < max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 1 + attention._block_index = 1 assert attention._decide_window_size() is None # Arrange - Case 3: max_window_layers is None (always return window_size) diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index c530a170c..2f125717e 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,9 +3,10 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.utils import Assert +from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -23,6 +24,7 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): + get_model_test_dataset() args = model_testing_config.config_args + ["run.tensor_logs.save=False"] model_ref = _get_trainer_from_args(args, model_testing_config.model_type)._multi_stage model_frozen = _get_trainer_from_args( @@ -39,7 +41,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerLayer, LlambaBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, SSMBlock)) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip( diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1eee3675d..42252c620 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -523,6 +523,7 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", + f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", ], megatron_args=None, checkpoint_format=None, From 8c03b54a05ba05324b16d638d10db91db8da0a06 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 24 Jul 2025 20:13:11 +0000 Subject: [PATCH 119/161] use hybrid cache, update test --- .../modeling_ssm_hybrid_apriel15b.py | 33 ++++++++++++++++++- .../llava_hybrid/modeling_llava_hybrid.py | 6 ++-- tests/models/test_checkpoint.py | 11 ++++--- 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index f8f6a0520..bd12243eb 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -18,7 +18,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs, logging +from transformers.utils import LossKwargs, can_return_tuple, logging from transformers.utils.generic import ModelOutput from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig @@ -1209,6 +1209,37 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): # Initialize weights and apply final processing self.post_init() + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache and past_key_values is None: + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + past_key_values = HybridMambaAttentionDynamicCache(self.config, batch_size, self.dtype, device=self.device) + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **flash_attn_kwargs, + ) + class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py index 6917fea93..9896d91d1 100644 --- a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py +++ b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py @@ -3,7 +3,6 @@ from transformers.activations import ACT2FN from .configuration_llava_hybrid import LlavaHybridConfig -from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel, HybridMambaAttentionDynamicCache class LlavaMultiModalProjector(nn.Module): @@ -43,7 +42,8 @@ def __init__(self, config: LlavaHybridConfig): assert ( config.text_config.model_type == "apriel_ssm_thinker_hybrid" ), "Only Apriel SSM Hybrid model is supported in LlavaHybridModel" - # from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel + from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel + self.language_model = AprielThinkerSSMHybridModel(config.text_config) self.post_init() @@ -69,6 +69,8 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): + from .modeling_ssm_hybrid_apriel15b import HybridMambaAttentionDynamicCache + # Copy of the method from `AprielThinkerSSMHybridForCausalLM` # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index b1e9e74f0..73bb24c82 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -307,11 +307,12 @@ def test_huggingface_model(model_testing_config, get_convert_path): ) ) errors = [] - auto_model = ( - transformers.AutoModel - if model_testing_config.name in ("diffusion_llama", "dream", "vision_hybrid_mamba2") - else transformers.AutoModelForCausalLM - ) + if model_testing_config.name in ("diffusion_llama", "dream"): + auto_model = transformers.AutoModel + elif model_testing_config.name == "vision_hybrid_mamba2": + auto_model = transformers.AutoModelForVision2Seq + else: + auto_model = transformers.AutoModelForCausalLM model_as_hf = auto_model.from_pretrained( hf_path, trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code ).cuda() From 7b32699be7c1a1fb29cc7386eb33280b0bc19a5c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 17:28:56 -0400 Subject: [PATCH 120/161] stuff --- fast_llm/layers/ssm/mamba2.py | 57 ++++++++++++++--------------------- fast_llm/models/ssm/config.py | 2 +- tests/utils/model_configs.py | 2 +- 3 files changed, 24 insertions(+), 37 deletions(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index ead32fa2a..b936ccf14 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -7,6 +7,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ @@ -97,9 +98,9 @@ def __init__( if self.repeat_kv_before_conv: self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("1", 1), td_conv_kernel), + (td_inner, td_conv_kernel), init_method=init_uniform_( - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), + -1 / math.sqrt(td_inner.size * td_conv_kernel.size), 1 / math.sqrt(td_inner.size * td_conv_kernel.size), ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 lr_scale=mamba_layer_lr_scale, @@ -110,9 +111,9 @@ def __init__( ) else: self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, TensorDim("1", 1), td_conv_kernel), + (td_xb, td_conv_kernel), init_method=init_uniform_( - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), + -1 / math.sqrt(td_xb.size * td_conv_kernel.size), 1 / math.sqrt(td_xb.size * td_conv_kernel.size), ), ) @@ -133,7 +134,13 @@ def __init__( weight_init_method=kaiming_init_(td_model.size), lr_scale=mamba_layer_lr_scale, ) - + self.dt_in_proj = Linear( + td_model, + tdt_rank, + bias=config.add_bias_linear, + weight_init_method=kaiming_init_(transformer_config.hidden_size), + lr_scale=mamba_layer_lr_scale, + ) # Initialize special dt projection to preserve variance at initialization dt_scale = config.dt_scale # 1.0 dt_init_std = self.dt_rank**-0.5 * dt_scale @@ -144,24 +151,6 @@ def __init__( else: raise NotImplementedError - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt_max = config.dt_max # or 0.1 - dt_min = config.dt_min # or 0.001 - dt_init_floor = config.dt_init_floor # or 1e-4 - dt = torch.exp(torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( - min=dt_init_floor - ) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - - def init_from_tensor_( - value: torch.Tensor, - ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.copy_(value) - - return init_ - self.dt_proj = Linear( tdt_rank, td_inner, @@ -171,18 +160,16 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) ) # define bias outside the linear layer since its also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), init_method=init_from_tensor_(inv_dt), lr_scale=mamba_layer_lr_scale + (td_inner,), + init_method=init_dtprojbias( + self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor + ), + lr_scale=mamba_layer_lr_scale, ) - A = einops.repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A).flatten() # Keep A_log in fp32 self.A_log = ParameterMeta.from_dims( (td_inner, td_state), - init_method=init_from_tensor_(A_log), + init_method=init_A(self.config.state_size, self.config.d_inner), lr_scale=mamba_layer_lr_scale, weight_decay=False, ) @@ -214,8 +201,8 @@ def forward(self, hidden_states, kwargs): A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split(zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1) x = einops.rearrange(x, "b l d -> b d l") z = einops.rearrange(z, "b l d -> b d l") @@ -225,7 +212,7 @@ def forward(self, hidden_states, kwargs): B = einops.rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = einops.rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) + self.dt_proj_bias # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) + self.dt_proj_bias # B, L, d_inner dt = einops.rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: @@ -238,7 +225,7 @@ def forward(self, hidden_states, kwargs): if _causal_conv1d_available: x = _causal_conv1d_fn( x=x, - weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + weight=self.conv1d_weight, bias=self.conv1d_bias, activation=self.activation, ) # B, L, D diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 9ca0123b2..b04b1f210 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -78,7 +78,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_discrete_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) elif SSMBlockType.mamba2.value in self.hybrid_block_layout: - inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank + inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner # + self.ssm.dt_rank tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.c_heads, d_inner // self.ssm.state_size)) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 42252c620..4976ad2b1 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -523,7 +523,7 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", - f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", ], megatron_args=None, checkpoint_format=None, From 1feccc866c1dea2da66567476fc911a37a855038 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 17:48:23 -0400 Subject: [PATCH 121/161] stuff --- fast_llm/layers/ssm/mamba2.py | 2 +- fast_llm/layers/ssm/mamba_layer.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 88fe4abc0..fdba10beb 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -111,7 +111,7 @@ def __init__( sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - # define bias outside the linear layer since its also used in the selective_scan_fn + # define bias outside the linear layer since it's also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( (inner_dim,), init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 49afa910e..11db37910 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -48,9 +48,7 @@ def init_dtprojbias( dt_max: float, dt_min: float, dt_init_floor: float ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - tensor = ( - tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min(dt_init_floor) - ) + tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min_(dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 return tensor.add_(torch.log(-torch.expm1(-tensor))) From d51f817c34097792d50ff7a6696092fa620fad9c Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 24 Jul 2025 21:58:25 +0000 Subject: [PATCH 122/161] finish ssm-hybrid conversion --- fast_llm/models/ssm/config.py | 1 + fast_llm/models/ssm/conversion.py | 40 +++++++++++++------ .../modeling_ssm_hybrid_apriel15b.py | 8 +++- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 41f4eadbe..70362a40e 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -158,6 +158,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): support_optimizer: typing.ClassVar[bool] = False name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" + trust_remote_code: typing.ClassVar[bool] = True @classmethod def get_handler_class(cls) -> type[CheckpointHandler]: diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index c765c956b..640615e0e 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -727,6 +727,7 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( + CustomModelingExportMixin, HybridModelCheckpointHandler, # handles the block structure parameter CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers @@ -741,6 +742,11 @@ class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( _default_block_type: str = SSMBlockType.mamba2_discrete.value _hf_prefix: str = "model" architecture: typing.ClassVar[str] = "AprielThinkerSSMHybridForCausalLM" + modeling_file = modeling_ssm_hybrid_apriel15b.__file__ + configuration_file = configuration_ssm_hybrid_apriel15b.__file__ + configuration_cls: typing.ClassVar[type[PretrainedConfig]] = ( + configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig + ) def _create_weight_converters( self, @@ -767,6 +773,14 @@ def _create_weight_converters( @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig", + "AutoModel": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridModel", + "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridForCausalLM", + }, + ), RenameParamConverter( fast_llm_names=(("ssm", "d_inner"),), export_names=(("ssm_cfg", "d_inner"),), @@ -791,19 +805,19 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ), ] - @classmethod - def _load_config(cls, directory: pathlib.Path | str) -> dict: - if not os.path.exists(directory / "config.json"): - raise FileNotFoundError(f"config.json not found in {directory}") - with open(directory / "config.json") as f: - config = json.load(f) - Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - return config - - @classmethod - def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - with open(directory / "config.json", "w") as f: - json.dump(config, f) + # @classmethod + # def _load_config(cls, directory: pathlib.Path | str) -> dict: + # if not os.path.exists(directory / "config.json"): + # raise FileNotFoundError(f"config.json not found in {directory}") + # with open(directory / "config.json") as f: + # config = json.load(f) + # Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + # return config + + # @classmethod + # def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + # with open(directory / "config.json", "w") as f: + # json.dump(config, f) class LlavaHybridHuggingfaceCheckpointHandler(CustomModelingExportMixin, LlavaHuggingfaceCheckpointHandler): diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index bd12243eb..da7984c70 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -357,7 +357,13 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx if len(self.key_cache) <= layer_idx: return 0 - return self.key_cache[layer_idx].shape[-2] + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or not self.key_cache[layer_idx].numel() # the layer has no cache + ) + return self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + # return self.key_cache[layer_idx].shape[-2] def reset(self): self.conv_states.zero_() From e528b50ba5c5e2ea726876779db010f83fccd8ef Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:00:20 -0400 Subject: [PATCH 123/161] misc --- fast_llm/layers/ssm/discrete_mamba2.py | 4 ++-- fast_llm/layers/ssm/mamba2.py | 12 ++++++++---- fast_llm/layers/ssm/mamba_layer.py | 10 +++++++--- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index b95ff76da..fdce9bf63 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -5,7 +5,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs @@ -97,7 +97,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, TensorDim("1", 1), td_conv_kernel), + (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_centered_((td_conv.size * td_conv_kernel.size) ** -0.5), lr_scale=mamba_layer_lr_scale, ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index fdba10beb..8be9dcb9b 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,7 +3,7 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames @@ -75,7 +75,11 @@ def __init__( conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim self.conv1d_weight = ParameterMeta.from_dims( - (conv1d_dim, tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel)), + ( + conv1d_dim, + tensor_space.get_tensor_dim(DefaultDimNames.scalar), + tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel), + ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, ) @@ -168,9 +172,9 @@ def forward(self, hidden_states, kwargs): .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias.squeeze(1), activation="silu") else: - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias.squeeze(1), activation="silu") x = ( x.unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 11db37910..07eec38e6 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -4,7 +4,7 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames @@ -87,7 +87,11 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.conv_kernel)), + ( + inner_dim, + tensor_space.get_tensor_dim(DefaultDimNames.scalar), + tensor_space.get_tensor_dim(SSMDimNames.conv_kernel), + ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, ) @@ -146,7 +150,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s out = _mamba_inner_fn( in_proj, - self.conv1d_weight.unsqueeze(1), + self.conv1d_weight, None, self.x_proj.weight, self.dt_proj_weight, From f898ff25648a66f8cc8bf5f44c1d43c3ffe3fa34 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 24 Jul 2025 22:00:28 +0000 Subject: [PATCH 124/161] fix architecture classvar --- fast_llm/models/gpt/conversion.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index f376db3ff..fb1801067 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -368,10 +368,10 @@ def _get_weight_and_bias_converters( class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = Starcoder2GPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "Starcoder2ForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "Starcoder2ForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "rotary", "type"),), @@ -495,10 +495,10 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlamaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "LlamaForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "LlamaForCausalLM" return super()._create_config_converters() + [ # TODO: Llama supports biases ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), @@ -547,10 +547,10 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A class Qwen2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = Qwen2GPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "Qwen2ForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "Qwen2ForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "normalization", "type"),), @@ -593,10 +593,10 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig class MistralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MistralGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "MistralForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "MistralForCausalLM" return super()._create_config_converters() + [ IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), ] @@ -1014,10 +1014,10 @@ def _create_weight_converters(self): class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "MixtralForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "MixtralForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "expert_routing_type"),), fast_llm_value=RoutingType.topk @@ -1055,13 +1055,13 @@ class MTPLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlam from fast_llm.models.gpt.external.mtp_llama import configuration_mtp_llama, modeling_mtp_llama format: typing.ClassVar[type[CheckpointFormat]] = MTPLlamaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "MTPLlamaForCausalLM" modeling_file = modeling_mtp_llama.__file__ configuration_file = configuration_mtp_llama.__file__ configuration_cls: typing.ClassVar[type[PretrainedConfig]] = MTPLlamaConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "MTPLlamaForCausalLM" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), @@ -1143,6 +1143,7 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen from fast_llm.models.gpt.external.diffusion_dream import configuration_dream, generation_utils, modeling_dream format: typing.ClassVar[type[CheckpointFormat]] = DiffusionDreamGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "DreamModel" modeling_file = modeling_dream.__file__ configuration_file = configuration_dream.__file__ generation_utils_file = generation_utils.__file__ @@ -1150,7 +1151,6 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "DreamModel" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), @@ -1171,6 +1171,7 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Llam ) format: typing.ClassVar[type[CheckpointFormat]] = DiffusionLlamaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "DiffusionLlamaModel" modeling_file = modeling_diffusion_llama.__file__ configuration_file = configuration_diffusion_llama.__file__ generation_utils_file = generation_utils.__file__ @@ -1178,7 +1179,6 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Llam @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "DiffusionLlamaModel" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), From c447fc3857800484929fc6cfb80ce879f201210c Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 24 Jul 2025 22:05:33 +0000 Subject: [PATCH 125/161] add llava test and m2 conversion test --- tests/models/test_checkpoint.py | 2 +- tests/utils/model_configs.py | 46 ++++++++++++++++++++++++++++++--- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 73bb24c82..da719a42d 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -309,7 +309,7 @@ def test_huggingface_model(model_testing_config, get_convert_path): errors = [] if model_testing_config.name in ("diffusion_llama", "dream"): auto_model = transformers.AutoModel - elif model_testing_config.name == "vision_hybrid_mamba2": + elif model_testing_config.name in ("llava", "vision_hybrid_mamba2"): auto_model = transformers.AutoModelForVision2Seq else: auto_model = transformers.AutoModelForCausalLM diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index c2bb68003..643ca6c27 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -13,13 +13,18 @@ DiffusionDreamGPTHuggingfaceCheckpointFormat, DiffusionLlamaGPTHuggingfaceCheckpointFormat, LlamaGPTHuggingfaceCheckpointFormat, + LlavaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat, LlavaHybridHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import ( + AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + LLambaHuggingfaceCheckpointFormat, + LlavaHybridHuggingfaceCheckpointFormat, +) from tests.utils.dataset import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from tests.utils.distributed_configs import DistributedTestingConfig @@ -450,6 +455,41 @@ def _update_and_add_testing_config( compare_factor=2.0, ) +_update_and_add_testing_config( + # Tests hybrid ssm, llamba converter. + "llama", + "llava", + extra_args=[ + "batch.max_image_size=128", + "model.base_model.vision_encoder.type=pixtral", + "model.base_model.vision_encoder.patch_norm.type=rms_norm", + "model.base_model.vision_encoder.transformer.add_linear_biases=False", + "model.base_model.vision_encoder.transformer.causal=False", + "model.base_model.vision_encoder.transformer.normalization.type=rms_norm", + "model.base_model.vision_encoder.transformer.type=image_encoder", + "model.base_model.vision_encoder.transformer.gated=True", + "model.base_model.vision_encoder.transformer.num_layers=2", + "model.base_model.vision_encoder.transformer.hidden_size=256", + "model.base_model.vision_encoder.transformer.num_attention_heads=8", + "model.base_model.vision_encoder.transformer.head_groups=8", + "model.base_model.vision_encoder.transformer.init_method_std=0.022", + "model.base_model.vision_encoder.transformer.rotary.type=rope_2d", + "model.base_model.vision_encoder.adapter_size=256", + "model.distributed.training_dtype=torch.bfloat16", + ], + megatron_args=None, + checkpoint_format=LlavaGPTHuggingfaceCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, + compare_factor=8.0, +) + _update_and_add_testing_config( # Tests hybrid ssm, llamba converter. "llama", @@ -510,11 +550,11 @@ def _update_and_add_testing_config( "model.base_model.hybrid_block_layout=['t','m2']", ], megatron_args=None, - checkpoint_format=None, + checkpoint_format=AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, From b49c42febac4f32dc1be83655b242d6199a385bc Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:16:42 -0400 Subject: [PATCH 126/161] misc --- fast_llm/layers/ssm/discrete_mamba2.py | 4 ++-- fast_llm/layers/ssm/mamba2.py | 8 ++++---- fast_llm/layers/ssm/mamba_layer.py | 4 ++-- .../modeling_ssm_hybrid_apriel15b.py | 20 +++++++++++++------ tests/utils/model_configs.py | 1 - 5 files changed, 22 insertions(+), 15 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 734e35b21..c0ae7e781 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -5,7 +5,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs @@ -103,7 +103,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, TensorDim("1", 1), td_conv_kernel), + (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b936ccf14..74c212add 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -4,7 +4,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias @@ -98,7 +98,7 @@ def __init__( if self.repeat_kv_before_conv: self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, td_conv_kernel), + (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_inner.size * td_conv_kernel.size), 1 / math.sqrt(td_inner.size * td_conv_kernel.size), @@ -111,7 +111,7 @@ def __init__( ) else: self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, td_conv_kernel), + (td_xb, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_xb.size * td_conv_kernel.size), 1 / math.sqrt(td_xb.size * td_conv_kernel.size), @@ -225,7 +225,7 @@ def forward(self, hidden_states, kwargs): if _causal_conv1d_available: x = _causal_conv1d_fn( x=x, - weight=self.conv1d_weight, + weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), bias=self.conv1d_bias, activation=self.activation, ) # B, L, D diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index a95e94c03..4493332ce 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -5,7 +5,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig @@ -98,7 +98,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), + (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=kaiming_init_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index f8f6a0520..4fde72458 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -843,9 +843,8 @@ def __init__( self.num_C_head = self.d_inner // self.d_state self.repeat_group = self.num_C_head // self.num_xb_head - self.in_proj = nn.Linear( - self.d_model, 2 * self.d_xb + 2 * self.d_inner + self.dt_rank, bias=bias, **factory_kwargs - ) + self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs) + self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization @@ -933,8 +932,17 @@ def forward( outputs = {} A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split( + zxbc, + [ + self.d_inner, + self.d_xb, + self.d_xb, + self.d_inner, + ], + dim=-1, + ) x = rearrange(x, "b l d -> b d l") z = rearrange(z, "b l d -> b d l") @@ -944,7 +952,7 @@ def forward( B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) # B, L, d_inner dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 4976ad2b1..1eee3675d 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -523,7 +523,6 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", - # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", ], megatron_args=None, checkpoint_format=None, From c1b7f44a10ff379a067b10b76df296f3bee4cac1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:19:08 -0400 Subject: [PATCH 127/161] misc --- .../models/ssm/external/llamba/modeling_mtp_llamba.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py index 8f49ded40..6d9746db1 100644 --- a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py +++ b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py @@ -322,21 +322,19 @@ def __init__(self, config, factory_kwargs, layer_idx, **kwargs): # Mixer self.mixer = DiscreteMamba2( - d_model=self.config._hidden_size, + d_model=self.config.d_model, layer_idx=layer_idx, **config.ssm_cfg, **factory_kwargs, ) # Other components - self.input_layernorm = LlamaRMSNorm( - hidden_size=self.config._hidden_size, eps=1e-5, factory_kwargs=factory_kwargs - ) + self.input_layernorm = LlamaRMSNorm(hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs) self.post_attention_layernorm = LlamaRMSNorm( - hidden_size=self.config._hidden_size, eps=1e-5, factory_kwargs=factory_kwargs + hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs ) self.mlp = LlamaMLP( - hidden_size=self.config._hidden_size, + hidden_size=self.config.d_model, **config.mlp_cfg, factory_kwargs=factory_kwargs, ) From 31f5d415ef0c7eeca54a26d415076cbf3ba33cfd Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:20:26 -0400 Subject: [PATCH 128/161] misc --- fast_llm/models/ssm/conversion.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index d57300252..43e3c67e5 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,6 +3,7 @@ import pathlib import typing +from fast_llm.config import MISSING from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, @@ -19,7 +20,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig -from fast_llm.layers.ssm.config import SSMBlockType +from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, @@ -42,11 +43,11 @@ class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - block_converter = RenameParamConverter( + block_converter = MappedConfigParamConverter( fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),), - ignore_missing=True, - default_value=[cls._default_block_type], + fast_llm_value=lambda x: [cls._default_block_type] if x == MISSING else x, + export_value=lambda x: [x_.value for x_ in x], ) return super()._create_config_converters() + [block_converter] @@ -202,7 +203,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ignore_missing=True, default_value=4, ), - RenameParamConverter( + MappedConfigParamConverter( fast_llm_names=(("ssm", "dt_init"),), export_names=( ( @@ -210,8 +211,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: "dt_init", ), ), - ignore_missing=True, - default_value="random", + fast_llm_value=lambda x: DTInitType.random if x == MISSING else DTInitType(x), + export_value=lambda x: x.value, ), ] @@ -258,6 +259,9 @@ def _create_weight_converters(self) -> list[WeightConverter]: ) # ================================================ # Mamba2 specific parameters + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.dt_in_proj", f"model.layers.{i}.mixer.dt_in_proj", ssm_bias + ) converters += self._get_weight_and_bias_converters( f"layers.{i+1}.mixer.dt_proj", f"model.layers.{i}.mixer.dt_proj", False ) From 0a9ff25f6e0a699caef881dfcaeef0b19f825764 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:22:24 -0400 Subject: [PATCH 129/161] misc --- fast_llm/models/ssm/config.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 6b9e28584..d2a69303c 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -40,9 +40,6 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): desc="Multi-token prediction mixer to use in the model. If None, will use the last block type in `hybrid_block_layout`.", hint=FieldHint.optional, ) - use_megatron_initialization: bool = Field( - default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing - ) # TODO: is this needed? # TODO: Support combination of different SSM block types. ssm_block_type: SSMBlockType | None = Field(init=False) From e7d9636819ab83df7204cc2b021fd4565188e946 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 19:55:53 -0400 Subject: [PATCH 130/161] Parallel discrete mamba 2 --- fast_llm/layers/ssm/config.py | 12 +- fast_llm/layers/ssm/discrete_mamba2.py | 212 ++++++++++--------------- fast_llm/layers/ssm/mamba2.py | 6 +- 3 files changed, 95 insertions(+), 135 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 15a6a8210..7f0b3cf61 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -211,23 +211,25 @@ def _validate(self) -> None: def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - num_heads = div(self.d_inner, self.state_size) # Head groups are configured differently depending on the block type. if block_type == SSMBlockType.mamba: + num_heads = div(self.d_inner, self.state_size) num_head_groups = num_heads elif block_type == SSMBlockType.mamba2: + num_heads = div(self.d_inner, self.state_size) num_head_groups = div(self.d_xb, self.state_size) elif block_type == SSMBlockType.mamba2_discrete: - Assert.eq(num_heads, self.n_v_heads) + # TODO: Use different variables? + num_heads = self.n_v_heads num_head_groups = self.n_qk_heads + # v_heads have size `headdim` that may be different from `state_size`. + Assert.multiple(self.d_inner, num_heads) else: raise NotImplementedError(block_type) tensor_space.add_tensor_dim(state_dim := TensorDim(SSMDimNames.state, self.state_size)) tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) - tensor_space.add_tensor_dim( - group_heads := TensorDim(SSMDimNames.group_heads, num_group_heads := div(num_heads, num_head_groups)) - ) + tensor_space.add_tensor_dim(group_heads := TensorDim(SSMDimNames.group_heads, div(num_heads, num_head_groups))) tensor_space.add_tensor_dim( heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index fdce9bf63..ac4fb87cc 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,12 +1,12 @@ import logging -import math import typing import einops import torch from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace -from fast_llm.layers.common.linear import Linear +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer @@ -32,12 +32,6 @@ _causal_conv1d_available = False -def bias_init_method(conv_weight): - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_centered_(bound) - - class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" @@ -51,198 +45,162 @@ def __init__( transformer_config: TransformerConfig, ): super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) - self.config: SSMConfig = config + self._config: SSMConfig = config layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") - - td_inner = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state) - td_model = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.head_groups) - td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection) - - self.d_model = td_model.size - self.d_inner = td_inner.size - self.d_state = td_state.size - self.chunk_size = config.chunk_size - self.n_qk_heads = td_n_qk_heads.size - self.n_v_heads = td_n_v_heads.size - self.conv_kernel_size = td_conv_kernel.size - - self.act = config.activation_type.activation_fn - self.activation_name = config.activation_type.name + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) + conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) + heads_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) + + self._local_heads = heads_dim.size + self._local_head_groups = tensor_space.get_tensor_dim(SSMDimNames.head_groups).size + self._local_inner_size = inner_dim.size + self._local_bc_size = tensor_space.get_tensor_dim(SSMDimNames.composite_head_groups_and_state).size # TODO: double check initializations # Projections - self.in_proj = Linear( - td_model, - td_inner_proj, + self.in_proj = OutputParallelLinear( + hidden_dim, + tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), bias=config.add_bias_linear, - weight_init_method=init_kaiming_(td_model.size), - lr_scale=mamba_layer_lr_scale, + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - self.z_bias = ( - ParameterMeta.from_dims( - (td_inner,), + if not config.add_bias_linear: + self.z_bias = ParameterMeta.from_dims( + (inner_dim,), weight_decay=False, init_method=init_zeros_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - if not config.add_bias_linear - else 0.0 - ) - self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), - init_method=init_uniform_centered_((td_conv.size * td_conv_kernel.size) ** -0.5), - lr_scale=mamba_layer_lr_scale, + ( + conv1d_dim, + tensor_space.get_tensor_dim(DefaultDimNames.scalar), + tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel), + ), + init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), + lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale + (conv1d_dim,), + init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + lr_scale=lr_scale, ) - # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_n_v_heads,), + (heads_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - - # out_proj - self.out_proj = Linear( - td_inner, - td_model, + self.out_proj = InputParallelLinear( + inner_dim, + hidden_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(td_inner.size), - lr_scale=mamba_layer_lr_scale, + weight_init_method=init_kaiming_(self._config.d_inner), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - ON variable names and pep8: keeping some variable names as in the original code for clarity. - - Args: - u: (B, L, D), - - Returns: - outputs: dict. - outputs["hidden_states"]: (B, L, D). - outputs["state"]: inference cache. - """ if kwargs[TransformerKwargs.sequence_first]: raise NotImplementedError(f"Sequence-first not supported for SSMs.") assert _mamba_available - outputs = {} - # assert state is None - batch, seqlen, dim = input_.shape - - state = None - # Hacky way to initialize state during inference - chunk_size = self.chunk_size if state is None else seqlen + sequence_length = input_.size(0 if kwargs[TransformerKwargs.sequence_first] else 1) # Pad input to nearest multiple of chunklen - padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = torch.nn.functional.pad(input_, (0, 0, 0, padded_len - seqlen)) + padded_length = (1 + (sequence_length - 1) // self._config.chunk_size) * self._config.chunk_size + if padded_length != sequence_length: + assert not kwargs[TransformerKwargs.sequence_first] and not self._sequence_parallel + input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) - # Project input - xBCzA_log = self.in_proj(u) + inner_projection = self.in_proj(input_) + # Standardize to (batch, sequence, inner_projection) + if kwargs[TransformerKwargs.sequence_first]: + inner_projection = inner_projection.transpose(0, 1) - ( - xBC, - z, - A_log, - ) = torch.split( - xBCzA_log, + xBC, z, A_log = torch.split( + inner_projection, [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, + self._local_inner_size + 2 * self._local_bc_size, + self._local_inner_size, + self._local_heads, ], dim=-1, ) - if state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead torch.nn.functional.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = einops.rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - state["conv"].copy_( - torch.nn.functional.pad(xBC_t, (self.conv_kernel_size - xBC_t.shape[-1], 0)) - ) # Update state (B D W) - # Convolutional layer - xBC = self.convolutional_forward(xBC, padded_len) + xBC = self.convolutional_forward(xBC, sequence_length) x, B, C = torch.split( xBC, [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, + self._local_inner_size, + self._local_bc_size, + self._local_bc_size, ], dim=-1, ) - x = einops.rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) - B = einops.rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) - C = einops.rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + x = einops.rearrange(x, "b l (h n) -> b l h n", h=self._local_heads) + B = einops.rearrange(B, "b l (h n) -> b l h n", h=self._local_head_groups) + C = einops.rearrange(C, "b l (h n) -> b l h n", h=self._local_head_groups) # SSM forward - result = _mamba_chunk_scan_combined( + y = _mamba_chunk_scan_combined( x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1), dt=A_log, dt_softplus=True, - A=-torch.ones(self.n_v_heads, device=A_log.device), + A=-torch.ones(self._local_heads, device=A_log.device), B=B, C=C, - chunk_size=chunk_size, - # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(state is not None), + chunk_size=self._config.chunk_size, + return_final_states=False, ) - - if state is not None: - y, ssm_state = result - state["ssm"].copy_(ssm_state) - else: - y = result - Du = torch.einsum("h,blhp->blhp", self.D, x) - y = einops.rearrange(y + Du, "b l h p -> b l (h p)") # Norm and gate - out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :].contiguous() + if not self._config.add_bias_linear: + z = z + self.z_bias - # TODO: since we do not support inference for now, we only return the hidden states for now. - return outputs["hidden_states"], None + # y: (batch, sequence, heads, state) -> (batch, sequence, heads * state) + y = ((y + Du).flatten(2, 3) * torch.nn.functional.silu(z))[:, :sequence_length] + if kwargs[TransformerKwargs.sequence_first]: + # TODO: Is contiguous needed? + y = y.transpose(0, 1).contiguous() + return self.out_proj(y) def convolutional_forward(self, xBC, padded_len): """Convolutional layer forward pass for the full sequence.""" - if _causal_conv1d_available and self.activation_name in ( - "silu", + if _causal_conv1d_available and self._config.activation_type in ( + ActivationType.silu, "swish", - "identity", + ActivationType.identity, ): xBC = _causal_conv1d_fn( xBC.transpose(1, 2), einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), self.conv1d_bias, - activation=None if self.activation_name == "identity" else self.activation_name, + activation=( + None + if self._config.activation_type == ActivationType.identity + else self._config.activation_type.value + ), ).transpose(1, 2) else: - xBC = self.act( + xBC = self._config.activation_type.activation_fn( torch.nn.functional.conv1d( xBC.transpose(1, 2), self.conv1d_weight, bias=self.conv1d_bias, groups=self.conv1d_weight.shape[0], - padding=self.conv_kernel_size - 1, + padding=self._config.conv_kernel_dimension - 1, )[..., :padded_len].transpose(1, 2) ) return xBC diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 8be9dcb9b..cba28f8b8 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -142,12 +142,12 @@ def __init__( # TODO: lr_scale? ) - def forward(self, hidden_states, kwargs): + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available assert _causal_conv1d_available - inner_projection = self.in_proj(hidden_states) - dt = self.dt_proj(self.dt_in_proj(hidden_states)) + self.dt_proj_bias + inner_projection = self.in_proj(input_) + dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias # Standardize to (batch, sequence, inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) From f88fb2f1f44aca0852ac293fc9b8950941cc7dd2 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 25 Jul 2025 19:10:28 +0000 Subject: [PATCH 131/161] rename vit layer to block --- fast_llm/layers/transformer/transformer.py | 2 +- fast_llm/models/gpt/model.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 19443a777..d2f3bfba8 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -217,5 +217,5 @@ def _create_mixer(self) -> Mixer: return Attention(self._config, self._tensor_space, self._block_index) -class VisionTransformerLayer(TransformerLayer): +class VisionTransformerBlock(TransformerBlock): _name: str = "Vision transformer layer" diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 62a58546a..47100d673 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -24,7 +24,7 @@ VisionTransformerKwargs, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerBlock, VisionTransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock, VisionTransformerBlock from fast_llm.layers.vision_encoder.adapter import VisionAdapter from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.layers.vision_encoder.patch_conv import PatchConv @@ -100,7 +100,7 @@ def get_output_layers(self) -> list[Layer]: def get_vision_layers(self) -> list[Layer]: vit_layers = [ - VisionTransformerLayer(self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1) + VisionTransformerBlock(self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1) for idx in range(self._config.vision_encoder.transformer.num_layers) ] return [ From 22296b32ce1b296c505e42dab7ca8893d0bf75a4 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 25 Jul 2025 19:24:58 +0000 Subject: [PATCH 132/161] block_index --- fast_llm/models/gpt/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 47100d673..c76c2191d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -100,7 +100,7 @@ def get_output_layers(self) -> list[Layer]: def get_vision_layers(self) -> list[Layer]: vit_layers = [ - VisionTransformerBlock(self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1) + VisionTransformerBlock(self._config.vision_encoder.transformer, self._tensor_space, block_index=idx + 1) for idx in range(self._config.vision_encoder.transformer.num_layers) ] return [ From c14b7643ae3f840f8da23404922f9482ff507284 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 25 Jul 2025 17:14:17 -0400 Subject: [PATCH 133/161] Mamba 2, misc --- fast_llm/engine/multi_stage/stage_base.py | 5 +- fast_llm/layers/ssm/config.py | 62 ++++++++++--------- fast_llm/layers/ssm/discrete_mamba2.py | 50 ++++++++++----- fast_llm/layers/ssm/mamba2.py | 22 ++++--- fast_llm/layers/ssm/mamba_layer.py | 27 ++++----- fast_llm/tensor.py | 74 +++++++++++++++-------- tests/models/test_checkpoint.py | 11 +++- tests/utils/model_configs.py | 9 +-- 8 files changed, 160 insertions(+), 100 deletions(-) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 9a8ce2092..3218a1963 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -185,8 +185,9 @@ def initialize_weights(self) -> None: # Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device) global_shape = meta.global_shape - if self._distributed_config.reproducible_init and ( - global_shape.numel() != parameter.numel() or not self._mode.on_device + if meta.requires_global_initialization or ( + self._distributed_config.reproducible_init + and (global_shape.numel() != parameter.numel() or not self._mode.on_device) ): # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 7f0b3cf61..c06d85148 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -5,31 +5,31 @@ from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig +from fast_llm.tensor import Initializer from fast_llm.utils import Assert, div class SSMDimNames: # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. state = "ssm_state" # State dimension (N), aka head size / num channels - + head_dim = "ssm_head_dim" head_groups = "ssm_head_groups" group_heads = "ssm_group_heads" + convolution_kernel = "ssm_convolution_kernel" # Kernel dimension of the conv1d in mamba layers + + dt_rank = "ssm_dt_rank" + + # Composite dimensions composite_heads = "ssm_composite_heads" - composite_heads_and_state = "ssm_composite_heads_and_state" + composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim" composite_head_groups_and_state = "ssm_composite_head_groups_and_state" - # Inner projection total dimension. + # Concatenated dimensions + concatenated_convolution = "ssm_concatenated_convolution" + concatenated_x_projection = "ssm_x_concatenated_x_projection" concatenated_inner_projection = "ssm_concatenated_inner_projection" - # Convolution shape in discrete mamba 2. TODO: Remove (dim too complex) - conv_dim = "ssm_conv_dim" - - dt_rank = "ssm_dt_rank" - - x_proj_dim = "x_proj_dim" # X projection dimension - conv_kernel = "conv_kernel" # Kernel size of the conv1d in mamba layers - class SSMBlockType(enum.StrEnum): """ @@ -62,7 +62,7 @@ class DTInitType(enum.StrEnum): constant = "constant" random = "random" - def get_init_method(self, scale: float): + def get_init_method(self, scale: float) -> Initializer: from fast_llm.tensor import init_fill_, init_uniform_centered_ return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) @@ -222,56 +222,64 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType # TODO: Use different variables? num_heads = self.n_v_heads num_head_groups = self.n_qk_heads - # v_heads have size `headdim` that may be different from `state_size`. - Assert.multiple(self.d_inner, num_heads) else: raise NotImplementedError(block_type) - tensor_space.add_tensor_dim(state_dim := TensorDim(SSMDimNames.state, self.state_size)) + tensor_space.add_tensor_dim(state := TensorDim(SSMDimNames.state, self.state_size)) + if block_type == SSMBlockType.mamba2_discrete: + tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, div(self.d_inner, num_heads))) + else: + head_dim = state + tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) tensor_space.add_tensor_dim(group_heads := TensorDim(SSMDimNames.group_heads, div(num_heads, num_head_groups))) tensor_space.add_tensor_dim( heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - heads_and_state := CompositeTensorDim( - SSMDimNames.composite_heads_and_state, (head_groups, group_heads, state_dim) + heads_and_head_dim := CompositeTensorDim( + SSMDimNames.composite_heads_and_head_dim, (head_groups, group_heads, head_dim) ) ) tensor_space.add_tensor_dim( head_groups_and_state := CompositeTensorDim( - SSMDimNames.composite_head_groups_and_state, (head_groups, state_dim) + SSMDimNames.composite_head_groups_and_state, (head_groups, state) ) ) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel, self.conv_kernel_dimension)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.convolution_kernel, self.conv_kernel_dimension)) # DT projection if block_type in (SSMBlockType.mamba, SSMBlockType.mamba2): - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, self.dt_rank)) + tensor_space.add_tensor_dim(dt_rank := TensorDim(SSMDimNames.dt_rank, self.dt_rank)) if block_type == SSMBlockType.mamba: - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.dt_rank + self.state_size * 2)) + tensor_space.add_tensor_dim( + ConcatenatedTensorDim(SSMDimNames.concatenated_x_projection, (dt_rank, state, state)) + ) # TODO: Use composition instead tensor_space.add_tensor_dim( - ConcatenatedTensorDim(SSMDimNames.concatenated_inner_projection, (heads_and_state, heads_and_state)) + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, (heads_and_head_dim, heads_and_head_dim) + ) ) elif block_type == SSMBlockType.mamba2: # TODO: Factor out state? tensor_space.add_tensor_dim( ConcatenatedTensorDim( SSMDimNames.concatenated_inner_projection, - (heads_and_state, head_groups_and_state, head_groups_and_state, heads_and_state), + (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim), ) ) elif block_type == SSMBlockType.mamba2_discrete: - # TODO: Factor as (head_groups, (group_heads + 2) * state_size + group_heads)? tensor_space.add_tensor_dim( ConcatenatedTensorDim( SSMDimNames.concatenated_inner_projection, - (heads_and_state, head_groups_and_state, head_groups_and_state, heads_and_state, heads), + (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim, heads), ) ) - # TODO: (head_groups, group_heads + 2, state_size) tensor_space.add_tensor_dim( - TensorDim(SSMDimNames.conv_dim, self.d_inner + 2 * self.n_qk_heads * self.state_size) + ConcatenatedTensorDim( + SSMDimNames.concatenated_convolution, + (heads_and_head_dim, head_groups_and_state, head_groups_and_state), + ) ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index ac4fb87cc..64377b93c 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -49,14 +49,18 @@ def __init__( layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) + conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.concatenated_convolution) heads_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) - self._local_heads = heads_dim.size + # local_head_groups = head_groups / TP self._local_head_groups = tensor_space.get_tensor_dim(SSMDimNames.head_groups).size + # local_heads = local_head_groups * group_heads + self._local_heads = heads_dim.size + # local_inner_size = local_heads * head_size self._local_inner_size = inner_dim.size + # local_bc_size = local_head_groups * state self._local_bc_size = tensor_space.get_tensor_dim(SSMDimNames.composite_head_groups_and_state).size # TODO: double check initializations @@ -80,7 +84,7 @@ def __init__( ( conv1d_dim, tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel), + tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -107,24 +111,25 @@ def __init__( ) def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - if kwargs[TransformerKwargs.sequence_first]: - raise NotImplementedError(f"Sequence-first not supported for SSMs.") - assert _mamba_available - sequence_length = input_.size(0 if kwargs[TransformerKwargs.sequence_first] else 1) + sequence_length = kwargs[TransformerKwargs.sequence_q_dim].global_size # Pad input to nearest multiple of chunklen padded_length = (1 + (sequence_length - 1) // self._config.chunk_size) * self._config.chunk_size if padded_length != sequence_length: - assert not kwargs[TransformerKwargs.sequence_first] and not self._sequence_parallel + assert not kwargs[TransformerKwargs.sequence_first] and input_.size(1) == sequence_length input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) + # inner_projection : (batch/local_or_padded_sequence, local_sequence/batch, hidden) + # -> (batch/local_or_padded_sequence, local_sequence/batch, inner_projection) + # inner_projection: (batch, local_or_padded_sequence, hidden) -> (batch, padded_sequence, local_inner_size) inner_projection = self.in_proj(input_) - # Standardize to (batch, sequence, inner_projection) + # Standardize to (batch, padded_sequence, inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) + print("QAIKOFNMJOWENM inner_projection", inner_projection.shape) xBC, z, A_log = torch.split( inner_projection, [ @@ -134,9 +139,13 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ ], dim=-1, ) + print("QAIKOFNMJOWENM xBC", xBC.shape, self._local_inner_size, self._local_bc_size) + print("QAIKOFNMJOWENM z", z.shape) + print("QAIKOFNMJOWENM A_log", A_log.shape) # Convolutional layer - xBC = self.convolutional_forward(xBC, sequence_length) + # xbc: (batch, padded_sequence, local_heads * head_size + 2 * local_head_groups * state) + xBC = self.convolutional_forward(xBC, padded_length) x, B, C = torch.split( xBC, @@ -148,13 +157,16 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ dim=-1, ) + # x: (batch, padded_sequence, local_heads * head_size) -> (batch, padded_sequence, local_heads, head_size) x = einops.rearrange(x, "b l (h n) -> b l h n", h=self._local_heads) + + # b,c: (batch, padded_sequence, local_head_groups * state) -> (batch, padded_sequence, local_head_groups, state) B = einops.rearrange(B, "b l (h n) -> b l h n", h=self._local_head_groups) C = einops.rearrange(C, "b l (h n) -> b l h n", h=self._local_head_groups) # SSM forward y = _mamba_chunk_scan_combined( - x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1), + x=self._apply_a_log(x, A_log), dt=A_log, dt_softplus=True, A=-torch.ones(self._local_heads, device=A_log.device), @@ -169,23 +181,31 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if not self._config.add_bias_linear: z = z + self.z_bias - # y: (batch, sequence, heads, state) -> (batch, sequence, heads * state) + # y: (batch, padded_sequence, local_heads, head_size) -> (batch, sequence, local_heads * head_size) y = ((y + Du).flatten(2, 3) * torch.nn.functional.silu(z))[:, :sequence_length] if kwargs[TransformerKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() + # out_proj: (batch/sequence, sequence/batch, local_heads * head_size) + # -> (batch/local_sequence, local_sequence/batch, hidden) + a, b = self.out_proj(y) + logger.info(f"EKFBN y {y.shape}") + logger.info(f"EKFBN a {a.shape}") return self.out_proj(y) + @torch.compile + def _apply_a_log(self, x: torch.Tensor, A_log: torch.Tensor) -> torch.Tensor: + return x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1) + def convolutional_forward(self, xBC, padded_len): """Convolutional layer forward pass for the full sequence.""" if _causal_conv1d_available and self._config.activation_type in ( ActivationType.silu, - "swish", ActivationType.identity, ): xBC = _causal_conv1d_fn( xBC.transpose(1, 2), - einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + self.conv1d_weight.squeeze(1), self.conv1d_bias, activation=( None diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index cba28f8b8..1ae25e44c 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -39,7 +39,7 @@ class Mamba2(Mixer): _XZ_DIMS = ( TransformerDimNames.batch, - SSMDimNames.composite_heads_and_state, + SSMDimNames.composite_heads_and_head_dim, TransformerDimNames.sequence_q, ) _BC_DIMS = ( @@ -62,7 +62,7 @@ def __init__( layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) + inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_head_dim) xb_dim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) @@ -78,7 +78,7 @@ def __init__( ( conv1d_dim, tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel), + tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -146,6 +146,8 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ assert _mamba_available assert _causal_conv1d_available + # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) + # -> (batch/sequence, sequence/batch, inner_projection) inner_projection = self.in_proj(input_) dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias # Standardize to (batch, sequence, inner_projection) @@ -161,10 +163,10 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ dim=2, ) - # z: (batch, sequence, heads * state) -> (batch, heads * state, sequence) + # z: (batch, sequence, local_heads * state) -> (batch, local_heads * state, sequence) z = z.transpose(1, 2) - # x: (batch, sequence, head_groups * state) -> (batch, heads * state, sequence) + # x: (batch, sequence, local_head_groups * state) -> (batch, local_heads * state, sequence) x = x.transpose(1, 2) if self._config.repeat_kv_before_conv: x = ( @@ -172,16 +174,16 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias.squeeze(1), activation="silu") + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") else: - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias.squeeze(1), activation="silu") + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") x = ( x.unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) - # b: (batch, sequence, head_groups * state) -> (batch, heads, state, sequence) + # b: (batch, sequence, local_head_groups * state) -> (batch, local_heads, state, sequence) b = ( b.transpose(1, 2) .unflatten(1, (self._local_head_groups, self._config.state_size)) @@ -216,9 +218,11 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._debug_level: self._debug_log(y, "y", self._XZ_DIMS, kwargs) - # y: (batch, heads * state, sequence) -> (batch, sequence, heads * state) + # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) y = y.transpose(1, 2)[:, :sequence_length] if kwargs[TransformerKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() + # (batch/sequence, sequence/batch, local_heads * state) + # -> (batch/local_sequence, local_sequence/batch, hidden) return self.out_proj(y) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 07eec38e6..64c8227fc 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -10,7 +10,7 @@ from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ +from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ from fast_llm.utils import Assert, get_lr_scale try: @@ -29,30 +29,27 @@ """ -def init_A(d_state, d_inner) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - # TODO: adopt this initialization to work for tensor parallel setting! +def init_A(d_state, d_inner) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa if tensor.numel() != d_state * d_inner: - raise ValueError(f"_init_A requires not supported for tensor slices.") - return torch.log( + raise ValueError("_init_A requires not supported for tensor slices.") + torch.log( torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device) .unsqueeze(0) .expand(d_inner, d_state), out=tensor, ) - return init_ + return LambdaInitializer(init_, requires_global_initialization=True) -def init_dtprojbias( - dt_max: float, dt_min: float, dt_init_floor: float -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +def init_dtprojbias(dt_max: float, dt_min: float, dt_init_floor: float) -> LambdaInitializer: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min_(dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - return tensor.add_(torch.log(-torch.expm1(-tensor))) + tensor.add_(torch.log(-torch.expm1(-tensor))) - return init_ + return LambdaInitializer(init_) class MambaLayer(Mixer): @@ -72,7 +69,7 @@ def __init__( Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) @@ -90,7 +87,7 @@ def __init__( ( inner_dim, tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(SSMDimNames.conv_kernel), + tensor_space.get_tensor_dim(SSMDimNames.convolution_kernel), ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, @@ -98,7 +95,7 @@ def __init__( self.x_proj = Linear( inner_dim, - tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim), + tensor_space.get_tensor_dim(SSMDimNames.concatenated_x_projection), weight_init_method=init_kaiming_(inner_dim.size), bias=False, lr_scale=lr_scale, diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 899e70005..b89ed4a04 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,3 +1,4 @@ +import abc import functools import math import typing @@ -241,7 +242,7 @@ def __init__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable[["ParameterMeta", torch.Tensor, torch.Generator], torch.Tensor] | None = None, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None" = None, weight_decay: bool = True, # Pass a list to split the parameter in contiguous (dim=0) chunks of equal size for optimization. lr_scale: float | None | tuple[float | None, ...] = None, @@ -251,7 +252,11 @@ def __init__( allow_no_grad: bool = False, ): super().__init__(data, tensor_name=tensor_name, dims=dims) - self.param_init_method = init_method + if init_method is not None and not isinstance(init_method, Initializer): + # Support non-wrapped callables for convenience. + assert callable(init_method) + init_method = LambdaInitializer(init_method) + self.param_init_method: Initializer | None = init_method self.param_weight_decay = weight_decay self._is_param = True self.param_grad_is_zero = False @@ -276,7 +281,7 @@ def __new__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None", weight_decay: bool = True, lr_scale: float | None | tuple[float | None, ...] = None, allow_sequence_tensor_parallel: bool = True, @@ -303,6 +308,10 @@ def init_parameter(self, tensor: torch.Tensor, distributed: Distributed) -> None generator = distributed.tp_init_generator if self.is_tensor_parallel else distributed.pp_init_generator self.param_init_method(self, tensor, generator) + @property + def requires_global_initialization(self) -> bool: + return self.param_init_method.requires_global_initialization + def save(self) -> dict[str, typing.Any]: return { "name": self.tensor_name, @@ -334,11 +343,32 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa -def init_fill_(value) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.fill_(value) +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + pass + + requires_global_initialization = False - return init_ + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + return self._init_method(meta, tensor, generator) + + +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) init_zeros_ = init_fill_(0.0) @@ -346,38 +376,32 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_normal_( - mean=0.0, std=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.normal_(mean, std, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) - return init_ + return LambdaInitializer(init_) -def init_kaiming_(d_in): +def init_kaiming_(d_in: float) -> LambdaInitializer: return init_normal_(0.0, math.sqrt(2.0 / d_in)) def init_uniform_( - low=0.0, high=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.uniform_(low, high, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) - return init_ + return LambdaInitializer(init_) -def init_uniform_centered_( - high, max_val=None, mean=0.0 -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: return init_uniform_( mean - high, mean + high, diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 05acf23dc..4bda5512c 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -284,10 +284,15 @@ def test_load_pretrained( @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_huggingface_model(model_testing_config, get_convert_path): # Test that Fast-LLM's Hugging Face wrapper produces the same results as the converted Hugging Face model. + # TODO: Stress the importance of this test as the main correctness test for most models. # TODO: Review test. Move to test_generate? fast_llm_path = get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) hf_path = get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) - model_ref = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( + try: + hf_class = model_testing_config.huggingface_model_for_causal_lm_class + except NotImplementedError: + pytest.skip(f"Hugging Face wrapper not implemented for {model_testing_config.name}.") + model_ref = hf_class.from_pretrained( CheckpointLoadConfig( path=get_convert_path(), format=DistributedCheckpointFormat, @@ -298,8 +303,8 @@ def test_huggingface_model(model_testing_config, get_convert_path): 0, model_ref.config.fast_llm_config.base_model.vocab_size, size=(4, 100), dtype=torch.int64, device="cuda" ) output_ref = model_ref(test_input) - model_from_fast_llm = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained(fast_llm_path) - model_from_hf = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( + model_from_fast_llm = hf_class.from_pretrained(fast_llm_path) + model_from_hf = hf_class.from_pretrained( CheckpointLoadConfig( path=hf_path, format=model_testing_config.checkpoint_format, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 038b53c26..722d8d63a 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -20,6 +20,7 @@ Starcoder2GPTHuggingfaceCheckpointFormat, ) from fast_llm.models.ssm.config import ( + AprielSSMHHybridHuggingfaceCheckpointFormat, AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat, ) @@ -540,19 +541,19 @@ def _update_and_add_testing_config( "model.base_model.ssm.chunk_size=32", ], megatron_args=None, - checkpoint_format=None, + checkpoint_format=AprielSSMHHybridHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, # TODO: Implement - ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, compare_factor=2.0, # Micro-sequence split and sequence-first not supported. - skip_tests=("sf", "stp", "sdp", "ms"), + skip_tests=("sdp", "ms"), ) From fa211747ea1ed81528e771e140b58ed7b579c3b7 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 25 Jul 2025 21:29:07 +0000 Subject: [PATCH 134/161] flexible import --- .../external/llava_hybrid/modeling_llava_hybrid.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py index 9896d91d1..0c7fd9b9b 100644 --- a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py +++ b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py @@ -4,6 +4,16 @@ from .configuration_llava_hybrid import LlavaHybridConfig +try: + # In the fast-llm repo, import from the SSM modeling file + from ..apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + AprielThinkerSSMHybridModel, + HybridMambaAttentionDynamicCache, + ) +except ImportError: + # In the exported checkpoint, import from local file + from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel, HybridMambaAttentionDynamicCache + class LlavaMultiModalProjector(nn.Module): def __init__(self, config: LlavaHybridConfig): @@ -42,7 +52,6 @@ def __init__(self, config: LlavaHybridConfig): assert ( config.text_config.model_type == "apriel_ssm_thinker_hybrid" ), "Only Apriel SSM Hybrid model is supported in LlavaHybridModel" - from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel self.language_model = AprielThinkerSSMHybridModel(config.text_config) self.post_init() @@ -69,8 +78,6 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - from .modeling_ssm_hybrid_apriel15b import HybridMambaAttentionDynamicCache - # Copy of the method from `AprielThinkerSSMHybridForCausalLM` # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` From d3cc1583f0d75b60f99c3ab96a82424371adb83f Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 28 Jul 2025 14:10:39 +0000 Subject: [PATCH 135/161] update import --- .../models/ssm/external/llava_hybrid/modeling_llava_hybrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py index 0c7fd9b9b..b056d3a00 100644 --- a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py +++ b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py @@ -6,7 +6,7 @@ try: # In the fast-llm repo, import from the SSM modeling file - from ..apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( AprielThinkerSSMHybridModel, HybridMambaAttentionDynamicCache, ) From 6d245c0578cf5d91ef95a4ce528fffe3a5d69f3f Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 28 Jul 2025 18:21:44 +0000 Subject: [PATCH 136/161] fix automodel export --- fast_llm/models/ssm/conversion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 059bff436..64afbea06 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -851,6 +851,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: "AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig", "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", "AutoModelForVision2Seq": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + "AutoModelForCausalLM": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", }, ), ] From 61ecb5d8206cd3335747a860cd87114919d80666 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 28 Jul 2025 18:59:06 +0000 Subject: [PATCH 137/161] try: remove assert for TP and distillation --- fast_llm/engine/training/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 4b8d805b8..3dbec5348 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -388,7 +388,7 @@ def _validate(self) -> None: # TODO: Add support. Assert.eq(self.model.distributed.pipeline_parallel, 1) # TODO: Check if these work. - Assert.eq(self.model.distributed.tensor_parallel, 1) + # Assert.eq(self.model.distributed.tensor_parallel, 1) Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() From 2565dac89eac0c746d9e3868ac46bbc62a8de695 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 28 Jul 2025 22:46:41 +0000 Subject: [PATCH 138/161] more verbose config --- fast_llm/engine/config_utils/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 7ab5b8e41..b23037e84 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -130,7 +130,7 @@ def __init__( self._distributed.config.data_rank == 0 and self._distributed.config.tensor_rank == 0 ) config_dict = config.to_dict() - config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.performance) + config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.debug) if self._config.experiment_dir is not None: self._experiment_directory = self._config.experiment_dir.resolve() From 7a7f12c54660a9091bcef3409867d426728dfc23 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 29 Jul 2025 00:19:41 +0000 Subject: [PATCH 139/161] use local token_ids instead of modifying batch --- fast_llm/models/gpt/model.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c76c2191d..3d393fd40 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -342,19 +342,20 @@ def preprocess( reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] + token_ids = batch.token_ids if sequence_first: # Move the sequence dimension first to make sequence parallel ops more efficient. - batch.token_ids = batch.token_ids.transpose(0, 1).contiguous() + token_ids = token_ids.transpose(0, 1).contiguous() preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size if sequence_first: - tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] + tokens = token_ids[sequence_k - sequence_q : sequence_k] else: # TODO: Avoid multiple contiguous calls? - tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() + tokens = token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths if batch.chosen_spans is not None: @@ -374,10 +375,10 @@ def preprocess( if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels if sequence_first: - labels = batch.token_ids[sequence_offset : sequence_k + prediction_heads] + labels = token_ids[sequence_offset : sequence_k + prediction_heads] else: # TODO: Avoid multiple contiguous calls? - labels = batch.token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() + labels = token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config labels_cloned = False From 743b42c393b162d6c5881619a9a0d5bd2a21aec6 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 29 Jul 2025 01:08:56 +0000 Subject: [PATCH 140/161] fix allreduce --- fast_llm/functional/cross_entropy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 7a289b579..df93bca28 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -277,7 +277,8 @@ def _torch_reverse_kl_forward_backward( loss = (loss_per_sample * loss_mask).mean() if group is not None and target_format != TargetFormat.labels: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= group.size() if grad_output is not None: loss.backward(torch.full_like(loss, grad_output)) From c7247dc4c70752a163c048d2fd720659e7e55200 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 29 Jul 2025 02:09:34 +0000 Subject: [PATCH 141/161] fix --- fast_llm/functional/cross_entropy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index df93bca28..95f141d96 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -151,7 +151,8 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= group.size() return loss, grad From 3074ec90d3d2e3164a77aeb57d040c7aa326c85f Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 29 Jul 2025 20:06:19 +0000 Subject: [PATCH 142/161] revert images_sizes conversion to np array --- fast_llm/data/dataset/gpt/memmap.py | 1 - fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/data/preparator/gpt_memmap/prepare.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 2a1986b63..493361f32 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -177,7 +177,6 @@ def _init( assert self._num_pixels == num_pixels if num_tokens is not None: assert self._num_tokens == num_tokens - self._image_sizes = np.array(self._image_sizes, dtype=np.int32) def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 42062a58c..29a784b77 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -143,7 +143,7 @@ def _sample(self) -> None: # Get the document sizes, the main information needed for sampling. document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) - if image_sizes.any(): + if image_sizes: image_token_sizes = [] for i, sizes in enumerate(image_sizes): image_token_sizes.append( diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index fce0f022c..d6d473838 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -458,7 +458,7 @@ def _split_and_blend_dataset_configs( text_sizes, image_sizes = dataset.get_document_sizes() tokens_cumsum = text_sizes.cumsum() Assert.eq(tokens_cumsum[-1], dataset_config.num_tokens) - if image_sizes.any(): + if image_sizes: num_pixels_cumsum = np.cumsum([x.prod(axis=1).sum() for x in image_sizes]) # We use the patch sizes only for the purposes of even splitting and blending weights. # We can always use a different patch size for training without any significant impact From c4cdd86403942141e289221502392fb675398cca Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 30 Jul 2025 16:13:29 +0000 Subject: [PATCH 143/161] debug logs --- fast_llm/engine/schedule/runner.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 8eca4559d..338c7a5df 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -193,6 +193,8 @@ def run_step( for step in schedule: self._train_step(context, step) + logger.info("End of the schedule steps") + # Make sure we used all the data. This also ensures the generator terminates and prevents a memory leak. try: next(context.data_iterator) @@ -202,6 +204,7 @@ def run_step( raise AssertionError("Data iterator did not terminate") assert context.done, context + logger.info("End data-iterator") if self._multi_stage.config.multi_stage.debug_activation_memory: log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"End of the schedule steps", str)) @@ -240,7 +243,9 @@ def run_step( # TODO: Option to update with reduce (needs per-layer grad_norm and update_successful) # TODO: Avoid blocking synchronizations: async transfer, turn noop_flag into a real noop flag # (uncomment line in apex). + logger.info("Updating weights") update_successful = self._optimizer.step(metrics) + logger.info("Weights updated") if self._multi_stage.config.multi_stage.debug_tensor_parallel and self._distributed.tensor_group is not None: for stage in self._stages_on_device: @@ -275,6 +280,7 @@ def run_step( return self._reduce_losses(context), update_successful, metrics def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: + logger.info("Reducing losses") reduced_losses = {} num_inputs = self._distributed_config.data_parallel * context.schedule.batch_config.num_inputs for name, losses in context.losses.items(): @@ -290,6 +296,7 @@ def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: else: reduced_loss = 0.0 reduced_losses[name] = reduced_loss + logger.info(f"Reduced losses: {reduced_losses}") return { name: reduced_loss.item() if isinstance(reduced_loss, torch.Tensor) else reduced_loss for name, reduced_loss in reduced_losses.items() From 24d7a05df5bf730629fee0f0ad9cbdae0da0bf22 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 5 Aug 2025 15:23:26 +0000 Subject: [PATCH 144/161] rm debug logs --- fast_llm/engine/schedule/runner.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 338c7a5df..8eca4559d 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -193,8 +193,6 @@ def run_step( for step in schedule: self._train_step(context, step) - logger.info("End of the schedule steps") - # Make sure we used all the data. This also ensures the generator terminates and prevents a memory leak. try: next(context.data_iterator) @@ -204,7 +202,6 @@ def run_step( raise AssertionError("Data iterator did not terminate") assert context.done, context - logger.info("End data-iterator") if self._multi_stage.config.multi_stage.debug_activation_memory: log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"End of the schedule steps", str)) @@ -243,9 +240,7 @@ def run_step( # TODO: Option to update with reduce (needs per-layer grad_norm and update_successful) # TODO: Avoid blocking synchronizations: async transfer, turn noop_flag into a real noop flag # (uncomment line in apex). - logger.info("Updating weights") update_successful = self._optimizer.step(metrics) - logger.info("Weights updated") if self._multi_stage.config.multi_stage.debug_tensor_parallel and self._distributed.tensor_group is not None: for stage in self._stages_on_device: @@ -280,7 +275,6 @@ def run_step( return self._reduce_losses(context), update_successful, metrics def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: - logger.info("Reducing losses") reduced_losses = {} num_inputs = self._distributed_config.data_parallel * context.schedule.batch_config.num_inputs for name, losses in context.losses.items(): @@ -296,7 +290,6 @@ def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: else: reduced_loss = 0.0 reduced_losses[name] = reduced_loss - logger.info(f"Reduced losses: {reduced_losses}") return { name: reduced_loss.item() if isinstance(reduced_loss, torch.Tensor) else reduced_loss for name, reduced_loss in reduced_losses.items() From 37ddef4d2d366b97a4b5fe9112df68c9d2bb5b5a Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 5 Aug 2025 20:17:39 +0000 Subject: [PATCH 145/161] changes for stp reverse-kl --- fast_llm/functional/cross_entropy.py | 28 +++++++++++++++++--------- fast_llm/layers/language_model/head.py | 11 +++++++++- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 95f141d96..afd7c2ef2 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -226,7 +226,8 @@ def _torch_reverse_kl_forward_backward( ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. - Much simpler and more reliable than custom implementation! + This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. + In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. """ Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") Assert.eq(target.shape, logits.shape) @@ -244,7 +245,6 @@ def _torch_reverse_kl_forward_backward( scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) # Clamp to prevent extreme values before log_softmax - scaled_target = torch.clamp(scaled_target, min=-50, max=50) teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) @@ -254,14 +254,9 @@ def _torch_reverse_kl_forward_backward( with torch.enable_grad(): logits_ = logits.detach().requires_grad_(grad_output is not None) - # Use log_softmax for consistency instead of _fused_softmax scaled_logits = logits_ * logits_scale_factor - scaled_logits = torch.clamp(scaled_logits, min=-50, max=50) student_log_probs = torch.log_softmax(scaled_logits, dim=-1) - # Convert to probabilities for kl_div - # student_probs_ = torch.exp(student_log_probs) - # Reverse KL: input=teacher_log_probs, target=student_probs if loss_mask is None: loss = torch.nn.functional.kl_div( @@ -299,6 +294,7 @@ def reverse_kl_forward_backward( logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, + vocab_parallel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -342,6 +338,18 @@ def reverse_kl_forward_backward( if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) # TODO: implement fused? - return _torch_reverse_kl_forward_backward( - logits, target, loss_mask, grad_output, logits_scale_factor, target_format, group, teacher_softmax_temperature - ) + if vocab_parallel: + Assert.eq(teacher_softmax_temperature, 1) + Assert.eq(logits_scale_factor, 1) + raise NotImplementedError("Vocab parallel reverse KL is not implemented yet.") + else: + return _torch_reverse_kl_forward_backward( + logits, + target, + loss_mask, + grad_output, + logits_scale_factor, + target_format, + group, + teacher_softmax_temperature, + ) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 21bf3bbd0..791b1f09d 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -239,7 +239,15 @@ def _get_targets( lm_target = None targets = (dpo_target, lm_target, distillation_target, loss_mask) - if self._sequence_parallel_logits: + # If we do distillation, no need to split it here as it has already been split in the embedding layer! + # if we do CPT/language modeling, we need to split the targets here! + if ( + self._config.distillation_model is not None + and self._sequence_parallel_logits + and not self._parallel_embeddings + and not self._sequence_parallel + ) or (self._config.distillation_model is None and self._sequence_parallel_logits): + # We dont split targets if they already have been split in the embedding layer! targets = [ None if target is None else split_op(target, self._tensor_space.distributed.tensor_group, 0) for target in targets @@ -412,6 +420,7 @@ def _logits_cross_entropy_forward_backward( target_format=( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), + vocab_parallel=logits.shape[-1] != self._config.vocab_size, ) elif self._distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( From a0d7a09fdce57e70fbf1524574ca8393adc0453b Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 6 Aug 2025 03:49:09 +0000 Subject: [PATCH 146/161] reverse kl: add clamping --- fast_llm/functional/cross_entropy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index afd7c2ef2..eaeaa0d18 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -245,6 +245,7 @@ def _torch_reverse_kl_forward_backward( scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) # Clamp to prevent extreme values before log_softmax + scaled_target = torch.clamp(scaled_target, min=-50, max=50) teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) @@ -255,6 +256,7 @@ def _torch_reverse_kl_forward_backward( logits_ = logits.detach().requires_grad_(grad_output is not None) scaled_logits = logits_ * logits_scale_factor + scaled_logits = torch.clamp(scaled_logits, min=-50, max=50) student_log_probs = torch.log_softmax(scaled_logits, dim=-1) # Reverse KL: input=teacher_log_probs, target=student_probs From 72945bcfd099c09ee7bf399d13258052aef5b483 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 6 Aug 2025 14:09:35 +0000 Subject: [PATCH 147/161] add loss mask for vision. should also handle padded sequences --- fast_llm/layers/language_model/head.py | 6 +++++- fast_llm/models/gpt/model.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 791b1f09d..eed2d134f 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -238,7 +238,7 @@ def _get_targets( else: lm_target = None - targets = (dpo_target, lm_target, distillation_target, loss_mask) + targets = (dpo_target, lm_target, distillation_target) # If we do distillation, no need to split it here as it has already been split in the embedding layer! # if we do CPT/language modeling, we need to split the targets here! if ( @@ -252,6 +252,10 @@ def _get_targets( None if target is None else split_op(target, self._tensor_space.distributed.tensor_group, 0) for target in targets ] + # Loss mask may need to be split. It was not split in the embedding layer as it is not used there. + if loss_mask is not None and self._sequence_parallel_logits: + loss_mask = split_op(loss_mask, self._tensor_space.distributed.tensor_group, 0) + targets = (*targets, loss_mask) if not any(target is not None for target in targets): # Simplify so we don't have to check every time. targets = None diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 3d393fd40..be172af96 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -407,16 +407,32 @@ def preprocess( kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) if self._config.vision_encoder.enabled: + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask, torch.ones_like(labels, dtype=torch.bool)) if self._config.vision_encoder.image_break_token is not None: if not labels_cloned: labels = labels.clone() labels_cloned = True labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) + loss_mask = torch.where( + labels == self._config.vision_encoder.image_break_token, False, loss_mask + ) + if self._config.distillation_model is not None: + kwargs[LanguageModelKwargs.loss_mask] = loss_mask if self._config.vision_encoder.image_end_token is not None: if not labels_cloned: labels = labels.clone() labels_cloned = True labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) + loss_mask = torch.where( + labels == self._config.vision_encoder.image_end_token, False, loss_mask + ) + if self._config.distillation_model is not None: + kwargs[LanguageModelKwargs.loss_mask] = loss_mask + # TODO: Check that this works. Can we remove previous loss_masking? + if self._config.distillation_model is not None: + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask, torch.ones_like(labels, dtype=torch.bool)) + loss_mask = torch.where(labels == -100, False, loss_mask) + kwargs[LanguageModelKwargs.loss_mask] = loss_mask kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) From 26541c6f8d89841cc604b79457ea9df8283ba6ae Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 6 Aug 2025 19:17:17 +0000 Subject: [PATCH 148/161] simplify loss-masking --- fast_llm/models/gpt/model.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index be172af96..7b4d165c5 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -397,40 +397,25 @@ def preprocess( valid_spans[:, 0].clamp_(min=sequence_offset) valid_spans[:, 1].clamp_(max=sequence_k + prediction_heads - 1) valid_spans -= sequence_offset - loss_mask = torch.ones_like(labels, dtype=torch.bool) for start, end in valid_spans: if sequence_first: - loss_mask[start : end + 1, i] = False + labels[start : end + 1, i] = -100 else: - loss_mask[i, start : end + 1] = False - if self._config.distillation_model is not None: - kwargs[LanguageModelKwargs.loss_mask] = loss_mask - labels = torch.where(loss_mask, labels, -100) + labels[i, start : end + 1] = -100 if self._config.vision_encoder.enabled: - loss_mask = kwargs.get(LanguageModelKwargs.loss_mask, torch.ones_like(labels, dtype=torch.bool)) if self._config.vision_encoder.image_break_token is not None: if not labels_cloned: labels = labels.clone() labels_cloned = True labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) - loss_mask = torch.where( - labels == self._config.vision_encoder.image_break_token, False, loss_mask - ) - if self._config.distillation_model is not None: - kwargs[LanguageModelKwargs.loss_mask] = loss_mask if self._config.vision_encoder.image_end_token is not None: if not labels_cloned: labels = labels.clone() labels_cloned = True labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) - loss_mask = torch.where( - labels == self._config.vision_encoder.image_end_token, False, loss_mask - ) - if self._config.distillation_model is not None: - kwargs[LanguageModelKwargs.loss_mask] = loss_mask - # TODO: Check that this works. Can we remove previous loss_masking? + # Loss-masking for distillation losses if self._config.distillation_model is not None: - loss_mask = kwargs.get(LanguageModelKwargs.loss_mask, torch.ones_like(labels, dtype=torch.bool)) + loss_mask = torch.ones_like(labels, dtype=torch.bool) loss_mask = torch.where(labels == -100, False, loss_mask) kwargs[LanguageModelKwargs.loss_mask] = loss_mask kwargs[LanguageModelKwargs.labels] = labels From d66942b8e374404663fe8f9cbecb9296033da5b4 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 6 Aug 2025 19:47:46 +0000 Subject: [PATCH 149/161] data time debug --- fast_llm/data/data/gpt/data.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 9df9b9b86..13625099f 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -1,6 +1,7 @@ import dataclasses import logging import pathlib +import time import typing import warnings from functools import partial @@ -39,6 +40,8 @@ class GPTBatch: def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: + start_time = time.perf_counter() + stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None sequence_lengths = None @@ -65,6 +68,12 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling batch_image_positions.append(torch.from_numpy(sample.image_positions)) else: batch_image_positions.append([]) + + data_time = (time.perf_counter() - start_time) * 1000 + if data_time > 1000: + logger.warning( + f"Data collate-fn took {data_time:,.2f} ms, Num images: {len(batch_images[0]) if has_images else 0}" + ) return GPTBatch( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, From f00509d99e61d21d31d2dde31003d8978a4cfa4c Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 6 Aug 2025 21:54:34 +0000 Subject: [PATCH 150/161] fix mm embedding indices --- fast_llm/layers/multi_modal/embedding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 948b2acf9..a5a789f9e 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -69,6 +69,7 @@ def _forward( for position, size in zip(positions, sizes): num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) if image_embedding_offset + num_patches < patch_start_offset: + image_embedding_offset += num_patches continue if self._config.vision_encoder.image_break_token is not None: patch_height = div(size[0], self._config.vision_encoder.patch_size) @@ -83,7 +84,7 @@ def _forward( input_start_index = max(row_start_src, patch_start_offset) - patch_start_offset input_end_index = min(row_start_src + patch_width, patch_end_offset) - patch_start_offset - embeddings_start_index = row_start_dst - max(patch_start_offset - row_start_src, 0) + embeddings_start_index = row_start_dst + max(patch_start_offset - row_start_src, 0) embeddings_end_index = ( row_start_dst + patch_width - max(row_start_src + patch_width - patch_end_offset, 0) ) From 8a5e8f0616a1ae6c2d4fa5a4d220a84e369174e7 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 7 Aug 2025 17:37:56 +0000 Subject: [PATCH 151/161] cleanup --- fast_llm/data/data/gpt/data.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 13625099f..9df9b9b86 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -1,7 +1,6 @@ import dataclasses import logging import pathlib -import time import typing import warnings from functools import partial @@ -40,8 +39,6 @@ class GPTBatch: def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: - start_time = time.perf_counter() - stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None sequence_lengths = None @@ -68,12 +65,6 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling batch_image_positions.append(torch.from_numpy(sample.image_positions)) else: batch_image_positions.append([]) - - data_time = (time.perf_counter() - start_time) * 1000 - if data_time > 1000: - logger.warning( - f"Data collate-fn took {data_time:,.2f} ms, Num images: {len(batch_images[0]) if has_images else 0}" - ) return GPTBatch( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, From 01783de82289b1124e76a87884ce05aacc0c8a6c Mon Sep 17 00:00:00 2001 From: Oleksiy Ostapenko Date: Thu, 7 Aug 2025 17:22:50 -0400 Subject: [PATCH 152/161] Mamba2 to be merged (#349) Co-authored-by: Joel Lamy-Poirier Co-authored-by: Denis Kochetkov Co-authored-by: Toolkit User --- fast_llm/engine/config_utils/tensor_space.py | 49 +- fast_llm/engine/multi_stage/config.py | 5 + fast_llm/engine/multi_stage/fsdp.py | 32 +- fast_llm/engine/training/config.py | 5 - fast_llm/engine/training/trainer.py | 2 +- fast_llm/functional/cross_entropy.py | 95 +- fast_llm/layers/language_model/embedding.py | 8 +- fast_llm/layers/language_model/head.py | 11 +- .../layers/language_model/preprocessing.py | 4 +- fast_llm/layers/ssm/config.py | 9 +- fast_llm/layers/ssm/discrete_mamba2.py | 18 +- fast_llm/layers/ssm/mamba2.py | 20 +- fast_llm/layers/ssm/mamba_layer.py | 16 +- fast_llm/layers/transformer/attention.py | 16 +- .../layers/transformer/mixture_of_experts.py | 6 +- fast_llm/layers/transformer/mlp.py | 6 +- fast_llm/layers/transformer/preprocessing.py | 2 +- .../transformer/rotary/preprocessing.py | 4 +- fast_llm/layers/transformer/rotary/rotary.py | 4 +- fast_llm/layers/transformer/transformer.py | 5 +- fast_llm/models/gpt/config.py | 12 +- fast_llm/models/gpt/megatron.py | 29 +- fast_llm/models/gpt/model.py | 2 +- fast_llm/models/ssm/config.py | 45 +- fast_llm/models/ssm/external/15B_hybrid.ipynb | 1562 +++++++++++++++++ fast_llm/models/ssm/external/5B_hybrid.ipynb | 416 +++++ .../modeling_ssm_hybrid_apriel15b.py | 27 +- fast_llm/models/ssm/huggingface.py | 1 + fast_llm/tensor.py | 56 +- setup.cfg | 8 +- tests/layers/test_lm_head.py | 2 - 31 files changed, 2291 insertions(+), 186 deletions(-) create mode 100644 fast_llm/models/ssm/external/15B_hybrid.ipynb create mode 100644 fast_llm/models/ssm/external/5B_hybrid.ipynb diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 0d971a88a..6c4b95b20 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -66,13 +66,23 @@ def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": - if self.parallel_group is not None: + if self.is_parallel: from fast_llm.core.ops import gather_op return gather_op(tensor, self.parallel_group, dim) else: return tensor + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + if self.is_parallel: + output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value) + output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim) + return output.flatten(dim, dim + 1) + else: + return tensor + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": return ( tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] @@ -85,7 +95,7 @@ class CompositeTensorDim(TensorDim): def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): parallel_dim = None for dim, tensor_dim in enumerate(tensor_dims): - if tensor_dim.is_parallel: + if tensor_dim.parallel_dim is not None: # TODO: Allow more than one parallel subdim? assert parallel_dim is None parallel_dim = tensor_dim.parallel_dim @@ -111,6 +121,15 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global_partial(tensor, dim + i) + + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): @@ -157,6 +176,27 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor else tensor ) + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + import torch + + return ( + torch.concatenate( + [ + tensor_dim.local_to_global_partial(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": if self.is_parallel and expand: raise NotImplementedError() @@ -223,8 +263,5 @@ def add_tensor_dim(self, tensor_dim: TensorDim) -> None: ) self._tensor_dims[tensor_dim.name] = tensor_dim - def get_tensor_dim(self, name: str) -> TensorDim: + def __getitem__(self, name: str) -> TensorDim: return self._tensor_dims[name] - - # TODO: Replace uses - __getitem__ = get_tensor_dim diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 6ac157dfe..719088057 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -31,6 +31,7 @@ if typing.TYPE_CHECKING: from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM + from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel logger = logging.getLogger(__name__) @@ -241,6 +242,10 @@ def get_checkpoint_handler_class(cls, format: type[CheckpointFormat] | str) -> t def get_model_class(cls) -> type["FastLLMModel"]: raise NotImplementedError + @classmethod + def get_inference_runner_class(cls) -> type["InferenceRunner"]: + raise NotImplementedError + @classmethod def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceBaseModelForCausalLM"]: raise NotImplementedError diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 5b44bf14b..be15cd37a 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -441,39 +441,21 @@ def _get_parameter_shard_indices_in_full_weight( where it is located in the shard if it exists, or -1 if it's not in the shard. Used to determine the location of each entry in a different distributed configuration. """ - - # Create an empty index for the global parameter. - index = torch.full( - parameter_meta.global_shape, - -1, - dtype=torch.int64, - device=device, - ) # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard begin, end = self._get_parameter_range_in_shard(parameter_name) - buffer_index = parameter_meta.global_to_local(index, expand=True) - # Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible. - # In that case, we work with a separate tensor to be copied back into `buffer_index`. - try: - buffer_index_flat = buffer_index.view(-1) - is_view = True - except RuntimeError: - buffer_index_flat = buffer_index.new_full((buffer_index.numel(),), -1) - is_view = False - - # Copy the shard indices at their respective positions in the flat buffer index. - buffer_index_flat[ + # Create an empty local index to hold the local shard indices. + buffer_index = torch.full_like(parameter_meta, -1, dtype=torch.int64, device=device) + + # Copy the shard indices at their respective positions in the buffer index. + buffer_index.flatten()[ self._index_buffer_to_param( self._fsdp_dim.rank * self._shard_size, parameter_name ) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name) ].copy_(torch.arange(begin, end, dtype=torch.int64, device=device)) - # If needed, copy the flat buffer index back into the index. - if not is_view: - buffer_index.copy_(buffer_index_flat.view_as(buffer_index)) - - return index + # Create a global index from the local one. + return parameter_meta.local_to_global_partial(buffer_index, -1) def copy_shard_overlaps( self, diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 3dbec5348..9372ad7fb 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -32,7 +32,6 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.training.trainer import Trainer, TrainingEvaluator @@ -403,10 +402,6 @@ def _setup(self): def get_trainer_class(cls) -> type["Trainer"]: raise NotImplementedError - @classmethod - def get_inference_runner_class(cls) -> type["InferenceRunner"]: - raise NotImplementedError - def _get_runnable(self) -> typing.Callable[[], None]: from fast_llm.engine.distributed.distributed import Distributed diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 5f5511a15..ec3c4cebe 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -142,7 +142,7 @@ def __init__(self, config: TrainerConfig): self._reference_models = {} for name, reference_config in self._config.reference_models.items(): log_main_rank(f"Creating `{name} reference model...") - self._reference_models[name] = self._config.get_inference_runner_class()( + self._reference_models[name] = reference_config.model.get_inference_runner_class()( reference_config.model.get_model_class()(reference_config.model) ) self._multi_stage.base_model.add_reference_model(name, self._reference_models[name]) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index eaeaa0d18..b18a9ec0b 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -49,6 +49,19 @@ def _torch_cross_entropy_forward_backward( return loss.detach_(), grad +def distributed_log_softmax(logits: torch.Tensor, group: ProcessGroup, dim: int = -1): + logits = logits.float() + local_max = logits.max(dim=dim, keepdim=True)[0] + all_reduce(local_max, op=ReduceOp.MAX, group=group) + + logits_shifted = logits - local_max + exp_logits = torch.exp(logits_shifted) + sum_exp = exp_logits.sum(dim=dim, keepdim=True) + all_reduce(sum_exp, op=ReduceOp.SUM, group=group) + + return logits_shifted - sum_exp.log() # log_softmax + + @torch.compile def _fused_softmax_base( logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 @@ -214,21 +227,21 @@ def cross_entropy_forward_backward( ) -def _torch_reverse_kl_forward_backward( +def _torch_reverse_kl_forward_backward_vocab_parallel( logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None, grad_output: float | None, - logits_scale_factor: float, target_format: TargetFormat, group: ProcessGroup | None = None, - teacher_softmax_temperature: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. + This is used for TP version where we split accross vocab dimantion. This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. """ + # TODO: merge into single function _torch_reverse_kl_forward_backward Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype @@ -236,16 +249,66 @@ def _torch_reverse_kl_forward_backward( Assert.eq(loss_mask.shape, logits.shape[:-1]) # Compute log probabilities - let _fused_softmax handle scaling internally - # teacher_probs = _fused_softmax(target, logits_scale_factor * (1 / teacher_softmax_temperature), group) - # # teacher_log_probs = torch.log(teacher_probs + 1e-8) # log(p) - # teacher_probs = torch.clamp(teacher_probs, min=1e-7) # or even 1e-6 - # teacher_log_probs = torch.log(teacher_probs) + teacher_log_probs = distributed_log_softmax(target, group=group) + batch_size = logits.shape[0] + with torch.enable_grad(): + logits_ = logits.detach().requires_grad_(grad_output is not None) + student_log_probs = distributed_log_softmax(logits_, group=group) + + # Reverse KL: input=teacher_log_probs, target=student_probs + if loss_mask is None: + loss = torch.nn.functional.kl_div( + teacher_log_probs, # input = log(p) + student_log_probs, # target = log(q) + reduction="sum", + log_target=True, + ) + else: + # Apply loss mask - this requires some reshaping + raise NotImplementedError("Loss mask not implemented with TP for reverse KL , it must be doublechecked") + loss_per_sample = torch.nn.functional.kl_div( + teacher_log_probs, student_log_probs, reduction="none", log_target=True + ).sum(dim=-1) + loss = (loss_per_sample * loss_mask).sum() + + if group is not None and target_format != TargetFormat.labels: + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= batch_size + + if grad_output is not None: + loss.backward(torch.full_like(loss, grad_output)) + grad = logits_.grad.to(logits.dtype) + else: + grad = None + return loss.detach_(), grad + + +def _torch_reverse_kl_forward_backward( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + group: ProcessGroup | None = None, + teacher_softmax_temperature: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Reverse KL using PyTorch's native kl_div function. + This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. + In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. + """ + Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) # Scale target logits more carefully scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) + # Clamp to prevent extreme values that cause NaNs in log_softmax + scaled_target = torch.clamp(scaled_target, min=-100.0, max=100.0) - # Clamp to prevent extreme values before log_softmax - scaled_target = torch.clamp(scaled_target, min=-50, max=50) teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) @@ -256,9 +319,10 @@ def _torch_reverse_kl_forward_backward( logits_ = logits.detach().requires_grad_(grad_output is not None) scaled_logits = logits_ * logits_scale_factor - scaled_logits = torch.clamp(scaled_logits, min=-50, max=50) + # Clamp to prevent extreme values that cause NaNs in log_softmax + scaled_logits = torch.clamp(scaled_logits, min=-100.0, max=100.0) student_log_probs = torch.log_softmax(scaled_logits, dim=-1) - + # Reverse KL: input=teacher_log_probs, target=student_probs if loss_mask is None: loss = torch.nn.functional.kl_div( @@ -279,6 +343,7 @@ def _torch_reverse_kl_forward_backward( loss /= group.size() if grad_output is not None: + # note, we never get here in TP over seq. dim. loss.backward(torch.full_like(loss, grad_output)) grad = logits_.grad.to(logits.dtype) else: @@ -344,6 +409,14 @@ def reverse_kl_forward_backward( Assert.eq(teacher_softmax_temperature, 1) Assert.eq(logits_scale_factor, 1) raise NotImplementedError("Vocab parallel reverse KL is not implemented yet.") + return _torch_reverse_kl_forward_backward_vocab_parallel( + logits, + target, + loss_mask, + grad_output, + target_format, + group, + ) else: return _torch_reverse_kl_forward_backward( logits, diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 7036a1e97..f6f43d199 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -46,10 +46,10 @@ def __init__( self._dropout_p = config.transformer.hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - vocab_dim = tensor_space.get_tensor_dim( + hidden_dim = tensor_space[TransformerDimNames.hidden] + vocab_dim = tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] if self._parallel_embeddings: self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size @@ -66,7 +66,7 @@ def __init__( ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( - (tensor_space.get_tensor_dim(LanguageModelDimNames.position_embed), hidden_dim), + (tensor_space[LanguageModelDimNames.position_embed], hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index eed2d134f..24c06d5cc 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -61,7 +61,7 @@ def __init__( if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] self._loss_coefficient = ( config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 @@ -108,9 +108,9 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: if self._tie_word_embeddings or self._prediction_distance > 0: return # untie embedding weights - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] self.output_weights = ParameterMeta.from_dims( (vocab_dim, hidden_dim), init_method=init_normal_( @@ -237,7 +237,6 @@ def _get_targets( ).flatten() else: lm_target = None - targets = (dpo_target, lm_target, distillation_target) # If we do distillation, no need to split it here as it has already been split in the embedding layer! # if we do CPT/language modeling, we need to split the targets here! @@ -350,9 +349,9 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._logits_scale_factor, ) if self._debug_transformer and self._cross_entropy_splits is None: - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ) + ] dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim] sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) dims[sequence_index] = ( diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index d719bef3d..c8d53a789 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -28,7 +28,7 @@ def __init__( assert config.use_absolute_position_embeddings self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: @@ -76,7 +76,7 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: return diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c06d85148..3b21ca698 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,13 +1,16 @@ import enum +import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig -from fast_llm.tensor import Initializer from fast_llm.utils import Assert, div +if typing.TYPE_CHECKING: + from fast_llm.tensor import Initializer + class SSMDimNames: # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. @@ -16,6 +19,8 @@ class SSMDimNames: head_groups = "ssm_head_groups" group_heads = "ssm_group_heads" + # Mamba 2 + x_proj_dim_2 = "x_proj_dim_2" # d_xb convolution_kernel = "ssm_convolution_kernel" # Kernel dimension of the conv1d in mamba layers dt_rank = "ssm_dt_rank" @@ -62,7 +67,7 @@ class DTInitType(enum.StrEnum): constant = "constant" random = "random" - def get_init_method(self, scale: float) -> Initializer: + def get_init_method(self, scale: float) -> "Initializer": from fast_llm.tensor import init_fill_, init_uniform_centered_ return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 64377b93c..c9d555de9 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -49,25 +49,25 @@ def __init__( layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.concatenated_convolution) - heads_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) + inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + hidden_dim = tensor_space[TransformerDimNames.hidden] + conv1d_dim = tensor_space[SSMDimNames.concatenated_convolution] + heads_dim = tensor_space[SSMDimNames.composite_heads] # local_head_groups = head_groups / TP - self._local_head_groups = tensor_space.get_tensor_dim(SSMDimNames.head_groups).size + self._local_head_groups = tensor_space[SSMDimNames.head_groups].size # local_heads = local_head_groups * group_heads self._local_heads = heads_dim.size # local_inner_size = local_heads * head_size self._local_inner_size = inner_dim.size # local_bc_size = local_head_groups * state - self._local_bc_size = tensor_space.get_tensor_dim(SSMDimNames.composite_head_groups_and_state).size + self._local_bc_size = tensor_space[SSMDimNames.composite_head_groups_and_state].size # TODO: double check initializations # Projections self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), + tensor_space[SSMDimNames.concatenated_inner_projection], bias=config.add_bias_linear, weight_init_method=init_kaiming_(transformer_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -83,8 +83,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( conv1d_dim, - tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 1ae25e44c..77c1b3869 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -62,13 +62,13 @@ def __init__( layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_head_dim) - xb_dim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) - hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) - dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) + inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] + hidden_dim: TensorDim = tensor_space[TransformerDimNames.hidden] + dt_rank_dim = tensor_space[SSMDimNames.dt_rank] - self._local_heads = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads).size - self._local_head_groups = tensor_space.get_tensor_dim(name=SSMDimNames.head_groups).size + self._local_heads = tensor_space[SSMDimNames.composite_heads].size + self._local_head_groups = tensor_space[SSMDimNames.head_groups].size self._group_heads = div(self._local_heads, self._local_head_groups) self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size @@ -77,8 +77,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( conv1d_dim, - tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -90,7 +90,7 @@ def __init__( ) self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), + tensor_space[SSMDimNames.concatenated_inner_projection], bias=config.add_bias_linear, weight_init_method=init_kaiming_(transformer_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -122,7 +122,7 @@ def __init__( lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space.get_tensor_dim(name=SSMDimNames.state)), + (inner_dim, tensor_space[SSMDimNames.state]), init_method=init_A(self._config.state_size, self._config.d_inner), lr_scale=lr_scale, weight_decay=False, diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 64c8227fc..9343ef1b8 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -69,8 +69,8 @@ def __init__( Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) + inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + hidden_dim = tensor_space[TransformerDimNames.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) @@ -78,7 +78,7 @@ def __init__( # TODO: lr_scale? self.in_proj = Linear( hidden_dim, - tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection), + tensor_space[SSMDimNames.concatenated_inner_projection], bias=False, weight_init_method=init_kaiming_(hidden_dim.size), ) @@ -86,8 +86,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( inner_dim, - tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(SSMDimNames.convolution_kernel), + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, @@ -95,7 +95,7 @@ def __init__( self.x_proj = Linear( inner_dim, - tensor_space.get_tensor_dim(SSMDimNames.concatenated_x_projection), + tensor_space[SSMDimNames.concatenated_x_projection], weight_init_method=init_kaiming_(inner_dim.size), bias=False, lr_scale=lr_scale, @@ -104,7 +104,7 @@ def __init__( # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( - (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.dt_rank)), + (inner_dim, tensor_space[SSMDimNames.dt_rank]), init_method=init_kaiming_(self._config.dt_rank), lr_scale=lr_scale, ) @@ -116,7 +116,7 @@ def __init__( ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.state)), + (inner_dim, tensor_space[SSMDimNames.state]), weight_decay=False, init_method=init_A(self._config.state_size, inner_dim.size), lr_scale=lr_scale, diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 7b8bc98c8..c03aeed8e 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -72,14 +72,14 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space.get_tensor_dim(self._transformer_dim_names.kv_channels).size - self._head_groups = self._tensor_space.get_tensor_dim(self._transformer_dim_names.head_groups).global_size - self._local_head_groups = self._tensor_space.get_tensor_dim(self._transformer_dim_names.head_groups).size - self._local_heads_per_group = self._tensor_space.get_tensor_dim(self._transformer_dim_names.group_heads).size + self._kv_channels = self._tensor_space[self._transformer_dim_names.kv_channels].size + self._head_groups = self._tensor_space[self._transformer_dim_names.head_groups].global_size + self._local_head_groups = self._tensor_space[self._transformer_dim_names.head_groups].size + self._local_heads_per_group = self._tensor_space[self._transformer_dim_names.group_heads].size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) + hidden_dim = self._tensor_space[self._transformer_dim_names.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) @@ -87,7 +87,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_query), + self._tensor_space[self._transformer_dim_names.composite_query], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -96,7 +96,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_key_value), + self._tensor_space[self._transformer_dim_names.composite_key_value], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -110,7 +110,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # Output. self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_dense), + self._tensor_space[self._transformer_dim_names.composite_dense], hidden_dim, bias=self._config.add_attn_dense_bias, weight_init_method=init_method_std_attn_proj, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 73f83ccf5..4fd2844d5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -63,8 +63,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( - tensor_space.get_tensor_dim(TransformerDimNames.hidden), - tensor_space.get_tensor_dim(TransformerDimNames.unshared_experts), + tensor_space[TransformerDimNames.hidden], + tensor_space[TransformerDimNames.unshared_experts], bias=False, weight_init_method=init_normal_( std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max @@ -255,7 +255,7 @@ def _debug_log( def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space.get_tensor_dim(dim_name),), + kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), tensor_name=f"{self._name} {name}", dtype=tensor.dtype, ) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index ecf2c3fea..5dee4e077 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -33,8 +33,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) - self._intermediate_dim = tensor_space.get_tensor_dim(self._transformer_dim_names.composite_expert_mlp) + hidden_dim = tensor_space[self._transformer_dim_names.hidden] + self._intermediate_dim = tensor_space[self._transformer_dim_names.composite_expert_mlp] self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._recompute_level = config.mlp_recompute_level @@ -49,7 +49,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space.get_tensor_dim(self._transformer_dim_names.composite_gated_expert_mlp), + tensor_space[self._transformer_dim_names.composite_gated_expert_mlp], bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index cb64ccf06..ee30112d7 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -30,7 +30,7 @@ def __init__( self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert not self._config.do_use_flash_attention(self._distributed_config) - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py index cc83dae02..c357411b6 100644 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ b/fast_llm/layers/transformer/rotary/preprocessing.py @@ -25,8 +25,8 @@ def __init__( self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: self._create_tensors(kwargs[TransformerKwargs.sequence_length]) diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index b2c69dd8d..6b4b81415 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -84,8 +84,8 @@ def __init__( super().__init__(config, tensor_space) self._tensor_space = tensor_space if self._tensor_space is not None: - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index d2f3bfba8..9289dccfb 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -48,7 +48,7 @@ def _get_meta( } return TensorMeta.from_dims( tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] for dim_name in dim_names ), tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", @@ -99,8 +99,7 @@ def __init__( self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - - hidden_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) + hidden_dim = self._tensor_space[self._transformer_dim_names.hidden] # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index bc64821f2..182ad1712 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -184,6 +184,12 @@ def get_model_class(cls) -> type["GPTModel"]: return GPTModel + @classmethod + def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: + from fast_llm.models.gpt.model import GPTInferenceRunner + + return GPTInferenceRunner + @classmethod def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM @@ -289,9 +295,3 @@ def get_trainer_class(cls) -> type["GPTTrainer"]: from fast_llm.models.gpt.trainer import GPTTrainer return GPTTrainer - - @classmethod - def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: - from fast_llm.models.gpt.model import GPTInferenceRunner - - return GPTInferenceRunner diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index e7379e61e..20ed8e828 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -14,8 +14,8 @@ def get_init_megatron( meta: "ParameterMeta", config: TransformerConfig -) -> typing.Callable[["torch.Tensor", "Distributed"], "torch.Tensor"]: - def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): +) -> typing.Callable[["torch.Tensor", "Distributed"], None]: + def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: Assert.eq(distributed.config.world_size, 1) if "bias" in meta.tensor_name: # Generator unused. @@ -29,11 +29,11 @@ def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): elif config.num_experts > 1 and "mlp.layer_" in meta.tensor_name: tensor_ = _init_moe_mlp_megatron(config, meta, tensor, distributed) elif "mlp.layer_2" in meta.tensor_name: - tensor_ = _init_transposed_mlp_weight_megatron(config, meta, tensor, distributed) + tensor_ = _init_transposed_mlp_weight_megatron(meta, tensor, distributed) else: # Word embedding (override generator), layer norm (generator unused), other mlp weights. return meta.param_init_method(meta, tensor, distributed.tp_init_generator) - return tensor.copy_(tensor_.reshape_as(tensor)) + tensor.copy_(tensor_.reshape_as(tensor)) return init_megatron @@ -58,9 +58,9 @@ def _init_attention_megatron( generator = distributed.tp_init_generator state = generator.get_state() # Initialize a mock dense layer to advance the random state - dense_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + dense_tensor_ := tensor.new_empty( config.kv_channels * config.num_attention_heads, config.hidden_size, ), @@ -68,9 +68,9 @@ def _init_attention_megatron( ) # QKV is split differently. (Assuming no tensor-parallel.) heads_per_group = div(config.num_attention_heads, config.head_groups) - qkv_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + qkv_tensor_ := tensor.new_empty( config.head_groups, heads_per_group + 2, config.kv_channels, @@ -110,18 +110,19 @@ def _init_position_embeddings_megatron( # Megatron initializes the position embeddings on cpu twice. assert meta.param_init_method is not None generator = distributed.default_cpu_generator - tensor_ = meta.param_init_method(meta, torch.empty(tensor.shape, dtype=tensor.dtype), generator) - return meta.param_init_method(meta, tensor_, generator) + meta.param_init_method(meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), generator) + meta.param_init_method(meta, tensor_, generator) + return tensor_ def _init_transposed_mlp_weight_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": import torch # Megatron never transposes the mlp layer 2 weight. assert meta.param_init_method is not None - tensor_ = meta.param_init_method(meta, torch.empty_like(tensor), distributed.tp_init_generator) + meta.param_init_method(meta, tensor_ := torch.empty_like(tensor), distributed.tp_init_generator) return tensor_.view(meta.size(1), meta.size(0)).t() @@ -132,8 +133,8 @@ def _init_moe_router_megatron( # Megatron initializes the router on cpu. assert meta.param_init_method is not None - tensor_ = meta.param_init_method( - meta, torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator + meta.param_init_method( + meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator ) return tensor_ diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 7b4d165c5..ebf84fc58 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -219,7 +219,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 886fa7a32..471e6d06c 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -4,19 +4,23 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.checkpoint.config import CheckpointHandler from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.models.gpt.config import ( + GPTBaseModelConfig, + GPTBatchConfig, + GPTHuggingfaceCheckpointFormat, + PretrainedGPTModelConfig, +) from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.models.gpt.model import GPTInferenceRunner from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM - from fast_llm.models.ssm.model import HybridSSMModel + from fast_llm.models.ssm.model import HybridSSMInferenceRunner, HybridSSMModel from fast_llm.models.ssm.trainer import HybridSSMTrainer logger = logging.getLogger(__name__) @@ -80,8 +84,7 @@ def _validate(self): self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None -class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class LLambaHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "llamba" @classmethod @@ -91,8 +94,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return LLambaHuggingfaceCheckpointHandler -class AprielSSMHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielSSMHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm" @classmethod @@ -102,8 +104,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielSSMHuggingfaceCheckpointHandler -class AprielSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielSSMHHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm_hybrid" @classmethod @@ -113,8 +114,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielSSMHHybridHuggingfaceCheckpointHandler -class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" trust_remote_code: typing.ClassVar[bool] = True @@ -165,6 +165,16 @@ def get_model_class(cls) -> type["HybridSSMModel"]: return HybridSSMModel + @classmethod + def get_inference_runner_class(cls) -> type["HybridSSMInferenceRunner"]: + from fast_llm.models.ssm.model import HybridSSMInferenceRunner + + logger.warning( + "HybridSSMInferenceRunner only supports training-style forward pass. Use generate with cache disabled." + ) + + return HybridSSMInferenceRunner + @classmethod def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM @@ -213,14 +223,3 @@ def _validate(self) -> None: Assert.none(reference_model.model.base_model.cross_entropy_splits) Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) - - @classmethod - def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: - from fast_llm.models.gpt.model import GPTInferenceRunner - - # TODO: we dont have inference runner for SSM/Hybrid yet, should return None? - logger.warning( - "No inference runner for SSM/Hybrid yet, using GPTInferenceRunner for now, which does not support SSM/Hybrid" - ) - - return GPTInferenceRunner diff --git a/fast_llm/models/ssm/external/15B_hybrid.ipynb b/fast_llm/models/ssm/external/15B_hybrid.ipynb new file mode 100644 index 000000000..a8f0c33b7 --- /dev/null +++ b/fast_llm/models/ssm/external/15B_hybrid.ipynb @@ -0,0 +1,1562 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from transformers import AutoConfig, AutoModelForCausalLM\n", + "# from transformers import MistralForCausalLM\n", + "# from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig\n", + "# from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import AprielSSMHybridForCausalLM\n", + "# autoreload changes to the code\n", + "%reload_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# model_path = \"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15bch-ifrhyb20l32h-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2/export/apriel_ssm_thinker_hybrid/1000\"\n", + "# AutoConfig.from_pretrained(model_path, trust_remote_code=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# model_path = \"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15bch-ifrhyb20l32h-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2/export/apriel_ssm_thinker_hybrid/1000\"\n", + "# m = AutoModelForCausalLM.from_pretrained(\n", + "# model_path, trust_remote_code=True,\n", + "# config=AutoConfig.from_pretrained(model_path, trust_remote_code=True),\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Slam 15B upcycled" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " Lead the weights of https://huggingface.co/ServiceNow-AI/Slam-15B-Upcycled/ into Thiked modeling, it shoudl work" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"/home/toolkit/dev/fml-ops/__oo_playground\")\n", + "from results_analysis.results_loader import ResultsLoader\n", + "layer_importance_path = \"/mnt/evaluations/training_evaluation/model_runs/lm_eval_runner/apriel_ssm_importance/\"\n", + "results_loader = ResultsLoader(layer_importance_path)\n", + "\n", + "results_loader.deserialize_results()\n", + "results_df = results_loader.to_df()\n", + "results_df[\"layer_index\"] = results_df.apply(lambda row: int(row[\"model_name_sanitized\"].split(\"_\")[-1] if \"layers_\" in row[\"model_name_sanitized\"] else -1), axis=1)\n", + "results_df = results_df[results_df[\"metric\"] == \"acc_norm\"]\n", + "columns_to_keep = [\"layer_index\", \"metric_value\"]\n", + "results_df = results_df[columns_to_keep]\n", + "layer_importance = results_df.groupby(\"layer_index\").mean()\n", + "layer_importance = layer_importance.sort_values(by=\"metric_value\", ascending=False).reset_index()\n", + "layer_importance = layer_importance[layer_importance[\"layer_index\"]!= -1]\n", + "layer_importance = list(layer_importance[\"layer_index\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[22,\n", + " 25,\n", + " 20,\n", + " 31,\n", + " 29,\n", + " 46,\n", + " 23,\n", + " 26,\n", + " 33,\n", + " 24,\n", + " 47,\n", + " 27,\n", + " 21,\n", + " 41,\n", + " 17,\n", + " 18,\n", + " 34,\n", + " 42,\n", + " 44,\n", + " 30,\n", + " 16,\n", + " 8,\n", + " 43,\n", + " 35,\n", + " 19,\n", + " 38,\n", + " 15,\n", + " 28,\n", + " 32,\n", + " 45,\n", + " 37,\n", + " 40,\n", + " 7,\n", + " 36,\n", + " 13,\n", + " 10,\n", + " 5,\n", + " 39,\n", + " 6,\n", + " 14,\n", + " 4,\n", + " 12,\n", + " 9,\n", + " 48,\n", + " 1,\n", + " 3,\n", + " 11,\n", + " 49,\n", + " 0]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "layer_importance" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "path_thinker = \"/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker\"\n", + "n_ssm = 25\n", + "\n", + "config_thinker = AutoConfig.from_pretrained(path_thinker)\n", + "hybrid_block_layout = [\"t\"] * config_thinker.num_hidden_layers\n", + "\n", + "for i in range(n_ssm):\n", + " hybrid_block_layout[layer_importance[i]] = \"m2d\"\n", + "\n", + "config_hybrid = AprielSSMHybridConfig(\n", + " **config_thinker.to_dict(),\n", + " hybrid_block_layout=hybrid_block_layout,\n", + " ssm_cfg = {\n", + " \"d_state\": 64,\n", + " \"n_v_heads\": 32,\n", + " \"n_qk_heads\": 32,\n", + " \"expand\": 1,\n", + " \"chunk_size\": 128,\n", + " \"activation\": \"identity\",\n", + " \"bias\": False,\n", + " \"d_conv\": 4,\n", + " \"d_inner\": 32 * 128\n", + " }\n", + ")\n", + "model_hybrid = AprielSSMHybridForCausalLM(config_hybrid)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['t',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 'm2d',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 't']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hybrid_block_layout" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "You are using a model of type llama to instantiate a model of type mistral. This is not supported for all configurations of models and can yield errors.\n", + "Loading checkpoint shards: 0%| | 0/4 [00:00 v, B -> k, C -> q\n", + " mamba_encoder.mixer.in_proj.weight.data[\n", + " mamba_config.ssm_cfg[\"d_inner\"] : mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"], :\n", + " ].copy_(layer_module.self_attn.v_proj.weight.data)\n", + " mamba_encoder.mixer.in_proj.weight.data[\n", + " mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"] : mamba_config.ssm_cfg[\"d_inner\"] + 2 * mamba_config.ssm_cfg[\"d_xb\"], :\n", + " ].copy_(layer_module.self_attn.k_proj.weight.data)\n", + " mamba_encoder.mixer.in_proj.weight.data[\n", + " mamba_config.ssm_cfg[\"d_inner\"] + 2 * mamba_config.ssm_cfg[\"d_xb\"] : 2 * mamba_config.ssm_cfg[\"d_inner\"] + 2 * mamba_config.ssm_cfg[\"d_xb\"], :\n", + " ].copy_(layer_module.self_attn.q_proj.weight.data)\n", + "\n", + " print(\"Init Mamba using Attention\")\n", + "\n", + " transformer.model.layers[layer_idx] = mamba_encoder\n", + "\n", + " # elif type == \"m2d\":\n", + " # print(\"Converting layer %d...\" % layer_idx)\n", + " # mamba_encoder = AprielSSMDecoderLayer(\n", + " # mamba_config,\n", + " # layer_idx,\n", + " # device=\"cpu\",\n", + " # dtype=torch_dtype,\n", + " # )\n", + " # mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict())\n", + " # mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict())\n", + " # mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict())\n", + " # mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict())\n", + "\n", + " # if init_with_kqvo:\n", + " \n", + "\n", + "\n", + " \n", + " else:\n", + " raise ValueError(f\"Invalid layer type: {type}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 7/7 [00:05<00:00, 1.22it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Converting layer %d... 0\n", + "Skipping transformer layer 0...\n", + "Converting layer %d... 1\n", + "Skipping transformer layer 1...\n", + "Converting layer %d... 2\n", + "Skipping transformer layer 2...\n", + "Converting layer %d... 3\n", + "Skipping transformer layer 3...\n", + "Converting layer %d... 4\n", + "Skipping transformer layer 4...\n", + "Converting layer %d... 5\n", + "Skipping transformer layer 5...\n", + "Converting layer %d... 6\n", + "Skipping transformer layer 6...\n", + "Converting layer %d... 7\n", + "Skipping transformer layer 7...\n", + "Converting layer %d... 8\n", + "Skipping transformer layer 8...\n", + "Converting layer %d... 9\n", + "Skipping transformer layer 9...\n", + "Converting layer %d... 10\n", + "Skipping transformer layer 10...\n", + "Converting layer %d... 11\n", + "Skipping transformer layer 11...\n", + "Converting layer %d... 12\n", + "Skipping transformer layer 12...\n", + "Converting layer %d... 13\n", + "Skipping transformer layer 13...\n", + "Converting layer %d... 14\n", + "Skipping transformer layer 14...\n", + "Converting layer %d... 15\n", + "Skipping transformer layer 15...\n", + "Converting layer %d... 16\n", + "Skipping transformer layer 16...\n", + "Converting layer %d... 17\n", + "Skipping transformer layer 17...\n", + "Converting layer %d... 18\n", + "Skipping transformer layer 18...\n", + "Converting layer %d... 19\n", + "Skipping transformer layer 19...\n", + "Converting layer %d... 20\n", + "Converting layer 20...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 21\n", + "Converting layer 21...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 22\n", + "Converting layer 22...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 23\n", + "Converting layer 23...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 24\n", + "Converting layer 24...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 25\n", + "Converting layer 25...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 26\n", + "Skipping transformer layer 26...\n", + "Converting layer %d... 27\n", + "Converting layer 27...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 28\n", + "Skipping transformer layer 28...\n", + "Converting layer %d... 29\n", + "Skipping transformer layer 29...\n", + "Converting layer %d... 30\n", + "Converting layer 30...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 31\n", + "Converting layer 31...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 32\n", + "Converting layer 32...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 33\n", + "Converting layer 33...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 34\n", + "Skipping transformer layer 34...\n", + "Converting layer %d... 35\n", + "Converting layer 35...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 36\n", + "Converting layer 36...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 37\n", + "Converting layer 37...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 38\n", + "Converting layer 38...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 39\n", + "Converting layer 39...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 40\n", + "Converting layer 40...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 41\n", + "Converting layer 41...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 42\n", + "Converting layer 42...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 43\n", + "Converting layer 43...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 44\n", + "Converting layer 44...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 45\n", + "Converting layer 45...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 46\n", + "Converting layer 46...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 47\n", + "Converting layer 47...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 48\n", + "Skipping transformer layer 48...\n", + "Converting layer %d... 49\n", + "Converting layer 49...\n", + "Init Mamba using Attention\n" + ] + } + ], + "source": [ + "transformer = AutoModelForCausalLM.from_pretrained(path_thinker)\n", + "init_with_kqvo = True\n", + "torch_dtype = torch.bfloat16\n", + "attn_bias = True\n", + "convert_layers(transformer, config_hybrid, hybrid_block_layout, init_with_kqvo, attn_bias, torch_dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "transformer.config = config_hybrid" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMHybridConfig {\n", + " \"architectures\": [\n", + " \"MistralForCausalLM\"\n", + " ],\n", + " \"attention_dropout\": 0.0,\n", + " \"bos_token_id\": 1,\n", + " \"eos_token_id\": 2,\n", + " \"head_dim\": 128,\n", + " \"hidden_act\": \"silu\",\n", + " \"hidden_size\": 5120,\n", + " \"hybrid_block_layout\": [\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"t\",\n", + " \"m2\",\n", + " \"t\",\n", + " \"t\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"t\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"t\",\n", + " \"m2\"\n", + " ],\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 14336,\n", + " \"max_position_embeddings\": 65536,\n", + " \"model_type\": \"apriel_ssm_thinker_hybrid\",\n", + " \"num_attention_heads\": 32,\n", + " \"num_hidden_layers\": 50,\n", + " \"num_key_value_heads\": 8,\n", + " \"rms_norm_eps\": 1e-05,\n", + " \"rope_theta\": 1000000.0,\n", + " \"sliding_window\": null,\n", + " \"ssm_cfg\": {\n", + " \"activation\": \"identity\",\n", + " \"bias\": false,\n", + " \"chunk_size\": 128,\n", + " \"conv_bias\": true,\n", + " \"d_conv\": 4,\n", + " \"d_inner\": 4096,\n", + " \"d_state\": 16,\n", + " \"d_xb\": 1024,\n", + " \"dt_init\": \"random\",\n", + " \"dt_init_floor\": 0.0001,\n", + " \"dt_max\": 0.1,\n", + " \"dt_min\": 0.001,\n", + " \"dt_rank\": \"auto\",\n", + " \"dt_scale\": 1.0,\n", + " \"expand\": 1,\n", + " \"n_qk_heads\": 32,\n", + " \"n_v_heads\": 32\n", + " },\n", + " \"tie_word_embeddings\": false,\n", + " \"torch_dtype\": \"bfloat16\",\n", + " \"transformers_version\": \"4.53.2\",\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 131072\n", + "}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "transformer.config" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "transformer.config.architectures=[\"AprielThinkerSSMHybridForCausalLM\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 427.77it/s]\n" + ] + } + ], + "source": [ + "# load state dict from existing pretrained SSM?\n", + "path_25hyb = \"/mnt/checkpoints/ssm/apriel_ssm_thinker5l_hybrid_1ssm_init_rand_debug_tpformat\" #\"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15b-oshyb25lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6/export/apriel_ssm_thinker_hybrid/5000_new\"\n", + "model = AprielThinkerSSMHybridForCausalLM.from_pretrained(path_25hyb)\n", + "state_dict = model.state_dict()\n", + "\n", + "# missing, unexpected = transformer.load_state_dict(state_dict, strict=False)\n", + "# print(missing)\n", + "# print(unexpected)\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Note: saving as transformer wilkl still keep architectures[\"Mistral....\"]. So currently need to manually update the checkpoints architectures list to have AprielThinkerSSMHybridForCausalLM" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# mamba2, state 16, expand 1, i.e. same as M1, but with discrete mamba2 and MIL\n", + "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_1ssm_leastimportant_m2_16hexp1_init_mil\") # 1 ssm\n", + "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_25ssm_leastimportant_m2_16hexp1_init_mil\") # 25 ssm\n", + "transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_25ssm_leastimportant_m2_16hexp1_init_mil_tpformat\") # 25 ssm\n", + "\n", + "\n", + "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_40ssm_leastimportant_m2_16hexp1_init_mil_uniform_from_25h5000lm6\") # 40 ssm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data mixing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([])\n", + "KL (global, F.kl_div) = 0.738795\n", + "KL (sum of shards, manual) = 0.738795\n" + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fast_llm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/fast_llm/models/ssm/external/5B_hybrid.ipynb b/fast_llm/models/ssm/external/5B_hybrid.ipynb new file mode 100644 index 000000000..9a33f577e --- /dev/null +++ b/fast_llm/models/ssm/external/5B_hybrid.ipynb @@ -0,0 +1,416 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "\n", + "import torch\n", + "import random\n", + "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", + "\n", + "fast_llm_path = \"/home/toolkit/dev/Fast-LLM\"\n", + "\n", + "# add fast_llm to the python path\n", + "import sys\n", + "sys.path.append(fast_llm_path)\n", + "from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridConfig\n", + "from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridModel, AprielSSMDecoderLayer, AprielSSMHybridForCausalLM\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "base = 0.612615\n", + "layer_scores = {\n", + " \"22\": 0.607389,\n", + " \"24\": 0.603498,\n", + " \"19\": 0.597907,\n", + " \"27\": 0.597173,\n", + " \"20\": 0.590442,\n", + " \"5\": 0.578949,\n", + " \"4\": 0.576852,\n", + " \"9\": 0.576484,\n", + " \"23\": 0.574833,\n", + " \"7\": 0.571860,\n", + " \"8\": 0.571790,\n", + " \"6\": 0.571614,\n", + " \"2\": 0.571330,\n", + " \"26\": 0.570205,\n", + " \"11\": 0.567128,\n", + " \"14\": 0.566175,\n", + " \"15\": 0.566076,\n", + " \"3\": 0.562861,\n", + " \"1\": 0.560154,\n", + " \"13\": 0.559304,\n", + " \"16\": 0.559017,\n", + " \"10\": 0.558789,\n", + " \"12\": 0.555186,\n", + " \"17\": 0.554236,\n", + " \"25\": 0.549215,\n", + " \"18\": 0.537257,\n", + " \"0\": 0.233085,\n", + "}\n", + "layer_scores = {k: base - v for k, v in layer_scores.items()}\n", + "layer_importanfce = sorted(layer_scores.items(), key=lambda x: x[1])\n", + "layer_importanfce_rand = random.sample(layer_importanfce, len(layer_importanfce))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('22', 0.005226000000000064),\n", + " ('24', 0.009117000000000042),\n", + " ('19', 0.014708000000000054),\n", + " ('27', 0.015442000000000067),\n", + " ('20', 0.022173),\n", + " ('5', 0.033665999999999974),\n", + " ('4', 0.03576299999999999),\n", + " ('9', 0.036131000000000024),\n", + " ('23', 0.03778199999999998),\n", + " ('7', 0.040754999999999986),\n", + " ('8', 0.040825),\n", + " ('6', 0.041001000000000065),\n", + " ('2', 0.041285000000000016),\n", + " ('26', 0.04241000000000006),\n", + " ('11', 0.045487000000000055),\n", + " ('14', 0.04644000000000004),\n", + " ('15', 0.046539),\n", + " ('3', 0.049754000000000076),\n", + " ('1', 0.05246099999999998),\n", + " ('13', 0.053311),\n", + " ('16', 0.053598000000000035),\n", + " ('10', 0.05382600000000004),\n", + " ('12', 0.05742900000000006),\n", + " ('17', 0.05837900000000007),\n", + " ('25', 0.06340000000000001),\n", + " ('18', 0.07535800000000004),\n", + " ('0', 0.37953000000000003)]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "layer_importanfce" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "layer_importanfce = layer_importanfce_rand" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create hybrid with any number of SSM layers" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", + "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "device = \"cuda\"\n", + "n_hybrid = 0\n", + "\n", + "index_swaped = []\n", + "hybrid_block_layout = [\"t\"] * config.num_hidden_layers\n", + "for i in range(n_hybrid):\n", + " hybrid_block_layout[int(layer_importanfce[i][0])] = \"m2d\"\n", + " index_swaped.append(int(layer_importanfce[i][0]))\n", + "\n", + "hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(),\n", + " hybrid_block_layout=hybrid_block_layout,\n", + " ssm_cfg={\n", + " \"d_state\": 64,\n", + " \"n_v_heads\": 24,\n", + " \"n_qk_heads\": 24,\n", + " \"expand\": 1,\n", + " \"chunk_size\": 128,\n", + " \"activation\": \"identity\",\n", + " \"bias\": False,\n", + " \"d_inner\": 24 * 128, # num_heads * head_dim\n", + " })" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['t',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't']" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hybrdif_apriel_config.hybrid_block_layout" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMHybridForCausalLM(\n", + " (model): AprielSSMHybridModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): AprielRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hybrid_apriel_model = AprielSSMHybridForCausalLM(hybrdif_apriel_config)\n", + "hybrid_apriel_model.to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 2.22it/s]\n" + ] + } + ], + "source": [ + "\n", + "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", + "apriel_state_dict = apriel_model.state_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "missing, unexpected = hybrid_apriel_model.load_state_dict(apriel_state_dict, strict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Missing keys: []\n", + "Unexpected keys: []\n" + ] + } + ], + "source": [ + "# unexpected will contain keys from the SSM layers we added\n", + "print(\"Missing keys:\", missing)\n", + "# unexpected will contain keys from the transformer layers we replaced\n", + "print(\"Unexpected keys:\", unexpected)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 5/5 [00:04<00:00, 1.22it/s]\n" + ] + } + ], + "source": [ + "from fast_llm.models.ssm.external.apriel_ssm.modeling_ssm_apriel import AprielSSMModel, AprielSSMForCausalLM\n", + "\n", + "mohawk_path = \"/mnt/checkpoints/ssm/mohawk_distributed_stage2_apriel_8GPU_16ksteps_lr0.0_layernorm/final\"\n", + "# config = AutoConfig.from_pretrained(mohawk_path, trust_remote_code=True)\n", + "apriel_model = AprielSSMForCausalLM.from_pretrained(mohawk_path, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", + "apriel_state_dict = apriel_model.state_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "missing, unexpected = hybrid_apriel_model.load_state_dict(apriel_state_dict, strict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Missing keys: ['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight']\n", + "Unexpected keys: ['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.10.mixer.z_bias', 'model.layers.10.mixer.D', 'model.layers.10.mixer.in_proj.weight', 'model.layers.10.mixer.conv1d.weight', 'model.layers.10.mixer.conv1d.bias', 'model.layers.10.mixer.out_proj.weight', 'model.layers.11.mixer.z_bias', 'model.layers.11.mixer.D', 'model.layers.11.mixer.in_proj.weight', 'model.layers.11.mixer.conv1d.weight', 'model.layers.11.mixer.conv1d.bias', 'model.layers.11.mixer.out_proj.weight', 'model.layers.12.mixer.z_bias', 'model.layers.12.mixer.D', 'model.layers.12.mixer.in_proj.weight', 'model.layers.12.mixer.conv1d.weight', 'model.layers.12.mixer.conv1d.bias', 'model.layers.12.mixer.out_proj.weight', 'model.layers.13.mixer.z_bias', 'model.layers.13.mixer.D', 'model.layers.13.mixer.in_proj.weight', 'model.layers.13.mixer.conv1d.weight', 'model.layers.13.mixer.conv1d.bias', 'model.layers.13.mixer.out_proj.weight', 'model.layers.14.mixer.z_bias', 'model.layers.14.mixer.D', 'model.layers.14.mixer.in_proj.weight', 'model.layers.14.mixer.conv1d.weight', 'model.layers.14.mixer.conv1d.bias', 'model.layers.14.mixer.out_proj.weight', 'model.layers.15.mixer.z_bias', 'model.layers.15.mixer.D', 'model.layers.15.mixer.in_proj.weight', 'model.layers.15.mixer.conv1d.weight', 'model.layers.15.mixer.conv1d.bias', 'model.layers.15.mixer.out_proj.weight', 'model.layers.16.mixer.z_bias', 'model.layers.16.mixer.D', 'model.layers.16.mixer.in_proj.weight', 'model.layers.16.mixer.conv1d.weight', 'model.layers.16.mixer.conv1d.bias', 'model.layers.16.mixer.out_proj.weight', 'model.layers.17.mixer.z_bias', 'model.layers.17.mixer.D', 'model.layers.17.mixer.in_proj.weight', 'model.layers.17.mixer.conv1d.weight', 'model.layers.17.mixer.conv1d.bias', 'model.layers.17.mixer.out_proj.weight', 'model.layers.18.mixer.z_bias', 'model.layers.18.mixer.D', 'model.layers.18.mixer.in_proj.weight', 'model.layers.18.mixer.conv1d.weight', 'model.layers.18.mixer.conv1d.bias', 'model.layers.18.mixer.out_proj.weight', 'model.layers.21.mixer.z_bias', 'model.layers.21.mixer.D', 'model.layers.21.mixer.in_proj.weight', 'model.layers.21.mixer.conv1d.weight', 'model.layers.21.mixer.conv1d.bias', 'model.layers.21.mixer.out_proj.weight', 'model.layers.25.mixer.z_bias', 'model.layers.25.mixer.D', 'model.layers.25.mixer.in_proj.weight', 'model.layers.25.mixer.conv1d.weight', 'model.layers.25.mixer.conv1d.bias', 'model.layers.25.mixer.out_proj.weight']\n" + ] + } + ], + "source": [ + "# unexpected will contain keys from the SSM layers we added\n", + "print(\"Missing keys:\", missing)\n", + "# unexpected will contain keys from the transformer layers we replaced\n", + "print(\"Unexpected keys:\", unexpected)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# hybrid_apriel_model.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_14ssm_leastimportant_init_MOHAWK\")\n", + "# hybrid_apriel_model.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_20ssm_leastimportant_init_rand\")\n", + "# hybrid_apriel_model.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_14ssm_randplacement_init_rand\")\n", + "hybrid_apriel_model.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_0ssm_full_transformer_debug\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "# save the hybrid model\n", + "output_path = \"/mnt/checkpoints/ssm/iterative_hybrids_5b\"\n", + "assert len(index_swaped) == 1\n", + "layer_swaped = index_swaped[0]\n", + "hybrid_apriel_model.save_pretrained(\n", + " f\"{output_path}/apriel_ssm_instruct5b_hybrid_{layer_swaped+1}ssm_leastimportant_32h_init_rand\"\n", + " )\n", + "print(f\"Hybrid model saved to {output_path}/apriel_ssm_instruct5b_hybrid_{layer_swaped+1}ssm_leastimportant_32h_init_rand\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fast_llm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index 35e9b6885..9f4588a29 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -893,7 +893,7 @@ def forward( self, hidden_states: torch.Tensor, past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - attention_mask: Optional[torch.Tensor] = None, + mamba_mask: Optional[torch.Tensor] = None, return_mixer_matrix=False, **kwargs, ): @@ -905,6 +905,10 @@ def forward( assert is_fast_path_available and "cuda" in self.in_proj.weight.device.type, "Only support fast path on cuda" cache_position = kwargs.get("cache_position", None) batch, seqlen, dim = hidden_states.shape + # mamba_mask = ( + # None if seqlen == 1 else mamba_mask + # ) # prevent that hidden_states are expanded to mask's seq. dimention., i.e. we do not need apply_mask_to_padding_states when generating single token at a time + # hidden_states = apply_mask_to_padding_states(hidden_states, mamba_mask) ssm_state, conv_state = None, None use_precomputed_states = False @@ -985,7 +989,7 @@ def forward( # Update state (B D W) conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) + x = self.act(self.conv1d(x)[..., :seqlen]).transpose(1, 2) else: assert self.activation in ["silu", "swish"] x = causal_conv1d_fn( @@ -993,7 +997,10 @@ def forward( weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, activation=self.activation, - ) + ) # .transpose(1, 2) + # x = apply_mask_to_padding_states(x, mamba_mask).transpose( + # 1, 2 + # ) # zero out everything that comes from padding tokens if not self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) @@ -1048,14 +1055,14 @@ def step(self, hidden_states, conv_state, ssm_state): A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states_input) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states_input) + z, x, B, C = torch.split(zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1) B = rearrange(B, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) B = torch.repeat_interleave(B, dim=1, repeats=self.repeat_group) C = rearrange(C, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) # B, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states_input)) # B, d_inner if self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) @@ -1239,9 +1246,10 @@ def forward( ) -> BaseModelOutputWithPast: use_cache = use_cache if use_cache is not None else self.config.use_cache if use_cache and past_key_values is None: + # for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test) batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] past_key_values = HybridMambaAttentionDynamicCache(self.config, batch_size, self.dtype, device=self.device) - return super().forward( + output = super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1253,6 +1261,10 @@ def forward( cache_position=cache_position, **flash_attn_kwargs, ) + past_key_values: HybridMambaAttentionDynamicCache = output.past_key_values + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + return output class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -1435,6 +1447,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, + mamba_mask=attention_mask, # non-expended mask **kwargs, ) diff --git a/fast_llm/models/ssm/huggingface.py b/fast_llm/models/ssm/huggingface.py index 02f472076..1ece10edf 100644 --- a/fast_llm/models/ssm/huggingface.py +++ b/fast_llm/models/ssm/huggingface.py @@ -20,4 +20,5 @@ class HuggingfaceHybridSSMModelForCausalLM(HuggingfaceGPTModelForCausalLM): config: HuggingfaceSSMModelConfig runner_class: typing.ClassVar[type[HybridSSMInferenceRunner]] = HybridSSMInferenceRunner model_class = HybridSSMModel + runner_class: typing.ClassVar[type[HybridSSMInferenceRunner]] = HybridSSMInferenceRunner _fast_llm_model: HybridSSMModel diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b89ed4a04..d080e6a1e 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,5 +1,6 @@ import abc import functools +import logging import math import typing @@ -13,6 +14,8 @@ from fast_llm.functional.triton.pointwise import triton_add, triton_copy from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + class _SafeTensorSliceMeta(type): def __instancecheck__(self, instance) -> bool: @@ -147,7 +150,7 @@ def from_tensor_space( reductions: tuple[tuple[str, ReduceOp], ...] = (), **kwargs: typing.Any, ) -> typing.Self: - dims = tuple(tensor_space.get_tensor_dim(dim_name) for dim_name in dim_names) + dims = tuple(tensor_space[dim_name] for dim_name in dim_names) if reductions: # kwarg not available for ParameterMeta, so we only provide if necessary. kwargs["reductions"] = tuple( @@ -159,12 +162,11 @@ def from_tensor_space( def global_shape(self) -> torch.Size: return torch.Size([dim.global_size for dim in self.dims]) - def local_to_global( - self, - tensor: torch.Tensor, - *, - distributed: Distributed, - ) -> tuple[torch.Tensor, ...]: + def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> tuple[torch.Tensor, ...]: + """ + Reconstruct a global tensor from its distributed slices. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ if tensor.ndim == 0: tensor = tensor[None] Assert.eq(tensor.shape, self.shape) @@ -188,14 +190,28 @@ def local_to_global( Assert.eq(tensor.shape, self.global_shape) return tensor, is_first_rank - def global_to_local( - self, - tensor: torch.Tensor | SafeTensorSlice, - # Return an expanded tensor, avoiding `flatten` which copies the data. TODO: Rework. - expand: bool = False, - ) -> torch.Tensor: + def local_to_global_partial(self, tensor: torch.Tensor, fill_value: float | int = -1) -> torch.Tensor: + """ + Construct a tensor of shape `self.global_shape` that contains its local slice at the appropriate location, + i.e. for which `self.global_to_local(self.local_to_global_partial(tensor)) == tensor`. + Other entries are filled with `fill_value`. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) + assert not self._reductions + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global_partial(tensor, dim, fill_value) + + Assert.eq(tensor.shape, self.global_shape) + return tensor + + def global_to_local(self, tensor: torch.Tensor | SafeTensorSlice) -> torch.Tensor: """ - Recover the tensor-parallel slice of a tensor. Support lazy-loaded safetensor slices. + Select the local slice of a global tensor. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. """ # Take a trivial slice to convert safetensor slices. tensor = tensor[:] @@ -205,9 +221,9 @@ def global_to_local( Assert.eq(tensor.shape, self.global_shape) for dim, tensor_dim in reversed(list(enumerate(self.dims))): - tensor = tensor_dim.global_to_local(tensor, dim, expand) - if not expand: - Assert.eq(tensor.shape, self.shape) + tensor = tensor_dim.global_to_local(tensor, dim) + + Assert.eq(tensor.shape, self.shape) return tensor @classmethod @@ -302,7 +318,11 @@ def __repr__(self, *, tensor_contents=()) -> str: def init_parameter(self, tensor: torch.Tensor, distributed: Distributed) -> None: assert self.param_init_method is not None - if distributed.config.tensor_parallel == 1 or distributed.config.reproducible_init: + if ( + distributed.config.tensor_parallel == 1 + or distributed.config.reproducible_init + or self.param_init_method.requires_global_initialization + ): generator = distributed.pp_init_generator else: generator = distributed.tp_init_generator if self.is_tensor_parallel else distributed.pp_init_generator diff --git a/setup.cfg b/setup.cfg index baa6e4adc..6ea98610c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,20 +41,20 @@ OPTIONAL = # Huggingface tools HUGGINGFACE = - transformers>=4.52.4 + transformers==4.53.2 hf-transfer>=0.1.9 datasets>=3.6.0 huggingface-hub>=0.32.6 # Required to run SSMs # To install on cpu environment (ex. for IDE support): -# MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[SSM]" --no-build-isolation +# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 cartesia_pytorch>=0.0.2 -GENERATION = - lm_eval>=0.4.9 +# GENERATION = +# lm_eval>=0.4.9 # Required for supporting vision inputs diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 9a878c494..6d00d05ba 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -23,12 +23,10 @@ def _reverse_kl_loss( ): scaled_target = target / teacher_softmax_temperature - scaled_target = torch.clamp(target, min=-50, max=50) teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) with torch.enable_grad(): # Use log_softmax for consistency instead of _fused_softmax - logits = torch.clamp(logits, min=-50, max=50) student_log_probs = torch.log_softmax(logits, dim=-1) if loss_mask is None: loss = torch.nn.functional.kl_div( From ad6c0c0a0a91e77495c730dc8caf625dd0c020a5 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 7 Aug 2025 21:31:03 +0000 Subject: [PATCH 153/161] checkpoint format for llava --- fast_llm/models/ssm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 471e6d06c..8fb0d5982 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -132,7 +132,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: # text_name: typing.ClassVar[str] = "mistral" -class LlavaHybridHuggingfaceCheckpointFormat(CheckpointFormat): +class LlavaHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): support_optimizer: typing.ClassVar[bool] = False name: typing.ClassVar[str] = "llava_hybrid" vision_name: typing.ClassVar[str] = "pixtral" From cbc94e0bb4511b06c18f0261923078606bf886ce Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 8 Aug 2025 14:33:52 +0000 Subject: [PATCH 154/161] fixes --- fast_llm/layers/vision_encoder/adapter.py | 8 ++++---- fast_llm/layers/vision_encoder/patch_conv.py | 14 +++++++------- fast_llm/models/gpt/model.py | 10 +++------- fast_llm/models/ssm/config.py | 1 - 4 files changed, 14 insertions(+), 19 deletions(-) diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py index d324d5221..a59c6226f 100644 --- a/fast_llm/layers/vision_encoder/adapter.py +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -18,18 +18,18 @@ class VisionAdapter(Layer): def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() - input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + input_dim = tensor_space[VisionEncoderDimNames.out_channels] self._activation_type = config.adapter_activation_type self.layer_1 = Linear( input_dim, - tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), + tensor_space[VisionEncoderDimNames.adapter_size], bias=True, weight_init_method=init_normal_(std=config.adapter_init_method_std), bias_init_method=init_normal_(std=config.adapter_init_method_std), ) self.layer_2 = Linear( - tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), - tensor_space.get_tensor_dim(TransformerDimNames.hidden), + tensor_space[VisionEncoderDimNames.adapter_size], + tensor_space[TransformerDimNames.hidden], bias=True, weight_init_method=init_normal_(std=config.adapter_init_method_std), bias_init_method=init_normal_(std=config.adapter_init_method_std), diff --git a/fast_llm/layers/vision_encoder/patch_conv.py b/fast_llm/layers/vision_encoder/patch_conv.py index 3d1845dd8..6c2a70930 100644 --- a/fast_llm/layers/vision_encoder/patch_conv.py +++ b/fast_llm/layers/vision_encoder/patch_conv.py @@ -19,23 +19,23 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self._lr_scale = config.adapter_lr_scale self.weight = ParameterMeta.from_dims( ( - self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels), - self._tensor_space.get_tensor_dim(VisionEncoderDimNames.in_channels), - self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size), - self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size), + self._tensor_space[VisionEncoderDimNames.out_channels], + self._tensor_space[VisionEncoderDimNames.in_channels], + self._tensor_space[VisionEncoderDimNames.patch_size], + self._tensor_space[VisionEncoderDimNames.patch_size], ), init_method=init_normal_(), lr_scale=self._lr_scale, ) if config.conv_bias: self.bias = ParameterMeta.from_dims( - (self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels),), + (self._tensor_space[VisionEncoderDimNames.out_channels],), init_method=init_normal_(), - lr_sclae=self._lr_scale, + lr_scale=self._lr_scale, ) else: self.bias = None - self.norm = config.patch_norm.get_layer(tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels)) + self.norm = config.patch_norm.get_layer(tensor_space[VisionEncoderDimNames.out_channels]) self.stride = config.patch_size def forward( diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index ebf84fc58..6261821ab 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -173,12 +173,8 @@ def preprocess_meta( VisionEncoderKwargs.image_std: image_std, VisionEncoderKwargs.image_rescale_factor: image_rescale_factor, VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, - VisionEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( - VisionTransformerDimNames.kv_channels - ).size, - VisionEncoderKwargs.out_channels: self._tensor_space.get_tensor_dim( - VisionEncoderDimNames.out_channels - ).size, + VisionEncoderKwargs.kv_channels: self._tensor_space[VisionTransformerDimNames.kv_channels].size, + VisionEncoderKwargs.out_channels: self._tensor_space[VisionEncoderDimNames.out_channels].size, } else: vision_kwargs = {} @@ -226,7 +222,7 @@ def preprocess_meta( else (batch_dim, hidden_sequence_q_dim, hidden_dim) ) if self._config.vision_encoder.enabled: - vision_hidden_dim = self._tensor_space.get_tensor_dim(VisionTransformerDimNames.hidden) + vision_hidden_dim = self._tensor_space[VisionTransformerDimNames.hidden] vision_hidden_dims = ( (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) if sequence_first diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 8fb0d5982..5dca41a70 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -133,7 +133,6 @@ def get_handler_class(cls) -> type[CheckpointHandler]: class LlavaHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False name: typing.ClassVar[str] = "llava_hybrid" vision_name: typing.ClassVar[str] = "pixtral" text_name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" From 52639960a37e5930243f67e2f73046717643becc Mon Sep 17 00:00:00 2001 From: Oleksiy Ostapenko Date: Wed, 13 Aug 2025 09:00:56 -0400 Subject: [PATCH 155/161] [Dev Hybrid] Distill with loss_mask (SFT dataset) and sequence-TP (#350) --- fast_llm/functional/config.py | 6 + fast_llm/functional/cross_entropy.py | 157 ++++++++++++++++++------- fast_llm/layers/language_model/head.py | 52 ++++++-- fast_llm/models/gpt/model.py | 3 + 4 files changed, 169 insertions(+), 49 deletions(-) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index c56b63065..5c8d75a6f 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -93,6 +93,12 @@ def _set_activation_fn_map() -> None: MAX_DROPLESS_BLOCK_SIZE_ROW = 128 +class ReverseKLImpl(str, enum.Enum): + tp = "tp" + stp = "stp" + no_tp = "no_tp" + + class CrossEntropyImpl(str, enum.Enum): auto = "auto" torch = "torch" diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index b18a9ec0b..1be4ed82b 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -1,7 +1,7 @@ import torch from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat +from fast_llm.functional.config import CrossEntropyImpl, ReverseKLImpl, TargetFormat from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward from fast_llm.utils import Assert @@ -234,6 +234,9 @@ def _torch_reverse_kl_forward_backward_vocab_parallel( grad_output: float | None, target_format: TargetFormat, group: ProcessGroup | None = None, + logits_scale_factor: float = 1.0, + teacher_softmax_temperature: float = 1.0, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. @@ -241,6 +244,12 @@ def _torch_reverse_kl_forward_backward_vocab_parallel( This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. """ + Assert.eq( + teacher_softmax_temperature, + 1, + msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", + ) + Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") # TODO: merge into single function _torch_reverse_kl_forward_backward Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") Assert.eq(target.shape, logits.shape) @@ -249,10 +258,10 @@ def _torch_reverse_kl_forward_backward_vocab_parallel( Assert.eq(loss_mask.shape, logits.shape[:-1]) # Compute log probabilities - let _fused_softmax handle scaling internally - teacher_log_probs = distributed_log_softmax(target, group=group) + teacher_log_probs = distributed_log_softmax(target.float(), group=group) batch_size = logits.shape[0] with torch.enable_grad(): - logits_ = logits.detach().requires_grad_(grad_output is not None) + logits_ = logits.float().detach().requires_grad_(grad_output is not None) student_log_probs = distributed_log_softmax(logits_, group=group) # Reverse KL: input=teacher_log_probs, target=student_probs @@ -284,20 +293,19 @@ def _torch_reverse_kl_forward_backward_vocab_parallel( return loss.detach_(), grad -def _torch_reverse_kl_forward_backward( +def _torch_reverse_kl_forward_backward_no_tp( logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None, grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, - group: ProcessGroup | None = None, teacher_softmax_temperature: float = 1.0, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. - This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. - In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. + THis is only used for no-TP case. """ Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") Assert.eq(target.shape, logits.shape) @@ -309,20 +317,20 @@ def _torch_reverse_kl_forward_backward( # Clamp to prevent extreme values that cause NaNs in log_softmax scaled_target = torch.clamp(scaled_target, min=-100.0, max=100.0) - teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) + teacher_log_probs = torch.log_softmax(scaled_target.float(), dim=-1) # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) # Use kl_div with: input=log(p), target=q, log_target=False # This gives: Σ q * (log(q) - log(p)) = exactly what we want! with torch.enable_grad(): - logits_ = logits.detach().requires_grad_(grad_output is not None) + logits_ = logits.float().detach().requires_grad_(grad_output is not None) scaled_logits = logits_ * logits_scale_factor # Clamp to prevent extreme values that cause NaNs in log_softmax scaled_logits = torch.clamp(scaled_logits, min=-100.0, max=100.0) - student_log_probs = torch.log_softmax(scaled_logits, dim=-1) - + student_log_probs = torch.log_softmax(scaled_logits.float(), dim=-1) + # Reverse KL: input=teacher_log_probs, target=student_probs if loss_mask is None: loss = torch.nn.functional.kl_div( @@ -336,11 +344,7 @@ def _torch_reverse_kl_forward_backward( loss_per_sample = torch.nn.functional.kl_div( teacher_log_probs, student_log_probs, reduction="none", log_target=True ).sum(dim=-1) - loss = (loss_per_sample * loss_mask).mean() - - if group is not None and target_format != TargetFormat.labels: - all_reduce(loss, op=ReduceOp.SUM, group=group) - loss /= group.size() + loss = (loss_per_sample * loss_mask).sum() / loss_mask.sum() if grad_output is not None: # note, we never get here in TP over seq. dim. @@ -352,6 +356,88 @@ def _torch_reverse_kl_forward_backward( return loss.detach_(), grad +def _torch_reverse_kl_forward_backward_sequence_tensor_parallel( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + teacher_softmax_temperature: float = 1.0, + total_valid_tokens: int | None = None, # total number of unmasked tokens in the batch + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Reverse KL using PyTorch's native kl_div function. + THis is only used for sequence-tensor-parallel case where we split over sequence dimension. + """ + Assert.eq( + total_valid_tokens is not None, + msg="Total valid tokens must be provided for sequence-tensor-parallel reverse KL", + ) + Assert.eq( + teacher_softmax_temperature, + 1, + msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", + ) + Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") + Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + # Scale target logits more carefully + scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) + # Clamp to prevent extreme values that cause NaNs in log_softmax + scaled_target = torch.clamp(scaled_target, min=-100.0, max=100.0) + + teacher_log_probs = torch.log_softmax(scaled_target.float(), dim=-1) + + # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) + # Use kl_div with: input=log(p), target=q, log_target=False + # This gives: Σ q * (log(q) - log(p)) = exactly what we want! + + with torch.enable_grad(): + logits_ = logits.float().detach().requires_grad_(grad_output is not None) + + scaled_logits = logits_ * logits_scale_factor + # Clamp to prevent extreme values that cause NaNs in log_softmax + scaled_logits = torch.clamp(scaled_logits, min=-100.0, max=100.0) + student_log_probs = torch.log_softmax(scaled_logits.float(), dim=-1) + + # Reverse KL: input=teacher_log_probs, target=student_probs + if loss_mask is None: + loss = torch.nn.functional.kl_div( + teacher_log_probs, # input = log(p) + student_log_probs, # target = log(q) + reduction="sum", + log_target=True, + ) + else: + # Apply loss mask - this requires some reshaping + loss_per_sample = torch.nn.functional.kl_div( + teacher_log_probs, student_log_probs, reduction="none", log_target=True + ).sum(dim=-1) + loss = (loss_per_sample * loss_mask).sum() # this can be 0.0 if all tokens are masked + + if grad_output is not None: + # note, if we compute gradient w.r.t sum of losses, + # and grad_output should reflect the scaling by 1/valid samples + loss.backward(torch.full_like(loss, grad_output)) + grad = logits_.grad.to(logits.dtype) + else: + grad = None + + return loss.detach_(), grad + + +REVERSE_KL_IMPLEMENTATIONS = { + ReverseKLImpl.no_tp: _torch_reverse_kl_forward_backward_no_tp, + ReverseKLImpl.tp: _torch_reverse_kl_forward_backward_vocab_parallel, + ReverseKLImpl.stp: _torch_reverse_kl_forward_backward_sequence_tensor_parallel, +} + + def reverse_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -361,7 +447,8 @@ def reverse_kl_forward_backward( logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, - vocab_parallel: bool = False, + reverse_kl_impl: ReverseKLImpl = ReverseKLImpl.no_tp, + total_valid_tokens: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -404,27 +491,15 @@ def reverse_kl_forward_backward( assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) - # TODO: implement fused? - if vocab_parallel: - Assert.eq(teacher_softmax_temperature, 1) - Assert.eq(logits_scale_factor, 1) - raise NotImplementedError("Vocab parallel reverse KL is not implemented yet.") - return _torch_reverse_kl_forward_backward_vocab_parallel( - logits, - target, - loss_mask, - grad_output, - target_format, - group, - ) - else: - return _torch_reverse_kl_forward_backward( - logits, - target, - loss_mask, - grad_output, - logits_scale_factor, - target_format, - group, - teacher_softmax_temperature, - ) + # TODO: implement fused reverse KL? + return REVERSE_KL_IMPLEMENTATIONS[reverse_kl_impl]( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + logits_scale_factor=logits_scale_factor, + target_format=target_format, + teacher_softmax_temperature=teacher_softmax_temperature, + group=group, + total_valid_tokens=total_valid_tokens, + ) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 24c06d5cc..b1f3564b9 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -11,7 +11,13 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig +from fast_llm.functional.config import ( + CrossEntropyImpl, + DistillationLossImpl, + ReverseKLImpl, + TargetFormat, + TritonConfig, +) from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward @@ -313,12 +319,13 @@ def _logits_cross_entropy_forward_backward_split( logit_input_grad_.copy_(grad_) loss = loss_ if loss is None else loss + loss_ del grad_, loss_ - loss_count = (self._cross_entropy_splits or 1) * (self._group_size if self._sequence_parallel_logits else 1) - if loss_count != 1: - loss.div_(loss_count) - if self._sequence_parallel_logits: - # TODO: Async - all_reduce(loss, group=self._tensor_space.distributed.tensor_group) + assert self._cross_entropy_splits is None, "This is not supported for now" + # loss_count = (self._cross_entropy_splits or 1) * (self._group_size if self._sequence_parallel_logits else 1) + # if loss_count != 1: + # loss.div_(loss_count) + # if self._sequence_parallel_logits: + # # TODO: Async + # all_reduce(loss, group=self._tensor_space.distributed.tensor_group) return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None def _logits_cross_entropy_forward_backward( @@ -412,6 +419,29 @@ def _logits_cross_entropy_forward_backward( if distillation_target is not None and self._distillation_loss_factor > 0.0: if self._distillation_loss_implementation == DistillationLossImpl.reverse_kl: + local_valid_tokens = total_valid_tokens = logits.shape[0] + if logits.shape[-1] != self._config.vocab_size: + reverse_kl_impl = ReverseKLImpl.tp + assert loss_mask is None, "Loss mask is not implemented for TP (vocab dim) reverse KL yet" + elif self._sequence_parallel_logits: + # grad_output already reflects scaling 1/ number of ranks (group_size), see _forward_backward + reverse_kl_impl = ReverseKLImpl.stp + if loss_mask is not None: + local_valid_tokens = loss_mask.sum() + total_valid_tokens = local_valid_tokens.clone() + all_reduce( + total_valid_tokens, op=ReduceOp.SUM, group=self._tensor_space.distributed.tensor_group + ) + else: + local_valid_tokens = logits.shape[0] + total_valid_tokens = local_valid_tokens * self._group_size + # in the loss function we compute grads w.r.t sum of losses, + # so we need to multiply back by the group size and divide by the number of valid tokens to get the correct scaling + # note, the function returns the sum of local losses, so we need to handle this properly for reporting + grad_output *= self._group_size / total_valid_tokens # multiply back by the group size + else: + reverse_kl_impl = ReverseKLImpl.no_tp + distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, @@ -423,8 +453,14 @@ def _logits_cross_entropy_forward_backward( target_format=( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), - vocab_parallel=logits.shape[-1] != self._config.vocab_size, + reverse_kl_impl=reverse_kl_impl, + total_valid_tokens=total_valid_tokens, ) + if self._sequence_parallel_logits: + # distillation_loss is local sum, so we need to divide by the number of valid tokens to get the correct scaling + all_reduce(distillation_loss, op=ReduceOp.SUM, group=self._tensor_space.distributed.tensor_group) + distillation_loss /= total_valid_tokens # final global loss + elif self._distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 6261821ab..da07e5291 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -143,11 +143,13 @@ def preprocess_meta( micro_batch_size = batch_meta.micro_batch_size sequence_length = batch_meta.sequence_length micro_sequence_length = batch_meta.micro_sequence_length + truncate_documents = batch_meta.truncate_documents else: micro_batch_size, sequence_length = batch_meta.shape if phase != PhaseType.inference: sequence_length -= self._config.prediction_heads micro_sequence_length = sequence_length + truncate_documents = True if self._config.vision_encoder.enabled: try: @@ -241,6 +243,7 @@ def preprocess_meta( TransformerKwargs.sequence_length: sequence_length, TransformerKwargs.sequence_q_dim: sequence_q_dim, TransformerKwargs.micro_batch_size: micro_batch_size, + LanguageModelKwargs.mask_inputs: not truncate_documents, } common_kwargs.update(vision_kwargs) From eb4ad2f4047c21275a5c2ed35cb983817282288d Mon Sep 17 00:00:00 2001 From: Oleksiy Ostapenko Date: Tue, 19 Aug 2025 18:15:13 +0200 Subject: [PATCH 156/161] varlen maba (#352) --- Dockerfile | 3 +- fast_llm/layers/ssm/config.py | 20 ++ fast_llm/layers/ssm/mamba2.py | 74 ++++++-- fast_llm/layers/ssm/preprocessing.py | 68 +++++++ fast_llm/models/ssm/config.py | 5 + fast_llm/models/ssm/model.py | 2 + setup.cfg | 2 +- tests/test_ssms.py | 271 ++++++++++++++++++++++++++- 8 files changed, 428 insertions(+), 17 deletions(-) create mode 100644 fast_llm/layers/ssm/preprocessing.py diff --git a/Dockerfile b/Dockerfile index 0f3c2d8cb..7cf951017 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,8 +29,9 @@ ENV PIP_CONSTRAINT="" # There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds. # We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d) # We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) +# Using varlen_mamba for variable length sequence support RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" -RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" +RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/jxiw/varlen_mamba@varlen_mamba" # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 3b21ca698..194063a26 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -12,6 +12,26 @@ from fast_llm.tensor import Initializer +class BaseSSMKwargs: + _kwargs_attributes = { + "cu_seqlens": "cu_seqlens", + "seq_idx": "seq_idx", + "ssm_position_ids": "ssm_position_ids", + } + + _prefix = "" + + def __init_subclass__(cls, prefix="", **kwargs): + super().__init_subclass__(**kwargs) + cls._prefix = prefix + for attr, value in BaseSSMKwargs._kwargs_attributes.items(): + setattr(cls, value, f"{cls._prefix}_{value}" if cls._prefix else value) + + +class SSMKwargs(BaseSSMKwargs, prefix=""): + pass + + class SSMDimNames: # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. state = "ssm_state" # State dimension (N), aka head size / num channels diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 77c1b3869..ff96c5ce8 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,3 +1,4 @@ +import inspect import logging import typing @@ -6,17 +7,28 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames, SSMKwargs from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ from fast_llm.utils import Assert, div, get_lr_scale +_mamba_varlen = False try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa _mamba_available = True + sig = inspect.signature(selective_scan_fn) + if "position_indices" in sig.parameters: + _mamba_varlen = True + logging.warning("Using selective_scan_fn from varlen_mamba that supports packing") + else: + _mamba_varlen = False + logging.warning("Using selective_scan_fn from original mamba without packing support") + # for training with packing install https://github.com/jxiw/varlen_mamba + # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md + except (ImportError, RuntimeError): _mamba_available = False @@ -143,8 +155,16 @@ def __init__( ) def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Note, we are nto doing "read" sequence-tensor parallel trainign here, since inner_projection is gathered over all GPUS. + This is also desired, since the currently used mamba kernel does not support STP. + TODO: use correct kernel from Mamba2! + """ assert _mamba_available assert _causal_conv1d_available + cu_seqlens = kwargs[SSMKwargs.cu_seqlens] + seq_idx = kwargs[SSMKwargs.seq_idx] + position_indices = kwargs[SSMKwargs.ssm_position_ids] # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) # -> (batch/sequence, sequence/batch, inner_projection) @@ -174,9 +194,20 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") + + if cu_seqlens is not None: + # from https://github.com/jxiw/M1/blob/d92b53faa640f8ebf624d3e9e771fe24648ef014/rl/verl/verl/models/mamba/hybrid_wrapper.py#L152 + x = _causal_conv1d_fn( + x=x.transpose(1, 2).contiguous().transpose(1, 2), + weight=self.conv1d_weight.squeeze(1), + bias=self.conv1d_bias, + seq_idx=seq_idx, + activation="silu", + ) else: x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") + + if not self._config.repeat_kv_before_conv: x = ( x.unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) @@ -203,17 +234,34 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ self._debug_log(c, "c", self._BC_DIMS, kwargs) self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) - y = selective_scan_fn( - x, - dt, - -torch.exp(self.A_log.float()), - b, - c, - self.D.float(), - z, - delta_bias=self.dt_proj_bias.float(), - delta_softplus=True, - ) + if not _mamba_varlen: + Assert.eq(cu_seqlens, None, msg="This version of Mamba2 does not support cu_seqlens, install verlen mamba") + y = selective_scan_fn( + x, + dt, + -torch.exp(self.A_log.float()), + b, + c, + self.D.float(), + z, + delta_bias=self.dt_proj_bias.float(), + delta_softplus=True, + ) + else: + position_indices = position_indices if cu_seqlens is not None else None + + y = selective_scan_fn( + x, + dt, + -torch.exp(self.A_log.float()), + b, + c, + self.D.float(), + z, + delta_bias=self.dt_proj_bias.float(), + delta_softplus=True, + position_indices=position_indices, + ) if self._debug_level: self._debug_log(y, "y", self._XZ_DIMS, kwargs) diff --git a/fast_llm/layers/ssm/preprocessing.py b/fast_llm/layers/ssm/preprocessing.py new file mode 100644 index 000000000..343f0bb28 --- /dev/null +++ b/fast_llm/layers/ssm/preprocessing.py @@ -0,0 +1,68 @@ +import logging +import typing + +import torch + +from fast_llm.engine.base_model.config import Preprocessor +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.ssm.config import SSMKwargs +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.models.ssm.config import HybridSSMBaseModelConfig +from fast_llm.utils import Assert + +logger = logging.getLogger(__name__) + + +class Mamba2Preprocessor(Preprocessor): + def __init__(self, config: HybridSSMBaseModelConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + self._transformer_dim_names = config.transformer._transformer_dim_names + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + """ + Simplified preprocessor that does not take into account micro-sequences. + """ + if TransformerKwargs.sequence_lengths not in kwargs: + return + sequence_lengths = kwargs[TransformerKwargs.sequence_lengths] + if TransformerKwargs.cu_seqlens_k in kwargs: + # already set this in the transformer preprocessor, so we can use it here + cu_seqlens_k = kwargs[TransformerKwargs.cu_seqlens_k] + cu_seqlens_q = kwargs[TransformerKwargs.cu_seqlens_q] + Assert.eq( + cu_seqlens_k.shape[0], + cu_seqlens_q.shape[0], + msg="cu_seqlens_k and cu_seqlens_q have different lengths, is micro_sequence_length being used? This is currently not supported for Mamba.", + ) + Assert.all_equal(cu_seqlens_k, cu_seqlens_q) + cu_seqlens = cu_seqlens_k + else: + seqlens = torch.cat(sequence_lengths) + cu_seqlens = torch.cat( + ( + torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), + torch.cumsum(seqlens, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), + ) + ) + kwargs[SSMKwargs.cu_seqlens] = cu_seqlens + # from https://github.com/jxiw/M1/blob/d92b53faa640f8ebf624d3e9e771fe24648ef014/rl/verl/verl/models/mamba/hybrid_wrapper.py#L152 + kwargs[SSMKwargs.seq_idx] = torch.cat( + [ + torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:] - cu_seqlens[:-1]) + ], + dim=0, + ).unsqueeze(0) + + sequence_lengths = kwargs.get(TransformerKwargs.sequence_lengths) + sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + position_ids = torch.stack( + [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] + ).to(self._tensor_space.distributed.device, dtype=torch.int64) + position_ids = position_ids[ + :, sequence_k - sequence_q : sequence_k + ] # this is only needed if we do micro-sequences? + kwargs[SSMKwargs.ssm_position_ids] = position_ids.to(torch.int32) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 5dca41a70..34f3151a6 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -207,6 +207,11 @@ def get_trainer_class(cls) -> type["HybridSSMTrainer"]: def _validate(self) -> None: super()._validate() + Assert.eq( + self.batch.micro_sequence_length, + self.batch.sequence_length, + msg="Micro-sequences not supported for SSMs. at htis point", + ) if (name := self.model.base_model.distillation_model) is None: Assert.empty(self.reference_models) else: diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 29f115bd9..fafe44090 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -6,6 +6,7 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.ssm.preprocessing import Mamba2Preprocessor from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTModel @@ -30,6 +31,7 @@ def __init__( distributed_config: DistributedConfig, ): super().__init__(config, distributed_config) + self._preprocessors.append(Mamba2Preprocessor(config, self._tensor_space)) def get_output_layers(self) -> list[Layer]: """ diff --git a/setup.cfg b/setup.cfg index 6ea98610c..c2eb1f6f2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,7 +50,7 @@ HUGGINGFACE = # To install on cpu environment (ex. for IDE support): # MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = - mamba_ssm[causal-conv1d]==2.2.4 + mamba_ssm[causal-conv1d] @ git+https://github.com/jxiw/varlen_mamba.git@varlen_mamba cartesia_pytorch>=0.0.2 # GENERATION = diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 694faa55b..2a338f1ba 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -1,19 +1,60 @@ +import inspect +import itertools import pathlib +from functools import partial import pytest import torch +from mamba2 import Mamba2 from fast_llm.config import NoAutoValidate from fast_llm.engine.checkpoint.config import CheckpointLoadConfig +from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, PhaseType +from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.schedule.config import ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs from fast_llm.models.gpt.config import GPTBatchConfig -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat from fast_llm.models.ssm.model import HybridSSMModel +_mamba_varlen = False +try: + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa + + _mamba_available = True + sig = inspect.signature(selective_scan_fn) + if "position_indices" in sig.parameters: + _mamba_varlen = True + else: + _mamba_varlen = False + # for training with packing install https://github.com/jxiw/varlen_mamba + # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md + +except (ImportError, RuntimeError): + _mamba_available = False + + +def get_hybrid_config(hybrid_block_layout=["t", "m2"], prediction_heads=1, default_mtp_type=None): + hidden_size = 512 + config = HybridSSMBaseModelConfig( + transformer=TransformerConfig(num_layers=len(hybrid_block_layout), hidden_size=hidden_size), + ssm=SSMConfig(d_xb=hidden_size, dt_rank=10, d_inner=hidden_size * 2), + hybrid_block_layout=hybrid_block_layout, + prediction_heads=prediction_heads, + default_mtp_type=default_mtp_type, + init_method_std_embed=0.02, + init_method_min_embed=-0.02, + init_method_max_embed=0.02, + use_position_embeddings=True, + tie_word_embeddings=False, + ) + return config + @pytest.mark.skip("Disabled due to cartesia_pytorch installation issue") @pytest.mark.slow @@ -80,3 +121,229 @@ def test_load_from_llamba_checkpoint(): logits = input_data[0][1]["logits"].cpu() assert torch.allclose(logits, hf_logits, atol=1e-2) + + +@pytest.fixture +def distributed_config(): + return DistributedConfig( + tensor_parallel=1, + pipeline_parallel=1, + sequence_data_parallel=1, + local_world_size=1, + world_size=1, + ) + + +@pytest.fixture +def distributed(distributed_config): + return Distributed(config=distributed_config) + + +def materialize_meta_tensors(model, tensor_space): + # Materialize parameters that are on meta device + for name, param in model.named_parameters(): + if param.device.type == "meta": + # Check if the parameter is a custom tensor type + if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): + param_data = param.new_empty(param.shape, device="cuda") + # Initialize param_data + param.init_parameter(param_data, tensor_space.distributed) + # Replace the parameter in the module + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + module = model + if module_path is not None: + for part in module_path.split("."): + module = getattr(module, part) + param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation + param.grad = None + param.grad_buffer = torch.empty_like(param) + param.param_grad_is_zero = True + module._parameters[param_name] = param + return model + + +def unpack(packed_hidden_states, cu_seqlens): + batch_size = packed_hidden_states.shape[0] + package_num = cu_seqlens.shape[0] - 1 + seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + hidden_dim = packed_hidden_states.shape[2] + hidden_states = torch.zeros( + package_num * batch_size, + seq_len, + hidden_dim, + dtype=packed_hidden_states.dtype, + device=packed_hidden_states.device, + ) + for j in range(batch_size): + for i in range(package_num): + line = j * package_num + i + hidden_states[line, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[ + j, cu_seqlens[i] : cu_seqlens[i + 1], : + ] + return hidden_states + + +def pack(hidden_states, cu_seqlens, batch_size): + package_num, seq_len, hidden_dim = hidden_states.shape + seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1] + seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2) + indices_3d = ( + torch.arange(seq_len, device=hidden_states.device).unsqueeze(0).unsqueeze(2).repeat(package_num, 1, hidden_dim) + ) + mask_3d = indices_3d < seq_len_list_3d.repeat(batch_size, 1, 1) + packed_hidden_states = hidden_states[mask_3d].view(batch_size, -1, hidden_dim) + return packed_hidden_states + + +def generate_random_cu_seqlens(seq_len, packages_num=2): + if packages_num < 1: + raise ValueError("packages_num must be at least 1") + + # base size of each chunk, and how many get an extra token + base, rem = divmod(seq_len, packages_num) + # lengths: e.g. for seq_len=10, packages=3 → [4,3,3] + lengths = [base + 1 if i < rem else base for i in range(packages_num)] + + # split points exclude the final cumulative (seq_len) + split_points = list(itertools.accumulate(lengths))[:-1] + + # cu_seqlens = [0] + split_points + [seq_len] + cu_seqlens = [0] + split_points + [seq_len] + + # index: for each chunk, we emit 0,1,...,length-1 + index = [] + for length in lengths: + index.extend(range(length)) + + # sanity check + assert len(cu_seqlens) - 1 == packages_num + assert sum(lengths) == seq_len + assert len(index) == seq_len + + return cu_seqlens, index + + +# Quick and dirty test for Mamba2 varlen block from https://github.com/jxiw/M1/blob/d92b53faa640f8ebf624d3e9e771fe24648ef014/rl/verl/tests/pack_mamba/test_mamba_layer.py +# TODO: integrate in the testing framework +@pytest.mark.slow +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA available") +@pytest.mark.skipif(not _mamba_available, reason="Mamba2 is not available") +@pytest.mark.skipif(not _mamba_varlen, reason="Mamba2 varlen is not available") +def test_mamba_varlen_block(distributed_config, distributed): + """ + Compare that the output and grads of packed and unpacked Mamba2 varlen block are the same. + """ + hybrid_config = get_hybrid_config(hybrid_block_layout=["m2", "t"]) + tensor_space = TensorSpace(distributed_config=distributed_config) + tensor_space.setup(distributed) + hybrid_config.setup_tensor_space(tensor_space) + layer_idx = 0 + + mixer_cls = partial(Mamba2, block_index=layer_idx) + block_packed = SSMBlock( + hybrid_config.transformer, + hybrid_config.ssm, + mixer_cls=mixer_cls, + tensor_space=tensor_space, + block_index=layer_idx, + ) + block_ref = SSMBlock( + hybrid_config.transformer, + hybrid_config.ssm, + mixer_cls=mixer_cls, + tensor_space=tensor_space, + block_index=layer_idx, + ) + device = "cuda" + materialize_meta_tensors(block_packed, tensor_space) + materialize_meta_tensors(block_ref, tensor_space) + block_ref.load_state_dict(block_packed.state_dict()) + block_packed.to(device) + block_ref.to(device) + + batch_size = 2 + seq_len = 64 + packages_num = 2 + hidden_dim = hybrid_config.transformer.hidden_size + + cu_seqlens, index = generate_random_cu_seqlens(seq_len, packages_num=packages_num) + cu_seqlens = torch.tensor(cu_seqlens).cuda() + ssm_position_ids = torch.tensor(index, dtype=torch.int32).unsqueeze(0).expand(batch_size, -1).contiguous().cuda() + seq_idx = ( + torch.cat( + [ + torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:] - cu_seqlens[:-1]) + ], + dim=0, + ) + .unsqueeze(0) + .repeat(batch_size, 1) + ) + + # Generate packed_hidden_states with random values for testing + hidden_states_list = [ + torch.randn(l, hidden_dim, device=device, dtype=torch.bfloat16, requires_grad=True) + for l in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + ] + packed_hidden_states = torch.cat(hidden_states_list, dim=0).unsqueeze(0) + packed_hidden_states = packed_hidden_states.expand(batch_size, -1, -1).contiguous() + # hidden_states should be forwarded without cu_seqlens + hidden_states = unpack(packed_hidden_states, cu_seqlens) + + # Check: sum of seq_len of item in hidden_states_list should be equal to seq_len of packed_hidden_states + assert sum([hs.shape[0] for hs in hidden_states_list]) == packed_hidden_states.shape[1] + # Check: max of seq_len of item in hidden_states_list should be equal to seq_len of hidden_states + assert max([hs.shape[0] for hs in hidden_states_list]) == hidden_states.shape[1] + + output_states_packed = block_packed( + packed_hidden_states, + {"cu_seqlens": cu_seqlens, "seq_idx": seq_idx, "ssm_position_ids": ssm_position_ids, "sequence_first": False}, + ) + output_states_unpacked = block_ref( + hidden_states.clone(), {"cu_seqlens": None, "seq_idx": None, "ssm_position_ids": None, "sequence_first": False} + ) + tollerance = 1e-4 + assert output_states_packed.shape == packed_hidden_states.shape + assert output_states_unpacked.shape == hidden_states.shape + assert not torch.isnan(hidden_states).any() + assert not torch.isinf(hidden_states).any() + + output_states_unpacked = pack(output_states_unpacked, cu_seqlens, batch_size) + torch.allclose(output_states_packed, output_states_unpacked, atol=tollerance) + + loss = output_states_packed.sum() + loss.backward() + loss_ref = output_states_unpacked.sum() + loss_ref.backward() + assert torch.allclose(block_packed.mixer.conv1d_weight.grad, block_ref.mixer.conv1d_weight.grad, atol=tollerance) + assert torch.allclose(block_packed.mixer.conv1d_bias.grad, block_ref.mixer.conv1d_bias.grad, atol=tollerance) + assert torch.allclose( + block_packed.mixer.in_proj.weight.grad_buffer, block_ref.mixer.in_proj.weight.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mixer.out_proj.weight.grad_buffer, block_ref.mixer.out_proj.weight.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mixer.dt_in_proj.weight.grad_buffer, + block_ref.mixer.dt_in_proj.weight.grad_buffer, + atol=tollerance, + ) + + assert torch.allclose( + block_packed.mlp.layer_1.weight.grad_buffer, block_ref.mlp.layer_1.weight.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mlp.layer_1.bias.grad_buffer, block_ref.mlp.layer_1.bias.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mlp.layer_2.weight.grad_buffer, block_ref.mlp.layer_2.weight.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mlp.layer_2.bias.grad_buffer, block_ref.mlp.layer_2.bias.grad_buffer, atol=tollerance + ) + + +if __name__ == "__main__": + pytest.main([__file__]) From 085991a874d1968029b63085c4e80551cc453d66 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 20 Aug 2025 13:41:21 +0000 Subject: [PATCH 157/161] lr scale --- fast_llm/layers/ssm/mamba2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index ff96c5ce8..d09ae9a67 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -151,7 +151,7 @@ def __init__( bias=config.add_bias_linear, weight_init_method=init_kaiming_(self._config.d_inner), sequence_parallel=self._sequence_parallel, - # TODO: lr_scale? + lr_scale=lr_scale, ) def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: From d3e3aecb3a98a94cff13381dbc1ef3135637292a Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 20 Aug 2025 19:25:44 +0000 Subject: [PATCH 158/161] fixes --- fast_llm/layers/ssm/mamba2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index d09ae9a67..5ed689a73 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -162,9 +162,9 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ """ assert _mamba_available assert _causal_conv1d_available - cu_seqlens = kwargs[SSMKwargs.cu_seqlens] - seq_idx = kwargs[SSMKwargs.seq_idx] - position_indices = kwargs[SSMKwargs.ssm_position_ids] + cu_seqlens = kwargs.get(SSMKwargs.cu_seqlens) + seq_idx = kwargs.get(SSMKwargs.seq_idx) + position_indices = kwargs.get(SSMKwargs.ssm_position_ids) # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) # -> (batch/sequence, sequence/batch, inner_projection) From 518ae8d087e8c09ee8e0166bc37625160c425329 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 27 Aug 2025 13:43:22 +0000 Subject: [PATCH 159/161] hybrid checkpoint creation script --- ...brid_checkpoint_with_importance_15b_mil.py | 176 ------------------ .../make_hybrid_checkpoint_with_mil.py | 104 +++++++++++ 2 files changed, 104 insertions(+), 176 deletions(-) delete mode 100644 fast_llm/models/ssm/external/make_hybrid_checkpoint_with_importance_15b_mil.py create mode 100644 fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_importance_15b_mil.py b/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_importance_15b_mil.py deleted file mode 100644 index dde11cfbc..000000000 --- a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_importance_15b_mil.py +++ /dev/null @@ -1,176 +0,0 @@ -import click -import torch -import transformers -from transformers import AutoConfig, AutoModelForCausalLM - -from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig -from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( - AprielSSMM2DecoderLayer, - AprielThinkerSSMHybridForCausalLM, -) - -device = "cuda" if torch.cuda.is_available() else "cpu" - -print("Transformers version:", transformers.__version__) - - -def convert_layers(transformer, mamba_config, hybrid_block_layout, init_with_kqvo, torch_dtype): - - for layer_idx, type in enumerate(hybrid_block_layout): - # print("Converting layer %d...", layer_idx) - # Fetch the layer module for easier access - layer_module = transformer.model.layers._modules[f"{layer_idx}"] - if type == "t": - print("Skipping transformer layer %d..." % layer_idx) - elif type == "m2": - print("Converting layer %d to Mamba2 with MIL init..." % layer_idx) - # Use MambaDecoderLayer for the remaining layers - mamba_encoder = AprielSSMM2DecoderLayer( - mamba_config, - layer_idx, - device="cpu", - dtype=torch_dtype, - ) - - mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict()) - mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict()) - mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict()) - mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict()) - - if init_with_kqvo: - # Copy weights: [z, x, B, C, dt], x -> v, B -> k, C -> q - mamba_encoder.mixer.in_proj.weight.data[ - mamba_config.ssm_cfg["d_inner"] : mamba_config.ssm_cfg["d_inner"] + mamba_config.ssm_cfg["d_xb"], : - ].copy_(layer_module.self_attn.v_proj.weight.data) - mamba_encoder.mixer.in_proj.weight.data[ - mamba_config.ssm_cfg["d_inner"] - + mamba_config.ssm_cfg["d_xb"] : mamba_config.ssm_cfg["d_inner"] - + 2 * mamba_config.ssm_cfg["d_xb"], - :, - ].copy_(layer_module.self_attn.k_proj.weight.data) - mamba_encoder.mixer.in_proj.weight.data[ - mamba_config.ssm_cfg["d_inner"] - + 2 * mamba_config.ssm_cfg["d_xb"] : 2 * mamba_config.ssm_cfg["d_inner"] - + 2 * mamba_config.ssm_cfg["d_xb"], - :, - ].copy_(layer_module.self_attn.q_proj.weight.data) - - print("Init Mamba using Attention") - - transformer.model.layers[layer_idx] = mamba_encoder - - elif type == "m2d": - raise NotImplementedError("Discrete Mamba2 not implemented") - else: - raise ValueError(f"Invalid layer type: {type}") - - -@click.command() -@click.option("--index_to_swap", type=int, required=True) -@click.option("--checkpoint", type=str, required=True) -@click.option("--output_model_path", type=str, required=True) -@click.option("--layer_type", type=str, default="m2") -@click.option("--mil_init", type=bool, default=True) -def main( - index_to_swap: int, - checkpoint=None, - output_model_path="/mnt/checkpoints/ssm/iterative_hybrids_15b_rkl_m2/apriel_ssm_thinker_15b_hybrid", - layer_type="m2", - mil_init=True, -): - print(f"index_to_swap: {index_to_swap}, checkpoint: {checkpoint}") - - layer_importance = [ - 47, - 39, - 24, - 36, - 31, - 43, - 32, - 20, - 38, - 37, - 30, - 33, - 22, - 23, - 40, - 42, - 44, - 35, - 41, - 27, - 21, - 46, - 45, - 49, - 25, - 34, - 29, - 28, - 19, - 26, - 18, - 17, - 16, - 13, - 15, - 14, - 8, - 9, - 12, - 6, - 11, - 5, - 48, - 7, - 10, - 3, - 4, - 1, - 0, - ] - path_base = "/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker" - config_base = AutoConfig.from_pretrained(path_base) - hybrid_block_layout = ["t"] * config_base.num_hidden_layers - - for i in range(index_to_swap + 1): - layer_idx = int(layer_importance[i]) - print(f"Swapping layer {layer_idx} to {layer_type}") - hybrid_block_layout[layer_idx] = layer_type - - transformer = AutoModelForCausalLM.from_pretrained(path_base) - model_hybrid_prev = AprielThinkerSSMHybridForCausalLM.from_pretrained(checkpoint, trust_remote_code=True).to( - torch.bfloat16 - ) - config_hybrid = AprielSSMHybridConfig(**model_hybrid_prev.config.to_dict()) - config_hybrid.hybrid_block_layout = hybrid_block_layout - convert_layers(transformer, config_hybrid, hybrid_block_layout, mil_init, torch.bfloat16) - - missing, unexpected = transformer.load_state_dict( - model_hybrid_prev.state_dict(), strict=False - ) # will not load the newly innitialized layer (will stay MIL), but will overwrite previous layers - if missing: - print("Missing keys:", missing) - if unexpected: - print("Unexpected keys:", unexpected) - transformer.to(torch.bfloat16) - model_hybrid_prev = None - print(transformer) - model_hybrid = AprielThinkerSSMHybridForCausalLM(config_hybrid) - missing, unexpected = model_hybrid.load_state_dict(transformer.state_dict()) - assert len(missing) == 0, "Missing keys: " + str(missing) - assert len(unexpected) == 0, "Unexpected keys: " + str(unexpected) - - model_hybrid.save_pretrained(f"{output_model_path}") - # config_hybrid.save_pretrained(f"{output_model_path}") - - -if __name__ == "__main__": - main() - # main( - # index_to_swap=1, - # checkpoint="/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15b-ihyb1lrklm216mil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2/export/apriel_ssm_thinker_hybrid/1000", - # layer_type="m2", - # ) diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py b/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py new file mode 100644 index 000000000..d50a45fa3 --- /dev/null +++ b/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py @@ -0,0 +1,104 @@ +import gc + +import click +import torch +from transformers import AutoModelForCausalLM + +from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig +from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + AprielSSMM2DecoderLayer, + AprielThinkerSSMHybridForCausalLM, +) + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +def convert_layers(transformer, mamba_config, hybrid_block_layout, init_with_kqvo, torch_dtype=torch.bfloat16): + config = transformer.config + embed_dim = config.hidden_size + num_heads = config.num_attention_heads + num_heads_kv = config.num_key_value_heads + head_dim = embed_dim // num_heads + head_dim * num_heads + head_dim * num_heads_kv + + for layer_idx, type in enumerate(hybrid_block_layout): + print("Converting layer %d...", layer_idx) + # Fetch the layer module for easier access + layer_module = transformer.model.layers._modules[f"{layer_idx}"] + if type == "t": + print("Skipping transformer layer %d..." % layer_idx) + elif type == "m2": + print("Converting layer %d..." % layer_idx) + # Use MambaDecoderLayer for the remaining layers + mamba_encoder = AprielSSMM2DecoderLayer( + mamba_config, + layer_idx, + device="cpu", + dtype=torch_dtype, + ) + + mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict()) + mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict()) + mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict()) + mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict()) + + if init_with_kqvo: + # Copy weights: [z, x, B, C, dt], x -> v, B -> k, C -> q + mamba_encoder.mixer.in_proj.weight.data[ + mamba_config.ssm_cfg["d_inner"] : mamba_config.ssm_cfg["d_inner"] + mamba_config.ssm_cfg["d_xb"], : + ].copy_(layer_module.self_attn.v_proj.weight.data) + mamba_encoder.mixer.in_proj.weight.data[ + mamba_config.ssm_cfg["d_inner"] + + mamba_config.ssm_cfg["d_xb"] : mamba_config.ssm_cfg["d_inner"] + + 2 * mamba_config.ssm_cfg["d_xb"], + :, + ].copy_(layer_module.self_attn.k_proj.weight.data) + mamba_encoder.mixer.in_proj.weight.data[ + mamba_config.ssm_cfg["d_inner"] + + 2 * mamba_config.ssm_cfg["d_xb"] : 2 * mamba_config.ssm_cfg["d_inner"] + + 2 * mamba_config.ssm_cfg["d_xb"], + :, + ].copy_(layer_module.self_attn.q_proj.weight.data) + + print("Init Mamba using Attention") + + transformer.model.layers[layer_idx] = mamba_encoder + + else: + raise ValueError(f"Invalid layer type: {type}") + + +@click.command() +@click.option("--m2_index", type=int, required=True) +@click.option("--hybrid_checkpoint", type=str, required=True) +@click.option("--save_dir", type=str, required=True) +def main(m2_index: int, hybrid_checkpoint: str, save_dir: str): + path_base = "/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker" + transformer = AutoModelForCausalLM.from_pretrained(path_base, trust_remote_code=True) + hybrid_config = AprielSSMHybridConfig.from_pretrained(hybrid_checkpoint) + + hybrid_block_layout = hybrid_config.hybrid_block_layout + hybrid_block_layout[m2_index] = "m2" + print(hybrid_block_layout) + + convert_layers(transformer, hybrid_config, hybrid_block_layout, True, torch.bfloat16) + hybrid_config.ssm_cfg["activation"] = "silu" + + # load all existing ssm layers + hybrid_model = AprielThinkerSSMHybridForCausalLM.from_pretrained(hybrid_checkpoint) + state_dict = hybrid_model.state_dict() + missing, unexpected = transformer.load_state_dict(state_dict, strict=False) + assert f"model.layers.{m2_index}.mixer.A_log" in missing + assert f"model.layers.{m2_index}.self_attn.q_proj.weight" in unexpected + print(missing) + print(unexpected) + transformer.save_pretrained(save_dir) + + hybrid_config.save_pretrained(save_dir) + + gc.collect() + + +if __name__ == "__main__": + main() From 4e860cc7e5698e0262323b23589f4a5c24fa27b4 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 27 Aug 2025 14:49:18 +0000 Subject: [PATCH 160/161] make hybrid checkpoint script --- .../ssm/external/make_hybrid_checkpoint_with_mil.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py b/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py index d50a45fa3..6ce283525 100644 --- a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py +++ b/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py @@ -70,16 +70,18 @@ def convert_layers(transformer, mamba_config, hybrid_block_layout, init_with_kqv @click.command() -@click.option("--m2_index", type=int, required=True) +@click.option("--m2_indexes", type=int, nargs="-1", required=True) @click.option("--hybrid_checkpoint", type=str, required=True) @click.option("--save_dir", type=str, required=True) -def main(m2_index: int, hybrid_checkpoint: str, save_dir: str): +def main(m2_indexes: list, hybrid_checkpoint: str, save_dir: str): + m2_indexes = list(m2_indexes) # convert tuple -> list path_base = "/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker" transformer = AutoModelForCausalLM.from_pretrained(path_base, trust_remote_code=True) hybrid_config = AprielSSMHybridConfig.from_pretrained(hybrid_checkpoint) hybrid_block_layout = hybrid_config.hybrid_block_layout - hybrid_block_layout[m2_index] = "m2" + for m2_index in m2_indexes: + hybrid_block_layout[m2_index] = "m2" print(hybrid_block_layout) convert_layers(transformer, hybrid_config, hybrid_block_layout, True, torch.bfloat16) @@ -89,8 +91,9 @@ def main(m2_index: int, hybrid_checkpoint: str, save_dir: str): hybrid_model = AprielThinkerSSMHybridForCausalLM.from_pretrained(hybrid_checkpoint) state_dict = hybrid_model.state_dict() missing, unexpected = transformer.load_state_dict(state_dict, strict=False) - assert f"model.layers.{m2_index}.mixer.A_log" in missing - assert f"model.layers.{m2_index}.self_attn.q_proj.weight" in unexpected + for m2_index in m2_indexes: + assert f"model.layers.{m2_index}.mixer.A_log" in missing + assert f"model.layers.{m2_index}.self_attn.q_proj.weight" in unexpected print(missing) print(unexpected) transformer.save_pretrained(save_dir) From cd1df188689a594d717a8ae68b51af0fd5f5eebc Mon Sep 17 00:00:00 2001 From: RaymondLi0 Date: Wed, 17 Sep 2025 16:59:49 +0200 Subject: [PATCH 161/161] Multimodal-SSM fixes and utils (#357) --- fast_llm/data/dataset/gpt/memmap.py | 4 +- fast_llm/functional/cross_entropy.py | 2 +- fast_llm/layers/vision_encoder/adapter.py | 2 + .../llava_hybrid/modeling_llava_hybrid.py | 3 +- .../ssm/external/make_hybrid_checkpoint.py | 163 ++++++++++++++++++ .../external/make_llava_hybrid_checkpoint.py | 153 ++++++++++++++++ 6 files changed, 324 insertions(+), 3 deletions(-) create mode 100644 fast_llm/models/ssm/external/make_hybrid_checkpoint.py create mode 100644 fast_llm/models/ssm/external/make_llava_hybrid_checkpoint.py diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 493361f32..4f62561a8 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -65,7 +65,9 @@ def _init( offset = stream.tell() if num_documents is not None: - assert self._num_documents == num_documents + assert ( + self._num_documents == num_documents + ), f"Inconsistent num_documents for dataset {self.name} - {self._prefix}. Expected {num_documents}, got {self._num_documents}." self._index_bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".idx"), mode="r", order="C") self._index_bin_buffer = memoryview(self._index_bin_buffer_mmap) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 1be4ed82b..d9ca547a7 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -160,7 +160,7 @@ def _fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: - per_sample_loss = per_sample_loss[loss_mask] + per_sample_loss = per_sample_loss * loss_mask loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py index a59c6226f..7ec50dfee 100644 --- a/fast_llm/layers/vision_encoder/adapter.py +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -26,6 +26,7 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): bias=True, weight_init_method=init_normal_(std=config.adapter_init_method_std), bias_init_method=init_normal_(std=config.adapter_init_method_std), + lr_scale=config.adapter_lr_scale, ) self.layer_2 = Linear( tensor_space[VisionEncoderDimNames.adapter_size], @@ -33,6 +34,7 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): bias=True, weight_init_method=init_normal_(std=config.adapter_init_method_std), bias_init_method=init_normal_(std=config.adapter_init_method_std), + lr_scale=config.adapter_lr_scale, ) def forward( diff --git a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py index b056d3a00..68073f9cd 100644 --- a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py +++ b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py @@ -76,6 +76,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + pixel_values=None, **kwargs, ): # Copy of the method from `AprielThinkerSSMHybridForCausalLM` @@ -95,7 +96,7 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, cache_position] else: past_key_values = HybridMambaAttentionDynamicCache( - self.config, input_ids.shape[0], self.dtype, device=self.device + self.config.text_config, input_ids.shape[0], self.dtype, device=self.device ) if attention_mask is not None and position_ids is None: diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint.py b/fast_llm/models/ssm/external/make_hybrid_checkpoint.py new file mode 100644 index 000000000..8a21c906f --- /dev/null +++ b/fast_llm/models/ssm/external/make_hybrid_checkpoint.py @@ -0,0 +1,163 @@ +import gc + +import click +import torch +from transformers import AutoModelForCausalLM + +from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig +from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + AprielSSMM2DecoderLayer, + AprielThinkerSSMHybridForCausalLM, +) + +device = "cuda" if torch.cuda.is_available() else "cpu" + +dstate = 16 +expand = 1 +# Calculate derived dimensions for the Mamba1 configuration +# d_model = config_base.text_config.hidden_size +d_inner = 4096 # hard code to match thinker #expand * d_model +d_xb = 1024 # hard code to match thinker #config_thinker.num_key_value_heads * (config_thinker.hidden_size // config_thinker.num_attention_heads) + + +def convert_layers( + transformer_config, + transformer_model, + mamba_config, + hybrid_block_layout, + init_with_kqvo, + torch_dtype=torch.bfloat16, +): + config = transformer_config + embed_dim = config.hidden_size + num_heads = config.num_attention_heads + num_heads_kv = config.num_key_value_heads + head_dim = embed_dim // num_heads + head_dim * num_heads + head_dim * num_heads_kv + + for layer_idx, type in enumerate(hybrid_block_layout): + print("Converting layer %d...", layer_idx) + # Fetch the layer module for easier access + layer_module = transformer_model.layers._modules[f"{layer_idx}"] + if type == "t": + print("Skipping transformer layer %d..." % layer_idx) + elif type == "m2": + print("Converting layer %d..." % layer_idx) + # Use MambaDecoderLayer for the remaining layers + mamba_encoder = AprielSSMM2DecoderLayer( + mamba_config, + layer_idx, + device="cpu", + dtype=torch_dtype, + ) + + mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict()) + mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict()) + mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict()) + mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict()) + + if init_with_kqvo: + # Copy weights: [z, x, B, C, dt], x -> v, B -> k, C -> q + mamba_encoder.mixer.in_proj.weight.data[ + mamba_config.ssm_cfg["d_inner"] : mamba_config.ssm_cfg["d_inner"] + mamba_config.ssm_cfg["d_xb"], : + ].copy_(layer_module.self_attn.v_proj.weight.data) + mamba_encoder.mixer.in_proj.weight.data[ + mamba_config.ssm_cfg["d_inner"] + + mamba_config.ssm_cfg["d_xb"] : mamba_config.ssm_cfg["d_inner"] + + 2 * mamba_config.ssm_cfg["d_xb"], + :, + ].copy_(layer_module.self_attn.k_proj.weight.data) + mamba_encoder.mixer.in_proj.weight.data[ + mamba_config.ssm_cfg["d_inner"] + + 2 * mamba_config.ssm_cfg["d_xb"] : 2 * mamba_config.ssm_cfg["d_inner"] + + 2 * mamba_config.ssm_cfg["d_xb"], + :, + ].copy_(layer_module.self_attn.q_proj.weight.data) + + print("Init Mamba using Attention") + + transformer_model.layers[layer_idx] = mamba_encoder + + else: + raise ValueError(f"Invalid layer type: {type}") + + +def make_hybrid_config(transformer): + config_dict = transformer.config.to_dict() + config_dict["hybrid_block_layout"] = ["t"] * transformer.config.num_hidden_layers + config_dict["model_type"] = "apriel_ssm_thinker_hybrid" + config_dict["ssm_cfg"] = { + "activation": "silu", + "d_state": dstate, + "d_xb": d_xb, + "expand": expand, + "d_conv": 4, + "d_inner": d_inner, + "conv_bias": True, + "bias": False, + } + hybrid_config = AprielSSMHybridConfig.from_dict(**config_dict) + return hybrid_config + + +@click.command() +@click.option( + "--base_checkpoint", type=str, required=False, default="/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker" +) +@click.option("--m2_indices", type=int, multiple=True, required=True) +@click.option("--hybrid_checkpoint", type=str, required=True) +@click.option("--save_dir", type=str, required=True) +def main(base_checkpoint: str, m2_indices: list, hybrid_checkpoint: str, save_dir: str): + """ + base_checkpoint: path to base transformer-model (teacher model) + m2_indices: indices of layers to convert to mamba layers with MiL init + hybrid_checkpoint: path to hybrid model (student model). + save_dir: directory to save the converted model. + + TODO: base_checkpoint can actually be a hybrid. Rename transformer variable to a better name + """ + m2_indices = list(m2_indices) # convert tuple -> list + transformer = AutoModelForCausalLM.from_pretrained(base_checkpoint, trust_remote_code=True) + if hybrid_checkpoint == "none": + print("No hybrid checkpoint provided, creating new config from base model.") + hybrid_config = make_hybrid_config(transformer) + else: + hybrid_config = AprielSSMHybridConfig.from_pretrained(hybrid_checkpoint) + + hybrid_block_layout = hybrid_config.hybrid_block_layout + for m2_index in m2_indices: + hybrid_block_layout[m2_index] = "m2" + print(hybrid_block_layout) + + convert_layers( + transformer.config, + transformer.model, + hybrid_config, + hybrid_block_layout, + init_with_kqvo=True, + torch_dtype=torch.bfloat16, + ) + hybrid_config.ssm_cfg["activation"] = "silu" + + # load all existing ssm layers + if hybrid_checkpoint != "none": + hybrid_model = AprielThinkerSSMHybridForCausalLM.from_pretrained(hybrid_checkpoint) + state_dict = hybrid_model.state_dict() + missing, unexpected = transformer.load_state_dict(state_dict, strict=False) + for m2_index in m2_indices: + assert f"model.layers.{m2_index}.mixer.A_log" in missing + assert f"model.layers.{m2_index}.self_attn.q_proj.weight" in unexpected + print("MISSING", missing) + print("UNEXPECTED", unexpected) + + # Save state-dict + transformer.save_pretrained(save_dir) + + hybrid_config.save_pretrained(save_dir) + + gc.collect() + + +if __name__ == "__main__": + main() diff --git a/fast_llm/models/ssm/external/make_llava_hybrid_checkpoint.py b/fast_llm/models/ssm/external/make_llava_hybrid_checkpoint.py new file mode 100644 index 000000000..1f9808f1b --- /dev/null +++ b/fast_llm/models/ssm/external/make_llava_hybrid_checkpoint.py @@ -0,0 +1,153 @@ +import gc +import json +import shutil + +import click +import torch +from transformers import AutoModelForVision2Seq + +from fast_llm.models.ssm.external.apriel_15b_hybrid import modeling_ssm_hybrid_apriel15b +from fast_llm.models.ssm.external.llava_hybrid import configuration_llava_hybrid, modeling_llava_hybrid +from fast_llm.models.ssm.external.llava_hybrid.configuration_llava_hybrid import LlavaHybridConfig +from fast_llm.models.ssm.external.llava_hybrid.modeling_llava_hybrid import LlavaHybridForConditionalGeneration +from fast_llm.models.ssm.external.make_hybrid_checkpoint import convert_layers + +device = "cuda" if torch.cuda.is_available() else "cpu" + +dstate = 16 +expand = 1 +# Calculate derived dimensions for the Mamba1 configuration +# d_model = config_base.text_config.hidden_size +d_inner = 4096 # hard code to match thinker #expand * d_model +d_xb = 1024 # hard code to match thinker #config_thinker.num_key_value_heads * (config_thinker.hidden_size // config_thinker.num_attention_heads) + + +def make_hybrid_llava_config(transformer): + config_dict = transformer.config.to_dict() + config_dict["text_config"]["hybrid_block_layout"] = ["t"] * transformer.config.text_config.num_hidden_layers + config_dict["text_config"]["model_type"] = "apriel_ssm_thinker_hybrid" + config_dict["text_config"]["ssm_cfg"] = { + "activation": "silu", + "d_state": dstate, + "d_xb": d_xb, + # "d_model": d_model, # will be set automatically + "expand": expand, + "d_conv": 4, + "d_inner": d_inner, # will be same as d_model * expand, + "conv_bias": True, + "bias": False, + } + llava_hybrid_config = LlavaHybridConfig(**config_dict) + return llava_hybrid_config + + +def make_hybrid_llava_model(transformer, llava_hybrid_config): + """ + Create a LlavaHybridForConditionalGeneration model with the same configuration as the given transformer model. + """ + llava_hybrid_model = LlavaHybridForConditionalGeneration(llava_hybrid_config) + # llava_hybrid_model.to(dtype=torch.bfloat16).to(device) + llava_hybrid_model.load_state_dict(transformer.state_dict(), strict=False) + return llava_hybrid_model + + +@click.command() +@click.option("--base_checkpoint", type=str, required=False, default="ServiceNow-AI/Apriel-Nemotron-15b-Thinker") +@click.option("--m2_indices", type=int, multiple=True, required=True) +@click.option("--hybrid_checkpoint", type=str, required=True) +@click.option("--save_dir", type=str, required=True) +@click.option( + "--tokenizer_dir", type=str, required=False, default="/mnt/plato/checkpoints/upstream/Mistral-Nemo-Base-2407/" +) +def main(base_checkpoint: str, m2_indices: list[int], hybrid_checkpoint: str, save_dir: str, tokenizer_dir: str): + """ + base_checkpoint: path to base transformer-model (teacher model) + m2_indices: indices of layers to convert to mamba layers with MiL init + hybrid_checkpoint: path to hybrid model (student model). Can be a hybrid with only transformer layers for the first distillation run. + save_dir: directory to save the converted model. + tokenizer_dir: directory containing tokenizer files to copy over to save_dir. + """ + m2_indices = list(m2_indices) # convert tuple -> list + transformer = AutoModelForVision2Seq.from_pretrained(base_checkpoint, trust_remote_code=True) + if hybrid_checkpoint == "none": + print("No hybrid checkpoint provided, creating new config from base model.") + hybrid_config = make_hybrid_llava_config(transformer) + else: + hybrid_config = LlavaHybridConfig.from_pretrained(hybrid_checkpoint) + + hybrid_block_layout = hybrid_config.text_config.hybrid_block_layout + for m2_index in m2_indices: + hybrid_block_layout[m2_index] = "m2" + print(hybrid_block_layout) + + # MiL init + convert_layers( + transformer.model.language_model.config, + transformer.model.language_model, + hybrid_config.text_config, + hybrid_block_layout, + init_with_kqvo=True, + torch_dtype=torch.bfloat16, + ) + hybrid_config.text_config.ssm_cfg["activation"] = "silu" + + # Load existing SSM layers + if hybrid_checkpoint != "none": + hybrid_llava_model = AutoModelForVision2Seq.from_pretrained( + hybrid_checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True + ) + llava_state_dict = hybrid_llava_model.state_dict() + missing, unexpected = transformer.load_state_dict(llava_state_dict, strict=False) + for m2_index in m2_indices: + assert f"model.layers.{m2_index}.mixer.A_log" in missing + assert f"model.layers.{m2_index}.self_attn.q_proj.weight" in unexpected + print("MISSING", missing) + print("UNEXPECTED", unexpected) + + # Save state-dict + transformer.save_pretrained(save_dir) + # Save new config + hybrid_config.save_pretrained(save_dir) + + # Copy modeling and tokenizer files + modeling_files = [ + configuration_llava_hybrid.__file__, + modeling_llava_hybrid.__file__, + modeling_ssm_hybrid_apriel15b.__file__, + ] + tokenizer_files = [ + f"{tokenizer_dir}/tokenizer.json", + f"{tokenizer_dir}/tokenizer_config.json", + f"{tokenizer_dir}/generation_config.json", + f"{tokenizer_dir}/special_tokens_map.json", + ] + for f in modeling_files + tokenizer_files: + shutil.copy(f, save_dir) + + # Update config with auto_maps + config_file = f"{save_dir}/config.json" + with open(config_file) as f: + dumped_config = json.load(f) + + dumped_config["auto_map"] = { + "AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig", + "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", + "AutoModelForVision2Seq": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + "AutoModelForCausalLM": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + } + dumped_config["text_config"]["auto_map"] = { + "AutoConfig": "configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig", + "AutoModel": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridModel", + "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridForCausalLM", + } + dumped_config["architectures"] = ["LlavaHybridForConditionalGeneration"] + dumped_config["text_config"]["architectures"] = ["AprielThinkerSSMHybridForCausalLM"] + with open(config_file, "w") as f: + json.dump(dumped_config, f, indent=2) + + torch.cuda.empty_cache() + gc.collect() + + +if __name__ == "__main__": + main()