Skip to content

Commit 01d3932

Browse files
committed
release v1
1 parent 10d06b0 commit 01d3932

File tree

2 files changed

+21
-36
lines changed

2 files changed

+21
-36
lines changed

README.md

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# TW-autoencode - Tied-weights AutoEncoder
22

3-
This repository contains the source code of unrolled LRP model and baselines from the paper, **Model Guidance via Explanations Turns Image
4-
Classifiers into Segmentation Models**.
5-
6-
- TODO add link for the paper.
3+
This repository contains the source code of unrolled LRP model and baselines from the paper, [**Model Guidance via Explanations Turns Image
4+
Classifiers into Segmentation Models**](https://arxiv.org/pdf/2407.03009).
75

86

97
## Updates
108

11-
**Apr. 2024** -- Clean the code.
9+
**Apr. 2024** -- Start to clean the code.
10+
11+
**July. 2024** -- Publish the first version of code.
1212

1313
## Datasets
1414

@@ -31,58 +31,45 @@ data_parent_folder
3131
- Supervised training for UNets in 4 different cases (20, 100, 500, 1464 pixel-labeled data)
3232

3333
```
34-
qsub train.sh train.py --model std_unet --data_path /fast/AG_Kainmueller/xyu/ --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 40 --epochs 15000 --num_labels 20 --uniform_masks 1.0 --save_folder std_unet_resnet50_lab20_lr1e-5_s40
34+
python train.py --model std_unet --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 42 --epochs 15000 --num_labels 20 --uniform_masks 1.0 --save_folder std_unet_resnet50_lab20_lr1e-5_s42
3535
```
3636

3737
```
38-
qsub train.sh train.py --model std_unet --data_path /fast/AG_Kainmueller/xyu/ --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 40 --epochs 3000 --num_labels 100 --uniform_masks 1.0 --save_folder std_unet_resnet50_lab100_lr1e-5_s40
38+
python train.py --model std_unet --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 42 --epochs 3000 --num_labels 100 --uniform_masks 1.0 --save_folder std_unet_resnet50_lab100_lr1e-5_s42
3939
```
4040

4141
```
42-
qsub train.sh train.py --model std_unet --data_path /fast/AG_Kainmueller/xyu/ --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 40 --epochs 600 --num_labels 500 --uniform_masks 0.5 --save_folder std_unet_resnet50_lab500_lr1e-5_s40
42+
python train.py --model std_unet --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 42 --epochs 600 --num_labels 500 --uniform_masks 0.5 --save_folder std_unet_resnet50_lab500_lr1e-5_s42
4343
```
4444

4545
```
46-
qsub train.sh train.py --model std_unet --data_path /fast/AG_Kainmueller/xyu/ --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 40 --epochs 200 --save_folder std_unet_resnet50_lab1464_lr1e-5_s40
46+
python train.py --model std_unet --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 42 --epochs 200 --save_folder std_unet_resnet50_lab1464_lr1e-5_s42
4747
```
4848
- Semi-supervised training for multi-task UNets and Unrolled LRP models in 4 different cases (20, 100, 500, 1464 pixel-labeled data). For multi-task UNets, change `--model unrolled_lrp` to `--model mt_unet`.
4949

5050

5151
```
52-
qsub train.sh train.py --model unrolled_lrp --semisup_dataset --data_path /fast/AG_Kainmueller/xyu/ --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 40 --epochs 100 --iterative_gradients --add_classification --num_labels 20 --uniform_masks 1.0 --save_folder lrp0_resnet50_lab20_lr1e-5_s40
52+
python train.py --model unrolled_lrp --semisup_dataset --add_classification --iterative_gradients --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 42 --epochs 100 --num_labels 20 --uniform_masks 1.0 --save_folder lrp0_resnet50_lab20_lr1e-5_s42
5353
```
5454

5555
```
56-
qsub train.sh train.py --model unrolled_lrp --semisup_dataset --data_path /fast/AG_Kainmueller/xyu/ --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 40 --epochs 100 --iterative_gradients --add_classification --num_labels 100 --uniform_masks 1.0 --save_folder lrp0_resnet50_lab100_lr1e-5_s40
56+
python train.py --model unrolled_lrp --semisup_dataset --add_classification --iterative_gradients --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 42 --epochs 100 --num_labels 100 --uniform_masks 1.0 --save_folder lrp0_resnet50_lab100_lr1e-5_s42
5757
```
5858

5959
```
60-
qsub train.sh train.py --model unrolled_lrp --semisup_dataset --data_path /fast/AG_Kainmueller/xyu/ --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 40 --epochs 100 --iterative_gradients --add_classification --num_labels 500 --uniform_masks 0.5 --save_folder lrp0_resnet50_lab500_lr1e-5_s40
60+
python train.py --model unrolled_lrp --semisup_dataset --add_classification --iterative_gradients --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 42 --epochs 100 --num_labels 500 --uniform_masks 0.5 --save_folder lrp0_resnet50_lab500_lr1e-5_s42
6161
```
6262

6363
```
64-
qsub train.sh train.py --model unrolled_lrp --semisup_dataset --data_path /fast/AG_Kainmueller/xyu/ --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 40 --epochs 100 --iterative_gradients --add_classification --save_folder lrp0_resnet50_lab1464_lr1e-5_s40
65-
```
66-
67-
68-
- TODO rewrite above comments which is used for running on the max cluster
69-
70-
example
71-
```
72-
python train.py --model std_unet --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 42 --epochs 15000 --num_labels 20 --uniform_masks 1.0 --save_folder std_unet_resnet50_lab20_lr1e-5_s42
73-
```
74-
75-
76-
```
77-
python train.py --model unrolled_lrp --semisup_dataset --add_classification --iterative_gradients --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 42 --epochs 15000 --num_labels 20 --uniform_masks 1.0 --save_folder lrp0_resnet50_lab20_lr1e-5_s42
64+
python train.py --model unrolled_lrp --semisup_dataset --add_classification --iterative_gradients --batch_size 10 --pretrain_weight_name ./snapshot/resnet50_10_pre_train_21 --encoder resnet50 --seed 42 --epochs 100 --save_folder lrp0_resnet50_lab1464_lr1e-5_s42
7865
```
7966

80-
- TODO write conda environment for running the code.
67+
Note: The way for counting epochs differs between the semisupervised and supervised datasets. Therefore, you need to set a higher value for the epochs argument when training the UNet.
8168

8269

8370
## Citation
8471

85-
- TODO
72+
Yu, Xiaoyan, et al. "Model Guidance via Explanations Turns Image Classifiers into Segmentation Models." World Conference on Explainable Artificial Intelligence. Cham: Springer Nature Switzerland, 2024.
8673

8774

8875
## Contact

models/multilabel.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, args, logger):
4646
self.train_sampler = semi_supervised_sampler(self.data.training_dataset.idx_mask, self.data.training_dataset.idx_no_mask,
4747
args.batch_size, len(self.data.training_dataset),0.5, 1.0,args.seed)
4848
self.train_loader=torch.utils.data.DataLoader(self.data.training_dataset,sampler=self.train_sampler.sample(),batch_size=args.batch_size)
49-
self.test_loader=torch.utils.data.DataLoader(self.data.testing_dataset, batch_size=40)
49+
self.test_loader=torch.utils.data.DataLoader(self.data.testing_dataset, batch_size=10)
5050

5151

5252
if args.pre_batch_size>0 and args.pre_epochs>0:
@@ -193,7 +193,7 @@ def pretraining_test(self,epoch):
193193

194194
test_loss /= len(self.pre_test_loader.dataset)
195195

196-
_,_,_,_,match_ratio,avg_class_acc,_,_,_ = gen_log_message("pre-test",epoch,test_loss, metrics,len(self.pre_test_loader.dataset),return_value=True,print_class_iou=False)
196+
_,_,match_ratio,avg_class_acc,_,_,_ = gen_log_message("pre-test",epoch,test_loss, metrics,len(self.pre_test_loader.dataset),return_value=True,print_class_iou=False)
197197

198198

199199
if self.args.wandb != 'None':
@@ -217,7 +217,7 @@ def train(self, epoch):
217217

218218
for batch_idx, (data, class_labels,sem_gts,sem_gt_exist) in enumerate(self.train_loader):
219219

220-
# measure number of samples before it goes to if function about self.args.only_send_labeled_data
220+
# measure number of samples
221221
num_samples=len(data)
222222
# 1) prepare data
223223
data, class_labels, sem_gts = data.to(self.device),class_labels.to(self.device), sem_gts.long().to(self.device)
@@ -239,10 +239,6 @@ def train(self, epoch):
239239
loss1.backward(retain_graph=True) # backpropagate heatmap loss individually
240240
loss2.backward() # backpropagate classification loss
241241
else:
242-
if self.args.only_send_labeled_data: #only design for ablation test
243-
data=data[sem_gt_exist==True]
244-
class_labels=class_labels[sem_gt_exist==True]
245-
sem_gts=sem_gts[sem_gt_exist==True]
246242
class_scores,heatmaps = self.model(data,class_labels)
247243
loss, loss1, loss2= self.loss_function(class_scores,class_labels,heatmaps,sem_gts,sem_gt_exist)
248244

@@ -285,7 +281,6 @@ def train(self, epoch):
285281
def test(self, epoch):
286282

287283
self.model.eval()
288-
# init parameters etc. for tracking
289284
test_loss = 0
290285

291286
metrics=init_metrics(self.test_metrics,self.args.num_classes)
@@ -315,6 +310,9 @@ def test(self, epoch):
315310

316311

317312
test_loss /= len(self.test_loader.dataset)
313+
314+
# Reminder: match_ratio and avg_class_acc are not good for accessing the classification perfromance
315+
# we check F1 score after saving the best checkpoint for segmentation task
318316
mIoU,iou_class,match_ratio,avg_class_acc,class_acc,pixel_acc,ap = gen_log_message("test",epoch,test_loss, metrics,len(self.test_loader.dataset),return_value=True,print_class_iou=True)
319317

320318

0 commit comments

Comments
 (0)