Facing various issues with validation loop when using IterableDataset that implements __len__ #19413
Unanswered
arzaatri
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 1 comment
-
For anyone reading this, setting
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello all,
I'm trying to train a neural network with a tabular Parquet dataset which cannot fit into memory. As a solution, I've been using PyArrow to load one row at a time, leaving the Dataloader to handle batching. I decided to wrap this in a Pytorch
IterableDataset
which implements the__len__
method, so that Lightning's Trainer can maintain the notion of an epoch.I'll post code below, but broadly speaking, my structure is to create a generator for each parquet file using pyarrow's
iter_batches
method. I chain these together usingitertools.chain
and return an iterator for resulting chain in the__iter__
method of a pytorchIterableDataset
.In addition, as described in Pytorch's documentation, I use
torch.utils.data.get_worker_info
to assign each worker a subset of these parquet file generators, to avoid redundant data. I implement__len__
by iterating over the parquet files in my dataset and adding up the number of rows in eachHere's the code:
Using the datamodule in isolation, I've confirmed that the dataloader runs for the expected number of steps. However, when using the Trainer, I've experienced various issues with this setup:
ceiling(len(dataset) / batch_size)
steps, as expectedlen(dataset) / batch_size = 40
, and I usenum_workers=2
in my dataloader, the validation epoch only goes for 34 steps according to the progress bar)num_workers = 0
Most of my testing was performed with num_workers=1 for simplicity, but all of the above issues still occur. Also, I've run this code using a standard Dataset which loads a subset of data into memory and everything worked fine, which has made me stop investigating
format_data
,get_s3_uris
, or the parquet files themselves as a culprits.Can anyone lend a hand? Am I doing something wrong, forgetting some setting, or is there an issue in how Lightning treats finite-length
IterableDataset
s?(I had a thought that maybe the shortened val epochs have to do with sanity checking. These checks run two steps, but I'm not sure if they're run once per process + once at the start of training. I don't believe this would address train dataloaders, though, unless sanity checking has some hidden interaction with train dataloaders. I suspect some overall inconsistency in resetting iterators at the appropriate time.)
EDIT: I didn't get any of this issues in pure Pytorch. Gonna test a bit more and make this an issue once I'm sure that it's a Lightning thing
Beta Was this translation helpful? Give feedback.
All reactions