diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 59f0ff48d22a75..4ccbc3e232e3c5 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -726,6 +726,8 @@ title: Mimi - local: model_doc/mms title: MMS + - local: model_doc/moshi + title: Moshi - local: model_doc/musicgen title: MusicGen - local: model_doc/musicgen_melody diff --git a/docs/source/en/model_doc/mimi.md b/docs/source/en/model_doc/mimi.md index 486d1836334949..58a431d31e6525 100644 --- a/docs/source/en/model_doc/mimi.md +++ b/docs/source/en/model_doc/mimi.md @@ -18,6 +18,7 @@ rendered properly in your Markdown viewer. ## Overview + The Mimi model was proposed in [Moshi: a speech-text foundation model for real-time dialogue](https://kyutai.org/Moshi.pdf) by Alexandre Défossez, Laurent Mazaré, Manu Orsini, Amélie Royer, Patrick Pérez, Hervé Jégou, Edouard Grave and Neil Zeghidour. Mimi is a high-fidelity audio codec model developed by the Kyutai team, that combines semantic and acoustic information into audio tokens running at 12Hz and a bitrate of 1.1kbps. In other words, it can be used to map audio waveforms into “audio tokens”, known as “codebooks”. The abstract from the paper is the following: @@ -29,6 +30,7 @@ Its architecture is based on [Encodec](model_doc/encodec) with several major dif * it uses additional transformers for encoding and decoding for better latent contextualization * it uses a different quantization scheme: one codebook is dedicated to semantic projection. + ## Usage example Here is a quick example of how to encode and decode an audio using this model: @@ -54,6 +56,7 @@ Here is a quick example of how to encode and decode an audio using this model: ``` This model was contributed by [Yoach Lacombe (ylacombe)](https://huggingface.co/ylacombe). + The original code can be found [here](https://github.com/kyutai-labs/moshi). diff --git a/docs/source/en/model_doc/moshi.md b/docs/source/en/model_doc/moshi.md new file mode 100644 index 00000000000000..3958a4dc5c1104 --- /dev/null +++ b/docs/source/en/model_doc/moshi.md @@ -0,0 +1,53 @@ + + +# Moshi + +## Overview + +The Moshi model was proposed in []() by . + + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + + +## MoshiConfig + +[[autodoc]] MoshiConfig + +## MoshiModel + +[[autodoc]] MoshiModel + - forward + +## MoshiForCausalLM + +[[autodoc]] MoshiForCausalLM + - forward + +## MoshiForConditionalGeneration + +[[autodoc]] MoshiForConditionalGeneration + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index aa13a97fe46150..3cb7872786bec2 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -594,7 +594,9 @@ "models.mt5": ["MT5Config"], "models.musicgen": [ "MusicgenConfig", - "MusicgenDecoderConfig", + ], + "models.moshi": [ + "MoshiConfig", ], "models.musicgen_melody": [ "MusicgenMelodyConfig", @@ -2790,6 +2792,15 @@ "MusicgenProcessor", ] ) + _import_structure["models.moshi"].extend( + [ + "MoshiForCausalLM", + "MoshiForConditionalGeneration", + "MoshiModel", + "MoshiPreTrainedModel", + "MoshiProcessor", + ] + ) _import_structure["models.musicgen_melody"].extend( [ "MusicgenMelodyForCausalLM", @@ -5384,6 +5395,9 @@ MusicgenConfig, MusicgenDecoderConfig, ) + from .models.moshi import ( + MoshiConfig, + ) from .models.musicgen_melody import ( MusicgenMelodyConfig, MusicgenMelodyDecoderConfig, @@ -7319,6 +7333,13 @@ MusicgenPreTrainedModel, MusicgenProcessor, ) + from .models.moshi import ( + MoshiForCausalLM, + MoshiForConditionalGeneration, + MoshiModel, + MoshiPreTrainedModel, + MoshiProcessor, + ) from .models.musicgen_melody import ( MusicgenMelodyForCausalLM, MusicgenMelodyForConditionalGeneration, diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 2339c4cd6b51d0..1bbc9a5d8016ba 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -1033,6 +1033,7 @@ def _get_non_default_generation_parameters(self) -> Dict[str, Any]: if decoder_config is not self: default_config = decoder_config.__class__() else: + default_config = None decoder_config = None # If it is a composite model, we want to check the subconfig that will be used for generation diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 5b5d1e7902bd67..b243bf6671f657 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -163,6 +163,7 @@ mra, mt5, musicgen, + moshi, musicgen_melody, mvp, nemotron, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 5a6ec14e78cd43..6893a247cba041 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -180,6 +180,7 @@ ("mra", "MraConfig"), ("mt5", "MT5Config"), ("musicgen", "MusicgenConfig"), + ("moshi", "MoshiConfig"), ("musicgen_melody", "MusicgenMelodyConfig"), ("mvp", "MvpConfig"), ("nat", "NatConfig"), @@ -484,6 +485,7 @@ ("mra", "MRA"), ("mt5", "MT5"), ("musicgen", "MusicGen"), + ("moshi", "Moshi"), ("musicgen_melody", "MusicGen Melody"), ("mvp", "MVP"), ("nat", "NAT"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 2bc71f07970aee..68cafca6e4e05e 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -171,6 +171,7 @@ ("mra", "MraModel"), ("mt5", "MT5Model"), ("musicgen", "MusicgenModel"), + ("moshi", "MoshiModel"), ("musicgen_melody", "MusicgenMelodyModel"), ("mvp", "MvpModel"), ("nat", "NatModel"), @@ -498,6 +499,7 @@ ("mixtral", "MixtralForCausalLM"), ("mpt", "MptForCausalLM"), ("musicgen", "MusicgenForCausalLM"), + ("moshi", "MoshiForCausalLM"), ("musicgen_melody", "MusicgenMelodyForCausalLM"), ("mvp", "MvpForCausalLM"), ("nemotron", "NemotronForCausalLM"), @@ -1261,6 +1263,7 @@ ("bark", "BarkModel"), ("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"), ("musicgen", "MusicgenForConditionalGeneration"), + ("moshi", "MoshiForConditionalGeneration"), ("musicgen_melody", "MusicgenMelodyForConditionalGeneration"), ("seamless_m4t", "SeamlessM4TForTextToSpeech"), ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToSpeech"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index e735579108d857..2f678b02edd75e 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -318,6 +318,7 @@ ), ), ("musicgen", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), + ("moshi", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), ("musicgen_melody", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), ("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)), ("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/mimi/configuration_mimi.py b/src/transformers/models/mimi/configuration_mimi.py index 5564b1a54ba63b..da0ba1e8a90823 100644 --- a/src/transformers/models/mimi/configuration_mimi.py +++ b/src/transformers/models/mimi/configuration_mimi.py @@ -30,6 +30,7 @@ class MimiConfig(PretrainedConfig): This is the configuration class to store the configuration of an [`MimiModel`]. It is used to instantiate a Mimi model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the + [kyutai/mimi](https://huggingface.co/kyutai/mimi) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the @@ -126,6 +127,7 @@ class MimiConfig(PretrainedConfig): ```python >>> from transformers import MimiModel, MimiConfig + >>> # Initializing a "kyutai/mimi" style configuration >>> configuration = MimiConfig() diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index d91b057ef28ec4..b50e7fa64ff02d 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1000,6 +1000,7 @@ def forward( ) use_cache = False + if use_cache and not isinstance(past_key_values, Cache): if past_key_values is None: past_key_values = DynamicCache() @@ -1687,6 +1688,7 @@ def forward( >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example") >>> audio_sample = dataset["train"]["audio"][0]["array"] + >>> model_id = "kyutai/mimi" >>> model = MimiModel.from_pretrained(model_id) >>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id) diff --git a/src/transformers/models/moshi/__init__.py b/src/transformers/models/moshi/__init__.py new file mode 100644 index 00000000000000..c1c617a3e816d8 --- /dev/null +++ b/src/transformers/models/moshi/__init__.py @@ -0,0 +1,59 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = { + "configuration_moshi": [ + "MoshiConfig", + ], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_moshi"] = [ + "MoshiForConditionalGeneration", + "MoshiForCausalLM", + "MoshiModel", + "MoshiPreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_moshi import ( + MoshiConfig, + ) + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_moshi import ( + MoshiForCausalLM, + MoshiForConditionalGeneration, + MoshiModel, + MoshiPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/moshi/configuration_moshi.py b/src/transformers/models/moshi/configuration_moshi.py new file mode 100644 index 00000000000000..eff73613a82b80 --- /dev/null +++ b/src/transformers/models/moshi/configuration_moshi.py @@ -0,0 +1,227 @@ +# coding=utf-8 +# Copyright 2024 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Moshi model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto.configuration_auto import AutoConfig + + +logger = logging.get_logger(__name__) + + + +class MoshiConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MoshiModel`]. It is used to instantiate a + Moshi model according to the specified arguments, defining the audio encoder, Moshi depth decoder and Moshi decoder + configs. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the MoshiDecoder model. Defines the number of different tokens that can be + represented by the `inputs_ids` passed when calling [`MoshiDecoder`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimensionality of the layers and the pooler layer of the main decoder. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of decoder layers. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the main decoder block. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. + max_position_embeddings (`int`, *optional*, defaults to 3750): + The maximum sequence length that this model might ever be used with. Typically, set this to something large + just in case (e.g., 512 or 1024 or 2048). + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + sliding_window (`int`, *optional*, defaults to 250): + Sliding window attention window size. If not specified, will default to `250`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ffn_dim (`int`, *optional*, defaults to 22528): + Dimensionality of the "intermediate" (often named feed-forward) layer in the main decoder block. Must be even. + num_codebooks (`int`, *optional*, defaults to 8): + The number of audio codebooks for each audio channels. + rms_norm_eps (`float`, *optional*, defaults to 1e-8): + The epsilon used by the rms normalization layers. + depth_hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer of the depth decoder. + depth_num_hidden_layers (`int`, *optional*, defaults to 6): + Number of depth decoder layers. + depth_num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the depth decoder block. + depth_max_position_embeddings (`int`, *optional*, defaults to 8): + The maximum sequence length that the depth decoder model might ever be used with. Typically, set this to the + number of codebooks. + depth_ffn_dim (`int`, *optional*, defaults to 5632): + Dimensionality of the "intermediate" (often named feed-forward) layer in the depth decoder block. Must be even. + depth_head_dim (`int`, *optional*, defaults to `depth_hidden_size // depth_num_attention_heads`): + The attention head dimension of the depth encoder layers. + depth_num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention in the depth decoder. + If it is not specified, will default to `depth_num_key_value_heads`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether input and output word embeddings should be tied. + kwargs (*optional*): + Dictionary of keyword arguments. Notably: + - **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that + defines the audio encoder config. + + + Example: + + ```python # TODO(YL): update + >>> from transformers import ( + ... MoshiConfig, + ... EncodecConfig, + ... MoshiForConditionalGeneration, + ... ) + + >>> # Initializing text encoder, audio encoder, and decoder model configurations + >>> audio_encoder_config = EncodecConfig() + + >>> configuration = MoshiConfig.from_sub_models_config( + ... audio_encoder_config + ... ) + + >>> # Initializing a MoshiForConditionalGeneration (with random weights) from the kyutai/moshiko style configuration + >>> model = MoshiForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + >>> config_text_encoder = model.config.text_encoder + >>> config_audio_encoder = model.config.audio_encoder + >>> config_decoder = model.config.decoder + + >>> # Saving the model, including its configuration + >>> model.save_pretrained("moshi-model") + + >>> # loading model and config from pretrained folder + >>> moshi_config = MoshiConfig.from_pretrained("moshi-model") + >>> model = MoshiForConditionalGeneration.from_pretrained("moshi-model", config=moshi_config) + ```""" + + model_type = "moshi" + is_composition = True + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__(self, + vocab_size=32000, + hidden_size=4096, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + audio_vocab_size=None, # TODO + max_position_embeddings=3750, + rope_theta=10000.0, + hidden_act="silu", + head_dim=None, + initializer_range=0.02, + use_cache=True, + sliding_window=3000, + attention_dropout=0.0, + ffn_dim=22528, + rms_norm_eps=1e-8, + num_codebooks=8, + depth_hidden_size=1024, + depth_num_hidden_layers=6, + depth_max_position_embeddings=8, + depth_num_attention_heads=16, + depth_ffn_dim=5632, + depth_head_dim=None, + depth_num_key_value_heads=None, + tie_word_embeddings=False, + **kwargs): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.hidden_act = hidden_act + self.head_dim = head_dim or hidden_size // num_attention_heads + self.initializer_range = initializer_range + self.use_cache = use_cache + self.sliding_window = sliding_window + self.attention_dropout = attention_dropout + if ffn_dim % 2 == 1: + raise ValueError(f"`ffn_dim={ffn_dim}` must be even.") + self.ffn_dim = ffn_dim + self.rms_norm_eps = rms_norm_eps + self.num_codebooks = num_codebooks + + + self.depth_hidden_size = depth_hidden_size + self.depth_num_hidden_layers = depth_num_hidden_layers + self.depth_max_position_embeddings = depth_max_position_embeddings + self.depth_num_attention_heads = depth_num_attention_heads + if depth_ffn_dim % 2 == 1: + raise ValueError(f"`depth_ffn_dim={depth_ffn_dim}` must be even.") + self.depth_ffn_dim = depth_ffn_dim + self.depth_head_dim = depth_head_dim or depth_hidden_size // depth_num_attention_heads + self.depth_num_key_value_heads = depth_num_key_value_heads if depth_num_key_value_heads is not None else depth_num_attention_heads + + audio_encoder_config = kwargs.pop("audio_encoder", None) + if audio_encoder_config is None: + raise ValueError("Config has to be initialized with audio_encoder config") + + audio_encoder_model_type = audio_encoder_config.pop("model_type") + + self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config) + + if self.num_codebooks > self.audio_encoder.num_codebooks: + raise ValueError(f"`num_codebooks={num_codebooks}` is greater than the maximum number of codebooks that the audio encoder can deal with ({self.audio_encoder.num_codebooks}). Please lower it.") + + self.audio_vocab_size = self.audio_encoder.codebook_size if audio_vocab_size is None else audio_vocab_size + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + @property + def sampling_rate(self): + return self.audio_encoder.sampling_rate + + @classmethod + def from_audio_encoder_config( + cls, + audio_encoder_config: PretrainedConfig, + **kwargs, + ): + r""" + Instantiate a [`MoshiConfig`] (or a derived class) from an audio encoder configuration. + + Returns: + [`MoshiConfig`]: An instance of a configuration object + """ + + return cls( + audio_encoder=audio_encoder_config.to_dict(), + **kwargs, + ) diff --git a/src/transformers/models/moshi/convert_moshi_transformers.py b/src/transformers/models/moshi/convert_moshi_transformers.py new file mode 100644 index 00000000000000..1575a5110acccb --- /dev/null +++ b/src/transformers/models/moshi/convert_moshi_transformers.py @@ -0,0 +1,256 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Moshi checkpoints.""" + +import argparse + +import safetensors +import torch + +from transformers import ( + MoshiConfig, + MoshiForConditionalGeneration, + MimiModel, # initial audio encoder + logging, +) +# EncodecFeatureExtractor, #TODO(YL): add it here and as AutoFeatureExtractor + + +logging.set_verbosity_info() +logger = logging.get_logger("transformers.models.mimi") + + +def assert_param_count(model_1, model_2): + count_1 = sum(p[1].numel() for p in model_1.named_parameters() if "final_proj" not in p[0]) + count_2 = sum(p[1].numel() for p in model_2.named_parameters() if "final_proj" not in p[0]) + assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}" + + +def param_count(model): + return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0]) + + +def _grab_best_device(use_gpu=True): + if torch.cuda.device_count() > 0 and use_gpu: + device = "cuda" + else: + device = "cpu" + return torch.device(device) + + +convert_list = [ + # GENERAL + ("out_norm", "decoder.model.norm"), + ("depformer_emb", "depth_decoder.emb"), + ("depformer_text_emb", "depth_decoder.text_emb"), + ("text_emb", "decoder.model.emb"), + + ("emb", "embed_tokens"), + ("text_linear", "decoder.lm_head"), + + ("depformer", "depth_decoder"), + ("transformer", "decoder.model"), + + # TRANSFORMERS PART + ("gating.linear_in", "mlp.fc1"), + ("gating.linear_out", "mlp.fc2"), + ("self_attn.out_proj", "self_attn.o_proj"), + ("norm1", "input_layernorm"), + ("norm2", "post_attention_layernorm"), + ("layer_scale_1", "self_attn_layer_scale"), + ("layer_scale_2", "mlp_layer_scale"), + ("alpha", "weight") +] + +def _preprocess_state_dict(state_dict, config): + # Moshi original weights are using a gating mechanism + + # pattern for depth transformer: + # stack(gating.{i}.linear_in)->mlp.fc1 + # stack(gating.{i}.linear_out)->mlp.fc2 + + for layer_idx in range(config.depth_num_hidden_layers): + linear_layers_in = [state_dict.pop(f"depformer.layers.{layer_idx}.gating.{i}.linear_in.weight") for i in range(config.num_codebooks)] + linear_layers_out = [state_dict.pop(f"depformer.layers.{layer_idx}.gating.{i}.linear_out.weight") for i in range(config.num_codebooks)] + + state_dict[f"depth_decoder.layers.{layer_idx}.mlp.fc1.weight"] = torch.stack(linear_layers_in) + state_dict[f"depth_decoder.layers.{layer_idx}.mlp.fc2.weight"] = torch.stack(linear_layers_out) + + input_projections = [] + lm_heads = [] + for codebook_idx in range(config.num_codebooks): + input_projections.append(state_dict.pop(f"depformer_in.{codebook_idx}.weight")) + lm_heads.append(state_dict.pop(f"linears.{codebook_idx}.weight")) + + state_dict["depth_decoder.input_projections.weight"] = torch.stack(input_projections, dim = 0) + state_dict["depth_decoder.lm_heads.weight"] = torch.stack(lm_heads, dim = 0) + + return state_dict + + +def _convert_model( + state_dict, + hf_model, + convert_list, + device, + config, + unwanted_prefix=None, +): + hidden_size = config.hidden_size + head_dim = config.head_dim + num_heads = int(config.hidden_size // config.head_dim) + num_key_value_heads = config.num_key_value_heads + key_value_head_dim = config.num_key_value_heads * head_dim + + + depth_hidden_size = config.depth_hidden_size + depth_head_dim = config.depth_head_dim + depth_num_heads = int(config.depth_hidden_size // config.depth_head_dim) + depth_num_key_value_heads = config.depth_num_key_value_heads + depth_key_value_head_dim = config.depth_num_key_value_heads * depth_head_dim + + state_dict = _preprocess_state_dict(state_dict, config) + + # permute for sliced rotary + def permute(w, n_heads, dim1=hidden_size, dim2=hidden_size): + return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + for k, v in list(state_dict.items()): + if "audio_encoder" not in k: + new_k = k if unwanted_prefix is None else k[len(unwanted_prefix) :] + for old_layer_name, new_layer_name in convert_list: + if old_layer_name in new_k: + new_k = new_k.replace(old_layer_name, new_layer_name) + + if "alpha" in k: + state_dict[k] = state_dict[k].squeeze() + + + + if "in_proj_weight" in new_k: + # split qkv into query key and value + mixed_qkv = state_dict.pop(k) + qkv_dim = mixed_qkv.size(0) // 3 + + query_layer = mixed_qkv[:qkv_dim] + key_layer = mixed_qkv[qkv_dim : qkv_dim * 2] + value_layer = mixed_qkv[qkv_dim * 2 :] + if "depth_decoder" in new_k: + state_dict[new_k.replace("in_proj_weight", "q_proj.weight")] = query_layer.view(config.num_codebooks, -1, query_layer.shape[-1]) + state_dict[new_k.replace("in_proj_weight", "k_proj.weight")] = key_layer.view(config.num_codebooks, -1, key_layer.shape[-1]) + state_dict[new_k.replace("in_proj_weight", "v_proj.weight")] = value_layer.view(config.num_codebooks, -1, value_layer.shape[-1]) + else: + state_dict[new_k.replace("in_proj_weight", "q_proj.weight")] = permute( + query_layer, num_heads, hidden_size, hidden_size) + state_dict[new_k.replace("in_proj_weight", "k_proj.weight")] = permute( + key_layer, num_key_value_heads, dim1=key_value_head_dim, dim2=hidden_size + ) + state_dict[new_k.replace("in_proj_weight", "v_proj.weight")] = value_layer + elif "o_proj" in new_k and "depth_decoder" in new_k: + output_layer = state_dict.pop(k) + state_dict[new_k] = output_layer.view(config.num_codebooks, -1, output_layer.shape[-1]) + else: + state_dict[new_k] = state_dict.pop(k) + + # Do the last one by hand + state_dict["depth_decoder.text_embed_tokens.weight"] = state_dict.pop("depth_decoder.decoder.model.embed_tokens.weight") + + extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys()) + missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys()) + if len(extra_keys) != 0: + raise ValueError(f"extra keys found: {extra_keys}") + if len(missing_keys) != 0: + raise ValueError(f"missing keys: {missing_keys}") + hf_model.load_state_dict(state_dict, strict=True) + n_params = param_count(hf_model) + + logger.info(f"model loaded: {round(n_params/1e6,1)}M params") + + hf_model.eval() + hf_model.to(device) + del state_dict + + return hf_model + + +@torch.no_grad() +def convert_checkpoint( + checkpoint_path, + pytorch_dump_folder_path, + mimi_repo_id, + config_path=None, + repo_id=None, +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + device = _grab_best_device() + + mimi_model = MimiModel.from_pretrained(mimi_repo_id) + + if config_path is not None: + config = MoshiConfig.from_pretrained(config_path) + else: + audio_encoder_config = mimi_model.config + config = MoshiConfig.from_audio_encoder_config(audio_encoder_config) + + model = MoshiForConditionalGeneration(config) + + # feature_extractor = EncodecFeatureExtractor( + # feature_size=config.audio_channels, + # sampling_rate=config.sampling_rate, + # ) + # feature_extractor.save_pretrained(pytorch_dump_folder_path) + + original_checkpoint = safetensors.torch.load_file(checkpoint_path) + if "best_state" in original_checkpoint: + # we might have a training state saved, in which case discard the yaml results and just retain the weights + original_checkpoint = original_checkpoint["best_state"] + + audio_checkpoint = mimi_model.state_dict() + original_checkpoint.update({f"audio_encoder.{key}": value for (key, value) in audio_checkpoint.items()}) + + model = _convert_model(original_checkpoint, model, convert_list, device, config) + + # TODO: set generation config + + model.save_pretrained(pytorch_dump_folder_path) + + if repo_id: + print("Pushing to the hub...") + # feature_extractor.push_to_hub(repo_id) + model.push_to_hub(repo_id, private=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint") + parser.add_argument("--mimi_repo_id", required=True, default=None, type=str, help="Repository id to HF Mimi.") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument( + "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub." + ) + + args = parser.parse_args() + convert_checkpoint( + args.checkpoint_path, + args.pytorch_dump_folder_path, + args.mimi_repo_id, + args.config_path, + args.push_to_hub, + ) diff --git a/src/transformers/models/moshi/generation_configuration_moshi.py b/src/transformers/models/moshi/generation_configuration_moshi.py new file mode 100644 index 00000000000000..954bab35e45105 --- /dev/null +++ b/src/transformers/models/moshi/generation_configuration_moshi.py @@ -0,0 +1,87 @@ +# coding=utf-8 +# Copyright 2024 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Moshigeneration configuration""" + +import copy +from typing import Dict + +from ...generation.configuration_utils import GenerationConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MoshiGenerationConfig(GenerationConfig): + model_type = "moshi" + is_composition = True + + # TODO (joao): nested from_dict + + def __init__( + self, + depth_decoder_config: Dict = None, + **kwargs, + ): + """Class that holds a generation configuration for [`MoshiForConditionalGeneration`]. + + The [`MoshiForConditionalGeneration`] model needs to encapsulates two generation config, the main one, and one for its depth decoder. + + This configuration inherit from [`GenerationConfig`] and can be used to control the model generation. Read the + documentation from [`GenerationConfig`] for more information. + + Args: + depth_decoder_config (`Dict`, *optional*): + Depth decoder generation configuration. + # TODO(YL): kwargs + + """ + super.__init__(**kwargs) + if depth_decoder_config is None: + depth_decoder_config = {} + logger.info("depth_decoder_config is None. initializing the semantic model with default values.") + + self.depth_decoder_config = GenerationConfig(**depth_decoder_config) + + @classmethod + def from_depth_decoder_config( + cls, + depth_decoder_config: GenerationConfig, + **kwargs, + ): + r""" + Instantiate a [`MoshiGenerationConfig`] (or a derived class) from Moshi depth decoder generation configuration. + + Returns: + [`MoshiGenerationConfig`]: An instance of a configuration object + """ + return cls( + depth_decoder_config=depth_decoder_config.to_dict(), + **kwargs, + ) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + + output["depth_decoder_config"] = self.depth_decoder_config.to_dict() + + output["model_type"] = self.__class__.model_type + return output diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py new file mode 100644 index 00000000000000..e7efa12aab1729 --- /dev/null +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -0,0 +1,2253 @@ +# coding=utf-8 +# Copyright 2024 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Moshi model.""" + +import copy +import inspect +import math +import random +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...generation.configuration_utils import GenerationConfig, GenerationMode +from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList +from ...generation.stopping_criteria import StoppingCriteriaList +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_outputs import ( + BaseModelOutput, + ModelOutput, + Seq2SeqLMOutput, + CausalLMOutputWithPast, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_auto import AutoModel +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from ...pytorch_utils import ALL_LAYERNORM_LAYERS + +from .configuration_moshi import MoshiConfig + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + +if TYPE_CHECKING: + from ...generation.streamers import BaseStreamer + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MoshiConfig" +_CHECKPOINT_FOR_DOC = "kyutai/moshiko" + + +@dataclass +class MoshiCausalLMOutputWithPast(ModelOutput): + """ + `MoshiForCausalLM` outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MoshiConditionalGenerationOutputWithPast(ModelOutput): + """ + `MoshiForConditionalGeneration` outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `text_labels` is provided): + Text language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the text language modeling head (scores for each vocabulary token before SoftMax). + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + depth_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `audio_labels` is provided): + Audio language modeling loss (for next-token prediction). + audio_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the audio language modeling heads. + depth_past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Past key-values of the depth decoder. + depth_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Hidden states of the depth decoder + depth_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Depth decoder's Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + depth_loss: Optional[torch.FloatTensor] = None + audio_logits: torch.FloatTensor = None + depth_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + depth_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + depth_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + # transpose to get (bsz, num_codebooks, seq_len) + input_ids = input_ids.transpose(1, 2) + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + if decoder_start_token_id is None: + raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaRMSNorm with Gemma->Moshi +class MoshiRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Moshi is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +ALL_LAYERNORM_LAYERS.append(MoshiRMSNorm) + + +class MoshiFlexibleLinear(nn.Module): + def __init__(self, input_size, output_size, num_layers): + super().__init__() + # Stack the weights for N layers into a single tensor (num_layers, output_size, input_size) + self.weight = nn.Parameter(torch.randn(num_layers, output_size, input_size)) + + def forward(self, x, layer_idx=None): + """ + `MoshiFlexibleLinear` creates one linear layer per codebook. There's multiple ways to use it. + In the default case, `sequence_length=num_layers`, so each element of the sequence will be matmul to the weights corresponding to its index on the sequence. + + For more advanced cases, one can specify which codebook's layer(s) to use with `layer_idx`. + If `layer_idx` indicates a single integer, all of the element of the sequence will be matmul to this single codebook's layer. + But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`. + + + Args: + x (`torch.FloatTensor): input to the layer of shape `(batch, num_layers, embed_dim)` or of shape `(batch, seq_length, embed_dim)` + layer_idx (`torch.Tensor`, *optional*): + Can be used to specify which codebook's layers(s) to use. + If it's a tensor of shape `(seq_length,)`, will matmul each element of the sequence to the corresponding weights. + But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`. + """ + if layer_idx is not None: + # Use torch.gather to select the corresponding weights for each sample + selected_weights = torch.index_select(self.weight, 0, layer_idx) + return torch.einsum('bnh,noh->bno', x, selected_weights) + + # Multiple layers case: + # use einsum to batch the operations (batch_size, num_layers, input_size) -> (batch_size, num_layers, output_size) + return torch.einsum('bnh,noh->bno', x, self.weight) + + +# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Moshi +class MoshiRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward + # TODO(joao): add me back asap :) + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MoshiGatingMLP(nn.Module): + def __init__(self, config, num_layers=1, is_depth_mlp=False): + super().__init__() + + self.activation_fn = ACT2FN[config.hidden_act] + ffn_dim = config.ffn_dim if not is_depth_mlp else config.depth_ffn_dim + hidden_size = config.hidden_size if not is_depth_mlp else config.depth_hidden_size + if num_layers == 1: + self.fc1 = nn.Linear(hidden_size, ffn_dim, bias=False) + self.fc2 = nn.Linear(ffn_dim // 2, hidden_size, bias=False) + else: + self.fc1 = MoshiFlexibleLinear(hidden_size, ffn_dim, num_layers) + self.fc2 = MoshiFlexibleLinear(ffn_dim // 2, hidden_size, num_layers) + + def forward(self, hidden_states: torch.Tensor, layer_idx:int=None) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) if layer_idx is None else self.fc1(hidden_states, layer_idx) + + batch_size, sequence_length, _ = hidden_states.shape + hidden_states = hidden_states.view(batch_size, sequence_length, 2, -1) + hidden_states = self.activation_fn(hidden_states[..., 0, :]) * hidden_states[..., 1, :] + hidden_states = self.fc2(hidden_states) if layer_idx is None else self.fc2(hidden_states, layer_idx) + return hidden_states + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MoshiAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, is_depth_attention=False): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.is_depth_attention = is_depth_attention + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size if not is_depth_attention else config.depth_hidden_size + self.num_heads = config.num_attention_heads if not is_depth_attention else config.depth_num_attention_heads + self.head_dim = config.head_dim if not is_depth_attention else config.depth_head_dim + self.num_key_value_heads = config.num_key_value_heads if not is_depth_attention else config.depth_num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings if not is_depth_attention else config.depth_max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.scaling = 1 / math.sqrt(self.head_dim) + + if self.hidden_size % self.num_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + if not is_depth_attention: + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + else: + self.q_proj = MoshiFlexibleLinear(self.hidden_size, self.num_heads * self.head_dim, num_layers=config.num_codebooks) + self.k_proj = MoshiFlexibleLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, num_layers=config.num_codebooks) + self.v_proj = MoshiFlexibleLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, num_layers=config.num_codebooks) + self.o_proj = MoshiFlexibleLinear(self.num_heads * self.head_dim, self.hidden_size, num_layers=config.num_codebooks) + + # rotary embeddings are not used in the depth decoder + self.rotary_emb = MoshiRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) if not is_depth_attention else None + + # Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward with Gemma->Moshi + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, cache_position) # Ignore copy + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.rotary_emb is not None: # Ignore copy + cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} if self.rotary_emb is not None else {"cache_position": cache_position} # Ignore copy + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, cache_position) # Ignore copy + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Moshi +class MoshiFlashAttention2(MoshiAttention): + """ + Moshi flash attention module. This module inherits from `MoshiAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, cache_position) # Ignore copy + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.rotary_emb is not None: # Ignore copy + cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy + + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} if self.rotary_emb is not None else {"cache_position": cache_position} # Ignore copy + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (MoshiRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, cache_position) # Ignore copy + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Moshi +class MoshiSdpaAttention(MoshiAttention): + """ + Moshi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MoshiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MoshiAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MoshiModel is using MoshiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, cache_position) # Ignore copy + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.rotary_emb is not None: # Ignore copy + cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy + + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} if self.rotary_emb is not None else {"cache_position": cache_position} # Ignore copy + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, cache_position) # Ignore copy + + return attn_output, None, past_key_value + + +MOSHI_ATTENTION_CLASSES = { + "eager": MoshiAttention, + "flash_attention_2": MoshiFlashAttention2, + "sdpa": MoshiSdpaAttention, +} + +class MoshiDecoderLayer(nn.Module): + def __init__(self, config: MoshiConfig, layer_idx: int, use_flexible_linear: bool=False, is_depth_layer=False): + super().__init__() + self.is_depth_layer = is_depth_layer + self.hidden_size = config.hidden_size if not is_depth_layer else config.depth_hidden_size + self.use_flexible_linear = use_flexible_linear + + self.self_attn = MOSHI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx, is_depth_attention=is_depth_layer) + + self.mlp = MoshiGatingMLP(config) if not use_flexible_linear else MoshiGatingMLP(config, config.num_codebooks, is_depth_layer) + self.input_layernorm = MoshiRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MoshiRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) if not self.use_flexible_linear else self.mlp(hidden_states, cache_position) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenPreTrainedModel with Musicgen->Moshi +class MoshiPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MoshiConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MoshiDecoderLayer", "MoshiAttention"] + _supports_flash_attn_2 = True + _supports_sdpa = True + main_input_name = "input_ids" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MOSHI_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MoshiConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +# TODO: update +MOSHI_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `input_ids` and `inputs_embeds` are both unset, `inputs_embeds` takes the value + of `inputs_embeds`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +MOSHI_DECODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): + Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes. + + Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes, + such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + + + + The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks, + target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If + you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of + frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks, + target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as + `input_ids`. + + + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of + the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +class MoshiDepthDecoder(MoshiPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoshiTransformerLayer`] + + Args: + config: MoshiConfig + """ + + def __init__(self, config: MoshiConfig): + super().__init__(config) + + self.text_embed_tokens = nn.Embedding(config.vocab_size + 1, config.depth_hidden_size) + + # the last codebook is never used as input + self.embed_tokens = nn.ModuleList( + [nn.Embedding(config.audio_vocab_size + 1, config.depth_hidden_size) for _ in range(config.num_codebooks - 1)] + ) + + self.input_projections = MoshiFlexibleLinear(config.hidden_size, config.depth_hidden_size, config.num_codebooks) + + self.layers = nn.ModuleList( + [MoshiDecoderLayer(config, layer_idx, use_flexible_linear=True, is_depth_layer=True) for layer_idx in range(config.depth_num_hidden_layers)] + ) + + self.lm_heads = MoshiFlexibleLinear(config.depth_hidden_size, config.audio_vocab_size, config.num_codebooks) + + self._attn_implementation = config._attn_implementation + + self.gradient_checkpointing = False + self.config = config + + def forward( # TODO: update docstrings entirely + self, + input_ids: Optional[torch.LongTensor] = None, # sequence of oracle input ids, i.e it must be the input ids that are predicted by the decoder # (B, S) + last_hidden_state: torch.LongTensor = None, # shape: (B*S, 1, hidden_dim) # use 8 times (B, S, H_in) | (B*S, H_in) + attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, # TODO: add to docstrings + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Embedded representation that will be contextualized by the model + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + # here, we actually predict a sequence length of C + # independtly from the batch size and sequence length + # 1/ input ids is passed through text_embed_tokens -> (B * S, H) H=1024 + # 2/ each codebooks is passed through the embedding layer ase well -> (B*S, C-1, H) + # 3/ concat the two precedent results and get (B*S, C, ,H) + # 4/ then we also pass the last hidden states through the input projection layers, one for each codebooks + # we get (B*S, C, H) + # 5/ sum one and another (B*S, C, H) + # 6/ pass (B*S, C, H) through the model and get (B*S, C, H_out) + # 7/ for each codebook, pass it through its own lm heads: (B*S, C, H) + # 8/ predict the codebook C1, C2 ... -> (B, S, C, H) + + # generation: + # we start with last hidden states and text tokens + # depending on position ids chose which embedding layer + + # TODO: can we suppose B*S each time instead of B,S + # in the generation mode, it's different: + # text_token (B*S, ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if use_cache and past_key_values is None and not self.training: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() + if cache_position is None: + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + input_ids.shape[1], device=input_ids.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # If inputs_embeds is provided, it has the priority over input_ids, which won't be used + if inputs_embeds is None: + inputs_embeds = [] + for position_idx in cache_position: + if position_idx == 0: + inputs_embeds.append(self.text_embed_tokens(input_ids[:, [position_idx]])) + else: + inputs_embeds.append(self.embed_tokens[(position_idx-1)](input_ids[:, [position_idx - past_seen_tokens]])) + + inputs_embeds = torch.cat(inputs_embeds, dim=1) + + inputs_embeds += self.input_projections(last_hidden_state, cache_position) + + + causal_mask = None + if attention_mask is not None: + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + hidden_states = inputs_embeds + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + # TODO: remove the float() operation in v4.46 + logits = self.lm_heads(hidden_states, cache_position).float() + + loss = None + if labels is not None: + loss = 0 + # TODO(YL) + + if not return_dict: + return tuple(v for v in [loss, logits, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.gemma.modeling_gemma.GemmaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): + # TODO: use this to make sure max_tokens = num_codebooks + super()._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + # Copied from transformers.models.gemma.modeling_gemma.GemmaForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "last_hidden_state": kwargs.get("last_hidden_state") # Ignore copy + } + ) + return model_inputs + + + + +@add_start_docstrings( + "The bare Moshi Model outputting raw hidden-states without any specific head on top.", + MOSHI_START_DOCSTRING, +) +# Copied from transformers.models.gemma.modeling_gemma.GemmaModel with Gemma->Moshi +class MoshiModel(MoshiPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoshiDecoderLayer`] + + Args: + config: MoshiConfig + """ + + def __init__(self, config: MoshiConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size + 1, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MoshiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = MoshiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MOSHI_DECODER_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, # TODO(YL): add to docstrings + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False # noqa: F841 + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True # noqa: F841 + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # normalized + # Moshi downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" + ) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +@add_start_docstrings( + "The Moshi decoder model with a text language modelling head on top. Only usable for text.", + MOSHI_START_DOCSTRING, +) +# Copied from transformers.models.musicgen.modeling_gemma.GemmaForCausalLM with GEMMA->MOSHI,Gemma->Moshi,gemma-7b->moshiko,google->kyutai, CausalLM->MoshiCausalLM +class MoshiForCausalLM(MoshiPreTrainedModel): + _tied_weights_keys = None # Ignore copy + + def __init__(self, config): + super().__init__(config) + self.model = MoshiModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MOSHI_DECODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoshiCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, MoshiCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MoshiForCausalLM + + >>> model = MoshiForCausalLM.from_pretrained("kyutai/moshiko") + >>> tokenizer = AutoTokenizer.from_pretrained("kyutai/moshiko") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return MoshiCausalLMOutputWithPast( + loss=loss, + logits=logits, + last_hidden_state=hidden_states, # Ignore copy + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + +@add_start_docstrings( + "The original Moshi model with an audio encoder, a Moshi depth decoder and a Moshi decoder, " + "for speech-to-speech.", + MOSHI_START_DOCSTRING, +) +class MoshiForConditionalGeneration(MoshiPreTrainedModel): + config_class = MoshiConfig + main_input_name = "input_ids" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def __init__(self, config: MoshiConfig): + super().__init__(config) + # We have 2 * num_codebooks audio embedding layers because we have the user input channel and the model output channel. + self.embed_tokens = nn.ModuleList( + [nn.Embedding(config.audio_vocab_size + 1, config.hidden_size) for _ in range(2 * config.num_codebooks)] + ) + self.audio_encoder = AutoModel.from_config(config.audio_encoder, attn_implementation=config._attn_implementation) + self.decoder = MoshiForCausalLM(config) + self.depth_decoder = MoshiDepthDecoder(config) + + self.num_codebooks = config.num_codebooks + self.post_init() + + def get_audio_encoder(self): + return self.audio_encoder + + def get_depth_decoder(self): + return self.depth_decoder + + def get_decoder(self): + return self.decoder + + def get_input_embeddings(self): + return self.decoder.get_input_embeddings() + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + return self.decoder.set_output_embeddings(new_embeddings) + + @add_start_docstrings_to_model_forward(MOSHI_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + user_input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it + user_audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings + moshi_input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it + moshi_audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings + past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + text_labels: Optional[torch.LongTensor] = None, #TODO: update do docstrings + audio_labels: Optional[torch.LongTensor] = None, #TODO: update do docstrings - must be 16 channels (first user than moshi?) + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, Seq2SeqLMOutput]: + r""" + Returns: + + Examples: + ```python + >>> from transformers import AutoProcessor, MoshiForConditionalGeneration + >>> import torch + + >>> processor = AutoProcessor.from_pretrained("kyutai/moshiko") + >>> model = MoshiForConditionalGeneration.from_pretrained("kyutai/moshiko") + >>> # TODO(YL): update + >>> inputs = processor( + ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"], + ... padding=True, + ... return_tensors="pt", + ... ) + + >>> pad_token_id = model.generation_config.pad_token_id + >>> input_ids = ( + ... torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long) + ... * pad_token_id + ... ) + + >>> logits = model(**inputs, input_ids=input_ids).logits + >>> logits.shape # (bsz * num_codebooks, tgt_len, vocab_size) + torch.Size([8, 1, 2048]) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + kwargs_audio_encoder = { + argument[len("audio_encoder_")]: value + for argument, value in kwargs.items() + if argument.startswith("audio_encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + kwargs_depth_decoder = { + argument[len("depth_decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("depth_decoder_") + } + + # TODO: we need to have same number of timestamps, and same number of batch + + + if (text_labels is not None) and (input_ids is None and inputs_embeds is None): + input_ids = shift_tokens_right( + text_labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id + ) + # TODO: also do it with audio_labels? + + # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used + if inputs_embeds is None: + if user_input_values is not None and user_audio_codes is None: + user_audio_codes = self.audio_encoder.encode(user_input_values, num_quantizers=self.num_codebooks, **kwargs_audio_encoder)[0] + + if moshi_input_values is not None and moshi_audio_codes is None: + moshi_audio_codes = self.audio_encoder.encode(moshi_input_values, num_quantizers=self.num_codebooks, **kwargs_audio_encoder)[0] + + # TODO: make sure it's the right order (user than moshi) + make sure it's done over the right dim + audio_codes = torch.cat([user_audio_codes, moshi_audio_codes], dim=1) + + if input_ids is None and audio_codes is None: + raise ValueError("You must provide at least one of `input_ids`, `inputs_embeds`, `input_values` and `audio_codes`.") + + if input_ids is not None: + inputs_embeds = self.decoder.model.embed_tokens(input_ids) + + if audio_codes is not None: + audio_inputs_embeds = sum([self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])]) + inputs_embeds = audio_inputs_embeds if inputs_embeds is None else audio_inputs_embeds + inputs_embeds + + # Decode + decoder_outputs = self.decoder( + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + past_key_values=past_key_values, + return_dict=True, + labels=text_labels, + **kwargs_decoder, + ) + + decoder_last_hidden_state = decoder_outputs.last_hidden_state + + depth_decoder_outputs = None + if text_labels is not None and audio_labels is not None: + # To use depth decoder forward here, we actually need oracle input ids since we're supposed to pass the true input ids + + # (batch_size, sequence_length) -> (batch_size * sequence_length, 1) + text_labels = text_labels.view(-1, 1) + # (batch_size, num_codebooks, sequence_length) -> (batch_size * sequence_length, num_codebooks) + audio_labels = audio_labels.transpose(1,2).reshape(-1, audio_labels.shape[1]) + + depth_input_ids = torch.cat([text_labels, audio_labels], dim=1) + # keep the last codebook out of input_ids + depth_input_ids = depth_input_ids[:, :-1] + + # (batch_size, sequence_length, dim) -> (batch_size * sequence_length, 1, dim) + decoder_last_hidden_state = decoder_last_hidden_state.view(-1, 1, decoder_last_hidden_state.shape[-1]) + + depth_decoder_outputs = self.depth_decoder( + last_hidden_state=decoder_last_hidden_state, + input_ids=depth_input_ids, + attention_mask=attention_mask, + ) + + if not return_dict: + outputs = decoder_outputs.to_tuple() + if depth_decoder_outputs is not None: + outputs += depth_decoder_outputs.to_tuple() + return outputs + + return MoshiConditionalGenerationOutputWithPast( + loss=decoder_outputs.loss, + logits=decoder_outputs.logits, + last_hidden_state=decoder_last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + depth_loss=None if depth_decoder_outputs is None else depth_decoder_outputs.loss, + audio_logits=None if depth_decoder_outputs is None else depth_decoder_outputs.logits, + depth_past_key_values=None if decoder_outputs is None else decoder_outputs.past_key_values, + depth_hidden_states=None if decoder_outputs is None else decoder_outputs.hidden_states, + depth_attentions=None if decoder_outputs is None else decoder_outputs.attentions, + ) + + def _prepare_inputs_embeds_for_generation( + self, + input_ids: Optional[torch.LongTensor] = None, + user_input_values: Optional[torch.FloatTensor] = None, + user_audio_codes: Optional[torch.Tensor] = None, + moshi_input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it + moshi_audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used + if inputs_embeds is None: + if input_ids is None and user_input_values is None and user_audio_codes is None and moshi_input_values is None and moshi_audio_codes is None: + raise ValueError("You must provide at least one of `input_ids`, `user_input_values`, `moshi_input_values`, `user_audio_codes` or `moshi_audio_codes`.") + + # TODO: make sure batch size and sequence length is concording + + if user_input_values is not None and user_audio_codes is None: + user_audio_codes = self.audio_encoder.encode(user_input_values, num_quantizers=self.num_codebooks)[0] + + if moshi_input_values is not None and moshi_audio_codes is None: + moshi_audio_codes = self.audio_encoder.encode(moshi_input_values, num_quantizers=self.num_codebooks)[0] + + audio_codes = None + if user_audio_codes is not None and moshi_audio_codes is not None: + # TODO: make sure it's the right order (user than moshi) + make sure it's done over the right dim + audio_codes = torch.cat([user_audio_codes, moshi_audio_codes], dim=1) + elif user_audio_codes is not None: + audio_codes = user_audio_codes + elif moshi_audio_codes is not None: + audio_codes = moshi_audio_codes + + if input_ids is not None: + inputs_embeds = self.decoder.model.embed_tokens(input_ids) + + if audio_codes is not None: + audio_inputs_embeds = sum([self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])]) + inputs_embeds = audio_inputs_embeds if inputs_embeds is None else audio_inputs_embeds + inputs_embeds + + return inputs_embeds, moshi_audio_codes + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.LongTensor] = None, + user_input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it + user_audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings + moshi_input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it + moshi_audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> torch.LongTensor: + """ + # TODO: modify + Generates sequences of token ids for models with a language modeling head. + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. + kwargs (`Dict[str, Any]`, *optional*): + Remaining dictionary of keyword arguments that are passed to the `generate` method. Refers to the + original [`generate` docstrings](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate) + for more information on how to use them. + Note that keywords with a *depth_* prefix will be input for the `generate` method of the + depth decoder. Otherwise, the latter will use its default generation config. + + Return: # TODO + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GenerateDecoderOnlyOutput`], + - [`~generation.GenerateBeamDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + inputs_embeds, moshi_audio_codes = self._prepare_inputs_embeds_for_generation( + input_ids=input_ids, + user_input_values=user_input_values, + user_audio_codes=user_audio_codes, + moshi_input_values=moshi_input_values, + moshi_audio_codes=moshi_audio_codes, + inputs_embeds=inputs_embeds, + ) + + self.generated_audio_codes = moshi_audio_codes + + outputs = super().generate(inputs_embeds=inputs_embeds, **kwargs) + + # check if outputs is a dict or a Tensor (depending on unaccessed `generation_config.return_dict_in_generate`) + if isinstance(outputs, torch.Tensor): + output_text_ids = outputs + else: + output_text_ids = outputs.sequences + + output_audio_codes = self.generated_audio_codes + + + output_values = self.audio_encoder.decode( + output_audio_codes, + ).audio_values + + + return output_text_ids, output_values + + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, + ): + # 1. Do usual operations done on LLMs like Gemma - because we pre-processed inputs, the first pass always has inputs_embeds + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + + # 2. Now that everything is prepared, generate audio_codes using the depth decoder + + # we want to do it after a first token has been generated + if model_inputs["input_ids"] is not None: + last_hidden_state = kwargs.get("last_hidden_state") + # (batch_size, sequence_length, dim) -> (batch_size * sequence_length, 1, dim) + last_hidden_state = last_hidden_state.view(-1, 1, last_hidden_state.shape[-1]) + + input_ids = model_inputs.pop("input_ids") + + # TODO: allow passing generation kwargs + generated_audio_codes = self.depth_decoder.generate( + last_hidden_state=last_hidden_state, + input_ids=input_ids.view(-1, 1), + min_length=self.num_codebooks + 1,# TODO: change + max_length=self.num_codebooks + 1,# TODO: change + ) + + # the first tokens are text tokens + generated_audio_codes = generated_audio_codes[:, 1:].unsqueeze(2) + + self.generated_audio_codes = torch.cat([self.generated_audio_codes, generated_audio_codes], dim=2) + + # TODO: for now, we don't use blank user input ids !! + inputs_embeds, _ = self._prepare_inputs_embeds_for_generation(input_ids, moshi_audio_codes=generated_audio_codes) + + model_inputs["input_ids"] = None + model_inputs["inputs_embeds"] = inputs_embeds + + return model_inputs + + def _update_model_kwargs_for_generation(self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder, num_new_tokens) + + # update last_hidden_state that'll be used in the depth decoder + model_kwargs["last_hidden_state"] = outputs.get("last_hidden_state") + return model_kwargs + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id) + + def resize_token_embeddings(self, *args, **kwargs): + raise NotImplementedError( + "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the" + " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or" + " model.decoder.resize_token_embeddings(...))" + ) + + def freeze_audio_encoder(self): + """ + Freeze the audio encoder weights. + """ + for param in self.audio_encoder.parameters(): + param.requires_grad = False + self.audio_encoder._requires_grad = False + + def freeze_depth_decoder(self): + """ + Freeze the depth encoder weights. + """ + for param in self.depth_decoder.parameters(): + param.requires_grad = False + self.depth_decoder._requires_grad = False \ No newline at end of file diff --git a/tests/models/mimi/test_modeling_mimi.py b/tests/models/mimi/test_modeling_mimi.py index dd0f77421be728..c5dd6d9daa79d9 100644 --- a/tests/models/mimi/test_modeling_mimi.py +++ b/tests/models/mimi/test_modeling_mimi.py @@ -762,6 +762,7 @@ def test_torch_compile(self): for i in range(n_iter): _ = model(inputs_dict["input_values"].to(torch_device)) + @is_flaky() def test_batching_equivalence(self): super().test_batching_equivalence() @@ -791,6 +792,7 @@ def test_integration_using_cache_decode(self): } librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + model_id = "kyutai/mimi" model = MimiModel.from_pretrained(model_id, use_cache=True).to(torch_device) @@ -841,6 +843,7 @@ def test_integration(self): "32": 1803071, } librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + model_id = "kyutai/mimi" processor = AutoFeatureExtractor.from_pretrained(model_id) diff --git a/tests/models/moshi/__init__.py b/tests/models/moshi/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py new file mode 100644 index 00000000000000..c9ded9097a02ff --- /dev/null +++ b/tests/models/moshi/test_modeling_moshi.py @@ -0,0 +1,2586 @@ +# coding=utf-8 +# Copyright 2021, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Moshi model.""" + +import copy +import inspect +import math +import tempfile +import unittest + +import numpy as np +from parameterized import parameterized +from pytest import mark + +from transformers import ( + EncodecConfig, + MoshiConfig, + MoshiProcessor, + PretrainedConfig, + T5Config, +) +from transformers.testing_utils import ( + is_torch_available, + require_flash_attn, + require_torch, + require_torch_accelerator, + require_torch_fp16, + require_torch_gpu, + require_torch_sdpa, + slow, + torch_device, +) +from transformers.utils import cached_property, is_torch_bf16_available_on_device, is_torch_fp16_available_on_device + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + MoshiForCausalLM, + MoshiForConditionalGeneration, + MoshiModel, + set_seed, + ) + from transformers.generation import ( + GenerateDecoderOnlyOutput, + GenerateEncoderDecoderOutput, + ) + + +def _config_zero_init(config): + configs_no_init = copy.deepcopy(config) + for key in configs_no_init.__dict__.keys(): + if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key: + setattr(configs_no_init, key, 1e-10) + if isinstance(getattr(configs_no_init, key, None), PretrainedConfig): + no_init_subconfig = _config_zero_init(getattr(configs_no_init, key)) + setattr(configs_no_init, key, no_init_subconfig) + return configs_no_init + + +def prepare_moshi_decoder_inputs_dict( + config, + input_ids, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + cross_attn_head_mask=None, +): + if attention_mask is None: + attention_mask = input_ids.reshape(-1, config.num_codebooks, input_ids.shape[-1])[:, 0, :] + attention_mask = attention_mask.ne(config.pad_token_id) + if head_mask is None: + head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device) + if encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=torch_device) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + "head_mask": head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + } + + +class MoshiDecoderTester: + def __init__( + self, + parent, + batch_size=4, # need batch_size != num_hidden_layers + seq_length=7, + is_training=True, + vocab_size=99, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=100, + pad_token_id=99, + bos_token_id=99, + num_codebooks=4, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.num_codebooks = num_codebooks + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size) + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + + config = self.get_config() + inputs_dict = prepare_moshi_decoder_inputs_dict( + config, + input_ids, + encoder_hidden_states=encoder_hidden_states, + ) + return config, inputs_dict + + def get_config(self): + config = MoshiConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + d_ff=self.intermediate_size, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.bos_token_id, + bos_token_id=self.bos_token_id, + num_codebooks=self.num_codebooks, + tie_word_embeddings=False, + ) + return config + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + +@require_torch +class MoshiDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (MoshiModel, MoshiForCausalLM) if is_torch_available() else () + greedy_sample_model_classes = ( + (MoshiForCausalLM,) if is_torch_available() else () + ) # we don't want to run all the generation tests, only a specific subset + test_pruning = False + test_resize_embeddings = False + + def setUp(self): + self.model_tester = MoshiDecoderTester(self) + self.config_tester = ConfigTester(self, config_class=MoshiConfig, hidden_size=16) + + def test_config(self): + self.config_tester.run_common_tests() + + # special case for labels + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks), + dtype=torch.long, + device=torch_device, + ) + return inputs_dict + + def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None): + if not self.model_tester.is_training: + self.skipTest(reason="model_tester.is_training is set to False") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_cache = False + config.return_dict = True + model = MoshiForCausalLM(config) + + model.to(torch_device) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + model.train() + + # Contrarily to the initial method, we don't unfreeze freezed parameters. + # Indeed, sinusoidal position embeddings have frozen weights that should stay frozen. + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + inputs = self._prepare_for_class(inputs_dict, MoshiForCausalLM, return_labels=True) + loss = model(**inputs).loss + loss.backward() + optimizer.step() + + for k, v in model.named_parameters(): + if v.requires_grad: + self.assertTrue(v.grad is not None, f"{k} in {MoshiForCausalLM.__name__} has no gradient!") + + # override since we have to compute the input embeddings over codebooks + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + + embed_tokens = model.get_input_embeddings() + + input_ids = input_ids.reshape(-1, config.num_codebooks, input_ids.shape[-1]) + + inputs["inputs_embeds"] = sum( + [embed_tokens[codebook](input_ids[:, codebook]) for codebook in range(config.num_codebooks)] + ) + + with torch.no_grad(): + model(**inputs)[0] + + # override since we have embeddings / LM heads over multiple codebooks + def test_model_get_set_embeddings(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + first_embed = model.get_input_embeddings()[0] + self.assertIsInstance(first_embed, torch.nn.Embedding) + lm_heads = model.get_output_embeddings() + self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear)) + + @unittest.skip(reason="Moshi does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="Moshi does not support all arguments tested") + def test_model_outputs_equivalence(self): + pass + + @unittest.skip(reason="Moshi has multiple inputs embeds and lm heads that should not be tied") + def test_tie_model_weights(self): + pass + + @unittest.skip(reason="Moshi has multiple inputs embeds and lm heads that should not be tied") + def test_tied_weights_keys(self): + pass + + def _get_input_ids_and_config(self, batch_size=2): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + input_ids = inputs_dict["input_ids"] + + # take max batch_size + sequence_length = input_ids.shape[-1] + input_ids = input_ids[: batch_size * config.num_codebooks, :] + + attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) + return config, input_ids, attention_mask + + def _get_logits_processor_kwargs(self, do_sample=False): + logits_processor_kwargs = {} + return logits_processor_kwargs + + def test_greedy_generate_stereo_outputs(self): + for model_class in self.greedy_sample_model_classes: + config, input_ids, attention_mask = self._get_input_ids_and_config() + config.audio_channels = 2 + model = model_class(config).to(torch_device).eval() + output_generate = self._greedy_generate( + model=model, + input_ids=input_ids.to(torch_device), + attention_mask=attention_mask.to(torch_device), + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) + + self.assertNotIn(config.pad_token_id, output_generate) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # Ignore copy + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + # Ignore copy + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + + # Ignore copy + outputs = model(dummy_input, output_hidden_states=True) + # Ignore copy + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + # Ignore copy + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + + # check with inference + dropout + model.train() + _ = model_fa(dummy_input, **other_inputs) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # Ignore copy + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + # Ignore copy + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + if model.config.is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input) + + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + else: + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + # Ignore copy + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_left_padding(self): + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # make sure we do left padding + dummy_attention_mask[:, :-1] = 0 + dummy_attention_mask[:, -1:] = 1 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_padding_right(self): + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # make sure we do right padding + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_use_cache(self): + max_new_tokens = 30 + + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + + if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device): + self.skipTest( + f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" + ) + + # Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead. + if torch_dtype == "float16": + torch_dtype = torch.float16 + elif torch_dtype == "bfloat16": + torch_dtype = torch.bfloat16 + elif torch_dtype == "float32": + torch_dtype = torch.float32 + + atols = { + ("cpu", False, torch.float32): 1e-6, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-6, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-6, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-6, + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + rtols = { + ("cpu", False, torch.float32): 1e-4, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-4, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-4, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-4, + ("cuda", True, torch.bfloat16): 3e-2, + ("cuda", True, torch.float16): 5e-3, + } + + def get_mean_reldiff(failcase, x, ref, atol, rtol): + return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + is_encoder_decoder = model.config.is_encoder_decoder + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) + model_sdpa = model_sdpa.eval().to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch_dtype, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa and model_sdpa.config.model_type != "falcon": + raise ValueError("The SDPA model should have SDPA attention layers") + + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model, + # but it would be nicer to have an efficient way to use parameterized.expand + fail_cases = [] + for padding_side in ["left", "right"]: + for use_mask in [False, True]: + for batch_size in [1, 5]: + # Ignore copy + batch_size_input_ids = self.model_tester.num_codebooks * batch_size + dummy_input = inputs_dict[model.main_input_name] + + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + dummy_input = dummy_input.to(torch_dtype) + + # Ignore copy + dummy_input = dummy_input[:batch_size_input_ids] + # Ignore copy + if dummy_input.shape[0] != batch_size_input_ids: + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + # Ignore copy + extension = torch.rand( + batch_size_input_ids - dummy_input.shape[0], + *dummy_input.shape[1:], + dtype=torch_dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) + else: + # Ignore copy + extension = torch.randint( + high=5, + size=(batch_size_input_ids - dummy_input.shape[0], *dummy_input.shape[1:]), + dtype=dummy_input.dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) + + if not use_mask: + dummy_attention_mask = None + else: + dummy_attention_mask = inputs_dict.get("attention_mask", None) + if dummy_attention_mask is None: + if is_encoder_decoder: + seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1] + else: + seqlen = dummy_input.shape[-1] + dummy_attention_mask = ( + torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) + ) + + dummy_attention_mask = dummy_attention_mask[:batch_size] + if dummy_attention_mask.shape[0] != batch_size: + extension = torch.ones( + batch_size - dummy_attention_mask.shape[0], + *dummy_attention_mask.shape[1:], + dtype=dummy_attention_mask.dtype, + device=torch_device, + ) + dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) + dummy_attention_mask = dummy_attention_mask.to(torch_device) + + dummy_attention_mask[:] = 1 + if padding_side == "left": + dummy_attention_mask[-1, :-1] = 1 + dummy_attention_mask[-1, -4:] = 0 + elif padding_side == "right": + dummy_attention_mask[-1, 1:] = 1 + dummy_attention_mask[-1, :3] = 0 + + for enable_kernels in [False, True]: + failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}" + + other_inputs = { + "output_hidden_states": True, + } + + # Otherwise fails for e.g. WhisperEncoderModel + if "attention_mask" in inspect.signature(model_eager.forward).parameters: + other_inputs["attention_mask"] = dummy_attention_mask + + # TODO: test gradients as well (& for FA2 as well!) + with torch.no_grad(): + with torch.backends.cuda.sdp_kernel( + enable_flash=enable_kernels, + enable_math=True, + enable_mem_efficient=enable_kernels, + ): + outputs_eager = model_eager(dummy_input, **other_inputs) + outputs_sdpa = model_sdpa(dummy_input, **other_inputs) + + logits_eager = ( + outputs_eager.hidden_states[-1] + if not is_encoder_decoder + else outputs_eager.decoder_hidden_states[-1] + ) + logits_sdpa = ( + outputs_sdpa.hidden_states[-1] + if not is_encoder_decoder + else outputs_sdpa.decoder_hidden_states[-1] + ) + + if torch_device in ["cpu", "cuda"]: + atol = atols[torch_device, enable_kernels, torch_dtype] + rtol = rtols[torch_device, enable_kernels, torch_dtype] + else: + atol = 1e-7 + rtol = 1e-4 + + # Masked tokens output slightly deviates - we don't mind that. + if use_mask: + if padding_side == "left": + sub_sdpa = logits_sdpa[:-1] + sub_eager = logits_eager[:-1] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + sub_sdpa = logits_sdpa[-1, :-4] + sub_eager = logits_eager[-1, :-4] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + # Testing the padding tokens is not really meaningful but anyway + # sub_sdpa = logits_sdpa[-1, -4:] + # sub_eager = logits_eager[-1, -4:] + # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) + elif padding_side == "right": + sub_sdpa = logits_sdpa[:-1] + sub_eager = logits_eager[:-1] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + sub_sdpa = logits_sdpa[-1, 3:] + sub_eager = logits_eager[-1, 3:] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + # Testing the padding tokens is not really meaningful but anyway + # sub_sdpa = logits_sdpa[-1, :3] + # sub_eager = logits_eager[-1, :3] + # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) + + else: + if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol) + ) + + self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) + + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_generate(self): + max_new_tokens = 30 + + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model_sdpa = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + attn_implementation="eager", + ).to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + + # Just test that a large cache works as expected + res_eager = model_eager.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + res_sdpa = model_sdpa.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + self.assertTrue(torch.allclose(res_eager, res_sdpa)) + + +def prepare_moshi_inputs_dict( + config, + input_ids, + decoder_input_ids, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + labels=None, +): + if decoder_attention_mask is None: + decoder_attention_mask = decoder_input_ids.reshape( + -1, config.decoder.num_codebooks, decoder_input_ids.shape[-1] + )[:, 0, :] + decoder_attention_mask = decoder_attention_mask.ne(config.decoder.pad_token_id) + if head_mask is None: + head_mask = torch.ones( + config.text_encoder.num_hidden_layers, config.text_encoder.num_attention_heads, device=torch_device + ) + if decoder_head_mask is None: + decoder_head_mask = torch.ones( + config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device + ) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones( + config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device + ) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "labels": labels, + } + + +class MoshiTester: + def __init__( + self, + parent, + batch_size=4, # need batch_size != num_hidden_layers + seq_length=7, + is_training=True, + vocab_size=99, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=100, + pad_token_id=99, + bos_token_id=99, + num_codebooks=4, + num_filters=4, + codebook_size=128, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.num_codebooks = num_codebooks + self.num_filters = num_filters + self.codebook_size = codebook_size + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + decoder_input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size) + + config = self.get_config() + inputs_dict = prepare_moshi_inputs_dict(config, input_ids, decoder_input_ids=decoder_input_ids) + return config, inputs_dict + + def get_config(self): + text_encoder_config = T5Config( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.intermediate_size, + num_layers=self.num_hidden_layers, + num_heads=self.num_attention_heads, + ) + audio_encoder_config = EncodecConfig( + hidden_size=self.vocab_size, + compress=1, + num_filters=self.num_filters, + codebook_size=self.codebook_size, + codebook_dim=self.vocab_size, + ) + decoder_config = MoshiConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + ffn_dim=self.intermediate_size, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.bos_token_id, + bos_token_id=self.bos_token_id, + num_codebooks=self.num_codebooks, + tie_word_embeddings=False, + ) + config = MoshiConfig.from_sub_models_config(text_encoder_config, audio_encoder_config, decoder_config) + return config + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + +@require_torch +class MoshiTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (MoshiForConditionalGeneration,) if is_torch_available() else () + greedy_sample_model_classes = (MoshiForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = {"text-to-audio": MoshiForConditionalGeneration} if is_torch_available() else {} + test_pruning = False # training is not supported yet for Moshi + test_headmasking = False + test_resize_embeddings = False + # not to test torchscript as the model tester doesn't prepare `input_values` and `padding_mask` + # (and `torchscript` hates `None` values). + test_torchscript = False + + def setUp(self): + self.model_tester = MoshiTester(self) + + # special case for labels + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks), + dtype=torch.long, + device=torch_device, + ) + return inputs_dict + + def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None): + if not self.model_tester.is_training: + self.skipTest(reason="model_tester.is_training is set to False") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_cache = False + config.return_dict = True + model = model_class(config) + + model.to(torch_device) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + model.train() + + # The audio encoder weights are not used during the forward pass (only during the generate pass) + # So we need to freeze it to be able to train. + model.freeze_audio_encoder() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + optimizer.step() + + for k, v in model.named_parameters(): + if v.requires_grad: + self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!") + + def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids): + text_encoder_config = config.text_encoder + decoder_config = config.decoder + + encoder_attentions = outputs["encoder_attentions"] + self.assertEqual(len(encoder_attentions), text_encoder_config.num_hidden_layers) + + self.assertEqual( + encoder_attentions[0].shape[-3:], + (text_encoder_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]), + ) + + decoder_attentions = outputs["decoder_attentions"] + num_decoder_layers = decoder_config.num_hidden_layers + self.assertEqual(len(decoder_attentions), num_decoder_layers) + + self.assertEqual( + decoder_attentions[0].shape[-3:], + (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]), + ) + + cross_attentions = outputs["cross_attentions"] + self.assertEqual(len(cross_attentions), num_decoder_layers) + + cross_attention_input_seq_len = decoder_input_ids.shape[-1] + self.assertEqual( + cross_attentions[0].shape[-3:], + (decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]), + ) + + def check_moshi_model_output_attentions( + self, + model_class, + config, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + output_attentions=True, + **kwargs, + ) + self._check_output_with_attentions(outputs, config, input_ids, decoder_input_ids) + + def check_moshi_model_output_attentions_from_config( + self, + model_class, + config, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ): + # Similar to `check_moshi_model_output_attentions`, but with `output_attentions` triggered from the + # config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded + # from the inner models' configurations. + config.output_attentions = True # model config -> won't work + + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + **kwargs, + ) + self.assertTrue( + all(key not in outputs for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"]) + ) + config.text_encoder.output_attentions = True # inner model config -> will work + config.audio_encoder.output_attentions = True + config.decoder.output_attentions = True + + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + **kwargs, + ) + self._check_output_with_attentions(outputs, config, input_ids, decoder_input_ids) + + # override since changing `output_attentions` from the top-level model config won't work + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + self.check_moshi_model_output_attentions(model_class, config, **inputs_dict) + self.check_moshi_model_output_attentions_from_config(model_class, config, **inputs_dict) + + # override since we have a specific forward signature for moshi + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = [ + "input_ids", + "attention_mask", + "input_values", + "padding_mask", + "decoder_input_ids", + "decoder_attention_mask", + ] + expected_arg_names.extend( + ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"] + if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names + else ["encoder_outputs"] + ) + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + + # override since changing `gradient_checkpointing` from the top-level model config won't work + def test_gradient_checkpointing_backward_compatibility(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if not model_class.supports_gradient_checkpointing: + continue + + config.text_encoder.gradient_checkpointing = True + config.audio_encoder.gradient_checkpointing = True + config.decoder.gradient_checkpointing = True + model = model_class(config) + self.assertTrue(model.is_gradient_checkpointing) + + @unittest.skip(reason="Moshi has multiple inputs embeds and lm heads that should not be tied.") + def test_tie_model_weights(self): + pass + + @unittest.skip(reason="Moshi has multiple inputs embeds and lm heads that should not be tied.") + def test_tied_model_weights_key_ignore(self): + pass + + @unittest.skip(reason="Moshi has multiple inputs embeds and lm heads that should not be tied.") + def test_tied_weights_keys(self): + pass + + @unittest.skip(reason="No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage(self): + pass + + @unittest.skip(reason="No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_checkpoints(self): + pass + + @unittest.skip(reason="No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_no_safetensors(self): + pass + + # override since changing `output_hidden_states` / `output_attentions` from the top-level model config won't work + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.text_encoder.output_hidden_states = True + config.audio_encoder.output_hidden_states = True + config.decoder.output_hidden_states = True + + config.text_encoder.output_attentions = True + config.decoder.output_attentions = True + + # no need to test all models as different heads yield the same functionality + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs) + + output = outputs[0] + + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_hidden_states.retain_grad() + + decoder_hidden_states = outputs.decoder_hidden_states[0] + decoder_hidden_states.retain_grad() + + if self.has_attentions: + encoder_attentions = outputs.encoder_attentions[0] + encoder_attentions.retain_grad() + + decoder_attentions = outputs.decoder_attentions[0] + decoder_attentions.retain_grad() + + cross_attentions = outputs.cross_attentions[0] + cross_attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(encoder_hidden_states.grad) + self.assertIsNotNone(decoder_hidden_states.grad) + + if self.has_attentions: + self.assertIsNotNone(encoder_attentions.grad) + self.assertIsNotNone(decoder_attentions.grad) + self.assertIsNotNone(cross_attentions.grad) + + # override since changing `output_hidden_states` from the top-level model config won't work + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states + + expected_num_layers = self.model_tester.num_hidden_layers + 1 + self.assertEqual(len(hidden_states), expected_num_layers) + + seq_length = self.model_tester.seq_length + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + hidden_states = outputs.decoder_hidden_states + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.text_encoder.output_hidden_states = True + config.audio_encoder.output_hidden_states = True + config.decoder.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + # override since the conv layers and lstm's in encodec are exceptions + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + uniform_init_parms = ["conv"] + ignore_init = ["lstm"] + if param.requires_grad: + if any(x in name for x in uniform_init_parms): + self.assertTrue( + -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + elif not any(x in name for x in ignore_init): + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + # override since we have embeddings / LM heads over multiple codebooks + def test_model_get_set_embeddings(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), torch.nn.Embedding) + lm_heads = model.get_output_embeddings() + self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear)) + + def _get_input_ids_and_config(self, batch_size=2): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + input_ids = inputs_dict["input_ids"] + + # take max batch_size + sequence_length = input_ids.shape[-1] + input_ids = input_ids[:batch_size, :] + attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) + + return config, input_ids, attention_mask + + # override since the `input_ids` cannot be used as the `decoder_input_ids` for moshi (input / outputs are + # different modalities -> different shapes) + def _greedy_generate( + self, + model, + input_ids, + attention_mask, + output_scores=False, + output_attentions=False, + output_hidden_states=False, + return_dict_in_generate=False, + ): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + output_generate = model.generate( + input_ids, + do_sample=False, + num_beams=1, + max_new_tokens=self.max_new_tokens, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + remove_invalid_values=True, + **model_kwargs, + ) + + return output_generate + + # override since the `input_ids` cannot be used as the `decoder_input_ids` for moshi (input / outputs are + # different modalities -> different shapes) + def _sample_generate( + self, + model, + input_ids, + attention_mask, + num_return_sequences, + output_scores=False, + output_attentions=False, + output_hidden_states=False, + return_dict_in_generate=False, + ): + torch.manual_seed(0) + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + output_generate = model.generate( + input_ids, + do_sample=True, + num_beams=1, + max_new_tokens=self.max_new_tokens, + num_return_sequences=num_return_sequences, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + remove_invalid_values=True, + **model_kwargs, + ) + + return output_generate + + def _get_logits_processor_kwargs(self, do_sample=False): + logits_processor_kwargs = {} + return logits_processor_kwargs + + def test_greedy_generate_dict_outputs(self): + for model_class in self.greedy_sample_model_classes: + # disable cache + config, input_ids, attention_mask = self._get_input_ids_and_config() + config.use_cache = False + model = model_class(config).to(torch_device).eval() + output_generate = self._greedy_generate( + model=model, + input_ids=input_ids.to(torch_device), + attention_mask=attention_mask.to(torch_device), + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) + + self.assertNotIn(config.pad_token_id, output_generate) + + def test_greedy_generate_dict_outputs_use_cache(self): + for model_class in self.greedy_sample_model_classes: + # enable cache + config, input_ids, attention_mask = self._get_input_ids_and_config() + + config.use_cache = True + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + output_generate = self._greedy_generate( + model=model, + input_ids=input_ids.to(torch_device), + attention_mask=attention_mask.to(torch_device), + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) + + def test_sample_generate(self): + for model_class in self.greedy_sample_model_classes: + config, input_ids, attention_mask = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() + + # check `generate()` and `sample()` are equal + output_generate = self._sample_generate( + model=model, + input_ids=input_ids.to(torch_device), + attention_mask=attention_mask.to(torch_device), + num_return_sequences=1, + ) + self.assertIsInstance(output_generate, torch.Tensor) + + def test_sample_generate_dict_output(self): + for model_class in self.greedy_sample_model_classes: + # disable cache + config, input_ids, attention_mask = self._get_input_ids_and_config() + config.use_cache = False + model = model_class(config).to(torch_device).eval() + + output_generate = self._sample_generate( + model=model, + input_ids=input_ids.to(torch_device), + attention_mask=attention_mask.to(torch_device), + num_return_sequences=3, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) + + def test_generate_without_input_ids(self): + config, _, _ = self._get_input_ids_and_config() + + # if no bos token id => cannot generate from None + if config.bos_token_id is None: + self.skipTest(reason="bos_token_id is None") + + for model_class in self.greedy_sample_model_classes: + model = model_class(config).to(torch_device) + model.eval() + + output_ids_generate = model.generate( + do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True + ) + self.assertIsNotNone(output_ids_generate) + + @require_torch_fp16 + @require_torch_accelerator # not all operations are supported in fp16 on CPU + def test_generate_fp16(self): + config, input_dict = self.model_tester.prepare_config_and_inputs() + + for model_class in self.greedy_sample_model_classes: + model = model_class(config).eval().to(torch_device) + model.half() + # greedy + model.generate(input_dict["input_ids"], attention_mask=input_dict["attention_mask"], max_new_tokens=10) + # sampling + model.generate( + input_dict["input_ids"], attention_mask=input_dict["attention_mask"], do_sample=True, max_new_tokens=10 + ) + + def test_greedy_generate_stereo_outputs(self): + for model_class in self.greedy_sample_model_classes: + config, input_ids, attention_mask = self._get_input_ids_and_config() + config.audio_channels = 2 + + model = model_class(config).to(torch_device).eval() + output_generate = self._greedy_generate( + model=model, + input_ids=input_ids.to(torch_device), + attention_mask=attention_mask.to(torch_device), + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) + + self.assertNotIn(config.pad_token_id, output_generate) + + @unittest.skip( + reason="MoshiModel is actually not the base of MoshiForCausalLM as the latter is a composit model" + ) + def test_save_load_fast_init_from_base(self): + pass + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # Ignore copy + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + # Ignore copy + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + + # Ignore copy + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input) + # Ignore copy + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + # Ignore copy + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + # Ignore copy + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + # Ignore copy + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + # Ignore copy + outputs = model(dummy_input, **other_inputs) + # Ignore copy + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + + # check with inference + dropout + model.train() + _ = model_fa(dummy_input, **other_inputs) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # Ignore copy + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + # Ignore copy + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + # Ignore copy + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input) + # Ignore copy + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + # Ignore copy + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + # Ignore copy + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + # Ignore copy + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + # Ignore copy + outputs = model(dummy_input, **other_inputs) + # Ignore copy + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_left_padding(self): + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask") + if dummy_attention_mask is None: + dummy_attention_mask = torch.ones_like(dummy_input) + + # make sure we do left padding + dummy_attention_mask[:, :-1] = 0 + dummy_attention_mask[:, -1:] = 1 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_padding_right(self): + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask") + if dummy_attention_mask is None: + dummy_attention_mask = torch.ones_like(dummy_input) + # make sure we do right padding + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_use_cache(self): + max_new_tokens = 30 + + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + + if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device): + self.skipTest( + f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" + ) + + # Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead. + if torch_dtype == "float16": + torch_dtype = torch.float16 + elif torch_dtype == "bfloat16": + torch_dtype = torch.bfloat16 + elif torch_dtype == "float32": + torch_dtype = torch.float32 + + atols = { + ("cpu", False, torch.float32): 1e-6, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-6, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-6, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-6, + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + rtols = { + ("cpu", False, torch.float32): 1e-4, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-4, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-4, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-4, + ("cuda", True, torch.bfloat16): 3e-2, + ("cuda", True, torch.float16): 5e-3, + } + + def get_mean_reldiff(failcase, x, ref, atol, rtol): + return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + is_encoder_decoder = model.config.is_encoder_decoder + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) + model_sdpa = model_sdpa.eval().to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch_dtype, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa and model_sdpa.config.model_type != "falcon": + raise ValueError("The SDPA model should have SDPA attention layers") + + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model, + # but it would be nicer to have an efficient way to use parameterized.expand + fail_cases = [] + for padding_side in ["left", "right"]: + for use_mask in [False, True]: + for batch_size in [1, 5]: + dummy_input = inputs_dict[model.main_input_name] + + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + dummy_input = dummy_input.to(torch_dtype) + + dummy_input = dummy_input[:batch_size] + if dummy_input.shape[0] != batch_size: + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + extension = torch.rand( + batch_size - dummy_input.shape[0], + *dummy_input.shape[1:], + dtype=torch_dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) + else: + extension = torch.randint( + high=5, + size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]), + dtype=dummy_input.dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) + + if not use_mask: + dummy_attention_mask = None + else: + dummy_attention_mask = inputs_dict.get("attention_mask", None) + if dummy_attention_mask is None: + # Ignore copy + seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1] + # Ignore copy + dummy_attention_mask = ( + torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) + ) + + dummy_attention_mask = dummy_attention_mask[:batch_size] + if dummy_attention_mask.shape[0] != batch_size: + extension = torch.ones( + batch_size - dummy_attention_mask.shape[0], + *dummy_attention_mask.shape[1:], + dtype=dummy_attention_mask.dtype, + device=torch_device, + ) + dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) + dummy_attention_mask = dummy_attention_mask.to(torch_device) + + dummy_attention_mask[:] = 1 + if padding_side == "left": + dummy_attention_mask[-1, :-1] = 1 + dummy_attention_mask[-1, -4:] = 0 + elif padding_side == "right": + dummy_attention_mask[-1, 1:] = 1 + dummy_attention_mask[-1, :3] = 0 + + for enable_kernels in [False, True]: + failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}" + # Ignore copy + batch_size_input_ids = self.model_tester.num_codebooks * batch_size + # Ignore copy + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[ + :batch_size_input_ids + ] + # Ignore copy + if decoder_input_ids.shape[0] != batch_size_input_ids: + # Ignore copy + extension = torch.ones( + batch_size_input_ids - decoder_input_ids.shape[0], + *decoder_input_ids.shape[1:], + dtype=decoder_input_ids.dtype, + device=torch_device, + ) + decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0) + decoder_input_ids = decoder_input_ids.to(torch_device) + + # TODO: never an `attention_mask` arg here? + # Ignore copy + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + + # TODO: test gradients as well (& for FA2 as well!) + # Ignore copy + with torch.no_grad(): + with torch.backends.cuda.sdp_kernel( + enable_flash=enable_kernels, + enable_math=True, + enable_mem_efficient=enable_kernels, + ): + outputs_eager = model_eager(dummy_input, **other_inputs) + outputs_sdpa = model_sdpa(dummy_input, **other_inputs) + + logits_eager = ( + outputs_eager.hidden_states[-1] + if not is_encoder_decoder + else outputs_eager.decoder_hidden_states[-1] + ) + logits_sdpa = ( + outputs_sdpa.hidden_states[-1] + if not is_encoder_decoder + else outputs_sdpa.decoder_hidden_states[-1] + ) + + if torch_device in ["cpu", "cuda"]: + atol = atols[torch_device, enable_kernels, torch_dtype] + rtol = rtols[torch_device, enable_kernels, torch_dtype] + else: + atol = 1e-7 + rtol = 1e-4 + + # Masked tokens output slightly deviates - we don't mind that. + if use_mask: + if padding_side == "left": + sub_sdpa = logits_sdpa[:-1] + sub_eager = logits_eager[:-1] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + sub_sdpa = logits_sdpa[-1, :-4] + sub_eager = logits_eager[-1, :-4] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + # Testing the padding tokens is not really meaningful but anyway + # sub_sdpa = logits_sdpa[-1, -4:] + # sub_eager = logits_eager[-1, -4:] + # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) + elif padding_side == "right": + sub_sdpa = logits_sdpa[:-1] + sub_eager = logits_eager[:-1] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + sub_sdpa = logits_sdpa[-1, 3:] + sub_eager = logits_eager[-1, 3:] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + # Testing the padding tokens is not really meaningful but anyway + # sub_sdpa = logits_sdpa[-1, :3] + # sub_eager = logits_eager[-1, :3] + # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) + + else: + if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol) + ) + + self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) + + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_generate(self): + max_new_tokens = 30 + + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model_sdpa = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + attn_implementation="eager", + ).to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + + # Just test that a large cache works as expected + res_eager = model_eager.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + res_sdpa = model_sdpa.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + self.assertTrue(torch.allclose(res_eager, res_sdpa)) + + def test_requires_grad_with_frozen_encoders(self): + config = self.model_tester.get_config() + for model_class in self.all_model_classes: + model = model_class(config) + model.freeze_audio_encoder() + + audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()] + text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()] + + self.assertFalse(all(audio_encoder_grads)) + self.assertTrue(all(text_encoder_grads)) + + model = model_class(config) + model.freeze_text_encoder() + + audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()] + text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()] + + self.assertTrue(all(audio_encoder_grads)) + self.assertFalse(all(text_encoder_grads)) + + +def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000): + """Produces a series of 'bip bip' sounds at a given frequency.""" + timesteps = np.arange(int(duration * sample_rate)) / sample_rate + wav = np.cos(2 * math.pi * 440 * timesteps) + time_period = (timesteps % (2 * bip_duration)) / (2 * bip_duration) + envelope = time_period >= 0.5 + return wav * envelope + + +def place_dict_on_device(dict_to_place, device): + for key in dict_to_place: + if dict_to_place[key] is not None and isinstance(dict_to_place[key], torch.Tensor): + dict_to_place[key] = dict_to_place[key].to(device) + return dict_to_place + + +@require_torch +class MoshiIntegrationTests(unittest.TestCase): + @cached_property + def model(self): + return MoshiForConditionalGeneration.from_pretrained("kyutai/moshiko").to(torch_device) + + @cached_property + def processor(self): + return MoshiProcessor.from_pretrained("kyutai/moshiko") + + @slow + def test_logits_text_prompt(self): + model = self.model + processor = self.processor + + inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt") + + # prepare the encoder inputs + input_ids = inputs.input_ids.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + # prepare the decoder inputs + pad_token_id = model.generation_config.pad_token_id + decoder_input_ids = ( + torch.ones((input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long).to(torch_device) + * pad_token_id + ) + + with torch.no_grad(): + logits = model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + ).logits + + # fmt: off + EXPECTED_LOGITS = torch.tensor( + [ + -0.9708, -3.0149, -4.6415, -1.4754, -0.2786, -2.3523, -2.6049, -6.7467, + -1.0206, -3.2984, -3.3968, -1.5108, -1.5786, -3.1493, -1.1503, -0.0545, + ] + ) + # fmt: on + + self.assertTrue(logits.shape == (*decoder_input_ids.shape, model.decoder.config.vocab_size)) + self.assertTrue(torch.allclose(logits[0, 0, :16].cpu(), EXPECTED_LOGITS, atol=1e-4)) + + @slow + def test_logits_text_audio_prompt(self): + model = self.model + processor = self.processor + + audio = [get_bip_bip(duration=0.5), get_bip_bip(duration=1.0)] + text = ["80s music", "Club techno"] + + inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt") + + # prepare the text encoder inputs + input_ids = inputs.input_ids.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + # prepare the audio encoder inputs + input_values = inputs.input_values.to(torch_device) + padding_mask = inputs.padding_mask.to(torch_device) + + with torch.no_grad(): + logits = model( + input_ids, + attention_mask=attention_mask, + input_values=input_values, + padding_mask=padding_mask, + ).logits + + # fmt: off + EXPECTED_LOGITS = torch.tensor( + [ + 0.1841, -2.9324, -0.7898, 0.1857, 0.4971, -2.8685, -1.6525, -1.6541, + 2.7757, -2.5942, -3.0959, -1.0120, -1.0147, -0.4605, -0.8885, 0.6820, + ] + ) + # fmt: on + + self.assertTrue(logits.shape == (8, 50, 2048)) + self.assertTrue(torch.allclose(logits[0, -1, :16].cpu(), EXPECTED_LOGITS, atol=1e-4)) + + @slow + def test_generate_unconditional_greedy(self): + model = self.model + + # only generate 1 sample with greedy - since it's deterministic all elements of the batch will be the same + unconditional_inputs = model.get_unconditional_inputs(num_samples=1) + unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device) + + output_values = model.generate(**unconditional_inputs, do_sample=False, max_new_tokens=5) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + 0.0056, 0.0064, 0.0063, 0.0054, 0.0042, 0.0033, 0.0024, 0.0015, + 0.0015, 0.0010, 0.0004, -0.0012, -0.0036, -0.0055, -0.0067, -0.0071, + ] + ) + # fmt: on + + self.assertTrue(output_values.shape == (1, 1, 3200)) + self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4)) + + @slow + def test_generate_unconditional_sampling(self): + model = self.model + + # for stochastic sampling we can generate multiple outputs + unconditional_inputs = model.get_unconditional_inputs(num_samples=2) + unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device) + + set_seed(0) + output_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=10) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + -0.0099, -0.0140, 0.0079, 0.0080, -0.0046, 0.0065, -0.0068, -0.0185, + 0.0105, 0.0059, 0.0329, 0.0249, -0.0204, -0.0341, -0.0465, 0.0053, + ] + ) + # fmt: on + + self.assertTrue(output_values.shape == (2, 1, 4480)) + self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4)) + + @slow + def test_generate_text_prompt_greedy(self): + model = self.model + processor = self.processor + + inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt") + + # prepare the encoder inputs + input_ids = inputs.input_ids.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + output_values = model.generate( + input_ids, attention_mask=attention_mask, do_sample=False, guidance_scale=None, max_new_tokens=10 + ) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + -1.1998e-04, -2.2302e-04, 4.6296e-04, 1.0524e-03, 2.4827e-04, + -4.0288e-05, -1.2468e-04, 4.9846e-05, 7.1485e-04, 4.4197e-04, + ] + ) + # fmt: on + + self.assertTrue(output_values.shape == (2, 1, 4480)) + self.assertTrue(torch.allclose(output_values[0, 0, :10].cpu(), EXPECTED_VALUES, atol=1e-4)) + + @slow + def test_generate_text_prompt_greedy_with_classifier_free_guidance(self): + model = self.model + processor = self.processor + + inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt") + + # prepare the encoder inputs + input_ids = inputs.input_ids.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + output_values = model.generate( + input_ids, attention_mask=attention_mask, do_sample=False, guidance_scale=3, max_new_tokens=10 + ) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + 0.0283, 0.0246, 0.0650, 0.0640, 0.0599, 0.0711, 0.0420, 0.0112, + 0.0511, 0.0746, 0.1363, 0.1213, 0.0185, -0.0578, -0.0908, 0.0443, + ] + ) + # fmt: on + + self.assertTrue(output_values.shape == (2, 1, 4480)) + self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4)) + + @slow + def test_generate_text_prompt_sampling(self): + model = self.model + processor = self.processor + + inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt") + + # prepare the encoder inputs + input_ids = inputs.input_ids.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + set_seed(0) + output_values = model.generate( + input_ids, attention_mask=attention_mask, do_sample=True, guidance_scale=None, max_new_tokens=10 + ) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + -0.0111, -0.0154, 0.0047, 0.0058, -0.0068, 0.0012, -0.0109, -0.0229, + 0.0010, -0.0038, 0.0167, 0.0042, -0.0421, -0.0610, -0.0764, -0.0326, + ] + ) + # fmt: on + + self.assertTrue(output_values.shape == (2, 1, 4480)) + self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4)) + + @slow + def test_generate_text_audio_prompt(self): + model = self.model + processor = self.processor + + audio = [get_bip_bip(duration=0.5), get_bip_bip(duration=1.0)] + text = ["80s music", "Club techno"] + + inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt") + inputs = place_dict_on_device(inputs, device=torch_device) + + output_values = model.generate(**inputs, do_sample=False, guidance_scale=None, max_new_tokens=10) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + -0.0036, -0.0130, -0.0261, -0.0384, -0.0557, -0.0718, -0.0680, -0.0632, + -0.0529, -0.0403, -0.0289, -0.0198, -0.0136, -0.0101, -0.0095, -0.0040, + ] + ) + # fmt: on + + self.assertTrue( + output_values.shape == (2, 1, 36480) + ) # input values take shape 32000 and we generate from there + self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, atol=1e-4)) + + +@require_torch +class MoshiStereoIntegrationTests(unittest.TestCase): + @cached_property + def model(self): + return MoshiForConditionalGeneration.from_pretrained("facebook/moshi-stereo-small").to(torch_device) + + @cached_property + def processor(self): + return MoshiProcessor.from_pretrained("facebook/moshi-stereo-small") + + @slow + def test_generate_unconditional_greedy(self): + model = self.model + + # only generate 1 sample with greedy - since it's deterministic all elements of the batch will be the same + unconditional_inputs = model.get_unconditional_inputs(num_samples=1) + unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device) + + output_values = model.generate(**unconditional_inputs, do_sample=False, max_new_tokens=12) + + # fmt: off + EXPECTED_VALUES_LEFT = torch.tensor( + [ + 0.0017, 0.0004, 0.0004, 0.0005, 0.0002, 0.0002, -0.0002, -0.0013, + -0.0010, -0.0015, -0.0018, -0.0032, -0.0060, -0.0082, -0.0096, -0.0099, + ] + ) + EXPECTED_VALUES_RIGHT = torch.tensor( + [ + 0.0038, 0.0028, 0.0031, 0.0032, 0.0031, 0.0032, 0.0030, 0.0019, + 0.0021, 0.0015, 0.0009, -0.0008, -0.0040, -0.0067, -0.0087, -0.0096, + ] + ) + # fmt: on + + # (bsz, channels, seq_len) + self.assertTrue(output_values.shape == (1, 2, 5760)) + self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4)) + self.assertTrue(torch.allclose(output_values[0, 1, :16].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4)) + + @slow + def test_generate_text_audio_prompt(self): + model = self.model + processor = self.processor + + # create stereo inputs + audio = [get_bip_bip(duration=0.5)[None, :].repeat(2, 0), get_bip_bip(duration=1.0)[None, :].repeat(2, 0)] + text = ["80s music", "Club techno"] + + inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt") + inputs = place_dict_on_device(inputs, device=torch_device) + + output_values = model.generate(**inputs, do_sample=False, guidance_scale=3.0, max_new_tokens=12) + + # fmt: off + EXPECTED_VALUES_LEFT = torch.tensor( + [ + 0.2535, 0.2008, 0.1471, 0.0896, 0.0306, -0.0200, -0.0501, -0.0728, + -0.0832, -0.0856, -0.0867, -0.0884, -0.0864, -0.0866, -0.0744, -0.0430, + ] + ) + EXPECTED_VALUES_RIGHT = torch.tensor( + [ + 0.1695, 0.1213, 0.0732, 0.0239, -0.0264, -0.0705, -0.0935, -0.1103, + -0.1163, -0.1139, -0.1104, -0.1082, -0.1027, -0.1004, -0.0900, -0.0614, + ] + ) + # fmt: on + + # (bsz, channels, seq_len) + self.assertTrue(output_values.shape == (2, 2, 37760)) + # input values take shape 32000 and we generate from there - we check the last (generated) values + self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4)) + self.assertTrue(torch.allclose(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4))