diff --git a/mlx_vlm/models/qwen2_vl/language.py b/mlx_vlm/models/qwen2_vl/language.py index 614bc3d..cebcd22 100644 --- a/mlx_vlm/models/qwen2_vl/language.py +++ b/mlx_vlm/models/qwen2_vl/language.py @@ -29,6 +29,14 @@ def __post_init__(self): if self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads + if self.rope_scaling: + required_keys = {"mrope_section", "type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + if not self.rope_scaling["type"] == "mrope": + raise ValueError(f"rope_scaling type must be 'mrope'") + @classmethod def from_dict(cls, params): return cls( @@ -253,7 +261,7 @@ def __call__( inputs_embeds: Optional[mx.array] = None, ): if inputs_embeds is None: - h = self.embed_tokens(inputs).astype(mx.float32) + h = self.embed_tokens(inputs) else: h = inputs_embeds diff --git a/mlx_vlm/models/qwen2_vl/qwen2_vl.py b/mlx_vlm/models/qwen2_vl/qwen2_vl.py index dcdbbef..8746eee 100644 --- a/mlx_vlm/models/qwen2_vl/qwen2_vl.py +++ b/mlx_vlm/models/qwen2_vl/qwen2_vl.py @@ -142,13 +142,11 @@ def sanitize(self, weights): def transform_key(key): if "vision_tower" not in key: key = key.replace("visual", "vision_tower") - if "language_model.model" not in key and key.split(".")[0] not in [ - "lm_head", - "language_model", - ]: - key = key.replace("model", "language_model.model") - if "lm_head" in key and key.split(".")[0] == "lm_head": - key = key.replace("lm_head", "language_model.lm_head") + if "language_model" not in key: + if "model" in key: + key = key.replace("model", "language_model.model") + elif "lm_head" in key: + key = key.replace("lm_head", "language_model.lm_head") return key return {transform_key(k): v for k, v in weights.items()} diff --git a/mlx_vlm/tests/test_models.py b/mlx_vlm/tests/test_models.py index db92cb6..56c6366 100644 --- a/mlx_vlm/tests/test_models.py +++ b/mlx_vlm/tests/test_models.py @@ -52,6 +52,7 @@ def vision_test_runner( num_channels, image_size: tuple, vision_feature_layer=-2, + **kwargs, ): self.assertEqual(vision_tower.model_type, model_type) @@ -62,10 +63,11 @@ def vision_test_runner( ) # Perform a forward pass - *_, hidden_states = vision_tower(input_tensor, output_hidden_states=True) - # Check the output tensor shape + hidden_states = vision_tower(input_tensor, output_hidden_states=True, **kwargs) + + # Check vision hidden feature layer's shape matches the expected hidden size self.assertEqual( - hidden_states[vision_feature_layer][-1][-1].shape, (vision_hidden_size,) + hidden_states[vision_feature_layer].shape[-1], vision_hidden_size ) def test_llava_bunny(self): @@ -618,6 +620,66 @@ def test_phi3_v(self): (config.vision_config.image_size, config.vision_config.image_size), ) + def test_qwen2_vl(self): + from mlx_vlm.models import qwen2_vl + + text_config = qwen2_vl.TextConfig( + model_type="qwen2_vl", + hidden_size=32, + num_hidden_layers=4, + intermediate_size=37, + num_attention_heads=4, + rms_norm_eps=1e-6, + vocab_size=152064, + num_key_value_heads=4, + max_position_embeddings=512, + rope_theta=10000, + rope_scaling={"type": "mrope", "mrope_section": [2, 1, 1]}, + tie_word_embeddings=False, + ) + + vision_config = qwen2_vl.VisionConfig( + model_type="qwen2_vl", + depth=2, + embed_dim=32, + hidden_size=32, + image_size=224, + num_heads=4, + patch_size=14, + mlp_ratio=4, + in_channels=3, + spatial_merge_size=1, + temporal_patch_size=2, + ) + + config = qwen2_vl.ModelConfig( + model_type="qwen2_vl", + text_config=text_config, + vision_config=vision_config, + rope_scaling=text_config.rope_scaling, + image_token_index=151655, + vocab_size=32000, + ) + + model = qwen2_vl.Model(config) + + self.language_test_runner( + model.language_model, + config.text_config.model_type, + config.text_config.vocab_size, + config.text_config.num_hidden_layers, + ) + + self.vision_test_runner( + model.vision_tower, + config.vision_config.model_type, + config.vision_config.hidden_size, + config.vision_config.in_channels, + (config.vision_config.image_size, config.vision_config.image_size), + vision_feature_layer=-1, + grid_thw=mx.ones((1, 3)), # image temporals shape (num_images, 3) + ) + if __name__ == "__main__": unittest.main()