diff --git a/tutorials/video_classification_example/data.py b/tutorials/video_classification_example/data.py index 16eb3857..988e47fa 100644 --- a/tutorials/video_classification_example/data.py +++ b/tutorials/video_classification_example/data.py @@ -128,6 +128,12 @@ def _audio_transform(self): ) def _make_ds_and_loader(self, mode: str): + """Creates both the dataset and dataloader for a given dataset split 'mode'. This returns + both the dataset and the dataloader specified, and should be called from self.{train|val|test}_dataloader(). + + Args: + mode (str): The dataset split to create. Should be 'train' or 'val'. + """ ds = LimitDataset( labeled_video_dataset( data_path=str( @@ -259,6 +265,13 @@ def _make_ds_and_loader(self, mode: str): def download_and_unzip(url, data_dir="./", verify=True): + """Download a zip file from a given URL and unpack it within data_dir. + + Args: + url (str): A URL to a zip file. + data_dir (str, optional): Directory where the zip will be unpacked. Defaults to "./". + verify (bool, optional): Whether to verify SSL certificate when requesting the zip file. Defaults to True. + """ data_dir = Path(data_dir) zipfile_name = url.split("/")[-1] data_zip_path = data_dir / zipfile_name @@ -271,19 +284,3 @@ def download_and_unzip(url, data_dir="./", verify=True): f.write(resp.content) unpack_archive(data_zip_path, extract_dir=data_dir) - - -if __name__ == "__main__": - from finetune import parse_args - from train import LearningRateMonitor, VideoClassificationLightningModule - args = parse_args("--gpus 1 --precision 16 --batch_size 8 --data_path ./yt_data".split()) - args.max_epochs = 200 - args.callbacks = [LearningRateMonitor()] - args.replace_sampler_ddp = False - args.reload_dataloaders_every_epoch = False - - pytorch_lightning.trainer.seed_everything(244) - dm = UCF11DataModule(args) - model = VideoClassificationLightningModule(args) - trainer = pytorch_lightning.Trainer.from_argparse_args(args) - trainer.fit(model, dm) diff --git a/tutorials/video_classification_example/finetune.py b/tutorials/video_classification_example/finetune.py index cff7fa99..2aca590d 100644 --- a/tutorials/video_classification_example/finetune.py +++ b/tutorials/video_classification_example/finetune.py @@ -2,7 +2,7 @@ import pytorch_lightning as pl import torch -from .data import KineticsDataModule, MiniKineticsDataModule, UCF11DataModule +from data import KineticsDataModule, MiniKineticsDataModule, UCF11DataModule from pytorchvideo.models.head import create_res_basic_head from torch import nn from torch.optim import Adam @@ -16,8 +16,7 @@ class Classifier(pl.LightningModule): - """ - """ + def __init__( self, num_classes: int = 11, @@ -35,6 +34,8 @@ def __init__( lr (float, optional): The learning rate for the Adam optimizer. Defaults to 2e-4. freeze_backbone (bool, optional): Whether to freeze the backbone or leave it trainable. Defaults to True. pretrained (bool, optional): Use the pretrained model from torchhub. When False, we initialize the slow_r50 model from scratch. Defaults to True. + + All extra kwargs will be available via self.hparams.. These will also be saved as TensorBoard Hparams. """ super().__init__() self.save_hyperparameters() @@ -64,9 +65,23 @@ def __init__( self.accuracy = {"train": self.train_acc, "val": self.val_acc} def forward(self, x: torch.Tensor): + """ + Forward defines the prediction/inference actions. + """ return self.head(self.backbone(x)) def shared_step(self, batch, mode: str): + """This shared step handles both the training and validation steps to avoid + re-writing the same code more than once. The given `mode` will change the name + of the logged metrics. + + Args: + batch (dict): PyTorchVideo batch dictionary containing a single batch of data. + mode (str): The type of step. Can be 'train', 'val', or 'test'. + + Returns: + torch.Tensor: The loss for a single batch step. + """ y_hat = self(batch["video"]) loss = self.loss_fn(y_hat, batch["label"]) self.log(f"{mode}_loss", loss) @@ -79,9 +94,35 @@ def shared_step(self, batch, mode: str): return loss def training_step(self, batch, batch_idx): + """ + This function is called in the inner loop of the training epoch. It must + return a loss that is used for loss.backwards() internally. The self.log(...) + function can be used to log any training metrics. + + PyTorchVideo batches are dictionaries containing each modality or metadata of + the batch collated video clips. Kinetics contains the following notable keys: + { + 'video': , + 'audio': , + 'label': , + } + + - "video" is a Tensor of shape (batch, channels, time, height, Width) + - "audio" is a Tensor of shape (batch, channels, time, 1, frequency) + - "label" is a Tensor of shape (batch, 1) + + The PyTorchVideo models and transforms expect the same input shapes and + dictionary structure making this function just a matter of unwrapping the dict and + feeding it through the model/loss. + """ return self.shared_step(batch, "train") def validation_step(self, batch, batch_idx): + """ + This function is called in the inner loop of the evaluation cycle. For this + simple example it's mostly the same as the training loop but with a different + metric name. + """ return self.shared_step(batch, "val") def test_step(self, batch, batch_idx): @@ -133,13 +174,9 @@ def parse_args(args=None): parser.add_argument("--audio_logmel_mean", default=-7.03, type=float) parser.add_argument("--audio_logmel_std", default=4.66, type=float) - # Trainer parameters. + # Add PyTorch Lightning's Trainer init arguments as parser flags parser = pl.Trainer.add_argparse_args(parser) - parser.set_defaults( - max_epochs=200, - replace_sampler_ddp=False, - reload_dataloaders_every_epoch=False, - ) + return parser.parse_args(args=args)