-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OOM issue with same batch size that was running ok on 0.0.80 #184
Comments
Hello and thanks for reporting issue, can you share the code please? |
Hi @erfanzar ,
|
Can you rerun the code? There was an issue with the loss function, which wasn't using the fused version. |
And since the sharding mechanism you're using is tensor parallel you can except OOM but not on 1k sequence length |
In v0.0.80 trainer will automatically use gradient checkpointing (this behavior is removed in 0.1.0 and you should pass gradient_checkponiting to model_kwargs (ill take blame for not having good documentation)) |
You are right!. In 0.0.80, it was part of training arguments as we can see in this example :
However, it was removed in the recent updates
after updating the code with :
i was able to run SFT code with 8 batch size but i got couple of warning:
The issue of NaN is presented even with 0.0.80 every time i use meta-llama/Llama-3.1-8B-Instruct model with Packing=True. It would go in some runs and in some runs will appear. I have worked around this issue by re-running the script multiple time till i had no NaN, without changing any arguments. With Packing=False, the issue would disappear .This issue is not presented with other llama3.2 models. Last note on sharding_axis_dims = (1, 1, -1, 1) choice. This setting give me 113 FLOPS (0.0.8) with TPUv4-8 against other sharding axis setting (98 FLOPS) . Hence, that why i chose it over other options. |
Look flops calculation method is changed in last version every thing was manually calculated but in this version it's calculated from jax analysis so except it to be wrong for example you might be running and getting 160flops but in some parts xla play a bit dumb and show 130Flops Check this out |
Thank you for the detailed reply. I have tested the TFLOPs in term of runtime speed on both 0.0.80 and 0.1dev using the same example of WebQuestion with llama3.1 8B. As you can see from the results, there is a problem with the speed with the recent updated even when with using different sharding strategies. I let the code run for a while till s/it metric become stable. You can see that we have double the speed with 0.0.80. Also notice how (1,1,-1,1) is the best setting in term of speed for TPUv-8 as i stated earlier. Actually, the difference between (1,1,-1,1) and (1,1,1,-1) become worse (almost double) with the new update. I have also noticed that it will take a while (3-5 mins) before the script start running with 0.1.0.dev update, so we should add 3-5 mins to the runtime to have an accurate head-to-head comparison. 0.0.80(1 ,-1 , 1, 1)
(1,1,-1,1)
(1,1,1,-1)
-------------------------------------------------------------------------------------------0.1.0.dev(1,-1,1,1)
(1,1,-1,1)
(1,1,1,-1)
------------------------------------------------------------------------------------------------------- Script to run the code on 0.0.80
Script to run the code on 0.1.0.dev
|
Thank you @salrowili for bringing up these issues and for your detailed feedback! |
Great!. Thank you @erfanzar for opening the topic. I have one question. I am planning to start sharing my code with the topic you just opened #185, but i am still struggling to run my codes on the new EasyDEL 0.1dev release. Its very slow compared to 0.0.80 and you have told me that it due to flax/NNX integration. The inference with the new 0.1dev is fast, but the problem is with SFT code. Do you have any estimation when the issue will be fixed? Because if it will be soon, i will wait and till it fixed and share my codes with 0.1dev release. |
Hi @salrowili, Many performance issues related to the new arguments and the updated base trainer have been resolved. These include fixes for duplicated You can rerun your benchmark to see if there are any remaining performance issues (avoid using ahead-of-time compilation). With Qwen-2 7B, batch size 8, and full sequence parallelism, I was able to achieve 6 seconds per iteration. Let me know how it goes! |
Hi @erfanzar . That's a great news!. Can you share the code you have used to achieve this performance? |
@salrowili im using tests/trainer_test.py |
Hi @salrowili, I hope the performance issue is running smoothly now. If you’re still encountering any other problems, please let me know. I’m currently working on improving the speed of the training process and would be happy to assist you further. For quicker communication and to resolve any issues more efficiently, feel free to connect with me on Discord. My user ID is |
Hi @erfanzar . The issue is still there and you can verify it by using the the SFT code that i posted early. Currently, i am doing the SFT training with 0.0.80 version. How i can disable AOT compilation? I thought i should slow down in reporting issues since you may seems busy with other EasyDel issues, and this why i did not update you that the issue still exists (: . I will try to DM you on Discord soon |
0.0.80 is much faster especially when you run larger model. I ran llama3.1 8B and with 0.1dev it gave me 1.9s/it with ~2H against 1.24s/it with 1H 15m |
Hey @salrowili the bugs are now fixed ill push fix commit today and now 0.1 is %5~10 faster |
…various components, fixes issues related to trainers being slow in #184
Hi @salrowili, Could you please test and confirm whether the execution time and Based on my tests using V3-8:
Let me know your findings! |
Hi @erfanzar . I think i have figured out the root of the speed issue. I think you are testing the code on TPUv3-8 which has 8 chips each has single core. In contrast, TPUv4-8 has 4 chips each has 2 cores. Thus, maybe you are only using single core from each chip on TPUv4-8??. This also may impact how you calculate FLOPs. I having tested 0.1 on both TPUv3-8 and TPUv4-8 and TPUv3-8 is much faster (almost double). Also the trainer output is messed up as it print output vertically not horizontally as it used to. Here is Trainer output for TPUv3-8 and TPUv4-8. The code for both is identical including the sharding method. Observe the step time. TPUv3-8
TPUv4-8
|
@salrowili For logging during training, you have the flexibility to customize how you log or add logging hooks. By default, the logging method is set to If the current logging format doesn't suit your preferences, reverting to |
You are right. However, i think it would be better to have the default progress bar type to tqdm because json will flood the terminal with loggings especially when you set the log step to be smaller. I have re-run the code and fix an issue that the dataset in TPUv3-8 uses the cached dataset, not the updated one. With total runtime we can see that TPUv3-8 is much faster than TPUv4-8 so the issue is not related to how we calculate metrics (e.g. step time). TPUv3-8
TPUv4-8
Also one thing related to jax jaxlib version. We need to fix the requirement of the jax jaxlib version in the mean time to 0.4.35. Lines 36 to 37 in d13ecbb
This is because the function core.new_main was depreciated after jax >0.4.35 . This function is used by fjformer repo : However, removing the new_main function and updating the code would a better solution since the new jax 0.5.0 has been released. |
Hi,
I've noticed that recent updates are causing the SFT trainer code to throw an OutOfMemory (OOM) error with the same batch size that previously ran without issue on version 0.0.80.
I attempted SFT tuning using bfloat16 (no LoRA) with LLaMA 3.1 8B, max_length=1024, and batch=8 on TPUv4-8, but encountered an OOM error. This fine-tuning setup was working ok with 0.0.80.
The text was updated successfully, but these errors were encountered: