Description
🐛 Describe the bug
Memory increase at the start of iteration after the start
I have been trying to use DataLoader2 with multiprocessing (and distributed in some cases). In general, its behavior is pretty strange relative to the original data loader implementation (which I'll call DataLoader1 below). It seems that after the completion of an epoch (iteration) the dataloader holds all data states instead of resetting. As a result, memory usage increases from the train epoch to the validation epoch.
More problematic still; when starting the next epoch the previous epochs states seem to be held and cause memory usage to spike upwards. I imagine this causes (some of the many recent issues) Memory Errors, and did for sure in my case when training with DDP. DataLoader1 has none of these issues.
I tested with a relatively complicated datapipe, using Multiplexing, several intermediate 1:Multi yielding mechanisms, and producing a pair (audio: tensor, metadata: dict).
I saw a recent post claiming that dictionaries were the issue. At least from what I have seen it is the reading service more than dictionaries.
DataLoader2 compared with torch.data DataLoader1
Here is the code that I used to produce the results below.
def BenchmarkLoading(Pipe, N=10000, NumPrints=10, phrase='Decoding Tars, DataPipe'):
bm = Benchmarker(f'{phrase}, {N} samples')
for i,x in enumerate(Pipe):
if i%N == 0 and i:
print('Iteration', i, 'of', N*NumPrints)
bm()
if i==N*NumPrints:
break
bm.compute()
batch_size = 1024 # Fixed for testing.
num_workers = 8
rs = reader.MultiProcessingReadingService(num_workers=num_workers)
dataloader = DataLoader2(Collator(datapipe.batch(batch_size)), reading_service=rs)
# dataloader = torch.utils.data.DataLoader(datapipe, num_workers=num_workers, batch_size=batch_size)
for i in range(100):
print(i)
dataloader.seed(i)
BenchmarkLoading(dataloader, N=100, NumPrints=100, phrase=f'{task_name}, Data loader (not 2), batches {batch_size}, workers {num_workers}')
print('\n\n')
- Epoch 0 behaves very similarly to DL1, maybe slightly faster , but it was only a single run.
- Starting from Epoch 1, DataLoader2 has very strange and risky start-up behavior.
- It appears to not clean-up from the previous run and instead memory usage spikes upward.
- Additionally, the startup time is much higher than "Epoch 0", 55s vs. 144s.
- The latter long load of 144s holds for all subsequent epochs as well.
Results:
# DL1: Very consistent behavior.
"Epoch 0"
Iteration 100 of 10000
Multiplexer 100% Augmentation, Data loader (not 2), batches 1024, workers 8, 100 samples took 54.31092572212219 seconds.
Iteration 200 of 10000
Multiplexer 100% Augmentation, Data loader (not 2), batches 1024, workers 8, 100 samples took 12.599214553833008 seconds.
Iteration 300 of 10000
Multiplexer 100% Augmentation, Data loader (not 2), batches 1024, workers 8, 100 samples took 13.167328834533691 seconds.
"Epoch 1"
Iteration 100 of 10000
Multiplexer 100% Augmentation, Data loader (not 2), batches 1024, workers 8, 100 samples took 54.44526529312134 seconds.
Iteration 200 of 10000
Multiplexer 100% Augmentation, Data loader (not 2), batches 1024, workers 8, 100 samples took 11.10575008392334 seconds.
Iteration 300 of 10000
Multiplexer 100% Augmentation, Data loader (not 2), batches 1024, workers 8, 100 samples took 16.893246173858643 seconds.
"Epoch 2"
Iteration 100 of 10000
Multiplexer 100% Augmentation, Data loader (not 2), batches 1024, workers 8, 100 samples took 54.24127960205078 seconds.
Iteration 200 of 10000
Multiplexer 100% Augmentation, Data loader (not 2), batches 1024, workers 8, 100 samples took 10.061070203781128 seconds.
Iteration 300 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 10.461547136306763 seconds.
DL2: Very consistently poor performance at the start after "Epoch 0"
"Epoch 0"
Iteration 100 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 55.075947523117065 seconds.
Iteration 200 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 10.734715938568115 seconds.
Iteration 300 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 10.988599061965942 seconds.
"Epoch 1"
Iteration 100 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 144.52063298225403 seconds.
Iteration 200 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 8.694239616394043 seconds.
Iteration 300 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 11.868495464324951 seconds.
"Epoch 2"
Iteration 100 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 144.5120747089386 seconds.
Iteration 200 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 10.244807004928589 seconds.
Iteration 300 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 12.066998481750488 seconds.
"Epoch 3"
Iteration 100 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 147.51504135131836 seconds.
Iteration 200 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 10.539249181747437 seconds.
Iteration 300 of 10000
Multiplexer 100% Augmentation, DataLoader2, batches 1024, workers 8, 100 samples took 10.461547136306763 seconds.
Attempts to get the resetting behavior of DL1
I studied the internal variable states embedded in DataLoader2 and the reading service.
In the reading service, there are pipes that stick around (per worker) after the epoch.
{'num_workers': 8,
'multiprocessing_context': None,
'worker_prefetch_cnt': 10,
'main_prefetch_cnt': 10,
'worker_init_fn': None,
'worker_reset_fn': None,
'_worker_processes': [(<ForkProcess name='ForkProcess-9' pid=27584 parent=23977 started daemon>,
<multiprocessing.queues.Queue at 0x7f8980c57b20>,
<multiprocessing.queues.Queue at 0x7f8971607940>),
(<ForkProcess name='ForkProcess-10' pid=27585 parent=23977 started daemon>,
<multiprocessing.queues.Queue at 0x7f8980c57df0>,
<multiprocessing.queues.Queue at 0x7f8971607e80>),
(<ForkProcess name='ForkProcess-11' pid=27617 parent=23977 started daemon>,
<multiprocessing.queues.Queue at 0x7f89716079a0>,
<multiprocessing.queues.Queue at 0x7f8971680490>),
(<ForkProcess name='ForkProcess-12' pid=27649 parent=23977 started daemon>,
<multiprocessing.queues.Queue at 0x7f89716800d0>,
<multiprocessing.queues.Queue at 0x7f8971680a60>),
(<ForkProcess name='ForkProcess-13' pid=27669 parent=23977 started daemon>,
<multiprocessing.queues.Queue at 0x7f8971680670>,
<multiprocessing.queues.Queue at 0x7f8971681030>),
(<ForkProcess name='ForkProcess-14' pid=27682 parent=23977 started daemon>,
<multiprocessing.queues.Queue at 0x7f8971680c40>,
<multiprocessing.queues.Queue at 0x7f89716814b0>),
(<ForkProcess name='ForkProcess-15' pid=27745 parent=23977 started daemon>,
<multiprocessing.queues.Queue at 0x7f8971681120>,
<multiprocessing.queues.Queue at 0x7f8971681a80>),
(<ForkProcess name='ForkProcess-16' pid=27777 parent=23977 started daemon>,
<multiprocessing.queues.Queue at 0x7f8971681690>,
<multiprocessing.queues.Queue at 0x7f8971682050>)],
'_dispatch_process': None,
'_worker_datapipes': [QueueWrapper,
QueueWrapper,
QueueWrapper,
QueueWrapper,
QueueWrapper,
QueueWrapper,
QueueWrapper,
QueueWrapper],
'_worker_consumer_datapipe': _IterateQueueDataPipes,
'_main_prefetch_datapipe': PrefetcherIterDataPipe,
'_end_datapipe': PrefetcherIterDataPipe,
'_mp': True}
I was able to effectively resolve the early startup time by resetting all of these values to their original values. However, this resulted in the creation of a whole new dataloader and doubled the memory usage (attached image).
dataloader.reading_service._worker_processes = []
dataloader.reading_service._worker_datapipes = []
dataloader.reading_service._worker_consumer_datapipe = None
dataloader.reading_service._main_prefetch_datapipe = None
dataloader.reading_service._end_datapipe = None
dataloader.datapipe = dataloader._datapipe_before_reading_service_adapt
dataloader._datapipe_iter = None
dataloader.valid_iterator_id = None
dataloader._adapted = False
Q: Is there a way to embed the reset behavior into the 'worker_reset_fn' variable of the reading service without causing the memory increase?
Other recommendations to hard reset the data loader every step? Compared to DL1, it is much less efficient to keep the memory stored and when resetting to briefly have 2 dataloaders worth of RAM usage. It also causes startup time for my jobs per epoch to nearly triple, before proceeding as normal.
I left my original comment here: #1150
Small comment about datapipes, isolating to the reading service
Datapipe performance is very consistent after resetting the iterator. This may be clear already from DL1 but I ran the test so showing it here: