From 2f29b65a52d46238136908883e2503bc48baa732 Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 13 May 2021 19:10:44 -0600 Subject: [PATCH] :lipstick: style --- .../video_classification_example/data.py | 98 ++++++++++++------- .../video_classification_example/finetune.py | 28 ++++-- 2 files changed, 82 insertions(+), 44 deletions(-) diff --git a/tutorials/video_classification_example/data.py b/tutorials/video_classification_example/data.py index 13344680..4dd379cb 100644 --- a/tutorials/video_classification_example/data.py +++ b/tutorials/video_classification_example/data.py @@ -1,8 +1,14 @@ -import requests -from argparse import Namespace, ArgumentParser -import pytorch_lightning +import itertools +from argparse import ArgumentParser, Namespace from pathlib import Path +from random import shuffle from shutil import unpack_archive + +import pytorch_lightning +import requests +import torch +from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler +from pytorchvideo.data.labeled_video_dataset import labeled_video_dataset from pytorchvideo.transforms import ( ApplyTransformToKey, Normalize, @@ -11,9 +17,7 @@ ShortSideScale, UniformTemporalSubsample, ) -from pytorchvideo.data import LabeledVideoDataset - -from torch.utils.data import DistributedSampler, RandomSampler +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler from torchaudio.transforms import MelSpectrogram, Resample from torchvision.transforms import ( CenterCrop, @@ -22,12 +26,6 @@ RandomCrop, RandomHorizontalFlip, ) -from pytorchvideo.data import make_clip_sampler -from pytorchvideo.data.labeled_video_dataset import labeled_video_dataset -import torch -import itertools -from torch.utils.data import DataLoader -from random import shuffle class LabeledVideoDataModule(pytorch_lightning.LightningDataModule): @@ -43,7 +41,11 @@ def __init__(self, args): self.root = Path(self.args.data_path) / self.SOURCE_DIR_NAME if not (self.SOURCE_URL is None or self.SOURCE_DIR_NAME is None): if not self.root.exists(): - download_and_unzip(self.SOURCE_URL, self.args.data_path, verify=getattr(self.args, 'verify', True)) + download_and_unzip( + self.SOURCE_URL, + self.args.data_path, + verify=getattr(self.args, "verify", True), + ) def _make_transforms(self, mode: str): @@ -91,8 +93,12 @@ def _video_transform(self, mode: str): def _audio_transform(self): args = self.args - n_fft = int(float(args.audio_resampled_rate) / 1000 * args.audio_mel_window_size) - hop_length = int(float(args.audio_resampled_rate) / 1000 * args.audio_mel_step_size) + n_fft = int( + float(args.audio_resampled_rate) / 1000 * args.audio_mel_window_size + ) + hop_length = int( + float(args.audio_resampled_rate) / 1000 * args.audio_mel_step_size + ) eps = 1e-10 return ApplyTransformToKey( key="audio", @@ -113,7 +119,9 @@ def _audio_transform(self): Lambda(torch.log), UniformTemporalSubsample(args.audio_mel_num_subsample), Lambda(lambda x: x.transpose(1, 0)), # (F, T) -> (T, F) - Lambda(lambda x: x.view(1, x.size(0), 1, x.size(1))), # (T, F) -> (1, T, 1, F) + Lambda( + lambda x: x.view(1, x.size(0), 1, x.size(1)) + ), # (T, F) -> (1, T, 1, F) Normalize((args.audio_logmel_mean,), (args.audio_logmel_std,)), ] ), @@ -122,21 +130,30 @@ def _audio_transform(self): def _make_ds_and_loader(self, mode: str): ds = LimitDataset( labeled_video_dataset( - data_path=str(Path(self.root) / (self.TRAIN_PATH if mode == 'train' else self.VAL_PATH)), - clip_sampler=make_clip_sampler("random" if mode == 'train' else 'uniform', self.args.clip_duration), + data_path=str( + Path(self.root) + / (self.TRAIN_PATH if mode == "train" else self.VAL_PATH) + ), + clip_sampler=make_clip_sampler( + "random" if mode == "train" else "uniform", self.args.clip_duration + ), video_path_prefix=self.args.video_path_prefix, transform=self._make_transforms(mode=mode), - video_sampler=DistributedSampler if (self.trainer is not None and self.trainer.use_ddp) else RandomSampler, + video_sampler=DistributedSampler + if (self.trainer is not None and self.trainer.use_ddp) + else RandomSampler, ) ) - return ds, DataLoader(ds, batch_size=self.args.batch_size, num_workers=self.args.workers) + return ds, DataLoader( + ds, batch_size=self.args.batch_size, num_workers=self.args.workers + ) def train_dataloader(self): - self.train_dataset, loader = self._make_ds_and_loader('train') + self.train_dataset, loader = self._make_ds_and_loader("train") return loader def val_dataloader(self): - self.val_dataset, loader = self._make_ds_and_loader('val') + self.val_dataset, loader = self._make_ds_and_loader("val") return loader @@ -151,7 +168,9 @@ class LimitDataset(torch.utils.data.Dataset): def __init__(self, dataset): super().__init__() self.dataset = dataset - self.dataset_iter = itertools.chain.from_iterable(itertools.repeat(iter(dataset), 2)) + self.dataset_iter = itertools.chain.from_iterable( + itertools.repeat(iter(dataset), 2) + ) def __getitem__(self, index): return next(self.dataset_iter) @@ -161,8 +180,8 @@ def __len__(self): class KineticsDataModule(LabeledVideoDataModule): - TRAIN_PATH = 'train.csv' - VAL_PATH = 'val.csv' + TRAIN_PATH = "train.csv" + VAL_PATH = "val.csv" NUM_CLASSES = 700 @@ -171,7 +190,7 @@ class MiniKineticsDataModule(LabeledVideoDataModule): TRAIN_PATH = "train" VAL_PATH = "val" SOURCE_URL = "https://pl-flash-data.s3.amazonaws.com/kinetics.zip" - SOURCE_DIR_NAME = 'kinetics' + SOURCE_DIR_NAME = "kinetics" NUM_CLASSES = 6 @@ -179,7 +198,7 @@ class UCF11DataModule(LabeledVideoDataModule): TRAIN_PATH = None VAL_PATH = None SOURCE_URL = "https://www.crcv.ucf.edu/data/YouTube_DataSet_Annotated.zip" - SOURCE_DIR_NAME = 'action_youtube_naudio' + SOURCE_DIR_NAME = "action_youtube_naudio" NUM_CLASSES = 11 def __init__(self, args): @@ -199,7 +218,11 @@ def __init__(self, args): for c in self.classes: # Scenes within each class directory - scene_names = sorted(x.name for x in (root / c).glob("*") if x.is_dir() and x.name != 'Annotation') + scene_names = sorted( + x.name + for x in (root / c).glob("*") + if x.is_dir() and x.name != "Annotation" + ) shuffle(scene_names) # Holdout a random actor/scene @@ -209,25 +232,30 @@ def __init__(self, args): # Keep track of which scenes we held out for each class w/ a dict self.holdout_scenes[c] = holdout_scene - for v in (root / c).glob('**/*.avi'): + for v in (root / c).glob("**/*.avi"): labeled_path = (v, {"label": self.class_to_label[c]}) if v.parent.name != holdout_scene: self.train_paths.append(labeled_path) else: self.val_paths.append(labeled_path) - def _make_ds_and_loader(self, mode: str): ds = LimitDataset( LabeledVideoDataset( - self.train_paths if mode == 'train' else self.val_paths, - clip_sampler=make_clip_sampler("random" if mode == 'train' else 'uniform', self.args.clip_duration), + self.train_paths if mode == "train" else self.val_paths, + clip_sampler=make_clip_sampler( + "random" if mode == "train" else "uniform", self.args.clip_duration + ), decode_audio=False, transform=self._make_transforms(mode=mode), - video_sampler=DistributedSampler if (self.trainer is not None and self.trainer.use_ddp) else RandomSampler, + video_sampler=DistributedSampler + if (self.trainer is not None and self.trainer.use_ddp) + else RandomSampler, ) ) - return ds, DataLoader(ds, batch_size=self.args.batch_size, num_workers=self.args.workers) + return ds, DataLoader( + ds, batch_size=self.args.batch_size, num_workers=self.args.workers + ) def download_and_unzip(url, data_dir="./", verify=True): @@ -246,5 +274,5 @@ def download_and_unzip(url, data_dir="./", verify=True): if __name__ == "__main__": - args = parse_args('--batch_size 4 --data_path ./yt_data'.split()) + args = parse_args("--batch_size 4 --data_path ./yt_data".split()) dm = UCF11DataModule(args) diff --git a/tutorials/video_classification_example/finetune.py b/tutorials/video_classification_example/finetune.py index 35055634..d77c9eb1 100644 --- a/tutorials/video_classification_example/finetune.py +++ b/tutorials/video_classification_example/finetune.py @@ -2,12 +2,11 @@ import pytorch_lightning as pl import torch +from data import KineticsDataModule, MiniKineticsDataModule, UCF11DataModule +from models import Classifier +from pytorchvideo.models.head import create_res_basic_head from torch import nn from torch.optim import Adam -from pytorchvideo.models.head import create_res_basic_head - -from data import UCF11DataModule, KineticsDataModule, MiniKineticsDataModule -from models import Classifier DATASET_MAP = { @@ -18,13 +17,22 @@ class Classifier(pl.LightningModule): - - def __init__(self, num_classes: int = 11, lr: float = 2e-4, freeze_backbone: bool = True, pretrained: bool = True): + def __init__( + self, + num_classes: int = 11, + lr: float = 2e-4, + freeze_backbone: bool = True, + pretrained: bool = True, + ): super().__init__() self.save_hyperparameters() # Backbone - resnet = torch.hub.load("facebookresearch/pytorchvideo", 'slow_r50', pretrained=self.hparams.pretrained) + resnet = torch.hub.load( + "facebookresearch/pytorchvideo", + "slow_r50", + pretrained=self.hparams.pretrained, + ) self.backbone = nn.Sequential(*list(resnet.children())[0][:-1]) if self.hparams.freeze_backbone: @@ -32,13 +40,15 @@ def __init__(self, num_classes: int = 11, lr: float = 2e-4, freeze_backbone: boo param.requires_grad = False # Head - self.head = create_res_basic_head(in_features=2048, out_features=self.hparams.num_classes) + self.head = create_res_basic_head( + in_features=2048, out_features=self.hparams.num_classes + ) # Metrics self.loss_fn = nn.CrossEntropyLoss() self.train_acc = pl.metrics.Accuracy() self.val_acc = pl.metrics.Accuracy() - self.accuracy = {'train': self.train_acc, 'val': self.val_acc} + self.accuracy = {"train": self.train_acc, "val": self.val_acc} def forward(self, x): if isinstance(x, dict):