-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hangs during vllm rollout, no error message #12
Comments
Hi @Vamix, glad to hear that you successfully run qwen2-7b using 2 nodes. For the hanging issue, we didn't meet this before. From the vllm debug message, I guess it's stucked at the prefill stage (as Running: 2reqs). Let's see if these debugging methods could help. |
Hi @PeterSH6 , thanks for your reply, and I find two problems in your code that may lead to the hanging issue: (1) When (2) When the evaluation dataset size is not divisible by the world_size, and tensor parallel is enabled for vllm. Hope you can add more checks in the code to raise errors when some hyper-parameters are not set correctly. And more instructions on how to set the hyper-parameters will be appreciated (e.g., which hyper-parameters should be divisible by which one). However, even I have fixed the above issues, I still cannot run the distributed training with tensor parallel rollout successfully. I'm facing new issues of CUDA error:
I suspect it is still related to wrong setting of some hyper-parameters, my script is as blow:
Could you give some suggestions for debugging this? Thanks a lot! |
Hi @Vamix, Thanks for your constructive advice!
The reason that You are right that this may be confusing for users. We will discuss this issue and possibly add more tutorials/assertions for better usability.
Great findings! This could be a defect issue, we may add some dummy samples to align the eval dataset size. Do you have any quick workaround to contribute to verl directly?
Will do so. Thanks for your suggestion.
We also encounter this issue at random times when using qwen2.5 (other models seem to be fine). We found that this may related to an internal bug in flash_attn or vLLM. See vllm-project/vllm#5687 and vllm-project/vllm#5376 Can you try it using a different backend of vLLM |
Hi @PeterSH6 thanks for your reply and sorry for my late reply. |
Hi veRL team, thanks for open-sourcing the great framework. I have successfully run the ppo training of qwen2-7b using 2 nodes, so I think there is no problem with my environment. But I encountered an issue when trying to run ppo training of qwen2.5 32b model with 8 nodes.
The config is https://github.com/volcengine/verl/blob/main/examples/ppo_trainer/run_qwen2.5-32b.sh. First, I found it triggered OOM using the default setting, so I changed trainer.nnodes into 8. Then, when using 8 nodes to run the ppo training, I find it stuck during vllm rollout. Even turned on vllm debug flags, there is no error message, the last output I can see is
INFO 11-12 22:04:23 metrics.py:406] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 103.0 tokens/s, Running: 2 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 11.0%, CPU KV cache usage: 0.0%.
Have you ever run into this hang issue? Hope you can share some suggestions for debugging. Thanks a lot!
The text was updated successfully, but these errors were encountered: