Skip to content

Commit 03cd281

Browse files
committed
run_HR
1 parent da0903a commit 03cd281

File tree

8 files changed

+718
-31
lines changed

8 files changed

+718
-31
lines changed

Face_Detection/align_warp_back_multiple_dlib_HR.py

Lines changed: 437 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import torch
5+
import numpy as np
6+
import skimage.io as io
7+
8+
# from FaceSDK.face_sdk import FaceDetection
9+
# from face_sdk import FaceDetection
10+
import matplotlib.pyplot as plt
11+
from matplotlib.patches import Rectangle
12+
from skimage.transform import SimilarityTransform
13+
from skimage.transform import warp
14+
from PIL import Image
15+
import torch.nn.functional as F
16+
import torchvision as tv
17+
import torchvision.utils as vutils
18+
import time
19+
import cv2
20+
import os
21+
from skimage import img_as_ubyte
22+
import json
23+
import argparse
24+
import dlib
25+
26+
27+
def _standard_face_pts():
28+
pts = (
29+
np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0
30+
- 1.0
31+
)
32+
33+
return np.reshape(pts, (5, 2))
34+
35+
36+
def _origin_face_pts():
37+
pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32)
38+
39+
return np.reshape(pts, (5, 2))
40+
41+
42+
def get_landmark(face_landmarks, id):
43+
part = face_landmarks.part(id)
44+
x = part.x
45+
y = part.y
46+
47+
return (x, y)
48+
49+
50+
def search(face_landmarks):
51+
52+
x1, y1 = get_landmark(face_landmarks, 36)
53+
x2, y2 = get_landmark(face_landmarks, 39)
54+
x3, y3 = get_landmark(face_landmarks, 42)
55+
x4, y4 = get_landmark(face_landmarks, 45)
56+
57+
x_nose, y_nose = get_landmark(face_landmarks, 30)
58+
59+
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
60+
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
61+
62+
x_left_eye = int((x1 + x2) / 2)
63+
y_left_eye = int((y1 + y2) / 2)
64+
x_right_eye = int((x3 + x4) / 2)
65+
y_right_eye = int((y3 + y4) / 2)
66+
67+
results = np.array(
68+
[
69+
[x_left_eye, y_left_eye],
70+
[x_right_eye, y_right_eye],
71+
[x_nose, y_nose],
72+
[x_left_mouth, y_left_mouth],
73+
[x_right_mouth, y_right_mouth],
74+
]
75+
)
76+
77+
return results
78+
79+
80+
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
81+
82+
std_pts = _standard_face_pts() # [-1,1]
83+
target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0
84+
85+
# print(target_pts)
86+
87+
h, w, c = img.shape
88+
if normalize == True:
89+
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
90+
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
91+
92+
# print(landmark)
93+
94+
affine = SimilarityTransform()
95+
96+
affine.estimate(target_pts, landmark)
97+
98+
return affine.params
99+
100+
101+
def show_detection(image, box, landmark):
102+
plt.imshow(image)
103+
print(box[2] - box[0])
104+
plt.gca().add_patch(
105+
Rectangle(
106+
(box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none"
107+
)
108+
)
109+
plt.scatter(landmark[0][0], landmark[0][1])
110+
plt.scatter(landmark[1][0], landmark[1][1])
111+
plt.scatter(landmark[2][0], landmark[2][1])
112+
plt.scatter(landmark[3][0], landmark[3][1])
113+
plt.scatter(landmark[4][0], landmark[4][1])
114+
plt.show()
115+
116+
117+
def affine2theta(affine, input_w, input_h, target_w, target_h):
118+
# param = np.linalg.inv(affine)
119+
param = affine
120+
theta = np.zeros([2, 3])
121+
theta[0, 0] = param[0, 0] * input_h / target_h
122+
theta[0, 1] = param[0, 1] * input_w / target_h
123+
theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1
124+
theta[1, 0] = param[1, 0] * input_h / target_w
125+
theta[1, 1] = param[1, 1] * input_w / target_w
126+
theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1
127+
return theta
128+
129+
130+
if __name__ == "__main__":
131+
132+
parser = argparse.ArgumentParser()
133+
parser.add_argument("--url", type=str, default="/home/jingliao/ziyuwan/celebrities", help="input")
134+
parser.add_argument(
135+
"--save_url", type=str, default="/home/jingliao/ziyuwan/celebrities_detected_face_reid", help="output"
136+
)
137+
opts = parser.parse_args()
138+
139+
url = opts.url
140+
save_url = opts.save_url
141+
142+
### If the origin url is None, then we don't need to reid the origin image
143+
144+
os.makedirs(url, exist_ok=True)
145+
os.makedirs(save_url, exist_ok=True)
146+
147+
face_detector = dlib.get_frontal_face_detector()
148+
landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
149+
150+
count = 0
151+
152+
map_id = {}
153+
for x in os.listdir(url):
154+
img_url = os.path.join(url, x)
155+
pil_img = Image.open(img_url).convert("RGB")
156+
157+
image = np.array(pil_img)
158+
159+
start = time.time()
160+
faces = face_detector(image)
161+
done = time.time()
162+
163+
if len(faces) == 0:
164+
print("Warning: There is no face in %s" % (x))
165+
continue
166+
167+
print(len(faces))
168+
169+
if len(faces) > 0:
170+
for face_id in range(len(faces)):
171+
current_face = faces[face_id]
172+
face_landmarks = landmark_locator(image, current_face)
173+
current_fl = search(face_landmarks)
174+
175+
affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3)
176+
aligned_face = warp(image, affine, output_shape=(512, 512, 3))
177+
img_name = x[:-4] + "_" + str(face_id + 1)
178+
io.imsave(os.path.join(save_url, img_name + ".png"), img_as_ubyte(aligned_face))
179+
180+
count += 1
181+
182+
if count % 1000 == 0:
183+
print("%d have finished ..." % (count))
184+

Face_Enhancement/models/networks/generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def compute_latent_vector_size(self, opt):
9797
else:
9898
raise ValueError("opt.num_upsampling_layers [%s] not recognized" % opt.num_upsampling_layers)
9999

100-
sw = opt.crop_size // (2 ** num_up_layers)
100+
sw = opt.load_size // (2 ** num_up_layers)
101101
sh = round(sw / opt.aspect_ratio)
102102

103103
return sw, sh

Global/options/test_options.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,4 @@ def initialize(self):
9797
self.parser.add_argument(
9898
"--Scratch_and_Quality_restore", action="store_true", help="For scratched images"
9999
)
100-
100+
self.parser.add_argument("--HR", action='store_true',help='Large input size with scratches')

Global/test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ def parameter_set(opt):
8686
opt.name = "mapping_scratch"
8787
opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality")
8888
opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_scratch")
89+
if opt.HR:
90+
opt.mapping_exp = 1
91+
opt.inference_optimize = True
92+
opt.mask_dilation = 3
93+
opt.name = "mapping_Patch_Attention"
8994

9095

9196
if __name__ == "__main__":
@@ -135,6 +140,11 @@ def parameter_set(opt):
135140
if opt.NL_use_mask:
136141
mask_name = mask_loader[i]
137142
mask = Image.open(os.path.join(opt.test_mask, mask_name)).convert("RGB")
143+
if opt.mask_dilation!=0:
144+
kernel=np.ones((3,3),np.uint8)
145+
mask=np.array(mask)
146+
mask=cv2.dilate(mask,kernel,iterations=opt.mask_dilation)
147+
mask=Image.fromarray(mask.astype('uint8'))
138148
origin = input
139149
input = irregular_hole_synthesize(input, mask)
140150
mask = mask_transform(mask)

README.md

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ The code originates from our research project and the aim is to demonstrate the
2323
**We are improving the algorithm so as to process high resolution photos. It takes time and please stay tuned.**
2424

