Self Supervised Learning (SSL) is the last frontier of AI. But quick experimentation is not possible today as no library offers stable and modular key-in-hand solutions. Existing libraries are either static (lightly-ai, solo-learn) or not concerned with SSL--justifying our existence. Our goal is to provide a flexible, full-fledge, optimized framework to conduct rapid SSL research and scale as needed.
To reach flexibility, scalability and stability, we rely on battle-tested third party libraries: PyTorch
, Lightning
, HuggingFace
, TorchMetrics
amongst a few others. Those dependencies allow us to focus on one thing: assembling everything into a powerful SSL research framework. stable-ssl
adopts a flexible and modular design for seamless integration of components from external libraries, including architectures, loss functions, evaluation metrics, and augmentations.
The key to SSL research is to log and monitor everything. This is what we bring to a new level with stable-ssl
by providing extremely rich logging and numerous callbacks that can be added and combined in any way you like within your trainer such as stable_ssl.callbacks.OnlineProbe
, stable_ssl.callbacks.OnlineKNN
, stable_ssl.callbacks.RankMe
, stable_ssl.callbacks.LiDAR
, stable_ssl.callbacks.OnlineWriter
, and so on.
stable-ssl
only requires you to get familiar with 3 components:
- data: the dataset should be a huggingface dataset e.g.
if it already exists on the Hub, otherwise you can wrap your own dataset into a HF dataset. Why? Imposing that format ensures consistent behavior (each sample is a dictionary) and leverage powerful utilities from the
import stable_ssl as ssl train_dataset = ssl.data.HFDataset( path="frgfm/imagenette", name="160px", split="train", transform=train_transform, )
datasets
package. Once datasets (train et al.) are created, they can be used as-is withtorch.utils.data.DataLoader
. However we recommend putting them into ourDataModule
e.g.to ensure precise logging and easy debugging.datamodule = ssl.data.DataModule(train=train_dataset, val=val_dataset, ...)
- module, models, forward: the overall orchestration leverages
ssl.Module
which inherits fromlightning.LightningModule
. We provide all the basic required utilities (optimizer/scheduler creation etc). So the only required implementation for the user is theforward
method, for example a supervised learning run would definethedef forward(self, batch, stage): batch["embedding"] = self.backbone(batch["image"])["logits"] if self.training: preds = self.classifier(batch["embedding"]) batch["loss"] = torch.nn.functional.cross_entropy(preds, batch["label"]) return batch
forward
method takes in a dictionary (batch
from the data loader) and should return a dictionary. If any module has to be trained, then aloss
key must be present. Further customization can be done (see theexamples
) ensuring that any desired behavior can be achieved. Theself
is a LightningModule with any attribute passed during module creation:anyconfig = AutoConfig.from_pretrained("microsoft/resnet-18") backbone = ViT(512) projector = torch.nn.Linear(512, 128) module = ssl.Module( backbone=backbone, projector=projector, forward=forward, simclr_loss=ssl.losses.NTXEntLoss(temperature=0.1), )
kwarg
passed tostable_ssl.Module
is automatically set, the only reservedkwarg
isforward
- trainer: the final step is to describe how training will happen! This is done with the
lightning.Trainer
module, for exampleonce this is specified, simply pipe everything into our manager class that will connect everything and launch fitting! This extra wrapper is needed to produce as precise logging as possible.trainer = pl.Trainer( max_epochs=10, num_sanity_val_steps=1, callbacks=[linear_probe, knn_probe, rankme], precision="16-mixed", logger=False, enable_checkpointing=False, ) manager = ssl.Manager(trainer=trainer, module=module, data=data) manager()
Minimal Example : SimCLR INET10
``` import stable_ssl as ssl import torch from transformers import AutoModelForImageClassification, AutoConfig import lightning as pl from stable_ssl.data import transforms import torchmetrics# without transform
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose(
transforms.RGB(),
transforms.RandomResizedCrop((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8
),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=(5, 5), p=1.0),
transforms.ToImage(mean=mean, std=std),
)
train_dataset = ssl.data.HFDataset(
path="frgfm/imagenette",
name="160px",
split="train",
transform=train_transform,
)
train = torch.utils.data.DataLoader(
dataset=train_dataset,
sampler=ssl.data.sampler.RepeatedRandomSampler(train_dataset, n_views=2),
batch_size=64,
num_workers=20,
drop_last=True,
)
val_transform = transforms.Compose(
transforms.RGB(),
transforms.Resize((256, 256)),
transforms.CenterCrop((224, 224)),
transforms.ToImage(mean=mean, std=std),
)
val = torch.utils.data.DataLoader(
dataset=ssl.data.HFDataset(
path="frgfm/imagenette",
name="160px",
split="validation",
transform=val_transform,
),
batch_size=128,
num_workers=10,
)
data = ssl.data.DataModule(train=train, val=val)
def forward(self, batch, stage):
batch["embedding"] = self.backbone(batch["image"])["logits"]
if self.training:
proj = self.projector(batch["embedding"])
views = ssl.data.fold_views(proj, batch["sample_idx"])
batch["loss"] = self.simclr_loss(views[0], views[1])
return batch
config = AutoConfig.from_pretrained("microsoft/resnet-18")
backbone = AutoModelForImageClassification.from_config(config)
projector = torch.nn.Linear(512, 128)
backbone.classifier[1] = torch.nn.Identity()
module = ssl.Module(
backbone=backbone,
projector=projector,
forward=forward,
simclr_loss=ssl.losses.NTXEntLoss(temperature=0.1),
)
linear_probe = ssl.callbacks.OnlineProbe(
"linear_probe",
module,
"embedding",
"label",
probe=torch.nn.Linear(512, 10),
loss_fn=torch.nn.CrossEntropyLoss(),
metrics={
"top1": torchmetrics.classification.MulticlassAccuracy(10),
"top5": torchmetrics.classification.MulticlassAccuracy(10, top_k=5),
},
)
knn_probe = ssl.callbacks.OnlineKNN(
module,
"knn_probe",
"embedding",
"label",
20000,
metrics=torchmetrics.classification.MulticlassAccuracy(10),
k=10,
features_dim=512,
)
trainer = pl.Trainer(
max_epochs=6,
num_sanity_val_steps=1,
callbacks=[linear_probe, knn_probe],
precision="16-mixed",
logger=False,
enable_checkpointing=False,
)
manager = ssl.Manager(trainer=trainer, module=module, data=data)
manager()
```
The library is not yet available on PyPI. You can install it from the source code, as follows.
-
conda (optional)
First use your favorite environment manager and install your favorite pytorch version, we provide an example with conda
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh bash Miniconda3-latest-Linux-x86_64.sh
follow installation instructions... once completed, create your environment
conda create -n my_env python=3.11
with your environment name (here
my_env
) and your favorite Python version (here,3.11
). Once completed, make sure to activate your environment (conda activate my_env
) before proceeding to the next steps!
-
Pytorch and our library (we recommend using
uv
for quicker package management):pip3 install uv uv pip install torch torchvision torchaudio uv pip install -e .
if you do not want to use uv, simply remove it from the above commands.
-
API login (optional)
wandb login huggingface-cli login
-
LATEX support in Matplotlib (optional)
-
Install the LaTex font (Computer Modern)
- we provide the ttf files in the repo to make things simple
- create your local folder (if not present) and copy the ttf files there
mkdir -p ~/.local/share/fonts
cp assets/cm-unicode-0.7.0\ 2/*ttf ~/.local/share/fonts/
- refresh the font cache with
fc-cache -f -v
- validate that the fonts are listed in your system with
fc-list | grep cmu
- refresh matplotlib cache
import shutil import matplotlib shutil.rmtree(matplotlib.get_cachedir())
-
Install the Tex compiler (optional, if not available on your system)
- install texlive locally following https://tug.org/texlive/quickinstall.html#running where you can use
-texdir your_path
to install to a local path (so you don't need sudo privileges) - follow the instructions at the end of the installation to edit the PATH variables, you can edit that variable for a conda environment with
conda env config vars set PATH=$PATH
- make sure inside the conde environment that you point to the right binaries e.g.
whereis latex
andwhereis mktexfmt
- If at some point there is an error that the file
latex.fmt
is not found. You can generate it withpdftex -ini -jobname=latex -progname=latex -translate-file=cp227.tcx *latex.ini
- or (unsure)
fmtutil-sys --all
- install texlive locally following https://tug.org/texlive/quickinstall.html#running where you can use
-
rc config (optional)
font.family: serif font.serif: cmr10 font.sans-serif: cmss10 font.monospace: cmtt10 text.usetex: True text.latex.preamble: \usepackage{amssymb} \usepackage{amsmath} \usepackage{bm} xtick.labelsize: 14 ytick.labelsize: 14 legend.fontsize: 14 axes.labelsize: 16 axes.titlesize: 16 axes.formatter.use_mathtext: True
which can be written to a file, e.g.,
~/.config/matplotlib/matplotlibrc
or set viarc
in your script directly. See here for more details. -
Example of matplotlib script to run for a quick test (optional)
from matplotlib import rc rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']}) rc('text', usetex=True) import numpy as np import matplotlib.pyplot as plt t = np.arange(0.0, 1.0 + 0.01, 0.01) s = np.cos(4 * np.pi * t) + 2 plt.rc('text', usetex=True) plt.rc('font', family='serif') plt.plot(t, s) plt.xlabel(r'\textbf{time} (s)') plt.ylabel(r'\textit{voltage} (mV)',fontsize=16) plt.title(r"\TeX\ is Number " r"$\displaystyle\sum_{n=1}^\infty\frac{-e^{i\pi}}{2^n}$!", fontsize=16, color='gray') # Make room for the ridiculously large title. plt.subplots_adjust(top=0.8) plt.savefig('tex_demo') plt.show()
-
-
If you'd like to contribute new features, bug fixes, or improvements to the documentation, please refer to our contributing guide for detailed instructions on how to get started.
-
You can also contribute by adding new methods, datasets, or configurations that improve the current performance of a method in the benchmark section.
stable-ssl
was started by Randall Balestriero
circa 2020 for internal research projects. After numerous refactorings and simplifications, it became practical for external use circa 2024 at which point Hugues Van Assel
and Lucas Maes
joined as core contributors.