|
| 1 | +import copy |
1 | 2 | from typing import Any, Dict |
2 | 3 |
|
| 4 | +import ray |
3 | 5 | import ray.util.collective as cc |
4 | 6 | import torch |
5 | 7 | import torch.distributed.distributed_c10d as c10d |
@@ -32,26 +34,121 @@ def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str = |
32 | 34 |
|
33 | 35 |
|
34 | 36 | def ray_broadcast_tensor_dict( |
35 | | - tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default" |
| 37 | + tensor_dict: Dict[str, torch.Tensor], |
| 38 | + src: int = 0, |
| 39 | + device=None, |
| 40 | + group_name: str = "default", |
| 41 | + backend: str = "nccl", |
| 42 | + offload_to_cpu: bool = False, |
| 43 | + pin_memory: bool = False, |
36 | 44 | ) -> Dict[str, torch.Tensor]: |
37 | 45 | rank = cc.get_rank(group_name) |
| 46 | + if tensor_dict is None: |
| 47 | + tensor_dict = {} |
38 | 48 | if rank == src: |
39 | 49 | metadata = [] |
40 | 50 | for k, v in tensor_dict.items(): |
41 | 51 | metadata.append((k, v.shape, v.dtype)) |
42 | 52 | else: |
43 | 53 | metadata = None |
44 | 54 | metadata = ray_broadcast_object(metadata, src, device, group_name) |
45 | | - if rank != src: |
46 | | - out_dict = {} |
47 | 55 | for k, shape, dtype in metadata: |
48 | 56 | if rank == src: |
49 | | - tensor = tensor_dict[k] |
| 57 | + if offload_to_cpu: |
| 58 | + tensor = tensor_dict[k].to(device) |
| 59 | + else: |
| 60 | + tensor = tensor_dict[k] |
50 | 61 | else: |
51 | | - tensor = torch.empty(shape, dtype=dtype, device=device) |
| 62 | + tensor = tensor_dict.get(k, torch.zeros(shape, dtype=dtype, device=device, pin_memory=pin_memory)) |
| 63 | + if backend == "gloo" and dtype == torch.bfloat16: |
| 64 | + # Gloo does not support bfloat16, convert to float16 |
| 65 | + tensor = tensor.view(torch.float16) |
52 | 66 | cc.broadcast(tensor, src, group_name) |
| 67 | + if backend == "gloo" and dtype == torch.bfloat16: |
| 68 | + # Convert back to bfloat16 if it was converted to float16 |
| 69 | + tensor = tensor.view(torch.bfloat16) |
53 | 70 | if rank != src: |
54 | | - out_dict[k] = tensor |
55 | | - if rank == src: |
56 | | - out_dict = tensor_dict |
57 | | - return out_dict |
| 71 | + if offload_to_cpu: |
| 72 | + tensor_dict[k] = tensor.cpu() |
| 73 | + else: |
| 74 | + tensor_dict[k] = tensor |
| 75 | + return tensor_dict |
| 76 | + |
| 77 | + |
| 78 | +@ray.remote |
| 79 | +class SharedVariableActor: |
| 80 | + def __init__(self, number_of_readers: int = 0, buffer_size_limit: int = 1000): |
| 81 | + self.data_queue = [] |
| 82 | + self.data_uid = 0 |
| 83 | + self.number_of_readers = number_of_readers |
| 84 | + self.queue_size = 0 |
| 85 | + self.signals = {} |
| 86 | + self.process_locks = {} |
| 87 | + self.signal_procs_meet_count = {} |
| 88 | + self.buffer_size_limit = buffer_size_limit |
| 89 | + |
| 90 | + def pickup_rollout_task(self, num_tasks: int): |
| 91 | + """ |
| 92 | + use queue size to control whether producers should generating new rollouts or wait |
| 93 | + for consumer to consumer more data. if queue size is less than threshold, |
| 94 | + it means consumer is consuming data fast enough, so producers can generate new rollouts. |
| 95 | + if queue size is greater than threshold, it means consumer is consuming data slowly, |
| 96 | + so producers should wait for consumer to consume more data. |
| 97 | +
|
| 98 | + Any free producer can pick up the task to generate rollout then increase the queued_data_size |
| 99 | + to prevent other producer to pick up the task redundantly, Note it is not the real |
| 100 | + queue length as data may still be generating |
| 101 | + """ |
| 102 | + ret = False |
| 103 | + if self.queue_size < (self.buffer_size_limit / max(0.1, self.signals.get("sample_utilization", 1.0))): |
| 104 | + ret = True |
| 105 | + self.queue_size += num_tasks |
| 106 | + return ret |
| 107 | + |
| 108 | + def append_data(self, data): |
| 109 | + self.data_queue.append([self.data_uid, data, 0]) # [data_uid, data, access_count] |
| 110 | + self.data_uid += 1 |
| 111 | + return True |
| 112 | + |
| 113 | + def get_data(self, data_uid: int): |
| 114 | + # for multi-process data reading |
| 115 | + if not self.data_queue: |
| 116 | + # no data in the queue, return None |
| 117 | + return None |
| 118 | + to_pop_index = None |
| 119 | + ret = None |
| 120 | + for i, (uid, data, access_count) in enumerate(self.data_queue): |
| 121 | + if uid == data_uid: |
| 122 | + # found the data with the given uid |
| 123 | + self.data_queue[i][2] += 1 |
| 124 | + ret = copy.deepcopy(data) |
| 125 | + if self.data_queue[i][2] == self.number_of_readers: |
| 126 | + to_pop_index = i |
| 127 | + break |
| 128 | + if to_pop_index is not None: |
| 129 | + # remove the data from the queue if it has been accessed by all readers |
| 130 | + self.data_queue.pop(to_pop_index) |
| 131 | + self.queue_size -= data["input_ids"].size(0) |
| 132 | + return ret |
| 133 | + |
| 134 | + def acquire_process_lock(self, key: str): |
| 135 | + # atomic lock for process |
| 136 | + if key not in self.process_locks: |
| 137 | + self.process_locks[key] = 1 # locked |
| 138 | + return 0 |
| 139 | + if self.process_locks[key] == 0: |
| 140 | + self.process_locks[key] = 1 # lock the process |
| 141 | + return 0 |
| 142 | + else: |
| 143 | + return 1 |
| 144 | + |
| 145 | + def release_process_lock(self, key: str): |
| 146 | + # atomic unlock for process |
| 147 | + assert self.process_locks.get(key, 0) == 1, f"Releasing a process lock {key} that is not locked." |
| 148 | + self.process_locks[key] = 0 |
| 149 | + |
| 150 | + def set_signal(self, key: str, signal: str): |
| 151 | + self.signals[key] = signal |
| 152 | + |
| 153 | + def get_signal(self): |
| 154 | + return self.signals |
0 commit comments