|
| 1 | +import numpy as np |
| 2 | +import cv2 |
| 3 | +from glob import glob |
| 4 | +import os |
| 5 | +import os.path as path |
| 6 | +import random |
| 7 | +import pickle |
| 8 | + |
| 9 | +import scipy.stats as st |
| 10 | + |
| 11 | +import torch |
| 12 | +import torch.utils.data as Data |
| 13 | + |
| 14 | + |
| 15 | +def process_pts(line): |
| 16 | + line = line.replace(',', '') |
| 17 | + line = line.split(' ') |
| 18 | + fname = line[0] |
| 19 | + pts = line[1:-3] |
| 20 | + ang = line[-3:] |
| 21 | + ang = [float(i) for i in ang] |
| 22 | + ang = np.float32(ang) |
| 23 | + pts = [float(i) for i in pts] |
| 24 | + pts = np.float32(pts) |
| 25 | + pts = pts.reshape([-1,3]) |
| 26 | + return fname, pts, ang |
| 27 | + |
| 28 | +def plot_gaussian_kernel(pos, size=25): |
| 29 | + sigma = (size-1) / 6 |
| 30 | + xx = np.linspace(-3,3,size) |
| 31 | + x, y = pos[0], pos[1] |
| 32 | + xbias = (x - (size-1)/2) / sigma |
| 33 | + x = xx + xbias |
| 34 | + ybias = (y - (size-1)/2) / sigma |
| 35 | + y = xx + ybias |
| 36 | + x = st.norm.pdf(x) |
| 37 | + y = st.norm.pdf(y) |
| 38 | + exp = np.outer(y,x) |
| 39 | + hmap = exp / exp.max() |
| 40 | + return hmap |
| 41 | + |
| 42 | +def plot_gaussian(hmap, pos, size, ksize=25): |
| 43 | + x, y = pos[0]/(384/size), pos[1]/(384/size) |
| 44 | + x1 = int(np.floor(x - ksize//2)) |
| 45 | + x2 = x1 + ksize |
| 46 | + y1 = int(np.floor(y - ksize//2)) |
| 47 | + y2 = y1 + ksize |
| 48 | + x = x - x1 |
| 49 | + y = y - y1 |
| 50 | + kernel = plot_gaussian_kernel([x,y], size=ksize) |
| 51 | + |
| 52 | + kernel_x1 = kernel_y1 = 0 |
| 53 | + kernel_x2 = kernel_y2 = ksize |
| 54 | + if x1<0: |
| 55 | + kernel_x1 = -x1 |
| 56 | + x1 = 0 |
| 57 | + |
| 58 | + if y1<0: |
| 59 | + kernel_y1 = -y1 |
| 60 | + y1 = 0 |
| 61 | + |
| 62 | + if y2>size: |
| 63 | + kernel_y2 = ksize - (y2 - size) |
| 64 | + y2 = size |
| 65 | + |
| 66 | + if x2 > size: |
| 67 | + kernel_x2 = ksize - (x2 - size) |
| 68 | + x2 = size |
| 69 | + |
| 70 | + # try: |
| 71 | + hmap[y1:y2, x1:x2] = kernel[kernel_y1:kernel_y2, kernel_x1:kernel_x2] |
| 72 | + # except Exception as e: |
| 73 | + # print(e) |
| 74 | + # print(y1,y2,x1,x2, kernel_y1,kernel_y2, kernel_x1, kernel_x2) |
| 75 | + |
| 76 | +def get_hmap(pts, size=128): |
| 77 | + hmap = np.zeros([size, size, 68]) |
| 78 | + for i in range(len(pts)): |
| 79 | + plot_gaussian(hmap[:,:,i], pts[i], size=size) |
| 80 | + return hmap |
| 81 | + |
| 82 | +# def get_hmap(pts, size=256): |
| 83 | +# pos = np.dstack(np.mgrid[0:size:1, 0:size:1]) |
| 84 | +# hmap = np.zeros([size, size, 68]) |
| 85 | +# for i, point in enumerate(pts): |
| 86 | +# p_resize = point / 256 * size |
| 87 | +# hmap[:, :, i] = st.multivariate_normal(mean=[p_resize[1], p_resize[0]], cov=16).pdf(pos) |
| 88 | +# return hmap |
| 89 | + |
| 90 | +def Seg2map(seg, size=128, interpolation=cv2.INTER_NEAREST): |
| 91 | + seg_new = np.zeros(seg.shape, dtype='float32') |
| 92 | + seg_new[seg > 7.5] = 1 |
| 93 | + seg = np.copy(cv2.resize(seg_new, (size, size), interpolation=interpolation)) |
| 94 | + return seg |
| 95 | + |
| 96 | +def Cv2tensor(img): |
| 97 | + img = img.transpose(2, 0, 1) |
| 98 | + img = torch.from_numpy(img.astype(np.float32)) |
| 99 | + return img |
| 100 | + |
| 101 | +class Reenactset(Data.Dataset): |
| 102 | + def __init__(self, pkl_path='', img_path='', max_iter=80000, consistency_iter=2, image_size=128): |
| 103 | + super(Reenactset, self).__init__() |
| 104 | + self.img_path = img_path |
| 105 | + self.data = pickle.load(open(pkl_path, 'rb')) |
| 106 | + self.idx = list(self.data.keys()) |
| 107 | + self.size = max_iter |
| 108 | + self.image_size = image_size |
| 109 | + self.consistency_iter = consistency_iter |
| 110 | + assert self.consistency_iter > 0 |
| 111 | + |
| 112 | + def __getitem__(self, index): |
| 113 | + ID = random.choice(self.idx) |
| 114 | + samples = random.sample(self.data[ID], self.consistency_iter+1) |
| 115 | + source = samples[0] |
| 116 | + target = samples[1] |
| 117 | + mid_samples = samples[2:] |
| 118 | + |
| 119 | + source_name, _, _ = process_pts(source) |
| 120 | + target_name, pts, _ = process_pts(target) |
| 121 | + |
| 122 | + m_pts = [] |
| 123 | + for m_s in mid_samples: |
| 124 | + _, m_pt, _ = process_pts(m_s) |
| 125 | + m_pts.append(m_pt) |
| 126 | + |
| 127 | + # pts = torch.from_numpy(pts[:, 0:2].astype(np.float32)) |
| 128 | + # pts = torch.unsqueeze(pts, 0) |
| 129 | + m_hmaps = [] |
| 130 | + hmap = Cv2tensor(get_hmap(pts, size=self.image_size)) |
| 131 | + for m_pt in m_pts: |
| 132 | + m_hmaps.append(Cv2tensor(get_hmap(m_pt, size=self.image_size))) |
| 133 | + |
| 134 | + source_file = self.img_path + f'/img/{ID}/{source_name}' |
| 135 | + target_file = self.img_path + f'/img/{ID}/{target_name}' |
| 136 | + target_seg_file = self.img_path + f'/seg/{ID}/seg_{target_name}' |
| 137 | + |
| 138 | + source_img = cv2.imread(source_file) |
| 139 | + source_img = cv2.resize(source_img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR) |
| 140 | + source_img = source_img / 255 |
| 141 | + source_img = Cv2tensor(source_img) |
| 142 | + |
| 143 | + target_img = cv2.imread(target_file) |
| 144 | + target_img = cv2.resize(target_img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR) |
| 145 | + target_img = target_img / 255 |
| 146 | + target_img = Cv2tensor(target_img) |
| 147 | + |
| 148 | + # source_seg = cv2.imread(self.img_path + f'/seg/{ID}/seg_{source_name}') |
| 149 | + # source_seg = Seg2map(source_seg, size=self.image_size) |
| 150 | + # source_seg = Cv2tensor(source_seg) |
| 151 | + |
| 152 | + target_seg = cv2.imread(target_seg_file) |
| 153 | + target_seg = Seg2map(target_seg, size=self.image_size) |
| 154 | + target_seg = Cv2tensor(target_seg) |
| 155 | + |
| 156 | + return source_img, hmap, target_img, target_seg, m_hmaps |
| 157 | + |
| 158 | + def __len__(self): |
| 159 | + return self.size |
| 160 | + |
| 161 | +class Reenactset_author(Data.Dataset): |
| 162 | + def __init__(self, img_path='', seg_path='', max_iter=80000, consistency_iter=3, image_size=256): |
| 163 | + super(Reenactset_author, self).__init__() |
| 164 | + self.img_path = img_path |
| 165 | + self.seg_path = seg_path |
| 166 | + self.data_list = sorted(glob(self.img_path + '/*.txt')) |
| 167 | + self.size = max_iter |
| 168 | + self.image_size = image_size |
| 169 | + self.consistency_iter = consistency_iter |
| 170 | + assert self.consistency_iter > 0 |
| 171 | + |
| 172 | + for data in self.data_list: |
| 173 | + lineList = [line.rstrip('\n') for line in open(data, 'r')] |
| 174 | + if len(lineList)<(self.consistency_iter+1): |
| 175 | + self.data_list.remove(data) |
| 176 | + |
| 177 | + def __getitem__(self, index): |
| 178 | + while True: |
| 179 | + ID = random.choice(self.data_list) |
| 180 | + lineList = [line.rstrip('\n') for line in open(ID, 'r')] |
| 181 | + samples = random.sample(lineList, self.consistency_iter+1) |
| 182 | + source = samples[0] |
| 183 | + target = samples[1] |
| 184 | + mid_samples = samples[2:] |
| 185 | + |
| 186 | + source_name, _, _ = process_pts(source) |
| 187 | + target_name, pts, _ = process_pts(target) |
| 188 | + |
| 189 | + m_pts = [] |
| 190 | + for m_s in mid_samples: |
| 191 | + _, m_pt, _ = process_pts(m_s) |
| 192 | + m_pts.append(m_pt) |
| 193 | + |
| 194 | + # pts = torch.from_numpy(pts[:, 0:2].astype(np.float32)) |
| 195 | + # pts = torch.unsqueeze(pts, 0) |
| 196 | + m_hmaps = [] |
| 197 | + try: |
| 198 | + hmap = Cv2tensor(get_hmap(pts, size=self.image_size)) |
| 199 | + for m_pt in m_pts: |
| 200 | + m_hmaps.append(Cv2tensor(get_hmap(m_pt, size=self.image_size))) |
| 201 | + except: |
| 202 | + continue |
| 203 | + |
| 204 | + source_file = self.img_path + f'/img/{path.basename(path.split(source_name)[0])}/{path.basename(source_name)}' |
| 205 | + target_file = self.img_path + f'/img/{path.basename(path.split(target_name)[0])}/{path.basename(target_name)}' |
| 206 | + target_seg_file = self.seg_path + f'/seg/{path.basename(path.split(target_name)[0])}/seg_{path.basename(target_name)}'[:-4] + '.png' |
| 207 | + |
| 208 | + source_img = cv2.imread(source_file) |
| 209 | + source_img = cv2.resize(source_img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR) |
| 210 | + source_img = source_img / 255 |
| 211 | + source_img = Cv2tensor(source_img) |
| 212 | + |
| 213 | + target_img = cv2.imread(target_file) |
| 214 | + target_img = cv2.resize(target_img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR) |
| 215 | + target_img = target_img / 255 |
| 216 | + target_img = Cv2tensor(target_img) |
| 217 | + |
| 218 | + # source_seg = cv2.imread(self.img_path + f'/seg/{ID}/seg_{source_name}') |
| 219 | + # source_seg = Seg2map(source_seg, size=self.image_size) |
| 220 | + # source_seg = Cv2tensor(source_seg) |
| 221 | + |
| 222 | + target_seg = cv2.imread(target_seg_file) |
| 223 | + target_seg = Seg2map(target_seg, size=self.image_size) |
| 224 | + target_seg = Cv2tensor(target_seg) |
| 225 | + break |
| 226 | + |
| 227 | + return source_img, hmap, target_img, target_seg, m_hmaps |
| 228 | + |
| 229 | + def __len__(self): |
| 230 | + return self.size |
| 231 | + |
| 232 | +def extract_img_hmap(path, line, augmentation=False, is_target=False): |
| 233 | + line = line.split(', ') |
| 234 | + _, img_file = os.path.split(line[0]) |
| 235 | + img = cv2.imread(os.path.join(path, img_file)) |
| 236 | + img = img / 255 |
| 237 | + |
| 238 | + if augmentation: |
| 239 | + is_flip = np.random.choice([True, False]) |
| 240 | + if is_flip: |
| 241 | + img = np.fliplr(img) |
| 242 | + |
| 243 | + img = Cv2tensor(img) |
| 244 | + |
| 245 | + output = [img] |
| 246 | + |
| 247 | + if is_target: |
| 248 | + seg = cv2.imread(os.path.join(path, 'seg_'+img_file)) |
| 249 | + seg = Seg2map(seg) |
| 250 | + landmark = np.fromstring(line[1], dtype=np.float32, sep=' ').reshape((68, 2)) |
| 251 | + heatmap = get_hmap(landmark) |
| 252 | + |
| 253 | + if augmentation: |
| 254 | + if is_flip: |
| 255 | + seg = np.fliplr(seg) |
| 256 | + heatmap = np.fliplr(heatmap) |
| 257 | + |
| 258 | + seg = Cv2tensor(seg) |
| 259 | + heatmap = Cv2tensor(heatmap) |
| 260 | + output.append(seg) |
| 261 | + output.append(heatmap) |
| 262 | + |
| 263 | + return output |
| 264 | + |
| 265 | +class Allen_KangHui(Data.Dataset): |
| 266 | + def __init__(self, lm_path='', max_iter=80000): |
| 267 | + super(Allen_KangHui, self).__init__() |
| 268 | + self.img_path, _ = os.path.split(lm_path) |
| 269 | + self.img_list = sorted(glob(self.img_path+'/*.jpg')) |
| 270 | + self.lm_list = open(lm_path, 'r').read().splitlines() |
| 271 | + self.size = max_iter |
| 272 | + |
| 273 | + def __getitem__(self, index): |
| 274 | + while True: |
| 275 | + source, target = random.sample(self.lm_list, 2) |
| 276 | + source_img = extract_img_hmap(self.img_path, source, True, False)[0] |
| 277 | + target_img, target_seg, hmap = extract_img_hmap(self.img_path, target, True, True) |
| 278 | + |
| 279 | + break |
| 280 | + |
| 281 | + return source_img, hmap, target_img, target_seg |
| 282 | + |
| 283 | + def __len__(self): |
| 284 | + return self.size |
0 commit comments