-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory spikes with large DataPipes #1150
Comments
Have you tried the same pipeline with |
I just tried it (I had previously tried DataLoader2 but perhaps with torch 1.13.1) and the spikes still occur. This makes sense to me because it seems the datapipe graph is still traversed in the same way. |
I posted another MemoryError that may be related here: https://discuss.pytorch.org/t/torchdata-w-ddp-start-of-epoch-2-get-memoryerror/179523 My MemoryError also occurs at the start of the epoch while using DDP and distributed multiprocessing. It seems to depend on the size of the shuffles that I put into the datapipe (one for files, one for fixed length decoding, one for augmentations), as I got through 9 epochs before reaching the OOM error most recently. It's really weird. I use < 150GB of RAM during training and my 500 GB of RAM gets overwhelmed at the beginning of epoch 2. I considered shutting down and restarting the pipe to resolve. |
@sesquipedalianist |
@andrew-bydlon |
It's difficult to provide code for this purpose as the code is property of a large corp. Some other notes and expansion of the other thoughts: Mention of shuffling cause memory increase: pytorch/pytorch#13246 (comment) I am generally storing data in tars in the form (arbitrary length audio, {labels: tensor, dataset: string, ID: string}) And here is an expansion of my list of pipes:
|
For now I have solved my issue by monkeypatching
This effectively disables |
@andrew-bydlon are you saving anything in memory (like audio samples?). That would likely cause the same issue as I was having. |
I'm not saving anything in memory other than prefetching. I'm using iterable datapipes to do all of the above per recommendations on the homepage. These default to prefetch factors of 10. The augmentation operations take some compute, but all of this happens at the start of epoch 2 (going from 20% memory -> 100%), so it seems extremely unexpected. |
Thank you both for your help. I have finally deep-dived this topic and made an issue: There is a lot of talk about Memory Leaks in the Issues. I really like the DataLoader2 API, but will be temporarily switching back to DL1 because of the issues that I mention. |
I tried this out without success. Glad it worked for you! |
🐛 Describe the bug
I’ve noticed large “spikes” in memory usage at the start of epochs when using IterDataPipes with attributes that take a lot of memory. These can cause my training jobs to fail with out-of-memory errors.
Here’s a minimal example to reproduce:
The memory usage (logged with psutil) looks like this:
Here,
start_epoch
indicates the start of an epoch andfirst_iter
corresponds to the first time each epoch we reach the pass statement in the dataloader loop. (To simplify the example code above I removed the code that logsstart_epoch
andfirst_iter
. I logged the memory usage from a separate process.)After some debugging, I can say that the memory spikes occur during the traversal of the graph that occurs in
torch/utils/data/graph_settings.py::apply_random_seed()
at the beginning of each epoch. Disabling the body of this function removes the memory spikes.The spikes seem to be caused by the pickling in https://github.com/pytorch/pytorch/blob/99ded8bbcea896b02f1c0babb055329c503ca95e/torch/utils/data/graph.py#L23
The code here defines
f = io.BytesIO()
and pickles tof
. If there are large datapipes to be pickled, it makes sense that the memory usage will blow up quickly and then fall again when f goes out of scope.I tried replacing
f = io.BytesIO()
withf = open(os.devnull, "wb")
(and addingf.close()
at the end of the function). This didn’t eliminate the memory spikes but it did make them a bit smaller.A few notes:
.in_memory_cache()
to see these spikes; it seems that any datapipe that occupies a lot of memory will cause themVersions
I have tested the above with both
I observed the same behavior in both cases.
The text was updated successfully, but these errors were encountered: