Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions dinov2/configs/ssl_default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ train:
cache_dataset: true
centering: "centering" # or "sinkhorn_knopp"
unfreeze_last_n_blocks: 40
# Path to a teacher_checkpoint.pth (or raw backbone .pth) from a prior run or
# an external source (e.g. a DINOv3 pretrained backbone). When non-empty and
# use_pretrained is False, this checkpoint is used to warm-start the model.
pretrained_weights: ''
student:
arch: vit_large
patch_size: 16
Expand Down Expand Up @@ -119,3 +123,14 @@ crops:
evaluation:
eval_period_iterations: 12500
bach_root: /data/eva-data/bach
# ── DINOv3-style Gram (patch-similarity) loss ─────────────────────────────────
# Disabled by default. Set gram.use_loss: true and gram.img_level/loss_weight
# in a per-run config to enable. With ema_teacher: true (the default here) no
# separate model is needed — the EMA teacher's patch tokens are used as targets.
gram:
use_loss: false # master switch
img_level: true # true → per-image Gram; false → batch-level Gram
loss_weight: 0.1 # scalar weight on the gram MSE term
normalized: true # L2-normalise patch features before computing Gram
remove_neg: false # zero-clip negative similarities in both matrices
remove_only_teacher_neg: false # zero-clip teacher negatives only
81 changes: 81 additions & 0 deletions dinov2/configs/train/vitg14_reg4_dinov3init.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# ViT-G/14 fine-tuned from an EXTERNAL / DINOv3-FORMAT checkpoint.
#
# Use this config when you have a pretrained backbone that did NOT come from
# Meta's released DINOv2 torch.hub weights — for example:
# • A checkpoint produced by path-fm-dinov3 (MedARC-AI/path-fm-dinov3)
# • Any teacher_checkpoint.pth saved by a prior OpenMidnight run
# • A raw backbone .pth in flat or {'model': ...} format
#
# The loader (_load_from_external_checkpoint in train.py) auto-detects the
# format and loads backbone weights with strict=False, so extra keys specific
# to DINOv3 (e.g. RoPE tensors absent from this model) are safely skipped.
#
# How to use
# ----------
# Override train.pretrained_weights at launch time:
#
# torchrun ... train.py \
# --config-file .../vitg14_reg4_dinov3init.yaml \
# train.pretrained_weights=/path/to/teacher_checkpoint.pth
#
# OR set STAGE1_CHECKPOINT in run_dinov3init.sh and run that script.
#
# Key differences vs. vitg14_reg4.yaml (DINOv2 fine-tuning from torch.hub):
#
# use_pretrained: False — skips torch.hub download
# pretrained_weights: '' — overridden at launch (see above)
# base_lr: 5.0e-04 — somewhat higher than pure fine-tuning because
# the external checkpoint may not perfectly align
# with this model's architecture
# warmup_epochs: 10 — same as fine-tuning default
# layerwise_decay: 1.0 — same as fine-tuning default
dino:
head_n_prototypes: 131072
head_bottleneck_dim: 384
do_kde: True
kde_loss_weight: .05
koleo_loss_weight: 0
do_koleo: False
ibot:
loss_weight: 1.0
mask_sample_probability: 0.5
mask_ratio_min_max:
- 0.1
- 0.45
separate_head: true
head_n_prototypes: 131072
train:
sample_list_path: /block/TCGA/sample_dataset_30.txt
streaming_from_hf: false
streaming_dataset_path: medarc/TCGA-12K-parquet
batch_size_per_gpu: 48
centering: sinkhorn_knopp
use_pretrained: False # ← skip torch.hub; we use pretrained_weights below
pretrained_weights: '' # ← OVERRIDE at launch: train.pretrained_weights=/path/to/ckpt.pth
OFFICIAL_EPOCH_LENGTH: 1250
num_workers: 24
prefetch_factor: 8
skip_checkpointer: true
student:
arch: vit_giant2
patch_size: 14
drop_path_rate: 0.4
ffn_layer: swiglufused
block_chunks: 4
num_register_tokens: 4
teacher:
momentum_teacher: 0.994
optim:
epochs: 200
early_stop: 200
weight_decay_end: 0.2
base_lr: 5.0e-04 # ← slightly higher than pure fine-tuning (2e-4)
warmup_epochs: 10
layerwise_decay: 1.0
crops:
local_crops_size: 98
evaluation:
eval_period_iterations: 5000
bach_root: /block/eva-data/bach
breakhis_root: /block/eva-data/breakhis
pcam_root: /block/eva-data/patch_camelyon
79 changes: 79 additions & 0 deletions dinov2/configs/train/vitg14_reg4_gram.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# ViT-G/14 fine-tuning with the DINOv3-style GRAM (patch-similarity) loss.
#
# The Gram loss (dinov2/loss/gram_loss.py) adds an MSE term that enforces the
# student to reproduce the pairwise cosine-similarity structure of the teacher's
# patch tokens. This is the key new objective introduced in DINOv3 / path-fm.
#
# In this config we use "EMA teacher mode": the existing EMA teacher's patch
# tokens are used directly as Gram targets — no separate frozen model is needed.
#
# Key differences vs. vitg14_reg4.yaml (standard DINOv2 fine-tuning):
#
# gram.use_loss: true — enables the Gram MSE loss
# gram.loss_weight: 0.1 — conservative starting weight; increase toward
# 1.0 if loss curves remain stable
# gram.img_level: true — per-image Gram (B×N×N) rather than batch-level;
# more memory-efficient and matches DINOv3 defaults
# gram.normalized: true — L2-normalise before Gram (cosine similarity)
#
# All other settings are identical to vitg14_reg4.yaml so you can do a clean
# ablation: the only variable is the addition of the Gram loss term.
#
# Launch with run_gram.sh.
dino:
head_n_prototypes: 131072
head_bottleneck_dim: 384
do_kde: True
kde_loss_weight: .05
koleo_loss_weight: 0
do_koleo: False
ibot:
loss_weight: 1.0
mask_sample_probability: 0.5
mask_ratio_min_max:
- 0.1
- 0.45
separate_head: true
head_n_prototypes: 131072
gram:
use_loss: true # ← enable Gram (patch-similarity) loss
img_level: true # per-image Gram matrices
loss_weight: 0.1 # tune up to ~1.0 if training is stable
normalized: true # L2-normalise patch features before Gram
remove_neg: false
remove_only_teacher_neg: false
train:
sample_list_path: /block/TCGA/sample_dataset_30.txt
streaming_from_hf: false
streaming_dataset_path: medarc/TCGA-12K-parquet
batch_size_per_gpu: 48
centering: sinkhorn_knopp
use_pretrained: True # start from Meta's DINOv2 weights (same as baseline)
pretrained_weights: ''
OFFICIAL_EPOCH_LENGTH: 1250
num_workers: 24
prefetch_factor: 8
skip_checkpointer: true
student:
arch: vit_giant2
patch_size: 14
drop_path_rate: 0.4
ffn_layer: swiglufused
block_chunks: 4
num_register_tokens: 4
teacher:
momentum_teacher: 0.994
optim:
epochs: 200
early_stop: 200
weight_decay_end: 0.2
base_lr: 2.0e-04
warmup_epochs: 10
layerwise_decay: 1.0
crops:
local_crops_size: 98
evaluation:
eval_period_iterations: 5000
bach_root: /block/eva-data/bach
breakhis_root: /block/eva-data/breakhis
pcam_root: /block/eva-data/patch_camelyon
66 changes: 66 additions & 0 deletions dinov2/configs/train/vitg14_reg4_scratch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# ViT-G/14 trained FROM SCRATCH (no pretrained starting point).
#
# Key differences vs. vitg14_reg4.yaml (DINOv2-pretrained fine-tuning):
#
# use_pretrained: False — random initialisation; no torch.hub download
# base_lr: 1.5e-03 — ~7.5× higher LR; from-scratch needs a larger
# signal to escape random init (DINOv2 paper uses
# 2e-3 for ViT-g at batch 3072; we scale down
# proportionally for our smaller effective batch)
# warmup_epochs: 20 — 2× longer warmup for stable early training
# layerwise_decay: 0.9 — standard LLRD (vs 1.0 for fine-tuning) lets
# lower backbone layers learn more conservatively
# epochs / early_stop: 200 — same wall-time as the pretrained run; you may
# want to increase this for a full from-scratch run
#
# Launch with run_scratch.sh.
dino:
head_n_prototypes: 131072
head_bottleneck_dim: 384
do_kde: True
kde_loss_weight: .05
koleo_loss_weight: 0
do_koleo: False
ibot:
loss_weight: 1.0
mask_sample_probability: 0.5
mask_ratio_min_max:
- 0.1
- 0.45
separate_head: true
head_n_prototypes: 131072
train:
sample_list_path: /block/TCGA/sample_dataset_30.txt
streaming_from_hf: false
streaming_dataset_path: medarc/TCGA-12K-parquet
batch_size_per_gpu: 48
centering: sinkhorn_knopp
use_pretrained: False # ← key: train from random initialisation
pretrained_weights: '' # empty → no external checkpoint either
OFFICIAL_EPOCH_LENGTH: 1250
num_workers: 24
prefetch_factor: 8
skip_checkpointer: true
student:
arch: vit_giant2
patch_size: 14
drop_path_rate: 0.4
ffn_layer: swiglufused
block_chunks: 4
num_register_tokens: 4
teacher:
momentum_teacher: 0.994
optim:
epochs: 200
early_stop: 200
weight_decay_end: 0.2
base_lr: 1.5e-03 # ← tuned for from-scratch (≈7.5× fine-tuning LR)
warmup_epochs: 20 # ← longer warmup for random-init stability
layerwise_decay: 0.9 # ← standard LLRD (fine-tuning uses 1.0)
crops:
local_crops_size: 98
evaluation:
eval_period_iterations: 5000
bach_root: /block/eva-data/bach
breakhis_root: /block/eva-data/breakhis
pcam_root: /block/eva-data/patch_camelyon
1 change: 1 addition & 0 deletions dinov2/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .ibot_patch_loss import iBOTPatchLoss
from .koleo_loss import KoLeoLoss
from .kde_loss import KDELoss
from .gram_loss import GramLoss
101 changes: 101 additions & 0 deletions dinov2/loss/gram_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
Gram (patch-similarity) loss from the DINOv3 training recipe.

Reference: MedARC-AI/path-fm-dinov3 (dinov3/loss/gram_loss.py)

The loss measures the MSE between the pairwise cosine-similarity matrices
("Gram matrices") of the student's and the teacher's patch-token features.
This encourages the student to reproduce the full relational structure of
patch embeddings captured by the teacher, beyond just the CLS token.

Usage (EMA-teacher mode, the default in this codebase):
In forward_backward(), pass the teacher's x_norm_patchtokens as
target_feats and the student's x_norm_patchtokens as output_feats.

Shapes
------
img_level=True (default) : (B, N, D) — one Gram matrix per image
img_level=False : (B*N, D) — one Gram matrix per batch
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class GramLoss(nn.Module):
"""
MSE between normalised pairwise patch-similarity matrices of student and
teacher, as introduced in the DINOv3 training objective.

Args:
apply_norm: L2-normalise patch features before computing
similarities (recommended; matches DINOv3).
remove_neg: Zero-clip negative similarities in *both*
student and teacher matrices.
remove_only_teacher_neg: Zero-clip teacher negatives; clip student
only where the teacher was also negative.
Mutually exclusive with remove_neg.
"""

def __init__(
self,
apply_norm: bool = True,
remove_neg: bool = False,
remove_only_teacher_neg: bool = False,
):
super().__init__()
assert not (remove_neg and remove_only_teacher_neg), (
"remove_neg and remove_only_teacher_neg are mutually exclusive"
)
self.mse = nn.MSELoss()
self.apply_norm = apply_norm
self.remove_neg = remove_neg
self.remove_only_teacher_neg = remove_only_teacher_neg

def forward(
self,
output_feats: torch.Tensor,
target_feats: torch.Tensor,
img_level: bool = True,
) -> torch.Tensor:
"""
Args:
output_feats: Student patch features.
Shape (B, N, D) if img_level else (B*N, D).
target_feats: Teacher patch features (no gradient flows).
Same shape as output_feats.
img_level: If True compute one Gram per image; if False compute
one Gram across the whole local batch.

Returns:
Scalar MSE loss averaged over all similarity pairs.
"""
# Always use fp32 for stability (patch tokens may be fp16)
output_feats = output_feats.float()
target_feats = target_feats.float()

# ── Teacher Gram matrix ───────────────────────────────────────────────
if self.apply_norm:
target_feats = F.normalize(target_feats, dim=-1)
if not img_level and target_feats.dim() == 3:
target_feats = target_feats.flatten(0, 1) # (B*N, D)
target_sim = torch.matmul(target_feats, target_feats.transpose(-1, -2))

# ── Student Gram matrix ───────────────────────────────────────────────
if self.apply_norm:
output_feats = F.normalize(output_feats, dim=-1)
if not img_level and output_feats.dim() == 3:
output_feats = output_feats.flatten(0, 1)
student_sim = torch.matmul(output_feats, output_feats.transpose(-1, -2))

# ── Optional negative clipping ─────────────────────────────────────────
if self.remove_neg:
target_sim = target_sim.clamp(min=0.0)
student_sim = student_sim.clamp(min=0.0)
elif self.remove_only_teacher_neg:
neg_mask = target_sim < 0
target_sim = target_sim.clamp(min=0.0)
student_sim[neg_mask] = student_sim[neg_mask].clamp(min=0.0)

return self.mse(student_sim, target_sim)
Loading