diff --git a/solo/methods/base.py b/solo/methods/base.py index d771b4f3..1c614ef3 100644 --- a/solo/methods/base.py +++ b/solo/methods/base.py @@ -195,8 +195,8 @@ def __init__(self, cfg: omegaconf.DictConfig): self.features_dim: int = self.backbone.inplanes # remove fc layer self.backbone.fc = nn.Identity() - cifar = cfg.data.dataset in ["cifar10", "cifar100"] - if cifar: + low_res = cfg.data.dataset in ["cifar10", "cifar100"] or cfg.method_kwargs.get("low_res", False) + if low_res: self.backbone.conv1 = nn.Conv2d( 3, 64, kernel_size=3, stride=1, padding=2, bias=False ) @@ -644,8 +644,8 @@ def __init__( if self.backbone_name.startswith("resnet"): # remove fc layer self.momentum_backbone.fc = nn.Identity() - cifar = cfg.data.dataset in ["cifar10", "cifar100"] - if cifar: + low_res = cfg.data.dataset in ["cifar10", "cifar100"] or cfg.method_kwargs.get("low_res", False) + if low_res: self.momentum_backbone.conv1 = nn.Conv2d( 3, 64, kernel_size=3, stride=1, padding=2, bias=False ) diff --git a/solo/utils/checkpointer.py b/solo/utils/checkpointer.py index b376e7bb..7d9b113c 100644 --- a/solo/utils/checkpointer.py +++ b/solo/utils/checkpointer.py @@ -164,7 +164,7 @@ def on_train_start(self, trainer: pl.Trainer, _): self.save_args(trainer) def on_train_epoch_end(self, trainer: pl.Trainer, _): - """Tries to save current checkpoint at the end of each train epoch. + """Tries to save the current checkpoint at the end of each train epoch. Args: trainer (pl.Trainer): pytorch lightning trainer object. @@ -173,3 +173,12 @@ def on_train_epoch_end(self, trainer: pl.Trainer, _): epoch = trainer.current_epoch # type: ignore if epoch % self.frequency == 0: self.save(trainer) + + def on_train_end(self, trainer: pl.Trainer, _): + """Saves model at the end of training. + + Args: + trainer (pl.Trainer): pytorch lightning trainer object. + """ + + self.save(trainer)