Skip to content

Commit d9acc2d

Browse files
committed
Group allreduce futures
1 parent 7ab4fd4 commit d9acc2d

File tree

5 files changed

+102
-67
lines changed

5 files changed

+102
-67
lines changed

torchft/collectives.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def allocate_reduce_scatter_output(
135135
return tensor, padded_sizes
136136

137137

138-
class _QuantizedOpFuture(Future[None]):
138+
class _QuantizedOpFuture(Future[list[torch.Tensor]]):
139139
def __init__(
140140
self,
141141
sync_stream: cuda.Stream,
@@ -145,11 +145,12 @@ def __init__(
145145
self._sync_stream = sync_stream
146146
self._keep_alive_tensors = keep_alive_tensors
147147

148-
def wait(self) -> None:
148+
def wait(self) -> list[torch.Tensor]:
149149
# Wait for the synchronization to complete.
150150
cuda.current_stream().wait_stream(self._sync_stream)
151151
# Clean up intermediate buffers.
152152
del self._keep_alive_tensors
153+
return []
153154

154155

155156
def reduce_scatter_quantized(
@@ -284,7 +285,7 @@ def allreduce_quantized(
284285
opts: AllreduceOptions | ReduceOp,
285286
process_group: "ProcessGroup",
286287
sync_stream: cuda.Stream | None = None,
287-
) -> Future[None]:
288+
) -> Future[list[torch.Tensor]]:
288289
"""
289290
Performs a quantized all-reduce operation on a list of tensors.
290291
@@ -314,6 +315,8 @@ def allreduce_quantized(
314315
A Future that can be used to wait for the operation to complete and
315316
clean up intermediate buffers.
316317
318+
The future's value is set to an empty list
319+
317320
Raises:
318321
NotImplementedError: If the reduce operation is not ReduceOp.AVG.
319322
"""

torchft/local_sgd.py

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -148,15 +148,15 @@ def _average(self) -> list[torch.Tensor]:
148148
"""
149149
Averages the model parameters across the manager and returns the averaged parameters.
150150
"""
151-
works = []
152151
averaged_parameters = []
153152
for p in self._model.parameters():
154153
# Create a new tensor to store the averaged parameter
155154
avg_param = extract_local_tensor(p)
156-
works.append(self._manager.allreduce(avg_param))
157155
averaged_parameters.append(avg_param)
158-
for work in works:
159-
work.wait()
156+
157+
work = self._manager.collect_all_allreduce(averaged_parameters)
158+
work.wait()
159+
160160
return averaged_parameters
161161

162162

@@ -197,9 +197,7 @@ def __init__(
197197
self._outer_optimizer = outer_optimizer
198198

199199
# Stores pending all reduce
200-
self._allreduce_futures: list[
201-
torch.futures.Future[None] | torch.futures.Future[torch.Tensor]
202-
] = []
200+
self._allreduce_futures: list[torch.futures.Future[None]] = []
203201

204202
if bucket_cap_mb is not None:
205203
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)
@@ -324,18 +322,27 @@ def _average_grads(self) -> None:
324322

325323
def _allreduce_per_param(self) -> None:
326324
"""Performs allreduce on each gradient tensor separately (original method)."""
325+
tensors = []
326+
327327
for p in self._model_fragment.parameters():
328328
# Perform allreduce on the pseudogradients
329329
assert p.grad is not None
330330
if isinstance(p, DTensor):
331-
work = self._manager.allreduce(
332-
p.grad._local_tensor, should_quantize=self.should_quantize
333-
)
331+
tensors.append(p.grad._local_tensor)
334332
else:
335-
work = self._manager.allreduce(
336-
p.grad, should_quantize=self.should_quantize
337-
)
338-
self._allreduce_futures.append(work)
333+
tensors.append(p.grad)
334+
335+
work = self._manager.collect_all_allreduce(
336+
tensors, should_quantize=self.should_quantize
337+
)
338+
339+
def callback(
340+
fut: torch.futures.Future[List[torch.futures.Future[torch.Tensor]]],
341+
) -> None:
342+
return
343+
344+
work = work.then(callback)
345+
self._allreduce_futures.append(work)
339346

340347
def bucketize_and_allreduce(
341348
self,
@@ -355,6 +362,9 @@ def bucketize_and_allreduce(
355362
total_size = sum(t.numel() for t in tensors)
356363
dtype, device = tensors[0].dtype, tensors[0].device
357364

365+
flat_buffers: list[torch.Tensor] = []
366+
all_bucket_tensors: list[list[Tuple[torch.Tensor, int, int]]] = []
367+
358368
offset = 0
359369
flat_index = 0
360370
while offset < total_size:
@@ -376,19 +386,27 @@ def bucketize_and_allreduce(
376386
pack_offset += numel
377387
flat_index += 1
378388

379-
work = self._manager.allreduce(
380-
flat_buffer, should_quantize=self.should_quantize
381-
)
389+
flat_buffers.append(flat_buffer)
390+
all_bucket_tensors.append(bucket_tensors)
391+
392+
offset += chunk_size
382393

383-
def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
384-
nonlocal bucket_tensors, flat_buffer
394+
def callback(
395+
fut: torch.futures.Future[List[torch.futures.Future[torch.Tensor]]],
396+
) -> None:
397+
nonlocal all_bucket_tensors, flat_buffers
398+
399+
for i in range(len(flat_buffers)):
400+
bucket_tensors = all_bucket_tensors[i]
401+
flat_buffer = flat_buffers[i]
385402
for t, pack_offset, numel in bucket_tensors:
386403
t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t))
387404

388-
work = work.then(callback)
389-
self._allreduce_futures.append(work)
390-
391-
offset += chunk_size
405+
work = self._manager.collect_all_allreduce(
406+
flat_buffers, should_quantize=self.should_quantize
407+
)
408+
work = work.then(callback)
409+
self._allreduce_futures.append(work)
392410

393411
def _allreduce_bucketized(self) -> None:
394412
"""
@@ -465,16 +483,6 @@ def __init__(
465483
if fragment_update_alpha < 0 or fragment_update_alpha > 1:
466484
raise ValueError("fragment_update_alpha must be between 0 and 1")
467485

468-
# TODO: Support multiple fragments
469-
# This requires changing the manager to support `should_commit` for each
470-
# fragment separately.
471-
if len(model_fragments) != 1:
472-
raise ValueError("Multiple fragments are not supported yet")
473-
474-
# TODO: Support `fragment_sync_delay`
475-
if fragment_sync_delay != 0:
476-
raise ValueError("Fragment synchronization delay is not supported yet")
477-
478486
# TODO: Support `fragment_update_alpha`
479487
if fragment_update_alpha != 0.0:
480488
raise ValueError(

torchft/local_sgd_test.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Dict
7+
from typing import Dict, List
88
from unittest import TestCase
99
from unittest.mock import MagicMock, create_autospec
1010

@@ -74,7 +74,7 @@ def test_local_sgd_healthy(self) -> None:
7474
manager.should_commit.return_value = True
7575
self.assertEqual(local_sgd._local_step, 0)
7676
self.assertEqual(manager.should_commit.call_count, 1)
77-
self.assertEqual(manager.allreduce.call_count, 4)
77+
self.assertEqual(manager.collect_all_allreduce.call_count, 1)
7878

7979
def test_extract_local_tensor(self) -> None:
8080
regular_tensor = torch.rand(3, 3, requires_grad=True)
@@ -160,7 +160,7 @@ def test_diloco_healthy(self) -> None:
160160
diloco._fragments[0].original_parameters, _params_dict(model)
161161
)
162162
self.assertEqual(manager.should_commit.call_count, 1)
163-
self.assertEqual(manager.allreduce.call_count, parameter_count)
163+
self.assertEqual(manager.collect_all_allreduce.call_count, 1)
164164

165165
outer_opt_state = outer_optimizer.state_dict()
166166
self.assertEqual(len(outer_opt_state["state"]), parameter_count)
@@ -207,13 +207,12 @@ def test_diloco_allreduce_call_efficiency(
207207
loss.backward()
208208
inner_optimizer.step()
209209

210-
allreduce_calls = manager.allreduce.call_count
211-
param_count = len([p for p in model.parameters() if p.requires_grad])
210+
allreduce_calls = manager.collect_all_allreduce.call_count
212211

213212
if expect_fewer_calls:
214-
self.assertLess(int(allreduce_calls), int(param_count))
213+
self.assertEqual(int(allreduce_calls), 1)
215214
else:
216-
self.assertEqual(int(allreduce_calls), int(param_count))
215+
self.assertEqual(int(allreduce_calls), 1)
217216

218217
def test_bucketization_correctness(self) -> None:
219218
class TinyModel(nn.Module):
@@ -238,16 +237,20 @@ def forward(self, x):
238237
manager._use_async_quorum = False
239238
manager.should_commit.return_value = True
240239

241-
# Define fake allreduce: multiplies buffer by 2
242-
def fake_allreduce(
243-
tensor: Tensor, should_quantize: bool
244-
) -> torch.futures.Future[Tensor]:
245-
tensor.mul_(2)
240+
# Define fake collect_all_allreduce: multiplies all buffers by 2
241+
def fake_collect_all_allreduce(
242+
tensors: List[Tensor], should_quantize: bool
243+
) -> torch.futures.Future[List[torch.futures.Future[Tensor]]]:
244+
for tensor in tensors:
245+
tensor.mul_(2)
246246
fut = torch.futures.Future() # pyre-fixme[29]: not a function
247-
fut.set_result(tensor)
248-
return fut
247+
fut.set_result(tensors)
249248

250-
manager.allreduce.side_effect = fake_allreduce
249+
futs = torch.futures.Future() # pyre-fixme[29]: not a function
250+
futs.set_result([fut])
251+
return futs
252+
253+
manager.collect_all_allreduce.side_effect = fake_collect_all_allreduce
251254

252255
diloco = DiLoCo(
253256
manager, [model], inner_opt, outer_opt, sync_every=2, use_bucketization=True

torchft/manager.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,33 @@ def shutdown(self, wait: bool = True) -> None:
288288
self._manager.shutdown()
289289
self._executor.shutdown(wait=wait)
290290

291+
def collect_all_allreduce(
292+
self, tensors: List[torch.Tensor], should_quantize: bool = False
293+
) -> torch.futures.Future[List[torch.futures.Future[torch.Tensor]]]:
294+
futs: List[torch.futures.Future[torch.Tensor]] = []
295+
default_futs: List[torch.futures.Future[torch.Tensor]] = []
296+
297+
for tensor in tensors:
298+
fut = self.allreduce(tensor, should_quantize=should_quantize)
299+
futs.append(fut)
300+
301+
default_fut = torch.futures.Future() # pyre-fixme[29]: not a function
302+
default_fut.set_result(tensor)
303+
default_futs.append(default_fut)
304+
305+
fut = torch.futures.collect_all(futs)
306+
307+
return self.wrap_future(fut, default_futs)
308+
291309
def allreduce(
292310
self, tensor: torch.Tensor, should_quantize: bool = False
311+
) -> torch.futures.Future[torch.Tensor]:
312+
fut = self._allreduce(tensor, should_quantize=should_quantize)
313+
fut = self.wrap_future(fut, tensor)
314+
return fut
315+
316+
def _allreduce(
317+
self, tensor: torch.Tensor, should_quantize: bool = False
293318
) -> torch.futures.Future[torch.Tensor]:
294319
"""
295320
Fault tolerant allreduce the tensor and return a Future that will be completed when
@@ -324,9 +349,8 @@ def allreduce(
324349
# Run the allreduce async and save the work object so we can wait on
325350
# it later.
326351
fut: Optional[
327-
torch.futures.Future[None]
352+
torch.futures.Future[List[torch.Tensor]]
328353
| torch.futures.Future[torch.Tensor]
329-
| torch.futures.Future[List[torch.Tensor]]
330354
] = None
331355
if should_quantize and IS_TRITON_AVAILABLE:
332356
fut = allreduce_quantized([tensor], ReduceOp.AVG, self._pg)
@@ -341,19 +365,16 @@ def callback(
341365
) -> torch.Tensor:
342366
nonlocal tensor
343367

344-
# check for exceptions
345368
fut.value()
346369

347-
tensor /= self.num_participants()
370+
if not should_quantize:
371+
tensor /= self.num_participants()
348372

349373
return tensor
350374

351375
assert fut is not None
352-
if not should_quantize:
353-
fut = fut.then(callback)
354-
fut = self.wrap_future(fut, tensor)
376+
fut = fut.then(callback)
355377
return fut
356-
357378
except Exception as e:
358379
self._logger.exception(
359380
f"got exception in all reduce -- skipping remaining: {e}"
@@ -686,21 +707,24 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
686707
Raises:
687708
RuntimeError: if should_commit fails max_retries times in a row and max_retries is set
688709
"""
689-
for work in self._pending_work:
690-
# check at the beginning of since .wait() may trigger errors
691-
if self._errored is not None:
710+
while True:
711+
if len(self._pending_work) == 0:
692712
break
693713

714+
work = self._pending_work.pop(0)
694715
# We swallow the error at in a future then callback so this will
695716
# never return an error.
696717
work.wait()
697718

719+
# Remove all work if there was an error.
720+
# We won't commit in this case as well.
721+
if self._errored is None:
722+
break
723+
698724
# make sure recovery is complete before committing
699725
if self._recovery_stream is not None:
700726
self._recovery_stream.synchronize()
701727

702-
self._pending_work = []
703-
704728
if err := self._pg.errored():
705729
self.report_error(err)
706730

torchft/process_group.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -774,9 +774,6 @@ def abort(self) -> None:
774774
super().abort()
775775

776776
def errored(self) -> Optional[Exception]:
777-
# force a synchronization to ensure all work is complete
778-
torch.cuda.synchronize()
779-
780777
return self._errored
781778

782779
def getBackendName(self) -> str:

0 commit comments

Comments
 (0)