Skip to content

Commit 14dadf1

Browse files
author
Orange15
committed
Update dec.py
fix bug
1 parent 2389cc0 commit 14dadf1

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

example/dec/dec.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def cluster(self, X, y=None, update_interval=None):
9797
test_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=False,
9898
last_batch_handle='pad')
9999
args = {k: mx.nd.array(v.asnumpy(), ctx=self.xpu) for k, v in self.args.items()}
100-
z = model.extract_feature(self.feature, args, test_iter, N, self.xpu).values()[0]
100+
z = model.extract_feature(self.feature, args, None, test_iter, N, self.xpu).values()[0]
101101
kmeans = KMeans(self.num_centers, n_init=20)
102102
kmeans.fit(z)
103103
args['dec_mu'][:] = kmeans.cluster_centers_
@@ -112,7 +112,7 @@ def ce(label, pred):
112112
self.y_pred = np.zeros((X.shape[0]))
113113
def refresh(i):
114114
if i%update_interval == 0:
115-
z = model.extract_feature(self.feature, args, test_iter, N, self.xpu).values()[0]
115+
z = model.extract_feature(self.feature, args, None, test_iter, N, self.xpu).values()[0]
116116
p = np.zeros((z.shape[0], self.num_centers))
117117
self.dec_op.forward([z, args['dec_mu'].asnumpy()], [p])
118118
y_pred = p.argmax(axis=1)
@@ -132,7 +132,7 @@ def refresh(i):
132132
solver.set_iter_start_callback(refresh)
133133
solver.set_monitor(Monitor(50))
134134

135-
solver.solve(self.xpu, self.loss, args, self.args_grad,
135+
solver.solve(self.xpu, self.loss, args, self.args_grad, None,
136136
train_iter, 0, 1000000000, {}, False)
137137
self.end_args = args
138138
if y is not None:
@@ -153,4 +153,4 @@ def mnist_exp(xpu):
153153
if __name__ == '__main__':
154154
logging.basicConfig(level=logging.INFO)
155155
mnist_exp(mx.gpu(0))
156-
156+

0 commit comments

Comments
 (0)