diff --git a/README.md b/README.md index 3134779..dae5c02 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,11 @@ and plots. The primary reported result is a verified Split CIFAR-10 suite. ## Project Scope - Config-driven benchmark runner for single runs and multi-method suites. -- Implemented baseline fine-tuning, EWC, reservoir replay, LwF, DER++, and A-GEM. +- Implemented baseline fine-tuning, EWC, reservoir replay, LwF, DER++, A-GEM, + ER-ACE, GDumb, and experimental Calibrated Anchor Replay. +- Includes CAR-component ablations exposed as `bic`, `icarl`, and `x_der_lite`; + these are lightweight protocol baselines, not exact reproductions of the + original papers. - Deterministic synthetic CI benchmark plus real MNIST and CIFAR-10 task streams. - Artifact tracking for config snapshots, metadata, JSONL events, CSV matrices, checkpoints, MLflow runs, aggregate reports, and plots. @@ -68,6 +72,43 @@ cl-bench suite \ --title "Split CIFAR-10 Headline Benchmark" ``` +Run the high-memory GDumb comparison used in the report: + +```bash +cl-bench suite \ + --config-name split_cifar10_headline \ + --methods gdumb \ + --seeds 13 21 \ + --tracking both \ + strategy.replay_buffer_size=10000 \ + strategy.gdumb_epochs=20 +``` + +Run a matched-memory paper suite: + +```bash +cl-bench suite \ + --config-name paper/split_cifar10_full \ + --methods replay derpp er_ace gdumb car bic icarl x_der_lite \ + --seeds 13 21 34 55 89 \ + --memory-budgets 200 500 1000 2000 5000 \ + --tracking both \ + --paper \ + --report-dir docs/paper/assets/split_cifar10_full \ + --title "Split CIFAR-10 Full-Data Paper Protocol" +``` + +Run a focused CAR hyperparameter sweep: + +```bash +cl-bench sweep \ + --config-name paper/split_cifar10_full \ + --method car \ + --study-name car_split_cifar10 \ + --n-trials 50 \ + --tracking both +``` + Use Hydra/OmegaConf-style overrides for quick experiments: ```bash @@ -94,32 +135,43 @@ cl-bench report \ --title "Local continual-learning report" ``` +Generate paper-oriented reports and comparison exports: + +```bash +cl-bench report --runs runs --output-dir docs/paper/assets/local --title "Paper report" --paper +cl-bench export --runs runs --output-dir docs/paper/exports --format mammoth +cl-bench export --runs runs --output-dir docs/paper/exports --format avalanche +``` + ## Verified Headline Benchmark Local verification on 2026-05-25 used Python 3.11.15, PyTorch 2.12.0, torchvision 0.27.0, NumPy 2.4.6, Hydra 1.3.2, MLflow 3.12.0, Ruff 0.15.14, pytest 9.0.3, and Matplotlib 3.10.9. -Command: +Commands: ```bash cl-bench suite --config-name split_cifar10_headline --methods baseline ewc replay lwf derpp agem --seeds 13 21 --tracking both --report-dir docs/assets/split_cifar10_headline --title "Split CIFAR-10 Headline Benchmark" +cl-bench suite --config-name split_cifar10_headline --methods gdumb --seeds 13 21 --tracking both strategy.replay_buffer_size=10000 strategy.gdumb_epochs=20 ``` The headline benchmark uses real CIFAR-10 images, five class-incremental tasks, 2,500 training examples per task, 1,000 test examples per task, two seeds, -5 epochs per task, a compact residual CIFAR ConvNet, and a 5,000-example replay -memory budget where applicable. It is a reproducible benchmark, not a paper -leaderboard claim. - -| Method | Average final accuracy | Average forgetting | Mean runtime | -| --- | ---: | ---: | ---: | -| DER++ | 51.15% +- 3.95% | 34.06% +- 4.74% | 578.7s | -| replay | 41.99% +- 0.27% | 45.27% +- 1.73% | 547.4s | -| LwF | 16.53% +- 0.13% | 76.71% +- 0.09% | 224.3s | -| A-GEM | 14.37% +- 0.39% | 79.34% +- 0.96% | 516.3s | -| baseline | 14.06% +- 0.10% | 79.14% +- 1.39% | 181.0s | -| EWC | 12.12% +- 0.74% | 69.20% +- 3.02% | 223.1s | +5 epochs per task, and a compact residual CIFAR ConvNet. The main suite uses a +5,000-example memory budget where applicable; the GDumb row is explicitly marked +as a 10,000-example high-memory comparison. This is a reproducible benchmark, not +a paper leaderboard claim. + +| Method | Memory | Average final accuracy | Average forgetting | Mean runtime | +| --- | ---: | ---: | ---: | ---: | +| GDumb | 10000 | 68.78% +- 0.22% | 12.89% +- 0.71% | 2020.4s | +| DER++ | 5000 | 51.15% +- 3.95% | 34.06% +- 4.74% | 578.7s | +| replay | 5000 | 41.99% +- 0.27% | 45.27% +- 1.73% | 547.4s | +| LwF | 5000 | 16.53% +- 0.13% | 76.71% +- 0.09% | 224.3s | +| A-GEM | 5000 | 14.37% +- 0.39% | 79.34% +- 0.96% | 516.3s | +| baseline | 5000 | 14.06% +- 0.10% | 79.14% +- 1.39% | 181.0s | +| EWC | 5000 | 12.12% +- 0.74% | 69.20% +- 3.02% | 223.1s | Generated report artifacts live in [`docs/assets/split_cifar10_headline`](docs/assets/split_cifar10_headline/README.md). @@ -136,14 +188,18 @@ src/cl_bench/ models.py # linear, MLP, small CNN, and CIFAR residual ConvNet factory reporting.py # run aggregation, leaderboard CSV/JSON, and plots tracking.py # JSON/JSONL/CSV artifacts and optional MLflow logging - strategies/ # baseline, EWC, replay, LwF, DER++, and A-GEM + strategies/ # baseline, EWC, replay, LwF, DER++, A-GEM, ER-ACE, GDumb, CAR configs/ smoke.yaml # fast deterministic CPU benchmark split_mnist_quick.yaml # bounded real MNIST suite for local CPU runs split_mnist.yaml # full five-task MNIST stream split_cifar10_headline.yaml # verified CIFAR-10 benchmark used in the README + paper/ # full-data CIFAR-10, CIFAR-100, and TinyImageNet protocols + method/ # method snippets such as CAR defaults + model/ # model/training snippets such as CIFAR ResNet-18 docs/ BENCHMARK_CARD.md # scope, metrics, limitations, reproducibility + paper/ # manuscript scaffold, claims table, and run checklist tests/ # unit and integration coverage ``` @@ -182,6 +238,13 @@ ignored by git. Curated README assets under `docs/assets/` are intentionally kep - DER++ stores replay logits online and combines current CE, replay CE, and logit-matching losses. - A-GEM projects conflicting gradients against replay-memory reference gradients. +- ER-ACE masks the current-task loss so new examples do not directly suppress + old classes, while replay examples still use full cross-entropy. +- GDumb keeps a class-balanced memory and retrains from scratch on stored + exemplars after each task. +- CAR keeps class-balanced exemplars with logit and feature anchors, refreshes + per-class prototypes after each task, and fits a lightweight temperature/bias + calibrator over memory before evaluation. - Best validation checkpoints are deep-copied before restoration to avoid mutable `state_dict` aliasing bugs. - The suite/report layer separates expensive benchmark execution from cheap, diff --git a/configs/method/car.yaml b/configs/method/car.yaml new file mode 100644 index 0000000..bb078e5 --- /dev/null +++ b/configs/method/car.yaml @@ -0,0 +1,13 @@ +method: car +strategy: + replay_buffer_size: 2000 + replay_batch_size: 128 + car_logit_anchor_weight: 0.25 + car_replay_ce_weight: 1.0 + car_feature_anchor_weight: 0.05 + car_prototype_anchor_weight: 0.05 + car_calibration_epochs: 10 + car_calibration_lr: 0.01 + car_calibration_weight_decay: 0.0 + car_replay_augment: true + car_use_current_task_mask: true diff --git a/configs/model/frozen_dinov2.yaml b/configs/model/frozen_dinov2.yaml new file mode 100644 index 0000000..cc9d47a --- /dev/null +++ b/configs/model/frozen_dinov2.yaml @@ -0,0 +1,15 @@ +model: linear +feature_protocol: + backbone: dinov2 + cache_dir: data/feature_cache/dinov2 + freeze_backbone: true + note: "Use cached frozen features for the modern-backbone protocol; extraction is run before cl-bench experiments." +training: + optimizer: adamw + learning_rate: 0.001 + weight_decay: 0.0001 + scheduler: cosine + warmup_epochs: 0 + batch_size: 256 + eval_batch_size: 1024 + augment: false diff --git a/configs/model/resnet18_cifar.yaml b/configs/model/resnet18_cifar.yaml new file mode 100644 index 0000000..9b505c9 --- /dev/null +++ b/configs/model/resnet18_cifar.yaml @@ -0,0 +1,12 @@ +model: resnet18_cifar +training: + optimizer: sgd + learning_rate: 0.05 + momentum: 0.9 + weight_decay: 0.0005 + scheduler: cosine + warmup_epochs: 1 + label_smoothing: 0.05 + batch_size: 128 + eval_batch_size: 512 + augment: true diff --git a/configs/paper/split_cifar100_full.yaml b/configs/paper/split_cifar100_full.yaml new file mode 100644 index 0000000..910ba01 --- /dev/null +++ b/configs/paper/split_cifar100_full.yaml @@ -0,0 +1,70 @@ +name: split_cifar100_full +method: car +seed: 13 +device: auto +model: resnet18_cifar +data_dir: data +output_dir: runs +tracking: + mode: both + mlflow_tracking_uri: sqlite:///mlruns/mlflow.db + mlflow_experiment: continual-learning-paper +training: + epochs: 20 + batch_size: 128 + eval_batch_size: 512 + learning_rate: 0.05 + optimizer: sgd + momentum: 0.9 + weight_decay: 0.0005 + scheduler: cosine + warmup_epochs: 1 + label_smoothing: 0.05 + val_fraction: 0.1 + num_workers: 2 + augment: true +strategy: + replay_buffer_size: 2000 + replay_batch_size: 128 + replay_loss_weight: 1.0 + derpp_alpha: 0.1 + derpp_beta: 1.0 + car_logit_anchor_weight: 0.25 + car_replay_ce_weight: 1.0 + car_feature_anchor_weight: 0.05 + car_prototype_anchor_weight: 0.05 + car_calibration_epochs: 10 + car_calibration_lr: 0.01 + car_replay_augment: true + car_use_current_task_mask: true +tasks: + - name: cifar100_00_09 + dataset: cifar100 + classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + - name: cifar100_10_19 + dataset: cifar100 + classes: [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + - name: cifar100_20_29 + dataset: cifar100 + classes: [20, 21, 22, 23, 24, 25, 26, 27, 28, 29] + - name: cifar100_30_39 + dataset: cifar100 + classes: [30, 31, 32, 33, 34, 35, 36, 37, 38, 39] + - name: cifar100_40_49 + dataset: cifar100 + classes: [40, 41, 42, 43, 44, 45, 46, 47, 48, 49] + - name: cifar100_50_59 + dataset: cifar100 + classes: [50, 51, 52, 53, 54, 55, 56, 57, 58, 59] + - name: cifar100_60_69 + dataset: cifar100 + classes: [60, 61, 62, 63, 64, 65, 66, 67, 68, 69] + - name: cifar100_70_79 + dataset: cifar100 + classes: [70, 71, 72, 73, 74, 75, 76, 77, 78, 79] + - name: cifar100_80_89 + dataset: cifar100 + classes: [80, 81, 82, 83, 84, 85, 86, 87, 88, 89] + - name: cifar100_90_99 + dataset: cifar100 + classes: [90, 91, 92, 93, 94, 95, 96, 97, 98, 99] diff --git a/configs/paper/split_cifar10_full.yaml b/configs/paper/split_cifar10_full.yaml new file mode 100644 index 0000000..afd7d8c --- /dev/null +++ b/configs/paper/split_cifar10_full.yaml @@ -0,0 +1,56 @@ +name: split_cifar10_full +method: car +seed: 13 +device: auto +model: resnet18_cifar +data_dir: data +output_dir: runs +tracking: + mode: both + mlflow_tracking_uri: sqlite:///mlruns/mlflow.db + mlflow_experiment: continual-learning-paper +training: + epochs: 20 + batch_size: 128 + eval_batch_size: 512 + learning_rate: 0.05 + optimizer: sgd + momentum: 0.9 + weight_decay: 0.0005 + scheduler: cosine + warmup_epochs: 1 + label_smoothing: 0.05 + val_fraction: 0.1 + num_workers: 2 + augment: true +strategy: + replay_buffer_size: 2000 + replay_batch_size: 128 + replay_loss_weight: 1.0 + derpp_alpha: 0.1 + derpp_beta: 1.0 + car_logit_anchor_weight: 0.25 + car_replay_ce_weight: 1.0 + car_feature_anchor_weight: 0.05 + car_prototype_anchor_weight: 0.05 + car_calibration_epochs: 10 + car_calibration_lr: 0.01 + car_calibration_weight_decay: 0.0 + car_replay_augment: true + car_use_current_task_mask: true +tasks: + - name: cifar10_airplane_automobile + dataset: cifar10 + classes: [0, 1] + - name: cifar10_bird_cat + dataset: cifar10 + classes: [2, 3] + - name: cifar10_deer_dog + dataset: cifar10 + classes: [4, 5] + - name: cifar10_frog_horse + dataset: cifar10 + classes: [6, 7] + - name: cifar10_ship_truck + dataset: cifar10 + classes: [8, 9] diff --git a/configs/paper/split_tinyimagenet.yaml b/configs/paper/split_tinyimagenet.yaml new file mode 100644 index 0000000..a725571 --- /dev/null +++ b/configs/paper/split_tinyimagenet.yaml @@ -0,0 +1,100 @@ +name: split_tinyimagenet +method: car +seed: 13 +device: auto +model: resnet18_cifar +data_dir: data +output_dir: runs +tracking: + mode: both + mlflow_tracking_uri: sqlite:///mlruns/mlflow.db + mlflow_experiment: continual-learning-paper +training: + epochs: 30 + batch_size: 128 + eval_batch_size: 512 + learning_rate: 0.05 + optimizer: sgd + momentum: 0.9 + weight_decay: 0.0005 + scheduler: cosine + warmup_epochs: 1 + label_smoothing: 0.05 + val_fraction: 0.1 + num_workers: 4 + augment: true +strategy: + replay_buffer_size: 5000 + replay_batch_size: 128 + replay_loss_weight: 1.0 + derpp_alpha: 0.1 + derpp_beta: 1.0 + car_logit_anchor_weight: 0.25 + car_replay_ce_weight: 1.0 + car_feature_anchor_weight: 0.05 + car_prototype_anchor_weight: 0.05 + car_calibration_epochs: 10 + car_calibration_lr: 0.01 + car_replay_augment: true + car_use_current_task_mask: true +tasks: + - name: tinyimagenet_000_009 + dataset: tinyimagenet + classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + - name: tinyimagenet_010_019 + dataset: tinyimagenet + classes: [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + - name: tinyimagenet_020_029 + dataset: tinyimagenet + classes: [20, 21, 22, 23, 24, 25, 26, 27, 28, 29] + - name: tinyimagenet_030_039 + dataset: tinyimagenet + classes: [30, 31, 32, 33, 34, 35, 36, 37, 38, 39] + - name: tinyimagenet_040_049 + dataset: tinyimagenet + classes: [40, 41, 42, 43, 44, 45, 46, 47, 48, 49] + - name: tinyimagenet_050_059 + dataset: tinyimagenet + classes: [50, 51, 52, 53, 54, 55, 56, 57, 58, 59] + - name: tinyimagenet_060_069 + dataset: tinyimagenet + classes: [60, 61, 62, 63, 64, 65, 66, 67, 68, 69] + - name: tinyimagenet_070_079 + dataset: tinyimagenet + classes: [70, 71, 72, 73, 74, 75, 76, 77, 78, 79] + - name: tinyimagenet_080_089 + dataset: tinyimagenet + classes: [80, 81, 82, 83, 84, 85, 86, 87, 88, 89] + - name: tinyimagenet_090_099 + dataset: tinyimagenet + classes: [90, 91, 92, 93, 94, 95, 96, 97, 98, 99] + - name: tinyimagenet_100_109 + dataset: tinyimagenet + classes: [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] + - name: tinyimagenet_110_119 + dataset: tinyimagenet + classes: [110, 111, 112, 113, 114, 115, 116, 117, 118, 119] + - name: tinyimagenet_120_129 + dataset: tinyimagenet + classes: [120, 121, 122, 123, 124, 125, 126, 127, 128, 129] + - name: tinyimagenet_130_139 + dataset: tinyimagenet + classes: [130, 131, 132, 133, 134, 135, 136, 137, 138, 139] + - name: tinyimagenet_140_149 + dataset: tinyimagenet + classes: [140, 141, 142, 143, 144, 145, 146, 147, 148, 149] + - name: tinyimagenet_150_159 + dataset: tinyimagenet + classes: [150, 151, 152, 153, 154, 155, 156, 157, 158, 159] + - name: tinyimagenet_160_169 + dataset: tinyimagenet + classes: [160, 161, 162, 163, 164, 165, 166, 167, 168, 169] + - name: tinyimagenet_170_179 + dataset: tinyimagenet + classes: [170, 171, 172, 173, 174, 175, 176, 177, 178, 179] + - name: tinyimagenet_180_189 + dataset: tinyimagenet + classes: [180, 181, 182, 183, 184, 185, 186, 187, 188, 189] + - name: tinyimagenet_190_199 + dataset: tinyimagenet + classes: [190, 191, 192, 193, 194, 195, 196, 197, 198, 199] diff --git a/configs/smoke.yaml b/configs/smoke.yaml index 2cdc2e4..f11f207 100644 --- a/configs/smoke.yaml +++ b/configs/smoke.yaml @@ -20,6 +20,7 @@ strategy: replay_loss_weight: 1.0 lwf_alpha: 0.5 lwf_temperature: 2.0 + gdumb_epochs: 3 tasks: - name: synthetic_0_1 dataset: synthetic diff --git a/configs/split_cifar10_headline.yaml b/configs/split_cifar10_headline.yaml index c4c1831..2e18e9b 100644 --- a/configs/split_cifar10_headline.yaml +++ b/configs/split_cifar10_headline.yaml @@ -28,6 +28,7 @@ strategy: derpp_alpha: 0.1 derpp_beta: 2.0 agem_memory_batch_size: 256 + gdumb_epochs: 20 tasks: - name: cifar10_airplane_automobile dataset: cifar10 diff --git a/docs/BENCHMARK_CARD.md b/docs/BENCHMARK_CARD.md index d17bbe6..348efd1 100644 --- a/docs/BENCHMARK_CARD.md +++ b/docs/BENCHMARK_CARD.md @@ -14,6 +14,13 @@ tasks after every training step. - Learning without Forgetting with temperature-scaled distillation. - DER++ with online replay-logit storage. - A-GEM with replay-memory gradient projection. +- ER-ACE with asymmetric current-task cross-entropy plus replay. +- GDumb with class-balanced memory and from-scratch memory training. +- Calibrated Anchor Replay with balanced exemplars, logit anchors, feature + anchors, per-class prototypes, and a post-task calibration head. +- `bic`, `icarl`, and `x_der_lite` are lightweight protocol baselines built from + the CAR components for ablation and comparison; they are not drop-in + reproductions of the original papers. ## Datasets @@ -23,6 +30,10 @@ tasks after every training step. - `split_mnist`: full five-task MNIST stream for longer local experiments. - `split_cifar10_headline`: real CIFAR-10 images, five class-incremental tasks, two verification seeds, and a compact residual ConvNet. +- `paper/split_cifar10_full`: full Split CIFAR-10 protocol for paper runs. +- `paper/split_cifar100_full`: full Split CIFAR-100 protocol for paper runs. +- `paper/split_tinyimagenet`: Split TinyImageNet protocol; requires the dataset + to be downloaded into the configured `data_dir`. ## Metrics @@ -43,3 +54,7 @@ into a leaderboard plus report plots. The Split CIFAR-10 reported result is designed for local reproducibility and engineering verification. It is not a leaderboard claim. Serious research comparisons should increase seeds, epochs, memory budgets, and dataset coverage. +The GDumb comparison uses a larger memory budget than the main 5,000-example +suite and should be read as a high-memory result rather than a same-budget claim. +Paper claims must be made only from matched memory budgets, matched model +families, full protocol configs, and multi-seed confidence intervals. diff --git a/docs/assets/split_cifar10_headline/README.md b/docs/assets/split_cifar10_headline/README.md index 27f120e..5e58987 100644 --- a/docs/assets/split_cifar10_headline/README.md +++ b/docs/assets/split_cifar10_headline/README.md @@ -1,19 +1,20 @@ -# Split CIFAR-10 Headline Benchmark +# Split CIFAR-10 Benchmark Benchmarks: `split_cifar10_headline` -Runs: `12` +Runs: `14` Tasks per run: `5` ## Leaderboard -| Method | Runs | Seeds | Final accuracy | Forgetting | Backward transfer | Mean runtime | -| --- | ---: | --- | ---: | ---: | ---: | ---: | -| derpp | 2 | 13,21 | 51.15 +- 3.95% | 34.06 +- 4.74% | -34.06% | 578.7s | -| replay | 2 | 13,21 | 41.99 +- 0.27% | 45.27 +- 1.73% | -45.27% | 547.4s | -| lwf | 2 | 13,21 | 16.53 +- 0.13% | 76.71 +- 0.09% | -76.71% | 224.3s | -| agem | 2 | 13,21 | 14.37 +- 0.39% | 79.34 +- 0.96% | -79.34% | 516.3s | -| baseline | 2 | 13,21 | 14.06 +- 0.10% | 79.14 +- 1.39% | -79.14% | 181.0s | -| ewc | 2 | 13,21 | 12.12 +- 0.74% | 69.20 +- 3.02% | -69.20% | 223.1s | +| Method | Runs | Seeds | Memory | Final accuracy | Forgetting | Backward transfer | Mean runtime | +| --- | ---: | --- | ---: | ---: | ---: | ---: | ---: | +| gdumb | 2 | 13,21 | 10000 | 68.78 +- 0.22% | 12.89 +- 0.71% | -11.93% | 2020.4s | +| derpp | 2 | 13,21 | 5000 | 51.15 +- 3.95% | 34.06 +- 4.74% | -34.06% | 578.7s | +| replay | 2 | 13,21 | 5000 | 41.99 +- 0.27% | 45.27 +- 1.73% | -45.27% | 547.4s | +| lwf | 2 | 13,21 | 5000 | 16.53 +- 0.13% | 76.71 +- 0.09% | -76.71% | 224.3s | +| agem | 2 | 13,21 | 5000 | 14.37 +- 0.39% | 79.34 +- 0.96% | -79.34% | 516.3s | +| baseline | 2 | 13,21 | 5000 | 14.06 +- 0.10% | 79.14 +- 1.39% | -79.14% | 181.0s | +| ewc | 2 | 13,21 | 5000 | 12.12 +- 0.74% | 69.20 +- 3.02% | -69.20% | 223.1s | ## Plots @@ -23,6 +24,10 @@ Tasks per run: `5` ![accuracy_matrices](accuracy_matrices.png) +## Protocol Notes + +Rows are aggregated by method and replay-memory budget. Compare rows with the same memory budget and model family for matched-protocol claims. + ## Source Runs | Method | Seed | Run directory | @@ -35,6 +40,8 @@ Tasks per run: `5` | derpp | 21 | `runs/split_cifar10_headline_derpp_20260525T140155Z` | | ewc | 13 | `runs/split_cifar10_headline_ewc_20260525T130553Z` | | ewc | 21 | `runs/split_cifar10_headline_ewc_20260525T134413Z` | +| gdumb | 13 | `runs/split_cifar10_headline_gdumb_20260525T160424Z` | +| gdumb | 21 | `runs/split_cifar10_headline_gdumb_20260525T163726Z` | | lwf | 13 | `runs/split_cifar10_headline_lwf_20260525T131842Z` | | lwf | 21 | `runs/split_cifar10_headline_lwf_20260525T135744Z` | | replay | 13 | `runs/split_cifar10_headline_replay_20260525T130954Z` | diff --git a/docs/assets/split_cifar10_headline/accuracy_matrices.png b/docs/assets/split_cifar10_headline/accuracy_matrices.png index 2db9218..65cf932 100644 Binary files a/docs/assets/split_cifar10_headline/accuracy_matrices.png and b/docs/assets/split_cifar10_headline/accuracy_matrices.png differ diff --git a/docs/assets/split_cifar10_headline/leaderboard.csv b/docs/assets/split_cifar10_headline/leaderboard.csv index 17dc817..0112104 100644 --- a/docs/assets/split_cifar10_headline/leaderboard.csv +++ b/docs/assets/split_cifar10_headline/leaderboard.csv @@ -1,7 +1,8 @@ -method,runs,seeds,average_final_accuracy_mean,average_final_accuracy_std,average_learning_accuracy_mean,average_forgetting_mean,average_forgetting_std,backward_transfer_mean,runtime_seconds_mean -derpp,2,"13,21",51.150000000000006,3.9499999999999993,78.39999999999999,34.0625,4.737499999999999,-34.0625,578.7151509795076 -replay,2,"13,21",41.99000000000001,0.2699999999999996,78.21,45.275,1.7250000000000014,-45.275,547.4045557500067 -lwf,2,"13,21",16.53,0.13000000000000078,77.9,76.7125,0.08749999999999858,-76.7125,224.33408429198607 -agem,2,"13,21",14.370000000000001,0.39000000000000057,77.84,79.3375,0.9624999999999986,-79.3375,516.2666623959958 -baseline,2,"13,21",14.059999999999999,0.10000000000000053,77.37,79.13749999999999,1.3874999999999957,-79.13749999999999,180.97890222900605 -ewc,2,"13,21",12.12,0.7400000000000002,67.48000000000002,69.19999999999999,3.0249999999999986,-69.19999999999999,223.06707881249895 +method,runs,seeds,protocol_key,models,average_final_accuracy_mean,average_final_accuracy_std,average_learning_accuracy_mean,average_forgetting_mean,average_forgetting_std,backward_transfer_mean,runtime_seconds_mean,memory_budget_mean +gdumb,2,"13,21",gdumb@memory10000,cifar_convnet,68.78,0.21999999999999886,78.32,12.8875,0.7125000000000012,-11.925,2020.4386042919941,10000.0 +derpp,2,"13,21",derpp@memory5000,cifar_convnet,51.150000000000006,3.9499999999999993,78.39999999999999,34.0625,4.737499999999999,-34.0625,578.7151509795076,5000.0 +replay,2,"13,21",replay@memory5000,cifar_convnet,41.99000000000001,0.2699999999999996,78.21,45.275,1.7250000000000014,-45.275,547.4045557500067,5000.0 +lwf,2,"13,21",lwf@memory5000,cifar_convnet,16.53,0.13000000000000078,77.9,76.7125,0.08749999999999858,-76.7125,224.33408429198607,5000.0 +agem,2,"13,21",agem@memory5000,cifar_convnet,14.370000000000001,0.39000000000000057,77.84,79.3375,0.9624999999999986,-79.3375,516.2666623959958,5000.0 +baseline,2,"13,21",baseline@memory5000,cifar_convnet,14.059999999999999,0.10000000000000053,77.37,79.13749999999999,1.3874999999999957,-79.13749999999999,180.97890222900605,5000.0 +ewc,2,"13,21",ewc@memory5000,cifar_convnet,12.12,0.7400000000000002,67.48000000000002,69.19999999999999,3.0249999999999986,-69.19999999999999,223.06707881249895,5000.0 diff --git a/docs/assets/split_cifar10_headline/leaderboard.png b/docs/assets/split_cifar10_headline/leaderboard.png index 49f4e7d..9288bfe 100644 Binary files a/docs/assets/split_cifar10_headline/leaderboard.png and b/docs/assets/split_cifar10_headline/leaderboard.png differ diff --git a/docs/assets/split_cifar10_headline/retention_curves.png b/docs/assets/split_cifar10_headline/retention_curves.png index 99c0277..b647afa 100644 Binary files a/docs/assets/split_cifar10_headline/retention_curves.png and b/docs/assets/split_cifar10_headline/retention_curves.png differ diff --git a/docs/assets/split_cifar10_headline/summary.json b/docs/assets/split_cifar10_headline/summary.json index f5e8421..0894866 100644 --- a/docs/assets/split_cifar10_headline/summary.json +++ b/docs/assets/split_cifar10_headline/summary.json @@ -2,8 +2,23 @@ "benchmarks": [ "split_cifar10_headline" ], - "generated_at_utc": "2026-05-25T14:20:17.367419+00:00", + "generated_at_utc": "2026-05-25T17:50:27.144363+00:00", "leaderboard": [ + { + "average_final_accuracy_mean": 68.78, + "average_final_accuracy_std": 0.21999999999999886, + "average_forgetting_mean": 12.8875, + "average_forgetting_std": 0.7125000000000012, + "average_learning_accuracy_mean": 78.32, + "backward_transfer_mean": -11.925, + "memory_budget_mean": 10000.0, + "method": "gdumb", + "models": "cifar_convnet", + "protocol_key": "gdumb@memory10000", + "runs": 2, + "runtime_seconds_mean": 2020.4386042919941, + "seeds": "13,21" + }, { "average_final_accuracy_mean": 51.150000000000006, "average_final_accuracy_std": 3.9499999999999993, @@ -11,7 +26,10 @@ "average_forgetting_std": 4.737499999999999, "average_learning_accuracy_mean": 78.39999999999999, "backward_transfer_mean": -34.0625, + "memory_budget_mean": 5000.0, "method": "derpp", + "models": "cifar_convnet", + "protocol_key": "derpp@memory5000", "runs": 2, "runtime_seconds_mean": 578.7151509795076, "seeds": "13,21" @@ -23,7 +41,10 @@ "average_forgetting_std": 1.7250000000000014, "average_learning_accuracy_mean": 78.21, "backward_transfer_mean": -45.275, + "memory_budget_mean": 5000.0, "method": "replay", + "models": "cifar_convnet", + "protocol_key": "replay@memory5000", "runs": 2, "runtime_seconds_mean": 547.4045557500067, "seeds": "13,21" @@ -35,7 +56,10 @@ "average_forgetting_std": 0.08749999999999858, "average_learning_accuracy_mean": 77.9, "backward_transfer_mean": -76.7125, + "memory_budget_mean": 5000.0, "method": "lwf", + "models": "cifar_convnet", + "protocol_key": "lwf@memory5000", "runs": 2, "runtime_seconds_mean": 224.33408429198607, "seeds": "13,21" @@ -47,7 +71,10 @@ "average_forgetting_std": 0.9624999999999986, "average_learning_accuracy_mean": 77.84, "backward_transfer_mean": -79.3375, + "memory_budget_mean": 5000.0, "method": "agem", + "models": "cifar_convnet", + "protocol_key": "agem@memory5000", "runs": 2, "runtime_seconds_mean": 516.2666623959958, "seeds": "13,21" @@ -59,7 +86,10 @@ "average_forgetting_std": 1.3874999999999957, "average_learning_accuracy_mean": 77.37, "backward_transfer_mean": -79.13749999999999, + "memory_budget_mean": 5000.0, "method": "baseline", + "models": "cifar_convnet", + "protocol_key": "baseline@memory5000", "runs": 2, "runtime_seconds_mean": 180.97890222900605, "seeds": "13,21" @@ -71,13 +101,16 @@ "average_forgetting_std": 3.0249999999999986, "average_learning_accuracy_mean": 67.48000000000002, "backward_transfer_mean": -69.19999999999999, + "memory_budget_mean": 5000.0, "method": "ewc", + "models": "cifar_convnet", + "protocol_key": "ewc@memory5000", "runs": 2, "runtime_seconds_mean": 223.06707881249895, "seeds": "13,21" } ], - "num_runs": 12, + "num_runs": 14, "runs": [ { "benchmark": "split_cifar10_headline", @@ -295,6 +328,62 @@ "cifar10_ship_truck" ] }, + { + "benchmark": "split_cifar10_headline", + "git_commit": "eea83cbce1406e91f4c2929c8467b70506f28e4a", + "method": "gdumb", + "run_dir": "runs/split_cifar10_headline_gdumb_20260525T160424Z", + "runtime_seconds": 1971.4124877499999, + "seed": 13, + "summary": { + "average_final_accuracy": 69.0, + "average_forgetting": 12.174999999999999, + "average_learning_accuracy": 77.2, + "backward_transfer": -10.25, + "gdumb_epochs": 20, + "model": "cifar_convnet", + "num_tasks": 5, + "replay_batch_size": 256, + "replay_buffer_size": 10000, + "runtime_seconds": 1971.4124877499999, + "seed": 13 + }, + "task_names": [ + "cifar10_airplane_automobile", + "cifar10_bird_cat", + "cifar10_deer_dog", + "cifar10_frog_horse", + "cifar10_ship_truck" + ] + }, + { + "benchmark": "split_cifar10_headline", + "git_commit": "eea83cbce1406e91f4c2929c8467b70506f28e4a", + "method": "gdumb", + "run_dir": "runs/split_cifar10_headline_gdumb_20260525T163726Z", + "runtime_seconds": 2069.4647208339884, + "seed": 21, + "summary": { + "average_final_accuracy": 68.56, + "average_forgetting": 13.600000000000001, + "average_learning_accuracy": 79.44, + "backward_transfer": -13.600000000000001, + "gdumb_epochs": 20, + "model": "cifar_convnet", + "num_tasks": 5, + "replay_batch_size": 256, + "replay_buffer_size": 10000, + "runtime_seconds": 2069.4647208339884, + "seed": 21 + }, + "task_names": [ + "cifar10_airplane_automobile", + "cifar10_bird_cat", + "cifar10_deer_dog", + "cifar10_frog_horse", + "cifar10_ship_truck" + ] + }, { "benchmark": "split_cifar10_headline", "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178", @@ -404,5 +493,5 @@ ] } ], - "title": "Split CIFAR-10 Headline Benchmark" + "title": "Split CIFAR-10 Benchmark" } diff --git a/docs/paper/README.md b/docs/paper/README.md new file mode 100644 index 0000000..225d3d1 --- /dev/null +++ b/docs/paper/README.md @@ -0,0 +1,25 @@ +# Calibrated Anchor Replay Paper Scaffold + +This directory contains the manuscript scaffold and reproducibility checklist for +turning the benchmark into a research paper. + +## Research Claim Gate + +The paper may claim a memory-accuracy Pareto improvement only if CAR beats DER++, +ER-ACE, and replay at matched memory budgets on at least two full-data protocols. +Rows that differ in memory budget, backbone, epochs, task order, or seed count +must be described as protocol variants rather than direct wins. + +## Primary Experiment Commands + +```bash +cl-bench suite --config-name paper/split_cifar10_full --methods replay derpp er_ace gdumb car bic icarl x_der_lite --seeds 13 21 34 55 89 --memory-budgets 200 500 1000 2000 5000 --tracking both --paper --report-dir docs/paper/assets/split_cifar10_full --title "Split CIFAR-10 Full-Data Paper Protocol" +cl-bench suite --config-name paper/split_cifar100_full --methods replay derpp er_ace gdumb car bic icarl x_der_lite --seeds 13 21 34 55 89 --memory-budgets 200 500 1000 2000 5000 --tracking both --paper --report-dir docs/paper/assets/split_cifar100_full --title "Split CIFAR-100 Full-Data Paper Protocol" +cl-bench suite --config-name paper/split_tinyimagenet --methods replay derpp er_ace gdumb car bic icarl x_der_lite --seeds 13 21 34 --memory-budgets 500 1000 2000 5000 10000 --tracking both --paper --report-dir docs/paper/assets/split_tinyimagenet --title "Split TinyImageNet Paper Protocol" +``` + +## Manuscript Status + +- `manuscript.tex`: outline and claims skeleton. +- `reproducibility_checklist.md`: run and reporting checklist. +- `claims_table.md`: public claim discipline before any SOTA wording. diff --git a/docs/paper/claims_table.md b/docs/paper/claims_table.md new file mode 100644 index 0000000..90e70fa --- /dev/null +++ b/docs/paper/claims_table.md @@ -0,0 +1,9 @@ +# Claims Table + +| Claim | Current status | Evidence required | +| --- | --- | --- | +| CAR improves the memory-accuracy Pareto frontier | Not yet established | Matched full-data runs across CIFAR-10, CIFAR-100, and TinyImageNet with confidence intervals. | +| CAR reduces recency bias | Mechanistically plausible | Calibration ablation plus old/new class bias analysis. | +| CAR reduces representation drift | Mechanistically plausible | Feature-anchor ablation plus prototype drift curves. | +| CAR is state of the art | Not claimed | Equal-protocol wins over DER++, ER-ACE, GDumb, iCaRL/BiC, and external reference numbers. | +| Current README headline is a paper result | Not claimed | The existing headline is a local-budget engineering benchmark only. | diff --git a/docs/paper/manuscript.tex b/docs/paper/manuscript.tex new file mode 100644 index 0000000..6c9e491 --- /dev/null +++ b/docs/paper/manuscript.tex @@ -0,0 +1,51 @@ +\documentclass{article} + +\usepackage{booktabs} +\usepackage{graphicx} +\usepackage{hyperref} + +\title{Calibrated Anchor Replay: Reducing Recency Bias and Representation Drift in Class-Incremental Learning} +\author{Utkarsh Rajput} +\date{} + +\begin{document} +\maketitle + +\begin{abstract} +Class-incremental continual learning requires models to acquire new classes +without overwriting earlier decision boundaries. We study whether replay methods +can be strengthened with lightweight logit anchors, feature anchors, class +prototypes, and post-task calibration while preserving fixed memory budgets. The +central claim is evaluated through memory-accuracy Pareto curves under matched +Split CIFAR-10, Split CIFAR-100, and TinyImageNet protocols. +\end{abstract} + +\section{Introduction} +State the plasticity-stability problem, why replay remains a strong baseline, +and why matched memory budgets are necessary for fair claims. + +\section{Method} +Define Calibrated Anchor Replay. The method stores exemplars, labels, teacher +logits, penultimate features, and prototype statistics. Training combines current +cross-entropy, replay cross-entropy, logit matching, feature anchoring, and +prototype anchoring. After each task, a temperature and bias calibration head is +fit over balanced memory. + +\section{Experimental Protocol} +Report full-data Split CIFAR-10, Split CIFAR-100, and TinyImageNet. Compare +baseline, replay, DER++, ER-ACE, A-GEM, GDumb, iCaRL-style replay, BiC-style +calibration, X-DER-lite, and CAR under matched memory budgets. + +\section{Results} +Insert generated tables and figures from \texttt{docs/paper/assets}. Claims must +be restricted to matched protocols. + +\section{Ablations} +Remove balanced memory, logit anchors, feature anchors, prototype anchors, and +calibration one at a time. + +\section{Limitations} +Discuss compute budget, dataset scope, external protocol differences, and any +cases where CAR fails to beat strong baselines. + +\end{document} diff --git a/docs/paper/reproducibility_checklist.md b/docs/paper/reproducibility_checklist.md new file mode 100644 index 0000000..fa01988 --- /dev/null +++ b/docs/paper/reproducibility_checklist.md @@ -0,0 +1,12 @@ +# Reproducibility Checklist + +- Use only full-data paper configs for paper claims. +- Run matched memory budgets before comparing methods. +- Use at least five seeds for CIFAR-10 and CIFAR-100. +- Use at least three seeds for TinyImageNet if compute is tight. +- Log JSON artifacts and MLflow artifacts for every run. +- Keep raw `runs/`, `mlruns/`, datasets, and checkpoints out of git. +- Commit only curated paper assets, tables, and figures. +- Include per-seed CSVs and confidence intervals in every paper report. +- Separate classic from-scratch results from frozen-backbone results. +- Mark external Mammoth, Avalanche, and ContinualAI numbers with protocol caveats. diff --git a/pyproject.toml b/pyproject.toml index da07a1a..57e26d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "continual-learning-bench" -version = "0.3.0" +version = "0.4.0" description = "A PyTorch benchmark framework for reproducible continual-learning experiments." readme = "README.md" requires-python = ">=3.10" @@ -40,6 +40,7 @@ experiment = [ "hydra-core>=1.3.2", "mlflow>=3.12,<4", "omegaconf>=2.3", + "optuna>=4.0", "pandas>=2.2", ] report = [ diff --git a/src/cl_bench/__init__.py b/src/cl_bench/__init__.py index 3cba1fc..b2313ae 100644 --- a/src/cl_bench/__init__.py +++ b/src/cl_bench/__init__.py @@ -3,4 +3,4 @@ from cl_bench.config import BenchmarkResult, ExperimentConfig, TaskSpec __all__ = ["BenchmarkResult", "ExperimentConfig", "TaskSpec"] -__version__ = "0.3.0" +__version__ = "0.4.0" diff --git a/src/cl_bench/cli.py b/src/cl_bench/cli.py index b4da137..f3d59df 100644 --- a/src/cl_bench/cli.py +++ b/src/cl_bench/cli.py @@ -7,10 +7,23 @@ from cl_bench.config import ExperimentConfig, load_config_with_overrides from cl_bench.experiments import run_experiment -from cl_bench.reporting import collect_runs, write_report +from cl_bench.reporting import collect_runs, write_export, write_report from cl_bench.tracking import MLflowRunLogger -METHODS = ("baseline", "ewc", "replay", "lwf", "derpp", "agem") +METHODS = ( + "baseline", + "ewc", + "replay", + "lwf", + "derpp", + "agem", + "er_ace", + "gdumb", + "car", + "bic", + "icarl", + "x_der_lite", +) def main(argv: Sequence[str] | None = None) -> int: @@ -18,7 +31,7 @@ def main(argv: Sequence[str] | None = None) -> int: args = parser.parse_args(argv) if args.command == "list-configs": - for path in sorted((Path.cwd() / "configs").glob("*.yaml")): + for path in sorted((Path.cwd() / "configs").rglob("*.yaml")): print(path) return 0 @@ -34,19 +47,26 @@ def main(argv: Sequence[str] | None = None) -> int: if args.command == "suite": base_config = apply_cli_overrides(load_cli_config(args), args) seeds = args.seeds or [base_config.seed] + memory_budgets = args.memory_budgets or [base_config.replay_buffer_size] run_dirs: list[Path] = [] - for seed in seeds: - for method in args.methods: - config = replace(base_config, method=method, seed=seed) - result = run_experiment(config) - run_dirs.append(result.run_dir) - print( - f"{method} seed={seed}: " - f"{result.summary['average_final_accuracy']:.2f}% final accuracy, " - f"{result.summary['average_forgetting']:.2f}% forgetting " - f"({result.run_dir})" - ) + for memory_budget in memory_budgets: + for seed in seeds: + for method in args.methods: + config = replace( + base_config, + method=method, + seed=seed, + replay_buffer_size=memory_budget, + ) + result = run_experiment(config) + run_dirs.append(result.run_dir) + print( + f"{method} seed={seed} memory={memory_budget}: " + f"{result.summary['average_final_accuracy']:.2f}% final accuracy, " + f"{result.summary['average_forgetting']:.2f}% forgetting " + f"({result.run_dir})" + ) if args.report_dir: records = collect_runs(run_dirs) @@ -55,6 +75,7 @@ def main(argv: Sequence[str] | None = None) -> int: output_dir=args.report_dir, title=args.title or f"{base_config.name} benchmark report", make_plots=not args.no_plots, + paper=args.paper, ) log_suite_report_to_mlflow(base_config, report.report_dir, len(records)) print(f"Report directory: {report.report_dir}") @@ -68,6 +89,7 @@ def main(argv: Sequence[str] | None = None) -> int: output_dir=args.output_dir, title=args.title, make_plots=not args.no_plots, + paper=args.paper, ) print(f"Report directory: {report.report_dir}") print(f"Leaderboard: {report.leaderboard_csv}") @@ -77,6 +99,19 @@ def main(argv: Sequence[str] | None = None) -> int: print(f" {plot}") return 0 + if args.command == "export": + records = collect_runs(args.runs) + paths = write_export(records, args.output_dir, args.format) + print("Exported:") + for path in paths: + print(f" {path}") + return 0 + + if args.command == "sweep": + base_config = apply_cli_overrides(load_cli_config(args), args) + run_sweep(base_config, args) + return 0 + parser.print_help() return 1 @@ -101,9 +136,11 @@ def build_parser() -> argparse.ArgumentParser: add_config_arguments(suite_parser) suite_parser.add_argument("--methods", nargs="+", choices=METHODS, default=list(METHODS)) suite_parser.add_argument("--seeds", nargs="+", type=int) + suite_parser.add_argument("--memory-budgets", nargs="+", type=int) suite_parser.add_argument("--report-dir") suite_parser.add_argument("--title") suite_parser.add_argument("--no-plots", action="store_true") + suite_parser.add_argument("--paper", action="store_true") add_runtime_overrides(suite_parser) suite_parser.add_argument("overrides", nargs="*", help="Hydra/OmegaConf key=value overrides.") @@ -115,6 +152,28 @@ def build_parser() -> argparse.ArgumentParser: report_parser.add_argument("--output-dir", required=True) report_parser.add_argument("--title", default="Continual-learning benchmark report") report_parser.add_argument("--no-plots", action="store_true") + report_parser.add_argument("--paper", action="store_true") + + export_parser = subparsers.add_parser( + "export", + help="Export run summaries in comparison-friendly CSV/JSON formats.", + ) + export_parser.add_argument("--runs", nargs="+", required=True) + export_parser.add_argument("--output-dir", required=True) + export_parser.add_argument("--format", choices=["csv", "mammoth", "avalanche"], required=True) + + sweep_parser = subparsers.add_parser( + "sweep", + help="Run an Optuna hyperparameter sweep for a method/config.", + ) + add_config_arguments(sweep_parser) + sweep_parser.add_argument("--method", choices=METHODS, default="car") + sweep_parser.add_argument("--study-name", required=True) + sweep_parser.add_argument("--n-trials", type=int, default=20) + sweep_parser.add_argument("--storage", default="sqlite:///mlruns/optuna.db") + sweep_parser.add_argument("--direction", choices=["maximize", "minimize"], default="maximize") + add_runtime_overrides(sweep_parser) + sweep_parser.add_argument("overrides", nargs="*", help="Hydra/OmegaConf key=value overrides.") subparsers.add_parser("list-configs", help="List YAML configs in ./configs.") return parser @@ -127,7 +186,18 @@ def add_config_arguments(parser: argparse.ArgumentParser) -> None: def add_runtime_overrides(parser: argparse.ArgumentParser) -> None: - parser.add_argument("--model", choices=["linear", "mlp", "small_cnn", "cnn"]) + parser.add_argument( + "--model", + choices=[ + "linear", + "mlp", + "small_cnn", + "cnn", + "cifar_convnet", + "resnet18_cifar", + "cifar_resnet18", + ], + ) parser.add_argument("--epochs", type=int) parser.add_argument("--seed", type=int) parser.add_argument("--device") @@ -189,3 +259,48 @@ def log_suite_report_to_mlflow(config: ExperimentConfig, report_dir: Path, run_c } ) logger.log_artifacts(report_dir) + + +def run_sweep(config: ExperimentConfig, args: argparse.Namespace) -> None: + try: + import optuna + except ImportError as exc: + raise RuntimeError( + "The sweep command requires Optuna. Install with: " + 'python -m pip install -e ".[experiment]"' + ) from exc + + study = optuna.create_study( + study_name=args.study_name, + storage=args.storage, + direction=args.direction, + load_if_exists=True, + ) + + def objective(trial: optuna.Trial) -> float: + trial_config = replace( + config, + method=args.method, + learning_rate=trial.suggest_float("learning_rate", 1e-4, 5e-2, log=True), + replay_buffer_size=trial.suggest_categorical( + "replay_buffer_size", + [200, 500, 1000, 2000, 5000], + ), + replay_batch_size=trial.suggest_categorical("replay_batch_size", [32, 64, 128, 256]), + car_logit_anchor_weight=trial.suggest_float("car_logit_anchor_weight", 0.0, 1.0), + car_replay_ce_weight=trial.suggest_float("car_replay_ce_weight", 0.25, 4.0), + car_feature_anchor_weight=trial.suggest_float("car_feature_anchor_weight", 0.0, 0.5), + car_prototype_anchor_weight=trial.suggest_float( + "car_prototype_anchor_weight", + 0.0, + 0.5, + ), + ) + result = run_experiment(trial_config) + return float(result.summary["average_final_accuracy"]) - 0.2 * float( + result.summary["average_forgetting"] + ) + + study.optimize(objective, n_trials=args.n_trials) + print(f"Best value: {study.best_value:.4f}") + print(f"Best params: {study.best_params}") diff --git a/src/cl_bench/config.py b/src/cl_bench/config.py index 6dc5d3a..2b8f674 100644 --- a/src/cl_bench/config.py +++ b/src/cl_bench/config.py @@ -18,6 +18,8 @@ class TaskSpec: test_samples_per_class: int | None = None train_limit: int | None = None test_limit: int | None = None + train_feature_cache: str | None = None + test_feature_cache: str | None = None @classmethod def from_dict(cls, raw: dict[str, Any]) -> TaskSpec: @@ -32,6 +34,8 @@ def from_dict(cls, raw: dict[str, Any]) -> TaskSpec: test_samples_per_class=_optional_int(raw.get("test_samples_per_class")), train_limit=_optional_int(raw.get("train_limit")), test_limit=_optional_int(raw.get("test_limit")), + train_feature_cache=_optional_str(raw.get("train_feature_cache")), + test_feature_cache=_optional_str(raw.get("test_feature_cache")), ) @@ -54,6 +58,12 @@ class ExperimentConfig: batch_size: int = 64 eval_batch_size: int = 256 learning_rate: float = 1e-3 + optimizer: str = "adam" + momentum: float = 0.9 + weight_decay: float = 0.0 + scheduler: str = "none" + warmup_epochs: int = 0 + label_smoothing: float = 0.0 val_fraction: float = 0.1 num_workers: int = 0 augment: bool = True @@ -67,6 +77,16 @@ class ExperimentConfig: derpp_alpha: float = 0.5 derpp_beta: float = 1.0 agem_memory_batch_size: int = 64 + gdumb_epochs: int = 20 + car_logit_anchor_weight: float = 0.25 + car_replay_ce_weight: float = 1.0 + car_feature_anchor_weight: float = 0.05 + car_prototype_anchor_weight: float = 0.05 + car_calibration_epochs: int = 10 + car_calibration_lr: float = 0.01 + car_calibration_weight_decay: float = 0.0 + car_replay_augment: bool = True + car_use_current_task_mask: bool = True save_checkpoint: bool = False @classmethod @@ -103,6 +123,12 @@ def from_dict(cls, raw: dict[str, Any]) -> ExperimentConfig: batch_size=int(training.get("batch_size", raw.get("batch_size", 64))), eval_batch_size=int(training.get("eval_batch_size", raw.get("eval_batch_size", 256))), learning_rate=float(training.get("learning_rate", raw.get("learning_rate", 1e-3))), + optimizer=str(training.get("optimizer", raw.get("optimizer", "adam"))), + momentum=float(training.get("momentum", raw.get("momentum", 0.9))), + weight_decay=float(training.get("weight_decay", raw.get("weight_decay", 0.0))), + scheduler=str(training.get("scheduler", raw.get("scheduler", "none"))), + warmup_epochs=int(training.get("warmup_epochs", raw.get("warmup_epochs", 0))), + label_smoothing=float(training.get("label_smoothing", raw.get("label_smoothing", 0.0))), val_fraction=float(training.get("val_fraction", raw.get("val_fraction", 0.1))), num_workers=int(training.get("num_workers", raw.get("num_workers", 0))), augment=bool(training.get("augment", raw.get("augment", True))), @@ -124,6 +150,49 @@ def from_dict(cls, raw: dict[str, Any]) -> ExperimentConfig: agem_memory_batch_size=int( strategy.get("agem_memory_batch_size", raw.get("agem_memory_batch_size", 64)) ), + gdumb_epochs=int(strategy.get("gdumb_epochs", raw.get("gdumb_epochs", 20))), + car_logit_anchor_weight=float( + strategy.get( + "car_logit_anchor_weight", + raw.get("car_logit_anchor_weight", 0.25), + ) + ), + car_replay_ce_weight=float( + strategy.get("car_replay_ce_weight", raw.get("car_replay_ce_weight", 1.0)) + ), + car_feature_anchor_weight=float( + strategy.get( + "car_feature_anchor_weight", + raw.get("car_feature_anchor_weight", 0.05), + ) + ), + car_prototype_anchor_weight=float( + strategy.get( + "car_prototype_anchor_weight", + raw.get("car_prototype_anchor_weight", 0.05), + ) + ), + car_calibration_epochs=int( + strategy.get("car_calibration_epochs", raw.get("car_calibration_epochs", 10)) + ), + car_calibration_lr=float( + strategy.get("car_calibration_lr", raw.get("car_calibration_lr", 0.01)) + ), + car_calibration_weight_decay=float( + strategy.get( + "car_calibration_weight_decay", + raw.get("car_calibration_weight_decay", 0.0), + ) + ), + car_replay_augment=bool( + strategy.get("car_replay_augment", raw.get("car_replay_augment", True)) + ), + car_use_current_task_mask=bool( + strategy.get( + "car_use_current_task_mask", + raw.get("car_use_current_task_mask", True), + ) + ), save_checkpoint=bool(raw.get("save_checkpoint", False)), ) @@ -206,6 +275,14 @@ def resolve_config_path(source: str | Path) -> Path: if path.exists(): return path + for root in [Path.cwd() / "configs", Path(__file__).resolve().parents[2] / "configs"]: + matches = sorted(root.rglob(name)) + if len(matches) == 1: + return matches[0] + if len(matches) > 1: + joined = ", ".join(str(match) for match in matches) + raise FileExistsError(f"Config name '{source}' is ambiguous: {joined}") + searched = ", ".join(str(root / name) for root in search_roots) raise FileNotFoundError(f"Could not find config '{source}'. Searched: {searched}") @@ -220,3 +297,9 @@ def _optional_int(value: Any) -> int | None: if value is None: return None return int(value) + + +def _optional_str(value: Any) -> str | None: + if value is None: + return None + return str(value) diff --git a/src/cl_bench/datasets.py b/src/cl_bench/datasets.py index 585d1b9..a2eda54 100644 --- a/src/cl_bench/datasets.py +++ b/src/cl_bench/datasets.py @@ -54,6 +54,29 @@ def __getitem__(self, index: int) -> tuple[torch.Tensor, int]: return self.data[index], self.targets[index] +class FeatureCacheDataset(Dataset): + """Dataset backed by a cached tensor feature file.""" + + def __init__(self, cache_path: str | Path): + payload = torch.load(cache_path, map_location="cpu") + if not isinstance(payload, dict) or "features" not in payload or "targets" not in payload: + raise ValueError( + "Feature cache must be a torch-saved mapping with 'features' and 'targets'." + ) + self.data = torch.as_tensor(payload["features"]).float() + self.targets = torch.as_tensor(payload["targets"]).long().tolist() + if self.data.size(0) != len(self.targets): + raise ValueError("Feature cache features and targets have different lengths.") + classes = payload.get("classes") + self.classes = list(classes) if classes is not None else sorted(set(self.targets)) + + def __len__(self) -> int: + return len(self.targets) + + def __getitem__(self, index: int) -> tuple[torch.Tensor, int]: + return self.data[index], int(self.targets[index]) + + def build_task_loaders(config: ExperimentConfig) -> tuple[list[TaskLoaders], tuple[int, ...], int]: task_loaders: list[TaskLoaders] = [] all_classes: set[int] = set() @@ -127,6 +150,20 @@ def _build_datasets( ) return train_dataset, test_dataset, [int(label) for label in task.classes] + if dataset_name in {"feature_cache", "cached_features"}: + if task.train_feature_cache is None or task.test_feature_cache is None: + raise ValueError( + "Feature-cache tasks must define train_feature_cache and test_feature_cache." + ) + train_dataset = FeatureCacheDataset(Path(config.data_dir) / task.train_feature_cache) + test_dataset = FeatureCacheDataset(Path(config.data_dir) / task.test_feature_cache) + classes = _resolve_classes(task, train_dataset) + return ( + _class_subset(train_dataset, classes, task.train_limit, config.seed + task_id), + _class_subset(test_dataset, classes, task.test_limit, config.seed + task_id + 10_000), + classes, + ) + train_dataset = _torchvision_dataset( dataset_name, Path(config.data_dir), @@ -157,28 +194,61 @@ def _torchvision_dataset(dataset_name: str, data_dir: Path, train: bool, augment return tv_datasets.KMNIST(data_dir, train=train, download=True, transform=transform) if dataset_name == "cifar10": return tv_datasets.CIFAR10(data_dir, train=train, download=True, transform=transform) + if dataset_name == "cifar100": + return tv_datasets.CIFAR100(data_dir, train=train, download=True, transform=transform) + if dataset_name == "tinyimagenet": + split = "train" if train else "val" + root = data_dir / "tiny-imagenet-200" / split + if not root.exists(): + raise FileNotFoundError( + "TinyImageNet is not auto-downloaded. Expected ImageFolder layout at " + f"{root}. Download tiny-imagenet-200 under the configured data_dir." + ) + return tv_datasets.ImageFolder(root, transform=transform) raise ValueError(f"Unsupported dataset: {dataset_name}") def _torchvision_transform(dataset_name: str, train: bool, augment: bool) -> transforms.Compose: - if dataset_name == "cifar10": + if dataset_name in {"cifar10", "cifar100", "tinyimagenet"}: steps: list[object] = [] if train and augment: - steps.extend([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()]) - steps.extend( - [ - transforms.ToTensor(), - transforms.Normalize( - mean=(0.4914, 0.4822, 0.4465), - std=(0.2470, 0.2435, 0.2616), - ), - ] - ) + size = 64 if dataset_name == "tinyimagenet" else 32 + steps.extend( + [ + transforms.RandomCrop(size, padding=4), + transforms.RandomHorizontalFlip(), + transforms.RandAugment(num_ops=2, magnitude=9), + ] + ) + mean = (0.485, 0.456, 0.406) if dataset_name == "tinyimagenet" else (0.4914, 0.4822, 0.4465) + std = (0.229, 0.224, 0.225) if dataset_name == "tinyimagenet" else (0.2470, 0.2435, 0.2616) + steps.extend([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) + if train and augment: + steps.append(Cutout(size=8 if dataset_name != "tinyimagenet" else 16)) return transforms.Compose(steps) return transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) +class Cutout: + """Randomly mask one square region in a tensor image.""" + + def __init__(self, size: int): + self.size = int(size) + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + if image.ndim != 3 or self.size <= 0: + return image + _, height, width = image.shape + cutout_height = min(self.size, height) + cutout_width = min(self.size, width) + top = int(torch.randint(0, height - cutout_height + 1, (1,)).item()) + left = int(torch.randint(0, width - cutout_width + 1, (1,)).item()) + image = image.clone() + image[:, top : top + cutout_height, left : left + cutout_width] = 0.0 + return image + + def _resolve_classes(task: TaskSpec, dataset: Dataset) -> list[int]: if task.classes == "all": if not hasattr(dataset, "classes"): diff --git a/src/cl_bench/experiments.py b/src/cl_bench/experiments.py index b400053..3a5e1a9 100644 --- a/src/cl_bench/experiments.py +++ b/src/cl_bench/experiments.py @@ -1,5 +1,6 @@ from __future__ import annotations +import platform import random import time from pathlib import Path @@ -21,6 +22,7 @@ def run_experiment(config: ExperimentConfig, repo_dir: str | Path | None = None) tasks, input_shape, num_classes = build_task_loaders(config) model = get_model(config.model, input_shape=input_shape, num_classes=num_classes).to(device) strategy = create_strategy(config, model, device) + model_parameter_count = sum(parameter.numel() for parameter in model.parameters()) run_dir = create_run_dir(config.output_dir, config.name, config.method) tracker = ExperimentTracker(run_dir) @@ -34,6 +36,11 @@ def run_experiment(config: ExperimentConfig, repo_dir: str | Path | None = None) "seed": config.seed, "device": str(device), "git_commit": commit, + "python": platform.python_version(), + "torch": torch.__version__, + "model_parameter_count": model_parameter_count, + "task_order": [task.name for task in tasks], + "task_classes": [task.classes for task in tasks], } tracker.write_json("run_metadata.json", metadata) mlflow_enabled = config.tracking.lower() in {"mlflow", "both"} @@ -109,8 +116,15 @@ def run_experiment(config: ExperimentConfig, repo_dir: str | Path | None = None) "model": config.model, "replay_buffer_size": config.replay_buffer_size, "replay_batch_size": config.replay_batch_size, + "gdumb_epochs": config.gdumb_epochs, + "optimizer": config.optimizer, + "scheduler": config.scheduler, + "weight_decay": config.weight_decay, + "label_smoothing": config.label_smoothing, + "model_parameter_count": model_parameter_count, } ) + summary.update(strategy.run_summary()) tracker.write_json("accuracy_matrix.json", matrix_to_jsonable(accuracy_matrix)) tracker.write_json("forgetting_matrix.json", matrix_to_jsonable(forgetting_matrix)) diff --git a/src/cl_bench/models.py b/src/cl_bench/models.py index 91c8696..8c26776 100644 --- a/src/cl_bench/models.py +++ b/src/cl_bench/models.py @@ -111,6 +111,46 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return self.classifier(torch.flatten(features, 1)) +class CifarResNet18(nn.Module): + """ResNet-18 variant with a CIFAR stem and no initial max-pool.""" + + def __init__(self, input_shape: tuple[int, ...], num_classes: int): + super().__init__() + if len(input_shape) != 3: + raise ValueError(f"CifarResNet18 expects CHW input, got {input_shape}.") + channels, _, _ = input_shape + self.in_channels = 64 + self.stem = nn.Sequential( + nn.Conv2d(channels, 64, kernel_size=3, stride=1, padding=1, bias=False), + _norm(64), + nn.ReLU(inplace=True), + ) + self.layer1 = self._make_layer(64, blocks=2, stride=1) + self.layer2 = self._make_layer(128, blocks=2, stride=2) + self.layer3 = self._make_layer(256, blocks=2, stride=2) + self.layer4 = self._make_layer(512, blocks=2, stride=2) + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + self.features = nn.Sequential( + self.stem, + self.layer1, + self.layer2, + self.layer3, + self.layer4, + self.pool, + ) + self.classifier = nn.Linear(512, num_classes) + + def _make_layer(self, out_channels: int, blocks: int, stride: int) -> nn.Sequential: + layers = [ResidualBlock(self.in_channels, out_channels, stride=stride)] + self.in_channels = out_channels + layers.extend(ResidualBlock(out_channels, out_channels) for _ in range(blocks - 1)) + return nn.Sequential(*layers) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + features = torch.flatten(self.features(inputs), 1) + return self.classifier(features) + + def get_model(model_name: str, input_shape: tuple[int, ...], num_classes: int) -> nn.Module: name = model_name.lower().replace("-", "_") if name == "linear": @@ -119,8 +159,10 @@ def get_model(model_name: str, input_shape: tuple[int, ...], num_classes: int) - return MLP(input_shape, num_classes) if name in {"small_cnn", "cnn"}: return SmallCNN(input_shape, num_classes) - if name in {"cifar_convnet", "resnet18_cifar"}: + if name == "cifar_convnet": return CifarConvNet(input_shape, num_classes) + if name in {"resnet18_cifar", "cifar_resnet18"}: + return CifarResNet18(input_shape, num_classes) raise ValueError(f"Unknown model architecture: {model_name}") diff --git a/src/cl_bench/reporting.py b/src/cl_bench/reporting.py index 3aa8cf0..f6d4358 100644 --- a/src/cl_bench/reporting.py +++ b/src/cl_bench/reporting.py @@ -103,12 +103,13 @@ def load_run(metrics_path: str | Path) -> RunRecord: def aggregate_records(records: Sequence[RunRecord]) -> list[dict[str, float | int | str]]: """Aggregate run summaries by method for leaderboard-style reporting.""" - by_method: dict[str, list[RunRecord]] = defaultdict(list) + by_method: dict[tuple[str, int], list[RunRecord]] = defaultdict(list) for record in records: - by_method[record.method].append(record) + memory_budget = int(_metric(record, "replay_buffer_size")) + by_method[(record.method, memory_budget)].append(record) rows: list[dict[str, float | int | str]] = [] - for method, method_records in sorted(by_method.items()): + for (method, memory_budget), method_records in sorted(by_method.items()): final_accuracy = [_metric(record, "average_final_accuracy") for record in method_records] learning_accuracy = [ _metric(record, "average_learning_accuracy") for record in method_records @@ -116,6 +117,10 @@ def aggregate_records(records: Sequence[RunRecord]) -> list[dict[str, float | in forgetting = [_metric(record, "average_forgetting") for record in method_records] backward_transfer = [_metric(record, "backward_transfer") for record in method_records] runtimes = [record.runtime_seconds for record in method_records] + memory_budgets = [_metric(record, "replay_buffer_size") for record in method_records] + models = ",".join( + sorted({str(record.summary.get("model", "")) for record in method_records}) + ) seeds = ",".join( str(record.seed) for record in sorted(method_records, key=lambda item: item.seed) ) @@ -125,6 +130,8 @@ def aggregate_records(records: Sequence[RunRecord]) -> list[dict[str, float | in "method": method, "runs": len(method_records), "seeds": seeds, + "protocol_key": f"{method}@memory{memory_budget}", + "models": models, "average_final_accuracy_mean": _mean(final_accuracy), "average_final_accuracy_std": _std(final_accuracy), "average_learning_accuracy_mean": _mean(learning_accuracy), @@ -132,6 +139,7 @@ def aggregate_records(records: Sequence[RunRecord]) -> list[dict[str, float | in "average_forgetting_std": _std(forgetting), "backward_transfer_mean": _mean(backward_transfer), "runtime_seconds_mean": _mean(runtimes), + "memory_budget_mean": _mean(memory_budgets), } ) @@ -143,6 +151,7 @@ def write_report( output_dir: str | Path, title: str, make_plots: bool = True, + paper: bool = False, ) -> ReportArtifacts: """Write CSV, JSON, Markdown, and optional plot artifacts for a benchmark suite.""" @@ -157,7 +166,11 @@ def write_report( plots: list[Path] = [] if make_plots: plots = _write_plots(records, leaderboard, report_dir, title) + if paper: + plots.extend(_write_paper_plots(records, leaderboard, report_dir)) markdown = _write_markdown(report_dir / "README.md", title, records, leaderboard, plots) + if paper: + _write_paper_tables(report_dir, records, leaderboard) return ReportArtifacts( report_dir=report_dir, leaderboard_csv=leaderboard_csv, @@ -172,6 +185,8 @@ def _write_leaderboard_csv(path: Path, rows: Sequence[dict[str, float | int | st "method", "runs", "seeds", + "protocol_key", + "models", "average_final_accuracy_mean", "average_final_accuracy_std", "average_learning_accuracy_mean", @@ -179,6 +194,7 @@ def _write_leaderboard_csv(path: Path, rows: Sequence[dict[str, float | int | st "average_forgetting_std", "backward_transfer_mean", "runtime_seconds_mean", + "memory_budget_mean", ] with path.open("w", encoding="utf-8", newline="") as handle: writer = csv.DictWriter(handle, fieldnames=fieldnames) @@ -238,15 +254,16 @@ def _write_markdown( "", "## Leaderboard", "", - "| Method | Runs | Seeds | Final accuracy | Forgetting | Backward transfer | Mean runtime |", - "| --- | ---: | --- | ---: | ---: | ---: | ---: |", + "| Method | Runs | Seeds | Memory | Final accuracy | Forgetting | Backward transfer | Mean runtime |", + "| --- | ---: | --- | ---: | ---: | ---: | ---: | ---: |", ] for row in leaderboard: lines.append( - "| {method} | {runs} | {seeds} | {accuracy} | {forgetting} | {bwt} | {runtime} |".format( + "| {method} | {runs} | {seeds} | {memory:.0f} | {accuracy} | {forgetting} | {bwt} | {runtime} |".format( method=row["method"], runs=row["runs"], seeds=row["seeds"], + memory=float(row["memory_budget_mean"]), accuracy=_format_with_std( float(row["average_final_accuracy_mean"]), float(row["average_final_accuracy_std"]), @@ -268,6 +285,15 @@ def _write_markdown( lines.append(f"![{plot.stem}]({plot.name})") lines.append("") + lines.extend( + [ + "## Protocol Notes", + "", + "Rows are aggregated by method and replay-memory budget. Compare rows with the same memory budget and model family for matched-protocol claims.", + "", + ] + ) + lines.extend( [ "## Source Runs", @@ -303,6 +329,144 @@ def _write_plots( return paths +def _write_paper_plots( + records: Sequence[RunRecord], + leaderboard: Sequence[dict[str, float | int | str]], + report_dir: Path, +) -> list[Path]: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + paths = [ + _plot_memory_accuracy_pareto( + plt, + leaderboard, + report_dir / "memory_accuracy_pareto.png", + ), + _plot_memory_forgetting_pareto( + plt, + leaderboard, + report_dir / "memory_forgetting_pareto.png", + ), + _plot_runtime_memory_tradeoff( + plt, + leaderboard, + report_dir / "runtime_memory_accuracy.png", + ), + ] + if any("car_calibration_temperature" in record.summary for record in records): + paths.append(_plot_calibration(plt, records, report_dir / "calibration_temperatures.png")) + plt.close("all") + return paths + + +def _write_paper_tables( + report_dir: Path, + records: Sequence[RunRecord], + leaderboard: Sequence[dict[str, float | int | str]], +) -> None: + per_seed_path = report_dir / "per_seed_results.csv" + with per_seed_path.open("w", encoding="utf-8", newline="") as handle: + fieldnames = [ + "benchmark", + "method", + "seed", + "memory", + "model", + "average_final_accuracy", + "average_forgetting", + "backward_transfer", + "runtime_seconds", + "git_commit", + ] + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + for record in records: + writer.writerow( + { + "benchmark": record.benchmark, + "method": record.method, + "seed": record.seed, + "memory": _metric(record, "replay_buffer_size"), + "model": record.summary.get("model", ""), + "average_final_accuracy": _metric(record, "average_final_accuracy"), + "average_forgetting": _metric(record, "average_forgetting"), + "backward_transfer": _metric(record, "backward_transfer"), + "runtime_seconds": record.runtime_seconds, + "git_commit": record.git_commit, + } + ) + + latex_lines = [ + r"\begin{tabular}{lrrrr}", + r"\toprule", + r"Method & Memory & Final Acc. & Forgetting & Runtime (s) \\", + r"\midrule", + ] + for row in leaderboard: + latex_lines.append( + "{method} & {memory:.0f} & {acc:.2f} $\\pm$ {acc_std:.2f} & {forget:.2f} $\\pm$ {forget_std:.2f} & {runtime:.1f} \\\\".format( + method=row["method"], + memory=float(row["memory_budget_mean"]), + acc=float(row["average_final_accuracy_mean"]), + acc_std=float(row["average_final_accuracy_std"]), + forget=float(row["average_forgetting_mean"]), + forget_std=float(row["average_forgetting_std"]), + runtime=float(row["runtime_seconds_mean"]), + ) + ) + latex_lines.extend([r"\bottomrule", r"\end{tabular}", ""]) + (report_dir / "leaderboard_table.tex").write_text("\n".join(latex_lines), encoding="utf-8") + + claims_lines = [ + "# Claims Table", + "", + "| Claim | Evidence status | Notes |", + "| --- | --- | --- |", + "| CAR improves the memory-accuracy Pareto frontier | Pending matched full-data runs | Requires equal memory, model, epochs, and seeds. |", + "| High-memory methods outperform naive fine-tuning | Supported if leaderboard contains matched runs | Must not mix budgets in the same claim. |", + "| Results reproduce external libraries | Pending external protocol match | Compare against Mammoth, Avalanche, and ContinualAI only with protocol caveats. |", + ] + (report_dir / "claims_table.md").write_text("\n".join(claims_lines) + "\n", encoding="utf-8") + + +def write_export( + records: Sequence[RunRecord], + output_dir: str | Path, + export_format: str, +) -> list[Path]: + output = Path(output_dir) + output.mkdir(parents=True, exist_ok=True) + rows = [ + { + "benchmark": record.benchmark, + "method": record.method, + "seed": record.seed, + "memory": _metric(record, "replay_buffer_size"), + "model": record.summary.get("model", ""), + "average_final_accuracy": _metric(record, "average_final_accuracy"), + "average_forgetting": _metric(record, "average_forgetting"), + "runtime_seconds": record.runtime_seconds, + "run_dir": str(record.run_dir), + "git_commit": record.git_commit, + } + for record in records + ] + prefix = export_format.lower() + csv_path = output / f"{prefix}_runs.csv" + json_path = output / f"{prefix}_runs.json" + with csv_path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys()) if rows else []) + writer.writeheader() + writer.writerows(rows) + with json_path.open("w", encoding="utf-8") as handle: + json.dump({"format": export_format, "runs": rows}, handle, indent=2, sort_keys=True) + handle.write("\n") + return [csv_path, json_path] + + def _plot_leaderboard( plt: Any, leaderboard: Sequence[dict[str, float | int | str]], path: Path, title: str ) -> Path: @@ -364,16 +528,17 @@ def _bar_chart( def _plot_retention_curves(plt: Any, records: Sequence[RunRecord], path: Path) -> Path: fig, axis = plt.subplots(figsize=(9.5, 5.5), constrained_layout=True) + max_steps = max(record.accuracy_matrix.shape[0] for record in records) for method, method_records in _records_by_method(records).items(): curves = [] for record in method_records: - curve = [] + curve = [np.nan] * max_steps for step in range(record.accuracy_matrix.shape[0]): seen = record.accuracy_matrix[step, : step + 1] - curve.append(float(np.nanmean(seen))) + curve[step] = float(np.nanmean(seen)) curves.append(curve) matrix = np.asarray(curves, dtype=float) - x_values = np.arange(1, matrix.shape[1] + 1) + x_values = np.arange(1, max_steps + 1) mean_curve = np.nanmean(matrix, axis=0) std_curve = np.nanstd(matrix, axis=0) color = _method_color(method) @@ -387,7 +552,7 @@ def _plot_retention_curves(plt: Any, records: Sequence[RunRecord], path: Path) - axis.set_xlabel("After training task") axis.set_ylabel("Mean accuracy on seen tasks (%)") axis.set_ylim(0, 100) - axis.set_xticks(np.arange(1, max(record.accuracy_matrix.shape[0] for record in records) + 1)) + axis.set_xticks(np.arange(1, max_steps + 1)) axis.grid(alpha=0.25) axis.legend(frameon=False) fig.savefig(path, dpi=180, bbox_inches="tight") @@ -436,6 +601,120 @@ def _plot_accuracy_matrices(plt: Any, records: Sequence[RunRecord], path: Path) return path +def _plot_memory_accuracy_pareto( + plt: Any, + leaderboard: Sequence[dict[str, float | int | str]], + path: Path, +) -> Path: + fig, axis = plt.subplots(figsize=(8.5, 5.5), constrained_layout=True) + for row in leaderboard: + method = str(row["method"]) + axis.scatter( + float(row["memory_budget_mean"]), + float(row["average_final_accuracy_mean"]), + s=95, + color=_method_color(method), + edgecolor="#111827", + linewidth=0.8, + ) + axis.text( + float(row["memory_budget_mean"]), + float(row["average_final_accuracy_mean"]) + 0.8, + method, + ha="center", + fontsize=8, + ) + axis.set_title("Memory-accuracy Pareto view", fontsize=13, fontweight="bold") + axis.set_xlabel("Replay memory budget") + axis.set_ylabel("Average final accuracy (%)") + axis.grid(alpha=0.25) + fig.savefig(path, dpi=180, bbox_inches="tight") + return path + + +def _plot_memory_forgetting_pareto( + plt: Any, + leaderboard: Sequence[dict[str, float | int | str]], + path: Path, +) -> Path: + fig, axis = plt.subplots(figsize=(8.5, 5.5), constrained_layout=True) + for row in leaderboard: + method = str(row["method"]) + axis.scatter( + float(row["memory_budget_mean"]), + float(row["average_forgetting_mean"]), + s=95, + color=_method_color(method), + edgecolor="#111827", + linewidth=0.8, + ) + axis.text( + float(row["memory_budget_mean"]), + float(row["average_forgetting_mean"]) + 0.8, + method, + ha="center", + fontsize=8, + ) + axis.set_title("Memory-forgetting Pareto view", fontsize=13, fontweight="bold") + axis.set_xlabel("Replay memory budget") + axis.set_ylabel("Average forgetting (%)") + axis.grid(alpha=0.25) + fig.savefig(path, dpi=180, bbox_inches="tight") + return path + + +def _plot_runtime_memory_tradeoff( + plt: Any, + leaderboard: Sequence[dict[str, float | int | str]], + path: Path, +) -> Path: + fig, axis = plt.subplots(figsize=(8.5, 5.5), constrained_layout=True) + runtimes = [float(row["runtime_seconds_mean"]) for row in leaderboard] + max_runtime = max(runtimes, default=1.0) + for row in leaderboard: + method = str(row["method"]) + axis.scatter( + float(row["memory_budget_mean"]), + float(row["average_final_accuracy_mean"]), + s=60 + 240 * float(row["runtime_seconds_mean"]) / max_runtime, + color=_method_color(method), + alpha=0.82, + edgecolor="#111827", + linewidth=0.8, + ) + axis.text( + float(row["memory_budget_mean"]), + float(row["average_final_accuracy_mean"]) + 0.8, + method, + ha="center", + fontsize=8, + ) + axis.set_title("Runtime, memory, and accuracy tradeoff", fontsize=13, fontweight="bold") + axis.set_xlabel("Replay memory budget") + axis.set_ylabel("Average final accuracy (%)") + axis.grid(alpha=0.25) + fig.savefig(path, dpi=180, bbox_inches="tight") + return path + + +def _plot_calibration(plt: Any, records: Sequence[RunRecord], path: Path) -> Path: + fig, axis = plt.subplots(figsize=(8.5, 5.0), constrained_layout=True) + car_records = [record for record in records if "car_calibration_temperature" in record.summary] + labels = [f"{record.method}-{record.seed}" for record in car_records] + values = [ + float(record.summary.get("car_calibration_temperature", 1.0)) for record in car_records + ] + axis.bar( + range(len(values)), values, color=[_method_color(record.method) for record in car_records] + ) + axis.set_title("Post-task calibration temperatures", fontsize=13, fontweight="bold") + axis.set_ylabel("Temperature") + axis.set_xticks(range(len(values)), labels, rotation=35, ha="right") + axis.grid(axis="y", alpha=0.25) + fig.savefig(path, dpi=180, bbox_inches="tight") + return path + + def _records_by_method(records: Sequence[RunRecord]) -> dict[str, list[RunRecord]]: by_method: dict[str, list[RunRecord]] = defaultdict(list) for record in sorted(records, key=lambda item: (item.method, item.seed)): @@ -478,6 +757,12 @@ def _method_color(method: str) -> str: "lwf": "#c2410c", "derpp": "#7c3aed", "agem": "#0f766e", + "er_ace": "#0891b2", + "gdumb": "#db2777", + "car": "#dc2626", + "bic": "#9333ea", + "icarl": "#ca8a04", + "x_der_lite": "#4f46e5", }.get(method, "#7c3aed") diff --git a/src/cl_bench/strategies/__init__.py b/src/cl_bench/strategies/__init__.py index e2d90ea..6305c97 100644 --- a/src/cl_bench/strategies/__init__.py +++ b/src/cl_bench/strategies/__init__.py @@ -6,30 +6,42 @@ from cl_bench.config import ExperimentConfig from cl_bench.strategies.agem import AGEMStrategy, build_agem from cl_bench.strategies.baseline import BaselineStrategy, build_baseline +from cl_bench.strategies.car import CARStrategy, build_car from cl_bench.strategies.derpp import DERPPStrategy, build_derpp +from cl_bench.strategies.er_ace import ERACEStrategy, build_er_ace from cl_bench.strategies.ewc import EWCStrategy, build_ewc +from cl_bench.strategies.gdumb import GDumbStrategy, build_gdumb from cl_bench.strategies.lwf import LwFStrategy, build_lwf from cl_bench.strategies.replay import ReplayStrategy, build_replay Strategy = ( - BaselineStrategy | EWCStrategy | ReplayStrategy | LwFStrategy | DERPPStrategy | AGEMStrategy + BaselineStrategy + | EWCStrategy + | ReplayStrategy + | LwFStrategy + | DERPPStrategy + | AGEMStrategy + | ERACEStrategy + | GDumbStrategy + | CARStrategy ) def create_strategy(config: ExperimentConfig, model: nn.Module, device: torch.device) -> Strategy: method = config.method.lower().replace("-", "_") + task_classes = _task_classes(config) if method == "baseline": - return build_baseline(model, device, config.learning_rate) - if method == "ewc": - return build_ewc( + strategy = build_baseline(model, device, config.learning_rate) + elif method == "ewc": + strategy = build_ewc( model, device, learning_rate=config.learning_rate, ewc_lambda=config.ewc_lambda, fisher_samples=config.fisher_samples, ) - if method == "replay": - return build_replay( + elif method == "replay": + strategy = build_replay( model, device, learning_rate=config.learning_rate, @@ -38,8 +50,19 @@ def create_strategy(config: ExperimentConfig, model: nn.Module, device: torch.de replay_loss_weight=config.replay_loss_weight, seed=config.seed, ) - if method == "derpp": - return build_derpp( + elif method == "er_ace": + strategy = build_er_ace( + model, + device, + learning_rate=config.learning_rate, + buffer_size=config.replay_buffer_size, + replay_batch_size=config.replay_batch_size, + replay_loss_weight=config.replay_loss_weight, + task_classes=task_classes, + seed=config.seed, + ) + elif method == "derpp": + strategy = build_derpp( model, device, learning_rate=config.learning_rate, @@ -49,8 +72,18 @@ def create_strategy(config: ExperimentConfig, model: nn.Module, device: torch.de beta=config.derpp_beta, seed=config.seed, ) - if method == "agem": - return build_agem( + elif method == "gdumb": + strategy = build_gdumb( + model, + device, + learning_rate=config.learning_rate, + buffer_size=config.replay_buffer_size, + memory_epochs=config.gdumb_epochs, + batch_size=config.batch_size, + seed=config.seed, + ) + elif method == "agem": + strategy = build_agem( model, device, learning_rate=config.learning_rate, @@ -58,22 +91,72 @@ def create_strategy(config: ExperimentConfig, model: nn.Module, device: torch.de memory_batch_size=config.agem_memory_batch_size, seed=config.seed, ) - if method == "lwf": - return build_lwf( + elif method == "lwf": + strategy = build_lwf( model, device, learning_rate=config.learning_rate, alpha=config.lwf_alpha, temperature=config.lwf_temperature, ) - raise ValueError(f"Unknown continual-learning method: {config.method}") + elif method in {"car", "bic", "icarl", "x_der_lite"}: + use_calibration = method != "icarl" + strategy = build_car( + model, + device, + learning_rate=config.learning_rate, + buffer_size=config.replay_buffer_size, + replay_batch_size=config.replay_batch_size, + task_classes=task_classes, + num_classes=_num_classes(config), + logit_anchor_weight=0.0 if method == "bic" else config.car_logit_anchor_weight, + replay_ce_weight=config.car_replay_ce_weight, + feature_anchor_weight=0.0 if method == "bic" else config.car_feature_anchor_weight, + prototype_anchor_weight=0.0 if method == "bic" else config.car_prototype_anchor_weight, + calibration_epochs=config.car_calibration_epochs if use_calibration else 0, + calibration_lr=config.car_calibration_lr, + calibration_weight_decay=config.car_calibration_weight_decay, + replay_augment=config.car_replay_augment, + use_current_task_mask=config.car_use_current_task_mask, + seed=config.seed, + ) + else: + raise ValueError(f"Unknown continual-learning method: {config.method}") + + strategy.configure_training( + optimizer=config.optimizer, + momentum=config.momentum, + weight_decay=config.weight_decay, + scheduler=config.scheduler, + warmup_epochs=config.warmup_epochs, + label_smoothing=config.label_smoothing, + ) + return strategy + + +def _task_classes(config: ExperimentConfig) -> list[list[int]]: + return [ + [int(label) for label in task.classes] for task in config.tasks if task.classes != "all" + ] + + +def _num_classes(config: ExperimentConfig) -> int: + explicit_classes = [ + int(label) for task in config.tasks if task.classes != "all" for label in task.classes + ] + if explicit_classes: + return max(explicit_classes) + 1 + return 1000 __all__ = [ "AGEMStrategy", "BaselineStrategy", + "CARStrategy", "DERPPStrategy", + "ERACEStrategy", "EWCStrategy", + "GDumbStrategy", "LwFStrategy", "ReplayStrategy", "Strategy", diff --git a/src/cl_bench/strategies/base.py b/src/cl_bench/strategies/base.py index 85632b8..39d41ca 100644 --- a/src/cl_bench/strategies/base.py +++ b/src/cl_bench/strategies/base.py @@ -16,11 +16,36 @@ def __init__(self, model: nn.Module, device: torch.device, learning_rate: float) self.model = model self.device = device self.learning_rate = learning_rate + self.optimizer_name = "adam" + self.momentum = 0.9 + self.weight_decay = 0.0 + self.scheduler_name = "none" + self.warmup_epochs = 0 + self.label_smoothing = 0.0 self.criterion = nn.CrossEntropyLoss() - self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) + self.optimizer = self._build_optimizer() + self.scheduler: torch.optim.lr_scheduler.LRScheduler | None = None self.current_task = -1 self.seen_tasks = 0 + def configure_training( + self, + optimizer: str, + momentum: float, + weight_decay: float, + scheduler: str, + warmup_epochs: int, + label_smoothing: float, + ) -> None: + self.optimizer_name = optimizer.lower().replace("-", "_") + self.momentum = momentum + self.weight_decay = weight_decay + self.scheduler_name = scheduler.lower().replace("-", "_") + self.warmup_epochs = warmup_epochs + self.label_smoothing = label_smoothing + self.criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing) + self.optimizer = self._build_optimizer() + def train_task( self, train_loader: DataLoader, @@ -31,7 +56,8 @@ def train_task( self.current_task = task_id self.seen_tasks = max(self.seen_tasks, task_id + 1) self.before_task(task_id) - self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) + self.optimizer = self._build_optimizer() + self.scheduler = self._build_scheduler(epochs) best_state: dict[str, torch.Tensor] | None = None best_val_loss = float("inf") @@ -52,6 +78,9 @@ def train_task( best_val_loss = float(val_metrics["loss"]) best_state = clone_state_dict(self.model) + if self.scheduler is not None: + self.scheduler.step() + if best_state is not None: load_state_dict(self.model, best_state, self.device) @@ -113,9 +142,50 @@ def after_task(self, train_loader: DataLoader, task_id: int) -> None: def extra_state_dict(self) -> dict[str, Any]: return {} + def run_summary(self) -> dict[str, float | int | str | None]: + return {} + def load_extra_state_dict(self, state: dict[str, Any]) -> None: del state + def _build_optimizer(self) -> torch.optim.Optimizer: + if self.optimizer_name == "sgd": + return torch.optim.SGD( + self.model.parameters(), + lr=self.learning_rate, + momentum=self.momentum, + weight_decay=self.weight_decay, + ) + if self.optimizer_name == "adamw": + return torch.optim.AdamW( + self.model.parameters(), + lr=self.learning_rate, + weight_decay=self.weight_decay, + ) + if self.optimizer_name == "adam": + return torch.optim.Adam( + self.model.parameters(), + lr=self.learning_rate, + weight_decay=self.weight_decay, + ) + raise ValueError(f"Unsupported optimizer: {self.optimizer_name}") + + def _build_scheduler(self, epochs: int) -> torch.optim.lr_scheduler.LRScheduler | None: + if self.scheduler_name in {"none", "constant", ""}: + return None + if self.scheduler_name == "cosine": + warmup_epochs = min(max(0, self.warmup_epochs), max(0, epochs - 1)) + + def lr_lambda(epoch: int) -> float: + if warmup_epochs and epoch < warmup_epochs: + return float(epoch + 1) / float(warmup_epochs) + cosine_epochs = max(1, epochs - warmup_epochs) + progress = (epoch - warmup_epochs + 1) / float(cosine_epochs) + return 0.5 * (1.0 + torch.cos(torch.tensor(progress * torch.pi)).item()) + + return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda) + raise ValueError(f"Unsupported scheduler: {self.scheduler_name}") + def _train_epoch(self, train_loader: DataLoader, task_id: int) -> dict[str, float]: self.model.train() totals: dict[str, float] = {"loss": 0.0, "ce_loss": 0.0} diff --git a/src/cl_bench/strategies/car.py b/src/cl_bench/strategies/car.py new file mode 100644 index 0000000..946fd99 --- /dev/null +++ b/src/cl_bench/strategies/car.py @@ -0,0 +1,360 @@ +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +from cl_bench.strategies.base import ContinualLearningStrategy +from cl_bench.strategies.replay import BalancedReplayBuffer + + +@dataclass(frozen=True) +class CARWeights: + logit_anchor: float + replay_ce: float + feature_anchor: float + prototype_anchor: float + + +class BiasTemperatureCalibration(nn.Module): + """Small post-task calibration head used for old/new class bias correction.""" + + def __init__(self, num_classes: int): + super().__init__() + self.log_temperature = nn.Parameter(torch.zeros(())) + self.bias = nn.Parameter(torch.zeros(num_classes)) + + def forward(self, logits: torch.Tensor) -> torch.Tensor: + temperature = self.log_temperature.exp().clamp(min=0.05, max=20.0) + return logits / temperature + self.bias + + @property + def temperature(self) -> float: + return float(self.log_temperature.detach().exp().clamp(min=0.05, max=20.0).item()) + + +class CARStrategy(ContinualLearningStrategy): + """Calibrated Anchor Replay for class-incremental continual learning.""" + + def __init__( + self, + model: nn.Module, + device: torch.device, + learning_rate: float, + buffer_size: int, + replay_batch_size: int, + task_classes: list[list[int]], + num_classes: int, + weights: CARWeights, + calibration_epochs: int, + calibration_lr: float, + calibration_weight_decay: float, + replay_augment: bool, + use_current_task_mask: bool, + seed: int, + ): + super().__init__(model=model, device=device, learning_rate=learning_rate) + self.buffer = BalancedReplayBuffer(capacity=buffer_size, seed=seed) + self.replay_batch_size = replay_batch_size + self.task_classes = task_classes + self.weights = weights + self.calibration_epochs = calibration_epochs + self.calibration_lr = calibration_lr + self.calibration_weight_decay = calibration_weight_decay + self.replay_augment = replay_augment + self.use_current_task_mask = use_current_task_mask + self.calibrator = BiasTemperatureCalibration(num_classes).to(device) + self.class_prototypes: dict[int, torch.Tensor] = {} + self._last_batch_features: torch.Tensor | None = None + self._last_batch_logits: torch.Tensor | None = None + self.last_calibration_loss = 0.0 + + def compute_loss( + self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]: + logits, features = _forward_with_features(self.model, inputs) + self._last_batch_logits = logits.detach().cpu() + self._last_batch_features = features.detach().cpu() + + current_logits = logits + if self.use_current_task_mask and task_id < len(self.task_classes): + current_logits = _mask_to_classes(logits, self.task_classes[task_id]) + ce_loss = self.criterion(current_logits, targets) + + replay_ce_loss = torch.zeros((), device=self.device) + logit_anchor_loss = torch.zeros((), device=self.device) + feature_anchor_loss = torch.zeros((), device=self.device) + prototype_anchor_loss = torch.zeros((), device=self.device) + + if len(self.buffer) > 0 and self.replay_batch_size > 0: + samples = self.buffer.sample_samples( + self.replay_batch_size, + require_logits=True, + require_features=True, + ) + replay_inputs = torch.stack([sample.inputs for sample in samples]).to(self.device) + replay_inputs = ( + _augment_replay_batch(replay_inputs) if self.replay_augment else replay_inputs + ) + replay_targets = torch.tensor( + [sample.target for sample in samples], + dtype=torch.long, + device=self.device, + ) + stored_logits = torch.stack( + [sample.logits for sample in samples if sample.logits is not None] + ).to(self.device) + stored_features = torch.stack( + [sample.features for sample in samples if sample.features is not None] + ).to(self.device) + + replay_logits, replay_features = _forward_with_features(self.model, replay_inputs) + replay_ce_loss = self.criterion(replay_logits, replay_targets) + logit_anchor_loss = F.mse_loss(replay_logits, stored_logits) + feature_anchor_loss = F.mse_loss(replay_features, stored_features) + prototypes = self._prototype_targets(replay_targets) + if prototypes is not None: + prototype_anchor_loss = F.mse_loss(replay_features, prototypes) + + loss = ( + ce_loss + + self.weights.replay_ce * replay_ce_loss + + self.weights.logit_anchor * logit_anchor_loss + + self.weights.feature_anchor * feature_anchor_loss + + self.weights.prototype_anchor * prototype_anchor_loss + ) + return ( + loss, + logits, + { + "ce_loss": float(ce_loss.detach().item()), + "car_replay_ce_loss": float(replay_ce_loss.detach().item()), + "car_logit_anchor_loss": float(logit_anchor_loss.detach().item()), + "car_feature_anchor_loss": float(feature_anchor_loss.detach().item()), + "car_prototype_anchor_loss": float(prototype_anchor_loss.detach().item()), + }, + ) + + def observe_batch( + self, + inputs: torch.Tensor, + targets: torch.Tensor, + logits: torch.Tensor, + task_id: int, + ) -> None: + del logits, task_id + if self._last_batch_logits is None or self._last_batch_features is None: + with torch.no_grad(): + batch_logits, batch_features = _forward_with_features(self.model, inputs) + self._last_batch_logits = batch_logits.detach().cpu() + self._last_batch_features = batch_features.detach().cpu() + self.buffer.add_batch( + inputs, + targets, + logits=self._last_batch_logits, + features=self._last_batch_features, + ) + self._last_batch_logits = None + self._last_batch_features = None + + def after_task(self, train_loader: DataLoader, task_id: int) -> None: + del train_loader, task_id + self._refresh_buffer_anchors() + self._fit_calibration() + + def evaluate(self, data_loader: DataLoader) -> dict[str, float]: + self.model.eval() + self.calibrator.eval() + total_loss = 0.0 + total_correct = 0 + total_examples = 0 + + with torch.no_grad(): + for inputs, targets in data_loader: + inputs = inputs.to(self.device) + targets = targets.to(self.device) + logits = self.calibrator(self.model(inputs)) + loss = self.criterion(logits, targets) + total_loss += float(loss.item()) * inputs.size(0) + total_correct += int((logits.argmax(dim=1) == targets).sum().item()) + total_examples += int(targets.numel()) + + if total_examples == 0: + return {"loss": 0.0, "accuracy": 0.0, "examples": 0.0} + return { + "loss": total_loss / total_examples, + "accuracy": 100.0 * total_correct / total_examples, + "examples": float(total_examples), + } + + def run_summary(self) -> dict[str, float]: + return { + "car_calibration_temperature": self.calibrator.temperature, + "car_calibration_loss": self.last_calibration_loss, + "car_num_prototypes": float(len(self.class_prototypes)), + } + + def extra_state_dict(self) -> dict[str, Any]: + return { + "buffer_inputs": [sample.inputs for sample in self.buffer.samples], + "buffer_targets": [sample.target for sample in self.buffer.samples], + "buffer_logits": [sample.logits for sample in self.buffer.samples], + "buffer_features": [sample.features for sample in self.buffer.samples], + "seen_count": self.buffer.seen_count, + "class_counts": dict(self.buffer.class_counts()), + "calibration_temperature": self.calibrator.temperature, + "calibration_bias": self.calibrator.bias.detach().cpu().tolist(), + "weights": self.weights.__dict__, + } + + def _prototype_targets(self, targets: torch.Tensor) -> torch.Tensor | None: + prototypes: list[torch.Tensor] = [] + for target in targets.detach().cpu().tolist(): + prototype = self.class_prototypes.get(int(target)) + if prototype is None: + return None + prototypes.append(prototype) + return torch.stack(prototypes).to(self.device) + + def _refresh_buffer_anchors(self) -> None: + if len(self.buffer) == 0: + return + + self.model.eval() + feature_sums: dict[int, torch.Tensor] = defaultdict(torch.Tensor) + feature_counts: dict[int, int] = defaultdict(int) + with torch.no_grad(): + for sample in self.buffer.samples: + inputs = sample.inputs.unsqueeze(0).to(self.device) + logits, features = _forward_with_features(self.model, inputs) + sample.logits = logits.squeeze(0).detach().cpu() + sample.features = features.squeeze(0).detach().cpu() + label = int(sample.target) + if feature_counts[label] == 0: + feature_sums[label] = sample.features.clone() + else: + feature_sums[label] = feature_sums[label] + sample.features + feature_counts[label] += 1 + + self.class_prototypes = { + label: feature_sums[label] / float(count) + for label, count in feature_counts.items() + if count > 0 + } + + def _fit_calibration(self) -> None: + if len(self.buffer) == 0 or self.calibration_epochs <= 0: + return + + inputs, targets = self.buffer.tensors() + dataset = TensorDataset(inputs, targets) + loader = DataLoader(dataset, batch_size=min(256, max(1, len(dataset))), shuffle=True) + optimizer = torch.optim.AdamW( + self.calibrator.parameters(), + lr=self.calibration_lr, + weight_decay=self.calibration_weight_decay, + ) + self.model.eval() + self.calibrator.train() + last_loss = 0.0 + for _ in range(self.calibration_epochs): + for batch_inputs, batch_targets in loader: + batch_inputs = batch_inputs.to(self.device) + batch_targets = batch_targets.to(self.device) + with torch.no_grad(): + logits = self.model(batch_inputs) + optimizer.zero_grad(set_to_none=True) + calibrated_logits = self.calibrator(logits) + loss = self.criterion(calibrated_logits, batch_targets) + loss.backward() + optimizer.step() + last_loss = float(loss.detach().item()) + self.last_calibration_loss = last_loss + + +def _forward_with_features( + model: nn.Module, inputs: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + features_module = getattr(model, "features", None) + classifier = getattr(model, "classifier", None) + if features_module is not None and classifier is not None: + features = torch.flatten(features_module(inputs), 1) + logits = classifier(features) + return logits, features + + logits = model(inputs) + return logits, logits.detach() if logits.requires_grad else logits + + +def _mask_to_classes(logits: torch.Tensor, classes: list[int]) -> torch.Tensor: + masked_logits = logits.clone() + allowed = torch.zeros(logits.size(1), dtype=torch.bool, device=logits.device) + allowed[torch.tensor(classes, dtype=torch.long, device=logits.device)] = True + masked_logits[:, ~allowed] = -1e9 + return masked_logits + + +def _augment_replay_batch(inputs: torch.Tensor) -> torch.Tensor: + if inputs.ndim != 4 or inputs.size(-1) < 8 or inputs.size(-2) < 8: + return inputs + + batch_size, _, height, width = inputs.shape + padded = F.pad(inputs, (4, 4, 4, 4), mode="reflect") + max_top = padded.size(-2) - height + max_left = padded.size(-1) - width + crops = [] + for index in range(batch_size): + top = int(torch.randint(0, max_top + 1, (1,), device=inputs.device).item()) + left = int(torch.randint(0, max_left + 1, (1,), device=inputs.device).item()) + crop = padded[index : index + 1, :, top : top + height, left : left + width] + if bool(torch.rand((), device=inputs.device) < 0.5): + crop = torch.flip(crop, dims=(-1,)) + crops.append(crop) + return torch.cat(crops, dim=0) + + +def build_car( + model: nn.Module, + device: torch.device, + learning_rate: float, + buffer_size: int, + replay_batch_size: int, + task_classes: list[list[int]], + num_classes: int, + logit_anchor_weight: float, + replay_ce_weight: float, + feature_anchor_weight: float, + prototype_anchor_weight: float, + calibration_epochs: int, + calibration_lr: float, + calibration_weight_decay: float, + replay_augment: bool, + use_current_task_mask: bool, + seed: int, +) -> CARStrategy: + return CARStrategy( + model=model, + device=device, + learning_rate=learning_rate, + buffer_size=buffer_size, + replay_batch_size=replay_batch_size, + task_classes=task_classes, + num_classes=num_classes, + weights=CARWeights( + logit_anchor=logit_anchor_weight, + replay_ce=replay_ce_weight, + feature_anchor=feature_anchor_weight, + prototype_anchor=prototype_anchor_weight, + ), + calibration_epochs=calibration_epochs, + calibration_lr=calibration_lr, + calibration_weight_decay=calibration_weight_decay, + replay_augment=replay_augment, + use_current_task_mask=use_current_task_mask, + seed=seed, + ) diff --git a/src/cl_bench/strategies/er_ace.py b/src/cl_bench/strategies/er_ace.py new file mode 100644 index 0000000..67e697d --- /dev/null +++ b/src/cl_bench/strategies/er_ace.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import torch +from torch import nn + +from cl_bench.strategies.base import ContinualLearningStrategy +from cl_bench.strategies.replay import ReservoirReplayBuffer + + +class ERACEStrategy(ContinualLearningStrategy): + """Experience replay with asymmetric cross-entropy for class-incremental streams.""" + + def __init__( + self, + model: nn.Module, + device: torch.device, + learning_rate: float, + buffer_size: int, + replay_batch_size: int, + replay_loss_weight: float, + task_classes: list[list[int]], + seed: int, + ): + super().__init__(model=model, device=device, learning_rate=learning_rate) + self.buffer = ReservoirReplayBuffer(capacity=buffer_size, seed=seed) + self.replay_batch_size = replay_batch_size + self.replay_loss_weight = replay_loss_weight + self.task_classes = task_classes + + def compute_loss( + self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]: + logits = self.model(inputs) + ce_loss = self.criterion(_mask_to_current_task(logits, self.task_classes[task_id]), targets) + replay_loss = torch.zeros((), device=self.device) + + if len(self.buffer) > 0 and self.replay_batch_size > 0: + replay_inputs, replay_targets = self.buffer.sample(self.replay_batch_size) + replay_inputs = replay_inputs.to(self.device) + replay_targets = replay_targets.to(self.device) + replay_logits = self.model(replay_inputs) + replay_loss = self.criterion(replay_logits, replay_targets) + + loss = ce_loss + self.replay_loss_weight * replay_loss + return ( + loss, + logits, + { + "er_ace_current_ce_loss": float(ce_loss.detach().item()), + "replay_loss": float(replay_loss.detach().item()), + }, + ) + + def observe_batch( + self, + inputs: torch.Tensor, + targets: torch.Tensor, + logits: torch.Tensor, + task_id: int, + ) -> None: + del logits, task_id + self.buffer.add_batch(inputs, targets) + + def extra_state_dict(self) -> dict[str, object]: + return { + "buffer_inputs": [sample.inputs for sample in self.buffer.samples], + "buffer_targets": [sample.target for sample in self.buffer.samples], + "seen_count": self.buffer.seen_count, + "replay_loss_weight": self.replay_loss_weight, + } + + +def _mask_to_current_task(logits: torch.Tensor, classes: list[int]) -> torch.Tensor: + masked_logits = logits.clone() + allowed = torch.zeros(logits.size(1), dtype=torch.bool, device=logits.device) + allowed[torch.tensor(classes, dtype=torch.long, device=logits.device)] = True + masked_logits[:, ~allowed] = -1e9 + return masked_logits + + +def build_er_ace( + model: nn.Module, + device: torch.device, + learning_rate: float, + buffer_size: int, + replay_batch_size: int, + replay_loss_weight: float, + task_classes: list[list[int]], + seed: int, +) -> ERACEStrategy: + return ERACEStrategy( + model=model, + device=device, + learning_rate=learning_rate, + buffer_size=buffer_size, + replay_batch_size=replay_batch_size, + replay_loss_weight=replay_loss_weight, + task_classes=task_classes, + seed=seed, + ) diff --git a/src/cl_bench/strategies/gdumb.py b/src/cl_bench/strategies/gdumb.py new file mode 100644 index 0000000..a0b87c8 --- /dev/null +++ b/src/cl_bench/strategies/gdumb.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import torch +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +from cl_bench.strategies.base import ContinualLearningStrategy, clone_state_dict, load_state_dict +from cl_bench.strategies.replay import BalancedReplayBuffer + + +class GDumbStrategy(ContinualLearningStrategy): + """Greedy class-balanced memory with from-scratch training on stored examples.""" + + def __init__( + self, + model: nn.Module, + device: torch.device, + learning_rate: float, + buffer_size: int, + memory_epochs: int, + batch_size: int, + seed: int, + ): + super().__init__(model=model, device=device, learning_rate=learning_rate) + self.buffer = BalancedReplayBuffer(capacity=buffer_size, seed=seed) + self.initial_state = clone_state_dict(model) + self.memory_epochs = memory_epochs + self.batch_size = batch_size + self.seed = seed + + def train_task( + self, + train_loader: DataLoader, + val_loader: DataLoader, + task_id: int, + epochs: int, + ) -> list[dict[str, float | int]]: + del epochs + self.current_task = task_id + self.seen_tasks = max(self.seen_tasks, task_id + 1) + self.before_task(task_id) + + example_count = 0 + for inputs, targets in train_loader: + self.buffer.add_batch(inputs, targets) + example_count += int(targets.numel()) + + self._fit_memory(task_id) + val_metrics = self.evaluate(val_loader) + return [ + { + "task_id": task_id, + "epoch": 1, + "train_loss": 0.0, + "train_accuracy": 0.0, + "train_examples": float(example_count), + **{f"val_{key}": value for key, value in val_metrics.items()}, + } + ] + + def compute_loss( + self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]: + del targets, task_id + logits = self.model(inputs) + return logits.sum() * 0.0, logits, {"gdumb_online_loss": 0.0} + + def observe_batch( + self, + inputs: torch.Tensor, + targets: torch.Tensor, + logits: torch.Tensor, + task_id: int, + ) -> None: + del logits, task_id + self.buffer.add_batch(inputs, targets) + + def _fit_memory(self, task_id: int) -> None: + if len(self.buffer) == 0 or self.memory_epochs <= 0: + return + load_state_dict(self.model, self.initial_state, self.device) + self.optimizer = self._build_optimizer() + inputs, targets = self.buffer.tensors() + memory_dataset = TensorDataset(inputs, targets) + memory_loader = DataLoader( + memory_dataset, + batch_size=self.batch_size, + shuffle=True, + generator=torch.Generator().manual_seed(self.seed + task_id), + ) + for _ in range(self.memory_epochs): + self.model.train() + for memory_inputs, memory_targets in memory_loader: + memory_inputs = memory_inputs.to(self.device) + memory_targets = memory_targets.to(self.device) + self.optimizer.zero_grad(set_to_none=True) + loss = self.criterion(self.model(memory_inputs), memory_targets) + loss.backward() + self.optimizer.step() + + def extra_state_dict(self) -> dict[str, object]: + return { + "buffer_inputs": [sample.inputs for sample in self.buffer.samples], + "buffer_targets": [sample.target for sample in self.buffer.samples], + "seen_count": self.buffer.seen_count, + "memory_epochs": self.memory_epochs, + "class_counts": dict(self.buffer.class_counts()), + } + + +def build_gdumb( + model: nn.Module, + device: torch.device, + learning_rate: float, + buffer_size: int, + memory_epochs: int, + batch_size: int, + seed: int, +) -> GDumbStrategy: + return GDumbStrategy( + model=model, + device=device, + learning_rate=learning_rate, + buffer_size=buffer_size, + memory_epochs=memory_epochs, + batch_size=batch_size, + seed=seed, + ) diff --git a/src/cl_bench/strategies/replay.py b/src/cl_bench/strategies/replay.py index c41d60e..083d248 100644 --- a/src/cl_bench/strategies/replay.py +++ b/src/cl_bench/strategies/replay.py @@ -1,6 +1,7 @@ from __future__ import annotations import random +from collections import Counter from dataclasses import dataclass import torch @@ -15,6 +16,7 @@ class ReplaySample: inputs: torch.Tensor target: int logits: torch.Tensor | None = None + features: torch.Tensor | None = None class ReservoirReplayBuffer: @@ -36,15 +38,24 @@ def add_batch( inputs: torch.Tensor, targets: torch.Tensor, logits: torch.Tensor | None = None, + features: torch.Tensor | None = None, ) -> None: logits_cpu = None if logits is None else logits.detach().cpu() + features_cpu = None if features is None else features.detach().cpu() for index, (input_tensor, target) in enumerate( zip(inputs.detach().cpu(), targets.detach().cpu(), strict=True) ): sample_logits = None if logits_cpu is None else logits_cpu[index] - self.add(input_tensor, int(target.item()), sample_logits) + sample_features = None if features_cpu is None else features_cpu[index] + self.add(input_tensor, int(target.item()), sample_logits, sample_features) - def add(self, inputs: torch.Tensor, target: int, logits: torch.Tensor | None = None) -> None: + def add( + self, + inputs: torch.Tensor, + target: int, + logits: torch.Tensor | None = None, + features: torch.Tensor | None = None, + ) -> None: self.seen_count += 1 if self.capacity == 0: return @@ -52,6 +63,7 @@ def add(self, inputs: torch.Tensor, target: int, logits: torch.Tensor | None = N inputs=inputs.detach().cpu().clone(), target=int(target), logits=None if logits is None else logits.detach().cpu().clone(), + features=None if features is None else features.detach().cpu().clone(), ) if len(self.samples) < self.capacity: self.samples.append(sample) @@ -66,20 +78,132 @@ def sample(self, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]: return inputs, targets def sample_tensors( - self, batch_size: int, require_logits: bool = False + self, + batch_size: int, + require_logits: bool = False, + require_features: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: - chosen = self.sample_samples(batch_size, require_logits=require_logits) + chosen = self.sample_samples( + batch_size, + require_logits=require_logits, + require_features=require_features, + ) inputs = torch.stack([sample.inputs for sample in chosen]) targets = torch.tensor([sample.target for sample in chosen], dtype=torch.long) if require_logits: logits = torch.stack([sample.logits for sample in chosen if sample.logits is not None]) + elif require_features: + logits = torch.stack( + [sample.features for sample in chosen if sample.features is not None] + ) else: logits = None return inputs, targets, logits - def sample_samples(self, batch_size: int, require_logits: bool = False) -> list[ReplaySample]: + def sample_samples( + self, + batch_size: int, + require_logits: bool = False, + require_features: bool = False, + ) -> list[ReplaySample]: + candidates = [ + sample + for sample in self.samples + if (not require_logits or sample.logits is not None) + and (not require_features or sample.features is not None) + ] + if not candidates: + raise ValueError("Cannot sample from an empty replay buffer.") + return self.rng.sample(candidates, k=min(batch_size, len(candidates))) + + +class BalancedReplayBuffer: + """Class-balanced memory buffer for exemplar-only baselines such as GDumb.""" + + def __init__(self, capacity: int, seed: int = 0): + if capacity < 0: + raise ValueError("Replay buffer capacity must be non-negative.") + self.capacity = capacity + self.rng = random.Random(seed) + self.samples: list[ReplaySample] = [] + self.seen_count = 0 + + def __len__(self) -> int: + return len(self.samples) + + def add_batch( + self, + inputs: torch.Tensor, + targets: torch.Tensor, + logits: torch.Tensor | None = None, + features: torch.Tensor | None = None, + ) -> None: + logits_cpu = None if logits is None else logits.detach().cpu() + features_cpu = None if features is None else features.detach().cpu() + for index, (input_tensor, target) in enumerate( + zip(inputs.detach().cpu(), targets.detach().cpu(), strict=True) + ): + sample_logits = None if logits_cpu is None else logits_cpu[index] + sample_features = None if features_cpu is None else features_cpu[index] + self.add(input_tensor, int(target.item()), sample_logits, sample_features) + + def add( + self, + inputs: torch.Tensor, + target: int, + logits: torch.Tensor | None = None, + features: torch.Tensor | None = None, + ) -> None: + self.seen_count += 1 + if self.capacity == 0: + return + + sample = ReplaySample( + inputs=inputs.detach().cpu().clone(), + target=int(target), + logits=None if logits is None else logits.detach().cpu().clone(), + features=None if features is None else features.detach().cpu().clone(), + ) + if len(self.samples) < self.capacity: + self.samples.append(sample) + return + + counts = self.class_counts() + target_count = counts.get(int(target), 0) + max_count = max(counts.values(), default=0) + if target_count >= max_count: + return + + replacement_classes = [label for label, count in counts.items() if count == max_count] + replacement_class = self.rng.choice(replacement_classes) + replacement_indices = [ + index + for index, existing_sample in enumerate(self.samples) + if existing_sample.target == replacement_class + ] + self.samples[self.rng.choice(replacement_indices)] = sample + + def class_counts(self) -> Counter[int]: + return Counter(sample.target for sample in self.samples) + + def tensors(self) -> tuple[torch.Tensor, torch.Tensor]: + if not self.samples: + raise ValueError("Cannot materialize an empty replay buffer.") + inputs = torch.stack([sample.inputs for sample in self.samples]) + targets = torch.tensor([sample.target for sample in self.samples], dtype=torch.long) + return inputs, targets + + def sample_samples( + self, + batch_size: int, + require_logits: bool = False, + require_features: bool = False, + ) -> list[ReplaySample]: candidates = [ - sample for sample in self.samples if not require_logits or sample.logits is not None + sample + for sample in self.samples + if (not require_logits or sample.logits is not None) + and (not require_features or sample.features is not None) ] if not candidates: raise ValueError("Cannot sample from an empty replay buffer.") diff --git a/src/cl_bench/tracking.py b/src/cl_bench/tracking.py index d1e6eae..7ba8322 100644 --- a/src/cl_bench/tracking.py +++ b/src/cl_bench/tracking.py @@ -41,7 +41,14 @@ def create_run_dir(output_dir: str | Path, benchmark_name: str, method: str) -> timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") safe_name = _safe_slug(benchmark_name) safe_method = _safe_slug(method) - return Path(output_dir) / f"{safe_name}_{safe_method}_{timestamp}" + base = Path(output_dir) / f"{safe_name}_{safe_method}_{timestamp}" + if not base.exists(): + return base + for suffix in range(1, 10_000): + candidate = Path(f"{base}_{suffix:03d}") + if not candidate.exists(): + return candidate + raise RuntimeError(f"Could not allocate a unique run directory for {base}.") def git_commit(repo_dir: str | Path | None = None) -> str | None: diff --git a/tests/test_config.py b/tests/test_config.py index b664ef7..4ace326 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -37,6 +37,7 @@ def test_load_cifar_headline_config_and_overrides() -> None: assert config.replay_loss_weight == 3.0 assert config.derpp_alpha == 0.1 assert config.derpp_beta == 2.0 + assert config.gdumb_epochs == 20 assert len(config.tasks) == 5 assert all(task.dataset == "cifar10" for task in config.tasks) @@ -49,6 +50,20 @@ def test_load_cifar_headline_config_and_overrides() -> None: assert overridden.tracking == "json" +def test_load_nested_paper_configs() -> None: + config = load_config("paper/split_cifar10_full") + + assert config.name == "split_cifar10_full" + assert config.method == "car" + assert config.model == "resnet18_cifar" + assert config.optimizer == "sgd" + assert config.scheduler == "cosine" + assert config.label_smoothing == 0.05 + assert config.car_logit_anchor_weight == 0.25 + assert len(config.tasks) == 5 + assert all(task.train_limit is None for task in config.tasks) + + def test_nested_training_and_strategy_values_are_parsed() -> None: config = ExperimentConfig.from_dict( { diff --git a/tests/test_datasets.py b/tests/test_datasets.py index da821bd..8c73e5f 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -95,3 +95,42 @@ def __getitem__(self, index: int): assert num_classes == 2 assert len(tasks[0].train_loader.dataset) == 5 assert len(tasks[0].test_loader.dataset) == 4 + + +def test_feature_cache_task_construction_is_deterministic(tmp_path) -> None: + train_cache = tmp_path / "train_features.pt" + test_cache = tmp_path / "test_features.pt" + payload = { + "features": torch.arange(24, dtype=torch.float32).reshape(6, 4), + "targets": torch.tensor([0, 0, 1, 1, 2, 2]), + "classes": [0, 1, 2], + } + torch.save(payload, train_cache) + torch.save(payload, test_cache) + config = ExperimentConfig( + name="feature_unit", + method="baseline", + seed=3, + model="linear", + data_dir=str(tmp_path), + batch_size=2, + eval_batch_size=4, + val_fraction=0.0, + tasks=[ + TaskSpec( + name="features_0_1", + dataset="feature_cache", + classes=[0, 1], + train_feature_cache=train_cache.name, + test_feature_cache=test_cache.name, + ) + ], + ) + + tasks, input_shape, num_classes = build_task_loaders(config) + + assert input_shape == (4,) + assert num_classes == 2 + assert len(tasks[0].train_loader.dataset) == 4 + first_batch = next(iter(tasks[0].train_loader)) + assert first_batch[0].shape[-1] == 4 diff --git a/tests/test_integration_smoke.py b/tests/test_integration_smoke.py index 1149060..473f3f4 100644 --- a/tests/test_integration_smoke.py +++ b/tests/test_integration_smoke.py @@ -21,7 +21,7 @@ def test_cpu_smoke_benchmark_writes_reproducibility_artifacts(tmp_path) -> None: def test_synthetic_suite_runs_core_memory_methods(tmp_path) -> None: - for method in ["baseline", "replay", "derpp", "agem"]: + for method in ["baseline", "replay", "derpp", "agem", "er_ace", "gdumb", "car"]: config = replace(load_config("smoke"), output_dir=str(tmp_path), method=method) result = run_experiment(config) diff --git a/tests/test_reporting.py b/tests/test_reporting.py index 165fc5b..f103882 100644 --- a/tests/test_reporting.py +++ b/tests/test_reporting.py @@ -2,7 +2,7 @@ import json -from cl_bench.reporting import aggregate_records, collect_runs, write_report +from cl_bench.reporting import aggregate_records, collect_runs, write_export, write_report def _write_metrics(run_dir, method: str, seed: int, final_accuracy: float) -> None: @@ -18,6 +18,8 @@ def _write_metrics(run_dir, method: str, seed: int, final_accuracy: float) -> No "backward_transfer": -5.0, "runtime_seconds": 1.5, "seed": seed, + "replay_buffer_size": 500, + "model": "linear", }, "accuracy_matrix": [[80.0, None], [70.0, final_accuracy]], "forgetting_matrix": [[0.0, None], [10.0, 0.0]], @@ -64,3 +66,17 @@ def test_collect_runs_from_mlflow_artifact_export_shape(tmp_path) -> None: assert len(records) == 1 assert records[0].method == "derpp" assert records[0].summary["average_final_accuracy"] == 88.0 + + +def test_write_paper_report_and_export_without_plots(tmp_path) -> None: + _write_metrics(tmp_path / "car_seed_1", "car", 1, 90.0) + records = collect_runs([tmp_path]) + + report = write_report(records, tmp_path / "paper", "Paper report", make_plots=False, paper=True) + exports = write_export(records, tmp_path / "export", "mammoth") + + assert report.markdown.exists() + assert (tmp_path / "paper" / "leaderboard_table.tex").exists() + assert (tmp_path / "paper" / "per_seed_results.csv").exists() + assert (tmp_path / "paper" / "claims_table.md").exists() + assert all(path.exists() for path in exports) diff --git a/tests/test_strategies.py b/tests/test_strategies.py index 3ce3518..1b40ee3 100644 --- a/tests/test_strategies.py +++ b/tests/test_strategies.py @@ -8,9 +8,13 @@ from cl_bench.strategies.agem import AGEMStrategy from cl_bench.strategies.base import clone_state_dict from cl_bench.strategies.baseline import BaselineStrategy +from cl_bench.strategies.car import CARStrategy, CARWeights, _augment_replay_batch from cl_bench.strategies.derpp import DERPPStrategy +from cl_bench.strategies.er_ace import ERACEStrategy from cl_bench.strategies.ewc import EWCStrategy +from cl_bench.strategies.gdumb import GDumbStrategy from cl_bench.strategies.lwf import LwFStrategy +from cl_bench.strategies.replay import BalancedReplayBuffer def _tiny_config() -> ExperimentConfig: @@ -145,3 +149,147 @@ def test_agem_projects_conflicting_gradient() -> None: assert strategy.last_gradient_dot < 0.0 assert strategy.last_projection_applied == 1.0 + + +def test_er_ace_masks_current_task_logits_and_replays_memory() -> None: + config = _tiny_config() + tasks, input_shape, num_classes = build_task_loaders(config) + model = get_model("linear", input_shape, num_classes) + strategy = ERACEStrategy( + model, + torch.device("cpu"), + learning_rate=0.05, + buffer_size=16, + replay_batch_size=4, + replay_loss_weight=1.0, + task_classes=[[0, 1]], + seed=1, + ) + inputs, targets = next(iter(tasks[0].train_loader)) + strategy.observe_batch(inputs, targets, strategy.model(inputs), task_id=0) + + loss, _, components = strategy.compute_loss(inputs, targets, task_id=0) + + assert loss.item() > 0.0 + assert len(strategy.buffer) == inputs.size(0) + assert "er_ace_current_ce_loss" in components + assert "replay_loss" in components + + +def test_balanced_replay_buffer_rebalances_classes() -> None: + buffer = BalancedReplayBuffer(capacity=4, seed=1) + buffer.add_batch(torch.randn(4, 1, 2, 2), torch.tensor([0, 0, 0, 0])) + buffer.add_batch(torch.randn(4, 1, 2, 2), torch.tensor([1, 1, 1, 1])) + + counts = buffer.class_counts() + + assert counts[0] == 2 + assert counts[1] == 2 + + +def test_gdumb_collects_balanced_memory_and_trains_after_task() -> None: + config = _tiny_config() + tasks, input_shape, num_classes = build_task_loaders(config) + model = get_model("linear", input_shape, num_classes) + strategy = GDumbStrategy( + model, + torch.device("cpu"), + learning_rate=0.05, + buffer_size=16, + memory_epochs=2, + batch_size=8, + seed=1, + ) + + strategy.train_task(tasks[0].train_loader, tasks[0].val_loader, task_id=0, epochs=1) + metrics = strategy.evaluate(tasks[0].test_loader) + + assert len(strategy.buffer) > 0 + assert 0.0 <= metrics["accuracy"] <= 100.0 + + +def test_car_stores_anchors_and_uses_all_loss_components() -> None: + config = _tiny_config() + tasks, input_shape, num_classes = build_task_loaders(config) + model = get_model("linear", input_shape, num_classes) + strategy = CARStrategy( + model, + torch.device("cpu"), + learning_rate=0.05, + buffer_size=16, + replay_batch_size=4, + task_classes=[[0, 1]], + num_classes=num_classes, + weights=CARWeights( + logit_anchor=0.25, + replay_ce=1.0, + feature_anchor=0.1, + prototype_anchor=0.1, + ), + calibration_epochs=1, + calibration_lr=0.01, + calibration_weight_decay=0.0, + replay_augment=False, + use_current_task_mask=True, + seed=1, + ) + inputs, targets = next(iter(tasks[0].train_loader)) + logits = strategy.model(inputs) + strategy.observe_batch(inputs, targets, logits, task_id=0) + strategy._refresh_buffer_anchors() + + loss, _, components = strategy.compute_loss(inputs, targets, task_id=0) + + assert loss.item() > 0.0 + assert len(strategy.buffer) == inputs.size(0) + assert all(sample.logits is not None for sample in strategy.buffer.samples) + assert all(sample.features is not None for sample in strategy.buffer.samples) + assert "car_replay_ce_loss" in components + assert "car_logit_anchor_loss" in components + assert "car_feature_anchor_loss" in components + assert "car_prototype_anchor_loss" in components + + +def test_car_calibration_updates_logits_without_backbone_change() -> None: + config = _tiny_config() + tasks, input_shape, num_classes = build_task_loaders(config) + model = get_model("linear", input_shape, num_classes) + strategy = CARStrategy( + model, + torch.device("cpu"), + learning_rate=0.05, + buffer_size=16, + replay_batch_size=4, + task_classes=[[0, 1]], + num_classes=num_classes, + weights=CARWeights(0.0, 1.0, 0.0, 0.0), + calibration_epochs=3, + calibration_lr=0.1, + calibration_weight_decay=0.0, + replay_augment=False, + use_current_task_mask=True, + seed=1, + ) + inputs, targets = next(iter(tasks[0].train_loader)) + with torch.no_grad(): + before_model = clone_state_dict(strategy.model) + before_logits = strategy.calibrator(strategy.model(inputs)).detach().clone() + strategy.observe_batch(inputs, targets, strategy.model(inputs), task_id=0) + strategy._refresh_buffer_anchors() + strategy._fit_calibration() + with torch.no_grad(): + after_logits = strategy.calibrator(strategy.model(inputs)).detach() + + for name, tensor in before_model.items(): + assert torch.equal(tensor, strategy.model.state_dict()[name]) + assert not torch.equal(before_logits, after_logits) + + +def test_replay_augmentation_does_not_mutate_stored_exemplars() -> None: + stored = torch.randn(4, 3, 32, 32) + original = stored.clone() + + augmented = _augment_replay_batch(stored) + + assert torch.equal(stored, original) + assert augmented.shape == stored.shape