-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Group metrics by labels #245
base: master
Are you sure you want to change the base?
Changes from 5 commits
76147df
6387817
7e9f48e
4243f6f
12886e1
860601d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import Any, List | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
class LabelsGrouper: | ||
def __init__(self, annotations_or_predictions_list: List[Any]): | ||
self.items = annotations_or_predictions_list | ||
if len(self.items) > 0: | ||
assert hasattr( | ||
self.items[0], "label" | ||
), f"Expected items to have attribute 'label' found none on {repr(self.items[0])}" | ||
self.codes, self.labels = pd.factorize( | ||
[item.label for item in self.items] | ||
) | ||
self.group_idx = 0 | ||
|
||
def __iter__(self): | ||
self.group_idx = 0 | ||
return self | ||
|
||
def __next__(self): | ||
if self.group_idx >= len(self.labels): | ||
raise StopIteration | ||
label = self.labels[self.group_idx] | ||
label_items = list( | ||
np.take(self.items, np.where(self.codes == self.group_idx)[0]) | ||
) | ||
self.group_idx += 1 | ||
return label, label_items | ||
|
||
def label_group(self, label: str) -> List[Any]: | ||
if len(self.items) == 0: | ||
return [] | ||
idx = np.where(self.labels == label)[0] | ||
if idx >= 0: | ||
label_items = list( | ||
np.take(self.items, np.where(self.codes == idx)[0]) | ||
) | ||
return label_items | ||
else: | ||
return [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,16 @@ | ||
import sys | ||
from abc import abstractmethod | ||
from typing import List, Union | ||
from collections import defaultdict | ||
from typing import Dict, List | ||
|
||
import numpy as np | ||
|
||
from nucleus.annotation import AnnotationList, BoxAnnotation, PolygonAnnotation | ||
from nucleus.prediction import BoxPrediction, PolygonPrediction, PredictionList | ||
from nucleus.annotation import AnnotationList | ||
from nucleus.prediction import PredictionList | ||
|
||
from .base import Metric, ScalarResult | ||
from .base import GroupedScalarResult, Metric, ScalarResult | ||
from .filters import confidence_filter, polygon_label_filter | ||
from .label_grouper import LabelsGrouper | ||
from .metric_utils import compute_average_precision | ||
from .polygon_utils import ( | ||
BoxOrPolygonAnnotation, | ||
|
@@ -80,7 +82,7 @@ def eval( | |
|
||
def __init__( | ||
self, | ||
enforce_label_match: bool = False, | ||
enforce_label_match: bool = True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. :nit: please adjust comment below |
||
confidence_threshold: float = 0.0, | ||
): | ||
"""Initializes PolygonMetric abstract object. | ||
|
@@ -93,6 +95,31 @@ def __init__( | |
assert 0 <= confidence_threshold <= 1 | ||
self.confidence_threshold = confidence_threshold | ||
|
||
def eval_grouped( | ||
self, | ||
annotations: List[BoxOrPolygonAnnotation], | ||
predictions: List[BoxOrPolygonPrediction], | ||
) -> GroupedScalarResult: | ||
grouped_annotations = LabelsGrouper(annotations) | ||
grouped_predictions = LabelsGrouper(predictions) | ||
results = {} | ||
for label, label_annotations in grouped_annotations: | ||
# TODO(gunnar): Enforce label match -> Why is that a parameter? Should we generally allow IOU matches | ||
# between different labels?!? | ||
Comment on lines
+107
to
+108
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general we should have an option to allow this. E.g. you need to compute matches across the classes for the confusion matrix. |
||
match_predictions = ( | ||
grouped_predictions.label_group(label) | ||
if self.enforce_label_match | ||
else predictions | ||
) | ||
eval_fn = label_match_wrapper(self.eval) | ||
result = eval_fn( | ||
label_annotations, | ||
match_predictions, | ||
enforce_label_match=self.enforce_label_match, | ||
) | ||
results[label] = result | ||
return GroupedScalarResult(group_to_scalar=results) | ||
|
||
@abstractmethod | ||
def eval( | ||
self, | ||
|
@@ -102,28 +129,34 @@ def eval( | |
# Main evaluation function that subclasses must override. | ||
pass | ||
|
||
def aggregate_score(self, results: List[ScalarResult]) -> ScalarResult: # type: ignore[override] | ||
return ScalarResult.aggregate(results) | ||
def aggregate_score(self, results: List[GroupedScalarResult]) -> Dict[str, ScalarResult]: # type: ignore[override] | ||
label_to_values = defaultdict(list) | ||
for item_result in results: | ||
for label, label_result in item_result.group_to_scalar.items(): | ||
label_to_values[label].append(label_result) | ||
scores = { | ||
label: ScalarResult.aggregate(values) | ||
for label, values in label_to_values.items() | ||
} | ||
return scores | ||
|
||
def __call__( | ||
self, annotations: AnnotationList, predictions: PredictionList | ||
) -> ScalarResult: | ||
) -> Dict[str, ScalarResult]: | ||
if self.confidence_threshold > 0: | ||
predictions = confidence_filter( | ||
predictions, self.confidence_threshold | ||
) | ||
polygon_annotations: List[Union[BoxAnnotation, PolygonAnnotation]] = [] | ||
polygon_annotations: List[BoxOrPolygonAnnotation] = [] | ||
polygon_annotations.extend(annotations.box_annotations) | ||
polygon_annotations.extend(annotations.polygon_annotations) | ||
polygon_predictions: List[Union[BoxPrediction, PolygonPrediction]] = [] | ||
polygon_predictions: List[BoxOrPolygonPrediction] = [] | ||
polygon_predictions.extend(predictions.box_predictions) | ||
polygon_predictions.extend(predictions.polygon_predictions) | ||
|
||
eval_fn = label_match_wrapper(self.eval) | ||
result = eval_fn( | ||
result = self.eval_grouped( | ||
polygon_annotations, | ||
polygon_predictions, | ||
enforce_label_match=self.enforce_label_match, | ||
) | ||
return result | ||
|
||
|
@@ -166,7 +199,7 @@ class PolygonIOU(PolygonMetric): | |
# TODO: Remove defaults once these are surfaced more cleanly to users. | ||
def __init__( | ||
self, | ||
enforce_label_match: bool = False, | ||
enforce_label_match: bool = True, | ||
iou_threshold: float = 0.0, | ||
confidence_threshold: float = 0.0, | ||
): | ||
|
@@ -234,7 +267,7 @@ class PolygonPrecision(PolygonMetric): | |
# TODO: Remove defaults once these are surfaced more cleanly to users. | ||
def __init__( | ||
self, | ||
enforce_label_match: bool = False, | ||
enforce_label_match: bool = True, | ||
iou_threshold: float = 0.5, | ||
confidence_threshold: float = 0.0, | ||
): | ||
|
@@ -303,7 +336,7 @@ class PolygonRecall(PolygonMetric): | |
# TODO: Remove defaults once these are surfaced more cleanly to users. | ||
def __init__( | ||
self, | ||
enforce_label_match: bool = False, | ||
enforce_label_match: bool = True, | ||
iou_threshold: float = 0.5, | ||
confidence_threshold: float = 0.0, | ||
): | ||
|
@@ -460,7 +493,7 @@ def __init__( | |
0 <= iou_threshold <= 1 | ||
), "IoU threshold must be between 0 and 1." | ||
self.iou_threshold = iou_threshold | ||
super().__init__(enforce_label_match=False, confidence_threshold=0) | ||
super().__init__(enforce_label_match=True, confidence_threshold=0) | ||
|
||
def eval( | ||
self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:nit: avoid hardcoded strings and instead make them constants