diff --git a/src/stratigraphy/benchmark/metrics.py b/src/stratigraphy/benchmark/metrics.py index edb4bbe2..f81d31d4 100644 --- a/src/stratigraphy/benchmark/metrics.py +++ b/src/stratigraphy/benchmark/metrics.py @@ -61,7 +61,7 @@ def get_metrics_list(self) -> list[Metrics]: class OverallMetricsCatalog: """Keeps track of all different relevant metrics that are computed for a dataset.""" - def __init__(self, languages: list[str]): + def __init__(self, languages: set[str]): self.layer_metrics = OverallMetrics() self.depth_interval_metrics = OverallMetrics() self.groundwater_metrics = OverallMetrics() diff --git a/src/stratigraphy/benchmark/score.py b/src/stratigraphy/benchmark/score.py index 218036e4..35fcc7cd 100644 --- a/src/stratigraphy/benchmark/score.py +++ b/src/stratigraphy/benchmark/score.py @@ -10,7 +10,6 @@ from dotenv import load_dotenv from stratigraphy import DATAPATH from stratigraphy.annotations.draw import draw_predictions -from stratigraphy.benchmark.ground_truth import GroundTruth from stratigraphy.evaluation.evaluation_dataclasses import BoreholeMetadataMetrics from stratigraphy.util.predictions import OverallFilePredictions @@ -21,39 +20,6 @@ logger = logging.getLogger(__name__) -def create_predictions_objects( - predictions: OverallFilePredictions, - ground_truth_path: Path | None, -) -> tuple[OverallFilePredictions, dict]: - """Create predictions objects from the predictions and evaluate them against the ground truth. - - Args: - predictions (OverallFilePredictions): The predictions objects. - ground_truth_path (Path | None): The path to the ground truth file. - - Returns: - tuple[OverallFilePredictions, dict]: The predictions objects and the number of ground truth values per - file. - """ - if ground_truth_path and ground_truth_path.exists(): # for inference no ground truth is available - ground_truth = GroundTruth(ground_truth_path) - ground_truth_is_present = True - else: - logging.warning("Ground truth file not found.") - ground_truth_is_present = False - return predictions, {} - - number_of_truth_values = {} - for file_predictions in predictions.file_predictions_list: - if ground_truth_is_present: - ground_truth_for_file = ground_truth.for_file(file_predictions.file_name) - if ground_truth_for_file: - file_predictions.evaluate(ground_truth_for_file) - number_of_truth_values[file_predictions.file_name] = len(ground_truth_for_file["layers"]) - - return predictions, number_of_truth_values - - def evaluate( predictions: OverallFilePredictions, ground_truth_path: Path, diff --git a/src/stratigraphy/evaluation/groundwater_evaluator.py b/src/stratigraphy/evaluation/groundwater_evaluator.py index 08ecec07..aa83ecaa 100644 --- a/src/stratigraphy/evaluation/groundwater_evaluator.py +++ b/src/stratigraphy/evaluation/groundwater_evaluator.py @@ -121,7 +121,7 @@ def evaluate(self) -> OverallGroundwaterMetrics: groundwater_elevation_metrics=groundwater_elevation_metrics, groundwater_date_metrics=groundwater_date_metrics, filename=filename, - ) # TODO: This clashes with the OverallMetrics object + ) overall_groundwater_metrics.add_groundwater_metrics(file_groundwater_metrics) diff --git a/src/stratigraphy/main.py b/src/stratigraphy/main.py index 95026c97..84aaed72 100644 --- a/src/stratigraphy/main.py +++ b/src/stratigraphy/main.py @@ -233,11 +233,11 @@ def start_pipeline( if part == "all": # Extract the groundwater levels - groundwater_in_document = GroundwaterInDocument.from_document(doc, metadata.elevation) + groundwater_entries = GroundwaterInDocument.from_document(doc, metadata.elevation) # Extract the layers - layer_predictions_list = LayersInDocument([], filename) - depths_materials_column_pairs_list = [] + layers = LayersInDocument([], filename) + depths_materials_columns_pairs = [] for page_index, page in enumerate(doc): page_number = page_index + 1 logger.info("Processing page %s", page_number) @@ -253,7 +253,7 @@ def start_pipeline( layer_predictions = remove_duplicate_layers( previous_page=doc[page_index - 1], current_page=page, - previous_layers=layer_predictions_list, + previous_layers=layers, current_layers=process_page_results.predictions, img_template_probability_threshold=matching_params[ "img_template_probability_threshold" @@ -262,8 +262,8 @@ def start_pipeline( else: layer_predictions = process_page_results.predictions - layer_predictions_list.add_layers_on_page(layer_predictions) - depths_materials_column_pairs_list.extend(process_page_results.depth_material_pairs) + layers.add_layers_on_page(layer_predictions) + depths_materials_columns_pairs.extend(process_page_results.depth_material_pairs) if draw_lines: # could be changed to if draw_lines and mflow_tracking: if not mlflow_tracking: @@ -276,10 +276,6 @@ def start_pipeline( ) mlflow.log_image(img, f"pages/{filename}_page_{page.number + 1}_lines.png") - groundwater_entries = groundwater_in_document - layers = layer_predictions_list - depths_materials_columns_pairs = depths_materials_column_pairs_list - # Add file predictions predictions.add_file_predictions( FilePredictions(