pip install lightning-tutorial
-
PyTorch Datasets and DataLoaders
- Key module:
torch.utils.data.Dataset
- Key module:
torch.utils.data.DataLoader
- Other essential functions
- Key module:
The Dataset
module is an overwritable python module. You can modify it at will as long as you maintain the following three class methods:
__init__
__len__
__getitem__
These are name-specific handles used by torch
under the hood when passing data through a model.
from torch.utils.data import Dataset
class TurtleData(Dataset):
def __init__(self):
"""
here we should pass requisite arguments
that enable __len__() and __getitem__()
"""
def __len__(self):
"""
Returns the length/size/# of samples in the dataset.
e.g., a 20,000 cell dataset would return `20_000`.
"""
return # len
def __getitem__(self, idx):
"""
Subset and return a batch of the data.
`idx` is the batch index (# of idx values = batch size).
Maximum `idx` passed is <= `self.__len__()`
"""
return # sampled data
-
Try it for yourself! Colab
Dataset
tutorial notebook
Similar to the usefulness of AnnData
, the Dataset
module creates a base unit for distributing and handling data. We can then take advantage of several torch built-ins to enable not only more organized, but faster data processing.
from torch.utils.data import DataLoader
dataset = TurtleData()
data_size = dataset.__len__()
print(data_size)
20_000
from torch.utils.data import random_split
train_dataset, val_dataset = random_split(dataset, [18_000, 2_000])
# this can then be fed to a DataLoader, as above
train_loader = DataLoader(train_dataset)
val_loader = DataLoader(val_dataset)
- Parent module:
torch.utils.data
- Datasets and DataLoaders tutorial
Create pytorch Datasets from AnnData
- Note: This is already done for you, if you've installed this tutorials associated package
pip install torch-adata
The base class, AnnDataset
is a subclass of the widely-used torch.utils.data.Dataset
.
import anndata as a
import torch_adata
adata = a.read_h5ad("/path/to/data.h5ad")
dataset = torch_adata.AnnDataset(adata)
Returns sampled data X_batch
as a torch.Tensor
.
# create a dummy index
idx = np.random.choice(range(dataset.__len__()), 5)
X_batch = dataset.__getitem__(idx)
Specialized class for time-resolved datasets. A subclass of the class, AnnDataset
.
import anndata as a
import torch_adata as ta
adata = a.read_h5ad("/path/to/data.h5ad")
dataset = torch_adata.TimeResolvedAnnDataset(adata, time_key="Time point")
Lightning basics and the LightningModule
from pytorch_lightning imoport LightningModule
class YourSOTAModel(LightningModule):
def __init__(self,
net,
optimizer_kwargs={"lr":1e-3},
scheduler_kwargs={},
):
super().__init__()
self.net = net
self.optimizer_kwargs = optimizer_kwargs
self.scheduler_kwargs = scheduler_kwargs
def forward(self, batch):
x, y = batch
y_hat = self.net(x)
loss = LossFunc(y_hat, y)
return y_hat, loss
def training_step(self, batch, batch_idx):
y_hat, loss = self.forward(batch)
return loss.sum()
def validation_step(self, batch, batch_idx):
y_hat, loss = self.forward(batch)
return loss.sum()
def test_step(self, batch, batch_idx):
y_hat, loss = self.forward(batch)
return loss.sum()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), **self._optim_kwargs)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer(), **self._scheduler_kwargs)
return [optimizer, ...], [scheduler, ...]
- Try it for yourself! Lightning Classifier tutorial notebook
Purpose: Make your model independent of a given dataset, while at the same time making your dataset reproducible and perhaps just as important: easily shareable.
from pytorch_lightning import LightningDataModule
from torch.data.utils import DataLoader
class YourDataModule(LightningDataModule):
def __init__(self):
# define any setup computations
def prepare_data(self):
# download data if applicable
def setup(self, stage):
# assign data to `Dataset`(s)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
- Try it for yourself! LightningDataModule tutorial notebook
When it comes to actually using one of these, it looks something like the following:
# Init the LightningDataModule as well as the LightningModel
data = YourDataModule()
model = YourLightningModel()
# Define trainer
trainer = Trainer(accelerator="auto", devices=1)
# Ultimately, both model and data are passed as an arg to trainer.fit
trainer.fit(model, data)
-
Try it for yourself! LightningGAN tutorial notebook
Here's an example of a LightningDataModule
implemented in practice, using the LARRY single-cell dataset: link. Initial downloading and formatting occurs only once but takes several minutes so we will leave it outside the scope of this tutorial.
I'd love to get in touch. Send me an email or open an issue!
⚡