From c0862316708a9d64dbc5392b6ea8f00b35bab739 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Fri, 13 Sep 2024 11:14:36 +0200 Subject: [PATCH 1/8] clean mimi commit --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/mimi.md | 74 + docs/source/en/perf_infer_gpu_one.md | 2 + src/transformers/__init__.py | 14 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + .../models/auto/feature_extraction_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 1 + src/transformers/models/mimi/__init__.py | 57 + .../models/mimi/configuration_mimi.py | 234 +++ .../convert_mimi_checkpoint_to_pytorch.py | 198 ++ src/transformers/models/mimi/modeling_mimi.py | 1723 +++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 14 + tests/models/mimi/__init__.py | 0 tests/models/mimi/test_modeling_mimi.py | 886 +++++++++ 16 files changed, 3210 insertions(+) create mode 100644 docs/source/en/model_doc/mimi.md create mode 100644 src/transformers/models/mimi/__init__.py create mode 100644 src/transformers/models/mimi/configuration_mimi.py create mode 100644 src/transformers/models/mimi/convert_mimi_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/mimi/modeling_mimi.py create mode 100644 tests/models/mimi/__init__.py create mode 100644 tests/models/mimi/test_modeling_mimi.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 1c7f62ec6ea7b8..9d997247b722a5 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -720,6 +720,8 @@ title: Hubert - local: model_doc/mctct title: MCTCT + - local: model_doc/mimi + title: Mimi - local: model_doc/mms title: MMS - local: model_doc/musicgen diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 8e3a4da8b021de..ba967fa5956b87 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -210,6 +210,7 @@ Flax), PyTorch, and/or TensorFlow. | [Megatron-BERT](model_doc/megatron-bert) | ✅ | ❌ | ❌ | | [Megatron-GPT2](model_doc/megatron_gpt2) | ✅ | ✅ | ✅ | | [MGP-STR](model_doc/mgp-str) | ✅ | ❌ | ❌ | +| [Mimi](model_doc/mimi) | ✅ | ❌ | ❌ | | [Mistral](model_doc/mistral) | ✅ | ✅ | ✅ | | [Mixtral](model_doc/mixtral) | ✅ | ❌ | ❌ | | [mLUKE](model_doc/mluke) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/mimi.md b/docs/source/en/model_doc/mimi.md new file mode 100644 index 00000000000000..d8082ba6893719 --- /dev/null +++ b/docs/source/en/model_doc/mimi.md @@ -0,0 +1,74 @@ + + +# Mimi + +## Overview + +The Mimi model was proposed in []() by . + +The abstract from the paper is the following: + +** + +Mimi is a high-fidelity audio codec model developed by the Kyutai team. It can be used to project audio waveforms into quantized latent spaces, and vice versa. In other words, it can be used to map audio waveforms into “audio tokens”, known as “codebooks”. + + +Its architecture is based on [Encodec](model_doc/encodec) with several major differences: +* it uses a much lower frame-rate. +* it uses additional transformers for encoding and decoding for better latent contextualization +* it uses a different quantization scheme: one codebook is dedicated to semantic projection. + + + +## Usage example + +Here is a quick example of how to encode and decode an audio using this model: + +```python +>>> from datasets import load_dataset, Audio +>>> from transformers import MimiModel, AutoFeatureExtractor +>>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + +>>> # load model and feature extractor +>>> model = MimiModel.from_pretrained("kmhf/mimi-test") # TODO(YL): modify once official +>>> feature_extractor = AutoFeatureExtractor.from_pretrained("kmhf/mimi-test") + +>>> # load audio sample +>>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate)) +>>> audio_sample = librispeech_dummy[-1]["audio"]["array"] +>>> inputs = feature_extractor(raw_audio=audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") + +>>> encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"]) +>>> audio_values = model.decode(encoder_outputs.audio_codes, inputs["padding_mask"])[0] +>>> # or the equivalent with a forward pass +>>> audio_values = model(inputs["input_values"], inputs["padding_mask"]).audio_values +``` + +This model was contributed by [Yoach Lacombe (ylacombe)](https://huggingface.co/ylacombe). +The original code can be found [here](). + + +## MimiConfig + +[[autodoc]] MimiConfig + +## MimiModel + +[[autodoc]] MimiModel + - decode + - encode + - forward diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index dd3433f2cd4862..4c220dd0f1483c 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -61,6 +61,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next) * [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video) * [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision) +* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi) * [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava) * [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava) * [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100) @@ -228,6 +229,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) * [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision) +* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 00cc67915f3664..e73e5e66a99595 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -572,6 +572,7 @@ "MgpstrProcessor", "MgpstrTokenizer", ], + "models.mimi": ["MimiConfig"], "models.mistral": ["MistralConfig"], "models.mixtral": ["MixtralConfig"], "models.mluke": [], @@ -2664,6 +2665,12 @@ "MgpstrPreTrainedModel", ] ) + _import_structure["models.mimi"].extend( + [ + "MimiModel", + "MimiPreTrainedModel", + ] + ) _import_structure["models.mistral"].extend( [ "MistralForCausalLM", @@ -5341,6 +5348,9 @@ MgpstrProcessor, MgpstrTokenizer, ) + from .models.mimi import ( + MimiConfig, + ) from .models.mistral import MistralConfig from .models.mixtral import MixtralConfig from .models.mobilebert import ( @@ -7203,6 +7213,10 @@ MgpstrModel, MgpstrPreTrainedModel, ) + from .models.mimi import ( + MimiModel, + MimiPreTrainedModel, + ) from .models.mistral import ( MistralForCausalLM, MistralForSequenceClassification, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 26b96def67d992..358fd12ebf222c 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -149,6 +149,7 @@ megatron_bert, megatron_gpt2, mgp_str, + mimi, mistral, mixtral, mluke, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index fa1a7fb88eafa8..15dbc15206fe97 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -167,6 +167,7 @@ ("mega", "MegaConfig"), ("megatron-bert", "MegatronBertConfig"), ("mgp-str", "MgpstrConfig"), + ("mimi", "MimiConfig"), ("mistral", "MistralConfig"), ("mixtral", "MixtralConfig"), ("mobilebert", "MobileBertConfig"), @@ -467,6 +468,7 @@ ("megatron-bert", "Megatron-BERT"), ("megatron_gpt2", "Megatron-GPT2"), ("mgp-str", "MGP-STR"), + ("mimi", "Mimi"), ("mistral", "Mistral"), ("mixtral", "Mixtral"), ("mluke", "mLUKE"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 7f335d66584f9f..dca0c08aa90957 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -69,6 +69,7 @@ ("levit", "LevitFeatureExtractor"), ("maskformer", "MaskFormerFeatureExtractor"), ("mctct", "MCTCTFeatureExtractor"), + ("mimi", "EncodecFeatureExtractor"), ("mobilenet_v1", "MobileNetV1FeatureExtractor"), ("mobilenet_v2", "MobileNetV2FeatureExtractor"), ("mobilevit", "MobileViTFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 45a9c4d0d078b7..6feb992daf6464 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -158,6 +158,7 @@ ("mega", "MegaModel"), ("megatron-bert", "MegatronBertModel"), ("mgp-str", "MgpstrForSceneTextRecognition"), + ("mimi", "MimiModel"), ("mistral", "MistralModel"), ("mixtral", "MixtralModel"), ("mobilebert", "MobileBertModel"), diff --git a/src/transformers/models/mimi/__init__.py b/src/transformers/models/mimi/__init__.py new file mode 100644 index 00000000000000..43b2bec6caa5b3 --- /dev/null +++ b/src/transformers/models/mimi/__init__.py @@ -0,0 +1,57 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_mimi": ["MimiConfig"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_mimi"] = [ + "MimiModel", + "MimiPreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_mimi import ( + MimiConfig, + ) + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_mimi import ( + MimiModel, + MimiPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/mimi/configuration_mimi.py b/src/transformers/models/mimi/configuration_mimi.py new file mode 100644 index 00000000000000..5706881c10e7b7 --- /dev/null +++ b/src/transformers/models/mimi/configuration_mimi.py @@ -0,0 +1,234 @@ +# coding=utf-8 +# Copyright 2024 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Mimi model configuration""" + +import math + +import numpy as np + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MimiConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`MimiModel`]. It is used to instantiate a + Mimi model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the + [facebook/mimi_24khz](https://huggingface.co/facebook/mimi_24khz) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + sampling_rate (`int`, *optional*, defaults to 24000): + The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz). + frame_rate (`float`, *optional*, defaults to 12.5): + Framerate of the model. + audio_channels (`int`, *optional*, defaults to 1): + Number of channels in the audio data. Either 1 for mono or 2 for stereo. + hidden_size (`int`, *optional*, defaults to 512): + Intermediate representation dimension. + num_filters (`int`, *optional*, defaults to 64): + Number of convolution kernels of first `MimiConv1d` down sampling layer. + num_residual_layers (`int`, *optional*, defaults to 1): + Number of residual layers. + upsampling_ratios (`Sequence[int]`, *optional*): + Kernel size and stride ratios. The encoder uses downsampling ratios instead of upsampling ratios, hence it + will use the ratios in the reverse order to the ones specified here that must match the decoder order. + If not specified, will defaults to `[8, 6, 5, 4]` + kernel_size (`int`, *optional*, defaults to 7): + Kernel size for the initial convolution. + last_kernel_size (`int`, *optional*, defaults to 3): + Kernel size for the last convolution layer. + residual_kernel_size (`int`, *optional*, defaults to 3): + Kernel size for the residual layers. + dilation_growth_rate (`int`, *optional*, defaults to 2): + How much to increase the dilation with each layer. + use_causal_conv (`bool`, *optional*, defaults to `True`): + Whether to use fully causal convolution. + pad_mode (`str`, *optional*, defaults to `"constant"`): + Padding mode for the convolutions. + compress (`int`, *optional*, defaults to 2): + Reduced dimensionality in residual branches (from Demucs v3). + trim_right_ratio (`float`, *optional*, defaults to 1.0): + Ratio for trimming at the right of the transposed convolution under the `use_causal_conv = True` setup. If + equal to 1.0, it means that all the trimming is done at the right. + codebook_size (`int`, *optional*, defaults to 2048): + Number of discret codes in each codebooks. + codebook_dim (`int`, *optional*, defaults to 256): + Dimension of the unquantized codebook vectors. If not defined, uses `hidden_size`. + num_quantizers (`int`, *optional*, defaults to 32): + Number of quantizer channels, or codebooks, in the quantizer. + use_conv_shortcut (`bool`, *optional*, defaults to `False`): + Whether to use a convolutional layer as the 'skip' connection in the `MimiResnetBlock` block. If False, + an identity function will be used, giving a generic residual connection. + vector_quantization_hidden_dimension (`int`, *optional*, defaults to 256): + Intermediate representation dimension in the residual vector quantization space. + num_semantic_quantizers (`int`, *optional*, defaults to 1): + Number of semantic quantizer channels, or codebooks, in the semantic quantizer. Must be lower than `num_quantizers`. + upsample_groups (`int`, *optional*, defaults to 512): + If `frame_rate!=encodec_frame_rate`, indicates the number of groups used in the upsampling operation to go from one rate to another. + num_hidden_layers (`int`, *optional*, defaults to 8): + Number of hidden layers in the Transformer models. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimension of the MLP representations. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 8000): + The maximum sequence length that this model might ever be used with. Mimi's sliding window attention + allows sequence of up to 8000 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the LayerNorm normalization layers. + use_cache (`bool`, *optional*, defaults to `False`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*, defaults to 250): + Sliding window attention window size. If not specified, will default to `250`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + layer_scale_initial_scale (`float`, *optional*, defaults to 0.01): + Initiale scale of the residual rescaling operation done in the Transformer models. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + Example: + + ```python + >>> from transformers import MimiModel, MimiConfig + + >>> # Initializing a "facebook/mimi_24khz" style configuration + >>> configuration = MimiConfig() + + >>> # Initializing a model (with random weights) from the "facebook/mimi_24khz" style configuration + >>> model = MimiModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mimi" + + def __init__( + self, + sampling_rate=24_000, + frame_rate=12.5, + audio_channels=1, + hidden_size=512, + num_filters=64, + num_residual_layers=1, + upsampling_ratios=None, + kernel_size=7, + last_kernel_size=3, + residual_kernel_size=3, + dilation_growth_rate=2, + use_causal_conv=True, + pad_mode="constant", + compress=2, + trim_right_ratio=1.0, + codebook_size=2048, + codebook_dim=256, + num_quantizers=32, + use_conv_shortcut=False, + vector_quantization_hidden_dimension=256, + num_semantic_quantizers=1, + upsample_groups=512, + num_hidden_layers=8, + intermediate_size=2048, + num_attention_heads=8, + num_key_value_heads=8, + head_dim=None, + hidden_act="gelu", + max_position_embeddings=8000, + initializer_range=0.02, + norm_eps=1e-5, + use_cache=False, + rope_theta=10000.0, + sliding_window=250, + attention_dropout=0.0, + layer_scale_initial_scale=0.01, + attention_bias=False, + **kwargs, + ): + self.sampling_rate = sampling_rate + self.frame_rate = frame_rate + self.audio_channels = audio_channels + self.hidden_size = hidden_size + self.num_filters = num_filters + self.num_residual_layers = num_residual_layers + self.upsampling_ratios = upsampling_ratios if upsampling_ratios else [8, 6, 5, 4] + self.kernel_size = kernel_size + self.last_kernel_size = last_kernel_size + self.residual_kernel_size = residual_kernel_size + self.dilation_growth_rate = dilation_growth_rate + self.use_causal_conv = use_causal_conv + self.pad_mode = pad_mode + self.compress = compress + self.trim_right_ratio = trim_right_ratio + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size + self.num_quantizers = num_quantizers + self.use_conv_shortcut = use_conv_shortcut + self.vector_quantization_hidden_dimension = vector_quantization_hidden_dimension + self.upsample_groups = upsample_groups + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.norm_eps = norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.sliding_window = sliding_window + self.attention_dropout = attention_dropout + self.head_dim = head_dim or hidden_size // num_attention_heads + self.layer_scale_initial_scale = layer_scale_initial_scale + self.attention_bias = attention_bias + + if num_semantic_quantizers >= self.num_quantizers: + raise ValueError( + f"The number of semantic quantizers should be lower than the total number of quantizers {self.num_quantizers}, but is currently {num_semantic_quantizers}." + ) + self.num_semantic_quantizers = num_semantic_quantizers + super().__init__(**kwargs) + + @property + def encodec_frame_rate(self) -> int: + hop_length = np.prod(self.upsampling_ratios) + return math.ceil(self.sampling_rate / hop_length) + + @property + def num_codebooks(self) -> int: + # alias to num_quantizers + return self.num_quantizers diff --git a/src/transformers/models/mimi/convert_mimi_checkpoint_to_pytorch.py b/src/transformers/models/mimi/convert_mimi_checkpoint_to_pytorch.py new file mode 100644 index 00000000000000..c617fa036c5d47 --- /dev/null +++ b/src/transformers/models/mimi/convert_mimi_checkpoint_to_pytorch.py @@ -0,0 +1,198 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Mimi checkpoints.""" + +import argparse + +import safetensors +import torch + +from transformers import ( + EncodecFeatureExtractor, + MimiConfig, + MimiModel, + logging, +) + + +logging.set_verbosity_info() +logger = logging.get_logger("transformers.models.mimi") + + +def assert_param_count(model_1, model_2): + count_1 = sum(p[1].numel() for p in model_1.named_parameters() if "final_proj" not in p[0]) + count_2 = sum(p[1].numel() for p in model_2.named_parameters() if "final_proj" not in p[0]) + assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}" + + +def param_count(model): + return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0]) + + +def _grab_best_device(use_gpu=True): + if torch.cuda.device_count() > 0 and use_gpu: + device = "cuda" + else: + device = "cpu" + return torch.device(device) + + +convert_list = [ + # GENERAL + ("conv.conv.conv", "conv"), + ("convtr.convtr.convtr", "conv"), + ("conv.conv", "conv"), + ("convtr.convtr", "conv"), + # QUANTIZER + ("quantizer.rvq_first.vq", "quantizer.semantic_residual_vector_quantizer"), + ("quantizer.rvq_first", "quantizer.semantic_residual_vector_quantizer"), + ("quantizer.rvq_rest.vq", "quantizer.acoustic_residual_vector_quantizer"), + ("quantizer.rvq_rest", "quantizer.acoustic_residual_vector_quantizer"), + ("_codebook", "codebook"), + ("_initialized", "initialized"), + ("embedding_sum", "embed_sum"), + # ENCODER PART + ("encoder.model", "encoder.layers"), + ("decoder.model", "decoder.layers"), + # TRANSFORMERS PART + ("encoder_transformer.transformer", "encoder_transformer"), + ("decoder_transformer.transformer", "decoder_transformer"), + ("linear1", "mlp.fc1"), + ("linear2", "mlp.fc2"), + ("self_attn.out_proj", "self_attn.o_proj"), + ("norm1", "input_layernorm"), + ("norm2", "post_attention_layernorm"), + ("layer_scale_1", "self_attn_layer_scale"), + ("layer_scale_2", "mlp_layer_scale"), +] + + +def _convert_model( + state_dict, + hf_model, + convert_list, + device, + config, + unwanted_prefix=None, +): + hidden_size = config.hidden_size + head_dim = config.head_dim + num_heads = int(config.hidden_size // config.head_dim) + num_key_value_heads = config.num_key_value_heads + key_value_head_dim = config.num_key_value_heads * head_dim + + # permute for sliced rotary + def permute(w, n_heads, dim1=hidden_size, dim2=hidden_size): + return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + for k, v in list(state_dict.items()): + new_k = k if unwanted_prefix is None else k[len(unwanted_prefix) :] + for old_layer_name, new_layer_name in convert_list: + if old_layer_name in new_k: + new_k = new_k.replace(old_layer_name, new_layer_name) + + if "in_proj_weight" in new_k: + # split qkv into query key and value + mixed_qkv = state_dict.pop(k) + qkv_dim = mixed_qkv.size(0) // 3 + + query_layer = mixed_qkv[:qkv_dim] + key_layer = mixed_qkv[qkv_dim : qkv_dim * 2] + value_layer = mixed_qkv[qkv_dim * 2 :] + + state_dict[new_k.replace("in_proj_weight", "q_proj.weight")] = permute(query_layer, num_heads) + state_dict[new_k.replace("in_proj_weight", "k_proj.weight")] = permute( + key_layer, num_key_value_heads, dim1=key_value_head_dim + ) + state_dict[new_k.replace("in_proj_weight", "v_proj.weight")] = value_layer + else: + state_dict[new_k] = state_dict.pop(k) + + extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys()) + missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys()) + if len(extra_keys) != 0: + raise ValueError(f"extra keys found: {extra_keys}") + if len(missing_keys) != 0: + raise ValueError(f"missing keys: {missing_keys}") + hf_model.load_state_dict(state_dict, strict=True) + n_params = param_count(hf_model) + + logger.info(f"model loaded: {round(n_params/1e6,1)}M params") + + hf_model.eval() + hf_model.to(device) + del state_dict + + return hf_model + + +@torch.no_grad() +def convert_checkpoint( + checkpoint_path, + pytorch_dump_folder_path, + config_path=None, + repo_id=None, +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + device = _grab_best_device() + + if config_path is not None: + config = MimiConfig.from_pretrained(config_path) + else: + config = MimiConfig() + + model = MimiModel(config) + + feature_extractor = EncodecFeatureExtractor( + feature_size=config.audio_channels, + sampling_rate=config.sampling_rate, + ) + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + original_checkpoint = safetensors.torch.load_file(checkpoint_path) + if "best_state" in original_checkpoint: + # we might have a training state saved, in which case discard the yaml results and just retain the weights + original_checkpoint = original_checkpoint["best_state"] + + model = _convert_model(original_checkpoint, model, convert_list, device, config) + + model.save_pretrained(pytorch_dump_folder_path) + + if repo_id: + print("Pushing to the hub...") + feature_extractor.push_to_hub(repo_id) + model.push_to_hub(repo_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument( + "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub." + ) + + args = parser.parse_args() + convert_checkpoint( + args.checkpoint_path, + args.pytorch_dump_folder_path, + args.config_path, + args.push_to_hub, + ) diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py new file mode 100644 index 00000000000000..3a5afa6319c959 --- /dev/null +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -0,0 +1,1723 @@ +# coding=utf-8 +# Copyright 2024 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mimi model.""" + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_mimi import MimiConfig + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + +logger = logging.get_logger(__name__) + + +# General docstring +_CONFIG_FOR_DOC = "MimiConfig" + + +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +@dataclass +class MimiOutput(ModelOutput): + """ + Args: + audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*): + Discret code embeddings computed using `model.encode`. + audio_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) + Decoded audio values, obtained using the decoder part of Mimi. + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer. + This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + The model will output the same cache format that is fed as input. + + If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't + have their past key value states given to this model). + decoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer. + This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + The model will output the same cache format that is fed as input. + + If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't + have their past key value states given to this model). + """ + + audio_codes: torch.LongTensor = None + audio_values: torch.FloatTensor = None + encoder_past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None + decoder_past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None + + +@dataclass +class MimiEncoderOutput(ModelOutput): + """ + Args: + audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*): + Discret code embeddings computed using `model.encode`. + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer. + This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + The model will output the same cache format that is fed as input. + + If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't + have their past key value states given to this model). + """ + + audio_codes: torch.LongTensor = None + encoder_past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None + + +@dataclass +class MimiDecoderOutput(ModelOutput): + """ + Args: + audio_values (`torch.FloatTensor` of shape `(batch_size, segment_length)`, *optional*): + Decoded audio values, obtained using the decoder part of Mimi. + decoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer. + This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + The model will output the same cache format that is fed as input. + + If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't + have their past key value states given to this model). + """ + + audio_values: torch.FloatTensor = None + decoder_past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None + + +class MimiConv1d(nn.Module): + """Conv1d with asymmetric or causal padding and normalization.""" + + def __init__( + self, + config, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + pad_mode=None, + bias: bool = True, + ): + super().__init__() + self.causal = config.use_causal_conv + self.pad_mode = config.pad_mode if pad_mode is None else pad_mode + + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + logger.warning( + "MimiConv1d has been initialized with stride > 1 and dilation > 1" + f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})." + ) + + self.conv = nn.Conv1d( + in_channels, out_channels, kernel_size, stride, dilation=dilation, groups=groups, bias=bias + ) + + kernel_size = self.conv.kernel_size[0] + stride = torch.tensor(self.conv.stride[0], dtype=torch.int64) + dilation = self.conv.dilation[0] + + # Effective kernel size with dilations. + kernel_size = torch.tensor((kernel_size - 1) * dilation + 1, dtype=torch.int64) + + self.register_buffer("stride", stride, persistent=False) + self.register_buffer("kernel_size", kernel_size, persistent=False) + self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False) + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv) + + # Copied from transformers.models.encodec.modeling_encodec.EncodecConv1d._get_extra_padding_for_conv1d + def _get_extra_padding_for_conv1d( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + """See `pad_for_conv1d`.""" + length = hidden_states.shape[-1] + n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1 + n_frames = torch.ceil(n_frames).to(torch.int64) - 1 + ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total + + return ideal_length - length + + @staticmethod + # Copied from transformers.models.encodec.modeling_encodec.EncodecConv1d._pad1d + def _pad1d(hidden_states: torch.Tensor, paddings: Tuple[int, int], mode: str = "zero", value: float = 0.0): + """Tiny wrapper around torch.nn.functional.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happens. + """ + length = hidden_states.shape[-1] + padding_left, padding_right = paddings + if not mode == "reflect": + return nn.functional.pad(hidden_states, paddings, mode, value) + + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + hidden_states = nn.functional.pad(hidden_states, (0, extra_pad)) + padded = nn.functional.pad(hidden_states, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + + def forward(self, hidden_states): + extra_padding = self._get_extra_padding_for_conv1d(hidden_states) + + if self.causal: + # Left padding for causal + hidden_states = self._pad1d(hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = self.padding_total // 2 + padding_left = self.padding_total - padding_right + hidden_states = self._pad1d( + hidden_states, (padding_left, padding_right + extra_padding), mode=self.pad_mode + ) + + hidden_states = self.conv(hidden_states) + return hidden_states + + +class MimiConvTranspose1d(nn.Module): + """ConvTranspose1d with asymmetric or causal padding and normalization.""" + + def __init__( + self, + config, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + groups: int = 1, + bias=True, + ): + super().__init__() + self.causal = config.use_causal_conv + self.trim_right_ratio = config.trim_right_ratio + self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, groups=groups, bias=bias) + + if not (self.causal or self.trim_right_ratio == 1.0): + raise ValueError("`trim_right_ratio` != 1.0 only makes sense for causal convolutions") + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv) + + def forward(self, hidden_states): + kernel_size = self.conv.kernel_size[0] + stride = self.conv.stride[0] + padding_total = kernel_size - stride + + hidden_states = self.conv(hidden_states) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + + padding_left = padding_total - padding_right + + # unpad + end = hidden_states.shape[-1] - padding_right + hidden_states = hidden_states[..., padding_left:end] + return hidden_states + + +# Copied from transformers.models.encodec.modeling_encodec.EncodecResnetBlock with Encodec->Mimi,EnCodec->Mimi +class MimiResnetBlock(nn.Module): + """ + Residual block from SEANet model as used by Mimi. + """ + + def __init__(self, config: MimiConfig, dim: int, dilations: List[int]): + super().__init__() + kernel_sizes = (config.residual_kernel_size, 1) + if len(kernel_sizes) != len(dilations): + raise ValueError("Number of kernel sizes should match number of dilations") + + hidden = dim // config.compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [nn.ELU()] + block += [MimiConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)] + self.block = nn.ModuleList(block) + + if config.use_conv_shortcut: + self.shortcut = MimiConv1d(config, dim, dim, kernel_size=1) + else: + self.shortcut = nn.Identity() + + def forward(self, hidden_states): + residual = hidden_states + for layer in self.block: + hidden_states = layer(hidden_states) + + return self.shortcut(residual) + hidden_states + + +class MimiEncoder(nn.Module): + """SEANet encoder as used by Mimi.""" + + def __init__(self, config: MimiConfig): + super().__init__() + model = [MimiConv1d(config, config.audio_channels, config.num_filters, config.kernel_size)] + scaling = 1 + + # Downsample to raw audio scale + for ratio in reversed(config.upsampling_ratios): + current_scale = scaling * config.num_filters + # Add residual layers + for j in range(config.num_residual_layers): + model += [MimiResnetBlock(config, current_scale, [config.dilation_growth_rate**j, 1])] + # Add downsampling layers + model += [nn.ELU()] + model += [MimiConv1d(config, current_scale, current_scale * 2, kernel_size=ratio * 2, stride=ratio)] + scaling *= 2 + + model += [nn.ELU()] + model += [MimiConv1d(config, scaling * config.num_filters, config.hidden_size, config.last_kernel_size)] + + self.layers = nn.ModuleList(model) + + # Copied from transformers.models.encodec.modeling_encodec.EncodecEncoder.forward + def forward(self, hidden_states): + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +class MimiLayerScale(nn.Module): + """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). + This rescales diagonally the residual outputs close to 0, with a learnt scale. + """ + + def __init__(self, config): + super().__init__() + channels = config.hidden_size + initial_scale = config.layer_scale_initial_scale + self.scale = nn.Parameter(torch.full((channels,), initial_scale, requires_grad=True)) + + def forward(self, x: torch.Tensor): + return self.scale * x + + +# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mimi +class MimiRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward + # TODO(joao): add me back asap :) + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MimiMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + # Copied from transformers.models.clip.modeling_clip.CLIPMLP.forward + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# copied from transformers.models.gemma.modeling_gemma.GemmaAttention with Gemma->Mimi +class MimiAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MimiConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.scaling = 1 / math.sqrt(config.head_dim) + + if self.hidden_size % self.num_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.rotary_emb = MimiRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + self.sliding_window = config.sliding_window # Ignore copy + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi +class MimiFlashAttention2(MimiAttention): + """ + Mimi flash attention module. This module inherits from `MimiAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (MimiRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Mimi +class MimiSdpaAttention(MimiAttention): + """ + Mimi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MimiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MimiAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MimiModel is using MimiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +MIMI_ATTENTION_CLASSES = { + "eager": MimiAttention, + "flash_attention_2": MimiFlashAttention2, + "sdpa": MimiSdpaAttention, +} + + +class MimiTransformerLayer(nn.Module): + def __init__(self, config: MimiConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = MimiMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) + self.self_attn_layer_scale = MimiLayerScale(config) + self.mlp_layer_scale = MimiLayerScale(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + self.self_attn_layer_scale(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.mlp_layer_scale(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MimiTransformerModel(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MimiTransformerLayer`] + + Args: + config: MimiConfig + """ + + def __init__(self, config: MimiConfig): + super().__init__() + + self.layers = nn.ModuleList( + [MimiTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + + self.gradient_checkpointing = False + self.config = config + + def forward( + self, + hidden_states: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Embedded representation that will be contextualized by the model + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if use_cache and past_key_values is None and not self.training: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = None + if attention_mask is not None: + causal_mask = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values, output_attentions + ) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.gemma.modeling_gemma.GemmaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class MimiDecoder(nn.Module): + """SEANet decoder as used by Mimi.""" + + def __init__(self, config: MimiConfig): + super().__init__() + scaling = int(2 ** len(config.upsampling_ratios)) + model = [MimiConv1d(config, config.hidden_size, scaling * config.num_filters, config.kernel_size)] + + # Upsample to raw audio scale + for ratio in config.upsampling_ratios: + current_scale = scaling * config.num_filters + # Add upsampling layers + model += [nn.ELU()] + model += [ + MimiConvTranspose1d(config, current_scale, current_scale // 2, kernel_size=ratio * 2, stride=ratio) + ] + # Add residual layers + for j in range(config.num_residual_layers): + model += [MimiResnetBlock(config, current_scale // 2, (config.dilation_growth_rate**j, 1))] + scaling //= 2 + + # Add final layers + model += [nn.ELU()] + model += [MimiConv1d(config, config.num_filters, config.audio_channels, config.last_kernel_size)] + self.layers = nn.ModuleList(model) + + # Copied from transformers.models.encodec.modeling_encodec.EncodecDecoder.forward + def forward(self, hidden_states): + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +class MimiEuclideanCodebook(nn.Module): + """Codebook with Euclidean distance.""" + + def __init__(self, config: MimiConfig, epsilon: float = 1e-5): + super().__init__() + embed = torch.zeros(config.codebook_size, config.codebook_dim) + + self.codebook_size = config.codebook_size + + self.register_buffer("initialized", torch.Tensor([True])) + self.register_buffer("cluster_usage", torch.ones(config.codebook_size)) + self.register_buffer("embed_sum", embed) + self._embed = None + self.epsilon = epsilon + + @property + def embed(self) -> torch.Tensor: + if self._embed is None: + self._embed = self.embed_sum / self.cluster_usage.clamp(min=self.epsilon)[:, None] + return self._embed + + def quantize(self, hidden_states): + # Projects each vector in `hidden_states` over the nearest centroid and return its index. + # `hidden_states` should be `[N, D]` with `N` the number of input vectors and `D` the dimension. + dists = torch.cdist(hidden_states[None], self.embed[None], p=2)[0] + embed_ind = dists.argmin(dim=-1) + return embed_ind + + # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.encode + def encode(self, hidden_states): + shape = hidden_states.shape + # pre-process + hidden_states = hidden_states.reshape((-1, shape[-1])) + # quantize + embed_ind = self.quantize(hidden_states) + # post-process + embed_ind = embed_ind.view(*shape[:-1]) + return embed_ind + + # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.decode + def decode(self, embed_ind): + quantize = nn.functional.embedding(embed_ind, self.embed) + return quantize + + +# Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization with Encodec->Mimi +class MimiVectorQuantization(nn.Module): + """ + Vector quantization implementation. Currently supports only euclidean distance. + """ + + def __init__(self, config: MimiConfig): + super().__init__() + self.codebook = MimiEuclideanCodebook(config) + + def encode(self, hidden_states): + hidden_states = hidden_states.permute(0, 2, 1) + embed_in = self.codebook.encode(hidden_states) + return embed_in + + def decode(self, embed_ind): + quantize = self.codebook.decode(embed_ind) + quantize = quantize.permute(0, 2, 1) + return quantize + + +class MimiResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer.""" + + def __init__(self, config: MimiConfig, num_quantizers: int = None): + super().__init__() + self.codebook_size = config.codebook_size + self.frame_rate = config.frame_rate + self.num_quantizers = num_quantizers if num_quantizers is not None else config.num_quantizers + self.layers = nn.ModuleList([MimiVectorQuantization(config) for _ in range(self.num_quantizers)]) + + self.input_proj = None + self.output_proj = None + if config.vector_quantization_hidden_dimension != config.hidden_size: + self.input_proj = torch.nn.Conv1d( + config.hidden_size, config.vector_quantization_hidden_dimension, 1, bias=False + ) + self.output_proj = torch.nn.Conv1d( + config.vector_quantization_hidden_dimension, config.hidden_size, 1, bias=False + ) + + def encode(self, embeddings: torch.Tensor, num_quantizers: Optional[int] = None) -> torch.Tensor: + """ + Encode a given input tensor with the specified frame rate at the given number of quantizers / codebooks. The RVQ encode method sets + the appropriate number of quantizers to use and returns indices for each quantizer. + """ + if self.input_proj is not None: + embeddings = self.input_proj(embeddings) + + num_quantizers = num_quantizers if num_quantizers is not None else self.num_quantizers + + residual = embeddings + all_indices = [] + for layer in self.layers[:num_quantizers]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes of shape [B, K, T] to the quantized representation.""" + quantized_out = torch.tensor(0.0, device=codes.device) + codes = codes.transpose(0, 1) + for i, indices in enumerate(codes): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + + if self.output_proj is not None: + quantized_out = self.output_proj(quantized_out) + return quantized_out + + +class MimiSplitResidualVectorQuantizer(nn.Module): + """Split Residual Vector Quantizer.""" + + def __init__(self, config: MimiConfig): + super().__init__() + self.codebook_size = config.codebook_size + self.frame_rate = config.frame_rate + self.max_num_quantizers = config.num_quantizers + + self.num_semantic_quantizers = config.num_semantic_quantizers + self.num_acoustic_quantizers = config.num_quantizers - config.num_semantic_quantizers + + self.semantic_residual_vector_quantizer = MimiResidualVectorQuantizer(config, self.num_semantic_quantizers) + self.acoustic_residual_vector_quantizer = MimiResidualVectorQuantizer(config, self.num_acoustic_quantizers) + + def encode(self, embeddings: torch.Tensor, num_quantizers: Optional[float] = None) -> torch.Tensor: + """ + Encode a given input tensor with the specified frame rate at the given number of quantizers / codebooks. The RVQ encode method sets + the appropriate number of quantizers to use and returns indices for each quantizer. + """ + + num_quantizers = self.max_num_quantizers if num_quantizers is None else num_quantizers + + if num_quantizers > self.max_num_quantizers: + raise ValueError( + f"The number of quantizers (i.e codebooks) asked should be lower than the total number of quantizers {self.max_num_quantizers}, but is currently {num_quantizers}." + ) + + if num_quantizers < self.num_semantic_quantizers: + raise ValueError( + f"The number of quantizers (i.e codebooks) asked should be higher than the number of semantic quantizers {self.num_semantic_quantizers}, but is currently {num_quantizers}." + ) + + # codes is [K, B, T], with T frames, K nb of codebooks. + codes = self.semantic_residual_vector_quantizer.encode(embeddings) + + if num_quantizers > self.num_semantic_quantizers: + acoustic_codes = self.acoustic_residual_vector_quantizer.encode( + embeddings, num_quantizers=num_quantizers - self.num_semantic_quantizers + ) + codes = torch.cat([codes, acoustic_codes], dim=0) + + return codes + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation.""" + + # The first num_semantic_quantizers codebooks are decoded using the semantic RVQ + quantized_out = self.semantic_residual_vector_quantizer.decode(codes[:, : self.num_semantic_quantizers]) + + # The rest of the codebooks are decoded using the acoustic RVQ + if codes.shape[1] > self.num_semantic_quantizers: + quantized_out += self.acoustic_residual_vector_quantizer.decode(codes[:, self.num_semantic_quantizers :]) + return quantized_out + + +class MimiPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MimiConfig + base_model_prefix = "mimi" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _no_split_modules = ["MimiDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True + + # Copied from transformers.models.encodec.modeling_encodec.EncodecPreTrainedModel._init_weights + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LSTM): + for name, param in module.named_parameters(): + if "weight" in name: + nn.init.xavier_uniform_(param) + elif "bias" in name: + nn.init.constant_(param, 0.0) + + +MIMI_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MimiConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +MIMI_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*): + Raw audio input converted to Float. + padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0 + for *masked*. + num_quantizers (`int`, *optional*): + Number of quantizers (i.e codebooks) to use. By default, all quantizers are used. + audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*): + Discret code embeddings computed using `model.encode`. + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer. + This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + The model will output the same cache format that is fed as input. + + If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't + have their past key value states given to this model). + decoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer. + This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + The model will output the same cache format that is fed as input. + + If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't + have their past key value states given to this model). + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The Mimi neural audio codec model.", + MIMI_START_DOCSTRING, +) +class MimiModel(MimiPreTrainedModel): + def __init__(self, config: MimiConfig): + super().__init__(config) + self.config = config + + self.encoder = MimiEncoder(config) + self.encoder_transformer = MimiTransformerModel(config) + self.downsample = ( + MimiConv1d( + config, + config.hidden_size, + config.hidden_size, + kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate), + stride=2, + bias=False, + pad_mode="replicate", + ) + if config.frame_rate != config.encodec_frame_rate + else None + ) + self.upsample = ( + MimiConvTranspose1d( + config, + config.hidden_size, + config.hidden_size, + kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate), + stride=2, + bias=False, + groups=config.upsample_groups, + ) + if config.frame_rate != config.encodec_frame_rate + else None + ) + self.decoder_transformer = MimiTransformerModel(config) + self.decoder = MimiDecoder(config) + + self.quantizer = MimiSplitResidualVectorQuantizer(config) + + self.bits_per_codebook = int(math.log2(self.config.codebook_size)) + if 2**self.bits_per_codebook != self.config.codebook_size: + raise ValueError("The codebook_size must be a power of 2.") + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _encode_frame( + self, + input_values: torch.Tensor, + num_quantizers: int, + padding_mask: int, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + return_dict: Optional[bool] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Encodes the given input using the underlying VQVAE. The padding mask is required to compute the correct scale. + """ + embeddings = self.encoder(input_values) + encoder_outputs = self.encoder_transformer( + embeddings.transpose(1, 2), past_key_values=past_key_values, return_dict=return_dict + ) + if return_dict: + past_key_values = encoder_outputs.get("past_key_values") + elif len(encoder_outputs) > 1: + past_key_values = encoder_outputs[1] + embeddings = encoder_outputs[0].transpose(1, 2) + embeddings = self.downsample(embeddings) + + codes = self.quantizer.encode(embeddings, num_quantizers) + codes = codes.transpose(0, 1) + return codes, past_key_values + + def encode( + self, + input_values: torch.Tensor, + padding_mask: torch.Tensor = None, + num_quantizers: Optional[float] = None, + encoder_past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], MimiEncoderOutput]: + """ + Encodes the input audio waveform into discrete codes. + + Args: + input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): + Float values of the input audio waveform. + padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): + Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0 + for *masked*. + num_quantizers (`int`, *optional*): + Number of quantizers (i.e codebooks) to use. By default, all quantizers are used. + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer. + This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + The model will output the same cache format that is fed as input. + + If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't + have their past key value states given to this model). + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + `codebook` of shape `[batch_size, num_codebooks, frames]`, the discrete encoded codes for the input audio waveform. + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + num_quantizers = self.config.num_quantizers if num_quantizers is None else num_quantizers + + if num_quantizers > self.config.num_quantizers: + raise ValueError( + f"The number of quantizers (i.e codebooks) asked should be lower than the total number of quantizers {self.config.num_quantizers}, but is currently {num_quantizers}." + ) + + _, channels, input_length = input_values.shape + + if channels < 1 or channels > 2: + raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}") + + if padding_mask is None: + padding_mask = torch.ones_like(input_values).bool() + + encoded_frames, encoder_past_key_values = self._encode_frame( + input_values, + num_quantizers, + padding_mask.bool(), + past_key_values=encoder_past_key_values, + return_dict=return_dict, + ) + + if not return_dict: + return ( + encoded_frames, + encoder_past_key_values, + ) + + return MimiEncoderOutput(encoded_frames, encoder_past_key_values) + + def _decode_frame( + self, + codes: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + return_dict: Optional[bool] = None, + ) -> torch.Tensor: + embeddings = self.quantizer.decode(codes) + + embeddings = self.upsample(embeddings) + decoder_outputs = self.decoder_transformer( + embeddings.transpose(1, 2), past_key_values=past_key_values, return_dict=return_dict + ) + if return_dict: + past_key_values = decoder_outputs.get("past_key_values") + elif len(decoder_outputs) > 1: + past_key_values = decoder_outputs[1] + embeddings = decoder_outputs[0].transpose(1, 2) + outputs = self.decoder(embeddings) + return outputs, past_key_values + + def decode( + self, + audio_codes: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + decoder_past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], MimiDecoderOutput]: + """ + Decodes the given frames into an output audio waveform. + + Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be + trimmed. + + Args: + audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*): + Discret code embeddings computed using `model.encode`. + padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): + Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0 + for *masked*. + decoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer. + This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + The model will output the same cache format that is fed as input. + + If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't + have their past key value states given to this model). + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + audio_values, decoder_past_key_values = self._decode_frame( + audio_codes, past_key_values=decoder_past_key_values, return_dict=return_dict + ) + + # truncate based on padding mask + if padding_mask is not None and padding_mask.shape[-1] < audio_values.shape[-1]: + audio_values = audio_values[..., : padding_mask.shape[-1]] + + if not return_dict: + return ( + audio_values, + decoder_past_key_values, + ) + return MimiDecoderOutput(audio_values, decoder_past_key_values) + + @add_start_docstrings_to_model_forward(MIMI_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MimiOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + num_quantizers: Optional[int] = None, + audio_codes: Optional[torch.Tensor] = None, + encoder_past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + decoder_past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], MimiOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from datasets import load_dataset + >>> from transformers import AutoProcessor, MimiModel + + >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example") + >>> audio_sample = dataset["train"]["audio"][0]["array"] + + >>> model_id = "facebook/mimi_24khz" + >>> model = MimiModel.from_pretrained(model_id) + >>> processor = AutoProcessor.from_pretrained(model_id) + + >>> inputs = processor(raw_audio=audio_sample, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> audio_codes = outputs.audio_codes + >>> audio_values = outputs.audio_values + ```""" + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if padding_mask is None: + padding_mask = torch.ones_like(input_values).bool() + + if audio_codes is None: + encoder_outputs = self.encode( + input_values, padding_mask, num_quantizers, encoder_past_key_values, return_dict=return_dict + ) + audio_codes = encoder_outputs[0] + if return_dict: + encoder_past_key_values = encoder_outputs.get("past_key_values") + elif len(encoder_outputs) > 1: + encoder_past_key_values = encoder_outputs[1] + + decoder_outputs = self.decode(audio_codes, padding_mask, decoder_past_key_values, return_dict=return_dict) + audio_values = decoder_outputs[0] + if return_dict: + decoder_past_key_values = decoder_outputs.get("past_key_values") + elif len(decoder_outputs) > 1: + decoder_past_key_values = decoder_outputs[1] + + if not return_dict: + return (audio_codes, audio_values, encoder_past_key_values, decoder_past_key_values) + + return MimiOutput( + audio_codes=audio_codes, + audio_values=audio_values, + encoder_past_key_values=encoder_past_key_values, + decoder_past_key_values=decoder_past_key_values, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index b9ce0d0f15bbf5..22e180e9c70201 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -5840,6 +5840,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class MimiModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MimiPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MistralForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/mimi/__init__.py b/tests/models/mimi/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/mimi/test_modeling_mimi.py b/tests/models/mimi/test_modeling_mimi.py new file mode 100644 index 00000000000000..9ba0f8bf4509da --- /dev/null +++ b/tests/models/mimi/test_modeling_mimi.py @@ -0,0 +1,886 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Mimi model.""" + +import inspect +import os +import tempfile +import unittest + +import numpy as np +from datasets import Audio, load_dataset +from packaging import version +from parameterized import parameterized +from pytest import mark + +from transformers import AutoFeatureExtractor, MimiConfig +from transformers.testing_utils import ( + is_flaky, + is_torch_available, + require_flash_attn, + require_torch, + require_torch_gpu, + require_torch_sdpa, + slow, + torch_device, +) +from transformers.utils import ( + is_torch_bf16_available_on_device, + is_torch_fp16_available_on_device, +) + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import MimiModel + + +# Copied from transformers.tests.encodec.test_modeling_encodec.prepare_inputs_dict +def prepare_inputs_dict( + config, + input_ids=None, + input_values=None, + decoder_input_ids=None, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, +): + if input_ids is not None: + encoder_dict = {"input_ids": input_ids} + else: + encoder_dict = {"input_values": input_values} + + decoder_dict = {"decoder_input_ids": decoder_input_ids} if decoder_input_ids is not None else {} + + return {**encoder_dict, **decoder_dict} + + +@require_torch +class MimiModelTester: + def __init__( + self, + parent, + batch_size=5, + num_channels=1, + is_training=False, + intermediate_size=40, + hidden_size=32, + num_filters=8, + num_residual_layers=1, + upsampling_ratios=[8, 4], + codebook_size=64, + vector_quantization_hidden_dimension=64, + codebook_dim=64, + upsample_groups=32, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + sliding_window=4, + use_cache=False, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.is_training = is_training + self.intermediate_size = intermediate_size + self.hidden_size = hidden_size + self.num_filters = num_filters + self.num_residual_layers = num_residual_layers + self.upsampling_ratios = upsampling_ratios + self.codebook_size = codebook_size + self.vector_quantization_hidden_dimension = vector_quantization_hidden_dimension + self.codebook_dim = codebook_dim + self.upsample_groups = upsample_groups + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.sliding_window = sliding_window + self.use_cache = use_cache + + def prepare_config_and_inputs(self): + input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0) + config = self.get_config() + inputs_dict = {"input_values": input_values} + return config, inputs_dict + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + def prepare_config_and_inputs_for_model_class(self, model_class): + config, inputs_dict = self.prepare_config_and_inputs() + inputs_dict["audio_codes"] = ids_tensor([self.batch_size, 1, self.num_channels], self.codebook_size).type( + torch.int32 + ) + + return config, inputs_dict + + def get_config(self): + return MimiConfig( + audio_channels=self.num_channels, + chunk_in_sec=None, + hidden_size=self.hidden_size, + num_filters=self.num_filters, + num_residual_layers=self.num_residual_layers, + upsampling_ratios=self.upsampling_ratios, + codebook_size=self.codebook_size, + vector_quantization_hidden_dimension=self.vector_quantization_hidden_dimension, + upsample_groups=self.upsample_groups, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + sliding_window=self.sliding_window, + codebook_dim=self.codebook_dim, + use_cache=self.use_cache, + ) + + def create_and_check_model_forward(self, config, inputs_dict): + model = MimiModel(config=config).to(torch_device).eval() + + input_values = inputs_dict["input_values"] + result = model(input_values) + self.parent.assertEqual( + result.audio_values.shape, (self.batch_size, self.num_channels, self.intermediate_size) + ) + + +@require_torch +class MimiModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (MimiModel,) if is_torch_available() else () + is_encoder_decoder = True + test_pruning = False + test_headmasking = False + test_resize_embeddings = False + test_torchscript = False + input_name = "input_values" + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + # model does support returning hidden states + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + if "output_attentions" in inputs_dict: + inputs_dict.pop("output_attentions") + if "output_hidden_states" in inputs_dict: + inputs_dict.pop("output_hidden_states") + return inputs_dict + + def setUp(self): + self.model_tester = MimiModelTester(self) + self.config_tester = ConfigTester( + self, config_class=MimiConfig, hidden_size=37, common_properties=[], has_text_modality=False + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_forward(*config_and_inputs) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["input_values", "padding_mask", "num_quantizers"] + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + + @unittest.skip(reason="The MimiModel does not have `inputs_embeds` logics") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="The MimiModel does not have `inputs_embeds` logics") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="The MimiModel does not have the usual `attention` logic") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="The MimiModel does not have the usual `attention` logic") + def test_torchscript_output_attentions(self): + pass + + @unittest.skip(reason="The MimiModel does not have the usual `hidden_states` logic") + def test_torchscript_output_hidden_state(self): + pass + + # Copied from transformers.tests.encodec.test_modeling_encodec.MimiModelTest._create_and_check_torchscript + def _create_and_check_torchscript(self, config, inputs_dict): + if not self.test_torchscript: + self.skipTest(reason="test_torchscript is set to False") + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.torchscript = True + configs_no_init.return_dict = False + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class) + + main_input_name = model_class.main_input_name + + try: + main_input = inputs[main_input_name] + model(main_input) + traced_model = torch.jit.trace(model, main_input) + except RuntimeError: + self.fail("Couldn't trace module.") + + with tempfile.TemporaryDirectory() as tmp_dir_name: + pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") + + try: + torch.jit.save(traced_model, pt_file_name) + except Exception: + self.fail("Couldn't save module.") + + try: + loaded_model = torch.jit.load(pt_file_name) + except Exception: + self.fail("Couldn't load module.") + + model.to(torch_device) + model.eval() + + loaded_model.to(torch_device) + loaded_model.eval() + + model_state_dict = model.state_dict() + loaded_model_state_dict = loaded_model.state_dict() + + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + + models_equal = True + for layer_name, p1 in model_state_dict.items(): + if layer_name in loaded_model_state_dict: + p2 = loaded_model_state_dict[layer_name] + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. + # (Even with this call, there are still memory leak by ~0.04MB) + self.clear_torch_jit_class_registry() + + @unittest.skip(reason="The MimiModel does not have the usual `attention` logic") + def test_attention_outputs(self): + pass + + @unittest.skip(reason="The MimiModel does not have the usual `hidden_states` logic") + def test_hidden_states_output(self): + pass + + # Copied from transformers.tests.encodec.test_modeling_encodec.MimiModelTest.test_determinism + def test_determinism(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def check_determinism(first, second): + # outputs are not tensors but list (since each sequence don't have the same frame_length) + out_1 = first.cpu().numpy() + out_2 = second.cpu().numpy() + out_1 = out_1[~np.isnan(out_1)] + out_2 = out_2[~np.isnan(out_2)] + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + first = model(**self._prepare_for_class(inputs_dict, model_class))[0] + second = model(**self._prepare_for_class(inputs_dict, model_class))[0] + + if isinstance(first, tuple) and isinstance(second, tuple): + for tensor1, tensor2 in zip(first, second): + check_determinism(tensor1, tensor2) + else: + check_determinism(first, second) + + # Copied from transformers.tests.encodec.test_modeling_encodec.MimiModelTest.test_model_outputs_equivalence + def test_model_outputs_equivalence(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def set_nan_tensor_to_zero(t): + t[t != t] = 0 + return t + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + with torch.no_grad(): + tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs) + + self.assertTrue(isinstance(tuple_output, tuple)) + self.assertTrue(isinstance(dict_output, dict)) + + for tuple_value, dict_value in zip(tuple_output, dict_output.values()): + self.assertTrue( + torch.allclose( + set_nan_tensor_to_zero(tuple_value), set_nan_tensor_to_zero(dict_value), atol=1e-5 + ), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_value - dict_value))}. Tuple has `nan`:" + f" {torch.isnan(tuple_value).any()} and `inf`: {torch.isinf(tuple_value)}. Dict has" + f" `nan`: {torch.isnan(dict_value).any()} and `inf`: {torch.isinf(dict_value)}." + ), + ) + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + uniform_init_parms = ["conv", "input_proj", "output_proj"] + if param.requires_grad: + if any(x in name for x in uniform_init_parms): + self.assertTrue( + -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + # Copied from transformers.tests.encodec.test_modeling_encodec.MimiModelTest.test_identity_shortcut + def test_identity_shortcut(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + config.use_conv_shortcut = False + self.model_tester.create_and_check_model_forward(config, inputs_dict) + + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + + if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device): + self.skipTest( + f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" + ) + + # Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead. + if torch_dtype == "float16": + torch_dtype = torch.float16 + elif torch_dtype == "bfloat16": + torch_dtype = torch.bfloat16 + elif torch_dtype == "float32": + torch_dtype = torch.float32 + + atols = { + ("cpu", False, torch.float32): 1e-6, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-6, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-6, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-6, + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + rtols = { + ("cpu", False, torch.float32): 1e-4, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-4, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-4, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-4, + ("cuda", True, torch.bfloat16): 3e-2, + ("cuda", True, torch.float16): 5e-3, + } + + def get_mean_reldiff(failcase, x, ref, atol, rtol): + return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + # FIXME: we deactivate boolean mask for models using "use_mask_token" in their constructors. + # These models support masking only in the case `use_mask_token=True`. Otherwise they cannot consume an input mask. + # This means that the class needs to be instantiated much later, after `use_mask` is set, which means a significant refactor of the code. + # However masking there is not done at any layers that matters (i.e self-attention), therefore we can safely deactivate it. + deactivate_mask = "use_mask_token" in inspect.signature(model_class).parameters + + is_encoder_decoder = model.config.is_encoder_decoder + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) + model_sdpa = model_sdpa.eval().to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch_dtype, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa and model_sdpa.config.model_type != "falcon": + raise ValueError("The SDPA model should have SDPA attention layers") + + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model, + # but it would be nicer to have an efficient way to use parameterized.expand + fail_cases = [] + for padding_side in ["left", "right"]: + for use_mask in [False, True]: + for output_attentions in [True, False]: + can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters + if not (self.has_attentions and can_output_attn) and output_attentions: + continue + for batch_size in [1, 5]: + dummy_input = inputs_dict[model.main_input_name] + + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + dummy_input = dummy_input.to(torch_dtype) + + dummy_input = dummy_input[:batch_size] + if dummy_input.shape[0] != batch_size: + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + extension = torch.rand( + batch_size - dummy_input.shape[0], + *dummy_input.shape[1:], + dtype=torch_dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) + else: + extension = torch.randint( + high=5, + size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]), + dtype=dummy_input.dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) + + if not use_mask: + dummy_attention_mask = None + else: + dummy_attention_mask = inputs_dict.get("attention_mask", None) + if dummy_attention_mask is None: + if is_encoder_decoder: + seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1] + else: + seqlen = dummy_input.shape[-1] + dummy_attention_mask = ( + torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) + ) + + dummy_attention_mask = dummy_attention_mask[:batch_size] + if dummy_attention_mask.shape[0] != batch_size: + extension = torch.ones( + batch_size - dummy_attention_mask.shape[0], + *dummy_attention_mask.shape[1:], + dtype=dummy_attention_mask.dtype, + device=torch_device, + ) + dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) + dummy_attention_mask = dummy_attention_mask.to(torch_device) + + dummy_attention_mask[:] = 1 + if padding_side == "left": + dummy_attention_mask[-1, :-1] = 1 + dummy_attention_mask[-1, -4:] = 0 + elif padding_side == "right": + dummy_attention_mask[-1, 1:] = 1 + dummy_attention_mask[-1, :3] = 0 + + for enable_kernels in [False, True]: + failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}" + if is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[ + :batch_size + ] + if decoder_input_ids.shape[0] != batch_size: + extension = torch.ones( + batch_size - decoder_input_ids.shape[0], + *decoder_input_ids.shape[1:], + dtype=decoder_input_ids.dtype, + device=torch_device, + ) + decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0) + decoder_input_ids = decoder_input_ids.to(torch_device) + + # TODO: never an `attention_mask` arg here? + processed_inputs = { + model.main_input_name: dummy_input, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + else: + processed_inputs = { + model.main_input_name: dummy_input, + "output_hidden_states": True, + } + + # Otherwise fails for e.g. WhisperEncoderModel + if "attention_mask" in inspect.signature(model_eager.forward).parameters: + processed_inputs["attention_mask"] = dummy_attention_mask + + if ( + self.has_attentions + and "output_attentions" in inspect.signature(model_sdpa.forward).parameters + ): + processed_inputs["output_attentions"] = output_attentions + if not deactivate_mask and ( + "bool_masked_pos" in inspect.signature(model_eager.forward).parameters + ): + dummy_mask = torch.ones((self.model_tester.num_masks,)) + + # In case of additional token (like class) we define a custom `mask_length` + if hasattr(self.model_tester, "mask_length"): + mask_length = self.model_tester.mask_length - dummy_mask.size(0) + else: + mask_length = self.model_tester.seq_length - dummy_mask.size(0) + dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)]) + dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool() + processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device) + + if "noise" in inspect.signature(model_eager.forward).parameters: + np.random.seed(2) + num_patches = int( + (self.model_tester.image_size // self.model_tester.patch_size) ** 2 + ) + noise = np.random.uniform(size=(batch_size, num_patches)) + processed_inputs["noise"] = torch.from_numpy(noise) + + # TODO: test gradients as well (& for FA2 as well!) + with torch.no_grad(): + with torch.backends.cuda.sdp_kernel( + enable_flash=enable_kernels, + enable_math=True, + enable_mem_efficient=enable_kernels, + ): + prepared_inputs = self._prepare_for_class(processed_inputs, model_class) + outputs_eager = model_eager(**prepared_inputs) + outputs_sdpa = model_sdpa(**prepared_inputs) + + # Ignore copy + logits_eager = outputs_eager.audio_values + # Ignore copy + logits_sdpa = outputs_sdpa.audio_values + + if torch_device in ["cpu", "cuda"]: + atol = atols[torch_device, enable_kernels, torch_dtype] + rtol = rtols[torch_device, enable_kernels, torch_dtype] + else: + atol = 1e-7 + rtol = 1e-4 + + # Masked tokens output slightly deviates - we don't mind that. + if use_mask: + if padding_side == "left": + sub_sdpa = logits_sdpa[:-1] + sub_eager = logits_eager[:-1] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + sub_sdpa = logits_sdpa[-1, :-4] + sub_eager = logits_eager[-1, :-4] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + # Testing the padding tokens is not really meaningful but anyway + # sub_sdpa = logits_sdpa[-1, -4:] + # sub_eager = logits_eager[-1, -4:] + # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) + elif padding_side == "right": + sub_sdpa = logits_sdpa[:-1] + sub_eager = logits_eager[:-1] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + sub_sdpa = logits_sdpa[-1, 3:] + sub_eager = logits_eager[-1, 3:] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + # Testing the padding tokens is not really meaningful but anyway + # sub_sdpa = logits_sdpa[-1, :3] + # sub_eager = logits_eager[-1, :3] + # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) + + else: + if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol) + ) + + self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + @is_flaky() + def test_flash_attn_2_inference_equivalence(self): + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + dummy_input = inputs_dict[model.main_input_name][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + outputs = model(dummy_input) + outputs_fa = model_fa(dummy_input) + + logits = outputs[1] + logits_fa = outputs_fa[1] + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + @unittest.skip(reason="The MimiModel does not support right padding") + def test_flash_attn_2_inference_equivalence_right_padding(self): + pass + + @unittest.skip(reason="The MimiModel does not have support dynamic compile yet") + def test_sdpa_can_compile_dynamic(self): + pass + + # For now, Let's focus only on GPU for `torch.compile` + @slow + @require_torch_gpu + def test_torch_compile(self): + if version.parse(torch.__version__) < version.parse("2.3"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + n_iter = 3 + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model.forward = torch.compile(model.forward) + for i in range(n_iter): + _ = model(inputs_dict["input_values"].to(torch_device)) + + +# Copied from transformers.tests.encodec.test_modeling_encodec.normalize +def normalize(arr): + norm = np.linalg.norm(arr) + normalized_arr = arr / norm + return normalized_arr + + +# Copied from transformers.tests.encodec.test_modeling_encodec.compute_rmse +def compute_rmse(arr1, arr2): + arr1_normalized = normalize(arr1) + arr2_normalized = normalize(arr2) + return np.sqrt(((arr1_normalized - arr2_normalized) ** 2).mean()) + + +@slow +@require_torch +class MimiIntegrationTest(unittest.TestCase): + def test_integration_using_cache_decode(self): + expected_rmse = { + "8": 0.0018785292, + "32": 0.0012330565, + } + + librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + model_id = "kmhf/mimi-test" # TODO(YL): modify once official + + model = MimiModel.from_pretrained(model_id, use_cache=True).to(torch_device) + processor = AutoFeatureExtractor.from_pretrained(model_id) + + librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate)) + audio_sample = librispeech_dummy[-1]["audio"]["array"] + + inputs = processor( + raw_audio=audio_sample, + sampling_rate=processor.sampling_rate, + return_tensors="pt", + ).to(torch_device) + + for num_codebooks, expected_rmse in expected_rmse.items(): + with torch.no_grad(): + # use max bandwith for best possible reconstruction + encoder_outputs = model.encode(inputs["input_values"], num_quantizers=int(num_codebooks)) + + audio_codes = encoder_outputs[0] + + decoder_outputs_first_part = model.decode(audio_codes[:, :, : audio_codes.shape[2] // 2]) + decoder_outputs_second_part = model.decode( + audio_codes[:, :, audio_codes.shape[2] // 2 :], + decoder_past_key_values=decoder_outputs_first_part.decoder_past_key_values, + ) + + audio_output_entire_context = model.decode(audio_codes)[0] + audio_output_concat_context = torch.cat( + [decoder_outputs_first_part[0], decoder_outputs_second_part[0]], dim=2 + ) + + # make sure audios are more or less equal + # the RMSE of two random gaussian noise vectors with ~N(0, 1) is around 1.0 + rmse = compute_rmse( + audio_output_concat_context.squeeze().cpu().numpy(), + audio_output_entire_context.squeeze().cpu().numpy(), + ) + self.assertTrue(rmse < 1e-3) + + def test_integration(self): + expected_rmses = { + "8": 0.0018785292, + "32": 0.0012330565, + } + expected_codesums = { + "8": 430423, + "32": 1803071, + } + librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + model_id = "kmhf/mimi-test" # TODO(YL): modify once official + + processor = AutoFeatureExtractor.from_pretrained(model_id) + + librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate)) + audio_sample = librispeech_dummy[-1]["audio"]["array"] + + inputs = processor( + raw_audio=audio_sample, + sampling_rate=processor.sampling_rate, + return_tensors="pt", + ).to(torch_device) + + for use_cache in [False, True]: + model = MimiModel.from_pretrained(model_id, use_cache=use_cache).to(torch_device) + for num_codebooks, expected_rmse in expected_rmses.items(): + with torch.no_grad(): + # use max bandwith for best possible reconstruction + encoder_outputs = model.encode(inputs["input_values"], num_quantizers=int(num_codebooks)) + + audio_code_sums = encoder_outputs[0].sum().cpu().item() + + # make sure audio encoded codes are correct + # assert relative difference less than a threshold, because `audio_code_sums` varies a bit + # depending on torch version + self.assertTrue( + np.abs(audio_code_sums - expected_codesums[num_codebooks]) <= (3e-3 * audio_code_sums) + ) + + input_values_dec = model.decode(encoder_outputs[0], padding_mask=inputs["padding_mask"])[0] + input_values_enc_dec = model( + inputs["input_values"], inputs["padding_mask"], num_quantizers=int(num_codebooks) + )[1] + + # make sure forward and decode gives same result + self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec)) + + # make sure shape matches + self.assertTrue(inputs["input_values"].shape == input_values_enc_dec.shape) + + arr = inputs["input_values"][0].cpu().numpy() + arr_enc_dec = input_values_enc_dec[0].cpu().numpy() + + # make sure audios are more or less equal + # the RMSE of two random gaussian noise vectors with ~N(0, 1) is around 1.0 + rmse = compute_rmse(arr, arr_enc_dec) + self.assertTrue(np.abs(rmse - expected_rmse) < 1e-5) From a544d27fd73f0e16b9a39cfa606027741bfdcb7e Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Wed, 18 Sep 2024 09:39:38 +0200 Subject: [PATCH 2/8] some nits suggestions from Arthur --- docs/source/en/model_doc/mimi.md | 4 +- .../models/mimi/configuration_mimi.py | 8 +- src/transformers/models/mimi/modeling_mimi.py | 115 +++++++++--------- tests/models/mimi/test_modeling_mimi.py | 4 +- 4 files changed, 65 insertions(+), 66 deletions(-) diff --git a/docs/source/en/model_doc/mimi.md b/docs/source/en/model_doc/mimi.md index d8082ba6893719..d9918b4e401e2d 100644 --- a/docs/source/en/model_doc/mimi.md +++ b/docs/source/en/model_doc/mimi.md @@ -44,8 +44,8 @@ Here is a quick example of how to encode and decode an audio using this model: >>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> # load model and feature extractor ->>> model = MimiModel.from_pretrained("kmhf/mimi-test") # TODO(YL): modify once official ->>> feature_extractor = AutoFeatureExtractor.from_pretrained("kmhf/mimi-test") +>>> model = MimiModel.from_pretrained("kmhf/mimi") # TODO(YL): modify once official +>>> feature_extractor = AutoFeatureExtractor.from_pretrained("kmhf/mimi") >>> # load audio sample >>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate)) diff --git a/src/transformers/models/mimi/configuration_mimi.py b/src/transformers/models/mimi/configuration_mimi.py index 5706881c10e7b7..fe941e3f859e64 100644 --- a/src/transformers/models/mimi/configuration_mimi.py +++ b/src/transformers/models/mimi/configuration_mimi.py @@ -30,7 +30,7 @@ class MimiConfig(PretrainedConfig): This is the configuration class to store the configuration of an [`MimiModel`]. It is used to instantiate a Mimi model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the - [facebook/mimi_24khz](https://huggingface.co/facebook/mimi_24khz) architecture. + [kmhf/mimi](https://huggingface.co/kmhf/mimi) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -65,7 +65,7 @@ class MimiConfig(PretrainedConfig): pad_mode (`str`, *optional*, defaults to `"constant"`): Padding mode for the convolutions. compress (`int`, *optional*, defaults to 2): - Reduced dimensionality in residual branches (from Demucs v3). + Reduced dimensionality in residual branches. trim_right_ratio (`float`, *optional*, defaults to 1.0): Ratio for trimming at the right of the transposed convolution under the `use_causal_conv = True` setup. If equal to 1.0, it means that all the trimming is done at the right. @@ -126,10 +126,10 @@ class MimiConfig(PretrainedConfig): ```python >>> from transformers import MimiModel, MimiConfig - >>> # Initializing a "facebook/mimi_24khz" style configuration + >>> # Initializing a "kmhf/mimi" style configuration >>> configuration = MimiConfig() - >>> # Initializing a model (with random weights) from the "facebook/mimi_24khz" style configuration + >>> # Initializing a model (with random weights) from the "kmhf/mimi" style configuration >>> model = MimiModel(configuration) >>> # Accessing the model configuration diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 3a5afa6319c959..0eae5cd5e238a2 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2024 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved. +# Copyright 2024 Kyutai, and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -215,6 +215,10 @@ def __init__( self.register_buffer("stride", stride, persistent=False) self.register_buffer("kernel_size", kernel_size, persistent=False) self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False) + + # Asymmetric padding required for odd strides + self.padding_right = self.padding_total // 2 + self.padding_left = self.padding_total - self.padding_right def apply_weight_norm(self): weight_norm = nn.utils.weight_norm @@ -266,11 +270,8 @@ def forward(self, hidden_states): # Left padding for causal hidden_states = self._pad1d(hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode) else: - # Asymmetric padding required for odd strides - padding_right = self.padding_total // 2 - padding_left = self.padding_total - padding_right hidden_states = self._pad1d( - hidden_states, (padding_left, padding_right + extra_padding), mode=self.pad_mode + hidden_states, (self.padding_left, self.padding_right + extra_padding), mode=self.pad_mode ) hidden_states = self.conv(hidden_states) @@ -294,27 +295,14 @@ def __init__( self.causal = config.use_causal_conv self.trim_right_ratio = config.trim_right_ratio self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, groups=groups, bias=bias) - + if not (self.causal or self.trim_right_ratio == 1.0): raise ValueError("`trim_right_ratio` != 1.0 only makes sense for causal convolutions") - - def apply_weight_norm(self): - weight_norm = nn.utils.weight_norm - if hasattr(nn.utils.parametrizations, "weight_norm"): - weight_norm = nn.utils.parametrizations.weight_norm - - weight_norm(self.conv) - - def remove_weight_norm(self): - nn.utils.remove_weight_norm(self.conv) - - def forward(self, hidden_states): + kernel_size = self.conv.kernel_size[0] stride = self.conv.stride[0] padding_total = kernel_size - stride - hidden_states = self.conv(hidden_states) - # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be # removed at the very end, when keeping only the right length for the output, # as removing it here would require also passing the length at the matching layer @@ -322,16 +310,29 @@ def forward(self, hidden_states): if self.causal: # Trim the padding on the right according to the specified ratio # if trim_right_ratio = 1.0, trim everything from right - padding_right = math.ceil(padding_total * self.trim_right_ratio) + self.padding_right = math.ceil(padding_total * self.trim_right_ratio) else: # Asymmetric padding required for odd strides - padding_right = padding_total // 2 + self.padding_right = padding_total // 2 - padding_left = padding_total - padding_right + self.padding_left = padding_total - padding_right + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) # unpad - end = hidden_states.shape[-1] - padding_right - hidden_states = hidden_states[..., padding_left:end] + end = hidden_states.shape[-1] - self.padding_right + hidden_states = hidden_states[..., self.padding_left:end] return hidden_states @@ -510,7 +511,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# copied from transformers.models.gemma.modeling_gemma.GemmaAttention with Gemma->Mimi +# Copied from transformers.models.gemma.modeling_gemma.GemmaAttention with Gemma->Mimi class MimiAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -612,7 +613,7 @@ def forward( return attn_output, attn_weights, past_key_value -# copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi +# Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi class MimiFlashAttention2(MimiAttention): """ Mimi flash attention module. This module inherits from `MimiAttention` as the weights of the module stays @@ -723,7 +724,7 @@ def forward( return attn_output, attn_weights, past_key_value -# copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Mimi +# Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Mimi class MimiSdpaAttention(MimiAttention): """ Mimi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -1452,32 +1453,30 @@ def __init__(self, config: MimiConfig): self.encoder = MimiEncoder(config) self.encoder_transformer = MimiTransformerModel(config) - self.downsample = ( - MimiConv1d( - config, - config.hidden_size, - config.hidden_size, - kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate), - stride=2, - bias=False, - pad_mode="replicate", - ) - if config.frame_rate != config.encodec_frame_rate - else None - ) - self.upsample = ( - MimiConvTranspose1d( - config, - config.hidden_size, - config.hidden_size, - kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate), - stride=2, - bias=False, - groups=config.upsample_groups, - ) - if config.frame_rate != config.encodec_frame_rate - else None - ) + + self.downsample = None + self.upsample = None + if config.frame_rate != config.encodec_frame_rate: + self.downsample = MimiConv1d( + config, + config.hidden_size, + config.hidden_size, + kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate), + stride=2, + bias=False, + pad_mode="replicate", + ) + + self.upsample = MimiConvTranspose1d( + config, + config.hidden_size, + config.hidden_size, + kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate), + stride=2, + bias=False, + groups=config.upsample_groups, + ) + self.decoder_transformer = MimiTransformerModel(config) self.decoder = MimiDecoder(config) @@ -1675,16 +1674,16 @@ def forward( ```python >>> from datasets import load_dataset - >>> from transformers import AutoProcessor, MimiModel + >>> from transformers import AutoFeatureExtractor, MimiModel >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example") >>> audio_sample = dataset["train"]["audio"][0]["array"] - >>> model_id = "facebook/mimi_24khz" + >>> model_id = "kmhf/mimi" >>> model = MimiModel.from_pretrained(model_id) - >>> processor = AutoProcessor.from_pretrained(model_id) + >>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id) - >>> inputs = processor(raw_audio=audio_sample, return_tensors="pt") + >>> inputs = feature_extractor(raw_audio=audio_sample, return_tensors="pt") >>> outputs = model(**inputs) >>> audio_codes = outputs.audio_codes diff --git a/tests/models/mimi/test_modeling_mimi.py b/tests/models/mimi/test_modeling_mimi.py index 9ba0f8bf4509da..1f44211dde4a22 100644 --- a/tests/models/mimi/test_modeling_mimi.py +++ b/tests/models/mimi/test_modeling_mimi.py @@ -787,7 +787,7 @@ def test_integration_using_cache_decode(self): } librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - model_id = "kmhf/mimi-test" # TODO(YL): modify once official + model_id = "kmhf/mimi" # TODO(YL): modify once official model = MimiModel.from_pretrained(model_id, use_cache=True).to(torch_device) processor = AutoFeatureExtractor.from_pretrained(model_id) @@ -837,7 +837,7 @@ def test_integration(self): "32": 1803071, } librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - model_id = "kmhf/mimi-test" # TODO(YL): modify once official + model_id = "kmhf/mimi" # TODO(YL): modify once official processor = AutoFeatureExtractor.from_pretrained(model_id) From 502865ff86b83fbf714d543a414ecdd9e303b521 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Wed, 18 Sep 2024 09:45:11 +0200 Subject: [PATCH 3/8] make fixup --- src/transformers/models/mimi/modeling_mimi.py | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 0eae5cd5e238a2..746262b01ca834 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -215,7 +215,7 @@ def __init__( self.register_buffer("stride", stride, persistent=False) self.register_buffer("kernel_size", kernel_size, persistent=False) self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False) - + # Asymmetric padding required for odd strides self.padding_right = self.padding_total // 2 self.padding_left = self.padding_total - self.padding_right @@ -295,10 +295,10 @@ def __init__( self.causal = config.use_causal_conv self.trim_right_ratio = config.trim_right_ratio self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, groups=groups, bias=bias) - + if not (self.causal or self.trim_right_ratio == 1.0): raise ValueError("`trim_right_ratio` != 1.0 only makes sense for causal convolutions") - + kernel_size = self.conv.kernel_size[0] stride = self.conv.stride[0] padding_total = kernel_size - stride @@ -315,7 +315,7 @@ def __init__( # Asymmetric padding required for odd strides self.padding_right = padding_total // 2 - self.padding_left = padding_total - padding_right + self.padding_left = padding_total - self.padding_right def apply_weight_norm(self): weight_norm = nn.utils.weight_norm @@ -332,7 +332,7 @@ def forward(self, hidden_states): # unpad end = hidden_states.shape[-1] - self.padding_right - hidden_states = hidden_states[..., self.padding_left:end] + hidden_states = hidden_states[..., self.padding_left : end] return hidden_states @@ -1453,29 +1453,29 @@ def __init__(self, config: MimiConfig): self.encoder = MimiEncoder(config) self.encoder_transformer = MimiTransformerModel(config) - + self.downsample = None self.upsample = None - if config.frame_rate != config.encodec_frame_rate: + if config.frame_rate != config.encodec_frame_rate: self.downsample = MimiConv1d( - config, - config.hidden_size, - config.hidden_size, - kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate), - stride=2, - bias=False, - pad_mode="replicate", - ) + config, + config.hidden_size, + config.hidden_size, + kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate), + stride=2, + bias=False, + pad_mode="replicate", + ) self.upsample = MimiConvTranspose1d( - config, - config.hidden_size, - config.hidden_size, - kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate), - stride=2, - bias=False, - groups=config.upsample_groups, - ) + config, + config.hidden_size, + config.hidden_size, + kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate), + stride=2, + bias=False, + groups=config.upsample_groups, + ) self.decoder_transformer = MimiTransformerModel(config) self.decoder = MimiDecoder(config) From c858321648644a0ddc28c2bec02511ebea49c31b Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Fri, 13 Sep 2024 12:10:46 +0200 Subject: [PATCH 4/8] first moshi WIP --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/moshi.md | 53 + src/transformers/__init__.py | 24 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 3 + .../models/auto/tokenization_auto.py | 1 + src/transformers/models/moshi/__init__.py | 61 + .../models/moshi/configuration_moshi.py | 255 ++ .../moshi/convert_moshi_transformers.py | 236 ++ .../models/moshi/modeling_moshi.py | 2683 +++++++++++++++++ tests/models/moshi/__init__.py | 0 tests/models/moshi/test_modeling_moshi.py | 2587 ++++++++++++++++ 13 files changed, 5908 insertions(+) create mode 100644 docs/source/en/model_doc/moshi.md create mode 100644 src/transformers/models/moshi/__init__.py create mode 100644 src/transformers/models/moshi/configuration_moshi.py create mode 100644 src/transformers/models/moshi/convert_moshi_transformers.py create mode 100644 src/transformers/models/moshi/modeling_moshi.py create mode 100644 tests/models/moshi/__init__.py create mode 100644 tests/models/moshi/test_modeling_moshi.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 9d997247b722a5..e2a6757ab855f9 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -724,6 +724,8 @@ title: Mimi - local: model_doc/mms title: MMS + - local: model_doc/moshi + title: Moshi - local: model_doc/musicgen title: MusicGen - local: model_doc/musicgen_melody diff --git a/docs/source/en/model_doc/moshi.md b/docs/source/en/model_doc/moshi.md new file mode 100644 index 00000000000000..3958a4dc5c1104 --- /dev/null +++ b/docs/source/en/model_doc/moshi.md @@ -0,0 +1,53 @@ + + +# Moshi + +## Overview + +The Moshi model was proposed in []() by . + + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + + +## MoshiConfig + +[[autodoc]] MoshiConfig + +## MoshiModel + +[[autodoc]] MoshiModel + - forward + +## MoshiForCausalLM + +[[autodoc]] MoshiForCausalLM + - forward + +## MoshiForConditionalGeneration + +[[autodoc]] MoshiForConditionalGeneration + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e73e5e66a99595..049cd65b535e87 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -595,6 +595,10 @@ "MusicgenConfig", "MusicgenDecoderConfig", ], + "models.moshi": [ + "MoshiConfig", + "MoshiDecoderConfig", + ], "models.musicgen_melody": [ "MusicgenMelodyConfig", "MusicgenMelodyDecoderConfig", @@ -2788,6 +2792,15 @@ "MusicgenProcessor", ] ) + _import_structure["models.moshi"].extend( + [ + "MoshiForCausalLM", + "MoshiForConditionalGeneration", + "MoshiModel", + "MoshiPreTrainedModel", + "MoshiProcessor", + ] + ) _import_structure["models.musicgen_melody"].extend( [ "MusicgenMelodyForCausalLM", @@ -5380,6 +5393,10 @@ MusicgenConfig, MusicgenDecoderConfig, ) + from .models.moshi import ( + MoshiConfig, + MoshiDecoderConfig, + ) from .models.musicgen_melody import ( MusicgenMelodyConfig, MusicgenMelodyDecoderConfig, @@ -7310,6 +7327,13 @@ MusicgenPreTrainedModel, MusicgenProcessor, ) + from .models.moshi import ( + MoshiForCausalLM, + MoshiForConditionalGeneration, + MoshiModel, + MoshiPreTrainedModel, + MoshiProcessor, + ) from .models.musicgen_melody import ( MusicgenMelodyForCausalLM, MusicgenMelodyForConditionalGeneration, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 358fd12ebf222c..37cda7bb29327e 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -163,6 +163,7 @@ mra, mt5, musicgen, + moshi, musicgen_melody, mvp, nemotron, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 15dbc15206fe97..7a330925419484 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -180,6 +180,7 @@ ("mra", "MraConfig"), ("mt5", "MT5Config"), ("musicgen", "MusicgenConfig"), + ("moshi", "MoshiConfig"), ("musicgen_melody", "MusicgenMelodyConfig"), ("mvp", "MvpConfig"), ("nat", "NatConfig"), @@ -483,6 +484,7 @@ ("mra", "MRA"), ("mt5", "MT5"), ("musicgen", "MusicGen"), + ("moshi", "Moshi"), ("musicgen_melody", "MusicGen Melody"), ("mvp", "MVP"), ("nat", "NAT"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 6feb992daf6464..c1dd6803ef22f0 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -171,6 +171,7 @@ ("mra", "MraModel"), ("mt5", "MT5Model"), ("musicgen", "MusicgenModel"), + ("moshi", "MoshiModel"), ("musicgen_melody", "MusicgenMelodyModel"), ("mvp", "MvpModel"), ("nat", "NatModel"), @@ -497,6 +498,7 @@ ("mixtral", "MixtralForCausalLM"), ("mpt", "MptForCausalLM"), ("musicgen", "MusicgenForCausalLM"), + ("moshi", "MoshiForCausalLM"), ("musicgen_melody", "MusicgenMelodyForCausalLM"), ("mvp", "MvpForCausalLM"), ("nemotron", "NemotronForCausalLM"), @@ -1260,6 +1262,7 @@ ("bark", "BarkModel"), ("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"), ("musicgen", "MusicgenForConditionalGeneration"), + ("moshi", "MoshiForConditionalGeneration"), ("musicgen_melody", "MusicgenMelodyForConditionalGeneration"), ("seamless_m4t", "SeamlessM4TForTextToSpeech"), ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToSpeech"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index c8eb06db04a098..c73cb9a3d9b451 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -318,6 +318,7 @@ ), ), ("musicgen", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), + ("moshi", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), ("musicgen_melody", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), ("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)), ("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/moshi/__init__.py b/src/transformers/models/moshi/__init__.py new file mode 100644 index 00000000000000..6d5fe355351744 --- /dev/null +++ b/src/transformers/models/moshi/__init__.py @@ -0,0 +1,61 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = { + "configuration_moshi": [ + "MoshiConfig", + "MoshiDecoderConfig", + ], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_moshi"] = [ + "MoshiForConditionalGeneration", + "MoshiForCausalLM", + "MoshiModel", + "MoshiPreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_moshi import ( + MoshiConfig, + MoshiDecoderConfig, + ) + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_moshi import ( + MoshiForCausalLM, + MoshiForConditionalGeneration, + MoshiModel, + MoshiPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/moshi/configuration_moshi.py b/src/transformers/models/moshi/configuration_moshi.py new file mode 100644 index 00000000000000..f6ea2139032825 --- /dev/null +++ b/src/transformers/models/moshi/configuration_moshi.py @@ -0,0 +1,255 @@ +# coding=utf-8 +# Copyright 2024 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Moshi model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto.configuration_auto import AutoConfig + + +logger = logging.get_logger(__name__) + + +class MoshiDecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`MoshiDecoder`]. It is used to instantiate a + Moshi decoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Moshi + [kyutai/moshiko](https://huggingface.co/kyutai/moshiko) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 2048): + Vocabulary size of the MoshiDecoder model. Defines the number of different tokens that can be + represented by the `inputs_ids` passed when calling [`MoshiDecoder`]. + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of decoder layers. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer block. + ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the decoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, text_encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically, set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_factor (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + scale_embedding (`bool`, *optional*, defaults to `False`): + Scale embeddings by diving by sqrt(hidden_size). + use_cache (`bool`, *optional*, defaults to `True`): + Whether the model should return the last key/values attentions (not used by all models) + num_codebooks (`int`, *optional*, defaults to 4): + The number of parallel codebooks forwarded to the model. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether input and output word embeddings should be tied. + audio_channels (`int`, *optional*, defaults to 1 + Number of channels in the audio data. Either 1 for mono or 2 for stereo. Stereo models generate a separate + audio stream for the left/right output channels. Mono models generate a single audio stream output. + """ + + model_type = "moshi_decoder" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=2048, + max_position_embeddings=2048, + num_hidden_layers=24, + ffn_dim=4096, + num_attention_heads=16, + layerdrop=0.0, + use_cache=True, + activation_function="gelu", + hidden_size=1024, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + initializer_factor=0.02, + scale_embedding=False, + num_codebooks=4, + audio_channels=1, + pad_token_id=2048, + bos_token_id=2048, + eos_token_id=None, + tie_word_embeddings=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.ffn_dim = ffn_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.initializer_factor = initializer_factor + self.layerdrop = layerdrop + self.use_cache = use_cache + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.num_codebooks = num_codebooks + + if audio_channels not in [1, 2]: + raise ValueError(f"Expected 1 (mono) or 2 (stereo) audio channels, got {audio_channels} channels.") + self.audio_channels = audio_channels + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class MoshiConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MoshiModel`]. It is used to instantiate a + Moshi model according to the specified arguments, defining the text encoder, audio encoder and Moshi decoder + configs. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + kwargs (*optional*): + Dictionary of keyword arguments. Notably: + + - **text_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that + defines the text encoder config. + - **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that + defines the audio encoder config. + - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines + the decoder config. + + Example: + + ```python + >>> from transformers import ( + ... MoshiConfig, + ... MoshiDecoderConfig, + ... T5Config, + ... EncodecConfig, + ... MoshiForConditionalGeneration, + ... ) + + >>> # Initializing text encoder, audio encoder, and decoder model configurations + >>> text_encoder_config = T5Config() + >>> audio_encoder_config = EncodecConfig() + >>> decoder_config = MoshiDecoderConfig() + + >>> configuration = MoshiConfig.from_sub_models_config( + ... text_encoder_config, audio_encoder_config, decoder_config + ... ) + + >>> # Initializing a MoshiForConditionalGeneration (with random weights) from the kyutai/moshiko style configuration + >>> model = MoshiForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + >>> config_text_encoder = model.config.text_encoder + >>> config_audio_encoder = model.config.audio_encoder + >>> config_decoder = model.config.decoder + + >>> # Saving the model, including its configuration + >>> model.save_pretrained("moshi-model") + + >>> # loading model and config from pretrained folder + >>> moshi_config = MoshiConfig.from_pretrained("moshi-model") + >>> model = MoshiForConditionalGeneration.from_pretrained("moshi-model", config=moshi_config) + ```""" + + model_type = "moshi" + is_composition = True + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs: + raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config") + + text_encoder_config = kwargs.pop("text_encoder") + text_encoder_model_type = text_encoder_config.pop("model_type") + + audio_encoder_config = kwargs.pop("audio_encoder") + audio_encoder_model_type = audio_encoder_config.pop("model_type") + + decoder_config = kwargs.pop("decoder") + + self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config) + self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config) + self.decoder = MoshiDecoderConfig(**decoder_config) + self.is_encoder_decoder = True + + @classmethod + def from_sub_models_config( + cls, + text_encoder_config: PretrainedConfig, + audio_encoder_config: PretrainedConfig, + decoder_config: MoshiDecoderConfig, + **kwargs, + ): + r""" + Instantiate a [`MoshiConfig`] (or a derived class) from text encoder, audio encoder and decoder + configurations. + + Returns: + [`MoshiConfig`]: An instance of a configuration object + """ + + return cls( + text_encoder=text_encoder_config.to_dict(), + audio_encoder=audio_encoder_config.to_dict(), + decoder=decoder_config.to_dict(), + **kwargs, + ) + + @property + # This is a property because you might want to change the codec model on the fly + def sampling_rate(self): + return self.audio_encoder.sampling_rate + + @property + def _attn_implementation(self): + # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.) + if hasattr(self, "_attn_implementation_internal"): + if self._attn_implementation_internal is None: + # `config.attn_implementation` should never be None, for backward compatibility. + return "eager" + else: + return self._attn_implementation_internal + else: + return "eager" + + @_attn_implementation.setter + def _attn_implementation(self, value): + self._attn_implementation_internal = value + self.decoder._attn_implementation = value diff --git a/src/transformers/models/moshi/convert_moshi_transformers.py b/src/transformers/models/moshi/convert_moshi_transformers.py new file mode 100644 index 00000000000000..7b146c44dfe647 --- /dev/null +++ b/src/transformers/models/moshi/convert_moshi_transformers.py @@ -0,0 +1,236 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Moshi checkpoints from the original repository.""" + +import argparse +from pathlib import Path +from typing import Dict, OrderedDict, Tuple + +import torch +from audiocraft.models import Moshi + +from transformers import ( + AutoFeatureExtractor, + AutoTokenizer, + EncodecModel, + MoshiDecoderConfig, + MoshiForConditionalGeneration, + MoshiProcessor, + T5EncoderModel, +) +from transformers.models.moshi.modeling_moshi import MoshiForCausalLM +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +EXPECTED_MISSING_KEYS = ["model.decoder.embed_positions.weights"] + + +def rename_keys(name): + if "emb" in name: + name = name.replace("emb", "model.decoder.embed_tokens") + if "transformer" in name: + name = name.replace("transformer", "model.decoder") + if "cross_attention" in name: + name = name.replace("cross_attention", "encoder_attn") + if "linear1" in name: + name = name.replace("linear1", "fc1") + if "linear2" in name: + name = name.replace("linear2", "fc2") + if "norm1" in name: + name = name.replace("norm1", "self_attn_layer_norm") + if "norm_cross" in name: + name = name.replace("norm_cross", "encoder_attn_layer_norm") + if "norm2" in name: + name = name.replace("norm2", "final_layer_norm") + if "out_norm" in name: + name = name.replace("out_norm", "model.decoder.layer_norm") + if "linears" in name: + name = name.replace("linears", "lm_heads") + if "condition_provider.conditioners.description.output_proj" in name: + name = name.replace("condition_provider.conditioners.description.output_proj", "enc_to_dec_proj") + return name + + +def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict, Dict]: + """Function that takes the fairseq Moshi state dict and renames it according to the HF + module names. It further partitions the state dict into the decoder (LM) state dict, and that for the + encoder-decoder projection.""" + keys = list(state_dict.keys()) + enc_dec_proj_state_dict = {} + for key in keys: + val = state_dict.pop(key) + key = rename_keys(key) + if "in_proj_weight" in key: + # split fused qkv proj + state_dict[key.replace("in_proj_weight", "q_proj.weight")] = val[:hidden_size, :] + state_dict[key.replace("in_proj_weight", "k_proj.weight")] = val[hidden_size : 2 * hidden_size, :] + state_dict[key.replace("in_proj_weight", "v_proj.weight")] = val[-hidden_size:, :] + elif "enc_to_dec_proj" in key: + enc_dec_proj_state_dict[key[len("enc_to_dec_proj.") :]] = val + else: + state_dict[key] = val + return state_dict, enc_dec_proj_state_dict + + +def decoder_config_from_checkpoint(checkpoint: str) -> MoshiDecoderConfig: + if checkpoint.endswith("small"): + # default config values + hidden_size = 1024 + num_hidden_layers = 24 + num_attention_heads = 16 + elif checkpoint.endswith("medium"): + hidden_size = 1536 + num_hidden_layers = 48 + num_attention_heads = 24 + elif checkpoint.endswith("large"): + hidden_size = 2048 + num_hidden_layers = 48 + num_attention_heads = 32 + else: + raise ValueError( + "Checkpoint should be one of `['small', 'medium', 'large']` for the mono checkpoints, " + "`['facebook/moshi-stereo-small', 'facebook/moshi-stereo-medium', 'facebook/moshi-stereo-large']` " + f"for the stereo checkpoints, or a custom checkpoint with the checkpoint size as a suffix, got {checkpoint}." + ) + + if "stereo" in checkpoint: + audio_channels = 2 + num_codebooks = 8 + else: + audio_channels = 1 + num_codebooks = 4 + + config = MoshiDecoderConfig( + hidden_size=hidden_size, + ffn_dim=hidden_size * 4, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_codebooks=num_codebooks, + audio_channels=audio_channels, + ) + return config + + +@torch.no_grad() +def convert_moshi_checkpoint( + checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu", safe_serialization=False +): + fairseq_model = Moshi.get_pretrained(checkpoint, device=device) + decoder_config = decoder_config_from_checkpoint(checkpoint) + + decoder_state_dict = fairseq_model.lm.state_dict() + decoder_state_dict, enc_dec_proj_state_dict = rename_state_dict( + decoder_state_dict, hidden_size=decoder_config.hidden_size + ) + + text_encoder = T5EncoderModel.from_pretrained("google-t5/t5-base") + audio_encoder = EncodecModel.from_pretrained("facebook/encodec_32khz") + decoder = MoshiForCausalLM(decoder_config).eval() + + # load all decoder weights - expect that we'll be missing embeddings and enc-dec projection + missing_keys, unexpected_keys = decoder.load_state_dict(decoder_state_dict, strict=False) + + for key in missing_keys.copy(): + if key.startswith(("text_encoder", "audio_encoder")) or key in EXPECTED_MISSING_KEYS: + missing_keys.remove(key) + + if len(missing_keys) > 0: + raise ValueError(f"Missing key(s) in state_dict: {missing_keys}") + + if len(unexpected_keys) > 0: + raise ValueError(f"Unexpected key(s) in state_dict: {unexpected_keys}") + + # init the composite model + model = MoshiForConditionalGeneration(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder) + + # load the pre-trained enc-dec projection (from the decoder state dict) + model.enc_to_dec_proj.load_state_dict(enc_dec_proj_state_dict) + + # check we can do a forward pass + input_ids = torch.arange(0, 2 * decoder_config.num_codebooks, dtype=torch.long).reshape(2, -1) + decoder_input_ids = input_ids.reshape(2 * decoder_config.num_codebooks, -1) + + with torch.no_grad(): + logits = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits + + if logits.shape != (2 * decoder_config.num_codebooks, 1, 2048): + raise ValueError("Incorrect shape for logits") + + # now construct the processor + tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") + feature_extractor = AutoFeatureExtractor.from_pretrained( + "facebook/encodec_32khz", padding_side="left", feature_size=decoder_config.audio_channels + ) + + processor = MoshiProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer) + + # set the appropriate bos/pad token ids + model.generation_config.decoder_start_token_id = 2048 + model.generation_config.pad_token_id = 2048 + + # set other default generation config params + model.generation_config.max_length = int(30 * audio_encoder.config.frame_rate) + model.generation_config.do_sample = True + model.generation_config.guidance_scale = 3.0 + + if pytorch_dump_folder is not None: + Path(pytorch_dump_folder).mkdir(exist_ok=True) + logger.info(f"Saving model {checkpoint} to {pytorch_dump_folder}") + model.save_pretrained(pytorch_dump_folder, safe_serialization=safe_serialization) + processor.save_pretrained(pytorch_dump_folder) + + if repo_id: + logger.info(f"Pushing model {checkpoint} to {repo_id}") + model.push_to_hub(repo_id, safe_serialization=safe_serialization) + processor.push_to_hub(repo_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--checkpoint", + default="small", + type=str, + help="Checkpoint size of the Moshi model you'd like to convert. Can be one of: " + "`['small', 'medium', 'large']` for the mono checkpoints, " + "`['facebook/moshi-stereo-small', 'facebook/moshi-stereo-medium', 'facebook/moshi-stereo-large']` " + "for the stereo checkpoints, or a custom checkpoint with the checkpoint size as a suffix.", + ) + parser.add_argument( + "--pytorch_dump_folder", + required=True, + default=None, + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub." + ) + parser.add_argument( + "--device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda." + ) + parser.add_argument( + "--safe_serialization", + action="store_true", + help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).", + ) + + args = parser.parse_args() + convert_moshi_checkpoint(args.checkpoint, args.pytorch_dump_folder, args.push_to_hub) diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py new file mode 100644 index 00000000000000..0ee858fd0b51a5 --- /dev/null +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -0,0 +1,2683 @@ +# coding=utf-8 +# Copyright 2024 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Moshi model.""" + +import copy +import inspect +import math +import random +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...generation.configuration_utils import GenerationConfig, GenerationMode +from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList +from ...generation.stopping_criteria import StoppingCriteriaList +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + ModelOutput, + Seq2SeqLMOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_auto import AutoModel +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +from .configuration_moshi import MoshiConfig, MoshiDecoderConfig + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + +if TYPE_CHECKING: + from ...generation.streamers import BaseStreamer + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MoshiConfig" +_CHECKPOINT_FOR_DOC = "kyutai/moshiko" + + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + # transpose to get (bsz, num_codebooks, seq_len) + input_ids = input_ids.transpose(1, 2) + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + if decoder_start_token_id is None: + raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + + +# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Moshi +class MoshiRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward + # TODO(joao): add me back asap :) + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MoshiMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + # Copied from transformers.models.clip.modeling_clip.CLIPMLP.forward + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# copied from transformers.models.gemma.modeling_gemma.GemmaAttention with Gemma->Moshi +class MoshiAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.scaling = 1 / math.sqrt(config.head_dim) + + if self.hidden_size % self.num_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.rotary_emb = MoshiRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Moshi +class MoshiFlashAttention2(MoshiAttention): + """ + Moshi flash attention module. This module inherits from `MoshiAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (MoshiRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Moshi +class MoshiSdpaAttention(MoshiAttention): + """ + Moshi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MoshiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MoshiAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MoshiModel is using MoshiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +MOSHI_ATTENTION_CLASSES = { + "eager": MoshiAttention, + "flash_attention_2": MoshiFlashAttention2, + "sdpa": MoshiSdpaAttention, +} + +class MoshiDecoderLayer(nn.Module): + def __init__(self, config: MoshiConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MOSHI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = MoshiMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) + + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + + +# Copied from transformers.models.mimi.modeling_mimi.MimiTransformerModel with Mimi->Moshi, TransformerModel->Decoder, TransformerLayer->DecoderLayer +class MoshiDecoder(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoshiTransformerLayer`] + + Args: + config: MoshiConfig + """ + + def __init__(self, config: MoshiConfig): + super().__init__() + + self.layers = nn.ModuleList( + [MoshiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + + self.gradient_checkpointing = False + self.config = config + + def forward( + self, + hidden_states: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Embedded representation that will be contextualized by the model + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache) and not self.training: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + return_legacy_cache = True + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values, use_cache, output_attentions + ) + + hidden_states = hidden_states + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + use_cache: bool, + output_attentions: bool, + ): + if self._attn_implementation == "flash_attention_2": + if attention_mask is not None and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + + # cache_position must be valid here no matter which cache we use + past_seen_tokens = cache_position[0] if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache + if using_sliding_window_cache: + target_length = max(sequence_length, self.config.sliding_window) + # StaticCache + elif using_static_cache: + target_length = past_key_values.get_max_length() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if self.config.sliding_window is not None: + if not using_sliding_window_cache or sequence_length > self.config.sliding_window: + exclude_mask.bitwise_or_( + torch.arange(target_length, device=device) + <= (cache_position.reshape(-1, 1) - self.config.sliding_window) + ) + causal_mask *= exclude_mask + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + + + +# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenPreTrainedModel with Musicgen->Moshi +class MoshiPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MoshiDecoderConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MoshiDecoderLayer", "MoshiAttention"] + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + std = self.config.initializer_factor + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MOSHI_START_DOCSTRING = r""" + + The Moshi model was proposed in [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by + Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi, Alexandre Défossez. It is an + encoder decoder transformer trained on the task of conditional music generation + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MoshiConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +MOSHI_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary, corresponding to the sequence of audio codes. + + Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes, + such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + + + The `decoder_input_ids` will automatically be converted from shape `(batch_size * num_codebooks, + target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If + you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of + frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks, + target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as + `decoder_input_ids`. + + + + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +MOSHI_DECODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): + Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes. + + Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes, + such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + + + + The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks, + target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If + you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of + frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks, + target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as + `input_ids`. + + + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of + the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + + +@add_start_docstrings( + "The bare Moshi decoder model outputting raw hidden-states without any specific head on top.", + MOSHI_START_DOCSTRING, +) +# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenModel with MUSICGEN->MOSHI,Musicgen->Moshi +class MoshiModel(MoshiPreTrainedModel): + def __init__(self, config: MoshiDecoderConfig): + super().__init__(config) + self.decoder = MoshiDecoder(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(MOSHI_DECODER_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + "The Moshi decoder model with a language modelling head on top.", + MOSHI_START_DOCSTRING, +) +# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM with MUSICGEN->MOSHI,Musicgen->Moshi,musicgen->moshi,MusicGen->Moshi +class MoshiForCausalLM(MoshiPreTrainedModel): + def __init__(self, config: MoshiDecoderConfig): + super().__init__(config) + + self.model = MoshiModel(config) + + self.num_codebooks = config.num_codebooks + self.lm_heads = nn.ModuleList( + [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)] + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_heads + + def set_output_embeddings(self, new_embeddings): + self.lm_heads = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @add_start_docstrings_to_model_forward(MOSHI_DECODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + Returns: + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (labels is not None) and (input_ids is None and inputs_embeds is None): + input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + lm_logits = torch.stack([head(hidden_states) for head in self.lm_heads], dim=1) + + loss = None + if labels is not None: + # since encoder hidden states have been concatenated to the decoder hidden states, + # we take the last timestamps corresponding to labels + logits = lm_logits[:, :, -labels.shape[1] :] + + loss_fct = CrossEntropyLoss() + loss = torch.zeros([], device=self.device) + + # per codebook cross-entropy + # -100 labels are ignored + labels = labels.masked_fill(labels == self.config.pad_token_id, -100) + + # per codebook cross-entropy + # ref: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/moshi.py#L242-L243 + for codebook in range(self.config.num_codebooks): + codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1]) + codebook_labels = labels[..., codebook].contiguous().view(-1) + loss += loss_fct(codebook_logits, codebook_labels) + + loss = loss / self.config.num_codebooks + + # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size) + lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:]) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=True, + delay_pattern_mask=None, + guidance_scale=None, + **kwargs, + ): + if delay_pattern_mask is None: + input_ids, delay_pattern_mask = self.build_delay_pattern_mask( + input_ids, + pad_token_id=self.generation_config.pad_token_id, + max_length=self.generation_config.max_length, + ) + + # apply the delay pattern mask + input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask) + + if guidance_scale is not None and guidance_scale > 1: + # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these + # before sampling) + input_ids = input_ids.repeat((2, 1)) + if attention_mask is not None: + attention_mask = attention_mask.repeat((2, 1)) + + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + "head_mask": head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "past_key_values": past_key_values, + "use_cache": use_cache, + } + + def build_delay_pattern_mask(self, input_ids: torch.LongTensor, pad_token_id: int, max_length: int = None): + """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by + one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there + are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, + seq_len)`: + - [P, -1, -1, -1, -1, P, P, P] + - [P, P, -1, -1, -1, -1, P, P] + - [P, P, P, -1, -1, -1, -1, P] + - [P, P, P, P, -1, -1, -1, -1] + where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include + a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the + mask is set to the value in the prompt: + - [P, a, b, -1, -1, P, P, P] + - [P, P, c, d, -1, -1, P, P] + - [P, P, P, e, f, -1, -1, P] + - [P, P, P, P, g, h, -1, -1] + where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1 + tokens in our prediction. + """ + # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len) + input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1]) + bsz, num_codebooks, seq_len = input_ids.shape + + max_length = max_length if max_length is not None else self.generation_config.max_length + input_ids_shifted = ( + torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1 + ) + + channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks + # we only apply the mask if we have a large enough seq len - otherwise we return as is + if max_length < 2 * channel_codebooks - 1: + return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1) + + # fill the shifted ids with the prompt entries, offset by the codebook idx + for codebook in range(channel_codebooks): + if self.config.audio_channels == 1: + # mono channel - loop over the codebooks one-by-one + input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook] + else: + # left/right channels are interleaved in the generated codebooks, so handle one then the other + input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook] + input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1] + + # construct a pattern mask that indicates the positions of padding tokens for each codebook + # first fill the upper triangular part (the EOS padding) + delay_pattern = torch.triu( + torch.ones((channel_codebooks, max_length), dtype=torch.bool), diagonal=max_length - channel_codebooks + 1 + ) + # then fill the lower triangular part (the BOS padding) + delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.bool)) + + if self.config.audio_channels == 2: + # for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion + delay_pattern = delay_pattern.repeat_interleave(2, dim=0) + + mask = ~delay_pattern.to(input_ids.device) + input_ids = mask * input_ids_shifted + ~mask * pad_token_id + + # find the first position to start generating - this is the first place we have the -1 token + # and will always be in the first codebook (since it has no codebook offset) + first_codebook_ids = input_ids[:, 0, :] + start_ids = (first_codebook_ids == -1).nonzero()[:, 1] + if len(start_ids) > 0: + first_start_id = min(start_ids) + else: + # we have no tokens that need to be filled - return entire matrix of input ids + first_start_id = seq_len + + # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len) + pattern_mask = input_ids.reshape(bsz * num_codebooks, -1) + input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1) + return input_ids, pattern_mask + + @staticmethod + def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): + """Apply a delay pattern mask to the decoder input ids, only preserving predictions where + the mask is set to -1, and otherwise setting to the value detailed in the mask.""" + seq_len = input_ids.shape[-1] + decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len] + input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask) + return input_ids + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + synced_gpus: Optional[bool] = None, + streamer: Optional["BaseStreamer"] = None, + **kwargs, + ): + """ + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + kwargs (`Dict[str, Any]`, *optional*): + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GenerateDecoderOnlyOutput`], + - [`~generation.GenerateBeamDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects + if generation_config is None: + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + generation_config.validate() + self._validate_model_kwargs(model_kwargs.copy()) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + + # 3. Define model inputs` + input_ids, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = input_ids.shape[0] // self.num_codebooks + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device) + + # 4. Define other model kwargs + model_kwargs["use_cache"] = generation_config.use_cache + model_kwargs["guidance_scale"] = generation_config.guidance_scale + + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor + ) + + # 5. Prepare `max_length` depending on other stopping criteria. + input_ids_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=input_ids, + input_ids_length=input_ids_length, + ) + + # 6. Prepare `input_ids` which will be used for auto-regressive generation + # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Moshi) + input_ids, delay_pattern_mask = self.build_delay_pattern_mask( + input_ids, + pad_token_id=generation_config._decoder_start_token_tensor, + max_length=generation_config.max_length, + ) + + if streamer is not None: + streamer.put(input_ids.cpu()) + + # stash the delay mask so that we don't have to recompute it in each forward pass + model_kwargs["delay_pattern_mask"] = delay_pattern_mask + + # 7. determine generation mode + generation_mode = generation_config.get_generation_mode() + + # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG) + if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: + logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) + generation_config.guidance_scale = None + + # 9. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=None, + logits_processor=logits_processor, + device=input_ids.device, + ) + + # 10. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + + if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): + # expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + **model_kwargs, + ) + + # 11. run sample + outputs = self._sample( + input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + + else: + raise ValueError( + "Got incompatible mode for generation, should be one of greedy or sampling. " + "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`." + ) + + if generation_config.return_dict_in_generate: + output_ids = outputs.sequences + else: + output_ids = outputs + + # apply the pattern mask to the final ids + output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"]) + + # revert the pattern delay mask by filtering the pad token id + output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape( + batch_size, self.num_codebooks, -1 + ) + + if generation_config.return_dict_in_generate: + outputs.sequences = output_ids + return outputs + else: + return output_ids + + +@add_start_docstrings( + "The composite Moshi model with a text encoder, audio encoder and Moshi decoder, " + "for music generation tasks with one or both of text and audio prompts.", + MOSHI_START_DOCSTRING, +) +# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForConditionalGeneration with MUSICGEN->MOSHI,Musicgen->Moshi,musicgen->moshi,MusicGen->Moshi,facebook/musicgen-small->kyutai/moshiko +class MoshiForConditionalGeneration(PreTrainedModel): + config_class = MoshiConfig + base_model_prefix = "encoder_decoder" + main_input_name = "input_ids" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def __init__( + self, + config: Optional[MoshiConfig] = None, + text_encoder: Optional[PreTrainedModel] = None, + audio_encoder: Optional[PreTrainedModel] = None, + decoder: Optional[MoshiForCausalLM] = None, + ): + if config is None and (text_encoder is None or audio_encoder is None or decoder is None): + raise ValueError( + "Either a configuration has to be provided, or all three of text encoder, audio encoder and Moshi decoder." + ) + if config is None: + config = MoshiConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config) + else: + if not isinstance(config, self.config_class): + raise ValueError(f"Config: {config} has to be of type {self.config_class}") + + if config.decoder.cross_attention_hidden_size is not None: + if config.decoder.cross_attention_hidden_size != config.text_encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the Moshi decoder's configuration, it has to be equal" + f" to the text encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" + f" `config.decoder.cross_attention_hidden_size` and {config.text_encoder.hidden_size} for" + " `config.text_encoder.hidden_size`." + ) + + # initialize with config + super().__init__(config) + + if text_encoder is None: + from ..auto.modeling_auto import AutoModelForTextEncoding + + text_encoder = AutoModelForTextEncoding.from_config(config.text_encoder) + + if audio_encoder is None: + from ..auto.modeling_auto import AutoModel + + audio_encoder = AutoModel.from_config(config.audio_encoder) + + if decoder is None: + decoder = MoshiForCausalLM(config.decoder) + + self.text_encoder = text_encoder + self.audio_encoder = audio_encoder + self.decoder = decoder + + if self.text_encoder.config.to_dict() != self.config.text_encoder.to_dict(): + logger.warning( + f"Config of the text_encoder: {self.text_encoder.__class__} is overwritten by shared text_encoder config:" + f" {self.config.text_encoder}" + ) + if self.audio_encoder.config.to_dict() != self.config.audio_encoder.to_dict(): + logger.warning( + f"Config of the audio_encoder: {self.audio_encoder.__class__} is overwritten by shared audio_encoder config:" + f" {self.config.audio_encoder}" + ) + if self.decoder.config.to_dict() != self.config.decoder.to_dict(): + logger.warning( + f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" + f" {self.config.decoder}" + ) + + # make sure that the individual model's config refers to the shared config + # so that the updates to the config will be synced + self.text_encoder.config = self.config.text_encoder + self.audio_encoder.config = self.config.audio_encoder + self.decoder.config = self.config.decoder + + # text encoder outputs might need to be projected to different dimension for decoder + if ( + self.text_encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size) + + if self.text_encoder.get_output_embeddings() is not None: + raise ValueError( + f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head" + ) + + decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys()) + if "encoder_hidden_states" not in decoder_signature: + raise ValueError( + "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the " + "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350" + ) + + # tie text encoder, decoder weights if config set accordingly + self.tie_weights() + + def tie_weights(self): + # tie text encoder & decoder if needed + if self.config.tie_encoder_decoder: + # tie text encoder and decoder base model + decoder_base_model_prefix = self.decoder.base_model_prefix + tied_weights = self._tie_encoder_decoder_weights( + self.text_encoder, + self.decoder._modules[decoder_base_model_prefix], + self.decoder.base_model_prefix, + "text_encoder", + ) + # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class + # attributed not an instance member, therefore modifying it will modify the entire class + # Leading to issues on subsequent calls by different tests or subsequent calls. + self._dynamic_tied_weights_keys = tied_weights + + def get_audio_encoder(self): + return self.audio_encoder + + def get_text_encoder(self): + return self.text_encoder + + def get_encoder(self): + # get the text encoder to compute the encoder hidden-states for generation + return self.get_text_encoder() + + def get_decoder(self): + return self.decoder + + def get_input_embeddings(self): + return self.text_encoder.get_input_embeddings() + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + return self.decoder.set_output_embeddings(new_embeddings) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" + Example: + + ```python + >>> from transformers import MoshiForConditionalGeneration + + >>> model = MoshiForConditionalGeneration.from_pretrained("kyutai/moshiko") + ```""" + + # At the moment fast initialization is not supported for composite models + if kwargs.get("_fast_init", False): + logger.warning( + "Fast initialization is currently not supported for MoshiForConditionalGeneration. " + "Falling back to slow initialization..." + ) + kwargs["_fast_init"] = False + + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + @classmethod + def from_sub_models_pretrained( + cls, + text_encoder_pretrained_model_name_or_path: str = None, + audio_encoder_pretrained_model_name_or_path: str = None, + decoder_pretrained_model_name_or_path: str = None, + *model_args, + **kwargs, + ) -> PreTrainedModel: + r""" + Instantiate a text encoder, an audio encoder, and a Moshi decoder from one, two or three base classes of the + library from pretrained model checkpoints. + + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you need to first set it back in training mode with `model.train()`. + + Params: + text_encoder_pretrained_model_name_or_path (`str`, *optional*): + Information necessary to initiate the text encoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + + audio_encoder_pretrained_model_name_or_path (`str`, *optional*): + Information necessary to initiate the audio encoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + + decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): + Information necessary to initiate the decoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + + model_args (remaining positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). + + - To update the text encoder configuration, use the prefix *text_encoder_* for each configuration + parameter. + - To update the audio encoder configuration, use the prefix *audio_encoder_* for each configuration + parameter. + - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a `config` is provided or automatically loaded. + + Example: + + ```python + >>> from transformers import MoshiForConditionalGeneration + + >>> # initialize a moshi model from a t5 text encoder, encodec audio encoder, and moshi decoder + >>> model = MoshiForConditionalGeneration.from_sub_models_pretrained( + ... text_encoder_pretrained_model_name_or_path="google-t5/t5-base", + ... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz", + ... decoder_pretrained_model_name_or_path="kyutai/moshiko", + ... ) + >>> # saving model after fine-tuning + >>> model.save_pretrained("./moshi-ft") + >>> # load fine-tuned model + >>> model = MoshiForConditionalGeneration.from_pretrained("./moshi-ft") + ```""" + + kwargs_text_encoder = { + argument[len("text_encoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("text_encoder_") + } + + kwargs_audio_encoder = { + argument[len("audio_encoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("audio_encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # remove text encoder, audio encoder and decoder kwargs from kwargs + for key in kwargs_text_encoder.keys(): + del kwargs["text_encoder_" + key] + for key in kwargs_audio_encoder.keys(): + del kwargs["audio_encoder_" + key] + for key in kwargs_decoder.keys(): + del kwargs["decoder_" + key] + + # Load and initialize the encoder and decoder + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + text_encoder = kwargs_text_encoder.pop("model", None) + if text_encoder is None: + if text_encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `text_encoder_model` is not defined as an argument, a `text_encoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_text_encoder: + encoder_config, kwargs_text_encoder = AutoConfig.from_pretrained( + text_encoder_pretrained_model_name_or_path, **kwargs_text_encoder, return_unused_kwargs=True + ) + + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: + logger.info( + f"Initializing {text_encoder_pretrained_model_name_or_path} as a text_encoder model " + "from a decoder model. Cross-attention and casual mask are disabled." + ) + encoder_config.is_decoder = False + encoder_config.add_cross_attention = False + + kwargs_text_encoder["config"] = encoder_config + + text_encoder = AutoModel.from_pretrained( + text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder + ) + + audio_encoder = kwargs_audio_encoder.pop("model", None) + if audio_encoder is None: + if audio_encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `audio_encoder_model` is not defined as an argument, an `audio_encoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_audio_encoder: + encoder_config, kwargs_audio_encoder = AutoConfig.from_pretrained( + audio_encoder_pretrained_model_name_or_path, **kwargs_audio_encoder, return_unused_kwargs=True + ) + + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: + logger.info( + f"Initializing {audio_encoder_pretrained_model_name_or_path} as an audio_encoder model " + "from a decoder model. Cross-attention and casual mask are disabled." + ) + encoder_config.is_decoder = False + encoder_config.add_cross_attention = False + + kwargs_audio_encoder["config"] = encoder_config + + audio_encoder = AutoModel.from_pretrained( + audio_encoder_pretrained_model_name_or_path, *model_args, **kwargs_audio_encoder + ) + + decoder = kwargs_decoder.pop("model", None) + if decoder is None: + if decoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_decoder: + decoder_config, kwargs_decoder = AutoConfig.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + ) + + if isinstance(decoder_config, MoshiConfig): + decoder_config = decoder_config.decoder + + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: + logger.info( + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" + f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" + f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + ) + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + kwargs_decoder["config"] = decoder_config + + if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: + logger.warning( + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_sub_models_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_sub_models_pretrained(...)`" + ) + + decoder = MoshiForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + + # instantiate config with corresponding kwargs + config = MoshiConfig.from_sub_models_config( + text_encoder.config, audio_encoder.config, decoder.config, **kwargs + ) + return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config) + + @add_start_docstrings_to_model_forward(MOSHI_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + input_values: Optional[torch.FloatTensor] = None, + padding_mask: Optional[torch.BoolTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, Seq2SeqLMOutput]: + r""" + Returns: + + Examples: + ```python + >>> from transformers import AutoProcessor, MoshiForConditionalGeneration + >>> import torch + + >>> processor = AutoProcessor.from_pretrained("kyutai/moshiko") + >>> model = MoshiForConditionalGeneration.from_pretrained("kyutai/moshiko") + + >>> inputs = processor( + ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"], + ... padding=True, + ... return_tensors="pt", + ... ) + + >>> pad_token_id = model.generation_config.pad_token_id + >>> decoder_input_ids = ( + ... torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long) + ... * pad_token_id + ... ) + + >>> logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits + >>> logits.shape # (bsz * num_codebooks, tgt_len, vocab_size) + torch.Size([8, 1, 2048]) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + kwargs_text_encoder = { + argument[len("text_encoder_")]: value + for argument, value in kwargs.items() + if argument.startswith("text_encoder_") + } + + kwargs_audio_encoder = { + argument[len("audio_encoder_")]: value + for argument, value in kwargs.items() + if argument.startswith("audio_encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + if encoder_outputs is None: + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs_text_encoder, + ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if ( + self.text_encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + if attention_mask is not None: + encoder_hidden_states = encoder_hidden_states * attention_mask[..., None] + + if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): + decoder_input_ids = shift_tokens_right( + labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id + ) + + elif decoder_input_ids is None and decoder_inputs_embeds is None: + audio_encoder_outputs = self.audio_encoder( + input_values=input_values, + padding_mask=padding_mask, + **kwargs_audio_encoder, + ) + audio_codes = audio_encoder_outputs.audio_codes + frames, bsz, codebooks, seq_len = audio_codes.shape + if frames != 1: + raise ValueError( + f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is " + "disabled by setting `chunk_length=None` in the audio encoder." + ) + + if self.config.decoder.audio_channels == 2 and audio_codes.shape[2] == self.decoder.num_codebooks // 2: + # mono input through encodec that we convert to stereo + audio_codes = audio_codes.repeat_interleave(2, dim=2) + + decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + past_key_values=past_key_values, + return_dict=return_dict, + labels=labels, + **kwargs_decoder, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqLMOutput( + loss=decoder_outputs.loss, + logits=decoder_outputs.logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_attention_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + decoder_delay_pattern_mask=None, + guidance_scale=None, + **kwargs, + ): + if decoder_delay_pattern_mask is None: + decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( + decoder_input_ids, + self.generation_config.pad_token_id, + max_length=self.generation_config.max_length, + ) + + # apply the delay pattern mask + decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask) + + if guidance_scale is not None and guidance_scale > 1: + # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these + # before sampling) + decoder_input_ids = decoder_input_ids.repeat((2, 1)) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) + + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + model_input_name: str, + model_kwargs: Dict[str, torch.Tensor], + decoder_start_token_id: int = None, + bos_token_id: int = None, + device: torch.device = None, + ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + + # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + elif "input_ids" in model_kwargs and model_input_name != "input_ids": + decoder_input_ids = model_kwargs.pop("input_ids") + else: + decoder_input_ids = None + + # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + if device is None: + device = self.device + decoder_input_ids_start = ( + torch.ones((batch_size * self.decoder.num_codebooks, 1), dtype=torch.long, device=device) + * decoder_start_token_id + ) + + # no user input -> use decoder_start_token_id as decoder_input_ids + if decoder_input_ids is None: + decoder_input_ids = decoder_input_ids_start + + # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust + # decoder_attention_mask if provided) + elif (decoder_input_ids[..., 0] != decoder_start_token_id).all().item(): + decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + decoder_attention_mask = torch.cat( + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + dim=-1, + ) + model_kwargs["decoder_attention_mask"] = decoder_attention_mask + + return decoder_input_ids, model_kwargs + + def _prepare_text_encoder_kwargs_for_generation( + self, + inputs_tensor: torch.Tensor, + model_kwargs, + model_input_name: Optional[str], + generation_config: GenerationConfig, + ) -> Dict[str, Any]: + # 1. get text encoder + encoder = self.get_text_encoder() + # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device + # as the inputs. + if hasattr(encoder, "_hf_hook"): + encoder._hf_hook.io_same_device = True + + # 2. Prepare encoder args and encoder kwargs from model kwargs. + irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() + if not any(argument.startswith(p) for p in irrelevant_prefix) + } + encoder_signature = set(inspect.signature(encoder.forward).parameters) + encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature + if not encoder_accepts_wildcard: + encoder_kwargs = { + argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature + } + encoder_kwargs["output_attentions"] = generation_config.output_attentions + encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states + guidance_scale = generation_config.guidance_scale + + # 3. make sure that encoder returns `ModelOutput` + model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name + encoder_kwargs["return_dict"] = True + encoder_kwargs[model_input_name] = inputs_tensor + last_hidden_state = encoder(**encoder_kwargs).last_hidden_state + + # for classifier free guidance we need to add a 'null' input to our encoder hidden states + if guidance_scale is not None and guidance_scale > 1: + last_hidden_state = torch.concatenate([last_hidden_state, torch.zeros_like(last_hidden_state)], dim=0) + if "attention_mask" in model_kwargs: + model_kwargs["attention_mask"] = torch.concatenate( + [model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0 + ) + + model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state) + + return model_kwargs + + def _prepare_audio_encoder_kwargs_for_generation( + self, input_values, model_kwargs, model_input_name: Optional[str] = None + ): + # 1. get audio encoder + encoder = self.get_audio_encoder() + # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device + # as the inputs. + if hasattr(encoder, "_hf_hook"): + encoder._hf_hook.io_same_device = True + + # 2. Prepare encoder args and encoder kwargs from model kwargs. + irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() + if not any(argument.startswith(p) for p in irrelevant_prefix) + } + encoder_signature = set(inspect.signature(encoder.forward).parameters) + encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature + if not encoder_accepts_wildcard: + encoder_kwargs = { + argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature + } + + # 3. make sure that encoder returns `ModelOutput` + model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name + encoder_kwargs["return_dict"] = True + + if self.decoder.config.audio_channels == 1: + encoder_kwargs[model_input_name] = input_values + audio_encoder_outputs = encoder.encode(**encoder_kwargs) + audio_codes = audio_encoder_outputs.audio_codes + audio_scales = audio_encoder_outputs.audio_scales + + frames, bsz, codebooks, seq_len = audio_codes.shape + + else: + if input_values.shape[1] != 2: + raise ValueError( + f"Expected stereo audio (2-channels) but example has {input_values.shape[1]} channel." + ) + + encoder_kwargs[model_input_name] = input_values[:, :1, :] + audio_encoder_outputs_left = encoder.encode(**encoder_kwargs) + audio_codes_left = audio_encoder_outputs_left.audio_codes + audio_scales_left = audio_encoder_outputs_left.audio_scales + + encoder_kwargs[model_input_name] = input_values[:, 1:, :] + audio_encoder_outputs_right = encoder.encode(**encoder_kwargs) + audio_codes_right = audio_encoder_outputs_right.audio_codes + audio_scales_right = audio_encoder_outputs_right.audio_scales + + frames, bsz, codebooks, seq_len = audio_codes_left.shape + # copy alternating left/right channel codes into stereo codebook + audio_codes = audio_codes_left.new_ones((frames, bsz, 2 * codebooks, seq_len)) + + audio_codes[:, :, ::2, :] = audio_codes_left + audio_codes[:, :, 1::2, :] = audio_codes_right + + if audio_scales_left != [None] or audio_scales_right != [None]: + audio_scales = torch.stack([audio_scales_left, audio_scales_right], dim=1) + else: + audio_scales = [None] * bsz + + if frames != 1: + raise ValueError( + f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is " + "disabled by setting `chunk_length=None` in the audio encoder." + ) + + decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) + + model_kwargs["decoder_input_ids"] = decoder_input_ids + model_kwargs["audio_scales"] = audio_scales + return model_kwargs + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id) + + def resize_token_embeddings(self, *args, **kwargs): + raise NotImplementedError( + "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the" + " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or" + " model.decoder.resize_token_embeddings(...))" + ) + + def freeze_audio_encoder(self): + """ + Freeze the audio encoder weights. + """ + for param in self.audio_encoder.parameters(): + param.requires_grad = False + self.audio_encoder._requires_grad = False + + def freeze_text_encoder(self): + """ + Freeze the text encoder weights. + """ + for param in self.text_encoder.parameters(): + param.requires_grad = False + self.text_encoder._requires_grad = False + + def _maybe_initialize_input_ids_for_generation( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.LongTensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + + encoder_outputs = model_kwargs.get("encoder_outputs") + if encoder_outputs is not None: + # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding + shape = encoder_outputs[0].size()[:-1] + return torch.ones(shape, dtype=torch.long, device=self.device) * -100 + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, torch.Tensor): + batch_size = value.shape[0] + break + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id + + def _get_decoder_start_token_id( + self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None + ) -> int: + decoder_start_token_id = ( + decoder_start_token_id + if decoder_start_token_id is not None + else self.generation_config.decoder_start_token_id + ) + bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id + + if decoder_start_token_id is not None: + return decoder_start_token_id + elif bos_token_id is not None: + return bos_token_id + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + synced_gpus: Optional[bool] = None, + streamer: Optional["BaseStreamer"] = None, + **kwargs, + ): + """ + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + kwargs (`Dict[str, Any]`, *optional*): + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GenerateDecoderOnlyOutput`], + - [`~generation.GenerateBeamDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects + if generation_config is None: + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + generation_config.validate() + self._validate_model_kwargs(model_kwargs.copy()) + + if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) is tuple: + # wrap the unconditional outputs as a BaseModelOutput for compatibility with the rest of generate + model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0]) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + + # 3. Define model inputs + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device) + + # 4. Define other model kwargs + model_kwargs["use_cache"] = generation_config.use_cache + model_kwargs["guidance_scale"] = generation_config.guidance_scale + + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor + ) + + if "encoder_outputs" not in model_kwargs: + # encoder_outputs are created and added to `model_kwargs` + model_kwargs = self._prepare_text_encoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name, generation_config + ) + + if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs: + model_kwargs = self._prepare_audio_encoder_kwargs_for_generation( + model_kwargs["input_values"], + model_kwargs, + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=batch_size, + model_input_name=model_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=generation_config._decoder_start_token_tensor, + bos_token_id=generation_config._bos_token_tensor, + device=inputs_tensor.device, + ) + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + ) + + # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Moshi) + input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( + input_ids, + pad_token_id=generation_config._decoder_start_token_tensor, + max_length=generation_config.max_length, + ) + # stash the delay mask so that we don't have to recompute in each forward pass + model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask + + # input_ids are ready to be placed on the streamer (if used) + if streamer is not None: + streamer.put(input_ids.cpu()) + + # 7. determine generation mode + generation_mode = generation_config.get_generation_mode() + + # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG) + if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: + logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) + generation_config.guidance_scale = None + + # 9. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=None, + logits_processor=logits_processor, + device=input_ids.device, + ) + + # 10. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + + if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): + # expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 11. run sample + outputs = self._sample( + input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + + else: + raise ValueError( + "Got incompatible mode for generation, should be one of greedy or sampling. " + "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`." + ) + + if generation_config.return_dict_in_generate: + output_ids = outputs.sequences + else: + output_ids = outputs + + # apply the pattern mask to the final ids + output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) + + # revert the pattern delay mask by filtering the pad token id + output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape( + batch_size, self.decoder.num_codebooks, -1 + ) + + # append the frame dimension back to the audio codes + output_ids = output_ids[None, ...] + + audio_scales = model_kwargs.get("audio_scales") + if audio_scales is None: + audio_scales = [None] * batch_size + + if self.decoder.config.audio_channels == 1: + output_values = self.audio_encoder.decode( + output_ids, + audio_scales=audio_scales, + ).audio_values + else: + codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales) + output_values_left = codec_outputs_left.audio_values + + codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales) + output_values_right = codec_outputs_right.audio_values + + output_values = torch.cat([output_values_left, output_values_right], dim=1) + + if generation_config.return_dict_in_generate: + outputs.sequences = output_values + return outputs + else: + return output_values \ No newline at end of file diff --git a/tests/models/moshi/__init__.py b/tests/models/moshi/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py new file mode 100644 index 00000000000000..218385478b999c --- /dev/null +++ b/tests/models/moshi/test_modeling_moshi.py @@ -0,0 +1,2587 @@ +# coding=utf-8 +# Copyright 2021, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Moshi model.""" + +import copy +import inspect +import math +import tempfile +import unittest + +import numpy as np +from parameterized import parameterized +from pytest import mark + +from transformers import ( + EncodecConfig, + MoshiConfig, + MoshiDecoderConfig, + MoshiProcessor, + PretrainedConfig, + T5Config, +) +from transformers.testing_utils import ( + is_torch_available, + require_flash_attn, + require_torch, + require_torch_accelerator, + require_torch_fp16, + require_torch_gpu, + require_torch_sdpa, + slow, + torch_device, +) +from transformers.utils import cached_property, is_torch_bf16_available_on_device, is_torch_fp16_available_on_device + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + MoshiForCausalLM, + MoshiForConditionalGeneration, + MoshiModel, + set_seed, + ) + from transformers.generation import ( + GenerateDecoderOnlyOutput, + GenerateEncoderDecoderOutput, + ) + + +def _config_zero_init(config): + configs_no_init = copy.deepcopy(config) + for key in configs_no_init.__dict__.keys(): + if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key: + setattr(configs_no_init, key, 1e-10) + if isinstance(getattr(configs_no_init, key, None), PretrainedConfig): + no_init_subconfig = _config_zero_init(getattr(configs_no_init, key)) + setattr(configs_no_init, key, no_init_subconfig) + return configs_no_init + + +def prepare_moshi_decoder_inputs_dict( + config, + input_ids, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + cross_attn_head_mask=None, +): + if attention_mask is None: + attention_mask = input_ids.reshape(-1, config.num_codebooks, input_ids.shape[-1])[:, 0, :] + attention_mask = attention_mask.ne(config.pad_token_id) + if head_mask is None: + head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device) + if encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=torch_device) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + "head_mask": head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + } + + +class MoshiDecoderTester: + def __init__( + self, + parent, + batch_size=4, # need batch_size != num_hidden_layers + seq_length=7, + is_training=True, + vocab_size=99, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=100, + pad_token_id=99, + bos_token_id=99, + num_codebooks=4, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.num_codebooks = num_codebooks + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size) + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + + config = self.get_config() + inputs_dict = prepare_moshi_decoder_inputs_dict( + config, + input_ids, + encoder_hidden_states=encoder_hidden_states, + ) + return config, inputs_dict + + def get_config(self): + config = MoshiDecoderConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + d_ff=self.intermediate_size, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.bos_token_id, + bos_token_id=self.bos_token_id, + num_codebooks=self.num_codebooks, + tie_word_embeddings=False, + ) + return config + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + +@require_torch +class MoshiDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (MoshiModel, MoshiForCausalLM) if is_torch_available() else () + greedy_sample_model_classes = ( + (MoshiForCausalLM,) if is_torch_available() else () + ) # we don't want to run all the generation tests, only a specific subset + test_pruning = False + test_resize_embeddings = False + + def setUp(self): + self.model_tester = MoshiDecoderTester(self) + self.config_tester = ConfigTester(self, config_class=MoshiDecoderConfig, hidden_size=16) + + def test_config(self): + self.config_tester.run_common_tests() + + # special case for labels + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks), + dtype=torch.long, + device=torch_device, + ) + return inputs_dict + + def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None): + if not self.model_tester.is_training: + self.skipTest(reason="model_tester.is_training is set to False") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_cache = False + config.return_dict = True + model = MoshiForCausalLM(config) + + model.to(torch_device) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + model.train() + + # Contrarily to the initial method, we don't unfreeze freezed parameters. + # Indeed, sinusoidal position embeddings have frozen weights that should stay frozen. + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + inputs = self._prepare_for_class(inputs_dict, MoshiForCausalLM, return_labels=True) + loss = model(**inputs).loss + loss.backward() + optimizer.step() + + for k, v in model.named_parameters(): + if v.requires_grad: + self.assertTrue(v.grad is not None, f"{k} in {MoshiForCausalLM.__name__} has no gradient!") + + # override since we have to compute the input embeddings over codebooks + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + + embed_tokens = model.get_input_embeddings() + + input_ids = input_ids.reshape(-1, config.num_codebooks, input_ids.shape[-1]) + + inputs["inputs_embeds"] = sum( + [embed_tokens[codebook](input_ids[:, codebook]) for codebook in range(config.num_codebooks)] + ) + + with torch.no_grad(): + model(**inputs)[0] + + # override since we have embeddings / LM heads over multiple codebooks + def test_model_get_set_embeddings(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + first_embed = model.get_input_embeddings()[0] + self.assertIsInstance(first_embed, torch.nn.Embedding) + lm_heads = model.get_output_embeddings() + self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear)) + + @unittest.skip(reason="Moshi does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="Moshi does not support all arguments tested") + def test_model_outputs_equivalence(self): + pass + + @unittest.skip(reason="Moshi has multiple inputs embeds and lm heads that should not be tied") + def test_tie_model_weights(self): + pass + + @unittest.skip(reason="Moshi has multiple inputs embeds and lm heads that should not be tied") + def test_tied_weights_keys(self): + pass + + def _get_input_ids_and_config(self, batch_size=2): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + input_ids = inputs_dict["input_ids"] + + # take max batch_size + sequence_length = input_ids.shape[-1] + input_ids = input_ids[: batch_size * config.num_codebooks, :] + + attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) + return config, input_ids, attention_mask + + def _get_logits_processor_kwargs(self, do_sample=False): + logits_processor_kwargs = {} + return logits_processor_kwargs + + def test_greedy_generate_stereo_outputs(self): + for model_class in self.greedy_sample_model_classes: + config, input_ids, attention_mask = self._get_input_ids_and_config() + config.audio_channels = 2 + model = model_class(config).to(torch_device).eval() + output_generate = self._greedy_generate( + model=model, + input_ids=input_ids.to(torch_device), + attention_mask=attention_mask.to(torch_device), + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) + + self.assertNotIn(config.pad_token_id, output_generate) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # Ignore copy + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + # Ignore copy + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + + # Ignore copy + outputs = model(dummy_input, output_hidden_states=True) + # Ignore copy + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + # Ignore copy + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + + # check with inference + dropout + model.train() + _ = model_fa(dummy_input, **other_inputs) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # Ignore copy + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + # Ignore copy + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + if model.config.is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input) + + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + else: + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + # Ignore copy + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_left_padding(self): + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # make sure we do left padding + dummy_attention_mask[:, :-1] = 0 + dummy_attention_mask[:, -1:] = 1 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_padding_right(self): + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # make sure we do right padding + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_use_cache(self): + max_new_tokens = 30 + + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + + if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device): + self.skipTest( + f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" + ) + + # Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead. + if torch_dtype == "float16": + torch_dtype = torch.float16 + elif torch_dtype == "bfloat16": + torch_dtype = torch.bfloat16 + elif torch_dtype == "float32": + torch_dtype = torch.float32 + + atols = { + ("cpu", False, torch.float32): 1e-6, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-6, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-6, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-6, + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + rtols = { + ("cpu", False, torch.float32): 1e-4, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-4, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-4, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-4, + ("cuda", True, torch.bfloat16): 3e-2, + ("cuda", True, torch.float16): 5e-3, + } + + def get_mean_reldiff(failcase, x, ref, atol, rtol): + return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + is_encoder_decoder = model.config.is_encoder_decoder + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) + model_sdpa = model_sdpa.eval().to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch_dtype, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa and model_sdpa.config.model_type != "falcon": + raise ValueError("The SDPA model should have SDPA attention layers") + + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model, + # but it would be nicer to have an efficient way to use parameterized.expand + fail_cases = [] + for padding_side in ["left", "right"]: + for use_mask in [False, True]: + for batch_size in [1, 5]: + # Ignore copy + batch_size_input_ids = self.model_tester.num_codebooks * batch_size + dummy_input = inputs_dict[model.main_input_name] + + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + dummy_input = dummy_input.to(torch_dtype) + + # Ignore copy + dummy_input = dummy_input[:batch_size_input_ids] + # Ignore copy + if dummy_input.shape[0] != batch_size_input_ids: + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + # Ignore copy + extension = torch.rand( + batch_size_input_ids - dummy_input.shape[0], + *dummy_input.shape[1:], + dtype=torch_dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) + else: + # Ignore copy + extension = torch.randint( + high=5, + size=(batch_size_input_ids - dummy_input.shape[0], *dummy_input.shape[1:]), + dtype=dummy_input.dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) + + if not use_mask: + dummy_attention_mask = None + else: + dummy_attention_mask = inputs_dict.get("attention_mask", None) + if dummy_attention_mask is None: + if is_encoder_decoder: + seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1] + else: + seqlen = dummy_input.shape[-1] + dummy_attention_mask = ( + torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) + ) + + dummy_attention_mask = dummy_attention_mask[:batch_size] + if dummy_attention_mask.shape[0] != batch_size: + extension = torch.ones( + batch_size - dummy_attention_mask.shape[0], + *dummy_attention_mask.shape[1:], + dtype=dummy_attention_mask.dtype, + device=torch_device, + ) + dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) + dummy_attention_mask = dummy_attention_mask.to(torch_device) + + dummy_attention_mask[:] = 1 + if padding_side == "left": + dummy_attention_mask[-1, :-1] = 1 + dummy_attention_mask[-1, -4:] = 0 + elif padding_side == "right": + dummy_attention_mask[-1, 1:] = 1 + dummy_attention_mask[-1, :3] = 0 + + for enable_kernels in [False, True]: + failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}" + + other_inputs = { + "output_hidden_states": True, + } + + # Otherwise fails for e.g. WhisperEncoderModel + if "attention_mask" in inspect.signature(model_eager.forward).parameters: + other_inputs["attention_mask"] = dummy_attention_mask + + # TODO: test gradients as well (& for FA2 as well!) + with torch.no_grad(): + with torch.backends.cuda.sdp_kernel( + enable_flash=enable_kernels, + enable_math=True, + enable_mem_efficient=enable_kernels, + ): + outputs_eager = model_eager(dummy_input, **other_inputs) + outputs_sdpa = model_sdpa(dummy_input, **other_inputs) + + logits_eager = ( + outputs_eager.hidden_states[-1] + if not is_encoder_decoder + else outputs_eager.decoder_hidden_states[-1] + ) + logits_sdpa = ( + outputs_sdpa.hidden_states[-1] + if not is_encoder_decoder + else outputs_sdpa.decoder_hidden_states[-1] + ) + + if torch_device in ["cpu", "cuda"]: + atol = atols[torch_device, enable_kernels, torch_dtype] + rtol = rtols[torch_device, enable_kernels, torch_dtype] + else: + atol = 1e-7 + rtol = 1e-4 + + # Masked tokens output slightly deviates - we don't mind that. + if use_mask: + if padding_side == "left": + sub_sdpa = logits_sdpa[:-1] + sub_eager = logits_eager[:-1] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + sub_sdpa = logits_sdpa[-1, :-4] + sub_eager = logits_eager[-1, :-4] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + # Testing the padding tokens is not really meaningful but anyway + # sub_sdpa = logits_sdpa[-1, -4:] + # sub_eager = logits_eager[-1, -4:] + # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) + elif padding_side == "right": + sub_sdpa = logits_sdpa[:-1] + sub_eager = logits_eager[:-1] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + sub_sdpa = logits_sdpa[-1, 3:] + sub_eager = logits_eager[-1, 3:] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + # Testing the padding tokens is not really meaningful but anyway + # sub_sdpa = logits_sdpa[-1, :3] + # sub_eager = logits_eager[-1, :3] + # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) + + else: + if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol) + ) + + self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) + + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_generate(self): + max_new_tokens = 30 + + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model_sdpa = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + attn_implementation="eager", + ).to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + + # Just test that a large cache works as expected + res_eager = model_eager.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + res_sdpa = model_sdpa.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + self.assertTrue(torch.allclose(res_eager, res_sdpa)) + + +def prepare_moshi_inputs_dict( + config, + input_ids, + decoder_input_ids, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + labels=None, +): + if decoder_attention_mask is None: + decoder_attention_mask = decoder_input_ids.reshape( + -1, config.decoder.num_codebooks, decoder_input_ids.shape[-1] + )[:, 0, :] + decoder_attention_mask = decoder_attention_mask.ne(config.decoder.pad_token_id) + if head_mask is None: + head_mask = torch.ones( + config.text_encoder.num_hidden_layers, config.text_encoder.num_attention_heads, device=torch_device + ) + if decoder_head_mask is None: + decoder_head_mask = torch.ones( + config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device + ) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones( + config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device + ) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "labels": labels, + } + + +class MoshiTester: + def __init__( + self, + parent, + batch_size=4, # need batch_size != num_hidden_layers + seq_length=7, + is_training=True, + vocab_size=99, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=100, + pad_token_id=99, + bos_token_id=99, + num_codebooks=4, + num_filters=4, + codebook_size=128, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.num_codebooks = num_codebooks + self.num_filters = num_filters + self.codebook_size = codebook_size + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + decoder_input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size) + + config = self.get_config() + inputs_dict = prepare_moshi_inputs_dict(config, input_ids, decoder_input_ids=decoder_input_ids) + return config, inputs_dict + + def get_config(self): + text_encoder_config = T5Config( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.intermediate_size, + num_layers=self.num_hidden_layers, + num_heads=self.num_attention_heads, + ) + audio_encoder_config = EncodecConfig( + hidden_size=self.vocab_size, + compress=1, + num_filters=self.num_filters, + codebook_size=self.codebook_size, + codebook_dim=self.vocab_size, + ) + decoder_config = MoshiDecoderConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + ffn_dim=self.intermediate_size, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.bos_token_id, + bos_token_id=self.bos_token_id, + num_codebooks=self.num_codebooks, + tie_word_embeddings=False, + ) + config = MoshiConfig.from_sub_models_config(text_encoder_config, audio_encoder_config, decoder_config) + return config + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + +@require_torch +class MoshiTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (MoshiForConditionalGeneration,) if is_torch_available() else () + greedy_sample_model_classes = (MoshiForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = {"text-to-audio": MoshiForConditionalGeneration} if is_torch_available() else {} + test_pruning = False # training is not supported yet for Moshi + test_headmasking = False + test_resize_embeddings = False + # not to test torchscript as the model tester doesn't prepare `input_values` and `padding_mask` + # (and `torchscript` hates `None` values). + test_torchscript = False + + def setUp(self): + self.model_tester = MoshiTester(self) + + # special case for labels + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks), + dtype=torch.long, + device=torch_device, + ) + return inputs_dict + + def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None): + if not self.model_tester.is_training: + self.skipTest(reason="model_tester.is_training is set to False") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_cache = False + config.return_dict = True + model = model_class(config) + + model.to(torch_device) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + model.train() + + # The audio encoder weights are not used during the forward pass (only during the generate pass) + # So we need to freeze it to be able to train. + model.freeze_audio_encoder() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + optimizer.step() + + for k, v in model.named_parameters(): + if v.requires_grad: + self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!") + + def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids): + text_encoder_config = config.text_encoder + decoder_config = config.decoder + + encoder_attentions = outputs["encoder_attentions"] + self.assertEqual(len(encoder_attentions), text_encoder_config.num_hidden_layers) + + self.assertEqual( + encoder_attentions[0].shape[-3:], + (text_encoder_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]), + ) + + decoder_attentions = outputs["decoder_attentions"] + num_decoder_layers = decoder_config.num_hidden_layers + self.assertEqual(len(decoder_attentions), num_decoder_layers) + + self.assertEqual( + decoder_attentions[0].shape[-3:], + (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]), + ) + + cross_attentions = outputs["cross_attentions"] + self.assertEqual(len(cross_attentions), num_decoder_layers) + + cross_attention_input_seq_len = decoder_input_ids.shape[-1] + self.assertEqual( + cross_attentions[0].shape[-3:], + (decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]), + ) + + def check_moshi_model_output_attentions( + self, + model_class, + config, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + output_attentions=True, + **kwargs, + ) + self._check_output_with_attentions(outputs, config, input_ids, decoder_input_ids) + + def check_moshi_model_output_attentions_from_config( + self, + model_class, + config, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ): + # Similar to `check_moshi_model_output_attentions`, but with `output_attentions` triggered from the + # config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded + # from the inner models' configurations. + config.output_attentions = True # model config -> won't work + + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + **kwargs, + ) + self.assertTrue( + all(key not in outputs for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"]) + ) + config.text_encoder.output_attentions = True # inner model config -> will work + config.audio_encoder.output_attentions = True + config.decoder.output_attentions = True + + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + **kwargs, + ) + self._check_output_with_attentions(outputs, config, input_ids, decoder_input_ids) + + # override since changing `output_attentions` from the top-level model config won't work + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + self.check_moshi_model_output_attentions(model_class, config, **inputs_dict) + self.check_moshi_model_output_attentions_from_config(model_class, config, **inputs_dict) + + # override since we have a specific forward signature for moshi + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = [ + "input_ids", + "attention_mask", + "input_values", + "padding_mask", + "decoder_input_ids", + "decoder_attention_mask", + ] + expected_arg_names.extend( + ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"] + if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names + else ["encoder_outputs"] + ) + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + + # override since changing `gradient_checkpointing` from the top-level model config won't work + def test_gradient_checkpointing_backward_compatibility(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if not model_class.supports_gradient_checkpointing: + continue + + config.text_encoder.gradient_checkpointing = True + config.audio_encoder.gradient_checkpointing = True + config.decoder.gradient_checkpointing = True + model = model_class(config) + self.assertTrue(model.is_gradient_checkpointing) + + @unittest.skip(reason="Moshi has multiple inputs embeds and lm heads that should not be tied.") + def test_tie_model_weights(self): + pass + + @unittest.skip(reason="Moshi has multiple inputs embeds and lm heads that should not be tied.") + def test_tied_model_weights_key_ignore(self): + pass + + @unittest.skip(reason="Moshi has multiple inputs embeds and lm heads that should not be tied.") + def test_tied_weights_keys(self): + pass + + @unittest.skip(reason="No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage(self): + pass + + @unittest.skip(reason="No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_checkpoints(self): + pass + + @unittest.skip(reason="No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_no_safetensors(self): + pass + + # override since changing `output_hidden_states` / `output_attentions` from the top-level model config won't work + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.text_encoder.output_hidden_states = True + config.audio_encoder.output_hidden_states = True + config.decoder.output_hidden_states = True + + config.text_encoder.output_attentions = True + config.decoder.output_attentions = True + + # no need to test all models as different heads yield the same functionality + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs) + + output = outputs[0] + + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_hidden_states.retain_grad() + + decoder_hidden_states = outputs.decoder_hidden_states[0] + decoder_hidden_states.retain_grad() + + if self.has_attentions: + encoder_attentions = outputs.encoder_attentions[0] + encoder_attentions.retain_grad() + + decoder_attentions = outputs.decoder_attentions[0] + decoder_attentions.retain_grad() + + cross_attentions = outputs.cross_attentions[0] + cross_attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(encoder_hidden_states.grad) + self.assertIsNotNone(decoder_hidden_states.grad) + + if self.has_attentions: + self.assertIsNotNone(encoder_attentions.grad) + self.assertIsNotNone(decoder_attentions.grad) + self.assertIsNotNone(cross_attentions.grad) + + # override since changing `output_hidden_states` from the top-level model config won't work + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states + + expected_num_layers = self.model_tester.num_hidden_layers + 1 + self.assertEqual(len(hidden_states), expected_num_layers) + + seq_length = self.model_tester.seq_length + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + hidden_states = outputs.decoder_hidden_states + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.text_encoder.output_hidden_states = True + config.audio_encoder.output_hidden_states = True + config.decoder.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + # override since the conv layers and lstm's in encodec are exceptions + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + uniform_init_parms = ["conv"] + ignore_init = ["lstm"] + if param.requires_grad: + if any(x in name for x in uniform_init_parms): + self.assertTrue( + -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + elif not any(x in name for x in ignore_init): + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + # override since we have embeddings / LM heads over multiple codebooks + def test_model_get_set_embeddings(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), torch.nn.Embedding) + lm_heads = model.get_output_embeddings() + self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear)) + + def _get_input_ids_and_config(self, batch_size=2): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + input_ids = inputs_dict["input_ids"] + + # take max batch_size + sequence_length = input_ids.shape[-1] + input_ids = input_ids[:batch_size, :] + attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) + + return config, input_ids, attention_mask + + # override since the `input_ids` cannot be used as the `decoder_input_ids` for moshi (input / outputs are + # different modalities -> different shapes) + def _greedy_generate( + self, + model, + input_ids, + attention_mask, + output_scores=False, + output_attentions=False, + output_hidden_states=False, + return_dict_in_generate=False, + ): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + output_generate = model.generate( + input_ids, + do_sample=False, + num_beams=1, + max_new_tokens=self.max_new_tokens, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + remove_invalid_values=True, + **model_kwargs, + ) + + return output_generate + + # override since the `input_ids` cannot be used as the `decoder_input_ids` for moshi (input / outputs are + # different modalities -> different shapes) + def _sample_generate( + self, + model, + input_ids, + attention_mask, + num_return_sequences, + output_scores=False, + output_attentions=False, + output_hidden_states=False, + return_dict_in_generate=False, + ): + torch.manual_seed(0) + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + output_generate = model.generate( + input_ids, + do_sample=True, + num_beams=1, + max_new_tokens=self.max_new_tokens, + num_return_sequences=num_return_sequences, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + remove_invalid_values=True, + **model_kwargs, + ) + + return output_generate + + def _get_logits_processor_kwargs(self, do_sample=False): + logits_processor_kwargs = {} + return logits_processor_kwargs + + def test_greedy_generate_dict_outputs(self): + for model_class in self.greedy_sample_model_classes: + # disable cache + config, input_ids, attention_mask = self._get_input_ids_and_config() + config.use_cache = False + model = model_class(config).to(torch_device).eval() + output_generate = self._greedy_generate( + model=model, + input_ids=input_ids.to(torch_device), + attention_mask=attention_mask.to(torch_device), + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) + + self.assertNotIn(config.pad_token_id, output_generate) + + def test_greedy_generate_dict_outputs_use_cache(self): + for model_class in self.greedy_sample_model_classes: + # enable cache + config, input_ids, attention_mask = self._get_input_ids_and_config() + + config.use_cache = True + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + output_generate = self._greedy_generate( + model=model, + input_ids=input_ids.to(torch_device), + attention_mask=attention_mask.to(torch_device), + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) + + def test_sample_generate(self): + for model_class in self.greedy_sample_model_classes: + config, input_ids, attention_mask = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() + + # check `generate()` and `sample()` are equal + output_generate = self._sample_generate( + model=model, + input_ids=input_ids.to(torch_device), + attention_mask=attention_mask.to(torch_device), + num_return_sequences=1, + ) + self.assertIsInstance(output_generate, torch.Tensor) + + def test_sample_generate_dict_output(self): + for model_class in self.greedy_sample_model_classes: + # disable cache + config, input_ids, attention_mask = self._get_input_ids_and_config() + config.use_cache = False + model = model_class(config).to(torch_device).eval() + + output_generate = self._sample_generate( + model=model, + input_ids=input_ids.to(torch_device), + attention_mask=attention_mask.to(torch_device), + num_return_sequences=3, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) + + def test_generate_without_input_ids(self): + config, _, _ = self._get_input_ids_and_config() + + # if no bos token id => cannot generate from None + if config.bos_token_id is None: + self.skipTest(reason="bos_token_id is None") + + for model_class in self.greedy_sample_model_classes: + model = model_class(config).to(torch_device) + model.eval() + + output_ids_generate = model.generate( + do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True + ) + self.assertIsNotNone(output_ids_generate) + + @require_torch_fp16 + @require_torch_accelerator # not all operations are supported in fp16 on CPU + def test_generate_fp16(self): + config, input_dict = self.model_tester.prepare_config_and_inputs() + + for model_class in self.greedy_sample_model_classes: + model = model_class(config).eval().to(torch_device) + model.half() + # greedy + model.generate(input_dict["input_ids"], attention_mask=input_dict["attention_mask"], max_new_tokens=10) + # sampling + model.generate( + input_dict["input_ids"], attention_mask=input_dict["attention_mask"], do_sample=True, max_new_tokens=10 + ) + + def test_greedy_generate_stereo_outputs(self): + for model_class in self.greedy_sample_model_classes: + config, input_ids, attention_mask = self._get_input_ids_and_config() + config.audio_channels = 2 + + model = model_class(config).to(torch_device).eval() + output_generate = self._greedy_generate( + model=model, + input_ids=input_ids.to(torch_device), + attention_mask=attention_mask.to(torch_device), + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) + + self.assertNotIn(config.pad_token_id, output_generate) + + @unittest.skip( + reason="MoshiModel is actually not the base of MoshiForCausalLM as the latter is a composit model" + ) + def test_save_load_fast_init_from_base(self): + pass + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # Ignore copy + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + # Ignore copy + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + + # Ignore copy + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input) + # Ignore copy + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + # Ignore copy + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + # Ignore copy + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + # Ignore copy + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + # Ignore copy + outputs = model(dummy_input, **other_inputs) + # Ignore copy + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + + # check with inference + dropout + model.train() + _ = model_fa(dummy_input, **other_inputs) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # Ignore copy + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + # Ignore copy + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + # Ignore copy + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input) + # Ignore copy + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + # Ignore copy + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + # Ignore copy + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + # Ignore copy + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + # Ignore copy + outputs = model(dummy_input, **other_inputs) + # Ignore copy + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_left_padding(self): + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask") + if dummy_attention_mask is None: + dummy_attention_mask = torch.ones_like(dummy_input) + + # make sure we do left padding + dummy_attention_mask[:, :-1] = 0 + dummy_attention_mask[:, -1:] = 1 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_padding_right(self): + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask") + if dummy_attention_mask is None: + dummy_attention_mask = torch.ones_like(dummy_input) + # make sure we do right padding + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_use_cache(self): + max_new_tokens = 30 + + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + + if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device): + self.skipTest( + f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" + ) + + # Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead. + if torch_dtype == "float16": + torch_dtype = torch.float16 + elif torch_dtype == "bfloat16": + torch_dtype = torch.bfloat16 + elif torch_dtype == "float32": + torch_dtype = torch.float32 + + atols = { + ("cpu", False, torch.float32): 1e-6, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-6, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-6, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-6, + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + rtols = { + ("cpu", False, torch.float32): 1e-4, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-4, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-4, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-4, + ("cuda", True, torch.bfloat16): 3e-2, + ("cuda", True, torch.float16): 5e-3, + } + + def get_mean_reldiff(failcase, x, ref, atol, rtol): + return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + is_encoder_decoder = model.config.is_encoder_decoder + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) + model_sdpa = model_sdpa.eval().to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch_dtype, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa and model_sdpa.config.model_type != "falcon": + raise ValueError("The SDPA model should have SDPA attention layers") + + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model, + # but it would be nicer to have an efficient way to use parameterized.expand + fail_cases = [] + for padding_side in ["left", "right"]: + for use_mask in [False, True]: + for batch_size in [1, 5]: + dummy_input = inputs_dict[model.main_input_name] + + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + dummy_input = dummy_input.to(torch_dtype) + + dummy_input = dummy_input[:batch_size] + if dummy_input.shape[0] != batch_size: + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + extension = torch.rand( + batch_size - dummy_input.shape[0], + *dummy_input.shape[1:], + dtype=torch_dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) + else: + extension = torch.randint( + high=5, + size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]), + dtype=dummy_input.dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) + + if not use_mask: + dummy_attention_mask = None + else: + dummy_attention_mask = inputs_dict.get("attention_mask", None) + if dummy_attention_mask is None: + # Ignore copy + seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1] + # Ignore copy + dummy_attention_mask = ( + torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) + ) + + dummy_attention_mask = dummy_attention_mask[:batch_size] + if dummy_attention_mask.shape[0] != batch_size: + extension = torch.ones( + batch_size - dummy_attention_mask.shape[0], + *dummy_attention_mask.shape[1:], + dtype=dummy_attention_mask.dtype, + device=torch_device, + ) + dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) + dummy_attention_mask = dummy_attention_mask.to(torch_device) + + dummy_attention_mask[:] = 1 + if padding_side == "left": + dummy_attention_mask[-1, :-1] = 1 + dummy_attention_mask[-1, -4:] = 0 + elif padding_side == "right": + dummy_attention_mask[-1, 1:] = 1 + dummy_attention_mask[-1, :3] = 0 + + for enable_kernels in [False, True]: + failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}" + # Ignore copy + batch_size_input_ids = self.model_tester.num_codebooks * batch_size + # Ignore copy + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[ + :batch_size_input_ids + ] + # Ignore copy + if decoder_input_ids.shape[0] != batch_size_input_ids: + # Ignore copy + extension = torch.ones( + batch_size_input_ids - decoder_input_ids.shape[0], + *decoder_input_ids.shape[1:], + dtype=decoder_input_ids.dtype, + device=torch_device, + ) + decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0) + decoder_input_ids = decoder_input_ids.to(torch_device) + + # TODO: never an `attention_mask` arg here? + # Ignore copy + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + + # TODO: test gradients as well (& for FA2 as well!) + # Ignore copy + with torch.no_grad(): + with torch.backends.cuda.sdp_kernel( + enable_flash=enable_kernels, + enable_math=True, + enable_mem_efficient=enable_kernels, + ): + outputs_eager = model_eager(dummy_input, **other_inputs) + outputs_sdpa = model_sdpa(dummy_input, **other_inputs) + + logits_eager = ( + outputs_eager.hidden_states[-1] + if not is_encoder_decoder + else outputs_eager.decoder_hidden_states[-1] + ) + logits_sdpa = ( + outputs_sdpa.hidden_states[-1] + if not is_encoder_decoder + else outputs_sdpa.decoder_hidden_states[-1] + ) + + if torch_device in ["cpu", "cuda"]: + atol = atols[torch_device, enable_kernels, torch_dtype] + rtol = rtols[torch_device, enable_kernels, torch_dtype] + else: + atol = 1e-7 + rtol = 1e-4 + + # Masked tokens output slightly deviates - we don't mind that. + if use_mask: + if padding_side == "left": + sub_sdpa = logits_sdpa[:-1] + sub_eager = logits_eager[:-1] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + sub_sdpa = logits_sdpa[-1, :-4] + sub_eager = logits_eager[-1, :-4] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + # Testing the padding tokens is not really meaningful but anyway + # sub_sdpa = logits_sdpa[-1, -4:] + # sub_eager = logits_eager[-1, -4:] + # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) + elif padding_side == "right": + sub_sdpa = logits_sdpa[:-1] + sub_eager = logits_eager[:-1] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + sub_sdpa = logits_sdpa[-1, 3:] + sub_eager = logits_eager[-1, 3:] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + # Testing the padding tokens is not really meaningful but anyway + # sub_sdpa = logits_sdpa[-1, :3] + # sub_eager = logits_eager[-1, :3] + # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) + + else: + if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol) + ) + + self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) + + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_generate(self): + max_new_tokens = 30 + + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model_sdpa = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + attn_implementation="eager", + ).to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + + # Just test that a large cache works as expected + res_eager = model_eager.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + res_sdpa = model_sdpa.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + self.assertTrue(torch.allclose(res_eager, res_sdpa)) + + def test_requires_grad_with_frozen_encoders(self): + config = self.model_tester.get_config() + for model_class in self.all_model_classes: + model = model_class(config) + model.freeze_audio_encoder() + + audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()] + text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()] + + self.assertFalse(all(audio_encoder_grads)) + self.assertTrue(all(text_encoder_grads)) + + model = model_class(config) + model.freeze_text_encoder() + + audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()] + text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()] + + self.assertTrue(all(audio_encoder_grads)) + self.assertFalse(all(text_encoder_grads)) + + +def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000): + """Produces a series of 'bip bip' sounds at a given frequency.""" + timesteps = np.arange(int(duration * sample_rate)) / sample_rate + wav = np.cos(2 * math.pi * 440 * timesteps) + time_period = (timesteps % (2 * bip_duration)) / (2 * bip_duration) + envelope = time_period >= 0.5 + return wav * envelope + + +def place_dict_on_device(dict_to_place, device): + for key in dict_to_place: + if dict_to_place[key] is not None and isinstance(dict_to_place[key], torch.Tensor): + dict_to_place[key] = dict_to_place[key].to(device) + return dict_to_place + + +@require_torch +class MoshiIntegrationTests(unittest.TestCase): + @cached_property + def model(self): + return MoshiForConditionalGeneration.from_pretrained("kyutai/moshiko").to(torch_device) + + @cached_property + def processor(self): + return MoshiProcessor.from_pretrained("kyutai/moshiko") + + @slow + def test_logits_text_prompt(self): + model = self.model + processor = self.processor + + inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt") + + # prepare the encoder inputs + input_ids = inputs.input_ids.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + # prepare the decoder inputs + pad_token_id = model.generation_config.pad_token_id + decoder_input_ids = ( + torch.ones((input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long).to(torch_device) + * pad_token_id + ) + + with torch.no_grad(): + logits = model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + ).logits + + # fmt: off + EXPECTED_LOGITS = torch.tensor( + [ + -0.9708, -3.0149, -4.6415, -1.4754, -0.2786, -2.3523, -2.6049, -6.7467, + -1.0206, -3.2984, -3.3968, -1.5108, -1.5786, -3.1493, -1.1503, -0.0545, + ] + ) + # fmt: on + + self.assertTrue(logits.shape == (*decoder_input_ids.shape, model.decoder.config.vocab_size)) + self.assertTrue(torch.allclose(logits[0, 0, :16].cpu(), EXPECTED_LOGITS, atol=1e-4)) + + @slow + def test_logits_text_audio_prompt(self): + model = self.model + processor = self.processor + + audio = [get_bip_bip(duration=0.5), get_bip_bip(duration=1.0)] + text = ["80s music", "Club techno"] + + inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt") + + # prepare the text encoder inputs + input_ids = inputs.input_ids.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + # prepare the audio encoder inputs + input_values = inputs.input_values.to(torch_device) + padding_mask = inputs.padding_mask.to(torch_device) + + with torch.no_grad(): + logits = model( + input_ids, + attention_mask=attention_mask, + input_values=input_values, + padding_mask=padding_mask, + ).logits + + # fmt: off + EXPECTED_LOGITS = torch.tensor( + [ + 0.1841, -2.9324, -0.7898, 0.1857, 0.4971, -2.8685, -1.6525, -1.6541, + 2.7757, -2.5942, -3.0959, -1.0120, -1.0147, -0.4605, -0.8885, 0.6820, + ] + ) + # fmt: on + + self.assertTrue(logits.shape == (8, 50, 2048)) + self.assertTrue(torch.allclose(logits[0, -1, :16].cpu(), EXPECTED_LOGITS, atol=1e-4)) + + @slow + def test_generate_unconditional_greedy(self): + model = self.model + + # only generate 1 sample with greedy - since it's deterministic all elements of the batch will be the same + unconditional_inputs = model.get_unconditional_inputs(num_samples=1) + unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device) + + output_values = model.generate(**unconditional_inputs, do_sample=False, max_new_tokens=5) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + 0.0056, 0.0064, 0.0063, 0.0054, 0.0042, 0.0033, 0.0024, 0.0015, + 0.0015, 0.0010, 0.0004, -0.0012, -0.0036, -0.0055, -0.0067, -0.0071, + ] + ) + # fmt: on + + self.assertTrue(output_values.shape == (1, 1, 3200)) + self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4)) + + @slow + def test_generate_unconditional_sampling(self): + model = self.model + + # for stochastic sampling we can generate multiple outputs + unconditional_inputs = model.get_unconditional_inputs(num_samples=2) + unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device) + + set_seed(0) + output_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=10) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + -0.0099, -0.0140, 0.0079, 0.0080, -0.0046, 0.0065, -0.0068, -0.0185, + 0.0105, 0.0059, 0.0329, 0.0249, -0.0204, -0.0341, -0.0465, 0.0053, + ] + ) + # fmt: on + + self.assertTrue(output_values.shape == (2, 1, 4480)) + self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4)) + + @slow + def test_generate_text_prompt_greedy(self): + model = self.model + processor = self.processor + + inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt") + + # prepare the encoder inputs + input_ids = inputs.input_ids.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + output_values = model.generate( + input_ids, attention_mask=attention_mask, do_sample=False, guidance_scale=None, max_new_tokens=10 + ) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + -1.1998e-04, -2.2302e-04, 4.6296e-04, 1.0524e-03, 2.4827e-04, + -4.0288e-05, -1.2468e-04, 4.9846e-05, 7.1485e-04, 4.4197e-04, + ] + ) + # fmt: on + + self.assertTrue(output_values.shape == (2, 1, 4480)) + self.assertTrue(torch.allclose(output_values[0, 0, :10].cpu(), EXPECTED_VALUES, atol=1e-4)) + + @slow + def test_generate_text_prompt_greedy_with_classifier_free_guidance(self): + model = self.model + processor = self.processor + + inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt") + + # prepare the encoder inputs + input_ids = inputs.input_ids.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + output_values = model.generate( + input_ids, attention_mask=attention_mask, do_sample=False, guidance_scale=3, max_new_tokens=10 + ) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + 0.0283, 0.0246, 0.0650, 0.0640, 0.0599, 0.0711, 0.0420, 0.0112, + 0.0511, 0.0746, 0.1363, 0.1213, 0.0185, -0.0578, -0.0908, 0.0443, + ] + ) + # fmt: on + + self.assertTrue(output_values.shape == (2, 1, 4480)) + self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4)) + + @slow + def test_generate_text_prompt_sampling(self): + model = self.model + processor = self.processor + + inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt") + + # prepare the encoder inputs + input_ids = inputs.input_ids.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + set_seed(0) + output_values = model.generate( + input_ids, attention_mask=attention_mask, do_sample=True, guidance_scale=None, max_new_tokens=10 + ) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + -0.0111, -0.0154, 0.0047, 0.0058, -0.0068, 0.0012, -0.0109, -0.0229, + 0.0010, -0.0038, 0.0167, 0.0042, -0.0421, -0.0610, -0.0764, -0.0326, + ] + ) + # fmt: on + + self.assertTrue(output_values.shape == (2, 1, 4480)) + self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4)) + + @slow + def test_generate_text_audio_prompt(self): + model = self.model + processor = self.processor + + audio = [get_bip_bip(duration=0.5), get_bip_bip(duration=1.0)] + text = ["80s music", "Club techno"] + + inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt") + inputs = place_dict_on_device(inputs, device=torch_device) + + output_values = model.generate(**inputs, do_sample=False, guidance_scale=None, max_new_tokens=10) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + -0.0036, -0.0130, -0.0261, -0.0384, -0.0557, -0.0718, -0.0680, -0.0632, + -0.0529, -0.0403, -0.0289, -0.0198, -0.0136, -0.0101, -0.0095, -0.0040, + ] + ) + # fmt: on + + self.assertTrue( + output_values.shape == (2, 1, 36480) + ) # input values take shape 32000 and we generate from there + self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, atol=1e-4)) + + +@require_torch +class MoshiStereoIntegrationTests(unittest.TestCase): + @cached_property + def model(self): + return MoshiForConditionalGeneration.from_pretrained("facebook/moshi-stereo-small").to(torch_device) + + @cached_property + def processor(self): + return MoshiProcessor.from_pretrained("facebook/moshi-stereo-small") + + @slow + def test_generate_unconditional_greedy(self): + model = self.model + + # only generate 1 sample with greedy - since it's deterministic all elements of the batch will be the same + unconditional_inputs = model.get_unconditional_inputs(num_samples=1) + unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device) + + output_values = model.generate(**unconditional_inputs, do_sample=False, max_new_tokens=12) + + # fmt: off + EXPECTED_VALUES_LEFT = torch.tensor( + [ + 0.0017, 0.0004, 0.0004, 0.0005, 0.0002, 0.0002, -0.0002, -0.0013, + -0.0010, -0.0015, -0.0018, -0.0032, -0.0060, -0.0082, -0.0096, -0.0099, + ] + ) + EXPECTED_VALUES_RIGHT = torch.tensor( + [ + 0.0038, 0.0028, 0.0031, 0.0032, 0.0031, 0.0032, 0.0030, 0.0019, + 0.0021, 0.0015, 0.0009, -0.0008, -0.0040, -0.0067, -0.0087, -0.0096, + ] + ) + # fmt: on + + # (bsz, channels, seq_len) + self.assertTrue(output_values.shape == (1, 2, 5760)) + self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4)) + self.assertTrue(torch.allclose(output_values[0, 1, :16].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4)) + + @slow + def test_generate_text_audio_prompt(self): + model = self.model + processor = self.processor + + # create stereo inputs + audio = [get_bip_bip(duration=0.5)[None, :].repeat(2, 0), get_bip_bip(duration=1.0)[None, :].repeat(2, 0)] + text = ["80s music", "Club techno"] + + inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt") + inputs = place_dict_on_device(inputs, device=torch_device) + + output_values = model.generate(**inputs, do_sample=False, guidance_scale=3.0, max_new_tokens=12) + + # fmt: off + EXPECTED_VALUES_LEFT = torch.tensor( + [ + 0.2535, 0.2008, 0.1471, 0.0896, 0.0306, -0.0200, -0.0501, -0.0728, + -0.0832, -0.0856, -0.0867, -0.0884, -0.0864, -0.0866, -0.0744, -0.0430, + ] + ) + EXPECTED_VALUES_RIGHT = torch.tensor( + [ + 0.1695, 0.1213, 0.0732, 0.0239, -0.0264, -0.0705, -0.0935, -0.1103, + -0.1163, -0.1139, -0.1104, -0.1082, -0.1027, -0.1004, -0.0900, -0.0614, + ] + ) + # fmt: on + + # (bsz, channels, seq_len) + self.assertTrue(output_values.shape == (2, 2, 37760)) + # input values take shape 32000 and we generate from there - we check the last (generated) values + self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4)) + self.assertTrue(torch.allclose(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4)) From 2eaadcaf5f40082c014d4b9ca1747ae65f5c572d Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Mon, 16 Sep 2024 16:50:28 +0200 Subject: [PATCH 5/8] converting weights working + configuration + generation configuration --- src/transformers/__init__.py | 3 - src/transformers/models/moshi/__init__.py | 2 - .../models/moshi/configuration_moshi.py | 275 +- .../moshi/convert_moshi_transformers.py | 377 +-- .../moshi/generation_configuration_moshi.py | 87 + .../models/moshi/modeling_moshi.py | 2511 ++++++++--------- tests/models/moshi/test_modeling_moshi.py | 7 +- 7 files changed, 1515 insertions(+), 1747 deletions(-) create mode 100644 src/transformers/models/moshi/generation_configuration_moshi.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 049cd65b535e87..3f31ee7e4827b6 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -593,11 +593,9 @@ "models.mt5": ["MT5Config"], "models.musicgen": [ "MusicgenConfig", - "MusicgenDecoderConfig", ], "models.moshi": [ "MoshiConfig", - "MoshiDecoderConfig", ], "models.musicgen_melody": [ "MusicgenMelodyConfig", @@ -5395,7 +5393,6 @@ ) from .models.moshi import ( MoshiConfig, - MoshiDecoderConfig, ) from .models.musicgen_melody import ( MusicgenMelodyConfig, diff --git a/src/transformers/models/moshi/__init__.py b/src/transformers/models/moshi/__init__.py index 6d5fe355351744..c1c617a3e816d8 100644 --- a/src/transformers/models/moshi/__init__.py +++ b/src/transformers/models/moshi/__init__.py @@ -19,7 +19,6 @@ _import_structure = { "configuration_moshi": [ "MoshiConfig", - "MoshiDecoderConfig", ], } @@ -39,7 +38,6 @@ if TYPE_CHECKING: from .configuration_moshi import ( MoshiConfig, - MoshiDecoderConfig, ) try: diff --git a/src/transformers/models/moshi/configuration_moshi.py b/src/transformers/models/moshi/configuration_moshi.py index f6ea2139032825..9fff7c5c7124c8 100644 --- a/src/transformers/models/moshi/configuration_moshi.py +++ b/src/transformers/models/moshi/configuration_moshi.py @@ -22,153 +22,91 @@ logger = logging.get_logger(__name__) -class MoshiDecoderConfig(PretrainedConfig): + +class MoshiConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of an [`MoshiDecoder`]. It is used to instantiate a - Moshi decoder according to the specified arguments, defining the model architecture. Instantiating a - configuration with the defaults will yield a similar configuration to that of the Moshi - [kyutai/moshiko](https://huggingface.co/kyutai/moshiko) architecture. + This is the configuration class to store the configuration of a [`MoshiModel`]. It is used to instantiate a + Moshi model according to the specified arguments, defining the audio encoder, Moshi depth decoder and Moshi decoder + configs. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - Args: - vocab_size (`int`, *optional*, defaults to 2048): + vocab_size (`int`, *optional*, defaults to 32000): Vocabulary size of the MoshiDecoder model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`MoshiDecoder`]. - hidden_size (`int`, *optional*, defaults to 1024): - Dimensionality of the layers and the pooler layer. - num_hidden_layers (`int`, *optional*, defaults to 24): + hidden_size (`int`, *optional*, defaults to 4096): + Dimensionality of the layers and the pooler layer of the main decoder. + num_hidden_layers (`int`, *optional*, defaults to 32): Number of decoder layers. - num_attention_heads (`int`, *optional*, defaults to 16): - Number of attention heads for each attention layer in the Transformer block. - ffn_dim (`int`, *optional*, defaults to 4096): - Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block. - activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): - The non-linear activation function (function or string) in the decoder and pooler. If string, `"gelu"`, - `"relu"`, `"silu"` and `"gelu_new"` are supported. - dropout (`float`, *optional*, defaults to 0.1): - The dropout probability for all fully connected layers in the embeddings, text_encoder, and pooler. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - activation_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for activations inside the fully connected layer. - max_position_embeddings (`int`, *optional*, defaults to 2048): + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the main decoder block. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. + max_position_embeddings (`int`, *optional*, defaults to 3750): The maximum sequence length that this model might ever be used with. Typically, set this to something large just in case (e.g., 512 or 1024 or 2048). - initializer_factor (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - layerdrop (`float`, *optional*, defaults to 0.0): - The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) - for more details. - scale_embedding (`bool`, *optional*, defaults to `False`): - Scale embeddings by diving by sqrt(hidden_size). + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. use_cache (`bool`, *optional*, defaults to `True`): - Whether the model should return the last key/values attentions (not used by all models) - num_codebooks (`int`, *optional*, defaults to 4): - The number of parallel codebooks forwarded to the model. - tie_word_embeddings(`bool`, *optional*, defaults to `False`): - Whether input and output word embeddings should be tied. - audio_channels (`int`, *optional*, defaults to 1 - Number of channels in the audio data. Either 1 for mono or 2 for stereo. Stereo models generate a separate - audio stream for the left/right output channels. Mono models generate a single audio stream output. - """ - - model_type = "moshi_decoder" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=2048, - max_position_embeddings=2048, - num_hidden_layers=24, - ffn_dim=4096, - num_attention_heads=16, - layerdrop=0.0, - use_cache=True, - activation_function="gelu", - hidden_size=1024, - dropout=0.1, - attention_dropout=0.0, - activation_dropout=0.0, - initializer_factor=0.02, - scale_embedding=False, - num_codebooks=4, - audio_channels=1, - pad_token_id=2048, - bos_token_id=2048, - eos_token_id=None, - tie_word_embeddings=False, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.ffn_dim = ffn_dim - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.dropout = dropout - self.attention_dropout = attention_dropout - self.activation_dropout = activation_dropout - self.activation_function = activation_function - self.initializer_factor = initializer_factor - self.layerdrop = layerdrop - self.use_cache = use_cache - self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True - self.num_codebooks = num_codebooks - - if audio_channels not in [1, 2]: - raise ValueError(f"Expected 1 (mono) or 2 (stereo) audio channels, got {audio_channels} channels.") - self.audio_channels = audio_channels - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - -class MoshiConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`MoshiModel`]. It is used to instantiate a - Moshi model according to the specified arguments, defining the text encoder, audio encoder and Moshi decoder - configs. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + sliding_window (`int`, *optional*, defaults to 250): + Sliding window attention window size. If not specified, will default to `250`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ffn_dim (`int`, *optional*, defaults to 22528): + Dimensionality of the "intermediate" (often named feed-forward) layer in the main decoder block. Must be even. + num_codebooks (`int`, *optional*, defaults to 8): + The number of audio codebooks for each audio channels. + rms_norm_eps (`float`, *optional*, defaults to 1e-8): + The epsilon used by the rms normalization layers. + depth_hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer of the depth decoder. + depth_num_hidden_layers (`int`, *optional*, defaults to 6): + Number of depth decoder layers. + depth_num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the depth decoder block. + depth_max_position_embeddings (`int`, *optional*, defaults to 8): + The maximum sequence length that the depth decoder model might ever be used with. Typically, set this to the + number of codebooks. + depth_ffn_dim (`int`, *optional*, defaults to 5632): + Dimensionality of the "intermediate" (often named feed-forward) layer in the depth decoder block. Must be even. + depth_head_dim (`int`, *optional*, defaults to `depth_hidden_size // depth_num_attention_heads`): + The attention head dimension of the depth encoder layers. + depth_num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention in the depth decoder. + If it is not specified, will default to `depth_num_key_value_heads`. kwargs (*optional*): Dictionary of keyword arguments. Notably: - - - **text_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that - defines the text encoder config. - **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines the audio encoder config. - - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines - the decoder config. + Example: ```python >>> from transformers import ( ... MoshiConfig, - ... MoshiDecoderConfig, - ... T5Config, ... EncodecConfig, ... MoshiForConditionalGeneration, ... ) >>> # Initializing text encoder, audio encoder, and decoder model configurations - >>> text_encoder_config = T5Config() >>> audio_encoder_config = EncodecConfig() - >>> decoder_config = MoshiDecoderConfig() >>> configuration = MoshiConfig.from_sub_models_config( - ... text_encoder_config, audio_encoder_config, decoder_config + ... audio_encoder_config ... ) >>> # Initializing a MoshiForConditionalGeneration (with random weights) from the kyutai/moshiko style configuration @@ -190,45 +128,95 @@ class MoshiConfig(PretrainedConfig): model_type = "moshi" is_composition = True + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__(self, + vocab_size=32000, + hidden_size=4096, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + audio_vocab_size=None, # TODO + max_position_embeddings=3750, + rope_theta=10000.0, + hidden_act="silu", + head_dim=None, + initializer_range=0.02, + use_cache=True, + sliding_window=3000, + attention_dropout=0.0, + ffn_dim=22528, + rms_norm_eps=1e-8, + num_codebooks=8, + depth_hidden_size=1024, + depth_num_hidden_layers=6, + depth_max_position_embeddings=8, + depth_num_attention_heads=16, + depth_ffn_dim=5632, + depth_head_dim=None, + depth_num_key_value_heads=None, + **kwargs): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.hidden_act = hidden_act + self.head_dim = head_dim or hidden_size // num_attention_heads + self.initializer_range = initializer_range + self.use_cache = use_cache + self.sliding_window = sliding_window + self.attention_dropout = attention_dropout + if ffn_dim % 2 == 1: + raise ValueError(f"`ffn_dim={ffn_dim}` must be even.") + self.ffn_dim = ffn_dim + self.rms_norm_eps = rms_norm_eps + self.num_codebooks = num_codebooks - def __init__(self, **kwargs): - super().__init__(**kwargs) - if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs: - raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config") - text_encoder_config = kwargs.pop("text_encoder") - text_encoder_model_type = text_encoder_config.pop("model_type") + self.depth_hidden_size = depth_hidden_size + self.depth_num_hidden_layers = depth_num_hidden_layers + self.depth_max_position_embeddings = depth_max_position_embeddings + self.depth_num_attention_heads = depth_num_attention_heads + if depth_ffn_dim % 2 == 1: + raise ValueError(f"`depth_ffn_dim={depth_ffn_dim}` must be even.") + self.depth_ffn_dim = depth_ffn_dim + self.depth_head_dim = depth_head_dim or depth_hidden_size // depth_num_attention_heads + self.depth_num_key_value_heads = depth_num_key_value_heads if depth_num_key_value_heads is not None else depth_num_attention_heads + + super().__init__(**kwargs) + + if "audio_encoder" not in kwargs: + raise ValueError("Config has to be initialized with audio_encoder config") audio_encoder_config = kwargs.pop("audio_encoder") audio_encoder_model_type = audio_encoder_config.pop("model_type") - decoder_config = kwargs.pop("decoder") - - self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config) self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config) - self.decoder = MoshiDecoderConfig(**decoder_config) - self.is_encoder_decoder = True + if self.num_codebooks > self.audio_encoder.num_codebooks: + raise ValueError(f"`num_codebooks={num_codebooks}` is greater than the maximum number of codebooks that the audio encoder can deal with ({self.audio_encoder.num_codebooks}). Please lower it.") + + self.audio_vocab_size = self.audio_encoder.codebook_size if audio_vocab_size is None else audio_vocab_size + + @classmethod - def from_sub_models_config( + def from_audio_encoder_config( cls, - text_encoder_config: PretrainedConfig, audio_encoder_config: PretrainedConfig, - decoder_config: MoshiDecoderConfig, **kwargs, ): r""" - Instantiate a [`MoshiConfig`] (or a derived class) from text encoder, audio encoder and decoder - configurations. + Instantiate a [`MoshiConfig`] (or a derived class) from an audio encoder configuration. Returns: [`MoshiConfig`]: An instance of a configuration object """ return cls( - text_encoder=text_encoder_config.to_dict(), audio_encoder=audio_encoder_config.to_dict(), - decoder=decoder_config.to_dict(), **kwargs, ) @@ -236,20 +224,3 @@ def from_sub_models_config( # This is a property because you might want to change the codec model on the fly def sampling_rate(self): return self.audio_encoder.sampling_rate - - @property - def _attn_implementation(self): - # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.) - if hasattr(self, "_attn_implementation_internal"): - if self._attn_implementation_internal is None: - # `config.attn_implementation` should never be None, for backward compatibility. - return "eager" - else: - return self._attn_implementation_internal - else: - return "eager" - - @_attn_implementation.setter - def _attn_implementation(self, value): - self._attn_implementation_internal = value - self.decoder._attn_implementation = value diff --git a/src/transformers/models/moshi/convert_moshi_transformers.py b/src/transformers/models/moshi/convert_moshi_transformers.py index 7b146c44dfe647..cef03195914798 100644 --- a/src/transformers/models/moshi/convert_moshi_transformers.py +++ b/src/transformers/models/moshi/convert_moshi_transformers.py @@ -12,225 +12,232 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Convert Moshi checkpoints from the original repository.""" +"""Convert Moshi checkpoints.""" import argparse -from pathlib import Path -from typing import Dict, OrderedDict, Tuple +import safetensors import torch -from audiocraft.models import Moshi from transformers import ( - AutoFeatureExtractor, - AutoTokenizer, - EncodecModel, - MoshiDecoderConfig, + MoshiConfig, MoshiForConditionalGeneration, - MoshiProcessor, - T5EncoderModel, + MimiModel, # initial audio encoder + logging, ) -from transformers.models.moshi.modeling_moshi import MoshiForCausalLM -from transformers.utils import logging +# EncodecFeatureExtractor, #TODO(YL): add it here and as AutoFeatureExtractor logging.set_verbosity_info() -logger = logging.get_logger(__name__) - - -EXPECTED_MISSING_KEYS = ["model.decoder.embed_positions.weights"] - - -def rename_keys(name): - if "emb" in name: - name = name.replace("emb", "model.decoder.embed_tokens") - if "transformer" in name: - name = name.replace("transformer", "model.decoder") - if "cross_attention" in name: - name = name.replace("cross_attention", "encoder_attn") - if "linear1" in name: - name = name.replace("linear1", "fc1") - if "linear2" in name: - name = name.replace("linear2", "fc2") - if "norm1" in name: - name = name.replace("norm1", "self_attn_layer_norm") - if "norm_cross" in name: - name = name.replace("norm_cross", "encoder_attn_layer_norm") - if "norm2" in name: - name = name.replace("norm2", "final_layer_norm") - if "out_norm" in name: - name = name.replace("out_norm", "model.decoder.layer_norm") - if "linears" in name: - name = name.replace("linears", "lm_heads") - if "condition_provider.conditioners.description.output_proj" in name: - name = name.replace("condition_provider.conditioners.description.output_proj", "enc_to_dec_proj") - return name - - -def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict, Dict]: - """Function that takes the fairseq Moshi state dict and renames it according to the HF - module names. It further partitions the state dict into the decoder (LM) state dict, and that for the - encoder-decoder projection.""" - keys = list(state_dict.keys()) - enc_dec_proj_state_dict = {} - for key in keys: - val = state_dict.pop(key) - key = rename_keys(key) - if "in_proj_weight" in key: - # split fused qkv proj - state_dict[key.replace("in_proj_weight", "q_proj.weight")] = val[:hidden_size, :] - state_dict[key.replace("in_proj_weight", "k_proj.weight")] = val[hidden_size : 2 * hidden_size, :] - state_dict[key.replace("in_proj_weight", "v_proj.weight")] = val[-hidden_size:, :] - elif "enc_to_dec_proj" in key: - enc_dec_proj_state_dict[key[len("enc_to_dec_proj.") :]] = val - else: - state_dict[key] = val - return state_dict, enc_dec_proj_state_dict - - -def decoder_config_from_checkpoint(checkpoint: str) -> MoshiDecoderConfig: - if checkpoint.endswith("small"): - # default config values - hidden_size = 1024 - num_hidden_layers = 24 - num_attention_heads = 16 - elif checkpoint.endswith("medium"): - hidden_size = 1536 - num_hidden_layers = 48 - num_attention_heads = 24 - elif checkpoint.endswith("large"): - hidden_size = 2048 - num_hidden_layers = 48 - num_attention_heads = 32 - else: - raise ValueError( - "Checkpoint should be one of `['small', 'medium', 'large']` for the mono checkpoints, " - "`['facebook/moshi-stereo-small', 'facebook/moshi-stereo-medium', 'facebook/moshi-stereo-large']` " - f"for the stereo checkpoints, or a custom checkpoint with the checkpoint size as a suffix, got {checkpoint}." - ) - - if "stereo" in checkpoint: - audio_channels = 2 - num_codebooks = 8 - else: - audio_channels = 1 - num_codebooks = 4 - - config = MoshiDecoderConfig( - hidden_size=hidden_size, - ffn_dim=hidden_size * 4, - num_hidden_layers=num_hidden_layers, - num_attention_heads=num_attention_heads, - num_codebooks=num_codebooks, - audio_channels=audio_channels, - ) - return config - - -@torch.no_grad() -def convert_moshi_checkpoint( - checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu", safe_serialization=False -): - fairseq_model = Moshi.get_pretrained(checkpoint, device=device) - decoder_config = decoder_config_from_checkpoint(checkpoint) - - decoder_state_dict = fairseq_model.lm.state_dict() - decoder_state_dict, enc_dec_proj_state_dict = rename_state_dict( - decoder_state_dict, hidden_size=decoder_config.hidden_size - ) +logger = logging.get_logger("transformers.models.mimi") - text_encoder = T5EncoderModel.from_pretrained("google-t5/t5-base") - audio_encoder = EncodecModel.from_pretrained("facebook/encodec_32khz") - decoder = MoshiForCausalLM(decoder_config).eval() - # load all decoder weights - expect that we'll be missing embeddings and enc-dec projection - missing_keys, unexpected_keys = decoder.load_state_dict(decoder_state_dict, strict=False) +def assert_param_count(model_1, model_2): + count_1 = sum(p[1].numel() for p in model_1.named_parameters() if "final_proj" not in p[0]) + count_2 = sum(p[1].numel() for p in model_2.named_parameters() if "final_proj" not in p[0]) + assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}" - for key in missing_keys.copy(): - if key.startswith(("text_encoder", "audio_encoder")) or key in EXPECTED_MISSING_KEYS: - missing_keys.remove(key) - if len(missing_keys) > 0: - raise ValueError(f"Missing key(s) in state_dict: {missing_keys}") +def param_count(model): + return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0]) - if len(unexpected_keys) > 0: - raise ValueError(f"Unexpected key(s) in state_dict: {unexpected_keys}") - # init the composite model - model = MoshiForConditionalGeneration(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder) - - # load the pre-trained enc-dec projection (from the decoder state dict) - model.enc_to_dec_proj.load_state_dict(enc_dec_proj_state_dict) - - # check we can do a forward pass - input_ids = torch.arange(0, 2 * decoder_config.num_codebooks, dtype=torch.long).reshape(2, -1) - decoder_input_ids = input_ids.reshape(2 * decoder_config.num_codebooks, -1) +def _grab_best_device(use_gpu=True): + if torch.cuda.device_count() > 0 and use_gpu: + device = "cuda" + else: + device = "cpu" + return torch.device(device) + + +convert_list = [ + # GENERAL + ("out_norm", "decoder.model.norm"), + ("depformer_emb", "depth_decoder.emb"), + ("depformer_text_emb", "depth_decoder.text_emb"), + ("text_emb", "decoder.model.emb"), + + ("emb", "embed_tokens"), + ("text_linear", "decoder.lm_head"), + + ("depformer", "depth_decoder"), + ("transformer", "decoder.model"), + + # TRANSFORMERS PART + ("gating.linear_in", "mlp.fc1"), + ("gating.linear_out", "mlp.fc2"), + ("self_attn.out_proj", "self_attn.o_proj"), + ("norm1", "input_layernorm"), + ("norm2", "post_attention_layernorm"), + ("layer_scale_1", "self_attn_layer_scale"), + ("layer_scale_2", "mlp_layer_scale"), + ("alpha", "weight") +] + +def _preprocess_state_dict(state_dict, config): + # Moshi original weights are using a gating mechanism + + # pattern for depth transformer: + # stack(gating.{i}.linear_in)->mlp.fc1 + # stack(gating.{i}.linear_out)->mlp.fc2 + + for layer_idx in range(config.depth_num_hidden_layers): + linear_layers_in = [state_dict.pop(f"depformer.layers.{layer_idx}.gating.{i}.linear_in.weight") for i in range(config.num_codebooks)] + linear_layers_out = [state_dict.pop(f"depformer.layers.{layer_idx}.gating.{i}.linear_out.weight") for i in range(config.num_codebooks)] + + state_dict[f"depth_decoder.layers.{layer_idx}.mlp.fc1.weight"] = torch.stack(linear_layers_in) + state_dict[f"depth_decoder.layers.{layer_idx}.mlp.fc2.weight"] = torch.stack(linear_layers_out) + + input_projections = [] + lm_heads = [] + for codebook_idx in range(config.num_codebooks): + input_projections.append(state_dict.pop(f"depformer_in.{codebook_idx}.weight")) + lm_heads.append(state_dict.pop(f"linears.{codebook_idx}.weight")) + + state_dict["depth_decoder.input_projections.weight"] = torch.stack(input_projections, dim = 1) + state_dict["depth_decoder.lm_heads.weight"] = torch.stack(lm_heads, dim = 1) + + return state_dict + + +def _convert_model( + state_dict, + hf_model, + convert_list, + device, + config, + unwanted_prefix=None, +): + hidden_size = config.hidden_size + head_dim = config.head_dim + num_heads = int(config.hidden_size // config.head_dim) + num_key_value_heads = config.num_key_value_heads + key_value_head_dim = config.num_key_value_heads * head_dim + + + depth_hidden_size = config.depth_hidden_size + depth_head_dim = config.depth_head_dim + depth_num_heads = int(config.depth_hidden_size // config.depth_head_dim) + depth_num_key_value_heads = config.depth_num_key_value_heads + depth_key_value_head_dim = config.depth_num_key_value_heads * depth_head_dim + + state_dict = _preprocess_state_dict(state_dict, config) + + # permute for sliced rotary + def permute(w, n_heads, dim1=hidden_size, dim2=hidden_size): + return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + for k, v in list(state_dict.items()): + if "audio_encoder" not in k: + new_k = k if unwanted_prefix is None else k[len(unwanted_prefix) :] + for old_layer_name, new_layer_name in convert_list: + if old_layer_name in new_k: + new_k = new_k.replace(old_layer_name, new_layer_name) + + if "in_proj_weight" in new_k: + # split qkv into query key and value + mixed_qkv = state_dict.pop(k) + qkv_dim = mixed_qkv.size(0) // 3 + + query_layer = mixed_qkv[:qkv_dim] + key_layer = mixed_qkv[qkv_dim : qkv_dim * 2] + value_layer = mixed_qkv[qkv_dim * 2 :] + if "depth_decoder" in new_k: + state_dict[new_k.replace("in_proj_weight", "q_proj.weight")] = query_layer.view(config.num_codebooks, -1, query_layer.shape[-1]) + state_dict[new_k.replace("in_proj_weight", "k_proj.weight")] = key_layer.view(config.num_codebooks, -1, key_layer.shape[-1]) + state_dict[new_k.replace("in_proj_weight", "v_proj.weight")] = value_layer.view(config.num_codebooks, -1, value_layer.shape[-1]) + else: + state_dict[new_k.replace("in_proj_weight", "q_proj.weight")] = permute( + query_layer, num_heads, hidden_size, hidden_size) + state_dict[new_k.replace("in_proj_weight", "k_proj.weight")] = permute( + key_layer, num_key_value_heads, dim1=key_value_head_dim, dim2=hidden_size + ) + state_dict[new_k.replace("in_proj_weight", "v_proj.weight")] = value_layer + else: + state_dict[new_k] = state_dict.pop(k) + + extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys()) + missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys()) + if len(extra_keys) != 0: + raise ValueError(f"extra keys found: {extra_keys}") + if len(missing_keys) != 0: + raise ValueError(f"missing keys: {missing_keys}") + hf_model.load_state_dict(state_dict, strict=True) + n_params = param_count(hf_model) + + logger.info(f"model loaded: {round(n_params/1e6,1)}M params") + + hf_model.eval() + hf_model.to(device) + del state_dict + + return hf_model - with torch.no_grad(): - logits = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits - if logits.shape != (2 * decoder_config.num_codebooks, 1, 2048): - raise ValueError("Incorrect shape for logits") +@torch.no_grad() +def convert_checkpoint( + checkpoint_path, + pytorch_dump_folder_path, + mimi_repo_id, + config_path=None, + repo_id=None, +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + device = _grab_best_device() + + mimi_model = MimiModel.from_pretrained(mimi_repo_id) + + if config_path is not None: + config = MoshiConfig.from_pretrained(config_path) + else: + audio_encoder_config = mimi_model.config + config = MoshiConfig.from_audio_encoder_config(audio_encoder_config) - # now construct the processor - tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") - feature_extractor = AutoFeatureExtractor.from_pretrained( - "facebook/encodec_32khz", padding_side="left", feature_size=decoder_config.audio_channels - ) + model = MoshiForConditionalGeneration(config) - processor = MoshiProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer) + # feature_extractor = EncodecFeatureExtractor( + # feature_size=config.audio_channels, + # sampling_rate=config.sampling_rate, + # ) + # feature_extractor.save_pretrained(pytorch_dump_folder_path) - # set the appropriate bos/pad token ids - model.generation_config.decoder_start_token_id = 2048 - model.generation_config.pad_token_id = 2048 + original_checkpoint = safetensors.torch.load_file(checkpoint_path) + if "best_state" in original_checkpoint: + # we might have a training state saved, in which case discard the yaml results and just retain the weights + original_checkpoint = original_checkpoint["best_state"] - # set other default generation config params - model.generation_config.max_length = int(30 * audio_encoder.config.frame_rate) - model.generation_config.do_sample = True - model.generation_config.guidance_scale = 3.0 + audio_checkpoint = mimi_model.state_dict() + original_checkpoint.update({f"audio_encoder.{key}": value for (key, value) in audio_checkpoint.items()}) + + model = _convert_model(original_checkpoint, model, convert_list, device, config) - if pytorch_dump_folder is not None: - Path(pytorch_dump_folder).mkdir(exist_ok=True) - logger.info(f"Saving model {checkpoint} to {pytorch_dump_folder}") - model.save_pretrained(pytorch_dump_folder, safe_serialization=safe_serialization) - processor.save_pretrained(pytorch_dump_folder) + model.save_pretrained(pytorch_dump_folder_path) if repo_id: - logger.info(f"Pushing model {checkpoint} to {repo_id}") - model.push_to_hub(repo_id, safe_serialization=safe_serialization) - processor.push_to_hub(repo_id) + print("Pushing to the hub...") + # feature_extractor.push_to_hub(repo_id) + model.push_to_hub(repo_id, private=True) if __name__ == "__main__": parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--checkpoint", - default="small", - type=str, - help="Checkpoint size of the Moshi model you'd like to convert. Can be one of: " - "`['small', 'medium', 'large']` for the mono checkpoints, " - "`['facebook/moshi-stereo-small', 'facebook/moshi-stereo-medium', 'facebook/moshi-stereo-large']` " - "for the stereo checkpoints, or a custom checkpoint with the checkpoint size as a suffix.", - ) + parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint") + parser.add_argument("--mimi_repo_id", required=True, default=None, type=str, help="Repository id to HF Mimi.") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") parser.add_argument( - "--pytorch_dump_folder", - required=True, - default=None, - type=str, - help="Path to the output PyTorch model directory.", + "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model." ) parser.add_argument( "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub." ) - parser.add_argument( - "--device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda." - ) - parser.add_argument( - "--safe_serialization", - action="store_true", - help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).", - ) args = parser.parse_args() - convert_moshi_checkpoint(args.checkpoint, args.pytorch_dump_folder, args.push_to_hub) + convert_checkpoint( + args.checkpoint_path, + args.pytorch_dump_folder_path, + args.mimi_repo_id, + args.config_path, + args.push_to_hub, + ) diff --git a/src/transformers/models/moshi/generation_configuration_moshi.py b/src/transformers/models/moshi/generation_configuration_moshi.py new file mode 100644 index 00000000000000..954bab35e45105 --- /dev/null +++ b/src/transformers/models/moshi/generation_configuration_moshi.py @@ -0,0 +1,87 @@ +# coding=utf-8 +# Copyright 2024 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Moshigeneration configuration""" + +import copy +from typing import Dict + +from ...generation.configuration_utils import GenerationConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MoshiGenerationConfig(GenerationConfig): + model_type = "moshi" + is_composition = True + + # TODO (joao): nested from_dict + + def __init__( + self, + depth_decoder_config: Dict = None, + **kwargs, + ): + """Class that holds a generation configuration for [`MoshiForConditionalGeneration`]. + + The [`MoshiForConditionalGeneration`] model needs to encapsulates two generation config, the main one, and one for its depth decoder. + + This configuration inherit from [`GenerationConfig`] and can be used to control the model generation. Read the + documentation from [`GenerationConfig`] for more information. + + Args: + depth_decoder_config (`Dict`, *optional*): + Depth decoder generation configuration. + # TODO(YL): kwargs + + """ + super.__init__(**kwargs) + if depth_decoder_config is None: + depth_decoder_config = {} + logger.info("depth_decoder_config is None. initializing the semantic model with default values.") + + self.depth_decoder_config = GenerationConfig(**depth_decoder_config) + + @classmethod + def from_depth_decoder_config( + cls, + depth_decoder_config: GenerationConfig, + **kwargs, + ): + r""" + Instantiate a [`MoshiGenerationConfig`] (or a derived class) from Moshi depth decoder generation configuration. + + Returns: + [`MoshiGenerationConfig`]: An instance of a configuration object + """ + return cls( + depth_decoder_config=depth_decoder_config.to_dict(), + **kwargs, + ) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + + output["depth_decoder_config"] = self.depth_decoder_config.to_dict() + + output["model_type"] = self.__class__.model_type + return output diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 0ee858fd0b51a5..6096764e18745a 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -37,10 +37,9 @@ ) from ...modeling_outputs import ( BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, - CausalLMOutputWithCrossAttentions, ModelOutput, Seq2SeqLMOutput, + CausalLMOutputWithPast, ) from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -48,6 +47,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -71,8 +71,9 @@ logging, replace_return_docstrings, ) +from ...pytorch_utils import ALL_LAYERNORM_LAYERS -from .configuration_moshi import MoshiConfig, MoshiDecoderConfig +from .configuration_moshi import MoshiConfig if is_flash_attn_2_available(): @@ -87,6 +88,44 @@ _CHECKPOINT_FOR_DOC = "kyutai/moshiko" +@dataclass +class MoshiCausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ @@ -107,6 +146,119 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaRMSNorm with Gemma->Moshi +class MoshiRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Moshi is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +ALL_LAYERNORM_LAYERS.append(MoshiRMSNorm) + + +class MoshiFlexibleLinear(nn.Module): + def __init__(self, input_size, output_size, num_layers): + super().__init__() + # Stack the weights for N layers into a single tensor (num_layers, output_size, input_size) + self.weight = nn.Parameter(torch.randn(num_layers, output_size, input_size)) + + def forward(self, x, layer_idx=None): + """ + `MoshiFlexibleLinear` creates one linear layer per codebook. There's multiple ways to use it. + In the default case, `sequence_length=num_layers`, so each element of the sequence will be matmul to the weights corresponding to its index on the sequence. + + For more advanced cases, one can specify which codebook's layer(s) to use with `layer_idx`. + If `layer_idx` indicates a single integer, all of the element of the sequence will be matmul to this single codebook's layer. + But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`. + + + Args: + x (`torch.FloatTensor): input to the layer of shape `(batch, num_layers, embed_dim)` or of shape `(batch, seq_length, embed_dim)` + layer_idx (`torch.Tensor`, *optional*): + Can be used to specify which codebook's layers(s) to use. + If it's a tensor of shape `(seq_length,)`, will matmul each element of the sequence to the corresponding weights. + But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`. + """ + if layer_idx is not None and layer_idx.dim()==0: + # Single layer case: select a specific layer (batch_size, 1 , input_size) -> (batch_size, 1, output_size) + return torch.matmul(x, self.weight[layer_idx].T) + elif layer_idx is not None: + # Use torch.gather to select the corresponding weights for each sample + selected_weights = torch.index_select(self.weight, 0, layer_idx) + return torch.einsum('bnh,noh->bno', x, selected_weights) + + # Multiple layers case: + # use einsum to batch the operations (batch_size, num_layers, input_size) -> (batch_size, num_layers, output_size) + return torch.einsum('bnh,noh->bno', x, self.weight) # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Moshi @@ -121,7 +273,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.register_buffer("inv_freq", inv_freq, persistent=False) @torch.no_grad() - # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward # TODO(joao): add me back asap :) def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] @@ -175,19 +327,27 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -class MoshiMLP(nn.Module): - def __init__(self, config): +class MoshiGatingMLP(nn.Module): + def __init__(self, config, num_layers=1, is_depth_mlp=False): super().__init__() - self.config = config + self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) - - # Copied from transformers.models.clip.modeling_clip.CLIPMLP.forward - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) + ffn_dim = config.ffn_dim if not is_depth_mlp else config.depth_ffn_dim + hidden_size = config.hidden_size if not is_depth_mlp else config.depth_hidden_size + if num_layers == 1: + self.fc1 = nn.Linear(hidden_size, ffn_dim, bias=False) + self.fc2 = nn.Linear(ffn_dim // 2, hidden_size, bias=False) + else: + self.fc1 = MoshiFlexibleLinear(hidden_size, ffn_dim, num_layers) + self.fc2 = MoshiFlexibleLinear(ffn_dim // 2, hidden_size, num_layers) + + def forward(self, hidden_states: torch.Tensor, layer_idx:int=None) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) if layer_idx is None else self.fc1(hidden_states, layer_idx) + + batch_size, sequence_length, _ = hidden_states.shape + hidden_states = hidden_states.view(batch_size, sequence_length, 2, -1) + hidden_states = self.activation_fn(hidden_states[..., 0, :]) * hidden_states[..., 1, :] + hidden_states = self.fc2(hidden_states) if layer_idx is None else self.fc2(hidden_states, layer_idx) return hidden_states @@ -204,14 +364,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# copied from transformers.models.gemma.modeling_gemma.GemmaAttention with Gemma->Moshi class MoshiAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None): + def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, is_depth_attention=False): super().__init__() self.config = config self.layer_idx = layer_idx + self.is_depth_attention = is_depth_attention if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " @@ -220,15 +380,15 @@ def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None): ) self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads + self.hidden_size = config.hidden_size if not is_depth_attention else config.depth_hidden_size + self.num_heads = config.num_attention_heads if not is_depth_attention else config.depth_num_attention_heads + self.head_dim = config.head_dim if not is_depth_attention else config.depth_head_dim + self.num_key_value_heads = config.num_key_value_heads if not is_depth_attention else config.depth_num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings + self.max_position_embeddings = config.max_position_embeddings if not is_depth_attention else config.depth_max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True - self.scaling = 1 / math.sqrt(config.head_dim) + self.scaling = 1 / math.sqrt(self.head_dim) if self.hidden_size % self.num_heads != 0: raise ValueError( @@ -236,16 +396,25 @@ def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None): f" and `num_heads`: {self.num_heads})." ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + if not is_depth_attention: + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + else: + self.q_proj = MoshiFlexibleLinear(self.hidden_size, self.num_heads * self.head_dim, num_layers=config.num_codebooks) + self.k_proj = MoshiFlexibleLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, num_layers=config.num_codebooks) + self.v_proj = MoshiFlexibleLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, num_layers=config.num_codebooks) + self.o_proj = MoshiFlexibleLinear(self.num_heads * self.head_dim, self.hidden_size, num_layers=config.num_codebooks) + + # rotary embeddings are not used in the depth decoder self.rotary_emb = MoshiRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, - ) + ) if not is_depth_attention else None + # Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward with Gemma->Moshi def forward( self, hidden_states: torch.Tensor, @@ -258,16 +427,17 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, position_ids) # Ignore copy + key_states = self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, position_ids) # Ignore copy + value_states = self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, position_ids) # Ignore copy query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if self.rotary_emb is not None: # Ignore copy + cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -297,7 +467,7 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, position_ids) # Ignore copy if not output_attentions: attn_weights = None @@ -305,7 +475,7 @@ def forward( return attn_output, attn_weights, past_key_value -# copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Moshi +# Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Moshi class MoshiFlashAttention2(MoshiAttention): """ Moshi flash attention module. This module inherits from `MoshiAttention` as the weights of the module stays @@ -352,8 +522,10 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if self.rotary_emb is not None: # Ignore copy + cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy + if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -416,7 +588,7 @@ def forward( return attn_output, attn_weights, past_key_value -# copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Moshi +# Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Moshi class MoshiSdpaAttention(MoshiAttention): """ Moshi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -462,8 +634,10 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if self.rotary_emb is not None: # Ignore copy + cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy + if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -512,15 +686,21 @@ def forward( } class MoshiDecoderLayer(nn.Module): - def __init__(self, config: MoshiConfig, layer_idx: int): + def __init__(self, config: MoshiConfig, layer_idx: int, use_flexible_linear: bool=False, is_depth_layer=False): super().__init__() - self.hidden_size = config.hidden_size - + self.is_depth_layer = is_depth_layer + self.hidden_size = config.hidden_size if not is_depth_layer else config.depth_hidden_size + self.use_flexible_linear = use_flexible_linear + + # depth_ffn_dim + # depth_hidden_size + # depth doesn't use pos embedding + self.self_attn = MOSHI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) - self.mlp = MoshiMLP(config) - self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) + self.mlp = MoshiGatingMLP(config) if not use_flexible_linear else MoshiGatingMLP(config, config.num_codebooks, is_depth_layer) + self.input_layernorm = MoshiRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MoshiRMSNorm(self.hidden_size, eps=config.rms_norm_eps) def forward( @@ -573,7 +753,7 @@ def forward( # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states) if not self.use_flexible_linear else self.mlp(hidden_states, cache_position) hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -587,9 +767,171 @@ def forward( return outputs +# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenPreTrainedModel with Musicgen->Moshi +class MoshiPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MoshiConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MoshiDecoderLayer", "MoshiAttention"] + _supports_flash_attn_2 = True + _supports_sdpa = True + main_input_name = "input_ids" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MOSHI_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MoshiConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +# TODO: update +MOSHI_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `input_ids` and `inputs_embeds` are both unset, `inputs_embeds` takes the value + of `inputs_embeds`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +MOSHI_DECODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): + Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes. + + Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes, + such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + + + + The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks, + target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If + you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of + frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks, + target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as + `input_ids`. + + + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of + the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" -# Copied from transformers.models.mimi.modeling_mimi.MimiTransformerModel with Mimi->Moshi, TransformerModel->Decoder, TransformerLayer->DecoderLayer -class MoshiDecoder(nn.Module): +# TODO: DO it as a depth decoder +class MoshiDepthDecoder(MoshiPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoshiTransformerLayer`] @@ -598,27 +940,54 @@ class MoshiDecoder(nn.Module): """ def __init__(self, config: MoshiConfig): - super().__init__() + super().__init__(config) + + self.text_embed_tokens = nn.Embedding(config.vocab_size + 1, config.depth_hidden_size) + + # the last codebook is never used as input + self.embed_tokens = nn.ModuleList( + [nn.Embedding(config.audio_vocab_size + 1, config.depth_hidden_size) for _ in range(config.num_codebooks - 1)] + ) + self.input_projections = MoshiFlexibleLinear(config.hidden_size, config.depth_hidden_size, config.num_codebooks) + + # TODO: remove if relevant + # nn.ModuleList( + # [nn.Linear(config.hidden_size, config.depth_hidden_size, bias=False) for _ in range(config.num_codebooks)] + # ) + self.layers = nn.ModuleList( - [MoshiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [MoshiDecoderLayer(config, layer_idx, use_flexible_linear=True, is_depth_layer=True) for layer_idx in range(config.depth_num_hidden_layers)] ) + + self.lm_heads = MoshiFlexibleLinear(config.depth_hidden_size, config.audio_vocab_size, config.num_codebooks) + + # TODO: remove if relevant + # nn.ModuleList( + # [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)] + # ) + self._attn_implementation = config._attn_implementation self.gradient_checkpointing = False self.config = config - def forward( + def forward( # TODO: update docstrings entirely self, - hidden_states: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + input_ids: Optional[torch.LongTensor] = None, # sequence of oracle input ids, i.e it must be the input ids that are predicted by the decoder # (B, S) + audio_codes: Optional[torch.Tensor] = None, # Same, shoud be oracle audio codebooks, but also with one channel less: # (B, C, S) or C-1 + last_hidden_states: torch.LongTensor = None, # use 8 times (B, S, H_in) | (B*S, H_in) + attention_mask: Optional[torch.BoolTensor] = None, + padding_mask: Optional[torch.BoolTensor] = None, + past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: """ Args: @@ -635,7 +1004,7 @@ def forward( Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] @@ -677,7 +1046,26 @@ def forward( more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ + # here, we actually predict a sequence length of C + # independtly from the batch size and sequence length + # 1/ input ids is passed through text_embed_tokens -> (B, S, H) H=1024 + # 2/ each codebooks is passed through the embedding layer ase well -> (B, C-1, S, H) + # 3/ concat the two precedent results and get (B, C, S ,H) + # 4/ then we also pass the last hidden states through the input projection layers, one for each codebooks + # we get (B, C, S, H) + # 5/ sum one and another (B, C, S, H) + # 6/ pass (B*S, C, H) through the model and get (B*S, C, H_out) + # 7/ for each codebook, pass it through its own lm heads: (B*S, C, H) + # 8/ predict the codebook C1, C2 ... -> (B, S, C, H) + + # TODO: can we suppose B*S each time instead of B,S + # in the generation mode, it's different: + # S=1 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -688,18 +1076,12 @@ def forward( if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False - - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache) and not self.training: + + if use_cache and past_key_values is None and not self.training: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - return_legacy_cache = True - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " - "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" - ) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -709,12 +1091,30 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, hidden_states, cache_position, past_key_values, use_cache, output_attentions - ) - - hidden_states = hidden_states + + # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used + if inputs_embeds is None: + if input_ids is None and audio_codes is None: + raise ValueError("You must provide at least one of `input_ids`, `inputs_embeds`, `input_values` and `audio_codes`.") + + if input_ids is not None: + inputs_embeds = self.text_embed_tokens(input_ids) + + # TODO: this should actually use embed_tokens depending on which position ids is asked for + # We should actually use one codebook embedding per element of the sequence + if audio_codes is not None: # TODO(YL): make sure it's C-1 + audio_inputs_embeds = sum([self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])]) + inputs_embeds = audio_inputs_embeds if inputs_embeds is None else audio_inputs_embeds + inputs_embeds + + # TODO: check position ids shape + inputs_embeds += self.input_projections(last_hidden_states, position_ids) + + + causal_mask = None + if attention_mask is not None: + causal_mask = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values, output_attentions + ) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -760,61 +1160,310 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() + + + # TODO: check position ids shape + # TODO: remove the float() operation in v4.46 + logits = self.lm_heads(hidden_states, position_ids).float() - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + loss = None + if labels is not None: + # TODO: it's probably not the right way to do it + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, + if not return_dict: + return tuple(v for v in [loss, logits, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) - - # Copied from transformers.models.mistral.modeling_mistral.MistralModel._update_causal_mask - def _update_causal_mask( + + def prepare_inputs_for_generation( self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - use_cache: bool, - output_attentions: bool, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, ): - if self._attn_implementation == "flash_attention_2": - if attention_mask is not None and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask + # TODO(YL): on the first step, `input_ids` is used + # then every new input_ids are passed as `audio_codes` instead! + # Make sure cache_positions is correct + # do we use num_logits_to_keep? + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + + + +@add_start_docstrings( + "The bare Moshi Model outputting raw hidden-states without any specific head on top.", + MOSHI_START_DOCSTRING, +) +# Copied from transformers.models.gemma.modeling_gemma.GemmaModel with Gemma->Moshi +class MoshiModel(MoshiPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoshiDecoderLayer`] + + Args: + config: MoshiConfig + """ + + def __init__(self, config: MoshiConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size + 1, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MoshiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = MoshiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MOSHI_DECODER_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False # noqa: F841 + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True # noqa: F841 + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # normalized + # Moshi downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" + ) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - - # cache_position must be valid here no matter which cache we use - past_seen_tokens = cache_position[0] if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, is_training=self.training, ): return None @@ -822,13 +1471,8 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - # SlidingWindowCache - if using_sliding_window_cache: - target_length = max(sequence_length, self.config.sliding_window) - # StaticCache - elif using_static_cache: + if using_static_cache: target_length = past_key_values.get_max_length() - # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] @@ -836,31 +1480,17 @@ def _update_causal_mask( else past_seen_tokens + sequence_length + 1 ) - if attention_mask is not None and attention_mask.dim() == 4: - causal_mask = attention_mask - else: - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if self.config.sliding_window is not None: - if not using_sliding_window_cache or sequence_length > self.config.sliding_window: - exclude_mask.bitwise_or_( - torch.arange(target_length, device=device) - <= (cache_position.reshape(-1, 1) - self.config.sliding_window) - ) - causal_mask *= exclude_mask - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None @@ -875,768 +1505,225 @@ def _update_causal_mask( return causal_mask +@add_start_docstrings( + "The Moshi decoder model with a text language modelling head on top. Only usable for text.", + MOSHI_START_DOCSTRING, +) +# Copied from transformers.models.musicgen.modeling_gemma.GemmaForCausalLM with GEMMA->MOSHI,Gemma->Moshi,gemma-7b->moshiko,google->kyutai, CausalLM->MoshiCausalLM +class MoshiForCausalLM(MoshiPreTrainedModel): + _tied_weights_keys = None # Ignore copy + def __init__(self, config): + super().__init__(config) + self.model = MoshiModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) -# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenPreTrainedModel with Musicgen->Moshi -class MoshiPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = MoshiDecoderConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["MoshiDecoderLayer", "MoshiAttention"] - _supports_flash_attn_2 = True - _supports_sdpa = True - - def _init_weights(self, module): - std = self.config.initializer_factor - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - + # Initialize weights and apply final processing + self.post_init() -MOSHI_START_DOCSTRING = r""" + def get_input_embeddings(self): + return self.model.embed_tokens - The Moshi model was proposed in [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by - Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi, Alexandre Défossez. It is an - encoder decoder transformer trained on the task of conditional music generation + def set_input_embeddings(self, value): + self.model.embed_tokens = value - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) + def get_output_embeddings(self): + return self.lm_head - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings - Parameters: - config ([`MoshiConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" + def set_decoder(self, decoder): + self.model = decoder -MOSHI_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. + def get_decoder(self): + return self.model - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. + @add_start_docstrings_to_model_forward(MOSHI_DECODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoshiCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, MoshiCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. + Returns: - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary, corresponding to the sequence of audio codes. + Example: - Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes, - such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - - - The `decoder_input_ids` will automatically be converted from shape `(batch_size * num_codebooks, - target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If - you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of - frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks, - target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as - `decoder_input_ids`. - - - - decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, - 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded - representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be - input (see `past_key_values`). This is useful if you want more control over how to convert - `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - - If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value - of `inputs_embeds`. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -MOSHI_DECODER_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): - Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes. - - Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes, - such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - - - - The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks, - target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If - you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of - frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks, - target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as - `input_ids`. - - - - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of - the decoder. - encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): - Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values - selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing - cross-attention on hidden heads. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - - -@add_start_docstrings( - "The bare Moshi decoder model outputting raw hidden-states without any specific head on top.", - MOSHI_START_DOCSTRING, -) -# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenModel with MUSICGEN->MOSHI,Musicgen->Moshi -class MoshiModel(MoshiPreTrainedModel): - def __init__(self, config: MoshiDecoderConfig): - super().__init__(config) - self.decoder = MoshiDecoder(config) - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.decoder.embed_tokens - - def set_input_embeddings(self, value): - self.decoder.embed_tokens = value - - def get_decoder(self): - return self.decoder - - @add_start_docstrings_to_model_forward(MOSHI_DECODER_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - input_ids=input_ids, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - encoder_hidden_states=encoder_hidden_states, - head_mask=head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return decoder_outputs - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - ) - - -@add_start_docstrings( - "The Moshi decoder model with a language modelling head on top.", - MOSHI_START_DOCSTRING, -) -# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM with MUSICGEN->MOSHI,Musicgen->Moshi,musicgen->moshi,MusicGen->Moshi -class MoshiForCausalLM(MoshiPreTrainedModel): - def __init__(self, config: MoshiDecoderConfig): - super().__init__(config) - - self.model = MoshiModel(config) - - self.num_codebooks = config.num_codebooks - self.lm_heads = nn.ModuleList( - [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)] - ) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.decoder.embed_tokens - - def set_input_embeddings(self, value): - self.model.decoder.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_heads - - def set_output_embeddings(self, new_embeddings): - self.lm_heads = new_embeddings - - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - - @add_start_docstrings_to_model_forward(MOSHI_DECODER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - Returns: - """ - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (labels is not None) and (input_ids is None and inputs_embeds is None): - input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id) - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - head_mask=head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - lm_logits = torch.stack([head(hidden_states) for head in self.lm_heads], dim=1) - - loss = None - if labels is not None: - # since encoder hidden states have been concatenated to the decoder hidden states, - # we take the last timestamps corresponding to labels - logits = lm_logits[:, :, -labels.shape[1] :] - - loss_fct = CrossEntropyLoss() - loss = torch.zeros([], device=self.device) - - # per codebook cross-entropy - # -100 labels are ignored - labels = labels.masked_fill(labels == self.config.pad_token_id, -100) - - # per codebook cross-entropy - # ref: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/moshi.py#L242-L243 - for codebook in range(self.config.num_codebooks): - codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1]) - codebook_labels = labels[..., codebook].contiguous().view(-1) - loss += loss_fct(codebook_logits, codebook_labels) - - loss = loss / self.config.num_codebooks - - # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size) - lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:]) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - head_mask=None, - cross_attn_head_mask=None, - past_key_values=None, - use_cache=True, - delay_pattern_mask=None, - guidance_scale=None, - **kwargs, - ): - if delay_pattern_mask is None: - input_ids, delay_pattern_mask = self.build_delay_pattern_mask( - input_ids, - pad_token_id=self.generation_config.pad_token_id, - max_length=self.generation_config.max_length, - ) - - # apply the delay pattern mask - input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask) - - if guidance_scale is not None and guidance_scale > 1: - # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these - # before sampling) - input_ids = input_ids.repeat((2, 1)) - if attention_mask is not None: - attention_mask = attention_mask.repeat((2, 1)) - - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "encoder_hidden_states": encoder_hidden_states, - "encoder_attention_mask": encoder_attention_mask, - "head_mask": head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "past_key_values": past_key_values, - "use_cache": use_cache, - } - - def build_delay_pattern_mask(self, input_ids: torch.LongTensor, pad_token_id: int, max_length: int = None): - """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by - one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there - are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, - seq_len)`: - - [P, -1, -1, -1, -1, P, P, P] - - [P, P, -1, -1, -1, -1, P, P] - - [P, P, P, -1, -1, -1, -1, P] - - [P, P, P, P, -1, -1, -1, -1] - where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include - a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the - mask is set to the value in the prompt: - - [P, a, b, -1, -1, P, P, P] - - [P, P, c, d, -1, -1, P, P] - - [P, P, P, e, f, -1, -1, P] - - [P, P, P, P, g, h, -1, -1] - where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1 - tokens in our prediction. - """ - # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len) - input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1]) - bsz, num_codebooks, seq_len = input_ids.shape - - max_length = max_length if max_length is not None else self.generation_config.max_length - input_ids_shifted = ( - torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1 - ) - - channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks - # we only apply the mask if we have a large enough seq len - otherwise we return as is - if max_length < 2 * channel_codebooks - 1: - return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1) - - # fill the shifted ids with the prompt entries, offset by the codebook idx - for codebook in range(channel_codebooks): - if self.config.audio_channels == 1: - # mono channel - loop over the codebooks one-by-one - input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook] - else: - # left/right channels are interleaved in the generated codebooks, so handle one then the other - input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook] - input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1] - - # construct a pattern mask that indicates the positions of padding tokens for each codebook - # first fill the upper triangular part (the EOS padding) - delay_pattern = torch.triu( - torch.ones((channel_codebooks, max_length), dtype=torch.bool), diagonal=max_length - channel_codebooks + 1 - ) - # then fill the lower triangular part (the BOS padding) - delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.bool)) - - if self.config.audio_channels == 2: - # for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion - delay_pattern = delay_pattern.repeat_interleave(2, dim=0) - - mask = ~delay_pattern.to(input_ids.device) - input_ids = mask * input_ids_shifted + ~mask * pad_token_id - - # find the first position to start generating - this is the first place we have the -1 token - # and will always be in the first codebook (since it has no codebook offset) - first_codebook_ids = input_ids[:, 0, :] - start_ids = (first_codebook_ids == -1).nonzero()[:, 1] - if len(start_ids) > 0: - first_start_id = min(start_ids) - else: - # we have no tokens that need to be filled - return entire matrix of input ids - first_start_id = seq_len - - # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len) - pattern_mask = input_ids.reshape(bsz * num_codebooks, -1) - input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1) - return input_ids, pattern_mask - - @staticmethod - def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): - """Apply a delay pattern mask to the decoder input ids, only preserving predictions where - the mask is set to -1, and otherwise setting to the value detailed in the mask.""" - seq_len = input_ids.shape[-1] - decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len] - input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask) - return input_ids - - @torch.no_grad() - def generate( - self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - synced_gpus: Optional[bool] = None, - streamer: Optional["BaseStreamer"] = None, - **kwargs, - ): - """ - - Generates sequences of token ids for models with a language modeling head. - - - - Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the - model's default generation configuration. You can override any `generation_config` by passing the corresponding - parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. - - For an overview of generation strategies and code examples, check out the [following - guide](./generation_strategies). - - - - Parameters: - inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): - The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the - method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` - should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of - `input_ids`, `input_values`, `input_features`, or `pixel_values`. - generation_config (`~generation.GenerationConfig`, *optional*): - The generation configuration to be used as base parametrization for the generation call. `**kwargs` - passed to generate matching the attributes of `generation_config` will override them. If - `generation_config` is not provided, the default will be used, which had the following loading - priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model - configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s - default values, whose documentation should be checked to parameterize generation. - logits_processor (`LogitsProcessorList`, *optional*): - Custom logits processors that complement the default logits processors built from arguments and - generation config. If a logit processor is passed that is already created with the arguments or a - generation config an error is thrown. This feature is intended for advanced users. - stopping_criteria (`StoppingCriteriaList`, *optional*): - Custom stopping criteria that complement the default stopping criteria built from arguments and a - generation config. If a stopping criteria is passed that is already created with the arguments or a - generation config an error is thrown. This feature is intended for advanced users. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - streamer (`BaseStreamer`, *optional*): - Streamer object that will be used to stream the generated sequences. Generated tokens are passed - through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - kwargs (`Dict[str, Any]`, *optional*): - Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be - forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder - specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. - - Return: - [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` - or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. - - If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible - [`~utils.ModelOutput`] types are: - - - [`~generation.GenerateDecoderOnlyOutput`], - - [`~generation.GenerateBeamDecoderOnlyOutput`] - - If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible - [`~utils.ModelOutput`] types are: - - - [`~generation.GenerateEncoderDecoderOutput`], - - [`~generation.GenerateBeamEncoderDecoderOutput`] - """ - # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects - if generation_config is None: - generation_config = self.generation_config - - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs - generation_config.validate() - self._validate_model_kwargs(model_kwargs.copy()) - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - requires_attention_mask = "encoder_outputs" not in model_kwargs - kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None - - # 3. Define model inputs` - input_ids, model_input_name, model_kwargs = self._prepare_model_inputs( - inputs, generation_config.bos_token_id, model_kwargs - ) - batch_size = input_ids.shape[0] // self.num_codebooks - self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device) - - # 4. Define other model kwargs - model_kwargs["use_cache"] = generation_config.use_cache - model_kwargs["guidance_scale"] = generation_config.guidance_scale - - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: - model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor - ) - - # 5. Prepare `max_length` depending on other stopping criteria. - input_ids_length = input_ids.shape[-1] - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None - generation_config = self._prepare_generated_length( - generation_config=generation_config, - has_default_max_length=has_default_max_length, - has_default_min_length=has_default_min_length, - model_input_name=model_input_name, - inputs_tensor=input_ids, - input_ids_length=input_ids_length, - ) - - # 6. Prepare `input_ids` which will be used for auto-regressive generation - # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Moshi) - input_ids, delay_pattern_mask = self.build_delay_pattern_mask( - input_ids, - pad_token_id=generation_config._decoder_start_token_tensor, - max_length=generation_config.max_length, - ) - - if streamer is not None: - streamer.put(input_ids.cpu()) - - # stash the delay mask so that we don't have to recompute it in each forward pass - model_kwargs["delay_pattern_mask"] = delay_pattern_mask - - # 7. determine generation mode - generation_mode = generation_config.get_generation_mode() - - # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG) - if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: - logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) - generation_config.guidance_scale = None - - # 9. prepare distribution pre_processing samplers - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_length, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=None, - logits_processor=logits_processor, - device=input_ids.device, + ```python + >>> from transformers import AutoTokenizer, MoshiForCausalLM + + >>> model = MoshiForCausalLM.from_pretrained("kyutai/moshiko") + >>> tokenizer = AutoTokenizer.from_pretrained("kyutai/moshiko") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # 10. prepare stopping criteria - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, ) - if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # expand input_ids with `num_return_sequences` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_return_sequences, - **model_kwargs, + hidden_states = outputs[0] + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() - # 11. run sample - outputs = self._sample( - input_ids, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, - **model_kwargs, - ) + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) - else: - raise ValueError( - "Got incompatible mode for generation, should be one of greedy or sampling. " - "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`." - ) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return MoshiCausalLMOutputWithPast( + loss=loss, + logits=logits, + last_hidden_state=hidden_states, # Ignore copy + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) - if generation_config.return_dict_in_generate: - output_ids = outputs.sequences + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - output_ids = outputs + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - # apply the pattern mask to the final ids - output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"]) + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device - # revert the pattern delay mask by filtering the pad token id - output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape( - batch_size, self.num_codebooks, -1 - ) + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min - if generation_config.return_dict_in_generate: - outputs.sequences = output_ids - return outputs - else: - return output_ids + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs @add_start_docstrings( - "The composite Moshi model with a text encoder, audio encoder and Moshi decoder, " - "for music generation tasks with one or both of text and audio prompts.", + "The original Moshi model with an audio encoder, a Moshi depth decoder and a Moshi decoder, " + "for speech-to-speech.", MOSHI_START_DOCSTRING, ) -# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForConditionalGeneration with MUSICGEN->MOSHI,Musicgen->Moshi,musicgen->moshi,MusicGen->Moshi,facebook/musicgen-small->kyutai/moshiko -class MoshiForConditionalGeneration(PreTrainedModel): +class MoshiForConditionalGeneration(MoshiPreTrainedModel): # TODO(YL): don't think I can initialize like this for a composite model config_class = MoshiConfig - base_model_prefix = "encoder_decoder" main_input_name = "input_ids" supports_gradient_checkpointing = True _supports_flash_attn_2 = True @@ -1645,124 +1732,84 @@ class MoshiForConditionalGeneration(PreTrainedModel): def __init__( self, config: Optional[MoshiConfig] = None, - text_encoder: Optional[PreTrainedModel] = None, audio_encoder: Optional[PreTrainedModel] = None, decoder: Optional[MoshiForCausalLM] = None, + depth_decoder: Optional[MoshiDepthDecoder] = None, ): - if config is None and (text_encoder is None or audio_encoder is None or decoder is None): + if config is None and (audio_encoder is None or decoder is None or depth_decoder is None): raise ValueError( - "Either a configuration has to be provided, or all three of text encoder, audio encoder and Moshi decoder." + "Either a configuration has to be provided, or all three of Moshi depth decoder, audio encoder and Moshi decoder." ) if config is None: - config = MoshiConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config) + config = MoshiConfig.from_audio_encoder_config(audio_encoder.config) else: if not isinstance(config, self.config_class): raise ValueError(f"Config: {config} has to be of type {self.config_class}") - if config.decoder.cross_attention_hidden_size is not None: - if config.decoder.cross_attention_hidden_size != config.text_encoder.hidden_size: - raise ValueError( - "If `cross_attention_hidden_size` is specified in the Moshi decoder's configuration, it has to be equal" - f" to the text encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" - f" `config.decoder.cross_attention_hidden_size` and {config.text_encoder.hidden_size} for" - " `config.text_encoder.hidden_size`." - ) + # TODO: verify decoder and depth decoder not incompatible + # TODO: does the decoder and depth decoder makes sense? # initialize with config super().__init__(config) - if text_encoder is None: - from ..auto.modeling_auto import AutoModelForTextEncoding - - text_encoder = AutoModelForTextEncoding.from_config(config.text_encoder) - if audio_encoder is None: from ..auto.modeling_auto import AutoModel audio_encoder = AutoModel.from_config(config.audio_encoder) if decoder is None: - decoder = MoshiForCausalLM(config.decoder) + decoder = MoshiForCausalLM(config) + + if depth_decoder is None: + depth_decoder = MoshiDepthDecoder(config) - self.text_encoder = text_encoder - self.audio_encoder = audio_encoder + self.depth_decoder = depth_decoder self.decoder = decoder + self.audio_encoder = audio_encoder - if self.text_encoder.config.to_dict() != self.config.text_encoder.to_dict(): + # We have 2 * num_codebooks audio embedding layers because we have the user input channel and the model output channel. + self.embed_tokens = nn.ModuleList( + [nn.Embedding(config.audio_vocab_size + 1, config.hidden_size) for _ in range(2 * config.num_codebooks)] + ) + + if self.decoder.config.to_dict() != self.config.to_dict(): + logger.warning( + f"Config of the decoder: {self.decoder.__class__} is overwritten by shared config:" + f" {self.config}" + ) + if self.depth_decoder.config.to_dict() != self.config.to_dict(): logger.warning( - f"Config of the text_encoder: {self.text_encoder.__class__} is overwritten by shared text_encoder config:" - f" {self.config.text_encoder}" + f"Config of the depth decoder: {depth_decoder.decoder.__class__} is overwritten by shared config:" + f" {self.config}" ) if self.audio_encoder.config.to_dict() != self.config.audio_encoder.to_dict(): logger.warning( f"Config of the audio_encoder: {self.audio_encoder.__class__} is overwritten by shared audio_encoder config:" f" {self.config.audio_encoder}" ) - if self.decoder.config.to_dict() != self.config.decoder.to_dict(): - logger.warning( - f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" - f" {self.config.decoder}" - ) # make sure that the individual model's config refers to the shared config # so that the updates to the config will be synced - self.text_encoder.config = self.config.text_encoder self.audio_encoder.config = self.config.audio_encoder - self.decoder.config = self.config.decoder - - # text encoder outputs might need to be projected to different dimension for decoder - if ( - self.text_encoder.config.hidden_size != self.decoder.config.hidden_size - and self.decoder.config.cross_attention_hidden_size is None - ): - self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size) - - if self.text_encoder.get_output_embeddings() is not None: - raise ValueError( - f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head" - ) - - decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys()) - if "encoder_hidden_states" not in decoder_signature: - raise ValueError( - "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the " - "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350" - ) + self.decoder.config = self.config + self.depth_decoder.config = self.config + + self.num_codebooks = config.num_codebooks # tie text encoder, decoder weights if config set accordingly self.tie_weights() - def tie_weights(self): - # tie text encoder & decoder if needed - if self.config.tie_encoder_decoder: - # tie text encoder and decoder base model - decoder_base_model_prefix = self.decoder.base_model_prefix - tied_weights = self._tie_encoder_decoder_weights( - self.text_encoder, - self.decoder._modules[decoder_base_model_prefix], - self.decoder.base_model_prefix, - "text_encoder", - ) - # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class - # attributed not an instance member, therefore modifying it will modify the entire class - # Leading to issues on subsequent calls by different tests or subsequent calls. - self._dynamic_tied_weights_keys = tied_weights - def get_audio_encoder(self): return self.audio_encoder - def get_text_encoder(self): - return self.text_encoder - - def get_encoder(self): - # get the text encoder to compute the encoder hidden-states for generation - return self.get_text_encoder() + def get_depth_decoder(self): + return self.depth_decoder def get_decoder(self): return self.decoder def get_input_embeddings(self): - return self.text_encoder.get_input_embeddings() + return self.decoder.get_input_embeddings() def get_output_embeddings(self): return self.decoder.get_output_embeddings() @@ -1791,218 +1838,19 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - @classmethod - def from_sub_models_pretrained( - cls, - text_encoder_pretrained_model_name_or_path: str = None, - audio_encoder_pretrained_model_name_or_path: str = None, - decoder_pretrained_model_name_or_path: str = None, - *model_args, - **kwargs, - ) -> PreTrainedModel: - r""" - Instantiate a text encoder, an audio encoder, and a Moshi decoder from one, two or three base classes of the - library from pretrained model checkpoints. - - - The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train - the model, you need to first set it back in training mode with `model.train()`. - - Params: - text_encoder_pretrained_model_name_or_path (`str`, *optional*): - Information necessary to initiate the text encoder. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - audio_encoder_pretrained_model_name_or_path (`str`, *optional*): - Information necessary to initiate the audio encoder. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): - Information necessary to initiate the decoder. Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - model_args (remaining positional arguments, *optional*): - All remaining positional arguments will be passed to the underlying model's `__init__` method. - - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). - - - To update the text encoder configuration, use the prefix *text_encoder_* for each configuration - parameter. - - To update the audio encoder configuration, use the prefix *audio_encoder_* for each configuration - parameter. - - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. - - To update the parent model configuration, do not use a prefix for each configuration parameter. - - Behaves differently depending on whether a `config` is provided or automatically loaded. - - Example: - - ```python - >>> from transformers import MoshiForConditionalGeneration - - >>> # initialize a moshi model from a t5 text encoder, encodec audio encoder, and moshi decoder - >>> model = MoshiForConditionalGeneration.from_sub_models_pretrained( - ... text_encoder_pretrained_model_name_or_path="google-t5/t5-base", - ... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz", - ... decoder_pretrained_model_name_or_path="kyutai/moshiko", - ... ) - >>> # saving model after fine-tuning - >>> model.save_pretrained("./moshi-ft") - >>> # load fine-tuned model - >>> model = MoshiForConditionalGeneration.from_pretrained("./moshi-ft") - ```""" - - kwargs_text_encoder = { - argument[len("text_encoder_") :]: value - for argument, value in kwargs.items() - if argument.startswith("text_encoder_") - } - - kwargs_audio_encoder = { - argument[len("audio_encoder_") :]: value - for argument, value in kwargs.items() - if argument.startswith("audio_encoder_") - } - - kwargs_decoder = { - argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") - } - - # remove text encoder, audio encoder and decoder kwargs from kwargs - for key in kwargs_text_encoder.keys(): - del kwargs["text_encoder_" + key] - for key in kwargs_audio_encoder.keys(): - del kwargs["audio_encoder_" + key] - for key in kwargs_decoder.keys(): - del kwargs["decoder_" + key] - - # Load and initialize the encoder and decoder - # The distinction between encoder and decoder at the model level is made - # by the value of the flag `is_decoder` that we need to set correctly. - text_encoder = kwargs_text_encoder.pop("model", None) - if text_encoder is None: - if text_encoder_pretrained_model_name_or_path is None: - raise ValueError( - "If `text_encoder_model` is not defined as an argument, a `text_encoder_pretrained_model_name_or_path` has " - "to be defined." - ) - - if "config" not in kwargs_text_encoder: - encoder_config, kwargs_text_encoder = AutoConfig.from_pretrained( - text_encoder_pretrained_model_name_or_path, **kwargs_text_encoder, return_unused_kwargs=True - ) - - if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: - logger.info( - f"Initializing {text_encoder_pretrained_model_name_or_path} as a text_encoder model " - "from a decoder model. Cross-attention and casual mask are disabled." - ) - encoder_config.is_decoder = False - encoder_config.add_cross_attention = False - - kwargs_text_encoder["config"] = encoder_config - - text_encoder = AutoModel.from_pretrained( - text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder - ) - - audio_encoder = kwargs_audio_encoder.pop("model", None) - if audio_encoder is None: - if audio_encoder_pretrained_model_name_or_path is None: - raise ValueError( - "If `audio_encoder_model` is not defined as an argument, an `audio_encoder_pretrained_model_name_or_path` has " - "to be defined." - ) - - if "config" not in kwargs_audio_encoder: - encoder_config, kwargs_audio_encoder = AutoConfig.from_pretrained( - audio_encoder_pretrained_model_name_or_path, **kwargs_audio_encoder, return_unused_kwargs=True - ) - - if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: - logger.info( - f"Initializing {audio_encoder_pretrained_model_name_or_path} as an audio_encoder model " - "from a decoder model. Cross-attention and casual mask are disabled." - ) - encoder_config.is_decoder = False - encoder_config.add_cross_attention = False - - kwargs_audio_encoder["config"] = encoder_config - - audio_encoder = AutoModel.from_pretrained( - audio_encoder_pretrained_model_name_or_path, *model_args, **kwargs_audio_encoder - ) - - decoder = kwargs_decoder.pop("model", None) - if decoder is None: - if decoder_pretrained_model_name_or_path is None: - raise ValueError( - "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " - "to be defined." - ) - - if "config" not in kwargs_decoder: - decoder_config, kwargs_decoder = AutoConfig.from_pretrained( - decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True - ) - - if isinstance(decoder_config, MoshiConfig): - decoder_config = decoder_config.decoder - - if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: - logger.info( - f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" - f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" - f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." - ) - decoder_config.is_decoder = True - decoder_config.add_cross_attention = True - - kwargs_decoder["config"] = decoder_config - - if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: - logger.warning( - f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " - f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " - "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " - "passed to `.from_sub_models_pretrained(...)` are set to `True` or do not pass a " - "`decoder_config` to `.from_sub_models_pretrained(...)`" - ) - - decoder = MoshiForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) - - # instantiate config with corresponding kwargs - config = MoshiConfig.from_sub_models_config( - text_encoder.config, audio_encoder.config, decoder.config, **kwargs - ) - return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config) - @add_start_docstrings_to_model_forward(MOSHI_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.BoolTensor] = None, - input_values: Optional[torch.FloatTensor] = None, + input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it padding_mask: Optional[torch.BoolTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, - encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings - and decide if it's 16 codebooks or (8 and another audio_values) past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, + text_labels: Optional[torch.LongTensor] = None, #TODO: update do docstrings + audio_labels: Optional[torch.LongTensor] = None, #TODO: update do docstrings use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -2019,7 +1867,7 @@ def forward( >>> processor = AutoProcessor.from_pretrained("kyutai/moshiko") >>> model = MoshiForConditionalGeneration.from_pretrained("kyutai/moshiko") - + >>> # TODO(YL): update >>> inputs = processor( ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"], ... padding=True, @@ -2027,23 +1875,17 @@ def forward( ... ) >>> pad_token_id = model.generation_config.pad_token_id - >>> decoder_input_ids = ( + >>> input_ids = ( ... torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long) ... * pad_token_id ... ) - >>> logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits + >>> logits = model(**inputs, input_ids=input_ids).logits >>> logits.shape # (bsz * num_codebooks, tgt_len, vocab_size) torch.Size([8, 1, 2048]) ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict - kwargs_text_encoder = { - argument[len("text_encoder_")]: value - for argument, value in kwargs.items() - if argument.startswith("text_encoder_") - } - kwargs_audio_encoder = { argument[len("audio_encoder_")]: value for argument, value in kwargs.items() @@ -2053,240 +1895,157 @@ def forward( kwargs_decoder = { argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } - - if encoder_outputs is None: - encoder_outputs = self.text_encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - **kwargs_text_encoder, + + # TODO: encode input_values + # TODO: how to deal with both streams, we actually two input_values stream + if input_values is not None and audio_codes is None: + # TODO: should be 16 codebooks + audio_codes = self.audio_encoder.encode(input_values, padding_mask, num_quantizers=self.num_codebooks, **kwargs_audio_encoder)[0] + + + if (text_labels is not None) and (input_ids is None and inputs_embeds is None): + input_ids = shift_tokens_right( + text_labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id ) - elif isinstance(encoder_outputs, tuple): - encoder_outputs = BaseModelOutput(*encoder_outputs) - - encoder_hidden_states = encoder_outputs[0] - - # optionally project encoder_hidden_states - if ( - self.text_encoder.config.hidden_size != self.decoder.config.hidden_size - and self.decoder.config.cross_attention_hidden_size is None - ): - encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) - - if attention_mask is not None: - encoder_hidden_states = encoder_hidden_states * attention_mask[..., None] - - if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): - decoder_input_ids = shift_tokens_right( - labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id - ) - - elif decoder_input_ids is None and decoder_inputs_embeds is None: - audio_encoder_outputs = self.audio_encoder( - input_values=input_values, - padding_mask=padding_mask, - **kwargs_audio_encoder, - ) - audio_codes = audio_encoder_outputs.audio_codes - frames, bsz, codebooks, seq_len = audio_codes.shape - if frames != 1: - raise ValueError( - f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is " - "disabled by setting `chunk_length=None` in the audio encoder." - ) - - if self.config.decoder.audio_channels == 2 and audio_codes.shape[2] == self.decoder.num_codebooks // 2: - # mono input through encodec that we convert to stereo - audio_codes = audio_codes.repeat_interleave(2, dim=2) - - decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) + # TODO: also do it with audio_labels? + + # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used + if inputs_embeds is None: + if input_ids is None and audio_codes is None: + raise ValueError("You must provide at least one of `input_ids`, `inputs_embeds`, `input_values` and `audio_codes`.") + + if input_ids is not None: + inputs_embeds = self.decoder.model.embed_tokens(input_ids) + + if audio_codes is not None: + audio_inputs_embeds = sum([self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])]) + inputs_embeds = audio_inputs_embeds if inputs_embeds is None else audio_inputs_embeds + inputs_embeds # Decode decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=attention_mask, - inputs_embeds=decoder_inputs_embeds, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, past_key_values=past_key_values, - return_dict=return_dict, - labels=labels, + return_dict=True, + labels=text_labels, **kwargs_decoder, ) - + + # TODO: how to deal with loss here ? maybe we can do one loss for text + # and one loss for audio_labels? + decoder_last_hidden_states = decoder_outputs.last_hidden_state + # TODO: we want to pass the audio_codes and audio_labels from the model inputs + + depth_decoder_outputs = None + if text_labels is not None and audio_labels is not None: + # TODO: how to deal with padding mask and attention mask ? + + # To use depth decoder forward here, we actually need oracle input ids since we're supposed to pass the true input ids + depth_decoder_outputs = self.depth_decoder( + hidden_states=decoder_last_hidden_states, + input_ids=text_labels, # probably need to reshape to (B*S) + audio_codes=audio_labels, # probably need to reshape to (B*S) + attention_mask=attention_mask, + padding_mask=padding_mask, + ) + + if not return_dict: - return decoder_outputs + encoder_outputs + outputs = decoder_outputs.to_tuple() + if depth_decoder_outputs is not None: + outputs += depth_decoder_outputs.to_tuple() + return outputs# TODO + encoder_outputs + # TODO: change return Seq2SeqLMOutput( - loss=decoder_outputs.loss, + loss=decoder_outputs.loss, # TODO: it's the text loss logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, ) + def _prepare_inputs_embeds_for_generation( + self, + input_ids: Optional[torch.LongTensor] = None, + input_values: Optional[torch.FloatTensor] = None, + audio_codes: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used + if inputs_embeds is None: + # TODO: here we have to decide how to deal with audio codes from the user + # also have to decide how to deal with number of channels + + if input_values is not None and audio_codes is None: + # TODO: should be 16 codebooks + audio_codes = self.audio_encoder.encode(input_values, num_quantizers=self.num_codebooks,)[0] + + + if input_ids is None and audio_codes is None: + raise ValueError("You must provide at least one of `input_ids`, `inputs_embeds`, `input_values` and `audio_codes`.") + + if input_ids is not None: + inputs_embeds = self.decoder.model.embed_tokens(input_ids) + + if audio_codes is not None: + audio_inputs_embeds = sum([self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])]) + inputs_embeds = audio_inputs_embeds if inputs_embeds is None else audio_inputs_embeds + inputs_embeds + + return inputs_embeds + + def prepare_inputs_for_generation( self, - decoder_input_ids, + input_ids, past_key_values=None, attention_mask=None, - head_mask=None, - decoder_attention_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, use_cache=None, - encoder_outputs=None, decoder_delay_pattern_mask=None, guidance_scale=None, **kwargs, ): if decoder_delay_pattern_mask is None: - decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( - decoder_input_ids, + input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( + input_ids, self.generation_config.pad_token_id, max_length=self.generation_config.max_length, ) # apply the delay pattern mask - decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask) + input_ids = self.decoder.apply_delay_pattern_mask(input_ids, decoder_delay_pattern_mask) if guidance_scale is not None and guidance_scale > 1: # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these # before sampling) - decoder_input_ids = decoder_input_ids.repeat((2, 1)) - if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) + input_ids = input_ids.repeat((2, 1)) + if attention_mask is not None: + attention_mask = attention_mask.repeat((2, 1)) if past_key_values is not None: past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: + if input_ids.shape[1] > past_length: remove_prefix_length = past_length else: # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 + remove_prefix_length = input_ids.shape[1] - 1 - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + input_ids = input_ids[:, remove_prefix_length:] return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, + "input_ids": None, # TODO encoder_outputs is defined. input_ids not needed "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, + "input_ids": input_ids, "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, } - def _prepare_decoder_input_ids_for_generation( - self, - batch_size: int, - model_input_name: str, - model_kwargs: Dict[str, torch.Tensor], - decoder_start_token_id: int = None, - bos_token_id: int = None, - device: torch.device = None, - ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: - """Prepares `decoder_input_ids` for generation with encoder-decoder models""" - - # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, - # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. - if model_kwargs is not None and "decoder_input_ids" in model_kwargs: - decoder_input_ids = model_kwargs.pop("decoder_input_ids") - elif "input_ids" in model_kwargs and model_input_name != "input_ids": - decoder_input_ids = model_kwargs.pop("input_ids") - else: - decoder_input_ids = None - - # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. - decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) - if device is None: - device = self.device - decoder_input_ids_start = ( - torch.ones((batch_size * self.decoder.num_codebooks, 1), dtype=torch.long, device=device) - * decoder_start_token_id - ) - - # no user input -> use decoder_start_token_id as decoder_input_ids - if decoder_input_ids is None: - decoder_input_ids = decoder_input_ids_start - - # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust - # decoder_attention_mask if provided) - elif (decoder_input_ids[..., 0] != decoder_start_token_id).all().item(): - decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) - if "decoder_attention_mask" in model_kwargs: - decoder_attention_mask = model_kwargs["decoder_attention_mask"] - decoder_attention_mask = torch.cat( - (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), - dim=-1, - ) - model_kwargs["decoder_attention_mask"] = decoder_attention_mask - - return decoder_input_ids, model_kwargs - - def _prepare_text_encoder_kwargs_for_generation( - self, - inputs_tensor: torch.Tensor, - model_kwargs, - model_input_name: Optional[str], - generation_config: GenerationConfig, - ) -> Dict[str, Any]: - # 1. get text encoder - encoder = self.get_text_encoder() - # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device - # as the inputs. - if hasattr(encoder, "_hf_hook"): - encoder._hf_hook.io_same_device = True - - # 2. Prepare encoder args and encoder kwargs from model kwargs. - irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] - encoder_kwargs = { - argument: value - for argument, value in model_kwargs.items() - if not any(argument.startswith(p) for p in irrelevant_prefix) - } - encoder_signature = set(inspect.signature(encoder.forward).parameters) - encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature - if not encoder_accepts_wildcard: - encoder_kwargs = { - argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature - } - encoder_kwargs["output_attentions"] = generation_config.output_attentions - encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states - guidance_scale = generation_config.guidance_scale - - # 3. make sure that encoder returns `ModelOutput` - model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name - encoder_kwargs["return_dict"] = True - encoder_kwargs[model_input_name] = inputs_tensor - last_hidden_state = encoder(**encoder_kwargs).last_hidden_state - - # for classifier free guidance we need to add a 'null' input to our encoder hidden states - if guidance_scale is not None and guidance_scale > 1: - last_hidden_state = torch.concatenate([last_hidden_state, torch.zeros_like(last_hidden_state)], dim=0) - if "attention_mask" in model_kwargs: - model_kwargs["attention_mask"] = torch.concatenate( - [model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0 - ) - - model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state) - - return model_kwargs def _prepare_audio_encoder_kwargs_for_generation( self, input_values, model_kwargs, model_input_name: Optional[str] = None @@ -2316,41 +2075,13 @@ def _prepare_audio_encoder_kwargs_for_generation( model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name encoder_kwargs["return_dict"] = True - if self.decoder.config.audio_channels == 1: - encoder_kwargs[model_input_name] = input_values - audio_encoder_outputs = encoder.encode(**encoder_kwargs) - audio_codes = audio_encoder_outputs.audio_codes - audio_scales = audio_encoder_outputs.audio_scales - - frames, bsz, codebooks, seq_len = audio_codes.shape - - else: - if input_values.shape[1] != 2: - raise ValueError( - f"Expected stereo audio (2-channels) but example has {input_values.shape[1]} channel." - ) - - encoder_kwargs[model_input_name] = input_values[:, :1, :] - audio_encoder_outputs_left = encoder.encode(**encoder_kwargs) - audio_codes_left = audio_encoder_outputs_left.audio_codes - audio_scales_left = audio_encoder_outputs_left.audio_scales - - encoder_kwargs[model_input_name] = input_values[:, 1:, :] - audio_encoder_outputs_right = encoder.encode(**encoder_kwargs) - audio_codes_right = audio_encoder_outputs_right.audio_codes - audio_scales_right = audio_encoder_outputs_right.audio_scales + encoder_kwargs[model_input_name] = input_values + audio_encoder_outputs = encoder.encode(**encoder_kwargs) + audio_codes = audio_encoder_outputs.audio_codes + audio_scales = audio_encoder_outputs.audio_scales - frames, bsz, codebooks, seq_len = audio_codes_left.shape - # copy alternating left/right channel codes into stereo codebook - audio_codes = audio_codes_left.new_ones((frames, bsz, 2 * codebooks, seq_len)) + frames, bsz, codebooks, seq_len = audio_codes.shape - audio_codes[:, :, ::2, :] = audio_codes_left - audio_codes[:, :, 1::2, :] = audio_codes_right - - if audio_scales_left != [None] or audio_scales_right != [None]: - audio_scales = torch.stack([audio_scales_left, audio_scales_right], dim=1) - else: - audio_scales = [None] * bsz if frames != 1: raise ValueError( @@ -2358,9 +2089,9 @@ def _prepare_audio_encoder_kwargs_for_generation( "disabled by setting `chunk_length=None` in the audio encoder." ) - decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) + input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) - model_kwargs["decoder_input_ids"] = decoder_input_ids + model_kwargs["input_ids"] = input_ids model_kwargs["audio_scales"] = audio_scales return model_kwargs @@ -2382,13 +2113,13 @@ def freeze_audio_encoder(self): param.requires_grad = False self.audio_encoder._requires_grad = False - def freeze_text_encoder(self): + def freeze_depth_decoder(self): """ - Freeze the text encoder weights. + Freeze the depth encoder weights. """ - for param in self.text_encoder.parameters(): + for param in self.depth_decoder.parameters(): param.requires_grad = False - self.text_encoder._requires_grad = False + self.depth_decoder._requires_grad = False def _maybe_initialize_input_ids_for_generation( self, @@ -2400,12 +2131,6 @@ def _maybe_initialize_input_ids_for_generation( if inputs is not None: return inputs - encoder_outputs = model_kwargs.get("encoder_outputs") - if encoder_outputs is not None: - # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding - shape = encoder_outputs[0].size()[:-1] - return torch.ones(shape, dtype=torch.long, device=self.device) * -100 - if bos_token_id is None: raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") @@ -2518,15 +2243,12 @@ def generate( generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) - if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) is tuple: - # wrap the unconditional outputs as a BaseModelOutput for compatibility with the rest of generate - model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0]) # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - requires_attention_mask = "encoder_outputs" not in model_kwargs + requires_attention_mask = False # TODO "encoder_outputs" not in model_kwargs kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None # 3. Define model inputs @@ -2545,13 +2267,8 @@ def generate( inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor ) - if "encoder_outputs" not in model_kwargs: - # encoder_outputs are created and added to `model_kwargs` - model_kwargs = self._prepare_text_encoder_kwargs_for_generation( - inputs_tensor, model_kwargs, model_input_name, generation_config - ) - if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs: + if "input_ids" not in model_kwargs and "input_values" in model_kwargs: model_kwargs = self._prepare_audio_encoder_kwargs_for_generation( model_kwargs["input_values"], model_kwargs, @@ -2662,19 +2379,11 @@ def generate( if audio_scales is None: audio_scales = [None] * batch_size - if self.decoder.config.audio_channels == 1: - output_values = self.audio_encoder.decode( - output_ids, - audio_scales=audio_scales, - ).audio_values - else: - codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales) - output_values_left = codec_outputs_left.audio_values - - codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales) - output_values_right = codec_outputs_right.audio_values + output_values = self.audio_encoder.decode( + output_ids, + audio_scales=audio_scales, + ).audio_values - output_values = torch.cat([output_values_left, output_values_right], dim=1) if generation_config.return_dict_in_generate: outputs.sequences = output_values diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py index 218385478b999c..c9ded9097a02ff 100644 --- a/tests/models/moshi/test_modeling_moshi.py +++ b/tests/models/moshi/test_modeling_moshi.py @@ -27,7 +27,6 @@ from transformers import ( EncodecConfig, MoshiConfig, - MoshiDecoderConfig, MoshiProcessor, PretrainedConfig, T5Config, @@ -155,7 +154,7 @@ def prepare_config_and_inputs(self): return config, inputs_dict def get_config(self): - config = MoshiDecoderConfig( + config = MoshiConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, @@ -185,7 +184,7 @@ class MoshiDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi def setUp(self): self.model_tester = MoshiDecoderTester(self) - self.config_tester = ConfigTester(self, config_class=MoshiDecoderConfig, hidden_size=16) + self.config_tester = ConfigTester(self, config_class=MoshiConfig, hidden_size=16) def test_config(self): self.config_tester.run_common_tests() @@ -1025,7 +1024,7 @@ def get_config(self): codebook_size=self.codebook_size, codebook_dim=self.vocab_size, ) - decoder_config = MoshiDecoderConfig( + decoder_config = MoshiConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, From 016d538690bc8aca40b99c8d847aec78334730f2 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Mon, 16 Sep 2024 17:19:35 +0200 Subject: [PATCH 6/8] finalize converting script - still missing tokenizer and FE and processor --- .../models/moshi/configuration_moshi.py | 5 ++++- .../models/moshi/convert_moshi_transformers.py | 15 +++++++++++++-- src/transformers/models/moshi/modeling_moshi.py | 6 +----- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/moshi/configuration_moshi.py b/src/transformers/models/moshi/configuration_moshi.py index 9fff7c5c7124c8..9eda765c834477 100644 --- a/src/transformers/models/moshi/configuration_moshi.py +++ b/src/transformers/models/moshi/configuration_moshi.py @@ -87,6 +87,8 @@ class MoshiConfig(PretrainedConfig): depth_num_key_value_heads (`int`, *optional*): This is the number of key_value heads that should be used to implement Grouped Query Attention in the depth decoder. If it is not specified, will default to `depth_num_key_value_heads`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether input and output word embeddings should be tied. kwargs (*optional*): Dictionary of keyword arguments. Notably: - **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that @@ -155,6 +157,7 @@ def __init__(self, depth_ffn_dim=5632, depth_head_dim=None, depth_num_key_value_heads=None, + tie_word_embeddings=False, **kwargs): self.vocab_size = vocab_size self.hidden_size = hidden_size @@ -186,7 +189,7 @@ def __init__(self, self.depth_head_dim = depth_head_dim or depth_hidden_size // depth_num_attention_heads self.depth_num_key_value_heads = depth_num_key_value_heads if depth_num_key_value_heads is not None else depth_num_attention_heads - super().__init__(**kwargs) + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) if "audio_encoder" not in kwargs: raise ValueError("Config has to be initialized with audio_encoder config") diff --git a/src/transformers/models/moshi/convert_moshi_transformers.py b/src/transformers/models/moshi/convert_moshi_transformers.py index cef03195914798..90f3ab486d2871 100644 --- a/src/transformers/models/moshi/convert_moshi_transformers.py +++ b/src/transformers/models/moshi/convert_moshi_transformers.py @@ -94,8 +94,8 @@ def _preprocess_state_dict(state_dict, config): input_projections.append(state_dict.pop(f"depformer_in.{codebook_idx}.weight")) lm_heads.append(state_dict.pop(f"linears.{codebook_idx}.weight")) - state_dict["depth_decoder.input_projections.weight"] = torch.stack(input_projections, dim = 1) - state_dict["depth_decoder.lm_heads.weight"] = torch.stack(lm_heads, dim = 1) + state_dict["depth_decoder.input_projections.weight"] = torch.stack(input_projections, dim = 0) + state_dict["depth_decoder.lm_heads.weight"] = torch.stack(lm_heads, dim = 0) return state_dict @@ -133,6 +133,11 @@ def permute(w, n_heads, dim1=hidden_size, dim2=hidden_size): for old_layer_name, new_layer_name in convert_list: if old_layer_name in new_k: new_k = new_k.replace(old_layer_name, new_layer_name) + + if "alpha" in k: + state_dict[k] = state_dict[k].squeeze() + + if "in_proj_weight" in new_k: # split qkv into query key and value @@ -153,9 +158,15 @@ def permute(w, n_heads, dim1=hidden_size, dim2=hidden_size): key_layer, num_key_value_heads, dim1=key_value_head_dim, dim2=hidden_size ) state_dict[new_k.replace("in_proj_weight", "v_proj.weight")] = value_layer + elif "o_proj" in new_k and "depth_decoder" in new_k: + output_layer = state_dict.pop(k) + state_dict[new_k] = output_layer.view(config.num_codebooks, -1, output_layer.shape[-1]) else: state_dict[new_k] = state_dict.pop(k) + # Do the last one by hand + state_dict["depth_decoder.text_embed_tokens.weight"] = state_dict.pop("depth_decoder.decoder.model.embed_tokens.weight") + extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys()) missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys()) if len(extra_keys) != 0: diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 6096764e18745a..8afeecc4112a1d 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -692,11 +692,7 @@ def __init__(self, config: MoshiConfig, layer_idx: int, use_flexible_linear: boo self.hidden_size = config.hidden_size if not is_depth_layer else config.depth_hidden_size self.use_flexible_linear = use_flexible_linear - # depth_ffn_dim - # depth_hidden_size - # depth doesn't use pos embedding - - self.self_attn = MOSHI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = MOSHI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx, is_depth_attention=is_depth_layer) self.mlp = MoshiGatingMLP(config) if not use_flexible_linear else MoshiGatingMLP(config, config.num_codebooks, is_depth_layer) self.input_layernorm = MoshiRMSNorm(self.hidden_size, eps=config.rms_norm_eps) From 34b6e24398f54d7c84da6593c182d50f5be5d3d5 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Mon, 16 Sep 2024 17:28:44 +0200 Subject: [PATCH 7/8] fix saving model w/o default config --- src/transformers/configuration_utils.py | 1 + src/transformers/models/moshi/convert_moshi_transformers.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 2339c4cd6b51d0..1bbc9a5d8016ba 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -1033,6 +1033,7 @@ def _get_non_default_generation_parameters(self) -> Dict[str, Any]: if decoder_config is not self: default_config = decoder_config.__class__() else: + default_config = None decoder_config = None # If it is a composite model, we want to check the subconfig that will be used for generation diff --git a/src/transformers/models/moshi/convert_moshi_transformers.py b/src/transformers/models/moshi/convert_moshi_transformers.py index 90f3ab486d2871..1575a5110acccb 100644 --- a/src/transformers/models/moshi/convert_moshi_transformers.py +++ b/src/transformers/models/moshi/convert_moshi_transformers.py @@ -223,6 +223,8 @@ def convert_checkpoint( original_checkpoint.update({f"audio_encoder.{key}": value for (key, value) in audio_checkpoint.items()}) model = _convert_model(original_checkpoint, model, convert_list, device, config) + + # TODO: set generation config model.save_pretrained(pytorch_dump_folder_path) From 50f9eb80ac4b0433c7fe46b4572fef2b34e9ebd0 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Fri, 20 Sep 2024 15:49:25 +0200 Subject: [PATCH 8/8] working generation --- .../models/moshi/configuration_moshi.py | 26 +- .../models/moshi/modeling_moshi.py | 971 ++++++++---------- 2 files changed, 430 insertions(+), 567 deletions(-) diff --git a/src/transformers/models/moshi/configuration_moshi.py b/src/transformers/models/moshi/configuration_moshi.py index 9eda765c834477..eff73613a82b80 100644 --- a/src/transformers/models/moshi/configuration_moshi.py +++ b/src/transformers/models/moshi/configuration_moshi.py @@ -97,7 +97,7 @@ class MoshiConfig(PretrainedConfig): Example: - ```python + ```python # TODO(YL): update >>> from transformers import ( ... MoshiConfig, ... EncodecConfig, @@ -189,21 +189,24 @@ def __init__(self, self.depth_head_dim = depth_head_dim or depth_hidden_size // depth_num_attention_heads self.depth_num_key_value_heads = depth_num_key_value_heads if depth_num_key_value_heads is not None else depth_num_attention_heads - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) - - if "audio_encoder" not in kwargs: + audio_encoder_config = kwargs.pop("audio_encoder", None) + if audio_encoder_config is None: raise ValueError("Config has to be initialized with audio_encoder config") - - audio_encoder_config = kwargs.pop("audio_encoder") + audio_encoder_model_type = audio_encoder_config.pop("model_type") self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config) + if self.num_codebooks > self.audio_encoder.num_codebooks: raise ValueError(f"`num_codebooks={num_codebooks}` is greater than the maximum number of codebooks that the audio encoder can deal with ({self.audio_encoder.num_codebooks}). Please lower it.") self.audio_vocab_size = self.audio_encoder.codebook_size if audio_vocab_size is None else audio_vocab_size + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) - + @property + def sampling_rate(self): + return self.audio_encoder.sampling_rate @classmethod def from_audio_encoder_config( @@ -213,17 +216,12 @@ def from_audio_encoder_config( ): r""" Instantiate a [`MoshiConfig`] (or a derived class) from an audio encoder configuration. - + Returns: [`MoshiConfig`]: An instance of a configuration object """ - + return cls( audio_encoder=audio_encoder_config.to_dict(), **kwargs, ) - - @property - # This is a property because you might want to change the codec model on the fly - def sampling_rate(self): - return self.audio_encoder.sampling_rate diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 8afeecc4112a1d..e7efa12aab1729 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -91,7 +91,7 @@ @dataclass class MoshiCausalLMOutputWithPast(ModelOutput): """ - Base class for causal language model (or autoregressive) outputs. + `MoshiForCausalLM` outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): @@ -127,6 +127,60 @@ class MoshiCausalLMOutputWithPast(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor, ...]] = None +@dataclass +class MoshiConditionalGenerationOutputWithPast(ModelOutput): + """ + `MoshiForConditionalGeneration` outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `text_labels` is provided): + Text language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the text language modeling head (scores for each vocabulary token before SoftMax). + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + depth_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `audio_labels` is provided): + Audio language modeling loss (for next-token prediction). + audio_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the audio language modeling heads. + depth_past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Past key-values of the depth decoder. + depth_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Hidden states of the depth decoder + depth_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Depth decoder's Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + depth_loss: Optional[torch.FloatTensor] = None + audio_logits: torch.FloatTensor = None + depth_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + depth_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + depth_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. @@ -248,10 +302,7 @@ def forward(self, x, layer_idx=None): If it's a tensor of shape `(seq_length,)`, will matmul each element of the sequence to the corresponding weights. But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`. """ - if layer_idx is not None and layer_idx.dim()==0: - # Single layer case: select a specific layer (batch_size, 1 , input_size) -> (batch_size, 1, output_size) - return torch.matmul(x, self.weight[layer_idx].T) - elif layer_idx is not None: + if layer_idx is not None: # Use torch.gather to select the corresponding weights for each sample selected_weights = torch.index_select(self.weight, 0, layer_idx) return torch.einsum('bnh,noh->bno', x, selected_weights) @@ -427,9 +478,9 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, position_ids) # Ignore copy - key_states = self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, position_ids) # Ignore copy - value_states = self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, position_ids) # Ignore copy + query_states = self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, cache_position) # Ignore copy query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -441,7 +492,7 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} if self.rotary_emb is not None else {"cache_position": cache_position} # Ignore copy key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -467,7 +518,7 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, position_ids) # Ignore copy + attn_output = self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, cache_position) # Ignore copy if not output_attentions: attn_weights = None @@ -511,9 +562,9 @@ def forward( bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, cache_position) # Ignore copy # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim @@ -529,7 +580,7 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} if self.rotary_emb is not None else {"cache_position": cache_position} # Ignore copy key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache @@ -580,7 +631,7 @@ def forward( ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, cache_position) # Ignore copy if not output_attentions: attn_weights = None @@ -626,9 +677,9 @@ def forward( bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, cache_position) # Ignore copy query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -641,7 +692,7 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} if self.rotary_emb is not None else {"cache_position": cache_position} # Ignore copy key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -674,7 +725,7 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, cache_position) # Ignore copy return attn_output, None, past_key_value @@ -926,7 +977,6 @@ def _init_weights(self, module): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ -# TODO: DO it as a depth decoder class MoshiDepthDecoder(MoshiPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoshiTransformerLayer`] @@ -947,21 +997,11 @@ def __init__(self, config: MoshiConfig): self.input_projections = MoshiFlexibleLinear(config.hidden_size, config.depth_hidden_size, config.num_codebooks) - # TODO: remove if relevant - # nn.ModuleList( - # [nn.Linear(config.hidden_size, config.depth_hidden_size, bias=False) for _ in range(config.num_codebooks)] - # ) - self.layers = nn.ModuleList( [MoshiDecoderLayer(config, layer_idx, use_flexible_linear=True, is_depth_layer=True) for layer_idx in range(config.depth_num_hidden_layers)] ) self.lm_heads = MoshiFlexibleLinear(config.depth_hidden_size, config.audio_vocab_size, config.num_codebooks) - - # TODO: remove if relevant - # nn.ModuleList( - # [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)] - # ) self._attn_implementation = config._attn_implementation @@ -971,10 +1011,8 @@ def __init__(self, config: MoshiConfig): def forward( # TODO: update docstrings entirely self, input_ids: Optional[torch.LongTensor] = None, # sequence of oracle input ids, i.e it must be the input ids that are predicted by the decoder # (B, S) - audio_codes: Optional[torch.Tensor] = None, # Same, shoud be oracle audio codebooks, but also with one channel less: # (B, C, S) or C-1 - last_hidden_states: torch.LongTensor = None, # use 8 times (B, S, H_in) | (B*S, H_in) + last_hidden_state: torch.LongTensor = None, # shape: (B*S, 1, hidden_dim) # use 8 times (B, S, H_in) | (B*S, H_in) attention_mask: Optional[torch.BoolTensor] = None, - padding_mask: Optional[torch.BoolTensor] = None, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -983,7 +1021,7 @@ def forward( # TODO: update docstrings entirely return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, - + cache_position: Optional[torch.LongTensor] = None, # TODO: add to docstrings ) -> Union[Tuple, BaseModelOutputWithPast]: """ Args: @@ -1049,19 +1087,23 @@ def forward( # TODO: update docstrings entirely """ # here, we actually predict a sequence length of C # independtly from the batch size and sequence length - # 1/ input ids is passed through text_embed_tokens -> (B, S, H) H=1024 - # 2/ each codebooks is passed through the embedding layer ase well -> (B, C-1, S, H) - # 3/ concat the two precedent results and get (B, C, S ,H) + # 1/ input ids is passed through text_embed_tokens -> (B * S, H) H=1024 + # 2/ each codebooks is passed through the embedding layer ase well -> (B*S, C-1, H) + # 3/ concat the two precedent results and get (B*S, C, ,H) # 4/ then we also pass the last hidden states through the input projection layers, one for each codebooks - # we get (B, C, S, H) - # 5/ sum one and another (B, C, S, H) + # we get (B*S, C, H) + # 5/ sum one and another (B*S, C, H) # 6/ pass (B*S, C, H) through the model and get (B*S, C, H_out) # 7/ for each codebook, pass it through its own lm heads: (B*S, C, H) # 8/ predict the codebook C1, C2 ... -> (B, S, C, H) + # generation: + # we start with last hidden states and text tokens + # depending on position ids chose which embedding layer + # TODO: can we suppose B*S each time instead of B,S # in the generation mode, it's different: - # S=1 + # text_token (B*S, ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1079,44 +1121,40 @@ def forward( # TODO: update docstrings entirely if use_cache and past_key_values is None and not self.training: past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device + past_seen_tokens, past_seen_tokens + input_ids.shape[1], device=input_ids.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used + # If inputs_embeds is provided, it has the priority over input_ids, which won't be used if inputs_embeds is None: - if input_ids is None and audio_codes is None: - raise ValueError("You must provide at least one of `input_ids`, `inputs_embeds`, `input_values` and `audio_codes`.") - - if input_ids is not None: - inputs_embeds = self.text_embed_tokens(input_ids) - - # TODO: this should actually use embed_tokens depending on which position ids is asked for - # We should actually use one codebook embedding per element of the sequence - if audio_codes is not None: # TODO(YL): make sure it's C-1 - audio_inputs_embeds = sum([self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])]) - inputs_embeds = audio_inputs_embeds if inputs_embeds is None else audio_inputs_embeds + inputs_embeds - - # TODO: check position ids shape - inputs_embeds += self.input_projections(last_hidden_states, position_ids) + inputs_embeds = [] + for position_idx in cache_position: + if position_idx == 0: + inputs_embeds.append(self.text_embed_tokens(input_ids[:, [position_idx]])) + else: + inputs_embeds.append(self.embed_tokens[(position_idx-1)](input_ids[:, [position_idx - past_seen_tokens]])) + + inputs_embeds = torch.cat(inputs_embeds, dim=1) + + inputs_embeds += self.input_projections(last_hidden_state, cache_position) causal_mask = None if attention_mask is not None: causal_mask = self._update_causal_mask( - attention_mask, hidden_states, cache_position, past_key_values, output_attentions + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None - + hidden_states = inputs_embeds for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1157,26 +1195,13 @@ def forward( # TODO: update docstrings entirely next_cache = next_decoder_cache if use_cache else None - - # TODO: check position ids shape # TODO: remove the float() operation in v4.46 - logits = self.lm_heads(hidden_states, position_ids).float() + logits = self.lm_heads(hidden_states, cache_position).float() loss = None if labels is not None: - # TODO: it's probably not the right way to do it - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + loss = 0 + # TODO(YL) if not return_dict: return tuple(v for v in [loss, logits, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -1189,6 +1214,77 @@ def forward( # TODO: update docstrings entirely attentions=all_self_attns, ) + # Copied from transformers.models.gemma.modeling_gemma.GemmaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): + # TODO: use this to make sure max_tokens = num_codebooks + super()._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + # Copied from transformers.models.gemma.modeling_gemma.GemmaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1201,11 +1297,6 @@ def prepare_inputs_for_generation( num_logits_to_keep=None, **kwargs, ): - # TODO(YL): on the first step, `input_ids` is used - # then every new input_ids are passed as `audio_codes` instead! - # Make sure cache_positions is correct - # do we use num_logits_to_keep? - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here @@ -1263,6 +1354,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "last_hidden_state": kwargs.get("last_hidden_state") # Ignore copy } ) return model_inputs @@ -1316,7 +1408,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, # TODO(YL): add to docstrings ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1718,82 +1810,25 @@ def prepare_inputs_for_generation( "for speech-to-speech.", MOSHI_START_DOCSTRING, ) -class MoshiForConditionalGeneration(MoshiPreTrainedModel): # TODO(YL): don't think I can initialize like this for a composite model +class MoshiForConditionalGeneration(MoshiPreTrainedModel): config_class = MoshiConfig main_input_name = "input_ids" supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - def __init__( - self, - config: Optional[MoshiConfig] = None, - audio_encoder: Optional[PreTrainedModel] = None, - decoder: Optional[MoshiForCausalLM] = None, - depth_decoder: Optional[MoshiDepthDecoder] = None, - ): - if config is None and (audio_encoder is None or decoder is None or depth_decoder is None): - raise ValueError( - "Either a configuration has to be provided, or all three of Moshi depth decoder, audio encoder and Moshi decoder." - ) - if config is None: - config = MoshiConfig.from_audio_encoder_config(audio_encoder.config) - else: - if not isinstance(config, self.config_class): - raise ValueError(f"Config: {config} has to be of type {self.config_class}") - - # TODO: verify decoder and depth decoder not incompatible - # TODO: does the decoder and depth decoder makes sense? - - # initialize with config + def __init__(self, config: MoshiConfig): super().__init__(config) - - if audio_encoder is None: - from ..auto.modeling_auto import AutoModel - - audio_encoder = AutoModel.from_config(config.audio_encoder) - - if decoder is None: - decoder = MoshiForCausalLM(config) - - if depth_decoder is None: - depth_decoder = MoshiDepthDecoder(config) - - self.depth_decoder = depth_decoder - self.decoder = decoder - self.audio_encoder = audio_encoder - # We have 2 * num_codebooks audio embedding layers because we have the user input channel and the model output channel. self.embed_tokens = nn.ModuleList( [nn.Embedding(config.audio_vocab_size + 1, config.hidden_size) for _ in range(2 * config.num_codebooks)] ) - - if self.decoder.config.to_dict() != self.config.to_dict(): - logger.warning( - f"Config of the decoder: {self.decoder.__class__} is overwritten by shared config:" - f" {self.config}" - ) - if self.depth_decoder.config.to_dict() != self.config.to_dict(): - logger.warning( - f"Config of the depth decoder: {depth_decoder.decoder.__class__} is overwritten by shared config:" - f" {self.config}" - ) - if self.audio_encoder.config.to_dict() != self.config.audio_encoder.to_dict(): - logger.warning( - f"Config of the audio_encoder: {self.audio_encoder.__class__} is overwritten by shared audio_encoder config:" - f" {self.config.audio_encoder}" - ) - - # make sure that the individual model's config refers to the shared config - # so that the updates to the config will be synced - self.audio_encoder.config = self.config.audio_encoder - self.decoder.config = self.config - self.depth_decoder.config = self.config + self.audio_encoder = AutoModel.from_config(config.audio_encoder, attn_implementation=config._attn_implementation) + self.decoder = MoshiForCausalLM(config) + self.depth_decoder = MoshiDepthDecoder(config) self.num_codebooks = config.num_codebooks - - # tie text encoder, decoder weights if config set accordingly - self.tie_weights() + self.post_init() def get_audio_encoder(self): return self.audio_encoder @@ -1813,40 +1848,20 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): return self.decoder.set_output_embeddings(new_embeddings) - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Example: - - ```python - >>> from transformers import MoshiForConditionalGeneration - - >>> model = MoshiForConditionalGeneration.from_pretrained("kyutai/moshiko") - ```""" - - # At the moment fast initialization is not supported for composite models - if kwargs.get("_fast_init", False): - logger.warning( - "Fast initialization is currently not supported for MoshiForConditionalGeneration. " - "Falling back to slow initialization..." - ) - kwargs["_fast_init"] = False - - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - @add_start_docstrings_to_model_forward(MOSHI_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.BoolTensor] = None, - input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it - padding_mask: Optional[torch.BoolTensor] = None, - audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings - and decide if it's 16 codebooks or (8 and another audio_values) + user_input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it + user_audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings + moshi_input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it + moshi_audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, text_labels: Optional[torch.LongTensor] = None, #TODO: update do docstrings - audio_labels: Optional[torch.LongTensor] = None, #TODO: update do docstrings + audio_labels: Optional[torch.LongTensor] = None, #TODO: update do docstrings - must be 16 channels (first user than moshi?) use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -1892,13 +1907,13 @@ def forward( argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } - # TODO: encode input_values - # TODO: how to deal with both streams, we actually two input_values stream - if input_values is not None and audio_codes is None: - # TODO: should be 16 codebooks - audio_codes = self.audio_encoder.encode(input_values, padding_mask, num_quantizers=self.num_codebooks, **kwargs_audio_encoder)[0] - + kwargs_depth_decoder = { + argument[len("depth_decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("depth_decoder_") + } + # TODO: we need to have same number of timestamps, and same number of batch + + if (text_labels is not None) and (input_ids is None and inputs_embeds is None): input_ids = shift_tokens_right( text_labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id @@ -1907,6 +1922,15 @@ def forward( # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used if inputs_embeds is None: + if user_input_values is not None and user_audio_codes is None: + user_audio_codes = self.audio_encoder.encode(user_input_values, num_quantizers=self.num_codebooks, **kwargs_audio_encoder)[0] + + if moshi_input_values is not None and moshi_audio_codes is None: + moshi_audio_codes = self.audio_encoder.encode(moshi_input_values, num_quantizers=self.num_codebooks, **kwargs_audio_encoder)[0] + + # TODO: make sure it's the right order (user than moshi) + make sure it's done over the right dim + audio_codes = torch.cat([user_audio_codes, moshi_audio_codes], dim=1) + if input_ids is None and audio_codes is None: raise ValueError("You must provide at least one of `input_ids`, `inputs_embeds`, `input_values` and `audio_codes`.") @@ -1930,60 +1954,80 @@ def forward( **kwargs_decoder, ) - # TODO: how to deal with loss here ? maybe we can do one loss for text - # and one loss for audio_labels? - decoder_last_hidden_states = decoder_outputs.last_hidden_state - # TODO: we want to pass the audio_codes and audio_labels from the model inputs + decoder_last_hidden_state = decoder_outputs.last_hidden_state depth_decoder_outputs = None if text_labels is not None and audio_labels is not None: - # TODO: how to deal with padding mask and attention mask ? - # To use depth decoder forward here, we actually need oracle input ids since we're supposed to pass the true input ids + + # (batch_size, sequence_length) -> (batch_size * sequence_length, 1) + text_labels = text_labels.view(-1, 1) + # (batch_size, num_codebooks, sequence_length) -> (batch_size * sequence_length, num_codebooks) + audio_labels = audio_labels.transpose(1,2).reshape(-1, audio_labels.shape[1]) + + depth_input_ids = torch.cat([text_labels, audio_labels], dim=1) + # keep the last codebook out of input_ids + depth_input_ids = depth_input_ids[:, :-1] + + # (batch_size, sequence_length, dim) -> (batch_size * sequence_length, 1, dim) + decoder_last_hidden_state = decoder_last_hidden_state.view(-1, 1, decoder_last_hidden_state.shape[-1]) + depth_decoder_outputs = self.depth_decoder( - hidden_states=decoder_last_hidden_states, - input_ids=text_labels, # probably need to reshape to (B*S) - audio_codes=audio_labels, # probably need to reshape to (B*S) + last_hidden_state=decoder_last_hidden_state, + input_ids=depth_input_ids, attention_mask=attention_mask, - padding_mask=padding_mask, ) - if not return_dict: outputs = decoder_outputs.to_tuple() if depth_decoder_outputs is not None: outputs += depth_decoder_outputs.to_tuple() - return outputs# TODO + encoder_outputs + return outputs - # TODO: change - return Seq2SeqLMOutput( - loss=decoder_outputs.loss, # TODO: it's the text loss + return MoshiConditionalGenerationOutputWithPast( + loss=decoder_outputs.loss, logits=decoder_outputs.logits, + last_hidden_state=decoder_last_hidden_state, past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + depth_loss=None if depth_decoder_outputs is None else depth_decoder_outputs.loss, + audio_logits=None if depth_decoder_outputs is None else depth_decoder_outputs.logits, + depth_past_key_values=None if decoder_outputs is None else decoder_outputs.past_key_values, + depth_hidden_states=None if decoder_outputs is None else decoder_outputs.hidden_states, + depth_attentions=None if decoder_outputs is None else decoder_outputs.attentions, ) def _prepare_inputs_embeds_for_generation( self, input_ids: Optional[torch.LongTensor] = None, - input_values: Optional[torch.FloatTensor] = None, - audio_codes: Optional[torch.Tensor] = None, + user_input_values: Optional[torch.FloatTensor] = None, + user_audio_codes: Optional[torch.Tensor] = None, + moshi_input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it + moshi_audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings inputs_embeds: Optional[torch.FloatTensor] = None, ): # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used if inputs_embeds is None: - # TODO: here we have to decide how to deal with audio codes from the user - # also have to decide how to deal with number of channels - - if input_values is not None and audio_codes is None: - # TODO: should be 16 codebooks - audio_codes = self.audio_encoder.encode(input_values, num_quantizers=self.num_codebooks,)[0] + if input_ids is None and user_input_values is None and user_audio_codes is None and moshi_input_values is None and moshi_audio_codes is None: + raise ValueError("You must provide at least one of `input_ids`, `user_input_values`, `moshi_input_values`, `user_audio_codes` or `moshi_audio_codes`.") + # TODO: make sure batch size and sequence length is concording - if input_ids is None and audio_codes is None: - raise ValueError("You must provide at least one of `input_ids`, `inputs_embeds`, `input_values` and `audio_codes`.") + if user_input_values is not None and user_audio_codes is None: + user_audio_codes = self.audio_encoder.encode(user_input_values, num_quantizers=self.num_codebooks)[0] + + if moshi_input_values is not None and moshi_audio_codes is None: + moshi_audio_codes = self.audio_encoder.encode(moshi_input_values, num_quantizers=self.num_codebooks)[0] + + audio_codes = None + if user_audio_codes is not None and moshi_audio_codes is not None: + # TODO: make sure it's the right order (user than moshi) + make sure it's done over the right dim + audio_codes = torch.cat([user_audio_codes, moshi_audio_codes], dim=1) + elif user_audio_codes is not None: + audio_codes = user_audio_codes + elif moshi_audio_codes is not None: + audio_codes = moshi_audio_codes if input_ids is not None: inputs_embeds = self.decoder.model.embed_tokens(input_ids) @@ -1992,103 +2036,194 @@ def _prepare_inputs_embeds_for_generation( audio_inputs_embeds = sum([self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])]) inputs_embeds = audio_inputs_embeds if inputs_embeds is None else audio_inputs_embeds + inputs_embeds - return inputs_embeds - + return inputs_embeds, moshi_audio_codes - def prepare_inputs_for_generation( + @torch.no_grad() + def generate( self, - input_ids, - past_key_values=None, - attention_mask=None, - use_cache=None, - decoder_delay_pattern_mask=None, - guidance_scale=None, + input_ids: Optional[torch.LongTensor] = None, + user_input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it + user_audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings + moshi_input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it + moshi_audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings + inputs_embeds: Optional[torch.FloatTensor] = None, **kwargs, - ): - if decoder_delay_pattern_mask is None: - input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( - input_ids, - self.generation_config.pad_token_id, - max_length=self.generation_config.max_length, - ) + ) -> torch.LongTensor: + """ + # TODO: modify + Generates sequences of token ids for models with a language modeling head. - # apply the delay pattern mask - input_ids = self.decoder.apply_delay_pattern_mask(input_ids, decoder_delay_pattern_mask) + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. + kwargs (`Dict[str, Any]`, *optional*): + Remaining dictionary of keyword arguments that are passed to the `generate` method. Refers to the + original [`generate` docstrings](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate) + for more information on how to use them. + Note that keywords with a *depth_* prefix will be input for the `generate` method of the + depth decoder. Otherwise, the latter will use its default generation config. + + Return: # TODO + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. - if guidance_scale is not None and guidance_scale > 1: - # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these - # before sampling) - input_ids = input_ids.repeat((2, 1)) - if attention_mask is not None: - attention_mask = attention_mask.repeat((2, 1)) + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + - [`~generation.GenerateDecoderOnlyOutput`], + - [`~generation.GenerateBeamDecoderOnlyOutput`] - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: - input_ids = input_ids[:, remove_prefix_length:] + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + inputs_embeds, moshi_audio_codes = self._prepare_inputs_embeds_for_generation( + input_ids=input_ids, + user_input_values=user_input_values, + user_audio_codes=user_audio_codes, + moshi_input_values=moshi_input_values, + moshi_audio_codes=moshi_audio_codes, + inputs_embeds=inputs_embeds, + ) + + self.generated_audio_codes = moshi_audio_codes + + outputs = super().generate(inputs_embeds=inputs_embeds, **kwargs) - return { - "input_ids": None, # TODO encoder_outputs is defined. input_ids not needed - "past_key_values": past_key_values, - "input_ids": input_ids, - "attention_mask": attention_mask, - "use_cache": use_cache, - } + # check if outputs is a dict or a Tensor (depending on unaccessed `generation_config.return_dict_in_generate`) + if isinstance(outputs, torch.Tensor): + output_text_ids = outputs + else: + output_text_ids = outputs.sequences + + output_audio_codes = self.generated_audio_codes + + + output_values = self.audio_encoder.decode( + output_audio_codes, + ).audio_values + + + return output_text_ids, output_values - def _prepare_audio_encoder_kwargs_for_generation( - self, input_values, model_kwargs, model_input_name: Optional[str] = None + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, ): - # 1. get audio encoder - encoder = self.get_audio_encoder() - # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device - # as the inputs. - if hasattr(encoder, "_hf_hook"): - encoder._hf_hook.io_same_device = True - - # 2. Prepare encoder args and encoder kwargs from model kwargs. - irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] - encoder_kwargs = { - argument: value - for argument, value in model_kwargs.items() - if not any(argument.startswith(p) for p in irrelevant_prefix) - } - encoder_signature = set(inspect.signature(encoder.forward).parameters) - encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature - if not encoder_accepts_wildcard: - encoder_kwargs = { - argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature - } + # 1. Do usual operations done on LLMs like Gemma - because we pre-processed inputs, the first pass always has inputs_embeds - # 3. make sure that encoder returns `ModelOutput` - model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name - encoder_kwargs["return_dict"] = True + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) - encoder_kwargs[model_input_name] = input_values - audio_encoder_outputs = encoder.encode(**encoder_kwargs) - audio_codes = audio_encoder_outputs.audio_codes - audio_scales = audio_encoder_outputs.audio_scales + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - frames, bsz, codebooks, seq_len = audio_codes.shape + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min - if frames != 1: - raise ValueError( - f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is " - "disabled by setting `chunk_length=None` in the audio encoder." + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, ) - input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) - - model_kwargs["input_ids"] = input_ids - model_kwargs["audio_scales"] = audio_scales + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + + # 2. Now that everything is prepared, generate audio_codes using the depth decoder + + # we want to do it after a first token has been generated + if model_inputs["input_ids"] is not None: + last_hidden_state = kwargs.get("last_hidden_state") + # (batch_size, sequence_length, dim) -> (batch_size * sequence_length, 1, dim) + last_hidden_state = last_hidden_state.view(-1, 1, last_hidden_state.shape[-1]) + + input_ids = model_inputs.pop("input_ids") + + # TODO: allow passing generation kwargs + generated_audio_codes = self.depth_decoder.generate( + last_hidden_state=last_hidden_state, + input_ids=input_ids.view(-1, 1), + min_length=self.num_codebooks + 1,# TODO: change + max_length=self.num_codebooks + 1,# TODO: change + ) + + # the first tokens are text tokens + generated_audio_codes = generated_audio_codes[:, 1:].unsqueeze(2) + + self.generated_audio_codes = torch.cat([self.generated_audio_codes, generated_audio_codes], dim=2) + + # TODO: for now, we don't use blank user input ids !! + inputs_embeds, _ = self._prepare_inputs_embeds_for_generation(input_ids, moshi_audio_codes=generated_audio_codes) + + model_inputs["input_ids"] = None + model_inputs["inputs_embeds"] = inputs_embeds + + return model_inputs + + def _update_model_kwargs_for_generation(self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder, num_new_tokens) + + # update last_hidden_state that'll be used in the depth decoder + model_kwargs["last_hidden_state"] = outputs.get("last_hidden_state") return model_kwargs def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): @@ -2115,274 +2250,4 @@ def freeze_depth_decoder(self): """ for param in self.depth_decoder.parameters(): param.requires_grad = False - self.depth_decoder._requires_grad = False - - def _maybe_initialize_input_ids_for_generation( - self, - inputs: Optional[torch.Tensor] = None, - bos_token_id: Optional[int] = None, - model_kwargs: Optional[Dict[str, torch.Tensor]] = None, - ) -> torch.LongTensor: - """Initializes input ids for generation, if necessary.""" - if inputs is not None: - return inputs - - if bos_token_id is None: - raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") - - # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with - # soft-prompting or in multimodal implementations built on top of decoder-only language models. - batch_size = 1 - for value in model_kwargs.values(): - if isinstance(value, torch.Tensor): - batch_size = value.shape[0] - break - return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id - - def _get_decoder_start_token_id( - self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None - ) -> int: - decoder_start_token_id = ( - decoder_start_token_id - if decoder_start_token_id is not None - else self.generation_config.decoder_start_token_id - ) - bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id - - if decoder_start_token_id is not None: - return decoder_start_token_id - elif bos_token_id is not None: - return bos_token_id - raise ValueError( - "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." - ) - - @torch.no_grad() - def generate( - self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - synced_gpus: Optional[bool] = None, - streamer: Optional["BaseStreamer"] = None, - **kwargs, - ): - """ - - Generates sequences of token ids for models with a language modeling head. - - - - Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the - model's default generation configuration. You can override any `generation_config` by passing the corresponding - parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. - - For an overview of generation strategies and code examples, check out the [following - guide](./generation_strategies). - - - - Parameters: - inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): - The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the - method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` - should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of - `input_ids`, `input_values`, `input_features`, or `pixel_values`. - generation_config (`~generation.GenerationConfig`, *optional*): - The generation configuration to be used as base parametrization for the generation call. `**kwargs` - passed to generate matching the attributes of `generation_config` will override them. If - `generation_config` is not provided, the default will be used, which had the following loading - priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model - configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s - default values, whose documentation should be checked to parameterize generation. - logits_processor (`LogitsProcessorList`, *optional*): - Custom logits processors that complement the default logits processors built from arguments and - generation config. If a logit processor is passed that is already created with the arguments or a - generation config an error is thrown. This feature is intended for advanced users. - stopping_criteria (`StoppingCriteriaList`, *optional*): - Custom stopping criteria that complement the default stopping criteria built from arguments and a - generation config. If a stopping criteria is passed that is already created with the arguments or a - generation config an error is thrown. This feature is intended for advanced users. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - streamer (`BaseStreamer`, *optional*): - Streamer object that will be used to stream the generated sequences. Generated tokens are passed - through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - kwargs (`Dict[str, Any]`, *optional*): - Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be - forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder - specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. - - Return: - [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` - or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. - - If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible - [`~utils.ModelOutput`] types are: - - - [`~generation.GenerateDecoderOnlyOutput`], - - [`~generation.GenerateBeamDecoderOnlyOutput`] - - If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible - [`~utils.ModelOutput`] types are: - - - [`~generation.GenerateEncoderDecoderOutput`], - - [`~generation.GenerateBeamEncoderDecoderOutput`] - """ - # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects - if generation_config is None: - generation_config = self.generation_config - - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs - generation_config.validate() - self._validate_model_kwargs(model_kwargs.copy()) - - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - requires_attention_mask = False # TODO "encoder_outputs" not in model_kwargs - kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None - - # 3. Define model inputs - inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( - inputs, generation_config.bos_token_id, model_kwargs - ) - batch_size = inputs_tensor.shape[0] - self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device) - - # 4. Define other model kwargs - model_kwargs["use_cache"] = generation_config.use_cache - model_kwargs["guidance_scale"] = generation_config.guidance_scale - - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: - model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor - ) - - - if "input_ids" not in model_kwargs and "input_values" in model_kwargs: - model_kwargs = self._prepare_audio_encoder_kwargs_for_generation( - model_kwargs["input_values"], - model_kwargs, - ) - - # 5. Prepare `input_ids` which will be used for auto-regressive generation - input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( - batch_size=batch_size, - model_input_name=model_input_name, - model_kwargs=model_kwargs, - decoder_start_token_id=generation_config._decoder_start_token_tensor, - bos_token_id=generation_config._bos_token_tensor, - device=inputs_tensor.device, - ) - - # 6. Prepare `max_length` depending on other stopping criteria. - input_ids_length = input_ids.shape[-1] - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None - generation_config = self._prepare_generated_length( - generation_config=generation_config, - has_default_max_length=has_default_max_length, - has_default_min_length=has_default_min_length, - model_input_name=model_input_name, - inputs_tensor=inputs_tensor, - input_ids_length=input_ids_length, - ) - - # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Moshi) - input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( - input_ids, - pad_token_id=generation_config._decoder_start_token_tensor, - max_length=generation_config.max_length, - ) - # stash the delay mask so that we don't have to recompute in each forward pass - model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask - - # input_ids are ready to be placed on the streamer (if used) - if streamer is not None: - streamer.put(input_ids.cpu()) - - # 7. determine generation mode - generation_mode = generation_config.get_generation_mode() - - # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG) - if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: - logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) - generation_config.guidance_scale = None - - # 9. prepare distribution pre_processing samplers - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_length, - encoder_input_ids=inputs_tensor, - prefix_allowed_tokens_fn=None, - logits_processor=logits_processor, - device=input_ids.device, - ) - - # 10. prepare stopping criteria - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) - - if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # expand input_ids with `num_return_sequences` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_return_sequences, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - - # 11. run sample - outputs = self._sample( - input_ids, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, - **model_kwargs, - ) - - else: - raise ValueError( - "Got incompatible mode for generation, should be one of greedy or sampling. " - "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`." - ) - - if generation_config.return_dict_in_generate: - output_ids = outputs.sequences - else: - output_ids = outputs - - # apply the pattern mask to the final ids - output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) - - # revert the pattern delay mask by filtering the pad token id - output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape( - batch_size, self.decoder.num_codebooks, -1 - ) - - # append the frame dimension back to the audio codes - output_ids = output_ids[None, ...] - - audio_scales = model_kwargs.get("audio_scales") - if audio_scales is None: - audio_scales = [None] * batch_size - - output_values = self.audio_encoder.decode( - output_ids, - audio_scales=audio_scales, - ).audio_values - - - if generation_config.return_dict_in_generate: - outputs.sequences = output_values - return outputs - else: - return output_values \ No newline at end of file + self.depth_decoder._requires_grad = False \ No newline at end of file