You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is quite a simple issue but the function nnx.metric.Accuracy only works for a network that returns more than one output, i.e. super().update(values=(logits.argmax(axis=-1) == labels)). Not sure if this is by design.
I was playing around with a simple binary classifier than only returned a single logit per observation. The Accuracy function gives nonsense for this case.
Anyway, might be worth noting in the doc string what the function expects.
The text was updated successfully, but these errors were encountered:
Hello,
This is quite a simple issue but the function
nnx.metric.Accuracy
only works for a network that returns more than one output, i.e.super().update(values=(logits.argmax(axis=-1) == labels))
. Not sure if this is by design.I was playing around with a simple binary classifier than only returned a single logit per observation. The Accuracy function gives nonsense for this case.
Anyway, might be worth noting in the doc string what the function expects.
The text was updated successfully, but these errors were encountered: