diff --git a/docs/source/tutorials/overview.rst b/docs/source/tutorials/overview.rst index f53d78eb..82798f86 100644 --- a/docs/source/tutorials/overview.rst +++ b/docs/source/tutorials/overview.rst @@ -20,7 +20,7 @@ Let's first go through how the library is organized. Now, let's assume that we want to train Barlow Twins on CIFAR10 for 100 epochs. For this, we won't use the ``main_pretrain.py`` file directly, but we'll build a minimal version of it in order to give a general overview of the library. -We start by importing everything that we will need (we will be relying on Pytorch Lightning to use our already implemented training/validation steps: +We start by importing everything that we will need (we will be relying on Pytorch Lightning to use our already implemented training/validation steps): .. code-block:: python @@ -30,20 +30,22 @@ We start by importing everything that we will need (we will be relying on Pytorc from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.plugins import DDPPlugin + # solo learn uses omega conf and hydra to manage configs files now + from omegaconf import DictConfig from solo.methods import BarlowTwins # imports the method class from solo.utils.checkpointer import Checkpointer # some data utilities # we need one dataloader to train an online linear classifier # (don't worry, the rest of the model has no idea of this classifier, so it doesn't use label info) - from solo.utils.classification_dataloader import prepare_data as prepare_data_classification + from solo.data.classification_dataloader import prepare_data as prepare_classification_dataloader # and some utilities to perform data loading for the method itself, including augmentation pipelines - from solo.utils.pretrain_dataloader import ( + from solo.data.pretrain_dataloader import ( + build_transform_pipeline, prepare_dataloader, prepare_datasets, prepare_n_crop_transform, - prepare_transform, ) @@ -56,38 +58,39 @@ However, for now, we won't rely on this, so let's just define all the needed par # common parameters for all methods # some parameters for extra functionally are missing, but don't mind this for now. base_kwargs = { - "backbone": "resnet18", - "num_classes": 10, + "name": "barlow_twins-cifar10", # change here for cifar100 + "backbone": { + "name": "resnet18", + "kwargs": {} + }, + "data": { + "dataset": "cifar10", + "num_classes": 10, + "train_path": "./data", # replace with your own path + "val_path": "./data", # replace with your own path + "num_large_crops": 2, # must equal 2 for barlow twins + "num_small_crops": 0, # must equal 0 for barlow twins + "num_workers": 4, + }, "cifar": True, "zero_init_residual": True, "max_epochs": 100, - "optimizer": "sgd", - "lars": True, - "lr": 0.01, - "gpus": "0", - "grad_clip_lars": True, - "weight_decay": 0.00001, - "classifier_lr": 0.5, - "exclude_bias_n_norm_lars": True, - "accumulate_grad_batches": 1, - "extra_optimizer_args": {"momentum": 0.9}, - "scheduler": "warmup_cosine", - "min_lr": 0.0, - "warmup_start_lr": 0.0, - "warmup_epochs": 10, - "num_crops_per_aug": [2, 0], - "num_large_crops": 2, - "num_small_crops": 0, - "eta_lars": 0.02, - "lr_decay_steps": None, + "optimizer": { + "name": "lars", + "lr": 0.01, + "batch_size": 256, + "weight_decay": 0.00001, + "classifier_lr": 0.1 # mandatory + + }, + "scheduler":{ + "name": "warmup_cosine", + "min_lr": 0.0, + "warmup_start_lr": 0.0, + "warmup_epochs": 10, + }, + "method": "barlow_twins", "dali_device": "gpu", - "batch_size": 256, - "num_workers": 4, - "data_dir": "/data/datasets", - "train_dir": "cifar10/train", - "val_dir": "cifar10/val", - "dataset": "cifar10", - "name": "barlow-cifar10", } # barlow specific parameters @@ -99,53 +102,69 @@ However, for now, we won't rely on this, so let's just define all the needed par "backbone_args": {"cifar": True, "zero_init_residual": True}, } - kwargs = {**base_kwargs, **method_kwargs} - - model = BarlowTwins(**kwargs) + cfg = DictConfig({**base_kwargs, "method_kwargs": method_kwargs}) + model = BarlowTwins(cfg) Now, let's create all the necessary data loaders. .. code-block:: python - # we first prepare our single transformation pipeline + # we first prepare our single transformation pipeline config transform_kwargs = { - "brightness": 0.4, - "contrast": 0.4, - "saturation": 0.2, - "hue": 0.1, - "gaussian_prob": 0.0, - "solarization_prob": 0.0, + "crop_size": 32, + "num_crops": 1, + "rrc": { + "enabled": True, + "crop_min_scale": 0.08, + "crop_max_scale": 1.0 + }, + "color_jitter": { + "prob": 0.8, + "brightness": 0.4, + "contrast": 0.4, + "saturation": 0.2, + "hue": 0.1, + }, + # all below need to be specified but are unused + "grayscale": {"prob": 0.0}, + "gaussian_blur": {"prob": 0.0}, + "solarization": {"prob": 0.0}, + "equalization": {"prob": 0.0}, + "horizontal_flip": {"prob": 0.0}, } - transform = [prepare_transform("cifar10", **transform_kwargs)] + aug_cfg = DictConfig(transform_kwargs) + augs = build_transform_pipeline("cifar10", aug_cfg) + - # then, we wrap the pipepline using this utility function + # then, we wrap the pipeline using this utility function # to make it produce an arbitrary number of crops - transform = prepare_n_crop_transform(transform, num_crops_per_aug=[2]) + transform = prepare_n_crop_transform([augs], num_crops_per_aug=[2]) # finally, we produce the Dataset/Dataloader classes train_dataset = prepare_datasets( - "cifar10", - transform, - data_dir="./", - train_dir=None, + dataset="cifar10", + transform=transform, + train_data_path=base_kwargs["data"]["train_path"], no_labels=False, ) train_loader = prepare_dataloader( - train_dataset, batch_size=base_kwargs["batch_size"], num_workers=base_kwargs["num_workers"] + train_dataset=train_dataset, + batch_size=base_kwargs["optimizer"]["batch_size"], + num_workers=base_kwargs["data"]["num_workers"] ) # we will also create a validation dataloader to automatically # check how well our models is doing in an online fashion. - _, val_loader = prepare_data_classification( - "cifar10", - data_dir="./", - train_dir=None, - val_dir=None, - batch_size=base_kwargs["batch_size"], - num_workers=base_kwargs["num_workers"], + _, val_loader = prepare_classification_dataloader( + dataset=base_kwargs["data"]["dataset"], # "cifar10" + train_data_path=base_kwargs["data"]["train_path"], + val_data_path=base_kwargs["data"]["val_path"], + batch_size=base_kwargs["optimizer"]["batch_size"], + num_workers=base_kwargs["data"]["num_workers"], ) + Now, we just need to define some extra magic for Pytorch Lightning to automatically log some stuff for us and then we can just create our lightning Trainer. .. code-block:: python @@ -165,26 +184,21 @@ Now, we just need to define some extra magic for Pytorch Lightning to automatica callbacks.append(lr_monitor) # checkpointer can automatically log your parameters, - # but we need to wrap it on a Namespace object - from argparse import Namespace - args = Namespace(**kwargs) # saves the checkout after every epoch ckpt = Checkpointer( - args, + cfg, logdir="checkpoints/barlow", frequency=1, ) callbacks.append(ckpt) trainer = Trainer.from_argparse_args( - args, + cfg, logger=wandb_logger, callbacks=callbacks, - plugins=DDPPlugin(find_unused_parameters=True), - checkpoint_callback=False, - terminate_on_nan=True, - accelerator="ddp", + accelerator="auto", # use whatever is available + strategy="ddp", # could change depending on your setup ) trainer.fit(model, train_loader, val_loader)