Skip to content

Commit f0aec83

Browse files
committed
loss function faster implementation
1 parent 5bf8386 commit f0aec83

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

loss.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch.nn.functional import softmax
99

1010
empirical_probs = (0.5*np.load('prior_probs.npy') + (0.5/313))**(-1)
11-
empirical_probs = empirical_probs/np.sum(empirical_probs)
11+
empirical_probs = torch.from_numpy(empirical_probs/np.sum(empirical_probs)).cuda()
1212

1313
def GaussianKernel(v1, v2, sigma=5):
1414
return np.exp(-np.linalg.norm(v1-v2, 2)**2/(2.*sigma**2))
@@ -42,11 +42,11 @@ def loadColorData(filename):
4242
def v(Z):
4343
args = torch.argmax(Z,dim=1)
4444
ant_size = tuple(args.size())
45-
return torch.from_numpy(empirical_probs[args.cpu().reshape(-1)].reshape(ant_size)).cuda()
45+
return empirical_probs[args.reshape(-1)].reshape(ant_size).cuda()
4646

47-
def classificationLoss(Z_hat, Z):
4847

49-
loss = - torch.sum(v(Z) * torch.sum(Z.cuda() * torch.log(softmax(Z_hat, dim=1)),dim=1))/Z.size(0)
48+
def classificationLoss(Z_hat, Z):
49+
loss = - torch.sum(v(Z) * torch.sum(Z * torch.log(softmax(Z_hat, dim=1)),dim=1))/Z.size(0)
5050
return loss
5151

5252
def regressorLoss(Z_hat,Z):

0 commit comments

Comments
 (0)