Skip to content

Commit

Permalink
Add All4One (#382)
Browse files Browse the repository at this point in the history
Co-authored-by: ImaGonEs <[email protected]>
  • Loading branch information
ImaGonEs and StepCrow authored Jan 8, 2024
1 parent 5b564c2 commit a151225
Show file tree
Hide file tree
Showing 11 changed files with 987 additions and 0 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ The library is self-contained, but it is possible to use the models outside of s
---

## Methods available
* [All4One](https://openaccess.thecvf.com/content/ICCV2023/html/Estepa_All4One_Symbiotic_Neighbour_Contrastive_Learning_via_Self-Attention_and_Redundancy_Reduction_ICCV_2023_paper.html)
* [Barlow Twins](https://arxiv.org/abs/2103.03230)
* [BYOL](https://arxiv.org/abs/2006.07733)
* [DeepCluster V2](https://arxiv.org/abs/2006.09882)
Expand Down Expand Up @@ -216,6 +217,7 @@ All pretrained models avaiable can be downloaded directly via the tables below o

| Method | Backbone | Epochs | Dali | Acc@1 | Acc@5 | Checkpoint |
|--------------|:--------:|:------:|:----:|:--------------:|:--------------:|:----------:|
| All4One | ResNet18 | 1000 | :x: | 93.24 | 99.88 | [:link:](https://drive.google.com/drive/folders/1dtYmZiftruQ7B2PQ8fo44wguCZ0eSzAd?usp=sharing) |
| Barlow Twins | ResNet18 | 1000 | :x: | 92.10 | 99.73 | [:link:](https://drive.google.com/drive/folders/1L5RAM3lCSViD2zEqLtC-GQKVw6mxtxJ_?usp=sharing) |
| BYOL | ResNet18 | 1000 | :x: | 92.58 | 99.79 | [:link:](https://drive.google.com/drive/folders/1KxeYAEE7Ev9kdFFhXWkPZhG-ya3_UwGP?usp=sharing) |
|DeepCluster V2| ResNet18 | 1000 | :x: | 88.85 | 99.58 | [:link:](https://drive.google.com/drive/folders/1tkEbiDQ38vZaQUsT6_vEpxbDxSUAGwF-?usp=sharing) |
Expand All @@ -237,6 +239,7 @@ All pretrained models avaiable can be downloaded directly via the tables below o

| Method | Backbone | Epochs | Dali | Acc@1 | Acc@5 | Checkpoint |
|--------------|:--------:|:------:|:----:|:--------------:|:--------------:|:----------:|
| All4One | ResNet18 | 1000 | :x: | 72.17 | 93.35 | [:link:](https://drive.google.com/drive/folders/1oQcC80XPr-Wxhjs-PEqD_8VhUa_izqeZ?usp=sharing) |
| Barlow Twins | ResNet18 | 1000 | :x: | 70.90 | 91.91 | [:link:](https://drive.google.com/drive/folders/1hDLSApF3zSMAKco1Ck4DMjyNxhsIR2yq?usp=sharing) |
| BYOL | ResNet18 | 1000 | :x: | 70.46 | 91.96 | [:link:](https://drive.google.com/drive/folders/1hwsEdsfsUulD2tAwa4epKK9pkSuvFv6m?usp=sharing) |
|DeepCluster V2| ResNet18 | 1000 | :x: | 63.61 | 88.09 | [:link:](https://drive.google.com/drive/folders/1gAKyMz41mvGh1BBOYdc_xu6JPSkKlWqK?usp=sharing) |
Expand All @@ -257,6 +260,7 @@ All pretrained models avaiable can be downloaded directly via the tables below o

| Method | Backbone | Epochs | Dali | Acc@1 (online) | Acc@1 (offline) | Acc@5 (online) | Acc@5 (offline) | Checkpoint |
|-------------------------|:--------:|:------:|:------------------:|:--------------:|:---------------:|:--------------:|:---------------:|:----------:|
| All4One | ResNet18 | 400 | :heavy_check_mark: | 81.93 | - | 96.23 | - | [:link:](https://drive.google.com/drive/folders/1bJCRLP5Rz_JEylNq9C4sY3ccYZSchUGR?usp=sharing) |
| Barlow Twins :rocket: | ResNet18 | 400 | :heavy_check_mark: | 80.38 | 80.16 | 95.28 | 95.14 | [:link:](https://drive.google.com/drive/folders/1rj8RbER9E71mBlCHIZEIhKPUFn437D5O?usp=sharing) |
| BYOL :rocket: | ResNet18 | 400 | :heavy_check_mark: | 80.16 | 80.32 | 95.02 | 94.94 | [:link:](https://drive.google.com/drive/folders/1riOLjMawD_znO4HYj8LBN2e1X4jXpDE1?usp=sharing) |
| DeepCluster V2 | ResNet18 | 400 | :x: | 75.36 | 75.4 | 93.22 | 93.10 | [:link:](https://drive.google.com/drive/folders/1d5jPuavrQ7lMlQZn5m2KnN5sPMGhHFo8?usp=sharing) |
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ While the library is self contained, it is possible to use the models outside of

solo/methods/base
solo/methods/linear
solo/methods/all4one
solo/methods/barlow
solo/methods/byol
solo/methods/deepclusterv2
Expand Down
48 changes: 48 additions & 0 deletions docs/source/solo/methods/all4one.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
All4One
======

.. automethod:: solo.methods.all4one.All4One.__init__
:noindex:


add_model_specific_args
~~~~~~~~~~~~~~~~~~~~~~~
.. automethod:: solo.methods.all4one.All4One.add_model_specific_args
:noindex:

learnable_params
~~~~~~~~~~~~~~~~
.. autoattribute:: solo.methods.all4one.All4One.learnable_params
:noindex:

dequeue_and_enqueue
~~~~~~~~~~~~~~~~~~~
.. automethod:: solo.methods.all4one.All4One.dequeue_and_enqueue
:noindex:

find_nn
~~~~~~~~~~~~~~~~~~~~
.. automethod:: solo.methods.all4one.All4One.find_nn
:noindex:

off_diagonal
~~~~~~~~~~~~~~~~~~~~
.. automethod:: solo.methods.all4one.All4One.off_diagonal
:noindex:


save_NN
~~~~~~~~~~~~~~~~~~~~
.. automethod:: solo.methods.all4one.All4One.save_NN
:noindex:


forward
~~~~~~~
.. automethod:: solo.methods.all4one.All4One.forward
:noindex:

training_step
~~~~~~~~~~~~~
.. automethod:: solo.methods.all4one.All4One.training_step
:noindex:
100 changes: 100 additions & 0 deletions docs/source/solo/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,103 @@ Whitening

.. automethod:: solo.utils.whitening.Whitening2d.__init__
:noindex:


PositionalEncoding1D
---------------------
:class:`PositionalEncoding1D` applies positional encoding to the last dimension of a 3D tensor.

__init__
~~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncoding1D.__init__
:noindex:

forward
~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncoding1D.forward
:noindex:

PositionalEncodingPermute1D
---------------------------
:class:`PositionalEncodingPermute1D` permutes the input tensor and applies 1D positional encoding.

__init__
~~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute1D.__init__
:noindex:

forward
~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute1D.forward
:noindex:

PositionalEncoding2D
---------------------
:class:`PositionalEncoding2D` applies positional encoding to the last two dimensions of a 4D tensor.

__init__
~~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncoding2D.__init__
:noindex:

forward
~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncoding2D.forward
:noindex:

PositionalEncodingPermute2D
---------------------------
:class:`PositionalEncodingPermute2D` permutes the input tensor and applies 2D positional encoding.

__init__
~~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute2D.__init__
:noindex:

forward
~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute2D.forward
:noindex:

PositionalEncoding3D
---------------------
:class:`PositionalEncoding3D` applies positional encoding to the last three dimensions of a 5D tensor.

__init__
~~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncoding3D.__init__
:noindex:

forward
~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncoding3D.forward
:noindex:

PositionalEncodingPermute3D
---------------------------
:class:`PositionalEncodingPermute3D` permutes the input tensor and applies 3D positional encoding.

__init__
~~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute3D.__init__
:noindex:

forward
~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute3D.forward
:noindex:

Summer
------
:class:`Summer` adds positional encoding to the original tensor.

__init__
~~~~~~~~
.. automethod:: solo.utils.positional_encoding.Summer.__init__
:noindex:

forward
~~~~~~~
.. automethod:: solo.utils.positional_encoding.Summer.forward
:noindex:

58 changes: 58 additions & 0 deletions scripts/pretrain/cifar/all4one.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
defaults:
- _self_
- augmentations: asymmetric.yaml
- wandb: private.yaml
- override hydra/hydra_logging: disabled
- override hydra/job_logging: disabled

# disable hydra outputs
hydra:
output_subdir: null
run:
dir: .

name: "All4One-cifar100" # change here for cifar10
method: "all4one"
backbone:
name: "resnet18"
method_kwargs:
temperature: 0.2
proj_hidden_dim: 2048
pred_hidden_dim: 4096
proj_output_dim: 256
queue_size: 98304
momentum:
base_tau: 0.99
final_tau: 1.0
data:
dataset: cifar100 # change here for cifar10
train_path: "./datasets/"
val_path: "./datasets/"
format: "image_folder"
num_workers: 4
optimizer:
name: "lars"
batch_size: 256
lr: 1.0
classifier_lr: 0.1
weight_decay: 1e-5
kwargs:
clip_lr: True
eta: 0.02
exclude_bias_n_norm: True
scheduler:
name: "warmup_cosine"
checkpoint:
enabled: True
dir: "trained_models"
frequency: 1
auto_resume:
enabled: False

# overwrite PL stuff
max_epochs: 1000
devices: [0]
sync_batchnorm: True
accelerator: "gpu"
strategy: "ddp"
precision: 16-mixed
55 changes: 55 additions & 0 deletions scripts/pretrain/imagenet-100/all4one.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
defaults:
- _self_
- augmentations: asymmetric.yaml
- wandb: private.yaml
- override hydra/hydra_logging: disabled
- override hydra/job_logging: disabled

# disable hydra outputs
hydra:
output_subdir: null
run:
dir: .

name: "all4one-imagenet100"
method: "all4one"
backbone:
name: "resnet18"
method_kwargs:
temperature: 0.2
proj_hidden_dim: 2048
pred_hidden_dim: 4096
proj_output_dim: 256
queue_size: 98340
data:
dataset: imagenet100
train_path: "./datasets/imagenet-100/train"
val_path: "./datasets/imagenet-100/val"
format: "dali"
num_workers: 4
optimizer:
name: "lars"
batch_size: 128
lr: 1.0
classifier_lr: 0.1
weight_decay: 1e-5
kwargs:
clip_lr: True
eta: 0.02
exclude_bias_n_norm: True
scheduler:
name: "warmup_cosine"
checkpoint:
enabled: True
dir: "trained_models"
frequency: 1
auto_resume:
enabled: True

# overwrite PL stuff
max_epochs: 400
devices: [0, 1]
sync_batchnorm: True
accelerator: "gpu"
strategy: "ddp"
precision: 16-mixed
4 changes: 4 additions & 0 deletions solo/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from solo.methods.vibcreg import VIbCReg
from solo.methods.vicreg import VICReg
from solo.methods.wmse import WMSE
from solo.methods.all4one import All4One


METHODS = {
# base classes
Expand All @@ -61,6 +63,7 @@
"vibcreg": VIbCReg,
"vicreg": VICReg,
"wmse": WMSE,
"all4one": All4One,
}
__all__ = [
"BarlowTwins",
Expand All @@ -83,4 +86,5 @@
"VIbCReg",
"VICReg",
"WMSE",
"All4One",
]
Loading

0 comments on commit a151225

Please sign in to comment.