From 0abeaab16fb636799492338509abaf8c32434ce3 Mon Sep 17 00:00:00 2001 From: Victor Turrisi Date: Sat, 15 Jun 2024 15:52:50 -0300 Subject: [PATCH 1/4] Add argument for low-res datasets --- solo/methods/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/solo/methods/base.py b/solo/methods/base.py index d771b4f3..1c614ef3 100644 --- a/solo/methods/base.py +++ b/solo/methods/base.py @@ -195,8 +195,8 @@ def __init__(self, cfg: omegaconf.DictConfig): self.features_dim: int = self.backbone.inplanes # remove fc layer self.backbone.fc = nn.Identity() - cifar = cfg.data.dataset in ["cifar10", "cifar100"] - if cifar: + low_res = cfg.data.dataset in ["cifar10", "cifar100"] or cfg.method_kwargs.get("low_res", False) + if low_res: self.backbone.conv1 = nn.Conv2d( 3, 64, kernel_size=3, stride=1, padding=2, bias=False ) @@ -644,8 +644,8 @@ def __init__( if self.backbone_name.startswith("resnet"): # remove fc layer self.momentum_backbone.fc = nn.Identity() - cifar = cfg.data.dataset in ["cifar10", "cifar100"] - if cifar: + low_res = cfg.data.dataset in ["cifar10", "cifar100"] or cfg.method_kwargs.get("low_res", False) + if low_res: self.momentum_backbone.conv1 = nn.Conv2d( 3, 64, kernel_size=3, stride=1, padding=2, bias=False ) From 3d58244de2720c14df8fb94641c20e079c2cde86 Mon Sep 17 00:00:00 2001 From: Victor Turrisi Date: Sat, 15 Jun 2024 15:56:21 -0300 Subject: [PATCH 2/4] Always save model at the end of training --- solo/utils/checkpointer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/solo/utils/checkpointer.py b/solo/utils/checkpointer.py index b376e7bb..a15683e4 100644 --- a/solo/utils/checkpointer.py +++ b/solo/utils/checkpointer.py @@ -164,7 +164,7 @@ def on_train_start(self, trainer: pl.Trainer, _): self.save_args(trainer) def on_train_epoch_end(self, trainer: pl.Trainer, _): - """Tries to save current checkpoint at the end of each train epoch. + """Tries to save the current checkpoint at the end of each train epoch. Args: trainer (pl.Trainer): pytorch lightning trainer object. @@ -173,3 +173,8 @@ def on_train_epoch_end(self, trainer: pl.Trainer, _): epoch = trainer.current_epoch # type: ignore if epoch % self.frequency == 0: self.save(trainer) + + def on_train_epoch_end(self, *args, **kwargs): + """Saves model at the end of training. """ + + self.save(trainer) From ccdba0b7e1b71f9fe3eb2f3138ae7be3a4ea5988 Mon Sep 17 00:00:00 2001 From: Victor Turrisi Date: Tue, 18 Jun 2024 08:55:33 -0300 Subject: [PATCH 3/4] Update checkpointer.py --- solo/utils/checkpointer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/solo/utils/checkpointer.py b/solo/utils/checkpointer.py index a15683e4..02ce4109 100644 --- a/solo/utils/checkpointer.py +++ b/solo/utils/checkpointer.py @@ -174,7 +174,12 @@ def on_train_epoch_end(self, trainer: pl.Trainer, _): if epoch % self.frequency == 0: self.save(trainer) - def on_train_epoch_end(self, *args, **kwargs): - """Saves model at the end of training. """ + def on_train_end(self, trainer: pl. +Trainer, _): + """Saves model at the end of training. + + Args: + trainer (pl.Trainer): pytorch lightning trainer object. + """ self.save(trainer) From 6985b84a5d8f65e01efa40ac5e285713e43f7cd2 Mon Sep 17 00:00:00 2001 From: Victor Turrisi Date: Tue, 18 Jun 2024 10:36:40 -0300 Subject: [PATCH 4/4] Update checkpointer.py --- solo/utils/checkpointer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/solo/utils/checkpointer.py b/solo/utils/checkpointer.py index 02ce4109..7d9b113c 100644 --- a/solo/utils/checkpointer.py +++ b/solo/utils/checkpointer.py @@ -174,8 +174,7 @@ def on_train_epoch_end(self, trainer: pl.Trainer, _): if epoch % self.frequency == 0: self.save(trainer) - def on_train_end(self, trainer: pl. -Trainer, _): + def on_train_end(self, trainer: pl.Trainer, _): """Saves model at the end of training. Args: