This repo contains the official implementation code for the paper Enhancing Few-shot Image Classification with Cosine Transformer (IEEE Access). In this project, we developed a transformer-based algorithm FS-CT for few-shot classification and cross-attention mechansim, where we proved that cosine similarity benefits attention mechanism and and improve few-shot algorithms across settings and datasets. In particular, with the proposed Cosine attention, we achieve a more stable and consistent output as correlation map between support and query feature and thus improve ViT-bases few-shot algorithms' performance greatly.
The overall architecture of the proposed Few-shot Cosine Transformer, which includes two main components: (a) learnable prototypical embedding that calculates the categorical proto representation given random support features that might be either in the far margin of the distribution or very close to each other and (b) Cosine transformer that determines the similarity matrix between proto representations and query samples for the few-shot classification tasks. The heart of the transformer architecture is Cosine attention, an attention mechanism with cosine similarity and no softmax function to deal with two different sets of features. The Cosine transformer shares a similar architecture with a standard transformer encoder block, with two skip connections to preserve information, a two-layer feed-forward network, and layer normalization between them to reduce noise. The outcome value is through a cosine linear layer, with cosine similarity replacing the dot-product, before feeding to softmax for query prediction.
pip install -r requirements.txt
- mini-ImageNet:
- Go to
/dataset/miniImagenet/
- Download the dataset from
download_miniImagenet.txt
- Run
source download_miniImagenet.sh
for processing dataset. - When complete, there are three JSON files
base.json
,val.json
, andnovel.json
for experiments
- Go to
- CUB-200:
- Go to
/dataset/CUB/
- Processing data similar to mini-Imagenet dataset
- Go to
- CIFAR-FS:
- Go to
/dataset/CIFAR_FS/
- Processing data similar to mini-Imagenet dataset
- Go to
- Omniglot:
- Go to
/dataset/Omniglot/
- Run
source download_Omniglot.sh
- Go to
- Yoga:
- This is our custom dataset with 50 yoga pose categories and 2480 images, including 50 categiores for training, 13 for validating, and 12 for testing set
- Go to
/dataset/Yoga/
- Run
source yoga_processing.sh
- Custom dataset:
- Require three data split json file:
base.json
,val.json
,novel.json
- The format should follow:
{ "label_names": `["class0","class1",...]`, "image_names": `["filepath1","filepath2",...]`, "image_labels": `[l1,l2,l3,...]` }
- Put these file in the same folder and change data_dir
['DATASETNAME']
inconfigs.py
to the corresponded folder path - See other datasets folder for examples
- Require three data split json file:
- Python scripts
- Testing only:
test.py
(does not support WandB ) - Training and testing:
train_test.py
- Testing only:
- Configurations pool:
- Backbones:
Conv4
/Conv6
/ResNet18
/ResNet34
- Methods:
CTX_softmax
/CTX_cosine
/FSCT_softmax
/FSCT_cosine
softmax
is the baseline scaled dot-product attention mechanismcosine
is our proposed Cosine attention mechanism
- Dataset:
miniImagenet
/CUB
/CIFAR
/Omniglot
/Yoga
- Backbones:
- Main parameters:
-
--backbone
: backbone model (defaultResNet34
) -
--FETI
: Using FETI (Feature Extractor Trained partially on ImageNet) for ResNet Backbone if1
, none if0
(default0
) -
--method
: few-shot method algorithm (defaultFSCT_cosine
) -
--n_way
: number of catergories for classification (default5
) -
--k_shot
: number of shot per category in Support set (default5
) -
--n_episode
: number of training/validating episodic batch per epoch -
--train_aug
: apply augmentation if1
, none if0
(default0
) -
--num_epoch
: number of training epoch (default50
) -
--wandb
: saving training log and plot visualization into WandB server if1
, none if0
(default0
) -
For other parameters, please read
io_utils.py
for detail information.
-
- Example:
python train_test.py --method FSCT_cosine --dataset miniImagenet --backbone ResNet34 --FETI 1 --n_way 5 --k_shot 5 --train_aug 0 --wandb 1
- Bash script for multiple running:
source run_script.sh
- Parameters can be modified within the script for specific experiments, including dataset, backbone, method, n_way, k_shot, augmentation
- All the method automatically push the training loss/val logs into WandB server. Set
--wandb 0
to turn it off
- Result logs after testing will be saved in
record/results.txt
Our method Few-Shot TransFormer achieves the following performances on:
Dataset | 1-shot Accuracy | 5-shot Accuracy |
---|---|---|
mini-ImageNet | 55.87+-0.86% | 73.42+-0.67% |
CIFAR-FS | 67.06+-0.89% | 82.89+-0.61% |
CUB | 81.23+-0.77% | 92.25+-0.37% |
All results are stored in record/official_results.txt
- A Closer Look at Few-shot Classification (ICLM 2019)
- CrossTransformers: spatially-aware few-shot transfer (NeurIPS 2020)
- This repository is mainly based on "A Closer Look at Few-shot Classification" official GitHub Repository: wyharveychen/CloserLookFewShot
- The CrossTransformer CTX implemented code in this repository is modified from lucidrains/cross-transformers-pytorch
If you find our code useful, please consider citing our work using the bibtex:
@article{nguyen2023FSCT,
author={Nguyen, Quang-Huy and Nguyen, Cuong Q. and Le, Dung D. and Pham, Hieu H.},
journal={IEEE Access},
title={Enhancing Few-Shot Image Classification With Cosine Transformer},
year={2023},
volume={11},
number={},
pages={79659-79672},
doi={10.1109/ACCESS.2023.3298299}}
If you have any concerns or support need on this repository, please send me an email at [email protected]