Skip to content

Commit a3f3609

Browse files
Add files via upload
0 parents  commit a3f3609

File tree

4 files changed

+931
-0
lines changed

4 files changed

+931
-0
lines changed

criterion.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import numpy as np
2+
import math
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as Functional
7+
from torch.nn import Parameter
8+
9+
10+
class LSGanLoss(nn.Module):
11+
def __init__(self, layer=3):
12+
super(LSGanLoss, self).__init__()
13+
self.layer = layer
14+
15+
def forward(self, real, fake):
16+
loss_G = 0
17+
loss_D = 0
18+
for i in range(self.layer):
19+
loss_G = loss_G + torch.mean((fake[i] - torch.ones_like(fake[i])) ** 2)
20+
loss_D = loss_D + 0.5 * (torch.mean((fake[i] - torch.zeros_like(fake[i])) ** 2) + torch.mean((real[i] - torch.ones_like(real[i])) ** 2))
21+
return loss_G, loss_D

dataset.py

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
import numpy as np
2+
import cv2
3+
from glob import glob
4+
import os
5+
import os.path as path
6+
import random
7+
import pickle
8+
9+
import scipy.stats as st
10+
11+
import torch
12+
import torch.utils.data as Data
13+
14+
15+
def process_pts(line):
16+
line = line.replace(',', '')
17+
line = line.split(' ')
18+
fname = line[0]
19+
pts = line[1:-3]
20+
ang = line[-3:]
21+
ang = [float(i) for i in ang]
22+
ang = np.float32(ang)
23+
pts = [float(i) for i in pts]
24+
pts = np.float32(pts)
25+
pts = pts.reshape([-1,3])
26+
return fname, pts, ang
27+
28+
def plot_gaussian_kernel(pos, size=25):
29+
sigma = (size-1) / 6
30+
xx = np.linspace(-3,3,size)
31+
x, y = pos[0], pos[1]
32+
xbias = (x - (size-1)/2) / sigma
33+
x = xx + xbias
34+
ybias = (y - (size-1)/2) / sigma
35+
y = xx + ybias
36+
x = st.norm.pdf(x)
37+
y = st.norm.pdf(y)
38+
exp = np.outer(y,x)
39+
hmap = exp / exp.max()
40+
return hmap
41+
42+
def plot_gaussian(hmap, pos, size, ksize=25):
43+
x, y = pos[0]/(384/size), pos[1]/(384/size)
44+
x1 = int(np.floor(x - ksize//2))
45+
x2 = x1 + ksize
46+
y1 = int(np.floor(y - ksize//2))
47+
y2 = y1 + ksize
48+
x = x - x1
49+
y = y - y1
50+
kernel = plot_gaussian_kernel([x,y], size=ksize)
51+
52+
kernel_x1 = kernel_y1 = 0
53+
kernel_x2 = kernel_y2 = ksize
54+
if x1<0:
55+
kernel_x1 = -x1
56+
x1 = 0
57+
58+
if y1<0:
59+
kernel_y1 = -y1
60+
y1 = 0
61+
62+
if y2>size:
63+
kernel_y2 = ksize - (y2 - size)
64+
y2 = size
65+
66+
if x2 > size:
67+
kernel_x2 = ksize - (x2 - size)
68+
x2 = size
69+
70+
# try:
71+
hmap[y1:y2, x1:x2] = kernel[kernel_y1:kernel_y2, kernel_x1:kernel_x2]
72+
# except Exception as e:
73+
# print(e)
74+
# print(y1,y2,x1,x2, kernel_y1,kernel_y2, kernel_x1, kernel_x2)
75+
76+
def get_hmap(pts, size=128):
77+
hmap = np.zeros([size, size, 68])
78+
for i in range(len(pts)):
79+
plot_gaussian(hmap[:,:,i], pts[i], size=size)
80+
return hmap
81+
82+
# def get_hmap(pts, size=256):
83+
# pos = np.dstack(np.mgrid[0:size:1, 0:size:1])
84+
# hmap = np.zeros([size, size, 68])
85+
# for i, point in enumerate(pts):
86+
# p_resize = point / 256 * size
87+
# hmap[:, :, i] = st.multivariate_normal(mean=[p_resize[1], p_resize[0]], cov=16).pdf(pos)
88+
# return hmap
89+
90+
def Seg2map(seg, size=128, interpolation=cv2.INTER_NEAREST):
91+
seg_new = np.zeros(seg.shape, dtype='float32')
92+
seg_new[seg > 7.5] = 1
93+
seg = np.copy(cv2.resize(seg_new, (size, size), interpolation=interpolation))
94+
return seg
95+
96+
def Cv2tensor(img):
97+
img = img.transpose(2, 0, 1)
98+
img = torch.from_numpy(img.astype(np.float32))
99+
return img
100+
101+
class Reenactset(Data.Dataset):
102+
def __init__(self, pkl_path='', img_path='', max_iter=80000, consistency_iter=2, image_size=128):
103+
super(Reenactset, self).__init__()
104+
self.img_path = img_path
105+
self.data = pickle.load(open(pkl_path, 'rb'))
106+
self.idx = list(self.data.keys())
107+
self.size = max_iter
108+
self.image_size = image_size
109+
self.consistency_iter = consistency_iter
110+
assert self.consistency_iter > 0
111+
112+
def __getitem__(self, index):
113+
ID = random.choice(self.idx)
114+
samples = random.sample(self.data[ID], self.consistency_iter+1)
115+
source = samples[0]
116+
target = samples[1]
117+
mid_samples = samples[2:]
118+
119+
source_name, _, _ = process_pts(source)
120+
target_name, pts, _ = process_pts(target)
121+
122+
m_pts = []
123+
for m_s in mid_samples:
124+
_, m_pt, _ = process_pts(m_s)
125+
m_pts.append(m_pt)
126+
127+
# pts = torch.from_numpy(pts[:, 0:2].astype(np.float32))
128+
# pts = torch.unsqueeze(pts, 0)
129+
m_hmaps = []
130+
hmap = Cv2tensor(get_hmap(pts, size=self.image_size))
131+
for m_pt in m_pts:
132+
m_hmaps.append(Cv2tensor(get_hmap(m_pt, size=self.image_size)))
133+
134+
source_file = self.img_path + f'/img/{ID}/{source_name}'
135+
target_file = self.img_path + f'/img/{ID}/{target_name}'
136+
target_seg_file = self.img_path + f'/seg/{ID}/seg_{target_name}'
137+
138+
source_img = cv2.imread(source_file)
139+
source_img = cv2.resize(source_img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
140+
source_img = source_img / 255
141+
source_img = Cv2tensor(source_img)
142+
143+
target_img = cv2.imread(target_file)
144+
target_img = cv2.resize(target_img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
145+
target_img = target_img / 255
146+
target_img = Cv2tensor(target_img)
147+
148+
# source_seg = cv2.imread(self.img_path + f'/seg/{ID}/seg_{source_name}')
149+
# source_seg = Seg2map(source_seg, size=self.image_size)
150+
# source_seg = Cv2tensor(source_seg)
151+
152+
target_seg = cv2.imread(target_seg_file)
153+
target_seg = Seg2map(target_seg, size=self.image_size)
154+
target_seg = Cv2tensor(target_seg)
155+
156+
return source_img, hmap, target_img, target_seg, m_hmaps
157+
158+
def __len__(self):
159+
return self.size
160+
161+
class Reenactset_author(Data.Dataset):
162+
def __init__(self, img_path='', seg_path='', max_iter=80000, consistency_iter=3, image_size=256):
163+
super(Reenactset_author, self).__init__()
164+
self.img_path = img_path
165+
self.seg_path = seg_path
166+
self.data_list = sorted(glob(self.img_path + '/*.txt'))
167+
self.size = max_iter
168+
self.image_size = image_size
169+
self.consistency_iter = consistency_iter
170+
assert self.consistency_iter > 0
171+
172+
for data in self.data_list:
173+
lineList = [line.rstrip('\n') for line in open(data, 'r')]
174+
if len(lineList)<(self.consistency_iter+1):
175+
self.data_list.remove(data)
176+
177+
def __getitem__(self, index):
178+
while True:
179+
ID = random.choice(self.data_list)
180+
lineList = [line.rstrip('\n') for line in open(ID, 'r')]
181+
samples = random.sample(lineList, self.consistency_iter+1)
182+
source = samples[0]
183+
target = samples[1]
184+
mid_samples = samples[2:]
185+
186+
source_name, _, _ = process_pts(source)
187+
target_name, pts, _ = process_pts(target)
188+
189+
m_pts = []
190+
for m_s in mid_samples:
191+
_, m_pt, _ = process_pts(m_s)
192+
m_pts.append(m_pt)
193+
194+
# pts = torch.from_numpy(pts[:, 0:2].astype(np.float32))
195+
# pts = torch.unsqueeze(pts, 0)
196+
m_hmaps = []
197+
try:
198+
hmap = Cv2tensor(get_hmap(pts, size=self.image_size))
199+
for m_pt in m_pts:
200+
m_hmaps.append(Cv2tensor(get_hmap(m_pt, size=self.image_size)))
201+
except:
202+
continue
203+
204+
source_file = self.img_path + f'/img/{path.basename(path.split(source_name)[0])}/{path.basename(source_name)}'
205+
target_file = self.img_path + f'/img/{path.basename(path.split(target_name)[0])}/{path.basename(target_name)}'
206+
target_seg_file = self.seg_path + f'/seg/{path.basename(path.split(target_name)[0])}/seg_{path.basename(target_name)}'[:-4] + '.png'
207+
208+
source_img = cv2.imread(source_file)
209+
source_img = cv2.resize(source_img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
210+
source_img = source_img / 255
211+
source_img = Cv2tensor(source_img)
212+
213+
target_img = cv2.imread(target_file)
214+
target_img = cv2.resize(target_img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
215+
target_img = target_img / 255
216+
target_img = Cv2tensor(target_img)
217+
218+
# source_seg = cv2.imread(self.img_path + f'/seg/{ID}/seg_{source_name}')
219+
# source_seg = Seg2map(source_seg, size=self.image_size)
220+
# source_seg = Cv2tensor(source_seg)
221+
222+
target_seg = cv2.imread(target_seg_file)
223+
target_seg = Seg2map(target_seg, size=self.image_size)
224+
target_seg = Cv2tensor(target_seg)
225+
break
226+
227+
return source_img, hmap, target_img, target_seg, m_hmaps
228+
229+
def __len__(self):
230+
return self.size
231+
232+
def extract_img_hmap(path, line, augmentation=False, is_target=False):
233+
line = line.split(', ')
234+
_, img_file = os.path.split(line[0])
235+
img = cv2.imread(os.path.join(path, img_file))
236+
img = img / 255
237+
238+
if augmentation:
239+
is_flip = np.random.choice([True, False])
240+
if is_flip:
241+
img = np.fliplr(img)
242+
243+
img = Cv2tensor(img)
244+
245+
output = [img]
246+
247+
if is_target:
248+
seg = cv2.imread(os.path.join(path, 'seg_'+img_file))
249+
seg = Seg2map(seg)
250+
landmark = np.fromstring(line[1], dtype=np.float32, sep=' ').reshape((68, 2))
251+
heatmap = get_hmap(landmark)
252+
253+
if augmentation:
254+
if is_flip:
255+
seg = np.fliplr(seg)
256+
heatmap = np.fliplr(heatmap)
257+
258+
seg = Cv2tensor(seg)
259+
heatmap = Cv2tensor(heatmap)
260+
output.append(seg)
261+
output.append(heatmap)
262+
263+
return output
264+
265+
class Allen_KangHui(Data.Dataset):
266+
def __init__(self, lm_path='', max_iter=80000):
267+
super(Allen_KangHui, self).__init__()
268+
self.img_path, _ = os.path.split(lm_path)
269+
self.img_list = sorted(glob(self.img_path+'/*.jpg'))
270+
self.lm_list = open(lm_path, 'r').read().splitlines()
271+
self.size = max_iter
272+
273+
def __getitem__(self, index):
274+
while True:
275+
source, target = random.sample(self.lm_list, 2)
276+
source_img = extract_img_hmap(self.img_path, source, True, False)[0]
277+
target_img, target_seg, hmap = extract_img_hmap(self.img_path, target, True, True)
278+
279+
break
280+
281+
return source_img, hmap, target_img, target_seg
282+
283+
def __len__(self):
284+
return self.size

0 commit comments

Comments
 (0)