Skip to content

Commit

Permalink
Improve docs, rename model
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Sep 14, 2024
1 parent 8bd2b1e commit 628331e
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 52 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Phi3](model_doc/phi3) ||||
| [PhoBERT](model_doc/phobert) ||||
| [Pix2Struct](model_doc/pix2struct) ||||
| [Pixtral](model_doc/pixtral) | |||
| [Pixtral](model_doc/pixtral) | |||
| [PLBart](model_doc/plbart) ||||
| [PoolFormer](model_doc/poolformer) ||||
| [Pop2Piano](model_doc/pop2piano) ||||
Expand Down
16 changes: 9 additions & 7 deletions docs/source/en/model_doc/pixtral.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,22 @@ rendered properly in your Markdown viewer.

## Overview

The Pixtral model was released by the Mistral AI team on [Vllm](https://github.com/vllm-project/vllm/pull/8377), where a version of the code can be found!

The Pixtral model was released by the Mistral AI team on [vLLM](https://github.com/vllm-project/vllm/pull/8377), where a version of the code can be found!

Tips:

- Pixtral is a multimodal model, the main contribution is the 2d ROPE on the images, and support for arbitrary image size (the images are not padded together nor are they resized)
- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders.
- Pixtral is a multimodal model, taking images and text as input, and producing text as output.
- This model follows the [Llava](llava) family, meaning image embeddings are placed instead of the `[IMG]` token placeholders. The model uses [`PixtralVisionModel`] for its vision encoder.
- The main contribution is the 2d ROPE (rotary postiion embeddings) on the images, and support for arbitrary image sizes (the images are not padded together nor are they resized).
- The format for one or mulitple prompts is the following:
```
"<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
```
Then, the processor will replace each `[IMG]` token with a number of `[IMG]` token that depends on the height and the width of the image. Each *row* of the image is separated by a `[IMG_BREAK]` token, and each image is separated by a `[IMG_END]` token.

This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) and [ArthurZ](https://huggingface.co/ArthurZ)
This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) and [ArthurZ](https://huggingface.co/ArthurZ). The original code can be found [here](https://github.com/vllm-project/vllm/pull/8377).

## Usage

Here is an example of how to run it:

Expand Down Expand Up @@ -83,9 +85,9 @@ Each image captures a different scene, from a close-up of a dog to expansive nat

[[autodoc]] PixtralVisionConfig

## PixtralModel
## PixtralVisionModel

[[autodoc]] PixtralModel
[[autodoc]] PixtralVisionModel
- forward

## PixtralImageProcessor
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2978,7 +2978,7 @@
"Pix2StructVisionModel",
]
)
_import_structure["models.pixtral"].extend(["PixtralModel", "PixtralPreTrainedModel"])
_import_structure["models.pixtral"].extend(["PixtralVisionModel", "PixtralPreTrainedModel"])
_import_structure["models.plbart"].extend(
[
"PLBartForCausalLM",
Expand Down Expand Up @@ -7456,8 +7456,8 @@
Pix2StructVisionModel,
)
from .models.pixtral import (
PixtralModel,
PixtralPreTrainedModel,
PixtralVisionModel,
)
from .models.plbart import (
PLBartForCausalLM,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@
("persimmon", "PersimmonModel"),
("phi", "PhiModel"),
("phi3", "Phi3Model"),
("pixtral", "PixtralModel"),
("pixtral", "PixtralVisionModel"),
("plbart", "PLBartModel"),
("poolformer", "PoolFormerModel"),
("prophetnet", "ProphetNetModel"),
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/pixtral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
pass
else:
_import_structure["modeling_pixtral"] = [
"PixtralModel",
"PixtralVisionModel",
"PixtralPreTrainedModel",
]

Expand All @@ -52,8 +52,8 @@
pass
else:
from .modeling_pixtral import (
PixtralModel,
PixtralPreTrainedModel,
PixtralVisionModel,
)

try:
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/pixtral/configuration_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

class PixtralVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PixtralModel`]. It is used to instantiate an
This is the configuration class to store the configuration of a [`PixtralVisionModel`]. It is used to instantiate an
Pixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Pixtral-9B.
Expand Down Expand Up @@ -58,13 +58,13 @@ class PixtralVisionConfig(PretrainedConfig):
Example:
```python
>>> from transformers import PixtralModel, PixtralVisionConfig, CLIPVisionConfig, LlamaConfig
>>> from transformers import PixtralVisionModel, PixtralVisionConfig, CLIPVisionConfig, LlamaConfig
>>> # Initializing a Pixtral 12B style configuration
>>> config = PixtralVisionConfig()
>>> # Initializing a model from the pixtral 12B style configuration
>>> model = PixtralModel(configuration)
>>> model = PixtralVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Expand Down
38 changes: 14 additions & 24 deletions src/transformers/models/pixtral/modeling_pixtral.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
# Copyright 2024 Mistral and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -48,15 +48,13 @@ def position_ids_in_meshgrid(patch_embeds_list, max_width):
class PixtralRotaryEmbedding(nn.Module):
"""
The key with pixtral embedding is just that you have a frequency for each pixel positions.
If you have height x width pixels (or embedding pixels)
If you have height x width pixels (or embedding pixels), then the frequency used for ROPE
is given by indexing the pre_computed frequency on the width and height.
then the frequency used for ROPE is given by indexing the pre_computed frequency on the
width and height.
What you output is of dimension (batch, height * width, dim) with dim the embed dim.
What you output is of dimension batch, height * width, dim with dim the embed dim.
This simply means that for each image hidden states, you are going to add
a corresponding positional embedding, based on it's index in the grid.
This simply means that for each image hidden state, you are going to add
a corresponding positional embedding, based on its index in the grid.
"""

def __init__(self, config, device):
Expand Down Expand Up @@ -319,9 +317,7 @@ def forward(
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
Embeddings which serve as input to the Transformer.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
Expand Down Expand Up @@ -392,17 +388,13 @@ def forward(
and behavior.
Parameters:
config ([`PixtralVisionConfig`] or [`PixtralVisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
config ([`PixtralVisionConfig`]):
Model configuration class with all the parameters of the vision encoder. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""


@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
PIXTRAL_START_DOCSTRING,
)
class PixtralPreTrainedModel(PreTrainedModel):
config_class = PixtralVisionConfig
base_model_prefix = "model"
Expand All @@ -412,9 +404,6 @@ class PixtralPreTrainedModel(PreTrainedModel):
_supports_cache_class = True

def _init_weights(self, module):
# important: this ported version of Pixtral isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/LLaVA/tree/main/pixtral should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
Expand All @@ -433,8 +422,9 @@ def _init_weights(self, module):

PIXTRAL_INPUTS_DOCSTRING = r"""
Args:
pixel_values: list of N_img images of variable sizes,
each of shape (C, H, W)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`AutoImageProcessor.__call__`]
for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
Expand Down Expand Up @@ -463,10 +453,10 @@ def generate_block_attention_mask(patch_embeds_list, tensor):


@add_start_docstrings(
"""The PIXTRAL model which consists of a vision backbone and a language model.""",
"The bare Pixtral vision encoder outputting raw hidden-states without any specific head on top.",
PIXTRAL_START_DOCSTRING,
)
class PixtralModel(PixtralPreTrainedModel):
class PixtralVisionModel(PixtralPreTrainedModel):
base_model_prefix = "vision_encoder"

def __init__(self, config):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -7067,14 +7067,14 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class PixtralModel(metaclass=DummyObject):
class PixtralPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class PixtralPreTrainedModel(metaclass=DummyObject):
class PixtralVisionModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
Expand Down
20 changes: 10 additions & 10 deletions tests/models/pixtral/test_modeling_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from transformers import (
AutoProcessor,
PixtralModel,
PixtralVisionConfig,
PixtralVisionModel,
is_torch_available,
is_vision_available,
)
Expand All @@ -46,7 +46,7 @@
from PIL import Image


class PixtralModelTester:
class PixtralVisionModelTester:
def __init__(
self,
parent,
Expand Down Expand Up @@ -107,7 +107,7 @@ def get_config(self):
)

def create_and_check_model(self, config, pixel_values):
model = PixtralModel(config=config)
model = PixtralVisionModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
Expand All @@ -120,7 +120,7 @@ def create_and_check_model(self, config, pixel_values):
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))

def create_and_check_model_with_projection(self, config, pixel_values):
model = PixtralModel(config=config)
model = PixtralVisionModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
Expand All @@ -140,17 +140,17 @@ def prepare_config_and_inputs_for_common(self):


@require_torch
class PixtralModelModelTest(ModelTesterMixin, unittest.TestCase):
class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
"""
Model tester for `PixtralModel`.
Model tester for `PixtralVisionModel`.
"""

all_model_classes = (PixtralModel,) if is_torch_available() else ()
all_model_classes = (PixtralVisionModel,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False

def setUp(self):
self.model_tester = PixtralModelTester(self)
self.model_tester = PixtralVisionModelTester(self)
self.config_tester = ConfigTester(self, config_class=PixtralVisionConfig, has_text_modality=False)

@unittest.skip("model does not support input embeds")
Expand Down Expand Up @@ -261,7 +261,7 @@ def test_determinism(self):


@require_torch
class PixtralModelIntegrationTest(unittest.TestCase):
class PixtralVisionModelIntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b")

Expand All @@ -273,7 +273,7 @@ def tearDown(self):
@require_bitsandbytes
def test_small_model_integration_test(self):
# Let' s make sure we test the preprocessing to replace what is used
model = PixtralModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)
model = PixtralVisionModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)

prompt = "<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
image_file = "https://pixtral-vl.github.io/static/images/view.jpg"
Expand Down

0 comments on commit 628331e

Please sign in to comment.