Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nan loss when training Llama-3.2-vision #84

Open
Blaizzy opened this issue Oct 13, 2024 · 11 comments
Open

Nan loss when training Llama-3.2-vision #84

Blaizzy opened this issue Oct 13, 2024 · 11 comments
Assignees
Labels
enhancement New feature or request

Comments

@Blaizzy
Copy link
Owner

Blaizzy commented Oct 13, 2024

Issue

I keep getting nan loss when training Llama-3.2-vision

I tried:

  • gradient clipping
  • lower learning rate
  • higher batch size, lora rank and alpha

But with no success.

Steps to reproduce:

  1. Install pc/llama3.2-vision branch.
pip install -U git+https://github.com/Blaizzy/mlx-vlm.git@pc/llama3.2-vision
  1. Add these two lines (31-32) to the lora.py to limit the dataset.
dataset = load_dataset(args.dataset, split=args.split+"[:20%]")
dataset = dataset.rename_columns({"image": "images", "conversations": "messages"})
  1. Quantize model (Optional).
python -m mlx_vlm.convert --hf-path unsloth/Llama-3.2-11B-Vision-Instruct -q --mlx-path Llama-3.2-11B-Vision-Instruct-4bit
  1. Start training.
python -m mlx_vlm.lora --model-path  Llama-3.2-11B-Vision-Instruct-4bit --dataset
 5CD-AI/Viet-ShareGPT-4o-Text-VQA --split Viet_OCR_VQA --steps 100 --learning-rate 5e-6 --lora-rank 16 --lora-alpha 16
Screenshot 2024-10-13 at 8 44 55 PM
@Blaizzy
Copy link
Owner Author

Blaizzy commented Oct 13, 2024

cc: @awni

@awni
Copy link
Contributor

awni commented Oct 14, 2024

So there are a couple things you should change in general about your Llama implementation:

  1. Use nn.RMSNorm instead of rolling your own
  2. Use nn.RoPE instead of rolling your own

These will both be (much) faster and numerically more stable. The NaNs are getting introduced during overflow in your RMSNorm implementation. Typically whenever you accumulate a lot of numbers you need to accumulate the result in a higher precision (so mean in your case). The nn.RMSNorm does this implicitly without the need for casting between mx.float32 and mx.float16.

I double check most of your model files are using nn.RMSNorm or nn.LayerNorm when possible. And same for RoPE. The inference especially will be much faster.

@Blaizzy
Copy link
Owner Author

Blaizzy commented Oct 15, 2024

Thanks a lot!

Yes, I was using a custom RMSNorm, I changed it to nn.RMSNorm and it's 3.25x faster 🚀.

When it comes to rope I was already using nn.RoPE since there are no changes needed and it's easier to integrate with cache.

The NaNs are getting introduced during overflow in your RMSNorm implementation. Typically whenever you accumulate a lot of numbers you need to accumulate the result in a higher precision (so mean in your case). The nn.RMSNorm does this implicitly without the need for casting between mx.float32 and mx.float16.

How did you check this?

@Blaizzy
Copy link
Owner Author

Blaizzy commented Oct 15, 2024

@awni I made the recommended changes but I can't seem to be able to run training on my machine (M3 Max 96GB).

It throws an error after processing 3 samples even with batch size of 1.

{'Epoch': 0, 'Step': 0, 'Loss': '1.5820'}
  3%|█▍                                            | 3/100 [00:09<05:03,  3.13s/it, Epoch=0, Step=0, Loss=1.5820]
zsh: segmentation fault  python -m mlx_vlm.lora --model-path Llama-3.2-11B-Vision-Instruct-4bit
/opt/homebrew/Caskroom/miniconda/base/envs/mlx_code/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

Could you please try it on your M2 ultra and see if the nan loss persists?

@awni
Copy link
Contributor

awni commented Oct 15, 2024

It's running on my M1 Max (32GB) with this command:

python -m mlx_vlm.lora --model-path  Llama-3.2-11B-Vision-Instruct-4bit --dataset 5CD-AI/Viet-ShareGPT-4o-Text-VQA --split Viet_OCR_VQA --steps 100 --learning-rate 5e-6 --lora-rank 16 --lora-alpha 16

and the modifications to the dataset you posted above. So far it processed 11 steps no problem (I modified the print to print every step):

{'Epoch': 0, 'Step': 0, 'Loss': '1.5796'}
{'Epoch': 0, 'Step': 1, 'Loss': '1.8235'}
{'Epoch': 0, 'Step': 2, 'Loss': '1.9262'}
{'Epoch': 0, 'Step': 3, 'Loss': '1.5627'}
{'Epoch': 0, 'Step': 4, 'Loss': '1.5274'}
{'Epoch': 0, 'Step': 5, 'Loss': '1.7451'}
{'Epoch': 0, 'Step': 6, 'Loss': '1.9609'}
{'Epoch': 0, 'Step': 7, 'Loss': '0.9124'}
{'Epoch': 0, 'Step': 8, 'Loss': '1.7157'}
{'Epoch': 0, 'Step': 9, 'Loss': '1.6776'}
{'Epoch': 0, 'Step': 10, 'Loss': '1.8323'}
{'Epoch': 0, 'Step': 11, 'Loss': '1.4830'}

However, you should not be getting a segfault. That isn't good. Which version of MLX are you running? Anything else different in your setup?

Also I notice the GPU utilization is pretty poor which is also not good. It should be close to 100% GPU utilization during training so there should be a bottleneck somewhere that needs fixing.

@Blaizzy
Copy link
Owner Author

Blaizzy commented Oct 15, 2024

Thanks!

Wow, that's really weird.

Here is my setup:

prince_canuma@MacBook-Pro-3 ~ % pip list | grep mlx
fastmlx                                   0.2.1
mlx                                       0.18.0
mlx-embeddings                            0.0.1             /Users/prince_canuma/Documents/Projects/LLMs/mlx-embeddings
mlx-lm                                    0.19.0
mlx-vlm                                   0.1.0             /Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm

Also I notice the GPU utilization is pretty poor which is also not good. It should be close to 100% GPU utilization during training so there should be a bottleneck somewhere that needs fixing.

I suspect the dataset loading function. I know it's not the best but I thought it's an optimization for the next release this one already took long enough.

https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/trainer/trainer.py#L58

@awni
Copy link
Contributor

awni commented Oct 15, 2024

Could you try upgrading to the latest MLX (0.18.1) (and if it's used here MLX LM (0.19.1)) just to be sure we didn't fix something.. (I think this PR may be related: ml-explore/mlx#1452)

Also remind me what's your machine and OS?

I suspect the dataset loading function. I know it's not the best but I thought it's an optimization for the next release this one already took long enough.

Data loading is often the issue. And yes next release is quite reasonable.. just letting you know in case you didn't notice it.

@Blaizzy
Copy link
Owner Author

Blaizzy commented Oct 15, 2024

Upgrading to v0.18.1 fixed it! 🚀

Data loading is often the issue. And yes next release is quite reasonable.. just letting you know in case you didn't notice it.

Thank you! Do you have any tips specific to MLX?

When I started getting the error, I figure it could be the data loading so I made some initial rough optimizations like using a generator and deleting the batch after processing and using the metal clear cache command.

@Blaizzy
Copy link
Owner Author

Blaizzy commented Oct 15, 2024

Also remind me what's your machine and OS?

Macbook Pro 14-inch
Chip: M3 Max
URAM: 96GB
OS: Sonoma 14.5

@awni
Copy link
Contributor

awni commented Oct 15, 2024

Thank you! Do you have any tips specific to MLX?

First verify that data loading is in fact the issue. I would do that by using the same batch over and over instead of loading it and make sure the GPU utilization is close to 100%.

If data loading is the problem then look into what's actually slow. Is it the IO itself or some preprocessing steps?

  • If you preload the dataset into RAM it probably isn't the IO
  • Do you do the preprocessing in MLX? If not, maybe try doing that so it runs fast on the GPU..

and using the metal clear cache command.

I wouldn't manually clear the cache unless you have a really good reason. That will typically just slow everything down.

@Blaizzy
Copy link
Owner Author

Blaizzy commented Oct 15, 2024

Awesome, thanks!

If you preload the dataset into RAM it probably isn't the IO

Do you do the preprocessing in MLX? If not, maybe try doing that so it runs fast on the GPU..

I preload/prefetch the batch before running it.

Then probably is the HF processor I use here for preparing the inputs is the bottleneck.

I would do that by using the same batch over and over instead of loading it and make sure the GPU utilization is close to 100%.

Could you elaborate here, I didn't quite get it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants