From c2864d0b342f0ad7f09f16a38ef681aba1b9a7eb Mon Sep 17 00:00:00 2001 From: dcleres Date: Thu, 26 Sep 2024 18:07:45 +0200 Subject: [PATCH] Close #80: Refactoring of the groundwater evaluation and classes --- src/stratigraphy/annotations/draw.py | 13 +- src/stratigraphy/benchmark/score.py | 238 ++----------- .../evaluation/evaluation_dataclasses.py | 5 +- .../evaluation/groundwater_evaluator.py | 135 ++++++++ .../evaluation/metadata_evaluator.py | 2 +- src/stratigraphy/evaluation/utility.py | 73 ++++ .../groundwater/groundwater_extraction.py | 75 ++-- src/stratigraphy/main.py | 13 +- src/stratigraphy/util/predictions.py | 326 +++++++++++------- 9 files changed, 505 insertions(+), 375 deletions(-) create mode 100644 src/stratigraphy/evaluation/groundwater_evaluator.py create mode 100644 src/stratigraphy/evaluation/utility.py diff --git a/src/stratigraphy/annotations/draw.py b/src/stratigraphy/annotations/draw.py index da8c631c..029298c7 100644 --- a/src/stratigraphy/annotations/draw.py +++ b/src/stratigraphy/annotations/draw.py @@ -9,7 +9,7 @@ from dotenv import load_dotenv from stratigraphy.depthcolumn.depthcolumn import DepthColumn from stratigraphy.depths_materials_column_pairs.depths_materials_column_pairs import DepthsMaterialsColumnPairs -from stratigraphy.groundwater.groundwater_extraction import GroundwaterInformationOnPage +from stratigraphy.groundwater.groundwater_extraction import GroundwaterOnPage from stratigraphy.layer.layer import LayerPrediction from stratigraphy.metadata.coordinate_extraction import Coordinate from stratigraphy.metadata.elevation_extraction import Elevation @@ -82,9 +82,10 @@ def draw_predictions( draw_coordinates(shape, coordinates) if elevation is not None and page_number == elevation.page: draw_elevation(shape, elevation) - for groundwater_entry in file_prediction.groundwater_entries: - if page_number == groundwater_entry.page: - draw_groundwater(shape, groundwater_entry) + for groundwater_on_page in file_prediction.groundwater.groundwater: + # TODO: Adapt this to the structures above -> List the groundwater in the function + if page_number == groundwater_on_page.page: + draw_groundwater(shape, groundwater_on_page) draw_depth_columns_and_material_rect( shape, page.derotation_matrix, @@ -185,12 +186,12 @@ def draw_coordinates(shape: fitz.Shape, coordinates: Coordinate) -> None: shape.finish(color=fitz.utils.getColor("purple")) -def draw_groundwater(shape: fitz.Shape, groundwater_entry: GroundwaterInformationOnPage) -> None: +def draw_groundwater(shape: fitz.Shape, groundwater_entry: GroundwaterOnPage) -> None: """Draw a bounding box around the area of the page where the coordinates were extracted from. Args: shape (fitz.Shape): The shape object for drawing. - groundwater_entry (GroundwaterInformationOnPage): The groundwater information to draw. + groundwater_entry (GroundwaterOnPage): The groundwater information to draw. """ shape.draw_rect(groundwater_entry.rect) shape.finish(color=fitz.utils.getColor("pink")) diff --git a/src/stratigraphy/benchmark/score.py b/src/stratigraphy/benchmark/score.py index 8ac74e4f..1f527fab 100644 --- a/src/stratigraphy/benchmark/score.py +++ b/src/stratigraphy/benchmark/score.py @@ -5,13 +5,12 @@ import os from pathlib import Path +import pandas as pd from dotenv import load_dotenv from stratigraphy import DATAPATH from stratigraphy.annotations.draw import draw_predictions -from stratigraphy.benchmark.ground_truth import GroundTruth -from stratigraphy.benchmark.metrics import DatasetMetrics, DatasetMetricsCatalog, Metrics +from stratigraphy.evaluation.evaluation_dataclasses import BoreholeMetadataMetrics from stratigraphy.util.predictions import OverallFilePredictions -from stratigraphy.util.util import parse_text load_dotenv() @@ -20,192 +19,6 @@ logger = logging.getLogger(__name__) -def get_layer_metrics(predictions: OverallFilePredictions, number_of_truth_values: dict) -> DatasetMetrics: - """Calculate F1, precision and recall for the layer predictions. - - Calculate F1, precision and recall for the individual documents as well as overall. - - Args: - predictions (dict): The predictions. - number_of_truth_values (dict): The number of ground truth values per file. - - Returns: - DatasetMetrics: the metrics for the layers - """ - layer_metrics = DatasetMetrics() - - for file_prediction in predictions.file_predictions_list: - hits = 0 - for layer in file_prediction.layers: - if layer.material_is_correct: - hits += 1 - if parse_text(layer.material_description.text) == "": - logger.warning("Empty string found in predictions") - layer_metrics.metrics[file_prediction.file_name] = Metrics( - tp=hits, fp=len(file_prediction.layers) - hits, fn=number_of_truth_values[file_prediction.file_name] - hits - ) - - return layer_metrics - - -def get_depth_interval_metrics(predictions: OverallFilePredictions) -> DatasetMetrics: - """Calculate F1, precision and recall for the depth interval predictions. - - Calculate F1, precision and recall for the individual documents as well as overall. - - Depth interval accuracy is not calculated for layers with incorrect material predictions. - - Args: - predictions (dict): The predictions. - - Returns: - DatasetMetrics: the metrics for the depth intervals - """ - depth_interval_metrics = DatasetMetrics() - - for file_prediction in predictions.file_predictions_list: - depth_interval_hits = 0 - depth_interval_occurrences = 0 - for layer in file_prediction.layers: - if layer.material_is_correct: - if layer.depth_interval_is_correct is not None: - depth_interval_occurrences += 1 - if layer.depth_interval_is_correct: - depth_interval_hits += 1 - - if depth_interval_occurrences > 0: - depth_interval_metrics.metrics[file_prediction.file_name] = Metrics( - tp=depth_interval_hits, fp=depth_interval_occurrences - depth_interval_hits, fn=0 - ) - - return depth_interval_metrics - - -def evaluate_borehole_extraction( - predictions: OverallFilePredictions, number_of_truth_values: dict -) -> DatasetMetricsCatalog: - """Evaluate the borehole extraction predictions. - - Args: - predictions (dict): The FilePredictions objects. - number_of_truth_values (dict): The number of layer ground truth values per file. - - Returns: - DatasetMetricsCatalogue: A DatasetMetricsCatalogue that maps a metrics name to the corresponding DatasetMetrics - object - """ - all_metrics = evaluate_layer_extraction(predictions, number_of_truth_values) - all_metrics.metrics["groundwater"] = get_metrics(predictions, "groundwater_is_correct", "groundwater") - all_metrics.metrics["groundwater_depth"] = get_metrics(predictions, "groundwater_is_correct", "groundwater_depth") - return all_metrics - - -def get_metrics(predictions: OverallFilePredictions, field_key: str, field_name: str) -> DatasetMetrics: - """Get the metrics for a specific field in the predictions. - - Args: - predictions (dict): The FilePredictions objects. - field_key (str): The key to access the specific field in the prediction objects. - field_name (str): The name of the field being evaluated. - - Returns: - DatasetMetrics: The requested DatasetMetrics object. - """ - dataset_metrics = DatasetMetrics() - - for file_prediction in predictions.file_predictions_list: - dataset_metrics.metrics[file_prediction.file_name] = getattr(file_prediction, field_key)[field_name] - - return dataset_metrics - - -def evaluate_layer_extraction( - predictions: OverallFilePredictions, number_of_truth_values: dict -) -> DatasetMetricsCatalog: - """Calculate F1, precision and recall for the predictions. - - Calculate F1, precision and recall for the individual documents as well as overall. - The individual document metrics are returned as a DataFrame. - - Args: - predictions (dict): The FilePredictions objects. - number_of_truth_values (dict): The number of layer ground truth values per file. - - Returns: - DatasetMetricsCatalogue: A dictionary that maps a metrics name to the corresponding DatasetMetrics object - """ - all_metrics = DatasetMetricsCatalog() - all_metrics.metrics["layer"] = get_layer_metrics(predictions, number_of_truth_values) - all_metrics.metrics["depth_interval"] = get_depth_interval_metrics(predictions) - - # create predictions by language - predictions_by_language = { - "de": OverallFilePredictions(), - "fr": OverallFilePredictions(), - } # TODO: make this dynamic and why is this hardcoded? - for file_predictions in predictions.file_predictions_list: - language = file_predictions.metadata.language - if language in predictions_by_language: - predictions_by_language[language].add_file_predictions(file_predictions) - - for language, language_predictions in predictions_by_language.items(): - language_number_of_truth_values = { - prediction.file_name: number_of_truth_values[prediction.file_name] - for prediction in language_predictions.file_predictions_list - } - all_metrics.metrics[f"{language}_layer"] = get_layer_metrics( - language_predictions, language_number_of_truth_values - ) - all_metrics.metrics[f"{language}_depth_interval"] = get_depth_interval_metrics(language_predictions) - - logging.info("Macro avg:") - logging.info( - "F1: %.1f%%, precision: %.1f%%, recall: %.1f%%, depth_interval_accuracy: %.1f%%", - all_metrics.metrics["layer"].macro_f1() * 100, - all_metrics.metrics["layer"].macro_precision() * 100, - all_metrics.metrics["layer"].macro_recall() * 100, - all_metrics.metrics["depth_interval"].macro_precision() * 100, - ) - - return all_metrics - - -def create_predictions_objects( - predictions: OverallFilePredictions, - ground_truth_path: Path | None, -) -> tuple[OverallFilePredictions, dict]: - """Create predictions objects from the predictions and evaluate them against the ground truth. - - Args: - predictions (dict): The predictions from the predictions.json file. - ground_truth_path (Path | None): The path to the ground truth file. - metadata_per_file (BoreholeMetadataList): The metadata for the files. - - Returns: - tuple[dict[str, FilePredictions], dict]: The predictions objects and the number of ground truth values per - file. - """ - if ground_truth_path and os.path.exists(ground_truth_path): # for inference no ground truth is available - ground_truth = GroundTruth(ground_truth_path) - ground_truth_is_present = True - else: - logging.warning("Ground truth file not found.") - ground_truth_is_present = False - - number_of_truth_values = {} - for file_predictions in predictions.file_predictions_list: - # prediction_object = FilePredictions.create_from_json(file_predictions, file_predictions.file_name) - - # predictions_objects[file_name] = prediction_object - if ground_truth_is_present: - ground_truth_for_file = ground_truth.for_file(file_predictions.file_name) - if ground_truth_for_file: - file_predictions.evaluate(ground_truth_for_file) - number_of_truth_values[file_predictions.file_name] = len(ground_truth_for_file["layers"]) - - return predictions, number_of_truth_values - - def evaluate( predictions: OverallFilePredictions, ground_truth_path: Path, @@ -218,52 +31,45 @@ def evaluate( # Evaluate the borehole extraction metadata ############################# metadata_metrics_list = predictions.evaluate_metadata_extraction(ground_truth_path) - metadata_metrics = metadata_metrics_list.get_cumulated_metrics() - document_level_metadata_metrics = metadata_metrics_list.get_document_level_metrics() + metadata_metrics: BoreholeMetadataMetrics = metadata_metrics_list.get_cumulated_metrics() + document_level_metadata_metrics: pd.DataFrame = metadata_metrics_list.get_document_level_metrics() document_level_metadata_metrics.to_csv( temp_directory / "document_level_metadata_metrics.csv", index_label="document_name" ) # mlflow.log_artifact expects a file # print the metrics logger.info("Metadata Performance metrics:") - logger.info(metadata_metrics) + logger.info(metadata_metrics.to_json()) if mlflow_tracking: import mlflow - mlflow.log_metrics(metadata_metrics) + mlflow.log_metrics(metadata_metrics.to_json()) mlflow.log_artifact(temp_directory / "document_level_metadata_metrics.csv") ############################# # Evaluate the borehole extraction ############################# - if predictions: - predictions, number_of_truth_values = create_predictions_objects(predictions, ground_truth_path) + metrics = predictions.evaluate_borehole_extraction(ground_truth_path) - if number_of_truth_values: # only evaluate if ground truth is available - metrics = evaluate_borehole_extraction(predictions, number_of_truth_values) - - metrics.document_level_metrics_df().to_csv( - temp_directory / "document_level_metrics.csv", index_label="document_name" - ) # mlflow.log_artifact expects a file - metrics_dict = metrics.metrics_dict() - - # Format the metrics dictionary to limit to three digits - formatted_metrics = {k: f"{v:.3f}" for k, v in metrics_dict.items()} - logger.info("Performance metrics: %s", formatted_metrics) + metrics.document_level_metrics_df().to_csv( + temp_directory / "document_level_metrics.csv", index_label="document_name" + ) # mlflow.log_artifact expects a file + metrics_dict = metrics.metrics_dict() - if mlflow_tracking: - mlflow.log_metrics(metrics_dict) - mlflow.log_artifact(temp_directory / "document_level_metrics.csv") + # Format the metrics dictionary to limit to three digits + formatted_metrics = {k: f"{v:.3f}" for k, v in metrics_dict.items()} + logger.info("Performance metrics: %s", formatted_metrics) - else: - logger.warning("Ground truth file not found. Skipping evaluation.") + if mlflow_tracking: + mlflow.log_metrics(metrics_dict) + mlflow.log_artifact(temp_directory / "document_level_metrics.csv") - ############################# - # Draw the prediction - ############################# - if input_directory and draw_directory: - draw_predictions(predictions, input_directory, draw_directory, document_level_metadata_metrics) + ############################# + # Draw the prediction + ############################# + if input_directory and draw_directory: + draw_predictions(predictions, input_directory, draw_directory, document_level_metadata_metrics) if __name__ == "__main__": diff --git a/src/stratigraphy/evaluation/evaluation_dataclasses.py b/src/stratigraphy/evaluation/evaluation_dataclasses.py index dec4a544..e96316a1 100644 --- a/src/stratigraphy/evaluation/evaluation_dataclasses.py +++ b/src/stratigraphy/evaluation/evaluation_dataclasses.py @@ -128,11 +128,10 @@ def get_cumulated_metrics(self) -> dict: coordinates_metrics = Metrics.micro_average( [metadata.coordinates_metrics for metadata in self.borehole_metadata_metrics] ) - return BoreholeMetadataMetrics( - elevation_metrics=elevation_metrics, coordinates_metrics=coordinates_metrics - ).to_json() + return BoreholeMetadataMetrics(elevation_metrics=elevation_metrics, coordinates_metrics=coordinates_metrics) def get_document_level_metrics(self) -> pd.DataFrame: + """Get the document level metrics.""" # Get a dataframe per document, concatenate, and sort by index (document name) return pd.concat( [metadata.get_document_level_metrics() for metadata in self.borehole_metadata_metrics] diff --git a/src/stratigraphy/evaluation/groundwater_evaluator.py b/src/stratigraphy/evaluation/groundwater_evaluator.py new file mode 100644 index 00000000..5d029526 --- /dev/null +++ b/src/stratigraphy/evaluation/groundwater_evaluator.py @@ -0,0 +1,135 @@ +"""Classes for evaluating the groundwater levels of a borehole.""" + +from dataclasses import dataclass +from typing import Any + +from stratigraphy.benchmark.ground_truth import GroundTruth +from stratigraphy.benchmark.metrics import DatasetMetrics +from stratigraphy.evaluation.evaluation_dataclasses import Metrics +from stratigraphy.evaluation.utility import count_against_ground_truth +from stratigraphy.groundwater.groundwater_extraction import Groundwater, GroundwaterInDocument + + +@dataclass +class GroundwaterMetrics: + """Class for storing the metrics of the groundwater information.""" + + groundwater_metrics: Metrics = None + groundwater_depth_metrics: Metrics = None + groundwater_elevation_metrics: Metrics = None + groundwater_date_metrics: Metrics = None + filename: str = None + + +class OverallGroundwaterMetrics: + """Class for storing the overall metrics of the groundwater information.""" + + groundwater_metrics: list[GroundwaterMetrics] = None + + def __init__(self): + self.groundwater_metrics = [] + + def add_groundwater_metrics(self, groundwater_metrics: GroundwaterMetrics): + """Add groundwater metrics to the list. + + Args: + groundwater_metrics (GroundwaterMetrics): The groundwater metrics to add. + """ + self.groundwater_metrics.append(groundwater_metrics) + + def groundwater_metrics_to_dataset_metrics(self): + """Convert the overall groundwater metrics to a DatasetMetrics object.""" + dataset_metrics = DatasetMetrics() + for groundwater_metrics in self.groundwater_metrics: + dataset_metrics.metrics[groundwater_metrics.filename] = groundwater_metrics.groundwater_metrics + return dataset_metrics + + def groundwater_depth_metrics_to_dataset_metrics(self): + """Convert the overall groundwater depth metrics to a DatasetMetrics object.""" + dataset_metrics = DatasetMetrics() + for groundwater_metrics in self.groundwater_metrics: + dataset_metrics.metrics[groundwater_metrics.filename] = groundwater_metrics.groundwater_depth_metrics + return dataset_metrics + + +class GroundwaterEvaluator: + """Class for evaluating the extracted groundwater information of a borehole.""" + + groundwater_entries: list[GroundwaterInDocument] = None + groundwater_ground_truth: dict[str, Any] = None + + def __init__(self, groundwater_entries: list[GroundwaterInDocument], ground_truth_path: str): + """Initializes the GroundwaterEvaluator object. + + Args: + groundwater_entries (list[GroundwaterInDocument]): The metadata to evaluate. + ground_truth_path (str): The path to the ground truth file. + """ + # Load the ground truth data for the metadata + self.groundwater_ground_truth = GroundTruth(ground_truth_path) + if self.groundwater_ground_truth is None: + self.groundwater_ground_truth = [] + + self.groundwater_entries = groundwater_entries + + def evaluate(self): + """Evaluate the groundwater information of the file against the ground truth. + + Args: + groundwater_ground_truth (list): The ground truth for the file. + """ + overall_groundwater_metrics = OverallGroundwaterMetrics() + + for groundwater_in_doc in self.groundwater_entries: + filename = groundwater_in_doc.filename + ground_truth = self.groundwater_ground_truth.for_file(filename).get("groundwater", []) + if ground_truth is None: + ground_truth = [] # If no ground truth is available, set it to an empty list + + ############################################################################################################ + ### Compute the metadata correctness for the groundwater information. + ############################################################################################################ + gt_groundwater = [ + Groundwater.from_json_values( + depth=json_gt_data["depth"], + date=json_gt_data["date"], + elevation=json_gt_data["elevation"], + ) + for json_gt_data in ground_truth + ] + + groundwater_metrics = count_against_ground_truth( + [ + ( + entry.groundwater.depth, + entry.groundwater.format_date(), + entry.groundwater.elevation, + ) + for entry in groundwater_in_doc.groundwater + ], + [(entry.depth, entry.format_date(), entry.elevation) for entry in gt_groundwater], + ) + groundwater_depth_metrics = count_against_ground_truth( + [entry.groundwater.depth for entry in groundwater_in_doc.groundwater], + [entry.depth for entry in gt_groundwater], + ) + groundwater_elevation_metrics = count_against_ground_truth( + [entry.groundwater.elevation for entry in groundwater_in_doc.groundwater], + [entry.elevation for entry in gt_groundwater], + ) + groundwater_date_metrics = count_against_ground_truth( + [entry.groundwater.date for entry in groundwater_in_doc.groundwater], + [entry.date for entry in gt_groundwater], + ) + + file_groundwater_metrics = GroundwaterMetrics( + groundwater_metrics=groundwater_metrics, + groundwater_depth_metrics=groundwater_depth_metrics, + groundwater_elevation_metrics=groundwater_elevation_metrics, + groundwater_date_metrics=groundwater_date_metrics, + filename=filename, + ) # TODO: This clashes with the DatasetMetrics object + + overall_groundwater_metrics.add_groundwater_metrics(file_groundwater_metrics) + + return overall_groundwater_metrics diff --git a/src/stratigraphy/evaluation/metadata_evaluator.py b/src/stratigraphy/evaluation/metadata_evaluator.py index 1c922d8b..3dc1f4e5 100644 --- a/src/stratigraphy/evaluation/metadata_evaluator.py +++ b/src/stratigraphy/evaluation/metadata_evaluator.py @@ -25,7 +25,7 @@ def __init__(self, metadata_list: BoreholeMetadataList, ground_truth_path: str): metadata_list (BoreholeMetadataList): The metadata to evaluate. ground_truth_path (str): The path to the ground truth file. """ - self.metadata_list = metadata_list + self.metadata_list: BoreholeMetadataList = metadata_list # Load the ground truth data for the metadata self.metadata_ground_truth = GroundTruth(ground_truth_path) diff --git a/src/stratigraphy/evaluation/utility.py b/src/stratigraphy/evaluation/utility.py new file mode 100644 index 00000000..16a81dfd --- /dev/null +++ b/src/stratigraphy/evaluation/utility.py @@ -0,0 +1,73 @@ +"""Utility functions for evaluation.""" + +from collections import Counter + +import Levenshtein +from stratigraphy.evaluation.evaluation_dataclasses import Metrics +from stratigraphy.layer.layer import LayerPrediction +from stratigraphy.util.util import parse_text + + +def count_against_ground_truth(values: list, ground_truth: list) -> Metrics: + """Count the number of true positives, false positives and false negatives. + + Args: + values (list): The values to count. + ground_truth (list): The ground truth values. + + Returns: + Metrics: The metrics for the values. + """ + # Counter deals with duplicates when doing intersection + values_counter = Counter(values) + ground_truth_counter = Counter(ground_truth) + + tp = (values_counter & ground_truth_counter).total() # size of intersection + return Metrics(tp=tp, fp=len(values) - tp, fn=len(ground_truth) - tp) + + +def find_matching_layer(layer: LayerPrediction, unmatched_layers: list[dict]) -> tuple[dict, bool] | tuple[None, None]: + """Find the matching layer in the ground truth. + + Args: + layer (LayerPrediction): The layer to match. + unmatched_layers (list[dict]): The layers from the ground truth that were not yet matched during the + current evaluation. + + Returns: + tuple[dict, bool] | tuple[None, None]: The matching layer and a boolean indicating if the depth interval + is correct. None if no match was found. + """ + parsed_text = parse_text(layer.material_description.text) + possible_matches = [ + ground_truth_layer + for ground_truth_layer in unmatched_layers + if Levenshtein.ratio(parsed_text, ground_truth_layer["material_description"]) > 0.9 + ] + + if not possible_matches: + return None, None + + for possible_match in possible_matches: + start = possible_match["depth_interval"]["start"] + end = possible_match["depth_interval"]["end"] + + if layer.depth_interval is None: + pass + + elif ( + start == 0 and layer.depth_interval.start is None and end == layer.depth_interval.end.value + ): # If not specified differently, we start at 0. + unmatched_layers.remove(possible_match) + return possible_match, True + + elif ( # noqa: SIM102 + layer.depth_interval.start is not None and layer.depth_interval.end is not None + ): # In all other cases we do not allow a None value. + if start == layer.depth_interval.start.value and end == layer.depth_interval.end.value: + unmatched_layers.remove(possible_match) + return possible_match, True + + match = max(possible_matches, key=lambda x: Levenshtein.ratio(parsed_text, x["material_description"])) + unmatched_layers.remove(match) + return match, False diff --git a/src/stratigraphy/groundwater/groundwater_extraction.py b/src/stratigraphy/groundwater/groundwater_extraction.py index 9b533878..d5fc101d 100644 --- a/src/stratigraphy/groundwater/groundwater_extraction.py +++ b/src/stratigraphy/groundwater/groundwater_extraction.py @@ -24,7 +24,7 @@ @dataclass -class GroundwaterInformation(metaclass=abc.ABCMeta): +class Groundwater(metaclass=abc.ABCMeta): """Abstract class for Groundwater Information.""" depth: float # Depth of the groundwater relative to the surface @@ -48,12 +48,7 @@ def __str__(self) -> str: Returns: str: The object as a string. """ - return ( - f"GroundwaterInformation(" - f"date={self.format_date()}, " - f"depth={self.depth}, " - f"elevation={self.elevation})" - ) + return f"Groundwater(" f"date={self.format_date()}, " f"depth={self.depth}, " f"elevation={self.elevation})" @staticmethod def from_json_values(depth: float | None, date: str | None, elevation: float | None): @@ -65,10 +60,10 @@ def from_json_values(depth: float | None, date: str | None, elevation: float | N elevation (float | None): The elevation of the groundwater. Returns: - GroundwaterInformation: The object created from the dictionary. + Groundwater: The object created from the dictionary. """ date = datetime.strptime(date, DATE_FORMAT).date() if date is not None and date != "" else None - return GroundwaterInformation(depth=depth, date=date, elevation=elevation) + return Groundwater(depth=depth, date=date, elevation=elevation) def format_date(self) -> str | None: """Formats the date of the groundwater measurement. @@ -83,10 +78,10 @@ def format_date(self) -> str | None: @dataclass(kw_only=True) -class GroundwaterInformationOnPage(ExtractedFeature): +class GroundwaterOnPage(ExtractedFeature): """Abstract class for Groundwater Information.""" - groundwater: GroundwaterInformation + groundwater: Groundwater def is_valid(self) -> bool: """Checks if the information is valid. @@ -122,15 +117,51 @@ def from_json_values(date: str | None, depth: float | None, elevation: float | N rect (list[float]): The rectangle that contains the extracted information. Returns: - GroundwaterInformationOnPage: The object created from the dictionary. + GroundwaterOnPage: The object created from the dictionary. """ - return GroundwaterInformationOnPage( - groundwater=GroundwaterInformation.from_json_values(depth=depth, date=date, elevation=elevation), + return GroundwaterOnPage( + groundwater=Groundwater.from_json_values(depth=depth, date=date, elevation=elevation), page=page, rect=fitz.Rect(rect), ) +@dataclass +class GroundwaterInDocument: + """Class for extracted groundwater information from a document.""" + + groundwater: list[GroundwaterOnPage] + filename: str + + def __init__(self, filename: str): + """Initializes the GroundwaterInDocument object. + + Args: + filename (str): The name of the document. + """ + self.groundwater = [] + self.filename = filename + + def add_groundwater_from_page(self, groundwater_on_page: GroundwaterOnPage | list[GroundwaterOnPage]): + """Adds groundwater information from a page to the groundwater list. + + Args: + groundwater_on_page (GroundwaterOnPage): The groundwater information from a page. + """ + if isinstance(groundwater_on_page, list): + self.groundwater.extend(groundwater_on_page) + else: + self.groundwater.append(groundwater_on_page) + + def get_groundwater_in_doc(self) -> list[Groundwater]: + """Returns the groundwater information in the document. + + Returns: + list[Groundwater]: The groundwater information in the document. + """ + return [entry.groundwater for entry in self.groundwater] + + class GroundwaterLevelExtractor(DataExtractor): """Extracts coordinates from a PDF document.""" @@ -143,7 +174,7 @@ class GroundwaterLevelExtractor(DataExtractor): preprocess_replacements = {",": ".", "'": ".", "o": "0", "\n": " ", "ΓΌ": "u"} - def get_groundwater_near_key(self, lines: list[TextLine], page: int) -> list[GroundwaterInformationOnPage]: + def get_groundwater_near_key(self, lines: list[TextLine], page: int) -> list[GroundwaterOnPage]: """Find groundwater information from text lines that are close to an explicit "groundwater" label. Also apply some preprocessing to the text of those text lines, to deal with some common (OCR) errors. @@ -153,7 +184,7 @@ def get_groundwater_near_key(self, lines: list[TextLine], page: int) -> list[Gro page (int): the page number (1-based) of the PDF document Returns: - list[GroundwaterInformationOnPage]: all found groundwater information + list[GroundwaterOnPage]: all found groundwater information """ # find the key that indicates the groundwater information groundwater_key_lines = self.find_feature_key(lines) @@ -178,14 +209,14 @@ def get_groundwater_near_key(self, lines: list[TextLine], page: int) -> list[Gro return extracted_groundwater_list - def get_groundwater_info_from_lines(self, lines: list[TextLine], page: int) -> GroundwaterInformationOnPage: + def get_groundwater_info_from_lines(self, lines: list[TextLine], page: int) -> GroundwaterOnPage: """Extracts the groundwater information from a list of text lines. Args: lines (list[TextLine]): the lines of text to extract the groundwater information from page (int): the page number (1-based) of the PDF document Returns: - GroundwaterInformationOnPage: the extracted groundwater information + GroundwaterOnPage: the extracted groundwater information """ date: dt | None = None depth: float | None = None @@ -261,15 +292,15 @@ def get_groundwater_info_from_lines(self, lines: list[TextLine], page: int) -> G # # TODO: IF the date is not provided for the groundwater (most of the time because there was only one # drilling date - chose the date of the document. Date needs to be extracted from the document separately) if depth or elevation: - return GroundwaterInformationOnPage( - groundwater=GroundwaterInformation(depth=depth, date=date, elevation=elevation), + return GroundwaterOnPage( + groundwater=Groundwater(depth=depth, date=date, elevation=elevation), rect=rect_union, page=page, ) else: raise ValueError("Could not extract all required information from the lines provided.") - def extract_groundwater(self, terrain_elevation: Elevation | None) -> list[GroundwaterInformationOnPage]: + def extract_groundwater(self, terrain_elevation: Elevation | None) -> list[GroundwaterOnPage]: """Extracts the groundwater information from a borehole profile. Processes the borehole profile page by page and tries to find the coordinates in the respective text of the @@ -281,7 +312,7 @@ def extract_groundwater(self, terrain_elevation: Elevation | None) -> list[Groun terrain_elevation (ElevationInformation | None): The elevation of the borehole. Returns: - list[GroundwaterInformationOnPage]: the extracted coordinates (if any) + list[GroundwaterOnPage]: the extracted coordinates (if any) """ for page in self.doc: lines = extract_text_lines(page) diff --git a/src/stratigraphy/main.py b/src/stratigraphy/main.py index 5e528213..81b8650a 100644 --- a/src/stratigraphy/main.py +++ b/src/stratigraphy/main.py @@ -14,7 +14,7 @@ from stratigraphy.annotations.plot_utils import plot_lines from stratigraphy.benchmark.score import evaluate from stratigraphy.extract import process_page -from stratigraphy.groundwater.groundwater_extraction import GroundwaterLevelExtractor +from stratigraphy.groundwater.groundwater_extraction import GroundwaterInDocument, GroundwaterLevelExtractor from stratigraphy.layer.duplicate_detection import remove_duplicate_layers from stratigraphy.layer.layer import LayerPrediction from stratigraphy.lines.line_detection import extract_lines, line_detection_params @@ -227,8 +227,11 @@ def start_pipeline( if part == "all": # Extract the groundwater levels + groundwater_in_document = GroundwaterInDocument(filename) groundwater_extractor = GroundwaterLevelExtractor(document=doc) - groundwater = groundwater_extractor.extract_groundwater(terrain_elevation=metadata.elevation) + groundwater_in_document.add_groundwater_from_page( + groundwater_extractor.extract_groundwater(terrain_elevation=metadata.elevation) + ) layer_predictions_list = [] depths_materials_column_pairs_list = [] @@ -277,17 +280,17 @@ def start_pipeline( layers=page_layer_predictions_list, depths_materials_columns_pairs=depths_materials_column_pairs_list, metadata=metadata, - groundwater_entries=groundwater, + groundwater=groundwater_in_document, ) ) else: predictions.add_file_predictions( FilePredictions( file_name=filename, - metadata=metadata, - groundwater_entries=None, layers=None, depths_materials_columns_pairs=None, + metadata=metadata, + groundwater=None, ) ) diff --git a/src/stratigraphy/util/predictions.py b/src/stratigraphy/util/predictions.py index 8e467888..111a9eeb 100644 --- a/src/stratigraphy/util/predictions.py +++ b/src/stratigraphy/util/predictions.py @@ -1,15 +1,17 @@ """This module contains classes for predictions.""" import logging -from collections import Counter +import os from pathlib import Path -import Levenshtein - +from stratigraphy.benchmark.ground_truth import GroundTruth +from stratigraphy.benchmark.metrics import DatasetMetrics, DatasetMetricsCatalog from stratigraphy.depths_materials_column_pairs.depths_materials_column_pairs import DepthsMaterialsColumnPairs from stratigraphy.evaluation.evaluation_dataclasses import Metrics, OverallBoreholeMetadataMetrics +from stratigraphy.evaluation.groundwater_evaluator import GroundwaterEvaluator from stratigraphy.evaluation.metadata_evaluator import MetadataEvaluator -from stratigraphy.groundwater.groundwater_extraction import GroundwaterInformation, GroundwaterInformationOnPage +from stratigraphy.evaluation.utility import find_matching_layer +from stratigraphy.groundwater.groundwater_extraction import GroundwaterInDocument from stratigraphy.layer.layer import LayerPrediction from stratigraphy.metadata.metadata import BoreholeMetadata, BoreholeMetadataList from stratigraphy.util.util import parse_text @@ -25,15 +27,14 @@ def __init__( layers: list[LayerPrediction], file_name: str, metadata: BoreholeMetadata, - groundwater_entries: list[GroundwaterInformationOnPage], + groundwater: GroundwaterInDocument, depths_materials_columns_pairs: list[DepthsMaterialsColumnPairs], ): self.layers: list[LayerPrediction] = layers self.depths_materials_columns_pairs: list[DepthsMaterialsColumnPairs] = depths_materials_columns_pairs self.file_name = file_name self.metadata = metadata - self.groundwater_entries = groundwater_entries - self.groundwater_is_correct: dict = {} + self.groundwater = groundwater def convert_to_ground_truth(self): """Convert the predictions to ground truth format. @@ -71,11 +72,8 @@ def evaluate(self, ground_truth: dict): Args: ground_truth (dict): The ground truth for the file. """ + # TODO: Call the evaluator for Layers instead self.evaluate_layers(ground_truth["layers"]) - groundwater_ground_truth = ground_truth.get("groundwater", []) - if groundwater_ground_truth is None: - groundwater_ground_truth = [] - self.evaluate_groundwater(groundwater_ground_truth) def evaluate_layers(self, ground_truth_layers: list): """Evaluate all layers of the predictions against the ground truth. @@ -85,7 +83,7 @@ def evaluate_layers(self, ground_truth_layers: list): """ unmatched_layers = ground_truth_layers.copy() for layer in self.layers: - match, depth_interval_is_correct = self._find_matching_layer(layer, unmatched_layers) + match, depth_interval_is_correct = find_matching_layer(layer, unmatched_layers) if match: layer.material_is_correct = True layer.depth_interval_is_correct = depth_interval_is_correct @@ -93,115 +91,6 @@ def evaluate_layers(self, ground_truth_layers: list): layer.material_is_correct = False layer.depth_interval_is_correct = None - @staticmethod - def count_against_ground_truth(values: list, ground_truth: list) -> Metrics: - """Count the number of true positives, false positives and false negatives. - - Args: - values (list): The values to count. - ground_truth (list): The ground truth values. - - Returns: - Metrics: The metrics for the values. - """ - # Counter deals with duplicates when doing intersection - values_counter = Counter(values) - ground_truth_counter = Counter(ground_truth) - - tp = (values_counter & ground_truth_counter).total() # size of intersection - return Metrics(tp=tp, fp=len(values) - tp, fn=len(ground_truth) - tp) - - def evaluate_groundwater(self, groundwater_ground_truth: list): - """Evaluate the groundwater information of the file against the ground truth. - - Args: - groundwater_ground_truth (list): The ground truth for the file. - """ - ############################################################################################################ - ### Compute the metadata correctness for the groundwater information. - ############################################################################################################ - gt_groundwater = [ - GroundwaterInformation.from_json_values( - depth=json_gt_data["depth"], - date=json_gt_data["date"], - elevation=json_gt_data["elevation"], - ) - for json_gt_data in groundwater_ground_truth - ] - - self.groundwater_is_correct["groundwater"] = self.count_against_ground_truth( - [ - ( - entry.groundwater.depth, - entry.groundwater.format_date(), - entry.groundwater.elevation, - ) - for entry in self.groundwater_entries - ], - [(entry.depth, entry.format_date(), entry.elevation) for entry in gt_groundwater], - ) - self.groundwater_is_correct["groundwater_depth"] = self.count_against_ground_truth( - [entry.groundwater.depth for entry in self.groundwater_entries], - [entry.depth for entry in gt_groundwater], - ) - self.groundwater_is_correct["groundwater_elevation"] = self.count_against_ground_truth( - [entry.groundwater.elevation for entry in self.groundwater_entries], - [entry.elevation for entry in gt_groundwater], - ) - self.groundwater_is_correct["groundwater_date"] = self.count_against_ground_truth( - [entry.groundwater.date for entry in self.groundwater_entries], - [entry.date for entry in gt_groundwater], - ) - - @staticmethod - def _find_matching_layer( - layer: LayerPrediction, unmatched_layers: list[dict] - ) -> tuple[dict, bool] | tuple[None, None]: - """Find the matching layer in the ground truth. - - Args: - layer (LayerPrediction): The layer to match. - unmatched_layers (list[dict]): The layers from the ground truth that were not yet matched during the - current evaluation. - - Returns: - tuple[dict, bool] | tuple[None, None]: The matching layer and a boolean indicating if the depth interval - is correct. None if no match was found. - """ - parsed_text = parse_text(layer.material_description.text) - possible_matches = [ - ground_truth_layer - for ground_truth_layer in unmatched_layers - if Levenshtein.ratio(parsed_text, ground_truth_layer["material_description"]) > 0.9 - ] - - if not possible_matches: - return None, None - - for possible_match in possible_matches: - start = possible_match["depth_interval"]["start"] - end = possible_match["depth_interval"]["end"] - - if layer.depth_interval is None: - pass - - elif ( - start == 0 and layer.depth_interval.start is None and end == layer.depth_interval.end.value - ): # If not specified differently, we start at 0. - unmatched_layers.remove(possible_match) - return possible_match, True - - elif ( # noqa: SIM102 - layer.depth_interval.start is not None and layer.depth_interval.end is not None - ): # In all other cases we do not allow a None value. - if start == layer.depth_interval.start.value and end == layer.depth_interval.end.value: - unmatched_layers.remove(possible_match) - return possible_match, True - - match = max(possible_matches, key=lambda x: Levenshtein.ratio(parsed_text, x["material_description"])) - unmatched_layers.remove(match) - return match, False - def to_json(self) -> dict: """Converts the object to a dictionary. @@ -218,7 +107,7 @@ def to_json(self) -> dict: ], "page_dimensions": self.metadata.page_dimensions, # TODO: This should be removed. As already in metadata. - "groundwater": [entry.to_json() for entry in self.groundwater_entries], + "groundwater": [entry.to_json() for entry in self.groundwater.groundwater], "file_name": self.file_name, } } @@ -260,9 +149,23 @@ def to_json(self): """ return {file_prediction.file_name: file_prediction.to_json() for file_prediction in self.file_predictions_list} + def get_groundwater_entries(self) -> list[GroundwaterInDocument]: + """Get the groundwater extractions from the predictions. + + Returns: + List[GroundwaterInDocument]: The groundwater extractions. + """ + return [file_prediction.groundwater for file_prediction in self.file_predictions_list] + + ############################################################################################################ + ### Evaluation methods + ############################################################################################################ + def evaluate_metadata_extraction(self, ground_truth_path: Path) -> OverallBoreholeMetadataMetrics: """Evaluate the metadata extraction of the predictions against the ground truth. + # TODO: Move to evaluator class + Args: ground_truth_path (Path): The path to the ground truth file. """ @@ -271,3 +174,182 @@ def evaluate_metadata_extraction(self, ground_truth_path: Path) -> OverallBoreho for file_prediction in self.file_predictions_list: metadata_per_file.add_metadata(file_prediction.metadata) return MetadataEvaluator(metadata_per_file, ground_truth_path).evaluate() + + def evaluate_borehole_extraction(self, ground_truth_path: str) -> DatasetMetricsCatalog: + """Evaluate the borehole extraction predictions. + + Args: + ground_truth_path (str): The path to the ground truth file. + + Returns: + DatasetMetricsCatalogue: A DatasetMetricsCatalogue that maps a metrics name to the corresponding + DatasetMetrics object + """ + ############################################################################################################ + ### Load the ground truth data for the borehole extraction + ############################################################################################################ + ground_truth = None + if ground_truth_path and os.path.exists(ground_truth_path): # for inference no ground truth is available + ground_truth = GroundTruth(ground_truth_path) + else: + logging.warning("Ground truth file not found.") + + ############################################################################################################ + ### Evaluate the borehole extraction + ############################################################################################################ + number_of_truth_values = {} + for file_predictions in self.file_predictions_list: + # prediction_object = FilePredictions.create_from_json(file_predictions, file_predictions.file_name) + + # predictions_objects[file_name] = prediction_object + if ground_truth: + ground_truth_for_file = ground_truth.for_file(file_predictions.file_name) + if ground_truth_for_file: + file_predictions.evaluate(ground_truth_for_file) + number_of_truth_values[file_predictions.file_name] = len(ground_truth_for_file["layers"]) + + if number_of_truth_values: + all_metrics = self.evaluate_layer_extraction(number_of_truth_values) + + overall_groundwater_metrics = GroundwaterEvaluator( + self.get_groundwater_entries(), ground_truth_path + ).evaluate() + all_metrics.metrics["groundwater"] = overall_groundwater_metrics.groundwater_metrics_to_dataset_metrics() + all_metrics.metrics["groundwater_depth"] = ( + overall_groundwater_metrics.groundwater_depth_metrics_to_dataset_metrics() + ) + return all_metrics + else: + logger.warning("Ground truth file not found. Skipping evaluation.") + return None + + def evaluate_layer_extraction(self, number_of_truth_values: dict) -> DatasetMetricsCatalog: + """Calculate F1, precision and recall for the predictions. + + Calculate F1, precision and recall for the individual documents as well as overall. + The individual document metrics are returned as a DataFrame. + + Args: + predictions (dict): The FilePredictions objects. + number_of_truth_values (dict): The number of layer ground truth values per file. + + Returns: + DatasetMetricsCatalogue: A dictionary that maps a metrics name to the corresponding DatasetMetrics object + """ + all_metrics = DatasetMetricsCatalog() + all_metrics.metrics["layer"] = get_layer_metrics(self, number_of_truth_values) + all_metrics.metrics["depth_interval"] = get_depth_interval_metrics(self) + + # create predictions by language + predictions_by_language = { + "de": OverallFilePredictions(), + "fr": OverallFilePredictions(), + } # TODO: make this dynamic and why is this hardcoded? + for file_predictions in self.file_predictions_list: + language = file_predictions.metadata.language + if language in predictions_by_language: + predictions_by_language[language].add_file_predictions(file_predictions) + + for language, language_predictions in predictions_by_language.items(): + language_number_of_truth_values = { + prediction.file_name: number_of_truth_values[prediction.file_name] + for prediction in language_predictions.file_predictions_list + } + all_metrics.metrics[f"{language}_layer"] = get_layer_metrics( + language_predictions, language_number_of_truth_values + ) + all_metrics.metrics[f"{language}_depth_interval"] = get_depth_interval_metrics(language_predictions) + + logging.info("Macro avg:") + logging.info( + "F1: %.1f%%, precision: %.1f%%, recall: %.1f%%, depth_interval_accuracy: %.1f%%", + all_metrics.metrics["layer"].macro_f1() * 100, + all_metrics.metrics["layer"].macro_precision() * 100, + all_metrics.metrics["layer"].macro_recall() * 100, + all_metrics.metrics["depth_interval"].macro_precision() * 100, + ) + + return all_metrics + + def get_metrics(self, field_key: str, field_name: str) -> DatasetMetrics: + """Get the metrics for a specific field in the predictions. + + Args: + predictions (dict): The FilePredictions objects. + field_key (str): The key to access the specific field in the prediction objects. + field_name (str): The name of the field being evaluated. + + Returns: + DatasetMetrics: The requested DatasetMetrics object. + """ + dataset_metrics = DatasetMetrics() + + for file_prediction in self.file_predictions_list: + dataset_metrics.metrics[file_prediction.file_name] = getattr(file_prediction, field_key)[field_name] + + return dataset_metrics + + +def get_layer_metrics(predictions: OverallFilePredictions, number_of_truth_values: dict) -> DatasetMetrics: + """Calculate F1, precision and recall for the layer predictions. + + Calculate F1, precision and recall for the individual documents as well as overall. + + # TODO: Try to mode this to the LayerPrediction class + + Args: + predictions (dict): The predictions. + number_of_truth_values (dict): The number of ground truth values per file. + + Returns: + DatasetMetrics: the metrics for the layers + """ + layer_metrics = DatasetMetrics() + + for file_prediction in predictions.file_predictions_list: + hits = 0 + for layer in file_prediction.layers: + if layer.material_is_correct: + hits += 1 + if parse_text(layer.material_description.text) == "": + logger.warning("Empty string found in predictions") + layer_metrics.metrics[file_prediction.file_name] = Metrics( + tp=hits, fp=len(file_prediction.layers) - hits, fn=number_of_truth_values[file_prediction.file_name] - hits + ) + + return layer_metrics + + +def get_depth_interval_metrics(predictions: OverallFilePredictions) -> DatasetMetrics: + """Calculate F1, precision and recall for the depth interval predictions. + + # TODO: Try to mode this to the LayerPrediction class + + Calculate F1, precision and recall for the individual documents as well as overall. + + Depth interval accuracy is not calculated for layers with incorrect material predictions. + + Args: + predictions (dict): The predictions. + + Returns: + DatasetMetrics: the metrics for the depth intervals + """ + depth_interval_metrics = DatasetMetrics() + + for file_prediction in predictions.file_predictions_list: + depth_interval_hits = 0 + depth_interval_occurrences = 0 + for layer in file_prediction.layers: + if layer.material_is_correct: + if layer.depth_interval_is_correct is not None: + depth_interval_occurrences += 1 + if layer.depth_interval_is_correct: + depth_interval_hits += 1 + + if depth_interval_occurrences > 0: + depth_interval_metrics.metrics[file_prediction.file_name] = Metrics( + tp=depth_interval_hits, fp=depth_interval_occurrences - depth_interval_hits, fn=0 + ) + + return depth_interval_metrics