Skip to content

Commit 1847b8e

Browse files
authored
More improvements to FSDP, benchmark against DDP (#13)
* leave root params in mem * fix logic * don't show mem usage all the time * fix * make configurable * add alloc/free for unsharded data * add alloc/free for unsharded grad * fix * record for * revert 69d74c4 - alloc/free for unsharded grad * revert alloc/free trick for unsharded params data * add support for DDP in benchmark * set device ids explicitly * fix * change up how weights are initialized * fix test * Add back alloc/free hack for unsharded data * Revert "Add back alloc/free hack for unsharded data" This reverts commit 0386841. * Handle frozen layers with reshard-only post-backward hook * Revert "Handle frozen layers with reshard-only post-backward hook" This reverts commit 0f408d2. * add to test * add to test * Fixes for frozen modules * Divide grad before and after reducing for stability * Add support for hybrid sharding * make grad clipping optional * clean up * calculate grad norm more efficiently * Revert "calculate grad norm more efficiently" This reverts commit d66a683. * fix
1 parent a4e0ccf commit 1847b8e

File tree

8 files changed

+280
-39
lines changed

8 files changed

+280
-39
lines changed

docs/source/distributed/fsdp.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
====================
33

44
.. automodule:: olmo_core.distributed.fsdp
5-
:members: FSDP, FSDPPrecision
5+
:members: FSDP, FSDPPrecision, FSDPShardingStrategy
66
:member-order: bysource

src/benchmarks/fsdp/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def build_components(
151151
config: TransformerConfig,
152152
batch_size: int,
153153
num_batches: int = 100,
154-
fsdp_wrapper: Literal["torch", "olmo_core"] = "olmo_core",
154+
fsdp_wrapper: Literal["torch", "olmo_core", "ddp"] = "olmo_core",
155155
wrap_blocks: bool = True,
156156
mixed_precision: bool = True,
157157
max_prefetch_count: int = 1,
@@ -204,6 +204,11 @@ def auto_wrap_policy(module: nn.Module, recurse: bool, *args, **kwargs) -> bool:
204204
)
205205

206206
model.apply(init_function) # just in case
207+
elif fsdp_wrapper == "ddp":
208+
from torch.nn.parallel import DistributedDataParallel as DDP
209+
210+
model = DDP(model.cuda(), device_ids=[dist.get_rank()])
211+
model.apply(init_function)
207212
else:
208213
raise NotImplementedError(fsdp_wrapper)
209214

src/benchmarks/fsdp/train.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import torch
1515
import torch.distributed as dist
16+
from torch.nn.utils import clip_grad_norm_
1617

1718
from olmo_core.distributed.checkpoint import (
1819
load_model_and_optim_state,
@@ -28,13 +29,14 @@ def main(
2829
config: TransformerConfig,
2930
batch_size: int,
3031
num_batches: int = 100,
31-
fsdp_wrapper: Literal["torch", "olmo_core"] = "olmo_core",
32+
fsdp_wrapper: Literal["torch", "olmo_core", "ddp"] = "olmo_core",
3233
dry_run: bool = False,
3334
save_path: Optional[str] = None,
3435
load_path: Optional[str] = None,
3536
mixed_precision: bool = True,
3637
profile: bool = False,
3738
trace_output: str = "/tmp/traces/olmo_core.chrome_trace.json.gz",
39+
max_grad_norm: Optional[float] = None,
3840
**kwargs,
3941
):
4042
model, optim, dataloader = build_components(
@@ -98,21 +100,28 @@ def on_trace_ready(p):
98100
loss.backward()
99101

100102
# Clip gradient norms.
101-
model.clip_grad_norm_(1.0)
103+
norm: Optional[torch.Tensor] = None
104+
if max_grad_norm is not None:
105+
if hasattr(model, "clip_grad_norm_"):
106+
norm = model.clip_grad_norm_(max_grad_norm)
107+
else:
108+
norm = clip_grad_norm_(model.parameters(), max_grad_norm)
102109

103110
# Take optimizer step.
104111
optim.step()
105112

106113
batch_time = time.monotonic() - batch_start
107114
if i > 0:
108115
batch_times.append(batch_time)
116+
norm_str = f"{norm.item():.3f}" if norm is not None else "n/a"
109117
print_rank0(
110118
f"Batch [{i+1}/{num_batches}]:\n"
111119
f" loss={loss.item():.3f}\n"
112-
f" throughput/seconds_per_batch={batch_time:.3f}",
120+
f" throughput/seconds_per_batch={batch_time:.3f}\n"
121+
f" grad/total_norm={norm_str}"
113122
)
114123

115-
if i == 2:
124+
if profile and i == 2:
116125
print_rank0(torch.cuda.memory_summary())
117126

118127
if p is not None:
@@ -134,7 +143,7 @@ def on_trace_ready(p):
134143
parser = argparse.ArgumentParser(prog="train.py", description="Train an FSDP model")
135144
parser.add_argument(
136145
"--fsdp",
137-
choices=["torch", "olmo_core"],
146+
choices=["torch", "olmo_core", "ddp"],
138147
default="olmo_core",
139148
help="""The FSDP implementation.""",
140149
)
@@ -190,6 +199,10 @@ def on_trace_ready(p):
190199
type=int,
191200
default=1,
192201
)
202+
parser.add_argument(
203+
"--max-grad-norm",
204+
type=float,
205+
)
193206
parser.add_argument(
194207
"--lr",
195208
type=float,
@@ -237,5 +250,6 @@ def on_trace_ready(p):
237250
mixed_precision=mixed_precision,
238251
max_prefetch_count=args.max_prefetch_count,
239252
learning_rate=args.lr,
253+
max_grad_norm=args.max_grad_norm,
240254
seed=args.seed,
241255
)

src/olmo_core/distributed/fsdp/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,6 @@
7777
-------------
7878
"""
7979

80-
from .fsdp import FSDP, FSDPDebugConfig, FSDPPrecision
80+
from .fsdp import FSDP, FSDPDebugConfig, FSDPPrecision, FSDPShardingStrategy
8181

82-
__all__ = ["FSDP", "FSDPDebugConfig", "FSDPPrecision"]
82+
__all__ = ["FSDP", "FSDPDebugConfig", "FSDPPrecision", "FSDPShardingStrategy"]

src/olmo_core/distributed/fsdp/flat_param_handle.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
ShardedFlatTensor,
1616
ShardingSpec,
1717
)
18-
from olmo_core.distributed.utils import get_rank, get_world_size
18+
from olmo_core.distributed.utils import (
19+
get_gradient_divide_factor,
20+
get_rank,
21+
get_world_size,
22+
)
1923
from olmo_core.stream import Stream
2024
from olmo_core.utils import get_default_device
2125

@@ -62,21 +66,41 @@ class FlatParamHandle:
6266
"""
6367

6468
process_group: Optional[dist.ProcessGroup] = None
69+
"""
70+
Process group containing all shards.
71+
"""
72+
73+
inter_group_process_group: Optional[dist.ProcessGroup] = None
74+
"""
75+
Process group for between-group reductions with hybrid sharding.
76+
"""
6577

6678
device: Optional[torch.device] = None
6779

6880
requires_grad: bool = True
6981

82+
pre_reduce_grad_divide_factor: float = 1.0
83+
84+
post_reduce_grad_divide_factor: float = 1.0
85+
7086
_ran_pre_unshard: bool = False
7187

7288
_ran_pre_reduce_scatter_grads: bool = False
7389

90+
def __post_init__(self):
91+
data_parallel_world_size = get_world_size(self.process_group)
92+
if self.inter_group_process_group is not None:
93+
data_parallel_world_size *= self.inter_group_process_group.size()
94+
self.pre_reduce_grad_divide_factor = get_gradient_divide_factor(data_parallel_world_size)
95+
self.post_reduce_grad_divide_factor = data_parallel_world_size / self.pre_reduce_grad_divide_factor
96+
7497
@classmethod
7598
def shard_params(
7699
cls,
77100
params: Iterable[nn.Parameter],
78101
param_fqns: Iterable[str],
79102
process_group: Optional[dist.ProcessGroup] = None,
103+
inter_group_process_group: Optional[dist.ProcessGroup] = None,
80104
device: Optional[torch.device] = None,
81105
) -> FlatParamHandle:
82106
"""
@@ -183,6 +207,7 @@ def shard_params(
183207
)
184208
else:
185209
flat_param = ShardedFlatParameter(torch.empty(0, device=device))
210+
flat_param.requires_grad = param.requires_grad
186211
flat_param.mark_as_sharded(sharding_spec, process_group=process_group)
187212

188213
flat_params.append(flat_param)
@@ -224,6 +249,7 @@ def shard_params(
224249
param_fqns=list(param_fqns),
225250
params_data=params_data,
226251
process_group=process_group,
252+
inter_group_process_group=inter_group_process_group,
227253
device=device,
228254
requires_grad=requires_grad,
229255
)
@@ -253,7 +279,7 @@ def pre_unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = F
253279
self.params_sharded_data_lp.copy_(self.params_data.sharded_data)
254280

255281
# Initialize unsharded, padded gradient.
256-
if set_grads and self.params_unsharded_grad is None:
282+
if set_grads and self.requires_grad and self.params_unsharded_grad is None:
257283
self.params_unsharded_grad = torch.zeros_like(all_params_unsharded_data)
258284

259285
def unshard_(
@@ -288,7 +314,7 @@ def unshard_(
288314
if rank0_only or dist.get_backend() == dist.Backend.GLOO:
289315
assert self.params_data.is_sharded
290316
self.params_data.unshard_(dtype=dtype, rank0_only=rank0_only)
291-
if set_grads:
317+
if set_grads and self.requires_grad:
292318
self.params_unsharded_grad = torch.zeros_like(self.params_data)
293319
else:
294320
assert not self.params_data.is_sharded
@@ -318,7 +344,7 @@ def unshard_(
318344

319345
param.unshard_(unsharded_data, dtype=dtype, rank0_only=rank0_only)
320346

321-
if set_grads:
347+
if set_grads and self.requires_grad:
322348
if param.grad is None and self.params_sharded_grad is not None:
323349
self.params_sharded_grad = None
324350
assert self.params_unsharded_grad is not None
@@ -360,6 +386,9 @@ def pre_reduce_scatter_grads_(
360386
Stream.current(self.device).record_for(self.params_unsharded_grad)
361387
self.params_unsharded_grad = self.params_unsharded_grad.to(dtype=grad_reduce_dtype)
362388

389+
if self.pre_reduce_grad_divide_factor > 1.0:
390+
self.params_unsharded_grad.div_(self.pre_reduce_grad_divide_factor)
391+
363392
def reduce_scatter_grads_(
364393
self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None
365394
):
@@ -368,6 +397,7 @@ def reduce_scatter_grads_(
368397
parameter as a view into the new sharded grad.
369398
"""
370399
if not self.requires_grad or self.params_unsharded_grad is None:
400+
self._ran_pre_reduce_scatter_grads = False
371401
return
372402

373403
if not self._ran_pre_reduce_scatter_grads:
@@ -398,12 +428,20 @@ def post_reduce_scatter_grads_(
398428
"""
399429
Finalize sharded gradients after the reduce-scatter.
400430
"""
431+
if not self.requires_grad or self.params_unsharded_grad is None:
432+
return
433+
401434
grad_dtype = grad_dtype or self.params_data.dtype
402435
grad_reduce_dtype = grad_reduce_dtype or grad_dtype
403436

404-
assert self.params_unsharded_grad is not None
405437
new_sharded_grad = self.params_data.sharded_chunk(self.params_unsharded_grad)
406438

439+
if self.inter_group_process_group is not None:
440+
dist.all_reduce(new_sharded_grad, group=self.inter_group_process_group)
441+
442+
if self.post_reduce_grad_divide_factor > 1.0:
443+
new_sharded_grad.div_(self.post_reduce_grad_divide_factor)
444+
407445
# Cast the new sharded gradient to the right dtype, potentially accumulating it into
408446
# the existing sharded gradient.
409447
if self.params_sharded_grad is None:

0 commit comments

Comments
 (0)