diff --git a/dinov2/configs/ssl_default_config.yaml b/dinov2/configs/ssl_default_config.yaml index 3320172..673f9f8 100644 --- a/dinov2/configs/ssl_default_config.yaml +++ b/dinov2/configs/ssl_default_config.yaml @@ -70,6 +70,10 @@ train: cache_dataset: true centering: "centering" # or "sinkhorn_knopp" unfreeze_last_n_blocks: 40 + # Path to a teacher_checkpoint.pth (or raw backbone .pth) from a prior run or + # an external source (e.g. a DINOv3 pretrained backbone). When non-empty and + # use_pretrained is False, this checkpoint is used to warm-start the model. + pretrained_weights: '' student: arch: vit_large patch_size: 16 @@ -119,3 +123,14 @@ crops: evaluation: eval_period_iterations: 12500 bach_root: /data/eva-data/bach +# ── DINOv3-style Gram (patch-similarity) loss ───────────────────────────────── +# Disabled by default. Set gram.use_loss: true and gram.img_level/loss_weight +# in a per-run config to enable. With ema_teacher: true (the default here) no +# separate model is needed — the EMA teacher's patch tokens are used as targets. +gram: + use_loss: false # master switch + img_level: true # true → per-image Gram; false → batch-level Gram + loss_weight: 0.1 # scalar weight on the gram MSE term + normalized: true # L2-normalise patch features before computing Gram + remove_neg: false # zero-clip negative similarities in both matrices + remove_only_teacher_neg: false # zero-clip teacher negatives only diff --git a/dinov2/configs/train/vitg14_reg4_dinov3init.yaml b/dinov2/configs/train/vitg14_reg4_dinov3init.yaml new file mode 100644 index 0000000..aac9a89 --- /dev/null +++ b/dinov2/configs/train/vitg14_reg4_dinov3init.yaml @@ -0,0 +1,81 @@ +# ViT-G/14 fine-tuned from an EXTERNAL / DINOv3-FORMAT checkpoint. +# +# Use this config when you have a pretrained backbone that did NOT come from +# Meta's released DINOv2 torch.hub weights — for example: +# • A checkpoint produced by path-fm-dinov3 (MedARC-AI/path-fm-dinov3) +# • Any teacher_checkpoint.pth saved by a prior OpenMidnight run +# • A raw backbone .pth in flat or {'model': ...} format +# +# The loader (_load_from_external_checkpoint in train.py) auto-detects the +# format and loads backbone weights with strict=False, so extra keys specific +# to DINOv3 (e.g. RoPE tensors absent from this model) are safely skipped. +# +# How to use +# ---------- +# Override train.pretrained_weights at launch time: +# +# torchrun ... train.py \ +# --config-file .../vitg14_reg4_dinov3init.yaml \ +# train.pretrained_weights=/path/to/teacher_checkpoint.pth +# +# OR set STAGE1_CHECKPOINT in run_dinov3init.sh and run that script. +# +# Key differences vs. vitg14_reg4.yaml (DINOv2 fine-tuning from torch.hub): +# +# use_pretrained: False — skips torch.hub download +# pretrained_weights: '' — overridden at launch (see above) +# base_lr: 5.0e-04 — somewhat higher than pure fine-tuning because +# the external checkpoint may not perfectly align +# with this model's architecture +# warmup_epochs: 10 — same as fine-tuning default +# layerwise_decay: 1.0 — same as fine-tuning default +dino: + head_n_prototypes: 131072 + head_bottleneck_dim: 384 + do_kde: True + kde_loss_weight: .05 + koleo_loss_weight: 0 + do_koleo: False +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.45 + separate_head: true + head_n_prototypes: 131072 +train: + sample_list_path: /block/TCGA/sample_dataset_30.txt + streaming_from_hf: false + streaming_dataset_path: medarc/TCGA-12K-parquet + batch_size_per_gpu: 48 + centering: sinkhorn_knopp + use_pretrained: False # ← skip torch.hub; we use pretrained_weights below + pretrained_weights: '' # ← OVERRIDE at launch: train.pretrained_weights=/path/to/ckpt.pth + OFFICIAL_EPOCH_LENGTH: 1250 + num_workers: 24 + prefetch_factor: 8 + skip_checkpointer: true +student: + arch: vit_giant2 + patch_size: 14 + drop_path_rate: 0.4 + ffn_layer: swiglufused + block_chunks: 4 + num_register_tokens: 4 +teacher: + momentum_teacher: 0.994 +optim: + epochs: 200 + early_stop: 200 + weight_decay_end: 0.2 + base_lr: 5.0e-04 # ← slightly higher than pure fine-tuning (2e-4) + warmup_epochs: 10 + layerwise_decay: 1.0 +crops: + local_crops_size: 98 +evaluation: + eval_period_iterations: 5000 + bach_root: /block/eva-data/bach + breakhis_root: /block/eva-data/breakhis + pcam_root: /block/eva-data/patch_camelyon diff --git a/dinov2/configs/train/vitg14_reg4_gram.yaml b/dinov2/configs/train/vitg14_reg4_gram.yaml new file mode 100644 index 0000000..14c68ee --- /dev/null +++ b/dinov2/configs/train/vitg14_reg4_gram.yaml @@ -0,0 +1,79 @@ +# ViT-G/14 fine-tuning with the DINOv3-style GRAM (patch-similarity) loss. +# +# The Gram loss (dinov2/loss/gram_loss.py) adds an MSE term that enforces the +# student to reproduce the pairwise cosine-similarity structure of the teacher's +# patch tokens. This is the key new objective introduced in DINOv3 / path-fm. +# +# In this config we use "EMA teacher mode": the existing EMA teacher's patch +# tokens are used directly as Gram targets — no separate frozen model is needed. +# +# Key differences vs. vitg14_reg4.yaml (standard DINOv2 fine-tuning): +# +# gram.use_loss: true — enables the Gram MSE loss +# gram.loss_weight: 0.1 — conservative starting weight; increase toward +# 1.0 if loss curves remain stable +# gram.img_level: true — per-image Gram (B×N×N) rather than batch-level; +# more memory-efficient and matches DINOv3 defaults +# gram.normalized: true — L2-normalise before Gram (cosine similarity) +# +# All other settings are identical to vitg14_reg4.yaml so you can do a clean +# ablation: the only variable is the addition of the Gram loss term. +# +# Launch with run_gram.sh. +dino: + head_n_prototypes: 131072 + head_bottleneck_dim: 384 + do_kde: True + kde_loss_weight: .05 + koleo_loss_weight: 0 + do_koleo: False +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.45 + separate_head: true + head_n_prototypes: 131072 +gram: + use_loss: true # ← enable Gram (patch-similarity) loss + img_level: true # per-image Gram matrices + loss_weight: 0.1 # tune up to ~1.0 if training is stable + normalized: true # L2-normalise patch features before Gram + remove_neg: false + remove_only_teacher_neg: false +train: + sample_list_path: /block/TCGA/sample_dataset_30.txt + streaming_from_hf: false + streaming_dataset_path: medarc/TCGA-12K-parquet + batch_size_per_gpu: 48 + centering: sinkhorn_knopp + use_pretrained: True # start from Meta's DINOv2 weights (same as baseline) + pretrained_weights: '' + OFFICIAL_EPOCH_LENGTH: 1250 + num_workers: 24 + prefetch_factor: 8 + skip_checkpointer: true +student: + arch: vit_giant2 + patch_size: 14 + drop_path_rate: 0.4 + ffn_layer: swiglufused + block_chunks: 4 + num_register_tokens: 4 +teacher: + momentum_teacher: 0.994 +optim: + epochs: 200 + early_stop: 200 + weight_decay_end: 0.2 + base_lr: 2.0e-04 + warmup_epochs: 10 + layerwise_decay: 1.0 +crops: + local_crops_size: 98 +evaluation: + eval_period_iterations: 5000 + bach_root: /block/eva-data/bach + breakhis_root: /block/eva-data/breakhis + pcam_root: /block/eva-data/patch_camelyon diff --git a/dinov2/configs/train/vitg14_reg4_scratch.yaml b/dinov2/configs/train/vitg14_reg4_scratch.yaml new file mode 100644 index 0000000..13ecf22 --- /dev/null +++ b/dinov2/configs/train/vitg14_reg4_scratch.yaml @@ -0,0 +1,66 @@ +# ViT-G/14 trained FROM SCRATCH (no pretrained starting point). +# +# Key differences vs. vitg14_reg4.yaml (DINOv2-pretrained fine-tuning): +# +# use_pretrained: False — random initialisation; no torch.hub download +# base_lr: 1.5e-03 — ~7.5× higher LR; from-scratch needs a larger +# signal to escape random init (DINOv2 paper uses +# 2e-3 for ViT-g at batch 3072; we scale down +# proportionally for our smaller effective batch) +# warmup_epochs: 20 — 2× longer warmup for stable early training +# layerwise_decay: 0.9 — standard LLRD (vs 1.0 for fine-tuning) lets +# lower backbone layers learn more conservatively +# epochs / early_stop: 200 — same wall-time as the pretrained run; you may +# want to increase this for a full from-scratch run +# +# Launch with run_scratch.sh. +dino: + head_n_prototypes: 131072 + head_bottleneck_dim: 384 + do_kde: True + kde_loss_weight: .05 + koleo_loss_weight: 0 + do_koleo: False +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.45 + separate_head: true + head_n_prototypes: 131072 +train: + sample_list_path: /block/TCGA/sample_dataset_30.txt + streaming_from_hf: false + streaming_dataset_path: medarc/TCGA-12K-parquet + batch_size_per_gpu: 48 + centering: sinkhorn_knopp + use_pretrained: False # ← key: train from random initialisation + pretrained_weights: '' # empty → no external checkpoint either + OFFICIAL_EPOCH_LENGTH: 1250 + num_workers: 24 + prefetch_factor: 8 + skip_checkpointer: true +student: + arch: vit_giant2 + patch_size: 14 + drop_path_rate: 0.4 + ffn_layer: swiglufused + block_chunks: 4 + num_register_tokens: 4 +teacher: + momentum_teacher: 0.994 +optim: + epochs: 200 + early_stop: 200 + weight_decay_end: 0.2 + base_lr: 1.5e-03 # ← tuned for from-scratch (≈7.5× fine-tuning LR) + warmup_epochs: 20 # ← longer warmup for random-init stability + layerwise_decay: 0.9 # ← standard LLRD (fine-tuning uses 1.0) +crops: + local_crops_size: 98 +evaluation: + eval_period_iterations: 5000 + bach_root: /block/eva-data/bach + breakhis_root: /block/eva-data/breakhis + pcam_root: /block/eva-data/patch_camelyon diff --git a/dinov2/loss/__init__.py b/dinov2/loss/__init__.py index c9da02a..a21b2d0 100644 --- a/dinov2/loss/__init__.py +++ b/dinov2/loss/__init__.py @@ -7,3 +7,4 @@ from .ibot_patch_loss import iBOTPatchLoss from .koleo_loss import KoLeoLoss from .kde_loss import KDELoss +from .gram_loss import GramLoss diff --git a/dinov2/loss/gram_loss.py b/dinov2/loss/gram_loss.py new file mode 100644 index 0000000..d3b7276 --- /dev/null +++ b/dinov2/loss/gram_loss.py @@ -0,0 +1,101 @@ +""" +Gram (patch-similarity) loss from the DINOv3 training recipe. + +Reference: MedARC-AI/path-fm-dinov3 (dinov3/loss/gram_loss.py) + +The loss measures the MSE between the pairwise cosine-similarity matrices +("Gram matrices") of the student's and the teacher's patch-token features. +This encourages the student to reproduce the full relational structure of +patch embeddings captured by the teacher, beyond just the CLS token. + +Usage (EMA-teacher mode, the default in this codebase): + In forward_backward(), pass the teacher's x_norm_patchtokens as + target_feats and the student's x_norm_patchtokens as output_feats. + +Shapes +------ + img_level=True (default) : (B, N, D) — one Gram matrix per image + img_level=False : (B*N, D) — one Gram matrix per batch +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GramLoss(nn.Module): + """ + MSE between normalised pairwise patch-similarity matrices of student and + teacher, as introduced in the DINOv3 training objective. + + Args: + apply_norm: L2-normalise patch features before computing + similarities (recommended; matches DINOv3). + remove_neg: Zero-clip negative similarities in *both* + student and teacher matrices. + remove_only_teacher_neg: Zero-clip teacher negatives; clip student + only where the teacher was also negative. + Mutually exclusive with remove_neg. + """ + + def __init__( + self, + apply_norm: bool = True, + remove_neg: bool = False, + remove_only_teacher_neg: bool = False, + ): + super().__init__() + assert not (remove_neg and remove_only_teacher_neg), ( + "remove_neg and remove_only_teacher_neg are mutually exclusive" + ) + self.mse = nn.MSELoss() + self.apply_norm = apply_norm + self.remove_neg = remove_neg + self.remove_only_teacher_neg = remove_only_teacher_neg + + def forward( + self, + output_feats: torch.Tensor, + target_feats: torch.Tensor, + img_level: bool = True, + ) -> torch.Tensor: + """ + Args: + output_feats: Student patch features. + Shape (B, N, D) if img_level else (B*N, D). + target_feats: Teacher patch features (no gradient flows). + Same shape as output_feats. + img_level: If True compute one Gram per image; if False compute + one Gram across the whole local batch. + + Returns: + Scalar MSE loss averaged over all similarity pairs. + """ + # Always use fp32 for stability (patch tokens may be fp16) + output_feats = output_feats.float() + target_feats = target_feats.float() + + # ── Teacher Gram matrix ─────────────────────────────────────────────── + if self.apply_norm: + target_feats = F.normalize(target_feats, dim=-1) + if not img_level and target_feats.dim() == 3: + target_feats = target_feats.flatten(0, 1) # (B*N, D) + target_sim = torch.matmul(target_feats, target_feats.transpose(-1, -2)) + + # ── Student Gram matrix ─────────────────────────────────────────────── + if self.apply_norm: + output_feats = F.normalize(output_feats, dim=-1) + if not img_level and output_feats.dim() == 3: + output_feats = output_feats.flatten(0, 1) + student_sim = torch.matmul(output_feats, output_feats.transpose(-1, -2)) + + # ── Optional negative clipping ───────────────────────────────────────── + if self.remove_neg: + target_sim = target_sim.clamp(min=0.0) + student_sim = student_sim.clamp(min=0.0) + elif self.remove_only_teacher_neg: + neg_mask = target_sim < 0 + target_sim = target_sim.clamp(min=0.0) + student_sim[neg_mask] = student_sim[neg_mask].clamp(min=0.0) + + return self.mse(student_sim, target_sim) diff --git a/dinov2/train/ssl_meta_arch.py b/dinov2/train/ssl_meta_arch.py index db95104..42179c7 100644 --- a/dinov2/train/ssl_meta_arch.py +++ b/dinov2/train/ssl_meta_arch.py @@ -9,7 +9,7 @@ import torch from torch import nn -from dinov2.loss import DINOLoss, iBOTPatchLoss, KoLeoLoss, KDELoss +from dinov2.loss import DINOLoss, iBOTPatchLoss, KoLeoLoss, KDELoss, GramLoss from dinov2.models import build_model_from_cfg from dinov2.layers import DINOHead from dinov2.utils.utils import has_batchnorms @@ -114,6 +114,20 @@ def __init__(self, cfg): else: logger.info("OPTIONS -- IBOT -- head shared with DINO") + # ── Gram (patch-similarity) loss — DINOv3-style ─────────────────────── + gram_cfg = getattr(cfg, "gram", None) + self.do_gram = gram_cfg is not None and getattr(gram_cfg, "use_loss", False) + if self.do_gram: + logger.info("OPTIONS -- GRAM loss enabled (EMA teacher mode)") + logger.info(f"OPTIONS -- GRAM -- loss_weight: {gram_cfg.loss_weight}") + logger.info(f"OPTIONS -- GRAM -- img_level: {gram_cfg.img_level}") + logger.info(f"OPTIONS -- GRAM -- normalized: {gram_cfg.normalized}") + self.gram_loss_fn = GramLoss( + apply_norm=gram_cfg.normalized, + remove_neg=gram_cfg.remove_neg, + remove_only_teacher_neg=gram_cfg.remove_only_teacher_neg, + ) + self.need_to_synchronize_fsdp_streams = True self.student = nn.ModuleDict(student_model_dict) @@ -228,9 +242,21 @@ def get_teacher_output(): else: raise NotImplementedError - return teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered + # Return raw patch tokens so the Gram loss can use them when enabled. + # When do_gram is False this tensor is simply ignored by the caller. + gram_teacher_patch_tokens = ibot_teacher_patch_tokens if self.do_gram else None + + return ( + teacher_dino_softmaxed_centered_list, + masked_teacher_ibot_softmaxed_centered, + gram_teacher_patch_tokens, + ) - teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output() + ( + teacher_dino_softmaxed_centered_list, + masked_teacher_ibot_softmaxed_centered, + gram_teacher_patch_tokens, + ) = get_teacher_output() reshard_fsdp_model(self.teacher) loss_dict = {} @@ -355,6 +381,21 @@ def get_teacher_output(): # accumulate loss loss_accumulator += self.ibot_loss_weight * ibot_patch_loss + # ── Gram (patch-similarity) loss ───────────────────────────────────── + if self.do_gram: + gram_cfg = self.cfg.gram + # Use full (unmasked) global crop patch tokens from both student and teacher. + # Teacher tokens come from get_teacher_output() under @no_grad so they + # already have requires_grad=False. + student_patch_tokens = student_global_backbone_output_dict["x_norm_patchtokens"] + gram_loss_val = gram_cfg.loss_weight * self.gram_loss_fn( + student_patch_tokens, + gram_teacher_patch_tokens, + img_level=gram_cfg.img_level, + ) + loss_accumulator += gram_loss_val + loss_dict["gram_loss"] = gram_loss_val.detach() + self.backprop_loss(loss_accumulator) self.fsdp_synchronize_streams() diff --git a/dinov2/train/train.py b/dinov2/train/train.py index a824e63..40f13f6 100644 --- a/dinov2/train/train.py +++ b/dinov2/train/train.py @@ -290,6 +290,85 @@ def _mlp_kind(block): student_backbone.norm.bias.copy_(model_pretrained.norm.bias) +def _load_from_external_checkpoint(cfg, model): + """Warm-start student **and** teacher backbones from an external checkpoint. + + Supports three checkpoint formats that naturally arise in practice: + + 1. **OpenMidnight / DINOv3 teacher checkpoint** — saved by ``do_test()``: + ``{"teacher": {"backbone.": ..., "dino_head.": ..., ...}}`` + + 2. **Raw backbone state-dict** — produced by frameworks that save only the + backbone (e.g. ``torch.save(model.state_dict(), path)``): + ``{"model": {"": ...}}`` *or* ``{"backbone": {"": ...}}`` + + 3. **Flat state-dict** — bare ``{: tensor}`` mapping without any + nesting prefix (the format Meta uses for released ViT checkpoints). + + In all cases the backbone weights are loaded with ``strict=False`` so that + architecture mismatches (e.g. loading a DINOv3 checkpoint into a model + without RoPE) are handled gracefully — matching parameters are copied and + unexpected keys are simply skipped with a warning. + + Both the student and teacher backbone are initialised from the same + checkpoint. The DINO / iBOT head weights that may be present in a teacher + checkpoint are **not** loaded here; the heads are always freshly + initialised. + + Config: ``train.pretrained_weights`` — path to the checkpoint file. + """ + path = cfg.train.pretrained_weights + if not path: + return + logger.info("Loading external backbone checkpoint from: %s", path) + raw = torch.load(path, map_location="cpu") + + # ── Detect format and extract backbone state-dict ───────────────────────── + if isinstance(raw, dict): + if "teacher" in raw: + # OpenMidnight / DINOv3 teacher checkpoint + sd = raw["teacher"] + # strip leading "backbone." prefix → plain backbone state-dict + sd = {k[len("backbone."):]: v for k, v in sd.items() if k.startswith("backbone.")} + logger.info("Detected teacher-checkpoint format; extracted backbone weights.") + elif "model" in raw: + sd = raw["model"] + logger.info("Detected 'model' key format.") + elif "backbone" in raw: + sd = raw["backbone"] + logger.info("Detected 'backbone' key format.") + else: + # Assume flat state-dict + sd = raw + logger.info("Detected flat state-dict format.") + else: + raise ValueError(f"Unrecognised checkpoint type: {type(raw)}") + + student_backbone = model.student.backbone + teacher_backbone = model.teacher.backbone + + missing_s, unexpected_s = student_backbone.load_state_dict(sd, strict=False) + missing_t, unexpected_t = teacher_backbone.load_state_dict(sd, strict=False) + + if missing_s: + logger.warning( + "External checkpoint: %d missing keys in student backbone (e.g. %s)", + len(missing_s), + missing_s[:3], + ) + if unexpected_s: + logger.warning( + "External checkpoint: %d unexpected keys in student backbone (e.g. %s)", + len(unexpected_s), + unexpected_s[:3], + ) + logger.info( + "External checkpoint loaded into student and teacher backbones " + "(%d keys matched).", + len(sd) - len(missing_s), + ) + + def _freeze_student_backbone_except_last_n(cfg, model): n_unfrozen = cfg.train.unfreeze_last_n_blocks student_backbone = model.student.backbone @@ -1262,9 +1341,12 @@ def main(args): cfg = setup(args) print(cfg) model = SSLMetaArch(cfg).to(torch.device("cuda")) - #Load model here from pretrained. if cfg.train.use_pretrained: + # Load Meta's released DINOv2 weights via torch.hub _load_pretrained_backbone(cfg, model) + elif getattr(cfg.train, "pretrained_weights", ""): + # Load from an external checkpoint (local teacher ckpt, DINOv3 backbone, etc.) + _load_from_external_checkpoint(cfg, model) _freeze_student_backbone_except_last_n(cfg, model) model.prepare_for_distributed_training() diff --git a/run_dinov3init.sh b/run_dinov3init.sh new file mode 100644 index 0000000..ab610d0 --- /dev/null +++ b/run_dinov3init.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +# run_dinov3init.sh — fine-tune from an external / DINOv3-format checkpoint. +# +# Use this when your starting weights come from: +# • A prior OpenMidnight teacher_checkpoint.pth +# • A path-fm-dinov3 (MedARC-AI) backbone checkpoint +# • Any flat backbone .pth or {'model': ...} format file +# +# The loader auto-detects the checkpoint format and loads backbone weights with +# strict=False, so architecture-specific keys (e.g. DINOv3 RoPE tensors) that +# are absent from this model are safely skipped. +# +# Usage: +# STAGE1_CHECKPOINT=/path/to/teacher_checkpoint.pth bash run_dinov3init.sh +# +# You MUST set STAGE1_CHECKPOINT. + +set -euo pipefail + +NUM_GPUS=${NUM_GPUS:-8} +MASTER_PORT=${MASTER_PORT:-12357} + +: "${STAGE1_CHECKPOINT:?Please set STAGE1_CHECKPOINT to the path of your pretrained checkpoint.}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CONFIG_FILE="${SCRIPT_DIR}/dinov2/configs/train/vitg14_reg4_dinov3init.yaml" +OUTPUT_DIR="${OUTPUT_DIR:-/block/openmidnight-dinov3init-$(date +%Y%m%d_%H%M%S)}" + +echo "=== OpenMidnight: external-checkpoint fine-tuning ===" +echo " GPUs : ${NUM_GPUS}" +echo " Config : ${CONFIG_FILE}" +echo " Checkpoint : ${STAGE1_CHECKPOINT}" +echo " Output dir : ${OUTPUT_DIR}" +echo "======================================================" + +mkdir -p "${OUTPUT_DIR}" + +torchrun \ + --nproc_per_node="${NUM_GPUS}" \ + --master_port="${MASTER_PORT}" \ + "${SCRIPT_DIR}/dinov2/train/train.py" \ + --config-file "${CONFIG_FILE}" \ + --no-resume \ + train.output_dir="${OUTPUT_DIR}" \ + train.pretrained_weights="${STAGE1_CHECKPOINT}" diff --git a/run_gram.sh b/run_gram.sh new file mode 100644 index 0000000..7e2d062 --- /dev/null +++ b/run_gram.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +# run_gram.sh — launch ViT-G/14 fine-tuning WITH the DINOv3-style Gram loss. +# +# The Gram loss adds a patch-similarity MSE term (dinov2/loss/gram_loss.py) on +# top of the standard DINO + iBOT objective. The EMA teacher's patch tokens +# are used as reference (no separate frozen model needed). +# +# Compare results against a baseline run of run_train.sh (vitg14_reg4.yaml) to +# ablate the effect of the Gram objective. +# +# Usage: +# bash run_gram.sh +# +# Set GRAM_LOSS_WEIGHT to tune the weight of the Gram term (default 0.1). + +set -euo pipefail + +NUM_GPUS=${NUM_GPUS:-8} +MASTER_PORT=${MASTER_PORT:-12356} +GRAM_LOSS_WEIGHT=${GRAM_LOSS_WEIGHT:-0.1} + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CONFIG_FILE="${SCRIPT_DIR}/dinov2/configs/train/vitg14_reg4_gram.yaml" +OUTPUT_DIR="${OUTPUT_DIR:-/block/openmidnight-gram-$(date +%Y%m%d_%H%M%S)}" + +echo "=== OpenMidnight: DINOv3-style Gram loss training ===" +echo " GPUs : ${NUM_GPUS}" +echo " Config : ${CONFIG_FILE}" +echo " Gram loss weight: ${GRAM_LOSS_WEIGHT}" +echo " Output dir : ${OUTPUT_DIR}" +echo "======================================================" + +mkdir -p "${OUTPUT_DIR}" + +torchrun \ + --nproc_per_node="${NUM_GPUS}" \ + --master_port="${MASTER_PORT}" \ + "${SCRIPT_DIR}/dinov2/train/train.py" \ + --config-file "${CONFIG_FILE}" \ + --no-resume \ + train.output_dir="${OUTPUT_DIR}" \ + gram.loss_weight="${GRAM_LOSS_WEIGHT}" diff --git a/run_scratch.sh b/run_scratch.sh new file mode 100644 index 0000000..1c6fe0b --- /dev/null +++ b/run_scratch.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# run_scratch.sh — launch from-scratch ViT-G/14 training (no pretrained init). +# +# This script trains the model from random initialisation using the tuned +# hyperparameters in vitg14_reg4_scratch.yaml (higher LR, longer warmup, +# layerwise decay 0.9). +# +# Usage: +# bash run_scratch.sh +# +# Adjust NUM_GPUS, OUTPUT_DIR, and data paths as needed. + +set -euo pipefail + +NUM_GPUS=${NUM_GPUS:-8} +MASTER_PORT=${MASTER_PORT:-12355} + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CONFIG_FILE="${SCRIPT_DIR}/dinov2/configs/train/vitg14_reg4_scratch.yaml" +OUTPUT_DIR="${OUTPUT_DIR:-/block/openmidnight-scratch-$(date +%Y%m%d_%H%M%S)}" + +echo "=== OpenMidnight: from-scratch training ===" +echo " GPUs : ${NUM_GPUS}" +echo " Config : ${CONFIG_FILE}" +echo " Output dir : ${OUTPUT_DIR}" +echo "============================================" + +mkdir -p "${OUTPUT_DIR}" + +torchrun \ + --nproc_per_node="${NUM_GPUS}" \ + --master_port="${MASTER_PORT}" \ + "${SCRIPT_DIR}/dinov2/train/train.py" \ + --config-file "${CONFIG_FILE}" \ + --no-resume \ + train.output_dir="${OUTPUT_DIR}"