Skip to content

Commit 4e3d50e

Browse files
author
David Cleres
committed
Addressed the comments from the PR
1 parent dd559ff commit 4e3d50e

File tree

4 files changed

+8
-46
lines changed

4 files changed

+8
-46
lines changed

src/stratigraphy/benchmark/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def get_metrics_list(self) -> list[Metrics]:
6161
class OverallMetricsCatalog:
6262
"""Keeps track of all different relevant metrics that are computed for a dataset."""
6363

64-
def __init__(self, languages: list[str]):
64+
def __init__(self, languages: set[str]):
6565
self.layer_metrics = OverallMetrics()
6666
self.depth_interval_metrics = OverallMetrics()
6767
self.groundwater_metrics = OverallMetrics()

src/stratigraphy/benchmark/score.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from dotenv import load_dotenv
1111
from stratigraphy import DATAPATH
1212
from stratigraphy.annotations.draw import draw_predictions
13-
from stratigraphy.benchmark.ground_truth import GroundTruth
1413
from stratigraphy.evaluation.evaluation_dataclasses import BoreholeMetadataMetrics
1514
from stratigraphy.util.predictions import OverallFilePredictions
1615

@@ -21,39 +20,6 @@
2120
logger = logging.getLogger(__name__)
2221

2322

24-
def create_predictions_objects(
25-
predictions: OverallFilePredictions,
26-
ground_truth_path: Path | None,
27-
) -> tuple[OverallFilePredictions, dict]:
28-
"""Create predictions objects from the predictions and evaluate them against the ground truth.
29-
30-
Args:
31-
predictions (OverallFilePredictions): The predictions objects.
32-
ground_truth_path (Path | None): The path to the ground truth file.
33-
34-
Returns:
35-
tuple[OverallFilePredictions, dict]: The predictions objects and the number of ground truth values per
36-
file.
37-
"""
38-
if ground_truth_path and ground_truth_path.exists(): # for inference no ground truth is available
39-
ground_truth = GroundTruth(ground_truth_path)
40-
ground_truth_is_present = True
41-
else:
42-
logging.warning("Ground truth file not found.")
43-
ground_truth_is_present = False
44-
return predictions, {}
45-
46-
number_of_truth_values = {}
47-
for file_predictions in predictions.file_predictions_list:
48-
if ground_truth_is_present:
49-
ground_truth_for_file = ground_truth.for_file(file_predictions.file_name)
50-
if ground_truth_for_file:
51-
file_predictions.evaluate(ground_truth_for_file)
52-
number_of_truth_values[file_predictions.file_name] = len(ground_truth_for_file["layers"])
53-
54-
return predictions, number_of_truth_values
55-
56-
5723
def evaluate(
5824
predictions: OverallFilePredictions,
5925
ground_truth_path: Path,

src/stratigraphy/evaluation/groundwater_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def evaluate(self) -> OverallGroundwaterMetrics:
121121
groundwater_elevation_metrics=groundwater_elevation_metrics,
122122
groundwater_date_metrics=groundwater_date_metrics,
123123
filename=filename,
124-
) # TODO: This clashes with the OverallMetrics object
124+
)
125125

126126
overall_groundwater_metrics.add_groundwater_metrics(file_groundwater_metrics)
127127

src/stratigraphy/main.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,11 @@ def start_pipeline(
233233

234234
if part == "all":
235235
# Extract the groundwater levels
236-
groundwater_in_document = GroundwaterInDocument.from_document(doc, metadata.elevation)
236+
groundwater_entries = GroundwaterInDocument.from_document(doc, metadata.elevation)
237237

238238
# Extract the layers
239-
layer_predictions_list = LayersInDocument([], filename)
240-
depths_materials_column_pairs_list = []
239+
layers = LayersInDocument([], filename)
240+
depths_materials_columns_pairs = []
241241
for page_index, page in enumerate(doc):
242242
page_number = page_index + 1
243243
logger.info("Processing page %s", page_number)
@@ -253,7 +253,7 @@ def start_pipeline(
253253
layer_predictions = remove_duplicate_layers(
254254
previous_page=doc[page_index - 1],
255255
current_page=page,
256-
previous_layers=layer_predictions_list,
256+
previous_layers=layers,
257257
current_layers=process_page_results.predictions,
258258
img_template_probability_threshold=matching_params[
259259
"img_template_probability_threshold"
@@ -262,8 +262,8 @@ def start_pipeline(
262262
else:
263263
layer_predictions = process_page_results.predictions
264264

265-
layer_predictions_list.add_layers_on_page(layer_predictions)
266-
depths_materials_column_pairs_list.extend(process_page_results.depth_material_pairs)
265+
layers.add_layers_on_page(layer_predictions)
266+
depths_materials_columns_pairs.extend(process_page_results.depth_material_pairs)
267267

268268
if draw_lines: # could be changed to if draw_lines and mflow_tracking:
269269
if not mlflow_tracking:
@@ -276,10 +276,6 @@ def start_pipeline(
276276
)
277277
mlflow.log_image(img, f"pages/{filename}_page_{page.number + 1}_lines.png")
278278

279-
groundwater_entries = groundwater_in_document
280-
layers = layer_predictions_list
281-
depths_materials_columns_pairs = depths_materials_column_pairs_list
282-
283279
# Add file predictions
284280
predictions.add_file_predictions(
285281
FilePredictions(

0 commit comments

Comments
 (0)