Skip to content

Commit

Permalink
💄 style
Browse files Browse the repository at this point in the history
  • Loading branch information
nateraw committed May 14, 2021
1 parent 72124c4 commit 2f29b65
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 44 deletions.
98 changes: 63 additions & 35 deletions tutorials/video_classification_example/data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import requests
from argparse import Namespace, ArgumentParser
import pytorch_lightning
import itertools
from argparse import ArgumentParser, Namespace
from pathlib import Path
from random import shuffle
from shutil import unpack_archive

import pytorch_lightning
import requests
import torch
from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler
from pytorchvideo.data.labeled_video_dataset import labeled_video_dataset
from pytorchvideo.transforms import (
ApplyTransformToKey,
Normalize,
Expand All @@ -11,9 +17,7 @@
ShortSideScale,
UniformTemporalSubsample,
)
from pytorchvideo.data import LabeledVideoDataset

from torch.utils.data import DistributedSampler, RandomSampler
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from torchaudio.transforms import MelSpectrogram, Resample
from torchvision.transforms import (
CenterCrop,
Expand All @@ -22,12 +26,6 @@
RandomCrop,
RandomHorizontalFlip,
)
from pytorchvideo.data import make_clip_sampler
from pytorchvideo.data.labeled_video_dataset import labeled_video_dataset
import torch
import itertools
from torch.utils.data import DataLoader
from random import shuffle


class LabeledVideoDataModule(pytorch_lightning.LightningDataModule):
Expand All @@ -43,7 +41,11 @@ def __init__(self, args):
self.root = Path(self.args.data_path) / self.SOURCE_DIR_NAME
if not (self.SOURCE_URL is None or self.SOURCE_DIR_NAME is None):
if not self.root.exists():
download_and_unzip(self.SOURCE_URL, self.args.data_path, verify=getattr(self.args, 'verify', True))
download_and_unzip(
self.SOURCE_URL,
self.args.data_path,
verify=getattr(self.args, "verify", True),
)

def _make_transforms(self, mode: str):

Expand Down Expand Up @@ -91,8 +93,12 @@ def _video_transform(self, mode: str):

def _audio_transform(self):
args = self.args
n_fft = int(float(args.audio_resampled_rate) / 1000 * args.audio_mel_window_size)
hop_length = int(float(args.audio_resampled_rate) / 1000 * args.audio_mel_step_size)
n_fft = int(
float(args.audio_resampled_rate) / 1000 * args.audio_mel_window_size
)
hop_length = int(
float(args.audio_resampled_rate) / 1000 * args.audio_mel_step_size
)
eps = 1e-10
return ApplyTransformToKey(
key="audio",
Expand All @@ -113,7 +119,9 @@ def _audio_transform(self):
Lambda(torch.log),
UniformTemporalSubsample(args.audio_mel_num_subsample),
Lambda(lambda x: x.transpose(1, 0)), # (F, T) -> (T, F)
Lambda(lambda x: x.view(1, x.size(0), 1, x.size(1))), # (T, F) -> (1, T, 1, F)
Lambda(
lambda x: x.view(1, x.size(0), 1, x.size(1))
), # (T, F) -> (1, T, 1, F)
Normalize((args.audio_logmel_mean,), (args.audio_logmel_std,)),
]
),
Expand All @@ -122,21 +130,30 @@ def _audio_transform(self):
def _make_ds_and_loader(self, mode: str):
ds = LimitDataset(
labeled_video_dataset(
data_path=str(Path(self.root) / (self.TRAIN_PATH if mode == 'train' else self.VAL_PATH)),
clip_sampler=make_clip_sampler("random" if mode == 'train' else 'uniform', self.args.clip_duration),
data_path=str(
Path(self.root)
/ (self.TRAIN_PATH if mode == "train" else self.VAL_PATH)
),
clip_sampler=make_clip_sampler(
"random" if mode == "train" else "uniform", self.args.clip_duration
),
video_path_prefix=self.args.video_path_prefix,
transform=self._make_transforms(mode=mode),
video_sampler=DistributedSampler if (self.trainer is not None and self.trainer.use_ddp) else RandomSampler,
video_sampler=DistributedSampler
if (self.trainer is not None and self.trainer.use_ddp)
else RandomSampler,
)
)
return ds, DataLoader(ds, batch_size=self.args.batch_size, num_workers=self.args.workers)
return ds, DataLoader(
ds, batch_size=self.args.batch_size, num_workers=self.args.workers
)

def train_dataloader(self):
self.train_dataset, loader = self._make_ds_and_loader('train')
self.train_dataset, loader = self._make_ds_and_loader("train")
return loader

def val_dataloader(self):
self.val_dataset, loader = self._make_ds_and_loader('val')
self.val_dataset, loader = self._make_ds_and_loader("val")
return loader


Expand All @@ -151,7 +168,9 @@ class LimitDataset(torch.utils.data.Dataset):
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
self.dataset_iter = itertools.chain.from_iterable(itertools.repeat(iter(dataset), 2))
self.dataset_iter = itertools.chain.from_iterable(
itertools.repeat(iter(dataset), 2)
)

def __getitem__(self, index):
return next(self.dataset_iter)
Expand All @@ -161,8 +180,8 @@ def __len__(self):


