Skip to content

Commit 84a3bdb

Browse files
authored
[fix] Typo in ShardedDDP unit test (#282)
* fix typo, backend for CPU test
1 parent 1c8d219 commit 84a3bdb

File tree

2 files changed

+8
-13
lines changed

2 files changed

+8
-13
lines changed

fairscale/nn/data_parallel/sharded_ddp.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,9 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
129129
return self.module(*inputs, **kwargs)
130130

131131
def reduce(self) -> None:
132-
""" .. deprecated:: 0.0.4
132+
""".. deprecated:: 0.0.4
133133
134-
This does not need to be called, the gradient reduction is done automatically during the BW pass
134+
This does not need to be called, the gradient reduction is done automatically during the BW pass
135135
"""
136136
logging.warning("This is not useful anymore, gradients have been reduced automatically with the backward pass")
137137

@@ -157,8 +157,7 @@ def no_sync(self) -> Generator:
157157
self.should_accumulate_grads = old_should_accumulate_grads
158158

159159
def _clear_counters(self) -> None:
160-
""" Reset all the grad reduce and call counters
161-
"""
160+
"""Reset all the grad reduce and call counters"""
162161
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
163162
self._reduced_grads = {o: 0 for o in self.sharded_optimizers}
164163

@@ -254,14 +253,14 @@ def _sync_params_and_buffers(self) -> None:
254253

255254
_ = list(map(lambda x: x.wait(), work_handles))
256255

257-
def _passing_sync_batchnorm_handle(self, module):
256+
def _passing_sync_batchnorm_handle(self, module: nn.Module) -> None:
258257
"""
259258
Passes handle required for ``torch.nn.modules.SyncBatchNorm``.
260259
Adapted from ``torch.nn.distributed.DistributedDataParallel``.
261260
"""
262261
for layer in module.modules():
263262
if isinstance(layer, torch.nn.modules.SyncBatchNorm):
264-
assert self.device_type != 'cpu', "SyncBatchNorm layers only work with GPU modules"
263+
assert self.device_type != "cpu", "SyncBatchNorm layers only work with GPU modules"
265264
# device_id logic has not been handled, assume single-process single-device
266265
# SyncBatchNorm only supports DDP with single-process single-device anyway'
267-
layer._specify_ddp_gpu_num(1)
266+
layer._specify_ddp_gpu_num(1) # type: ignore

tests/nn/data_parallel/test_sharded_ddp.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,7 @@ def test_ddp_attributes():
316316
# - device_type
317317

318318
url = "file://" + tempfile.mkstemp()[1]
319-
backend = dist.Backend.NCCL
320-
dist.init_process_group(init_method=url, backend=backend, rank=0, world_size=1)
319+
dist.init_process_group(init_method=url, backend="gloo", rank=0, world_size=1)
321320

322321
model = Sequential(Linear(2, 3), Linear(3, 3))
323322
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
@@ -352,10 +351,7 @@ def test_ddp_sync_batch_norm():
352351
temp_file_name = tempfile.mkstemp()[1]
353352
device = "cuda"
354353
mp.spawn(
355-
run_test_ddp_sync_batch_norm,
356-
args=(world_size, backend, device, temp_file_name),
357-
nprocs=world_size,
358-
join=True
354+
run_test_ddp_sync_batch_norm, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True
359355
)
360356

361357

0 commit comments

Comments
 (0)