-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory requirements to replicate on Pems-Bay #88
Comments
I managed to replicate the result on PEMS-Bay, using a batch_size of 4 it trains ok. Reading the paper, I saw that there are these suggestions for low memory scenarios:
I quote: "Most results in this paper were collected using fast attention alone with less than 40GBs of GPU memory. Strided convolutions are only necessary on the largest datasets." This part is not clear to me: Shifted window attention saves meaningful amounts of memory when using quadratic attention, so we primarily use it when we need to mask padded sequences. |
I'm surprised batch size 4 worked that well that's interesting. The pems-bay results in v3 of the paper (the current one) were run on multiple A100s. I don't think this was necessary, I just had compute at the end of the project and scaled way up to see what would happen. All the other results (including arxiv v1 and v2 pems-bay) were on more accessible GPUs. I meant to circle back and publish training commands for low-gpu settings but this project dragged on in peer review so long I switched institutions and research topics by the time it was over and couldn't return to it :)
The shifted window attention moves memory requirements from the length dimension (where it's normally quadratic) to the batch dimension (where it's linear). So if you are using a linear attention approximation like performer, you really aren't saving much memory here overall. We mask padded sequences for mixed-length datasets which are included in the codebase (m4, wikipedia) but are not heavily discussed in the paper. Basically if your context sequences have mixed lengths, the flattened spatiotemporal sequences require an unusual attention mask that does not fit most efficient attention implementations. In this situation it makes sense to revert to vanilla quadratic attention where we can easily use any mask we'd like, and reduce computation with shifted windows. If you are actively working on this now, I think it'd be an interesting update to do this with Flash Attention instead of Vanilla Attention. You still save compute with shifted windows and the performance gains are significant enough where I think it is hard to justify approximations like Performer today - at least at these sequence lengths. |
Thanks to the authors for the great work and for sharing the code!
I am interested in replicating the results on Pems-Bay before trying the model on a custom dataset of similar size.
I am using the suggested command:
I set
accelerator="gpu"
in theTrainer
and run the command from the README:It gives an OOM error on my GPU (which has more than 20GB of space)
How many GBs of GPU memory are required? Did you use multiple GPUs to train
spacetimeformer
on Pems-Bay?I also tried keeping
accelerator="dp"
, but it gives this error:I hope someone can help me run
spacetimeformer
in multi-gpu mode.In case it is not possible, do you have some suggestion to reduce memory requirements while mantaining good performance?
For now I could fit the model on a single GPU only by setting
batch_size=4
, but I fear this would lead to a bad model fit.Thank you!
The text was updated successfully, but these errors were encountered: