Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect aggregation for classification metrics #2758

Open
salvomcl opened this issue Jan 30, 2025 · 3 comments
Open

Incorrect aggregation for classification metrics #2758

salvomcl opened this issue Jan 30, 2025 · 3 comments
Labels
bug Something isn't working

Comments

@salvomcl
Copy link
Contributor

salvomcl commented Jan 30, 2025

Hi everyone, I've noticed that currently (all?) the metrics in the library are updated using the update method of NumericMetricState link.
However, when using metrics such as PrecisionMetric and RecallMetric, this aggregation approach can lead to incorrect results.

Consider the Recall metric with two batches of data:

Batch True Positives (TP) True Negatives (TN) False Positives (FP) False Negatives (FN) Recall
1 40 50 10 50 0.44 (40/90)
2 20 90 30 10 0.67 (20/30)

With the current implementation, the final recall at the end of an epoch is computed as (150*0.44+150*0.67)/(150+150)~0.55.
However, if we computed recall using the full dataset, the correct value would be: (40+20)/(40+20+50+10)=0.5.

This discrepancy arises because the current approach averages batch-wise recall scores instead of aggregating the confusion matrix values (TP, FP, TN, FN).

To ensure accurate precision and recall calculations, we should store the confusion matrix as part of the metric state. This approach is used in torchmetrics link

This however is not possible for more complex metric such as AurocMetric where the input and target tensors are required. I think this is the reason why torchmetrics doesn't provide an update method for this metric.

To summarize, metrics such as precision and recall should be computed by aggregating confusion matrix values instead of averaging per-batch scores.

For metrics such as the AUROC instead, If we maintain an approximation, it should be well documented to avoid misunderstanding with the users; otherwise, we could remove the update method (following torchmetrics practice).

What are your thoughts on this?

Thanks!

@laggui laggui added the bug Something isn't working label Jan 31, 2025
@laggui
Copy link
Member

laggui commented Jan 31, 2025

Thanks for flagging this issue, you're absolutely right! The metric calculation at each batch is correct, but the aggregation is inaccurate.

There are actually two problems with the way it's currently implemented:

  1. As you pointed out, the state of these metrics is currently a numeric

    self.state.update(
    100.0 * metric,
    sample_size,
    FormatOptions::new(Self::NAME).unit("%").precision(2),
    )

    But this is incorrect, because the metric cannot be aggregated like other numeric ratio metrics (e.g., accuracy) as it is here:
    /// Update the state.
    pub fn update(&mut self, value: f64, batch_size: usize, format: FormatOptions) -> MetricEntry {
    self.sum += value * batch_size as f64;
    self.count += batch_size;
    self.current = value;
    let value_current = value;
    let value_running = self.sum / self.count as f64;
    // Numeric metric state is an aggregated value
    let serialized = NumericEntry::Aggregated(value_current, batch_size).serialize();

    The number of elements does not define the aggregation, but rather the number of (TP, FP, TN, FN). This could be fixed by introducing a new state instead of NumericMetricState for these metrics.

  2. The logger serializes the state for each entry (batch). The entries for a specific are aggregated based on the serialized values:

    // Accurately compute the aggregated value based on the *actual* number of points
    // since not all mini-batches are guaranteed to have the specified batch size
    let (sum, num_points) = points
    .into_iter()
    .map(|entry| match entry {
    NumericEntry::Value(v) => (v, 1),
    // Right now the mean is the only aggregate available, so we can assume that the sum
    // of an entry corresponds to (value * number of elements)
    NumericEntry::Aggregated(v, n) => (v * n as f64, n),
    })
    .reduce(|(acc_v, acc_n), (v, n)| (acc_v + v, acc_n + n))
    .unwrap();
    let value = match aggregate {
    Aggregate::Mean => sum / num_points as f64,
    };

    We probably to add a new variant to the numeric entry types. It doesn't need to actually capture all the (TP, TN, FP, FN), we just need to correctly preserve the ratio (so the correct numerator and denominator, which we can then accurately aggregate). Now that I think about it, this general case could probably replace the NumericEntry::Aggregated(value, num_elems) variant.
    /// Numeric metric entry.
    pub enum NumericEntry {
    /// Single numeric value.
    Value(f64),
    /// Aggregated numeric (value, number of elements).
    Aggregated(f64, usize),
    }

@salvomcl
Copy link
Contributor Author

salvomcl commented Feb 4, 2025

Yeah, preserving the ratio will work for simple metrics such as recall or precision (we simply need to replace batch_size with the right denominator when updating).

However this is not true for other metrics, such as auroc, which requires the full input and output tensors. For them I'm not sure whether the library should provide an (approximated) update method.

@laggui
Copy link
Member

laggui commented Feb 4, 2025

For the metric state, this can be kept in a new state type so we can accumulate the predictions and targets.

But you're right that for logging an entry, this will be a bit more complicated. We don't want to serialize all this info, but there are a couple ways to approach this. Two of which come to mind:

We could add a new event at the end of each epoch to accurately compute the aggregated metric from the state, which would be logged.

Another option is to keep the running metric as it is currently done for the numeric state (since it's just a ratio). At the end of the epoch the running value is the aggregated value, though it might be a bit wasteful for more complex metrics that require more compute.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants