Skip to content

Commit 2cfd3a7

Browse files
authored
More FSDP optimizations (#10)
* profile training loop * add `--trace-output` arg * skip norm for empty grads * expand user * debug logging * post-backward final hook * better logging * mark backward exec order finalized for children * fix lint
1 parent f235164 commit 2cfd3a7

File tree

2 files changed

+115
-46
lines changed

2 files changed

+115
-46
lines changed

src/benchmarks/fsdp/train.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
"""
55

66
import argparse
7+
import contextlib
78
import logging
8-
import os
99
import time
1010
from pathlib import Path
1111
from typing import Literal, Optional
@@ -32,6 +32,8 @@ def main(
3232
save_path: Optional[str] = None,
3333
load_path: Optional[str] = None,
3434
mixed_precision: bool = True,
35+
profile: bool = False,
36+
trace_output: str = "/tmp/traces/olmo_core.chrome_trace.json.gz",
3537
**kwargs,
3638
):
3739
model, optim, dataloader = build_components(
@@ -56,33 +58,56 @@ def main(
5658
print_rank0(f"Saving checkpoint to {checkpoint_dir}...")
5759
save_model_and_optim_state(checkpoint_dir, model, optim)
5860

61+
profiler = contextlib.nullcontext()
62+
if profile:
63+
from torch.profiler import ProfilerActivity, schedule
64+
65+
def on_trace_ready(p):
66+
trace_path = Path(trace_output).expanduser()
67+
trace_path.parent.mkdir(exist_ok=True, parents=True)
68+
p.export_chrome_trace(str(trace_path))
69+
print_rank0(f"Tracing complete, saved to '{trace_path}'")
70+
71+
profiler = torch.profiler.profile(
72+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
73+
record_shapes=False,
74+
profile_memory=False,
75+
with_stack=True,
76+
schedule=schedule(wait=1, warmup=5, active=3, repeat=1),
77+
on_trace_ready=on_trace_ready,
78+
)
79+
5980
print_rank0("Starting training...")
60-
for i, batch in enumerate(iter(dataloader)):
61-
log.debug("Batch: %s", batch)
62-
batch_start = time.monotonic()
81+
with profiler as p:
82+
for i, batch in enumerate(iter(dataloader)):
83+
log.debug("Batch: %s", batch)
84+
batch_start = time.monotonic()
6385

64-
# Zero-gradients.
65-
optim.zero_grad()
86+
# Zero-gradients.
87+
optim.zero_grad()
6688

67-
# Run forward pass.
68-
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=mixed_precision):
69-
loss = compute_loss(model, batch)
89+
# Run forward pass.
90+
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=mixed_precision):
91+
loss = compute_loss(model, batch)
7092

71-
# Trigger backward pass.
72-
loss.backward()
93+
# Trigger backward pass.
94+
loss.backward()
7395

74-
# Clip gradient norms.
75-
model.clip_grad_norm_(1.0)
96+
# Clip gradient norms.
97+
model.clip_grad_norm_(1.0)
7698

77-
# Take optimizer step.
78-
optim.step()
99+
# Take optimizer step.
100+
optim.step()
79101

80-
batch_end = time.monotonic()
81-
print_rank0(
82-
f"Batch [{i+1}/{num_batches}]:\n"
83-
f" loss={loss.item():.3f}\n"
84-
f" throughput/seconds_per_batch={batch_end-batch_start:.3f}",
85-
)
102+
batch_end = time.monotonic()
103+
print_rank0(
104+
f"Batch [{i+1}/{num_batches}]:\n"
105+
f" loss={loss.item():.3f}\n"
106+
f" throughput/seconds_per_batch={batch_end-batch_start:.3f}",
107+
)
108+
109+
if p is not None:
110+
p.step()
86111

87112
if save_path is not None:
88113
checkpoint_dir = Path(save_path) / "final"
@@ -126,6 +151,15 @@ def main(
126151
"--debug",
127152
action="store_true",
128153
)
154+
parser.add_argument(
155+
"--profile",
156+
action="store_true",
157+
)
158+
parser.add_argument(
159+
"--trace-output",
160+
type=str,
161+
default="/tmp/traces/olmo_core.chrome_trace.json.gz",
162+
)
129163
parser.add_argument(
130164
"--save-path",
131165
type=str,
@@ -168,7 +202,7 @@ def main(
168202
raise NotImplementedError(args.model_size)
169203

170204
if args.debug:
171-
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
205+
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
172206
config.debug = True
173207

174208
dist.init_process_group(backend="nccl")
@@ -185,6 +219,8 @@ def main(
185219
dry_run=args.dry_run,
186220
save_path=args.save_path,
187221
load_path=args.load_path,
222+
profile=args.profile,
223+
trace_output=args.trace_output,
188224
mixed_precision=mixed_precision,
189225
max_prefetch_count=args.max_prefetch_count,
190226
learning_rate=args.lr,

src/olmo_core/distributed/fsdp/fsdp.py

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import torch
2727
import torch.distributed as dist
2828
import torch.nn as nn
29+
from torch.autograd import Variable
2930

3031
from olmo_core.distributed.tensors import ShardedFlatParameter
3132
from olmo_core.stream import Stream
@@ -322,7 +323,7 @@ def clip_grad_norm_(self, max_norm: float, norm_type: float = 2.0) -> torch.Tens
322323
nonsharded_params: Set[nn.Parameter] = set()
323324
grads: List[torch.Tensor] = []
324325
for param in self.parameters():
325-
if param.grad is None:
326+
if param.grad is None or param.grad.numel() == 0:
326327
continue
327328

328329
if isinstance(param, ShardedFlatParameter):
@@ -394,7 +395,11 @@ def _lazy_init(self):
394395
self.state.forward_execution_order.append(self)
395396
return
396397

397-
log.debug("Completing lazy initialization from root FSDP for %s...", self.module.__class__.__name__)
398+
log.debug(
399+
"Completing lazy initialization from root FSDP for %s (%s)...",
400+
self.module.__class__.__name__,
401+
id(self.module),
402+
)
398403

399404
# Initialize streams.
400405
self.state.compute_stream = Stream.default(self.device)
@@ -494,7 +499,7 @@ def _shard(self):
494499
495500
This should only be called once at initialization.
496501
"""
497-
log.debug("Sharding %s...", self.module.__class__.__name__)
502+
log.debug("Sharding %s (%s)...", self.module.__class__.__name__, id(self.module))
498503

499504
params_with_grads: List[nn.Parameter] = []
500505
params_with_grads_fqns: List[str] = []
@@ -568,7 +573,7 @@ def _unshard(
568573

569574
kwargs = dict(cast=cast, set_grads=set_grads, recurse=recurse, rank0_only=rank0_only)
570575

571-
log.debug("Unsharding %s...", self.module.__class__.__name__)
576+
log.debug("Unsharding %s (%s)...", self.module.__class__.__name__, id(self.module))
572577
self.state.params_prefetched = True
573578

574579
# NOTE: `unshard_stream` should wait on current stream (usually `compute_stream` / `default_stream`)
@@ -600,7 +605,11 @@ def _unshard(
600605
def _prefetch(self, prefetch_from: deque[FSDP], **kwargs):
601606
for module in self._deque_from(prefetch_from):
602607
log.debug(
603-
"Prefetching %s from %s...", module.module.__class__.__name__, self.module.__class__.__name__
608+
"Prefetching %s (%s) from %s (%s)...",
609+
module.module.__class__.__name__,
610+
id(module.module),
611+
self.module.__class__.__name__,
612+
id(self.module),
604613
)
605614
module._unshard(**kwargs)
606615

@@ -611,7 +620,7 @@ def _reshard(self, writeback: bool = False, recurse: bool = False):
611620
"""
612621
kwargs = dict(writeback=writeback, recurse=recurse)
613622

614-
log.debug("Resharding %s...", self.module.__class__.__name__)
623+
log.debug("Resharding %s (%s)...", self.module.__class__.__name__, id(self.module))
615624
self.state.params_prefetched = False
616625

617626
for handle in self.state.flat_param_handles:
@@ -637,7 +646,7 @@ def _reduce_scatter_grads(self):
637646

638647
grad_reduce_dtype: Optional[torch.dtype] = self.precision.reduce_dtype or self.precision.param_dtype
639648
with self.state.reduce_stream(wait_stream=self.state.current_stream):
640-
log.debug("Reduce-scattering grads for %s", self.module.__class__.__name__)
649+
log.debug("Reduce-scattering grads for %s (%s)", self.module.__class__.__name__, id(self.module))
641650
for handle in self.state.flat_param_handles:
642651
handle.reduce_scatter_grads_(grad_reduce_dtype=grad_reduce_dtype)
643652

@@ -659,13 +668,16 @@ def _deque_from(self, prefetch_queue: deque[FSDP]) -> Generator[FSDP, None, None
659668
@torch.no_grad()
660669
def _pre_backward_hook(self, *unused: Any):
661670
del unused
662-
log.debug("Running pre-backward hook for %s...", self.module.__class__.__name__)
671+
log.debug("Running pre-backward hook for %s (%s)...", self.module.__class__.__name__, id(self.module))
663672

664673
# Remove all pre backward hooks for this FSDP instance since they all do the same thing.
665674
for handle in self.state.pre_backward_hook_handles:
666675
handle.remove()
667676
self.state.pre_backward_hook_handles.clear()
668677

678+
if self.is_root:
679+
self._register_post_backward_final_hook()
680+
669681
# Unshard parameters in place.
670682
self._unshard(set_grads=True)
671683

@@ -684,10 +696,12 @@ def _register_pre_backward_hook(self, x: torch.Tensor):
684696
self.state.pre_backward_hook_handles.append(handle)
685697

686698
def _register_pre_backward_hooks(self, output: Any):
687-
log.debug("Registering pre-backward hooks for %s...", self.module.__class__.__name__)
699+
log.debug("Registering pre-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module))
688700
# Clear existing hooks if there are any.
689701
if self.state.pre_backward_hook_handles:
690-
log.debug("Removing old pre-backward hooks for %s...", self.module.__class__.__name__)
702+
log.debug(
703+
"Removing old pre-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module)
704+
)
691705
for handle in self.state.pre_backward_hook_handles:
692706
handle.remove()
693707
self.state.pre_backward_hook_handles.clear()
@@ -699,29 +713,19 @@ def _register_pre_backward_hooks(self, output: Any):
699713
@torch.no_grad()
700714
def _post_backward_hook(self, param_name: str, *unused: Any):
701715
del unused
702-
log.debug("Running post-backward hook for %s.%s...", self.module.__class__.__name__, param_name)
703716
self.state.post_backward_hook_handles.pop(param_name).remove()
704717

705718
# If there are still more handles then there are still more post-backward hooks to be ran
706719
# in the current FSDP node. Only the last handle should do the work.
707720
if self.state.post_backward_hook_handles:
708721
return
709722

723+
log.debug("Running post-backward hook for %s (%s)", self.module.__class__.__name__, id(self.module))
724+
710725
# NOTE: reshard *before* reducing grads to correctly handle precision settings.
711726
self._reshard()
712727
self._reduce_scatter_grads()
713728

714-
# The root FSDP instance needs to do some final cleanup.
715-
if not self.is_root:
716-
return
717-
718-
# Mark backward execution order as finalized.
719-
self.state.backward_execution_order_finalized = True
720-
721-
# Wait for unsharding and reducing streams to complete so the model is not left in a bad
722-
# state before grad clipping, optimizer step, or whatever else.
723-
self.state.current_stream.wait_stream(self.state.reduce_stream)
724-
725729
def _register_post_backward_hook(self, param_name: str, param: ShardedFlatParameter):
726730
# Force creation of a `grad_fn` in order to register a hook that will run *after* this param's
727731
# backward pass.
@@ -733,13 +737,42 @@ def _register_post_backward_hook(self, param_name: str, param: ShardedFlatParame
733737
self.state.post_backward_hook_handles[param_name] = handle
734738

735739
def _register_post_backward_hooks(self):
736-
log.debug("Registering post-backward hooks for %s...", self.module.__class__.__name__)
740+
log.debug(
741+
"Registering post-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module)
742+
)
737743
# Clear existing hooks if there are any.
738744
if self.state.post_backward_hook_handles:
739-
log.debug("Removing old post-backward hooks for %s...", self.module.__class__.__name__)
745+
log.debug(
746+
"Removing old post-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module)
747+
)
740748
for handle in self.state.post_backward_hook_handles.values():
741749
handle.remove()
742750
self.state.post_backward_hook_handles.clear()
743751
for param_name, param in self._managed_named_parameters():
744752
if param.requires_grad:
745753
self._register_post_backward_hook(param_name, param)
754+
755+
@torch.no_grad()
756+
def _post_backward_final_hook(self):
757+
if not self.is_root:
758+
return
759+
760+
log.debug("Running post-backward final hook for %s (%s)", self.module.__class__.__name__, id(self.module))
761+
762+
# Mark backward execution order as finalized.
763+
self.state.backward_execution_order_finalized = True
764+
for child in self._fsdp_children(recurse=True):
765+
child.state.backward_execution_order_finalized = True
766+
767+
# Wait for unsharding and reducing streams to complete so the model is not left in a bad
768+
# state before grad clipping, optimizer step, or whatever else.
769+
self.state.current_stream.wait_stream(self.state.reduce_stream)
770+
771+
def _register_post_backward_final_hook(self):
772+
if not self.is_root:
773+
return
774+
775+
log.debug(
776+
"Registering post-backward final hook for %s (%s)...", self.module.__class__.__name__, id(self.module)
777+
)
778+
Variable._execution_engine.queue_callback(self._post_backward_final_hook)

0 commit comments

Comments
 (0)