-
Notifications
You must be signed in to change notification settings - Fork 0
/
1
77 lines (63 loc) · 2.81 KB
/
1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import os
import glob
import time
import torch
from torch.nn.parallel import DistributedDataParallel
from gfpgan.inference import Inferencer
from gfpgan.utils.img_util import img2tensor, tensor2img, save_img
def infer_folder(inferencer, folder, output_suffix="_gfpgan"):
# Collect image paths in the folder
image_paths = glob.glob(os.path.join(folder, "*.jpg")) + glob.glob(os.path.join(folder, "*.png"))
# Create output folder
output_folder = folder + output_suffix
os.makedirs(output_folder, exist_ok=True)
for img_path in image_paths:
# Load input image
input_img = img2tensor(img_path, bgr_order=False, expand=False)
# Apply GFPGAN
output_tensor = inferencer.predict(input_img)
# Convert output tensor to image and save to file
output_img = tensor2img(output_tensor, bgr_order=False, min_max=(-1, 1))
output_path = os.path.join(output_folder, os.path.basename(img_path))
save_img(output_img, output_path)
print(f"Saved {output_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="options/test/test_gfpgan_x2.yml")
parser.add_argument("--checkpoint_path", type=str, default="experiments/pretrained_models/gfpgan/GFPGANv1.pth")
parser.add_argument("--upscale", type=int, default=2)
parser.add_argument("--tile", type=int, default=0)
parser.add_argument("--tile_overlap", type=int, default=0)
parser.add_argument("--alpha_upsampler", type=str, default="bicubic")
parser.add_argument("--face_enhance", type=bool, default=False)
parser.add_argument("--arch_g", type=str, default="stylegan2")
parser.add_argument("--output_suffix", type=str, default="_gfpgan")
parser.add_argument("input", type=str, help="input image or folder")
args = parser.parse_args()
# Initialize inferencer
inferencer = Inferencer(
args.config,
args.checkpoint_path,
args.upscale,
args.tile,
args.tile_overlap,
args.alpha_upsampler,
args.face_enhance,
args.arch_g
)
if os.path.isfile(args.input):
# Input is a file, infer on single image
inferencer.cuda()
input_tensor = img2tensor(args.input, bgr_order=False, expand=False).cuda()
output_tensor = inferencer.predict(input_tensor)
output_img = tensor2img(output_tensor, bgr_order=False, min_max=(-1, 1))
save_img(output_img, args.input + args.output_suffix)
print(f"Saved {args.input + args.output_suffix}")
elif os.path.isdir(args.input):
# Input is a folder, infer on all images in the folder
infer_folder(inferencer, args.input, args.output_suffix)
else:
print(f"Error: {args.input} is not a valid file or folder")
if __name__ == "__main__":
main()