class KineticsDataModule(LabeledVideoDataModule):
TRAIN_PATH = 'train.csv'
VAL_PATH = 'val.csv'
TRAIN_PATH = "train.csv"
VAL_PATH = "val.csv"
NUM_CLASSES = 700


Expand All @@ -171,15 +190,15 @@ class MiniKineticsDataModule(LabeledVideoDataModule):
TRAIN_PATH = "train"
VAL_PATH = "val"
SOURCE_URL = "https://pl-flash-data.s3.amazonaws.com/kinetics.zip"
SOURCE_DIR_NAME = 'kinetics'
SOURCE_DIR_NAME = "kinetics"
NUM_CLASSES = 6


class UCF11DataModule(LabeledVideoDataModule):
TRAIN_PATH = None
VAL_PATH = None
SOURCE_URL = "https://www.crcv.ucf.edu/data/YouTube_DataSet_Annotated.zip"
SOURCE_DIR_NAME = 'action_youtube_naudio'
SOURCE_DIR_NAME = "action_youtube_naudio"
NUM_CLASSES = 11

def __init__(self, args):
Expand All @@ -199,7 +218,11 @@ def __init__(self, args):
for c in self.classes:

# Scenes within each class directory
scene_names = sorted(x.name for x in (root / c).glob("*") if x.is_dir() and x.name != 'Annotation')
scene_names = sorted(
x.name
for x in (root / c).glob("*")
if x.is_dir() and x.name != "Annotation"
)
shuffle(scene_names)

# Holdout a random actor/scene
Expand All @@ -209,25 +232,30 @@ def __init__(self, args):
# Keep track of which scenes we held out for each class w/ a dict
self.holdout_scenes[c] = holdout_scene

for v in (root / c).glob('**/*.avi'):
for v in (root / c).glob("**/*.avi"):
labeled_path = (v, {"label": self.class_to_label[c]})
if v.parent.name != holdout_scene:
self.train_paths.append(labeled_path)
else:
self.val_paths.append(labeled_path)


def _make_ds_and_loader(self, mode: str):
ds = LimitDataset(
LabeledVideoDataset(
self.train_paths if mode == 'train' else self.val_paths,
clip_sampler=make_clip_sampler("random" if mode == 'train' else 'uniform', self.args.clip_duration),
self.train_paths if mode == "train" else self.val_paths,
clip_sampler=make_clip_sampler(
"random" if mode == "train" else "uniform", self.args.clip_duration
),
decode_audio=False,
transform=self._make_transforms(mode=mode),
video_sampler=DistributedSampler if (self.trainer is not None and self.trainer.use_ddp) else RandomSampler,
video_sampler=DistributedSampler
if (self.trainer is not None and self.trainer.use_ddp)
else RandomSampler,
)
)
return ds, DataLoader(ds, batch_size=self.args.batch_size, num_workers=self.args.workers)
return ds, DataLoader(
ds, batch_size=self.args.batch_size, num_workers=self.args.workers
)


def download_and_unzip(url, data_dir="./", verify=True):
Expand All @@ -246,5 +274,5 @@ def download_and_unzip(url, data_dir="./", verify=True):


if __name__ == "__main__":
args = parse_args('--batch_size 4 --data_path ./yt_data'.split())
args = parse_args("--batch_size 4 --data_path ./yt_data".split())
dm = UCF11DataModule(args)
28 changes: 19 additions & 9 deletions tutorials/video_classification_example/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import pytorch_lightning as pl
import torch
from data import KineticsDataModule, MiniKineticsDataModule, UCF11DataModule
from models import Classifier
from pytorchvideo.models.head import create_res_basic_head
from torch import nn
from torch.optim import Adam
from pytorchvideo.models.head import create_res_basic_head

from data import UCF11DataModule, KineticsDataModule, MiniKineticsDataModule
from models import Classifier


DATASET_MAP = {
Expand All @@ -18,27 +17,38 @@


class Classifier(pl.LightningModule):

def __init__(self, num_classes: int = 11, lr: float = 2e-4, freeze_backbone: bool = True, pretrained: bool = True):
def __init__(
self,
num_classes: int = 11,
lr: float = 2e-4,
freeze_backbone: bool = True,
pretrained: bool = True,
):
super().__init__()
self.save_hyperparameters()

# Backbone
resnet = torch.hub.load("facebookresearch/pytorchvideo", 'slow_r50', pretrained=self.hparams.pretrained)
resnet = torch.hub.load(
"facebookresearch/pytorchvideo",
"slow_r50",
pretrained=self.hparams.pretrained,
)
self.backbone = nn.Sequential(*list(resnet.children())[0][:-1])

if self.hparams.freeze_backbone:
for param in self.backbone.parameters():
param.requires_grad = False

# Head
self.head = create_res_basic_head(in_features=2048, out_features=self.hparams.num_classes)
self.head = create_res_basic_head(
in_features=2048, out_features=self.hparams.num_classes
)

# Metrics
self.loss_fn = nn.CrossEntropyLoss()
self.train_acc = pl.metrics.Accuracy()
self.val_acc = pl.metrics.Accuracy()
self.accuracy = {'train': self.train_acc, 'val': self.val_acc}
self.accuracy = {"train": self.train_acc, "val": self.val_acc}

def forward(self, x):
if isinstance(x, dict):
Expand Down

0 comments on commit 2f29b65

Please sign in to comment.