From 1926065263d2612ba3aad1ab1e10fee079a00beb Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 29 Sep 2024 19:44:44 +0200 Subject: [PATCH] Add support for Pixtral-12B (#67) * add pixtral (working example) * convert to mlx * add prompt utils * add kwargs to all models and formatting * formatting * fix phi3v processor * refactor cache and mask * fix pixel val loading * formatting --- mlx_vlm/models/idefics2/idefics2.py | 7 +- mlx_vlm/models/idefics2/language.py | 11 +- mlx_vlm/models/llava/language.py | 11 +- mlx_vlm/models/llava/llava.py | 7 +- mlx_vlm/models/llava_bunny/language.py | 7 +- mlx_vlm/models/llava_bunny/llava_bunny.py | 1 + mlx_vlm/models/llava_next/language.py | 11 +- mlx_vlm/models/llava_next/llava_next.py | 7 +- mlx_vlm/models/multi_modality/language.py | 11 +- .../models/multi_modality/multi_modality.py | 7 +- mlx_vlm/models/paligemma/language.py | 12 +- mlx_vlm/models/phi3_v/phi3_v.py | 16 +- mlx_vlm/models/pixtral/__init__.py | 8 + mlx_vlm/models/pixtral/language.py | 220 ++++++++++++ mlx_vlm/models/pixtral/pixtral.py | 193 +++++++++++ mlx_vlm/models/pixtral/vision.py | 324 ++++++++++++++++++ mlx_vlm/models/qwen2_vl/vision.py | 4 +- mlx_vlm/prompt_utils.py | 9 +- mlx_vlm/utils.py | 34 +- 19 files changed, 842 insertions(+), 58 deletions(-) create mode 100644 mlx_vlm/models/pixtral/__init__.py create mode 100644 mlx_vlm/models/pixtral/language.py create mode 100644 mlx_vlm/models/pixtral/pixtral.py create mode 100644 mlx_vlm/models/pixtral/vision.py diff --git a/mlx_vlm/models/idefics2/idefics2.py b/mlx_vlm/models/idefics2/idefics2.py index 1703aff..1c78365 100644 --- a/mlx_vlm/models/idefics2/idefics2.py +++ b/mlx_vlm/models/idefics2/idefics2.py @@ -251,7 +251,12 @@ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_id return mx.concatenate(final_embeddings, axis=1) def __call__( - self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None + self, + input_ids: mx.array, + pixel_values: mx.array, + mask: mx.array, + cache=None, + **kwargs, ): input_embeddings = self.get_input_embeddings(input_ids, pixel_values) logits = self.language_model( diff --git a/mlx_vlm/models/idefics2/language.py b/mlx_vlm/models/idefics2/language.py index ff67afe..66f1bb8 100644 --- a/mlx_vlm/models/idefics2/language.py +++ b/mlx_vlm/models/idefics2/language.py @@ -6,6 +6,8 @@ import mlx.core as mx import mlx.nn as nn +from ..base import KVCache, create_attention_mask + @dataclass class TextConfig: @@ -62,7 +64,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -116,7 +118,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -153,10 +155,7 @@ def __call__( else: h = inputs_embeds - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h) if cache is None: cache = [None] * len(self.layers) diff --git a/mlx_vlm/models/llava/language.py b/mlx_vlm/models/llava/language.py index 8492002..732b636 100644 --- a/mlx_vlm/models/llava/language.py +++ b/mlx_vlm/models/llava/language.py @@ -5,6 +5,8 @@ import mlx.core as mx import mlx.nn as nn +from ..base import KVCache, create_attention_mask + @dataclass class TextConfig: @@ -78,7 +80,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -132,7 +134,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -166,10 +168,7 @@ def __call__( else: h = inputs_embeds - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h) if cache is None: cache = [None] * len(self.layers) diff --git a/mlx_vlm/models/llava/llava.py b/mlx_vlm/models/llava/llava.py index 7a7209a..39aae4a 100644 --- a/mlx_vlm/models/llava/llava.py +++ b/mlx_vlm/models/llava/llava.py @@ -132,7 +132,12 @@ def _merge_input_ids_with_image_features( return mx.concatenate(final_embeddings, axis=1) def __call__( - self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None + self, + input_ids: mx.array, + pixel_values: mx.array, + mask: mx.array, + cache=None, + **kwargs, ): input_embddings = self.get_input_embeddings(input_ids, pixel_values) logits = self.language_model( diff --git a/mlx_vlm/models/llava_bunny/language.py b/mlx_vlm/models/llava_bunny/language.py index e4010da..153a650 100644 --- a/mlx_vlm/models/llava_bunny/language.py +++ b/mlx_vlm/models/llava_bunny/language.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn -from ..base import KVCache +from ..base import KVCache, create_attention_mask @dataclass @@ -174,10 +174,7 @@ def __call__( else: h = inputs_embeds - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h) if cache is None: cache = [None] * len(self.layers) diff --git a/mlx_vlm/models/llava_bunny/llava_bunny.py b/mlx_vlm/models/llava_bunny/llava_bunny.py index cf210f9..0e145a2 100644 --- a/mlx_vlm/models/llava_bunny/llava_bunny.py +++ b/mlx_vlm/models/llava_bunny/llava_bunny.py @@ -183,6 +183,7 @@ def __call__( pixel_values: mx.array, mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, + **kwargs, ): input_embeddings = self.get_input_embeddings(input_ids, pixel_values) logits = self.language_model( diff --git a/mlx_vlm/models/llava_next/language.py b/mlx_vlm/models/llava_next/language.py index a374a4e..497e3b3 100644 --- a/mlx_vlm/models/llava_next/language.py +++ b/mlx_vlm/models/llava_next/language.py @@ -5,6 +5,8 @@ import mlx.core as mx import mlx.nn as nn +from ..base import KVCache, create_attention_mask + @dataclass class TextConfig: @@ -78,7 +80,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -132,7 +134,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -166,10 +168,7 @@ def __call__( else: h = inputs_embeds - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h) if cache is None: cache = [None] * len(self.layers) diff --git a/mlx_vlm/models/llava_next/llava_next.py b/mlx_vlm/models/llava_next/llava_next.py index dcc6018..878d7ca 100644 --- a/mlx_vlm/models/llava_next/llava_next.py +++ b/mlx_vlm/models/llava_next/llava_next.py @@ -136,7 +136,12 @@ def _merge_input_ids_with_image_features( return mx.concatenate(final_embeddings, axis=1) def __call__( - self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None + self, + input_ids: mx.array, + pixel_values: mx.array, + mask: mx.array, + cache=None, + **kwargs, ): input_embddings = self.get_input_embeddings(input_ids, pixel_values) diff --git a/mlx_vlm/models/multi_modality/language.py b/mlx_vlm/models/multi_modality/language.py index e598bf8..22a85d8 100644 --- a/mlx_vlm/models/multi_modality/language.py +++ b/mlx_vlm/models/multi_modality/language.py @@ -5,6 +5,8 @@ import mlx.core as mx import mlx.nn as nn +from ..base import KVCache, create_attention_mask + @dataclass class TextConfig: @@ -78,7 +80,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -132,7 +134,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -166,10 +168,7 @@ def __call__( else: h = inputs_embeds - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h) if cache is None: cache = [None] * len(self.layers) diff --git a/mlx_vlm/models/multi_modality/multi_modality.py b/mlx_vlm/models/multi_modality/multi_modality.py index de86be8..52a0bc9 100644 --- a/mlx_vlm/models/multi_modality/multi_modality.py +++ b/mlx_vlm/models/multi_modality/multi_modality.py @@ -360,7 +360,12 @@ def _merge_input_ids_with_image_features( return mx.concatenate(final_embeddings, axis=1) def __call__( - self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None + self, + input_ids: mx.array, + pixel_values: mx.array, + mask: mx.array, + cache=None, + **kwargs, ): input_embeddings = self.get_input_embeddings(input_ids, pixel_values) diff --git a/mlx_vlm/models/paligemma/language.py b/mlx_vlm/models/paligemma/language.py index 6c8c57e..32cdecd 100644 --- a/mlx_vlm/models/paligemma/language.py +++ b/mlx_vlm/models/paligemma/language.py @@ -5,6 +5,8 @@ import mlx.core as mx import mlx.nn as nn +from ..base import KVCache, create_attention_mask + @dataclass class TextConfig: @@ -66,7 +68,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -120,7 +122,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -155,11 +157,9 @@ def __call__( else: h = inputs_embeds - h = h * (self.config.hidden_size**0.5) + h *= self.config.hidden_size**0.5 - if cache is not None: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + mask = create_attention_mask(h) if cache is None: cache = [None] * len(self.layers) diff --git a/mlx_vlm/models/phi3_v/phi3_v.py b/mlx_vlm/models/phi3_v/phi3_v.py index 07fa579..770f001 100644 --- a/mlx_vlm/models/phi3_v/phi3_v.py +++ b/mlx_vlm/models/phi3_v/phi3_v.py @@ -8,6 +8,7 @@ import mlx.nn as nn import numpy as np +from ..base import KVCache, create_attention_mask from .language import LanguageModel, TextConfig from .su_rope import Phi3SuScaledRotaryEmbedding from .vision import VisionConfig, VisionModel @@ -90,7 +91,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -148,7 +149,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -179,16 +180,18 @@ def __call__( ): h = self.embed_tokens(inputs) p = np.argwhere(inputs < 0).tolist() + if pixel_values is not None: h = self.vision_embed_tokens(pixel_values, h, image_sizes, p) - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) + + mask = create_attention_mask(h) + if cache is None: cache = [None] * len(self.layers) + for layer, c in zip(self.layers, cache): h = layer(h, mask, c) + return self.norm(h) @@ -206,6 +209,7 @@ def __call__( pixel_values=None, mask=None, cache=None, + **kwargs, ): out = self.model(inputs, pixel_values, mask, cache) return self.lm_head(out).astype(self.lm_head.weight.dtype) diff --git a/mlx_vlm/models/pixtral/__init__.py b/mlx_vlm/models/pixtral/__init__.py new file mode 100644 index 0000000..41a9815 --- /dev/null +++ b/mlx_vlm/models/pixtral/__init__.py @@ -0,0 +1,8 @@ +from .pixtral import ( + LanguageModel, + Model, + ModelConfig, + TextConfig, + VisionConfig, + VisionModel, +) diff --git a/mlx_vlm/models/pixtral/language.py b/mlx_vlm/models/pixtral/language.py new file mode 100644 index 0000000..da8482c --- /dev/null +++ b/mlx_vlm/models/pixtral/language.py @@ -0,0 +1,220 @@ +import inspect +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from ..base import KVCache, create_attention_mask + + +@dataclass +class TextConfig: + model_type: str + hidden_size: int = 5120 + head_dim: int = 128 + num_hidden_layers: int = 40 + intermediate_size: int = 14336 + num_attention_heads: int = 32 + rms_norm_eps: float = 1e-06 + vocab_size: int = 131072 + num_key_value_heads: int = 8 + rope_theta: float = 1000000000.0 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.rope_scaling: + required_keys = {"factor", "type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + if self.rope_scaling["type"] != "linear": + raise ValueError("rope_scaling 'type' currently only supports 'linear'") + + +class Attention(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + + dim = config.hidden_size + self.n_heads = n_heads = config.num_attention_heads + self.n_kv_heads = n_kv_heads = config.num_key_value_heads + + head_dim = config.head_dim + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + rope_scale = ( + 1 / config.rope_scaling["factor"] + if config.rope_scaling is not None + and config.rope_scaling["type"] == "linear" + else 1 + ) + self.rope = nn.RoPE( + head_dim, + traditional=config.rope_traditional, + base=config.rope_theta, + scale=rope_scale, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.self_attn = Attention(config) + self.mlp = MLP(config.hidden_size, config.intermediate_size) + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.config = config + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + +class Mistral(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = [ + TransformerBlock(config=config) for _ in range(config.num_hidden_layers) + ] + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + inputs_embeds=None, + ): + # for passing merged input embeddings + if inputs_embeds is None: + h = self.embed_tokens(inputs) + else: + h = inputs_embeds + + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + return self.norm(h) + + +class LanguageModel(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + self.config = config + self.model_type = config.model_type + if self.model_type != "mistral": + raise ValueError( + f"Model type {self.model_type} not supported. Currently only 'mistral' is supported" + ) + self.model = Mistral(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + inputs_embeds=None, + mask: Optional[mx.array] = None, + ): + out = self.model(inputs, cache, inputs_embeds) + return self.lm_head(out) + + @staticmethod + def sanitize(weights): + # Remove unused precomputed rotary freqs + return { + k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k + } + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return self.config.head_dim + + @property + def n_kv_heads(self): + return self.config.num_key_value_heads diff --git a/mlx_vlm/models/pixtral/pixtral.py b/mlx_vlm/models/pixtral/pixtral.py new file mode 100644 index 0000000..b49397b --- /dev/null +++ b/mlx_vlm/models/pixtral/pixtral.py @@ -0,0 +1,193 @@ +import glob +import inspect +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from huggingface_hub import snapshot_download + +from .language import LanguageModel, TextConfig +from .vision import VisionConfig, VisionModel + + +@dataclass +class ModelConfig: + text_config: TextConfig + vision_config: VisionConfig + model_type: str + ignore_index: int = -100 + image_token_index: int = 10 + vision_feature_select_strategy: str = "full" + vision_feature_layer: int = -1 + vocab_size: int = 32000 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, config.text_config.hidden_size, bias=True + ) + self.gelu = nn.GELU() + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=True + ) + + def __call__(self, x: mx.array) -> mx.array: + x = self.linear_1(x) + x = self.gelu(x) + x = self.linear_2(x) + return x + + +class Model(nn.Module): + def __init__(self, config: ModelConfig): + self.config = config + self.vision_tower = VisionModel(config.vision_config) + self.language_model = LanguageModel(config.text_config) + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.vision_feature_layer = config.vision_feature_layer + self.vision_feature_select_strategy = config.vision_feature_select_strategy + + def get_input_embeddings( + self, + input_ids: Optional[mx.array] = None, + pixel_values: Optional[mx.array] = None, + ): + if pixel_values is None: + return self.language_model(input_ids) + + # Get the input embeddings from the language model + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + + # Get the ouptut hidden states from the vision model + *_, hidden_states = self.vision_tower( + pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True + ) + + # Select the hidden states from the desired layer + selected_image_feature = hidden_states[self.vision_feature_layer] + + if self.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + "Unexpected feature selection strategy: " + f"{self.vision_feature_select_strategy}" + ) + + # Pass image features through the multi-modal projector + image_features = self.multi_modal_projector(selected_image_feature) + + # Insert special image tokens in the input_ids + final_inputs_embeds = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids + ) + return final_inputs_embeds + + def _merge_input_ids_with_image_features( + self, image_features, inputs_embeds, input_ids + ): + image_token_index = self.config.image_token_index + num_images, num_image_patches, embed_dim = image_features.shape + + # Positions of tokens in input_ids, assuming batch size is 1 + image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() + + text_segments = [] + start_idx = 0 + + for position in image_positions: + text_segments.append(inputs_embeds[:, start_idx:position]) + start_idx = position + 1 + + image_embeddings = mx.split(image_features, image_features.shape[0]) + final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] + final_embeddings += [inputs_embeds[:, start_idx:]] + + # Create a final embedding of shape + # (1, num_image_patches*num_images + sequence_len, embed_dim) + return mx.concatenate(final_embeddings, axis=1) + + def __call__( + self, + input_ids: mx.array, + pixel_values: mx.array, + mask: mx.array, + cache=None, + **kwargs, + ): + input_embddings = self.get_input_embeddings(input_ids, pixel_values) + logits = self.language_model( + input_ids, cache=cache, inputs_embeds=input_embddings + ) + return logits + + @staticmethod + def from_pretrained(path_or_hf_repo: str): + path = Path(path_or_hf_repo) + if not path.exists(): + path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + ], + ) + ) + + with open(path / "config.json", "r") as f: + model_config = json.load(f) + + model_config = ModelConfig.from_dict(model_config) + + model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) + model_config.text_config = TextConfig.from_dict(model_config.text_config) + + model = Model(model_config) + weight_files = glob.glob(str(path / "*.safetensors")) + if not weight_files: + raise FileNotFoundError(f"No safetensors found in {path}") + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf)) + + weights = VisionModel.sanitize(weights) + weights = LanguageModel.sanitize(weights) + + model.load_weights(list(weights.items())) + return model + + def sanitize(self, weights): + def transform_key(key): + if "vision_tower" in key: + if "transformer" in key: + key = key.replace("vision_tower", "vision_tower.vision_model") + if "patch_conv" in key: + key = key.replace("vision_tower", "vision_tower.vision_model") + if "ln_pre" in key: + key = key.replace("vision_tower", "vision_tower.vision_model") + return key + + return {transform_key(k): v for k, v in weights.items()} diff --git a/mlx_vlm/models/pixtral/vision.py b/mlx_vlm/models/pixtral/vision.py new file mode 100644 index 0000000..2db77f4 --- /dev/null +++ b/mlx_vlm/models/pixtral/vision.py @@ -0,0 +1,324 @@ +import inspect +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + + +@dataclass +class VisionConfig: + model_type: str + num_hidden_layers: int = 24 + hidden_size: int = 1024 + head_dim: int = 64 + intermediate_size: int = 4096 + num_attention_heads: int = 16 + image_size: int = 336 + patch_size: int = 14 + projection_dim: int = 768 + vocab_size: int = 32000 + num_channels: int = 3 + rms_norm_eps: float = 1e-5 + rope_theta: float = 10000.0 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +def check_array_shape(arr): + shape = arr.shape + + # Check if the shape has 4 dimensions + if len(shape) != 4: + return False + + out_channels, kH, KW, _ = shape + + # Check if out_channels is the largest, and kH and KW are the same + if (out_channels >= kH) and (out_channels >= KW) and (kH == KW): + return True + else: + return False + + +def position_ids_in_meshgrid(patch_embeds_list, max_width): + positions = [] + for patch in patch_embeds_list: + height, width = patch.shape[1], patch.shape[2] + h_grid, v_grid = mx.meshgrid(mx.arange(height), mx.arange(width), indexing="ij") + h_grid = h_grid.reshape(-1, 1) + v_grid = v_grid.reshape(-1, 1) + ids = h_grid * max_width + v_grid + positions.append(ids.flatten()) + return mx.concatenate(positions) + + +def generate_block_attention_mask(patch_embeds_list, tensor): + seq_len = tensor.shape[1] + d_min = -1e9 # Using a large negative value as MLX doesn't have finfo + + causal_mask = mx.full((seq_len, seq_len), vals=d_min) + + block_end_idx = mx.cumsum(mx.array(patch_embeds_list)) + block_start_idx = mx.concatenate([mx.array([0]), mx.array(patch_embeds_list[:-1])]) + block_start_idx = mx.cumsum(block_start_idx) + + for start, end in zip(block_start_idx, block_end_idx): + start, end = int(start), int(end) # Convert to integers for indexing + causal_mask[start:end, start:end] = 0 + + causal_mask = mx.broadcast_to( + causal_mask[None, None, :, :], (tensor.shape[0], 1, seq_len, seq_len) + ) + return causal_mask + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return mx.concatenate((-x2, x1), axis=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + cos = mx.expand_dims(cos, axis=unsqueeze_dim) + sin = mx.expand_dims(sin, axis=unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Attention(nn.Module): + def __init__( + self, + dims: int, + num_heads: int, + query_input_dims: Optional[int] = None, + key_input_dims: Optional[int] = None, + value_input_dims: Optional[int] = None, + value_dims: Optional[int] = None, + value_output_dims: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + + if (dims % num_heads) != 0: + raise ValueError( + "The input feature dimensions should be divisible by the " + f"number of heads ({dims} % {num_heads}) != 0" + ) + + query_input_dims = query_input_dims or dims + key_input_dims = key_input_dims or dims + value_input_dims = value_input_dims or key_input_dims + value_dims = value_dims or dims + value_output_dims = value_output_dims or dims + + self.embed_dim = dims + self.num_heads = num_heads + self.head_dim = self.embed_dim // self.num_heads + + self.scale = self.head_dim**-0.5 + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + + def __call__(self, queries, keys, values, position_embeddings, mask=None): + queries = self.q_proj(queries) + keys = self.k_proj(keys) + values = self.v_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + cos, sin = position_embeddings + queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin, unsqueeze_dim=0) + + attn_weights = mx.matmul(queries, keys.transpose(0, 1, 3, 2)) * self.scale + + if mask is not None: + attn_weights = attn_weights + mask + + attn_weights = mx.softmax(attn_weights, axis=-1) + output = mx.matmul(attn_weights, values) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + dim = config.hidden_size + hidden_dim = config.intermediate_size + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class EncoderLayer(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.attention = Attention( + config.hidden_size, config.num_attention_heads, bias=True + ) + self.attention_norm = nn.RMSNorm(self.embed_dim, eps=config.rms_norm_eps) + self.feed_forward = MLP(config) + self.ffn_norm = nn.RMSNorm(self.embed_dim, eps=config.rms_norm_eps) + + def __call__( + self, + x: mx.array, + position_embeddings: mx.array, + mask: Optional[mx.array] = None, + ) -> mx.array: + y = self.attention_norm(x) + y = self.attention(y, y, y, position_embeddings, mask) + x = x + y + y = self.ffn_norm(x) + y = self.feed_forward(y) + return x + y + + +class Encoder(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] + + +class PixtralRotaryEmbedding: + def __init__(self, config): + self.dim = config.head_dim + self.base = config.rope_theta + max_patches_per_side = config.image_size // config.patch_size + freqs = 1.0 / ( + self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim) + ) + + h = mx.arange(max_patches_per_side) + w = mx.arange(max_patches_per_side) + + freqs_h = mx.outer(h, freqs[::2]).astype(mx.float32) + freqs_w = mx.outer(w, freqs[1::2]).astype(mx.float32) + inv_freq = mx.concatenate( + [ + mx.tile(freqs_h[:, None, :], (1, max_patches_per_side, 1)), + mx.tile(freqs_w[None, :, :], (max_patches_per_side, 1, 1)), + ], + axis=-1, + ).reshape(-1, self.dim // 2) + + self.inv_freq = mx.concatenate((inv_freq, inv_freq), axis=-1) + + def __call__(self, x, position_ids): + freqs = self.inv_freq[position_ids] + emb = freqs + cos = mx.cos(emb) + sin = mx.sin(emb) + return cos.astype(x.dtype), sin.astype(x.dtype) + + +class PixtralVisionModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.config = config + self.patch_conv = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + self.ln_pre = nn.RMSNorm(config.hidden_size) + self.transformer = Encoder(config) + self.patch_positional_embedding = PixtralRotaryEmbedding(config) + + def __call__( + self, + x: mx.array, + output_hidden_states: Optional[bool] = None, + ) -> mx.array: + B, H, W, C = x.shape + patch_embeds_list = [self.patch_conv(img[None, :]) for img in x] + + patch_embeds = mx.concatenate( + [p.reshape(B, -1, p.shape[-1]) for p in patch_embeds_list], axis=1 + ) + + patch_embeds = self.ln_pre(patch_embeds) + + position_ids = position_ids_in_meshgrid( + patch_embeds_list, + max_width=self.config.image_size // self.config.patch_size, + ) + + position_embedding = self.patch_positional_embedding(patch_embeds, position_ids) + + mask = generate_block_attention_mask( + [p.shape[2] * p.shape[1] for p in patch_embeds_list], patch_embeds + ) + + encoder_states = (patch_embeds,) if output_hidden_states else None + + for l in self.transformer.layers: + patch_embeds = l( + patch_embeds, mask=mask, position_embeddings=position_embedding + ) + if output_hidden_states: + encoder_states = encoder_states + (patch_embeds,) + + return patch_embeds, encoder_states + + +class VisionModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + + self.model_type = config.model_type + if self.model_type not in ["clip_vision_model", "pixtral"]: + raise ValueError(f"Unsupported model type: {self.model_type}") + + self.vision_model = PixtralVisionModel(config) + + def __call__( + self, x: mx.array, output_hidden_states: Optional[bool] = None + ) -> mx.array: + return self.vision_model(x, output_hidden_states) + + def sanitize(self, weights): + sanitized_weights = {} + for k, v in weights.items(): + if "position_ids" in k: + # Remove unused position_ids + continue + elif "patch_conv.weight" in k: + # PyTorch conv2d weight tensors have shape: + # [out_channels, in_channels, kH, KW] + # MLX conv2d expects the weight be of shape: + # [out_channels, kH, KW, in_channels] + if check_array_shape(v): + sanitized_weights[k] = v + else: + sanitized_weights[k] = v.transpose(0, 2, 3, 1) + else: + sanitized_weights[k] = v + + return sanitized_weights diff --git a/mlx_vlm/models/qwen2_vl/vision.py b/mlx_vlm/models/qwen2_vl/vision.py index ddc43ec..f07f880 100644 --- a/mlx_vlm/models/qwen2_vl/vision.py +++ b/mlx_vlm/models/qwen2_vl/vision.py @@ -310,9 +310,9 @@ def rot_pos_emb(self, grid_thw): max_grid_size = mx.max(grid_thw[:, 1:]) rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb_np = rotary_pos_emb_full[pos_ids] + rotary_pos_emb_full = rotary_pos_emb_full[pos_ids] - return rotary_pos_emb_np.reshape(pos_ids.shape[0], -1) + return rotary_pos_emb_full.reshape(pos_ids.shape[0], -1) def __call__( self, diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py index e8ce18d..b47a9a1 100644 --- a/mlx_vlm/prompt_utils.py +++ b/mlx_vlm/prompt_utils.py @@ -11,13 +11,13 @@ def get_message_json(model_name, prompt): Returns: dict: A dictionary representing the JSON message for the specified model. """ - if model_name.lower() in ["idefics2", "qwen2_vl"]: + if model_name.lower() in ["idefics2", "qwen2_vl", "llava"]: message = { "role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}], } - elif model_name.lower() in ["llava-qwen2", "llava", "llava_next", "bunny-llama"]: + elif model_name.lower() in ["llava-qwen2", "llava_next", "bunny-llama"]: message = {"role": "user", "content": f"\n{prompt}"} elif model_name.lower() == "phi3_v": message = {"role": "user", "content": f"<|image_1|>\n{prompt}"} @@ -25,6 +25,11 @@ def get_message_json(model_name, prompt): message = {"role": "user", "content": f"{prompt}"} elif model_name.lower() == "paligemma": message = prompt + elif model_name.lower() == "pixtral": + message = { + "role": "user", + "content": [{"type": "image"}, {"type": "text", "content": prompt}], + } else: raise ValueError(f"Unsupported model: {model_name}") diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 7adb01a..9545311 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -711,21 +711,37 @@ def prepare_inputs(image_processor, processor, image, prompt, image_token_index) if isinstance(image, str): image = load_image(image) + image_grid_thw = None if image_processor is not None: text_chunks = [processor(chunk).input_ids for chunk in prompt.split("")] input_ids = mx.array([text_chunks[0] + [image_token_index] + text_chunks[1]]) - pixel_values = image_processor.preprocess(images=[image])[0] - pixel_values = mx.array(np.expand_dims(pixel_values, axis=0)) + pixel_values = mx.array(image_processor.preprocess(images=[image])[0]) + pixel_values = mx.array(mx.expand_dims(pixel_values, axis=0)) else: - inputs = processor( - text=[prompt], images=[image], padding=True, return_tensors="np" - ) - pixel_values = mx.array(inputs["pixel_values"]) + processor.tokenizer.pad_token = processor.tokenizer.eos_token + try: + inputs = processor( + text=[prompt], images=[image], padding=True, return_tensors="mlx" + ) + except Exception as e: + inputs = processor( + text=prompt, images=[image], padding=True, return_tensors="mlx" + ) # for phi3_v model + + if isinstance(inputs["pixel_values"], list): + pixel_values = mx.array(inputs["pixel_values"][0][0])[None, :] + elif isinstance(inputs["pixel_values"], np.ndarray): + pixel_values = mx.array(inputs["pixel_values"]) + else: + raise ValueError( + f"Invalid pixel_values type: {type(inputs['pixel_values'])}" + ) + input_ids = mx.array(inputs["input_ids"]) - mask = mx.array(inputs["attention_mask"]) - if "image_sizes" in inputs: - return input_ids, pixel_values, inputs["image_sizes"] + mask = inputs["attention_mask"] image_grid_thw = inputs.get("image_grid_thw", None) + if "image_sizes" in inputs: + return input_ids, pixel_values, inputs["image_sizes"], image_grid_thw return input_ids, pixel_values, mask, image_grid_thw