Facing various issues with validation loop when using IterableDataset that implements __len__ #19413
Unanswered
arzaatri
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 1 comment
-
For anyone reading this, setting
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello all,
I'm trying to train a neural network with a tabular Parquet dataset which cannot fit into memory. As a solution, I've been using PyArrow to load one row at a time, leaving the Dataloader to handle batching. I decided to wrap this in a Pytorch
IterableDataset
which implements the__len__
method, so that Lightning's Trainer can maintain the notion of an epoch.I'll post code below, but broadly speaking, my structure is to create a generator for each parquet file using pyarrow's
iter_batches
method. I chain these together usingitertools.chain
and return an iterator for resulting chain in the__iter__
method of a pytorchIterableDataset
.In addition, as described in Pytorch's documentation, I use
torch.utils.data.get_worker_info
to assign each worker a subset of these parquet file generators, to avoid redundant data. I implement__len__
by iterating over the parquet files in my dataset and adding up the number of rows in eachHere's the code:
Using the datamodule in isolation, I've confirmed that the dataloader runs for the expected number of steps. However, when using the Trainer, I've experienced various issues with this setup:
ceiling(len(dataset) / batch_size)
steps, as expectedlen(dataset) / batch_size = 40
, and I usenum_workers=2
in my dataloader, the validation epoch only goes for 34 steps according to the progress bar)num_workers = 0
Most of my testing was performed with num_workers=1 for simplicity, but all of the above issues still occur. Also, I've run this code using a standard Dataset which loads a subset of data into memory and everything worked fine, which has made me stop investigating
format_data
,get_s3_uris
, or the parquet files themselves as a culprits.Can anyone lend a hand? Am I doing something wrong, forgetting some setting, or is there an issue in how Lightning treats finite-length
IterableDataset
s?(I had a thought that maybe the shortened val epochs have to do with sanity checking. These checks run two steps, but I'm not sure if they're run once per process + once at the start of training. I don't believe this would address train dataloaders, though, unless sanity checking has some hidden interaction with train dataloaders. I suspect some overall inconsistency in resetting iterators at the appropriate time.)
EDIT: I didn't get any of this issues in pure Pytorch. Gonna test a bit more and make this an issue once I'm sure that it's a Lightning thing
Beta Was this translation helpful? Give feedback.
All reactions