File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change 8
8
from torch .nn .functional import softmax
9
9
10
10
empirical_probs = (0.5 * np .load ('prior_probs.npy' ) + (0.5 / 313 ))** (- 1 )
11
- empirical_probs = empirical_probs / np .sum (empirical_probs )
11
+ empirical_probs = torch . from_numpy ( empirical_probs / np .sum (empirical_probs )). cuda ( )
12
12
13
13
def GaussianKernel (v1 , v2 , sigma = 5 ):
14
14
return np .exp (- np .linalg .norm (v1 - v2 , 2 )** 2 / (2. * sigma ** 2 ))
@@ -42,11 +42,11 @@ def loadColorData(filename):
42
42
def v (Z ):
43
43
args = torch .argmax (Z ,dim = 1 )
44
44
ant_size = tuple (args .size ())
45
- return torch . from_numpy ( empirical_probs [args .cpu (). reshape (- 1 )].reshape (ant_size ) ).cuda ()
45
+ return empirical_probs [args .reshape (- 1 )].reshape (ant_size ).cuda ()
46
46
47
- def classificationLoss (Z_hat , Z ):
48
47
49
- loss = - torch .sum (v (Z ) * torch .sum (Z .cuda () * torch .log (softmax (Z_hat , dim = 1 )),dim = 1 ))/ Z .size (0 )
48
+ def classificationLoss (Z_hat , Z ):
49
+ loss = - torch .sum (v (Z ) * torch .sum (Z * torch .log (softmax (Z_hat , dim = 1 )),dim = 1 ))/ Z .size (0 )
50
50
return loss
51
51
52
52
def regressorLoss (Z_hat ,Z ):
You can’t perform that action at this time.
0 commit comments