Skip to content

[WIP] Gemma3 support. #2485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open

[WIP] Gemma3 support. #2485

wants to merge 27 commits into from

Conversation

krammnic
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Mar 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2485

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 12, 2025
@krammnic
Copy link
Contributor Author

krammnic commented Mar 12, 2025

  • Recipes, recipe registry
  • Component builders (with all changes according to the tech report)
  • Model builders (with all correct params according to the tech report)
  • SigLIP
  • Convert weight
  • Multimodal model
  • Tokenizer? (I assume that it is almost same as in Gemma2)
  • Was able to run 1B without multimodal
  • Manual runs for multimodal versions

@krammnic krammnic changed the title [Early WIP] Gemma3 support. [WIP] Gemma3 support. Mar 12, 2025
@krammnic
Copy link
Contributor Author

Was able to run full on 1B (without multimodal)

@felipemello1
Copy link
Contributor

Hey Mark, great work! As a sanity check, do you think you could compare vs HF for correctness with a larger sentence?

You can follow this as a start: https://gist.github.com/felipemello1/e3f1b1c358e145c7a4d610cf44cca374

@krammnic
Copy link
Contributor Author

Hey Mark, great work! As a sanity check, do you think you could compare vs HF for correctness with a larger sentence?

You can follow this as a start: https://gist.github.com/felipemello1/e3f1b1c358e145c7a4d610cf44cca374

Hey Felipe! Yep, sure, it is still WIP until I will be confident (we will do multimodal runs) and some configs will be fixed by Gemma team

@bzz
Copy link
Contributor

bzz commented Apr 4, 2025

Curious about

some configs will be fixed by Gemma team

any chance it refers to huggingface/transformers#36683 ?
That seems to have been patched in Transformers

@krammnic
Copy link
Contributor Author

krammnic commented Apr 6, 2025

Curious about

some configs will be fixed by Gemma team

any chance it refers to huggingface/transformers#36683 ? That seems to have been patched in Transformers

Hey @bzz. Exactly, the issue is very similar. Unfortunately, we require more information on the conversion stage from the config, which is missing in 4b config.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your patience on this one! A couple high-level comments:

  1. The whole text-only 1B vs multimodal 4B+ thing is a bit awkward. I feel like in this PR you wrote a bunch of the SigLIP components but didn't really hook them up to anything, and so the actual builders are all text-only. I think that's fine if there's still stuff up in the air (viz. (2)), but wonder if we should do something similar to our other multimodal models: provide gemma_decoder, gemma_vision_encoder, then for 4B+ we hook into EarlyFusion and for 1B we just use the decoder directly.
  2. Can you share more on some of the blockers around HF config for 4B+ models you were alluding to? I want to understand how much we should try to hack around things vs just hold off here.

@@ -120,7 +120,7 @@ def forward(
# Norm applied before self-attention
h = self.sa_norm(x)
attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: just lint this file to remove the whitespace changes

@@ -0,0 +1,114 @@
# Config for multi-device QLoRA finetuning in lora_finetune_single_device.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to rename this file: 27_qlora_single_device.yaml -> 27B_qlora_single_device.yaml

# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
path: /tmp/gemma-3-4=12b-it/tokenizer.model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this path needs to be changed?

Comment on lines 44 to 57
checkpoint_files: [
model-00001-of-00012.safetensors,
model-00002-of-00012.safetensors,
model-00003-of-00012.safetensors,
model-00004-of-00012.safetensors,
model-00005-of-00012.safetensors,
model-00006-of-00012.safetensors,
model-00007-of-00012.safetensors,
model-00008-of-00012.safetensors,
model-00009-of-00012.safetensors,
model-00010-of-00012.safetensors,
model-00011-of-00012.safetensors,
model-00012-of-00012.safetensors,
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ones that are a bit longer, you can also do this. Personally I prefer it whenever there are >5 files, but no strong preference here

# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
path: /tmp/gemma-3-4=12b-it/tokenizer.model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to update this one too

q_norm=GemmaRMSNorm(head_dim, eps=norm_eps),
attn_dropout=attn_dropout,
# perform global only on the each 6 layer, according to the tech-report
sliding_window_size=sliding_window_size if (layer_idx % 6) != 0 or layer_idx == 0 else None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that both this and Llama4 are doing this interleaving of local and global attention layers, we should think about whether there's a more general abstraction we can use to make this easier. (No need to worry about it for this PR though)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do it in a follow up

Comment on lines +204 to +205
local_rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=local_rope_base)
global_rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=global_rope_base)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also see that they do linear RoPE scaling by a factor of 8, is that right? If so do we need to make any change here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have done this (according to the tech report):

"We increase RoPE base frequency from 10k to 1M on globa lself-attention layers, and keep the frequency of the local layers at 10k."

But yes, I'm not sure about this:
"We find a scaling factor of 8 to work well in practice."

Let me investigate little bit on this further.

self.final_norm = nn.LayerNorm(embed_dim, layer_norm_eps)
self.avg_pool = SiglipAveragePooling()

@torch.inference_mode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw this in the reference implementation -- do you know why they do it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They do not update the SigLIP model on the training and post-training, using same pre-trained model for the all models

width = int(seq_len ** 0.5)
if width * width != seq_len:
raise ValueError(
f"Sequence length {seq_len} is not a perfect square. Cannot reshape to a square image."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that Gemma3 expects square images of a fixed size. Where are we doing the image processing to make that happen?

See discussion here: https://github.com/pytorch/torchtune/pull/1835#discussion_r1803410251
"""

_GEMMA3_FROM_HF = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To check my understanding here: the 1B model is text-only, 4B+ are all multimodal. This means that the HF weights for the 1B model will look like what you've given here, but the weights for 4B+ models would be language_model.{keys you have here}. And this is why they provide Gemma3ForConditionalGeneration and Gemma3ForCausalLM as the separate classes for the different model sizes. But on our side, I think there are two options:

a) just include the vision keys in the mapping, they should be ignored for the 1B model, then add an optional prefix to every key in the mapping
a) provide a different model type for 4B+ Gemma3 models (e.g. Gemma3VLM)

Personally I think (a) is preferable if it's feasible.

You also mentioned there were some difficulties around getting the information you need for the 4B+ models from the config. Is there a hard blocker here? Naively looking at what you added in _checkpointer.py it seems to me that it should work, but maybe I am missing something obvious. (Maybe we can move that logic into a utility in this file and import it, so as not to clutter up the checkpointer code -- something like _infer_gemma3_attn_data_from_config)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming that a) is a first a) not a second one. The problem is that it looks like we can't do it like this, because of the different structure of the checkpoints :/ Speaking about config.json problem, check this out: https://huggingface.co/google/gemma-3-4b-it/discussions/14

@krammnic
Copy link
Contributor Author

  1. o the actual builders are all text-only. I think that's fine if there's still stuff up in the air (viz. (2)), but wonder if we should do something similar to our other multimodal models: provide g
  1. Sure! I will push this changes (with EarlyFusion after we will decide on the structure of the 1B vs 4B+ and discuss the blocker)
  2. Speaking about blocker: https://huggingface.co/google/gemma-3-4b-it/discussions/14

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants