-
-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Nan loss when training Llama-3.2-vision #84
Comments
cc: @awni |
So there are a couple things you should change in general about your Llama implementation:
These will both be (much) faster and numerically more stable. The NaNs are getting introduced during overflow in your RMSNorm implementation. Typically whenever you accumulate a lot of numbers you need to accumulate the result in a higher precision (so I double check most of your model files are using |
Thanks a lot! Yes, I was using a custom RMSNorm, I changed it to When it comes to rope I was already using
How did you check this? |
@awni I made the recommended changes but I can't seem to be able to run training on my machine (M3 Max 96GB). It throws an error after processing 3 samples even with batch size of 1.
Could you please try it on your M2 ultra and see if the nan loss persists? |
It's running on my M1 Max (32GB) with this command:
and the modifications to the dataset you posted above. So far it processed 11 steps no problem (I modified the print to print every step):
However, you should not be getting a segfault. That isn't good. Which version of MLX are you running? Anything else different in your setup? Also I notice the GPU utilization is pretty poor which is also not good. It should be close to 100% GPU utilization during training so there should be a bottleneck somewhere that needs fixing. |
Thanks! Wow, that's really weird. Here is my setup:
I suspect the dataset loading function. I know it's not the best but I thought it's an optimization for the next release this one already took long enough. https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/trainer/trainer.py#L58 |
Could you try upgrading to the latest MLX (0.18.1) (and if it's used here MLX LM (0.19.1)) just to be sure we didn't fix something.. (I think this PR may be related: ml-explore/mlx#1452) Also remind me what's your machine and OS?
Data loading is often the issue. And yes next release is quite reasonable.. just letting you know in case you didn't notice it. |
Upgrading to v0.18.1 fixed it! 🚀
Thank you! Do you have any tips specific to MLX? When I started getting the error, I figure it could be the data loading so I made some initial rough optimizations like using a |
Macbook Pro 14-inch |
First verify that data loading is in fact the issue. I would do that by using the same batch over and over instead of loading it and make sure the GPU utilization is close to 100%. If data loading is the problem then look into what's actually slow. Is it the IO itself or some preprocessing steps?
I wouldn't manually clear the cache unless you have a really good reason. That will typically just slow everything down. |
Awesome, thanks!
I preload/prefetch the batch before running it. Then probably is the HF processor I use here for preparing the inputs is the bottleneck.
Could you elaborate here, I didn't quite get it. |
Issue
I keep getting
nan
loss when training Llama-3.2-visionI tried:
But with no success.
Steps to reproduce:
pc/llama3.2-vision
branch.lora.py
to limit the dataset.The text was updated successfully, but these errors were encountered: