-
Notifications
You must be signed in to change notification settings - Fork 582
[WIP] Gemma3 support. #2485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP] Gemma3 support. #2485
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2485
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Was able to run full on 1B (without multimodal) |
Hey Mark, great work! As a sanity check, do you think you could compare vs HF for correctness with a larger sentence? You can follow this as a start: https://gist.github.com/felipemello1/e3f1b1c358e145c7a4d610cf44cca374 |
Hey Felipe! Yep, sure, it is still WIP until I will be confident (we will do multimodal runs) and some configs will be fixed by Gemma team |
Curious about
any chance it refers to huggingface/transformers#36683 ? |
Hey @bzz. Exactly, the issue is very similar. Unfortunately, we require more information on the conversion stage from the config, which is missing in 4b config. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your patience on this one! A couple high-level comments:
- The whole text-only 1B vs multimodal 4B+ thing is a bit awkward. I feel like in this PR you wrote a bunch of the SigLIP components but didn't really hook them up to anything, and so the actual builders are all text-only. I think that's fine if there's still stuff up in the air (viz. (2)), but wonder if we should do something similar to our other multimodal models: provide
gemma_decoder
,gemma_vision_encoder
, then for 4B+ we hook intoEarlyFusion
and for 1B we just use the decoder directly. - Can you share more on some of the blockers around HF config for 4B+ models you were alluding to? I want to understand how much we should try to hack around things vs just hold off here.
torchtune/modules/transformer.py
Outdated
@@ -120,7 +120,7 @@ def forward( | |||
# Norm applied before self-attention | |||
h = self.sa_norm(x) | |||
attn_out = self.attn(h, h, mask=mask, input_pos=input_pos) | |||
|
|||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: just lint this file to remove the whitespace changes
@@ -0,0 +1,114 @@ | |||
# Config for multi-device QLoRA finetuning in lora_finetune_single_device.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to rename this file: 27_qlora_single_device.yaml
-> 27B_qlora_single_device.yaml
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.gemma.gemma_tokenizer | ||
path: /tmp/gemma-3-4=12b-it/tokenizer.model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this path needs to be changed?
recipes/configs/gemma3/27B_lora.yaml
Outdated
checkpoint_files: [ | ||
model-00001-of-00012.safetensors, | ||
model-00002-of-00012.safetensors, | ||
model-00003-of-00012.safetensors, | ||
model-00004-of-00012.safetensors, | ||
model-00005-of-00012.safetensors, | ||
model-00006-of-00012.safetensors, | ||
model-00007-of-00012.safetensors, | ||
model-00008-of-00012.safetensors, | ||
model-00009-of-00012.safetensors, | ||
model-00010-of-00012.safetensors, | ||
model-00011-of-00012.safetensors, | ||
model-00012-of-00012.safetensors, | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For ones that are a bit longer, you can also do this. Personally I prefer it whenever there are >5 files, but no strong preference here
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.gemma.gemma_tokenizer | ||
path: /tmp/gemma-3-4=12b-it/tokenizer.model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to update this one too
q_norm=GemmaRMSNorm(head_dim, eps=norm_eps), | ||
attn_dropout=attn_dropout, | ||
# perform global only on the each 6 layer, according to the tech-report | ||
sliding_window_size=sliding_window_size if (layer_idx % 6) != 0 or layer_idx == 0 else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that both this and Llama4 are doing this interleaving of local and global attention layers, we should think about whether there's a more general abstraction we can use to make this easier. (No need to worry about it for this PR though)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do it in a follow up
local_rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=local_rope_base) | ||
global_rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=global_rope_base) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also see that they do linear RoPE scaling by a factor of 8, is that right? If so do we need to make any change here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have done this (according to the tech report):
"We increase RoPE base frequency from 10k to 1M on globa lself-attention layers, and keep the frequency of the local layers at 10k."
But yes, I'm not sure about this:
"We find a scaling factor of 8 to work well in practice."
Let me investigate little bit on this further.
self.final_norm = nn.LayerNorm(embed_dim, layer_norm_eps) | ||
self.avg_pool = SiglipAveragePooling() | ||
|
||
@torch.inference_mode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw this in the reference implementation -- do you know why they do it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They do not update the SigLIP model on the training and post-training, using same pre-trained model for the all models
width = int(seq_len ** 0.5) | ||
if width * width != seq_len: | ||
raise ValueError( | ||
f"Sequence length {seq_len} is not a perfect square. Cannot reshape to a square image." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw that Gemma3 expects square images of a fixed size. Where are we doing the image processing to make that happen?
See discussion here: https://github.com/pytorch/torchtune/pull/1835#discussion_r1803410251 | ||
""" | ||
|
||
_GEMMA3_FROM_HF = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To check my understanding here: the 1B model is text-only, 4B+ are all multimodal. This means that the HF weights for the 1B model will look like what you've given here, but the weights for 4B+ models would be language_model.{keys you have here}
. And this is why they provide Gemma3ForConditionalGeneration
and Gemma3ForCausalLM
as the separate classes for the different model sizes. But on our side, I think there are two options:
a) just include the vision keys in the mapping, they should be ignored for the 1B model, then add an optional prefix to every key in the mapping
a) provide a different model type for 4B+ Gemma3 models (e.g. Gemma3VLM
)
Personally I think (a) is preferable if it's feasible.
You also mentioned there were some difficulties around getting the information you need for the 4B+ models from the config. Is there a hard blocker here? Naively looking at what you added in _checkpointer.py
it seems to me that it should work, but maybe I am missing something obvious. (Maybe we can move that logic into a utility in this file and import it, so as not to clutter up the checkpointer code -- something like _infer_gemma3_attn_data_from_config
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming that a) is a first a) not a second one. The problem is that it looks like we can't do it like this, because of the different structure of the checkpoints :/ Speaking about config.json problem, check this out: https://huggingface.co/google/gemma-3-4b-it/discussions/14
|
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example