Skip to content

Commit

Permalink
Add a post-backward stream for more throughput improvement (#11)
Browse files Browse the repository at this point in the history
* limit all-gathers

* Add another stream for post-backward stuff

* post-reduce-scatter

* updates

* fix

* revert

* show average throughput at the end

* exclude first batch

* init fn

* make mypy happy
  • Loading branch information
epwalsh authored Apr 17, 2024
1 parent 2cfd3a7 commit 8c75d23
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 22 deletions.
3 changes: 3 additions & 0 deletions src/benchmarks/fsdp/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ def auto_wrap_policy(module: nn.Module, recurse: bool, *args, **kwargs) -> bool:
use_orig_params=True,
param_init_fn=lambda m: init_function(m.to_empty(device=get_default_device())),
device_id=dist.get_rank(),
limit_all_gathers=True,
)

model.apply(init_function) # just in case
else:
raise NotImplementedError(fsdp_wrapper)

Expand Down
12 changes: 10 additions & 2 deletions src/benchmarks/fsdp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import contextlib
import logging
import time
from collections import deque
from pathlib import Path
from typing import Literal, Optional

Expand Down Expand Up @@ -78,6 +79,7 @@ def on_trace_ready(p):
)

print_rank0("Starting training...")
batch_times: deque[float] = deque([], 50)
with profiler as p:
for i, batch in enumerate(iter(dataloader)):
log.debug("Batch: %s", batch)
Expand All @@ -99,16 +101,22 @@ def on_trace_ready(p):
# Take optimizer step.
optim.step()

batch_end = time.monotonic()
batch_time = time.monotonic() - batch_start
if i > 0:
batch_times.append(batch_time)
print_rank0(
f"Batch [{i+1}/{num_batches}]:\n"
f" loss={loss.item():.3f}\n"
f" throughput/seconds_per_batch={batch_end-batch_start:.3f}",
f" throughput/seconds_per_batch={batch_time:.3f}",
)

if p is not None:
p.step()

if batch_times:
time_per_batch = sum(batch_times) / len(batch_times)
print_rank0(f"Average throughput: {time_per_batch:.3f}s/b")

if save_path is not None:
checkpoint_dir = Path(save_path) / "final"
print_rank0(f"Saving checkpoint to {checkpoint_dir}...")
Expand Down
74 changes: 56 additions & 18 deletions src/olmo_core/distributed/fsdp/flat_param_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class FlatParamHandle:
the same shape as the sharded version of ``params_data``.
"""

params_sharded_grad_tmp: Optional[torch.Tensor] = None
"""
Temporary storage for the local consolidated sharded grads during the reduce-scatter.
"""

process_group: Optional[dist.ProcessGroup] = None

device: Optional[torch.device] = None
Expand All @@ -69,6 +74,8 @@ class FlatParamHandle:

_ran_pre_unshard: bool = False

_ran_pre_reduce_scatter_grads: bool = False

@classmethod
def shard_params(
cls,
Expand Down Expand Up @@ -279,6 +286,8 @@ def unshard_(
if self.params_sharded_data_lp is not None:
Stream.current(self.device).record_for(self.params_sharded_data_lp)

self._ran_pre_unshard = False

# Gather full, padded, unsharded data for all params.
if rank0_only or dist.get_backend() == dist.Backend.GLOO:
assert self.params_data.is_sharded
Expand Down Expand Up @@ -325,8 +334,6 @@ def reshard_(self, writeback: bool = False):
"""
Reshard the handle's managed flat parameters in-place.
"""
self._ran_pre_unshard = False

if not self.params:
return

Expand All @@ -339,43 +346,74 @@ def reshard_(self, writeback: bool = False):
flat_param.data = self.params_data[offset : offset + flat_param.sharded_numel]
offset += flat_param.sharded_numel

def pre_reduce_scatter_grads_(
self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None
):
"""
Cast the unsharded, padded gradient to the reduce dtype if needed.
"""
self._ran_pre_reduce_scatter_grads = True

if not self.requires_grad or self.params_unsharded_grad is None:
return

grad_dtype = grad_dtype or self.params_data.dtype
grad_reduce_dtype = grad_reduce_dtype or grad_dtype

if grad_reduce_dtype != self.params_unsharded_grad.dtype:
Stream.current(self.device).record_for(self.params_unsharded_grad)
self.params_unsharded_grad = self.params_unsharded_grad.to(dtype=grad_reduce_dtype)

self.params_sharded_grad_tmp = torch.empty(
self.params_data.sharded_shape, dtype=self.params_unsharded_grad.dtype, device=self.device
)

def reduce_scatter_grads_(
self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None
):
"""
Reduce-scatter the unsharded, padded gradient, and set the ``.grad`` attribute of each
parameter as a view into the new sharded grad.
"""
if not self.requires_grad or self.params_unsharded_grad is None:
return

local_rank = get_rank(self.process_group)
if not self._ran_pre_reduce_scatter_grads:
self.pre_reduce_scatter_grads_(grad_dtype=grad_dtype, grad_reduce_dtype=grad_reduce_dtype)
assert self.params_sharded_grad_tmp is not None
else:
assert self.params_sharded_grad_tmp is not None
Stream.current(self.device).record_for(self.params_unsharded_grad)
Stream.current(self.device).record_for(self.params_sharded_grad_tmp)

self._ran_pre_reduce_scatter_grads = False

grad_dtype = grad_dtype or self.params_data.dtype
grad_reduce_dtype = grad_reduce_dtype or grad_dtype
assert self.params_unsharded_grad.dtype == grad_reduce_dtype

# Reduce the unsharded padded grad for all params.
# NOTE: Only NCCL supports reduce-scatter. So with other backends we use all-reduce.
all_params_unsharded_padded_grad = self.params_unsharded_grad.to(dtype=grad_reduce_dtype or grad_dtype)
if dist.get_backend() == dist.Backend.NCCL:
# Get chunks corresponding to each rank.
grad_chunks = self.params_data.chunk_unsharded(all_params_unsharded_padded_grad)
new_sharded_grad = torch.empty_like(grad_chunks[local_rank])
dist.reduce_scatter(new_sharded_grad, grad_chunks, group=self.process_group)
new_sharded_grad = new_sharded_grad.to(dtype=grad_dtype)
grad_chunks = self.params_data.chunk_unsharded(self.params_unsharded_grad)
dist.reduce_scatter(self.params_sharded_grad_tmp, grad_chunks, group=self.process_group)
else:
dist.all_reduce(all_params_unsharded_padded_grad, group=self.process_group)
new_sharded_grad = self.params_data.sharded_chunk(all_params_unsharded_padded_grad).to(
dtype=grad_dtype
)
dist.all_reduce(self.params_unsharded_grad, group=self.process_group)
self.params_sharded_grad_tmp.copy_(self.params_data.sharded_chunk(self.params_unsharded_grad))

# Deallocate the unsharded padded grad.
del all_params_unsharded_padded_grad
# Since we're potentially using a separate stream for this reduce-scatter, we need to make
# NOTE: Since we're potentially using a separate stream for this reduce-scatter, we need to make
# sure `params_unsharded_grad` is not deallocated before the reduce-scatter finishes.
Stream.current(self.device).record_for(self.params_unsharded_grad)
self.params_unsharded_grad = None

# Cast the reduce-scatter target to the right dtype, potentially accumulating it into
# the existing gradient.
if self.params_sharded_grad is None:
self.params_sharded_grad = new_sharded_grad
self.params_sharded_grad = self.params_sharded_grad_tmp.to(grad_dtype)
else:
self.params_sharded_grad.add_(new_sharded_grad)

del new_sharded_grad
self.params_sharded_grad.add_(self.params_sharded_grad_tmp)

# At this point each param will be sharded again, and we set the grad for each param as a view
# into the sharded grad.
Expand Down
11 changes: 9 additions & 2 deletions src/olmo_core/distributed/fsdp/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def _lazy_init(self):
self.state.compute_stream = Stream.default(self.device)
self.state.pre_unshard_stream = Stream.new(self.device)
self.state.unshard_stream = Stream.new(self.device)
self.state.post_backward_stream = Stream.new(self.device)
self.state.reduce_stream = Stream.new(self.device)

# Initialize execution order.
Expand All @@ -419,6 +420,7 @@ def _lazy_init(self):
compute_stream=self.state.compute_stream,
pre_unshard_stream=self.state.pre_unshard_stream,
unshard_stream=self.state.unshard_stream,
post_backward_stream=self.state.post_backward_stream,
reduce_stream=self.state.reduce_stream,
forward_execution_order=self.state.forward_execution_order,
forward_prefetch_queue=self.state.forward_prefetch_queue,
Expand Down Expand Up @@ -644,9 +646,14 @@ def _reduce_scatter_grads(self):
)
return

log.debug("Reduce-scattering grads for %s (%s)", self.module.__class__.__name__, id(self.module))
grad_reduce_dtype: Optional[torch.dtype] = self.precision.reduce_dtype or self.precision.param_dtype
with self.state.reduce_stream(wait_stream=self.state.current_stream):
log.debug("Reduce-scattering grads for %s (%s)", self.module.__class__.__name__, id(self.module))

with self.state.post_backward_stream(wait_stream=self.state.current_stream):
for handle in self.state.flat_param_handles:
handle.pre_reduce_scatter_grads_(grad_reduce_dtype=grad_reduce_dtype)

with self.state.reduce_stream(wait_stream=self.state.post_backward_stream):
for handle in self.state.flat_param_handles:
handle.reduce_scatter_grads_(grad_reduce_dtype=grad_reduce_dtype)

Expand Down
5 changes: 5 additions & 0 deletions src/olmo_core/distributed/fsdp/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ class FSDPState:
Stream used during the all-gather for unsharding parameters.
"""

post_backward_stream: Stream = field(default_factory=Stream.default)
"""
Stream used during the post-backward hook to cast gradients in preparation for the all-gather.
"""

reduce_stream: Stream = field(default_factory=Stream.default)
"""
Stream used during the reduce-scatter for reducing gradients after the backward pass.
Expand Down

0 comments on commit 8c75d23

Please sign in to comment.