Skip to content

Commit

Permalink
fix(tf2): predict_classes deprecated (#271)
Browse files Browse the repository at this point in the history
  • Loading branch information
crlotwhite authored Jul 14, 2022
1 parent 2b83f17 commit 60c707e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tf2/tf2-06-1-softmax_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@
print('--------------')
# or use argmax embedded method, predict_classes
c = tf.model.predict(np.array([[1, 1, 0, 1]]))
c_onehot = tf.model.predict_classes(np.array([[1, 1, 0, 1]]))
c_onehot = np.argmax(c, axis=-1)
print(c, c_onehot)

print('--------------')
all = tf.model.predict(np.array([[1, 11, 7, 9], [1, 3, 4, 3], [1, 1, 0, 1]]))
all_onehot = tf.model.predict_classes(np.array([[1, 11, 7, 9], [1, 3, 4, 3], [1, 1, 0, 1]]))
all_onehot = np.argmax(all, axis=-1)
print(all, all_onehot)
4 changes: 2 additions & 2 deletions tf2/tf2-06-2-softmax_zoo_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@

# Single data test
test_data = np.array([[0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0]]) # expected prediction == 3 (feathers)
print(tf.model.predict(test_data), tf.model.predict_classes(test_data))
print(tf.model.predict(test_data), np.argmax(tf.model.predict(test_data), axis=-1))

# Full x_data test
pred = tf.model.predict_classes(x_data)
pred = np.argmax(tf.model.predict(x_data), axis=-1)
for p, y in zip(pred, y_data.flatten()):
print("[{}] Prediction: {} True Y: {}".format(p == int(y), p, int(y)))

0 comments on commit 60c707e

Please sign in to comment.