Skip to content

Commit

Permalink
[misc] fix: weak reference of WorkerDict in RayTrainer (#65)
Browse files Browse the repository at this point in the history
* [misc] fix: weak reference of WorkerDict in RayTrainer

* remove docker changes to next commit
  • Loading branch information
PeterSH6 authored Dec 30, 2024
1 parent 8254cb1 commit 94f4ca0
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/start/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ found in :doc:`FSDP Workers<../workers/fsdp_workers>`.
# install vllm
pip3 install vllm==0.6.3 # or you can install 0.5.4, 0.4.2 and 0.3.1
pip3 install ray==2.10 # other version may have bug
pip3 install ray
# flash attention 2
pip3 install flash-attn --no-build-isolation
Expand Down
6 changes: 3 additions & 3 deletions docs/workers/ray_trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ We first introduce a basic implementation of initializing the

.. code:: python
# Due to the Ray issue, we can only support max_colocate_count=1 for now.
# This means that each GPU can only have one process.
# We can support max_colocate > 1 when applying this pull request: https://github.com/ray-project/ray/pull/44385
# max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool
# For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.
# For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes,
use_gpu=True,
max_colocate_count=1)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies = [
"hydra-core",
"numpy",
"pybind11",
"ray==2.10",
"ray",
"tensordict",
"transformers",
"vllm<=0.6.3",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dill
hydra-core
numpy
pybind11
ray==2.10
ray
tensordict<0.6
transformers
vllm<=0.6.3
9 changes: 6 additions & 3 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ class ResourcePoolManager:

def create_resource_pool(self):
for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
# Due to the Ray issue, we can only support max_colocate_count=1 for now.
# This means that each GPU can only have one process.
# We can support max_colocate > 1 when applying this pull request: https://github.com/ray-project/ray/pull/44385
# max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool
# For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.
# For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models
resource_pool = RayResourcePool(process_on_nodes=process_on_nodes,
use_gpu=True,
max_colocate_count=1,
Expand Down Expand Up @@ -377,11 +377,14 @@ def init_workers(self):
# you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
all_wg = {}
self.wg_dicts = []
for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)
# keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699
self.wg_dicts.append(wg_dict)

if self.use_critic:
self.critic_wg = all_wg['critic']
Expand Down

0 comments on commit 94f4ca0

Please sign in to comment.