Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Nemotron/Minitron GGUF Conversion & Inference Support #8922

Merged
merged 10 commits into from
Aug 16, 2024

Conversation

suhara
Copy link
Contributor

@suhara suhara commented Aug 8, 2024

This PR adds HF->GGUF conversion & inference support for Nemotron models including Nemotron-3, Nemotron-4 and "Minitron" models.

The PR should support any Nemotron/Minitron models but has been primarily tested with the following Minitron model

HF support for Nemotron has been recently added and as of Transformers 4.44.0 Nemotron is supported (Thank you @Vaibhavs10 for the information!). You may need to install a newer version of the transformers library by running pip install transformers>=4.44.0.

Please see this PR for details.

The Nemotron architecture is similar to the Llama-2 architecture with a few key differences:

  • Vocabulary size: Nemotron uses 256k SentencePiece tokenizer
  • FFN layer: Nemotron uses Squared ReLU (up and down projections)
  • RoPE scheduling: Nemotron uses partial (50%) RoPE
  • Layer Normalization: Nemotron adds 1 to LayerNorm's weight for better numerical stability

You can find details about the model architecture in the following papers:



This PR was created in collaboration with @SpaceCowboy850, who is another contributor to this PR.

@github-actions github-actions bot added the python python script changes label Aug 8, 2024
Copy link
Collaborator

@Vaibhavs10 Vaibhavs10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: As of Transformers 4.44.0 Nemotron is supported, so no need to install transformers from source.

@suhara
Copy link
Contributor Author

suhara commented Aug 8, 2024

Awesome! Thank you for sharing @Vaibhavs10! I've updated the original PR description.

@mofosyne mofosyne added the Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level label Aug 8, 2024
src/llama.cpp Outdated Show resolved Hide resolved
convert_hf_to_gguf.py Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
convert_hf_to_gguf.py Outdated Show resolved Hide resolved
@suhara
Copy link
Contributor Author

suhara commented Aug 9, 2024

Thank you @compilade for the comments and suggestions! Committed changes accordingly.

Copy link
Collaborator

@Vaibhavs10 Vaibhavs10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @suhara - can you rebase to main, specifically make sure this commit is in - this should fix the failing requirements.txt test.

@suhara
Copy link
Contributor Author

suhara commented Aug 13, 2024

Hi @Vaibhavs10 , thanks for reviewing! I rebased it onto the latest main branch.

Copy link
Collaborator

@Vaibhavs10 Vaibhavs10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what am I missing here but I wasn't able to make the GGUF run tried to test the PR via this:

git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp && gh pr checkout 8922
huggingface-cli download nvidia/Minitron-4B-Base --local-dir minitron --local-dir-use-symlinks False
python convert_hf_to_gguf.py minitron --outtype f16 --outfile model.gguf
llama-cli -m model.gguf -p "Meaning to life is"

I get error loading model architecture: unknown model architecture: 'nemotron'

EDIT: I'm stupid, I was using an older binary!

Copy link
Collaborator

@Vaibhavs10 Vaibhavs10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested it (using the steps mentioned above), it works quite well!

Let's wait for @compilade to review + approve then we can merge! 🤗

src/llama.cpp Outdated
Comment on lines 7630 to 7632
// optional MLP bias
layer.ffn_down_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_up_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not correct to use ctx_split for bias tensors, it should use ctx_layer instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your comment @slaren !

Sorry for the naive question. What's the difference between ctx_split and ctx_layer?

Something not clear is that some part of llama.cpp uses ctx_split for bias tensors as well.

For example,

Should they be corrected (which is out of the scope of this PR but wanted to ask to have a better understanding of them)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx_split only makes a difference when using tensor parallelism with -sm row, which is only supported on the CUDA backend when using multiple GPUs. When using -sm row, ctx_split splits the rows of the matrix between the available GPUs. This is only supported for matrix multiplication, so it should only be used with the matrix portion of linear/dense layers. The other cases are also wrong and should be corrected as well, but it doesn't need to be done here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! Updated the two lines accordingly. Agree with you that the other parts should be fixed outside this PR.

@suhara
Copy link
Contributor Author

suhara commented Aug 14, 2024

Hi @compilade
Can you take a look and see if it looks good to you? Thank you!

Copy link
Collaborator

@compilade compilade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

src/llama.cpp Show resolved Hide resolved
@suhara
Copy link
Contributor Author

suhara commented Aug 15, 2024

Thank you all for your reviews and support @compilade @Vaibhavs10 @ggerganov @slaren !

Could anybody help merge this PR? Thank you!

@slaren slaren merged commit 2a24c8c into ggerganov:master Aug 16, 2024
54 checks passed
@schmorp
Copy link

