7
7
from stratigraphy.evaluation.evaluation_dataclasses import Metrics
8
8
9
9
10
- class DatasetMetrics :
10
+ class OverallMetrics :
11
11
"""Keeps track of a particular metrics for all documents in a dataset."""
12
12
13
13
# TODO: Currently, some methods for averaging metrics are in the Metrics class.
@@ -58,23 +58,32 @@ def get_metrics_list(self) -> list[Metrics]:
58
58
return list(self.metrics.values())
59
59
60
60
61
- class DatasetMetricsCatalog :
61
+ class OverallMetricsCatalog :
62
62
"""Keeps track of all different relevant metrics that are computed for a dataset."""
63
63
64
- def __init__(self):
65
- self.metrics: dict[str, DatasetMetrics] = {}
64
+ def __init__(self, languages: set[str]):
65
+ self.layer_metrics = OverallMetrics()
66
+ self.depth_interval_metrics = OverallMetrics()
67
+ self.groundwater_metrics = OverallMetrics()
68
+ self.groundwater_depth_metrics = OverallMetrics()
69
+ self.languages = languages
70
+
71
+ # Initialize language-specific metrics
72
+ for lang in languages:
73
+ setattr(self, f"{lang}_layer_metrics", OverallMetrics())
74
+ setattr(self, f"{lang}_depth_interval_metrics", OverallMetrics())
66
75
67
76
def document_level_metrics_df(self) -> pd.DataFrame:
68
77
"""Return a DataFrame with all the document level metrics."""
69
78
all_series = [
70
- self.metrics["layer"] .to_dataframe("F1", lambda metric: metric.f1),
71
- self.metrics["layer"] .to_dataframe("precision", lambda metric: metric.precision),
72
- self.metrics["layer"] .to_dataframe("recall", lambda metric: metric.recall),
73
- self.metrics["depth_interval"] .to_dataframe("Depth_interval_accuracy", lambda metric: metric.precision),
74
- self.metrics["layer"] .to_dataframe("Number Elements", lambda metric: metric.tp + metric.fn),
75
- self.metrics["layer"] .to_dataframe("Number wrong elements", lambda metric: metric.fp + metric.fn),
76
- self.metrics["groundwater"] .to_dataframe("groundwater", lambda metric: metric.f1),
77
- self.metrics["groundwater_depth"] .to_dataframe("groundwater_depth", lambda metric: metric.f1),
79
+ self.layer_metrics .to_dataframe("F1", lambda metric: metric.f1),
80
+ self.layer_metrics .to_dataframe("precision", lambda metric: metric.precision),
81
+ self.layer_metrics .to_dataframe("recall", lambda metric: metric.recall),
82
+ self.depth_interval_metrics .to_dataframe("Depth_interval_accuracy", lambda metric: metric.precision),
83
+ self.layer_metrics .to_dataframe("Number Elements", lambda metric: metric.tp + metric.fn),
84
+ self.layer_metrics .to_dataframe("Number wrong elements", lambda metric: metric.fp + metric.fn),
85
+ self.groundwater_metrics .to_dataframe("groundwater", lambda metric: metric.f1),
86
+ self.groundwater_depth_metrics .to_dataframe("groundwater_depth", lambda metric: metric.f1),
78
87
]
79
88
document_level_metrics = pd.DataFrame()
80
89
for series in all_series:
@@ -86,21 +95,19 @@ def metrics_dict(self) -> dict[str, float]:
86
95
# Initialize a defaultdict to automatically return 0.0 for missing keys
87
96
result = defaultdict(lambda: None)
88
97
89
- # Safely compute groundwater metrics using .get() to avoid KeyErrors
90
- groundwater_metrics = Metrics.micro_average(
91
- self.metrics.get("groundwater", DatasetMetrics()).get_metrics_list()
92
- )
93
- groundwater_depth_metrics = Metrics.micro_average(
94
- self.metrics.get("groundwater_depth", DatasetMetrics()).get_metrics_list()
95
- )
98
+ # Compute the micro-average metrics for the groundwater and groundwater depth metrics
99
+ groundwater_metrics = Metrics.micro_average(self.groundwater_metrics.metrics.values())
100
+ groundwater_depth_metrics = Metrics.micro_average(self.groundwater_depth_metrics.metrics.values())
96
101
97
102
# Populate the basic metrics
98
103
result.update(
99
104
{
100
- "F1": self.metrics.get("layer", DatasetMetrics()).pseudo_macro_f1(),
101
- "recall": self.metrics.get("layer", DatasetMetrics()).macro_recall(),
102
- "precision": self.metrics.get("layer", DatasetMetrics()).macro_precision(),
103
- "depth_interval_accuracy": self.metrics.get("depth_interval", DatasetMetrics()).macro_precision(),
105
+ "F1": self.layer_metrics.pseudo_macro_f1() if self.layer_metrics else None,
106
+ "recall": self.layer_metrics.macro_recall() if self.layer_metrics else None,
107
+ "precision": self.layer_metrics.macro_precision() if self.layer_metrics else None,
108
+ "depth_interval_accuracy": self.depth_interval_metrics.macro_precision()
109
+ if self.depth_interval_metrics
110
+ else None,
104
111
"groundwater_f1": groundwater_metrics.f1,
105
112
"groundwater_recall": groundwater_metrics.recall,
106
113
"groundwater_precision": groundwater_metrics.precision,
@@ -111,16 +118,16 @@ def metrics_dict(self) -> dict[str, float]:
111
118
)
112
119
113
120
# Add dynamic language-specific metrics only if they exist
114
- for lang in ["de", "fr"] :
115
- layer_key = f"{lang}_layer "
116
- depth_key = f"{lang}_depth_interval "
121
+ for lang in self.languages :
122
+ layer_key = f"{lang}_layer_metrics "
123
+ depth_key = f"{lang}_depth_interval_metrics "
117
124
118
- if layer_key in self.metrics:
119
- result[f"{lang}_F1"] = self.metrics[ layer_key] .pseudo_macro_f1()
120
- result[f"{lang}_recall"] = self.metrics[ layer_key] .macro_recall()
121
- result[f"{lang}_precision"] = self.metrics[ layer_key] .macro_precision()
125
+ if getattr(self, layer_key) and getattr( self, layer_key) .metrics:
126
+ result[f"{lang}_F1"] = getattr( self, layer_key) .pseudo_macro_f1()
127
+ result[f"{lang}_recall"] = getattr( self, layer_key) .macro_recall()
128
+ result[f"{lang}_precision"] = getattr( self, layer_key) .macro_precision()
122
129
123
- if depth_key in self.metrics:
124
- result[f"{lang}_depth_interval_accuracy"] = self.metrics[ depth_key] .macro_precision()
130
+ if getattr(self, depth_key) and getattr( self, depth_key) .metrics:
131
+ result[f"{lang}_depth_interval_accuracy"] = getattr( self, depth_key) .macro_precision()
125
132
126
133
return dict(result) # Convert defaultdict back to a regular dict
0 commit comments