diff --git a/docs/source/en/index.md b/docs/source/en/index.md index c18426de4c031c..97148840a2d2ea 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -253,7 +253,7 @@ Flax), PyTorch, and/or TensorFlow. | [Phi3](model_doc/phi3) | ✅ | ❌ | ❌ | | [PhoBERT](model_doc/phobert) | ✅ | ✅ | ✅ | | [Pix2Struct](model_doc/pix2struct) | ✅ | ❌ | ❌ | -| [Pixtral](model_doc/pixtral) | ❌ | ❌ | ❌ | +| [Pixtral](model_doc/pixtral) | ✅ | ❌ | ❌ | | [PLBart](model_doc/plbart) | ✅ | ❌ | ❌ | | [PoolFormer](model_doc/poolformer) | ✅ | ❌ | ❌ | | [Pop2Piano](model_doc/pop2piano) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/pixtral.md b/docs/source/en/model_doc/pixtral.md index 8df2bf5af5f9ca..dfb3df7477708a 100644 --- a/docs/source/en/model_doc/pixtral.md +++ b/docs/source/en/model_doc/pixtral.md @@ -18,20 +18,22 @@ rendered properly in your Markdown viewer. ## Overview -The Pixtral model was released by the Mistral AI team on [Vllm](https://github.com/vllm-project/vllm/pull/8377), where a version of the code can be found! - +The Pixtral model was released by the Mistral AI team on [vLLM](https://github.com/vllm-project/vllm/pull/8377), where a version of the code can be found! Tips: -- Pixtral is a multimodal model, the main contribution is the 2d ROPE on the images, and support for arbitrary image size (the images are not padded together nor are they resized) -- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders. +- Pixtral is a multimodal model, taking images and text as input, and producing text as output. +- This model follows the [Llava](llava) family, meaning image embeddings are placed instead of the `[IMG]` token placeholders. The model uses [`PixtralVisionModel`] for its vision encoder. +- The main contribution is the 2d ROPE (rotary postiion embeddings) on the images, and support for arbitrary image sizes (the images are not padded together nor are they resized). - The format for one or mulitple prompts is the following: ``` "[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]" ``` Then, the processor will replace each `[IMG]` token with a number of `[IMG]` token that depends on the height and the width of the image. Each *row* of the image is separated by a `[IMG_BREAK]` token, and each image is separated by a `[IMG_END]` token. -This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) and [ArthurZ](https://huggingface.co/ArthurZ) +This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) and [ArthurZ](https://huggingface.co/ArthurZ). The original code can be found [here](https://github.com/vllm-project/vllm/pull/8377). + +## Usage Here is an example of how to run it: @@ -83,9 +85,9 @@ Each image captures a different scene, from a close-up of a dog to expansive nat [[autodoc]] PixtralVisionConfig -## PixtralModel +## PixtralVisionModel -[[autodoc]] PixtralModel +[[autodoc]] PixtralVisionModel - forward ## PixtralImageProcessor diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 36775d8454ab8c..bcb1217f6c7b00 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2978,7 +2978,7 @@ "Pix2StructVisionModel", ] ) - _import_structure["models.pixtral"].extend(["PixtralModel", "PixtralPreTrainedModel"]) + _import_structure["models.pixtral"].extend(["PixtralVisionModel", "PixtralPreTrainedModel"]) _import_structure["models.plbart"].extend( [ "PLBartForCausalLM", @@ -7456,8 +7456,8 @@ Pix2StructVisionModel, ) from .models.pixtral import ( - PixtralModel, PixtralPreTrainedModel, + PixtralVisionModel, ) from .models.plbart import ( PLBartForCausalLM, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e0d15f1e236590..6db0f97016dd18 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -193,7 +193,7 @@ ("persimmon", "PersimmonModel"), ("phi", "PhiModel"), ("phi3", "Phi3Model"), - ("pixtral", "PixtralModel"), + ("pixtral", "PixtralVisionModel"), ("plbart", "PLBartModel"), ("poolformer", "PoolFormerModel"), ("prophetnet", "ProphetNetModel"), diff --git a/src/transformers/models/pixtral/__init__.py b/src/transformers/models/pixtral/__init__.py index e09ed8e60127dd..69335eeab902d3 100644 --- a/src/transformers/models/pixtral/__init__.py +++ b/src/transformers/models/pixtral/__init__.py @@ -29,7 +29,7 @@ pass else: _import_structure["modeling_pixtral"] = [ - "PixtralModel", + "PixtralVisionModel", "PixtralPreTrainedModel", ] @@ -52,8 +52,8 @@ pass else: from .modeling_pixtral import ( - PixtralModel, PixtralPreTrainedModel, + PixtralVisionModel, ) try: diff --git a/src/transformers/models/pixtral/configuration_pixtral.py b/src/transformers/models/pixtral/configuration_pixtral.py index dcc1e458ca78a3..a2fdc9b2ed303a 100644 --- a/src/transformers/models/pixtral/configuration_pixtral.py +++ b/src/transformers/models/pixtral/configuration_pixtral.py @@ -22,7 +22,7 @@ class PixtralVisionConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`PixtralModel`]. It is used to instantiate an + This is the configuration class to store the configuration of a [`PixtralVisionModel`]. It is used to instantiate an Pixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Pixtral-9B. @@ -58,13 +58,13 @@ class PixtralVisionConfig(PretrainedConfig): Example: ```python - >>> from transformers import PixtralModel, PixtralVisionConfig, CLIPVisionConfig, LlamaConfig + >>> from transformers import PixtralVisionModel, PixtralVisionConfig, CLIPVisionConfig, LlamaConfig >>> # Initializing a Pixtral 12B style configuration >>> config = PixtralVisionConfig() >>> # Initializing a model from the pixtral 12B style configuration - >>> model = PixtralModel(configuration) + >>> model = PixtralVisionModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index 0e10c78b7852af..06b9701a75661a 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# Copyright 2024 Mistral and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -48,15 +48,13 @@ def position_ids_in_meshgrid(patch_embeds_list, max_width): class PixtralRotaryEmbedding(nn.Module): """ The key with pixtral embedding is just that you have a frequency for each pixel positions. - If you have height x width pixels (or embedding pixels) + If you have height x width pixels (or embedding pixels), then the frequency used for ROPE + is given by indexing the pre_computed frequency on the width and height. - then the frequency used for ROPE is given by indexing the pre_computed frequency on the - width and height. + What you output is of dimension (batch, height * width, dim) with dim the embed dim. - What you output is of dimension batch, height * width, dim with dim the embed dim. - - This simply means that for each image hidden states, you are going to add - a corresponding positional embedding, based on it's index in the grid. + This simply means that for each image hidden state, you are going to add + a corresponding positional embedding, based on its index in the grid. """ def __init__(self, config, device): @@ -319,9 +317,7 @@ def forward( r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. + Embeddings which serve as input to the Transformer. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: @@ -392,17 +388,13 @@ def forward( and behavior. Parameters: - config ([`PixtralVisionConfig`] or [`PixtralVisionConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not + config ([`PixtralVisionConfig`]): + Model configuration class with all the parameters of the vision encoder. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - PIXTRAL_START_DOCSTRING, -) class PixtralPreTrainedModel(PreTrainedModel): config_class = PixtralVisionConfig base_model_prefix = "model" @@ -412,9 +404,6 @@ class PixtralPreTrainedModel(PreTrainedModel): _supports_cache_class = True def _init_weights(self, module): - # important: this ported version of Pixtral isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - the original codebase - # https://github.com/haotian-liu/LLaVA/tree/main/pixtral should serve for that purpose std = ( self.config.initializer_range if hasattr(self.config, "initializer_range") @@ -433,8 +422,9 @@ def _init_weights(self, module): PIXTRAL_INPUTS_DOCSTRING = r""" Args: - pixel_values: list of N_img images of variable sizes, - each of shape (C, H, W) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`AutoImageProcessor.__call__`] + for details. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -463,10 +453,10 @@ def generate_block_attention_mask(patch_embeds_list, tensor): @add_start_docstrings( - """The PIXTRAL model which consists of a vision backbone and a language model.""", + "The bare Pixtral vision encoder outputting raw hidden-states without any specific head on top.", PIXTRAL_START_DOCSTRING, ) -class PixtralModel(PixtralPreTrainedModel): +class PixtralVisionModel(PixtralPreTrainedModel): base_model_prefix = "vision_encoder" def __init__(self, config): diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 2db7b38b580375..c16893d06bcbf5 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -7067,14 +7067,14 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class PixtralModel(metaclass=DummyObject): +class PixtralPreTrainedModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class PixtralPreTrainedModel(metaclass=DummyObject): +class PixtralVisionModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/tests/models/pixtral/test_modeling_pixtral.py b/tests/models/pixtral/test_modeling_pixtral.py index bd41fa1c9e62fb..9a128f6ad28823 100644 --- a/tests/models/pixtral/test_modeling_pixtral.py +++ b/tests/models/pixtral/test_modeling_pixtral.py @@ -21,8 +21,8 @@ from transformers import ( AutoProcessor, - PixtralModel, PixtralVisionConfig, + PixtralVisionModel, is_torch_available, is_vision_available, ) @@ -46,7 +46,7 @@ from PIL import Image -class PixtralModelTester: +class PixtralVisionModelTester: def __init__( self, parent, @@ -107,7 +107,7 @@ def get_config(self): ) def create_and_check_model(self, config, pixel_values): - model = PixtralModel(config=config) + model = PixtralVisionModel(config=config) model.to(torch_device) model.eval() with torch.no_grad(): @@ -120,7 +120,7 @@ def create_and_check_model(self, config, pixel_values): self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) def create_and_check_model_with_projection(self, config, pixel_values): - model = PixtralModel(config=config) + model = PixtralVisionModel(config=config) model.to(torch_device) model.eval() with torch.no_grad(): @@ -140,17 +140,17 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class PixtralModelModelTest(ModelTesterMixin, unittest.TestCase): +class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase): """ - Model tester for `PixtralModel`. + Model tester for `PixtralVisionModel`. """ - all_model_classes = (PixtralModel,) if is_torch_available() else () + all_model_classes = (PixtralVisionModel,) if is_torch_available() else () test_pruning = False test_head_masking = False def setUp(self): - self.model_tester = PixtralModelTester(self) + self.model_tester = PixtralVisionModelTester(self) self.config_tester = ConfigTester(self, config_class=PixtralVisionConfig, has_text_modality=False) @unittest.skip("model does not support input embeds") @@ -261,7 +261,7 @@ def test_determinism(self): @require_torch -class PixtralModelIntegrationTest(unittest.TestCase): +class PixtralVisionModelIntegrationTest(unittest.TestCase): def setUp(self): self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b") @@ -273,7 +273,7 @@ def tearDown(self): @require_bitsandbytes def test_small_model_integration_test(self): # Let' s make sure we test the preprocessing to replace what is used - model = PixtralModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True) + model = PixtralVisionModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True) prompt = "[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]" image_file = "https://pixtral-vl.github.io/static/images/view.jpg"