diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..bc50a73 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,51 @@ +name: CI + +on: + push: + pull_request: + +jobs: + lint-test-build: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + - name: Install package + run: | + python -m pip install --upgrade pip + python -m pip install -e ".[dev,experiment,report]" + - name: Lint + run: ruff check . + - name: Format check + run: ruff format --check . + - name: Tests + run: pytest --cov + - name: CPU smoke benchmark + run: cl-bench run --config configs/smoke.yaml --method baseline --epochs 1 --device cpu --output-dir /tmp/cl-bench-runs + - name: CPU benchmark suite report + run: | + cl-bench suite \ + --config configs/smoke.yaml \ + --methods baseline replay derpp agem \ + --seeds 7 \ + --epochs 1 \ + --device cpu \ + --tracking json \ + --output-dir /tmp/cl-bench-runs \ + --report-dir /tmp/cl-bench-report + - name: Build package + run: python -m build + - name: Import check + run: python -c "import cl_bench; print(cl_bench.__version__)" + - uses: actions/upload-artifact@v4 + if: always() + with: + name: benchmark-report-${{ matrix.python-version }} + path: /tmp/cl-bench-report diff --git a/.gitignore b/.gitignore index 66b7405..33a4845 100644 --- a/.gitignore +++ b/.gitignore @@ -2,23 +2,39 @@ __pycache__/ *.py[cod] *$py.class -*.pyc -# Distribution / packaging +# Build and packaging artifacts *.egg-info/ dist/ build/ +.ruff_cache/ +.pytest_cache/ +.coverage +htmlcov/ # Virtual environments .venv/ venv/ env/ -# Jupyter Notebook checkpoints -.ipynb_checkpoints/ - -# Results and logs (generated at runtime) -results/logs/ +# Runtime artifacts generated by experiments +data/ +checkpoints/ +logs/ +results/ +runs/ +wandb/ +tensorboard/ +mlruns/ +.hydra/ +*.pt +*.pth +*.ckpt +*.npy +*.npz -# OS files +# Notebook/editor/OS noise +.ipynb_checkpoints/ .DS_Store +.idea/ +.vscode/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..83dc6af --- /dev/null +++ b/Dockerfile @@ -0,0 +1,16 @@ +FROM python:3.12-slim + +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONUNBUFFERED=1 + +WORKDIR /app + +COPY pyproject.toml README.md License ./ +COPY src ./src +COPY configs ./configs +COPY tests ./tests + +RUN python -m pip install --upgrade pip \ + && python -m pip install -e ".[dev,experiment,report]" + +CMD ["cl-bench", "run", "--config", "configs/smoke.yaml", "--method", "baseline", "--epochs", "1", "--device", "cpu", "--output-dir", "/tmp/cl-bench-runs"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..b580646 --- /dev/null +++ b/Makefile @@ -0,0 +1,39 @@ +PYTHON ?= python3 +VENV ?= .venv +BIN := $(VENV)/bin +export PYTHONPATH := src + +.PHONY: setup lint format test smoke suite benchmark build verify clean + +setup: + $(PYTHON) -m venv $(VENV) + $(BIN)/python -m pip install --upgrade pip + $(BIN)/python -m pip install -e ".[dev,experiment,report]" + +lint: + $(BIN)/ruff check . + $(BIN)/ruff format --check . + +format: + $(BIN)/ruff format . + $(BIN)/ruff check --fix . + +test: + $(BIN)/pytest --cov + +smoke: + $(BIN)/cl-bench run --config configs/smoke.yaml --method baseline --epochs 1 --device cpu + +suite: + $(BIN)/cl-bench suite --config configs/smoke.yaml --methods baseline ewc replay lwf derpp agem --seeds 7 --epochs 1 --device cpu --report-dir reports/smoke + +benchmark: + $(BIN)/cl-bench suite --config-name split_cifar10_headline --methods baseline ewc replay lwf derpp agem --seeds 13 21 --tracking both --report-dir docs/assets/split_cifar10_headline --title "Split CIFAR-10 Headline Benchmark" + +build: + $(BIN)/python -m build + +verify: lint test smoke build + +clean: + rm -rf build dist *.egg-info .pytest_cache .ruff_cache htmlcov .coverage diff --git a/README.md b/README.md index 49ee05b..3134779 100644 --- a/README.md +++ b/README.md @@ -1,245 +1,188 @@ -# 🧠 Continual Learning System +# Continual Learning Benchmark -
+[![CI](https://github.com/1Utkarsh1/Continual-Learning/actions/workflows/ci.yml/badge.svg)](https://github.com/1Utkarsh1/Continual-Learning/actions/workflows/ci.yml) -![Python](https://img.shields.io/badge/Python-3.8%2B-blue) -![PyTorch](https://img.shields.io/badge/PyTorch-2.0%2B-orange) -![License](https://img.shields.io/badge/License-MIT-green) -![Status](https://img.shields.io/badge/Status-Active-brightgreen) +A PyTorch benchmark framework for comparing continual-learning strategies under +the same task stream, metric suite, artifact pipeline, local MLflow tracker, and +report generator. -**A robust framework for training neural networks that can learn sequentially without forgetting** +Experiments are config-driven, methods share a stable lifecycle interface, real +and synthetic benchmarks use the same runner, and every run writes +reproducibility artifacts that can be aggregated into leaderboard CSV/JSON files +and plots. The primary reported result is a verified Split CIFAR-10 suite. -[Overview](#overview) • -[Key Features](#key-features) • -[Techniques](#techniques) • -[Installation](#installation) • -[Usage](#usage) • -[Results](#results) • -[Contributing](#contributing) +![Split CIFAR-10 headline benchmark leaderboard](docs/assets/split_cifar10_headline/leaderboard.png) -
+## Project Scope -## 🔄 Overview +- Config-driven benchmark runner for single runs and multi-method suites. +- Implemented baseline fine-tuning, EWC, reservoir replay, LwF, DER++, and A-GEM. +- Deterministic synthetic CI benchmark plus real MNIST and CIFAR-10 task streams. +- Artifact tracking for config snapshots, metadata, JSONL events, CSV matrices, + checkpoints, MLflow runs, aggregate reports, and plots. +- Python package, CLI, Dockerfile, Makefile, Ruff, pytest coverage, and GitHub + Actions matrix across Python 3.10, 3.11, and 3.12. -The Continual Learning System is a comprehensive framework for developing neural networks that can learn tasks sequentially without suffering from catastrophic forgetting. This project implements several state-of-the-art techniques to mitigate forgetting in neural networks, allowing them to adapt to new tasks while retaining performance on previously learned ones. +## Quickstart - -## 🌟 Key Features - -- **Task Sequential Learning**: Train models on a sequence of tasks without complete retraining -- **Forgetting Mitigation**: Advanced techniques to prevent catastrophic forgetting -- **Performance Tracking**: Comprehensive metrics to monitor how well knowledge is retained -- **Experiment Framework**: Easily run and compare different continual learning approaches -- **Visualization Tools**: Track and visualize forgetting metrics across sequential tasks - -## 🧩 Techniques Implemented - -### Elastic Weight Consolidation (EWC) - -EWC measures the importance of neural network weights for previously learned tasks and penalizes changes to important weights when learning new tasks. - -```python -# Loss calculation with EWC -loss = task_loss + lambda_ewc * ewc_loss +```bash +git clone https://github.com/1Utkarsh1/Continual-Learning.git +cd Continual-Learning + +python3 -m venv .venv +source .venv/bin/activate +python -m pip install --upgrade pip +python -m pip install -e ".[dev,experiment,report]" + +ruff check . +ruff format --check . +pytest --cov +cl-bench run --config-name smoke --method baseline --epochs 1 --device cpu ``` -### Experience Replay +Or use the project automation: -This technique maintains a memory buffer of examples from previous tasks and periodically replays them during training on new tasks. - -```python -# Replay during training -combined_loss = current_task_loss + alpha * replay_loss +```bash +make setup +make verify +make benchmark ``` -### Learning without Forgetting (LwF) +## CLI -LwF uses knowledge distillation to preserve the model's behavior on previous tasks when learning new ones. +Run one benchmark: -```python -# LwF distillation loss -distillation_loss = KL_divergence(current_outputs, previous_outputs) +```bash +cl-bench run --config-name smoke --method replay --epochs 1 --device cpu ``` -### Task-specific Components - -For some approaches, we isolate or add task-specific parameters while sharing a common feature extraction backbone. - -## 🔧 Installation +Run the headline Split CIFAR-10 suite with local MLflow tracking and plots: ```bash -# Clone the repository -git clone https://github.com/1Utkarsh1/continual-learning.git -cd continual-learning - -# Create a virtual environment -python -m venv venv -source venv/bin/activate # On Windows: venv\Scripts\activate - -# Install dependencies -pip install -r requirements.txt +cl-bench suite \ + --config-name split_cifar10_headline \ + --methods baseline ewc replay lwf derpp agem \ + --seeds 13 21 \ + --tracking both \ + --report-dir docs/assets/split_cifar10_headline \ + --title "Split CIFAR-10 Headline Benchmark" ``` -## 📊 Usage - -### Quick Start +Use Hydra/OmegaConf-style overrides for quick experiments: ```bash -# Run baseline experiment (sequential training without any continual learning techniques) -python src/main.py --method baseline --tasks mnist_split - -# Run EWC experiment -python src/main.py --method ewc --tasks mnist_split --lambda_ewc 5000 - -# Run Experience Replay experiment -python src/main.py --method replay --tasks mnist_split --buffer_size 500 +cl-bench suite \ + --config-name split_cifar10_headline \ + --methods baseline derpp \ + --seeds 13 \ + --tracking json \ + training.epochs=1 strategy.replay_buffer_size=500 ``` -### Custom Task Sequences - -You can define your own task sequences in a YAML configuration file: - -```yaml -# config/tasks/custom_sequence.yaml -task_sequence: - - name: "mnist_digits_0_4" - dataset: "mnist" - classes: [0, 1, 2, 3, 4] - - - name: "mnist_digits_5_9" - dataset: "mnist" - classes: [5, 6, 7, 8, 9] - - - name: "fashion_mnist" - dataset: "fashion_mnist" - classes: "all" -``` - -## 📈 Experimental Results - -### Comparison of Methods - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
MethodAverage AccuracyAverage ForgettingTraining Time
Naïve Fine-tuning45.2%35.8%1.0x
EWC78.5%10.2%1.2x
Experience Replay82.3%7.5%1.5x
LwF75.7%12.8%1.3x
-
- - -## 📝 Recent Experiment Results - -The following experiment results were obtained on March 11, 2025 using the MNIST split task sequence: - -**Baseline (Naïve Fine-tuning):** -- Command: `python src/main.py --method baseline --tasks mnist_split --epochs 5` -- Task sequence: ['mnist_0_4', 'mnist_5_9'] -- Average final accuracy: 49.74% -- Average forgetting: 49.90% - -**Learning without Forgetting (LwF):** -- Command: `python src/main.py --method lwf --tasks mnist_split --epochs 5` -- Task sequence: ['mnist_0_4', 'mnist_5_9'] -- Average final accuracy: 49.67% -- Average forgetting: 49.83% - -## 🛠️ Project Structure +Inspect local experiment runs: +```bash +mlflow ui --backend-store-uri sqlite:///mlruns/mlflow.db ``` -continual_learning/ -├── src/ # Source code -│ ├── models/ # Neural network architectures -│ ├── data/ # Data loading and preprocessing -│ ├── methods/ # Continual learning algorithms -│ ├── utils/ # Utility functions -│ └── main.py # Main entry point -├── experiments/ # Jupyter notebooks for experiments -├── config/ # Configuration files -│ ├── models/ # Model configurations -│ └── tasks/ # Task sequence definitions -├── results/ # Saved results and visualizations -└── docs/ # Documentation -``` - -## 📝 Example Experiments - -1. **Split MNIST** - - Train on digits 0-4, then 5-9 - - Compare different methods' ability to remember the first task - -2. **Task Incremental Learning** - - Train on MNIST → Fashion-MNIST → KMNIST - - Measure accuracy on all previous datasets after each task - -3. **Class Incremental Learning** - - Add new classes (one at a time) to a classifier - - Test identification of all classes after each addition -## 🔮 Roadmap +Aggregate existing run directories or MLflow artifact exports: -- [x] Implement baseline sequential training -- [x] Implement Elastic Weight Consolidation (EWC) -- [x] Implement Experience Replay -- [x] Implement Learning without Forgetting (LwF) -- [x] Add support for generative replay -- [x] Implement parameter isolation methods -- [x] Add support for continual reinforcement learning -- [x] Develop benchmark suite for comparing methods - -## 🤝 Contributing - -Contributions are welcome! Please feel free to submit a Pull Request. - -1. Fork the repository -2. Create your feature branch (`git checkout -b feature/amazing-feature`) -3. Commit your changes (`git commit -m 'Add some amazing feature'`) -4. Push to the branch (`git push origin feature/amazing-feature`) -5. Open a Pull Request +```bash +cl-bench report \ + --runs runs \ + --output-dir reports/local \ + --title "Local continual-learning report" +``` -## 📚 References +## Verified Headline Benchmark -1. Kirkpatrick, J. et al. "Overcoming catastrophic forgetting in neural networks" - *Proceedings of the National Academy of Sciences* (2017) -2. Rebuffi, S. et al. "iCaRL: Incremental Classifier and Representation Learning" - *CVPR* (2017) -3. Li, Z. and Hoiem, D. "Learning without Forgetting" - *IEEE Transactions on Pattern Analysis and Machine Intelligence* (2018) -4. Chaudhry, A. et al. "Efficient Lifelong Learning with A-GEM" - *ICLR* (2019) +Local verification on 2026-05-25 used Python 3.11.15, PyTorch 2.12.0, +torchvision 0.27.0, NumPy 2.4.6, Hydra 1.3.2, MLflow 3.12.0, Ruff 0.15.14, +pytest 9.0.3, and Matplotlib 3.10.9. -## 📄 License +Command: -This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. +```bash +cl-bench suite --config-name split_cifar10_headline --methods baseline ewc replay lwf derpp agem --seeds 13 21 --tracking both --report-dir docs/assets/split_cifar10_headline --title "Split CIFAR-10 Headline Benchmark" +``` ---- +The headline benchmark uses real CIFAR-10 images, five class-incremental tasks, +2,500 training examples per task, 1,000 test examples per task, two seeds, +5 epochs per task, a compact residual CIFAR ConvNet, and a 5,000-example replay +memory budget where applicable. It is a reproducible benchmark, not a paper +leaderboard claim. + +| Method | Average final accuracy | Average forgetting | Mean runtime | +| --- | ---: | ---: | ---: | +| DER++ | 51.15% +- 3.95% | 34.06% +- 4.74% | 578.7s | +| replay | 41.99% +- 0.27% | 45.27% +- 1.73% | 547.4s | +| LwF | 16.53% +- 0.13% | 76.71% +- 0.09% | 224.3s | +| A-GEM | 14.37% +- 0.39% | 79.34% +- 0.96% | 516.3s | +| baseline | 14.06% +- 0.10% | 79.14% +- 1.39% | 181.0s | +| EWC | 12.12% +- 0.74% | 69.20% +- 3.02% | 223.1s | + +Generated report artifacts live in +[`docs/assets/split_cifar10_headline`](docs/assets/split_cifar10_headline/README.md). + +## Architecture + +```text +src/cl_bench/ + cli.py # run, suite, report, config discovery, and overrides + config.py # TaskSpec, ExperimentConfig, BenchmarkResult + datasets.py # synthetic, MNIST-family, and CIFAR-10 task construction + experiments.py # seeded run orchestration and evaluation loop + metrics.py # accuracy, forgetting, transfer, and summary metrics + models.py # linear, MLP, small CNN, and CIFAR residual ConvNet factory + reporting.py # run aggregation, leaderboard CSV/JSON, and plots + tracking.py # JSON/JSONL/CSV artifacts and optional MLflow logging + strategies/ # baseline, EWC, replay, LwF, DER++, and A-GEM +configs/ + smoke.yaml # fast deterministic CPU benchmark + split_mnist_quick.yaml # bounded real MNIST suite for local CPU runs + split_mnist.yaml # full five-task MNIST stream + split_cifar10_headline.yaml # verified CIFAR-10 benchmark used in the README +docs/ + BENCHMARK_CARD.md # scope, metrics, limitations, reproducibility +tests/ # unit and integration coverage +``` -
- Made with ❤️ by the Continual Learning Team
- GitHub • - Website • - Contact -
+## Run Artifacts + +Each run is written to `runs/__/` and contains: + +- `config.yaml`: exact config snapshot. +- `run_metadata.json`: seed, device, and git commit when available. +- `metrics.jsonl`: event-level training and evaluation metrics. +- `metrics.json`: final run summary. +- `accuracy_matrix.{json,csv}` and `forgetting_matrix.{json,csv}`. +- Optional `checkpoints/final_model.pt`. +- Optional MLflow run entries with params, metrics, tags, and artifacts. + +The report command writes: + +- `leaderboard.csv` +- `summary.json` +- `README.md` +- `leaderboard.png` +- `retention_curves.png` +- `accuracy_matrices.png` + +Generated `data/`, `runs/`, `results/`, logs, checkpoints, and NumPy arrays are +ignored by git. Curated README assets under `docs/assets/` are intentionally kept. + +## Engineering Notes + +- EWC estimates the empirical Fisher from per-sample log-likelihood gradients and + normalizes by the actual number of samples used. +- Replay uses reservoir sampling so a bounded buffer represents the full observed + stream instead of only the newest examples. +- LwF stores a frozen teacher after each task and combines supervised loss with + temperature-scaled KL distillation. +- DER++ stores replay logits online and combines current CE, replay CE, and + logit-matching losses. +- A-GEM projects conflicting gradients against replay-memory reference gradients. +- Best validation checkpoints are deep-copied before restoration to avoid mutable + `state_dict` aliasing bugs. +- The suite/report layer separates expensive benchmark execution from cheap, + repeatable analysis over saved metrics. diff --git a/configs/smoke.yaml b/configs/smoke.yaml new file mode 100644 index 0000000..2cdc2e4 --- /dev/null +++ b/configs/smoke.yaml @@ -0,0 +1,33 @@ +name: smoke +method: baseline +seed: 7 +device: cpu +model: linear +data_dir: data +output_dir: runs +training: + epochs: 1 + batch_size: 16 + eval_batch_size: 64 + learning_rate: 0.05 + val_fraction: 0.2 + num_workers: 0 +strategy: + ewc_lambda: 10.0 + fisher_samples: 16 + replay_buffer_size: 32 + replay_batch_size: 8 + replay_loss_weight: 1.0 + lwf_alpha: 0.5 + lwf_temperature: 2.0 +tasks: + - name: synthetic_0_1 + dataset: synthetic + classes: [0, 1] + samples_per_class: 20 + test_samples_per_class: 10 + - name: synthetic_2_3 + dataset: synthetic + classes: [2, 3] + samples_per_class: 20 + test_samples_per_class: 10 diff --git a/configs/split_cifar10_headline.yaml b/configs/split_cifar10_headline.yaml new file mode 100644 index 0000000..c4c1831 --- /dev/null +++ b/configs/split_cifar10_headline.yaml @@ -0,0 +1,56 @@ +name: split_cifar10_headline +method: derpp +seed: 13 +device: auto +model: cifar_convnet +data_dir: data +output_dir: runs +tracking: + mode: both + mlflow_tracking_uri: sqlite:///mlruns/mlflow.db + mlflow_experiment: continual-learning-bench +training: + epochs: 5 + batch_size: 128 + eval_batch_size: 512 + learning_rate: 0.001 + val_fraction: 0.1 + num_workers: 0 + augment: true +strategy: + ewc_lambda: 50.0 + fisher_samples: 256 + replay_buffer_size: 5000 + replay_batch_size: 256 + replay_loss_weight: 3.0 + lwf_alpha: 0.5 + lwf_temperature: 2.0 + derpp_alpha: 0.1 + derpp_beta: 2.0 + agem_memory_batch_size: 256 +tasks: + - name: cifar10_airplane_automobile + dataset: cifar10 + classes: [0, 1] + train_limit: 2500 + test_limit: 1000 + - name: cifar10_bird_cat + dataset: cifar10 + classes: [2, 3] + train_limit: 2500 + test_limit: 1000 + - name: cifar10_deer_dog + dataset: cifar10 + classes: [4, 5] + train_limit: 2500 + test_limit: 1000 + - name: cifar10_frog_horse + dataset: cifar10 + classes: [6, 7] + train_limit: 2500 + test_limit: 1000 + - name: cifar10_ship_truck + dataset: cifar10 + classes: [8, 9] + train_limit: 2500 + test_limit: 1000 diff --git a/configs/split_mnist.yaml b/configs/split_mnist.yaml new file mode 100644 index 0000000..f3996f9 --- /dev/null +++ b/configs/split_mnist.yaml @@ -0,0 +1,38 @@ +name: split_mnist +method: baseline +seed: 42 +device: auto +model: small_cnn +data_dir: data +output_dir: runs +training: + epochs: 1 + batch_size: 64 + eval_batch_size: 256 + learning_rate: 0.001 + val_fraction: 0.1 + num_workers: 2 +strategy: + ewc_lambda: 100.0 + fisher_samples: 256 + replay_buffer_size: 500 + replay_batch_size: 64 + replay_loss_weight: 1.0 + lwf_alpha: 0.5 + lwf_temperature: 2.0 +tasks: + - name: mnist_0_1 + dataset: mnist + classes: [0, 1] + - name: mnist_2_3 + dataset: mnist + classes: [2, 3] + - name: mnist_4_5 + dataset: mnist + classes: [4, 5] + - name: mnist_6_7 + dataset: mnist + classes: [6, 7] + - name: mnist_8_9 + dataset: mnist + classes: [8, 9] diff --git a/configs/split_mnist_quick.yaml b/configs/split_mnist_quick.yaml new file mode 100644 index 0000000..1bbc1f4 --- /dev/null +++ b/configs/split_mnist_quick.yaml @@ -0,0 +1,48 @@ +name: split_mnist_quick +method: baseline +seed: 13 +device: auto +model: small_cnn +data_dir: data +output_dir: runs +training: + epochs: 3 + batch_size: 128 + eval_batch_size: 512 + learning_rate: 0.001 + val_fraction: 0.1 + num_workers: 0 +strategy: + ewc_lambda: 75.0 + fisher_samples: 64 + replay_buffer_size: 400 + replay_batch_size: 64 + replay_loss_weight: 1.0 + lwf_alpha: 0.5 + lwf_temperature: 2.0 +tasks: + - name: mnist_0_1 + dataset: mnist + classes: [0, 1] + train_limit: 600 + test_limit: 300 + - name: mnist_2_3 + dataset: mnist + classes: [2, 3] + train_limit: 600 + test_limit: 300 + - name: mnist_4_5 + dataset: mnist + classes: [4, 5] + train_limit: 600 + test_limit: 300 + - name: mnist_6_7 + dataset: mnist + classes: [6, 7] + train_limit: 600 + test_limit: 300 + - name: mnist_8_9 + dataset: mnist + classes: [8, 9] + train_limit: 600 + test_limit: 300 diff --git a/data/MNIST/raw/t10k-images-idx3-ubyte b/data/MNIST/raw/t10k-images-idx3-ubyte deleted file mode 100644 index 1170b2c..0000000 Binary files a/data/MNIST/raw/t10k-images-idx3-ubyte and /dev/null differ diff --git a/data/MNIST/raw/t10k-images-idx3-ubyte.gz b/data/MNIST/raw/t10k-images-idx3-ubyte.gz deleted file mode 100644 index 5ace8ea..0000000 Binary files a/data/MNIST/raw/t10k-images-idx3-ubyte.gz and /dev/null differ diff --git a/data/MNIST/raw/t10k-labels-idx1-ubyte b/data/MNIST/raw/t10k-labels-idx1-ubyte deleted file mode 100644 index d1c3a97..0000000 Binary files a/data/MNIST/raw/t10k-labels-idx1-ubyte and /dev/null differ diff --git a/data/MNIST/raw/t10k-labels-idx1-ubyte.gz b/data/MNIST/raw/t10k-labels-idx1-ubyte.gz deleted file mode 100644 index a7e1415..0000000 Binary files a/data/MNIST/raw/t10k-labels-idx1-ubyte.gz and /dev/null differ diff --git a/data/MNIST/raw/train-images-idx3-ubyte b/data/MNIST/raw/train-images-idx3-ubyte deleted file mode 100644 index bbce276..0000000 Binary files a/data/MNIST/raw/train-images-idx3-ubyte and /dev/null differ diff --git a/data/MNIST/raw/train-images-idx3-ubyte.gz b/data/MNIST/raw/train-images-idx3-ubyte.gz deleted file mode 100644 index b50e4b6..0000000 Binary files a/data/MNIST/raw/train-images-idx3-ubyte.gz and /dev/null differ diff --git a/data/MNIST/raw/train-labels-idx1-ubyte b/data/MNIST/raw/train-labels-idx1-ubyte deleted file mode 100644 index d6b4c5d..0000000 Binary files a/data/MNIST/raw/train-labels-idx1-ubyte and /dev/null differ diff --git a/data/MNIST/raw/train-labels-idx1-ubyte.gz b/data/MNIST/raw/train-labels-idx1-ubyte.gz deleted file mode 100644 index 707a576..0000000 Binary files a/data/MNIST/raw/train-labels-idx1-ubyte.gz and /dev/null differ diff --git a/docs/BENCHMARK_CARD.md b/docs/BENCHMARK_CARD.md new file mode 100644 index 0000000..d17bbe6 --- /dev/null +++ b/docs/BENCHMARK_CARD.md @@ -0,0 +1,45 @@ +# Benchmark Card + +## Scope + +This project evaluates continual-learning strategies on task streams where a model +sees one subset of classes at a time and is re-evaluated on all previously seen +tasks after every training step. + +## Implemented Methods + +- Baseline sequential fine-tuning. +- Elastic Weight Consolidation with empirical Fisher estimates. +- Reservoir replay with a bounded memory budget. +- Learning without Forgetting with temperature-scaled distillation. +- DER++ with online replay-logit storage. +- A-GEM with replay-memory gradient projection. + +## Datasets + +- `synthetic`: deterministic image-like tensors for fast CI and unit tests. +- `split_mnist_quick`: real MNIST images with bounded per-task subsets so a full + four-method comparison runs on CPU. +- `split_mnist`: full five-task MNIST stream for longer local experiments. +- `split_cifar10_headline`: real CIFAR-10 images, five class-incremental tasks, + two verification seeds, and a compact residual ConvNet. + +## Metrics + +- Average final accuracy: mean accuracy over seen tasks after the final task. +- Average learning accuracy: mean diagonal accuracy after each task is learned. +- Average forgetting: best previous accuracy minus final accuracy for prior tasks. +- Backward transfer: final accuracy minus first-learned accuracy on prior tasks. + +## Reproducibility + +Every run writes `config.yaml`, `run_metadata.json`, event-level `metrics.jsonl`, +CSV/JSON matrices, final `metrics.json`, and optionally MLflow params, metrics, +tags, and artifacts. The suite command can aggregate multiple methods and seeds +into a leaderboard plus report plots. + +## Limitations + +The Split CIFAR-10 reported result is designed for local reproducibility and +engineering verification. It is not a leaderboard claim. Serious research +comparisons should increase seeds, epochs, memory budgets, and dataset coverage. diff --git a/docs/assets/split_cifar10_headline/README.md b/docs/assets/split_cifar10_headline/README.md new file mode 100644 index 0000000..27f120e --- /dev/null +++ b/docs/assets/split_cifar10_headline/README.md @@ -0,0 +1,41 @@ +# Split CIFAR-10 Headline Benchmark + +Benchmarks: `split_cifar10_headline` +Runs: `12` +Tasks per run: `5` + +## Leaderboard + +| Method | Runs | Seeds | Final accuracy | Forgetting | Backward transfer | Mean runtime | +| --- | ---: | --- | ---: | ---: | ---: | ---: | +| derpp | 2 | 13,21 | 51.15 +- 3.95% | 34.06 +- 4.74% | -34.06% | 578.7s | +| replay | 2 | 13,21 | 41.99 +- 0.27% | 45.27 +- 1.73% | -45.27% | 547.4s | +| lwf | 2 | 13,21 | 16.53 +- 0.13% | 76.71 +- 0.09% | -76.71% | 224.3s | +| agem | 2 | 13,21 | 14.37 +- 0.39% | 79.34 +- 0.96% | -79.34% | 516.3s | +| baseline | 2 | 13,21 | 14.06 +- 0.10% | 79.14 +- 1.39% | -79.14% | 181.0s | +| ewc | 2 | 13,21 | 12.12 +- 0.74% | 69.20 +- 3.02% | -69.20% | 223.1s | + +## Plots + +![leaderboard](leaderboard.png) + +![retention_curves](retention_curves.png) + +![accuracy_matrices](accuracy_matrices.png) + +## Source Runs + +| Method | Seed | Run directory | +| --- | ---: | --- | +| agem | 13 | `runs/split_cifar10_headline_agem_20260525T133221Z` | +| agem | 21 | `runs/split_cifar10_headline_agem_20260525T141129Z` | +| baseline | 13 | `runs/split_cifar10_headline_baseline_20260525T130248Z` | +| baseline | 21 | `runs/split_cifar10_headline_baseline_20260525T134059Z` | +| derpp | 13 | `runs/split_cifar10_headline_derpp_20260525T132220Z` | +| derpp | 21 | `runs/split_cifar10_headline_derpp_20260525T140155Z` | +| ewc | 13 | `runs/split_cifar10_headline_ewc_20260525T130553Z` | +| ewc | 21 | `runs/split_cifar10_headline_ewc_20260525T134413Z` | +| lwf | 13 | `runs/split_cifar10_headline_lwf_20260525T131842Z` | +| lwf | 21 | `runs/split_cifar10_headline_lwf_20260525T135744Z` | +| replay | 13 | `runs/split_cifar10_headline_replay_20260525T130954Z` | +| replay | 21 | `runs/split_cifar10_headline_replay_20260525T134756Z` | diff --git a/docs/assets/split_cifar10_headline/accuracy_matrices.png b/docs/assets/split_cifar10_headline/accuracy_matrices.png new file mode 100644 index 0000000..2db9218 Binary files /dev/null and b/docs/assets/split_cifar10_headline/accuracy_matrices.png differ diff --git a/docs/assets/split_cifar10_headline/leaderboard.csv b/docs/assets/split_cifar10_headline/leaderboard.csv new file mode 100644 index 0000000..17dc817 --- /dev/null +++ b/docs/assets/split_cifar10_headline/leaderboard.csv @@ -0,0 +1,7 @@ +method,runs,seeds,average_final_accuracy_mean,average_final_accuracy_std,average_learning_accuracy_mean,average_forgetting_mean,average_forgetting_std,backward_transfer_mean,runtime_seconds_mean +derpp,2,"13,21",51.150000000000006,3.9499999999999993,78.39999999999999,34.0625,4.737499999999999,-34.0625,578.7151509795076 +replay,2,"13,21",41.99000000000001,0.2699999999999996,78.21,45.275,1.7250000000000014,-45.275,547.4045557500067 +lwf,2,"13,21",16.53,0.13000000000000078,77.9,76.7125,0.08749999999999858,-76.7125,224.33408429198607 +agem,2,"13,21",14.370000000000001,0.39000000000000057,77.84,79.3375,0.9624999999999986,-79.3375,516.2666623959958 +baseline,2,"13,21",14.059999999999999,0.10000000000000053,77.37,79.13749999999999,1.3874999999999957,-79.13749999999999,180.97890222900605 +ewc,2,"13,21",12.12,0.7400000000000002,67.48000000000002,69.19999999999999,3.0249999999999986,-69.19999999999999,223.06707881249895 diff --git a/docs/assets/split_cifar10_headline/leaderboard.png b/docs/assets/split_cifar10_headline/leaderboard.png new file mode 100644 index 0000000..49f4e7d Binary files /dev/null and b/docs/assets/split_cifar10_headline/leaderboard.png differ diff --git a/docs/assets/split_cifar10_headline/retention_curves.png b/docs/assets/split_cifar10_headline/retention_curves.png new file mode 100644 index 0000000..99c0277 Binary files /dev/null and b/docs/assets/split_cifar10_headline/retention_curves.png differ diff --git a/docs/assets/split_cifar10_headline/summary.json b/docs/assets/split_cifar10_headline/summary.json new file mode 100644 index 0000000..f5e8421 --- /dev/null +++ b/docs/assets/split_cifar10_headline/summary.json @@ -0,0 +1,408 @@ +{ + "benchmarks": [ + "split_cifar10_headline" + ], + "generated_at_utc": "2026-05-25T14:20:17.367419+00:00", + "leaderboard": [ + { + "average_final_accuracy_mean": 51.150000000000006, + "average_final_accuracy_std": 3.9499999999999993, + "average_forgetting_mean": 34.0625, + "average_forgetting_std": 4.737499999999999, + "average_learning_accuracy_mean": 78.39999999999999, + "backward_transfer_mean": -34.0625, + "method": "derpp", + "runs": 2, + "runtime_seconds_mean": 578.7151509795076, + "seeds": "13,21" + }, + { + "average_final_accuracy_mean": 41.99000000000001, + "average_final_accuracy_std": 0.2699999999999996, + "average_forgetting_mean": 45.275, + "average_forgetting_std": 1.7250000000000014, + "average_learning_accuracy_mean": 78.21, + "backward_transfer_mean": -45.275, + "method": "replay", + "runs": 2, + "runtime_seconds_mean": 547.4045557500067, + "seeds": "13,21" + }, + { + "average_final_accuracy_mean": 16.53, + "average_final_accuracy_std": 0.13000000000000078, + "average_forgetting_mean": 76.7125, + "average_forgetting_std": 0.08749999999999858, + "average_learning_accuracy_mean": 77.9, + "backward_transfer_mean": -76.7125, + "method": "lwf", + "runs": 2, + "runtime_seconds_mean": 224.33408429198607, + "seeds": "13,21" + }, + { + "average_final_accuracy_mean": 14.370000000000001, + "average_final_accuracy_std": 0.39000000000000057, + "average_forgetting_mean": 79.3375, + "average_forgetting_std": 0.9624999999999986, + "average_learning_accuracy_mean": 77.84, + "backward_transfer_mean": -79.3375, + "method": "agem", + "runs": 2, + "runtime_seconds_mean": 516.2666623959958, + "seeds": "13,21" + }, + { + "average_final_accuracy_mean": 14.059999999999999, + "average_final_accuracy_std": 0.10000000000000053, + "average_forgetting_mean": 79.13749999999999, + "average_forgetting_std": 1.3874999999999957, + "average_learning_accuracy_mean": 77.37, + "backward_transfer_mean": -79.13749999999999, + "method": "baseline", + "runs": 2, + "runtime_seconds_mean": 180.97890222900605, + "seeds": "13,21" + }, + { + "average_final_accuracy_mean": 12.12, + "average_final_accuracy_std": 0.7400000000000002, + "average_forgetting_mean": 69.19999999999999, + "average_forgetting_std": 3.0249999999999986, + "average_learning_accuracy_mean": 67.48000000000002, + "backward_transfer_mean": -69.19999999999999, + "method": "ewc", + "runs": 2, + "runtime_seconds_mean": 223.06707881249895, + "seeds": "13,21" + } + ], + "num_runs": 12, + "runs": [ + { + "benchmark": "split_cifar10_headline", + "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178", + "method": "agem", + "run_dir": "runs/split_cifar10_headline_agem_20260525T133221Z", + "runtime_seconds": 505.9784695839917, + "seed": 13, + "summary": { + "average_final_accuracy": 14.760000000000002, + "average_forgetting": 80.3, + "average_learning_accuracy": 79.0, + "backward_transfer": -80.3, + "model": "cifar_convnet", + "num_tasks": 5, + "replay_batch_size": 256, + "replay_buffer_size": 5000, + "runtime_seconds": 505.9784695839917, + "seed": 13 + }, + "task_names": [ + "cifar10_airplane_automobile", + "cifar10_bird_cat", + "cifar10_deer_dog", + "cifar10_frog_horse", + "cifar10_ship_truck" + ] + }, + { + "benchmark": "split_cifar10_headline", + "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178", + "method": "agem", + "run_dir": "runs/split_cifar10_headline_agem_20260525T141129Z", + "runtime_seconds": 526.5548552079999, + "seed": 21, + "summary": { + "average_final_accuracy": 13.98, + "average_forgetting": 78.375, + "average_learning_accuracy": 76.68, + "backward_transfer": -78.375, + "model": "cifar_convnet", + "num_tasks": 5, + "replay_batch_size": 256, + "replay_buffer_size": 5000, + "runtime_seconds": 526.5548552079999, + "seed": 21 + }, + "task_names": [ + "cifar10_airplane_automobile", + "cifar10_bird_cat", + "cifar10_deer_dog", + "cifar10_frog_horse", + "cifar10_ship_truck" + ] + }, + { + "benchmark": "split_cifar10_headline", + "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178", + "method": "baseline", + "run_dir": "runs/split_cifar10_headline_baseline_20260525T130248Z", + "runtime_seconds": 175.88738362499862, + "seed": 13, + "summary": { + "average_final_accuracy": 14.16, + "average_forgetting": 77.75, + "average_learning_accuracy": 76.36, + "backward_transfer": -77.75, + "model": "cifar_convnet", + "num_tasks": 5, + "replay_batch_size": 256, + "replay_buffer_size": 5000, + "runtime_seconds": 175.88738362499862, + "seed": 13 + }, + "task_names": [ + "cifar10_airplane_automobile", + "cifar10_bird_cat", + "cifar10_deer_dog", + "cifar10_frog_horse", + "cifar10_ship_truck" + ] + }, + { + "benchmark": "split_cifar10_headline", + "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178", + "method": "baseline", + "run_dir": "runs/split_cifar10_headline_baseline_20260525T134059Z", + "runtime_seconds": 186.0704208330135, + "seed": 21, + "summary": { + "average_final_accuracy": 13.959999999999999, + "average_forgetting": 80.52499999999999, + "average_learning_accuracy": 78.38, + "backward_transfer": -80.52499999999999, + "model": "cifar_convnet", + "num_tasks": 5, + "replay_batch_size": 256, + "replay_buffer_size": 5000, + "runtime_seconds": 186.0704208330135, + "seed": 21 + }, + "task_names": [ + "cifar10_airplane_automobile", + "cifar10_bird_cat", + "cifar10_deer_dog", + "cifar10_frog_horse", + "cifar10_ship_truck" + ] + }, + { + "benchmark": "split_cifar10_headline", + "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178", + "method": "derpp", + "run_dir": "runs/split_cifar10_headline_derpp_20260525T132220Z", + "runtime_seconds": 592.4993861670082, + "seed": 13, + "summary": { + "average_final_accuracy": 55.1, + "average_forgetting": 29.325, + "average_learning_accuracy": 78.55999999999999, + "backward_transfer": -29.325, + "model": "cifar_convnet", + "num_tasks": 5, + "replay_batch_size": 256, + "replay_buffer_size": 5000, + "runtime_seconds": 592.4993861670082, + "seed": 13 + }, + "task_names": [ + "cifar10_airplane_automobile", + "cifar10_bird_cat", + "cifar10_deer_dog", + "cifar10_frog_horse", + "cifar10_ship_truck" + ] + }, + { + "benchmark": "split_cifar10_headline", + "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178", + "method": "derpp", + "run_dir": "runs/split_cifar10_headline_derpp_20260525T140155Z", + "runtime_seconds": 564.930915792007, + "seed": 21, + "summary": { + "average_final_accuracy": 47.2, + "average_forgetting": 38.8, + "average_learning_accuracy": 78.24, + "backward_transfer": -38.8, + "model": "cifar_convnet", + "num_tasks": 5, + "replay_batch_size": 256, + "replay_buffer_size": 5000, + "runtime_seconds": 564.930915792007, + "seed": 21 + }, + "task_names": [ + "cifar10_airplane_automobile", + "cifar10_bird_cat", + "cifar10_deer_dog", + "cifar10_frog_horse", + "cifar10_ship_truck" + ] + }, + { + "benchmark": "split_cifar10_headline", + "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178", + "method": "ewc", + "run_dir": "runs/split_cifar10_headline_ewc_20260525T130553Z", + "runtime_seconds": 231.59013450000202, + "seed": 13, + "summary": { + "average_final_accuracy": 12.86, + "average_forgetting": 72.225, + "average_learning_accuracy": 70.64000000000001, + "backward_transfer": -72.225, + "model": "cifar_convnet", + "num_tasks": 5, + "replay_batch_size": 256, + "replay_buffer_size": 5000, + "runtime_seconds": 231.59013450000202, + "seed": 13 + }, + "task_names": [ + "cifar10_airplane_automobile", + "cifar10_bird_cat", + "cifar10_deer_dog", + "cifar10_frog_horse", + "cifar10_ship_truck" + ] + }, + { + "benchmark": "split_cifar10_headline", + "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178", + "method": "ewc", + "run_dir": "runs/split_cifar10_headline_ewc_20260525T134413Z", + "runtime_seconds": 214.54402312499587, + "seed": 21, + "summary": { + "average_final_accuracy": 11.379999999999999, + "average_forgetting": 66.175, + "average_learning_accuracy": 64.32000000000001, + "backward_transfer": -66.175, + "model": "cifar_convnet", + "num_tasks": 5, + "replay_batch_size": 256, + "replay_buffer_size": 5000, + "runtime_seconds": 214.54402312499587, + "seed": 21 + }, + "task_names": [ + "cifar10_airplane_automobile", + "cifar10_bird_cat", + "cifar10_deer_dog", + "cifar10_frog_horse", + "cifar10_ship_truck" + ] + }, + { + "benchmark": "split_cifar10_headline", + "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178", + "method": "lwf", + "run_dir": "runs/split_cifar10_headline_lwf_20260525T131842Z", + "runtime_seconds": 210.86184020899236, + "seed": 13, + "summary": { + "average_final_accuracy": 16.4, + "average_forgetting": 76.625, + "average_learning_accuracy": 77.7, + "backward_transfer": -76.625, + "model": "cifar_convnet", + "num_tasks": 5, + "replay_batch_size": 256, + "replay_buffer_size": 5000, + "runtime_seconds": 210.86184020899236, + "seed": 13 + }, + "task_names": [ + "cifar10_airplane_automobile", + "cifar10_bird_cat", + "cifar10_deer_dog", + "cifar10_frog_horse", + "cifar10_ship_truck" + ] + }, + { + "benchmark": "split_cifar10_headline", + "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178", + "method": "lwf", + "run_dir": "runs/split_cifar10_headline_lwf_20260525T135744Z", + "runtime_seconds": 237.80632837497978, + "seed": 21, + "summary": { + "average_final_accuracy": 16.66, + "average_forgetting": 76.8, + "average_learning_accuracy": 78.1, + "backward_transfer": -76.8, + "model": "cifar_convnet", + "num_tasks": 5, + "replay_batch_size": 256, + "replay_buffer_size": 5000, + "runtime_seconds": 237.80632837497978, + "seed": 21 + }, + "task_names": [ + "cifar10_airplane_automobile", + "cifar10_bird_cat", + "cifar10_deer_dog", + "cifar10_frog_horse", + "cifar10_ship_truck" + ] + }, + { + "benchmark": "split_cifar10_headline", + "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178", + "method": "replay", + "run_dir": "runs/split_cifar10_headline_replay_20260525T130954Z", + "runtime_seconds": 519.1472583750146, + "seed": 13, + "summary": { + "average_final_accuracy": 42.260000000000005, + "average_forgetting": 43.55, + "average_learning_accuracy": 77.1, + "backward_transfer": -43.55, + "model": "cifar_convnet", + "num_tasks": 5, + "replay_batch_size": 256, + "replay_buffer_size": 5000, + "runtime_seconds": 519.1472583750146, + "seed": 13 + }, + "task_names": [ + "cifar10_airplane_automobile", + "cifar10_bird_cat", + "cifar10_deer_dog", + "cifar10_frog_horse", + "cifar10_ship_truck" + ] + }, + { + "benchmark": "split_cifar10_headline", + "git_commit": "972afde98a891a3c5aa25c8602b4e266fce07178", + "method": "replay", + "run_dir": "runs/split_cifar10_headline_replay_20260525T134756Z", + "runtime_seconds": 575.6618531249987, + "seed": 21, + "summary": { + "average_final_accuracy": 41.720000000000006, + "average_forgetting": 47.0, + "average_learning_accuracy": 79.32, + "backward_transfer": -47.0, + "model": "cifar_convnet", + "num_tasks": 5, + "replay_batch_size": 256, + "replay_buffer_size": 5000, + "runtime_seconds": 575.6618531249987, + "seed": 21 + }, + "task_names": [ + "cifar10_airplane_automobile", + "cifar10_bird_cat", + "cifar10_deer_dog", + "cifar10_frog_horse", + "cifar10_ship_truck" + ] + } + ], + "title": "Split CIFAR-10 Headline Benchmark" +} diff --git a/experiments/split_mnist_comparison.ipynb b/experiments/split_mnist_comparison.ipynb deleted file mode 100644 index 03d5c53..0000000 --- a/experiments/split_mnist_comparison.ipynb +++ /dev/null @@ -1,266 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Split MNIST Continual Learning Comparison\n", - "\n", - "This notebook demonstrates a comparison of different continual learning methods on the Split MNIST benchmark.\n", - "\n", - "The task sequence consists of two tasks:\n", - "- **Task 1**: MNIST classes 0–4\n", - "- **Task 2**: MNIST classes 5–9\n", - "\n", - "Methods compared:\n", - "- **Baseline**: Naive fine-tuning (demonstrates catastrophic forgetting)\n", - "- **EWC**: Elastic Weight Consolidation\n", - "- **Replay**: Experience Replay\n", - "- **LwF**: Learning without Forgetting" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import sys\n", - "\n", - "# Ensure the repository root is on the Python path\n", - "repo_root = os.path.abspath(os.path.join(os.getcwd(), '..'))\n", - "if repo_root not in sys.path:\n", - " sys.path.insert(0, repo_root)\n", - "\n", - "import numpy as np\n", - "import torch\n", - "import matplotlib.pyplot as plt\n", - "\n", - "from src.data.data_loader import get_task_sequence\n", - "from src.models.model_factory import get_model\n", - "from src.methods.baseline import BaselineLearner\n", - "from src.methods.ewc import EWCLearner\n", - "from src.methods.replay import ExperienceReplayLearner\n", - "from src.methods.lwf import LwFLearner\n", - "from src.utils.metrics import evaluate_performance, compute_forgetting\n", - "from src.utils.visualization import plot_performance, plot_forgetting\n", - "\n", - "print('Imports successful')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Configuration\n", - "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", - "BATCH_SIZE = 64\n", - "EPOCHS = 3\n", - "LEARNING_RATE = 0.001\n", - "SEED = 42\n", - "\n", - "# Set random seed for reproducibility\n", - "torch.manual_seed(SEED)\n", - "np.random.seed(SEED)\n", - "\n", - "print(f'Using device: {DEVICE}')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Define the Split MNIST task sequence\n", - "task_sequence = [\n", - " {'name': 'mnist_0_4', 'dataset': 'mnist', 'classes': [0, 1, 2, 3, 4]},\n", - " {'name': 'mnist_5_9', 'dataset': 'mnist', 'classes': [5, 6, 7, 8, 9]},\n", - "]\n", - "\n", - "# Load data\n", - "task_data = get_task_sequence(task_sequence, BATCH_SIZE)\n", - "\n", - "# Determine input shape and number of classes\n", - "input_shape = task_data[0]['train_loader'].dataset[0][0].shape\n", - "all_classes = set()\n", - "for task in task_sequence:\n", - " all_classes.update(task['classes'])\n", - "num_classes = len(all_classes)\n", - "\n", - "print(f'Input shape: {input_shape}')\n", - "print(f'Number of classes: {num_classes}')\n", - "print(f'Tasks: {[t[\"name\"] for t in task_sequence]}')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def run_experiment(learner, task_data, task_sequence):\n", - " \"\"\"Train a learner sequentially on all tasks and record performance.\"\"\"\n", - " n_tasks = len(task_sequence)\n", - " performance_matrix = np.zeros((n_tasks, n_tasks))\n", - "\n", - " for task_id, task_dict in enumerate(task_data):\n", - " train_loader = task_dict['train_loader']\n", - " val_loader = task_dict['val_loader']\n", - " print(f' Training on task {task_id + 1}/{n_tasks}: {task_sequence[task_id][\"name\"]}')\n", - "\n", - " learner.train(\n", - " train_loader=train_loader,\n", - " val_loader=val_loader,\n", - " task_id=task_id,\n", - " epochs=EPOCHS,\n", - " eval_freq=1,\n", - " )\n", - "\n", - " for eval_task_id in range(task_id + 1):\n", - " eval_loader = task_data[eval_task_id]['test_loader']\n", - " accuracy = learner.evaluate(eval_loader, eval_task_id)\n", - " performance_matrix[task_id, eval_task_id] = accuracy\n", - " print(f' Accuracy on task {eval_task_id + 1}: {accuracy:.2f}%')\n", - "\n", - " return performance_matrix" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "results = {}\n", - "\n", - "# --- Baseline ---\n", - "print('=== Baseline (Fine-tuning) ===')\n", - "model = get_model('simple_cnn', input_shape, num_classes).to(DEVICE)\n", - "learner = BaselineLearner(model=model, device=DEVICE, learning_rate=LEARNING_RATE)\n", - "results['baseline'] = run_experiment(learner, task_data, task_sequence)\n", - "\n", - "print()\n", - "\n", - "# --- EWC ---\n", - "print('=== EWC ===')\n", - "model = get_model('simple_cnn', input_shape, num_classes).to(DEVICE)\n", - "learner = EWCLearner(model=model, device=DEVICE, learning_rate=LEARNING_RATE,\n", - " lambda_ewc=5000, fisher_sample_size=200)\n", - "results['ewc'] = run_experiment(learner, task_data, task_sequence)\n", - "\n", - "print()\n", - "\n", - "# --- Experience Replay ---\n", - "print('=== Experience Replay ===')\n", - "model = get_model('simple_cnn', input_shape, num_classes).to(DEVICE)\n", - "learner = ExperienceReplayLearner(model=model, device=DEVICE, learning_rate=LEARNING_RATE,\n", - " buffer_size=500, replay_batch_size=32)\n", - "results['replay'] = run_experiment(learner, task_data, task_sequence)\n", - "\n", - "print()\n", - "\n", - "# --- LwF ---\n", - "print('=== LwF ===')\n", - "model = get_model('simple_cnn', input_shape, num_classes).to(DEVICE)\n", - "learner = LwFLearner(model=model, device=DEVICE, learning_rate=LEARNING_RATE,\n", - " temperature=2.0, alpha=1.0)\n", - "results['lwf'] = run_experiment(learner, task_data, task_sequence)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "task_names = [t['name'] for t in task_sequence]\n", - "\n", - "# Summary table\n", - "print('\\n=== Final Accuracy on Each Task After All Training ===')\n", - "print(f'{\"Method\":<12}', end='')\n", - "for name in task_names:\n", - " print(f'{name:>16}', end='')\n", - "print(f'{\"Average\":>12}')\n", - "\n", - "for method, perf in results.items():\n", - " final_row = perf[-1, :len(task_names)]\n", - " print(f'{method:<12}', end='')\n", - " for acc in final_row:\n", - " print(f'{acc:>15.2f}%', end='')\n", - " print(f'{np.mean(final_row):>11.2f}%')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Plot performance matrices\n", - "fig, axes = plt.subplots(2, 2, figsize=(14, 10))\n", - "method_names = list(results.keys())\n", - "\n", - "for ax, (method, perf) in zip(axes.flatten(), results.items()):\n", - " im = ax.imshow(perf, vmin=0, vmax=100, cmap='Blues', aspect='auto')\n", - " ax.set_title(method.upper())\n", - " ax.set_xlabel('Task evaluated on')\n", - " ax.set_ylabel('After training task')\n", - " ax.set_xticks(range(len(task_names)))\n", - " ax.set_yticks(range(len(task_names)))\n", - " ax.set_xticklabels(task_names, rotation=30, ha='right')\n", - " ax.set_yticklabels(task_names)\n", - " for i in range(len(task_names)):\n", - " for j in range(len(task_names)):\n", - " if perf[i, j] > 0:\n", - " ax.text(j, i, f'{perf[i, j]:.1f}', ha='center', va='center', fontsize=9)\n", - " plt.colorbar(im, ax=ax)\n", - "\n", - "plt.suptitle('Accuracy (%) — Split MNIST Comparison', fontsize=14)\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Average final accuracy comparison bar chart\n", - "avg_accuracies = {\n", - " method: np.mean(perf[-1, :len(task_names)])\n", - " for method, perf in results.items()\n", - "}\n", - "\n", - "plt.figure(figsize=(8, 5))\n", - "bars = plt.bar(avg_accuracies.keys(), avg_accuracies.values(),\n", - " color=['#4c72b0', '#dd8452', '#55a868', '#c44e52'])\n", - "for bar, val in zip(bars, avg_accuracies.values()):\n", - " plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,\n", - " f'{val:.1f}%', ha='center', va='bottom')\n", - "plt.ylim(0, 110)\n", - "plt.ylabel('Average Accuracy (%)')\n", - "plt.title('Average Final Accuracy — Split MNIST')\n", - "plt.tight_layout()\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.8.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..da07a1a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,86 @@ +[build-system] +requires = ["hatchling>=1.24"] +build-backend = "hatchling.build" + +[project] +name = "continual-learning-bench" +version = "0.3.0" +description = "A PyTorch benchmark framework for reproducible continual-learning experiments." +readme = "README.md" +requires-python = ">=3.10" +license = { file = "License" } +authors = [{ name = "Utkarsh Rajput" }] +keywords = ["continual-learning", "pytorch", "benchmark", "machine-learning"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +dependencies = [ + "numpy>=2.0", + "pyyaml>=6.0.3", + "torch>=2.6", + "torchvision>=0.21", +] + +[project.optional-dependencies] +dev = [ + "build>=1.2", + "pytest>=8.3", + "pytest-cov>=6.2", + "ruff>=0.12", +] +experiment = [ + "hydra-core>=1.3.2", + "mlflow>=3.12,<4", + "omegaconf>=2.3", + "pandas>=2.2", +] +report = [ + "matplotlib>=3.10", +] + +[project.scripts] +cl-bench = "cl_bench.cli:main" + +[project.urls] +Homepage = "https://github.com/1Utkarsh1/Continual-Learning" +Repository = "https://github.com/1Utkarsh1/Continual-Learning" + +[tool.hatch.build.targets.wheel] +packages = ["src/cl_bench"] + +[tool.pytest.ini_options] +addopts = "-q" +testpaths = ["tests"] + +[tool.ruff] +line-length = 100 +target-version = "py310" + +[tool.ruff.lint] +select = ["B", "E", "F", "I", "SIM", "UP"] +ignore = ["E501"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +line-ending = "auto" + +[tool.coverage.run] +branch = true +source = ["cl_bench"] + +[tool.coverage.report] +show_missing = true +skip_covered = true +exclude_also = [ + "if __name__ == .__main__.:", + "raise NotImplementedError", +] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index fe0b979..0000000 --- a/requirements.txt +++ /dev/null @@ -1,26 +0,0 @@ -# Core Deep Learning Frameworks -torch -torchvision -numpy -tqdm - -# Data Processing -pandas -scikit-learn -pillow - -# Visualization -matplotlib -seaborn - -# Experiment Tracking -tensorboard -jupyterlab -ipywidgets - -# Utilities -pyyaml -tabulate - -# Testing (Optional) -pytest \ No newline at end of file diff --git a/results/baseline_mnist_split_20250311_000335/final_model.pt b/results/baseline_mnist_split_20250311_000335/final_model.pt deleted file mode 100644 index 4569ece..0000000 Binary files a/results/baseline_mnist_split_20250311_000335/final_model.pt and /dev/null differ diff --git a/results/baseline_mnist_split_20250311_000335/forgetting.png b/results/baseline_mnist_split_20250311_000335/forgetting.png deleted file mode 100644 index 2576af8..0000000 Binary files a/results/baseline_mnist_split_20250311_000335/forgetting.png and /dev/null differ diff --git a/results/baseline_mnist_split_20250311_000335/forgetting_matrix.npy b/results/baseline_mnist_split_20250311_000335/forgetting_matrix.npy deleted file mode 100644 index 8cff948..0000000 Binary files a/results/baseline_mnist_split_20250311_000335/forgetting_matrix.npy and /dev/null differ diff --git a/results/baseline_mnist_split_20250311_000335/performance.png b/results/baseline_mnist_split_20250311_000335/performance.png deleted file mode 100644 index 58de4f3..0000000 Binary files a/results/baseline_mnist_split_20250311_000335/performance.png and /dev/null differ diff --git a/results/baseline_mnist_split_20250311_000335/performance_matrix.npy b/results/baseline_mnist_split_20250311_000335/performance_matrix.npy deleted file mode 100644 index 3977f84..0000000 Binary files a/results/baseline_mnist_split_20250311_000335/performance_matrix.npy and /dev/null differ diff --git a/results/logs/continual_learning_20250310_232605.log b/results/logs/continual_learning_20250310_232605.log deleted file mode 100644 index 003f125..0000000 --- a/results/logs/continual_learning_20250310_232605.log +++ /dev/null @@ -1,40 +0,0 @@ -2025-03-10 23:26:05,445 - __main__ - INFO - Starting Continual Learning experiment -2025-03-10 23:26:05,445 - __main__ - INFO - Arguments: Namespace(method='baseline', tasks='mnist_split', model='simple_cnn', epochs=5, batch_size=64, learning_rate=0.001, lambda_ewc=5000, fisher_sample_size=200, buffer_size=500, replay_batch_size=32, temperature=2.0, alpha=1.0, seed=42, device=None, eval_freq=1, save_dir='results') -2025-03-10 23:26:05,448 - __main__ - INFO - Using device: cpu -2025-03-10 23:26:05,449 - __main__ - INFO - Task sequence: ['mnist_0_4', 'mnist_5_9'] -2025-03-10 23:26:05,449 - src.data.data_loader - INFO - Loading task mnist_0_4 (dataset: mnist, classes: [0, 1, 2, 3, 4]) -2025-03-10 23:26:33,270 - src.data.data_loader - INFO - Loaded task mnist_0_4 with 27537 training, 3059 validation, and 5139 test samples -2025-03-10 23:26:33,271 - src.data.data_loader - INFO - Loading task mnist_5_9 (dataset: mnist, classes: [5, 6, 7, 8, 9]) -2025-03-10 23:26:33,316 - src.data.data_loader - INFO - Loaded task mnist_5_9 with 26464 training, 2940 validation, and 4861 test samples -2025-03-10 23:26:33,325 - src.models.model_factory - INFO - Created simple_cnn model with input shape torch.Size([1, 28, 28]) and 5 output classes -2025-03-10 23:26:33,326 - __main__ - INFO - Starting training on task 1/2: mnist_0_4 -2025-03-10 23:27:31,606 - src.methods.baseline - INFO - Task 1 - Epoch 1/5: Train Loss: 0.1076, Train Acc: 96.27%, Val Loss: 0.0239, Val Acc: 99.15% -2025-03-10 23:28:29,398 - src.methods.baseline - INFO - Task 1 - Epoch 2/5: Train Loss: 0.0236, Train Acc: 99.28%, Val Loss: 0.0255, Val Acc: 99.28% -2025-03-10 23:29:27,271 - src.methods.baseline - INFO - Task 1 - Epoch 3/5: Train Loss: 0.0148, Train Acc: 99.53%, Val Loss: 0.0198, Val Acc: 99.44% -2025-03-10 23:30:24,558 - src.methods.baseline - INFO - Task 1 - Epoch 4/5: Train Loss: 0.0119, Train Acc: 99.63%, Val Loss: 0.0255, Val Acc: 99.18% -2025-03-10 23:31:22,322 - src.methods.baseline - INFO - Task 1 - Epoch 5/5: Train Loss: 0.0083, Train Acc: 99.75%, Val Loss: 0.0211, Val Acc: 99.58% -2025-03-10 23:31:22,322 - src.methods.baseline - INFO - Loaded best model for task 1 with validation loss: 0.0198 -2025-03-10 23:31:38,202 - __main__ - INFO - After task 1, accuracy on task 1: 99.86% -2025-03-10 23:31:38,202 - __main__ - INFO - Starting training on task 2/2: mnist_5_9 -2025-03-10 23:31:51,157 - __main__ - ERROR - Experiment failed with error: Target 8 is out of bounds. -Traceback (most recent call last): - File "C:\Users\kumar\Desktop\github\continual_learning\src\main.py", line 302, in main - run_continual_learning(args, logger) - File "C:\Users\kumar\Desktop\github\continual_learning\src\main.py", line 245, in run_continual_learning - learner.train( - File "C:\Users\kumar\Desktop\github\continual_learning\src\methods\baseline.py", line 92, in train - loss = self.criterion(outputs, targets) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "C:\Users\kumar\AppData\Roaming\Python\Python312\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "C:\Users\kumar\AppData\Roaming\Python\Python312\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "C:\Users\kumar\AppData\Roaming\Python\Python312\site-packages\torch\nn\modules\loss.py", line 1188, in forward - return F.cross_entropy(input, target, weight=self.weight, - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "C:\Users\kumar\AppData\Roaming\Python\Python312\site-packages\torch\nn\functional.py", line 3104, in cross_entropy - return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -IndexError: Target 8 is out of bounds. diff --git a/results/logs/continual_learning_20250311_000335.log b/results/logs/continual_learning_20250311_000335.log deleted file mode 100644 index 6f9f3b0..0000000 --- a/results/logs/continual_learning_20250311_000335.log +++ /dev/null @@ -1,35 +0,0 @@ -2025-03-11 00:03:35,631 - __main__ - INFO - Starting Continual Learning experiment -2025-03-11 00:03:35,631 - __main__ - INFO - Arguments: Namespace(method='baseline', tasks='mnist_split', model='simple_cnn', epochs=5, batch_size=64, learning_rate=0.001, lambda_ewc=5000, fisher_sample_size=200, buffer_size=500, replay_batch_size=32, temperature=2.0, alpha=1.0, seed=42, device=None, eval_freq=1, save_dir='results') -2025-03-11 00:03:35,634 - __main__ - INFO - Using device: cpu -2025-03-11 00:03:35,634 - __main__ - INFO - Task sequence: ['mnist_0_4', 'mnist_5_9'] -2025-03-11 00:03:35,634 - src.data.data_loader - INFO - Loading task mnist_0_4 (dataset: mnist, classes: [0, 1, 2, 3, 4]) -2025-03-11 00:03:35,697 - src.data.data_loader - INFO - Loaded task mnist_0_4 with 27537 training, 3059 validation, and 5139 test samples -2025-03-11 00:03:35,698 - src.data.data_loader - INFO - Loading task mnist_5_9 (dataset: mnist, classes: [5, 6, 7, 8, 9]) -2025-03-11 00:03:35,749 - src.data.data_loader - INFO - Loaded task mnist_5_9 with 26464 training, 2940 validation, and 4861 test samples -2025-03-11 00:03:35,755 - src.models.model_factory - INFO - Created simple_cnn model with input shape torch.Size([1, 28, 28]) and 10 output classes -2025-03-11 00:03:35,756 - __main__ - INFO - Starting training on task 1/2: mnist_0_4 -2025-03-11 00:04:32,911 - src.methods.baseline - INFO - Task 1 - Epoch 1/5: Train Loss: 0.1474, Train Acc: 94.77%, Val Loss: 0.0358, Val Acc: 98.79% -2025-03-11 00:05:29,314 - src.methods.baseline - INFO - Task 1 - Epoch 2/5: Train Loss: 0.0251, Train Acc: 99.26%, Val Loss: 0.0307, Val Acc: 99.02% -2025-03-11 00:06:25,855 - src.methods.baseline - INFO - Task 1 - Epoch 3/5: Train Loss: 0.0152, Train Acc: 99.55%, Val Loss: 0.0241, Val Acc: 99.64% -2025-03-11 00:07:22,073 - src.methods.baseline - INFO - Task 1 - Epoch 4/5: Train Loss: 0.0118, Train Acc: 99.67%, Val Loss: 0.0301, Val Acc: 99.28% -2025-03-11 00:08:18,914 - src.methods.baseline - INFO - Task 1 - Epoch 5/5: Train Loss: 0.0099, Train Acc: 99.71%, Val Loss: 0.0186, Val Acc: 99.51% -2025-03-11 00:08:18,915 - src.methods.baseline - INFO - Loaded best model for task 1 with validation loss: 0.0186 -2025-03-11 00:08:33,903 - __main__ - INFO - After task 1, accuracy on task 1: 99.81% -2025-03-11 00:08:33,904 - __main__ - INFO - Starting training on task 2/2: mnist_5_9 -2025-03-11 00:09:28,940 - src.methods.baseline - INFO - Task 2 - Epoch 1/5: Train Loss: 0.4041, Train Acc: 90.08%, Val Loss: 0.0462, Val Acc: 98.33% -2025-03-11 00:10:24,280 - src.methods.baseline - INFO - Task 2 - Epoch 2/5: Train Loss: 0.0377, Train Acc: 98.87%, Val Loss: 0.0326, Val Acc: 98.88% -2025-03-11 00:11:19,652 - src.methods.baseline - INFO - Task 2 - Epoch 3/5: Train Loss: 0.0294, Train Acc: 99.04%, Val Loss: 0.0263, Val Acc: 99.39% -2025-03-11 00:12:14,864 - src.methods.baseline - INFO - Task 2 - Epoch 4/5: Train Loss: 0.0178, Train Acc: 99.44%, Val Loss: 0.0326, Val Acc: 99.18% -2025-03-11 00:13:10,415 - src.methods.baseline - INFO - Task 2 - Epoch 5/5: Train Loss: 0.0166, Train Acc: 99.49%, Val Loss: 0.0328, Val Acc: 99.15% -2025-03-11 00:13:10,416 - src.methods.baseline - INFO - Loaded best model for task 2 with validation loss: 0.0263 -2025-03-11 00:13:25,385 - __main__ - INFO - After task 2, accuracy on task 1: 0.00% -2025-03-11 00:13:40,495 - __main__ - INFO - After task 2, accuracy on task 2: 99.49% -2025-03-11 00:13:40,507 - src.methods.baseline - INFO - Model saved to C:\Users\kumar\Desktop\github\continual_learning\results\baseline_mnist_split_20250311_000335\final_model.pt -2025-03-11 00:13:41,358 - __main__ - INFO - -Experiment summary: -2025-03-11 00:13:41,358 - __main__ - INFO - Method: baseline -2025-03-11 00:13:41,359 - __main__ - INFO - Task sequence: ['mnist_0_4', 'mnist_5_9'] -2025-03-11 00:13:41,359 - __main__ - INFO - Average final accuracy: 49.74% -2025-03-11 00:13:41,359 - __main__ - INFO - Average forgetting: 49.90% -2025-03-11 00:13:41,359 - __main__ - INFO - Results saved to: C:\Users\kumar\Desktop\github\continual_learning\results\baseline_mnist_split_20250311_000335 -2025-03-11 00:13:41,380 - __main__ - INFO - Experiment completed successfully diff --git a/results/logs/continual_learning_20250311_001410.log b/results/logs/continual_learning_20250311_001410.log deleted file mode 100644 index bb736b1..0000000 --- a/results/logs/continual_learning_20250311_001410.log +++ /dev/null @@ -1,23 +0,0 @@ -2025-03-11 00:14:10,411 - __main__ - INFO - Starting Continual Learning experiment -2025-03-11 00:14:10,411 - __main__ - INFO - Arguments: Namespace(method='lwf', tasks='mnist_split', model='simple_cnn', epochs=5, batch_size=64, learning_rate=0.001, lambda_ewc=5000, fisher_sample_size=200, buffer_size=500, replay_batch_size=32, temperature=2.0, alpha=1.0, seed=42, device=None, eval_freq=1, save_dir='results') -2025-03-11 00:14:10,414 - __main__ - INFO - Using device: cpu -2025-03-11 00:14:10,415 - __main__ - INFO - Task sequence: ['mnist_0_4', 'mnist_5_9'] -2025-03-11 00:14:10,415 - src.data.data_loader - INFO - Loading task mnist_0_4 (dataset: mnist, classes: [0, 1, 2, 3, 4]) -2025-03-11 00:14:10,483 - src.data.data_loader - INFO - Loaded task mnist_0_4 with 27537 training, 3059 validation, and 5139 test samples -2025-03-11 00:14:10,484 - src.data.data_loader - INFO - Loading task mnist_5_9 (dataset: mnist, classes: [5, 6, 7, 8, 9]) -2025-03-11 00:14:10,535 - src.data.data_loader - INFO - Loaded task mnist_5_9 with 26464 training, 2940 validation, and 4861 test samples -2025-03-11 00:14:10,540 - src.models.model_factory - INFO - Created simple_cnn model with input shape torch.Size([1, 28, 28]) and 10 output classes -2025-03-11 00:14:10,541 - __main__ - INFO - Starting training on task 1/2: mnist_0_4 -2025-03-11 00:17:51,424 - __main__ - INFO - After task 1, accuracy on task 1: 99.67% -2025-03-11 00:17:51,424 - __main__ - INFO - Starting training on task 2/2: mnist_5_9 -2025-03-11 00:22:30,045 - __main__ - INFO - After task 2, accuracy on task 1: 0.00% -2025-03-11 00:22:45,006 - __main__ - INFO - After task 2, accuracy on task 2: 99.34% -2025-03-11 00:22:45,013 - src.methods.baseline - INFO - Model saved to C:\Users\kumar\Desktop\github\continual_learning\results\lwf_mnist_split_20250311_001410\final_model.pt -2025-03-11 00:22:45,578 - __main__ - INFO - -Experiment summary: -2025-03-11 00:22:45,578 - __main__ - INFO - Method: lwf -2025-03-11 00:22:45,578 - __main__ - INFO - Task sequence: ['mnist_0_4', 'mnist_5_9'] -2025-03-11 00:22:45,579 - __main__ - INFO - Average final accuracy: 49.67% -2025-03-11 00:22:45,579 - __main__ - INFO - Average forgetting: 49.83% -2025-03-11 00:22:45,579 - __main__ - INFO - Results saved to: C:\Users\kumar\Desktop\github\continual_learning\results\lwf_mnist_split_20250311_001410 -2025-03-11 00:22:45,594 - __main__ - INFO - Experiment completed successfully diff --git a/results/lwf_mnist_split_20250311_001410/final_model.pt b/results/lwf_mnist_split_20250311_001410/final_model.pt deleted file mode 100644 index c10c8d1..0000000 Binary files a/results/lwf_mnist_split_20250311_001410/final_model.pt and /dev/null differ diff --git a/results/lwf_mnist_split_20250311_001410/forgetting.png b/results/lwf_mnist_split_20250311_001410/forgetting.png deleted file mode 100644 index 33065df..0000000 Binary files a/results/lwf_mnist_split_20250311_001410/forgetting.png and /dev/null differ diff --git a/results/lwf_mnist_split_20250311_001410/forgetting_matrix.npy b/results/lwf_mnist_split_20250311_001410/forgetting_matrix.npy deleted file mode 100644 index e90967a..0000000 Binary files a/results/lwf_mnist_split_20250311_001410/forgetting_matrix.npy and /dev/null differ diff --git a/results/lwf_mnist_split_20250311_001410/performance.png b/results/lwf_mnist_split_20250311_001410/performance.png deleted file mode 100644 index 54e0635..0000000 Binary files a/results/lwf_mnist_split_20250311_001410/performance.png and /dev/null differ diff --git a/results/lwf_mnist_split_20250311_001410/performance_matrix.npy b/results/lwf_mnist_split_20250311_001410/performance_matrix.npy deleted file mode 100644 index 7b431a9..0000000 Binary files a/results/lwf_mnist_split_20250311_001410/performance_matrix.npy and /dev/null differ diff --git a/src/__init__.py b/src/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/cl_bench/__init__.py b/src/cl_bench/__init__.py new file mode 100644 index 0000000..3cba1fc --- /dev/null +++ b/src/cl_bench/__init__.py @@ -0,0 +1,6 @@ +"""Continual-learning benchmark framework.""" + +from cl_bench.config import BenchmarkResult, ExperimentConfig, TaskSpec + +__all__ = ["BenchmarkResult", "ExperimentConfig", "TaskSpec"] +__version__ = "0.3.0" diff --git a/src/cl_bench/__main__.py b/src/cl_bench/__main__.py new file mode 100644 index 0000000..1e145c0 --- /dev/null +++ b/src/cl_bench/__main__.py @@ -0,0 +1,4 @@ +from cl_bench.cli import main + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/cl_bench/cli.py b/src/cl_bench/cli.py new file mode 100644 index 0000000..b4da137 --- /dev/null +++ b/src/cl_bench/cli.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +import argparse +from collections.abc import Sequence +from dataclasses import replace +from pathlib import Path + +from cl_bench.config import ExperimentConfig, load_config_with_overrides +from cl_bench.experiments import run_experiment +from cl_bench.reporting import collect_runs, write_report +from cl_bench.tracking import MLflowRunLogger + +METHODS = ("baseline", "ewc", "replay", "lwf", "derpp", "agem") + + +def main(argv: Sequence[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + + if args.command == "list-configs": + for path in sorted((Path.cwd() / "configs").glob("*.yaml")): + print(path) + return 0 + + if args.command == "run": + config = apply_cli_overrides(load_cli_config(args), args) + result = run_experiment(config) + print(f"Run directory: {result.run_dir}") + print(f"Metrics: {result.metrics_path}") + print(f"Average final accuracy: {result.summary['average_final_accuracy']:.2f}%") + print(f"Average forgetting: {result.summary['average_forgetting']:.2f}%") + return 0 + + if args.command == "suite": + base_config = apply_cli_overrides(load_cli_config(args), args) + seeds = args.seeds or [base_config.seed] + run_dirs: list[Path] = [] + + for seed in seeds: + for method in args.methods: + config = replace(base_config, method=method, seed=seed) + result = run_experiment(config) + run_dirs.append(result.run_dir) + print( + f"{method} seed={seed}: " + f"{result.summary['average_final_accuracy']:.2f}% final accuracy, " + f"{result.summary['average_forgetting']:.2f}% forgetting " + f"({result.run_dir})" + ) + + if args.report_dir: + records = collect_runs(run_dirs) + report = write_report( + records, + output_dir=args.report_dir, + title=args.title or f"{base_config.name} benchmark report", + make_plots=not args.no_plots, + ) + log_suite_report_to_mlflow(base_config, report.report_dir, len(records)) + print(f"Report directory: {report.report_dir}") + print(f"Leaderboard: {report.leaderboard_csv}") + return 0 + + if args.command == "report": + records = collect_runs(args.runs) + report = write_report( + records, + output_dir=args.output_dir, + title=args.title, + make_plots=not args.no_plots, + ) + print(f"Report directory: {report.report_dir}") + print(f"Leaderboard: {report.leaderboard_csv}") + if report.plots: + print("Plots:") + for plot in report.plots: + print(f" {plot}") + return 0 + + parser.print_help() + return 1 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="cl-bench", + description="Run reproducible continual-learning benchmarks.", + ) + subparsers = parser.add_subparsers(dest="command") + + run_parser = subparsers.add_parser("run", help="Run a benchmark from a YAML config.") + add_config_arguments(run_parser) + run_parser.add_argument("--method", choices=METHODS) + add_runtime_overrides(run_parser) + run_parser.add_argument("overrides", nargs="*", help="Hydra/OmegaConf key=value overrides.") + + suite_parser = subparsers.add_parser( + "suite", + help="Run multiple methods/seeds from one config and optionally build a report.", + ) + add_config_arguments(suite_parser) + suite_parser.add_argument("--methods", nargs="+", choices=METHODS, default=list(METHODS)) + suite_parser.add_argument("--seeds", nargs="+", type=int) + suite_parser.add_argument("--report-dir") + suite_parser.add_argument("--title") + suite_parser.add_argument("--no-plots", action="store_true") + add_runtime_overrides(suite_parser) + suite_parser.add_argument("overrides", nargs="*", help="Hydra/OmegaConf key=value overrides.") + + report_parser = subparsers.add_parser( + "report", + help="Aggregate existing run directories or metrics.json files into a report.", + ) + report_parser.add_argument("--runs", nargs="+", required=True) + report_parser.add_argument("--output-dir", required=True) + report_parser.add_argument("--title", default="Continual-learning benchmark report") + report_parser.add_argument("--no-plots", action="store_true") + + subparsers.add_parser("list-configs", help="List YAML configs in ./configs.") + return parser + + +def add_config_arguments(parser: argparse.ArgumentParser) -> None: + config_group = parser.add_mutually_exclusive_group(required=True) + config_group.add_argument("--config", help="Config path or config name.") + config_group.add_argument("--config-name", help="Hydra-style config name from ./configs.") + + +def add_runtime_overrides(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--model", choices=["linear", "mlp", "small_cnn", "cnn"]) + parser.add_argument("--epochs", type=int) + parser.add_argument("--seed", type=int) + parser.add_argument("--device") + parser.add_argument("--output-dir") + parser.add_argument("--data-dir") + parser.add_argument("--batch-size", type=int) + parser.add_argument("--eval-batch-size", type=int) + parser.add_argument("--learning-rate", type=float) + parser.add_argument("--tracking", choices=["json", "mlflow", "both"]) + parser.add_argument("--save-checkpoint", action="store_true") + + +def apply_cli_overrides(config: ExperimentConfig, args: argparse.Namespace) -> ExperimentConfig: + updates = {} + for arg_name, field_name in [ + ("method", "method"), + ("model", "model"), + ("epochs", "epochs"), + ("seed", "seed"), + ("device", "device"), + ("output_dir", "output_dir"), + ("data_dir", "data_dir"), + ("batch_size", "batch_size"), + ("eval_batch_size", "eval_batch_size"), + ("learning_rate", "learning_rate"), + ("tracking", "tracking"), + ]: + value = getattr(args, arg_name, None) + if value is not None: + updates[field_name] = value + + if getattr(args, "save_checkpoint", False): + updates["save_checkpoint"] = True + + return replace(config, **updates) + + +def load_cli_config(args: argparse.Namespace) -> ExperimentConfig: + source = args.config_name or args.config + overrides = getattr(args, "overrides", None) or [] + return load_config_with_overrides(source, overrides) + + +def log_suite_report_to_mlflow(config: ExperimentConfig, report_dir: Path, run_count: int) -> None: + if config.tracking.lower() not in {"mlflow", "both"}: + return + with MLflowRunLogger( + tracking_uri=config.mlflow_tracking_uri, + experiment_name=config.mlflow_experiment, + run_name=f"{config.name}_suite_report", + enabled=True, + ) as logger: + logger.log_params( + { + "benchmark": config.name, + "report_dir": str(report_dir), + "run_count": run_count, + "artifact_type": "suite_report", + } + ) + logger.log_artifacts(report_dir) diff --git a/src/cl_bench/config.py b/src/cl_bench/config.py new file mode 100644 index 0000000..6dc5d3a --- /dev/null +++ b/src/cl_bench/config.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any + +import yaml + + +@dataclass +class TaskSpec: + """Declarative description of one task in a continual-learning benchmark.""" + + name: str + dataset: str + classes: list[int] | str + samples_per_class: int | None = None + test_samples_per_class: int | None = None + train_limit: int | None = None + test_limit: int | None = None + + @classmethod + def from_dict(cls, raw: dict[str, Any]) -> TaskSpec: + classes = raw["classes"] + if classes != "all": + classes = [int(label) for label in classes] + return cls( + name=str(raw["name"]), + dataset=str(raw["dataset"]), + classes=classes, + samples_per_class=_optional_int(raw.get("samples_per_class")), + test_samples_per_class=_optional_int(raw.get("test_samples_per_class")), + train_limit=_optional_int(raw.get("train_limit")), + test_limit=_optional_int(raw.get("test_limit")), + ) + + +@dataclass +class ExperimentConfig: + """Runtime configuration for a benchmark run.""" + + name: str + method: str + tasks: list[TaskSpec] + seed: int = 42 + device: str = "auto" + model: str = "mlp" + data_dir: str = "data" + output_dir: str = "runs" + tracking: str = "json" + mlflow_tracking_uri: str = "sqlite:///mlruns/mlflow.db" + mlflow_experiment: str = "continual-learning-bench" + epochs: int = 1 + batch_size: int = 64 + eval_batch_size: int = 256 + learning_rate: float = 1e-3 + val_fraction: float = 0.1 + num_workers: int = 0 + augment: bool = True + ewc_lambda: float = 50.0 + fisher_samples: int = 128 + replay_buffer_size: int = 512 + replay_batch_size: int = 32 + replay_loss_weight: float = 1.0 + lwf_alpha: float = 0.5 + lwf_temperature: float = 2.0 + derpp_alpha: float = 0.5 + derpp_beta: float = 1.0 + agem_memory_batch_size: int = 64 + save_checkpoint: bool = False + + @classmethod + def from_dict(cls, raw: dict[str, Any]) -> ExperimentConfig: + training = raw.get("training", {}) + strategy = raw.get("strategy", {}) + tracking = raw.get("tracking", {}) + if not isinstance(tracking, dict): + tracking = {"mode": tracking} + tasks = [TaskSpec.from_dict(task) for task in raw["tasks"]] + return cls( + name=str(raw.get("name", "continual_learning")), + method=str(raw.get("method", "baseline")), + tasks=tasks, + seed=int(raw.get("seed", 42)), + device=str(raw.get("device", "auto")), + model=str(raw.get("model", "mlp")), + data_dir=str(raw.get("data_dir", "data")), + output_dir=str(raw.get("output_dir", "runs")), + tracking=str(tracking.get("mode", raw.get("tracking", "json"))), + mlflow_tracking_uri=str( + tracking.get( + "mlflow_tracking_uri", + raw.get("mlflow_tracking_uri", "sqlite:///mlruns/mlflow.db"), + ) + ), + mlflow_experiment=str( + tracking.get( + "mlflow_experiment", + raw.get("mlflow_experiment", "continual-learning-bench"), + ) + ), + epochs=int(training.get("epochs", raw.get("epochs", 1))), + batch_size=int(training.get("batch_size", raw.get("batch_size", 64))), + eval_batch_size=int(training.get("eval_batch_size", raw.get("eval_batch_size", 256))), + learning_rate=float(training.get("learning_rate", raw.get("learning_rate", 1e-3))), + val_fraction=float(training.get("val_fraction", raw.get("val_fraction", 0.1))), + num_workers=int(training.get("num_workers", raw.get("num_workers", 0))), + augment=bool(training.get("augment", raw.get("augment", True))), + ewc_lambda=float(strategy.get("ewc_lambda", raw.get("ewc_lambda", 50.0))), + fisher_samples=int(strategy.get("fisher_samples", raw.get("fisher_samples", 128))), + replay_buffer_size=int( + strategy.get("replay_buffer_size", raw.get("replay_buffer_size", 512)) + ), + replay_batch_size=int( + strategy.get("replay_batch_size", raw.get("replay_batch_size", 32)) + ), + replay_loss_weight=float( + strategy.get("replay_loss_weight", raw.get("replay_loss_weight", 1.0)) + ), + lwf_alpha=float(strategy.get("lwf_alpha", raw.get("lwf_alpha", 0.5))), + lwf_temperature=float(strategy.get("lwf_temperature", raw.get("lwf_temperature", 2.0))), + derpp_alpha=float(strategy.get("derpp_alpha", raw.get("derpp_alpha", 0.5))), + derpp_beta=float(strategy.get("derpp_beta", raw.get("derpp_beta", 1.0))), + agem_memory_batch_size=int( + strategy.get("agem_memory_batch_size", raw.get("agem_memory_batch_size", 64)) + ), + save_checkpoint=bool(raw.get("save_checkpoint", False)), + ) + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + +@dataclass +class BenchmarkResult: + """Summary returned by an experiment run.""" + + run_dir: Path + method: str + task_names: list[str] + accuracy_matrix: list[list[float | None]] + forgetting_matrix: list[list[float | None]] + summary: dict[str, float | int | str | None] + metrics_path: Path + config_path: Path + runtime_seconds: float + git_commit: str | None = None + + +def load_config(source: str | Path) -> ExperimentConfig: + """Load an experiment config from a YAML path or a known config name.""" + + path = resolve_config_path(source) + with path.open("r", encoding="utf-8") as handle: + raw = yaml.safe_load(handle) + if not isinstance(raw, dict): + raise ValueError(f"Config {path} must contain a YAML mapping.") + return ExperimentConfig.from_dict(raw) + + +def load_config_with_overrides( + source: str | Path, + overrides: list[str] | None = None, +) -> ExperimentConfig: + """Load config and apply Hydra/OmegaConf dot-list overrides.""" + + overrides = overrides or [] + if not overrides: + return load_config(source) + + try: + from omegaconf import OmegaConf + except ImportError as exc: + raise RuntimeError( + "Hydra-style overrides require the experiment extra: " + 'python -m pip install -e ".[experiment]"' + ) from exc + + path = resolve_config_path(source) + base = OmegaConf.load(path) + override_conf = OmegaConf.from_dotlist(overrides) + merged = OmegaConf.merge(base, override_conf) + raw = OmegaConf.to_container(merged, resolve=True) + if not isinstance(raw, dict): + raise ValueError(f"Config {path} must contain a YAML mapping.") + return ExperimentConfig.from_dict(raw) + + +def resolve_config_path(source: str | Path) -> Path: + candidate = Path(source) + if candidate.exists(): + return candidate + + name = str(source) + if not name.endswith((".yaml", ".yml")): + name = f"{name}.yaml" + + search_roots = [ + Path.cwd() / "configs", + Path.cwd() / "configs" / "experiments", + Path(__file__).resolve().parents[2] / "configs", + Path(__file__).resolve().parents[2] / "configs" / "experiments", + ] + for root in search_roots: + path = root / name + if path.exists(): + return path + + searched = ", ".join(str(root / name) for root in search_roots) + raise FileNotFoundError(f"Could not find config '{source}'. Searched: {searched}") + + +def dump_config(config: ExperimentConfig, path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + yaml.safe_dump(config.to_dict(), handle, sort_keys=False) + + +def _optional_int(value: Any) -> int | None: + if value is None: + return None + return int(value) diff --git a/src/cl_bench/datasets.py b/src/cl_bench/datasets.py new file mode 100644 index 0000000..585d1b9 --- /dev/null +++ b/src/cl_bench/datasets.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from pathlib import Path + +import torch +from torch.utils.data import DataLoader, Dataset, Subset, random_split +from torchvision import datasets as tv_datasets +from torchvision import transforms + +from cl_bench.config import ExperimentConfig, TaskSpec + + +@dataclass +class TaskLoaders: + name: str + dataset: str + classes: list[int] + train_loader: DataLoader + val_loader: DataLoader + test_loader: DataLoader + + +class SyntheticImageDataset(Dataset): + """Small deterministic image classification dataset for CI and smoke runs.""" + + def __init__( + self, + classes: Sequence[int], + samples_per_class: int, + image_shape: tuple[int, int, int] = (1, 8, 8), + noise_std: float = 0.12, + seed: int = 0, + ): + self.classes = [int(label) for label in classes] + self.data: list[torch.Tensor] = [] + self.targets: list[int] = [] + noise_generator = torch.Generator().manual_seed(seed) + + for class_id in self.classes: + prototype_generator = torch.Generator().manual_seed(10_003 + class_id * 997) + prototype = torch.randn(image_shape, generator=prototype_generator) * 0.6 + prototype = prototype + float(class_id) * 0.15 + for _ in range(samples_per_class): + noise = torch.randn(image_shape, generator=noise_generator) * noise_std + self.data.append((prototype + noise).float()) + self.targets.append(class_id) + + def __len__(self) -> int: + return len(self.targets) + + def __getitem__(self, index: int) -> tuple[torch.Tensor, int]: + return self.data[index], self.targets[index] + + +def build_task_loaders(config: ExperimentConfig) -> tuple[list[TaskLoaders], tuple[int, ...], int]: + task_loaders: list[TaskLoaders] = [] + all_classes: set[int] = set() + + for task_id, task in enumerate(config.tasks): + train_dataset, test_dataset, classes = _build_datasets(task, config, task_id) + all_classes.update(classes) + + train_subset, val_subset = _split_train_validation( + train_dataset, + val_fraction=config.val_fraction, + seed=config.seed + task_id, + ) + train_loader = DataLoader( + train_subset, + batch_size=config.batch_size, + shuffle=True, + generator=torch.Generator().manual_seed(config.seed + task_id), + num_workers=config.num_workers, + ) + val_loader = DataLoader( + val_subset, + batch_size=config.eval_batch_size, + shuffle=False, + num_workers=config.num_workers, + ) + test_loader = DataLoader( + test_dataset, + batch_size=config.eval_batch_size, + shuffle=False, + num_workers=config.num_workers, + ) + task_loaders.append( + TaskLoaders( + name=task.name, + dataset=task.dataset, + classes=classes, + train_loader=train_loader, + val_loader=val_loader, + test_loader=test_loader, + ) + ) + + if not task_loaders: + raise ValueError("At least one task must be configured.") + + first_inputs, _ = task_loaders[0].train_loader.dataset[0] + input_shape = tuple(first_inputs.shape) + num_classes = max(all_classes) + 1 + return task_loaders, input_shape, num_classes + + +def _build_datasets( + task: TaskSpec, config: ExperimentConfig, task_id: int +) -> tuple[Dataset, Dataset, list[int]]: + dataset_name = task.dataset.lower() + if dataset_name == "synthetic": + if task.classes == "all": + raise ValueError("Synthetic tasks must list explicit class ids.") + train_samples = task.samples_per_class or 32 + test_samples = task.test_samples_per_class or max(8, train_samples // 2) + train_dataset = SyntheticImageDataset( + task.classes, + samples_per_class=train_samples, + seed=config.seed + task_id * 100, + ) + test_dataset = SyntheticImageDataset( + task.classes, + samples_per_class=test_samples, + seed=config.seed + task_id * 100 + 50_000, + ) + return train_dataset, test_dataset, [int(label) for label in task.classes] + + train_dataset = _torchvision_dataset( + dataset_name, + Path(config.data_dir), + train=True, + augment=config.augment, + ) + test_dataset = _torchvision_dataset( + dataset_name, + Path(config.data_dir), + train=False, + augment=False, + ) + classes = _resolve_classes(task, train_dataset) + train_subset = _class_subset(train_dataset, classes, task.train_limit, config.seed + task_id) + test_subset = _class_subset( + test_dataset, classes, task.test_limit, config.seed + task_id + 10_000 + ) + return train_subset, test_subset, classes + + +def _torchvision_dataset(dataset_name: str, data_dir: Path, train: bool, augment: bool) -> Dataset: + transform = _torchvision_transform(dataset_name, train=train, augment=augment) + if dataset_name == "mnist": + return tv_datasets.MNIST(data_dir, train=train, download=True, transform=transform) + if dataset_name == "fashion_mnist": + return tv_datasets.FashionMNIST(data_dir, train=train, download=True, transform=transform) + if dataset_name == "kmnist": + return tv_datasets.KMNIST(data_dir, train=train, download=True, transform=transform) + if dataset_name == "cifar10": + return tv_datasets.CIFAR10(data_dir, train=train, download=True, transform=transform) + raise ValueError(f"Unsupported dataset: {dataset_name}") + + +def _torchvision_transform(dataset_name: str, train: bool, augment: bool) -> transforms.Compose: + if dataset_name == "cifar10": + steps: list[object] = [] + if train and augment: + steps.extend([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()]) + steps.extend( + [ + transforms.ToTensor(), + transforms.Normalize( + mean=(0.4914, 0.4822, 0.4465), + std=(0.2470, 0.2435, 0.2616), + ), + ] + ) + return transforms.Compose(steps) + + return transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) + + +def _resolve_classes(task: TaskSpec, dataset: Dataset) -> list[int]: + if task.classes == "all": + if not hasattr(dataset, "classes"): + raise ValueError(f"Dataset {task.dataset} does not expose class metadata.") + return list(range(len(dataset.classes))) + return [int(label) for label in task.classes] + + +def _class_subset(dataset: Dataset, classes: list[int], limit: int | None, seed: int) -> Subset: + targets = getattr(dataset, "targets", None) + if targets is None: + targets = [dataset[index][1] for index in range(len(dataset))] + targets_tensor = torch.as_tensor(targets) + class_tensor = torch.tensor(classes, dtype=targets_tensor.dtype) + mask = torch.isin(targets_tensor, class_tensor) + indices = torch.nonzero(mask, as_tuple=False).flatten().tolist() + + if limit is not None and limit < len(indices): + generator = torch.Generator().manual_seed(seed) + selected = torch.randperm(len(indices), generator=generator)[:limit].tolist() + indices = [indices[index] for index in selected] + + return Subset(dataset, indices) + + +def _split_train_validation( + dataset: Dataset, val_fraction: float, seed: int +) -> tuple[Subset, Subset]: + if not 0.0 <= val_fraction < 1.0: + raise ValueError("val_fraction must be in [0.0, 1.0).") + + total = len(dataset) + if total < 2 or val_fraction == 0.0: + return Subset(dataset, list(range(total))), Subset(dataset, list(range(total))) + + val_size = max(1, int(total * val_fraction)) + train_size = total - val_size + if train_size <= 0: + train_size, val_size = total - 1, 1 + + train_subset, val_subset = random_split( + dataset, + [train_size, val_size], + generator=torch.Generator().manual_seed(seed), + ) + return train_subset, val_subset diff --git a/src/cl_bench/experiments.py b/src/cl_bench/experiments.py new file mode 100644 index 0000000..b400053 --- /dev/null +++ b/src/cl_bench/experiments.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import random +import time +from pathlib import Path + +import numpy as np +import torch + +from cl_bench.config import BenchmarkResult, ExperimentConfig, dump_config +from cl_bench.datasets import build_task_loaders +from cl_bench.metrics import compute_forgetting, matrix_to_jsonable, summarize_accuracy +from cl_bench.models import get_model +from cl_bench.strategies import create_strategy +from cl_bench.tracking import ExperimentTracker, MLflowRunLogger, create_run_dir, git_commit + + +def run_experiment(config: ExperimentConfig, repo_dir: str | Path | None = None) -> BenchmarkResult: + set_seed(config.seed) + device = resolve_device(config.device) + tasks, input_shape, num_classes = build_task_loaders(config) + model = get_model(config.model, input_shape=input_shape, num_classes=num_classes).to(device) + strategy = create_strategy(config, model, device) + + run_dir = create_run_dir(config.output_dir, config.name, config.method) + tracker = ExperimentTracker(run_dir) + config_path = run_dir / "config.yaml" + dump_config(config, config_path) + + commit = git_commit(repo_dir or Path.cwd()) + metadata = { + "benchmark": config.name, + "method": config.method, + "seed": config.seed, + "device": str(device), + "git_commit": commit, + } + tracker.write_json("run_metadata.json", metadata) + mlflow_enabled = config.tracking.lower() in {"mlflow", "both"} + + with MLflowRunLogger( + tracking_uri=config.mlflow_tracking_uri, + experiment_name=config.mlflow_experiment, + run_name=f"{config.name}_{config.method}_seed_{config.seed}", + enabled=mlflow_enabled, + ) as mlflow_logger: + mlflow_logger.log_params(config.to_dict()) + mlflow_logger.set_tags(metadata) + mlflow_logger.log_environment() + + start_time = time.perf_counter() + accuracy_matrix = np.full((len(tasks), len(tasks)), np.nan, dtype=float) + + for task_id, task in enumerate(tasks): + tracker.log_event( + { + "event": "task_started", + "task_id": task_id, + "task_name": task.name, + "classes": task.classes, + } + ) + history = strategy.train_task( + task.train_loader, + task.val_loader, + task_id=task_id, + epochs=config.epochs, + ) + for epoch_metrics in history: + tracker.log_event({"event": "epoch_finished", **epoch_metrics}) + mlflow_logger.log_metrics( + { + f"task_{task_id}_{key}": value + for key, value in epoch_metrics.items() + if key not in {"task_id", "epoch"} + }, + step=task_id * config.epochs + int(epoch_metrics["epoch"]), + ) + + for eval_task_id in range(task_id + 1): + eval_task = tasks[eval_task_id] + metrics = strategy.evaluate(eval_task.test_loader) + accuracy_matrix[task_id, eval_task_id] = metrics["accuracy"] + tracker.log_event( + { + "event": "evaluation", + "after_task_id": task_id, + "eval_task_id": eval_task_id, + "eval_task_name": eval_task.name, + **metrics, + } + ) + mlflow_logger.log_metrics( + { + f"eval_after_{task_id}_task_{eval_task_id}_accuracy": metrics["accuracy"], + f"eval_after_{task_id}_task_{eval_task_id}_loss": metrics["loss"], + }, + step=task_id, + ) + + runtime_seconds = time.perf_counter() - start_time + forgetting_matrix = compute_forgetting(accuracy_matrix) + summary = summarize_accuracy(accuracy_matrix) + summary.update( + { + "runtime_seconds": runtime_seconds, + "seed": config.seed, + "num_tasks": len(tasks), + "model": config.model, + "replay_buffer_size": config.replay_buffer_size, + "replay_batch_size": config.replay_batch_size, + } + ) + + tracker.write_json("accuracy_matrix.json", matrix_to_jsonable(accuracy_matrix)) + tracker.write_json("forgetting_matrix.json", matrix_to_jsonable(forgetting_matrix)) + tracker.write_matrix_csv("accuracy_matrix.csv", accuracy_matrix) + tracker.write_matrix_csv("forgetting_matrix.csv", forgetting_matrix) + + if config.save_checkpoint: + strategy.save_checkpoint(run_dir / "checkpoints" / "final_model.pt") + + metrics_path = tracker.write_json( + "metrics.json", + { + "benchmark": config.name, + "method": config.method, + "task_names": [task.name for task in tasks], + "summary": summary, + "accuracy_matrix": matrix_to_jsonable(accuracy_matrix), + "forgetting_matrix": matrix_to_jsonable(forgetting_matrix), + "runtime_seconds": runtime_seconds, + "seed": config.seed, + "git_commit": commit, + }, + ) + mlflow_logger.log_metrics(summary) + mlflow_logger.log_artifacts(run_dir) + + return BenchmarkResult( + run_dir=run_dir, + method=config.method, + task_names=[task.name for task in tasks], + accuracy_matrix=matrix_to_jsonable(accuracy_matrix), + forgetting_matrix=matrix_to_jsonable(forgetting_matrix), + summary=summary, + metrics_path=metrics_path, + config_path=config_path, + runtime_seconds=runtime_seconds, + git_commit=commit, + ) + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def resolve_device(device_name: str) -> torch.device: + requested = device_name.lower() + if requested == "auto": + if torch.cuda.is_available(): + return torch.device("cuda") + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return torch.device("mps") + return torch.device("cpu") + device = torch.device(requested) + if device.type == "cuda" and not torch.cuda.is_available(): + raise RuntimeError("CUDA was requested but is not available.") + if device.type == "mps" and not ( + hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + ): + raise RuntimeError("MPS was requested but is not available.") + return device diff --git a/src/cl_bench/metrics.py b/src/cl_bench/metrics.py new file mode 100644 index 0000000..db8c860 --- /dev/null +++ b/src/cl_bench/metrics.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from collections.abc import Iterable + +import numpy as np +import torch + + +def accuracy_from_logits(logits: torch.Tensor, targets: torch.Tensor) -> float: + """Return classification accuracy as a percentage.""" + + if targets.numel() == 0: + return 0.0 + predictions = logits.argmax(dim=1) + return float((predictions == targets).float().mean().item() * 100.0) + + +def compute_forgetting(accuracy_matrix: np.ndarray) -> np.ndarray: + """Compute best-so-far forgetting for every evaluated task.""" + + matrix = np.asarray(accuracy_matrix, dtype=float) + forgetting = np.full(matrix.shape, np.nan, dtype=float) + + for train_step in range(matrix.shape[0]): + for task_id in range(matrix.shape[1]): + current = matrix[train_step, task_id] + if np.isnan(current) or task_id > train_step: + continue + if train_step == task_id: + forgetting[train_step, task_id] = 0.0 + continue + + previous = matrix[:train_step, task_id] + if np.isnan(previous).all(): + forgetting[train_step, task_id] = 0.0 + continue + best_previous = float(np.nanmax(previous)) + forgetting[train_step, task_id] = max(0.0, best_previous - current) + + return forgetting + + +def summarize_accuracy(accuracy_matrix: np.ndarray) -> dict[str, float]: + """Summarize continual-learning accuracy and forgetting metrics.""" + + matrix = np.asarray(accuracy_matrix, dtype=float) + final_row = matrix[-1] + final_seen = final_row[~np.isnan(final_row)] + average_final_accuracy = float(np.mean(final_seen)) if final_seen.size else 0.0 + + diagonal = np.diag(matrix) + learned = diagonal[~np.isnan(diagonal)] + average_learning_accuracy = float(np.mean(learned)) if learned.size else 0.0 + + forgetting = compute_forgetting(matrix) + if matrix.shape[0] > 1: + final_forgetting = forgetting[-1, : matrix.shape[0] - 1] + final_forgetting = final_forgetting[~np.isnan(final_forgetting)] + average_forgetting = float(np.mean(final_forgetting)) if final_forgetting.size else 0.0 + else: + average_forgetting = 0.0 + + if matrix.shape[0] > 1: + previous_task_ids = range(matrix.shape[0] - 1) + transfers = [ + matrix[-1, task_id] - matrix[task_id, task_id] + for task_id in previous_task_ids + if not np.isnan(matrix[-1, task_id]) and not np.isnan(matrix[task_id, task_id]) + ] + backward_transfer = float(np.mean(transfers)) if transfers else 0.0 + else: + backward_transfer = 0.0 + + return { + "average_final_accuracy": average_final_accuracy, + "average_learning_accuracy": average_learning_accuracy, + "average_forgetting": average_forgetting, + "backward_transfer": backward_transfer, + } + + +def matrix_to_jsonable(matrix: np.ndarray) -> list[list[float | None]]: + """Convert a NumPy matrix to JSON-safe nested lists.""" + + result: list[list[float | None]] = [] + for row in np.asarray(matrix, dtype=float): + result.append([None if np.isnan(value) else float(value) for value in row]) + return result + + +def mean_or_zero(values: Iterable[float]) -> float: + values = list(values) + return float(np.mean(values)) if values else 0.0 diff --git a/src/cl_bench/models.py b/src/cl_bench/models.py new file mode 100644 index 0000000..91c8696 --- /dev/null +++ b/src/cl_bench/models.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from functools import reduce +from operator import mul + +import torch +from torch import nn + + +class LinearClassifier(nn.Module): + def __init__(self, input_shape: tuple[int, ...], num_classes: int): + super().__init__() + self.net = nn.Sequential(nn.Flatten(), nn.Linear(_num_features(input_shape), num_classes)) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self.net(inputs) + + +class MLP(nn.Module): + def __init__(self, input_shape: tuple[int, ...], num_classes: int, hidden_dim: int = 128): + super().__init__() + input_dim = _num_features(input_shape) + self.net = nn.Sequential( + nn.Flatten(), + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, num_classes), + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self.net(inputs) + + +class SmallCNN(nn.Module): + def __init__(self, input_shape: tuple[int, ...], num_classes: int): + super().__init__() + if len(input_shape) != 3: + raise ValueError(f"SmallCNN expects CHW input, got {input_shape}.") + channels, height, width = input_shape + self.features = nn.Sequential( + nn.Conv2d(channels, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(32, 64, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + ) + pooled_height = max(1, height // 4) + pooled_width = max(1, width // 4) + self.classifier = nn.Sequential( + nn.Flatten(), + nn.Linear(64 * pooled_height * pooled_width, 128), + nn.ReLU(), + nn.Linear(128, num_classes), + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self.classifier(self.features(inputs)) + + +class ResidualBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int, stride: int = 1): + super().__init__() + self.net = nn.Sequential( + nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False + ), + _norm(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), + _norm(out_channels), + ) + if stride != 1 or in_channels != out_channels: + self.shortcut = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), + _norm(out_channels), + ) + else: + self.shortcut = nn.Identity() + self.activation = nn.ReLU(inplace=True) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self.activation(self.net(inputs) + self.shortcut(inputs)) + + +class CifarConvNet(nn.Module): + """Compact residual CNN tuned for 32x32 RGB continual-learning benchmarks.""" + + def __init__(self, input_shape: tuple[int, ...], num_classes: int): + super().__init__() + if len(input_shape) != 3: + raise ValueError(f"CifarConvNet expects CHW input, got {input_shape}.") + channels, _, _ = input_shape + self.features = nn.Sequential( + nn.Conv2d(channels, 64, kernel_size=3, padding=1, bias=False), + _norm(64), + nn.ReLU(inplace=True), + ResidualBlock(64, 64), + ResidualBlock(64, 128, stride=2), + ResidualBlock(128, 128), + ResidualBlock(128, 256, stride=2), + ResidualBlock(256, 256), + nn.AdaptiveAvgPool2d((1, 1)), + ) + self.classifier = nn.Linear(256, num_classes) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + features = self.features(inputs) + return self.classifier(torch.flatten(features, 1)) + + +def get_model(model_name: str, input_shape: tuple[int, ...], num_classes: int) -> nn.Module: + name = model_name.lower().replace("-", "_") + if name == "linear": + return LinearClassifier(input_shape, num_classes) + if name == "mlp": + return MLP(input_shape, num_classes) + if name in {"small_cnn", "cnn"}: + return SmallCNN(input_shape, num_classes) + if name in {"cifar_convnet", "resnet18_cifar"}: + return CifarConvNet(input_shape, num_classes) + raise ValueError(f"Unknown model architecture: {model_name}") + + +def _num_features(input_shape: tuple[int, ...]) -> int: + return int(reduce(mul, input_shape, 1)) + + +def _norm(channels: int) -> nn.BatchNorm2d: + return nn.BatchNorm2d(channels) diff --git a/src/cl_bench/reporting.py b/src/cl_bench/reporting.py new file mode 100644 index 0000000..3aa8cf0 --- /dev/null +++ b/src/cl_bench/reporting.py @@ -0,0 +1,487 @@ +from __future__ import annotations + +import csv +import json +from collections import defaultdict +from collections.abc import Iterable, Sequence +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import numpy as np + + +@dataclass(frozen=True) +class RunRecord: + """Parsed metrics for one completed benchmark run.""" + + run_dir: Path + benchmark: str + method: str + seed: int + task_names: list[str] + summary: dict[str, float | int | str | None] + accuracy_matrix: np.ndarray + forgetting_matrix: np.ndarray + runtime_seconds: float + git_commit: str | None + + +@dataclass(frozen=True) +class ReportArtifacts: + """Files written by the reporting pipeline.""" + + report_dir: Path + leaderboard_csv: Path + summary_json: Path + markdown: Path + plots: list[Path] + + +def collect_runs(sources: Sequence[str | Path]) -> list[RunRecord]: + """Load metrics from run directories, metrics files, or parent directories.""" + + metric_paths = discover_metrics(sources) + if not metric_paths: + raise FileNotFoundError("No metrics.json files were found in the supplied run paths.") + return [load_run(path) for path in metric_paths] + + +def discover_metrics(sources: Sequence[str | Path]) -> list[Path]: + """Return sorted metrics.json paths discovered from files or directories.""" + + discovered: set[Path] = set() + for source in sources: + path = Path(source) + if path.is_file(): + if path.name != "metrics.json": + raise ValueError(f"Expected metrics.json file, got: {path}") + discovered.add(path) + continue + + if not path.exists(): + raise FileNotFoundError(path) + + direct_metrics = path / "metrics.json" + if direct_metrics.exists(): + discovered.add(direct_metrics) + continue + + for metrics_path in path.rglob("metrics.json"): + discovered.add(metrics_path) + + return sorted(discovered) + + +def load_run(metrics_path: str | Path) -> RunRecord: + """Parse one run metrics artifact.""" + + path = Path(metrics_path) + if path.is_dir(): + path = path / "metrics.json" + with path.open("r", encoding="utf-8") as handle: + payload = json.load(handle) + + summary = dict(payload["summary"]) + runtime_seconds = float(payload.get("runtime_seconds", summary.get("runtime_seconds", 0.0))) + seed = int(payload.get("seed", summary.get("seed", 0))) + return RunRecord( + run_dir=path.parent, + benchmark=str(payload["benchmark"]), + method=str(payload["method"]), + seed=seed, + task_names=[str(name) for name in payload["task_names"]], + summary=summary, + accuracy_matrix=_matrix_from_json(payload["accuracy_matrix"]), + forgetting_matrix=_matrix_from_json(payload["forgetting_matrix"]), + runtime_seconds=runtime_seconds, + git_commit=payload.get("git_commit"), + ) + + +def aggregate_records(records: Sequence[RunRecord]) -> list[dict[str, float | int | str]]: + """Aggregate run summaries by method for leaderboard-style reporting.""" + + by_method: dict[str, list[RunRecord]] = defaultdict(list) + for record in records: + by_method[record.method].append(record) + + rows: list[dict[str, float | int | str]] = [] + for method, method_records in sorted(by_method.items()): + final_accuracy = [_metric(record, "average_final_accuracy") for record in method_records] + learning_accuracy = [ + _metric(record, "average_learning_accuracy") for record in method_records + ] + forgetting = [_metric(record, "average_forgetting") for record in method_records] + backward_transfer = [_metric(record, "backward_transfer") for record in method_records] + runtimes = [record.runtime_seconds for record in method_records] + seeds = ",".join( + str(record.seed) for record in sorted(method_records, key=lambda item: item.seed) + ) + + rows.append( + { + "method": method, + "runs": len(method_records), + "seeds": seeds, + "average_final_accuracy_mean": _mean(final_accuracy), + "average_final_accuracy_std": _std(final_accuracy), + "average_learning_accuracy_mean": _mean(learning_accuracy), + "average_forgetting_mean": _mean(forgetting), + "average_forgetting_std": _std(forgetting), + "backward_transfer_mean": _mean(backward_transfer), + "runtime_seconds_mean": _mean(runtimes), + } + ) + + return sorted(rows, key=lambda row: float(row["average_final_accuracy_mean"]), reverse=True) + + +def write_report( + records: Sequence[RunRecord], + output_dir: str | Path, + title: str, + make_plots: bool = True, +) -> ReportArtifacts: + """Write CSV, JSON, Markdown, and optional plot artifacts for a benchmark suite.""" + + if not records: + raise ValueError("At least one run record is required to write a report.") + + report_dir = Path(output_dir) + report_dir.mkdir(parents=True, exist_ok=True) + leaderboard = aggregate_records(records) + leaderboard_csv = _write_leaderboard_csv(report_dir / "leaderboard.csv", leaderboard) + summary_json = _write_summary_json(report_dir / "summary.json", title, records, leaderboard) + plots: list[Path] = [] + if make_plots: + plots = _write_plots(records, leaderboard, report_dir, title) + markdown = _write_markdown(report_dir / "README.md", title, records, leaderboard, plots) + return ReportArtifacts( + report_dir=report_dir, + leaderboard_csv=leaderboard_csv, + summary_json=summary_json, + markdown=markdown, + plots=plots, + ) + + +def _write_leaderboard_csv(path: Path, rows: Sequence[dict[str, float | int | str]]) -> Path: + fieldnames = [ + "method", + "runs", + "seeds", + "average_final_accuracy_mean", + "average_final_accuracy_std", + "average_learning_accuracy_mean", + "average_forgetting_mean", + "average_forgetting_std", + "backward_transfer_mean", + "runtime_seconds_mean", + ] + with path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + writer.writerow(row) + return path + + +def _write_summary_json( + path: Path, + title: str, + records: Sequence[RunRecord], + leaderboard: Sequence[dict[str, float | int | str]], +) -> Path: + payload = { + "title": title, + "generated_at_utc": datetime.now(timezone.utc).isoformat(), + "benchmarks": sorted({record.benchmark for record in records}), + "num_runs": len(records), + "leaderboard": list(leaderboard), + "runs": [ + { + "run_dir": str(record.run_dir), + "benchmark": record.benchmark, + "method": record.method, + "seed": record.seed, + "task_names": record.task_names, + "summary": record.summary, + "runtime_seconds": record.runtime_seconds, + "git_commit": record.git_commit, + } + for record in records + ], + } + with path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + handle.write("\n") + return path + + +def _write_markdown( + path: Path, + title: str, + records: Sequence[RunRecord], + leaderboard: Sequence[dict[str, float | int | str]], + plots: Sequence[Path], +) -> Path: + benchmark_names = ", ".join(sorted({record.benchmark for record in records})) + task_count = max(len(record.task_names) for record in records) + lines = [ + f"# {title}", + "", + f"Benchmarks: `{benchmark_names}`", + f"Runs: `{len(records)}`", + f"Tasks per run: `{task_count}`", + "", + "## Leaderboard", + "", + "| Method | Runs | Seeds | Final accuracy | Forgetting | Backward transfer | Mean runtime |", + "| --- | ---: | --- | ---: | ---: | ---: | ---: |", + ] + for row in leaderboard: + lines.append( + "| {method} | {runs} | {seeds} | {accuracy} | {forgetting} | {bwt} | {runtime} |".format( + method=row["method"], + runs=row["runs"], + seeds=row["seeds"], + accuracy=_format_with_std( + float(row["average_final_accuracy_mean"]), + float(row["average_final_accuracy_std"]), + suffix="%", + ), + forgetting=_format_with_std( + float(row["average_forgetting_mean"]), + float(row["average_forgetting_std"]), + suffix="%", + ), + bwt=f"{float(row['backward_transfer_mean']):.2f}%", + runtime=f"{float(row['runtime_seconds_mean']):.1f}s", + ) + ) + + if plots: + lines.extend(["", "## Plots", ""]) + for plot in plots: + lines.append(f"![{plot.stem}]({plot.name})") + lines.append("") + + lines.extend( + [ + "## Source Runs", + "", + "| Method | Seed | Run directory |", + "| --- | ---: | --- |", + ] + ) + for record in sorted(records, key=lambda item: (item.method, item.seed, str(item.run_dir))): + lines.append(f"| {record.method} | {record.seed} | `{record.run_dir}` |") + + path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8") + return path + + +def _write_plots( + records: Sequence[RunRecord], + leaderboard: Sequence[dict[str, float | int | str]], + report_dir: Path, + title: str, +) -> list[Path]: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + paths = [ + _plot_leaderboard(plt, leaderboard, report_dir / "leaderboard.png", title), + _plot_retention_curves(plt, records, report_dir / "retention_curves.png"), + _plot_accuracy_matrices(plt, records, report_dir / "accuracy_matrices.png"), + ] + plt.close("all") + return paths + + +def _plot_leaderboard( + plt: Any, leaderboard: Sequence[dict[str, float | int | str]], path: Path, title: str +) -> Path: + methods = [str(row["method"]) for row in leaderboard] + final_means = [float(row["average_final_accuracy_mean"]) for row in leaderboard] + final_stds = [float(row["average_final_accuracy_std"]) for row in leaderboard] + forgetting_means = [float(row["average_forgetting_mean"]) for row in leaderboard] + forgetting_stds = [float(row["average_forgetting_std"]) for row in leaderboard] + colors = [_method_color(method) for method in methods] + + fig, axes = plt.subplots(1, 2, figsize=(12, 5.8), constrained_layout=True) + fig.suptitle(title, fontsize=16, fontweight="bold") + _bar_chart( + axes[0], + methods, + final_means, + final_stds, + colors, + title="Average final accuracy", + ylabel="Accuracy (%)", + higher_is_better=True, + ) + _bar_chart( + axes[1], + methods, + forgetting_means, + forgetting_stds, + colors, + title="Average forgetting", + ylabel="Forgetting (%)", + higher_is_better=False, + ) + fig.savefig(path, dpi=180, bbox_inches="tight") + return path + + +def _bar_chart( + axis: Any, + labels: Sequence[str], + means: Sequence[float], + stds: Sequence[float], + colors: Sequence[str], + title: str, + ylabel: str, + higher_is_better: bool, +) -> None: + positions = np.arange(len(labels)) + axis.bar( + positions, means, yerr=stds, color=colors, capsize=4, edgecolor="#111827", linewidth=0.8 + ) + axis.set_title(title, fontsize=12, fontweight="bold") + axis.set_ylabel(ylabel) + axis.set_xticks(positions, labels) + axis.set_ylim(0, max(100.0 if higher_is_better else 5.0, max(means, default=0.0) * 1.2 + 2.0)) + axis.grid(axis="y", alpha=0.22) + for position, value in zip(positions, means, strict=True): + axis.text(position, value + 1.0, f"{value:.1f}", ha="center", va="bottom", fontsize=9) + + +def _plot_retention_curves(plt: Any, records: Sequence[RunRecord], path: Path) -> Path: + fig, axis = plt.subplots(figsize=(9.5, 5.5), constrained_layout=True) + for method, method_records in _records_by_method(records).items(): + curves = [] + for record in method_records: + curve = [] + for step in range(record.accuracy_matrix.shape[0]): + seen = record.accuracy_matrix[step, : step + 1] + curve.append(float(np.nanmean(seen))) + curves.append(curve) + matrix = np.asarray(curves, dtype=float) + x_values = np.arange(1, matrix.shape[1] + 1) + mean_curve = np.nanmean(matrix, axis=0) + std_curve = np.nanstd(matrix, axis=0) + color = _method_color(method) + axis.plot(x_values, mean_curve, marker="o", linewidth=2.2, label=method, color=color) + if matrix.shape[0] > 1: + axis.fill_between( + x_values, mean_curve - std_curve, mean_curve + std_curve, color=color, alpha=0.16 + ) + + axis.set_title("Retention across the task stream", fontsize=13, fontweight="bold") + axis.set_xlabel("After training task") + axis.set_ylabel("Mean accuracy on seen tasks (%)") + axis.set_ylim(0, 100) + axis.set_xticks(np.arange(1, max(record.accuracy_matrix.shape[0] for record in records) + 1)) + axis.grid(alpha=0.25) + axis.legend(frameon=False) + fig.savefig(path, dpi=180, bbox_inches="tight") + return path + + +def _plot_accuracy_matrices(plt: Any, records: Sequence[RunRecord], path: Path) -> Path: + representative_records = [ + max(method_records, key=lambda record: _metric(record, "average_final_accuracy")) + for method_records in _records_by_method(records).values() + ] + columns = min(2, len(representative_records)) + rows = int(np.ceil(len(representative_records) / columns)) + fig, axes = plt.subplots( + rows, columns, figsize=(6.4 * columns, 5.4 * rows), squeeze=False, constrained_layout=True + ) + + for axis in axes.ravel()[len(representative_records) :]: + axis.axis("off") + + image = None + for axis, record in zip(axes.ravel(), representative_records, strict=False): + matrix = np.ma.masked_invalid(record.accuracy_matrix) + image = axis.imshow(matrix, vmin=0, vmax=100, cmap="viridis") + axis.set_title(f"{record.method} accuracy matrix", fontsize=12, fontweight="bold") + axis.set_xlabel("Evaluated task") + axis.set_ylabel("After training task") + axis.set_xticks(range(len(record.task_names))) + axis.set_yticks(range(len(record.task_names))) + axis.set_xticklabels( + [_short_task_name(name) for name in record.task_names], rotation=35, ha="right" + ) + axis.set_yticklabels([str(index + 1) for index in range(len(record.task_names))]) + for row in range(record.accuracy_matrix.shape[0]): + for column in range(record.accuracy_matrix.shape[1]): + value = record.accuracy_matrix[row, column] + if np.isnan(value): + continue + axis.text( + column, row, f"{value:.0f}", ha="center", va="center", color="white", fontsize=8 + ) + + if image is not None: + fig.colorbar(image, ax=axes.ravel().tolist(), shrink=0.78, label="Accuracy (%)") + fig.savefig(path, dpi=180, bbox_inches="tight") + return path + + +def _records_by_method(records: Sequence[RunRecord]) -> dict[str, list[RunRecord]]: + by_method: dict[str, list[RunRecord]] = defaultdict(list) + for record in sorted(records, key=lambda item: (item.method, item.seed)): + by_method[record.method].append(record) + return dict(by_method) + + +def _matrix_from_json(values: Sequence[Sequence[float | None]]) -> np.ndarray: + return np.array( + [[np.nan if value is None else float(value) for value in row] for row in values] + ) + + +def _metric(record: RunRecord, key: str) -> float: + value = record.summary.get(key, 0.0) + return 0.0 if value is None else float(value) + + +def _mean(values: Iterable[float]) -> float: + values = list(values) + return float(np.mean(values)) if values else 0.0 + + +def _std(values: Iterable[float]) -> float: + values = list(values) + return float(np.std(values, ddof=0)) if len(values) > 1 else 0.0 + + +def _format_with_std(mean: float, std: float, suffix: str) -> str: + if std == 0.0: + return f"{mean:.2f}{suffix}" + return f"{mean:.2f} +- {std:.2f}{suffix}" + + +def _method_color(method: str) -> str: + return { + "baseline": "#475569", + "ewc": "#2563eb", + "replay": "#16a34a", + "lwf": "#c2410c", + "derpp": "#7c3aed", + "agem": "#0f766e", + }.get(method, "#7c3aed") + + +def _short_task_name(name: str) -> str: + if "_" not in name: + return name + return name.split("_", 1)[1] diff --git a/src/cl_bench/strategies/__init__.py b/src/cl_bench/strategies/__init__.py new file mode 100644 index 0000000..e2d90ea --- /dev/null +++ b/src/cl_bench/strategies/__init__.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import torch +from torch import nn + +from cl_bench.config import ExperimentConfig +from cl_bench.strategies.agem import AGEMStrategy, build_agem +from cl_bench.strategies.baseline import BaselineStrategy, build_baseline +from cl_bench.strategies.derpp import DERPPStrategy, build_derpp +from cl_bench.strategies.ewc import EWCStrategy, build_ewc +from cl_bench.strategies.lwf import LwFStrategy, build_lwf +from cl_bench.strategies.replay import ReplayStrategy, build_replay + +Strategy = ( + BaselineStrategy | EWCStrategy | ReplayStrategy | LwFStrategy | DERPPStrategy | AGEMStrategy +) + + +def create_strategy(config: ExperimentConfig, model: nn.Module, device: torch.device) -> Strategy: + method = config.method.lower().replace("-", "_") + if method == "baseline": + return build_baseline(model, device, config.learning_rate) + if method == "ewc": + return build_ewc( + model, + device, + learning_rate=config.learning_rate, + ewc_lambda=config.ewc_lambda, + fisher_samples=config.fisher_samples, + ) + if method == "replay": + return build_replay( + model, + device, + learning_rate=config.learning_rate, + buffer_size=config.replay_buffer_size, + replay_batch_size=config.replay_batch_size, + replay_loss_weight=config.replay_loss_weight, + seed=config.seed, + ) + if method == "derpp": + return build_derpp( + model, + device, + learning_rate=config.learning_rate, + buffer_size=config.replay_buffer_size, + replay_batch_size=config.replay_batch_size, + alpha=config.derpp_alpha, + beta=config.derpp_beta, + seed=config.seed, + ) + if method == "agem": + return build_agem( + model, + device, + learning_rate=config.learning_rate, + buffer_size=config.replay_buffer_size, + memory_batch_size=config.agem_memory_batch_size, + seed=config.seed, + ) + if method == "lwf": + return build_lwf( + model, + device, + learning_rate=config.learning_rate, + alpha=config.lwf_alpha, + temperature=config.lwf_temperature, + ) + raise ValueError(f"Unknown continual-learning method: {config.method}") + + +__all__ = [ + "AGEMStrategy", + "BaselineStrategy", + "DERPPStrategy", + "EWCStrategy", + "LwFStrategy", + "ReplayStrategy", + "Strategy", + "create_strategy", +] diff --git a/src/cl_bench/strategies/agem.py b/src/cl_bench/strategies/agem.py new file mode 100644 index 0000000..3fb8621 --- /dev/null +++ b/src/cl_bench/strategies/agem.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import torch +from torch import nn + +from cl_bench.strategies.base import ContinualLearningStrategy +from cl_bench.strategies.replay import ReservoirReplayBuffer + + +class AGEMStrategy(ContinualLearningStrategy): + """A-GEM gradient projection against replay-memory reference gradients.""" + + def __init__( + self, + model: nn.Module, + device: torch.device, + learning_rate: float, + buffer_size: int, + memory_batch_size: int, + seed: int, + ): + super().__init__(model=model, device=device, learning_rate=learning_rate) + self.buffer = ReservoirReplayBuffer(capacity=buffer_size, seed=seed) + self.memory_batch_size = memory_batch_size + self.last_gradient_dot = 0.0 + self.last_projection_applied = 0.0 + + def compute_loss( + self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]: + del task_id + logits = self.model(inputs) + loss = self.criterion(logits, targets) + return ( + loss, + logits, + { + "ce_loss": float(loss.detach().item()), + "agem_gradient_dot": self.last_gradient_dot, + "agem_projection_applied": self.last_projection_applied, + }, + ) + + def after_backward( + self, + inputs: torch.Tensor, + targets: torch.Tensor, + task_id: int, + ) -> None: + del inputs, targets, task_id + self.last_gradient_dot = 0.0 + self.last_projection_applied = 0.0 + if len(self.buffer) == 0 or self.memory_batch_size <= 0: + return + + current_grad = _flatten_gradients(self.model) + replay_inputs, replay_targets = self.buffer.sample(self.memory_batch_size) + replay_inputs = replay_inputs.to(self.device) + replay_targets = replay_targets.to(self.device) + + self.optimizer.zero_grad(set_to_none=True) + reference_loss = self.criterion(self.model(replay_inputs), replay_targets) + reference_loss.backward() + reference_grad = _flatten_gradients(self.model) + + dot_product = torch.dot(current_grad, reference_grad) + self.last_gradient_dot = float(dot_product.detach().cpu().item()) + if dot_product < 0: + reference_norm = torch.dot(reference_grad, reference_grad).clamp_min(1e-12) + current_grad = current_grad - (dot_product / reference_norm) * reference_grad + self.last_projection_applied = 1.0 + + _assign_flattened_gradients(self.model, current_grad) + + def observe_batch( + self, + inputs: torch.Tensor, + targets: torch.Tensor, + logits: torch.Tensor, + task_id: int, + ) -> None: + del logits, task_id + self.buffer.add_batch(inputs, targets) + + def extra_state_dict(self) -> dict[str, object]: + return { + "buffer_inputs": [sample.inputs for sample in self.buffer.samples], + "buffer_targets": [sample.target for sample in self.buffer.samples], + "seen_count": self.buffer.seen_count, + "memory_batch_size": self.memory_batch_size, + } + + +def _flatten_gradients(model: nn.Module) -> torch.Tensor: + pieces = [] + for parameter in model.parameters(): + if not parameter.requires_grad: + continue + if parameter.grad is None: + pieces.append(torch.zeros_like(parameter).reshape(-1)) + else: + pieces.append(parameter.grad.detach().clone().reshape(-1)) + if not pieces: + return torch.zeros((), device=next(model.parameters()).device) + return torch.cat(pieces) + + +def _assign_flattened_gradients(model: nn.Module, flat_gradient: torch.Tensor) -> None: + offset = 0 + for parameter in model.parameters(): + if not parameter.requires_grad: + continue + count = parameter.numel() + gradient = flat_gradient[offset : offset + count].view_as(parameter).clone() + parameter.grad = gradient + offset += count + + +def build_agem( + model: nn.Module, + device: torch.device, + learning_rate: float, + buffer_size: int, + memory_batch_size: int, + seed: int, +) -> AGEMStrategy: + return AGEMStrategy( + model=model, + device=device, + learning_rate=learning_rate, + buffer_size=buffer_size, + memory_batch_size=memory_batch_size, + seed=seed, + ) diff --git a/src/cl_bench/strategies/base.py b/src/cl_bench/strategies/base.py new file mode 100644 index 0000000..85632b8 --- /dev/null +++ b/src/cl_bench/strategies/base.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +import torch +from torch import nn +from torch.utils.data import DataLoader + + +class ContinualLearningStrategy(ABC): + """Stable lifecycle interface shared by all continual-learning methods.""" + + def __init__(self, model: nn.Module, device: torch.device, learning_rate: float): + self.model = model + self.device = device + self.learning_rate = learning_rate + self.criterion = nn.CrossEntropyLoss() + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) + self.current_task = -1 + self.seen_tasks = 0 + + def train_task( + self, + train_loader: DataLoader, + val_loader: DataLoader, + task_id: int, + epochs: int, + ) -> list[dict[str, float | int]]: + self.current_task = task_id + self.seen_tasks = max(self.seen_tasks, task_id + 1) + self.before_task(task_id) + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) + + best_state: dict[str, torch.Tensor] | None = None + best_val_loss = float("inf") + history: list[dict[str, float | int]] = [] + + for epoch in range(epochs): + train_metrics = self._train_epoch(train_loader, task_id) + val_metrics = self.evaluate(val_loader) + epoch_metrics = { + "task_id": task_id, + "epoch": epoch + 1, + **{f"train_{key}": value for key, value in train_metrics.items()}, + **{f"val_{key}": value for key, value in val_metrics.items()}, + } + history.append(epoch_metrics) + + if val_metrics["loss"] < best_val_loss: + best_val_loss = float(val_metrics["loss"]) + best_state = clone_state_dict(self.model) + + if best_state is not None: + load_state_dict(self.model, best_state, self.device) + + self.after_task(train_loader, task_id) + return history + + def evaluate(self, data_loader: DataLoader) -> dict[str, float]: + self.model.eval() + total_loss = 0.0 + total_correct = 0 + total_examples = 0 + + with torch.no_grad(): + for inputs, targets in data_loader: + inputs = inputs.to(self.device) + targets = targets.to(self.device) + logits = self.model(inputs) + loss = self.criterion(logits, targets) + total_loss += float(loss.item()) * inputs.size(0) + total_correct += int((logits.argmax(dim=1) == targets).sum().item()) + total_examples += int(targets.numel()) + + if total_examples == 0: + return {"loss": 0.0, "accuracy": 0.0, "examples": 0.0} + return { + "loss": total_loss / total_examples, + "accuracy": 100.0 * total_correct / total_examples, + "examples": float(total_examples), + } + + def save_checkpoint(self, path: str | Path) -> None: + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + torch.save( + { + "model_state_dict": self.model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + "current_task": self.current_task, + "seen_tasks": self.seen_tasks, + "strategy_state": self.extra_state_dict(), + }, + path, + ) + + def load_checkpoint(self, path: str | Path) -> None: + checkpoint = torch.load(path, map_location=self.device) + self.model.load_state_dict(checkpoint["model_state_dict"]) + self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + self.current_task = int(checkpoint["current_task"]) + self.seen_tasks = int(checkpoint["seen_tasks"]) + self.load_extra_state_dict(checkpoint.get("strategy_state", {})) + + def before_task(self, task_id: int) -> None: + del task_id + + def after_task(self, train_loader: DataLoader, task_id: int) -> None: + del train_loader, task_id + + def extra_state_dict(self) -> dict[str, Any]: + return {} + + def load_extra_state_dict(self, state: dict[str, Any]) -> None: + del state + + def _train_epoch(self, train_loader: DataLoader, task_id: int) -> dict[str, float]: + self.model.train() + totals: dict[str, float] = {"loss": 0.0, "ce_loss": 0.0} + total_correct = 0 + total_examples = 0 + + for inputs, targets in train_loader: + inputs = inputs.to(self.device) + targets = targets.to(self.device) + + self.optimizer.zero_grad(set_to_none=True) + loss, logits, components = self.compute_loss(inputs, targets, task_id) + loss.backward() + self.after_backward(inputs, targets, task_id) + self.optimizer.step() + self.observe_batch(inputs, targets, logits, task_id) + + batch_size = int(targets.numel()) + total_examples += batch_size + total_correct += int((logits.argmax(dim=1) == targets).sum().item()) + totals["loss"] += float(loss.item()) * batch_size + for name, value in components.items(): + totals[name] = totals.get(name, 0.0) + float(value) * batch_size + + if total_examples == 0: + return {"loss": 0.0, "accuracy": 0.0} + + metrics = {name: value / total_examples for name, value in totals.items()} + metrics["accuracy"] = 100.0 * total_correct / total_examples + return metrics + + @abstractmethod + def compute_loss( + self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]: + raise NotImplementedError + + def after_backward( + self, + inputs: torch.Tensor, + targets: torch.Tensor, + task_id: int, + ) -> None: + del inputs, targets, task_id + + def observe_batch( + self, + inputs: torch.Tensor, + targets: torch.Tensor, + logits: torch.Tensor, + task_id: int, + ) -> None: + del inputs, targets, logits, task_id + + +def clone_state_dict(module: nn.Module) -> dict[str, torch.Tensor]: + """Deep-copy a module state dict onto CPU to avoid mutable checkpoint aliases.""" + + return {name: tensor.detach().cpu().clone() for name, tensor in module.state_dict().items()} + + +def load_state_dict( + module: nn.Module, state_dict: dict[str, torch.Tensor], device: torch.device +) -> None: + module.load_state_dict({name: tensor.to(device) for name, tensor in state_dict.items()}) diff --git a/src/cl_bench/strategies/baseline.py b/src/cl_bench/strategies/baseline.py new file mode 100644 index 0000000..05d7b9c --- /dev/null +++ b/src/cl_bench/strategies/baseline.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import torch +from torch import nn + +from cl_bench.strategies.base import ContinualLearningStrategy + + +class BaselineStrategy(ContinualLearningStrategy): + """Naive sequential fine-tuning baseline.""" + + def compute_loss( + self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]: + del task_id + logits = self.model(inputs) + loss = self.criterion(logits, targets) + return loss, logits, {"ce_loss": float(loss.detach().item())} + + +def build_baseline( + model: nn.Module, device: torch.device, learning_rate: float +) -> BaselineStrategy: + return BaselineStrategy(model=model, device=device, learning_rate=learning_rate) diff --git a/src/cl_bench/strategies/derpp.py b/src/cl_bench/strategies/derpp.py new file mode 100644 index 0000000..2adc1a1 --- /dev/null +++ b/src/cl_bench/strategies/derpp.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch import nn + +from cl_bench.strategies.base import ContinualLearningStrategy +from cl_bench.strategies.replay import ReservoirReplayBuffer + + +class DERPPStrategy(ContinualLearningStrategy): + """Dark Experience Replay++ with stored logits and replay labels.""" + + def __init__( + self, + model: nn.Module, + device: torch.device, + learning_rate: float, + buffer_size: int, + replay_batch_size: int, + alpha: float, + beta: float, + seed: int, + ): + super().__init__(model=model, device=device, learning_rate=learning_rate) + self.buffer = ReservoirReplayBuffer(capacity=buffer_size, seed=seed) + self.replay_batch_size = replay_batch_size + self.alpha = alpha + self.beta = beta + + def compute_loss( + self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]: + del task_id + logits = self.model(inputs) + ce_loss = self.criterion(logits, targets) + distillation_loss = torch.zeros((), device=self.device) + replay_ce_loss = torch.zeros((), device=self.device) + + if len(self.buffer) > 0 and self.replay_batch_size > 0: + replay_inputs, replay_targets, replay_logits = self.buffer.sample_tensors( + self.replay_batch_size, + require_logits=True, + ) + replay_inputs = replay_inputs.to(self.device) + replay_targets = replay_targets.to(self.device) + replay_logits = replay_logits.to(self.device) + current_replay_logits = self.model(replay_inputs) + distillation_loss = F.mse_loss(current_replay_logits, replay_logits) + replay_ce_loss = self.criterion(current_replay_logits, replay_targets) + + loss = ce_loss + self.alpha * distillation_loss + self.beta * replay_ce_loss + return ( + loss, + logits, + { + "ce_loss": float(ce_loss.detach().item()), + "derpp_distillation_loss": float(distillation_loss.detach().item()), + "derpp_replay_ce_loss": float(replay_ce_loss.detach().item()), + }, + ) + + def observe_batch( + self, + inputs: torch.Tensor, + targets: torch.Tensor, + logits: torch.Tensor, + task_id: int, + ) -> None: + del task_id + self.buffer.add_batch(inputs, targets, logits=logits) + + def extra_state_dict(self) -> dict[str, object]: + return { + "buffer_inputs": [sample.inputs for sample in self.buffer.samples], + "buffer_targets": [sample.target for sample in self.buffer.samples], + "buffer_logits": [sample.logits for sample in self.buffer.samples], + "seen_count": self.buffer.seen_count, + "alpha": self.alpha, + "beta": self.beta, + } + + +def build_derpp( + model: nn.Module, + device: torch.device, + learning_rate: float, + buffer_size: int, + replay_batch_size: int, + alpha: float, + beta: float, + seed: int, +) -> DERPPStrategy: + return DERPPStrategy( + model=model, + device=device, + learning_rate=learning_rate, + buffer_size=buffer_size, + replay_batch_size=replay_batch_size, + alpha=alpha, + beta=beta, + seed=seed, + ) diff --git a/src/cl_bench/strategies/ewc.py b/src/cl_bench/strategies/ewc.py new file mode 100644 index 0000000..e860ccf --- /dev/null +++ b/src/cl_bench/strategies/ewc.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.data import DataLoader + +from cl_bench.strategies.base import ContinualLearningStrategy, clone_state_dict + + +class EWCStrategy(ContinualLearningStrategy): + """Elastic Weight Consolidation with deterministic empirical Fisher estimates.""" + + def __init__( + self, + model: nn.Module, + device: torch.device, + learning_rate: float, + ewc_lambda: float, + fisher_samples: int, + ): + super().__init__(model=model, device=device, learning_rate=learning_rate) + self.ewc_lambda = ewc_lambda + self.fisher_samples = fisher_samples + self.fishers: list[dict[str, torch.Tensor]] = [] + self.optimal_parameters: list[dict[str, torch.Tensor]] = [] + + def compute_loss( + self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]: + del task_id + logits = self.model(inputs) + ce_loss = self.criterion(logits, targets) + ewc_penalty = self._ewc_penalty() + loss = ce_loss + self.ewc_lambda * ewc_penalty + return ( + loss, + logits, + { + "ce_loss": float(ce_loss.detach().item()), + "ewc_penalty": float(ewc_penalty.detach().item()), + }, + ) + + def after_task(self, train_loader: DataLoader, task_id: int) -> None: + del task_id + self.fishers.append(self._estimate_fisher(train_loader)) + self.optimal_parameters.append(clone_state_dict(self.model)) + + def extra_state_dict(self) -> dict[str, object]: + return { + "ewc_lambda": self.ewc_lambda, + "fisher_samples": self.fisher_samples, + "fishers": self.fishers, + "optimal_parameters": self.optimal_parameters, + } + + def load_extra_state_dict(self, state: dict[str, object]) -> None: + self.ewc_lambda = float(state.get("ewc_lambda", self.ewc_lambda)) + self.fisher_samples = int(state.get("fisher_samples", self.fisher_samples)) + self.fishers = [ + {name: tensor.to(self.device) for name, tensor in fisher.items()} + for fisher in state.get("fishers", []) + ] + self.optimal_parameters = [ + {name: tensor.to(self.device) for name, tensor in params.items()} + for params in state.get("optimal_parameters", []) + ] + + def _ewc_penalty(self) -> torch.Tensor: + penalty = torch.zeros((), device=self.device) + if not self.fishers: + return penalty + + named_parameters = dict(self.model.named_parameters()) + for fisher, optimum in zip(self.fishers, self.optimal_parameters, strict=True): + for name, parameter in named_parameters.items(): + if not parameter.requires_grad: + continue + fisher_tensor = fisher[name].to(self.device) + optimum_tensor = optimum[name].to(self.device) + penalty = penalty + 0.5 * torch.sum( + fisher_tensor * torch.square(parameter - optimum_tensor) + ) + return penalty + + def _estimate_fisher(self, data_loader: DataLoader) -> dict[str, torch.Tensor]: + self.model.eval() + fisher_loader = DataLoader( + data_loader.dataset, + batch_size=data_loader.batch_size, + shuffle=False, + num_workers=0, + ) + fisher = { + name: torch.zeros_like(parameter, device=self.device) + for name, parameter in self.model.named_parameters() + if parameter.requires_grad + } + sample_count = 0 + max_samples = self.fisher_samples if self.fisher_samples > 0 else float("inf") + + for inputs, targets in fisher_loader: + inputs = inputs.to(self.device) + targets = targets.to(self.device) + for index in range(inputs.size(0)): + if sample_count >= max_samples: + break + self.model.zero_grad(set_to_none=True) + logits = self.model(inputs[index : index + 1]) + log_probability = F.log_softmax(logits, dim=1)[0, targets[index]] + log_probability.backward() + for name, parameter in self.model.named_parameters(): + if parameter.grad is not None and parameter.requires_grad: + fisher[name] += torch.square(parameter.grad.detach()) + sample_count += 1 + if sample_count >= max_samples: + break + + self.model.zero_grad(set_to_none=True) + if sample_count == 0: + return {name: value.detach().cpu() for name, value in fisher.items()} + + return {name: (value / sample_count).detach().cpu() for name, value in fisher.items()} + + +def build_ewc( + model: nn.Module, + device: torch.device, + learning_rate: float, + ewc_lambda: float, + fisher_samples: int, +) -> EWCStrategy: + return EWCStrategy( + model=model, + device=device, + learning_rate=learning_rate, + ewc_lambda=ewc_lambda, + fisher_samples=fisher_samples, + ) diff --git a/src/cl_bench/strategies/lwf.py b/src/cl_bench/strategies/lwf.py new file mode 100644 index 0000000..da15fa7 --- /dev/null +++ b/src/cl_bench/strategies/lwf.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import copy +from typing import Any + +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.data import DataLoader + +from cl_bench.strategies.base import ContinualLearningStrategy + + +class LwFStrategy(ContinualLearningStrategy): + """Learning without Forgetting using a frozen teacher from the previous task.""" + + def __init__( + self, + model: nn.Module, + device: torch.device, + learning_rate: float, + alpha: float, + temperature: float, + ): + super().__init__(model=model, device=device, learning_rate=learning_rate) + self.alpha = alpha + self.temperature = temperature + self.teacher_model: nn.Module | None = None + + def compute_loss( + self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]: + del task_id + logits = self.model(inputs) + ce_loss = self.criterion(logits, targets) + distillation_loss = torch.zeros((), device=self.device) + + if self.teacher_model is not None: + with torch.no_grad(): + teacher_logits = self.teacher_model(inputs) + temperature = self.temperature + student_log_probs = F.log_softmax(logits / temperature, dim=1) + teacher_probs = F.softmax(teacher_logits / temperature, dim=1) + distillation_loss = ( + F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") + * temperature + * temperature + ) + + loss = ce_loss + self.alpha * distillation_loss + return ( + loss, + logits, + { + "ce_loss": float(ce_loss.detach().item()), + "distillation_loss": float(distillation_loss.detach().item()), + }, + ) + + def after_task(self, train_loader: DataLoader, task_id: int) -> None: + del train_loader, task_id + self.teacher_model = copy.deepcopy(self.model).to(self.device) + self.teacher_model.eval() + for parameter in self.teacher_model.parameters(): + parameter.requires_grad_(False) + + def extra_state_dict(self) -> dict[str, Any]: + return { + "alpha": self.alpha, + "temperature": self.temperature, + "teacher_state_dict": None + if self.teacher_model is None + else self.teacher_model.state_dict(), + } + + def load_extra_state_dict(self, state: dict[str, Any]) -> None: + self.alpha = float(state.get("alpha", self.alpha)) + self.temperature = float(state.get("temperature", self.temperature)) + teacher_state = state.get("teacher_state_dict") + if teacher_state is None: + self.teacher_model = None + return + self.teacher_model = copy.deepcopy(self.model).to(self.device) + self.teacher_model.load_state_dict(teacher_state) + self.teacher_model.eval() + for parameter in self.teacher_model.parameters(): + parameter.requires_grad_(False) + + +def build_lwf( + model: nn.Module, + device: torch.device, + learning_rate: float, + alpha: float, + temperature: float, +) -> LwFStrategy: + return LwFStrategy( + model=model, + device=device, + learning_rate=learning_rate, + alpha=alpha, + temperature=temperature, + ) diff --git a/src/cl_bench/strategies/replay.py b/src/cl_bench/strategies/replay.py new file mode 100644 index 0000000..c41d60e --- /dev/null +++ b/src/cl_bench/strategies/replay.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +import random +from dataclasses import dataclass + +import torch +from torch import nn +from torch.utils.data import DataLoader + +from cl_bench.strategies.base import ContinualLearningStrategy + + +@dataclass +class ReplaySample: + inputs: torch.Tensor + target: int + logits: torch.Tensor | None = None + + +class ReservoirReplayBuffer: + """Bounded replay buffer using reservoir sampling over the observed stream.""" + + def __init__(self, capacity: int, seed: int = 0): + if capacity < 0: + raise ValueError("Replay buffer capacity must be non-negative.") + self.capacity = capacity + self.rng = random.Random(seed) + self.samples: list[ReplaySample] = [] + self.seen_count = 0 + + def __len__(self) -> int: + return len(self.samples) + + def add_batch( + self, + inputs: torch.Tensor, + targets: torch.Tensor, + logits: torch.Tensor | None = None, + ) -> None: + logits_cpu = None if logits is None else logits.detach().cpu() + for index, (input_tensor, target) in enumerate( + zip(inputs.detach().cpu(), targets.detach().cpu(), strict=True) + ): + sample_logits = None if logits_cpu is None else logits_cpu[index] + self.add(input_tensor, int(target.item()), sample_logits) + + def add(self, inputs: torch.Tensor, target: int, logits: torch.Tensor | None = None) -> None: + self.seen_count += 1 + if self.capacity == 0: + return + sample = ReplaySample( + inputs=inputs.detach().cpu().clone(), + target=int(target), + logits=None if logits is None else logits.detach().cpu().clone(), + ) + if len(self.samples) < self.capacity: + self.samples.append(sample) + return + + replacement_index = self.rng.randrange(self.seen_count) + if replacement_index < self.capacity: + self.samples[replacement_index] = sample + + def sample(self, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]: + inputs, targets, _ = self.sample_tensors(batch_size) + return inputs, targets + + def sample_tensors( + self, batch_size: int, require_logits: bool = False + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + chosen = self.sample_samples(batch_size, require_logits=require_logits) + inputs = torch.stack([sample.inputs for sample in chosen]) + targets = torch.tensor([sample.target for sample in chosen], dtype=torch.long) + if require_logits: + logits = torch.stack([sample.logits for sample in chosen if sample.logits is not None]) + else: + logits = None + return inputs, targets, logits + + def sample_samples(self, batch_size: int, require_logits: bool = False) -> list[ReplaySample]: + candidates = [ + sample for sample in self.samples if not require_logits or sample.logits is not None + ] + if not candidates: + raise ValueError("Cannot sample from an empty replay buffer.") + return self.rng.sample(candidates, k=min(batch_size, len(candidates))) + + +class ReplayStrategy(ContinualLearningStrategy): + def __init__( + self, + model: nn.Module, + device: torch.device, + learning_rate: float, + buffer_size: int, + replay_batch_size: int, + replay_loss_weight: float, + seed: int, + ): + super().__init__(model=model, device=device, learning_rate=learning_rate) + self.buffer = ReservoirReplayBuffer(capacity=buffer_size, seed=seed) + self.replay_batch_size = replay_batch_size + self.replay_loss_weight = replay_loss_weight + + def compute_loss( + self, inputs: torch.Tensor, targets: torch.Tensor, task_id: int + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]: + del task_id + logits = self.model(inputs) + ce_loss = self.criterion(logits, targets) + replay_loss = torch.zeros((), device=self.device) + + if len(self.buffer) > 0 and self.replay_batch_size > 0: + replay_inputs, replay_targets = self.buffer.sample(self.replay_batch_size) + replay_inputs = replay_inputs.to(self.device) + replay_targets = replay_targets.to(self.device) + replay_logits = self.model(replay_inputs) + replay_loss = self.criterion(replay_logits, replay_targets) + + loss = ce_loss + self.replay_loss_weight * replay_loss + return ( + loss, + logits, + { + "ce_loss": float(ce_loss.detach().item()), + "replay_loss": float(replay_loss.detach().item()), + }, + ) + + def observe_batch( + self, + inputs: torch.Tensor, + targets: torch.Tensor, + logits: torch.Tensor, + task_id: int, + ) -> None: + del logits, task_id + self.buffer.add_batch(inputs, targets) + + def after_task(self, train_loader: DataLoader, task_id: int) -> None: + del train_loader, task_id + + def extra_state_dict(self) -> dict[str, object]: + return { + "buffer_inputs": [sample.inputs for sample in self.buffer.samples], + "buffer_targets": [sample.target for sample in self.buffer.samples], + "buffer_logits": [sample.logits for sample in self.buffer.samples], + "seen_count": self.buffer.seen_count, + } + + +def build_replay( + model: nn.Module, + device: torch.device, + learning_rate: float, + buffer_size: int, + replay_batch_size: int, + replay_loss_weight: float, + seed: int, +) -> ReplayStrategy: + return ReplayStrategy( + model=model, + device=device, + learning_rate=learning_rate, + buffer_size=buffer_size, + replay_batch_size=replay_batch_size, + replay_loss_weight=replay_loss_weight, + seed=seed, + ) diff --git a/src/cl_bench/tracking.py b/src/cl_bench/tracking.py new file mode 100644 index 0000000..d1e6eae --- /dev/null +++ b/src/cl_bench/tracking.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import json +import platform +import subprocess +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import numpy as np +import torch + + +class ExperimentTracker: + """Writes reproducibility artifacts for one benchmark run.""" + + def __init__(self, run_dir: Path): + self.run_dir = run_dir + self.run_dir.mkdir(parents=True, exist_ok=True) + self.events_path = self.run_dir / "metrics.jsonl" + + def log_event(self, event: dict[str, Any]) -> None: + payload = {"time_utc": utc_now(), **event} + with self.events_path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(payload, sort_keys=True) + "\n") + + def write_json(self, name: str, payload: dict[str, Any] | list[Any]) -> Path: + path = self.run_dir / name + with path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + handle.write("\n") + return path + + def write_matrix_csv(self, name: str, matrix: np.ndarray) -> Path: + path = self.run_dir / name + np.savetxt(path, matrix, delimiter=",", fmt="%.6f") + return path + + +def create_run_dir(output_dir: str | Path, benchmark_name: str, method: str) -> Path: + timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + safe_name = _safe_slug(benchmark_name) + safe_method = _safe_slug(method) + return Path(output_dir) / f"{safe_name}_{safe_method}_{timestamp}" + + +def git_commit(repo_dir: str | Path | None = None) -> str | None: + try: + completed = subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=repo_dir, + check=True, + capture_output=True, + text=True, + ) + except (OSError, subprocess.CalledProcessError): + return None + commit = completed.stdout.strip() + return commit or None + + +def utc_now() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _safe_slug(value: str) -> str: + return "".join(char if char.isalnum() or char in {"-", "_"} else "-" for char in value).strip( + "-" + ) + + +class MLflowRunLogger: + """Optional local MLflow logger layered on top of JSON artifacts.""" + + def __init__( + self, + tracking_uri: str | Path, + experiment_name: str, + run_name: str, + enabled: bool, + ): + self.tracking_uri = tracking_uri + self.experiment_name = experiment_name + self.run_name = run_name + self.enabled = enabled + self._mlflow: Any | None = None + self._active_run: Any | None = None + + def __enter__(self) -> MLflowRunLogger: + if not self.enabled: + return self + try: + import mlflow + except ImportError as exc: + raise RuntimeError( + "MLflow tracking was requested but mlflow is not installed. " + 'Install with: python -m pip install -e ".[experiment]"' + ) from exc + + self._mlflow = mlflow + mlflow.set_tracking_uri(_normalize_mlflow_uri(self.tracking_uri)) + mlflow.set_experiment(self.experiment_name) + self._active_run = mlflow.start_run(run_name=self.run_name) + return self + + def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: + if self._mlflow is not None and self._active_run is not None: + self._mlflow.end_run(status="FAILED" if exc_type else "FINISHED") + + def log_params(self, params: dict[str, Any]) -> None: + if self._mlflow is None: + return + for key, value in _flatten_params(params).items(): + self._mlflow.log_param(key, value) + + def set_tags(self, tags: dict[str, Any]) -> None: + if self._mlflow is None: + return + for key, value in tags.items(): + if value is not None: + self._mlflow.set_tag(key, str(value)) + + def log_environment(self) -> None: + if self._mlflow is None: + return + self.set_tags( + { + "python": platform.python_version(), + "platform": platform.platform(), + "torch": torch.__version__, + "cuda_available": torch.cuda.is_available(), + "mps_available": hasattr(torch.backends, "mps") + and torch.backends.mps.is_available(), + } + ) + + def log_metrics( + self, metrics: dict[str, float | int | str | None], step: int | None = None + ) -> None: + if self._mlflow is None: + return + for key, value in metrics.items(): + if isinstance(value, (int, float)): + self._mlflow.log_metric(key, float(value), step=step) + + def log_artifacts(self, run_dir: str | Path) -> None: + if self._mlflow is None: + return + self._mlflow.log_artifacts(str(run_dir)) + + +def _normalize_mlflow_uri(uri: str | Path) -> str: + value = str(uri) + if value.startswith("sqlite:///"): + db_path = Path(value.removeprefix("sqlite:///")) + db_path.parent.mkdir(parents=True, exist_ok=True) + return value + if "://" in value: + return value + return Path(value).resolve().as_uri() + + +def _flatten_params( + params: dict[str, Any], prefix: str = "" +) -> dict[str, str | int | float | bool]: + flattened: dict[str, str | int | float | bool] = {} + for key, value in params.items(): + name = f"{prefix}.{key}" if prefix else str(key) + if isinstance(value, dict): + flattened.update(_flatten_params(value, name)) + elif isinstance(value, (str, int, float, bool)): + flattened[name] = value + elif value is None: + flattened[name] = "null" + else: + flattened[name] = json.dumps(value, sort_keys=True, default=str) + return flattened diff --git a/src/data/__init__.py b/src/data/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/data/data_loader.py b/src/data/data_loader.py deleted file mode 100644 index 618adcc..0000000 --- a/src/data/data_loader.py +++ /dev/null @@ -1,253 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -""" -Data loading module for the Continual Learning System. -Handles loading and preprocessing of different datasets and task sequences. -""" - -import os -import logging -import numpy as np -from typing import List, Dict, Any, Union, Tuple - -import torch -import torchvision -import torchvision.transforms as transforms -from torch.utils.data import DataLoader, Subset, ConcatDataset, random_split - -logger = logging.getLogger(__name__) - - -def get_dataset(dataset_name: str, root: str = './data', train: bool = True, transform=None): - """ - Get the specified dataset. - - Args: - dataset_name (str): Name of the dataset ('mnist', 'fashion_mnist', 'kmnist', etc.) - root (str): Root directory for dataset storage - train (bool): Whether to load the training set or the test set - transform: Transformations to apply to the data - - Returns: - torch.utils.data.Dataset: The requested dataset - """ - dataset_name = dataset_name.lower() - - # Create directory if it doesn't exist - os.makedirs(root, exist_ok=True) - - # Default transformations if none provided - if transform is None: - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) # MNIST-like normalization - ]) - - # Load dataset based on name - if dataset_name == 'mnist': - dataset = torchvision.datasets.MNIST( - root=root, train=train, download=True, transform=transform - ) - elif dataset_name == 'fashion_mnist': - dataset = torchvision.datasets.FashionMNIST( - root=root, train=train, download=True, transform=transform - ) - elif dataset_name == 'kmnist': - dataset = torchvision.datasets.KMNIST( - root=root, train=train, download=True, transform=transform - ) - elif dataset_name == 'cifar10': - if transform is None: - # Default CIFAR10 transformation - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) - ]) - - dataset = torchvision.datasets.CIFAR10( - root=root, train=train, download=True, transform=transform - ) - elif dataset_name == 'cifar100': - if transform is None: - # Default CIFAR100 transformation - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) - ]) - - dataset = torchvision.datasets.CIFAR100( - root=root, train=train, download=True, transform=transform - ) - else: - raise ValueError(f"Unknown dataset: {dataset_name}") - - return dataset - - -def create_class_subset(dataset, classes: Union[List[int], str]): - """ - Create a subset of a dataset containing only the specified classes. - - Args: - dataset: PyTorch dataset - classes: List of class indices to include, or 'all' to include all classes - - Returns: - tuple: (Subset of the dataset, list of classes included) - """ - if classes == 'all': - # If we want all classes, return the whole dataset and a list of all class indices - if hasattr(dataset, 'classes'): - all_classes = list(range(len(dataset.classes))) - else: - # Try to infer the number of classes - targets = dataset.targets if hasattr(dataset, 'targets') else dataset.targets - all_classes = list(set(targets.numpy() if torch.is_tensor(targets) else targets)) - - return dataset, all_classes - - # Get targets from the dataset - if hasattr(dataset, 'targets'): - targets = dataset.targets - elif hasattr(dataset, 'target'): - targets = dataset.target - else: - # For datasets that store targets differently (e.g., as part of __getitem__) - # We'll need to iterate through the dataset - targets = torch.tensor([dataset[i][1] for i in range(len(dataset))]) - - # Convert targets to numpy array for easier filtering - if torch.is_tensor(targets): - targets = targets.numpy() - elif not isinstance(targets, np.ndarray): - targets = np.array(targets) - - # Get indices for the requested classes - indices = [i for i, label in enumerate(targets) if label in classes] - - # Create and return the subset - return Subset(dataset, indices), classes - - -def split_dataset(dataset, val_split: float = 0.1): - """ - Split a dataset into training and validation sets. - - Args: - dataset: PyTorch dataset - val_split (float): Fraction of data to use for validation - - Returns: - tuple: (training subset, validation subset) - """ - val_size = int(len(dataset) * val_split) - train_size = len(dataset) - val_size - - train_dataset, val_dataset = random_split( - dataset, - [train_size, val_size], - generator=torch.Generator().manual_seed(42) - ) - - return train_dataset, val_dataset - - -def get_task_sequence(task_configs: List[Dict[str, Any]], batch_size: int = 32): - """ - Create data loaders for a sequence of tasks. - - Args: - task_configs (list): List of task configuration dictionaries - batch_size (int): Batch size for data loaders - - Returns: - list: List of task data dictionaries, each containing name, classes, and data loaders - """ - # Data root directory is in the project root - data_root = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'data') - - task_data = [] - - for task_config in task_configs: - dataset_name = task_config['dataset'] - task_name = task_config['name'] - classes = task_config['classes'] - - logger.info(f"Loading task {task_name} (dataset: {dataset_name}, classes: {classes})") - - # Get training and test datasets - train_dataset = get_dataset(dataset_name, root=data_root, train=True) - test_dataset = get_dataset(dataset_name, root=data_root, train=False) - - # Create class-specific subsets if needed - train_subset, actual_classes = create_class_subset(train_dataset, classes) - test_subset, _ = create_class_subset(test_dataset, classes) - - # Split training set into train and validation - train_split, val_split = split_dataset(train_subset) - - # Create data loaders - train_loader = DataLoader( - train_split, - batch_size=batch_size, - shuffle=True, - num_workers=2, - pin_memory=True - ) - - val_loader = DataLoader( - val_split, - batch_size=batch_size, - shuffle=False, - num_workers=2, - pin_memory=True - ) - - test_loader = DataLoader( - test_subset, - batch_size=batch_size, - shuffle=False, - num_workers=2, - pin_memory=True - ) - - # Store task data - task_data.append({ - 'name': task_name, - 'dataset': dataset_name, - 'classes': actual_classes, - 'train_loader': train_loader, - 'val_loader': val_loader, - 'test_loader': test_loader - }) - - logger.info(f"Loaded task {task_name} with {len(train_split)} training, " - f"{len(val_split)} validation, and {len(test_subset)} test samples") - - return task_data - - -if __name__ == '__main__': - # Test the data loading functionality - logging.basicConfig(level=logging.INFO) - - # Example task sequence - task_configs = [ - {'name': 'mnist_0_4', 'dataset': 'mnist', 'classes': [0, 1, 2, 3, 4]}, - {'name': 'mnist_5_9', 'dataset': 'mnist', 'classes': [5, 6, 7, 8, 9]} - ] - - task_data = get_task_sequence(task_configs) - - # Print information about the tasks - for i, task in enumerate(task_data): - logger.info(f"Task {i+1}: {task['name']}") - logger.info(f" Classes: {task['classes']}") - logger.info(f" Training samples: {len(task['train_loader'].dataset)}") - logger.info(f" Validation samples: {len(task['val_loader'].dataset)}") - logger.info(f" Test samples: {len(task['test_loader'].dataset)}") - - # Get a batch of data - images, labels = next(iter(task['train_loader'])) - logger.info(f" Batch shape: {images.shape}, Labels: {labels.numpy()[:5]} ...") \ No newline at end of file diff --git a/src/main.py b/src/main.py deleted file mode 100644 index e126705..0000000 --- a/src/main.py +++ /dev/null @@ -1,310 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -""" -Main entry point for the Continual Learning System. -This script orchestrates the training and evaluation of continual learning approaches. -""" - -import os -import sys -import argparse -import logging -import yaml -import torch -import numpy as np -import random -from datetime import datetime - -# Add the 'src' directory to the Python path -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from src.data.data_loader import get_task_sequence -from src.models.model_factory import get_model -from src.methods.baseline import BaselineLearner -from src.methods.ewc import EWCLearner -from src.methods.replay import ExperienceReplayLearner -from src.methods.lwf import LwFLearner -from src.utils.metrics import evaluate_performance, compute_forgetting -from src.utils.visualization import plot_performance, plot_forgetting - - -def setup_logging(): - """Set up logging configuration.""" - log_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'results', 'logs') - os.makedirs(log_dir, exist_ok=True) - - log_file = os.path.join(log_dir, f'continual_learning_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log') - - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler(log_file), - logging.StreamHandler() - ] - ) - return logging.getLogger(__name__) - - -def set_seed(seed): - """Set random seed for reproducibility.""" - if seed is not None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - -def parse_arguments(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description='Continual Learning System') - - parser.add_argument('--method', type=str, required=True, - choices=['baseline', 'ewc', 'replay', 'lwf'], - help='Continual learning method to use') - - parser.add_argument('--tasks', type=str, default='mnist_split', - help='Predefined task sequence (or path to custom YAML file)') - - parser.add_argument('--model', type=str, default='simple_cnn', - help='Base model architecture') - - parser.add_argument('--epochs', type=int, default=5, - help='Number of epochs per task') - - parser.add_argument('--batch_size', type=int, default=64, - help='Batch size for training') - - parser.add_argument('--learning_rate', type=float, default=0.001, - help='Learning rate for optimizer') - - # EWC specific arguments - parser.add_argument('--lambda_ewc', type=float, default=5000, - help='Regularization strength for EWC') - parser.add_argument('--fisher_sample_size', type=int, default=200, - help='Number of samples to estimate Fisher information') - - # Experience Replay specific arguments - parser.add_argument('--buffer_size', type=int, default=500, - help='Size of the replay buffer') - parser.add_argument('--replay_batch_size', type=int, default=32, - help='Batch size for replayed samples') - - # LwF specific arguments - parser.add_argument('--temperature', type=float, default=2.0, - help='Temperature for knowledge distillation in LwF') - parser.add_argument('--alpha', type=float, default=1.0, - help='Weight for distillation loss in LwF') - - # General arguments - parser.add_argument('--seed', type=int, default=42, - help='Random seed for reproducibility') - - parser.add_argument('--device', type=str, default=None, - help='Device to use (cuda or cpu)') - - parser.add_argument('--eval_freq', type=int, default=1, - help='Frequency of evaluation during training (in epochs)') - - parser.add_argument('--save_dir', type=str, default='results', - help='Directory to save results') - - return parser.parse_args() - - -def load_task_config(task_name_or_path): - """ - Load task configuration from predefined sequences or a custom YAML file. - - Args: - task_name_or_path (str): Name of predefined task or path to custom YAML file - - Returns: - dict: Task configuration - """ - predefined_tasks = { - 'mnist_split': [ - {'name': 'mnist_0_4', 'dataset': 'mnist', 'classes': [0, 1, 2, 3, 4]}, - {'name': 'mnist_5_9', 'dataset': 'mnist', 'classes': [5, 6, 7, 8, 9]} - ], - 'multi_dataset': [ - {'name': 'mnist', 'dataset': 'mnist', 'classes': 'all'}, - {'name': 'fashion_mnist', 'dataset': 'fashion_mnist', 'classes': 'all'}, - {'name': 'kmnist', 'dataset': 'kmnist', 'classes': 'all'} - ], - } - - if task_name_or_path in predefined_tasks: - return {'task_sequence': predefined_tasks[task_name_or_path]} - - # Load from YAML file - if os.path.exists(task_name_or_path): - with open(task_name_or_path, 'r') as f: - return yaml.safe_load(f) - - raise ValueError(f"Task sequence '{task_name_or_path}' not recognized and file not found.") - - -def run_continual_learning(args, logger): - """ - Run the continual learning experiment. - - Args: - args: Command line arguments - logger: Logger instance - """ - # Set random seed - set_seed(args.seed) - - # Set device - if args.device is None: - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - else: - device = torch.device(args.device) - - logger.info(f"Using device: {device}") - - # Load task sequence - task_config = load_task_config(args.tasks) - task_sequence = task_config['task_sequence'] - - logger.info(f"Task sequence: {[task['name'] for task in task_sequence]}") - - # Create data loaders for all tasks - task_data = get_task_sequence(task_sequence, args.batch_size) - - # Initialize model and learner - input_shape = task_data[0]['train_loader'].dataset[0][0].shape - all_classes = set() - for task in task_sequence: - if task['classes'] == 'all': - # 'all' is a sentinel meaning all classes in the dataset (e.g. 10 for MNIST variants) - all_classes.update(range(10)) - else: - all_classes.update(task['classes']) - num_classes = len(all_classes) - - model = get_model(args.model, input_shape, num_classes) - model = model.to(device) - - # Select appropriate learner based on method - if args.method == 'baseline': - learner = BaselineLearner( - model=model, - device=device, - learning_rate=args.learning_rate - ) - elif args.method == 'ewc': - learner = EWCLearner( - model=model, - device=device, - learning_rate=args.learning_rate, - lambda_ewc=args.lambda_ewc, - fisher_sample_size=args.fisher_sample_size - ) - elif args.method == 'replay': - learner = ExperienceReplayLearner( - model=model, - device=device, - learning_rate=args.learning_rate, - buffer_size=args.buffer_size, - replay_batch_size=args.replay_batch_size - ) - elif args.method == 'lwf': - learner = LwFLearner( - model=model, - device=device, - learning_rate=args.learning_rate, - temperature=args.temperature, - alpha=args.alpha - ) - else: - raise ValueError(f"Unknown method: {args.method}") - - # Create directory to save results - results_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), args.save_dir) - experiment_dir = os.path.join(results_dir, f"{args.method}_{args.tasks}_{datetime.now().strftime('%Y%m%d_%H%M%S')}") - os.makedirs(experiment_dir, exist_ok=True) - - # Track performance after each task - performance_matrix = np.zeros((len(task_sequence), len(task_sequence))) - - # Sequentially train on each task - for task_id, task_data_dict in enumerate(task_data): - task_name = task_data_dict['name'] - train_loader = task_data_dict['train_loader'] - val_loader = task_data_dict['val_loader'] - - logger.info(f"Starting training on task {task_id+1}/{len(task_data)}: {task_name}") - - # Train on current task - learner.train( - train_loader=train_loader, - val_loader=val_loader, - task_id=task_id, - epochs=args.epochs, - eval_freq=args.eval_freq - ) - - # Evaluate on all tasks seen so far - for eval_task_id in range(task_id + 1): - eval_task_data = task_data[eval_task_id] - eval_loader = eval_task_data['test_loader'] - - accuracy = learner.evaluate(eval_loader, eval_task_id) - performance_matrix[task_id, eval_task_id] = accuracy - - logger.info(f"After task {task_id+1}, accuracy on task {eval_task_id+1}: {accuracy:.2f}%") - - # Calculate forgetting - forgetting_matrix = compute_forgetting(performance_matrix) - - # Save performance metrics - np.save(os.path.join(experiment_dir, 'performance_matrix.npy'), performance_matrix) - np.save(os.path.join(experiment_dir, 'forgetting_matrix.npy'), forgetting_matrix) - - # Save model - learner.save(os.path.join(experiment_dir, 'final_model.pt')) - - # Generate plots - task_names = [task['name'] for task in task_sequence] - performance_plot = plot_performance(performance_matrix, task_names) - forgetting_plot = plot_forgetting(forgetting_matrix, task_names) - - performance_plot.savefig(os.path.join(experiment_dir, 'performance.png')) - forgetting_plot.savefig(os.path.join(experiment_dir, 'forgetting.png')) - - # Print summary - logger.info("\nExperiment summary:") - logger.info(f"Method: {args.method}") - logger.info(f"Task sequence: {[task['name'] for task in task_sequence]}") - logger.info(f"Average final accuracy: {np.mean(performance_matrix[-1, :]):.2f}%") - logger.info(f"Average forgetting: {np.mean(forgetting_matrix[-1, :]):.2f}%") - logger.info(f"Results saved to: {experiment_dir}") - - -def main(): - """Main function.""" - # Parse arguments - args = parse_arguments() - - # Set up logging - logger = setup_logging() - logger.info("Starting Continual Learning experiment") - logger.info(f"Arguments: {args}") - - # Run experiment - try: - run_continual_learning(args, logger) - logger.info("Experiment completed successfully") - except Exception as e: - logger.exception(f"Experiment failed with error: {e}") - sys.exit(1) - - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/src/methods/__init__.py b/src/methods/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/methods/baseline.py b/src/methods/baseline.py deleted file mode 100644 index 0cfdbd6..0000000 --- a/src/methods/baseline.py +++ /dev/null @@ -1,239 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -""" -Baseline learner that simply fine-tunes the model on new tasks. -This implementation will demonstrate catastrophic forgetting. -""" - -import os -import logging -import time -from typing import Dict, List, Optional - -import torch -import torch.nn as nn -import torch.optim as optim -from torch.utils.data import DataLoader -from tqdm import tqdm - -logger = logging.getLogger(__name__) - - -class BaselineLearner: - """ - Naive fine-tuning implementation that trains sequentially on tasks. - Used as a baseline to demonstrate catastrophic forgetting. - """ - - def __init__(self, model: nn.Module, device: torch.device, learning_rate: float = 0.001): - """ - Initialize the baseline learner. - - Args: - model (nn.Module): The neural network model - device (torch.device): Device to run the model on - learning_rate (float): Learning rate for the optimizer - """ - self.model = model - self.device = device - self.learning_rate = learning_rate - - # Initialize optimizer - self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate) - - # Loss function - self.criterion = nn.CrossEntropyLoss() - - # Keep track of tasks - self.current_task = 0 - self.seen_tasks = 0 - - def train(self, train_loader: DataLoader, val_loader: DataLoader, - task_id: int, epochs: int, eval_freq: int = 1): - """ - Train the model on a new task. - - Args: - train_loader (DataLoader): Training data loader - val_loader (DataLoader): Validation data loader - task_id (int): ID of the current task - epochs (int): Number of training epochs - eval_freq (int): Frequency of evaluation during training (in epochs) - """ - # Update task info - self.current_task = task_id - self.seen_tasks = max(self.seen_tasks, task_id + 1) - - # Create a new optimizer for the new task with the original learning rate - self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate) - - # Training loop - best_val_loss = float('inf') - best_model_state = None - - for epoch in range(epochs): - # Training phase - self.model.train() - train_loss = 0.0 - train_correct = 0 - train_total = 0 - - # Use tqdm for a nice progress bar - train_pbar = tqdm(train_loader, desc=f"Task {task_id+1} - Epoch {epoch+1}/{epochs} [Train]") - for inputs, targets in train_pbar: - inputs, targets = inputs.to(self.device), targets.to(self.device) - - # Zero the parameter gradients - self.optimizer.zero_grad() - - # Forward pass - outputs = self.model(inputs) - loss = self.criterion(outputs, targets) - - # Backward pass and optimize - loss.backward() - self.optimizer.step() - - # Update statistics - train_loss += loss.item() * inputs.size(0) - _, predicted = torch.max(outputs.data, 1) - train_total += targets.size(0) - train_correct += (predicted == targets).sum().item() - - # Update progress bar - train_pbar.set_postfix({'loss': loss.item(), - 'acc': 100.0 * train_correct / train_total}) - - train_loss = train_loss / train_total - train_acc = 100.0 * train_correct / train_total - - # Validation phase (run every eval_freq epochs) - if epoch % eval_freq == 0: - val_loss, val_acc = self._evaluate_training(val_loader) - - logger.info(f"Task {task_id+1} - Epoch {epoch+1}/{epochs}: " - f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, " - f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%") - - # Save the best model based on validation loss - if val_loss < best_val_loss: - best_val_loss = val_loss - best_model_state = self.model.state_dict().copy() - - # Load the best model from this task's training - if best_model_state is not None: - self.model.load_state_dict(best_model_state) - logger.info(f"Loaded best model for task {task_id+1} with validation loss: {best_val_loss:.4f}") - - def _evaluate_training(self, val_loader: DataLoader) -> tuple: - """ - Evaluate the model on validation data during training. - - Args: - val_loader (DataLoader): Validation data loader - - Returns: - tuple: (validation loss, validation accuracy) - """ - self.model.eval() - val_loss = 0.0 - val_correct = 0 - val_total = 0 - - with torch.no_grad(): - for inputs, targets in val_loader: - inputs, targets = inputs.to(self.device), targets.to(self.device) - - # Forward pass - outputs = self.model(inputs) - loss = self.criterion(outputs, targets) - - # Update statistics - val_loss += loss.item() * inputs.size(0) - _, predicted = torch.max(outputs.data, 1) - val_total += targets.size(0) - val_correct += (predicted == targets).sum().item() - - val_loss = val_loss / val_total - val_acc = 100.0 * val_correct / val_total - - return val_loss, val_acc - - def evaluate(self, test_loader: DataLoader, task_id: Optional[int] = None) -> float: - """ - Evaluate the model on test data. - - Args: - test_loader (DataLoader): Test data loader - task_id (int, optional): ID of the task to evaluate - - Returns: - float: Test accuracy as a percentage - """ - self.model.eval() - test_correct = 0 - test_total = 0 - - with torch.no_grad(): - for inputs, targets in test_loader: - inputs, targets = inputs.to(self.device), targets.to(self.device) - - # Forward pass - outputs = self.model(inputs) - - # Calculate accuracy - _, predicted = torch.max(outputs.data, 1) - test_total += targets.size(0) - test_correct += (predicted == targets).sum().item() - - test_acc = 100.0 * test_correct / test_total - - return test_acc - - def save(self, path: str): - """ - Save the model to a file. - - Args: - path (str): Path to save the model - """ - # Create directory if it doesn't exist - os.makedirs(os.path.dirname(path), exist_ok=True) - - # Save model state - torch.save({ - 'model_state_dict': self.model.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), - 'seen_tasks': self.seen_tasks, - 'current_task': self.current_task - }, path) - - logger.info(f"Model saved to {path}") - - def load(self, path: str): - """ - Load the model from a file. - - Args: - path (str): Path to load the model from - """ - if not os.path.exists(path): - logger.error(f"Model file not found: {path}") - return - - # Load model state - checkpoint = torch.load(path, map_location=self.device) - self.model.load_state_dict(checkpoint['model_state_dict']) - self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - self.seen_tasks = checkpoint['seen_tasks'] - self.current_task = checkpoint['current_task'] - - logger.info(f"Model loaded from {path}") - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - - logger.info("This is a baseline learner that demonstrates catastrophic forgetting.") - logger.info("It should be used as a comparison point for more advanced continual learning methods.") \ No newline at end of file diff --git a/src/methods/ewc.py b/src/methods/ewc.py deleted file mode 100644 index e1f12a4..0000000 --- a/src/methods/ewc.py +++ /dev/null @@ -1,330 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -""" -Elastic Weight Consolidation (EWC) learner for continual learning. -EWC prevents catastrophic forgetting by penalizing changes to parameters -that are important for previously learned tasks. - -Reference: -Kirkpatrick, J., Pascanu, R., Rabinowitz, N., Veness, J., Desjardins, G., Rusu, A. A., ... & Hadsell, R. (2017). -"Overcoming catastrophic forgetting in neural networks." -Proceedings of the National Academy of Sciences, 114(13), 3521-3526. -""" - -import os -import logging -import copy -from typing import Dict, List, Optional - -import torch -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -from torch.utils.data import DataLoader -from tqdm import tqdm - -from .baseline import BaselineLearner - -logger = logging.getLogger(__name__) - - -class EWCLearner(BaselineLearner): - """ - Elastic Weight Consolidation (EWC) learner. - Extends the baseline learner with a regularization term to prevent forgetting. - """ - - def __init__(self, model: nn.Module, device: torch.device, learning_rate: float = 0.001, - lambda_ewc: float = 5000, fisher_sample_size: int = 200): - """ - Initialize the EWC learner. - - Args: - model (nn.Module): The neural network model - device (torch.device): Device to run the model on - learning_rate (float): Learning rate for the optimizer - lambda_ewc (float): Regularization strength for EWC - fisher_sample_size (int): Number of samples to estimate Fisher information - """ - super().__init__(model, device, learning_rate) - - self.lambda_ewc = lambda_ewc - self.fisher_sample_size = fisher_sample_size - - # Store parameters and Fisher information for each task - self.fisher_matrices = {} # Fisher information for each task - self.optimal_parameters = {} # Optimal parameters for each task - - def _compute_fisher_information(self, data_loader: DataLoader, task_id: int): - """ - Compute the Fisher information matrix for the current task. - The Fisher information measures how much the model parameters affect the output. - - Args: - data_loader (DataLoader): Data loader for the current task - task_id (int): ID of the current task - """ - logger.info(f"Computing Fisher information matrix for task {task_id+1}") - - # Initialize Fisher information matrix - fisher = {} - for name, param in self.model.named_parameters(): - if param.requires_grad: - fisher[name] = torch.zeros_like(param.data) - - # Set model to evaluation mode - self.model.eval() - - # Sample a subset of data for Fisher computation - sample_count = 0 - - for inputs, targets in data_loader: - if sample_count >= self.fisher_sample_size: - break - - inputs, targets = inputs.to(self.device), targets.to(self.device) - batch_size = inputs.shape[0] - - # Get model outputs - log_probs = F.log_softmax(self.model(inputs), dim=1) - - # Compute gradients for each sample in the batch - for i in range(batch_size): - if sample_count >= self.fisher_sample_size: - break - - sample_log_prob = log_probs[i, targets[i]] - - # Compute gradients - self.optimizer.zero_grad() - sample_log_prob.backward(retain_graph=(i < batch_size - 1)) - - # Accumulate squared gradients - for name, param in self.model.named_parameters(): - if param.requires_grad and param.grad is not None: - fisher[name] += param.grad.data.pow(2) / self.fisher_sample_size - - sample_count += 1 - - # Store the computed Fisher information - self.fisher_matrices[task_id] = fisher - - logger.info(f"Fisher information matrix computed using {sample_count} samples") - - def _store_optimal_parameters(self, task_id: int): - """ - Store the optimal parameters for the current task. - - Args: - task_id (int): ID of the current task - """ - logger.info(f"Storing optimal parameters for task {task_id+1}") - - optimal_params = {} - for name, param in self.model.named_parameters(): - if param.requires_grad: - optimal_params[name] = param.data.clone() - - self.optimal_parameters[task_id] = optimal_params - - def _compute_ewc_loss(self): - """ - Compute the EWC regularization loss based on stored Fisher information. - This penalties changes to parameters that were important for previous tasks. - - Returns: - torch.Tensor: EWC regularization loss - """ - ewc_loss = 0 - - for task_id in range(self.current_task): - if task_id not in self.fisher_matrices or task_id not in self.optimal_parameters: - continue - - for name, param in self.model.named_parameters(): - if name in self.fisher_matrices[task_id] and name in self.optimal_parameters[task_id]: - # Compute the squared difference between current and optimal parameters - # weighted by the Fisher information - fisher = self.fisher_matrices[task_id][name] - optimal_param = self.optimal_parameters[task_id][name] - ewc_loss += torch.sum(fisher * (param - optimal_param).pow(2)) / 2 - - return ewc_loss * self.lambda_ewc - - def train(self, train_loader: DataLoader, val_loader: DataLoader, - task_id: int, epochs: int, eval_freq: int = 1): - """ - Train the model on a new task with EWC regularization. - - Args: - train_loader (DataLoader): Training data loader - val_loader (DataLoader): Validation data loader - task_id (int): ID of the current task - epochs (int): Number of training epochs - eval_freq (int): Frequency of evaluation during training (in epochs) - """ - # Update task info - self.current_task = task_id - self.seen_tasks = max(self.seen_tasks, task_id + 1) - - # Create a new optimizer for the new task with the original learning rate - self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate) - - # Training loop - best_val_loss = float('inf') - best_model_state = None - - for epoch in range(epochs): - # Training phase - self.model.train() - train_loss = 0.0 - train_task_loss = 0.0 - train_ewc_loss = 0.0 - train_correct = 0 - train_total = 0 - - # Use tqdm for a nice progress bar - train_pbar = tqdm(train_loader, desc=f"Task {task_id+1} - Epoch {epoch+1}/{epochs} [Train]") - for inputs, targets in train_pbar: - inputs, targets = inputs.to(self.device), targets.to(self.device) - - # Zero the parameter gradients - self.optimizer.zero_grad() - - # Forward pass - outputs = self.model(inputs) - task_loss = self.criterion(outputs, targets) - - # Compute EWC regularization loss (for tasks > 0) - ewc_loss = self._compute_ewc_loss() if task_id > 0 else 0 - - # Total loss - loss = task_loss + ewc_loss - - # Backward pass and optimize - loss.backward() - self.optimizer.step() - - # Update statistics - train_loss += loss.item() * inputs.size(0) - train_task_loss += task_loss.item() * inputs.size(0) - if task_id > 0: - train_ewc_loss += ewc_loss.item() * inputs.size(0) - - _, predicted = torch.max(outputs.data, 1) - train_total += targets.size(0) - train_correct += (predicted == targets).sum().item() - - # Update progress bar - train_pbar.set_postfix({ - 'loss': loss.item(), - 'task_loss': task_loss.item(), - 'ewc_loss': ewc_loss.item() if task_id > 0 else 0, - 'acc': 100.0 * train_correct / train_total - }) - - train_loss = train_loss / train_total - train_task_loss = train_task_loss / train_total - train_ewc_loss = train_ewc_loss / train_total if task_id > 0 else 0 - train_acc = 100.0 * train_correct / train_total - - # Validation phase (run every eval_freq epochs) - if epoch % eval_freq == 0: - val_loss, val_acc = self._evaluate_training(val_loader) - - logger.info(f"Task {task_id+1} - Epoch {epoch+1}/{epochs}: " - f"Train Loss: {train_loss:.4f} (Task: {train_task_loss:.4f}, EWC: {train_ewc_loss:.4f}), " - f"Train Acc: {train_acc:.2f}%, " - f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%") - - # Save the best model based on validation loss - if val_loss < best_val_loss: - best_val_loss = val_loss - best_model_state = self.model.state_dict().copy() - - # Load the best model from this task's training - if best_model_state is not None: - self.model.load_state_dict(best_model_state) - logger.info(f"Loaded best model for task {task_id+1} with validation loss: {best_val_loss:.4f}") - - # After training on this task, compute and store Fisher information and optimal parameters - self._compute_fisher_information(train_loader, task_id) - self._store_optimal_parameters(task_id) - - def save(self, path: str): - """ - Save the model and EWC-specific information to a file. - - Args: - path (str): Path to save the model - """ - # Create directory if it doesn't exist - os.makedirs(os.path.dirname(path), exist_ok=True) - - # Convert fisher matrices and optimal parameters to CPU tensors for serialization - fisher_cpu = {} - for task_id, task_fisher in self.fisher_matrices.items(): - fisher_cpu[task_id] = {name: tensor.cpu() for name, tensor in task_fisher.items()} - - params_cpu = {} - for task_id, task_params in self.optimal_parameters.items(): - params_cpu[task_id] = {name: tensor.cpu() for name, tensor in task_params.items()} - - # Save model state - torch.save({ - 'model_state_dict': self.model.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), - 'seen_tasks': self.seen_tasks, - 'current_task': self.current_task, - 'fisher_matrices': fisher_cpu, - 'optimal_parameters': params_cpu, - 'lambda_ewc': self.lambda_ewc, - 'fisher_sample_size': self.fisher_sample_size - }, path) - - logger.info(f"Model saved to {path}") - - def load(self, path: str): - """ - Load the model and EWC-specific information from a file. - - Args: - path (str): Path to load the model from - """ - if not os.path.exists(path): - logger.error(f"Model file not found: {path}") - return - - # Load model state - checkpoint = torch.load(path, map_location=self.device) - self.model.load_state_dict(checkpoint['model_state_dict']) - self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - self.seen_tasks = checkpoint['seen_tasks'] - self.current_task = checkpoint['current_task'] - - # Load EWC-specific information - self.lambda_ewc = checkpoint.get('lambda_ewc', self.lambda_ewc) - self.fisher_sample_size = checkpoint.get('fisher_sample_size', self.fisher_sample_size) - - # Load Fisher matrices and optimal parameters, moving them to the correct device - if 'fisher_matrices' in checkpoint: - self.fisher_matrices = {} - for task_id, task_fisher in checkpoint['fisher_matrices'].items(): - self.fisher_matrices[int(task_id)] = {name: tensor.to(self.device) - for name, tensor in task_fisher.items()} - - if 'optimal_parameters' in checkpoint: - self.optimal_parameters = {} - for task_id, task_params in checkpoint['optimal_parameters'].items(): - self.optimal_parameters[int(task_id)] = {name: tensor.to(self.device) - for name, tensor in task_params.items()} - - logger.info(f"Model loaded from {path}") - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - - logger.info("This is an implementation of Elastic Weight Consolidation for continual learning.") - logger.info("Reference: Kirkpatrick et al. (2017) - 'Overcoming catastrophic forgetting in neural networks'") \ No newline at end of file diff --git a/src/methods/lwf.py b/src/methods/lwf.py deleted file mode 100644 index ad38ec3..0000000 --- a/src/methods/lwf.py +++ /dev/null @@ -1,41 +0,0 @@ -from src.methods.baseline import BaselineLearner -import torch -import torch.nn.functional as F -import copy - -class LwFLearner(BaselineLearner): - def __init__(self, model, device, learning_rate, temperature, alpha): - super().__init__(model, device, learning_rate) - self.temperature = temperature - self.alpha = alpha - self.teacher_model = None # Will hold a copy of the model from previous tasks - - def train(self, train_loader, val_loader=None, task_id=0, epochs=1, eval_freq=1, **kwargs): - self.model.train() - optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) - criterion_ce = torch.nn.CrossEntropyLoss() - criterion_kd = torch.nn.KLDivLoss(reduction='batchmean') - T = self.temperature - - for epoch in range(epochs): - for batch in train_loader: - inputs, targets = batch - inputs, targets = inputs.to(self.device), targets.to(self.device) - optimizer.zero_grad() - outputs = self.model(inputs) - loss_ce = criterion_ce(outputs, targets) - loss_kd = 0.0 - if self.teacher_model is not None: - self.teacher_model.eval() - with torch.no_grad(): - teacher_outputs = self.teacher_model(inputs) - # Compute soft targets: student uses log_softmax, teacher uses softmax - soft_student = F.log_softmax(outputs / T, dim=1) - soft_teacher = F.softmax(teacher_outputs / T, dim=1) - loss_kd = criterion_kd(soft_student, soft_teacher) * (T * T) - loss = self.alpha * loss_ce + (1 - self.alpha) * loss_kd - loss.backward() - optimizer.step() - - # Update teacher model after training the current task - self.teacher_model = copy.deepcopy(self.model) \ No newline at end of file diff --git a/src/methods/replay.py b/src/methods/replay.py deleted file mode 100644 index e9f3d11..0000000 --- a/src/methods/replay.py +++ /dev/null @@ -1,44 +0,0 @@ -from src.methods.baseline import BaselineLearner -import torch -import random - -class ExperienceReplayLearner(BaselineLearner): - def __init__(self, model, device, learning_rate, buffer_size, replay_batch_size): - super().__init__(model, device, learning_rate) - self.buffer_size = buffer_size - self.replay_batch_size = replay_batch_size - self.buffer = [] - - def train(self, train_loader, val_loader=None, task_id=0, epochs=1, eval_freq=1, **kwargs): - self.model.train() - optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) - criterion = torch.nn.CrossEntropyLoss() - - for epoch in range(epochs): - for batch in train_loader: - inputs, targets = batch - inputs, targets = inputs.to(self.device), targets.to(self.device) - optimizer.zero_grad() - outputs = self.model(inputs) - loss = criterion(outputs, targets) - - # Incorporate replay loss if buffer is not empty - if len(self.buffer) > 0: - replay_samples = random.sample(self.buffer, min(self.replay_batch_size, len(self.buffer))) - replay_inputs = torch.stack([sample[0] for sample in replay_samples]).to(self.device) - replay_targets = torch.tensor([sample[1] for sample in replay_samples]).to(self.device) - replay_outputs = self.model(replay_inputs) - replay_loss = criterion(replay_outputs, replay_targets) - loss = loss + replay_loss - - loss.backward() - optimizer.step() - - # Update the replay buffer with current task samples - for batch in train_loader: - inputs, targets = batch - for i in range(inputs.size(0)): - self.buffer.append((inputs[i].detach().cpu(), targets[i].detach().cpu())) - # Keep buffer size within limits - if len(self.buffer) > self.buffer_size: - self.buffer = self.buffer[-self.buffer_size:] \ No newline at end of file diff --git a/src/models/__init__.py b/src/models/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/models/__pycache__/model_factory.cpython-312.pyc b/src/models/__pycache__/model_factory.cpython-312.pyc deleted file mode 100644 index cd7cb47..0000000 Binary files a/src/models/__pycache__/model_factory.cpython-312.pyc and /dev/null differ diff --git a/src/models/model_factory.py b/src/models/model_factory.py deleted file mode 100644 index 789ed61..0000000 --- a/src/models/model_factory.py +++ /dev/null @@ -1,318 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -""" -Model factory module for the Continual Learning System. -Provides different neural network architectures. -""" - -import logging -from typing import Tuple, List, Dict - -import torch -import torch.nn as nn -import torch.nn.functional as F - -logger = logging.getLogger(__name__) - - -class SimpleCNN(nn.Module): - """ - A simple CNN model for image classification tasks. - """ - def __init__(self, input_shape: Tuple[int, int, int], num_classes: int): - """ - Initialize the model. - - Args: - input_shape (tuple): Shape of input images (channels, height, width) - num_classes (int): Number of output classes - """ - super(SimpleCNN, self).__init__() - - # Extract input dimensions - channels, height, width = input_shape - - # Feature extractor - self.features = nn.Sequential( - nn.Conv2d(channels, 32, kernel_size=3, stride=1, padding=1), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2) - ) - - # Calculate the size of feature maps after convolutions and pooling - feature_size = (height // 8) * (width // 8) * 128 - - # Classifier - self.classifier = nn.Sequential( - nn.Linear(feature_size, 256), - nn.ReLU(inplace=True), - nn.Dropout(0.5), - nn.Linear(256, num_classes) - ) - - def forward(self, x): - """Forward pass through the network.""" - x = self.features(x) - x = torch.flatten(x, 1) - x = self.classifier(x) - return x - - -class MLP(nn.Module): - """ - A simple multi-layer perceptron for simpler tasks. - """ - def __init__(self, input_shape: Tuple[int, int, int], num_classes: int, hidden_dim: int = 256): - """ - Initialize the model. - - Args: - input_shape (tuple): Shape of input images (channels, height, width) - num_classes (int): Number of output classes - hidden_dim (int): Dimensionality of hidden layers - """ - super(MLP, self).__init__() - - # Calculate input dimensionality - input_dim = input_shape[0] * input_shape[1] * input_shape[2] - - self.layers = nn.Sequential( - nn.Flatten(), - nn.Linear(input_dim, hidden_dim), - nn.ReLU(inplace=True), - nn.Dropout(0.5), - nn.Linear(hidden_dim, hidden_dim), - nn.ReLU(inplace=True), - nn.Dropout(0.5), - nn.Linear(hidden_dim, num_classes) - ) - - def forward(self, x): - """Forward pass through the network.""" - return self.layers(x) - - -class LeNet5(nn.Module): - """ - LeNet-5 convolutional network architecture for image classification. - """ - def __init__(self, input_shape: Tuple[int, int, int], num_classes: int): - """ - Initialize the model. - - Args: - input_shape (tuple): Shape of input images (channels, height, width) - num_classes (int): Number of output classes - """ - super(LeNet5, self).__init__() - - # Extract input dimensions - channels, height, width = input_shape - - # Feature extraction - self.conv1 = nn.Conv2d(channels, 6, kernel_size=5) - self.relu1 = nn.ReLU() - self.pool1 = nn.MaxPool2d(kernel_size=2) - self.conv2 = nn.Conv2d(6, 16, kernel_size=5) - self.relu2 = nn.ReLU() - self.pool2 = nn.MaxPool2d(kernel_size=2) - - # Calculate the size of feature maps before the fully connected layer - # For MNIST (1x28x28), after two conv+pool layers, we get 16x4x4 - # Need to adjust dynamically for different input shapes - conv1_out_size = (height - 5 + 1) // 2 # Conv 5x5 then pool 2x2 - conv2_out_size = (conv1_out_size - 5 + 1) // 2 # Conv 5x5 then pool 2x2 - fc_input_size = 16 * conv2_out_size * conv2_out_size - - # Classification - self.fc1 = nn.Linear(fc_input_size, 120) - self.relu3 = nn.ReLU() - self.fc2 = nn.Linear(120, 84) - self.relu4 = nn.ReLU() - self.fc3 = nn.Linear(84, num_classes) - - def forward(self, x): - """Forward pass through the network.""" - x = self.conv1(x) - x = self.relu1(x) - x = self.pool1(x) - x = self.conv2(x) - x = self.relu2(x) - x = self.pool2(x) - x = torch.flatten(x, 1) - x = self.fc1(x) - x = self.relu3(x) - x = self.fc2(x) - x = self.relu4(x) - x = self.fc3(x) - return x - - -class SmallResNet(nn.Module): - """ - A small ResNet-like architecture suitable for continual learning experiments. - """ - class ResBlock(nn.Module): - """Basic residual block with two 3x3 convolutions.""" - def __init__(self, in_channels: int, out_channels: int, stride: int = 1): - super(SmallResNet.ResBlock, self).__init__() - - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, - stride=stride, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(out_channels) - self.relu = nn.ReLU(inplace=True) - - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, - stride=1, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(out_channels) - - # Shortcut connection - self.shortcut = nn.Sequential() - if stride != 1 or in_channels != out_channels: - self.shortcut = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=1, - stride=stride, bias=False), - nn.BatchNorm2d(out_channels) - ) - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - out += self.shortcut(residual) - out = self.relu(out) - - return out - - def __init__(self, input_shape: Tuple[int, int, int], num_classes: int): - """ - Initialize the model. - - Args: - input_shape (tuple): Shape of input images (channels, height, width) - num_classes (int): Number of output classes - """ - super(SmallResNet, self).__init__() - - channels, height, width = input_shape - - self.in_channels = 16 - - self.conv1 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(16) - self.relu = nn.ReLU(inplace=True) - - # Create residual blocks - self.layer1 = self._make_layer(16, 2, stride=1) - self.layer2 = self._make_layer(32, 2, stride=2) - self.layer3 = self._make_layer(64, 2, stride=2) - - # Calculate the size after the feature extraction layers - # Each layer with stride=2 reduces dim by 2 - final_size = height // 4 if height > 8 else 2 # Don't reduce below 2x2 - - # Global average pooling - self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) - - # Classifier - self.fc = nn.Linear(64, num_classes) - - # Initialize weights - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - def _make_layer(self, out_channels: int, num_blocks: int, stride: int): - """Create a layer of residual blocks.""" - strides = [stride] + [1] * (num_blocks - 1) - layers = [] - for stride in strides: - layers.append(self.ResBlock(self.in_channels, out_channels, stride)) - self.in_channels = out_channels - - return nn.Sequential(*layers) - - def forward(self, x): - """Forward pass through the network.""" - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - - x = self.avg_pool(x) - x = torch.flatten(x, 1) - x = self.fc(x) - - return x - - -def get_model(model_name: str, input_shape: Tuple[int, int, int], num_classes: int): - """ - Factory function to get the specified model. - - Args: - model_name (str): Name of the model architecture - input_shape (tuple): Shape of input data (channels, height, width) - num_classes (int): Number of output classes - - Returns: - nn.Module: Instantiated model - """ - model_name = model_name.lower() - - if model_name == 'simple_cnn': - model = SimpleCNN(input_shape, num_classes) - elif model_name == 'mlp': - model = MLP(input_shape, num_classes) - elif model_name == 'lenet5': - model = LeNet5(input_shape, num_classes) - elif model_name == 'small_resnet': - model = SmallResNet(input_shape, num_classes) - else: - raise ValueError(f"Unknown model architecture: {model_name}") - - logger.info(f"Created {model_name} model with input shape {input_shape} and {num_classes} output classes") - - return model - - -if __name__ == "__main__": - # Test the model factory - logging.basicConfig(level=logging.INFO) - - # Test creating different models - models = { - 'simple_cnn': get_model('simple_cnn', (1, 28, 28), 10), - 'mlp': get_model('mlp', (1, 28, 28), 10), - 'lenet5': get_model('lenet5', (1, 28, 28), 10), - 'small_resnet': get_model('small_resnet', (1, 28, 28), 10) - } - - # Print model architectures - for name, model in models.items(): - print(f"\nModel: {name}") - print(model) - - # Test forward pass - x = torch.randn(2, 1, 28, 28) # Batch of 2 MNIST-like images - y = model(x) - print(f"Output shape: {y.shape}") \ No newline at end of file diff --git a/src/utils/__init__.py b/src/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/utils/metrics.py b/src/utils/metrics.py deleted file mode 100644 index c38c54a..0000000 --- a/src/utils/metrics.py +++ /dev/null @@ -1,290 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -""" -Metrics module for the Continual Learning System. -This module provides functions for calculating and tracking -performance metrics in continual learning. -""" - -import os -import logging -import numpy as np -from typing import List, Dict, Any, Optional, Tuple - -logger = logging.getLogger(__name__) - - -def evaluate_performance( - model, - task_loaders: List[Any], - device, - use_task_id: bool = False -) -> np.ndarray: - """ - Evaluate the performance of a model on all provided tasks. - - Args: - model: PyTorch model to evaluate - task_loaders (list): List of data loaders for each task - device: Device to run the model on - use_task_id (bool): Whether to provide task_id to the model's forward method - - Returns: - np.ndarray: Vector of accuracies for each task - """ - model.eval() - accuracies = [] - - for task_id, loader in enumerate(task_loaders): - correct = 0 - total = 0 - - for inputs, targets in loader: - inputs, targets = inputs.to(device), targets.to(device) - - # Forward pass with or without task_id - if use_task_id: - outputs = model(inputs, task_id=task_id) - else: - outputs = model(inputs) - - # Calculate accuracy - _, predicted = outputs.max(1) - total += targets.size(0) - correct += predicted.eq(targets).sum().item() - - # Calculate accuracy for this task - accuracy = 100.0 * correct / total if total > 0 else 0.0 - accuracies.append(accuracy) - - return np.array(accuracies) - - -def compute_forgetting(performance_matrix: np.ndarray) -> np.ndarray: - """ - Compute the forgetting matrix from a performance matrix. - - Args: - performance_matrix (np.ndarray): Matrix where rows represent training steps - (task i trained) and columns represent - performance on each task - - Returns: - np.ndarray: Forgetting matrix where each element [i,j] represents how much - task j was forgotten after training on task i - """ - num_tasks = performance_matrix.shape[0] - forgetting_matrix = np.zeros_like(performance_matrix) - - for i in range(1, num_tasks): # For each training step (except the first) - for j in range(i): # For each previously learned task - # Forgetting is the difference between the best previous performance - # and the current performance - best_previous = np.max(performance_matrix[:i, j]) - forgetting_matrix[i, j] = max(0, best_previous - performance_matrix[i, j]) - - return forgetting_matrix - - -def average_forgetting(forgetting_matrix: np.ndarray) -> List[float]: - """ - Compute the average forgetting after each task. - - Args: - forgetting_matrix (np.ndarray): Matrix where each element [i,j] represents - how much task j was forgotten after training - on task i - - Returns: - list: Average forgetting after each task - """ - num_tasks = forgetting_matrix.shape[0] - avg_forgetting = [] - - for i in range(1, num_tasks): # For each training step (except the first) - # Average forgetting for all previous tasks - if i > 0: - avg_forgetting.append(np.mean(forgetting_matrix[i, :i])) - - return avg_forgetting - - -def backward_transfer(performance_matrix: np.ndarray) -> List[float]: - """ - Compute the backward transfer after each task. - Backward transfer measures how training on later tasks affects - performance on earlier tasks. - - Args: - performance_matrix (np.ndarray): Matrix where rows represent training steps - (task i trained) and columns represent - performance on each task - - Returns: - list: Backward transfer after each task - """ - num_tasks = performance_matrix.shape[0] - bt_values = [] - - for i in range(1, num_tasks): # For each training step (except the first) - # Sum of differences between current and original performance - bt_sum = 0 - for j in range(i): # For each previously learned task - bt_sum += performance_matrix[i, j] - performance_matrix[j, j] - - # Average backward transfer - bt_values.append(bt_sum / i if i > 0 else 0) - - return bt_values - - -def forward_transfer(performance_matrix: np.ndarray, random_performance: List[float]) -> List[float]: - """ - Compute the forward transfer after each task. - Forward transfer measures how training on earlier tasks affects - the initial performance on later tasks. - - Args: - performance_matrix (np.ndarray): Matrix where rows represent training steps - (task i trained) and columns represent - performance on each task - random_performance (list): Performance on each task with random initialization - - Returns: - list: Forward transfer after each task - """ - num_tasks = performance_matrix.shape[0] - ft_values = [] - - for i in range(1, num_tasks): # For each task (except the first) - # Performance on task i before training on it - initial_perf = performance_matrix[i-1, i] - # Performance on task i with random initialization - random_perf = random_performance[i] - - # Forward transfer is the difference - ft_values.append(initial_perf - random_perf) - - return ft_values - - -def learning_curve_area(learning_curves: Dict[str, List[float]]) -> Dict[str, float]: - """ - Compute the area under the learning curve for different methods. - Higher area means faster learning. - - Args: - learning_curves (dict): Dictionary mapping method names to lists of - performance values during training - - Returns: - dict: Dictionary mapping method names to ALC values - """ - alc_values = {} - - for method, curve in learning_curves.items(): - # Normalize curve to [0, 1] - min_val = min(curve) - max_val = max(curve) - - if max_val > min_val: - normalized_curve = [(v - min_val) / (max_val - min_val) for v in curve] - else: - normalized_curve = [0.5] * len(curve) - - # Compute area under the curve - alc = np.trapz(normalized_curve, dx=1.0/len(normalized_curve)) - alc_values[method] = alc - - return alc_values - - -def average_accuracy(performance_matrix: np.ndarray) -> float: - """ - Compute the average accuracy across all tasks after training is complete. - - Args: - performance_matrix (np.ndarray): Matrix where rows represent training steps - (task i trained) and columns represent - performance on each task - - Returns: - float: Average final accuracy across all tasks - """ - # Final performance is the last row of the matrix - final_performance = performance_matrix[-1, :] - return np.mean(final_performance) - - -def compute_metrics_summary(performance_matrix: np.ndarray, random_performance: Optional[List[float]] = None) -> Dict[str, float]: - """ - Compute a summary of continual learning metrics. - - Args: - performance_matrix (np.ndarray): Matrix where rows represent training steps - (task i trained) and columns represent - performance on each task - random_performance (list, optional): Performance on each task with random initialization - - Returns: - dict: Dictionary of metric values - """ - # Compute forgetting - forgetting_matrix = compute_forgetting(performance_matrix) - avg_forget = np.mean(forgetting_matrix[-1, :-1]) if performance_matrix.shape[0] > 1 else 0.0 - - # Compute backward transfer - bt_values = backward_transfer(performance_matrix) - avg_bt = np.mean(bt_values) if bt_values else 0.0 - - # Compute forward transfer if random performance is provided - avg_ft = 0.0 - if random_performance is not None: - ft_values = forward_transfer(performance_matrix, random_performance) - avg_ft = np.mean(ft_values) if ft_values else 0.0 - - # Compute average accuracy - avg_acc = average_accuracy(performance_matrix) - - # Return metrics summary - return { - 'average_accuracy': avg_acc, - 'average_forgetting': avg_forget, - 'backward_transfer': avg_bt, - 'forward_transfer': avg_ft - } - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - - # Example performance matrix - num_tasks = 3 - performance = np.array([ - [90.0, np.nan, np.nan], # After task 1, only task 1 evaluated - [85.0, 92.0, np.nan], # After task 2, tasks 1-2 evaluated - [82.0, 88.0, 94.0] # After task 3, all tasks evaluated - ]) - - # Example random performance - random_perf = [40.0, 42.0, 38.0] - - # Compute and print metrics - forgetting_mat = compute_forgetting(performance) - print("Forgetting matrix:") - print(forgetting_mat) - - avg_forgetting_values = average_forgetting(forgetting_mat) - print(f"Average forgetting after each task: {avg_forgetting_values}") - - bt_values = backward_transfer(performance) - print(f"Backward transfer after each task: {bt_values}") - - ft_values = forward_transfer(performance, random_perf) - print(f"Forward transfer after each task: {ft_values}") - - metrics_summary = compute_metrics_summary(performance, random_perf) - print("Metrics summary:") - for metric, value in metrics_summary.items(): - print(f" {metric}: {value:.2f}") \ No newline at end of file diff --git a/src/utils/visualization.py b/src/utils/visualization.py deleted file mode 100644 index 470175b..0000000 --- a/src/utils/visualization.py +++ /dev/null @@ -1,313 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -""" -Visualization utilities for the Continual Learning System. -""" - -import os -import logging -import numpy as np -from typing import List, Dict, Any, Optional -import matplotlib.pyplot as plt -import seaborn as sns - -logger = logging.getLogger(__name__) - - -def plot_performance(performance_matrix: np.ndarray, task_names: List[str] = None): - """ - Plot the performance of a model across sequential tasks. - - Args: - performance_matrix (np.ndarray): Matrix where rows represent training progress - (task_i trained) and columns represent - performance on each task - task_names (list): Names of the tasks - - Returns: - matplotlib.figure.Figure: The generated figure - """ - # Create task names if not provided - if task_names is None: - task_names = [f"Task {i+1}" for i in range(performance_matrix.shape[1])] - - # Create figure - fig, ax = plt.subplots(figsize=(10, 6)) - - # Set seaborn style - sns.set_style("whitegrid") - - # Plot matrix as heatmap - sns.heatmap(performance_matrix, annot=True, fmt=".1f", cmap="YlGnBu", - xticklabels=task_names, yticklabels=[f"After Task {i+1}" for i in range(performance_matrix.shape[0])], - cbar_kws={'label': 'Accuracy (%)'}) - - # Add labels and title - plt.xlabel("Evaluated Task") - plt.ylabel("Training Progress") - plt.title("Model Performance After Each Task") - - plt.tight_layout() - - return fig - - -def plot_forgetting(forgetting_matrix: np.ndarray, task_names: List[str] = None): - """ - Plot the forgetting of a model across sequential tasks. - - Args: - forgetting_matrix (np.ndarray): Matrix where rows represent training progress - (task_i trained) and columns represent - forgetting on each task - task_names (list): Names of the tasks - - Returns: - matplotlib.figure.Figure: The generated figure - """ - # Create task names if not provided - if task_names is None: - task_names = [f"Task {i+1}" for i in range(forgetting_matrix.shape[1])] - - # Create figure - fig, ax = plt.subplots(figsize=(10, 6)) - - # Set seaborn style - sns.set_style("whitegrid") - - # Plot matrix as heatmap with a different colormap for forgetting - sns.heatmap(forgetting_matrix, annot=True, fmt=".1f", cmap="Reds", - xticklabels=task_names, yticklabels=[f"After Task {i+1}" for i in range(forgetting_matrix.shape[0])], - cbar_kws={'label': 'Forgetting (%)'}) - - # Add labels and title - plt.xlabel("Task Forgotten") - plt.ylabel("Training Progress") - plt.title("Forgetting After Each Task") - - plt.tight_layout() - - return fig - - -def plot_accuracy_over_time(accuracies: List[float], task_boundaries: List[int], - task_names: List[str] = None, title: str = "Accuracy Over Time"): - """ - Plot the accuracy of a model throughout training across multiple tasks. - - Args: - accuracies (list): Validation accuracies throughout training - task_boundaries (list): Epoch indices where tasks change - task_names (list): Names of the tasks - title (str): Title for the plot - - Returns: - matplotlib.figure.Figure: The generated figure - """ - # Create task names if not provided - num_tasks = len(task_boundaries) + 1 - if task_names is None: - task_names = [f"Task {i+1}" for i in range(num_tasks)] - - # Create figure - fig, ax = plt.subplots(figsize=(12, 6)) - - # Set seaborn style - sns.set_style("whitegrid") - - # Plot accuracy - epochs = np.arange(1, len(accuracies) + 1) - plt.plot(epochs, accuracies, 'b-', linewidth=2) - - # Add task boundary lines - for boundary in task_boundaries: - plt.axvline(x=boundary, color='r', linestyle='--') - - # Add task labels - task_midpoints = [0] + task_boundaries - task_midpoints.append(len(accuracies)) - for i in range(num_tasks): - midpoint = (task_midpoints[i] + task_midpoints[i+1]) / 2 - plt.text(midpoint, min(accuracies) - 5, task_names[i], - horizontalalignment='center', fontsize=12) - - # Add labels and title - plt.xlabel("Epoch") - plt.ylabel("Accuracy (%)") - plt.title(title) - plt.ylim(min(accuracies) - 10, max(accuracies) + 5) - - plt.tight_layout() - - return fig - - -def plot_task_comparison(final_accuracies: Dict[str, List[float]], task_names: List[str] = None): - """ - Plot a comparison of final accuracies across different methods. - - Args: - final_accuracies (dict): Dictionary mapping method names to lists of final accuracies on each task - task_names (list): Names of the tasks - - Returns: - matplotlib.figure.Figure: The generated figure - """ - # Get number of tasks - num_tasks = len(next(iter(final_accuracies.values()))) - - # Create task names if not provided - if task_names is None: - task_names = [f"Task {i+1}" for i in range(num_tasks)] - - # Create figure - fig, ax = plt.subplots(figsize=(12, 8)) - - # Set seaborn style - sns.set_style("whitegrid") - - # Prepare data for grouped bar chart - methods = list(final_accuracies.keys()) - x = np.arange(len(task_names)) - width = 0.8 / len(methods) - - # Plot bars for each method - for i, method in enumerate(methods): - offset = (i - len(methods)/2 + 0.5) * width - plt.bar(x + offset, final_accuracies[method], width, label=method) - - # Add labels and title - plt.xlabel("Task") - plt.ylabel("Final Accuracy (%)") - plt.title("Comparison of Methods Across Tasks") - plt.xticks(x, task_names) - plt.legend() - - plt.tight_layout() - - return fig - - -def plot_average_metrics(metrics: Dict[str, Dict[str, float]], metrics_to_plot: List[str] = None): - """ - Plot average metrics across different methods. - - Args: - metrics (dict): Dictionary mapping method names to dictionaries of metrics - metrics_to_plot (list): List of metric names to plot - - Returns: - matplotlib.figure.Figure: The generated figure - """ - # Determine which metrics to plot - if metrics_to_plot is None: - # Get all unique metrics - metrics_to_plot = set() - for method_metrics in metrics.values(): - metrics_to_plot.update(method_metrics.keys()) - metrics_to_plot = sorted(list(metrics_to_plot)) - - # Create figure - fig, axes = plt.subplots(1, len(metrics_to_plot), figsize=(15, 6)) - if len(metrics_to_plot) == 1: - axes = [axes] # Ensure axes is always a list - - # Set seaborn style - sns.set_style("whitegrid") - - # Plot each metric - for i, metric in enumerate(metrics_to_plot): - ax = axes[i] - - # Extract values for this metric from each method - methods = [] - values = [] - for method, method_metrics in metrics.items(): - if metric in method_metrics: - methods.append(method) - values.append(method_metrics[metric]) - - # Plot bar chart - colors = sns.color_palette("viridis", len(methods)) - ax.bar(methods, values, color=colors) - - # Add labels - ax.set_title(metric) - ax.set_ylabel("Value") - ax.set_xticks(range(len(methods))) - ax.set_xticklabels(methods, rotation=45, ha='right') - - plt.tight_layout() - - return fig - - -def plot_forgetting_curve(forgetting_values: Dict[str, List[float]], task_names: List[str] = None): - """ - Plot forgetting curves for different methods. - - Args: - forgetting_values (dict): Dictionary mapping method names to lists of forgetting values - task_names (list): Names of the tasks - - Returns: - matplotlib.figure.Figure: The generated figure - """ - # Get number of tasks - num_tasks = len(next(iter(forgetting_values.values()))) - - # Create task names if not provided - if task_names is None: - task_names = [f"Task {i+1}" for i in range(num_tasks)] - - # Create figure - fig, ax = plt.subplots(figsize=(10, 6)) - - # Set seaborn style - sns.set_style("whitegrid") - - # Plot forgetting curves for each method - for method, values in forgetting_values.items(): - plt.plot(range(1, num_tasks), values, marker='o', linewidth=2, label=method) - - # Add labels and title - plt.xlabel("Tasks Learned") - plt.ylabel("Average Forgetting (%)") - plt.title("Forgetting Curve for Different Methods") - plt.xticks(range(1, num_tasks), [f"After Task {i+1}" for i in range(1, num_tasks)]) - plt.legend() - plt.grid(True) - - plt.tight_layout() - - return fig - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - - # Demo visualization with random data - num_tasks = 3 - performance = np.random.uniform(60, 100, size=(num_tasks, num_tasks)) - # Set upper triangle to NaN (tasks not seen yet) - for i in range(num_tasks): - for j in range(i+1, num_tasks): - performance[i, j] = np.nan - - # Example forgetting matrix - forgetting = np.zeros((num_tasks, num_tasks)) - for i in range(1, num_tasks): - for j in range(i): - forgetting[i, j] = performance[j, j] - performance[i, j] - - # Create demo plots - task_names = ["Digits 0-4", "Digits 5-9", "Fashion MNIST"] - - perf_fig = plot_performance(performance, task_names) - perf_fig.savefig("demo_performance.png") - - forget_fig = plot_forgetting(forgetting, task_names) - forget_fig.savefig("demo_forgetting.png") - - logger.info("Created demo visualizations: demo_performance.png and demo_forgetting.png") \ No newline at end of file diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..b664ef7 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from cl_bench.config import ExperimentConfig, load_config, load_config_with_overrides + + +def test_load_named_smoke_config() -> None: + config = load_config("smoke") + + assert isinstance(config, ExperimentConfig) + assert config.name == "smoke" + assert config.method == "baseline" + assert config.tasks[0].classes == [0, 1] + assert config.tasks[1].dataset == "synthetic" + + +def test_load_real_data_quick_config() -> None: + config = load_config("split_mnist_quick") + + assert config.name == "split_mnist_quick" + assert config.model == "small_cnn" + assert config.epochs == 3 + assert len(config.tasks) == 5 + assert all(task.dataset == "mnist" for task in config.tasks) + assert all(task.train_limit == 600 for task in config.tasks) + + +def test_load_cifar_headline_config_and_overrides() -> None: + config = load_config("split_cifar10_headline") + + assert config.name == "split_cifar10_headline" + assert config.method == "derpp" + assert config.model == "cifar_convnet" + assert config.tracking == "both" + assert config.epochs == 5 + assert config.replay_buffer_size == 5000 + assert config.replay_batch_size == 256 + assert config.replay_loss_weight == 3.0 + assert config.derpp_alpha == 0.1 + assert config.derpp_beta == 2.0 + assert len(config.tasks) == 5 + assert all(task.dataset == "cifar10" for task in config.tasks) + + overridden = load_config_with_overrides( + "split_cifar10_headline", + ["method=agem", "training.epochs=1", "tracking.mode=json"], + ) + assert overridden.method == "agem" + assert overridden.epochs == 1 + assert overridden.tracking == "json" + + +def test_nested_training_and_strategy_values_are_parsed() -> None: + config = ExperimentConfig.from_dict( + { + "name": "unit", + "method": "ewc", + "training": {"epochs": 3, "batch_size": 12, "learning_rate": 0.02}, + "strategy": {"ewc_lambda": 123.0, "fisher_samples": 7}, + "tasks": [ + { + "name": "a", + "dataset": "synthetic", + "classes": [0, 1], + } + ], + } + ) + + assert config.epochs == 3 + assert config.batch_size == 12 + assert config.learning_rate == 0.02 + assert config.ewc_lambda == 123.0 + assert config.fisher_samples == 7 diff --git a/tests/test_datasets.py b/tests/test_datasets.py new file mode 100644 index 0000000..da821bd --- /dev/null +++ b/tests/test_datasets.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import torch + +from cl_bench.config import ExperimentConfig, TaskSpec +from cl_bench.datasets import build_task_loaders + + +def test_synthetic_task_construction_is_deterministic() -> None: + config = ExperimentConfig( + name="unit", + method="baseline", + seed=11, + model="linear", + batch_size=8, + eval_batch_size=16, + val_fraction=0.25, + tasks=[ + TaskSpec( + name="first", + dataset="synthetic", + classes=[0, 1], + samples_per_class=8, + test_samples_per_class=4, + ), + TaskSpec( + name="second", + dataset="synthetic", + classes=[2, 3], + samples_per_class=8, + test_samples_per_class=4, + ), + ], + ) + + first_tasks, input_shape, num_classes = build_task_loaders(config) + second_tasks, second_shape, second_num_classes = build_task_loaders(config) + + assert [task.name for task in first_tasks] == ["first", "second"] + assert input_shape == (1, 8, 8) + assert num_classes == 4 + assert second_shape == input_shape + assert second_num_classes == num_classes + + first_batch = next(iter(first_tasks[0].train_loader)) + second_batch = next(iter(second_tasks[0].train_loader)) + assert first_batch[1].tolist() == second_batch[1].tolist() + + +def test_cifar_task_construction_uses_rgb_shape_and_limits(monkeypatch) -> None: + class FakeCIFAR10: + classes = [str(index) for index in range(10)] + + def __init__(self, root, train: bool, download: bool, transform): + del root, download + self.transform = transform + examples_per_class = 6 if train else 4 + self.targets = [class_id for class_id in range(10) for _ in range(examples_per_class)] + self.data = [ + torch.full((3, 32, 32), float(class_id) / 10.0) + for class_id in range(10) + for _ in range(examples_per_class) + ] + + def __len__(self) -> int: + return len(self.targets) + + def __getitem__(self, index: int): + return self.data[index], self.targets[index] + + monkeypatch.setattr("cl_bench.datasets.tv_datasets.CIFAR10", FakeCIFAR10) + config = ExperimentConfig( + name="cifar_unit", + method="baseline", + seed=3, + model="cifar_convnet", + batch_size=4, + eval_batch_size=8, + val_fraction=0.0, + augment=False, + tasks=[ + TaskSpec( + name="cifar_0_1", + dataset="cifar10", + classes=[0, 1], + train_limit=5, + test_limit=4, + ) + ], + ) + + tasks, input_shape, num_classes = build_task_loaders(config) + + assert input_shape == (3, 32, 32) + assert num_classes == 2 + assert len(tasks[0].train_loader.dataset) == 5 + assert len(tasks[0].test_loader.dataset) == 4 diff --git a/tests/test_integration_smoke.py b/tests/test_integration_smoke.py new file mode 100644 index 0000000..1149060 --- /dev/null +++ b/tests/test_integration_smoke.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from dataclasses import replace + +from cl_bench.config import load_config +from cl_bench.experiments import run_experiment + + +def test_cpu_smoke_benchmark_writes_reproducibility_artifacts(tmp_path) -> None: + config = replace(load_config("smoke"), output_dir=str(tmp_path), method="derpp") + + result = run_experiment(config) + + assert result.run_dir.exists() + assert result.config_path.exists() + assert result.metrics_path.exists() + assert (result.run_dir / "metrics.jsonl").exists() + assert (result.run_dir / "accuracy_matrix.csv").exists() + assert (result.run_dir / "forgetting_matrix.csv").exists() + assert result.summary["num_tasks"] == 2 + + +def test_synthetic_suite_runs_core_memory_methods(tmp_path) -> None: + for method in ["baseline", "replay", "derpp", "agem"]: + config = replace(load_config("smoke"), output_dir=str(tmp_path), method=method) + result = run_experiment(config) + + assert result.method == method + assert result.metrics_path.exists() + assert result.summary["num_tasks"] == 2 diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..703c908 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import numpy as np + +from cl_bench.metrics import compute_forgetting, matrix_to_jsonable, summarize_accuracy + + +def test_forgetting_uses_best_previous_accuracy() -> None: + matrix = np.array( + [ + [80.0, np.nan, np.nan], + [75.0, 70.0, np.nan], + [78.0, 65.0, 90.0], + ] + ) + + forgetting = compute_forgetting(matrix) + + assert forgetting[0, 0] == 0.0 + assert forgetting[1, 0] == 5.0 + assert forgetting[2, 0] == 2.0 + assert forgetting[2, 1] == 5.0 + assert np.isnan(forgetting[0, 1]) + + +def test_summary_and_json_matrix_are_nan_safe() -> None: + matrix = np.array([[50.0, np.nan], [40.0, 75.0]]) + + summary = summarize_accuracy(matrix) + json_matrix = matrix_to_jsonable(matrix) + + assert summary["average_final_accuracy"] == 57.5 + assert summary["average_forgetting"] == 10.0 + assert json_matrix == [[50.0, None], [40.0, 75.0]] diff --git a/tests/test_replay.py b/tests/test_replay.py new file mode 100644 index 0000000..2738ad4 --- /dev/null +++ b/tests/test_replay.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import torch + +from cl_bench.strategies.replay import ReservoirReplayBuffer + + +def test_reservoir_replay_buffer_is_bounded_and_not_last_n() -> None: + buffer = ReservoirReplayBuffer(capacity=3, seed=3) + + for value in range(10): + inputs = torch.full((1, 2), float(value)) + targets = torch.tensor([value]) + buffer.add_batch(inputs, targets) + + stored_targets = [sample.target for sample in buffer.samples] + + assert len(buffer) == 3 + assert buffer.seen_count == 10 + assert stored_targets != [7, 8, 9] + + +def test_replay_sampling_returns_tensors() -> None: + buffer = ReservoirReplayBuffer(capacity=5, seed=0) + buffer.add_batch(torch.randn(4, 1, 8, 8), torch.tensor([0, 1, 2, 3])) + + inputs, targets = buffer.sample(batch_size=2) + + assert inputs.shape == (2, 1, 8, 8) + assert targets.shape == (2,) + assert targets.dtype == torch.long diff --git a/tests/test_reporting.py b/tests/test_reporting.py new file mode 100644 index 0000000..165fc5b --- /dev/null +++ b/tests/test_reporting.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import json + +from cl_bench.reporting import aggregate_records, collect_runs, write_report + + +def _write_metrics(run_dir, method: str, seed: int, final_accuracy: float) -> None: + run_dir.mkdir(parents=True) + payload = { + "benchmark": "unit", + "method": method, + "task_names": ["a", "b"], + "summary": { + "average_final_accuracy": final_accuracy, + "average_learning_accuracy": 80.0, + "average_forgetting": 5.0, + "backward_transfer": -5.0, + "runtime_seconds": 1.5, + "seed": seed, + }, + "accuracy_matrix": [[80.0, None], [70.0, final_accuracy]], + "forgetting_matrix": [[0.0, None], [10.0, 0.0]], + "runtime_seconds": 1.5, + "seed": seed, + "git_commit": None, + } + (run_dir / "metrics.json").write_text(json.dumps(payload), encoding="utf-8") + + +def test_collect_and_aggregate_runs(tmp_path) -> None: + _write_metrics(tmp_path / "baseline_seed_1", "baseline", 1, 60.0) + _write_metrics(tmp_path / "baseline_seed_2", "baseline", 2, 80.0) + _write_metrics(tmp_path / "replay_seed_1", "replay", 1, 90.0) + + records = collect_runs([tmp_path]) + leaderboard = aggregate_records(records) + + assert [row["method"] for row in leaderboard] == ["replay", "baseline"] + baseline = next(row for row in leaderboard if row["method"] == "baseline") + assert baseline["runs"] == 2 + assert baseline["seeds"] == "1,2" + assert baseline["average_final_accuracy_mean"] == 70.0 + + +def test_write_report_without_plots(tmp_path) -> None: + _write_metrics(tmp_path / "replay_seed_1", "replay", 1, 90.0) + records = collect_runs([tmp_path]) + + report = write_report(records, tmp_path / "report", "Unit report", make_plots=False) + + assert report.leaderboard_csv.exists() + assert report.summary_json.exists() + assert report.markdown.exists() + assert report.plots == [] + + +def test_collect_runs_from_mlflow_artifact_export_shape(tmp_path) -> None: + export_dir = tmp_path / "mlruns" / "0" / "run-id" / "artifacts" / "run" + _write_metrics(export_dir, "derpp", 13, 88.0) + + records = collect_runs([tmp_path / "mlruns"]) + + assert len(records) == 1 + assert records[0].method == "derpp" + assert records[0].summary["average_final_accuracy"] == 88.0 diff --git a/tests/test_strategies.py b/tests/test_strategies.py new file mode 100644 index 0000000..3ce3518 --- /dev/null +++ b/tests/test_strategies.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import torch + +from cl_bench.config import ExperimentConfig, TaskSpec +from cl_bench.datasets import build_task_loaders +from cl_bench.models import get_model +from cl_bench.strategies.agem import AGEMStrategy +from cl_bench.strategies.base import clone_state_dict +from cl_bench.strategies.baseline import BaselineStrategy +from cl_bench.strategies.derpp import DERPPStrategy +from cl_bench.strategies.ewc import EWCStrategy +from cl_bench.strategies.lwf import LwFStrategy + + +def _tiny_config() -> ExperimentConfig: + return ExperimentConfig( + name="tiny", + method="baseline", + seed=5, + model="linear", + epochs=1, + batch_size=8, + eval_batch_size=16, + learning_rate=0.05, + val_fraction=0.2, + tasks=[ + TaskSpec( + name="a", + dataset="synthetic", + classes=[0, 1], + samples_per_class=8, + test_samples_per_class=4, + ) + ], + ) + + +def test_strategy_lifecycle_trains_and_evaluates() -> None: + config = _tiny_config() + tasks, input_shape, num_classes = build_task_loaders(config) + model = get_model("linear", input_shape, num_classes) + strategy = BaselineStrategy(model=model, device=torch.device("cpu"), learning_rate=0.05) + + history = strategy.train_task(tasks[0].train_loader, tasks[0].val_loader, task_id=0, epochs=1) + metrics = strategy.evaluate(tasks[0].test_loader) + + assert len(history) == 1 + assert strategy.current_task == 0 + assert strategy.seen_tasks == 1 + assert 0.0 <= metrics["accuracy"] <= 100.0 + + +def test_clone_state_dict_does_not_alias_model_parameters() -> None: + model = torch.nn.Linear(2, 2) + cloned = clone_state_dict(model) + + with torch.no_grad(): + model.weight.add_(10.0) + + assert not torch.equal(cloned["weight"], model.state_dict()["weight"]) + + +def test_ewc_fisher_is_deterministic_and_normalized() -> None: + config = _tiny_config() + tasks, input_shape, num_classes = build_task_loaders(config) + model_a = get_model("linear", input_shape, num_classes) + model_b = get_model("linear", input_shape, num_classes) + model_b.load_state_dict(model_a.state_dict()) + + strategy_a = EWCStrategy(model_a, torch.device("cpu"), 0.01, ewc_lambda=1.0, fisher_samples=4) + strategy_b = EWCStrategy(model_b, torch.device("cpu"), 0.01, ewc_lambda=1.0, fisher_samples=4) + + fisher_a = strategy_a._estimate_fisher(tasks[0].train_loader) + fisher_b = strategy_b._estimate_fisher(tasks[0].train_loader) + + assert fisher_a.keys() == fisher_b.keys() + for name in fisher_a: + assert torch.allclose(fisher_a[name], fisher_b[name]) + assert torch.all(fisher_a[name] >= 0) + + +def test_lwf_creates_frozen_teacher_after_task() -> None: + config = _tiny_config() + tasks, input_shape, num_classes = build_task_loaders(config) + model = get_model("linear", input_shape, num_classes) + strategy = LwFStrategy(model, torch.device("cpu"), 0.05, alpha=0.5, temperature=2.0) + + strategy.train_task(tasks[0].train_loader, tasks[0].val_loader, task_id=0, epochs=1) + + assert strategy.teacher_model is not None + assert all(not parameter.requires_grad for parameter in strategy.teacher_model.parameters()) + + +def test_derpp_stores_logits_and_uses_replay_components() -> None: + config = _tiny_config() + tasks, input_shape, num_classes = build_task_loaders(config) + model = get_model("linear", input_shape, num_classes) + strategy = DERPPStrategy( + model, + torch.device("cpu"), + learning_rate=0.05, + buffer_size=16, + replay_batch_size=4, + alpha=0.5, + beta=1.0, + seed=1, + ) + inputs, targets = next(iter(tasks[0].train_loader)) + logits = strategy.model(inputs) + + strategy.observe_batch(inputs, targets, logits, task_id=0) + loss, _, components = strategy.compute_loss(inputs, targets, task_id=1) + + assert len(strategy.buffer) == inputs.size(0) + assert all(sample.logits is not None for sample in strategy.buffer.samples) + assert loss.item() > 0.0 + assert "derpp_distillation_loss" in components + assert "derpp_replay_ce_loss" in components + + +def test_agem_projects_conflicting_gradient() -> None: + model = torch.nn.Linear(2, 2, bias=False) + with torch.no_grad(): + model.weight.zero_() + strategy = AGEMStrategy( + model, + torch.device("cpu"), + learning_rate=0.1, + buffer_size=4, + memory_batch_size=2, + seed=1, + ) + strategy.buffer.add_batch( + torch.tensor([[2.0, 0.0], [2.0, 0.0]]), + torch.tensor([0, 0]), + ) + current_inputs = torch.tensor([[2.0, 0.0], [2.0, 0.0]]) + current_targets = torch.tensor([1, 1]) + + loss, _, _ = strategy.compute_loss(current_inputs, current_targets, task_id=1) + strategy.optimizer.zero_grad(set_to_none=True) + loss.backward() + strategy.after_backward(current_inputs, current_targets, task_id=1) + + assert strategy.last_gradient_dot < 0.0 + assert strategy.last_projection_applied == 1.0 diff --git a/tests/test_tracking.py b/tests/test_tracking.py new file mode 100644 index 0000000..d0a7ab3 --- /dev/null +++ b/tests/test_tracking.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from dataclasses import replace +from pathlib import Path + +import pytest + +from cl_bench.config import load_config +from cl_bench.experiments import run_experiment + +mlflow = pytest.importorskip("mlflow") + + +def test_mlflow_tracking_logs_params_metrics_and_artifacts(tmp_path) -> None: + tracking_uri = f"sqlite:///{tmp_path / 'mlflow.db'}" + config = replace( + load_config("smoke"), + output_dir=str(tmp_path / "runs"), + method="baseline", + tracking="both", + mlflow_tracking_uri=tracking_uri, + mlflow_experiment="unit-cl-bench", + ) + + result = run_experiment(config) + + mlflow.set_tracking_uri(tracking_uri) + experiment = mlflow.get_experiment_by_name("unit-cl-bench") + assert experiment is not None + runs = mlflow.search_runs([experiment.experiment_id], output_format="list") + assert len(runs) == 1 + run = runs[0] + assert run.data.params["method"] == "baseline" + assert "average_final_accuracy" in run.data.metrics + artifact_path = mlflow.artifacts.download_artifacts( + run_id=run.info.run_id, + artifact_path="metrics.json", + ) + assert Path(artifact_path).exists() + assert result.metrics_path.exists()