diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..bc50a73
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,51 @@
+name: CI
+
+on:
+ push:
+ pull_request:
+
+jobs:
+ lint-test-build:
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ python-version: ["3.10", "3.11", "3.12"]
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: pip
+ - name: Install package
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install -e ".[dev,experiment,report]"
+ - name: Lint
+ run: ruff check .
+ - name: Format check
+ run: ruff format --check .
+ - name: Tests
+ run: pytest --cov
+ - name: CPU smoke benchmark
+ run: cl-bench run --config configs/smoke.yaml --method baseline --epochs 1 --device cpu --output-dir /tmp/cl-bench-runs
+ - name: CPU benchmark suite report
+ run: |
+ cl-bench suite \
+ --config configs/smoke.yaml \
+ --methods baseline replay derpp agem \
+ --seeds 7 \
+ --epochs 1 \
+ --device cpu \
+ --tracking json \
+ --output-dir /tmp/cl-bench-runs \
+ --report-dir /tmp/cl-bench-report
+ - name: Build package
+ run: python -m build
+ - name: Import check
+ run: python -c "import cl_bench; print(cl_bench.__version__)"
+ - uses: actions/upload-artifact@v4
+ if: always()
+ with:
+ name: benchmark-report-${{ matrix.python-version }}
+ path: /tmp/cl-bench-report
diff --git a/.gitignore b/.gitignore
index 66b7405..33a4845 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,23 +2,39 @@
__pycache__/
*.py[cod]
*$py.class
-*.pyc
-# Distribution / packaging
+# Build and packaging artifacts
*.egg-info/
dist/
build/
+.ruff_cache/
+.pytest_cache/
+.coverage
+htmlcov/
# Virtual environments
.venv/
venv/
env/
-# Jupyter Notebook checkpoints
-.ipynb_checkpoints/
-
-# Results and logs (generated at runtime)
-results/logs/
+# Runtime artifacts generated by experiments
+data/
+checkpoints/
+logs/
+results/
+runs/
+wandb/
+tensorboard/
+mlruns/
+.hydra/
+*.pt
+*.pth
+*.ckpt
+*.npy
+*.npz
-# OS files
+# Notebook/editor/OS noise
+.ipynb_checkpoints/
.DS_Store
+.idea/
+.vscode/
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..83dc6af
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,16 @@
+FROM python:3.12-slim
+
+ENV PYTHONDONTWRITEBYTECODE=1
+ENV PYTHONUNBUFFERED=1
+
+WORKDIR /app
+
+COPY pyproject.toml README.md License ./
+COPY src ./src
+COPY configs ./configs
+COPY tests ./tests
+
+RUN python -m pip install --upgrade pip \
+ && python -m pip install -e ".[dev,experiment,report]"
+
+CMD ["cl-bench", "run", "--config", "configs/smoke.yaml", "--method", "baseline", "--epochs", "1", "--device", "cpu", "--output-dir", "/tmp/cl-bench-runs"]
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..b580646
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,39 @@
+PYTHON ?= python3
+VENV ?= .venv
+BIN := $(VENV)/bin
+export PYTHONPATH := src
+
+.PHONY: setup lint format test smoke suite benchmark build verify clean
+
+setup:
+ $(PYTHON) -m venv $(VENV)
+ $(BIN)/python -m pip install --upgrade pip
+ $(BIN)/python -m pip install -e ".[dev,experiment,report]"
+
+lint:
+ $(BIN)/ruff check .
+ $(BIN)/ruff format --check .
+
+format:
+ $(BIN)/ruff format .
+ $(BIN)/ruff check --fix .
+
+test:
+ $(BIN)/pytest --cov
+
+smoke:
+ $(BIN)/cl-bench run --config configs/smoke.yaml --method baseline --epochs 1 --device cpu
+
+suite:
+ $(BIN)/cl-bench suite --config configs/smoke.yaml --methods baseline ewc replay lwf derpp agem --seeds 7 --epochs 1 --device cpu --report-dir reports/smoke
+
+benchmark:
+ $(BIN)/cl-bench suite --config-name split_cifar10_headline --methods baseline ewc replay lwf derpp agem --seeds 13 21 --tracking both --report-dir docs/assets/split_cifar10_headline --title "Split CIFAR-10 Headline Benchmark"
+
+build:
+ $(BIN)/python -m build
+
+verify: lint test smoke build
+
+clean:
+ rm -rf build dist *.egg-info .pytest_cache .ruff_cache htmlcov .coverage
diff --git a/README.md b/README.md
index 49ee05b..3134779 100644
--- a/README.md
+++ b/README.md
@@ -1,245 +1,188 @@
-# 🧠 Continual Learning System
+# Continual Learning Benchmark
-
+[](https://github.com/1Utkarsh1/Continual-Learning/actions/workflows/ci.yml)
-
-
-
-
+A PyTorch benchmark framework for comparing continual-learning strategies under
+the same task stream, metric suite, artifact pipeline, local MLflow tracker, and
+report generator.
-**A robust framework for training neural networks that can learn sequentially without forgetting**
+Experiments are config-driven, methods share a stable lifecycle interface, real
+and synthetic benchmarks use the same runner, and every run writes
+reproducibility artifacts that can be aggregated into leaderboard CSV/JSON files
+and plots. The primary reported result is a verified Split CIFAR-10 suite.
-[Overview](#overview) •
-[Key Features](#key-features) •
-[Techniques](#techniques) •
-[Installation](#installation) •
-[Usage](#usage) •
-[Results](#results) •
-[Contributing](#contributing)
+
-
+## Project Scope
-## 🔄 Overview
+- Config-driven benchmark runner for single runs and multi-method suites.
+- Implemented baseline fine-tuning, EWC, reservoir replay, LwF, DER++, and A-GEM.
+- Deterministic synthetic CI benchmark plus real MNIST and CIFAR-10 task streams.
+- Artifact tracking for config snapshots, metadata, JSONL events, CSV matrices,
+ checkpoints, MLflow runs, aggregate reports, and plots.
+- Python package, CLI, Dockerfile, Makefile, Ruff, pytest coverage, and GitHub
+ Actions matrix across Python 3.10, 3.11, and 3.12.
-The Continual Learning System is a comprehensive framework for developing neural networks that can learn tasks sequentially without suffering from catastrophic forgetting. This project implements several state-of-the-art techniques to mitigate forgetting in neural networks, allowing them to adapt to new tasks while retaining performance on previously learned ones.
+## Quickstart
-
-## 🌟 Key Features
-
-- **Task Sequential Learning**: Train models on a sequence of tasks without complete retraining
-- **Forgetting Mitigation**: Advanced techniques to prevent catastrophic forgetting
-- **Performance Tracking**: Comprehensive metrics to monitor how well knowledge is retained
-- **Experiment Framework**: Easily run and compare different continual learning approaches
-- **Visualization Tools**: Track and visualize forgetting metrics across sequential tasks
-
-## 🧩 Techniques Implemented
-
-### Elastic Weight Consolidation (EWC)
-
-EWC measures the importance of neural network weights for previously learned tasks and penalizes changes to important weights when learning new tasks.
-
-```python
-# Loss calculation with EWC
-loss = task_loss + lambda_ewc * ewc_loss
+```bash
+git clone https://github.com/1Utkarsh1/Continual-Learning.git
+cd Continual-Learning
+
+python3 -m venv .venv
+source .venv/bin/activate
+python -m pip install --upgrade pip
+python -m pip install -e ".[dev,experiment,report]"
+
+ruff check .
+ruff format --check .
+pytest --cov
+cl-bench run --config-name smoke --method baseline --epochs 1 --device cpu
```
-### Experience Replay
+Or use the project automation:
-This technique maintains a memory buffer of examples from previous tasks and periodically replays them during training on new tasks.
-
-```python
-# Replay during training
-combined_loss = current_task_loss + alpha * replay_loss
+```bash
+make setup
+make verify
+make benchmark
```
-### Learning without Forgetting (LwF)
+## CLI
-LwF uses knowledge distillation to preserve the model's behavior on previous tasks when learning new ones.
+Run one benchmark:
-```python
-# LwF distillation loss
-distillation_loss = KL_divergence(current_outputs, previous_outputs)
+```bash
+cl-bench run --config-name smoke --method replay --epochs 1 --device cpu
```
-### Task-specific Components
-
-For some approaches, we isolate or add task-specific parameters while sharing a common feature extraction backbone.
-
-## 🔧 Installation
+Run the headline Split CIFAR-10 suite with local MLflow tracking and plots:
```bash
-# Clone the repository
-git clone https://github.com/1Utkarsh1/continual-learning.git
-cd continual-learning
-
-# Create a virtual environment
-python -m venv venv
-source venv/bin/activate # On Windows: venv\Scripts\activate
-
-# Install dependencies
-pip install -r requirements.txt
+cl-bench suite \
+ --config-name split_cifar10_headline \
+ --methods baseline ewc replay lwf derpp agem \
+ --seeds 13 21 \
+ --tracking both \
+ --report-dir docs/assets/split_cifar10_headline \
+ --title "Split CIFAR-10 Headline Benchmark"
```
-## 📊 Usage
-
-### Quick Start
+Use Hydra/OmegaConf-style overrides for quick experiments:
```bash
-# Run baseline experiment (sequential training without any continual learning techniques)
-python src/main.py --method baseline --tasks mnist_split
-
-# Run EWC experiment
-python src/main.py --method ewc --tasks mnist_split --lambda_ewc 5000
-
-# Run Experience Replay experiment
-python src/main.py --method replay --tasks mnist_split --buffer_size 500
+cl-bench suite \
+ --config-name split_cifar10_headline \
+ --methods baseline derpp \
+ --seeds 13 \
+ --tracking json \
+ training.epochs=1 strategy.replay_buffer_size=500
```
-### Custom Task Sequences
-
-You can define your own task sequences in a YAML configuration file:
-
-```yaml
-# config/tasks/custom_sequence.yaml
-task_sequence:
- - name: "mnist_digits_0_4"
- dataset: "mnist"
- classes: [0, 1, 2, 3, 4]
-
- - name: "mnist_digits_5_9"
- dataset: "mnist"
- classes: [5, 6, 7, 8, 9]
-
- - name: "fashion_mnist"
- dataset: "fashion_mnist"
- classes: "all"
-```
-
-## 📈 Experimental Results
-
-### Comparison of Methods
-
-
-
-
- | Method |
- Average Accuracy |
- Average Forgetting |
- Training Time |
-
-
- | Naïve Fine-tuning |
- 45.2% |
- 35.8% |
- 1.0x |
-
-
- | EWC |
- 78.5% |
- 10.2% |
- 1.2x |
-
-
- | Experience Replay |
- 82.3% |
- 7.5% |
- 1.5x |
-
-
- | LwF |
- 75.7% |
- 12.8% |
- 1.3x |
-
-
-
-
-
-## 📝 Recent Experiment Results
-
-The following experiment results were obtained on March 11, 2025 using the MNIST split task sequence:
-
-**Baseline (Naïve Fine-tuning):**
-- Command: `python src/main.py --method baseline --tasks mnist_split --epochs 5`
-- Task sequence: ['mnist_0_4', 'mnist_5_9']
-- Average final accuracy: 49.74%
-- Average forgetting: 49.90%
-
-**Learning without Forgetting (LwF):**
-- Command: `python src/main.py --method lwf --tasks mnist_split --epochs 5`
-- Task sequence: ['mnist_0_4', 'mnist_5_9']
-- Average final accuracy: 49.67%
-- Average forgetting: 49.83%
-
-## 🛠️ Project Structure
+Inspect local experiment runs:
+```bash
+mlflow ui --backend-store-uri sqlite:///mlruns/mlflow.db
```
-continual_learning/
-├── src/ # Source code
-│ ├── models/ # Neural network architectures
-│ ├── data/ # Data loading and preprocessing
-│ ├── methods/ # Continual learning algorithms
-│ ├── utils/ # Utility functions
-│ └── main.py # Main entry point
-├── experiments/ # Jupyter notebooks for experiments
-├── config/ # Configuration files
-│ ├── models/ # Model configurations
-│ └── tasks/ # Task sequence definitions
-├── results/ # Saved results and visualizations
-└── docs/ # Documentation
-```
-
-## 📝 Example Experiments
-
-1. **Split MNIST**
- - Train on digits 0-4, then 5-9
- - Compare different methods' ability to remember the first task
-
-2. **Task Incremental Learning**
- - Train on MNIST → Fashion-MNIST → KMNIST
- - Measure accuracy on all previous datasets after each task
-
-3. **Class Incremental Learning**
- - Add new classes (one at a time) to a classifier
- - Test identification of all classes after each addition
-## 🔮 Roadmap
+Aggregate existing run directories or MLflow artifact exports:
-- [x] Implement baseline sequential training
-- [x] Implement Elastic Weight Consolidation (EWC)
-- [x] Implement Experience Replay
-- [x] Implement Learning without Forgetting (LwF)
-- [x] Add support for generative replay
-- [x] Implement parameter isolation methods
-- [x] Add support for continual reinforcement learning
-- [x] Develop benchmark suite for comparing methods
-
-## 🤝 Contributing
-
-Contributions are welcome! Please feel free to submit a Pull Request.
-
-1. Fork the repository
-2. Create your feature branch (`git checkout -b feature/amazing-feature`)
-3. Commit your changes (`git commit -m 'Add some amazing feature'`)
-4. Push to the branch (`git push origin feature/amazing-feature`)
-5. Open a Pull Request
+```bash
+cl-bench report \
+ --runs runs \
+ --output-dir reports/local \
+ --title "Local continual-learning report"
+```
-## 📚 References
+## Verified Headline Benchmark
-1. Kirkpatrick, J. et al. "Overcoming catastrophic forgetting in neural networks" - *Proceedings of the National Academy of Sciences* (2017)
-2. Rebuffi, S. et al. "iCaRL: Incremental Classifier and Representation Learning" - *CVPR* (2017)
-3. Li, Z. and Hoiem, D. "Learning without Forgetting" - *IEEE Transactions on Pattern Analysis and Machine Intelligence* (2018)
-4. Chaudhry, A. et al. "Efficient Lifelong Learning with A-GEM" - *ICLR* (2019)
+Local verification on 2026-05-25 used Python 3.11.15, PyTorch 2.12.0,
+torchvision 0.27.0, NumPy 2.4.6, Hydra 1.3.2, MLflow 3.12.0, Ruff 0.15.14,
+pytest 9.0.3, and Matplotlib 3.10.9.
-## 📄 License
+Command:
-This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
+```bash
+cl-bench suite --config-name split_cifar10_headline --methods baseline ewc replay lwf derpp agem --seeds 13 21 --tracking both --report-dir docs/assets/split_cifar10_headline --title "Split CIFAR-10 Headline Benchmark"
+```
----
+The headline benchmark uses real CIFAR-10 images, five class-incremental tasks,
+2,500 training examples per task, 1,000 test examples per task, two seeds,
+5 epochs per task, a compact residual CIFAR ConvNet, and a 5,000-example replay
+memory budget where applicable. It is a reproducible benchmark, not a paper
+leaderboard claim.
+
+| Method | Average final accuracy | Average forgetting | Mean runtime |
+| --- | ---: | ---: | ---: |
+| DER++ | 51.15% +- 3.95% | 34.06% +- 4.74% | 578.7s |
+| replay | 41.99% +- 0.27% | 45.27% +- 1.73% | 547.4s |
+| LwF | 16.53% +- 0.13% | 76.71% +- 0.09% | 224.3s |
+| A-GEM | 14.37% +- 0.39% | 79.34% +- 0.96% | 516.3s |
+| baseline | 14.06% +- 0.10% | 79.14% +- 1.39% | 181.0s |
+| EWC | 12.12% +- 0.74% | 69.20% +- 3.02% | 223.1s |
+
+Generated report artifacts live in
+[`docs/assets/split_cifar10_headline`](docs/assets/split_cifar10_headline/README.md).
+
+## Architecture
+
+```text
+src/cl_bench/
+ cli.py # run, suite, report, config discovery, and overrides
+ config.py # TaskSpec, ExperimentConfig, BenchmarkResult
+ datasets.py # synthetic, MNIST-family, and CIFAR-10 task construction
+ experiments.py # seeded run orchestration and evaluation loop
+ metrics.py # accuracy, forgetting, transfer, and summary metrics
+ models.py # linear, MLP, small CNN, and CIFAR residual ConvNet factory
+ reporting.py # run aggregation, leaderboard CSV/JSON, and plots
+ tracking.py # JSON/JSONL/CSV artifacts and optional MLflow logging
+ strategies/ # baseline, EWC, replay, LwF, DER++, and A-GEM
+configs/
+ smoke.yaml # fast deterministic CPU benchmark
+ split_mnist_quick.yaml # bounded real MNIST suite for local CPU runs
+ split_mnist.yaml # full five-task MNIST stream
+ split_cifar10_headline.yaml # verified CIFAR-10 benchmark used in the README
+docs/
+ BENCHMARK_CARD.md # scope, metrics, limitations, reproducibility
+tests/ # unit and integration coverage
+```
-
+## Run Artifacts
+
+Each run is written to `runs/__/` and contains:
+
+- `config.yaml`: exact config snapshot.
+- `run_metadata.json`: seed, device, and git commit when available.
+- `metrics.jsonl`: event-level training and evaluation metrics.
+- `metrics.json`: final run summary.
+- `accuracy_matrix.{json,csv}` and `forgetting_matrix.{json,csv}`.
+- Optional `checkpoints/final_model.pt`.
+- Optional MLflow run entries with params, metrics, tags, and artifacts.
+
+The report command writes:
+
+- `leaderboard.csv`
+- `summary.json`
+- `README.md`
+- `leaderboard.png`
+- `retention_curves.png`
+- `accuracy_matrices.png`
+
+Generated `data/`, `runs/`, `results/`, logs, checkpoints, and NumPy arrays are
+ignored by git. Curated README assets under `docs/assets/` are intentionally kept.
+
+## Engineering Notes
+
+- EWC estimates the empirical Fisher from per-sample log-likelihood gradients and
+ normalizes by the actual number of samples used.
+- Replay uses reservoir sampling so a bounded buffer represents the full observed
+ stream instead of only the newest examples.
+- LwF stores a frozen teacher after each task and combines supervised loss with
+ temperature-scaled KL distillation.
+- DER++ stores replay logits online and combines current CE, replay CE, and
+ logit-matching losses.
+- A-GEM projects conflicting gradients against replay-memory reference gradients.
+- Best validation checkpoints are deep-copied before restoration to avoid mutable
+ `state_dict` aliasing bugs.
+- The suite/report layer separates expensive benchmark execution from cheap,
+ repeatable analysis over saved metrics.
diff --git a/configs/smoke.yaml b/configs/smoke.yaml
new file mode 100644
index 0000000..2cdc2e4
--- /dev/null
+++ b/configs/smoke.yaml
@@ -0,0 +1,33 @@
+name: smoke
+method: baseline
+seed: 7
+device: cpu
+model: linear
+data_dir: data
+output_dir: runs
+training:
+ epochs: 1
+ batch_size: 16
+ eval_batch_size: 64
+ learning_rate: 0.05
+ val_fraction: 0.2
+ num_workers: 0
+strategy:
+ ewc_lambda: 10.0
+ fisher_samples: 16
+ replay_buffer_size: 32
+ replay_batch_size: 8
+ replay_loss_weight: 1.0
+ lwf_alpha: 0.5
+ lwf_temperature: 2.0
+tasks:
+ - name: synthetic_0_1
+ dataset: synthetic
+ classes: [0, 1]
+ samples_per_class: 20
+ test_samples_per_class: 10
+ - name: synthetic_2_3
+ dataset: synthetic
+ classes: [2, 3]
+ samples_per_class: 20
+ test_samples_per_class: 10
diff --git a/configs/split_cifar10_headline.yaml b/configs/split_cifar10_headline.yaml
new file mode 100644
index 0000000..c4c1831
--- /dev/null
+++ b/configs/split_cifar10_headline.yaml
@@ -0,0 +1,56 @@
+name: split_cifar10_headline
+method: derpp
+seed: 13
+device: auto
+model: cifar_convnet
+data_dir: data
+output_dir: runs
+tracking:
+ mode: both
+ mlflow_tracking_uri: sqlite:///mlruns/mlflow.db
+ mlflow_experiment: continual-learning-bench
+training:
+ epochs: 5
+ batch_size: 128
+ eval_batch_size: 512
+ learning_rate: 0.001
+ val_fraction: 0.1
+ num_workers: 0
+ augment: true
+strategy:
+ ewc_lambda: 50.0
+ fisher_samples: 256
+ replay_buffer_size: 5000
+ replay_batch_size: 256
+ replay_loss_weight: 3.0
+ lwf_alpha: 0.5
+ lwf_temperature: 2.0
+ derpp_alpha: 0.1
+ derpp_beta: 2.0
+ agem_memory_batch_size: 256
+tasks:
+ - name: cifar10_airplane_automobile
+ dataset: cifar10
+ classes: [0, 1]
+ train_limit: 2500
+ test_limit: 1000
+ - name: cifar10_bird_cat
+ dataset: cifar10
+ classes: [2, 3]
+ train_limit: 2500
+ test_limit: 1000
+ - name: cifar10_deer_dog
+ dataset: cifar10
+ classes: [4, 5]
+ train_limit: 2500
+ test_limit: 1000
+ - name: cifar10_frog_horse
+ dataset: cifar10
+ classes: [6, 7]
+ train_limit: 2500
+ test_limit: 1000
+ - name: cifar10_ship_truck
+ dataset: cifar10
+ classes: [8, 9]
+ train_limit: 2500
+ test_limit: 1000
diff --git a/configs/split_mnist.yaml b/configs/split_mnist.yaml
new file mode 100644
index 0000000..f3996f9
--- /dev/null
+++ b/configs/split_mnist.yaml
@@ -0,0 +1,38 @@
+name: split_mnist
+method: baseline
+seed: 42
+device: auto
+model: small_cnn
+data_dir: data
+output_dir: runs
+training:
+ epochs: 1
+ batch_size: 64
+ eval_batch_size: 256
+ learning_rate: 0.001
+ val_fraction: 0.1
+ num_workers: 2
+strategy:
+ ewc_lambda: 100.0
+ fisher_samples: 256
+ replay_buffer_size: 500
+ replay_batch_size: 64
+ replay_loss_weight: 1.0
+ lwf_alpha: 0.5
+ lwf_temperature: 2.0
+tasks:
+ - name: mnist_0_1
+ dataset: mnist
+ classes: [0, 1]
+ - name: mnist_2_3
+ dataset: mnist
+ classes: [2, 3]
+ - name: mnist_4_5
+ dataset: mnist
+ classes: [4, 5]
+ - name: mnist_6_7
+ dataset: mnist
+ classes: [6, 7]
+ - name: mnist_8_9
+ dataset: mnist
+ classes: [8, 9]
diff --git a/configs/split_mnist_quick.yaml b/configs/split_mnist_quick.yaml
new file mode 100644
index 0000000..1bbc1f4
--- /dev/null
+++ b/configs/split_mnist_quick.yaml
@@ -0,0 +1,48 @@
+name: split_mnist_quick
+method: baseline
+seed: 13
+device: auto
+model: small_cnn
+data_dir: data
+output_dir: runs
+training:
+ epochs: 3
+ batch_size: 128
+ eval_batch_size: 512
+ learning_rate: 0.001
+ val_fraction: 0.1
+ num_workers: 0
+strategy:
+ ewc_lambda: 75.0
+ fisher_samples: 64
+ replay_buffer_size: 400
+ replay_batch_size: 64
+ replay_loss_weight: 1.0
+ lwf_alpha: 0.5
+ lwf_temperature: 2.0
+tasks:
+ - name: mnist_0_1
+ dataset: mnist
+ classes: [0, 1]
+ train_limit: 600
+ test_limit: 300
+ - name: mnist_2_3
+ dataset: mnist
+ classes: [2, 3]
+ train_limit: 600
+ test_limit: 300
+ - name: mnist_4_5
+ dataset: mnist
+ classes: [4, 5]
+ train_limit: 600
+ test_limit: 300
+ - name: mnist_6_7
+ dataset: mnist
+ classes: [6, 7]
+ train_limit: 600
+ test_limit: 300
+ - name: mnist_8_9
+ dataset: mnist
+ classes: [8, 9]
+ train_limit: 600
+ test_limit: 300
diff --git a/data/MNIST/raw/t10k-images-idx3-ubyte b/data/MNIST/raw/t10k-images-idx3-ubyte
deleted file mode 100644
index 1170b2c..0000000
Binary files a/data/MNIST/raw/t10k-images-idx3-ubyte and /dev/null differ
diff --git a/data/MNIST/raw/t10k-images-idx3-ubyte.gz b/data/MNIST/raw/t10k-images-idx3-ubyte.gz
deleted file mode 100644
index 5ace8ea..0000000
Binary files a/data/MNIST/raw/t10k-images-idx3-ubyte.gz and /dev/null differ
diff --git a/data/MNIST/raw/t10k-labels-idx1-ubyte b/data/MNIST/raw/t10k-labels-idx1-ubyte
deleted file mode 100644
index d1c3a97..0000000
Binary files a/data/MNIST/raw/t10k-labels-idx1-ubyte and /dev/null differ
diff --git a/data/MNIST/raw/t10k-labels-idx1-ubyte.gz b/data/MNIST/raw/t10k-labels-idx1-ubyte.gz
deleted file mode 100644
index a7e1415..0000000
Binary files a/data/MNIST/raw/t10k-labels-idx1-ubyte.gz and /dev/null differ
diff --git a/data/MNIST/raw/train-images-idx3-ubyte b/data/MNIST/raw/train-images-idx3-ubyte
deleted file mode 100644
index bbce276..0000000
Binary files a/data/MNIST/raw/train-images-idx3-ubyte and /dev/null differ
diff --git a/data/MNIST/raw/train-images-idx3-ubyte.gz b/data/MNIST/raw/train-images-idx3-ubyte.gz
deleted file mode 100644
index b50e4b6..0000000
Binary files a/data/MNIST/raw/train-images-idx3-ubyte.gz and /dev/null differ
diff --git a/data/MNIST/raw/train-labels-idx1-ubyte b/data/MNIST/raw/train-labels-idx1-ubyte
deleted file mode 100644
index d6b4c5d..0000000
Binary files a/data/MNIST/raw/train-labels-idx1-ubyte and /dev/null differ
diff --git a/data/MNIST/raw/train-labels-idx1-ubyte.gz b/data/MNIST/raw/train-labels-idx1-ubyte.gz
deleted file mode 100644
index 707a576..0000000
Binary files a/data/MNIST/raw/train-labels-idx1-ubyte.gz and /dev/null differ
diff --git a/docs/BENCHMARK_CARD.md b/docs/BENCHMARK_CARD.md
new file mode 100644
index 0000000..d17bbe6
--- /dev/null
+++ b/docs/BENCHMARK_CARD.md
@@ -0,0 +1,45 @@
+# Benchmark Card
+
+## Scope
+
+This project evaluates continual-learning strategies on task streams where a model
+sees one subset of classes at a time and is re-evaluated on all previously seen
+tasks after every training step.
+
+## Implemented Methods
+
+- Baseline sequential fine-tuning.
+- Elastic Weight Consolidation with empirical Fisher estimates.
+- Reservoir replay with a bounded memory budget.
+- Learning without Forgetting with temperature-scaled distillation.
+- DER++ with online replay-logit storage.
+- A-GEM with replay-memory gradient projection.
+
+## Datasets
+
+- `synthetic`: deterministic image-like tensors for fast CI and unit tests.
+- `split_mnist_quick`: real MNIST images with bounded per-task subsets so a full
+ four-method comparison runs on CPU.
+- `split_mnist`: full five-task MNIST stream for longer local experiments.
+- `split_cifar10_headline`: real CIFAR-10 images, five class-incremental tasks,
+ two verification seeds, and a compact residual ConvNet.
+
+## Metrics
+
+- Average final accuracy: mean accuracy over seen tasks after the final task.
+- Average learning accuracy: mean diagonal accuracy after each task is learned.
+- Average forgetting: best previous accuracy minus final accuracy for prior tasks.
+- Backward transfer: final accuracy minus first-learned accuracy on prior tasks.
+
+## Reproducibility
+
+Every run writes `config.yaml`, `run_metadata.json`, event-level `metrics.jsonl`,
+CSV/JSON matrices, final `metrics.json`, and optionally MLflow params, metrics,
+tags, and artifacts. The suite command can aggregate multiple methods and seeds
+into a leaderboard plus report plots.
+
+## Limitations
+
+The Split CIFAR-10 reported result is designed for local reproducibility and
+engineering verification. It is not a leaderboard claim. Serious research
+comparisons should increase seeds, epochs, memory budgets, and dataset coverage.
diff --git a/docs/assets/split_cifar10_headline/README.md b/docs/assets/split_cifar10_headline/README.md
new file mode 100644
index 0000000..27f120e
--- /dev/null
+++ b/docs/assets/split_cifar10_headline/README.md
@@ -0,0 +1,41 @@
+# Split CIFAR-10 Headline Benchmark
+
+Benchmarks: `split_cifar10_headline`
+Runs: `12`
+Tasks per run: `5`
+
+## Leaderboard
+
+| Method | Runs | Seeds | Final accuracy | Forgetting | Backward transfer | Mean runtime |
+| --- | ---: | --- | ---: | ---: | ---: | ---: |
+| derpp | 2 | 13,21 | 51.15 +- 3.95% | 34.06 +- 4.74% | -34.06% | 578.7s |
+| replay | 2 | 13,21 | 41.99 +- 0.27% | 45.27 +- 1.73% | -45.27% | 547.4s |
+| lwf | 2 | 13,21 | 16.53 +- 0.13% | 76.71 +- 0.09% | -76.71% | 224.3s |
+| agem | 2 | 13,21 | 14.37 +- 0.39% | 79.34 +- 0.96% | -79.34% | 516.3s |
+| baseline | 2 | 13,21 | 14.06 +- 0.10% | 79.14 +- 1.39% | -79.14% | 181.0s |
+| ewc | 2 | 13,21 | 12.12 +- 0.74% | 69.20 +- 3.02% | -69.20% | 223.1s |
+
+## Plots
+
+
+
+
+
+
+
+## Source Runs
+
+| Method | Seed | Run directory |
+| --- | ---: | --- |
+| agem | 13 | `runs/split_cifar10_headline_agem_20260525T133221Z` |
+| agem | 21 | `runs/split_cifar10_headline_agem_20260525T141129Z` |
+| baseline | 13 | `runs/split_cifar10_headline_baseline_20260525T130248Z` |
+| baseline | 21 | `runs/split_cifar10_headline_baseline_20260525T134059Z` |
+| derpp | 13 | `runs/split_cifar10_headline_derpp_20260525T132220Z` |
+| derpp | 21 | `runs/split_cifar10_headline_derpp_20260525T140155Z` |
+| ewc | 13 | `runs/split_cifar10_headline_ewc_20260525T130553Z` |
+| ewc | 21 | `runs/split_cifar10_headline_ewc_20260525T134413Z` |
+| lwf | 13 | `runs/split_cifar10_headline_lwf_20260525T131842Z` |
+| lwf | 21 | `runs/split_cifar10_headline_lwf_20260525T135744Z` |
+| replay | 13 | `runs/split_cifar10_headline_replay_20260525T130954Z` |
+| replay | 21 | `runs/split_cifar10_headline_replay_20260525T134756Z` |
diff --git a/docs/assets/split_cifar10_headline/accuracy_matrices.png b/docs/assets/split_cifar10_headline/accuracy_matrices.png
new file mode 100644
index 0000000..2db9218
Binary files /dev/null and b/docs/assets/split_cifar10_headline/accuracy_matrices.png differ
diff --git a/docs/assets/split_cifar10_headline/leaderboard.csv b/docs/assets/split_cifar10_headline/leaderboard.csv
new file mode 100644
index 0000000..17dc817
--- /dev/null
+++ b/docs/assets/split_cifar10_headline/leaderboard.csv
@@ -0,0 +1,7 @@
+method,runs,seeds,average_final_accuracy_mean,average_final_accuracy_std,average_learning_accuracy_mean,average_forgetting_mean,average_forgetting_std,backward_transfer_mean,runtime_seconds_mean
+derpp,2,"13,21",51.150000000000006,3.9499999999999993,78.39999999999999,34.0625,4.737499999999999,-34.0625,578.7151509795076
+replay,2,"13,21",41.99000000000001,0.2699999999999996,78.21,45.275,1.7250000000000014,-45.275,547.4045557500067
+lwf,2,"13,21",16.53,0.13000000000000078,77.9,76.7125,0.08749999999999858,-76.7125,224.33408429198607
+agem,2,"13,21",14.370000000000001,0.39000000000000057,77.84,79.3375,0.9624999999999986,-79.3375,516.2666623959958
+baseline,2,"13,21",14.059999999999999,0.10000000000000053,77.37,79.13749999999999,1.3874999999999957,-79.13749999999999,180.97890222900605
+ewc,2,"13,21",12.12,0.7400000000000002,67.48000000000002,69.19999999999999,3.0249999999999986,-69.19999999999999,223.06707881249895
diff --git a/docs/assets/split_cifar10_headline/leaderboard.png b/docs/assets/split_cifar10_headline/leaderboard.png
new file mode 100644
index 0000000..49f4e7d
Binary files /dev/null and b/docs/assets/split_cifar10_headline/leaderboard.png differ
diff --git a/docs/assets/split_cifar10_headline/retention_curves.png b/docs/assets/split_cifar10_headline/retention_curves.png
new file mode 100644
index 0000000..99c0277
Binary files /dev/null and b/docs/assets/split_cifar10_headline/retention_curves.png differ
diff --git a/docs/assets/split_cifar10_headline/summary.json b/docs/assets/split_cifar10_headline/summary.json
new file mode 100644
index 0000000..f5e8421
--- /dev/null
+++ b/docs/assets/split_cifar10_headline/summary.json
@@ -0,0 +1,408 @@
+{
+ "benchmarks": [
+ "split_cifar10_headline"
+ ],
+ "generated_at_utc": "2026-05-25T14:20:17.367419+00:00",
+ "leaderboard": [
+ {
+ "average_final_accuracy_mean": 51.150000000000006,
+ "average_final_accuracy_std": 3.9499999999999993,
+ "average_forgetting_mean": 34.0625,
+ "average_forgetting_std": 4.737499999999999,
+ "average_learning_accuracy_mean": 78.39999999999999,
+ "backward_transfer_mean": -34.0625,
+ "method": "derpp",
+ "runs": 2,
+ "runtime_seconds_mean": 578.7151509795076,
+ "seeds": "13,21"
+ },
+ {
+ "average_final_accuracy_mean": 41.99000000000001,
+ "average_final_accuracy_std": 0.2699999999999996,
+ "average_forgetting_mean": 45.275,
+ "average_forgetting_std": 1.7250000000000014,
+ "average_learning_accuracy_mean": 78.21,
+ "backward_transfer_mean": -45.275,
+ "method": "replay",
+ "runs": 2,
+ "runtime_seconds_mean": 547.4045557500067,
+ "seeds": "13,21"
+ },
+ {
+ "average_final_accuracy_mean": 16.53,
+ "average_final_accuracy_std": 0.13000000000000078,
+ "average_forgetting_mean": 76.7125,
+ "average_forgetting_std": 0.08749999999999858,
+ "average_learning_accuracy_mean": 77.9,
+ "backward_transfer_mean": -76.7125,
+ "method": "lwf",
+ "runs": 2,
+ "runtime_seconds_mean": 224.33408429198607,
+ "seeds": "13,21"
+ },
+ {
+ "average_final_accuracy_mean": 14.370000000000001,
+ "average_final_accuracy_std": 0.39000000000000057,
+ "average_forgetting_mean": 79.3375,
+ "average_forgetting_std": 0.9624999999999986,
+ "average_learning_accuracy_mean": 77.84,
+ "backward_transfer_mean": -79.3375,
+ "method": "agem",
+ "runs": 2,
+ "runtime_seconds_mean": 516.2666623959958,
+ "seeds": "13,21"
+ },
+ {
+ "average_final_accuracy_mean": 14.059999999999999,
+ "average_final_accuracy_std": 0.10000000000000053,
+ "average_forgetting_mean": 79.13749999999999,
+ "average_forgetting_std": 1.3874999999999957,
+ "average_learning_accuracy_mean": 77.37,
+ "backward_transfer_mean": -79.13749999999999,
+ "method": "baseline",
+ "runs": 2,
+ "runtime_seconds_mean": 180.97890222900605,
+ "seeds": "13,21"
+ },
+ {
+ "average_final_accuracy_mean": 12.12,
+ "average_final_accuracy_std": 0.7400000000000002,
+ "average_forgetting_mean": 69.19999999999999,
+ "average_forgetting_std": 3.0249999999999986,
+ "average_learning_accuracy_mean": 67.48000000000002,
+ "backward_transfer_mean": -69.19999999999999,
+ "method": "ewc",
+ "runs": 2,
+ "runtime_seconds_mean": 223.06707881249895,
+ "seeds": "13,21"
+ }
+ ],
+ "num_runs": 12,
+ "runs": [
+ {
+ "benchmark": "split_cifar10_headline",
+ "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178",
+ "method": "agem",
+ "run_dir": "runs/split_cifar10_headline_agem_20260525T133221Z",
+ "runtime_seconds": 505.9784695839917,
+ "seed": 13,
+ "summary": {
+ "average_final_accuracy": 14.760000000000002,
+ "average_forgetting": 80.3,
+ "average_learning_accuracy": 79.0,
+ "backward_transfer": -80.3,
+ "model": "cifar_convnet",
+ "num_tasks": 5,
+ "replay_batch_size": 256,
+ "replay_buffer_size": 5000,
+ "runtime_seconds": 505.9784695839917,
+ "seed": 13
+ },
+ "task_names": [
+ "cifar10_airplane_automobile",
+ "cifar10_bird_cat",
+ "cifar10_deer_dog",
+ "cifar10_frog_horse",
+ "cifar10_ship_truck"
+ ]
+ },
+ {
+ "benchmark": "split_cifar10_headline",
+ "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178",
+ "method": "agem",
+ "run_dir": "runs/split_cifar10_headline_agem_20260525T141129Z",
+ "runtime_seconds": 526.5548552079999,
+ "seed": 21,
+ "summary": {
+ "average_final_accuracy": 13.98,
+ "average_forgetting": 78.375,
+ "average_learning_accuracy": 76.68,
+ "backward_transfer": -78.375,
+ "model": "cifar_convnet",
+ "num_tasks": 5,
+ "replay_batch_size": 256,
+ "replay_buffer_size": 5000,
+ "runtime_seconds": 526.5548552079999,
+ "seed": 21
+ },
+ "task_names": [
+ "cifar10_airplane_automobile",
+ "cifar10_bird_cat",
+ "cifar10_deer_dog",
+ "cifar10_frog_horse",
+ "cifar10_ship_truck"
+ ]
+ },
+ {
+ "benchmark": "split_cifar10_headline",
+ "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178",
+ "method": "baseline",
+ "run_dir": "runs/split_cifar10_headline_baseline_20260525T130248Z",
+ "runtime_seconds": 175.88738362499862,
+ "seed": 13,
+ "summary": {
+ "average_final_accuracy": 14.16,
+ "average_forgetting": 77.75,
+ "average_learning_accuracy": 76.36,
+ "backward_transfer": -77.75,
+ "model": "cifar_convnet",
+ "num_tasks": 5,
+ "replay_batch_size": 256,
+ "replay_buffer_size": 5000,
+ "runtime_seconds": 175.88738362499862,
+ "seed": 13
+ },
+ "task_names": [
+ "cifar10_airplane_automobile",
+ "cifar10_bird_cat",
+ "cifar10_deer_dog",
+ "cifar10_frog_horse",
+ "cifar10_ship_truck"
+ ]
+ },
+ {
+ "benchmark": "split_cifar10_headline",
+ "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178",
+ "method": "baseline",
+ "run_dir": "runs/split_cifar10_headline_baseline_20260525T134059Z",
+ "runtime_seconds": 186.0704208330135,
+ "seed": 21,
+ "summary": {
+ "average_final_accuracy": 13.959999999999999,
+ "average_forgetting": 80.52499999999999,
+ "average_learning_accuracy": 78.38,
+ "backward_transfer": -80.52499999999999,
+ "model": "cifar_convnet",
+ "num_tasks": 5,
+ "replay_batch_size": 256,
+ "replay_buffer_size": 5000,
+ "runtime_seconds": 186.0704208330135,
+ "seed": 21
+ },
+ "task_names": [
+ "cifar10_airplane_automobile",
+ "cifar10_bird_cat",
+ "cifar10_deer_dog",
+ "cifar10_frog_horse",
+ "cifar10_ship_truck"
+ ]
+ },
+ {
+ "benchmark": "split_cifar10_headline",
+ "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178",
+ "method": "derpp",
+ "run_dir": "runs/split_cifar10_headline_derpp_20260525T132220Z",
+ "runtime_seconds": 592.4993861670082,
+ "seed": 13,
+ "summary": {
+ "average_final_accuracy": 55.1,
+ "average_forgetting": 29.325,
+ "average_learning_accuracy": 78.55999999999999,
+ "backward_transfer": -29.325,
+ "model": "cifar_convnet",
+ "num_tasks": 5,
+ "replay_batch_size": 256,
+ "replay_buffer_size": 5000,
+ "runtime_seconds": 592.4993861670082,
+ "seed": 13
+ },
+ "task_names": [
+ "cifar10_airplane_automobile",
+ "cifar10_bird_cat",
+ "cifar10_deer_dog",
+ "cifar10_frog_horse",
+ "cifar10_ship_truck"
+ ]
+ },
+ {
+ "benchmark": "split_cifar10_headline",
+ "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178",
+ "method": "derpp",
+ "run_dir": "runs/split_cifar10_headline_derpp_20260525T140155Z",
+ "runtime_seconds": 564.930915792007,
+ "seed": 21,
+ "summary": {
+ "average_final_accuracy": 47.2,
+ "average_forgetting": 38.8,
+ "average_learning_accuracy": 78.24,
+ "backward_transfer": -38.8,
+ "model": "cifar_convnet",
+ "num_tasks": 5,
+ "replay_batch_size": 256,
+ "replay_buffer_size": 5000,
+ "runtime_seconds": 564.930915792007,
+ "seed": 21
+ },
+ "task_names": [
+ "cifar10_airplane_automobile",
+ "cifar10_bird_cat",
+ "cifar10_deer_dog",
+ "cifar10_frog_horse",
+ "cifar10_ship_truck"
+ ]
+ },
+ {
+ "benchmark": "split_cifar10_headline",
+ "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178",
+ "method": "ewc",
+ "run_dir": "runs/split_cifar10_headline_ewc_20260525T130553Z",
+ "runtime_seconds": 231.59013450000202,
+ "seed": 13,
+ "summary": {
+ "average_final_accuracy": 12.86,
+ "average_forgetting": 72.225,
+ "average_learning_accuracy": 70.64000000000001,
+ "backward_transfer": -72.225,
+ "model": "cifar_convnet",
+ "num_tasks": 5,
+ "replay_batch_size": 256,
+ "replay_buffer_size": 5000,
+ "runtime_seconds": 231.59013450000202,
+ "seed": 13
+ },
+ "task_names": [
+ "cifar10_airplane_automobile",
+ "cifar10_bird_cat",
+ "cifar10_deer_dog",
+ "cifar10_frog_horse",
+ "cifar10_ship_truck"
+ ]
+ },
+ {
+ "benchmark": "split_cifar10_headline",
+ "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178",
+ "method": "ewc",
+ "run_dir": "runs/split_cifar10_headline_ewc_20260525T134413Z",
+ "runtime_seconds": 214.54402312499587,
+ "seed": 21,
+ "summary": {
+ "average_final_accuracy": 11.379999999999999,
+ "average_forgetting": 66.175,
+ "average_learning_accuracy": 64.32000000000001,
+ "backward_transfer": -66.175,
+ "model": "cifar_convnet",
+ "num_tasks": 5,
+ "replay_batch_size": 256,
+ "replay_buffer_size": 5000,
+ "runtime_seconds": 214.54402312499587,
+ "seed": 21
+ },
+ "task_names": [
+ "cifar10_airplane_automobile",
+ "cifar10_bird_cat",
+ "cifar10_deer_dog",
+ "cifar10_frog_horse",
+ "cifar10_ship_truck"
+ ]
+ },
+ {
+ "benchmark": "split_cifar10_headline",
+ "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178",
+ "method": "lwf",
+ "run_dir": "runs/split_cifar10_headline_lwf_20260525T131842Z",
+ "runtime_seconds": 210.86184020899236,
+ "seed": 13,
+ "summary": {
+ "average_final_accuracy": 16.4,
+ "average_forgetting": 76.625,
+ "average_learning_accuracy": 77.7,
+ "backward_transfer": -76.625,
+ "model": "cifar_convnet",
+ "num_tasks": 5,
+ "replay_batch_size": 256,
+ "replay_buffer_size": 5000,
+ "runtime_seconds": 210.86184020899236,
+ "seed": 13
+ },
+ "task_names": [
+ "cifar10_airplane_automobile",
+ "cifar10_bird_cat",
+ "cifar10_deer_dog",
+ "cifar10_frog_horse",
+ "cifar10_ship_truck"
+ ]
+ },
+ {
+ "benchmark": "split_cifar10_headline",
+ "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178",
+ "method": "lwf",
+ "run_dir": "runs/split_cifar10_headline_lwf_20260525T135744Z",
+ "runtime_seconds": 237.80632837497978,
+ "seed": 21,
+ "summary": {
+ "average_final_accuracy": 16.66,
+ "average_forgetting": 76.8,
+ "average_learning_accuracy": 78.1,
+ "backward_transfer": -76.8,
+ "model": "cifar_convnet",
+ "num_tasks": 5,
+ "replay_batch_size": 256,
+ "replay_buffer_size": 5000,
+ "runtime_seconds": 237.80632837497978,
+ "seed": 21
+ },
+ "task_names": [
+ "cifar10_airplane_automobile",
+ "cifar10_bird_cat",
+ "cifar10_deer_dog",
+ "cifar10_frog_horse",
+ "cifar10_ship_truck"
+ ]
+ },
+ {
+ "benchmark": "split_cifar10_headline",
+ "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178",
+ "method": "replay",
+ "run_dir": "runs/split_cifar10_headline_replay_20260525T130954Z",
+ "runtime_seconds": 519.1472583750146,
+ "seed": 13,
+ "summary": {
+ "average_final_accuracy": 42.260000000000005,
+ "average_forgetting": 43.55,
+ "average_learning_accuracy": 77.1,
+ "backward_transfer": -43.55,
+ "model": "cifar_convnet",
+ "num_tasks": 5,
+ "replay_batch_size": 256,
+ "replay_buffer_size": 5000,
+ "runtime_seconds": 519.1472583750146,
+ "seed": 13
+ },
+ "task_names": [
+ "cifar10_airplane_automobile",
+ "cifar10_bird_cat",
+ "cifar10_deer_dog",
+ "cifar10_frog_horse",
+ "cifar10_ship_truck"
+ ]
+ },
+ {
+ "benchmark": "split_cifar10_headline",
+ "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178",
+ "method": "replay",
+ "run_dir": "runs/split_cifar10_headline_replay_20260525T134756Z",
+ "runtime_seconds": 575.6618531249987,
+ "seed": 21,
+ "summary": {
+ "average_final_accuracy": 41.720000000000006,
+ "average_forgetting": 47.0,
+ "average_learning_accuracy": 79.32,
+ "backward_transfer": -47.0,
+ "model": "cifar_convnet",
+ "num_tasks": 5,
+ "replay_batch_size": 256,
+ "replay_buffer_size": 5000,
+ "runtime_seconds": 575.6618531249987,
+ "seed": 21
+ },
+ "task_names": [
+ "cifar10_airplane_automobile",
+ "cifar10_bird_cat",
+ "cifar10_deer_dog",
+ "cifar10_frog_horse",
+ "cifar10_ship_truck"
+ ]
+ }
+ ],
+ "title": "Split CIFAR-10 Headline Benchmark"
+}
diff --git a/experiments/split_mnist_comparison.ipynb b/experiments/split_mnist_comparison.ipynb
deleted file mode 100644
index 03d5c53..0000000
--- a/experiments/split_mnist_comparison.ipynb
+++ /dev/null
@@ -1,266 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Split MNIST Continual Learning Comparison\n",
- "\n",
- "This notebook demonstrates a comparison of different continual learning methods on the Split MNIST benchmark.\n",
- "\n",
- "The task sequence consists of two tasks:\n",
- "- **Task 1**: MNIST classes 0–4\n",
- "- **Task 2**: MNIST classes 5–9\n",
- "\n",
- "Methods compared:\n",
- "- **Baseline**: Naive fine-tuning (demonstrates catastrophic forgetting)\n",
- "- **EWC**: Elastic Weight Consolidation\n",
- "- **Replay**: Experience Replay\n",
- "- **LwF**: Learning without Forgetting"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import os\n",
- "import sys\n",
- "\n",
- "# Ensure the repository root is on the Python path\n",
- "repo_root = os.path.abspath(os.path.join(os.getcwd(), '..'))\n",
- "if repo_root not in sys.path:\n",
- " sys.path.insert(0, repo_root)\n",
- "\n",
- "import numpy as np\n",
- "import torch\n",
- "import matplotlib.pyplot as plt\n",
- "\n",
- "from src.data.data_loader import get_task_sequence\n",
- "from src.models.model_factory import get_model\n",
- "from src.methods.baseline import BaselineLearner\n",
- "from src.methods.ewc import EWCLearner\n",
- "from src.methods.replay import ExperienceReplayLearner\n",
- "from src.methods.lwf import LwFLearner\n",
- "from src.utils.metrics import evaluate_performance, compute_forgetting\n",
- "from src.utils.visualization import plot_performance, plot_forgetting\n",
- "\n",
- "print('Imports successful')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Configuration\n",
- "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
- "BATCH_SIZE = 64\n",
- "EPOCHS = 3\n",
- "LEARNING_RATE = 0.001\n",
- "SEED = 42\n",
- "\n",
- "# Set random seed for reproducibility\n",
- "torch.manual_seed(SEED)\n",
- "np.random.seed(SEED)\n",
- "\n",
- "print(f'Using device: {DEVICE}')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Define the Split MNIST task sequence\n",
- "task_sequence = [\n",
- " {'name': 'mnist_0_4', 'dataset': 'mnist', 'classes': [0, 1, 2, 3, 4]},\n",
- " {'name': 'mnist_5_9', 'dataset': 'mnist', 'classes': [5, 6, 7, 8, 9]},\n",
- "]\n",
- "\n",
- "# Load data\n",
- "task_data = get_task_sequence(task_sequence, BATCH_SIZE)\n",
- "\n",
- "# Determine input shape and number of classes\n",
- "input_shape = task_data[0]['train_loader'].dataset[0][0].shape\n",
- "all_classes = set()\n",
- "for task in task_sequence:\n",
- " all_classes.update(task['classes'])\n",
- "num_classes = len(all_classes)\n",
- "\n",
- "print(f'Input shape: {input_shape}')\n",
- "print(f'Number of classes: {num_classes}')\n",
- "print(f'Tasks: {[t[\"name\"] for t in task_sequence]}')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def run_experiment(learner, task_data, task_sequence):\n",
- " \"\"\"Train a learner sequentially on all tasks and record performance.\"\"\"\n",
- " n_tasks = len(task_sequence)\n",
- " performance_matrix = np.zeros((n_tasks, n_tasks))\n",
- "\n",
- " for task_id, task_dict in enumerate(task_data):\n",
- " train_loader = task_dict['train_loader']\n",
- " val_loader = task_dict['val_loader']\n",
- " print(f' Training on task {task_id + 1}/{n_tasks}: {task_sequence[task_id][\"name\"]}')\n",
- "\n",
- " learner.train(\n",
- " train_loader=train_loader,\n",
- " val_loader=val_loader,\n",
- " task_id=task_id,\n",
- " epochs=EPOCHS,\n",
- " eval_freq=1,\n",
- " )\n",
- "\n",
- " for eval_task_id in range(task_id + 1):\n",
- " eval_loader = task_data[eval_task_id]['test_loader']\n",
- " accuracy = learner.evaluate(eval_loader, eval_task_id)\n",
- " performance_matrix[task_id, eval_task_id] = accuracy\n",
- " print(f' Accuracy on task {eval_task_id + 1}: {accuracy:.2f}%')\n",
- "\n",
- " return performance_matrix"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "results = {}\n",
- "\n",
- "# --- Baseline ---\n",
- "print('=== Baseline (Fine-tuning) ===')\n",
- "model = get_model('simple_cnn', input_shape, num_classes).to(DEVICE)\n",
- "learner = BaselineLearner(model=model, device=DEVICE, learning_rate=LEARNING_RATE)\n",
- "results['baseline'] = run_experiment(learner, task_data, task_sequence)\n",
- "\n",
- "print()\n",
- "\n",
- "# --- EWC ---\n",
- "print('=== EWC ===')\n",
- "model = get_model('simple_cnn', input_shape, num_classes).to(DEVICE)\n",
- "learner = EWCLearner(model=model, device=DEVICE, learning_rate=LEARNING_RATE,\n",
- " lambda_ewc=5000, fisher_sample_size=200)\n",
- "results['ewc'] = run_experiment(learner, task_data, task_sequence)\n",
- "\n",
- "print()\n",
- "\n",
- "# --- Experience Replay ---\n",
- "print('=== Experience Replay ===')\n",
- "model = get_model('simple_cnn', input_shape, num_classes).to(DEVICE)\n",
- "learner = ExperienceReplayLearner(model=model, device=DEVICE, learning_rate=LEARNING_RATE,\n",
- " buffer_size=500, replay_batch_size=32)\n",
- "results['replay'] = run_experiment(learner, task_data, task_sequence)\n",
- "\n",
- "print()\n",
- "\n",
- "# --- LwF ---\n",
- "print('=== LwF ===')\n",
- "model = get_model('simple_cnn', input_shape, num_classes).to(DEVICE)\n",
- "learner = LwFLearner(model=model, device=DEVICE, learning_rate=LEARNING_RATE,\n",
- " temperature=2.0, alpha=1.0)\n",
- "results['lwf'] = run_experiment(learner, task_data, task_sequence)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "task_names = [t['name'] for t in task_sequence]\n",
- "\n",
- "# Summary table\n",
- "print('\\n=== Final Accuracy on Each Task After All Training ===')\n",
- "print(f'{\"Method\":<12}', end='')\n",
- "for name in task_names:\n",
- " print(f'{name:>16}', end='')\n",
- "print(f'{\"Average\":>12}')\n",
- "\n",
- "for method, perf in results.items():\n",
- " final_row = perf[-1, :len(task_names)]\n",
- " print(f'{method:<12}', end='')\n",
- " for acc in final_row:\n",
- " print(f'{acc:>15.2f}%', end='')\n",
- " print(f'{np.mean(final_row):>11.2f}%')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Plot performance matrices\n",
- "fig, axes = plt.subplots(2, 2, figsize=(14, 10))\n",
- "method_names = list(results.keys())\n",
- "\n",
- "for ax, (method, perf) in zip(axes.flatten(), results.items()):\n",
- " im = ax.imshow(perf, vmin=0, vmax=100, cmap='Blues', aspect='auto')\n",
- " ax.set_title(method.upper())\n",
- " ax.set_xlabel('Task evaluated on')\n",
- " ax.set_ylabel('After training task')\n",
- " ax.set_xticks(range(len(task_names)))\n",
- " ax.set_yticks(range(len(task_names)))\n",
- " ax.set_xticklabels(task_names, rotation=30, ha='right')\n",
- " ax.set_yticklabels(task_names)\n",
- " for i in range(len(task_names)):\n",
- " for j in range(len(task_names)):\n",
- " if perf[i, j] > 0:\n",
- " ax.text(j, i, f'{perf[i, j]:.1f}', ha='center', va='center', fontsize=9)\n",
- " plt.colorbar(im, ax=ax)\n",
- "\n",
- "plt.suptitle('Accuracy (%) — Split MNIST Comparison', fontsize=14)\n",
- "plt.tight_layout()\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Average final accuracy comparison bar chart\n",
- "avg_accuracies = {\n",
- " method: np.mean(perf[-1, :len(task_names)])\n",
- " for method, perf in results.items()\n",
- "}\n",
- "\n",
- "plt.figure(figsize=(8, 5))\n",
- "bars = plt.bar(avg_accuracies.keys(), avg_accuracies.values(),\n",
- " color=['#4c72b0', '#dd8452', '#55a868', '#c44e52'])\n",
- "for bar, val in zip(bars, avg_accuracies.values()):\n",
- " plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,\n",
- " f'{val:.1f}%', ha='center', va='bottom')\n",
- "plt.ylim(0, 110)\n",
- "plt.ylabel('Average Accuracy (%)')\n",
- "plt.title('Average Final Accuracy — Split MNIST')\n",
- "plt.tight_layout()\n",
- "plt.show()"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "name": "python",
- "version": "3.8.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..da07a1a
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,86 @@
+[build-system]
+requires = ["hatchling>=1.24"]
+build-backend = "hatchling.build"
+
+[project]
+name = "continual-learning-bench"
+version = "0.3.0"
+description = "A PyTorch benchmark framework for reproducible continual-learning experiments."
+readme = "README.md"
+requires-python = ">=3.10"
+license = { file = "License" }
+authors = [{ name = "Utkarsh Rajput" }]
+keywords = ["continual-learning", "pytorch", "benchmark", "machine-learning"]
+classifiers = [
+ "Development Status :: 4 - Beta",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: MIT License",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+]
+dependencies = [
+ "numpy>=2.0",
+ "pyyaml>=6.0.3",
+ "torch>=2.6",
+ "torchvision>=0.21",
+]
+
+[project.optional-dependencies]
+dev = [
+ "build>=1.2",
+ "pytest>=8.3",
+ "pytest-cov>=6.2",
+ "ruff>=0.12",
+]
+experiment = [
+ "hydra-core>=1.3.2",
+ "mlflow>=3.12,<4",
+ "omegaconf>=2.3",
+ "pandas>=2.2",
+]
+report = [
+ "matplotlib>=3.10",
+]
+
+[project.scripts]
+cl-bench = "cl_bench.cli:main"
+
+[project.urls]
+Homepage = "https://github.com/1Utkarsh1/Continual-Learning"
+Repository = "https://github.com/1Utkarsh1/Continual-Learning"
+
+[tool.hatch.build.targets.wheel]
+packages = ["src/cl_bench"]
+
+[tool.pytest.ini_options]
+addopts = "-q"
+testpaths = ["tests"]
+
+[tool.ruff]
+line-length = 100
+target-version = "py310"
+
+[tool.ruff.lint]
+select = ["B", "E", "F", "I", "SIM", "UP"]
+ignore = ["E501"]
+
+[tool.ruff.format]
+quote-style = "double"
+indent-style = "space"
+line-ending = "auto"
+
+[tool.coverage.run]
+branch = true
+source = ["cl_bench"]
+
+[tool.coverage.report]
+show_missing = true
+skip_covered = true
+exclude_also = [
+ "if __name__ == .__main__.:",
+ "raise NotImplementedError",
+]
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index fe0b979..0000000
--- a/requirements.txt
+++ /dev/null
@@ -1,26 +0,0 @@
-# Core Deep Learning Frameworks
-torch
-torchvision
-numpy
-tqdm
-
-# Data Processing
-pandas
-scikit-learn
-pillow
-
-# Visualization
-matplotlib
-seaborn
-
-# Experiment Tracking
-tensorboard
-jupyterlab
-ipywidgets
-
-# Utilities
-pyyaml
-tabulate
-
-# Testing (Optional)
-pytest
\ No newline at end of file
diff --git a/results/baseline_mnist_split_20250311_000335/final_model.pt b/results/baseline_mnist_split_20250311_000335/final_model.pt
deleted file mode 100644
index 4569ece..0000000
Binary files a/results/baseline_mnist_split_20250311_000335/final_model.pt and /dev/null differ
diff --git a/results/baseline_mnist_split_20250311_000335/forgetting.png b/results/baseline_mnist_split_20250311_000335/forgetting.png
deleted file mode 100644
index 2576af8..0000000
Binary files a/results/baseline_mnist_split_20250311_000335/forgetting.png and /dev/null differ
diff --git a/results/baseline_mnist_split_20250311_000335/forgetting_matrix.npy b/results/baseline_mnist_split_20250311_000335/forgetting_matrix.npy
deleted file mode 100644
index 8cff948..0000000
Binary files a/results/baseline_mnist_split_20250311_000335/forgetting_matrix.npy and /dev/null differ
diff --git a/results/baseline_mnist_split_20250311_000335/performance.png b/results/baseline_mnist_split_20250311_000335/performance.png
deleted file mode 100644
index 58de4f3..0000000
Binary files a/results/baseline_mnist_split_20250311_000335/performance.png and /dev/null differ
diff --git a/results/baseline_mnist_split_20250311_000335/performance_matrix.npy b/results/baseline_mnist_split_20250311_000335/performance_matrix.npy
deleted file mode 100644
index 3977f84..0000000
Binary files a/results/baseline_mnist_split_20250311_000335/performance_matrix.npy and /dev/null differ
diff --git a/results/logs/continual_learning_20250310_232605.log b/results/logs/continual_learning_20250310_232605.log
deleted file mode 100644
index 003f125..0000000
--- a/results/logs/continual_learning_20250310_232605.log
+++ /dev/null
@@ -1,40 +0,0 @@
-2025-03-10 23:26:05,445 - __main__ - INFO - Starting Continual Learning experiment
-2025-03-10 23:26:05,445 - __main__ - INFO - Arguments: Namespace(method='baseline', tasks='mnist_split', model='simple_cnn', epochs=5, batch_size=64, learning_rate=0.001, lambda_ewc=5000, fisher_sample_size=200, buffer_size=500, replay_batch_size=32, temperature=2.0, alpha=1.0, seed=42, device=None, eval_freq=1, save_dir='results')
-2025-03-10 23:26:05,448 - __main__ - INFO - Using device: cpu
-2025-03-10 23:26:05,449 - __main__ - INFO - Task sequence: ['mnist_0_4', 'mnist_5_9']
-2025-03-10 23:26:05,449 - src.data.data_loader - INFO - Loading task mnist_0_4 (dataset: mnist, classes: [0, 1, 2, 3, 4])
-2025-03-10 23:26:33,270 - src.data.data_loader - INFO - Loaded task mnist_0_4 with 27537 training, 3059 validation, and 5139 test samples
-2025-03-10 23:26:33,271 - src.data.data_loader - INFO - Loading task mnist_5_9 (dataset: mnist, classes: [5, 6, 7, 8, 9])
-2025-03-10 23:26:33,316 - src.data.data_loader - INFO - Loaded task mnist_5_9 with 26464 training, 2940 validation, and 4861 test samples
-2025-03-10 23:26:33,325 - src.models.model_factory - INFO - Created simple_cnn model with input shape torch.Size([1, 28, 28]) and 5 output classes
-2025-03-10 23:26:33,326 - __main__ - INFO - Starting training on task 1/2: mnist_0_4
-2025-03-10 23:27:31,606 - src.methods.baseline - INFO - Task 1 - Epoch 1/5: Train Loss: 0.1076, Train Acc: 96.27%, Val Loss: 0.0239, Val Acc: 99.15%
-2025-03-10 23:28:29,398 - src.methods.baseline - INFO - Task 1 - Epoch 2/5: Train Loss: 0.0236, Train Acc: 99.28%, Val Loss: 0.0255, Val Acc: 99.28%
-2025-03-10 23:29:27,271 - src.methods.baseline - INFO - Task 1 - Epoch 3/5: Train Loss: 0.0148, Train Acc: 99.53%, Val Loss: 0.0198, Val Acc: 99.44%
-2025-03-10 23:30:24,558 - src.methods.baseline - INFO - Task 1 - Epoch 4/5: Train Loss: 0.0119, Train Acc: 99.63%, Val Loss: 0.0255, Val Acc: 99.18%
-2025-03-10 23:31:22,322 - src.methods.baseline - INFO - Task 1 - Epoch 5/5: Train Loss: 0.0083, Train Acc: 99.75%, Val Loss: 0.0211, Val Acc: 99.58%
-2025-03-10 23:31:22,322 - src.methods.baseline - INFO - Loaded best model for task 1 with validation loss: 0.0198
-2025-03-10 23:31:38,202 - __main__ - INFO - After task 1, accuracy on task 1: 99.86%
-2025-03-10 23:31:38,202 - __main__ - INFO - Starting training on task 2/2: mnist_5_9
-2025-03-10 23:31:51,157 - __main__ - ERROR - Experiment failed with error: Target 8 is out of bounds.
-Traceback (most recent call last):
- File "C:\Users\kumar\Desktop\github\continual_learning\src\main.py", line 302, in main
- run_continual_learning(args, logger)
- File "C:\Users\kumar\Desktop\github\continual_learning\src\main.py", line 245, in run_continual_learning
- learner.train(
- File "C:\Users\kumar\Desktop\github\continual_learning\src\methods\baseline.py", line 92, in train
- loss = self.criterion(outputs, targets)
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- File "C:\Users\kumar\AppData\Roaming\Python\Python312\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
- return self._call_impl(*args, **kwargs)
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- File "C:\Users\kumar\AppData\Roaming\Python\Python312\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
- return forward_call(*args, **kwargs)
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- File "C:\Users\kumar\AppData\Roaming\Python\Python312\site-packages\torch\nn\modules\loss.py", line 1188, in forward
- return F.cross_entropy(input, target, weight=self.weight,
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- File "C:\Users\kumar\AppData\Roaming\Python\Python312\site-packages\torch\nn\functional.py", line 3104, in cross_entropy
- return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-IndexError: Target 8 is out of bounds.
diff --git a/results/logs/continual_learning_20250311_000335.log b/results/logs/continual_learning_20250311_000335.log
deleted file mode 100644
index 6f9f3b0..0000000
--- a/results/logs/continual_learning_20250311_000335.log
+++ /dev/null
@@ -1,35 +0,0 @@
-2025-03-11 00:03:35,631 - __main__ - INFO - Starting Continual Learning experiment
-2025-03-11 00:03:35,631 - __main__ - INFO - Arguments: Namespace(method='baseline', tasks='mnist_split', model='simple_cnn', epochs=5, batch_size=64, learning_rate=0.001, lambda_ewc=5000, fisher_sample_size=200, buffer_size=500, replay_batch_size=32, temperature=2.0, alpha=1.0, seed=42, device=None, eval_freq=1, save_dir='results')
-2025-03-11 00:03:35,634 - __main__ - INFO - Using device: cpu
-2025-03-11 00:03:35,634 - __main__ - INFO - Task sequence: ['mnist_0_4', 'mnist_5_9']
-2025-03-11 00:03:35,634 - src.data.data_loader - INFO - Loading task mnist_0_4 (dataset: mnist, classes: [0, 1, 2, 3, 4])
-2025-03-11 00:03:35,697 - src.data.data_loader - INFO - Loaded task mnist_0_4 with 27537 training, 3059 validation, and 5139 test samples
-2025-03-11 00:03:35,698 - src.data.data_loader - INFO - Loading task mnist_5_9 (dataset: mnist, classes: [5, 6, 7, 8, 9])
-2025-03-11 00:03:35,749 - src.data.data_loader - INFO - Loaded task mnist_5_9 with 26464 training, 2940 validation, and 4861 test samples
-2025-03-11 00:03:35,755 - src.models.model_factory - INFO - Created simple_cnn model with input shape torch.Size([1, 28, 28]) and 10 output classes
-2025-03-11 00:03:35,756 - __main__ - INFO - Starting training on task 1/2: mnist_0_4
-2025-03-11 00:04:32,911 - src.methods.baseline - INFO - Task 1 - Epoch 1/5: Train Loss: 0.1474, Train Acc: 94.77%, Val Loss: 0.0358, Val Acc: 98.79%
-2025-03-11 00:05:29,314 - src.methods.baseline - INFO - Task 1 - Epoch 2/5: Train Loss: 0.0251, Train Acc: 99.26%, Val Loss: 0.0307, Val Acc: 99.02%
-2025-03-11 00:06:25,855 - src.methods.baseline - INFO - Task 1 - Epoch 3/5: Train Loss: 0.0152, Train Acc: 99.55%, Val Loss: 0.0241, Val Acc: 99.64%
-2025-03-11 00:07:22,073 - src.methods.baseline - INFO - Task 1 - Epoch 4/5: Train Loss: 0.0118, Train Acc: 99.67%, Val Loss: 0.0301, Val Acc: 99.28%
-2025-03-11 00:08:18,914 - src.methods.baseline - INFO - Task 1 - Epoch 5/5: Train Loss: 0.0099, Train Acc: 99.71%, Val Loss: 0.0186, Val Acc: 99.51%
-2025-03-11 00:08:18,915 - src.methods.baseline - INFO - Loaded best model for task 1 with validation loss: 0.0186
-2025-03-11 00:08:33,903 - __main__ - INFO - After task 1, accuracy on task 1: 99.81%
-2025-03-11 00:08:33,904 - __main__ - INFO - Starting training on task 2/2: mnist_5_9
-2025-03-11 00:09:28,940 - src.methods.baseline - INFO - Task 2 - Epoch 1/5: Train Loss: 0.4041, Train Acc: 90.08%, Val Loss: 0.0462, Val Acc: 98.33%
-2025-03-11 00:10:24,280 - src.methods.baseline - INFO - Task 2 - Epoch 2/5: Train Loss: 0.0377, Train Acc: 98.87%, Val Loss: 0.0326, Val Acc: 98.88%
-2025-03-11 00:11:19,652 - src.methods.baseline - INFO - Task 2 - Epoch 3/5: Train Loss: 0.0294, Train Acc: 99.04%, Val Loss: 0.0263, Val Acc: 99.39%
-2025-03-11 00:12:14,864 - src.methods.baseline - INFO - Task 2 - Epoch 4/5: Train Loss: 0.0178, Train Acc: 99.44%, Val Loss: 0.0326, Val Acc: 99.18%
-2025-03-11 00:13:10,415 - src.methods.baseline - INFO - Task 2 - Epoch 5/5: Train Loss: 0.0166, Train Acc: 99.49%, Val Loss: 0.0328, Val Acc: 99.15%
-2025-03-11 00:13:10,416 - src.methods.baseline - INFO - Loaded best model for task 2 with validation loss: 0.0263
-2025-03-11 00:13:25,385 - __main__ - INFO - After task 2, accuracy on task 1: 0.00%
-2025-03-11 00:13:40,495 - __main__ - INFO - After task 2, accuracy on task 2: 99.49%
-2025-03-11 00:13:40,507 - src.methods.baseline - INFO - Model saved to C:\Users\kumar\Desktop\github\continual_learning\results\baseline_mnist_split_20250311_000335\final_model.pt
-2025-03-11 00:13:41,358 - __main__ - INFO -
-Experiment summary:
-2025-03-11 00:13:41,358 - __main__ - INFO - Method: baseline
-2025-03-11 00:13:41,359 - __main__ - INFO - Task sequence: ['mnist_0_4', 'mnist_5_9']
-2025-03-11 00:13:41,359 - __main__ - INFO - Average final accuracy: 49.74%
-2025-03-11 00:13:41,359 - __main__ - INFO - Average forgetting: 49.90%
-2025-03-11 00:13:41,359 - __main__ - INFO - Results saved to: C:\Users\kumar\Desktop\github\continual_learning\results\baseline_mnist_split_20250311_000335
-2025-03-11 00:13:41,380 - __main__ - INFO - Experiment completed successfully
diff --git a/results/logs/continual_learning_20250311_001410.log b/results/logs/continual_learning_20250311_001410.log
deleted file mode 100644
index bb736b1..0000000
--- a/results/logs/continual_learning_20250311_001410.log
+++ /dev/null
@@ -1,23 +0,0 @@
-2025-03-11 00:14:10,411 - __main__ - INFO - Starting Continual Learning experiment
-2025-03-11 00:14:10,411 - __main__ - INFO - Arguments: Namespace(method='lwf', tasks='mnist_split', model='simple_cnn', epochs=5, batch_size=64, learning_rate=0.001, lambda_ewc=5000, fisher_sample_size=200, buffer_size=500, replay_batch_size=32, temperature=2.0, alpha=1.0, seed=42, device=None, eval_freq=1, save_dir='results')
-2025-03-11 00:14:10,414 - __main__ - INFO - Using device: cpu
-2025-03-11 00:14:10,415 - __main__ - INFO - Task sequence: ['mnist_0_4', 'mnist_5_9']
-2025-03-11 00:14:10,415 - src.data.data_loader - INFO - Loading task mnist_0_4 (dataset: mnist, classes: [0, 1, 2, 3, 4])
-2025-03-11 00:14:10,483 - src.data.data_loader - INFO - Loaded task mnist_0_4 with 27537 training, 3059 validation, and 5139 test samples
-2025-03-11 00:14:10,484 - src.data.data_loader - INFO - Loading task mnist_5_9 (dataset: mnist, classes: [5, 6, 7, 8, 9])
-2025-03-11 00:14:10,535 - src.data.data_loader - INFO - Loaded task mnist_5_9 with 26464 training, 2940 validation, and 4861 test samples
-2025-03-11 00:14:10,540 - src.models.model_factory - INFO - Created simple_cnn model with input shape torch.Size([1, 28, 28]) and 10 output classes
-2025-03-11 00:14:10,541 - __main__ - INFO - Starting training on task 1/2: mnist_0_4
-2025-03-11 00:17:51,424 - __main__ - INFO - After task 1, accuracy on task 1: 99.67%
-2025-03-11 00:17:51,424 - __main__ - INFO - Starting training on task 2/2: mnist_5_9
-2025-03-11 00:22:30,045 - __main__ - INFO - After task 2, accuracy on task 1: 0.00%
-2025-03-11 00:22:45,006 - __main__ - INFO - After task 2, accuracy on task 2: 99.34%
-2025-03-11 00:22:45,013 - src.methods.baseline - INFO - Model saved to C:\Users\kumar\Desktop\github\continual_learning\results\lwf_mnist_split_20250311_001410\final_model.pt
-2025-03-11 00:22:45,578 - __main__ - INFO -
-Experiment summary:
-2025-03-11 00:22:45,578 - __main__ - INFO - Method: lwf
-2025-03-11 00:22:45,578 - __main__ - INFO - Task sequence: ['mnist_0_4', 'mnist_5_9']
-2025-03-11 00:22:45,579 - __main__ - INFO - Average final accuracy: 49.67%
-2025-03-11 00:22:45,579 - __main__ - INFO - Average forgetting: 49.83%
-2025-03-11 00:22:45,579 - __main__ - INFO - Results saved to: C:\Users\kumar\Desktop\github\continual_learning\results\lwf_mnist_split_20250311_001410
-2025-03-11 00:22:45,594 - __main__ - INFO - Experiment completed successfully
diff --git a/results/lwf_mnist_split_20250311_001410/final_model.pt b/results/lwf_mnist_split_20250311_001410/final_model.pt
deleted file mode 100644
index c10c8d1..0000000
Binary files a/results/lwf_mnist_split_20250311_001410/final_model.pt and /dev/null differ
diff --git a/results/lwf_mnist_split_20250311_001410/forgetting.png b/results/lwf_mnist_split_20250311_001410/forgetting.png
deleted file mode 100644
index 33065df..0000000
Binary files a/results/lwf_mnist_split_20250311_001410/forgetting.png and /dev/null differ
diff --git a/results/lwf_mnist_split_20250311_001410/forgetting_matrix.npy b/results/lwf_mnist_split_20250311_001410/forgetting_matrix.npy
deleted file mode 100644
index e90967a..0000000
Binary files a/results/lwf_mnist_split_20250311_001410/forgetting_matrix.npy and /dev/null differ
diff --git a/results/lwf_mnist_split_20250311_001410/performance.png b/results/lwf_mnist_split_20250311_001410/performance.png
deleted file mode 100644
index 54e0635..0000000
Binary files a/results/lwf_mnist_split_20250311_001410/performance.png and /dev/null differ
diff --git a/results/lwf_mnist_split_20250311_001410/performance_matrix.npy b/results/lwf_mnist_split_20250311_001410/performance_matrix.npy
deleted file mode 100644
index 7b431a9..0000000
Binary files a/results/lwf_mnist_split_20250311_001410/performance_matrix.npy and /dev/null differ
diff --git a/src/__init__.py b/src/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/src/cl_bench/__init__.py b/src/cl_bench/__init__.py
new file mode 100644
index 0000000..3cba1fc
--- /dev/null
+++ b/src/cl_bench/__init__.py
@@ -0,0 +1,6 @@
+"""Continual-learning benchmark framework."""
+
+from cl_bench.config import BenchmarkResult, ExperimentConfig, TaskSpec
+
+__all__ = ["BenchmarkResult", "ExperimentConfig", "TaskSpec"]
+__version__ = "0.3.0"
diff --git a/src/cl_bench/__main__.py b/src/cl_bench/__main__.py
new file mode 100644
index 0000000..1e145c0
--- /dev/null
+++ b/src/cl_bench/__main__.py
@@ -0,0 +1,4 @@
+from cl_bench.cli import main
+
+if __name__ == "__main__":
+ raise SystemExit(main())
diff --git a/src/cl_bench/cli.py b/src/cl_bench/cli.py
new file mode 100644
index 0000000..b4da137
--- /dev/null
+++ b/src/cl_bench/cli.py
@@ -0,0 +1,191 @@
+from __future__ import annotations
+
+import argparse
+from collections.abc import Sequence
+from dataclasses import replace
+from pathlib import Path
+
+from cl_bench.config import ExperimentConfig, load_config_with_overrides
+from cl_bench.experiments import run_experiment
+from cl_bench.reporting import collect_runs, write_report
+from cl_bench.tracking import MLflowRunLogger
+
+METHODS = ("baseline", "ewc", "replay", "lwf", "derpp", "agem")
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = build_parser()
+ args = parser.parse_args(argv)
+
+ if args.command == "list-configs":
+ for path in sorted((Path.cwd() / "configs").glob("*.yaml")):
+ print(path)
+ return 0
+
+ if args.command == "run":
+ config = apply_cli_overrides(load_cli_config(args), args)
+ result = run_experiment(config)
+ print(f"Run directory: {result.run_dir}")
+ print(f"Metrics: {result.metrics_path}")
+ print(f"Average final accuracy: {result.summary['average_final_accuracy']:.2f}%")
+ print(f"Average forgetting: {result.summary['average_forgetting']:.2f}%")
+ return 0
+
+ if args.command == "suite":
+ base_config = apply_cli_overrides(load_cli_config(args), args)
+ seeds = args.seeds or [base_config.seed]
+ run_dirs: list[Path] = []
+
+ for seed in seeds:
+ for method in args.methods:
+ config = replace(base_config, method=method, seed=seed)
+ result = run_experiment(config)
+ run_dirs.append(result.run_dir)
+ print(
+ f"{method} seed={seed}: "
+ f"{result.summary['average_final_accuracy']:.2f}% final accuracy, "
+ f"{result.summary['average_forgetting']:.2f}% forgetting "
+ f"({result.run_dir})"
+ )
+
+ if args.report_dir:
+ records = collect_runs(run_dirs)
+ report = write_report(
+ records,
+ output_dir=args.report_dir,
+ title=args.title or f"{base_config.name} benchmark report",
+ make_plots=not args.no_plots,
+ )
+ log_suite_report_to_mlflow(base_config, report.report_dir, len(records))
+ print(f"Report directory: {report.report_dir}")
+ print(f"Leaderboard: {report.leaderboard_csv}")
+ return 0
+
+ if args.command == "report":
+ records = collect_runs(args.runs)
+ report = write_report(
+ records,
+ output_dir=args.output_dir,
+ title=args.title,
+ make_plots=not args.no_plots,
+ )
+ print(f"Report directory: {report.report_dir}")
+ print(f"Leaderboard: {report.leaderboard_csv}")
+ if report.plots:
+ print("Plots:")
+ for plot in report.plots:
+ print(f" {plot}")
+ return 0
+
+ parser.print_help()
+ return 1
+
+
+def build_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser(
+ prog="cl-bench",
+ description="Run reproducible continual-learning benchmarks.",
+ )
+ subparsers = parser.add_subparsers(dest="command")
+
+ run_parser = subparsers.add_parser("run", help="Run a benchmark from a YAML config.")
+ add_config_arguments(run_parser)
+ run_parser.add_argument("--method", choices=METHODS)
+ add_runtime_overrides(run_parser)
+ run_parser.add_argument("overrides", nargs="*", help="Hydra/OmegaConf key=value overrides.")
+
+ suite_parser = subparsers.add_parser(
+ "suite",
+ help="Run multiple methods/seeds from one config and optionally build a report.",
+ )
+ add_config_arguments(suite_parser)
+ suite_parser.add_argument("--methods", nargs="+", choices=METHODS, default=list(METHODS))
+ suite_parser.add_argument("--seeds", nargs="+", type=int)
+ suite_parser.add_argument("--report-dir")
+ suite_parser.add_argument("--title")
+ suite_parser.add_argument("--no-plots", action="store_true")
+ add_runtime_overrides(suite_parser)
+ suite_parser.add_argument("overrides", nargs="*", help="Hydra/OmegaConf key=value overrides.")
+
+ report_parser = subparsers.add_parser(
+ "report",
+ help="Aggregate existing run directories or metrics.json files into a report.",
+ )
+ report_parser.add_argument("--runs", nargs="+", required=True)
+ report_parser.add_argument("--output-dir", required=True)
+ report_parser.add_argument("--title", default="Continual-learning benchmark report")
+ report_parser.add_argument("--no-plots", action="store_true")
+
+ subparsers.add_parser("list-configs", help="List YAML configs in ./configs.")
+ return parser
+
+
+def add_config_arguments(parser: argparse.ArgumentParser) -> None:
+ config_group = parser.add_mutually_exclusive_group(required=True)
+ config_group.add_argument("--config", help="Config path or config name.")
+ config_group.add_argument("--config-name", help="Hydra-style config name from ./configs.")
+
+
+def add_runtime_overrides(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument("--model", choices=["linear", "mlp", "small_cnn", "cnn"])
+ parser.add_argument("--epochs", type=int)
+ parser.add_argument("--seed", type=int)
+ parser.add_argument("--device")
+ parser.add_argument("--output-dir")
+ parser.add_argument("--data-dir")
+ parser.add_argument("--batch-size", type=int)
+ parser.add_argument("--eval-batch-size", type=int)
+ parser.add_argument("--learning-rate", type=float)
+ parser.add_argument("--tracking", choices=["json", "mlflow", "both"])
+ parser.add_argument("--save-checkpoint", action="store_true")
+
+
+def apply_cli_overrides(config: ExperimentConfig, args: argparse.Namespace) -> ExperimentConfig:
+ updates = {}
+ for arg_name, field_name in [
+ ("method", "method"),
+ ("model", "model"),
+ ("epochs", "epochs"),
+ ("seed", "seed"),
+ ("device", "device"),
+ ("output_dir", "output_dir"),
+ ("data_dir", "data_dir"),
+ ("batch_size", "batch_size"),
+ ("eval_batch_size", "eval_batch_size"),
+ ("learning_rate", "learning_rate"),
+ ("tracking", "tracking"),
+ ]:
+ value = getattr(args, arg_name, None)
+ if value is not None:
+ updates[field_name] = value
+
+ if getattr(args, "save_checkpoint", False):
+ updates["save_checkpoint"] = True
+
+ return replace(config, **updates)
+
+
+def load_cli_config(args: argparse.Namespace) -> ExperimentConfig:
+ source = args.config_name or args.config
+ overrides = getattr(args, "overrides", None) or []
+ return load_config_with_overrides(source, overrides)
+
+
+def log_suite_report_to_mlflow(config: ExperimentConfig, report_dir: Path, run_count: int) -> None:
+ if config.tracking.lower() not in {"mlflow", "both"}:
+ return
+ with MLflowRunLogger(
+ tracking_uri=config.mlflow_tracking_uri,
+ experiment_name=config.mlflow_experiment,
+ run_name=f"{config.name}_suite_report",
+ enabled=True,
+ ) as logger:
+ logger.log_params(
+ {
+ "benchmark": config.name,
+ "report_dir": str(report_dir),
+ "run_count": run_count,
+ "artifact_type": "suite_report",
+ }
+ )
+ logger.log_artifacts(report_dir)
diff --git a/src/cl_bench/config.py b/src/cl_bench/config.py
new file mode 100644
index 0000000..6dc5d3a
--- /dev/null
+++ b/src/cl_bench/config.py
@@ -0,0 +1,222 @@
+from __future__ import annotations
+
+from dataclasses import asdict, dataclass
+from pathlib import Path
+from typing import Any
+
+import yaml
+
+
+@dataclass
+class TaskSpec:
+ """Declarative description of one task in a continual-learning benchmark."""
+
+ name: str
+ dataset: str
+ classes: list[int] | str
+ samples_per_class: int | None = None
+ test_samples_per_class: int | None = None
+ train_limit: int | None = None
+ test_limit: int | None = None
+
+ @classmethod
+ def from_dict(cls, raw: dict[str, Any]) -> TaskSpec:
+ classes = raw["classes"]
+ if classes != "all":
+ classes = [int(label) for label in classes]
+ return cls(
+ name=str(raw["name"]),
+ dataset=str(raw["dataset"]),
+ classes=classes,
+ samples_per_class=_optional_int(raw.get("samples_per_class")),
+ test_samples_per_class=_optional_int(raw.get("test_samples_per_class")),
+ train_limit=_optional_int(raw.get("train_limit")),
+ test_limit=_optional_int(raw.get("test_limit")),
+ )
+
+
+@dataclass
+class ExperimentConfig:
+ """Runtime configuration for a benchmark run."""
+
+ name: str
+ method: str
+ tasks: list[TaskSpec]
+ seed: int = 42
+ device: str = "auto"
+ model: str = "mlp"
+ data_dir: str = "data"
+ output_dir: str = "runs"
+ tracking: str = "json"
+ mlflow_tracking_uri: str = "sqlite:///mlruns/mlflow.db"
+ mlflow_experiment: str = "continual-learning-bench"
+ epochs: int = 1
+ batch_size: int = 64
+ eval_batch_size: int = 256
+ learning_rate: float = 1e-3
+ val_fraction: float = 0.1
+ num_workers: int = 0
+ augment: bool = True
+ ewc_lambda: float = 50.0
+ fisher_samples: int = 128
+ replay_buffer_size: int = 512
+ replay_batch_size: int = 32
+ replay_loss_weight: float = 1.0
+ lwf_alpha: float = 0.5
+ lwf_temperature: float = 2.0
+ derpp_alpha: float = 0.5
+ derpp_beta: float = 1.0
+ agem_memory_batch_size: int = 64
+ save_checkpoint: bool = False
+
+ @classmethod
+ def from_dict(cls, raw: dict[str, Any]) -> ExperimentConfig:
+ training = raw.get("training", {})
+ strategy = raw.get("strategy", {})
+ tracking = raw.get("tracking", {})
+ if not isinstance(tracking, dict):
+ tracking = {"mode": tracking}
+ tasks = [TaskSpec.from_dict(task) for task in raw["tasks"]]
+ return cls(
+ name=str(raw.get("name", "continual_learning")),
+ method=str(raw.get("method", "baseline")),
+ tasks=tasks,
+ seed=int(raw.get("seed", 42)),
+ device=str(raw.get("device", "auto")),
+ model=str(raw.get("model", "mlp")),
+ data_dir=str(raw.get("data_dir", "data")),
+ output_dir=str(raw.get("output_dir", "runs")),
+ tracking=str(tracking.get("mode", raw.get("tracking", "json"))),
+ mlflow_tracking_uri=str(
+ tracking.get(
+ "mlflow_tracking_uri",
+ raw.get("mlflow_tracking_uri", "sqlite:///mlruns/mlflow.db"),
+ )
+ ),
+ mlflow_experiment=str(
+ tracking.get(
+ "mlflow_experiment",
+ raw.get("mlflow_experiment", "continual-learning-bench"),
+ )
+ ),
+ epochs=int(training.get("epochs", raw.get("epochs", 1))),
+ batch_size=int(training.get("batch_size", raw.get("batch_size", 64))),
+ eval_batch_size=int(training.get("eval_batch_size", raw.get("eval_batch_size", 256))),
+ learning_rate=float(training.get("learning_rate", raw.get("learning_rate", 1e-3))),
+ val_fraction=float(training.get("val_fraction", raw.get("val_fraction", 0.1))),
+ num_workers=int(training.get("num_workers", raw.get("num_workers", 0))),
+ augment=bool(training.get("augment", raw.get("augment", True))),
+ ewc_lambda=float(strategy.get("ewc_lambda", raw.get("ewc_lambda", 50.0))),
+ fisher_samples=int(strategy.get("fisher_samples", raw.get("fisher_samples", 128))),
+ replay_buffer_size=int(
+ strategy.get("replay_buffer_size", raw.get("replay_buffer_size", 512))
+ ),
+ replay_batch_size=int(
+ strategy.get("replay_batch_size", raw.get("replay_batch_size", 32))
+ ),
+ replay_loss_weight=float(
+ strategy.get("replay_loss_weight", raw.get("replay_loss_weight", 1.0))
+ ),
+ lwf_alpha=float(strategy.get("lwf_alpha", raw.get("lwf_alpha", 0.5))),
+ lwf_temperature=float(strategy.get("lwf_temperature", raw.get("lwf_temperature", 2.0))),
+ derpp_alpha=float(strategy.get("derpp_alpha", raw.get("derpp_alpha", 0.5))),
+ derpp_beta=float(strategy.get("derpp_beta", raw.get("derpp_beta", 1.0))),
+ agem_memory_batch_size=int(
+ strategy.get("agem_memory_batch_size", raw.get("agem_memory_batch_size", 64))
+ ),
+ save_checkpoint=bool(raw.get("save_checkpoint", False)),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ return asdict(self)
+
+
+@dataclass
+class BenchmarkResult:
+ """Summary returned by an experiment run."""
+
+ run_dir: Path
+ method: str
+ task_names: list[str]
+ accuracy_matrix: list[list[float | None]]
+ forgetting_matrix: list[list[float | None]]
+ summary: dict[str, float | int | str | None]
+ metrics_path: Path
+ config_path: Path
+ runtime_seconds: float
+ git_commit: str | None = None
+
+
+def load_config(source: str | Path) -> ExperimentConfig:
+ """Load an experiment config from a YAML path or a known config name."""
+
+ path = resolve_config_path(source)
+ with path.open("r", encoding="utf-8") as handle:
+ raw = yaml.safe_load(handle)
+ if not isinstance(raw, dict):
+ raise ValueError(f"Config {path} must contain a YAML mapping.")
+ return ExperimentConfig.from_dict(raw)
+
+
+def load_config_with_overrides(
+ source: str | Path,
+ overrides: list[str] | None = None,
+) -> ExperimentConfig:
+ """Load config and apply Hydra/OmegaConf dot-list overrides."""
+
+ overrides = overrides or []
+ if not overrides:
+ return load_config(source)
+
+ try:
+ from omegaconf import OmegaConf
+ except ImportError as exc:
+ raise RuntimeError(
+ "Hydra-style overrides require the experiment extra: "
+ 'python -m pip install -e ".[experiment]"'
+ ) from exc
+
+ path = resolve_config_path(source)
+ base = OmegaConf.load(path)
+ override_conf = OmegaConf.from_dotlist(overrides)
+ merged = OmegaConf.merge(base, override_conf)
+ raw = OmegaConf.to_container(merged, resolve=True)
+ if not isinstance(raw, dict):
+ raise ValueError(f"Config {path} must contain a YAML mapping.")
+ return ExperimentConfig.from_dict(raw)
+
+
+def resolve_config_path(source: str | Path) -> Path:
+ candidate = Path(source)
+ if candidate.exists():
+ return candidate
+
+ name = str(source)
+ if not name.endswith((".yaml", ".yml")):
+ name = f"{name}.yaml"
+
+ search_roots = [
+ Path.cwd() / "configs",
+ Path.cwd() / "configs" / "experiments",
+ Path(__file__).resolve().parents[2] / "configs",
+ Path(__file__).resolve().parents[2] / "configs" / "experiments",
+ ]
+ for root in search_roots:
+ path = root / name
+ if path.exists():
+ return path
+
+ searched = ", ".join(str(root / name) for root in search_roots)
+ raise FileNotFoundError(f"Could not find config '{source}'. Searched: {searched}")
+
+
+def dump_config(config: ExperimentConfig, path: Path) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with path.open("w", encoding="utf-8") as handle:
+ yaml.safe_dump(config.to_dict(), handle, sort_keys=False)
+
+
+def _optional_int(value: Any) -> int | None:
+ if value is None:
+ return None
+ return int(value)
diff --git a/src/cl_bench/datasets.py b/src/cl_bench/datasets.py
new file mode 100644
index 0000000..585d1b9
--- /dev/null
+++ b/src/cl_bench/datasets.py
@@ -0,0 +1,227 @@
+from __future__ import annotations
+
+from collections.abc import Sequence
+from dataclasses import dataclass
+from pathlib import Path
+
+import torch
+from torch.utils.data import DataLoader, Dataset, Subset, random_split
+from torchvision import datasets as tv_datasets
+from torchvision import transforms
+
+from cl_bench.config import ExperimentConfig, TaskSpec
+
+
+@dataclass
+class TaskLoaders:
+ name: str
+ dataset: str
+ classes: list[int]
+ train_loader: DataLoader
+ val_loader: DataLoader
+ test_loader: DataLoader
+
+
+class SyntheticImageDataset(Dataset):
+ """Small deterministic image classification dataset for CI and smoke runs."""
+
+ def __init__(
+ self,
+ classes: Sequence[int],
+ samples_per_class: int,
+ image_shape: tuple[int, int, int] = (1, 8, 8),
+ noise_std: float = 0.12,
+ seed: int = 0,
+ ):
+ self.classes = [int(label) for label in classes]
+ self.data: list[torch.Tensor] = []
+ self.targets: list[int] = []
+ noise_generator = torch.Generator().manual_seed(seed)
+
+ for class_id in self.classes:
+ prototype_generator = torch.Generator().manual_seed(10_003 + class_id * 997)
+ prototype = torch.randn(image_shape, generator=prototype_generator) * 0.6
+ prototype = prototype + float(class_id) * 0.15
+ for _ in range(samples_per_class):
+ noise = torch.randn(image_shape, generator=noise_generator) * noise_std
+ self.data.append((prototype + noise).float())
+ self.targets.append(class_id)
+
+ def __len__(self) -> int:
+ return len(self.targets)
+
+ def __getitem__(self, index: int) -> tuple[torch.Tensor, int]:
+ return self.data[index], self.targets[index]
+
+
+def build_task_loaders(config: ExperimentConfig) -> tuple[list[TaskLoaders], tuple[int, ...], int]:
+ task_loaders: list[TaskLoaders] = []
+ all_classes: set[int] = set()
+
+ for task_id, task in enumerate(config.tasks):
+ train_dataset, test_dataset, classes = _build_datasets(task, config, task_id)
+ all_classes.update(classes)
+
+ train_subset, val_subset = _split_train_validation(
+ train_dataset,
+ val_fraction=config.val_fraction,
+ seed=config.seed + task_id,
+ )
+ train_loader = DataLoader(
+ train_subset,
+ batch_size=config.batch_size,
+ shuffle=True,
+ generator=torch.Generator().manual_seed(config.seed + task_id),
+ num_workers=config.num_workers,
+ )
+ val_loader = DataLoader(
+ val_subset,
+ batch_size=config.eval_batch_size,
+ shuffle=False,
+ num_workers=config.num_workers,
+ )
+ test_loader = DataLoader(
+ test_dataset,
+ batch_size=config.eval_batch_size,
+ shuffle=False,
+ num_workers=config.num_workers,
+ )
+ task_loaders.append(
+ TaskLoaders(
+ name=task.name,
+ dataset=task.dataset,
+ classes=classes,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ test_loader=test_loader,
+ )
+ )
+
+ if not task_loaders:
+ raise ValueError("At least one task must be configured.")
+
+ first_inputs, _ = task_loaders[0].train_loader.dataset[0]
+ input_shape = tuple(first_inputs.shape)
+ num_classes = max(all_classes) + 1
+ return task_loaders, input_shape, num_classes
+
+
+def _build_datasets(
+ task: TaskSpec, config: ExperimentConfig, task_id: int
+) -> tuple[Dataset, Dataset, list[int]]:
+ dataset_name = task.dataset.lower()
+ if dataset_name == "synthetic":
+ if task.classes == "all":
+ raise ValueError("Synthetic tasks must list explicit class ids.")
+ train_samples = task.samples_per_class or 32
+ test_samples = task.test_samples_per_class or max(8, train_samples // 2)
+ train_dataset = SyntheticImageDataset(
+ task.classes,
+ samples_per_class=train_samples,
+ seed=config.seed + task_id * 100,
+ )
+ test_dataset = SyntheticImageDataset(
+ task.classes,
+ samples_per_class=test_samples,
+ seed=config.seed + task_id * 100 + 50_000,
+ )
+ return train_dataset, test_dataset, [int(label) for label in task.classes]
+
+ train_dataset = _torchvision_dataset(
+ dataset_name,
+ Path(config.data_dir),
+ train=True,
+ augment=config.augment,
+ )
+ test_dataset = _torchvision_dataset(
+ dataset_name,
+ Path(config.data_dir),
+ train=False,
+ augment=False,
+ )
+ classes = _resolve_classes(task, train_dataset)
+ train_subset = _class_subset(train_dataset, classes, task.train_limit, config.seed + task_id)
+ test_subset = _class_subset(
+ test_dataset, classes, task.test_limit, config.seed + task_id + 10_000
+ )
+ return train_subset, test_subset, classes
+
+
+def _torchvision_dataset(dataset_name: str, data_dir: Path, train: bool, augment: bool) -> Dataset:
+ transform = _torchvision_transform(dataset_name, train=train, augment=augment)
+ if dataset_name == "mnist":
+ return tv_datasets.MNIST(data_dir, train=train, download=True, transform=transform)
+ if dataset_name == "fashion_mnist":
+ return tv_datasets.FashionMNIST(data_dir, train=train, download=True, transform=transform)
+ if dataset_name == "kmnist":
+ return tv_datasets.KMNIST(data_dir, train=train, download=True, transform=transform)
+ if dataset_name == "cifar10":
+ return tv_datasets.CIFAR10(data_dir, train=train, download=True, transform=transform)
+ raise ValueError(f"Unsupported dataset: {dataset_name}")
+
+
+def _torchvision_transform(dataset_name: str, train: bool, augment: bool) -> transforms.Compose:
+ if dataset_name == "cifar10":
+ steps: list[object] = []
+ if train and augment:
+ steps.extend([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()])
+ steps.extend(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=(0.4914, 0.4822, 0.4465),
+ std=(0.2470, 0.2435, 0.2616),
+ ),
+ ]
+ )
+ return transforms.Compose(steps)
+
+ return transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
+
+
+def _resolve_classes(task: TaskSpec, dataset: Dataset) -> list[int]:
+ if task.classes == "all":
+ if not hasattr(dataset, "classes"):
+ raise ValueError(f"Dataset {task.dataset} does not expose class metadata.")
+ return list(range(len(dataset.classes)))
+ return [int(label) for label in task.classes]
+
+
+def _class_subset(dataset: Dataset, classes: list[int], limit: int | None, seed: int) -> Subset:
+ targets = getattr(dataset, "targets", None)
+ if targets is None:
+ targets = [dataset[index][1] for index in range(len(dataset))]
+ targets_tensor = torch.as_tensor(targets)
+ class_tensor = torch.tensor(classes, dtype=targets_tensor.dtype)
+ mask = torch.isin(targets_tensor, class_tensor)
+ indices = torch.nonzero(mask, as_tuple=False).flatten().tolist()
+
+ if limit is not None and limit < len(indices):
+ generator = torch.Generator().manual_seed(seed)
+ selected = torch.randperm(len(indices), generator=generator)[:limit].tolist()
+ indices = [indices[index] for index in selected]
+
+ return Subset(dataset, indices)
+
+
+def _split_train_validation(
+ dataset: Dataset, val_fraction: float, seed: int
+) -> tuple[Subset, Subset]:
+ if not 0.0 <= val_fraction < 1.0:
+ raise ValueError("val_fraction must be in [0.0, 1.0).")
+
+ total = len(dataset)
+ if total < 2 or val_fraction == 0.0:
+ return Subset(dataset, list(range(total))), Subset(dataset, list(range(total)))
+
+ val_size = max(1, int(total * val_fraction))
+ train_size = total - val_size
+ if train_size <= 0:
+ train_size, val_size = total - 1, 1
+
+ train_subset, val_subset = random_split(
+ dataset,
+ [train_size, val_size],
+ generator=torch.Generator().manual_seed(seed),
+ )
+ return train_subset, val_subset
diff --git a/src/cl_bench/experiments.py b/src/cl_bench/experiments.py
new file mode 100644
index 0000000..b400053
--- /dev/null
+++ b/src/cl_bench/experiments.py
@@ -0,0 +1,179 @@
+from __future__ import annotations
+
+import random
+import time
+from pathlib import Path
+
+import numpy as np
+import torch
+
+from cl_bench.config import BenchmarkResult, ExperimentConfig, dump_config
+from cl_bench.datasets import build_task_loaders
+from cl_bench.metrics import compute_forgetting, matrix_to_jsonable, summarize_accuracy
+from cl_bench.models import get_model
+from cl_bench.strategies import create_strategy
+from cl_bench.tracking import ExperimentTracker, MLflowRunLogger, create_run_dir, git_commit
+
+
+def run_experiment(config: ExperimentConfig, repo_dir: str | Path | None = None) -> BenchmarkResult:
+ set_seed(config.seed)
+ device = resolve_device(config.device)
+ tasks, input_shape, num_classes = build_task_loaders(config)
+ model = get_model(config.model, input_shape=input_shape, num_classes=num_classes).to(device)
+ strategy = create_strategy(config, model, device)
+
+ run_dir = create_run_dir(config.output_dir, config.name, config.method)
+ tracker = ExperimentTracker(run_dir)
+ config_path = run_dir / "config.yaml"
+ dump_config(config, config_path)
+
+ commit = git_commit(repo_dir or Path.cwd())
+ metadata = {
+ "benchmark": config.name,
+ "method": config.method,
+ "seed": config.seed,
+ "device": str(device),
+ "git_commit": commit,
+ }
+ tracker.write_json("run_metadata.json", metadata)
+ mlflow_enabled = config.tracking.lower() in {"mlflow", "both"}
+
+ with MLflowRunLogger(
+ tracking_uri=config.mlflow_tracking_uri,
+ experiment_name=config.mlflow_experiment,
+ run_name=f"{config.name}_{config.method}_seed_{config.seed}",
+ enabled=mlflow_enabled,
+ ) as mlflow_logger:
+ mlflow_logger.log_params(config.to_dict())
+ mlflow_logger.set_tags(metadata)
+ mlflow_logger.log_environment()
+
+ start_time = time.perf_counter()
+ accuracy_matrix = np.full((len(tasks), len(tasks)), np.nan, dtype=float)
+
+ for task_id, task in enumerate(tasks):
+ tracker.log_event(
+ {
+ "event": "task_started",
+ "task_id": task_id,
+ "task_name": task.name,
+ "classes": task.classes,
+ }
+ )
+ history = strategy.train_task(
+ task.train_loader,
+ task.val_loader,
+ task_id=task_id,
+ epochs=config.epochs,
+ )
+ for epoch_metrics in history:
+ tracker.log_event({"event": "epoch_finished", **epoch_metrics})
+ mlflow_logger.log_metrics(
+ {
+ f"task_{task_id}_{key}": value
+ for key, value in epoch_metrics.items()
+ if key not in {"task_id", "epoch"}
+ },
+ step=task_id * config.epochs + int(epoch_metrics["epoch"]),
+ )
+
+ for eval_task_id in range(task_id + 1):
+ eval_task = tasks[eval_task_id]
+ metrics = strategy.evaluate(eval_task.test_loader)
+ accuracy_matrix[task_id, eval_task_id] = metrics["accuracy"]
+ tracker.log_event(
+ {
+ "event": "evaluation",
+ "after_task_id": task_id,
+ "eval_task_id": eval_task_id,
+ "eval_task_name": eval_task.name,
+ **metrics,
+ }
+ )
+ mlflow_logger.log_metrics(
+ {
+ f"eval_after_{task_id}_task_{eval_task_id}_accuracy": metrics["accuracy"],
+ f"eval_after_{task_id}_task_{eval_task_id}_loss": metrics["loss"],
+ },
+ step=task_id,
+ )
+
+ runtime_seconds = time.perf_counter() - start_time
+ forgetting_matrix = compute_forgetting(accuracy_matrix)
+ summary = summarize_accuracy(accuracy_matrix)
+ summary.update(
+ {
+ "runtime_seconds": runtime_seconds,
+ "seed": config.seed,
+ "num_tasks": len(tasks),
+ "model": config.model,
+ "replay_buffer_size": config.replay_buffer_size,
+ "replay_batch_size": config.replay_batch_size,
+ }
+ )
+
+ tracker.write_json("accuracy_matrix.json", matrix_to_jsonable(accuracy_matrix))
+ tracker.write_json("forgetting_matrix.json", matrix_to_jsonable(forgetting_matrix))
+ tracker.write_matrix_csv("accuracy_matrix.csv", accuracy_matrix)
+ tracker.write_matrix_csv("forgetting_matrix.csv", forgetting_matrix)
+
+ if config.save_checkpoint:
+ strategy.save_checkpoint(run_dir / "checkpoints" / "final_model.pt")
+
+ metrics_path = tracker.write_json(
+ "metrics.json",
+ {
+ "benchmark": config.name,
+ "method": config.method,
+ "task_names": [task.name for task in tasks],
+ "summary": summary,
+ "accuracy_matrix": matrix_to_jsonable(accuracy_matrix),
+ "forgetting_matrix": matrix_to_jsonable(forgetting_matrix),
+ "runtime_seconds": runtime_seconds,
+ "seed": config.seed,
+ "git_commit": commit,
+ },
+ )
+ mlflow_logger.log_metrics(summary)
+ mlflow_logger.log_artifacts(run_dir)
+
+ return BenchmarkResult(
+ run_dir=run_dir,
+ method=config.method,
+ task_names=[task.name for task in tasks],
+ accuracy_matrix=matrix_to_jsonable(accuracy_matrix),
+ forgetting_matrix=matrix_to_jsonable(forgetting_matrix),
+ summary=summary,
+ metrics_path=metrics_path,
+ config_path=config_path,
+ runtime_seconds=runtime_seconds,
+ git_commit=commit,
+ )
+
+
+def set_seed(seed: int) -> None:
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def resolve_device(device_name: str) -> torch.device:
+ requested = device_name.lower()
+ if requested == "auto":
+ if torch.cuda.is_available():
+ return torch.device("cuda")
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
+ return torch.device("mps")
+ return torch.device("cpu")
+ device = torch.device(requested)
+ if device.type == "cuda" and not torch.cuda.is_available():
+ raise RuntimeError("CUDA was requested but is not available.")
+ if device.type == "mps" and not (
+ hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
+ ):
+ raise RuntimeError("MPS was requested but is not available.")
+ return device
diff --git a/src/cl_bench/metrics.py b/src/cl_bench/metrics.py
new file mode 100644
index 0000000..db8c860
--- /dev/null
+++ b/src/cl_bench/metrics.py
@@ -0,0 +1,93 @@
+from __future__ import annotations
+
+from collections.abc import Iterable
+
+import numpy as np
+import torch
+
+
+def accuracy_from_logits(logits: torch.Tensor, targets: torch.Tensor) -> float:
+ """Return classification accuracy as a percentage."""
+
+ if targets.numel() == 0:
+ return 0.0
+ predictions = logits.argmax(dim=1)
+ return float((predictions == targets).float().mean().item() * 100.0)
+
+
+def compute_forgetting(accuracy_matrix: np.ndarray) -> np.ndarray:
+ """Compute best-so-far forgetting for every evaluated task."""
+
+ matrix = np.asarray(accuracy_matrix, dtype=float)
+ forgetting = np.full(matrix.shape, np.nan, dtype=float)
+
+ for train_step in range(matrix.shape[0]):
+ for task_id in range(matrix.shape[1]):
+ current = matrix[train_step, task_id]
+ if np.isnan(current) or task_id > train_step:
+ continue
+ if train_step == task_id:
+ forgetting[train_step, task_id] = 0.0
+ continue
+
+ previous = matrix[:train_step, task_id]
+ if np.isnan(previous).all():
+ forgetting[train_step, task_id] = 0.0
+ continue
+ best_previous = float(np.nanmax(previous))
+ forgetting[train_step, task_id] = max(0.0, best_previous - current)
+
+ return forgetting
+
+
+def summarize_accuracy(accuracy_matrix: np.ndarray) -> dict[str, float]:
+ """Summarize continual-learning accuracy and forgetting metrics."""
+
+ matrix = np.asarray(accuracy_matrix, dtype=float)
+ final_row = matrix[-1]
+ final_seen = final_row[~np.isnan(final_row)]
+ average_final_accuracy = float(np.mean(final_seen)) if final_seen.size else 0.0
+
+ diagonal = np.diag(matrix)
+ learned = diagonal[~np.isnan(diagonal)]
+ average_learning_accuracy = float(np.mean(learned)) if learned.size else 0.0
+
+ forgetting = compute_forgetting(matrix)
+ if matrix.shape[0] > 1:
+ final_forgetting = forgetting[-1, : matrix.shape[0] - 1]
+ final_forgetting = final_forgetting[~np.isnan(final_forgetting)]
+ average_forgetting = float(np.mean(final_forgetting)) if final_forgetting.size else 0.0
+ else:
+ average_forgetting = 0.0
+
+ if matrix.shape[0] > 1:
+ previous_task_ids = range(matrix.shape[0] - 1)
+ transfers = [
+ matrix[-1, task_id] - matrix[task_id, task_id]
+ for task_id in previous_task_ids
+ if not np.isnan(matrix[-1, task_id]) and not np.isnan(matrix[task_id, task_id])
+ ]
+ backward_transfer = float(np.mean(transfers)) if transfers else 0.0
+ else:
+ backward_transfer = 0.0
+
+ return {
+ "average_final_accuracy": average_final_accuracy,
+ "average_learning_accuracy": average_learning_accuracy,
+ "average_forgetting": average_forgetting,
+ "backward_transfer": backward_transfer,
+ }
+
+
+def matrix_to_jsonable(matrix: np.ndarray) -> list[list[float | None]]:
+ """Convert a NumPy matrix to JSON-safe nested lists."""
+
+ result: list[list[float | None]] = []
+ for row in np.asarray(matrix, dtype=float):
+ result.append([None if np.isnan(value) else float(value) for value in row])
+ return result
+
+
+def mean_or_zero(values: Iterable[float]) -> float:
+ values = list(values)
+ return float(np.mean(values)) if values else 0.0
diff --git a/src/cl_bench/models.py b/src/cl_bench/models.py
new file mode 100644
index 0000000..91c8696
--- /dev/null
+++ b/src/cl_bench/models.py
@@ -0,0 +1,132 @@
+from __future__ import annotations
+
+from functools import reduce
+from operator import mul
+
+import torch
+from torch import nn
+
+
+class LinearClassifier(nn.Module):
+ def __init__(self, input_shape: tuple[int, ...], num_classes: int):
+ super().__init__()
+ self.net = nn.Sequential(nn.Flatten(), nn.Linear(_num_features(input_shape), num_classes))
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ return self.net(inputs)
+
+
+class MLP(nn.Module):
+ def __init__(self, input_shape: tuple[int, ...], num_classes: int, hidden_dim: int = 128):
+ super().__init__()
+ input_dim = _num_features(input_shape)
+ self.net = nn.Sequential(
+ nn.Flatten(),
+ nn.Linear(input_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, num_classes),
+ )
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ return self.net(inputs)
+
+
+class SmallCNN(nn.Module):
+ def __init__(self, input_shape: tuple[int, ...], num_classes: int):
+ super().__init__()
+ if len(input_shape) != 3:
+ raise ValueError(f"SmallCNN expects CHW input, got {input_shape}.")
+ channels, height, width = input_shape
+ self.features = nn.Sequential(
+ nn.Conv2d(channels, 32, kernel_size=3, padding=1),
+ nn.ReLU(),
+ nn.MaxPool2d(2),
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
+ nn.ReLU(),
+ nn.MaxPool2d(2),
+ )
+ pooled_height = max(1, height // 4)
+ pooled_width = max(1, width // 4)
+ self.classifier = nn.Sequential(
+ nn.Flatten(),
+ nn.Linear(64 * pooled_height * pooled_width, 128),
+ nn.ReLU(),
+ nn.Linear(128, num_classes),
+ )
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ return self.classifier(self.features(inputs))
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
+ ),
+ _norm(out_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
+ _norm(out_channels),
+ )
+ if stride != 1 or in_channels != out_channels:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
+ _norm(out_channels),
+ )
+ else:
+ self.shortcut = nn.Identity()
+ self.activation = nn.ReLU(inplace=True)
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ return self.activation(self.net(inputs) + self.shortcut(inputs))
+
+
+class CifarConvNet(nn.Module):
+ """Compact residual CNN tuned for 32x32 RGB continual-learning benchmarks."""
+
+ def __init__(self, input_shape: tuple[int, ...], num_classes: int):
+ super().__init__()
+ if len(input_shape) != 3:
+ raise ValueError(f"CifarConvNet expects CHW input, got {input_shape}.")
+ channels, _, _ = input_shape
+ self.features = nn.Sequential(
+ nn.Conv2d(channels, 64, kernel_size=3, padding=1, bias=False),
+ _norm(64),
+ nn.ReLU(inplace=True),
+ ResidualBlock(64, 64),
+ ResidualBlock(64, 128, stride=2),
+ ResidualBlock(128, 128),
+ ResidualBlock(128, 256, stride=2),
+ ResidualBlock(256, 256),
+ nn.AdaptiveAvgPool2d((1, 1)),
+ )
+ self.classifier = nn.Linear(256, num_classes)
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ features = self.features(inputs)
+ return self.classifier(torch.flatten(features, 1))
+
+
+def get_model(model_name: str, input_shape: tuple[int, ...], num_classes: int) -> nn.Module:
+ name = model_name.lower().replace("-", "_")
+ if name == "linear":
+ return LinearClassifier(input_shape, num_classes)
+ if name == "mlp":
+ return MLP(input_shape, num_classes)
+ if name in {"small_cnn", "cnn"}:
+ return SmallCNN(input_shape, num_classes)
+ if name in {"cifar_convnet", "resnet18_cifar"}:
+ return CifarConvNet(input_shape, num_classes)
+ raise ValueError(f"Unknown model architecture: {model_name}")
+
+
+def _num_features(input_shape: tuple[int, ...]) -> int:
+ return int(reduce(mul, input_shape, 1))
+
+
+def _norm(channels: int) -> nn.BatchNorm2d:
+ return nn.BatchNorm2d(channels)
diff --git a/src/cl_bench/reporting.py b/src/cl_bench/reporting.py
new file mode 100644
index 0000000..3aa8cf0
--- /dev/null
+++ b/src/cl_bench/reporting.py
@@ -0,0 +1,487 @@
+from __future__ import annotations
+
+import csv
+import json
+from collections import defaultdict
+from collections.abc import Iterable, Sequence
+from dataclasses import dataclass
+from datetime import datetime, timezone
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+
+
+@dataclass(frozen=True)
+class RunRecord:
+ """Parsed metrics for one completed benchmark run."""
+
+ run_dir: Path
+ benchmark: str
+ method: str
+ seed: int
+ task_names: list[str]
+ summary: dict[str, float | int | str | None]
+ accuracy_matrix: np.ndarray
+ forgetting_matrix: np.ndarray
+ runtime_seconds: float
+ git_commit: str | None
+
+
+@dataclass(frozen=True)
+class ReportArtifacts:
+ """Files written by the reporting pipeline."""
+
+ report_dir: Path
+ leaderboard_csv: Path
+ summary_json: Path
+ markdown: Path
+ plots: list[Path]
+
+
+def collect_runs(sources: Sequence[str | Path]) -> list[RunRecord]:
+ """Load metrics from run directories, metrics files, or parent directories."""
+
+ metric_paths = discover_metrics(sources)
+ if not metric_paths:
+ raise FileNotFoundError("No metrics.json files were found in the supplied run paths.")
+ return [load_run(path) for path in metric_paths]
+
+
+def discover_metrics(sources: Sequence[str | Path]) -> list[Path]:
+ """Return sorted metrics.json paths discovered from files or directories."""
+
+ discovered: set[Path] = set()
+ for source in sources:
+ path = Path(source)
+ if path.is_file():
+ if path.name != "metrics.json":
+ raise ValueError(f"Expected metrics.json file, got: {path}")
+ discovered.add(path)
+ continue
+
+ if not path.exists():
+ raise FileNotFoundError(path)
+
+ direct_metrics = path / "metrics.json"
+ if direct_metrics.exists():
+ discovered.add(direct_metrics)
+ continue
+
+ for metrics_path in path.rglob("metrics.json"):
+ discovered.add(metrics_path)
+
+ return sorted(discovered)
+
+
+def load_run(metrics_path: str | Path) -> RunRecord:
+ """Parse one run metrics artifact."""
+
+ path = Path(metrics_path)
+ if path.is_dir():
+ path = path / "metrics.json"
+ with path.open("r", encoding="utf-8") as handle:
+ payload = json.load(handle)
+
+ summary = dict(payload["summary"])
+ runtime_seconds = float(payload.get("runtime_seconds", summary.get("runtime_seconds", 0.0)))
+ seed = int(payload.get("seed", summary.get("seed", 0)))
+ return RunRecord(
+ run_dir=path.parent,
+ benchmark=str(payload["benchmark"]),
+ method=str(payload["method"]),
+ seed=seed,
+ task_names=[str(name) for name in payload["task_names"]],
+ summary=summary,
+ accuracy_matrix=_matrix_from_json(payload["accuracy_matrix"]),
+ forgetting_matrix=_matrix_from_json(payload["forgetting_matrix"]),
+ runtime_seconds=runtime_seconds,
+ git_commit=payload.get("git_commit"),
+ )
+
+
+def aggregate_records(records: Sequence[RunRecord]) -> list[dict[str, float | int | str]]:
+ """Aggregate run summaries by method for leaderboard-style reporting."""
+
+ by_method: dict[str, list[RunRecord]] = defaultdict(list)
+ for record in records:
+ by_method[record.method].append(record)
+
+ rows: list[dict[str, float | int | str]] = []
+ for method, method_records in sorted(by_method.items()):
+ final_accuracy = [_metric(record, "average_final_accuracy") for record in method_records]
+ learning_accuracy = [
+ _metric(record, "average_learning_accuracy") for record in method_records
+ ]
+ forgetting = [_metric(record, "average_forgetting") for record in method_records]
+ backward_transfer = [_metric(record, "backward_transfer") for record in method_records]
+ runtimes = [record.runtime_seconds for record in method_records]
+ seeds = ",".join(
+ str(record.seed) for record in sorted(method_records, key=lambda item: item.seed)
+ )
+
+ rows.append(
+ {
+ "method": method,
+ "runs": len(method_records),
+ "seeds": seeds,
+ "average_final_accuracy_mean": _mean(final_accuracy),
+ "average_final_accuracy_std": _std(final_accuracy),
+ "average_learning_accuracy_mean": _mean(learning_accuracy),
+ "average_forgetting_mean": _mean(forgetting),
+ "average_forgetting_std": _std(forgetting),
+ "backward_transfer_mean": _mean(backward_transfer),
+ "runtime_seconds_mean": _mean(runtimes),
+ }
+ )
+
+ return sorted(rows, key=lambda row: float(row["average_final_accuracy_mean"]), reverse=True)
+
+
+def write_report(
+ records: Sequence[RunRecord],
+ output_dir: str | Path,
+ title: str,
+ make_plots: bool = True,
+) -> ReportArtifacts:
+ """Write CSV, JSON, Markdown, and optional plot artifacts for a benchmark suite."""
+
+ if not records:
+ raise ValueError("At least one run record is required to write a report.")
+
+ report_dir = Path(output_dir)
+ report_dir.mkdir(parents=True, exist_ok=True)
+ leaderboard = aggregate_records(records)
+ leaderboard_csv = _write_leaderboard_csv(report_dir / "leaderboard.csv", leaderboard)
+ summary_json = _write_summary_json(report_dir / "summary.json", title, records, leaderboard)
+ plots: list[Path] = []
+ if make_plots:
+ plots = _write_plots(records, leaderboard, report_dir, title)
+ markdown = _write_markdown(report_dir / "README.md", title, records, leaderboard, plots)
+ return ReportArtifacts(
+ report_dir=report_dir,
+ leaderboard_csv=leaderboard_csv,
+ summary_json=summary_json,
+ markdown=markdown,
+ plots=plots,
+ )
+
+
+def _write_leaderboard_csv(path: Path, rows: Sequence[dict[str, float | int | str]]) -> Path:
+ fieldnames = [
+ "method",
+ "runs",
+ "seeds",
+ "average_final_accuracy_mean",
+ "average_final_accuracy_std",
+ "average_learning_accuracy_mean",
+ "average_forgetting_mean",
+ "average_forgetting_std",
+ "backward_transfer_mean",
+ "runtime_seconds_mean",
+ ]
+ with path.open("w", encoding="utf-8", newline="") as handle:
+ writer = csv.DictWriter(handle, fieldnames=fieldnames)
+ writer.writeheader()
+ for row in rows:
+ writer.writerow(row)
+ return path
+
+
+def _write_summary_json(
+ path: Path,
+ title: str,
+ records: Sequence[RunRecord],
+ leaderboard: Sequence[dict[str, float | int | str]],
+) -> Path:
+ payload = {
+ "title": title,
+ "generated_at_utc": datetime.now(timezone.utc).isoformat(),
+ "benchmarks": sorted({record.benchmark for record in records}),
+ "num_runs": len(records),
+ "leaderboard": list(leaderboard),
+ "runs": [
+ {
+ "run_dir": str(record.run_dir),
+ "benchmark": record.benchmark,
+ "method": record.method,
+ "seed": record.seed,
+ "task_names": record.task_names,
+ "summary": record.summary,
+ "runtime_seconds": record.runtime_seconds,
+ "git_commit": record.git_commit,
+ }
+ for record in records
+ ],
+ }
+ with path.open("w", encoding="utf-8") as handle:
+ json.dump(payload, handle, indent=2, sort_keys=True)
+ handle.write("\n")
+ return path
+
+
+def _write_markdown(
+ path: Path,
+ title: str,
+ records: Sequence[RunRecord],
+ leaderboard: Sequence[dict[str, float | int | str]],
+ plots: Sequence[Path],
+) -> Path:
+ benchmark_names = ", ".join(sorted({record.benchmark for record in records}))
+ task_count = max(len(record.task_names) for record in records)
+ lines = [
+ f"# {title}",
+ "",
+ f"Benchmarks: `{benchmark_names}`",
+ f"Runs: `{len(records)}`",
+ f"Tasks per run: `{task_count}`",
+ "",
+ "## Leaderboard",
+ "",
+ "| Method | Runs | Seeds | Final accuracy | Forgetting | Backward transfer | Mean runtime |",
+ "| --- | ---: | --- | ---: | ---: | ---: | ---: |",
+ ]
+ for row in leaderboard:
+ lines.append(
+ "| {method} | {runs} | {seeds} | {accuracy} | {forgetting} | {bwt} | {runtime} |".format(
+ method=row["method"],
+ runs=row["runs"],
+ seeds=row["seeds"],
+ accuracy=_format_with_std(
+ float(row["average_final_accuracy_mean"]),
+ float(row["average_final_accuracy_std"]),
+ suffix="%",
+ ),
+ forgetting=_format_with_std(
+ float(row["average_forgetting_mean"]),
+ float(row["average_forgetting_std"]),
+ suffix="%",
+ ),
+ bwt=f"{float(row['backward_transfer_mean']):.2f}%",
+ runtime=f"{float(row['runtime_seconds_mean']):.1f}s",
+ )
+ )
+
+ if plots:
+ lines.extend(["", "## Plots", ""])
+ for plot in plots:
+ lines.append(f"")
+ lines.append("")
+
+ lines.extend(
+ [
+ "## Source Runs",
+ "",
+ "| Method | Seed | Run directory |",
+ "| --- | ---: | --- |",
+ ]
+ )
+ for record in sorted(records, key=lambda item: (item.method, item.seed, str(item.run_dir))):
+ lines.append(f"| {record.method} | {record.seed} | `{record.run_dir}` |")
+
+ path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
+ return path
+
+
+def _write_plots(
+ records: Sequence[RunRecord],
+ leaderboard: Sequence[dict[str, float | int | str]],
+ report_dir: Path,
+ title: str,
+) -> list[Path]:
+ import matplotlib
+
+ matplotlib.use("Agg")
+ import matplotlib.pyplot as plt
+
+ paths = [
+ _plot_leaderboard(plt, leaderboard, report_dir / "leaderboard.png", title),
+ _plot_retention_curves(plt, records, report_dir / "retention_curves.png"),
+ _plot_accuracy_matrices(plt, records, report_dir / "accuracy_matrices.png"),
+ ]
+ plt.close("all")
+ return paths
+
+
+def _plot_leaderboard(
+ plt: Any, leaderboard: Sequence[dict[str, float | int | str]], path: Path, title: str
+) -> Path:
+ methods = [str(row["method"]) for row in leaderboard]
+ final_means = [float(row["average_final_accuracy_mean"]) for row in leaderboard]
+ final_stds = [float(row["average_final_accuracy_std"]) for row in leaderboard]
+ forgetting_means = [float(row["average_forgetting_mean"]) for row in leaderboard]
+ forgetting_stds = [float(row["average_forgetting_std"]) for row in leaderboard]
+ colors = [_method_color(method) for method in methods]
+
+ fig, axes = plt.subplots(1, 2, figsize=(12, 5.8), constrained_layout=True)
+ fig.suptitle(title, fontsize=16, fontweight="bold")
+ _bar_chart(
+ axes[0],
+ methods,
+ final_means,
+ final_stds,
+ colors,
+ title="Average final accuracy",
+ ylabel="Accuracy (%)",
+ higher_is_better=True,
+ )
+ _bar_chart(
+ axes[1],
+ methods,
+ forgetting_means,
+ forgetting_stds,
+ colors,
+ title="Average forgetting",
+ ylabel="Forgetting (%)",
+ higher_is_better=False,
+ )
+ fig.savefig(path, dpi=180, bbox_inches="tight")
+ return path
+
+
+def _bar_chart(
+ axis: Any,
+ labels: Sequence[str],
+ means: Sequence[float],
+ stds: Sequence[float],
+ colors: Sequence[str],
+ title: str,
+ ylabel: str,
+ higher_is_better: bool,
+) -> None:
+ positions = np.arange(len(labels))
+ axis.bar(
+ positions, means, yerr=stds, color=colors, capsize=4, edgecolor="#111827", linewidth=0.8
+ )
+ axis.set_title(title, fontsize=12, fontweight="bold")
+ axis.set_ylabel(ylabel)
+ axis.set_xticks(positions, labels)
+ axis.set_ylim(0, max(100.0 if higher_is_better else 5.0, max(means, default=0.0) * 1.2 + 2.0))
+ axis.grid(axis="y", alpha=0.22)
+ for position, value in zip(positions, means, strict=True):
+ axis.text(position, value + 1.0, f"{value:.1f}", ha="center", va="bottom", fontsize=9)
+
+
+def _plot_retention_curves(plt: Any, records: Sequence[RunRecord], path: Path) -> Path:
+ fig, axis = plt.subplots(figsize=(9.5, 5.5), constrained_layout=True)
+ for method, method_records in _records_by_method(records).items():
+ curves = []
+ for record in method_records:
+ curve = []
+ for step in range(record.accuracy_matrix.shape[0]):
+ seen = record.accuracy_matrix[step, : step + 1]
+ curve.append(float(np.nanmean(seen)))
+ curves.append(curve)
+ matrix = np.asarray(curves, dtype=float)
+ x_values = np.arange(1, matrix.shape[1] + 1)
+ mean_curve = np.nanmean(matrix, axis=0)
+ std_curve = np.nanstd(matrix, axis=0)
+ color = _method_color(method)
+ axis.plot(x_values, mean_curve, marker="o", linewidth=2.2, label=method, color=color)
+ if matrix.shape[0] > 1:
+ axis.fill_between(
+ x_values, mean_curve - std_curve, mean_curve + std_curve, color=color, alpha=0.16
+ )
+
+ axis.set_title("Retention across the task stream", fontsize=13, fontweight="bold")
+ axis.set_xlabel("After training task")
+ axis.set_ylabel("Mean accuracy on seen tasks (%)")
+ axis.set_ylim(0, 100)
+ axis.set_xticks(np.arange(1, max(record.accuracy_matrix.shape[0] for record in records) + 1))
+ axis.grid(alpha=0.25)
+ axis.legend(frameon=False)
+ fig.savefig(path, dpi=180, bbox_inches="tight")
+ return path
+
+
+def _plot_accuracy_matrices(plt: Any, records: Sequence[RunRecord], path: Path) -> Path:
+ representative_records = [
+ max(method_records, key=lambda record: _metric(record, "average_final_accuracy"))
+ for method_records in _records_by_method(records).values()
+ ]
+ columns = min(2, len(representative_records))
+ rows = int(np.ceil(len(representative_records) / columns))
+ fig, axes = plt.subplots(
+ rows, columns, figsize=(6.4 * columns, 5.4 * rows), squeeze=False, constrained_layout=True
+ )
+
+ for axis in axes.ravel()[len(representative_records) :]:
+ axis.axis("off")
+
+ image = None
+ for axis, record in zip(axes.ravel(), representative_records, strict=False):
+ matrix = np.ma.masked_invalid(record.accuracy_matrix)
+ image = axis.imshow(matrix, vmin=0, vmax=100, cmap="viridis")
+ axis.set_title(f"{record.method} accuracy matrix", fontsize=12, fontweight="bold")
+ axis.set_xlabel("Evaluated task")
+ axis.set_ylabel("After training task")
+ axis.set_xticks(range(len(record.task_names)))
+ axis.set_yticks(range(len(record.task_names)))
+ axis.set_xticklabels(
+ [_short_task_name(name) for name in record.task_names], rotation=35, ha="right"
+ )
+ axis.set_yticklabels([str(index + 1) for index in range(len(record.task_names))])
+ for row in range(record.accuracy_matrix.shape[0]):
+ for column in range(record.accuracy_matrix.shape[1]):
+ value = record.accuracy_matrix[row, column]
+ if np.isnan(value):
+ continue
+ axis.text(
+ column, row, f"{value:.0f}", ha="center", va="center", color="white", fontsize=8
+ )
+
+ if image is not None:
+ fig.colorbar(image, ax=axes.ravel().tolist(), shrink=0.78, label="Accuracy (%)")
+ fig.savefig(path, dpi=180, bbox_inches="tight")
+ return path
+
+
+def _records_by_method(records: Sequence[RunRecord]) -> dict[str, list[RunRecord]]:
+ by_method: dict[str, list[RunRecord]] = defaultdict(list)
+ for record in sorted(records, key=lambda item: (item.method, item.seed)):
+ by_method[record.method].append(record)
+ return dict(by_method)
+
+
+def _matrix_from_json(values: Sequence[Sequence[float | None]]) -> np.ndarray:
+ return np.array(
+ [[np.nan if value is None else float(value) for value in row] for row in values]
+ )
+
+
+def _metric(record: RunRecord, key: str) -> float:
+ value = record.summary.get(key, 0.0)
+ return 0.0 if value is None else float(value)
+
+
+def _mean(values: Iterable[float]) -> float:
+ values = list(values)
+ return float(np.mean(values)) if values else 0.0
+
+
+def _std(values: Iterable[float]) -> float:
+ values = list(values)
+ return float(np.std(values, ddof=0)) if len(values) > 1 else 0.0
+
+
+def _format_with_std(mean: float, std: float, suffix: str) -> str:
+ if std == 0.0:
+ return f"{mean:.2f}{suffix}"
+ return f"{mean:.2f} +- {std:.2f}{suffix}"
+
+
+def _method_color(method: str) -> str:
+ return {
+ "baseline": "#475569",
+ "ewc": "#2563eb",
+ "replay": "#16a34a",
+ "lwf": "#c2410c",
+ "derpp": "#7c3aed",
+ "agem": "#0f766e",
+ }.get(method, "#7c3aed")
+
+
+def _short_task_name(name: str) -> str:
+ if "_" not in name:
+ return name
+ return name.split("_", 1)[1]
diff --git a/src/cl_bench/strategies/__init__.py b/src/cl_bench/strategies/__init__.py
new file mode 100644
index 0000000..e2d90ea
--- /dev/null
+++ b/src/cl_bench/strategies/__init__.py
@@ -0,0 +1,81 @@
+from __future__ import annotations
+
+import torch
+from torch import nn
+
+from cl_bench.config import ExperimentConfig
+from cl_bench.strategies.agem import AGEMStrategy, build_agem
+from cl_bench.strategies.baseline import BaselineStrategy, build_baseline
+from cl_bench.strategies.derpp import DERPPStrategy, build_derpp
+from cl_bench.strategies.ewc import EWCStrategy, build_ewc
+from cl_bench.strategies.lwf import LwFStrategy, build_lwf
+from cl_bench.strategies.replay import ReplayStrategy, build_replay
+
+Strategy = (
+ BaselineStrategy | EWCStrategy | ReplayStrategy | LwFStrategy | DERPPStrategy | AGEMStrategy
+)
+
+
+def create_strategy(config: ExperimentConfig, model: nn.Module, device: torch.device) -> Strategy:
+ method = config.method.lower().replace("-", "_")
+ if method == "baseline":
+ return build_baseline(model, device, config.learning_rate)
+ if method == "ewc":
+ return build_ewc(
+ model,
+ device,
+ learning_rate=config.learning_rate,
+ ewc_lambda=config.ewc_lambda,
+ fisher_samples=config.fisher_samples,
+ )
+ if method == "replay":
+ return build_replay(
+ model,
+ device,
+ learning_rate=config.learning_rate,
+ buffer_size=config.replay_buffer_size,
+ replay_batch_size=config.replay_batch_size,
+ replay_loss_weight=config.replay_loss_weight,
+ seed=config.seed,
+ )
+ if method == "derpp":
+ return build_derpp(
+ model,
+ device,
+ learning_rate=config.learning_rate,
+ buffer_size=config.replay_buffer_size,
+ replay_batch_size=config.replay_batch_size,
+ alpha=config.derpp_alpha,
+ beta=config.derpp_beta,
+ seed=config.seed,
+ )
+ if method == "agem":
+ return build_agem(
+ model,
+ device,
+ learning_rate=config.learning_rate,
+ buffer_size=config.replay_buffer_size,
+ memory_batch_size=config.agem_memory_batch_size,
+ seed=config.seed,
+ )
+ if method == "lwf":
+ return build_lwf(
+ model,
+ device,
+ learning_rate=config.learning_rate,
+ alpha=config.lwf_alpha,
+ temperature=config.lwf_temperature,
+ )
+ raise ValueError(f"Unknown continual-learning method: {config.method}")
+
+
+__all__ = [
+ "AGEMStrategy",
+ "BaselineStrategy",
+ "DERPPStrategy",
+ "EWCStrategy",
+ "LwFStrategy",
+ "ReplayStrategy",
+ "Strategy",
+ "create_strategy",
+]
diff --git a/src/cl_bench/strategies/agem.py b/src/cl_bench/strategies/agem.py
new file mode 100644
index 0000000..3fb8621
--- /dev/null
+++ b/src/cl_bench/strategies/agem.py
@@ -0,0 +1,134 @@
+from __future__ import annotations
+
+import torch
+from torch import nn
+
+from cl_bench.strategies.base import ContinualLearningStrategy
+from cl_bench.strategies.replay import ReservoirReplayBuffer
+
+
+class AGEMStrategy(ContinualLearningStrategy):
+ """A-GEM gradient projection against replay-memory reference gradients."""
+
+ def __init__(
+ self,
+ model: nn.Module,
+ device: torch.device,
+ learning_rate: float,
+ buffer_size: int,
+ memory_batch_size: int,
+ seed: int,
+ ):
+ super().__init__(model=model, device=device, learning_rate=learning_rate)
+ self.buffer = ReservoirReplayBuffer(capacity=buffer_size, seed=seed)
+ self.memory_batch_size = memory_batch_size
+ self.last_gradient_dot = 0.0
+ self.last_projection_applied = 0.0
+
+ def compute_loss(
+ self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int
+ ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]:
+ del task_id
+ logits = self.model(inputs)
+ loss = self.criterion(logits, targets)
+ return (
+ loss,
+ logits,
+ {
+ "ce_loss": float(loss.detach().item()),
+ "agem_gradient_dot": self.last_gradient_dot,
+ "agem_projection_applied": self.last_projection_applied,
+ },
+ )
+
+ def after_backward(
+ self,
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ task_id: int,
+ ) -> None:
+ del inputs, targets, task_id
+ self.last_gradient_dot = 0.0
+ self.last_projection_applied = 0.0
+ if len(self.buffer) == 0 or self.memory_batch_size <= 0:
+ return
+
+ current_grad = _flatten_gradients(self.model)
+ replay_inputs, replay_targets = self.buffer.sample(self.memory_batch_size)
+ replay_inputs = replay_inputs.to(self.device)
+ replay_targets = replay_targets.to(self.device)
+
+ self.optimizer.zero_grad(set_to_none=True)
+ reference_loss = self.criterion(self.model(replay_inputs), replay_targets)
+ reference_loss.backward()
+ reference_grad = _flatten_gradients(self.model)
+
+ dot_product = torch.dot(current_grad, reference_grad)
+ self.last_gradient_dot = float(dot_product.detach().cpu().item())
+ if dot_product < 0:
+ reference_norm = torch.dot(reference_grad, reference_grad).clamp_min(1e-12)
+ current_grad = current_grad - (dot_product / reference_norm) * reference_grad
+ self.last_projection_applied = 1.0
+
+ _assign_flattened_gradients(self.model, current_grad)
+
+ def observe_batch(
+ self,
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ logits: torch.Tensor,
+ task_id: int,
+ ) -> None:
+ del logits, task_id
+ self.buffer.add_batch(inputs, targets)
+
+ def extra_state_dict(self) -> dict[str, object]:
+ return {
+ "buffer_inputs": [sample.inputs for sample in self.buffer.samples],
+ "buffer_targets": [sample.target for sample in self.buffer.samples],
+ "seen_count": self.buffer.seen_count,
+ "memory_batch_size": self.memory_batch_size,
+ }
+
+
+def _flatten_gradients(model: nn.Module) -> torch.Tensor:
+ pieces = []
+ for parameter in model.parameters():
+ if not parameter.requires_grad:
+ continue
+ if parameter.grad is None:
+ pieces.append(torch.zeros_like(parameter).reshape(-1))
+ else:
+ pieces.append(parameter.grad.detach().clone().reshape(-1))
+ if not pieces:
+ return torch.zeros((), device=next(model.parameters()).device)
+ return torch.cat(pieces)
+
+
+def _assign_flattened_gradients(model: nn.Module, flat_gradient: torch.Tensor) -> None:
+ offset = 0
+ for parameter in model.parameters():
+ if not parameter.requires_grad:
+ continue
+ count = parameter.numel()
+ gradient = flat_gradient[offset : offset + count].view_as(parameter).clone()
+ parameter.grad = gradient
+ offset += count
+
+
+def build_agem(
+ model: nn.Module,
+ device: torch.device,
+ learning_rate: float,
+ buffer_size: int,
+ memory_batch_size: int,
+ seed: int,
+) -> AGEMStrategy:
+ return AGEMStrategy(
+ model=model,
+ device=device,
+ learning_rate=learning_rate,
+ buffer_size=buffer_size,
+ memory_batch_size=memory_batch_size,
+ seed=seed,
+ )
diff --git a/src/cl_bench/strategies/base.py b/src/cl_bench/strategies/base.py
new file mode 100644
index 0000000..85632b8
--- /dev/null
+++ b/src/cl_bench/strategies/base.py
@@ -0,0 +1,183 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import Any
+
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
+
+
+class ContinualLearningStrategy(ABC):
+ """Stable lifecycle interface shared by all continual-learning methods."""
+
+ def __init__(self, model: nn.Module, device: torch.device, learning_rate: float):
+ self.model = model
+ self.device = device
+ self.learning_rate = learning_rate
+ self.criterion = nn.CrossEntropyLoss()
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
+ self.current_task = -1
+ self.seen_tasks = 0
+
+ def train_task(
+ self,
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ task_id: int,
+ epochs: int,
+ ) -> list[dict[str, float | int]]:
+ self.current_task = task_id
+ self.seen_tasks = max(self.seen_tasks, task_id + 1)
+ self.before_task(task_id)
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
+
+ best_state: dict[str, torch.Tensor] | None = None
+ best_val_loss = float("inf")
+ history: list[dict[str, float | int]] = []
+
+ for epoch in range(epochs):
+ train_metrics = self._train_epoch(train_loader, task_id)
+ val_metrics = self.evaluate(val_loader)
+ epoch_metrics = {
+ "task_id": task_id,
+ "epoch": epoch + 1,
+ **{f"train_{key}": value for key, value in train_metrics.items()},
+ **{f"val_{key}": value for key, value in val_metrics.items()},
+ }
+ history.append(epoch_metrics)
+
+ if val_metrics["loss"] < best_val_loss:
+ best_val_loss = float(val_metrics["loss"])
+ best_state = clone_state_dict(self.model)
+
+ if best_state is not None:
+ load_state_dict(self.model, best_state, self.device)
+
+ self.after_task(train_loader, task_id)
+ return history
+
+ def evaluate(self, data_loader: DataLoader) -> dict[str, float]:
+ self.model.eval()
+ total_loss = 0.0
+ total_correct = 0
+ total_examples = 0
+
+ with torch.no_grad():
+ for inputs, targets in data_loader:
+ inputs = inputs.to(self.device)
+ targets = targets.to(self.device)
+ logits = self.model(inputs)
+ loss = self.criterion(logits, targets)
+ total_loss += float(loss.item()) * inputs.size(0)
+ total_correct += int((logits.argmax(dim=1) == targets).sum().item())
+ total_examples += int(targets.numel())
+
+ if total_examples == 0:
+ return {"loss": 0.0, "accuracy": 0.0, "examples": 0.0}
+ return {
+ "loss": total_loss / total_examples,
+ "accuracy": 100.0 * total_correct / total_examples,
+ "examples": float(total_examples),
+ }
+
+ def save_checkpoint(self, path: str | Path) -> None:
+ path = Path(path)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ torch.save(
+ {
+ "model_state_dict": self.model.state_dict(),
+ "optimizer_state_dict": self.optimizer.state_dict(),
+ "current_task": self.current_task,
+ "seen_tasks": self.seen_tasks,
+ "strategy_state": self.extra_state_dict(),
+ },
+ path,
+ )
+
+ def load_checkpoint(self, path: str | Path) -> None:
+ checkpoint = torch.load(path, map_location=self.device)
+ self.model.load_state_dict(checkpoint["model_state_dict"])
+ self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
+ self.current_task = int(checkpoint["current_task"])
+ self.seen_tasks = int(checkpoint["seen_tasks"])
+ self.load_extra_state_dict(checkpoint.get("strategy_state", {}))
+
+ def before_task(self, task_id: int) -> None:
+ del task_id
+
+ def after_task(self, train_loader: DataLoader, task_id: int) -> None:
+ del train_loader, task_id
+
+ def extra_state_dict(self) -> dict[str, Any]:
+ return {}
+
+ def load_extra_state_dict(self, state: dict[str, Any]) -> None:
+ del state
+
+ def _train_epoch(self, train_loader: DataLoader, task_id: int) -> dict[str, float]:
+ self.model.train()
+ totals: dict[str, float] = {"loss": 0.0, "ce_loss": 0.0}
+ total_correct = 0
+ total_examples = 0
+
+ for inputs, targets in train_loader:
+ inputs = inputs.to(self.device)
+ targets = targets.to(self.device)
+
+ self.optimizer.zero_grad(set_to_none=True)
+ loss, logits, components = self.compute_loss(inputs, targets, task_id)
+ loss.backward()
+ self.after_backward(inputs, targets, task_id)
+ self.optimizer.step()
+ self.observe_batch(inputs, targets, logits, task_id)
+
+ batch_size = int(targets.numel())
+ total_examples += batch_size
+ total_correct += int((logits.argmax(dim=1) == targets).sum().item())
+ totals["loss"] += float(loss.item()) * batch_size
+ for name, value in components.items():
+ totals[name] = totals.get(name, 0.0) + float(value) * batch_size
+
+ if total_examples == 0:
+ return {"loss": 0.0, "accuracy": 0.0}
+
+ metrics = {name: value / total_examples for name, value in totals.items()}
+ metrics["accuracy"] = 100.0 * total_correct / total_examples
+ return metrics
+
+ @abstractmethod
+ def compute_loss(
+ self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int
+ ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]:
+ raise NotImplementedError
+
+ def after_backward(
+ self,
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ task_id: int,
+ ) -> None:
+ del inputs, targets, task_id
+
+ def observe_batch(
+ self,
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ logits: torch.Tensor,
+ task_id: int,
+ ) -> None:
+ del inputs, targets, logits, task_id
+
+
+def clone_state_dict(module: nn.Module) -> dict[str, torch.Tensor]:
+ """Deep-copy a module state dict onto CPU to avoid mutable checkpoint aliases."""
+
+ return {name: tensor.detach().cpu().clone() for name, tensor in module.state_dict().items()}
+
+
+def load_state_dict(
+ module: nn.Module, state_dict: dict[str, torch.Tensor], device: torch.device
+) -> None:
+ module.load_state_dict({name: tensor.to(device) for name, tensor in state_dict.items()})
diff --git a/src/cl_bench/strategies/baseline.py b/src/cl_bench/strategies/baseline.py
new file mode 100644
index 0000000..05d7b9c
--- /dev/null
+++ b/src/cl_bench/strategies/baseline.py
@@ -0,0 +1,24 @@
+from __future__ import annotations
+
+import torch
+from torch import nn
+
+from cl_bench.strategies.base import ContinualLearningStrategy
+
+
+class BaselineStrategy(ContinualLearningStrategy):
+ """Naive sequential fine-tuning baseline."""
+
+ def compute_loss(
+ self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int
+ ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]:
+ del task_id
+ logits = self.model(inputs)
+ loss = self.criterion(logits, targets)
+ return loss, logits, {"ce_loss": float(loss.detach().item())}
+
+
+def build_baseline(
+ model: nn.Module, device: torch.device, learning_rate: float
+) -> BaselineStrategy:
+ return BaselineStrategy(model=model, device=device, learning_rate=learning_rate)
diff --git a/src/cl_bench/strategies/derpp.py b/src/cl_bench/strategies/derpp.py
new file mode 100644
index 0000000..2adc1a1
--- /dev/null
+++ b/src/cl_bench/strategies/derpp.py
@@ -0,0 +1,103 @@
+from __future__ import annotations
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from cl_bench.strategies.base import ContinualLearningStrategy
+from cl_bench.strategies.replay import ReservoirReplayBuffer
+
+
+class DERPPStrategy(ContinualLearningStrategy):
+ """Dark Experience Replay++ with stored logits and replay labels."""
+
+ def __init__(
+ self,
+ model: nn.Module,
+ device: torch.device,
+ learning_rate: float,
+ buffer_size: int,
+ replay_batch_size: int,
+ alpha: float,
+ beta: float,
+ seed: int,
+ ):
+ super().__init__(model=model, device=device, learning_rate=learning_rate)
+ self.buffer = ReservoirReplayBuffer(capacity=buffer_size, seed=seed)
+ self.replay_batch_size = replay_batch_size
+ self.alpha = alpha
+ self.beta = beta
+
+ def compute_loss(
+ self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int
+ ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]:
+ del task_id
+ logits = self.model(inputs)
+ ce_loss = self.criterion(logits, targets)
+ distillation_loss = torch.zeros((), device=self.device)
+ replay_ce_loss = torch.zeros((), device=self.device)
+
+ if len(self.buffer) > 0 and self.replay_batch_size > 0:
+ replay_inputs, replay_targets, replay_logits = self.buffer.sample_tensors(
+ self.replay_batch_size,
+ require_logits=True,
+ )
+ replay_inputs = replay_inputs.to(self.device)
+ replay_targets = replay_targets.to(self.device)
+ replay_logits = replay_logits.to(self.device)
+ current_replay_logits = self.model(replay_inputs)
+ distillation_loss = F.mse_loss(current_replay_logits, replay_logits)
+ replay_ce_loss = self.criterion(current_replay_logits, replay_targets)
+
+ loss = ce_loss + self.alpha * distillation_loss + self.beta * replay_ce_loss
+ return (
+ loss,
+ logits,
+ {
+ "ce_loss": float(ce_loss.detach().item()),
+ "derpp_distillation_loss": float(distillation_loss.detach().item()),
+ "derpp_replay_ce_loss": float(replay_ce_loss.detach().item()),
+ },
+ )
+
+ def observe_batch(
+ self,
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ logits: torch.Tensor,
+ task_id: int,
+ ) -> None:
+ del task_id
+ self.buffer.add_batch(inputs, targets, logits=logits)
+
+ def extra_state_dict(self) -> dict[str, object]:
+ return {
+ "buffer_inputs": [sample.inputs for sample in self.buffer.samples],
+ "buffer_targets": [sample.target for sample in self.buffer.samples],
+ "buffer_logits": [sample.logits for sample in self.buffer.samples],
+ "seen_count": self.buffer.seen_count,
+ "alpha": self.alpha,
+ "beta": self.beta,
+ }
+
+
+def build_derpp(
+ model: nn.Module,
+ device: torch.device,
+ learning_rate: float,
+ buffer_size: int,
+ replay_batch_size: int,
+ alpha: float,
+ beta: float,
+ seed: int,
+) -> DERPPStrategy:
+ return DERPPStrategy(
+ model=model,
+ device=device,
+ learning_rate=learning_rate,
+ buffer_size=buffer_size,
+ replay_batch_size=replay_batch_size,
+ alpha=alpha,
+ beta=beta,
+ seed=seed,
+ )
diff --git a/src/cl_bench/strategies/ewc.py b/src/cl_bench/strategies/ewc.py
new file mode 100644
index 0000000..e860ccf
--- /dev/null
+++ b/src/cl_bench/strategies/ewc.py
@@ -0,0 +1,140 @@
+from __future__ import annotations
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.utils.data import DataLoader
+
+from cl_bench.strategies.base import ContinualLearningStrategy, clone_state_dict
+
+
+class EWCStrategy(ContinualLearningStrategy):
+ """Elastic Weight Consolidation with deterministic empirical Fisher estimates."""
+
+ def __init__(
+ self,
+ model: nn.Module,
+ device: torch.device,
+ learning_rate: float,
+ ewc_lambda: float,
+ fisher_samples: int,
+ ):
+ super().__init__(model=model, device=device, learning_rate=learning_rate)
+ self.ewc_lambda = ewc_lambda
+ self.fisher_samples = fisher_samples
+ self.fishers: list[dict[str, torch.Tensor]] = []
+ self.optimal_parameters: list[dict[str, torch.Tensor]] = []
+
+ def compute_loss(
+ self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int
+ ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]:
+ del task_id
+ logits = self.model(inputs)
+ ce_loss = self.criterion(logits, targets)
+ ewc_penalty = self._ewc_penalty()
+ loss = ce_loss + self.ewc_lambda * ewc_penalty
+ return (
+ loss,
+ logits,
+ {
+ "ce_loss": float(ce_loss.detach().item()),
+ "ewc_penalty": float(ewc_penalty.detach().item()),
+ },
+ )
+
+ def after_task(self, train_loader: DataLoader, task_id: int) -> None:
+ del task_id
+ self.fishers.append(self._estimate_fisher(train_loader))
+ self.optimal_parameters.append(clone_state_dict(self.model))
+
+ def extra_state_dict(self) -> dict[str, object]:
+ return {
+ "ewc_lambda": self.ewc_lambda,
+ "fisher_samples": self.fisher_samples,
+ "fishers": self.fishers,
+ "optimal_parameters": self.optimal_parameters,
+ }
+
+ def load_extra_state_dict(self, state: dict[str, object]) -> None:
+ self.ewc_lambda = float(state.get("ewc_lambda", self.ewc_lambda))
+ self.fisher_samples = int(state.get("fisher_samples", self.fisher_samples))
+ self.fishers = [
+ {name: tensor.to(self.device) for name, tensor in fisher.items()}
+ for fisher in state.get("fishers", [])
+ ]
+ self.optimal_parameters = [
+ {name: tensor.to(self.device) for name, tensor in params.items()}
+ for params in state.get("optimal_parameters", [])
+ ]
+
+ def _ewc_penalty(self) -> torch.Tensor:
+ penalty = torch.zeros((), device=self.device)
+ if not self.fishers:
+ return penalty
+
+ named_parameters = dict(self.model.named_parameters())
+ for fisher, optimum in zip(self.fishers, self.optimal_parameters, strict=True):
+ for name, parameter in named_parameters.items():
+ if not parameter.requires_grad:
+ continue
+ fisher_tensor = fisher[name].to(self.device)
+ optimum_tensor = optimum[name].to(self.device)
+ penalty = penalty + 0.5 * torch.sum(
+ fisher_tensor * torch.square(parameter - optimum_tensor)
+ )
+ return penalty
+
+ def _estimate_fisher(self, data_loader: DataLoader) -> dict[str, torch.Tensor]:
+ self.model.eval()
+ fisher_loader = DataLoader(
+ data_loader.dataset,
+ batch_size=data_loader.batch_size,
+ shuffle=False,
+ num_workers=0,
+ )
+ fisher = {
+ name: torch.zeros_like(parameter, device=self.device)
+ for name, parameter in self.model.named_parameters()
+ if parameter.requires_grad
+ }
+ sample_count = 0
+ max_samples = self.fisher_samples if self.fisher_samples > 0 else float("inf")
+
+ for inputs, targets in fisher_loader:
+ inputs = inputs.to(self.device)
+ targets = targets.to(self.device)
+ for index in range(inputs.size(0)):
+ if sample_count >= max_samples:
+ break
+ self.model.zero_grad(set_to_none=True)
+ logits = self.model(inputs[index : index + 1])
+ log_probability = F.log_softmax(logits, dim=1)[0, targets[index]]
+ log_probability.backward()
+ for name, parameter in self.model.named_parameters():
+ if parameter.grad is not None and parameter.requires_grad:
+ fisher[name] += torch.square(parameter.grad.detach())
+ sample_count += 1
+ if sample_count >= max_samples:
+ break
+
+ self.model.zero_grad(set_to_none=True)
+ if sample_count == 0:
+ return {name: value.detach().cpu() for name, value in fisher.items()}
+
+ return {name: (value / sample_count).detach().cpu() for name, value in fisher.items()}
+
+
+def build_ewc(
+ model: nn.Module,
+ device: torch.device,
+ learning_rate: float,
+ ewc_lambda: float,
+ fisher_samples: int,
+) -> EWCStrategy:
+ return EWCStrategy(
+ model=model,
+ device=device,
+ learning_rate=learning_rate,
+ ewc_lambda=ewc_lambda,
+ fisher_samples=fisher_samples,
+ )
diff --git a/src/cl_bench/strategies/lwf.py b/src/cl_bench/strategies/lwf.py
new file mode 100644
index 0000000..da15fa7
--- /dev/null
+++ b/src/cl_bench/strategies/lwf.py
@@ -0,0 +1,103 @@
+from __future__ import annotations
+
+import copy
+from typing import Any
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.utils.data import DataLoader
+
+from cl_bench.strategies.base import ContinualLearningStrategy
+
+
+class LwFStrategy(ContinualLearningStrategy):
+ """Learning without Forgetting using a frozen teacher from the previous task."""
+
+ def __init__(
+ self,
+ model: nn.Module,
+ device: torch.device,
+ learning_rate: float,
+ alpha: float,
+ temperature: float,
+ ):
+ super().__init__(model=model, device=device, learning_rate=learning_rate)
+ self.alpha = alpha
+ self.temperature = temperature
+ self.teacher_model: nn.Module | None = None
+
+ def compute_loss(
+ self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int
+ ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]:
+ del task_id
+ logits = self.model(inputs)
+ ce_loss = self.criterion(logits, targets)
+ distillation_loss = torch.zeros((), device=self.device)
+
+ if self.teacher_model is not None:
+ with torch.no_grad():
+ teacher_logits = self.teacher_model(inputs)
+ temperature = self.temperature
+ student_log_probs = F.log_softmax(logits / temperature, dim=1)
+ teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
+ distillation_loss = (
+ F.kl_div(student_log_probs, teacher_probs, reduction="batchmean")
+ * temperature
+ * temperature
+ )
+
+ loss = ce_loss + self.alpha * distillation_loss
+ return (
+ loss,
+ logits,
+ {
+ "ce_loss": float(ce_loss.detach().item()),
+ "distillation_loss": float(distillation_loss.detach().item()),
+ },
+ )
+
+ def after_task(self, train_loader: DataLoader, task_id: int) -> None:
+ del train_loader, task_id
+ self.teacher_model = copy.deepcopy(self.model).to(self.device)
+ self.teacher_model.eval()
+ for parameter in self.teacher_model.parameters():
+ parameter.requires_grad_(False)
+
+ def extra_state_dict(self) -> dict[str, Any]:
+ return {
+ "alpha": self.alpha,
+ "temperature": self.temperature,
+ "teacher_state_dict": None
+ if self.teacher_model is None
+ else self.teacher_model.state_dict(),
+ }
+
+ def load_extra_state_dict(self, state: dict[str, Any]) -> None:
+ self.alpha = float(state.get("alpha", self.alpha))
+ self.temperature = float(state.get("temperature", self.temperature))
+ teacher_state = state.get("teacher_state_dict")
+ if teacher_state is None:
+ self.teacher_model = None
+ return
+ self.teacher_model = copy.deepcopy(self.model).to(self.device)
+ self.teacher_model.load_state_dict(teacher_state)
+ self.teacher_model.eval()
+ for parameter in self.teacher_model.parameters():
+ parameter.requires_grad_(False)
+
+
+def build_lwf(
+ model: nn.Module,
+ device: torch.device,
+ learning_rate: float,
+ alpha: float,
+ temperature: float,
+) -> LwFStrategy:
+ return LwFStrategy(
+ model=model,
+ device=device,
+ learning_rate=learning_rate,
+ alpha=alpha,
+ temperature=temperature,
+ )
diff --git a/src/cl_bench/strategies/replay.py b/src/cl_bench/strategies/replay.py
new file mode 100644
index 0000000..c41d60e
--- /dev/null
+++ b/src/cl_bench/strategies/replay.py
@@ -0,0 +1,169 @@
+from __future__ import annotations
+
+import random
+from dataclasses import dataclass
+
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
+
+from cl_bench.strategies.base import ContinualLearningStrategy
+
+
+@dataclass
+class ReplaySample:
+ inputs: torch.Tensor
+ target: int
+ logits: torch.Tensor | None = None
+
+
+class ReservoirReplayBuffer:
+ """Bounded replay buffer using reservoir sampling over the observed stream."""
+
+ def __init__(self, capacity: int, seed: int = 0):
+ if capacity < 0:
+ raise ValueError("Replay buffer capacity must be non-negative.")
+ self.capacity = capacity
+ self.rng = random.Random(seed)
+ self.samples: list[ReplaySample] = []
+ self.seen_count = 0
+
+ def __len__(self) -> int:
+ return len(self.samples)
+
+ def add_batch(
+ self,
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ logits: torch.Tensor | None = None,
+ ) -> None:
+ logits_cpu = None if logits is None else logits.detach().cpu()
+ for index, (input_tensor, target) in enumerate(
+ zip(inputs.detach().cpu(), targets.detach().cpu(), strict=True)
+ ):
+ sample_logits = None if logits_cpu is None else logits_cpu[index]
+ self.add(input_tensor, int(target.item()), sample_logits)
+
+ def add(self, inputs: torch.Tensor, target: int, logits: torch.Tensor | None = None) -> None:
+ self.seen_count += 1
+ if self.capacity == 0:
+ return
+ sample = ReplaySample(
+ inputs=inputs.detach().cpu().clone(),
+ target=int(target),
+ logits=None if logits is None else logits.detach().cpu().clone(),
+ )
+ if len(self.samples) < self.capacity:
+ self.samples.append(sample)
+ return
+
+ replacement_index = self.rng.randrange(self.seen_count)
+ if replacement_index < self.capacity:
+ self.samples[replacement_index] = sample
+
+ def sample(self, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]:
+ inputs, targets, _ = self.sample_tensors(batch_size)
+ return inputs, targets
+
+ def sample_tensors(
+ self, batch_size: int, require_logits: bool = False
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
+ chosen = self.sample_samples(batch_size, require_logits=require_logits)
+ inputs = torch.stack([sample.inputs for sample in chosen])
+ targets = torch.tensor([sample.target for sample in chosen], dtype=torch.long)
+ if require_logits:
+ logits = torch.stack([sample.logits for sample in chosen if sample.logits is not None])
+ else:
+ logits = None
+ return inputs, targets, logits
+
+ def sample_samples(self, batch_size: int, require_logits: bool = False) -> list[ReplaySample]:
+ candidates = [
+ sample for sample in self.samples if not require_logits or sample.logits is not None
+ ]
+ if not candidates:
+ raise ValueError("Cannot sample from an empty replay buffer.")
+ return self.rng.sample(candidates, k=min(batch_size, len(candidates)))
+
+
+class ReplayStrategy(ContinualLearningStrategy):
+ def __init__(
+ self,
+ model: nn.Module,
+ device: torch.device,
+ learning_rate: float,
+ buffer_size: int,
+ replay_batch_size: int,
+ replay_loss_weight: float,
+ seed: int,
+ ):
+ super().__init__(model=model, device=device, learning_rate=learning_rate)
+ self.buffer = ReservoirReplayBuffer(capacity=buffer_size, seed=seed)
+ self.replay_batch_size = replay_batch_size
+ self.replay_loss_weight = replay_loss_weight
+
+ def compute_loss(
+ self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int
+ ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]:
+ del task_id
+ logits = self.model(inputs)
+ ce_loss = self.criterion(logits, targets)
+ replay_loss = torch.zeros((), device=self.device)
+
+ if len(self.buffer) > 0 and self.replay_batch_size > 0:
+ replay_inputs, replay_targets = self.buffer.sample(self.replay_batch_size)
+ replay_inputs = replay_inputs.to(self.device)
+ replay_targets = replay_targets.to(self.device)
+ replay_logits = self.model(replay_inputs)
+ replay_loss = self.criterion(replay_logits, replay_targets)
+
+ loss = ce_loss + self.replay_loss_weight * replay_loss
+ return (
+ loss,
+ logits,
+ {
+ "ce_loss": float(ce_loss.detach().item()),
+ "replay_loss": float(replay_loss.detach().item()),
+ },
+ )
+
+ def observe_batch(
+ self,
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ logits: torch.Tensor,
+ task_id: int,
+ ) -> None:
+ del logits, task_id
+ self.buffer.add_batch(inputs, targets)
+
+ def after_task(self, train_loader: DataLoader, task_id: int) -> None:
+ del train_loader, task_id
+
+ def extra_state_dict(self) -> dict[str, object]:
+ return {
+ "buffer_inputs": [sample.inputs for sample in self.buffer.samples],
+ "buffer_targets": [sample.target for sample in self.buffer.samples],
+ "buffer_logits": [sample.logits for sample in self.buffer.samples],
+ "seen_count": self.buffer.seen_count,
+ }
+
+
+def build_replay(
+ model: nn.Module,
+ device: torch.device,
+ learning_rate: float,
+ buffer_size: int,
+ replay_batch_size: int,
+ replay_loss_weight: float,
+ seed: int,
+) -> ReplayStrategy:
+ return ReplayStrategy(
+ model=model,
+ device=device,
+ learning_rate=learning_rate,
+ buffer_size=buffer_size,
+ replay_batch_size=replay_batch_size,
+ replay_loss_weight=replay_loss_weight,
+ seed=seed,
+ )
diff --git a/src/cl_bench/tracking.py b/src/cl_bench/tracking.py
new file mode 100644
index 0000000..d1e6eae
--- /dev/null
+++ b/src/cl_bench/tracking.py
@@ -0,0 +1,177 @@
+from __future__ import annotations
+
+import json
+import platform
+import subprocess
+from datetime import datetime, timezone
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+
+
+class ExperimentTracker:
+ """Writes reproducibility artifacts for one benchmark run."""
+
+ def __init__(self, run_dir: Path):
+ self.run_dir = run_dir
+ self.run_dir.mkdir(parents=True, exist_ok=True)
+ self.events_path = self.run_dir / "metrics.jsonl"
+
+ def log_event(self, event: dict[str, Any]) -> None:
+ payload = {"time_utc": utc_now(), **event}
+ with self.events_path.open("a", encoding="utf-8") as handle:
+ handle.write(json.dumps(payload, sort_keys=True) + "\n")
+
+ def write_json(self, name: str, payload: dict[str, Any] | list[Any]) -> Path:
+ path = self.run_dir / name
+ with path.open("w", encoding="utf-8") as handle:
+ json.dump(payload, handle, indent=2, sort_keys=True)
+ handle.write("\n")
+ return path
+
+ def write_matrix_csv(self, name: str, matrix: np.ndarray) -> Path:
+ path = self.run_dir / name
+ np.savetxt(path, matrix, delimiter=",", fmt="%.6f")
+ return path
+
+
+def create_run_dir(output_dir: str | Path, benchmark_name: str, method: str) -> Path:
+ timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
+ safe_name = _safe_slug(benchmark_name)
+ safe_method = _safe_slug(method)
+ return Path(output_dir) / f"{safe_name}_{safe_method}_{timestamp}"
+
+
+def git_commit(repo_dir: str | Path | None = None) -> str | None:
+ try:
+ completed = subprocess.run(
+ ["git", "rev-parse", "HEAD"],
+ cwd=repo_dir,
+ check=True,
+ capture_output=True,
+ text=True,
+ )
+ except (OSError, subprocess.CalledProcessError):
+ return None
+ commit = completed.stdout.strip()
+ return commit or None
+
+
+def utc_now() -> str:
+ return datetime.now(timezone.utc).isoformat()
+
+
+def _safe_slug(value: str) -> str:
+ return "".join(char if char.isalnum() or char in {"-", "_"} else "-" for char in value).strip(
+ "-"
+ )
+
+
+class MLflowRunLogger:
+ """Optional local MLflow logger layered on top of JSON artifacts."""
+
+ def __init__(
+ self,
+ tracking_uri: str | Path,
+ experiment_name: str,
+ run_name: str,
+ enabled: bool,
+ ):
+ self.tracking_uri = tracking_uri
+ self.experiment_name = experiment_name
+ self.run_name = run_name
+ self.enabled = enabled
+ self._mlflow: Any | None = None
+ self._active_run: Any | None = None
+
+ def __enter__(self) -> MLflowRunLogger:
+ if not self.enabled:
+ return self
+ try:
+ import mlflow
+ except ImportError as exc:
+ raise RuntimeError(
+ "MLflow tracking was requested but mlflow is not installed. "
+ 'Install with: python -m pip install -e ".[experiment]"'
+ ) from exc
+
+ self._mlflow = mlflow
+ mlflow.set_tracking_uri(_normalize_mlflow_uri(self.tracking_uri))
+ mlflow.set_experiment(self.experiment_name)
+ self._active_run = mlflow.start_run(run_name=self.run_name)
+ return self
+
+ def __exit__(self, exc_type: object, exc: object, traceback: object) -> None:
+ if self._mlflow is not None and self._active_run is not None:
+ self._mlflow.end_run(status="FAILED" if exc_type else "FINISHED")
+
+ def log_params(self, params: dict[str, Any]) -> None:
+ if self._mlflow is None:
+ return
+ for key, value in _flatten_params(params).items():
+ self._mlflow.log_param(key, value)
+
+ def set_tags(self, tags: dict[str, Any]) -> None:
+ if self._mlflow is None:
+ return
+ for key, value in tags.items():
+ if value is not None:
+ self._mlflow.set_tag(key, str(value))
+
+ def log_environment(self) -> None:
+ if self._mlflow is None:
+ return
+ self.set_tags(
+ {
+ "python": platform.python_version(),
+ "platform": platform.platform(),
+ "torch": torch.__version__,
+ "cuda_available": torch.cuda.is_available(),
+ "mps_available": hasattr(torch.backends, "mps")
+ and torch.backends.mps.is_available(),
+ }
+ )
+
+ def log_metrics(
+ self, metrics: dict[str, float | int | str | None], step: int | None = None
+ ) -> None:
+ if self._mlflow is None:
+ return
+ for key, value in metrics.items():
+ if isinstance(value, (int, float)):
+ self._mlflow.log_metric(key, float(value), step=step)
+
+ def log_artifacts(self, run_dir: str | Path) -> None:
+ if self._mlflow is None:
+ return
+ self._mlflow.log_artifacts(str(run_dir))
+
+
+def _normalize_mlflow_uri(uri: str | Path) -> str:
+ value = str(uri)
+ if value.startswith("sqlite:///"):
+ db_path = Path(value.removeprefix("sqlite:///"))
+ db_path.parent.mkdir(parents=True, exist_ok=True)
+ return value
+ if "://" in value:
+ return value
+ return Path(value).resolve().as_uri()
+
+
+def _flatten_params(
+ params: dict[str, Any], prefix: str = ""
+) -> dict[str, str | int | float | bool]:
+ flattened: dict[str, str | int | float | bool] = {}
+ for key, value in params.items():
+ name = f"{prefix}.{key}" if prefix else str(key)
+ if isinstance(value, dict):
+ flattened.update(_flatten_params(value, name))
+ elif isinstance(value, (str, int, float, bool)):
+ flattened[name] = value
+ elif value is None:
+ flattened[name] = "null"
+ else:
+ flattened[name] = json.dumps(value, sort_keys=True, default=str)
+ return flattened
diff --git a/src/data/__init__.py b/src/data/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/src/data/data_loader.py b/src/data/data_loader.py
deleted file mode 100644
index 618adcc..0000000
--- a/src/data/data_loader.py
+++ /dev/null
@@ -1,253 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-"""
-Data loading module for the Continual Learning System.
-Handles loading and preprocessing of different datasets and task sequences.
-"""
-
-import os
-import logging
-import numpy as np
-from typing import List, Dict, Any, Union, Tuple
-
-import torch
-import torchvision
-import torchvision.transforms as transforms
-from torch.utils.data import DataLoader, Subset, ConcatDataset, random_split
-
-logger = logging.getLogger(__name__)
-
-
-def get_dataset(dataset_name: str, root: str = './data', train: bool = True, transform=None):
- """
- Get the specified dataset.
-
- Args:
- dataset_name (str): Name of the dataset ('mnist', 'fashion_mnist', 'kmnist', etc.)
- root (str): Root directory for dataset storage
- train (bool): Whether to load the training set or the test set
- transform: Transformations to apply to the data
-
- Returns:
- torch.utils.data.Dataset: The requested dataset
- """
- dataset_name = dataset_name.lower()
-
- # Create directory if it doesn't exist
- os.makedirs(root, exist_ok=True)
-
- # Default transformations if none provided
- if transform is None:
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,)) # MNIST-like normalization
- ])
-
- # Load dataset based on name
- if dataset_name == 'mnist':
- dataset = torchvision.datasets.MNIST(
- root=root, train=train, download=True, transform=transform
- )
- elif dataset_name == 'fashion_mnist':
- dataset = torchvision.datasets.FashionMNIST(
- root=root, train=train, download=True, transform=transform
- )
- elif dataset_name == 'kmnist':
- dataset = torchvision.datasets.KMNIST(
- root=root, train=train, download=True, transform=transform
- )
- elif dataset_name == 'cifar10':
- if transform is None:
- # Default CIFAR10 transformation
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
- ])
-
- dataset = torchvision.datasets.CIFAR10(
- root=root, train=train, download=True, transform=transform
- )
- elif dataset_name == 'cifar100':
- if transform is None:
- # Default CIFAR100 transformation
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
- ])
-
- dataset = torchvision.datasets.CIFAR100(
- root=root, train=train, download=True, transform=transform
- )
- else:
- raise ValueError(f"Unknown dataset: {dataset_name}")
-
- return dataset
-
-
-def create_class_subset(dataset, classes: Union[List[int], str]):
- """
- Create a subset of a dataset containing only the specified classes.
-
- Args:
- dataset: PyTorch dataset
- classes: List of class indices to include, or 'all' to include all classes
-
- Returns:
- tuple: (Subset of the dataset, list of classes included)
- """
- if classes == 'all':
- # If we want all classes, return the whole dataset and a list of all class indices
- if hasattr(dataset, 'classes'):
- all_classes = list(range(len(dataset.classes)))
- else:
- # Try to infer the number of classes
- targets = dataset.targets if hasattr(dataset, 'targets') else dataset.targets
- all_classes = list(set(targets.numpy() if torch.is_tensor(targets) else targets))
-
- return dataset, all_classes
-
- # Get targets from the dataset
- if hasattr(dataset, 'targets'):
- targets = dataset.targets
- elif hasattr(dataset, 'target'):
- targets = dataset.target
- else:
- # For datasets that store targets differently (e.g., as part of __getitem__)
- # We'll need to iterate through the dataset
- targets = torch.tensor([dataset[i][1] for i in range(len(dataset))])
-
- # Convert targets to numpy array for easier filtering
- if torch.is_tensor(targets):
- targets = targets.numpy()
- elif not isinstance(targets, np.ndarray):
- targets = np.array(targets)
-
- # Get indices for the requested classes
- indices = [i for i, label in enumerate(targets) if label in classes]
-
- # Create and return the subset
- return Subset(dataset, indices), classes
-
-
-def split_dataset(dataset, val_split: float = 0.1):
- """
- Split a dataset into training and validation sets.
-
- Args:
- dataset: PyTorch dataset
- val_split (float): Fraction of data to use for validation
-
- Returns:
- tuple: (training subset, validation subset)
- """
- val_size = int(len(dataset) * val_split)
- train_size = len(dataset) - val_size
-
- train_dataset, val_dataset = random_split(
- dataset,
- [train_size, val_size],
- generator=torch.Generator().manual_seed(42)
- )
-
- return train_dataset, val_dataset
-
-
-def get_task_sequence(task_configs: List[Dict[str, Any]], batch_size: int = 32):
- """
- Create data loaders for a sequence of tasks.
-
- Args:
- task_configs (list): List of task configuration dictionaries
- batch_size (int): Batch size for data loaders
-
- Returns:
- list: List of task data dictionaries, each containing name, classes, and data loaders
- """
- # Data root directory is in the project root
- data_root = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'data')
-
- task_data = []
-
- for task_config in task_configs:
- dataset_name = task_config['dataset']
- task_name = task_config['name']
- classes = task_config['classes']
-
- logger.info(f"Loading task {task_name} (dataset: {dataset_name}, classes: {classes})")
-
- # Get training and test datasets
- train_dataset = get_dataset(dataset_name, root=data_root, train=True)
- test_dataset = get_dataset(dataset_name, root=data_root, train=False)
-
- # Create class-specific subsets if needed
- train_subset, actual_classes = create_class_subset(train_dataset, classes)
- test_subset, _ = create_class_subset(test_dataset, classes)
-
- # Split training set into train and validation
- train_split, val_split = split_dataset(train_subset)
-
- # Create data loaders
- train_loader = DataLoader(
- train_split,
- batch_size=batch_size,
- shuffle=True,
- num_workers=2,
- pin_memory=True
- )
-
- val_loader = DataLoader(
- val_split,
- batch_size=batch_size,
- shuffle=False,
- num_workers=2,
- pin_memory=True
- )
-
- test_loader = DataLoader(
- test_subset,
- batch_size=batch_size,
- shuffle=False,
- num_workers=2,
- pin_memory=True
- )
-
- # Store task data
- task_data.append({
- 'name': task_name,
- 'dataset': dataset_name,
- 'classes': actual_classes,
- 'train_loader': train_loader,
- 'val_loader': val_loader,
- 'test_loader': test_loader
- })
-
- logger.info(f"Loaded task {task_name} with {len(train_split)} training, "
- f"{len(val_split)} validation, and {len(test_subset)} test samples")
-
- return task_data
-
-
-if __name__ == '__main__':
- # Test the data loading functionality
- logging.basicConfig(level=logging.INFO)
-
- # Example task sequence
- task_configs = [
- {'name': 'mnist_0_4', 'dataset': 'mnist', 'classes': [0, 1, 2, 3, 4]},
- {'name': 'mnist_5_9', 'dataset': 'mnist', 'classes': [5, 6, 7, 8, 9]}
- ]
-
- task_data = get_task_sequence(task_configs)
-
- # Print information about the tasks
- for i, task in enumerate(task_data):
- logger.info(f"Task {i+1}: {task['name']}")
- logger.info(f" Classes: {task['classes']}")
- logger.info(f" Training samples: {len(task['train_loader'].dataset)}")
- logger.info(f" Validation samples: {len(task['val_loader'].dataset)}")
- logger.info(f" Test samples: {len(task['test_loader'].dataset)}")
-
- # Get a batch of data
- images, labels = next(iter(task['train_loader']))
- logger.info(f" Batch shape: {images.shape}, Labels: {labels.numpy()[:5]} ...")
\ No newline at end of file
diff --git a/src/main.py b/src/main.py
deleted file mode 100644
index e126705..0000000
--- a/src/main.py
+++ /dev/null
@@ -1,310 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-"""
-Main entry point for the Continual Learning System.
-This script orchestrates the training and evaluation of continual learning approaches.
-"""
-
-import os
-import sys
-import argparse
-import logging
-import yaml
-import torch
-import numpy as np
-import random
-from datetime import datetime
-
-# Add the 'src' directory to the Python path
-sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-
-from src.data.data_loader import get_task_sequence
-from src.models.model_factory import get_model
-from src.methods.baseline import BaselineLearner
-from src.methods.ewc import EWCLearner
-from src.methods.replay import ExperienceReplayLearner
-from src.methods.lwf import LwFLearner
-from src.utils.metrics import evaluate_performance, compute_forgetting
-from src.utils.visualization import plot_performance, plot_forgetting
-
-
-def setup_logging():
- """Set up logging configuration."""
- log_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'results', 'logs')
- os.makedirs(log_dir, exist_ok=True)
-
- log_file = os.path.join(log_dir, f'continual_learning_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
-
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.FileHandler(log_file),
- logging.StreamHandler()
- ]
- )
- return logging.getLogger(__name__)
-
-
-def set_seed(seed):
- """Set random seed for reproducibility."""
- if seed is not None:
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
-
-
-def parse_arguments():
- """Parse command line arguments."""
- parser = argparse.ArgumentParser(description='Continual Learning System')
-
- parser.add_argument('--method', type=str, required=True,
- choices=['baseline', 'ewc', 'replay', 'lwf'],
- help='Continual learning method to use')
-
- parser.add_argument('--tasks', type=str, default='mnist_split',
- help='Predefined task sequence (or path to custom YAML file)')
-
- parser.add_argument('--model', type=str, default='simple_cnn',
- help='Base model architecture')
-
- parser.add_argument('--epochs', type=int, default=5,
- help='Number of epochs per task')
-
- parser.add_argument('--batch_size', type=int, default=64,
- help='Batch size for training')
-
- parser.add_argument('--learning_rate', type=float, default=0.001,
- help='Learning rate for optimizer')
-
- # EWC specific arguments
- parser.add_argument('--lambda_ewc', type=float, default=5000,
- help='Regularization strength for EWC')
- parser.add_argument('--fisher_sample_size', type=int, default=200,
- help='Number of samples to estimate Fisher information')
-
- # Experience Replay specific arguments
- parser.add_argument('--buffer_size', type=int, default=500,
- help='Size of the replay buffer')
- parser.add_argument('--replay_batch_size', type=int, default=32,
- help='Batch size for replayed samples')
-
- # LwF specific arguments
- parser.add_argument('--temperature', type=float, default=2.0,
- help='Temperature for knowledge distillation in LwF')
- parser.add_argument('--alpha', type=float, default=1.0,
- help='Weight for distillation loss in LwF')
-
- # General arguments
- parser.add_argument('--seed', type=int, default=42,
- help='Random seed for reproducibility')
-
- parser.add_argument('--device', type=str, default=None,
- help='Device to use (cuda or cpu)')
-
- parser.add_argument('--eval_freq', type=int, default=1,
- help='Frequency of evaluation during training (in epochs)')
-
- parser.add_argument('--save_dir', type=str, default='results',
- help='Directory to save results')
-
- return parser.parse_args()
-
-
-def load_task_config(task_name_or_path):
- """
- Load task configuration from predefined sequences or a custom YAML file.
-
- Args:
- task_name_or_path (str): Name of predefined task or path to custom YAML file
-
- Returns:
- dict: Task configuration
- """
- predefined_tasks = {
- 'mnist_split': [
- {'name': 'mnist_0_4', 'dataset': 'mnist', 'classes': [0, 1, 2, 3, 4]},
- {'name': 'mnist_5_9', 'dataset': 'mnist', 'classes': [5, 6, 7, 8, 9]}
- ],
- 'multi_dataset': [
- {'name': 'mnist', 'dataset': 'mnist', 'classes': 'all'},
- {'name': 'fashion_mnist', 'dataset': 'fashion_mnist', 'classes': 'all'},
- {'name': 'kmnist', 'dataset': 'kmnist', 'classes': 'all'}
- ],
- }
-
- if task_name_or_path in predefined_tasks:
- return {'task_sequence': predefined_tasks[task_name_or_path]}
-
- # Load from YAML file
- if os.path.exists(task_name_or_path):
- with open(task_name_or_path, 'r') as f:
- return yaml.safe_load(f)
-
- raise ValueError(f"Task sequence '{task_name_or_path}' not recognized and file not found.")
-
-
-def run_continual_learning(args, logger):
- """
- Run the continual learning experiment.
-
- Args:
- args: Command line arguments
- logger: Logger instance
- """
- # Set random seed
- set_seed(args.seed)
-
- # Set device
- if args.device is None:
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- else:
- device = torch.device(args.device)
-
- logger.info(f"Using device: {device}")
-
- # Load task sequence
- task_config = load_task_config(args.tasks)
- task_sequence = task_config['task_sequence']
-
- logger.info(f"Task sequence: {[task['name'] for task in task_sequence]}")
-
- # Create data loaders for all tasks
- task_data = get_task_sequence(task_sequence, args.batch_size)
-
- # Initialize model and learner
- input_shape = task_data[0]['train_loader'].dataset[0][0].shape
- all_classes = set()
- for task in task_sequence:
- if task['classes'] == 'all':
- # 'all' is a sentinel meaning all classes in the dataset (e.g. 10 for MNIST variants)
- all_classes.update(range(10))
- else:
- all_classes.update(task['classes'])
- num_classes = len(all_classes)
-
- model = get_model(args.model, input_shape, num_classes)
- model = model.to(device)
-
- # Select appropriate learner based on method
- if args.method == 'baseline':
- learner = BaselineLearner(
- model=model,
- device=device,
- learning_rate=args.learning_rate
- )
- elif args.method == 'ewc':
- learner = EWCLearner(
- model=model,
- device=device,
- learning_rate=args.learning_rate,
- lambda_ewc=args.lambda_ewc,
- fisher_sample_size=args.fisher_sample_size
- )
- elif args.method == 'replay':
- learner = ExperienceReplayLearner(
- model=model,
- device=device,
- learning_rate=args.learning_rate,
- buffer_size=args.buffer_size,
- replay_batch_size=args.replay_batch_size
- )
- elif args.method == 'lwf':
- learner = LwFLearner(
- model=model,
- device=device,
- learning_rate=args.learning_rate,
- temperature=args.temperature,
- alpha=args.alpha
- )
- else:
- raise ValueError(f"Unknown method: {args.method}")
-
- # Create directory to save results
- results_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), args.save_dir)
- experiment_dir = os.path.join(results_dir, f"{args.method}_{args.tasks}_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
- os.makedirs(experiment_dir, exist_ok=True)
-
- # Track performance after each task
- performance_matrix = np.zeros((len(task_sequence), len(task_sequence)))
-
- # Sequentially train on each task
- for task_id, task_data_dict in enumerate(task_data):
- task_name = task_data_dict['name']
- train_loader = task_data_dict['train_loader']
- val_loader = task_data_dict['val_loader']
-
- logger.info(f"Starting training on task {task_id+1}/{len(task_data)}: {task_name}")
-
- # Train on current task
- learner.train(
- train_loader=train_loader,
- val_loader=val_loader,
- task_id=task_id,
- epochs=args.epochs,
- eval_freq=args.eval_freq
- )
-
- # Evaluate on all tasks seen so far
- for eval_task_id in range(task_id + 1):
- eval_task_data = task_data[eval_task_id]
- eval_loader = eval_task_data['test_loader']
-
- accuracy = learner.evaluate(eval_loader, eval_task_id)
- performance_matrix[task_id, eval_task_id] = accuracy
-
- logger.info(f"After task {task_id+1}, accuracy on task {eval_task_id+1}: {accuracy:.2f}%")
-
- # Calculate forgetting
- forgetting_matrix = compute_forgetting(performance_matrix)
-
- # Save performance metrics
- np.save(os.path.join(experiment_dir, 'performance_matrix.npy'), performance_matrix)
- np.save(os.path.join(experiment_dir, 'forgetting_matrix.npy'), forgetting_matrix)
-
- # Save model
- learner.save(os.path.join(experiment_dir, 'final_model.pt'))
-
- # Generate plots
- task_names = [task['name'] for task in task_sequence]
- performance_plot = plot_performance(performance_matrix, task_names)
- forgetting_plot = plot_forgetting(forgetting_matrix, task_names)
-
- performance_plot.savefig(os.path.join(experiment_dir, 'performance.png'))
- forgetting_plot.savefig(os.path.join(experiment_dir, 'forgetting.png'))
-
- # Print summary
- logger.info("\nExperiment summary:")
- logger.info(f"Method: {args.method}")
- logger.info(f"Task sequence: {[task['name'] for task in task_sequence]}")
- logger.info(f"Average final accuracy: {np.mean(performance_matrix[-1, :]):.2f}%")
- logger.info(f"Average forgetting: {np.mean(forgetting_matrix[-1, :]):.2f}%")
- logger.info(f"Results saved to: {experiment_dir}")
-
-
-def main():
- """Main function."""
- # Parse arguments
- args = parse_arguments()
-
- # Set up logging
- logger = setup_logging()
- logger.info("Starting Continual Learning experiment")
- logger.info(f"Arguments: {args}")
-
- # Run experiment
- try:
- run_continual_learning(args, logger)
- logger.info("Experiment completed successfully")
- except Exception as e:
- logger.exception(f"Experiment failed with error: {e}")
- sys.exit(1)
-
-
-if __name__ == '__main__':
- main()
\ No newline at end of file
diff --git a/src/methods/__init__.py b/src/methods/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/src/methods/baseline.py b/src/methods/baseline.py
deleted file mode 100644
index 0cfdbd6..0000000
--- a/src/methods/baseline.py
+++ /dev/null
@@ -1,239 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-"""
-Baseline learner that simply fine-tunes the model on new tasks.
-This implementation will demonstrate catastrophic forgetting.
-"""
-
-import os
-import logging
-import time
-from typing import Dict, List, Optional
-
-import torch
-import torch.nn as nn
-import torch.optim as optim
-from torch.utils.data import DataLoader
-from tqdm import tqdm
-
-logger = logging.getLogger(__name__)
-
-
-class BaselineLearner:
- """
- Naive fine-tuning implementation that trains sequentially on tasks.
- Used as a baseline to demonstrate catastrophic forgetting.
- """
-
- def __init__(self, model: nn.Module, device: torch.device, learning_rate: float = 0.001):
- """
- Initialize the baseline learner.
-
- Args:
- model (nn.Module): The neural network model
- device (torch.device): Device to run the model on
- learning_rate (float): Learning rate for the optimizer
- """
- self.model = model
- self.device = device
- self.learning_rate = learning_rate
-
- # Initialize optimizer
- self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
-
- # Loss function
- self.criterion = nn.CrossEntropyLoss()
-
- # Keep track of tasks
- self.current_task = 0
- self.seen_tasks = 0
-
- def train(self, train_loader: DataLoader, val_loader: DataLoader,
- task_id: int, epochs: int, eval_freq: int = 1):
- """
- Train the model on a new task.
-
- Args:
- train_loader (DataLoader): Training data loader
- val_loader (DataLoader): Validation data loader
- task_id (int): ID of the current task
- epochs (int): Number of training epochs
- eval_freq (int): Frequency of evaluation during training (in epochs)
- """
- # Update task info
- self.current_task = task_id
- self.seen_tasks = max(self.seen_tasks, task_id + 1)
-
- # Create a new optimizer for the new task with the original learning rate
- self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
-
- # Training loop
- best_val_loss = float('inf')
- best_model_state = None
-
- for epoch in range(epochs):
- # Training phase
- self.model.train()
- train_loss = 0.0
- train_correct = 0
- train_total = 0
-
- # Use tqdm for a nice progress bar
- train_pbar = tqdm(train_loader, desc=f"Task {task_id+1} - Epoch {epoch+1}/{epochs} [Train]")
- for inputs, targets in train_pbar:
- inputs, targets = inputs.to(self.device), targets.to(self.device)
-
- # Zero the parameter gradients
- self.optimizer.zero_grad()
-
- # Forward pass
- outputs = self.model(inputs)
- loss = self.criterion(outputs, targets)
-
- # Backward pass and optimize
- loss.backward()
- self.optimizer.step()
-
- # Update statistics
- train_loss += loss.item() * inputs.size(0)
- _, predicted = torch.max(outputs.data, 1)
- train_total += targets.size(0)
- train_correct += (predicted == targets).sum().item()
-
- # Update progress bar
- train_pbar.set_postfix({'loss': loss.item(),
- 'acc': 100.0 * train_correct / train_total})
-
- train_loss = train_loss / train_total
- train_acc = 100.0 * train_correct / train_total
-
- # Validation phase (run every eval_freq epochs)
- if epoch % eval_freq == 0:
- val_loss, val_acc = self._evaluate_training(val_loader)
-
- logger.info(f"Task {task_id+1} - Epoch {epoch+1}/{epochs}: "
- f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
- f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
-
- # Save the best model based on validation loss
- if val_loss < best_val_loss:
- best_val_loss = val_loss
- best_model_state = self.model.state_dict().copy()
-
- # Load the best model from this task's training
- if best_model_state is not None:
- self.model.load_state_dict(best_model_state)
- logger.info(f"Loaded best model for task {task_id+1} with validation loss: {best_val_loss:.4f}")
-
- def _evaluate_training(self, val_loader: DataLoader) -> tuple:
- """
- Evaluate the model on validation data during training.
-
- Args:
- val_loader (DataLoader): Validation data loader
-
- Returns:
- tuple: (validation loss, validation accuracy)
- """
- self.model.eval()
- val_loss = 0.0
- val_correct = 0
- val_total = 0
-
- with torch.no_grad():
- for inputs, targets in val_loader:
- inputs, targets = inputs.to(self.device), targets.to(self.device)
-
- # Forward pass
- outputs = self.model(inputs)
- loss = self.criterion(outputs, targets)
-
- # Update statistics
- val_loss += loss.item() * inputs.size(0)
- _, predicted = torch.max(outputs.data, 1)
- val_total += targets.size(0)
- val_correct += (predicted == targets).sum().item()
-
- val_loss = val_loss / val_total
- val_acc = 100.0 * val_correct / val_total
-
- return val_loss, val_acc
-
- def evaluate(self, test_loader: DataLoader, task_id: Optional[int] = None) -> float:
- """
- Evaluate the model on test data.
-
- Args:
- test_loader (DataLoader): Test data loader
- task_id (int, optional): ID of the task to evaluate
-
- Returns:
- float: Test accuracy as a percentage
- """
- self.model.eval()
- test_correct = 0
- test_total = 0
-
- with torch.no_grad():
- for inputs, targets in test_loader:
- inputs, targets = inputs.to(self.device), targets.to(self.device)
-
- # Forward pass
- outputs = self.model(inputs)
-
- # Calculate accuracy
- _, predicted = torch.max(outputs.data, 1)
- test_total += targets.size(0)
- test_correct += (predicted == targets).sum().item()
-
- test_acc = 100.0 * test_correct / test_total
-
- return test_acc
-
- def save(self, path: str):
- """
- Save the model to a file.
-
- Args:
- path (str): Path to save the model
- """
- # Create directory if it doesn't exist
- os.makedirs(os.path.dirname(path), exist_ok=True)
-
- # Save model state
- torch.save({
- 'model_state_dict': self.model.state_dict(),
- 'optimizer_state_dict': self.optimizer.state_dict(),
- 'seen_tasks': self.seen_tasks,
- 'current_task': self.current_task
- }, path)
-
- logger.info(f"Model saved to {path}")
-
- def load(self, path: str):
- """
- Load the model from a file.
-
- Args:
- path (str): Path to load the model from
- """
- if not os.path.exists(path):
- logger.error(f"Model file not found: {path}")
- return
-
- # Load model state
- checkpoint = torch.load(path, map_location=self.device)
- self.model.load_state_dict(checkpoint['model_state_dict'])
- self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
- self.seen_tasks = checkpoint['seen_tasks']
- self.current_task = checkpoint['current_task']
-
- logger.info(f"Model loaded from {path}")
-
-
-if __name__ == "__main__":
- logging.basicConfig(level=logging.INFO)
-
- logger.info("This is a baseline learner that demonstrates catastrophic forgetting.")
- logger.info("It should be used as a comparison point for more advanced continual learning methods.")
\ No newline at end of file
diff --git a/src/methods/ewc.py b/src/methods/ewc.py
deleted file mode 100644
index e1f12a4..0000000
--- a/src/methods/ewc.py
+++ /dev/null
@@ -1,330 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-"""
-Elastic Weight Consolidation (EWC) learner for continual learning.
-EWC prevents catastrophic forgetting by penalizing changes to parameters
-that are important for previously learned tasks.
-
-Reference:
-Kirkpatrick, J., Pascanu, R., Rabinowitz, N., Veness, J., Desjardins, G., Rusu, A. A., ... & Hadsell, R. (2017).
-"Overcoming catastrophic forgetting in neural networks."
-Proceedings of the National Academy of Sciences, 114(13), 3521-3526.
-"""
-
-import os
-import logging
-import copy
-from typing import Dict, List, Optional
-
-import torch
-import torch.nn as nn
-import torch.optim as optim
-import torch.nn.functional as F
-from torch.utils.data import DataLoader
-from tqdm import tqdm
-
-from .baseline import BaselineLearner
-
-logger = logging.getLogger(__name__)
-
-
-class EWCLearner(BaselineLearner):
- """
- Elastic Weight Consolidation (EWC) learner.
- Extends the baseline learner with a regularization term to prevent forgetting.
- """
-
- def __init__(self, model: nn.Module, device: torch.device, learning_rate: float = 0.001,
- lambda_ewc: float = 5000, fisher_sample_size: int = 200):
- """
- Initialize the EWC learner.
-
- Args:
- model (nn.Module): The neural network model
- device (torch.device): Device to run the model on
- learning_rate (float): Learning rate for the optimizer
- lambda_ewc (float): Regularization strength for EWC
- fisher_sample_size (int): Number of samples to estimate Fisher information
- """
- super().__init__(model, device, learning_rate)
-
- self.lambda_ewc = lambda_ewc
- self.fisher_sample_size = fisher_sample_size
-
- # Store parameters and Fisher information for each task
- self.fisher_matrices = {} # Fisher information for each task
- self.optimal_parameters = {} # Optimal parameters for each task
-
- def _compute_fisher_information(self, data_loader: DataLoader, task_id: int):
- """
- Compute the Fisher information matrix for the current task.
- The Fisher information measures how much the model parameters affect the output.
-
- Args:
- data_loader (DataLoader): Data loader for the current task
- task_id (int): ID of the current task
- """
- logger.info(f"Computing Fisher information matrix for task {task_id+1}")
-
- # Initialize Fisher information matrix
- fisher = {}
- for name, param in self.model.named_parameters():
- if param.requires_grad:
- fisher[name] = torch.zeros_like(param.data)
-
- # Set model to evaluation mode
- self.model.eval()
-
- # Sample a subset of data for Fisher computation
- sample_count = 0
-
- for inputs, targets in data_loader:
- if sample_count >= self.fisher_sample_size:
- break
-
- inputs, targets = inputs.to(self.device), targets.to(self.device)
- batch_size = inputs.shape[0]
-
- # Get model outputs
- log_probs = F.log_softmax(self.model(inputs), dim=1)
-
- # Compute gradients for each sample in the batch
- for i in range(batch_size):
- if sample_count >= self.fisher_sample_size:
- break
-
- sample_log_prob = log_probs[i, targets[i]]
-
- # Compute gradients
- self.optimizer.zero_grad()
- sample_log_prob.backward(retain_graph=(i < batch_size - 1))
-
- # Accumulate squared gradients
- for name, param in self.model.named_parameters():
- if param.requires_grad and param.grad is not None:
- fisher[name] += param.grad.data.pow(2) / self.fisher_sample_size
-
- sample_count += 1
-
- # Store the computed Fisher information
- self.fisher_matrices[task_id] = fisher
-
- logger.info(f"Fisher information matrix computed using {sample_count} samples")
-
- def _store_optimal_parameters(self, task_id: int):
- """
- Store the optimal parameters for the current task.
-
- Args:
- task_id (int): ID of the current task
- """
- logger.info(f"Storing optimal parameters for task {task_id+1}")
-
- optimal_params = {}
- for name, param in self.model.named_parameters():
- if param.requires_grad:
- optimal_params[name] = param.data.clone()
-
- self.optimal_parameters[task_id] = optimal_params
-
- def _compute_ewc_loss(self):
- """
- Compute the EWC regularization loss based on stored Fisher information.
- This penalties changes to parameters that were important for previous tasks.
-
- Returns:
- torch.Tensor: EWC regularization loss
- """
- ewc_loss = 0
-
- for task_id in range(self.current_task):
- if task_id not in self.fisher_matrices or task_id not in self.optimal_parameters:
- continue
-
- for name, param in self.model.named_parameters():
- if name in self.fisher_matrices[task_id] and name in self.optimal_parameters[task_id]:
- # Compute the squared difference between current and optimal parameters
- # weighted by the Fisher information
- fisher = self.fisher_matrices[task_id][name]
- optimal_param = self.optimal_parameters[task_id][name]
- ewc_loss += torch.sum(fisher * (param - optimal_param).pow(2)) / 2
-
- return ewc_loss * self.lambda_ewc
-
- def train(self, train_loader: DataLoader, val_loader: DataLoader,
- task_id: int, epochs: int, eval_freq: int = 1):
- """
- Train the model on a new task with EWC regularization.
-
- Args:
- train_loader (DataLoader): Training data loader
- val_loader (DataLoader): Validation data loader
- task_id (int): ID of the current task
- epochs (int): Number of training epochs
- eval_freq (int): Frequency of evaluation during training (in epochs)
- """
- # Update task info
- self.current_task = task_id
- self.seen_tasks = max(self.seen_tasks, task_id + 1)
-
- # Create a new optimizer for the new task with the original learning rate
- self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
-
- # Training loop
- best_val_loss = float('inf')
- best_model_state = None
-
- for epoch in range(epochs):
- # Training phase
- self.model.train()
- train_loss = 0.0
- train_task_loss = 0.0
- train_ewc_loss = 0.0
- train_correct = 0
- train_total = 0
-
- # Use tqdm for a nice progress bar
- train_pbar = tqdm(train_loader, desc=f"Task {task_id+1} - Epoch {epoch+1}/{epochs} [Train]")
- for inputs, targets in train_pbar:
- inputs, targets = inputs.to(self.device), targets.to(self.device)
-
- # Zero the parameter gradients
- self.optimizer.zero_grad()
-
- # Forward pass
- outputs = self.model(inputs)
- task_loss = self.criterion(outputs, targets)
-
- # Compute EWC regularization loss (for tasks > 0)
- ewc_loss = self._compute_ewc_loss() if task_id > 0 else 0
-
- # Total loss
- loss = task_loss + ewc_loss
-
- # Backward pass and optimize
- loss.backward()
- self.optimizer.step()
-
- # Update statistics
- train_loss += loss.item() * inputs.size(0)
- train_task_loss += task_loss.item() * inputs.size(0)
- if task_id > 0:
- train_ewc_loss += ewc_loss.item() * inputs.size(0)
-
- _, predicted = torch.max(outputs.data, 1)
- train_total += targets.size(0)
- train_correct += (predicted == targets).sum().item()
-
- # Update progress bar
- train_pbar.set_postfix({
- 'loss': loss.item(),
- 'task_loss': task_loss.item(),
- 'ewc_loss': ewc_loss.item() if task_id > 0 else 0,
- 'acc': 100.0 * train_correct / train_total
- })
-
- train_loss = train_loss / train_total
- train_task_loss = train_task_loss / train_total
- train_ewc_loss = train_ewc_loss / train_total if task_id > 0 else 0
- train_acc = 100.0 * train_correct / train_total
-
- # Validation phase (run every eval_freq epochs)
- if epoch % eval_freq == 0:
- val_loss, val_acc = self._evaluate_training(val_loader)
-
- logger.info(f"Task {task_id+1} - Epoch {epoch+1}/{epochs}: "
- f"Train Loss: {train_loss:.4f} (Task: {train_task_loss:.4f}, EWC: {train_ewc_loss:.4f}), "
- f"Train Acc: {train_acc:.2f}%, "
- f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
-
- # Save the best model based on validation loss
- if val_loss < best_val_loss:
- best_val_loss = val_loss
- best_model_state = self.model.state_dict().copy()
-
- # Load the best model from this task's training
- if best_model_state is not None:
- self.model.load_state_dict(best_model_state)
- logger.info(f"Loaded best model for task {task_id+1} with validation loss: {best_val_loss:.4f}")
-
- # After training on this task, compute and store Fisher information and optimal parameters
- self._compute_fisher_information(train_loader, task_id)
- self._store_optimal_parameters(task_id)
-
- def save(self, path: str):
- """
- Save the model and EWC-specific information to a file.
-
- Args:
- path (str): Path to save the model
- """
- # Create directory if it doesn't exist
- os.makedirs(os.path.dirname(path), exist_ok=True)
-
- # Convert fisher matrices and optimal parameters to CPU tensors for serialization
- fisher_cpu = {}
- for task_id, task_fisher in self.fisher_matrices.items():
- fisher_cpu[task_id] = {name: tensor.cpu() for name, tensor in task_fisher.items()}
-
- params_cpu = {}
- for task_id, task_params in self.optimal_parameters.items():
- params_cpu[task_id] = {name: tensor.cpu() for name, tensor in task_params.items()}
-
- # Save model state
- torch.save({
- 'model_state_dict': self.model.state_dict(),
- 'optimizer_state_dict': self.optimizer.state_dict(),
- 'seen_tasks': self.seen_tasks,
- 'current_task': self.current_task,
- 'fisher_matrices': fisher_cpu,
- 'optimal_parameters': params_cpu,
- 'lambda_ewc': self.lambda_ewc,
- 'fisher_sample_size': self.fisher_sample_size
- }, path)
-
- logger.info(f"Model saved to {path}")
-
- def load(self, path: str):
- """
- Load the model and EWC-specific information from a file.
-
- Args:
- path (str): Path to load the model from
- """
- if not os.path.exists(path):
- logger.error(f"Model file not found: {path}")
- return
-
- # Load model state
- checkpoint = torch.load(path, map_location=self.device)
- self.model.load_state_dict(checkpoint['model_state_dict'])
- self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
- self.seen_tasks = checkpoint['seen_tasks']
- self.current_task = checkpoint['current_task']
-
- # Load EWC-specific information
- self.lambda_ewc = checkpoint.get('lambda_ewc', self.lambda_ewc)
- self.fisher_sample_size = checkpoint.get('fisher_sample_size', self.fisher_sample_size)
-
- # Load Fisher matrices and optimal parameters, moving them to the correct device
- if 'fisher_matrices' in checkpoint:
- self.fisher_matrices = {}
- for task_id, task_fisher in checkpoint['fisher_matrices'].items():
- self.fisher_matrices[int(task_id)] = {name: tensor.to(self.device)
- for name, tensor in task_fisher.items()}
-
- if 'optimal_parameters' in checkpoint:
- self.optimal_parameters = {}
- for task_id, task_params in checkpoint['optimal_parameters'].items():
- self.optimal_parameters[int(task_id)] = {name: tensor.to(self.device)
- for name, tensor in task_params.items()}
-
- logger.info(f"Model loaded from {path}")
-
-
-if __name__ == "__main__":
- logging.basicConfig(level=logging.INFO)
-
- logger.info("This is an implementation of Elastic Weight Consolidation for continual learning.")
- logger.info("Reference: Kirkpatrick et al. (2017) - 'Overcoming catastrophic forgetting in neural networks'")
\ No newline at end of file
diff --git a/src/methods/lwf.py b/src/methods/lwf.py
deleted file mode 100644
index ad38ec3..0000000
--- a/src/methods/lwf.py
+++ /dev/null
@@ -1,41 +0,0 @@
-from src.methods.baseline import BaselineLearner
-import torch
-import torch.nn.functional as F
-import copy
-
-class LwFLearner(BaselineLearner):
- def __init__(self, model, device, learning_rate, temperature, alpha):
- super().__init__(model, device, learning_rate)
- self.temperature = temperature
- self.alpha = alpha
- self.teacher_model = None # Will hold a copy of the model from previous tasks
-
- def train(self, train_loader, val_loader=None, task_id=0, epochs=1, eval_freq=1, **kwargs):
- self.model.train()
- optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
- criterion_ce = torch.nn.CrossEntropyLoss()
- criterion_kd = torch.nn.KLDivLoss(reduction='batchmean')
- T = self.temperature
-
- for epoch in range(epochs):
- for batch in train_loader:
- inputs, targets = batch
- inputs, targets = inputs.to(self.device), targets.to(self.device)
- optimizer.zero_grad()
- outputs = self.model(inputs)
- loss_ce = criterion_ce(outputs, targets)
- loss_kd = 0.0
- if self.teacher_model is not None:
- self.teacher_model.eval()
- with torch.no_grad():
- teacher_outputs = self.teacher_model(inputs)
- # Compute soft targets: student uses log_softmax, teacher uses softmax
- soft_student = F.log_softmax(outputs / T, dim=1)
- soft_teacher = F.softmax(teacher_outputs / T, dim=1)
- loss_kd = criterion_kd(soft_student, soft_teacher) * (T * T)
- loss = self.alpha * loss_ce + (1 - self.alpha) * loss_kd
- loss.backward()
- optimizer.step()
-
- # Update teacher model after training the current task
- self.teacher_model = copy.deepcopy(self.model)
\ No newline at end of file
diff --git a/src/methods/replay.py b/src/methods/replay.py
deleted file mode 100644
index e9f3d11..0000000
--- a/src/methods/replay.py
+++ /dev/null
@@ -1,44 +0,0 @@
-from src.methods.baseline import BaselineLearner
-import torch
-import random
-
-class ExperienceReplayLearner(BaselineLearner):
- def __init__(self, model, device, learning_rate, buffer_size, replay_batch_size):
- super().__init__(model, device, learning_rate)
- self.buffer_size = buffer_size
- self.replay_batch_size = replay_batch_size
- self.buffer = []
-
- def train(self, train_loader, val_loader=None, task_id=0, epochs=1, eval_freq=1, **kwargs):
- self.model.train()
- optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
- criterion = torch.nn.CrossEntropyLoss()
-
- for epoch in range(epochs):
- for batch in train_loader:
- inputs, targets = batch
- inputs, targets = inputs.to(self.device), targets.to(self.device)
- optimizer.zero_grad()
- outputs = self.model(inputs)
- loss = criterion(outputs, targets)
-
- # Incorporate replay loss if buffer is not empty
- if len(self.buffer) > 0:
- replay_samples = random.sample(self.buffer, min(self.replay_batch_size, len(self.buffer)))
- replay_inputs = torch.stack([sample[0] for sample in replay_samples]).to(self.device)
- replay_targets = torch.tensor([sample[1] for sample in replay_samples]).to(self.device)
- replay_outputs = self.model(replay_inputs)
- replay_loss = criterion(replay_outputs, replay_targets)
- loss = loss + replay_loss
-
- loss.backward()
- optimizer.step()
-
- # Update the replay buffer with current task samples
- for batch in train_loader:
- inputs, targets = batch
- for i in range(inputs.size(0)):
- self.buffer.append((inputs[i].detach().cpu(), targets[i].detach().cpu()))
- # Keep buffer size within limits
- if len(self.buffer) > self.buffer_size:
- self.buffer = self.buffer[-self.buffer_size:]
\ No newline at end of file
diff --git a/src/models/__init__.py b/src/models/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/src/models/__pycache__/model_factory.cpython-312.pyc b/src/models/__pycache__/model_factory.cpython-312.pyc
deleted file mode 100644
index cd7cb47..0000000
Binary files a/src/models/__pycache__/model_factory.cpython-312.pyc and /dev/null differ
diff --git a/src/models/model_factory.py b/src/models/model_factory.py
deleted file mode 100644
index 789ed61..0000000
--- a/src/models/model_factory.py
+++ /dev/null
@@ -1,318 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-"""
-Model factory module for the Continual Learning System.
-Provides different neural network architectures.
-"""
-
-import logging
-from typing import Tuple, List, Dict
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-logger = logging.getLogger(__name__)
-
-
-class SimpleCNN(nn.Module):
- """
- A simple CNN model for image classification tasks.
- """
- def __init__(self, input_shape: Tuple[int, int, int], num_classes: int):
- """
- Initialize the model.
-
- Args:
- input_shape (tuple): Shape of input images (channels, height, width)
- num_classes (int): Number of output classes
- """
- super(SimpleCNN, self).__init__()
-
- # Extract input dimensions
- channels, height, width = input_shape
-
- # Feature extractor
- self.features = nn.Sequential(
- nn.Conv2d(channels, 32, kernel_size=3, stride=1, padding=1),
- nn.ReLU(inplace=True),
- nn.MaxPool2d(kernel_size=2, stride=2),
- nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
- nn.ReLU(inplace=True),
- nn.MaxPool2d(kernel_size=2, stride=2),
- nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
- nn.ReLU(inplace=True),
- nn.MaxPool2d(kernel_size=2, stride=2)
- )
-
- # Calculate the size of feature maps after convolutions and pooling
- feature_size = (height // 8) * (width // 8) * 128
-
- # Classifier
- self.classifier = nn.Sequential(
- nn.Linear(feature_size, 256),
- nn.ReLU(inplace=True),
- nn.Dropout(0.5),
- nn.Linear(256, num_classes)
- )
-
- def forward(self, x):
- """Forward pass through the network."""
- x = self.features(x)
- x = torch.flatten(x, 1)
- x = self.classifier(x)
- return x
-
-
-class MLP(nn.Module):
- """
- A simple multi-layer perceptron for simpler tasks.
- """
- def __init__(self, input_shape: Tuple[int, int, int], num_classes: int, hidden_dim: int = 256):
- """
- Initialize the model.
-
- Args:
- input_shape (tuple): Shape of input images (channels, height, width)
- num_classes (int): Number of output classes
- hidden_dim (int): Dimensionality of hidden layers
- """
- super(MLP, self).__init__()
-
- # Calculate input dimensionality
- input_dim = input_shape[0] * input_shape[1] * input_shape[2]
-
- self.layers = nn.Sequential(
- nn.Flatten(),
- nn.Linear(input_dim, hidden_dim),
- nn.ReLU(inplace=True),
- nn.Dropout(0.5),
- nn.Linear(hidden_dim, hidden_dim),
- nn.ReLU(inplace=True),
- nn.Dropout(0.5),
- nn.Linear(hidden_dim, num_classes)
- )
-
- def forward(self, x):
- """Forward pass through the network."""
- return self.layers(x)
-
-
-class LeNet5(nn.Module):
- """
- LeNet-5 convolutional network architecture for image classification.
- """
- def __init__(self, input_shape: Tuple[int, int, int], num_classes: int):
- """
- Initialize the model.
-
- Args:
- input_shape (tuple): Shape of input images (channels, height, width)
- num_classes (int): Number of output classes
- """
- super(LeNet5, self).__init__()
-
- # Extract input dimensions
- channels, height, width = input_shape
-
- # Feature extraction
- self.conv1 = nn.Conv2d(channels, 6, kernel_size=5)
- self.relu1 = nn.ReLU()
- self.pool1 = nn.MaxPool2d(kernel_size=2)
- self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
- self.relu2 = nn.ReLU()
- self.pool2 = nn.MaxPool2d(kernel_size=2)
-
- # Calculate the size of feature maps before the fully connected layer
- # For MNIST (1x28x28), after two conv+pool layers, we get 16x4x4
- # Need to adjust dynamically for different input shapes
- conv1_out_size = (height - 5 + 1) // 2 # Conv 5x5 then pool 2x2
- conv2_out_size = (conv1_out_size - 5 + 1) // 2 # Conv 5x5 then pool 2x2
- fc_input_size = 16 * conv2_out_size * conv2_out_size
-
- # Classification
- self.fc1 = nn.Linear(fc_input_size, 120)
- self.relu3 = nn.ReLU()
- self.fc2 = nn.Linear(120, 84)
- self.relu4 = nn.ReLU()
- self.fc3 = nn.Linear(84, num_classes)
-
- def forward(self, x):
- """Forward pass through the network."""
- x = self.conv1(x)
- x = self.relu1(x)
- x = self.pool1(x)
- x = self.conv2(x)
- x = self.relu2(x)
- x = self.pool2(x)
- x = torch.flatten(x, 1)
- x = self.fc1(x)
- x = self.relu3(x)
- x = self.fc2(x)
- x = self.relu4(x)
- x = self.fc3(x)
- return x
-
-
-class SmallResNet(nn.Module):
- """
- A small ResNet-like architecture suitable for continual learning experiments.
- """
- class ResBlock(nn.Module):
- """Basic residual block with two 3x3 convolutions."""
- def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
- super(SmallResNet.ResBlock, self).__init__()
-
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
- stride=stride, padding=1, bias=False)
- self.bn1 = nn.BatchNorm2d(out_channels)
- self.relu = nn.ReLU(inplace=True)
-
- self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
- stride=1, padding=1, bias=False)
- self.bn2 = nn.BatchNorm2d(out_channels)
-
- # Shortcut connection
- self.shortcut = nn.Sequential()
- if stride != 1 or in_channels != out_channels:
- self.shortcut = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=1,
- stride=stride, bias=False),
- nn.BatchNorm2d(out_channels)
- )
-
- def forward(self, x):
- residual = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
-
- out += self.shortcut(residual)
- out = self.relu(out)
-
- return out
-
- def __init__(self, input_shape: Tuple[int, int, int], num_classes: int):
- """
- Initialize the model.
-
- Args:
- input_shape (tuple): Shape of input images (channels, height, width)
- num_classes (int): Number of output classes
- """
- super(SmallResNet, self).__init__()
-
- channels, height, width = input_shape
-
- self.in_channels = 16
-
- self.conv1 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
- self.bn1 = nn.BatchNorm2d(16)
- self.relu = nn.ReLU(inplace=True)
-
- # Create residual blocks
- self.layer1 = self._make_layer(16, 2, stride=1)
- self.layer2 = self._make_layer(32, 2, stride=2)
- self.layer3 = self._make_layer(64, 2, stride=2)
-
- # Calculate the size after the feature extraction layers
- # Each layer with stride=2 reduces dim by 2
- final_size = height // 4 if height > 8 else 2 # Don't reduce below 2x2
-
- # Global average pooling
- self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
-
- # Classifier
- self.fc = nn.Linear(64, num_classes)
-
- # Initialize weights
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
- def _make_layer(self, out_channels: int, num_blocks: int, stride: int):
- """Create a layer of residual blocks."""
- strides = [stride] + [1] * (num_blocks - 1)
- layers = []
- for stride in strides:
- layers.append(self.ResBlock(self.in_channels, out_channels, stride))
- self.in_channels = out_channels
-
- return nn.Sequential(*layers)
-
- def forward(self, x):
- """Forward pass through the network."""
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
-
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
-
- x = self.avg_pool(x)
- x = torch.flatten(x, 1)
- x = self.fc(x)
-
- return x
-
-
-def get_model(model_name: str, input_shape: Tuple[int, int, int], num_classes: int):
- """
- Factory function to get the specified model.
-
- Args:
- model_name (str): Name of the model architecture
- input_shape (tuple): Shape of input data (channels, height, width)
- num_classes (int): Number of output classes
-
- Returns:
- nn.Module: Instantiated model
- """
- model_name = model_name.lower()
-
- if model_name == 'simple_cnn':
- model = SimpleCNN(input_shape, num_classes)
- elif model_name == 'mlp':
- model = MLP(input_shape, num_classes)
- elif model_name == 'lenet5':
- model = LeNet5(input_shape, num_classes)
- elif model_name == 'small_resnet':
- model = SmallResNet(input_shape, num_classes)
- else:
- raise ValueError(f"Unknown model architecture: {model_name}")
-
- logger.info(f"Created {model_name} model with input shape {input_shape} and {num_classes} output classes")
-
- return model
-
-
-if __name__ == "__main__":
- # Test the model factory
- logging.basicConfig(level=logging.INFO)
-
- # Test creating different models
- models = {
- 'simple_cnn': get_model('simple_cnn', (1, 28, 28), 10),
- 'mlp': get_model('mlp', (1, 28, 28), 10),
- 'lenet5': get_model('lenet5', (1, 28, 28), 10),
- 'small_resnet': get_model('small_resnet', (1, 28, 28), 10)
- }
-
- # Print model architectures
- for name, model in models.items():
- print(f"\nModel: {name}")
- print(model)
-
- # Test forward pass
- x = torch.randn(2, 1, 28, 28) # Batch of 2 MNIST-like images
- y = model(x)
- print(f"Output shape: {y.shape}")
\ No newline at end of file
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/src/utils/metrics.py b/src/utils/metrics.py
deleted file mode 100644
index c38c54a..0000000
--- a/src/utils/metrics.py
+++ /dev/null
@@ -1,290 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-"""
-Metrics module for the Continual Learning System.
-This module provides functions for calculating and tracking
-performance metrics in continual learning.
-"""
-
-import os
-import logging
-import numpy as np
-from typing import List, Dict, Any, Optional, Tuple
-
-logger = logging.getLogger(__name__)
-
-
-def evaluate_performance(
- model,
- task_loaders: List[Any],
- device,
- use_task_id: bool = False
-) -> np.ndarray:
- """
- Evaluate the performance of a model on all provided tasks.
-
- Args:
- model: PyTorch model to evaluate
- task_loaders (list): List of data loaders for each task
- device: Device to run the model on
- use_task_id (bool): Whether to provide task_id to the model's forward method
-
- Returns:
- np.ndarray: Vector of accuracies for each task
- """
- model.eval()
- accuracies = []
-
- for task_id, loader in enumerate(task_loaders):
- correct = 0
- total = 0
-
- for inputs, targets in loader:
- inputs, targets = inputs.to(device), targets.to(device)
-
- # Forward pass with or without task_id
- if use_task_id:
- outputs = model(inputs, task_id=task_id)
- else:
- outputs = model(inputs)
-
- # Calculate accuracy
- _, predicted = outputs.max(1)
- total += targets.size(0)
- correct += predicted.eq(targets).sum().item()
-
- # Calculate accuracy for this task
- accuracy = 100.0 * correct / total if total > 0 else 0.0
- accuracies.append(accuracy)
-
- return np.array(accuracies)
-
-
-def compute_forgetting(performance_matrix: np.ndarray) -> np.ndarray:
- """
- Compute the forgetting matrix from a performance matrix.
-
- Args:
- performance_matrix (np.ndarray): Matrix where rows represent training steps
- (task i trained) and columns represent
- performance on each task
-
- Returns:
- np.ndarray: Forgetting matrix where each element [i,j] represents how much
- task j was forgotten after training on task i
- """
- num_tasks = performance_matrix.shape[0]
- forgetting_matrix = np.zeros_like(performance_matrix)
-
- for i in range(1, num_tasks): # For each training step (except the first)
- for j in range(i): # For each previously learned task
- # Forgetting is the difference between the best previous performance
- # and the current performance
- best_previous = np.max(performance_matrix[:i, j])
- forgetting_matrix[i, j] = max(0, best_previous - performance_matrix[i, j])
-
- return forgetting_matrix
-
-
-def average_forgetting(forgetting_matrix: np.ndarray) -> List[float]:
- """
- Compute the average forgetting after each task.
-
- Args:
- forgetting_matrix (np.ndarray): Matrix where each element [i,j] represents
- how much task j was forgotten after training
- on task i
-
- Returns:
- list: Average forgetting after each task
- """
- num_tasks = forgetting_matrix.shape[0]
- avg_forgetting = []
-
- for i in range(1, num_tasks): # For each training step (except the first)
- # Average forgetting for all previous tasks
- if i > 0:
- avg_forgetting.append(np.mean(forgetting_matrix[i, :i]))
-
- return avg_forgetting
-
-
-def backward_transfer(performance_matrix: np.ndarray) -> List[float]:
- """
- Compute the backward transfer after each task.
- Backward transfer measures how training on later tasks affects
- performance on earlier tasks.
-
- Args:
- performance_matrix (np.ndarray): Matrix where rows represent training steps
- (task i trained) and columns represent
- performance on each task
-
- Returns:
- list: Backward transfer after each task
- """
- num_tasks = performance_matrix.shape[0]
- bt_values = []
-
- for i in range(1, num_tasks): # For each training step (except the first)
- # Sum of differences between current and original performance
- bt_sum = 0
- for j in range(i): # For each previously learned task
- bt_sum += performance_matrix[i, j] - performance_matrix[j, j]
-
- # Average backward transfer
- bt_values.append(bt_sum / i if i > 0 else 0)
-
- return bt_values
-
-
-def forward_transfer(performance_matrix: np.ndarray, random_performance: List[float]) -> List[float]:
- """
- Compute the forward transfer after each task.
- Forward transfer measures how training on earlier tasks affects
- the initial performance on later tasks.
-
- Args:
- performance_matrix (np.ndarray): Matrix where rows represent training steps
- (task i trained) and columns represent
- performance on each task
- random_performance (list): Performance on each task with random initialization
-
- Returns:
- list: Forward transfer after each task
- """
- num_tasks = performance_matrix.shape[0]
- ft_values = []
-
- for i in range(1, num_tasks): # For each task (except the first)
- # Performance on task i before training on it
- initial_perf = performance_matrix[i-1, i]
- # Performance on task i with random initialization
- random_perf = random_performance[i]
-
- # Forward transfer is the difference
- ft_values.append(initial_perf - random_perf)
-
- return ft_values
-
-
-def learning_curve_area(learning_curves: Dict[str, List[float]]) -> Dict[str, float]:
- """
- Compute the area under the learning curve for different methods.
- Higher area means faster learning.
-
- Args:
- learning_curves (dict): Dictionary mapping method names to lists of
- performance values during training
-
- Returns:
- dict: Dictionary mapping method names to ALC values
- """
- alc_values = {}
-
- for method, curve in learning_curves.items():
- # Normalize curve to [0, 1]
- min_val = min(curve)
- max_val = max(curve)
-
- if max_val > min_val:
- normalized_curve = [(v - min_val) / (max_val - min_val) for v in curve]
- else:
- normalized_curve = [0.5] * len(curve)
-
- # Compute area under the curve
- alc = np.trapz(normalized_curve, dx=1.0/len(normalized_curve))
- alc_values[method] = alc
-
- return alc_values
-
-
-def average_accuracy(performance_matrix: np.ndarray) -> float:
- """
- Compute the average accuracy across all tasks after training is complete.
-
- Args:
- performance_matrix (np.ndarray): Matrix where rows represent training steps
- (task i trained) and columns represent
- performance on each task
-
- Returns:
- float: Average final accuracy across all tasks
- """
- # Final performance is the last row of the matrix
- final_performance = performance_matrix[-1, :]
- return np.mean(final_performance)
-
-
-def compute_metrics_summary(performance_matrix: np.ndarray, random_performance: Optional[List[float]] = None) -> Dict[str, float]:
- """
- Compute a summary of continual learning metrics.
-
- Args:
- performance_matrix (np.ndarray): Matrix where rows represent training steps
- (task i trained) and columns represent
- performance on each task
- random_performance (list, optional): Performance on each task with random initialization
-
- Returns:
- dict: Dictionary of metric values
- """
- # Compute forgetting
- forgetting_matrix = compute_forgetting(performance_matrix)
- avg_forget = np.mean(forgetting_matrix[-1, :-1]) if performance_matrix.shape[0] > 1 else 0.0
-
- # Compute backward transfer
- bt_values = backward_transfer(performance_matrix)
- avg_bt = np.mean(bt_values) if bt_values else 0.0
-
- # Compute forward transfer if random performance is provided
- avg_ft = 0.0
- if random_performance is not None:
- ft_values = forward_transfer(performance_matrix, random_performance)
- avg_ft = np.mean(ft_values) if ft_values else 0.0
-
- # Compute average accuracy
- avg_acc = average_accuracy(performance_matrix)
-
- # Return metrics summary
- return {
- 'average_accuracy': avg_acc,
- 'average_forgetting': avg_forget,
- 'backward_transfer': avg_bt,
- 'forward_transfer': avg_ft
- }
-
-
-if __name__ == "__main__":
- logging.basicConfig(level=logging.INFO)
-
- # Example performance matrix
- num_tasks = 3
- performance = np.array([
- [90.0, np.nan, np.nan], # After task 1, only task 1 evaluated
- [85.0, 92.0, np.nan], # After task 2, tasks 1-2 evaluated
- [82.0, 88.0, 94.0] # After task 3, all tasks evaluated
- ])
-
- # Example random performance
- random_perf = [40.0, 42.0, 38.0]
-
- # Compute and print metrics
- forgetting_mat = compute_forgetting(performance)
- print("Forgetting matrix:")
- print(forgetting_mat)
-
- avg_forgetting_values = average_forgetting(forgetting_mat)
- print(f"Average forgetting after each task: {avg_forgetting_values}")
-
- bt_values = backward_transfer(performance)
- print(f"Backward transfer after each task: {bt_values}")
-
- ft_values = forward_transfer(performance, random_perf)
- print(f"Forward transfer after each task: {ft_values}")
-
- metrics_summary = compute_metrics_summary(performance, random_perf)
- print("Metrics summary:")
- for metric, value in metrics_summary.items():
- print(f" {metric}: {value:.2f}")
\ No newline at end of file
diff --git a/src/utils/visualization.py b/src/utils/visualization.py
deleted file mode 100644
index 470175b..0000000
--- a/src/utils/visualization.py
+++ /dev/null
@@ -1,313 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-"""
-Visualization utilities for the Continual Learning System.
-"""
-
-import os
-import logging
-import numpy as np
-from typing import List, Dict, Any, Optional
-import matplotlib.pyplot as plt
-import seaborn as sns
-
-logger = logging.getLogger(__name__)
-
-
-def plot_performance(performance_matrix: np.ndarray, task_names: List[str] = None):
- """
- Plot the performance of a model across sequential tasks.
-
- Args:
- performance_matrix (np.ndarray): Matrix where rows represent training progress
- (task_i trained) and columns represent
- performance on each task
- task_names (list): Names of the tasks
-
- Returns:
- matplotlib.figure.Figure: The generated figure
- """
- # Create task names if not provided
- if task_names is None:
- task_names = [f"Task {i+1}" for i in range(performance_matrix.shape[1])]
-
- # Create figure
- fig, ax = plt.subplots(figsize=(10, 6))
-
- # Set seaborn style
- sns.set_style("whitegrid")
-
- # Plot matrix as heatmap
- sns.heatmap(performance_matrix, annot=True, fmt=".1f", cmap="YlGnBu",
- xticklabels=task_names, yticklabels=[f"After Task {i+1}" for i in range(performance_matrix.shape[0])],
- cbar_kws={'label': 'Accuracy (%)'})
-
- # Add labels and title
- plt.xlabel("Evaluated Task")
- plt.ylabel("Training Progress")
- plt.title("Model Performance After Each Task")
-
- plt.tight_layout()
-
- return fig
-
-
-def plot_forgetting(forgetting_matrix: np.ndarray, task_names: List[str] = None):
- """
- Plot the forgetting of a model across sequential tasks.
-
- Args:
- forgetting_matrix (np.ndarray): Matrix where rows represent training progress
- (task_i trained) and columns represent
- forgetting on each task
- task_names (list): Names of the tasks
-
- Returns:
- matplotlib.figure.Figure: The generated figure
- """
- # Create task names if not provided
- if task_names is None:
- task_names = [f"Task {i+1}" for i in range(forgetting_matrix.shape[1])]
-
- # Create figure
- fig, ax = plt.subplots(figsize=(10, 6))
-
- # Set seaborn style
- sns.set_style("whitegrid")
-
- # Plot matrix as heatmap with a different colormap for forgetting
- sns.heatmap(forgetting_matrix, annot=True, fmt=".1f", cmap="Reds",
- xticklabels=task_names, yticklabels=[f"After Task {i+1}" for i in range(forgetting_matrix.shape[0])],
- cbar_kws={'label': 'Forgetting (%)'})
-
- # Add labels and title
- plt.xlabel("Task Forgotten")
- plt.ylabel("Training Progress")
- plt.title("Forgetting After Each Task")
-
- plt.tight_layout()
-
- return fig
-
-
-def plot_accuracy_over_time(accuracies: List[float], task_boundaries: List[int],
- task_names: List[str] = None, title: str = "Accuracy Over Time"):
- """
- Plot the accuracy of a model throughout training across multiple tasks.
-
- Args:
- accuracies (list): Validation accuracies throughout training
- task_boundaries (list): Epoch indices where tasks change
- task_names (list): Names of the tasks
- title (str): Title for the plot
-
- Returns:
- matplotlib.figure.Figure: The generated figure
- """
- # Create task names if not provided
- num_tasks = len(task_boundaries) + 1
- if task_names is None:
- task_names = [f"Task {i+1}" for i in range(num_tasks)]
-
- # Create figure
- fig, ax = plt.subplots(figsize=(12, 6))
-
- # Set seaborn style
- sns.set_style("whitegrid")
-
- # Plot accuracy
- epochs = np.arange(1, len(accuracies) + 1)
- plt.plot(epochs, accuracies, 'b-', linewidth=2)
-
- # Add task boundary lines
- for boundary in task_boundaries:
- plt.axvline(x=boundary, color='r', linestyle='--')
-
- # Add task labels
- task_midpoints = [0] + task_boundaries
- task_midpoints.append(len(accuracies))
- for i in range(num_tasks):
- midpoint = (task_midpoints[i] + task_midpoints[i+1]) / 2
- plt.text(midpoint, min(accuracies) - 5, task_names[i],
- horizontalalignment='center', fontsize=12)
-
- # Add labels and title
- plt.xlabel("Epoch")
- plt.ylabel("Accuracy (%)")
- plt.title(title)
- plt.ylim(min(accuracies) - 10, max(accuracies) + 5)
-
- plt.tight_layout()
-
- return fig
-
-
-def plot_task_comparison(final_accuracies: Dict[str, List[float]], task_names: List[str] = None):
- """
- Plot a comparison of final accuracies across different methods.
-
- Args:
- final_accuracies (dict): Dictionary mapping method names to lists of final accuracies on each task
- task_names (list): Names of the tasks
-
- Returns:
- matplotlib.figure.Figure: The generated figure
- """
- # Get number of tasks
- num_tasks = len(next(iter(final_accuracies.values())))
-
- # Create task names if not provided
- if task_names is None:
- task_names = [f"Task {i+1}" for i in range(num_tasks)]
-
- # Create figure
- fig, ax = plt.subplots(figsize=(12, 8))
-
- # Set seaborn style
- sns.set_style("whitegrid")
-
- # Prepare data for grouped bar chart
- methods = list(final_accuracies.keys())
- x = np.arange(len(task_names))
- width = 0.8 / len(methods)
-
- # Plot bars for each method
- for i, method in enumerate(methods):
- offset = (i - len(methods)/2 + 0.5) * width
- plt.bar(x + offset, final_accuracies[method], width, label=method)
-
- # Add labels and title
- plt.xlabel("Task")
- plt.ylabel("Final Accuracy (%)")
- plt.title("Comparison of Methods Across Tasks")
- plt.xticks(x, task_names)
- plt.legend()
-
- plt.tight_layout()
-
- return fig
-
-
-def plot_average_metrics(metrics: Dict[str, Dict[str, float]], metrics_to_plot: List[str] = None):
- """
- Plot average metrics across different methods.
-
- Args:
- metrics (dict): Dictionary mapping method names to dictionaries of metrics
- metrics_to_plot (list): List of metric names to plot
-
- Returns:
- matplotlib.figure.Figure: The generated figure
- """
- # Determine which metrics to plot
- if metrics_to_plot is None:
- # Get all unique metrics
- metrics_to_plot = set()
- for method_metrics in metrics.values():
- metrics_to_plot.update(method_metrics.keys())
- metrics_to_plot = sorted(list(metrics_to_plot))
-
- # Create figure
- fig, axes = plt.subplots(1, len(metrics_to_plot), figsize=(15, 6))
- if len(metrics_to_plot) == 1:
- axes = [axes] # Ensure axes is always a list
-
- # Set seaborn style
- sns.set_style("whitegrid")
-
- # Plot each metric
- for i, metric in enumerate(metrics_to_plot):
- ax = axes[i]
-
- # Extract values for this metric from each method
- methods = []
- values = []
- for method, method_metrics in metrics.items():
- if metric in method_metrics:
- methods.append(method)
- values.append(method_metrics[metric])
-
- # Plot bar chart
- colors = sns.color_palette("viridis", len(methods))
- ax.bar(methods, values, color=colors)
-
- # Add labels
- ax.set_title(metric)
- ax.set_ylabel("Value")
- ax.set_xticks(range(len(methods)))
- ax.set_xticklabels(methods, rotation=45, ha='right')
-
- plt.tight_layout()
-
- return fig
-
-
-def plot_forgetting_curve(forgetting_values: Dict[str, List[float]], task_names: List[str] = None):
- """
- Plot forgetting curves for different methods.
-
- Args:
- forgetting_values (dict): Dictionary mapping method names to lists of forgetting values
- task_names (list): Names of the tasks
-
- Returns:
- matplotlib.figure.Figure: The generated figure
- """
- # Get number of tasks
- num_tasks = len(next(iter(forgetting_values.values())))
-
- # Create task names if not provided
- if task_names is None:
- task_names = [f"Task {i+1}" for i in range(num_tasks)]
-
- # Create figure
- fig, ax = plt.subplots(figsize=(10, 6))
-
- # Set seaborn style
- sns.set_style("whitegrid")
-
- # Plot forgetting curves for each method
- for method, values in forgetting_values.items():
- plt.plot(range(1, num_tasks), values, marker='o', linewidth=2, label=method)
-
- # Add labels and title
- plt.xlabel("Tasks Learned")
- plt.ylabel("Average Forgetting (%)")
- plt.title("Forgetting Curve for Different Methods")
- plt.xticks(range(1, num_tasks), [f"After Task {i+1}" for i in range(1, num_tasks)])
- plt.legend()
- plt.grid(True)
-
- plt.tight_layout()
-
- return fig
-
-
-if __name__ == "__main__":
- logging.basicConfig(level=logging.INFO)
-
- # Demo visualization with random data
- num_tasks = 3
- performance = np.random.uniform(60, 100, size=(num_tasks, num_tasks))
- # Set upper triangle to NaN (tasks not seen yet)
- for i in range(num_tasks):
- for j in range(i+1, num_tasks):
- performance[i, j] = np.nan
-
- # Example forgetting matrix
- forgetting = np.zeros((num_tasks, num_tasks))
- for i in range(1, num_tasks):
- for j in range(i):
- forgetting[i, j] = performance[j, j] - performance[i, j]
-
- # Create demo plots
- task_names = ["Digits 0-4", "Digits 5-9", "Fashion MNIST"]
-
- perf_fig = plot_performance(performance, task_names)
- perf_fig.savefig("demo_performance.png")
-
- forget_fig = plot_forgetting(forgetting, task_names)
- forget_fig.savefig("demo_forgetting.png")
-
- logger.info("Created demo visualizations: demo_performance.png and demo_forgetting.png")
\ No newline at end of file
diff --git a/tests/test_config.py b/tests/test_config.py
new file mode 100644
index 0000000..b664ef7
--- /dev/null
+++ b/tests/test_config.py
@@ -0,0 +1,73 @@
+from __future__ import annotations
+
+from cl_bench.config import ExperimentConfig, load_config, load_config_with_overrides
+
+
+def test_load_named_smoke_config() -> None:
+ config = load_config("smoke")
+
+ assert isinstance(config, ExperimentConfig)
+ assert config.name == "smoke"
+ assert config.method == "baseline"
+ assert config.tasks[0].classes == [0, 1]
+ assert config.tasks[1].dataset == "synthetic"
+
+
+def test_load_real_data_quick_config() -> None:
+ config = load_config("split_mnist_quick")
+
+ assert config.name == "split_mnist_quick"
+ assert config.model == "small_cnn"
+ assert config.epochs == 3
+ assert len(config.tasks) == 5
+ assert all(task.dataset == "mnist" for task in config.tasks)
+ assert all(task.train_limit == 600 for task in config.tasks)
+
+
+def test_load_cifar_headline_config_and_overrides() -> None:
+ config = load_config("split_cifar10_headline")
+
+ assert config.name == "split_cifar10_headline"
+ assert config.method == "derpp"
+ assert config.model == "cifar_convnet"
+ assert config.tracking == "both"
+ assert config.epochs == 5
+ assert config.replay_buffer_size == 5000
+ assert config.replay_batch_size == 256
+ assert config.replay_loss_weight == 3.0
+ assert config.derpp_alpha == 0.1
+ assert config.derpp_beta == 2.0
+ assert len(config.tasks) == 5
+ assert all(task.dataset == "cifar10" for task in config.tasks)
+
+ overridden = load_config_with_overrides(
+ "split_cifar10_headline",
+ ["method=agem", "training.epochs=1", "tracking.mode=json"],
+ )
+ assert overridden.method == "agem"
+ assert overridden.epochs == 1
+ assert overridden.tracking == "json"
+
+
+def test_nested_training_and_strategy_values_are_parsed() -> None:
+ config = ExperimentConfig.from_dict(
+ {
+ "name": "unit",
+ "method": "ewc",
+ "training": {"epochs": 3, "batch_size": 12, "learning_rate": 0.02},
+ "strategy": {"ewc_lambda": 123.0, "fisher_samples": 7},
+ "tasks": [
+ {
+ "name": "a",
+ "dataset": "synthetic",
+ "classes": [0, 1],
+ }
+ ],
+ }
+ )
+
+ assert config.epochs == 3
+ assert config.batch_size == 12
+ assert config.learning_rate == 0.02
+ assert config.ewc_lambda == 123.0
+ assert config.fisher_samples == 7
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
new file mode 100644
index 0000000..da821bd
--- /dev/null
+++ b/tests/test_datasets.py
@@ -0,0 +1,97 @@
+from __future__ import annotations
+
+import torch
+
+from cl_bench.config import ExperimentConfig, TaskSpec
+from cl_bench.datasets import build_task_loaders
+
+
+def test_synthetic_task_construction_is_deterministic() -> None:
+ config = ExperimentConfig(
+ name="unit",
+ method="baseline",
+ seed=11,
+ model="linear",
+ batch_size=8,
+ eval_batch_size=16,
+ val_fraction=0.25,
+ tasks=[
+ TaskSpec(
+ name="first",
+ dataset="synthetic",
+ classes=[0, 1],
+ samples_per_class=8,
+ test_samples_per_class=4,
+ ),
+ TaskSpec(
+ name="second",
+ dataset="synthetic",
+ classes=[2, 3],
+ samples_per_class=8,
+ test_samples_per_class=4,
+ ),
+ ],
+ )
+
+ first_tasks, input_shape, num_classes = build_task_loaders(config)
+ second_tasks, second_shape, second_num_classes = build_task_loaders(config)
+
+ assert [task.name for task in first_tasks] == ["first", "second"]
+ assert input_shape == (1, 8, 8)
+ assert num_classes == 4
+ assert second_shape == input_shape
+ assert second_num_classes == num_classes
+
+ first_batch = next(iter(first_tasks[0].train_loader))
+ second_batch = next(iter(second_tasks[0].train_loader))
+ assert first_batch[1].tolist() == second_batch[1].tolist()
+
+
+def test_cifar_task_construction_uses_rgb_shape_and_limits(monkeypatch) -> None:
+ class FakeCIFAR10:
+ classes = [str(index) for index in range(10)]
+
+ def __init__(self, root, train: bool, download: bool, transform):
+ del root, download
+ self.transform = transform
+ examples_per_class = 6 if train else 4
+ self.targets = [class_id for class_id in range(10) for _ in range(examples_per_class)]
+ self.data = [
+ torch.full((3, 32, 32), float(class_id) / 10.0)
+ for class_id in range(10)
+ for _ in range(examples_per_class)
+ ]
+
+ def __len__(self) -> int:
+ return len(self.targets)
+
+ def __getitem__(self, index: int):
+ return self.data[index], self.targets[index]
+
+ monkeypatch.setattr("cl_bench.datasets.tv_datasets.CIFAR10", FakeCIFAR10)
+ config = ExperimentConfig(
+ name="cifar_unit",
+ method="baseline",
+ seed=3,
+ model="cifar_convnet",
+ batch_size=4,
+ eval_batch_size=8,
+ val_fraction=0.0,
+ augment=False,
+ tasks=[
+ TaskSpec(
+ name="cifar_0_1",
+ dataset="cifar10",
+ classes=[0, 1],
+ train_limit=5,
+ test_limit=4,
+ )
+ ],
+ )
+
+ tasks, input_shape, num_classes = build_task_loaders(config)
+
+ assert input_shape == (3, 32, 32)
+ assert num_classes == 2
+ assert len(tasks[0].train_loader.dataset) == 5
+ assert len(tasks[0].test_loader.dataset) == 4
diff --git a/tests/test_integration_smoke.py b/tests/test_integration_smoke.py
new file mode 100644
index 0000000..1149060
--- /dev/null
+++ b/tests/test_integration_smoke.py
@@ -0,0 +1,30 @@
+from __future__ import annotations
+
+from dataclasses import replace
+
+from cl_bench.config import load_config
+from cl_bench.experiments import run_experiment
+
+
+def test_cpu_smoke_benchmark_writes_reproducibility_artifacts(tmp_path) -> None:
+ config = replace(load_config("smoke"), output_dir=str(tmp_path), method="derpp")
+
+ result = run_experiment(config)
+
+ assert result.run_dir.exists()
+ assert result.config_path.exists()
+ assert result.metrics_path.exists()
+ assert (result.run_dir / "metrics.jsonl").exists()
+ assert (result.run_dir / "accuracy_matrix.csv").exists()
+ assert (result.run_dir / "forgetting_matrix.csv").exists()
+ assert result.summary["num_tasks"] == 2
+
+
+def test_synthetic_suite_runs_core_memory_methods(tmp_path) -> None:
+ for method in ["baseline", "replay", "derpp", "agem"]:
+ config = replace(load_config("smoke"), output_dir=str(tmp_path), method=method)
+ result = run_experiment(config)
+
+ assert result.method == method
+ assert result.metrics_path.exists()
+ assert result.summary["num_tasks"] == 2
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
new file mode 100644
index 0000000..703c908
--- /dev/null
+++ b/tests/test_metrics.py
@@ -0,0 +1,34 @@
+from __future__ import annotations
+
+import numpy as np
+
+from cl_bench.metrics import compute_forgetting, matrix_to_jsonable, summarize_accuracy
+
+
+def test_forgetting_uses_best_previous_accuracy() -> None:
+ matrix = np.array(
+ [
+ [80.0, np.nan, np.nan],
+ [75.0, 70.0, np.nan],
+ [78.0, 65.0, 90.0],
+ ]
+ )
+
+ forgetting = compute_forgetting(matrix)
+
+ assert forgetting[0, 0] == 0.0
+ assert forgetting[1, 0] == 5.0
+ assert forgetting[2, 0] == 2.0
+ assert forgetting[2, 1] == 5.0
+ assert np.isnan(forgetting[0, 1])
+
+
+def test_summary_and_json_matrix_are_nan_safe() -> None:
+ matrix = np.array([[50.0, np.nan], [40.0, 75.0]])
+
+ summary = summarize_accuracy(matrix)
+ json_matrix = matrix_to_jsonable(matrix)
+
+ assert summary["average_final_accuracy"] == 57.5
+ assert summary["average_forgetting"] == 10.0
+ assert json_matrix == [[50.0, None], [40.0, 75.0]]
diff --git a/tests/test_replay.py b/tests/test_replay.py
new file mode 100644
index 0000000..2738ad4
--- /dev/null
+++ b/tests/test_replay.py
@@ -0,0 +1,31 @@
+from __future__ import annotations
+
+import torch
+
+from cl_bench.strategies.replay import ReservoirReplayBuffer
+
+
+def test_reservoir_replay_buffer_is_bounded_and_not_last_n() -> None:
+ buffer = ReservoirReplayBuffer(capacity=3, seed=3)
+
+ for value in range(10):
+ inputs = torch.full((1, 2), float(value))
+ targets = torch.tensor([value])
+ buffer.add_batch(inputs, targets)
+
+ stored_targets = [sample.target for sample in buffer.samples]
+
+ assert len(buffer) == 3
+ assert buffer.seen_count == 10
+ assert stored_targets != [7, 8, 9]
+
+
+def test_replay_sampling_returns_tensors() -> None:
+ buffer = ReservoirReplayBuffer(capacity=5, seed=0)
+ buffer.add_batch(torch.randn(4, 1, 8, 8), torch.tensor([0, 1, 2, 3]))
+
+ inputs, targets = buffer.sample(batch_size=2)
+
+ assert inputs.shape == (2, 1, 8, 8)
+ assert targets.shape == (2,)
+ assert targets.dtype == torch.long
diff --git a/tests/test_reporting.py b/tests/test_reporting.py
new file mode 100644
index 0000000..165fc5b
--- /dev/null
+++ b/tests/test_reporting.py
@@ -0,0 +1,66 @@
+from __future__ import annotations
+
+import json
+
+from cl_bench.reporting import aggregate_records, collect_runs, write_report
+
+
+def _write_metrics(run_dir, method: str, seed: int, final_accuracy: float) -> None:
+ run_dir.mkdir(parents=True)
+ payload = {
+ "benchmark": "unit",
+ "method": method,
+ "task_names": ["a", "b"],
+ "summary": {
+ "average_final_accuracy": final_accuracy,
+ "average_learning_accuracy": 80.0,
+ "average_forgetting": 5.0,
+ "backward_transfer": -5.0,
+ "runtime_seconds": 1.5,
+ "seed": seed,
+ },
+ "accuracy_matrix": [[80.0, None], [70.0, final_accuracy]],
+ "forgetting_matrix": [[0.0, None], [10.0, 0.0]],
+ "runtime_seconds": 1.5,
+ "seed": seed,
+ "git_commit": None,
+ }
+ (run_dir / "metrics.json").write_text(json.dumps(payload), encoding="utf-8")
+
+
+def test_collect_and_aggregate_runs(tmp_path) -> None:
+ _write_metrics(tmp_path / "baseline_seed_1", "baseline", 1, 60.0)
+ _write_metrics(tmp_path / "baseline_seed_2", "baseline", 2, 80.0)
+ _write_metrics(tmp_path / "replay_seed_1", "replay", 1, 90.0)
+
+ records = collect_runs([tmp_path])
+ leaderboard = aggregate_records(records)
+
+ assert [row["method"] for row in leaderboard] == ["replay", "baseline"]
+ baseline = next(row for row in leaderboard if row["method"] == "baseline")
+ assert baseline["runs"] == 2
+ assert baseline["seeds"] == "1,2"
+ assert baseline["average_final_accuracy_mean"] == 70.0
+
+
+def test_write_report_without_plots(tmp_path) -> None:
+ _write_metrics(tmp_path / "replay_seed_1", "replay", 1, 90.0)
+ records = collect_runs([tmp_path])
+
+ report = write_report(records, tmp_path / "report", "Unit report", make_plots=False)
+
+ assert report.leaderboard_csv.exists()
+ assert report.summary_json.exists()
+ assert report.markdown.exists()
+ assert report.plots == []
+
+
+def test_collect_runs_from_mlflow_artifact_export_shape(tmp_path) -> None:
+ export_dir = tmp_path / "mlruns" / "0" / "run-id" / "artifacts" / "run"
+ _write_metrics(export_dir, "derpp", 13, 88.0)
+
+ records = collect_runs([tmp_path / "mlruns"])
+
+ assert len(records) == 1
+ assert records[0].method == "derpp"
+ assert records[0].summary["average_final_accuracy"] == 88.0
diff --git a/tests/test_strategies.py b/tests/test_strategies.py
new file mode 100644
index 0000000..3ce3518
--- /dev/null
+++ b/tests/test_strategies.py
@@ -0,0 +1,147 @@
+from __future__ import annotations
+
+import torch
+
+from cl_bench.config import ExperimentConfig, TaskSpec
+from cl_bench.datasets import build_task_loaders
+from cl_bench.models import get_model
+from cl_bench.strategies.agem import AGEMStrategy
+from cl_bench.strategies.base import clone_state_dict
+from cl_bench.strategies.baseline import BaselineStrategy
+from cl_bench.strategies.derpp import DERPPStrategy
+from cl_bench.strategies.ewc import EWCStrategy
+from cl_bench.strategies.lwf import LwFStrategy
+
+
+def _tiny_config() -> ExperimentConfig:
+ return ExperimentConfig(
+ name="tiny",
+ method="baseline",
+ seed=5,
+ model="linear",
+ epochs=1,
+ batch_size=8,
+ eval_batch_size=16,
+ learning_rate=0.05,
+ val_fraction=0.2,
+ tasks=[
+ TaskSpec(
+ name="a",
+ dataset="synthetic",
+ classes=[0, 1],
+ samples_per_class=8,
+ test_samples_per_class=4,
+ )
+ ],
+ )
+
+
+def test_strategy_lifecycle_trains_and_evaluates() -> None:
+ config = _tiny_config()
+ tasks, input_shape, num_classes = build_task_loaders(config)
+ model = get_model("linear", input_shape, num_classes)
+ strategy = BaselineStrategy(model=model, device=torch.device("cpu"), learning_rate=0.05)
+
+ history = strategy.train_task(tasks[0].train_loader, tasks[0].val_loader, task_id=0, epochs=1)
+ metrics = strategy.evaluate(tasks[0].test_loader)
+
+ assert len(history) == 1
+ assert strategy.current_task == 0
+ assert strategy.seen_tasks == 1
+ assert 0.0 <= metrics["accuracy"] <= 100.0
+
+
+def test_clone_state_dict_does_not_alias_model_parameters() -> None:
+ model = torch.nn.Linear(2, 2)
+ cloned = clone_state_dict(model)
+
+ with torch.no_grad():
+ model.weight.add_(10.0)
+
+ assert not torch.equal(cloned["weight"], model.state_dict()["weight"])
+
+
+def test_ewc_fisher_is_deterministic_and_normalized() -> None:
+ config = _tiny_config()
+ tasks, input_shape, num_classes = build_task_loaders(config)
+ model_a = get_model("linear", input_shape, num_classes)
+ model_b = get_model("linear", input_shape, num_classes)
+ model_b.load_state_dict(model_a.state_dict())
+
+ strategy_a = EWCStrategy(model_a, torch.device("cpu"), 0.01, ewc_lambda=1.0, fisher_samples=4)
+ strategy_b = EWCStrategy(model_b, torch.device("cpu"), 0.01, ewc_lambda=1.0, fisher_samples=4)
+
+ fisher_a = strategy_a._estimate_fisher(tasks[0].train_loader)
+ fisher_b = strategy_b._estimate_fisher(tasks[0].train_loader)
+
+ assert fisher_a.keys() == fisher_b.keys()
+ for name in fisher_a:
+ assert torch.allclose(fisher_a[name], fisher_b[name])
+ assert torch.all(fisher_a[name] >= 0)
+
+
+def test_lwf_creates_frozen_teacher_after_task() -> None:
+ config = _tiny_config()
+ tasks, input_shape, num_classes = build_task_loaders(config)
+ model = get_model("linear", input_shape, num_classes)
+ strategy = LwFStrategy(model, torch.device("cpu"), 0.05, alpha=0.5, temperature=2.0)
+
+ strategy.train_task(tasks[0].train_loader, tasks[0].val_loader, task_id=0, epochs=1)
+
+ assert strategy.teacher_model is not None
+ assert all(not parameter.requires_grad for parameter in strategy.teacher_model.parameters())
+
+
+def test_derpp_stores_logits_and_uses_replay_components() -> None:
+ config = _tiny_config()
+ tasks, input_shape, num_classes = build_task_loaders(config)
+ model = get_model("linear", input_shape, num_classes)
+ strategy = DERPPStrategy(
+ model,
+ torch.device("cpu"),
+ learning_rate=0.05,
+ buffer_size=16,
+ replay_batch_size=4,
+ alpha=0.5,
+ beta=1.0,
+ seed=1,
+ )
+ inputs, targets = next(iter(tasks[0].train_loader))
+ logits = strategy.model(inputs)
+
+ strategy.observe_batch(inputs, targets, logits, task_id=0)
+ loss, _, components = strategy.compute_loss(inputs, targets, task_id=1)
+
+ assert len(strategy.buffer) == inputs.size(0)
+ assert all(sample.logits is not None for sample in strategy.buffer.samples)
+ assert loss.item() > 0.0
+ assert "derpp_distillation_loss" in components
+ assert "derpp_replay_ce_loss" in components
+
+
+def test_agem_projects_conflicting_gradient() -> None:
+ model = torch.nn.Linear(2, 2, bias=False)
+ with torch.no_grad():
+ model.weight.zero_()
+ strategy = AGEMStrategy(
+ model,
+ torch.device("cpu"),
+ learning_rate=0.1,
+ buffer_size=4,
+ memory_batch_size=2,
+ seed=1,
+ )
+ strategy.buffer.add_batch(
+ torch.tensor([[2.0, 0.0], [2.0, 0.0]]),
+ torch.tensor([0, 0]),
+ )
+ current_inputs = torch.tensor([[2.0, 0.0], [2.0, 0.0]])
+ current_targets = torch.tensor([1, 1])
+
+ loss, _, _ = strategy.compute_loss(current_inputs, current_targets, task_id=1)
+ strategy.optimizer.zero_grad(set_to_none=True)
+ loss.backward()
+ strategy.after_backward(current_inputs, current_targets, task_id=1)
+
+ assert strategy.last_gradient_dot < 0.0
+ assert strategy.last_projection_applied == 1.0
diff --git a/tests/test_tracking.py b/tests/test_tracking.py
new file mode 100644
index 0000000..d0a7ab3
--- /dev/null
+++ b/tests/test_tracking.py
@@ -0,0 +1,40 @@
+from __future__ import annotations
+
+from dataclasses import replace
+from pathlib import Path
+
+import pytest
+
+from cl_bench.config import load_config
+from cl_bench.experiments import run_experiment
+
+mlflow = pytest.importorskip("mlflow")
+
+
+def test_mlflow_tracking_logs_params_metrics_and_artifacts(tmp_path) -> None:
+ tracking_uri = f"sqlite:///{tmp_path / 'mlflow.db'}"
+ config = replace(
+ load_config("smoke"),
+ output_dir=str(tmp_path / "runs"),
+ method="baseline",
+ tracking="both",
+ mlflow_tracking_uri=tracking_uri,
+ mlflow_experiment="unit-cl-bench",
+ )
+
+ result = run_experiment(config)
+
+ mlflow.set_tracking_uri(tracking_uri)
+ experiment = mlflow.get_experiment_by_name("unit-cl-bench")
+ assert experiment is not None
+ runs = mlflow.search_runs([experiment.experiment_id], output_format="list")
+ assert len(runs) == 1
+ run = runs[0]
+ assert run.data.params["method"] == "baseline"
+ assert "average_final_accuracy" in run.data.metrics
+ artifact_path = mlflow.artifacts.download_artifacts(
+ run_id=run.info.run_id,
+ artifact_path="metrics.json",
+ )
+ assert Path(artifact_path).exists()
+ assert result.metrics_path.exists()