Skip to content

Fix Gemma 3 VLM detection for dual-registered configs#77

Merged
RobotSail merged 1 commit intoRed-Hat-AI-Innovation-Team:mainfrom
RobotSail:fix/gemma3-vlm-extraction
Mar 24, 2026
Merged

Fix Gemma 3 VLM detection for dual-registered configs#77
RobotSail merged 1 commit intoRed-Hat-AI-Innovation-Team:mainfrom
RobotSail:fix/gemma3-vlm-extraction

Conversation

@RobotSail
Copy link
Copy Markdown
Collaborator

@RobotSail RobotSail commented Mar 20, 2026

Summary

Same fix as instructlab/training#695 but for mini_trainer.

is_vlm_with_causal_lm() returns False for Gemma 3 because Gemma3Config is dual-registered in both MODEL_FOR_CAUSAL_LM_MAPPING and MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING. The function bails out early when the config is in the CausalLM mapping, but AutoModelForCausalLM actually resolves to Gemma3ForConditionalGeneration (a full VLM), not a text-only CausalLM.

The fix checks what class the mapping resolves to. If it's a ForConditionalGeneration VLM, the model is treated as needing backbone extraction.

Note: The extraction itself works correctly, but Gemma3ForCausalLM.__init__ then fails with 'Gemma3TextConfig' object has no attribute 'vision_config' — that's a separate issue in the transformers Gemma 3 model code where the CausalLM class references self.config.vision_config during init, which doesn't exist on Gemma3TextConfig. This PR fixes the detection logic; the transformers bug will need a separate upstream fix.

Test plan

  • Verified routing for all models: text-only models (Qwen2, Llama, Granite, Mistral) still use AutoModelForCausalLM, VLMs with extractable backbones (Gemma 3, Gemma 3n, Ministral, Mistral3-VLM) correctly route to extract_causal_lm, VLMs without CausalLM variant (Qwen3-VL) use direct_vlm_load
  • Gemma 3 backbone extraction succeeds (logs show ✅ Extracted Gemma3ForCausalLM successfully)

Summary by CodeRabbit

  • Bug Fixes
    • Improved vision–language model detection and model-class resolution to better handle dual-registered configurations, reducing misclassification and improving selection of text-generation components.
    • Liger kernel application now falls back to the model's text-config type when available and otherwise skips with a warning instead of erroring, avoiding hard failures for unsupported types.

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 20, 2026

📝 Walkthrough

Walkthrough

is_vlm_with_causal_lm and get_model_class_from_config now prefer resolving mapped classes that include "ForConditionalGeneration" by inspecting a nested text_config; setup_model_for_training's Liger kernel application now falls back to text_config.model_type and logs+skips unsupported types instead of raising. (36 words)

Changes

Cohort / File(s) Summary
VLM detection logic
src/mini_trainer/vlm_utils.py
Refined is_vlm_with_causal_lm(config) to resolve the mapped class and only treat configs as VLM-backed CausalLMs when the resolved class name includes ForConditionalGeneration and the nested config.text_config exists and maps to a causal-LM entry; otherwise returns False or falls back to prior text_config check.
Model class resolution
src/mini_trainer/utils.py
Adjusted get_model_class_from_config() to, when the initial mapping's resolved class name contains ForConditionalGeneration, prefer the mapped class for config.text_config if present and mapped; otherwise return the resolved top-level mapped class.
Liger kernel application
src/mini_trainer/setup_model_for_training.py
Changed _apply_liger_kernels_if_requested to try model_config.text_config.model_type as a fallback for finding an apply_fn; if none found, log a rank-0 warning and skip Liger optimization instead of raising an error.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • Maxusmusti

Poem

🐰 A hop through configs, curious and spry,

I peek at text_config before I try,
ForConditionalGeneration is my clue,
I nudge kernels gently — skip if they stew,
Hooray for care that keeps builds spry! 🎉

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: fixing VLM detection specifically for Gemma 3's dual-registered config. The change across three files addresses exactly this issue by improving detection logic for models registered in multiple mappings.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link
Copy Markdown

codecov bot commented Mar 20, 2026

Codecov Report

❌ Patch coverage is 4.54545% with 21 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/mini_trainer/setup_model_for_training.py 0.00% 9 Missing ⚠️
src/mini_trainer/utils.py 0.00% 6 Missing ⚠️
src/mini_trainer/vlm_utils.py 14.28% 6 Missing ⚠️

📢 Thoughts on this report? Let us know!

@RobotSail RobotSail force-pushed the fix/gemma3-vlm-extraction branch from 3601a80 to 6f2ddbf Compare March 23, 2026 17:06
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
src/mini_trainer/utils.py (1)

159-162: Replace class-name heuristic with explicit dual-registration check.

Line 159's reliance on "ForConditionalGeneration" in resolved_cls.__name__ is fragile—class names can change across transformers versions. Use an explicit check against MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING (already used as a fallback) to detect dual-registered configs more reliably.

