Potentially incorrect calculation of total_updates
on >=4.46.0 since #34198 affecting multi gpu training
#35387
Labels
total_updates
on >=4.46.0 since #34198 affecting multi gpu training
#35387
System Info
Okay I have been pulling my hair out for a few hours and turns out this bug only happens when
average_tokens_across_devices
is True and epochs > 1Simplest case to reproduce
DDP
world size 2
dataset length = 4
epochs = 2
micro batch size = 1 (aka per gpu batch size)
gradient accumulation = 1
average_tokens_across_devices = True
So every epoch, total 2 steps on both devices
but as first epoch finishes, we get
The main culprit here is
transformers/src/transformers/trainer.py
Lines 2468 to 2473 in 8f38f58
steps_in_epoch
per rank is correctly calculated as 2 but total updates is 3Normally that is harmless because dataloader would be exhausted and would result in empty batch and it won't enter the loop on 2473.
However, when using the recently added option
average_tokens_across_devices
, it will try to gather number of total items in batches across all ranks and gather doesn't like broadcastingNone
transformers/src/transformers/trainer.py
Lines 5139 to 5156 in 8f38f58
This problem does not surface with 1 gpu because
average_tokens_across_devices
is auto set toFalse
and neither under epoch = 1 becauseDefaultFlowCallback
stops the training process considering global step and expected max stepsWho can help?
@muellerzr
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Train any LM on more than one gpu, set at least average_tokens_across_devices = True and epochs > 1
Expected behavior
Either we fix
total_updates
count or we handleNone
for gatherThe text was updated successfully, but these errors were encountered: