-
Notifications
You must be signed in to change notification settings - Fork 43
/
utils.py
104 lines (86 loc) · 3.29 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from pathlib import Path
import cv2
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
def resize_like(x, target, mode="bilinear"):
return F.interpolate(x, target.shape[-2:], mode=mode, align_corners=False)
def list2nparray(lst, dtype=None):
"""fast conversion from nested list to ndarray by pre-allocating space"""
if isinstance(lst, np.ndarray):
return lst
assert isinstance(lst, (list, tuple)), "bad type: {}".format(type(lst))
assert lst, "attempt to convert empty list to np array"
if isinstance(lst[0], np.ndarray):
dim1 = lst[0].shape
assert all(i.shape == dim1 for i in lst)
if dtype is None:
dtype = lst[0].dtype
assert all(i.dtype == dtype for i in lst), "bad dtype: {} {}".format(
dtype, set(i.dtype for i in lst)
)
elif isinstance(lst[0], (int, float, complex, np.number)):
return np.array(lst, dtype=dtype)
else:
dim1 = list2nparray(lst[0])
if dtype is None:
dtype = dim1.dtype
dim1 = dim1.shape
shape = [len(lst)] + list(dim1)
rst = np.empty(shape, dtype=dtype)
for idx, i in enumerate(lst):
rst[idx] = i
return rst
def get_img_list(path):
if Path(path).is_file():
return [Path(path)]
else:
return (
sorted(list(Path(path).glob("*.png")))
+ sorted(list(Path(path).glob("*.jpg")))
+ sorted(list(Path(path).glob("*.jpeg")))
)
def gen_miss(img, mask, output):
imgs = get_img_list(img)
masks = get_img_list(mask)
print("Total images:", len(imgs), len(masks))
out = Path(output)
out.mkdir(parents=True, exist_ok=True)
for i, (img, mask) in tqdm(enumerate(zip(imgs, masks))):
path = out.joinpath("miss_%04d.png" % (i + 1))
img = cv2.imread(str(img), cv2.IMREAD_COLOR)
mask = cv2.imread(str(mask), cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, img.shape[:2][::-1])
mask = mask[..., np.newaxis]
miss = img * (mask > 127) + 255 * (mask <= 127)
cv2.imwrite(str(path), miss)
def merge_imgs(dirs, output, row=1, gap=2, res=512):
image_list = [get_img_list(path) for path in dirs]
img_count = [len(image) for image in image_list]
print("Total images:", img_count)
assert min(img_count) > 0, "Please check the path of empty folder."
output_dir = Path(output)
output_dir.mkdir(parents=True, exist_ok=True)
n_img = len(dirs)
row = row
column = (n_img - 1) // row + 1
print("Row:", row)
print("Column:", column)
for i, unit in tqdm(enumerate(zip(*image_list))):
name = output_dir.joinpath("merge_%04d.png" % i)
merge = (
np.ones(
[res * row + (row + 1) * gap, res * column + (column + 1) * gap, 3],
np.uint8,
)
* 255
)
for j, img in enumerate(unit):
r = j // column
c = j - r * column
img = cv2.imread(str(img), cv2.IMREAD_COLOR)
if img.shape[:2] != (res, res):
img = cv2.resize(img, (res, res))
start_h, start_w = (r + 1) * gap + r * res, (c + 1) * gap + c * res
merge[start_h : start_h + res, start_w : start_w + res] = img
cv2.imwrite(str(name), merge)