-
Notifications
You must be signed in to change notification settings - Fork 490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incorrect aggregation for classification metrics #2758
Comments
Thanks for flagging this issue, you're absolutely right! The metric calculation at each batch is correct, but the aggregation is inaccurate. There are actually two problems with the way it's currently implemented:
|
Yeah, preserving the ratio will work for simple metrics such as recall or precision (we simply need to replace However this is not true for other metrics, such as auroc, which requires the full input and output tensors. For them I'm not sure whether the library should provide an (approximated) update method. |
For the metric state, this can be kept in a new state type so we can accumulate the predictions and targets. But you're right that for logging an entry, this will be a bit more complicated. We don't want to serialize all this info, but there are a couple ways to approach this. Two of which come to mind: We could add a new event at the end of each epoch to accurately compute the aggregated metric from the state, which would be logged. Another option is to keep the running metric as it is currently done for the numeric state (since it's just a ratio). At the end of the epoch the running value is the aggregated value, though it might be a bit wasteful for more complex metrics that require more compute. |
Hi everyone, I've noticed that currently (all?) the metrics in the library are updated using the
update
method ofNumericMetricState
link.However, when using metrics such as
PrecisionMetric
andRecallMetric
, this aggregation approach can lead to incorrect results.Consider the Recall metric with two batches of data:
With the current implementation, the final recall at the end of an epoch is computed as
(150*0.44+150*0.67)/(150+150)~0.55
.However, if we computed recall using the full dataset, the correct value would be:
(40+20)/(40+20+50+10)=0.5
.This discrepancy arises because the current approach averages batch-wise recall scores instead of aggregating the confusion matrix values (TP, FP, TN, FN).
To ensure accurate precision and recall calculations, we should store the confusion matrix as part of the metric state. This approach is used in
torchmetrics
linkThis however is not possible for more complex metric such as
AurocMetric
where the input and target tensors are required. I think this is the reason whytorchmetrics
doesn't provide anupdate
method for this metric.To summarize, metrics such as precision and recall should be computed by aggregating confusion matrix values instead of averaging per-batch scores.
For metrics such as the AUROC instead, If we maintain an approximation, it should be well documented to avoid misunderstanding with the users; otherwise, we could remove the update method (following
torchmetrics
practice).What are your thoughts on this?
Thanks!
The text was updated successfully, but these errors were encountered: