forked from EdgarLefevre/SalienceNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathGradLoss.py
More file actions
68 lines (53 loc) · 2.13 KB
/
GradLoss.py
File metadata and controls
68 lines (53 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class GradientLoss(nn.Module):
def __init__(self):
super(GradientLoss, self).__init__()
kernel1 = [[-1., -2., -1.],
[0., 0., 0.],
[1., 2., 1.]]
kernel1 = torch.FloatTensor(kernel1).unsqueeze(0).unsqueeze(0)
self.weight1 = nn.Parameter(data=kernel1, requires_grad=False).cuda()
kernel2 = [[-1., 0., 1.],
[-2., 0., 2.],
[-1., 0., 1.]]
kernel2 = torch.FloatTensor(kernel2).unsqueeze(0).unsqueeze(0).cuda()
self.weight2 = nn.Parameter(data=kernel2, requires_grad=False)
def forward(self, x, y):
x1 = x[:, 0]
x2 = x[:, 0]
x3 = x[:, 0]
x1 = F.conv2d(x1.unsqueeze(1), self.weight1, padding=2)
x2 = F.conv2d(x2.unsqueeze(1), self.weight1, padding=2)
x3 = F.conv2d(x3.unsqueeze(1), self.weight1, padding=2)
x = torch.cat([x1, x2, x3], dim=1)
y1 = y[:, 0]
y2 = y[:, 0]
y3 = y[:, 0]
y1 = F.conv2d(y1.unsqueeze(1), self.weight1, padding=2)
y2 = F.conv2d(y2.unsqueeze(1), self.weight1, padding=2)
y3 = F.conv2d(y3.unsqueeze(1), self.weight1, padding=2)
y = torch.cat([y1, y2, y3], dim=1)
loss1=torch.mean(torch.mean((x-y)**2))/100.0
x1 = x[:, 0]
x2 = x[:, 0]
x3 = x[:, 0]
x1 = F.conv2d(x1.unsqueeze(1), self.weight2, padding=2)
x2 = F.conv2d(x2.unsqueeze(1), self.weight2, padding=2)
x3 = F.conv2d(x3.unsqueeze(1), self.weight2, padding=2)
x = torch.cat([x1, x2, x3], dim=1)
y1 = y[:, 0]
y2 = y[:, 0]
y3 = y[:, 0]
y1 = F.conv2d(y1.unsqueeze(1), self.weight2, padding=2)
y2 = F.conv2d(y2.unsqueeze(1), self.weight2, padding=2)
y3 = F.conv2d(y3.unsqueeze(1), self.weight2, padding=2)
y = torch.cat([y1, y2, y3], dim=1)
loss2=torch.mean(torch.mean((x-y)**2))/100.0
loss=torch.sqrt(loss1*loss1+loss2*loss2)*2
#print(loss)
#loss=(torch.log(loss))/100.0
#print(loss)
return loss