Skip to content

Commit b2998f0

Browse files
authored
Merge pull request #87 from swisstopo/LGVISIUM-80-Refactor-the-groundwater-object-and-its-evaluation
Close #LGVISIUM-80: Refactor the groundwater object and its evaluation
2 parents a5d0947 + 4e3d50e commit b2998f0

30 files changed

+1242
-653
lines changed

.github/workflows/pipeline_run.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
- uses: actions/checkout@v4
1313
- uses: actions/setup-python@v5
1414
with:
15-
python-version: '3.10'
15+
python-version: '3.11'
1616
- name: Create Environment and run pipeline
1717
shell: bash
1818
run: |

.github/workflows/pre-commit.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ jobs:
1212
- uses: actions/checkout@v3
1313
- uses: actions/setup-python@v3
1414
with:
15-
python-version: 3.10.14
15+
python-version: '3.11'
1616
- uses: pre-commit/[email protected]

.github/workflows/pytest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
- uses: actions/checkout@v4
1313
- uses: actions/setup-python@v5
1414
with:
15-
python-version: '3.10'
15+
python-version: '3.11'
1616
- name: Create Environment and run tests
1717
shell: bash
1818
run: |

example/example_gw_groundtruth.json

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
{
2+
"example_borehole_profile.pdf": {
3+
"groundwater": [
4+
{
5+
"date": "2016-04-18",
6+
"depth": 2.22,
7+
"elevation": 448.07
8+
},
9+
{
10+
"date": "2016-04-20",
11+
"depth": 3.22,
12+
"elevation": 447.07
13+
}
14+
],
15+
"layers": [],
16+
"metadata": {
17+
"coordinates": {
18+
"E": 615790,
19+
"N": 157500
20+
},
21+
"drilling_date": "1995-09-03",
22+
"drilling_methods": null,
23+
"original_name": "",
24+
"project_name": "",
25+
"reference_elevation": 788.6,
26+
"total_depth": null
27+
}
28+
},
29+
"example_borehole_profile_2.pdf": {
30+
"groundwater": [
31+
{
32+
"date": "2016-04-18",
33+
"depth": 2.22,
34+
"elevation": 448.07
35+
},
36+
{
37+
"date": "2016-04-20",
38+
"depth": 3.22,
39+
"elevation": 447.07
40+
}
41+
],
42+
"layers": [],
43+
"metadata": {
44+
"coordinates": {
45+
"E": 615790,
46+
"N": 157500
47+
},
48+
"drilling_date": "1995-09-03",
49+
"drilling_methods": null,
50+
"original_name": "",
51+
"project_name": "",
52+
"reference_elevation": 788.6,
53+
"total_depth": null
54+
}
55+
}
56+
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ name = "swissgeol-boreholes-dataextraction"
77
version = "0.0.1-dev"
88
description = "Python project to analyse borehole profiles."
99
readme = "README.md"
10-
requires-python = ">=3.10"
10+
requires-python = ">=3.11"
1111
dependencies = [
1212
"boto3",
1313
"pandas",

src/stratigraphy/annotations/draw.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
import fitz
88
import pandas as pd
99
from dotenv import load_dotenv
10+
from stratigraphy.data_extractor.data_extractor import FeatureOnPage
1011
from stratigraphy.depthcolumn.depthcolumn import DepthColumn
1112
from stratigraphy.depths_materials_column_pairs.depths_materials_column_pairs import DepthsMaterialsColumnPairs
12-
from stratigraphy.groundwater.groundwater_extraction import GroundwaterInformationOnPage
13-
from stratigraphy.layer.layer import LayerPrediction
13+
from stratigraphy.groundwater.groundwater_extraction import Groundwater
14+
from stratigraphy.layer.layer import Layer
1415
from stratigraphy.metadata.coordinate_extraction import Coordinate
1516
from stratigraphy.metadata.elevation_extraction import Elevation
1617
from stratigraphy.text.textblock import TextBlock
@@ -90,7 +91,7 @@ def draw_predictions(
9091
draw_coordinates(shape, coordinates)
9192
if elevation is not None and page_number == elevation.page:
9293
draw_elevation(shape, elevation)
93-
for groundwater_entry in file_prediction.groundwater_entries:
94+
for groundwater_entry in file_prediction.groundwater.groundwater:
9495
if page_number == groundwater_entry.page:
9596
draw_groundwater(shape, groundwater_entry)
9697
draw_depth_columns_and_material_rect(
@@ -103,7 +104,7 @@ def draw_predictions(
103104
page.derotation_matrix,
104105
[
105106
layer
106-
for layer in file_prediction.layers
107+
for layer in file_prediction.layers.get_all_layers()
107108
if layer.material_description.page_number == page_number
108109
],
109110
)
@@ -197,19 +198,19 @@ def draw_coordinates(shape: fitz.Shape, coordinates: Coordinate) -> None:
197198
shape.finish(color=fitz.utils.getColor("purple"))
198199

199200

200-
def draw_groundwater(shape: fitz.Shape, groundwater_entry: GroundwaterInformationOnPage) -> None:
201-
"""Draw a bounding box around the area of the page where the coordinates were extracted from.
201+
def draw_groundwater(shape: fitz.Shape, groundwater_entry: FeatureOnPage[Groundwater]) -> None:
202+
"""Draw a bounding box around the area of the page where the groundwater information was extracted from.
202203

203204
Args:
204205
shape (fitz.Shape): The shape object for drawing.
205-
groundwater_entry (GroundwaterInformationOnPage): The groundwater information to draw.
206+
groundwater_entry (FeatureOnPage[Groundwater]): The groundwater information to draw.
206207
"""
207208
shape.draw_rect(groundwater_entry.rect)
208209
shape.finish(color=fitz.utils.getColor("pink"))
209210

210211

211212
def draw_elevation(shape: fitz.Shape, elevation: Elevation) -> None:
212-
"""Draw a bounding box around the area of the page where the coordinates were extracted from.
213+
"""Draw a bounding box around the area of the page where the elevation were extracted from.
213214

214215
Args:
215216
shape (fitz.Shape): The shape object for drawing.
@@ -219,9 +220,7 @@ def draw_elevation(shape: fitz.Shape, elevation: Elevation) -> None:
219220
shape.finish(color=fitz.utils.getColor("blue"))
220221

221222

222-
def draw_material_descriptions(
223-
shape: fitz.Shape, derotation_matrix: fitz.Matrix, layers: list[LayerPrediction]
224-
) -> None:
223+
def draw_material_descriptions(shape: fitz.Shape, derotation_matrix: fitz.Matrix, layers: list[Layer]) -> None:
225224
"""Draw information about material descriptions on a pdf page.
226225

227226
In particular, this function:

src/stratigraphy/benchmark/ground_truth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class GroundTruth:
1414
"""Ground truth data for the stratigraphy benchmark."""
1515

16-
def __init__(self, path: Path):
16+
def __init__(self, path: Path) -> None:
1717
self.ground_truth = defaultdict(dict)
1818

1919
# Load the ground truth data

src/stratigraphy/benchmark/metrics.py

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from stratigraphy.evaluation.evaluation_dataclasses import Metrics
88

99

10-
class DatasetMetrics:
10+
class OverallMetrics:
1111
"""Keeps track of a particular metrics for all documents in a dataset."""
1212

1313
# TODO: Currently, some methods for averaging metrics are in the Metrics class.
@@ -58,23 +58,32 @@ def get_metrics_list(self) -> list[Metrics]:
5858
return list(self.metrics.values())
5959

6060

61-
class DatasetMetricsCatalog:
61+
class OverallMetricsCatalog:
6262
"""Keeps track of all different relevant metrics that are computed for a dataset."""
6363

64-
def __init__(self):
65-
self.metrics: dict[str, DatasetMetrics] = {}
64+
def __init__(self, languages: set[str]):
65+
self.layer_metrics = OverallMetrics()
66+
self.depth_interval_metrics = OverallMetrics()
67+
self.groundwater_metrics = OverallMetrics()
68+
self.groundwater_depth_metrics = OverallMetrics()
69+
self.languages = languages
70+
71+
# Initialize language-specific metrics
72+
for lang in languages:
73+
setattr(self, f"{lang}_layer_metrics", OverallMetrics())
74+
setattr(self, f"{lang}_depth_interval_metrics", OverallMetrics())
6675

6776
def document_level_metrics_df(self) -> pd.DataFrame:
6877
"""Return a DataFrame with all the document level metrics."""
6978
all_series = [
70-
self.metrics["layer"].to_dataframe("F1", lambda metric: metric.f1),
71-
self.metrics["layer"].to_dataframe("precision", lambda metric: metric.precision),
72-
self.metrics["layer"].to_dataframe("recall", lambda metric: metric.recall),
73-
self.metrics["depth_interval"].to_dataframe("Depth_interval_accuracy", lambda metric: metric.precision),
74-
self.metrics["layer"].to_dataframe("Number Elements", lambda metric: metric.tp + metric.fn),
75-
self.metrics["layer"].to_dataframe("Number wrong elements", lambda metric: metric.fp + metric.fn),
76-
self.metrics["groundwater"].to_dataframe("groundwater", lambda metric: metric.f1),
77-
self.metrics["groundwater_depth"].to_dataframe("groundwater_depth", lambda metric: metric.f1),
79+
self.layer_metrics.to_dataframe("F1", lambda metric: metric.f1),
80+
self.layer_metrics.to_dataframe("precision", lambda metric: metric.precision),
81+
self.layer_metrics.to_dataframe("recall", lambda metric: metric.recall),
82+
self.depth_interval_metrics.to_dataframe("Depth_interval_accuracy", lambda metric: metric.precision),
83+
self.layer_metrics.to_dataframe("Number Elements", lambda metric: metric.tp + metric.fn),
84+
self.layer_metrics.to_dataframe("Number wrong elements", lambda metric: metric.fp + metric.fn),
85+
self.groundwater_metrics.to_dataframe("groundwater", lambda metric: metric.f1),
86+
self.groundwater_depth_metrics.to_dataframe("groundwater_depth", lambda metric: metric.f1),
7887
]
7988
document_level_metrics = pd.DataFrame()
8089
for series in all_series:
@@ -86,21 +95,19 @@ def metrics_dict(self) -> dict[str, float]:
8695
# Initialize a defaultdict to automatically return 0.0 for missing keys
8796
result = defaultdict(lambda: None)
8897

89-
# Safely compute groundwater metrics using .get() to avoid KeyErrors
90-
groundwater_metrics = Metrics.micro_average(
91-
self.metrics.get("groundwater", DatasetMetrics()).get_metrics_list()
92-
)
93-
groundwater_depth_metrics = Metrics.micro_average(
94-
self.metrics.get("groundwater_depth", DatasetMetrics()).get_metrics_list()
95-
)
98+
# Compute the micro-average metrics for the groundwater and groundwater depth metrics
99+
groundwater_metrics = Metrics.micro_average(self.groundwater_metrics.metrics.values())
100+
groundwater_depth_metrics = Metrics.micro_average(self.groundwater_depth_metrics.metrics.values())
96101

97102
# Populate the basic metrics
98103
result.update(
99104
{
100-
"F1": self.metrics.get("layer", DatasetMetrics()).pseudo_macro_f1(),
101-
"recall": self.metrics.get("layer", DatasetMetrics()).macro_recall(),
102-
"precision": self.metrics.get("layer", DatasetMetrics()).macro_precision(),
103-
"depth_interval_accuracy": self.metrics.get("depth_interval", DatasetMetrics()).macro_precision(),
105+
"F1": self.layer_metrics.pseudo_macro_f1() if self.layer_metrics else None,
106+
"recall": self.layer_metrics.macro_recall() if self.layer_metrics else None,
107+
"precision": self.layer_metrics.macro_precision() if self.layer_metrics else None,
108+
"depth_interval_accuracy": self.depth_interval_metrics.macro_precision()
109+
if self.depth_interval_metrics
110+
else None,
104111
"groundwater_f1": groundwater_metrics.f1,
105112
"groundwater_recall": groundwater_metrics.recall,
106113
"groundwater_precision": groundwater_metrics.precision,
@@ -111,16 +118,16 @@ def metrics_dict(self) -> dict[str, float]:
111118
)
112119

113120
# Add dynamic language-specific metrics only if they exist
114-
for lang in ["de", "fr"]:
115-
layer_key = f"{lang}_layer"
116-
depth_key = f"{lang}_depth_interval"
121+
for lang in self.languages:
122+
layer_key = f"{lang}_layer_metrics"
123+
depth_key = f"{lang}_depth_interval_metrics"
117124

118-
if layer_key in self.metrics:
119-
result[f"{lang}_F1"] = self.metrics[layer_key].pseudo_macro_f1()
120-
result[f"{lang}_recall"] = self.metrics[layer_key].macro_recall()
121-
result[f"{lang}_precision"] = self.metrics[layer_key].macro_precision()
125+
if getattr(self, layer_key) and getattr(self, layer_key).metrics:
126+
result[f"{lang}_F1"] = getattr(self, layer_key).pseudo_macro_f1()
127+
result[f"{lang}_recall"] = getattr(self, layer_key).macro_recall()
128+
result[f"{lang}_precision"] = getattr(self, layer_key).macro_precision()
122129

123-
if depth_key in self.metrics:
124-
result[f"{lang}_depth_interval_accuracy"] = self.metrics[depth_key].macro_precision()
130+
if getattr(self, depth_key) and getattr(self, depth_key).metrics:
131+
result[f"{lang}_depth_interval_accuracy"] = getattr(self, depth_key).macro_precision()
125132

126133
return dict(result) # Convert defaultdict back to a regular dict

0 commit comments

Comments
 (0)