diff --git a/src/stratigraphy/benchmark/ground_truth.py b/src/stratigraphy/benchmark/ground_truth.py index b7f3497c..904a1010 100644 --- a/src/stratigraphy/benchmark/ground_truth.py +++ b/src/stratigraphy/benchmark/ground_truth.py @@ -31,9 +31,9 @@ def __init__(self, path: Path): metadata = ground_truth_item["metadata"] self.ground_truth[borehole_profile]["metadata"] = metadata - def for_file(self, file_name: str) -> list: + def for_file(self, file_name: str) -> dict: if file_name in self.ground_truth: return self.ground_truth[file_name] else: logger.warning(f"No ground truth data found for {file_name}.") - return [] + return {} diff --git a/src/stratigraphy/benchmark/score.py b/src/stratigraphy/benchmark/score.py index 9232f1d2..bc5a20b4 100644 --- a/src/stratigraphy/benchmark/score.py +++ b/src/stratigraphy/benchmark/score.py @@ -275,8 +275,10 @@ def create_predictions_objects(predictions: dict, ground_truth_path: Path) -> tu predictions_objects[file_name] = prediction_object if ground_truth_is_present: - predictions_objects[file_name].evaluate(ground_truth.for_file(file_name)) - number_of_truth_values[file_name] = len(ground_truth.for_file(file_name)["layers"]) + ground_truth_for_file = ground_truth.for_file(file_name) + if ground_truth_for_file: + predictions_objects[file_name].evaluate(ground_truth_for_file) + number_of_truth_values[file_name] = len(ground_truth_for_file["layers"]) return predictions_objects, number_of_truth_values