By Zhongzhan Huang, Senwei Liang, Mingfu Liang, Wei He and Haizhao Yang.
The implementation of paper ''Efficient Attention Network: Accelerate Attention by Searching Where to Plug'' [paper].
Efficient Attention Network (EAN) is a framework to improve the efficiency for the existing attention modules in computer vision. In EAN, we leverage the sharing mechanism (Huang et al. 2020) to share the attention module within the backbone and search where to connect the shared attention module via reinforcement learning.
- Python 3.6 and PyTorch 1.0
Our implementation is divided in three parts. First, we pre-train a supernet. Second, we use a policy-gradient-based method to search for an optimal connection scheme from the supernet. Last, we train from scratch a network searched by the second step.
First, we pretrain a supernet and the checkpoint is saved in NAS_ckpts. For example, we train a SGE-supernet,
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_imagenet/train_imagenet_ensemble_subset.py -a forward_config_share_sge_resnet50 -data /home/jovyan/ILSVRC2012_Data --checkpoint NAS_ckpts/ensemble_sge_train_on_subset
or train a DIA-supernet,
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_imagenet/train_imagenet_ensemble_subset.py -a forward_dia_fbresnet50 -data /home/jovyan/ILSVRC2012_Data --checkpoint NAS_ckpts/ensemble_dia_train_on_subset
Then, we search an optimal connection scheme from supernet.
For SGE,
python search_imagenet/run_code_search_sge.py
For DIA,
python search_imagenet/run_code_search_dia.py
Last, we train from scracth the attention network with the connection scheme searched in the second step. Note that to train the attention network with the different scheme, we need to edit train_imagenet/run_codes_train_from_scratch.py
python train_imagenet/run_codes_train_from_scratch.py
The checkpoints will be save in NAS_ckpts.
If you find this paper helps in your research, please kindly cite
We would like to thank Taehoon Kim for his pytorch version of ENAS fromework and Xiang Li for his attention network framework.