Skip to content

Commit a4e0ccf

Browse files
authored
FSDP memory usage improvements (#12)
* set to none * avoid allocating another sharded tensor * try allocating one fewer * clone when needed * post reduce-scatter * show memory usage when profiling
1 parent 8c75d23 commit a4e0ccf

File tree

3 files changed

+40
-25
lines changed

3 files changed

+40
-25
lines changed

src/benchmarks/fsdp/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def on_trace_ready(p):
7878
on_trace_ready=on_trace_ready,
7979
)
8080

81+
print_rank0(torch.cuda.memory_summary())
82+
8183
print_rank0("Starting training...")
8284
batch_times: deque[float] = deque([], 50)
8385
with profiler as p:
@@ -86,7 +88,7 @@ def on_trace_ready(p):
8688
batch_start = time.monotonic()
8789

8890
# Zero-gradients.
89-
optim.zero_grad()
91+
optim.zero_grad(set_to_none=True)
9092

9193
# Run forward pass.
9294
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=mixed_precision):
@@ -110,6 +112,9 @@ def on_trace_ready(p):
110112
f" throughput/seconds_per_batch={batch_time:.3f}",
111113
)
112114

115+
if i == 2:
116+
print_rank0(torch.cuda.memory_summary())
117+
113118
if p is not None:
114119
p.step()
115120

src/olmo_core/distributed/fsdp/flat_param_handle.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,6 @@ class FlatParamHandle:
6161
the same shape as the sharded version of ``params_data``.
6262
"""
6363

64-
params_sharded_grad_tmp: Optional[torch.Tensor] = None
65-
"""
66-
Temporary storage for the local consolidated sharded grads during the reduce-scatter.
67-
"""
68-
6964
process_group: Optional[dist.ProcessGroup] = None
7065

7166
device: Optional[torch.device] = None
@@ -254,7 +249,8 @@ def pre_unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = F
254249

255250
# Cast sharded ``params_data`` to ``dtype``.
256251
if dtype is not None:
257-
self.params_sharded_data_lp = self.params_data.sharded_data.to(dtype)
252+
self.params_sharded_data_lp = self.params_data.sharded_chunk(all_params_unsharded_data)
253+
self.params_sharded_data_lp.copy_(self.params_data.sharded_data)
258254

259255
# Initialize unsharded, padded gradient.
260256
if set_grads and self.params_unsharded_grad is None:
@@ -364,10 +360,6 @@ def pre_reduce_scatter_grads_(
364360
Stream.current(self.device).record_for(self.params_unsharded_grad)
365361
self.params_unsharded_grad = self.params_unsharded_grad.to(dtype=grad_reduce_dtype)
366362

367-
self.params_sharded_grad_tmp = torch.empty(
368-
self.params_data.sharded_shape, dtype=self.params_unsharded_grad.dtype, device=self.device
369-
)
370-
371363
def reduce_scatter_grads_(
372364
self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None
373365
):
@@ -380,11 +372,8 @@ def reduce_scatter_grads_(
380372

381373
if not self._ran_pre_reduce_scatter_grads:
382374
self.pre_reduce_scatter_grads_(grad_dtype=grad_dtype, grad_reduce_dtype=grad_reduce_dtype)
383-
assert self.params_sharded_grad_tmp is not None
384375
else:
385-
assert self.params_sharded_grad_tmp is not None
386376
Stream.current(self.device).record_for(self.params_unsharded_grad)
387-
Stream.current(self.device).record_for(self.params_sharded_grad_tmp)
388377

389378
self._ran_pre_reduce_scatter_grads = False
390379

@@ -397,23 +386,40 @@ def reduce_scatter_grads_(
397386
if dist.get_backend() == dist.Backend.NCCL:
398387
# Get chunks corresponding to each rank.
399388
grad_chunks = self.params_data.chunk_unsharded(self.params_unsharded_grad)
400-
dist.reduce_scatter(self.params_sharded_grad_tmp, grad_chunks, group=self.process_group)
389+
dist.reduce_scatter(
390+
grad_chunks[get_rank(group=self.process_group)], grad_chunks, group=self.process_group
391+
)
401392
else:
402393
dist.all_reduce(self.params_unsharded_grad, group=self.process_group)
403-
self.params_sharded_grad_tmp.copy_(self.params_data.sharded_chunk(self.params_unsharded_grad))
404394

405-
# Deallocate the unsharded padded grad.
406-
# NOTE: Since we're potentially using a separate stream for this reduce-scatter, we need to make
407-
# sure `params_unsharded_grad` is not deallocated before the reduce-scatter finishes.
408-
Stream.current(self.device).record_for(self.params_unsharded_grad)
409-
self.params_unsharded_grad = None
395+
def post_reduce_scatter_grads_(
396+
self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None
397+
):
398+
"""
399+
Finalize sharded gradients after the reduce-scatter.
400+
"""
401+
grad_dtype = grad_dtype or self.params_data.dtype
402+
grad_reduce_dtype = grad_reduce_dtype or grad_dtype
403+
404+
assert self.params_unsharded_grad is not None
405+
new_sharded_grad = self.params_data.sharded_chunk(self.params_unsharded_grad)
410406

411-
# Cast the reduce-scatter target to the right dtype, potentially accumulating it into
412-
# the existing gradient.
407+
# Cast the new sharded gradient to the right dtype, potentially accumulating it into
408+
# the existing sharded gradient.
413409
if self.params_sharded_grad is None:
414-
self.params_sharded_grad = self.params_sharded_grad_tmp.to(grad_dtype)
410+
if new_sharded_grad.dtype == grad_dtype:
411+
self.params_sharded_grad = new_sharded_grad.clone()
412+
else:
413+
self.params_sharded_grad = new_sharded_grad.to(grad_dtype)
415414
else:
416-
self.params_sharded_grad.add_(self.params_sharded_grad_tmp)
415+
self.params_sharded_grad.add_(new_sharded_grad)
416+
417+
# Deallocate the unsharded padded grad.
418+
# NOTE: Since we're potentially using a separate stream here, we need to make
419+
# sure `params_unsharded_grad` is not deallocated before this finishes.
420+
Stream.current(self.device).record_for(self.params_unsharded_grad)
421+
self.params_unsharded_grad = None
422+
del new_sharded_grad
417423

418424
# At this point each param will be sharded again, and we set the grad for each param as a view
419425
# into the sharded grad.

src/olmo_core/distributed/fsdp/fsdp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,10 @@ def _reduce_scatter_grads(self):
657657
for handle in self.state.flat_param_handles:
658658
handle.reduce_scatter_grads_(grad_reduce_dtype=grad_reduce_dtype)
659659

660+
with self.state.post_backward_stream(wait_stream=self.state.reduce_stream):
661+
for handle in self.state.flat_param_handles:
662+
handle.post_reduce_scatter_grads_(grad_reduce_dtype=grad_reduce_dtype)
663+
660664
def _deque_from(self, prefetch_queue: deque[FSDP]) -> Generator[FSDP, None, None]:
661665
count = 0
662666
while prefetch_queue and count < self.max_prefetch_count:

0 commit comments

Comments
 (0)