schmorp commented Aug 16, 2024

Sorry for disturbing, but when I try to convert the linked minitron-4b model with transformers 4.44.0 and current llama.cpp, it simply complains about missing tokenizer.model. Any idea why that could be?

@suhara
Copy link
Contributor Author

suhara commented Aug 17, 2024

Hi @schmorp

I think the repo has been updated and tokenizer.model (in the sentencepiece format) is not hosted there anymore.

You can actually extract tokenizer.model from nemo/minitron-4b-base.nemo

$ cd minitron/nemo
$ tar -xf minitron-4b-base.nemo
$ ls
914829c706e34a92ab89d5213695f4e5_nemotron_2_256k.model
b1bc02bf987043f3884c39152f183238_nemotron_2_256k.model
minitron-4b-base.nemo
model_config.yaml
model_weights
$ cp 914829c706e34a92ab89d5213695f4e5_nemotron_2_256k.model ../tokenizer.model

$ cd ../../
$ python convert_hf_to_gguf.py minitron --outtype f16 --outfile model.gguf

There are two tokenizer files but they are the same and either can be renamed astoknizer.model

  • 914829c706e34a92ab89d5213695f4e5_nemotron_2_256k.model
  • b1bc02bf987043f3884c39152f183238_nemotron_2_256k.model

@schmorp
Copy link

schmorp commented Aug 17, 2024

@suhara thanks a lot!

@schmorp
Copy link

schmorp commented Aug 17, 2024

Minitron-8B converts, but then can't be used:

llm_load_tensors: ggml ctx size = 0.15 MiB
llama_model_load: error loading model: check_tensor_dims: tensor 'blk.0.attn_q.weight' has wrong shape; expected 4096, 4096, got 4096, 6144, 1, 1
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '/tmp/Minitron-8B-Base.gguf'
main : failed to init

@schmorp
Copy link

schmorp commented Aug 17, 2024

Minitron-4B seems to work. So it seems Minitron-8B is not quite supported yet.

@suhara
Copy link
Contributor Author

suhara commented Aug 17, 2024

I'll look into this but I think I know the root cause.

8B uses head_dim: 128 and that may be the cause.
https://huggingface.co/nvidia/Minitron-8B-Base/blob/main/config.json#L25

Many HF models including Llama asserts head_dim == hidden_size // num_attention_heads.

llama_model_load: error loading model: check_tensor_dims: tensor 'blk.0.attn_q.weight' has wrong shape; expected 4096, 4096, got 4096, 6144, 1, 1

6144 = 128 * 48 so the conversion seems to be correct. The expectation (4096) is wrong.


FYI, for 4B, head_dim (128) == hidden_size (3072) // num_attention_heads (24) so this doesn't cause the issue.

@schmorp
Copy link

schmorp commented Aug 18, 2024

That's good news, thanks for looking into this. I'll have a try at the 340B.

@schmorp
Copy link

schmorp commented Aug 18, 2024

For the 340B, conversion instantly fails flat because there isn't a config.json file.

@nicoboss
Copy link
Contributor

nicoboss commented Aug 18, 2024

I tried nvidia/Nemotron-4-340B-Instruct as well. Turns out even if you add a config.json the conversion results in a metadata only GGUF as all Nemotron-3 and Nemotron-4 models lack pytorch_model.bin or any safetensor files.

The only option seems to be using the SafeTensor conversion provided by @mgoin under https://huggingface.co/collections/mgoin/nemotron-in-vllm-66a151b4240bcd9c28735ec5. He unfortunately never shared how he converted nemo into safetensor.

@mgoin
Copy link

mgoin commented Aug 18, 2024

@nicoboss if the conversion steps and script would be useful, I can document this tomorrow!

@nicoboss
Copy link
Contributor

nicoboss commented Aug 18, 2024

@nicoboss if the conversion steps and script would be useful, I can document this tomorrow!

This would be absolutely awesome. Thanks a lot! I’m very interested in how the conversion works. Maybe it would even be possible to implement it inside convert_hf_to_gguf.py. I'm currently working together with @schmorp to GGUF quantize all Nemotron-3, Nemotron-4 and "Minitron" models. While your collection is great it unfortunately misses many Nemotron-3 models which we could convert by our own if you share your tools and knowledge. Nemotron-4-340B-Instruct is one of my favorite models and I can't thank you enough to convert it into a usable format.

@schmorp
Copy link

schmorp commented Aug 22, 2024

And just to document this here, Llama-3.1-Minitron-4B-Width-Base fails with:

cvs/llama.cpp/ggml/src/ggml.c:6399: GGML_ASSERT(c->ne[0] >= n_dims / 2) failed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants