Skip to content

Commit b1915d2

Browse files
authored
Merge pull request #6391 from hpcaitech/grpo-zero-bubble-rebase
[feat] Add zero-bubble support for RL
2 parents e5fdefa + eb158eb commit b1915d2

File tree

16 files changed

+2429
-23
lines changed

16 files changed

+2429
-23
lines changed

.github/workflows/run_chatgpt_examples.yml

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
2020
runs-on: [self-hosted, ubuntu-latest]
2121
container:
22-
image: image-cloud.luchentech.com/hpcaitech/pytorch-cuda:2.2.2-12.1.0
22+
image: image-cloud.luchentech.com/hpcaitech/pytorch-cuda:2.5.1-12.4.1
2323
options: --gpus all --rm -v /data/scratch/examples-data:/data/scratch/examples-data --shm-size=10.24gb
2424
timeout-minutes: 180
2525
defaults:
@@ -29,24 +29,32 @@ jobs:
2929
- name: Checkout ColossalAI
3030
uses: actions/checkout@v2
3131

32+
- name: Install torch
33+
run: |
34+
pip uninstall flash-attn
35+
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
36+
37+
- name: Install flash-attn
38+
run: |
39+
pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
40+
3241
- name: Install Colossal-AI
3342
run: |
34-
pip install --no-cache-dir -v -e .
43+
BUILD_EXT=1 pip install --no-cache-dir -v -e .
3544
3645
- name: Install ChatGPT
3746
env:
3847
CFLAGS: "-O1"
3948
CXXFLAGS: "-O1"
4049
MAX_JOBS: 4
4150
run: |
42-
pip install flash-attn --no-build-isolation
4351
cd applications/ColossalChat
44-
pip install --no-cache-dir -v .
52+
pip install --no-cache-dir -v -e .
4553
pip install --no-cache-dir -r examples/requirements.txt
4654
47-
- name: Install Transformers
48-
run: |
49-
pip install --no-cache-dir transformers==4.36.2
55+
# - name: Install Transformers
56+
# run: |
57+
# pip install --no-cache-dir transformers==4.36.2
5058

5159
- name: Execute Examples
5260
run: |

applications/ColossalChat/coati/distributed/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ This repository implements a distributed Reinforcement Learning (RL) training fr
1414
* **Rollout and Policy Decoupling**: Efficient generation and consumption of data through parallel inferencer-trainer architecture.
1515
* **Evaluation Integration**: Easily plug in task-specific eval datasets.
1616
* **Checkpoints and Logging**: Configurable intervals and directories.
17+
* **[New]**: Zero Bubble training framework that supports GRPO and DAPO. [(read more)](./zero_bubble/README.md)
1718

1819
---
1920

applications/ColossalChat/coati/distributed/comm.py

Lines changed: 106 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import copy
12
from typing import Any, Dict
23

4+
import ray
35
import ray.util.collective as cc
46
import torch
57
import torch.distributed.distributed_c10d as c10d
@@ -32,26 +34,121 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str =
3234

3335

3436
def ray_broadcast_tensor_dict(
35-
tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
37+
tensor_dict: Dict[str, torch.Tensor],
38+
src: int = 0,
39+
device=None,
40+
group_name: str = "default",
41+
backend: str = "nccl",
42+
offload_to_cpu: bool = False,
43+
pin_memory: bool = False,
3644
) -> Dict[str, torch.Tensor]:
3745
rank = cc.get_rank(group_name)
46+
if tensor_dict is None:
47+
tensor_dict = {}
3848
if rank == src:
3949
metadata = []
4050
for k, v in tensor_dict.items():
4151
metadata.append((k, v.shape, v.dtype))
4252
else:
4353
metadata = None
4454
metadata = ray_broadcast_object(metadata, src, device, group_name)
45-
if rank != src:
46-
out_dict = {}
4755
for k, shape, dtype in metadata:
4856
if rank == src:
49-
tensor = tensor_dict[k]
57+
if offload_to_cpu:
58+
tensor = tensor_dict[k].to(device)
59+
else:
60+
tensor = tensor_dict[k]
5061
else:
51-
tensor = torch.empty(shape, dtype=dtype, device=device)
62+
tensor = tensor_dict.get(k, torch.zeros(shape, dtype=dtype, device=device, pin_memory=pin_memory))
63+
if backend == "gloo" and dtype == torch.bfloat16:
64+
# Gloo does not support bfloat16, convert to float16
65+
tensor = tensor.view(torch.float16)
5266
cc.broadcast(tensor, src, group_name)
67+
if backend == "gloo" and dtype == torch.bfloat16:
68+
# Convert back to bfloat16 if it was converted to float16
69+
tensor = tensor.view(torch.bfloat16)
5370
if rank != src:
54-
out_dict[k] = tensor
55-
if rank == src:
56-
out_dict = tensor_dict
57-
return out_dict
71+
if offload_to_cpu:
72+
tensor_dict[k] = tensor.cpu()
73+
else:
74+
tensor_dict[k] = tensor
75+
return tensor_dict
76+
77+
78+
@ray.remote
79+
class SharedVariableActor:
80+
def __init__(self, number_of_readers: int = 0, buffer_size_limit: int = 1000):
81+
self.data_queue = []
82+
self.data_uid = 0
83+
self.number_of_readers = number_of_readers
84+
self.queue_size = 0
85+
self.signals = {}
86+
self.process_locks = {}
87+
self.signal_procs_meet_count = {}
88+
self.buffer_size_limit = buffer_size_limit
89+
90+
def pickup_rollout_task(self, num_tasks: int):
91+
"""
92+
use queue size to control whether producers should generating new rollouts or wait
93+
for consumer to consumer more data. if queue size is less than threshold,
94+
it means consumer is consuming data fast enough, so producers can generate new rollouts.
95+
if queue size is greater than threshold, it means consumer is consuming data slowly,
96+
so producers should wait for consumer to consume more data.
97+
98+
Any free producer can pick up the task to generate rollout then increase the queued_data_size
99+
to prevent other producer to pick up the task redundantly, Note it is not the real
100+
queue length as data may still be generating
101+
"""
102+
ret = False
103+
if self.queue_size < (self.buffer_size_limit / max(0.1, self.signals.get("sample_utilization", 1.0))):
104+
ret = True
105+
self.queue_size += num_tasks
106+
return ret
107+
108+
def append_data(self, data):
109+
self.data_queue.append([self.data_uid, data, 0]) # [data_uid, data, access_count]
110+
self.data_uid += 1
111+
return True
112+
113+
def get_data(self, data_uid: int):
114+
# for multi-process data reading
115+
if not self.data_queue:
116+
# no data in the queue, return None
117+
return None
118+
to_pop_index = None
119+
ret = None
120+
for i, (uid, data, access_count) in enumerate(self.data_queue):
121+
if uid == data_uid:
122+
# found the data with the given uid
123+
self.data_queue[i][2] += 1
124+
ret = copy.deepcopy(data)
125+
if self.data_queue[i][2] == self.number_of_readers:
126+
to_pop_index = i
127+
break
128+
if to_pop_index is not None:
129+
# remove the data from the queue if it has been accessed by all readers
130+
self.data_queue.pop(to_pop_index)
131+
self.queue_size -= data["input_ids"].size(0)
132+
return ret
133+
134+
def acquire_process_lock(self, key: str):
135+
# atomic lock for process
136+
if key not in self.process_locks:
137+
self.process_locks[key] = 1 # locked
138+
return 0
139+
if self.process_locks[key] == 0:
140+
self.process_locks[key] = 1 # lock the process
141+
return 0
142+
else:
143+
return 1
144+
145+
def release_process_lock(self, key: str):
146+
# atomic unlock for process
147+
assert self.process_locks.get(key, 0) == 1, f"Releasing a process lock {key} that is not locked."
148+
self.process_locks[key] = 0
149+
150+
def set_signal(self, key: str, signal: str):
151+
self.signals[key] = signal
152+
153+
def get_signal(self):
154+
return self.signals

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
generate_config: Dict[str, Any],
6060
tokenizer: PreTrainedTokenizer,
6161
num_generations: int = 8,
62+
tokenizer_config: Dict[str, Any] = None,
6263
):
6364
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
6465
model_config.update(self.FORCE_MODEL_CONFIG)
@@ -132,6 +133,7 @@ def __init__(
132133
generate_config: Dict[str, Any],
133134
tokenizer: PreTrainedTokenizer,
134135
num_generations: int = 8,
136+
tokenizer_config: Dict[str, Any] = None,
135137
):
136138
if sgl is None:
137139
raise ImportError("sglang is not installed")
@@ -196,12 +198,14 @@ def __init__(
196198
generate_config: Dict[str, Any],
197199
tokenizer: PreTrainedTokenizer,
198200
num_generations: int = 8,
201+
tokenizer_config: Dict[str, Any] = None,
199202
):
200203
if LLM is None:
201204
raise ImportError("vllm is not installed")
202205
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
203206
path = model_config.pop("path")
204-
self.llm = LLM(model=path, **model_config)
207+
tokenizer_path = tokenizer_config.get("path", None) if tokenizer_config is not None else None
208+
self.llm = LLM(model=path, tokenizer=tokenizer_path, **model_config)
205209
generate_config = generate_config.copy()
206210
generate_config.update(self.FORCE_GENERATE_CONFIG)
207211
generate_config.update({"n": num_generations})

0 commit comments

Comments
 (0)