You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Since this issue (#23) it should be possible to have a y array of longer than 2, however I cannot get it to work.
I have 5 classes, so I tried to use precision for 1 label with:
precision = keras_metrics.precision(label=0)
However, this results in the error:
~/anaconda3/envs/dl/lib/python3.6/site-packages/keras_metrics/metrics.py in _categorical(self, y_true, y_pred, dtype)
46 return self._binary(y_true, y_pred, dtype, label=1)
47 elif labels > 2:
---> 48 raise ValueError("With 2 and more output classes a "
49 "metric label must be specified")
50
ValueError: With 2 and more output classes a metric label must be specified
It seems like it only looks at the y shape, and not whether a label is specified?
Would this be better?:
def _categorical(self, y_true, y_pred, dtype, label=None):
labels = y_pred.shape[-1]
if labels == 2:
label=1
if labels > 2 and label: # label != None
raise ValueError("With 2 and more output classes a metric label must be specified")
else:
return self._binary(y_true, y_pred, dtype, label=label)
The text was updated successfully, but these errors were encountered:
Could you, please, post a model configuration or at least the last layer (example of output data), so I could understand why this fix is necessary? Thank you in advance.
@ybubnov I can confirm this issue. I tried with f1 = keras_metrics.f1_score(label=1) self.model.compile(optimizer="Adam", loss='binary_crossentropy', metrics=[f1])
and got the same error. My last layer currently is output_layer = keras.layers.Dense(self.n_classes, activation="softmax")(dense2)
The labels are 1-hot-encoded and have 12 classes, so the values have to be in the form np.array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.])
Since this issue (#23) it should be possible to have a y array of longer than 2, however I cannot get it to work.
I have 5 classes, so I tried to use precision for 1 label with:
However, this results in the error:
It seems like it only looks at the y shape, and not whether a label is specified?
Would this be better?:
The text was updated successfully, but these errors were encountered: