Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added download link to repvit and weights extraction code #31

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README_TRAIN.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ Note: adjust the number of GPUs and the batch size to fit your experiment enviro

Run the following commands to start encoder-only KD:

Download [rep_vit weight](https://github.com/THU-MIG/RepViT/releases).

```
wget https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m0_9_distill_300e.pth

mv repvit_m0_9_distill_300e.pth weights/repvit_m1_distill_300.pth
```


```
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --master_port 29501 --nproc_per_node 8 \
training/train.py --cfg training/configs/rep_vit_m1_fuse_sa_distill.yaml \
Expand All @@ -62,6 +71,7 @@ python scripts/convert_weights.py output/rep_vit_m1_fuse_sa_distill/default/ckpt
```

## (Phase 2) Prompt-in-the-Loop Knowledge Distillation <a name="prompt"></a>
Use script/extract_weights.py to extract prompt encoder and decoder weights from sam_vit_h_4b8939.pth.

Run the following commands to start prompt-in-the-loop KD:

Expand Down
10 changes: 10 additions & 0 deletions scripts/extract_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch

sam_weight = torch.load('weights/sam_vit_h_4b8939.pth')
key_word = 'encoder'
new_weight = {}
for key in sam_weight.keys():
if key_word in key:
new_weight[key] = sam_weight[key]

torch.save(new_weight, f'weights/sam_vit_h_{key_word}.pth')