Proposed refactor
 def get_model_class_from_config(model_path):
     """Get the actual model class (not just the name) from a pretrained path.
@@ -168,11 +168,11 @@ def get_model_class_from_config(model_path):
     config_class = config.__class__
     if config_class in mapping:
         resolved_cls = mapping[config_class]
-        # Some models (e.g. Gemma 3) are dual-registered: the top-level config
-        # maps to a ForConditionalGeneration VLM, not a text-only CausalLM.
-        # In that case, prefer the text_config's CausalLM class instead.
-        if "ForConditionalGeneration" in resolved_cls.__name__:
+        # Dual-registered VLM configs may map to multimodal generation classes;
+        # prefer text backbone CausalLM when available.
+        from transformers.models.auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
+        if config_class in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING:
             text_config = getattr(config, "text_config", None)
             if text_config is not None and text_config.__class__ in mapping:
                 return mapping[text_config.__class__]
         return resolved_cls
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/mini_trainer/utils.py` around lines 159 - 162, The code currently detects
dual-registered text+image models by checking if "ForConditionalGeneration" is
in resolved_cls.__name__, which is fragile; instead check whether the config
class is present in the explicit MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING mapping:
inside the branch where resolved_cls and text_config are obtained (symbols:
resolved_cls, text_config, mapping), replace the string-name heuristic with a
membership test like if text_config.__class__ in
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING (or check mapping is
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING) and then return
mapping[text_config.__class__]; keep the existing fallback behavior if that
explicit mapping check fails.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@src/mini_trainer/utils.py`:
- Around line 159-162: The code currently detects dual-registered text+image
models by checking if "ForConditionalGeneration" is in resolved_cls.__name__,
which is fragile; instead check whether the config class is present in the
explicit MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING mapping: inside the branch where
resolved_cls and text_config are obtained (symbols: resolved_cls, text_config,
mapping), replace the string-name heuristic with a membership test like if
text_config.__class__ in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING (or check mapping
is MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING) and then return
mapping[text_config.__class__]; keep the existing fallback behavior if that
explicit mapping check fails.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5d01063b-f811-4c72-aa17-9fbdf934e369

📥 Commits

Reviewing files that changed from the base of the PR and between 3601a80 and 6f2ddbf.

📒 Files selected for processing (2)
  • src/mini_trainer/utils.py
  • src/mini_trainer/vlm_utils.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/mini_trainer/vlm_utils.py

…skip

Three fixes for models with dual-registered or unsupported VLM configs:

1. vlm_utils.py: is_vlm_with_causal_lm() — when the top-level config is in
   MODEL_FOR_CAUSAL_LM_MAPPING but resolves to a ForConditionalGeneration
   class (a VLM), treat it as needing backbone extraction.

2. utils.py: get_model_class_from_config() — same issue for OSFT path.
   When the CausalLM mapping resolves to a VLM class, fall through to use
   text_config's CausalLM class so OSFT doesn't wrap the VLM.

3. setup_model_for_training.py: _apply_liger_kernels_if_requested() —
   gracefully skip Liger with a warning instead of crashing when the model
   type is not supported (gemma3n, mistral3, qwen3_5). Also tries the
   text_config model_type as fallback for extracted VLM models.

Tested: gemma3, gemma3n, ministral, mistral3-vlm, qwen3.5 all pass SFT
and OSFT (with Liger graceful skip) to loss 0.0000 on overfit dataset.

Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com>
@RobotSail RobotSail force-pushed the fix/gemma3-vlm-extraction branch from 6f2ddbf to dc4b28b Compare March 24, 2026 07:37
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
src/mini_trainer/setup_model_for_training.py (1)

71-89: Well-structured fallback logic for VLM model types.

The text_config fallback aligns with the patterns established in vlm_utils.py and utils.py, and the graceful degradation (warning + skip instead of raising) is appropriate for the Gemma 3 VLM scenario described in the PR.

One minor UX improvement: when both the original model_type and text_model_type were attempted, the warning only reports the original. Consider enhancing clarity:

💬 Optional: Show both attempted model types in warning
 if apply_fn is None:
+    attempted = [model_type]
+    if text_model_type and text_model_type != model_type:
+        attempted.append(text_model_type)
     log_rank_0(
-        f"⚠️  Liger kernels do not support model type '{model_type}' — "
+        f"⚠️  Liger kernels do not support model type(s) {attempted} — "
         f"skipping Liger optimization. Training will proceed without it."
     )
     return
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/mini_trainer/setup_model_for_training.py` around lines 71 - 89, The
warning only reports model_type even when you attempted the text_config
fallback; update the log produced by log_rank_0 in the apply_fn None path to
mention both the original model_type and, if present and different,
text_model_type (obtained from getattr(model_config, "text_config", None) and
getattr(text_config, "model_type", None)), so the message clearly lists which
model types were tried against MODEL_TYPE_TO_APPLY_LIGER_FN and that Liger
optimization is being skipped; keep the existing early-return behavior
unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@src/mini_trainer/setup_model_for_training.py`:
- Around line 71-89: The warning only reports model_type even when you attempted
the text_config fallback; update the log produced by log_rank_0 in the apply_fn
None path to mention both the original model_type and, if present and different,
text_model_type (obtained from getattr(model_config, "text_config", None) and
getattr(text_config, "model_type", None)), so the message clearly lists which
model types were tried against MODEL_TYPE_TO_APPLY_LIGER_FN and that Liger
optimization is being skipped; keep the existing early-return behavior
unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: cfa178cb-ab65-4e6c-a980-986987c1b99f

📥 Commits

Reviewing files that changed from the base of the PR and between 6f2ddbf and dc4b28b.

📒 Files selected for processing (3)
  • src/mini_trainer/setup_model_for_training.py
  • src/mini_trainer/utils.py
  • src/mini_trainer/vlm_utils.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/mini_trainer/utils.py
  • src/mini_trainer/vlm_utils.py

@RobotSail RobotSail merged commit 5730d5a into Red-Hat-AI-Innovation-Team:main Mar 24, 2026
10 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants