forked from icoz69/CaNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset_mask_val.py
166 lines (125 loc) · 6.29 KB
/
dataset_mask_val.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import random
import os
import torchvision
import torch
from PIL import Image
import torchvision.transforms.functional as F
import torch.nn.functional as F_tensor
import numpy as np
from torch.utils.data import DataLoader
import time
class Dataset(object):
def __init__(self, data_dir, fold, input_size=[321, 321], normalize_mean=[0, 0, 0],
normalize_std=[1, 1, 1]):
self.data_dir = data_dir
self.input_size = input_size
#random sample 1000 pairs
self.chosen_data_list_1 = self.get_new_exist_class_dict(fold=fold)
chosen_data_list_2 = self.chosen_data_list_1[:]
chosen_data_list_3 = self.chosen_data_list_1[:]
random.shuffle(chosen_data_list_2)
random.shuffle(chosen_data_list_3)
self.chosen_data_list=self.chosen_data_list_1+chosen_data_list_2+chosen_data_list_3
self.chosen_data_list=self.chosen_data_list[:1000]
self.binary_pair_list = self.get_binary_pair_list()#a dict of each class, which contains all imgs that include this class
self.history_mask_list = [None] * 1000
self.query_class_support_list=[None] * 1000
for index in range (1000):
query_name=self.chosen_data_list[index][0]
sample_class=self.chosen_data_list[index][1]
support_img_list = self.binary_pair_list[sample_class] # all img that contain the sample_class
while True: # random sample a support data
support_name = support_img_list[random.randint(0, len(support_img_list) - 1)]
if support_name != query_name:
break
self.query_class_support_list[index]=[query_name,sample_class,support_name]
self.initiaize_transformation(normalize_mean, normalize_std, input_size)
pass
def get_new_exist_class_dict(self, fold):
new_exist_class_list = []
f = open(os.path.join(self.data_dir, 'Binary_map_aug','val', 'split%1d_val.txt' % (fold)))
while True:
item = f.readline()
if item == '':
break
img_name = item[:11]
cat = int(item[13:15])
new_exist_class_list.append([img_name, cat])
return new_exist_class_list
def initiaize_transformation(self, normalize_mean, normalize_std, input_size):
self.ToTensor = torchvision.transforms.ToTensor()
self.normalize = torchvision.transforms.Normalize(normalize_mean, normalize_std)
def get_binary_pair_list(self): # a list store all img name that contain that class
binary_pair_list = {}
for Class in range(1, 21):
binary_pair_list[Class] = self.read_txt(
os.path.join(self.data_dir, 'Binary_map_aug', 'val', '%d.txt' % Class))
return binary_pair_list
def read_txt(self, dir):
f = open(dir)
out_list = []
line = f.readline()
while line:
out_list.append(line.split()[0])
line = f.readline()
return out_list
def __getitem__(self, index):
# give an query index,sample a target class first
query_name = self.query_class_support_list[index][0]
sample_class = self.query_class_support_list[index][1] # random sample a class in this img
support_name=self.query_class_support_list[index][2]
input_size = self.input_size[0]
# random scale and crop for support
scaled_size = int(random.uniform(1,1.5)*input_size)
scale_transform_mask = torchvision.transforms.Resize([scaled_size, scaled_size], interpolation=Image.NEAREST)
scale_transform_rgb = torchvision.transforms.Resize([scaled_size, scaled_size], interpolation=Image.BILINEAR)
flip_flag = random.random()
support_rgb = self.normalize(
self.ToTensor(
scale_transform_rgb(
self.flip(flip_flag,
Image.open(
os.path.join(self.data_dir, 'JPEGImages', support_name + '.jpg'))))))
support_mask = self.ToTensor(
scale_transform_mask(
self.flip(flip_flag,
Image.open(
os.path.join(self.data_dir, 'Binary_map_aug', 'val', str(sample_class),
support_name + '.png')))))
margin_h = random.randint(0, scaled_size - input_size)
margin_w = random.randint(0, scaled_size - input_size)
support_rgb = support_rgb[:, margin_h:margin_h + input_size, margin_w:margin_w + input_size]
support_mask = support_mask[:, margin_h:margin_h + input_size, margin_w:margin_w + input_size]
# random scale and crop for query
scaled_size = 321
scale_transform_mask = torchvision.transforms.Resize([scaled_size, scaled_size], interpolation=Image.NEAREST)
scale_transform_rgb = torchvision.transforms.Resize([scaled_size, scaled_size], interpolation=Image.BILINEAR)
flip_flag = 0#random.random()
query_rgb = self.normalize(
self.ToTensor(
scale_transform_rgb(
self.flip(flip_flag,
Image.open(
os.path.join(self.data_dir, 'JPEGImages', query_name + '.jpg'))))))
query_mask = self.ToTensor(
scale_transform_mask(
self.flip(flip_flag,
Image.open(
os.path.join(self.data_dir, 'Binary_map_aug', 'val', str(sample_class),
query_name + '.png')))))
margin_h = random.randint(0, scaled_size - input_size)
margin_w = random.randint(0, scaled_size - input_size)
query_rgb = query_rgb[:, margin_h:margin_h + input_size, margin_w:margin_w + input_size]
query_mask = query_mask[:, margin_h:margin_h + input_size, margin_w:margin_w + input_size]
if self.history_mask_list[index] is None:
history_mask=torch.zeros(2,41,41).fill_(0.0)
else:
history_mask=self.history_mask_list[index]
return query_rgb, query_mask, support_rgb, support_mask,history_mask,sample_class,index
def flip(self, flag, img):
if flag > 0.5:
return F.hflip(img)
else:
return img
def __len__(self):
return 1000