Skip to content

Commit 60c707e

Browse files
authored
fix(tf2): predict_classes deprecated (#271)
1 parent 2b83f17 commit 60c707e

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

tf2/tf2-06-1-softmax_classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@
4848
print('--------------')
4949
# or use argmax embedded method, predict_classes
5050
c = tf.model.predict(np.array([[1, 1, 0, 1]]))
51-
c_onehot = tf.model.predict_classes(np.array([[1, 1, 0, 1]]))
51+
c_onehot = np.argmax(c, axis=-1)
5252
print(c, c_onehot)
5353

5454
print('--------------')
5555
all = tf.model.predict(np.array([[1, 11, 7, 9], [1, 3, 4, 3], [1, 1, 0, 1]]))
56-
all_onehot = tf.model.predict_classes(np.array([[1, 11, 7, 9], [1, 3, 4, 3], [1, 1, 0, 1]]))
56+
all_onehot = np.argmax(all, axis=-1)
5757
print(all, all_onehot)

tf2/tf2-06-2-softmax_zoo_classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828

2929
# Single data test
3030
test_data = np.array([[0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0]]) # expected prediction == 3 (feathers)
31-
print(tf.model.predict(test_data), tf.model.predict_classes(test_data))
31+
print(tf.model.predict(test_data), np.argmax(tf.model.predict(test_data), axis=-1))
3232

3333
# Full x_data test
34-
pred = tf.model.predict_classes(x_data)
34+
pred = np.argmax(tf.model.predict(x_data), axis=-1)
3535
for p, y in zip(pred, y_data.flatten()):
3636
print("[{}] Prediction: {} True Y: {}".format(p == int(y), p, int(y)))

0 commit comments

Comments
 (0)