Skip to content

JorisMonnet/ProjectSNAIL

Repository files navigation

CS-502 Project Option 2 - A Simple Neural Attentive Meta-Learner (SNAIL) - Team 7

Colin Pelletier, Matthieu Burguburu, Joris Monnet

Few Shot Benchmark for Biomedical Datasets

Installation

We used the provided fewshotbench.zip file containing the benchmark code.

Conda

Create a conda env and install requirements with:

conda env create -f environment.yml 

Before each run, activate the environment with:

conda activate few-shot-benchmark 

Pip

Alternatively, for environments that do not support conda (e.g. Google Colab), install requirements with:

python -m pip install -r requirements.txt

Usage

Training

python run.py exp.name={exp_name} method=snail dataset=tabula_muris

By default, method is set to MAML, and dataset is set to Tabula Muris. The experiment name must always be specified.

Testing

The training process will automatically evaluate at the end. To only evaluate without running training, use the following:

python run.py exp.name={exp_name} method=snail dataset=tabula_muris mode=test

Run run.py with the same parameters as the training run, with mode=test and it will automatically use the best checkpoint (as measured by val ACC) from the most recent training run with that combination of exp.name/method/dataset/model. To choose a run conducted at a different time (i.e. not the latest), pass in the timestamp in the form checkpoint.time={yyyymmdd_hhmmss}. To choose a model from a specific epoch, use checkpoint.iter=40.

Usage for SNAIL

SNAIL Training

To train SNAIL, add the n_query=1 parameter the command:

python run.py exp.name=snail_{exp_name} method=snail dataset=tabula_muris n_query=1

Hyperparameter Tuning

Use the ./hyperparamters_optim.sh bash script to run hyperparameter optimization. It contains multiple command lines which will run the training script. You can change the learning rate manually for each command line and remove/add more if necessary. To run the script, use:

./hyperparamters_optim.sh

Ablation Study

Use the ./ablation.sh bash script to run an ablation study. In this file, you need to provide the architecture that you want to test in JSON format. To run the script, use:

./ablation.sh

Datasets

We provide a set of datasets in datasets/. The data itself is not in the GitHub, but will either be automatically downloaded (Tabula Muris), or needs to be manually downloaded from here for the SwissProt dataset. These should be unzipped and put under data/{dataset_name}.

The configurations for each dataset are located at conf/dataset/{dataset_name}.yaml. To create a dataset, subclass the FewShotDataset class to create a SimpleDataset (for baseline / transfer-learning methods) and SetDataset (for the few-shot setting) and create a new config file for the dataset with the pointer to these classes.

The provided datasets are:

Dataset Task Modality Type Source
Tabula Muris Cell-type prediction Gene expression Classification Cao et al. (2021)
SwissProt Protein function prediction Protein sequence Classification Uniprot

Methods

We provide a set of methods in methods/, including a baseline method that does typical transfer learning, and meta-learning methods like Protoypical Networks (protonet), Matching Networks (matchingnet), and Model-Agnostic Meta-Learning (MAML). To create a new method, subclass the MetaTemplate class and create a new method config file at conf/method/{method_name}.yaml with the pointer to the new class. Here, we added a new method called Simple Neural Attentive Meta learner(SNAIL) from Mishra et al. (2017). The file added for this method are located in methods/snail/ folder and we added a config file: conf/method/snail.yaml.

The methods include:

Method Source
Baseline, Baseline++ Chen et al. (2019)
ProtoNet Snell et al. (2017)
MatchingNet Vinyals et al. (2016)
MAML Finn et al. (2017)
SNAIL Mishra et al. (2017)

Models

We provide a set of backbone layers, blocks, and models in backbone.py, inclduing a 2-layer fully connected network as well as ConvNets and ResNets. The default backbone for each dataset is set in each dataset's config file, e.g. dataset/tabula_muris.yaml.

Configurations

This repository uses the Hydra framework for configuration management. The top-level configurations are specified in the conf/main.yaml file. Dataset-specific values are set in files in the conf/dataset/ directory, and few-shot method-specific files are specified in conf/method.

Note that the files in the dataset directory are at the top-level package, so configurations can be set at the command line directly, e.g. n_shot = 5 or backbone.layer_dim = [20,20]. However, configurations in conf/method are in the method package, which needs to be specified e.g. method.stop_epoch=20.

Note also that in Hydra, configurations are inherited through the specification of defaults. For instance, conf/method/maml.yaml inherits from conf/method/meta_base.yaml, which itself inherits from conf/method/method_base.yaml. Each configuration file then only needs to specify the deltas/differences to the file it is inheriting from.

For more on Hydra, see their tutorial. For an example of a benchmark that uses Hydra for configuration management, see BenchMD.

Experiment Tracking

We use Weights and Biases (WandB) for tracking experiments and results during training. All hydra configurations, as well as training loss, validation accuracy, and post-train eval results are logged. To disable WandB, use wandb.mode=disabled.

You must update the project and entity fields in conf/main.yaml to your own project and entity after creating one on WandB.

To log in to WandB, run wandb login and enter the API key provided on the website for your account.

References

Algorithm implementations based on COMET and CloserLookFewShot. Dataset preprocessing code is modified from each respective dataset paper, where applicable.

About

Project for EPFL CS-502 course. Team 7

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 3

  •  
  •  
  •