diff --git a/tutorials/video_classification_example/data.py b/tutorials/video_classification_example/data.py index 6c23dce2..16eb3857 100644 --- a/tutorials/video_classification_example/data.py +++ b/tutorials/video_classification_example/data.py @@ -70,6 +70,7 @@ def _video_transform(self, mode: str): transform=Compose( [ UniformTemporalSubsample(args.video_num_subsampled), + Lambda(lambda x: x / 255.0), Normalize(args.video_means, args.video_stds), ] + ( @@ -217,7 +218,7 @@ def __init__(self, args): for c in self.classes: # Scenes within each class directory - scene_names = sorted( + scene_names = list( x.name for x in (root / c).glob("*") if x.is_dir() and x.name != "Annotation" @@ -273,5 +274,16 @@ def download_and_unzip(url, data_dir="./", verify=True): if __name__ == "__main__": - args = parse_args("--batch_size 4 --data_path ./yt_data".split()) + from finetune import parse_args + from train import LearningRateMonitor, VideoClassificationLightningModule + args = parse_args("--gpus 1 --precision 16 --batch_size 8 --data_path ./yt_data".split()) + args.max_epochs = 200 + args.callbacks = [LearningRateMonitor()] + args.replace_sampler_ddp = False + args.reload_dataloaders_every_epoch = False + + pytorch_lightning.trainer.seed_everything(244) dm = UCF11DataModule(args) + model = VideoClassificationLightningModule(args) + trainer = pytorch_lightning.Trainer.from_argparse_args(args) + trainer.fit(model, dm) diff --git a/tutorials/video_classification_example/finetune.py b/tutorials/video_classification_example/finetune.py index d77c9eb1..cff7fa99 100644 --- a/tutorials/video_classification_example/finetune.py +++ b/tutorials/video_classification_example/finetune.py @@ -2,8 +2,7 @@ import pytorch_lightning as pl import torch -from data import KineticsDataModule, MiniKineticsDataModule, UCF11DataModule -from models import Classifier +from .data import KineticsDataModule, MiniKineticsDataModule, UCF11DataModule from pytorchvideo.models.head import create_res_basic_head from torch import nn from torch.optim import Adam @@ -17,17 +16,30 @@ class Classifier(pl.LightningModule): + """ + """ def __init__( self, num_classes: int = 11, lr: float = 2e-4, freeze_backbone: bool = True, pretrained: bool = True, + **kwargs ): + """A classifier for finetuning pretrained video classification backbones from + torchhub. We use the slow_r50 model here, but you can edit this class to + use whatever backbone/head you'd like. + + Args: + num_classes (int, optional): Number of output classes. Defaults to 11. + lr (float, optional): The learning rate for the Adam optimizer. Defaults to 2e-4. + freeze_backbone (bool, optional): Whether to freeze the backbone or leave it trainable. Defaults to True. + pretrained (bool, optional): Use the pretrained model from torchhub. When False, we initialize the slow_r50 model from scratch. Defaults to True. + """ super().__init__() self.save_hyperparameters() - # Backbone + # The pretrained resnet model - we strip off its head to get the backbone resnet = torch.hub.load( "facebookresearch/pytorchvideo", "slow_r50", @@ -35,26 +47,24 @@ def __init__( ) self.backbone = nn.Sequential(*list(resnet.children())[0][:-1]) + # Freeze the backbone layers if specified if self.hparams.freeze_backbone: for param in self.backbone.parameters(): param.requires_grad = False - # Head + # Create a new head we will train on top of the backbone self.head = create_res_basic_head( in_features=2048, out_features=self.hparams.num_classes ) - # Metrics + # Metrics we will keep track of self.loss_fn = nn.CrossEntropyLoss() self.train_acc = pl.metrics.Accuracy() self.val_acc = pl.metrics.Accuracy() self.accuracy = {"train": self.train_acc, "val": self.val_acc} - def forward(self, x): - if isinstance(x, dict): - x = x["video"] - feats = self.backbone(x) - return self.head(feats) + def forward(self, x: torch.Tensor): + return self.head(self.backbone(x)) def shared_step(self, batch, mode: str): y_hat = self(batch["video"]) @@ -127,7 +137,6 @@ def parse_args(args=None): parser = pl.Trainer.add_argparse_args(parser) parser.set_defaults( max_epochs=200, - callbacks=[pl.callbacks.LearningRateMonitor()], replace_sampler_ddp=False, reload_dataloaders_every_epoch=False, ) @@ -138,7 +147,7 @@ def main(args): pl.trainer.seed_everything() dm_cls = DATASET_MAP.get(args.dataset) dm = dm_cls(args) - model = Classifier(num_classes=dm_cls.NUM_CLASSES) + model = Classifier(num_classes=dm_cls.NUM_CLASSES, **vars(args)) trainer = pl.Trainer.from_argparse_args(args) trainer.fit(model, dm) diff --git a/tutorials/video_classification_example/train.py b/tutorials/video_classification_example/train.py index 17beb719..8588d129 100644 --- a/tutorials/video_classification_example/train.py +++ b/tutorials/video_classification_example/train.py @@ -80,7 +80,7 @@ def __init__(self, args): if self.args.arch == "video_resnet": self.model = pytorchvideo.models.resnet.create_resnet( input_channel=3, - model_num_class=400, + model_num_class=11 # 400, ) self.batch_key = "video" elif self.args.arch == "audio_resnet": @@ -235,6 +235,7 @@ def _video_transform(self, mode: str): transform=Compose( [ UniformTemporalSubsample(args.video_num_subsampled), + Lambda(lambda x: x/255.0), Normalize(args.video_means, args.video_stds), ] + (