Skip to content

Commit

Permalink
Added: Qwen2-VL Unit Tests, Refactored Weight Sanitization (#63)
Browse files Browse the repository at this point in the history
* fix: extra 'language_model' prefixes in sanitization

* remove .astype(mx.float32) to support float16, add mRoPE config validation

* add qwen2_vl tests, assert vision layer feature dimension against hidden size

---------

Co-authored-by: b.zimring <[email protected]>
  • Loading branch information
benzimring and b.zimring authored Oct 6, 2024
1 parent 1926065 commit d4b562f
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 11 deletions.
10 changes: 9 additions & 1 deletion mlx_vlm/models/qwen2_vl/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads

if self.rope_scaling:
required_keys = {"mrope_section", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")

if not self.rope_scaling["type"] == "mrope":
raise ValueError(f"rope_scaling type must be 'mrope'")

@classmethod
def from_dict(cls, params):
return cls(
Expand Down Expand Up @@ -253,7 +261,7 @@ def __call__(
inputs_embeds: Optional[mx.array] = None,
):
if inputs_embeds is None:
h = self.embed_tokens(inputs).astype(mx.float32)
h = self.embed_tokens(inputs)
else:
h = inputs_embeds

Expand Down
12 changes: 5 additions & 7 deletions mlx_vlm/models/qwen2_vl/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,11 @@ def sanitize(self, weights):
def transform_key(key):
if "vision_tower" not in key:
key = key.replace("visual", "vision_tower")
if "language_model.model" not in key and key.split(".")[0] not in [
"lm_head",
"language_model",
]:
key = key.replace("model", "language_model.model")
if "lm_head" in key and key.split(".")[0] == "lm_head":
key = key.replace("lm_head", "language_model.lm_head")
if "language_model" not in key:
if "model" in key:
key = key.replace("model", "language_model.model")
elif "lm_head" in key:
key = key.replace("lm_head", "language_model.lm_head")
return key

return {transform_key(k): v for k, v in weights.items()}
68 changes: 65 additions & 3 deletions mlx_vlm/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def vision_test_runner(
num_channels,
image_size: tuple,
vision_feature_layer=-2,
**kwargs,
):
self.assertEqual(vision_tower.model_type, model_type)

Expand All @@ -62,10 +63,11 @@ def vision_test_runner(
)

# Perform a forward pass
*_, hidden_states = vision_tower(input_tensor, output_hidden_states=True)
# Check the output tensor shape
hidden_states = vision_tower(input_tensor, output_hidden_states=True, **kwargs)

# Check vision hidden feature layer's shape matches the expected hidden size
self.assertEqual(
hidden_states[vision_feature_layer][-1][-1].shape, (vision_hidden_size,)
hidden_states[vision_feature_layer].shape[-1], vision_hidden_size
)

def test_llava_bunny(self):
Expand Down Expand Up @@ -618,6 +620,66 @@ def test_phi3_v(self):
(config.vision_config.image_size, config.vision_config.image_size),
)

def test_qwen2_vl(self):
from mlx_vlm.models import qwen2_vl

text_config = qwen2_vl.TextConfig(
model_type="qwen2_vl",
hidden_size=32,
num_hidden_layers=4,
intermediate_size=37,
num_attention_heads=4,
rms_norm_eps=1e-6,
vocab_size=152064,
num_key_value_heads=4,
max_position_embeddings=512,
rope_theta=10000,
rope_scaling={"type": "mrope", "mrope_section": [2, 1, 1]},
tie_word_embeddings=False,
)

vision_config = qwen2_vl.VisionConfig(
model_type="qwen2_vl",
depth=2,
embed_dim=32,
hidden_size=32,
image_size=224,
num_heads=4,
patch_size=14,
mlp_ratio=4,
in_channels=3,
spatial_merge_size=1,
temporal_patch_size=2,
)

config = qwen2_vl.ModelConfig(
model_type="qwen2_vl",
text_config=text_config,
vision_config=vision_config,
rope_scaling=text_config.rope_scaling,
image_token_index=151655,
vocab_size=32000,
)

model = qwen2_vl.Model(config)

self.language_test_runner(
model.language_model,
config.text_config.model_type,
config.text_config.vocab_size,
config.text_config.num_hidden_layers,
)

self.vision_test_runner(
model.vision_tower,
config.vision_config.model_type,
config.vision_config.hidden_size,
config.vision_config.in_channels,
(config.vision_config.image_size, config.vision_config.image_size),
vision_feature_layer=-1,
grid_thw=mx.ones((1, 3)), # image temporals shape (num_images, 3)
)


if __name__ == "__main__":
unittest.main()

0 comments on commit d4b562f

Please sign in to comment.