Skip to content

Commit

Permalink
🚧 wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nateraw committed May 20, 2021
1 parent 7fd0880 commit 0ec28aa
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 15 deletions.
16 changes: 14 additions & 2 deletions tutorials/video_classification_example/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def _video_transform(self, mode: str):
transform=Compose(
[
UniformTemporalSubsample(args.video_num_subsampled),
Lambda(lambda x: x / 255.0),
Normalize(args.video_means, args.video_stds),
]
+ (
Expand Down Expand Up @@ -217,7 +218,7 @@ def __init__(self, args):
for c in self.classes:

# Scenes within each class directory
scene_names = sorted(
scene_names = list(
x.name
for x in (root / c).glob("*")
if x.is_dir() and x.name != "Annotation"
Expand Down Expand Up @@ -273,5 +274,16 @@ def download_and_unzip(url, data_dir="./", verify=True):


if __name__ == "__main__":
args = parse_args("--batch_size 4 --data_path ./yt_data".split())
from finetune import parse_args
from train import LearningRateMonitor, VideoClassificationLightningModule
args = parse_args("--gpus 1 --precision 16 --batch_size 8 --data_path ./yt_data".split())
args.max_epochs = 200
args.callbacks = [LearningRateMonitor()]
args.replace_sampler_ddp = False
args.reload_dataloaders_every_epoch = False

pytorch_lightning.trainer.seed_everything(244)
dm = UCF11DataModule(args)
model = VideoClassificationLightningModule(args)
trainer = pytorch_lightning.Trainer.from_argparse_args(args)
trainer.fit(model, dm)
33 changes: 21 additions & 12 deletions tutorials/video_classification_example/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import pytorch_lightning as pl
import torch
from data import KineticsDataModule, MiniKineticsDataModule, UCF11DataModule
from models import Classifier
from .data import KineticsDataModule, MiniKineticsDataModule, UCF11DataModule
from pytorchvideo.models.head import create_res_basic_head
from torch import nn
from torch.optim import Adam
Expand All @@ -17,44 +16,55 @@


class Classifier(pl.LightningModule):
"""
"""
def __init__(
self,
num_classes: int = 11,
lr: float = 2e-4,
freeze_backbone: bool = True,
pretrained: bool = True,
**kwargs
):
"""A classifier for finetuning pretrained video classification backbones from
torchhub. We use the slow_r50 model here, but you can edit this class to
use whatever backbone/head you'd like.
Args:
num_classes (int, optional): Number of output classes. Defaults to 11.
lr (float, optional): The learning rate for the Adam optimizer. Defaults to 2e-4.
freeze_backbone (bool, optional): Whether to freeze the backbone or leave it trainable. Defaults to True.
pretrained (bool, optional): Use the pretrained model from torchhub. When False, we initialize the slow_r50 model from scratch. Defaults to True.
"""
super().__init__()
self.save_hyperparameters()

# Backbone
# The pretrained resnet model - we strip off its head to get the backbone
resnet = torch.hub.load(
"facebookresearch/pytorchvideo",
"slow_r50",
pretrained=self.hparams.pretrained,
)
self.backbone = nn.Sequential(*list(resnet.children())[0][:-1])

# Freeze the backbone layers if specified
if self.hparams.freeze_backbone:
for param in self.backbone.parameters():
param.requires_grad = False

# Head
# Create a new head we will train on top of the backbone
self.head = create_res_basic_head(
in_features=2048, out_features=self.hparams.num_classes
)

# Metrics
# Metrics we will keep track of
self.loss_fn = nn.CrossEntropyLoss()
self.train_acc = pl.metrics.Accuracy()
self.val_acc = pl.metrics.Accuracy()
self.accuracy = {"train": self.train_acc, "val": self.val_acc}

def forward(self, x):
if isinstance(x, dict):
x = x["video"]
feats = self.backbone(x)
return self.head(feats)
def forward(self, x: torch.Tensor):
return self.head(self.backbone(x))

def shared_step(self, batch, mode: str):
y_hat = self(batch["video"])
Expand Down Expand Up @@ -127,7 +137,6 @@ def parse_args(args=None):
parser = pl.Trainer.add_argparse_args(parser)
parser.set_defaults(
max_epochs=200,
callbacks=[pl.callbacks.LearningRateMonitor()],
replace_sampler_ddp=False,
reload_dataloaders_every_epoch=False,
)
Expand All @@ -138,7 +147,7 @@ def main(args):
pl.trainer.seed_everything()
dm_cls = DATASET_MAP.get(args.dataset)
dm = dm_cls(args)
model = Classifier(num_classes=dm_cls.NUM_CLASSES)
model = Classifier(num_classes=dm_cls.NUM_CLASSES, **vars(args))
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, dm)

Expand Down
3 changes: 2 additions & 1 deletion tutorials/video_classification_example/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, args):
if self.args.arch == "video_resnet":
self.model = pytorchvideo.models.resnet.create_resnet(
input_channel=3,
model_num_class=400,
model_num_class=11 # 400,
)
self.batch_key = "video"
elif self.args.arch == "audio_resnet":
Expand Down Expand Up @@ -235,6 +235,7 @@ def _video_transform(self, mode: str):
transform=Compose(
[
UniformTemporalSubsample(args.video_num_subsampled),
Lambda(lambda x: x/255.0),
Normalize(args.video_means, args.video_stds),
]
+ (
Expand Down

0 comments on commit 0ec28aa

Please sign in to comment.