2525
## News
26+
The framework now supports the restoration of high-resolution input.
27+
28+
<img src='imgs/HR.png'>
29+
2630
Training code is available and welcome to have a try and learn the training details.
2731

2832
You can now play with our [Colab](https://colab.research.google.com/drive/1NEm6AsybIiC5TwTU_4DqDkQO0nFRB-uA?usp=sharing) and try it on your photos.
@@ -101,6 +105,16 @@ python run.py --input_folder [test_image_folder_path] \
101105
--with_scratch
102106
```
103107

108+
For high-resolution images with scratches:
109+
110+
```
111+
python run.py --input_folder [test_image_folder_path] \
112+
--output_folder [output_path] \
113+
--GPU 0 \
114+
--with_scratch \
115+
--HR
116+
```
117+
104118
Note: Please try to use the absolute path. The final results will be saved in `./output_path/final_output/`. You could also check the produced results of different steps in `output_path`.
105119

106120
### 2) Scratch Detection
@@ -132,8 +146,8 @@ python test.py --Scratch_and_Quality_restore \
132146
--outputs_dir [output_path]
133147
134148
python test.py --Quality_restore \
135-
--test_input [test_image_folder_path] \
136-
--outputs_dir [output_path]
149+
--test_input [test_image_folder_path] \
150+
--outputs_dir [output_path]
137151
```
138152

139153
<img src='imgs/global.png'>
@@ -203,14 +217,17 @@ Traing the mapping with scraches:
203217
python train_mapping.py --no_TTUR --NL_res --random_hole --use_SN --correlation_renormalize --training_dataset mapping --NL_use_mask --NL_fusion_method combine --non_local Setting_42 --use_v2_degradation --use_vae_which_epoch 200 --continue_train --name mapping_scratch --label_nc 0 --loadSize 256 --fineSize 256 --dataroot [your_data_folder] --no_instance --resize_or_crop crop_only --batchSize 36 --no_html --gpu_ids 0,1,2,3 --nThreads 8 --load_pretrainA [ckpt_of_domainA_SR_old_photos] --load_pretrainB [ckpt_of_domainB_old_photos] --l2_feat 60 --n_downsample_global 3 --mc 64 --k_size 4 --start_r 1 --mapping_n_block 6 --map_mc 512 --use_l1_feat --niter 150 --niter_decay 100 --outputs_dir [your_output_folder] --checkpoints_dir [your_ckpt_folder] --irregular_mask [absolute_path_of_mask_file]
204218
```
205219

206-
220+
Traing the mapping with scraches (Multi-Scale Patch Attention for HR input):
221+
```
222+
python train_mapping.py --no_TTUR --NL_res --random_hole --use_SN --correlation_renormalize --training_dataset mapping --NL_use_mask --NL_fusion_method combine --non_local Setting_42 --use_v2_degradation --use_vae_which_epoch 200 --continue_train --name mapping_Pathc_Attention --label_nc 0 --loadSize 256 --fineSize 256 --dataroot [your_data_folder] --no_instance --resize_or_crop crop_only --batchSize 36 --no_html --gpu_ids 0,1,2,3 --nThreads 8 --load_pretrainA [ckpt_of_domainA_SR_old_photos] --load_pretrainB [ckpt_of_domainB_old_photos] --l2_feat 60 --n_downsample_global 3 --mc 64 --k_size 4 --start_r 1 --mapping_n_block 6 --map_mc 512 --use_l1_feat --niter 150 --niter_decay 100 --outputs_dir [your_output_folder] --checkpoints_dir [your_ckpt_folder] --irregular_mask [absolute_path_of_mask_file] --mapping_exp 1
223+
```
207224

208225
## To Do
209226
- [x] Clean testing code
210227
- [x] Release pretrained model
211228
- [x] Collab demo
212-
- [ ] Replace face detection module (dlib) with RetinaFace
213229
- [x] Release training code
230+
- [x] Processing of high-resolution input
214231

215232

216233
## Citation

imgs/HR.png

3.28 MB
Loading

0 commit comments

Comments
 (0)