diff --git a/tutorials/video_classification_example/data.py b/tutorials/video_classification_example/data.py new file mode 100644 index 00000000..49972d9b --- /dev/null +++ b/tutorials/video_classification_example/data.py @@ -0,0 +1,344 @@ +import itertools +from pathlib import Path +from random import shuffle +from shutil import unpack_archive +from typing import Tuple + +import pytorch_lightning as pl +import requests +import torch +from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler +from pytorchvideo.data.labeled_video_dataset import labeled_video_dataset +from pytorchvideo.transforms import ( + ApplyTransformToKey, + Normalize, + RandomShortSideScale, + ShortSideScale, + UniformTemporalSubsample, +) +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler +from torchvision.transforms import ( + CenterCrop, + Compose, + Lambda, + RandomCrop, + RandomHorizontalFlip, +) + + +class LabeledVideoDataModule(pl.LightningDataModule): + + SOURCE_URL: str = None + SOURCE_DIR_NAME: str = "" + NUM_CLASSES: int = 700 + VERIFY_SSL: bool = True + + def __init__( + self, + root: str = "./", + clip_duration: int = 2, + video_num_subsampled: int = 8, + video_crop_size: int = 224, + video_means: Tuple[float] = (0.45, 0.45, 0.45), + video_stds: Tuple[float] = (0.225, 0.225, 0.225), + video_min_short_side_scale: int = 256, + video_max_short_side_scale: int = 320, + video_horizontal_flip_p: float = 0.5, + batch_size: int = 4, + workers: int = 4, + **kwargs + ): + """ + A LabeledVideoDataModule expects a dataset in the following format: + + /root # Root Folder + ├── train # Split Folder + │ ├── archery # Class Folder + │ │ ├── -1q7jA3DXQM_000005_000015.mp4 # Videos + │ │ ├── -5NN5hdIwTc_000036_000046.mp4 + │ │ ... + │ ├── bowling + │ │ ├── -5ExwuF5IUI_000030_000040.mp4 + │ │ ... + │ ├── high_jump + │ │ ├── -5ExwuF5IUI_000030_000040.mp4 + │ │ ... + ├── val + │ ├── archery + │ │ ├── -1q7jA3DXQM_000005_000015.mp4 + │ │ ├── -5NN5hdIwTc_000036_000046.mp4 + │ │ ... + │ ├── bowling + │ │ ├── -5ExwuF5IUI_000030_000040.mp4 + │ │ ... + + Args: + root (str, optional): Directory where your dataset is stored. Defaults to "./". + clip_duration (int, optional): Duration of clip samples. Defaults to 2. + video_num_subsampled (int, optional): Number of subsamples to take of individual videos. Defaults to 8. + video_crop_size (int, optional): Size to crop the video to. Defaults to 224. + video_means (Tuple[float], optional): Means used to normalize dataset. Defaults to (0.45, 0.45, 0.45). + video_stds (Tuple[float], optional): Standard deviations used to normalized dataset. Defaults to (0.225, 0.225, 0.225). + video_min_short_side_scale (int, optional): min_size arg passed to pytorchvideo.transforms.RandomShortSideScale. Defaults to 256. + video_max_short_side_scale (int, optional): max_size arg passed to pytorchvideo.transforms.RandomShortSideScale. Defaults to 320. + video_horizontal_flip_p (float, optional): Probability of flipping a training example horizontally. Defaults to 0.5. + batch_size (int, optional): Number of examples per batch. Defaults to 4. + workers (int, optional): Number of DataLoader workers. Defaults to 4. + """ + + super().__init__() + self.root = root + self.data_path = Path(self.root) / self.SOURCE_DIR_NAME + self.clip_duration = clip_duration + self.video_num_subsampled = video_num_subsampled + self.video_crop_size = video_crop_size + self.video_means = video_means + self.video_stds = video_stds + self.video_min_short_side_scale = video_min_short_side_scale + self.video_max_short_side_scale = video_max_short_side_scale + self.video_horizontal_flip_p = video_horizontal_flip_p + self.batch_size = batch_size + self.workers = workers + + # Transforms applied to train dataset. + self.train_transform = ApplyTransformToKey( + key="video", + transform=Compose( + [ + UniformTemporalSubsample(self.video_num_subsampled), + Lambda(lambda x: x / 255.0), + Normalize(self.video_means, self.video_stds), + RandomShortSideScale( + min_size=self.video_min_short_side_scale, + max_size=self.video_max_short_side_scale, + ), + RandomCrop(self.video_crop_size), + RandomHorizontalFlip(p=self.video_horizontal_flip_p), + ] + ), + ) + + # Transforms applied on val dataset or for inference. + self.val_transform = ApplyTransformToKey( + key="video", + transform=Compose( + [ + UniformTemporalSubsample(self.video_num_subsampled), + Lambda(lambda x: x / 255.0), + Normalize(self.video_means, self.video_stds), + ShortSideScale(self.video_min_short_side_scale), + CenterCrop(self.video_crop_size), + ] + ), + ) + + def prepare_data(self): + """Download the dataset if it doesn't already exist. This runs only on rank 0""" + if not (self.SOURCE_URL is None or self.SOURCE_DIR_NAME is None): + if not self.data_path.exists(): + download_and_unzip(self.SOURCE_URL, self.root, verify=self.VERIFY_SSL) + + def train_dataloader(self): + do_use_ddp = self.trainer is not None and self.trainer.use_ddp + self.train_dataset = LimitDataset( + labeled_video_dataset( + data_path=str(Path(self.data_path) / "train"), + clip_sampler=make_clip_sampler("random", self.clip_duration), + transform=self.train_transform, + decode_audio=False, + video_sampler=DistributedSampler if do_use_ddp else RandomSampler, + ) + ) + return DataLoader( + self.train_dataset, batch_size=self.batch_size, num_workers=self.workers + ) + + def val_dataloader(self): + do_use_ddp = self.trainer is not None and self.trainer.use_ddp + self.val_dataset = LimitDataset( + labeled_video_dataset( + data_path=str(Path(self.data_path) / "val"), + clip_sampler=make_clip_sampler("uniform", self.clip_duration), + transform=self.val_transform, + decode_audio=False, + video_sampler=DistributedSampler if do_use_ddp else RandomSampler, + ) + ) + return DataLoader( + self.val_dataset, batch_size=self.batch_size, num_workers=self.workers + ) + + +class UCF11DataModule(LabeledVideoDataModule): + + SOURCE_URL: str = "https://www.crcv.ucf.edu/data/YouTube_DataSet_Annotated.zip" + SOURCE_DIR_NAME: str = "action_youtube_naudio" + NUM_CLASSES: int = 11 + VERIFY_SSL: bool = False + + def __init__(self, **kwargs): + """ + The UCF11 Dataset contains 11 action classes: basketball shooting, biking/cycling, diving, + golf swinging, horse back riding, soccer juggling, swinging, tennis swinging, trampoline jumping, + volleyball spiking, and walking with a dog. + + For each class, the videos are grouped into 25 group/scene folders containing at least 4 video clips each. + The video clips in the same scene folder share some common features, such as the same actor, similar + background, similar viewpoint, and so on. + + The folder structure looks like the following: + + /root/action_youtube_naudio + ├── basketball # Class Folder Path + │ ├── v_shooting_01 # Scene/Group Folder Path + │ │ ├── v_shooting_01_01.avi # Video Path + │ │ ├── v_shooting_01_02.avi + │ │ ├── v_shooting_01_03.avi + │ │ ├── ... + │ ├── v_shooting_02 + │ ├── v_shooting_03 + │ ├── ... + │ ... + ├── biking + │ ├── v_biking_01 + │ │ ├── v_biking_01_01.avi + │ │ ├── v_biking_01_02.avi + │ │ ├── v_biking_01_03.avi + │ ├── v_biking_02 + │ ├── v_biking_03 + │ ... + ... + + We take 80% of all scenes and use the videos within for training. The remaining scenes' videos + are used for validation. We do this so the validation data contains only videos from scenes/actors + that the model has not seen yet. + """ + super().__init__(**kwargs) + + def setup(self, stage: str = None): + """Set up anything needed for initializing train/val datasets. This runs on all nodes.""" + + # Names of classes to predict. + # Ex. ['basketball', 'biking', 'diving', ...] + self.classes = sorted(x.name for x in self.data_path.glob("*") if x.is_dir()) + + # Mapping from label to class id. + # Ex. {'basketball': 0, 'biking': 1, 'diving': 2, ...} + self.label_to_id = {} + + # A list to hold all available scenes across all classes. + scene_folders = [] + + for class_id, class_name in enumerate(self.classes): + + self.label_to_id[class_name] = class_id + + # The path of a class folder within self.data_path. + # Ex. 'action_youtube_naudio/{basketball|biking|diving|...}' + class_folder = self.data_path / class_name + + # Collect scene folders within this class. + # Ex. 'action_youtube_naudio/basketball/v_shooting_01' + for scene_folder in filter(Path.is_dir, class_folder.glob("v_*")): + scene_folders.append(scene_folder) + + # Randomly shuffle the scene folders before splitting them into train/val. + shuffle(scene_folders) + + # Determine number of scenes in train/validation splits. + self.num_train_scenes = int(0.8 * len(scene_folders)) + self.num_val_scenes = len(scene_folders) - self.num_train_scenes + + # Collect train/val paths to videos within each scene folder. + # Validation only uses videos from scenes not seen by model during training. + self.train_paths = [] + self.val_paths = [] + for i, scene_path in enumerate(scene_folders): + + # The actual name of the class (Ex. 'basketball'). + class_name = scene_path.parent.name + + # Loop over all the videos within the given scene folder. + for video_path in scene_path.glob("*.avi"): + + # Construct a tuple containing (, ). + # In our case, we assign the class's ID as 'label'. + labeled_path = (video_path, {"label": self.label_to_id[class_name]}) + + if i < self.num_train_scenes: + self.train_paths.append(labeled_path) + else: + self.val_paths.append(labeled_path) + + def train_dataloader(self): + self.train_dataset = LimitDataset( + LabeledVideoDataset( + self.train_paths, + clip_sampler=make_clip_sampler("random", self.clip_duration), + decode_audio=False, + transform=self.train_transform, + video_sampler=RandomSampler, + ) + ) + return DataLoader( + self.train_dataset, batch_size=self.batch_size, num_workers=self.workers + ) + + def val_dataloader(self): + self.val_dataset = LimitDataset( + LabeledVideoDataset( + self.val_paths, + clip_sampler=make_clip_sampler("uniform", self.clip_duration), + decode_audio=False, + transform=self.val_transform, + video_sampler=RandomSampler, + ) + ) + return DataLoader( + self.val_dataset, batch_size=self.batch_size, num_workers=self.workers + ) + + +def download_and_unzip(url, data_dir="./", verify=True): + """Download a zip file from a given URL and unpack it within data_dir. + + Args: + url (str): A URL to a zip file. + data_dir (str, optional): Directory where the zip will be unpacked. Defaults to "./". + verify (bool, optional): Whether to verify SSL certificate when requesting the zip file. Defaults to True. + """ + data_dir = Path(data_dir) + zipfile_name = url.split("/")[-1] + data_zip_path = data_dir / zipfile_name + data_dir.mkdir(exist_ok=True, parents=True) + + if not data_zip_path.exists(): + resp = requests.get(url, verify=verify) + + with data_zip_path.open("wb") as f: + f.write(resp.content) + + unpack_archive(data_zip_path, extract_dir=data_dir) + + +class LimitDataset(torch.utils.data.Dataset): + + """ + To ensure a constant number of samples are retrieved from the dataset we use this + LimitDataset wrapper. This is necessary because several of the underlying videos + may be corrupted while fetching or decoding, however, we always want the same + number of steps per epoch. + """ + + def __init__(self, dataset): + super().__init__() + self.dataset = dataset + self.dataset_iter = itertools.chain.from_iterable( + itertools.repeat(iter(dataset), 2) + ) + + def __getitem__(self, index): + return next(self.dataset_iter) + + def __len__(self): + return self.dataset.num_videos diff --git a/tutorials/video_classification_example/finetune.py b/tutorials/video_classification_example/finetune.py new file mode 100644 index 00000000..0dd05734 --- /dev/null +++ b/tutorials/video_classification_example/finetune.py @@ -0,0 +1,39 @@ +import pytorch_lightning as pl +from data import UCF11DataModule +from models import SlowResnet50LightningModel +from train import parse_args + + +def train(args): + pl.seed_everything(224) + dm = UCF11DataModule(**vars(args)) + model = SlowResnet50LightningModel(num_classes=dm.NUM_CLASSES, **vars(args)) + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model, dm) + + +def main(): + args = parse_args() + if args.on_cluster: + from slurm import copy_and_run_with_config + + copy_and_run_with_config( + train, + args, + args.working_directory, + job_name=args.job_name, + time="72:00:00", + partition=args.partition, + gpus_per_node=args.gpus, + ntasks_per_node=args.gpus, + cpus_per_task=10, + mem="470GB", + nodes=args.num_nodes, + constraint="volta32gb", + ) + else: # local + train(args) + + +if __name__ == "__main__": + main() diff --git a/tutorials/video_classification_example/models.py b/tutorials/video_classification_example/models.py new file mode 100644 index 00000000..2ec97fa6 --- /dev/null +++ b/tutorials/video_classification_example/models.py @@ -0,0 +1,139 @@ +import pytorch_lightning as pl +import torch +from pytorchvideo.models.head import create_res_basic_head +from pytorchvideo.models.resnet import create_resnet +from torch import nn + + +class VideoClassificationLightningModule(pl.LightningModule): + def __init__(self, num_classes: int = 11, lr: float = 2e-4, **kwargs): + """A classifier for finetuning pretrained video classification backbones from + torchhub. We use the slow_r50 model here, but you can edit this class to + use whatever backbone/head you'd like. + + Args: + num_classes (int, optional): Number of output classes. Defaults to 11. + lr (float, optional): The learning rate for the Adam optimizer. Defaults to 2e-4. + freeze_backbone (bool, optional): Whether to freeze the backbone or leave it trainable. Defaults to True. + pretrained (bool, optional): Use the pretrained model from torchhub. When False, we initialize the + slow_r50 model from scratch. Defaults to True. + + All extra kwargs will be available via self.hparams.. These will also be saved as + TensorBoard Hparams. + """ + super().__init__() + + # Saves all kwargs to self.hparams. Use references to self.hparams., not the init args themselves. + self.save_hyperparameters() + + # Build the model in separate function so its easier to override. + self.model = self._build_model() + + # Metrics we will keep track of. + self.loss_fn = nn.CrossEntropyLoss() + self.train_acc = pl.metrics.Accuracy() + self.val_acc = pl.metrics.Accuracy() + self.accuracy = {"train": self.train_acc, "val": self.val_acc} + + def _build_model(self): + return create_resnet(model_num_class=self.hparams.num_classes) + + def on_train_epoch_start(self): + """ + For distributed training we need to set the datasets video sampler epoch so + that shuffling is done correctly + """ + epoch = self.trainer.current_epoch + if self.trainer.use_ddp: + self.trainer.datamodule.train_dataset.dataset.video_sampler.set_epoch(epoch) + + def forward(self, x: torch.Tensor): + """ + Forward defines the prediction/inference actions. + """ + return self.model(x) + + def shared_step(self, batch, mode: str): + """This shared step handles both the training and validation steps to avoid + re-writing the same code more than once. The given `mode` will change the name + of the logged metrics. + + PyTorchVideo batches are dictionaries containing each modality or metadata of + the batch collated video clips. Kinetics contains the following notable keys: + { + 'video': , + 'label': , + } + + - "video" is a Tensor of shape (batch, channels, time, height, Width) + - "label" is a Tensor of shape (batch, 1) + + The PyTorchVideo models and transforms expect the same input shapes and + dictionary structure making this function just a matter of unwrapping the dict and + feeding it through the model/loss. + + Args: + batch (dict): PyTorchVideo batch dictionary containing a single batch of data. + mode (str): The type of step. Can be 'train', 'val', or 'test'. + + Returns: + torch.Tensor: The loss for a single batch step. + """ + + outputs = self(batch["video"]) + + loss = self.loss_fn(outputs, batch["label"]) + self.log(f"{mode}_loss", loss) + + proba = outputs.softmax(dim=1) + preds = proba.argmax(dim=1) + + acc = self.accuracy[mode](preds, batch["label"]) + self.log(f"{mode}_acc", acc, prog_bar=True) + + return loss + + def training_step(self, batch, batch_idx): + """ + This function is called in the inner loop of the training epoch. It must + return a loss that is used for loss.backwards() internally. + """ + return self.shared_step(batch, "train") + + def validation_step(self, batch, batch_idx): + """ + This function is called in the inner loop of the evaluation cycle. For this + simple example it's mostly the same as the training loop but with a different + metric name. + """ + return self.shared_step(batch, "val") + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.hparams.lr) + + +class SlowResnet50LightningModel(VideoClassificationLightningModule): + def __init__(self, freeze_backbone: bool = True, pretrained: bool = True, **kwargs): + super().__init__( + freeze_backbone=freeze_backbone, pretrained=pretrained, **kwargs + ) + + def _build_model(self): + # The pretrained resnet model - we strip off its head to get the backbone. + resnet = torch.hub.load( + "facebookresearch/pytorchvideo", + "slow_r50", + pretrained=self.hparams.pretrained, + ) + self.backbone = nn.Sequential(*list(resnet.children())[0][:-1]) + + # Freeze the backbone layers if specified. + if self.hparams.freeze_backbone: + for param in self.backbone.parameters(): + param.requires_grad = False + + # Create a new head we will train on top of the backbone. + self.head = create_res_basic_head( + in_features=2048, out_features=self.hparams.num_classes + ) + return nn.Sequential(self.backbone, self.head) diff --git a/tutorials/video_classification_example/train.py b/tutorials/video_classification_example/train.py index b2d896ba..8568b9c2 100644 --- a/tutorials/video_classification_example/train.py +++ b/tutorials/video_classification_example/train.py @@ -1,379 +1,12 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from argparse import ArgumentParser -import argparse -import itertools -import logging -import os +import pytorch_lightning as pl +from data import LabeledVideoDataModule +from models import VideoClassificationLightningModule -import pytorch_lightning -import pytorchvideo.data -import pytorchvideo.models.resnet -import torch -import torch.nn.functional as F -from pytorch_lightning.callbacks import LearningRateMonitor -from pytorchvideo.transforms import ( - ApplyTransformToKey, - Normalize, - RandomShortSideScale, - RemoveKey, - ShortSideScale, - UniformTemporalSubsample, -) -from slurm import copy_and_run_with_config -from torch.utils.data import DistributedSampler, RandomSampler -from torchaudio.transforms import MelSpectrogram, Resample -from torchvision.transforms import ( - CenterCrop, - Compose, - Lambda, - RandomCrop, - RandomHorizontalFlip, -) - -""" -This video classification example demonstrates how PyTorchVideo models, datasets and -transforms can be used with PyTorch Lightning module. Specifically it shows how a -simple pipeline to train a Resnet on the Kinetics video dataset can be built. - -Don't worry if you don't have PyTorch Lightning experience. We'll provide an explanation -of how the PyTorch Lightning module works to accompany the example. - -The code can be separated into three main components: -1. VideoClassificationLightningModule (pytorch_lightning.LightningModule), this defines: - - how the model is constructed, - - the inner train or validation loop (i.e. computing loss/metrics from a minibatch) - - optimizer configuration - -2. KineticsDataModule (pytorch_lightning.LightningDataModule), this defines: - - how to fetch/prepare the dataset - - the train and val dataloaders for the associated dataset - -3. pytorch_lightning.Trainer, this is a concrete PyTorch Lightning class that provides - the training pipeline configuration and a fit(, ) - function to start the training/validation loop. - -All three components are combined in the train() function. We'll explain the rest of the -details inline. -""" - - -class VideoClassificationLightningModule(pytorch_lightning.LightningModule): - def __init__(self, args): - """ - This LightningModule implementation constructs a PyTorchVideo ResNet, - defines the train and val loss to be trained with (cross_entropy), and - configures the optimizer. - """ - self.args = args - super().__init__() - self.train_accuracy = pytorch_lightning.metrics.Accuracy() - self.val_accuracy = pytorch_lightning.metrics.Accuracy() - - ############# - # PTV Model # - ############# - - # Here we construct the PyTorchVideo model. For this example we're using a - # ResNet that works with Kinetics (e.g. 400 num_classes). For your application, - # this could be changed to any other PyTorchVideo model (e.g. for SlowFast use - # create_slowfast). - if self.args.arch == "video_resnet": - self.model = pytorchvideo.models.resnet.create_resnet( - input_channel=3, - model_num_class=400, - ) - self.batch_key = "video" - elif self.args.arch == "audio_resnet": - self.model = pytorchvideo.models.resnet.create_acoustic_resnet( - input_channel=1, - model_num_class=400, - ) - self.batch_key = "audio" - else: - raise Exception("{self.args.arch} not supported") - - def on_train_epoch_start(self): - """ - For distributed training we need to set the datasets video sampler epoch so - that shuffling is done correctly - """ - epoch = self.trainer.current_epoch - if self.trainer.use_ddp: - self.trainer.datamodule.train_dataset.dataset.video_sampler.set_epoch(epoch) - - def forward(self, x): - """ - Forward defines the prediction/inference actions. - """ - return self.model(x) - - def training_step(self, batch, batch_idx): - """ - This function is called in the inner loop of the training epoch. It must - return a loss that is used for loss.backwards() internally. The self.log(...) - function can be used to log any training metrics. - - PyTorchVideo batches are dictionaries containing each modality or metadata of - the batch collated video clips. Kinetics contains the following notable keys: - { - 'video': , - 'audio': , - 'label': , - } - - - "video" is a Tensor of shape (batch, channels, time, height, Width) - - "audio" is a Tensor of shape (batch, channels, time, 1, frequency) - - "label" is a Tensor of shape (batch, 1) - - The PyTorchVideo models and transforms expect the same input shapes and - dictionary structure making this function just a matter of unwrapping the dict and - feeding it through the model/loss. - """ - x = batch[self.batch_key] - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, batch["label"]) - acc = self.train_accuracy(F.softmax(y_hat, dim=-1), batch["label"]) - self.log("train_loss", loss) - self.log( - "train_acc", acc, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True - ) - return loss - - def validation_step(self, batch, batch_idx): - """ - This function is called in the inner loop of the evaluation cycle. For this - simple example it's mostly the same as the training loop but with a different - metric name. - """ - x = batch[self.batch_key] - y_hat = self.model(x) - loss = F.cross_entropy(y_hat, batch["label"]) - acc = self.val_accuracy(F.softmax(y_hat, dim=-1), batch["label"]) - self.log("val_loss", loss) - self.log( - "val_acc", acc, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True - ) - return loss - - def configure_optimizers(self): - """ - We use the SGD optimizer with per step cosine annealing scheduler. - """ - optimizer = torch.optim.SGD( - self.parameters(), - lr=self.args.lr, - momentum=self.args.momentum, - weight_decay=self.args.weight_decay, - ) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, self.args.max_epochs, last_epoch=-1 - ) - return [optimizer], [scheduler] - - -class KineticsDataModule(pytorch_lightning.LightningDataModule): - """ - This LightningDataModule implementation constructs a PyTorchVideo Kinetics dataset for both - the train and val partitions. It defines each partition's augmentation and - preprocessing transforms and configures the PyTorch DataLoaders. - """ - - def __init__(self, args): - self.args = args - super().__init__() - - def _make_transforms(self, mode: str): - """ - ################## - # PTV Transforms # - ################## - - # Each PyTorchVideo dataset has a "transform" arg. This arg takes a - # Callable[[Dict], Any], and is used on the output Dict of the dataset to - # define any application specific processing or augmentation. Transforms can - # either be implemented by the user application or reused from any library - # that's domain specific to the modality. E.g. for video we recommend using - # TorchVision, for audio we recommend TorchAudio. - # - # To improve interoperation between domain transform libraries, PyTorchVideo - # provides a dictionary transform API that provides: - # - ApplyTransformToKey(key, transform) - applies a transform to specific modality - # - RemoveKey(key) - remove a specific modality from the clip - # - # In the case that the recommended libraries don't provide transforms that - # are common enough for PyTorchVideo use cases, PyTorchVideo will provide them in - # the same structure as the recommended library. E.g. TorchVision didn't - # have a RandomShortSideScale video transform so it's been added to PyTorchVideo. - """ - if self.args.data_type == "video": - transform = [ - self._video_transform(mode), - RemoveKey("audio"), - ] - elif self.args.data_type == "audio": - transform = [ - self._audio_transform(), - RemoveKey("video"), - ] - else: - raise Exception(f"{self.args.data_type} not supported") - - return Compose(transform) - - def _video_transform(self, mode: str): - """ - This function contains example transforms using both PyTorchVideo and TorchVision - in the same Callable. For 'train' mode, we use augmentations (prepended with - 'Random'), for 'val' mode we use the respective determinstic function. - """ - args = self.args - return ApplyTransformToKey( - key="video", - transform=Compose( - [ - UniformTemporalSubsample(args.video_num_subsampled), - Normalize(args.video_means, args.video_stds), - ] - + ( - [ - RandomShortSideScale( - min_size=args.video_min_short_side_scale, - max_size=args.video_max_short_side_scale, - ), - RandomCrop(args.video_crop_size), - RandomHorizontalFlip(p=args.video_horizontal_flip_p), - ] - if mode == "train" - else [ - ShortSideScale(args.video_min_short_side_scale), - CenterCrop(args.video_crop_size), - ] - ) - ), - ) - - def _audio_transform(self): - """ - This function contains example transforms using both PyTorchVideo and TorchAudio - in the same Callable. - """ - args = self.args - n_fft = int( - float(args.audio_resampled_rate) / 1000 * args.audio_mel_window_size - ) - hop_length = int( - float(args.audio_resampled_rate) / 1000 * args.audio_mel_step_size - ) - eps = 1e-10 - return ApplyTransformToKey( - key="audio", - transform=Compose( - [ - Resample( - orig_freq=args.audio_raw_sample_rate, - new_freq=args.audio_resampled_rate, - ), - MelSpectrogram( - sample_rate=args.audio_resampled_rate, - n_fft=n_fft, - hop_length=hop_length, - n_mels=args.audio_num_mels, - center=False, - ), - Lambda(lambda x: x.clamp(min=eps)), - Lambda(torch.log), - UniformTemporalSubsample(args.audio_mel_num_subsample), - Lambda(lambda x: x.transpose(1, 0)), # (F, T) -> (T, F) - Lambda( - lambda x: x.view(1, x.size(0), 1, x.size(1)) - ), # (T, F) -> (1, T, 1, F) - Normalize((args.audio_logmel_mean,), (args.audio_logmel_std,)), - ] - ), - ) - - def train_dataloader(self): - """ - Defines the train DataLoader that the PyTorch Lightning Trainer trains/tests with. - """ - sampler = DistributedSampler if self.trainer.use_ddp else RandomSampler - train_transform = self._make_transforms(mode="train") - self.train_dataset = LimitDataset( - pytorchvideo.data.Kinetics( - data_path=os.path.join(self.args.data_path, "train.csv"), - clip_sampler=pytorchvideo.data.make_clip_sampler( - "random", self.args.clip_duration - ), - video_path_prefix=self.args.video_path_prefix, - transform=train_transform, - video_sampler=sampler, - ) - ) - return torch.utils.data.DataLoader( - self.train_dataset, - batch_size=self.args.batch_size, - num_workers=self.args.workers, - ) - - def val_dataloader(self): - """ - Defines the train DataLoader that the PyTorch Lightning Trainer trains/tests with. - """ - sampler = DistributedSampler if self.trainer.use_ddp else RandomSampler - val_transform = self._make_transforms(mode="val") - self.val_dataset = LimitDataset( - pytorchvideo.data.Kinetics( - data_path=os.path.join(self.args.data_path, "val.csv"), - clip_sampler=pytorchvideo.data.make_clip_sampler( - "uniform", self.args.clip_duration - ), - video_path_prefix=self.args.video_path_prefix, - transform=val_transform, - video_sampler=sampler, - ) - ) - return torch.utils.data.DataLoader( - self.val_dataset, - batch_size=self.args.batch_size, - num_workers=self.args.workers, - ) - - -class LimitDataset(torch.utils.data.Dataset): - """ - To ensure a constant number of samples are retrieved from the dataset we use this - LimitDataset wrapper. This is necessary because several of the underlying videos - may be corrupted while fetching or decoding, however, we always want the same - number of steps per epoch. - """ - - def __init__(self, dataset): - super().__init__() - self.dataset = dataset - self.dataset_iter = itertools.chain.from_iterable( - itertools.repeat(iter(dataset), 2) - ) - - def __getitem__(self, index): - return next(self.dataset_iter) - - def __len__(self): - return self.dataset.num_videos() - - -def main(): - """ - To train the ResNet with the Kinetics dataset we construct the two modules above, - and pass them to the fit function of a pytorch_lightning.Trainer. - - This example can be run either locally (with default parameters) or on a Slurm - cluster. To run on a Slurm cluster provide the --on_cluster argument. - """ - setup_logger() - - pytorch_lightning.trainer.seed_everything() - parser = argparse.ArgumentParser() +def parse_args(args=None): + parser = ArgumentParser() # Cluster parameters. parser.add_argument("--on_cluster", action="store_true") @@ -381,55 +14,35 @@ def main(): parser.add_argument("--working_directory", default=".", type=str) parser.add_argument("--partition", default="dev", type=str) - # Model parameters. - parser.add_argument("--lr", "--learning-rate", default=0.1, type=float) - parser.add_argument("--momentum", default=0.9, type=float) - parser.add_argument("--weight_decay", default=1e-4, type=float) - parser.add_argument( - "--arch", - default="video_resnet", - choices=["video_resnet", "audio_resnet"], - type=str, - ) + # Model Parameters. + parser.add_argument("--lr", "--learning_rate", default=2e-4, type=float) - # Data parameters. - parser.add_argument("--data_path", default=None, type=str, required=True) - parser.add_argument("--video_path_prefix", default="", type=str) - parser.add_argument("--workers", default=8, type=int) - parser.add_argument("--batch_size", default=32, type=int) - parser.add_argument("--clip_duration", default=2, type=float) - parser.add_argument( - "--data_type", default="video", choices=["video", "audio"], type=str - ) - parser.add_argument("--video_num_subsampled", default=8, type=int) - parser.add_argument("--video_means", default=(0.45, 0.45, 0.45), type=tuple) - parser.add_argument("--video_stds", default=(0.225, 0.225, 0.225), type=tuple) - parser.add_argument("--video_crop_size", default=224, type=int) - parser.add_argument("--video_min_short_side_scale", default=256, type=int) - parser.add_argument("--video_max_short_side_scale", default=320, type=int) - parser.add_argument("--video_horizontal_flip_p", default=0.5, type=float) - parser.add_argument("--audio_raw_sample_rate", default=44100, type=int) - parser.add_argument("--audio_resampled_rate", default=16000, type=int) - parser.add_argument("--audio_mel_window_size", default=32, type=int) - parser.add_argument("--audio_mel_step_size", default=16, type=int) - parser.add_argument("--audio_num_mels", default=80, type=int) - parser.add_argument("--audio_mel_num_subsample", default=128, type=int) - parser.add_argument("--audio_logmel_mean", default=-7.03, type=float) - parser.add_argument("--audio_logmel_std", default=4.66, type=float) + # Data Parameters. + parser = LabeledVideoDataModule.add_argparse_args(parser) - # Trainer parameters. - parser = pytorch_lightning.Trainer.add_argparse_args(parser) + # Training Parameters. + parser = pl.Trainer.add_argparse_args(parser) parser.set_defaults( - max_epochs=200, - callbacks=[LearningRateMonitor()], + callbacks=[pl.callbacks.LearningRateMonitor()], replace_sampler_ddp=False, - reload_dataloaders_every_epoch=False, ) - # Build trainer, ResNet lightning-module and Kinetics data-module. - args = parser.parse_args() + return parser.parse_args(args) + + +def train(args): + pl.seed_everything(224) + dm = LabeledVideoDataModule.from_argparse_args(args) + model = VideoClassificationLightningModule(num_classes=dm.NUM_CLASSES, **vars(args)) + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model, dm) + +def main(): + args = parse_args() if args.on_cluster: + from slurm import copy_and_run_with_config + copy_and_run_with_config( train, args, @@ -448,21 +61,5 @@ def main(): train(args) -def train(args): - trainer = pytorch_lightning.Trainer.from_argparse_args(args) - classification_module = VideoClassificationLightningModule(args) - data_module = KineticsDataModule(args) - trainer.fit(classification_module, data_module) - - -def setup_logger(): - ch = logging.StreamHandler() - formatter = logging.Formatter("\n%(asctime)s [%(levelname)s] %(name)s: %(message)s") - ch.setFormatter(formatter) - logger = logging.getLogger("pytorchvideo") - logger.setLevel(logging.DEBUG) - logger.addHandler(ch) - - if __name__ == "__main__": main()