-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathwls.py
69 lines (54 loc) · 1.96 KB
/
wls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import cv2
import numpy as np
import scipy.sparse as sparse
from scipy.sparse.linalg import spsolve
def wls_filter(in_, data_term_weight, guidance, lambda_=0.05, small_num=0.00001):
h, w, _ = guidance.shape
k = h * w
guidance = cv2.cvtColor(guidance, cv2.COLOR_RGB2GRAY).tolist()
# Compute affinities between adjacent pixels based on gradients of guidance
dy = np.diff(guidance, axis=0)
dy = - lambda_ / (np.abs(dy) ** 2 + small_num)
dy = np.pad(dy, ([0, 1], [0, 0]), 'constant', constant_values=0)
dy = dy.flatten('F').T
dx = np.diff(guidance, axis=1)
dx = -lambda_ / (np.abs(dx) ** 2 + small_num)
dx = np.pad(dx, ([0, 0], [0, 1]), 'constant', constant_values=0)
dx = dx.flatten(order='F').T
B = np.vstack((dx, dy))
d = [-h, -1]
tmp = sparse.spdiags(B, d, k, k)
# row vector
ea = dx
temp = [dx]
we = np.pad(temp, ([0, 0], [h, 0]))[0]
we = we[0:len(we) - h]
# row vector
so = dy
temp = [dy]
no = np.pad(temp, ([0, 0], [1, 0]))[0]
no = no[0:len(no) - 1]
# row vector
D = -(ea + we + so + no)
Asmoothness = tmp + tmp.T + sparse.spdiags(D, 0, k, k)
# Normalize data weight
data_weight = data_term_weight - np.min(data_term_weight)
data_weight = data_weight / (np.max(data_weight) + small_num)
# Make sure we have a boundary condition for the top line:
# It will be the minimum of the transmission in each column
# With reliability 0.8
reliability_mask = np.where(data_weight[0] < 0.6, 1, 0)
in_row1 = np.min(in_, axis=0)
# print(reliability_mask)
for i in range(w):
if reliability_mask[i] == 1:
data_weight[0][i] = 0.8
for i in range(w):
if reliability_mask[i] == 1:
in_[0][i] = in_row1[i]
Adata = sparse.spdiags(data_weight.flatten(), 0, k, k)
A = Asmoothness + Adata
b = Adata * in_.flatten(order='F').T
X = spsolve(A, b)
out = np.reshape(X, [h, w], order='F')
return out