Fix Gemma 3 VLM detection for dual-registered configs#77
Conversation
📝 WalkthroughWalkthroughis_vlm_with_causal_lm and get_model_class_from_config now prefer resolving mapped classes that include "ForConditionalGeneration" by inspecting a nested Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
3601a80 to
6f2ddbf
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (1)
src/mini_trainer/utils.py (1)
159-162: Replace class-name heuristic with explicit dual-registration check.Line 159's reliance on
"ForConditionalGeneration"inresolved_cls.__name__is fragile—class names can change across transformers versions. Use an explicit check againstMODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING(already used as a fallback) to detect dual-registered configs more reliably.Proposed refactor
def get_model_class_from_config(model_path): """Get the actual model class (not just the name) from a pretrained path. @@ -168,11 +168,11 @@ def get_model_class_from_config(model_path): config_class = config.__class__ if config_class in mapping: resolved_cls = mapping[config_class] - # Some models (e.g. Gemma 3) are dual-registered: the top-level config - # maps to a ForConditionalGeneration VLM, not a text-only CausalLM. - # In that case, prefer the text_config's CausalLM class instead. - if "ForConditionalGeneration" in resolved_cls.__name__: + # Dual-registered VLM configs may map to multimodal generation classes; + # prefer text backbone CausalLM when available. + from transformers.models.auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING + if config_class in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING: text_config = getattr(config, "text_config", None) if text_config is not None and text_config.__class__ in mapping: return mapping[text_config.__class__] return resolved_cls🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/mini_trainer/utils.py` around lines 159 - 162, The code currently detects dual-registered text+image models by checking if "ForConditionalGeneration" is in resolved_cls.__name__, which is fragile; instead check whether the config class is present in the explicit MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING mapping: inside the branch where resolved_cls and text_config are obtained (symbols: resolved_cls, text_config, mapping), replace the string-name heuristic with a membership test like if text_config.__class__ in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING (or check mapping is MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING) and then return mapping[text_config.__class__]; keep the existing fallback behavior if that explicit mapping check fails.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@src/mini_trainer/utils.py`:
- Around line 159-162: The code currently detects dual-registered text+image
models by checking if "ForConditionalGeneration" is in resolved_cls.__name__,
which is fragile; instead check whether the config class is present in the
explicit MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING mapping: inside the branch where
resolved_cls and text_config are obtained (symbols: resolved_cls, text_config,
mapping), replace the string-name heuristic with a membership test like if
text_config.__class__ in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING (or check mapping
is MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING) and then return
mapping[text_config.__class__]; keep the existing fallback behavior if that
explicit mapping check fails.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5d01063b-f811-4c72-aa17-9fbdf934e369
📒 Files selected for processing (2)
src/mini_trainer/utils.pysrc/mini_trainer/vlm_utils.py
🚧 Files skipped from review as they are similar to previous changes (1)
- src/mini_trainer/vlm_utils.py
…skip Three fixes for models with dual-registered or unsupported VLM configs: 1. vlm_utils.py: is_vlm_with_causal_lm() — when the top-level config is in MODEL_FOR_CAUSAL_LM_MAPPING but resolves to a ForConditionalGeneration class (a VLM), treat it as needing backbone extraction. 2. utils.py: get_model_class_from_config() — same issue for OSFT path. When the CausalLM mapping resolves to a VLM class, fall through to use text_config's CausalLM class so OSFT doesn't wrap the VLM. 3. setup_model_for_training.py: _apply_liger_kernels_if_requested() — gracefully skip Liger with a warning instead of crashing when the model type is not supported (gemma3n, mistral3, qwen3_5). Also tries the text_config model_type as fallback for extracted VLM models. Tested: gemma3, gemma3n, ministral, mistral3-vlm, qwen3.5 all pass SFT and OSFT (with Liger graceful skip) to loss 0.0000 on overfit dataset. Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com>
6f2ddbf to
dc4b28b
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (1)
src/mini_trainer/setup_model_for_training.py (1)
71-89: Well-structured fallback logic for VLM model types.The
text_configfallback aligns with the patterns established invlm_utils.pyandutils.py, and the graceful degradation (warning + skip instead of raising) is appropriate for the Gemma 3 VLM scenario described in the PR.One minor UX improvement: when both the original
model_typeandtext_model_typewere attempted, the warning only reports the original. Consider enhancing clarity:💬 Optional: Show both attempted model types in warning
if apply_fn is None: + attempted = [model_type] + if text_model_type and text_model_type != model_type: + attempted.append(text_model_type) log_rank_0( - f"⚠️ Liger kernels do not support model type '{model_type}' — " + f"⚠️ Liger kernels do not support model type(s) {attempted} — " f"skipping Liger optimization. Training will proceed without it." ) return🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/mini_trainer/setup_model_for_training.py` around lines 71 - 89, The warning only reports model_type even when you attempted the text_config fallback; update the log produced by log_rank_0 in the apply_fn None path to mention both the original model_type and, if present and different, text_model_type (obtained from getattr(model_config, "text_config", None) and getattr(text_config, "model_type", None)), so the message clearly lists which model types were tried against MODEL_TYPE_TO_APPLY_LIGER_FN and that Liger optimization is being skipped; keep the existing early-return behavior unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@src/mini_trainer/setup_model_for_training.py`:
- Around line 71-89: The warning only reports model_type even when you attempted
the text_config fallback; update the log produced by log_rank_0 in the apply_fn
None path to mention both the original model_type and, if present and different,
text_model_type (obtained from getattr(model_config, "text_config", None) and
getattr(text_config, "model_type", None)), so the message clearly lists which
model types were tried against MODEL_TYPE_TO_APPLY_LIGER_FN and that Liger
optimization is being skipped; keep the existing early-return behavior
unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: cfa178cb-ab65-4e6c-a980-986987c1b99f
📒 Files selected for processing (3)
src/mini_trainer/setup_model_for_training.pysrc/mini_trainer/utils.pysrc/mini_trainer/vlm_utils.py
🚧 Files skipped from review as they are similar to previous changes (2)
- src/mini_trainer/utils.py
- src/mini_trainer/vlm_utils.py
5730d5a
into
Red-Hat-AI-Innovation-Team:main
Summary
Same fix as instructlab/training#695 but for mini_trainer.
is_vlm_with_causal_lm()returnsFalsefor Gemma 3 becauseGemma3Configis dual-registered in bothMODEL_FOR_CAUSAL_LM_MAPPINGandMODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING. The function bails out early when the config is in the CausalLM mapping, butAutoModelForCausalLMactually resolves toGemma3ForConditionalGeneration(a full VLM), not a text-only CausalLM.The fix checks what class the mapping resolves to. If it's a
ForConditionalGenerationVLM, the model is treated as needing backbone extraction.Note: The extraction itself works correctly, but
Gemma3ForCausalLM.__init__then fails with'Gemma3TextConfig' object has no attribute 'vision_config'— that's a separate issue in the transformers Gemma 3 model code where the CausalLM class referencesself.config.vision_configduring init, which doesn't exist onGemma3TextConfig. This PR fixes the detection logic; the transformers bug will need a separate upstream fix.Test plan
AutoModelForCausalLM, VLMs with extractable backbones (Gemma 3, Gemma 3n, Ministral, Mistral3-VLM) correctly route toextract_causal_lm, VLMs without CausalLM variant (Qwen3-VL) usedirect_vlm_load✅ Extracted Gemma3ForCausalLM successfully)Summary by CodeRabbit