@@ -129,9 +129,9 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
129
129
return self .module (* inputs , ** kwargs )
130
130
131
131
def reduce (self ) -> None :
132
- """ .. deprecated:: 0.0.4
132
+ """.. deprecated:: 0.0.4
133
133
134
- This does not need to be called, the gradient reduction is done automatically during the BW pass
134
+ This does not need to be called, the gradient reduction is done automatically during the BW pass
135
135
"""
136
136
logging .warning ("This is not useful anymore, gradients have been reduced automatically with the backward pass" )
137
137
@@ -157,8 +157,7 @@ def no_sync(self) -> Generator:
157
157
self .should_accumulate_grads = old_should_accumulate_grads
158
158
159
159
def _clear_counters (self ) -> None :
160
- """ Reset all the grad reduce and call counters
161
- """
160
+ """Reset all the grad reduce and call counters"""
162
161
self ._grad_to_be_reduced = [True for _ in self ._grad_to_be_reduced ]
163
162
self ._reduced_grads = {o : 0 for o in self .sharded_optimizers }
164
163
@@ -254,14 +253,14 @@ def _sync_params_and_buffers(self) -> None:
254
253
255
254
_ = list (map (lambda x : x .wait (), work_handles ))
256
255
257
- def _passing_sync_batchnorm_handle (self , module ) :
256
+ def _passing_sync_batchnorm_handle (self , module : nn . Module ) -> None :
258
257
"""
259
258
Passes handle required for ``torch.nn.modules.SyncBatchNorm``.
260
259
Adapted from ``torch.nn.distributed.DistributedDataParallel``.
261
260
"""
262
261
for layer in module .modules ():
263
262
if isinstance (layer , torch .nn .modules .SyncBatchNorm ):
264
- assert self .device_type != ' cpu' , "SyncBatchNorm layers only work with GPU modules"
263
+ assert self .device_type != " cpu" , "SyncBatchNorm layers only work with GPU modules"
265
264
# device_id logic has not been handled, assume single-process single-device
266
265
# SyncBatchNorm only supports DDP with single-process single-device anyway'
267
- layer ._specify_ddp_gpu_num (1 )
266
+ layer ._specify_ddp_gpu_num (1 ) # type: ignore
0 commit comments