diff --git a/src/stratigraphy/benchmark/metrics.py b/src/stratigraphy/benchmark/metrics.py new file mode 100644 index 00000000..45688d95 --- /dev/null +++ b/src/stratigraphy/benchmark/metrics.py @@ -0,0 +1,155 @@ +"""Classes for keeping track of metrics such as the F1-score, precision and recall.""" + +from collections.abc import Callable +from dataclasses import dataclass + +import pandas as pd + + +@dataclass +class Metrics: + """Computes F-score metrics. + + See also https://en.wikipedia.org/wiki/F-score + + Args: + tp (int): The true positive count + fp (int): The false positive count + fn (int): The false negative count + """ + + tp: int + fp: int + fn: int + + @property + def precision(self) -> float: + """Calculate the precision.""" + if self.tp + self.fp > 0: + return self.tp / (self.tp + self.fp) + else: + return 0 + + @property + def recall(self) -> float: + """Calculate the recall.""" + if self.tp + self.fn > 0: + return self.tp / (self.tp + self.fn) + else: + return 0 + + @property + def f1(self) -> float: + """Calculate the F1 score.""" + if self.precision + self.recall > 0: + return 2 * self.precision * self.recall / (self.precision + self.recall) + else: + return 0 + + +class DatasetMetrics: + """Keeps track of a particular metrics for all documents in a dataset.""" + + def __init__(self): + self.metrics: dict[str, Metrics] = {} + + def overall_metrics(self) -> Metrics: + """Can be used to compute micro averages.""" + return Metrics( + tp=sum(metric.tp for metric in self.metrics.values()), + fp=sum(metric.fp for metric in self.metrics.values()), + fn=sum(metric.fn for metric in self.metrics.values()), + ) + + def macro_f1(self) -> float: + """Compute the macro F1 score.""" + if self.metrics: + return sum([metric.f1 for metric in self.metrics.values()]) / len(self.metrics) + else: + return 0 + + def macro_precision(self) -> float: + """Compute the macro precision score.""" + if self.metrics: + return sum([metric.precision for metric in self.metrics.values()]) / len(self.metrics) + else: + return 0 + + def macro_recall(self) -> float: + """Compute the macro recall score.""" + if self.metrics: + return sum([metric.recall for metric in self.metrics.values()]) / len(self.metrics) + else: + return 0 + + def pseudo_macro_f1(self) -> float: + """Compute a "pseudo" macro F1 score by using the values of the macro precision and macro recall. + + TODO: we probably should not use this metric, and use the proper macro F1 score instead. + """ + if self.metrics and self.macro_precision() + self.macro_recall() > 0: + return 2 * self.macro_precision() * self.macro_recall() / (self.macro_precision() + self.macro_recall()) + else: + return 0 + + def to_dataframe(self, name: str, fn: Callable[[Metrics], float]) -> pd.DataFrame: + series = pd.Series({filename: fn(metric) for filename, metric in self.metrics.items()}) + return series.to_frame(name=name) + + +class DatasetMetricsCatalog: + """Keeps track of all different relevant metrics that are computed for a dataset.""" + + def __init__(self): + self.metrics: dict[str, DatasetMetrics] = {} + + def document_level_metrics_df(self) -> pd.DataFrame: + all_series = [ + self.metrics["layer"].to_dataframe("F1", lambda metric: metric.f1), + self.metrics["layer"].to_dataframe("precision", lambda metric: metric.precision), + self.metrics["layer"].to_dataframe("recall", lambda metric: metric.recall), + self.metrics["depth_interval"].to_dataframe("Depth_interval_accuracy", lambda metric: metric.precision), + self.metrics["layer"].to_dataframe("Number Elements", lambda metric: metric.tp + metric.fn), + self.metrics["layer"].to_dataframe("Number wrong elements", lambda metric: metric.fp + metric.fn), + self.metrics["coordinates"].to_dataframe("coordinates", lambda metric: metric.f1), + self.metrics["elevation"].to_dataframe("elevation", lambda metric: metric.f1), + self.metrics["groundwater"].to_dataframe("groundwater", lambda metric: metric.f1), + self.metrics["groundwater_depth"].to_dataframe("groundwater_depth", lambda metric: metric.f1), + ] + document_level_metrics = pd.DataFrame() + for series in all_series: + document_level_metrics = document_level_metrics.join(series, how="outer") + return document_level_metrics + + def metrics_dict(self) -> dict[str, float]: + coordinates_metrics = self.metrics["coordinates"].overall_metrics() + groundwater_metrics = self.metrics["groundwater"].overall_metrics() + groundwater_depth_metrics = self.metrics["groundwater_depth"].overall_metrics() + elevation_metrics = self.metrics["elevation"].overall_metrics() + + return { + "F1": self.metrics["layer"].pseudo_macro_f1(), + "recall": self.metrics["layer"].macro_recall(), + "precision": self.metrics["layer"].macro_precision(), + "depth_interval_accuracy": self.metrics["depth_interval"].macro_precision(), + "de_F1": self.metrics["de_layer"].pseudo_macro_f1(), + "de_recall": self.metrics["de_layer"].macro_recall(), + "de_precision": self.metrics["de_layer"].macro_precision(), + "de_depth_interval_accuracy": self.metrics["de_depth_interval"].macro_precision(), + "fr_F1": self.metrics["fr_layer"].pseudo_macro_f1(), + "fr_recall": self.metrics["fr_layer"].macro_recall(), + "fr_precision": self.metrics["fr_layer"].macro_precision(), + "fr_depth_interval_accuracy": self.metrics["fr_depth_interval"].macro_precision(), + "coordinate_f1": coordinates_metrics.f1, + "coordinate_recall": coordinates_metrics.recall, + "coordinate_precision": coordinates_metrics.precision, + "groundwater_f1": groundwater_metrics.f1, + "groundwater_recall": groundwater_metrics.recall, + "groundwater_precision": groundwater_metrics.precision, + "groundwater_depth_f1": groundwater_depth_metrics.f1, + "groundwater_depth_recall": groundwater_depth_metrics.recall, + "groundwater_depth_precision": groundwater_depth_metrics.precision, + "elevation_f1": elevation_metrics.f1, + "elevation_recall": elevation_metrics.recall, + "elevation_precision": elevation_metrics.precision, + } diff --git a/src/stratigraphy/benchmark/score.py b/src/stratigraphy/benchmark/score.py index 584024db..490f4b56 100644 --- a/src/stratigraphy/benchmark/score.py +++ b/src/stratigraphy/benchmark/score.py @@ -1,140 +1,89 @@ """Evaluate the predictions against the ground truth.""" +import json import logging import os -from collections import defaultdict from pathlib import Path -import pandas as pd from dotenv import load_dotenv from stratigraphy import DATAPATH from stratigraphy.benchmark.ground_truth import GroundTruth +from stratigraphy.benchmark.metrics import DatasetMetrics, DatasetMetricsCatalog, Metrics +from stratigraphy.util.draw import draw_predictions from stratigraphy.util.predictions import FilePredictions from stratigraphy.util.util import parse_text load_dotenv() mlflow_tracking = os.getenv("MLFLOW_TRACKING") == "True" # Checks whether MLFlow tracking is enabled - +logging.basicConfig(format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S") logger = logging.getLogger(__name__) -def f1(precision: float, recall: float) -> float: - """Calculate the F1 score. +def get_layer_metrics(predictions: dict, number_of_truth_values: dict) -> DatasetMetrics: + """Calculate F1, precision and recall for the layer predictions. + + Calculate F1, precision and recall for the individual documents as well as overall. Args: - precision (float): Precision. - recall (float): Recall. + predictions (dict): The predictions. + number_of_truth_values (dict): The number of ground truth values per file. Returns: - float: The F1 score. + DatasetMetrics: the metrics for the layers """ - if precision + recall > 0: - return 2 * precision * recall / (precision + recall) - else: - return 0 + layer_metrics = DatasetMetrics() + for filename, file_prediction in predictions.items(): + hits = 0 + for layer in file_prediction.layers: + if layer.material_is_correct: + hits += 1 + if parse_text(layer.material_description.text) == "": + logger.warning("Empty string found in predictions") + layer_metrics.metrics[filename] = Metrics( + tp=hits, fp=len(file_prediction.layers) - hits, fn=number_of_truth_values[filename] - hits + ) -def get_scores( - predictions: dict, number_of_truth_values: dict, return_document_level_metrics: bool -) -> dict | tuple[dict, pd.DataFrame]: - """Calculate F1, precision and recall for the predictions. + return layer_metrics + + +def get_depth_interval_metrics(predictions: dict) -> DatasetMetrics: + """Calculate F1, precision and recall for the depth interval predictions. Calculate F1, precision and recall for the individual documents as well as overall. - The individual document metrics are returned as a DataFrame. + + Depth interval accuracy is not calculated for layers with incorrect material predictions. Args: predictions (dict): The predictions. - number_of_truth_values (dict): The number of ground truth values per file. - return_document_level_metrics (bool): Whether to return the document level metrics. Returns: - tuple[dict, pd.DataFrame]: A tuple containing the overall F1, precision and recall as a dictionary and the - individual document metrics as a DataFrame. + DatasetMetrics: the metrics for the depth intervals """ - document_level_metrics = { - "document_name": [], - "F1": [], - "precision": [], - "recall": [], - "Depth_interval_accuracy": [], - "Number Elements": [], - "Number wrong elements": [], - } - # separate list to calculate the overall depth interval accuracy is required, - # as the depth interval accuracy is not calculated for documents with no correct - # material predictions. - depth_interval_accuracies = [] + depth_interval_metrics = DatasetMetrics() + for filename, file_prediction in predictions.items(): - hits = 0 depth_interval_hits = 0 depth_interval_occurences = 0 for layer in file_prediction.layers: if layer.material_is_correct: - hits += 1 + if layer.depth_interval_is_correct is not None: + depth_interval_occurences += 1 if layer.depth_interval_is_correct: depth_interval_hits += 1 - depth_interval_occurences += 1 - elif layer.depth_interval_is_correct is not None: - depth_interval_occurences += 1 - if parse_text(layer.material_description.text) == "": - logger.warning("Empty string found in predictions") - tp = hits - fp = len(file_prediction.layers) - tp - fn = number_of_truth_values[filename] - tp - - if tp: - precision = tp / (tp + fp) - recall = tp / (tp + fn) - else: - precision = 0 - recall = 0 - document_level_metrics["document_name"].append(filename) - document_level_metrics["precision"].append(precision) - document_level_metrics["recall"].append(recall) - document_level_metrics["F1"].append(f1(precision, recall)) - document_level_metrics["Number Elements"].append(number_of_truth_values[filename]) - document_level_metrics["Number wrong elements"].append(fn + fp) - try: - document_level_metrics["Depth_interval_accuracy"].append(depth_interval_hits / depth_interval_occurences) - depth_interval_accuracies.append(depth_interval_hits / depth_interval_occurences) - except ZeroDivisionError: - document_level_metrics["Depth_interval_accuracy"].append(None) - - if len(document_level_metrics["precision"]): - overall_precision = sum(document_level_metrics["precision"]) / len(document_level_metrics["precision"]) - overall_recall = sum(document_level_metrics["recall"]) / len(document_level_metrics["recall"]) - try: - overall_depth_interval_accuracy = sum(depth_interval_accuracies) / len(depth_interval_accuracies) - except ZeroDivisionError: - overall_depth_interval_accuracy = None - else: - overall_precision = 0 - overall_recall = 0 - - if overall_depth_interval_accuracy is None: - overall_depth_interval_accuracy = 0 - - if return_document_level_metrics: - return { - "F1": f1(overall_precision, overall_recall), - "precision": overall_precision, - "recall": overall_recall, - "depth_interval_accuracy": overall_depth_interval_accuracy, - }, pd.DataFrame(document_level_metrics) - else: - return { - "F1": f1(overall_precision, overall_recall), - "precision": overall_precision, - "recall": overall_recall, - "depth_interval_accuracy": overall_depth_interval_accuracy, - } + if depth_interval_occurences > 0: + depth_interval_metrics.metrics[filename] = Metrics( + tp=depth_interval_hits, fp=depth_interval_occurences - depth_interval_hits, fn=0 + ) + + return depth_interval_metrics def evaluate_borehole_extraction( predictions: dict[str, FilePredictions], number_of_truth_values: dict -) -> tuple[dict, pd.DataFrame]: +) -> DatasetMetricsCatalog: """Evaluate the borehole extraction predictions. Args: @@ -142,33 +91,18 @@ def evaluate_borehole_extraction( number_of_truth_values (dict): The number of layer ground truth values per file. Returns: - tuple[dict, pd.DataFrame]: A tuple containing the overall metrics as a dictionary and the - individual document metrics as a DataFrame. + DatasetMetricsCatalogue: A DatasetMetricsCatalogue that maps a metrics name to the corresponding DatasetMetrics + object """ - layer_metrics, layer_document_level_metrics = evaluate_layer_extraction(predictions, number_of_truth_values) - ( - metadata_metrics, - document_level_metrics_metadata, - ) = evaluate_metadata(predictions) - ( - metrics_groundwater, - document_level_metrics_groundwater, - document_level_metrics_groundwater_depth, - ) = evaluate_groundwater(predictions) - metrics = {**layer_metrics, **metadata_metrics, **metrics_groundwater} - document_level_metrics = pd.merge( - layer_document_level_metrics, document_level_metrics_metadata, on="document_name", how="outer" - ) - document_level_metrics = pd.merge( - document_level_metrics, document_level_metrics_groundwater, on="document_name", how="outer" - ) - document_level_metrics = pd.merge( - document_level_metrics, document_level_metrics_groundwater_depth, on="document_name", how="outer" - ) - return metrics, document_level_metrics + all_metrics = evaluate_layer_extraction(predictions, number_of_truth_values) + all_metrics.metrics["coordinates"] = get_metrics(predictions, "metadata_is_correct", "coordinates") + all_metrics.metrics["elevation"] = get_metrics(predictions, "metadata_is_correct", "elevation") + all_metrics.metrics["groundwater"] = get_metrics(predictions, "groundwater_is_correct", "groundwater") + all_metrics.metrics["groundwater_depth"] = get_metrics(predictions, "groundwater_is_correct", "groundwater_depth") + return all_metrics -def get_metrics(predictions: dict[str, FilePredictions], field_key: str, field_name: str) -> dict: +def get_metrics(predictions: dict[str, FilePredictions], field_key: str, field_name: str) -> DatasetMetrics: """Get the metrics for a specific field in the predictions. Args: @@ -177,125 +111,17 @@ def get_metrics(predictions: dict[str, FilePredictions], field_key: str, field_n field_name (str): The name of the field being evaluated. Returns: - dict: The document level metrics and overall metrics. + DatasetMetrics: The requested DatasetMetrics object. """ - document_level_metrics = { - "document_name": [], - field_name: [], - } - - tp = 0 # correct prediction - fn = 0 # no predictions, i.e. None - fp = 0 # wrong prediction + dataset_metrics = DatasetMetrics() for file_name, file_prediction in predictions.items(): - is_correct = getattr(file_prediction, field_key)[field_name] - tp += is_correct["tp"] - fp += is_correct["fp"] - fn += is_correct["fn"] - document_level_metrics["document_name"].append(file_name) - - try: - precision = is_correct["tp"] / (is_correct["tp"] + is_correct["fp"]) - except ZeroDivisionError: - precision = 0 - try: - recall = is_correct["tp"] / (is_correct["tp"] + is_correct["fn"]) - except ZeroDivisionError: - recall = 0 - document_level_metrics[field_name].append(f1(precision, recall)) - - try: - precision = tp / (tp + fp) - except ZeroDivisionError: - precision = 0 - try: - recall = tp / (tp + fn) - except ZeroDivisionError: - recall = 0 - - metrics = { - f"{field_name}_precision": precision, - f"{field_name}_recall": recall, - f"{field_name}_f1": f1(precision, recall), - f"{field_name}_tp": tp, - f"{field_name}_fp": fp, - f"{field_name}_fn": fn, - } - - return document_level_metrics, metrics - - -def get_metadata_metrics(predictions: dict[str, FilePredictions], metadata_field: str) -> dict: - """Get the metadata metrics.""" - return get_metrics(predictions, "metadata_is_correct", metadata_field) - - -def get_groundwater_metrics(predictions: dict[str, FilePredictions], metadata_field: str) -> dict: - """Get the groundwater information metrics.""" - return get_metrics(predictions, "groundwater_is_correct", metadata_field) - - -def evaluate_groundwater(predictions: dict[str, FilePredictions]) -> tuple[dict, pd.DataFrame]: - """Evaluate the groundwater information predictions. - - Args: - predictions (dict): The FilePredictions objects. - - Returns: - tuple[dict, pd.DataFrame]: The overall groundwater information accuracy and the individual document metrics as - a DataFrame. - """ - document_level_metrics_groundwater, metrics_groundwater = get_groundwater_metrics(predictions, "groundwater") - document_level_metrics_groundwater_depth, metrics_groundwater_depth = get_groundwater_metrics( - predictions, "groundwater_depth" - ) - - metrics_groundwater.update(metrics_groundwater_depth) - - return ( - metrics_groundwater, - pd.DataFrame(document_level_metrics_groundwater), - pd.DataFrame(document_level_metrics_groundwater_depth), - ) - - -def evaluate_metadata(predictions: dict[str, FilePredictions]) -> tuple[dict, pd.DataFrame]: - """Evaluate the metadata predictions. - - Args: - predictions (dict): The FilePredictions objects. + dataset_metrics.metrics[file_name] = getattr(file_prediction, field_key)[field_name] - Returns: - tuple[dict, pd.DataFrame]: The overall coordinate metrics as a DataFrame. - """ - document_level_metrics_coordinates, metrics_coordinates = get_metadata_metrics(predictions, "coordinates") - document_level_metrics_elevation, metrics_elevation = get_metadata_metrics(predictions, "elevation") - metrics = { - "coordinate_precision": metrics_coordinates["coordinates_precision"], - "coordinate_recall": metrics_coordinates["coordinates_recall"], - "coordinate_f1": metrics_coordinates["coordinates_f1"], - "coordinates_tp": metrics_coordinates["coordinates_tp"], - "coordinates_fp": metrics_coordinates["coordinates_fp"], - "coordinates_fn": metrics_coordinates["coordinates_fn"], - "elevation_precision": metrics_elevation["elevation_precision"], - "elevation_recall": metrics_elevation["elevation_recall"], - "elevation_f1": metrics_elevation["elevation_f1"], - "elevation_tp": metrics_elevation["elevation_tp"], - "elevation_fp": metrics_elevation["elevation_fp"], - "elevation_fn": metrics_elevation["elevation_fn"], - } - document_level_metrics_metadata = pd.merge( - pd.DataFrame(document_level_metrics_coordinates), - pd.DataFrame(document_level_metrics_elevation), - on="document_name", - how="outer", - ) + return dataset_metrics - return (metrics, document_level_metrics_metadata) - -def evaluate_layer_extraction(predictions: dict, number_of_truth_values: dict) -> tuple[dict, pd.DataFrame]: +def evaluate_layer_extraction(predictions: dict, number_of_truth_values: dict) -> DatasetMetricsCatalog: """Calculate F1, precision and recall for the predictions. Calculate F1, precision and recall for the individual documents as well as overall. @@ -306,48 +132,43 @@ def evaluate_layer_extraction(predictions: dict, number_of_truth_values: dict) - number_of_truth_values (dict): The number of layer ground truth values per file. Returns: - tuple[dict, pd.DataFrame]: A tuple containing the overall F1, precision and recall as a dictionary and the - individual document metrics as a DataFrame. + DatasetMetricsCatalogue: A dictionary that maps a metrics name to the corresponding DatasetMetrics object """ - metrics = {} - metrics["all"], document_level_metrics = get_scores( - predictions, number_of_truth_values, return_document_level_metrics=True - ) + all_metrics = DatasetMetricsCatalog() + all_metrics.metrics["layer"] = get_layer_metrics(predictions, number_of_truth_values) + all_metrics.metrics["depth_interval"] = get_depth_interval_metrics(predictions) + # create predictions by language - predictions_by_language = defaultdict(dict) + predictions_by_language = {"de": {}, "fr": {}} for file_name, file_predictions in predictions.items(): language = file_predictions.language - predictions_by_language[language][file_name] = file_predictions + if language in predictions_by_language: + predictions_by_language[language][file_name] = file_predictions for language, language_predictions in predictions_by_language.items(): language_number_of_truth_values = { file_name: number_of_truth_values[file_name] for file_name in language_predictions } - metrics[language] = get_scores( - language_predictions, language_number_of_truth_values, return_document_level_metrics=False + all_metrics.metrics[f"{language}_layer"] = get_layer_metrics( + language_predictions, language_number_of_truth_values ) + all_metrics.metrics[f"{language}_depth_interval"] = get_depth_interval_metrics(language_predictions) logging.info("Macro avg:") logging.info( - f"F1: {metrics['all']['F1']:.1%}, " - f"precision: {metrics['all']['precision']:.1%}, recall: {metrics['all']['recall']:.1%}, " - f"depth_interval_accuracy: {metrics['all']['depth_interval_accuracy']:.1%}" + f"F1: {all_metrics.metrics['layer'].macro_f1():.1%}, " + f"precision: {all_metrics.metrics['layer'].macro_precision():.1%}, " + f"recall: {all_metrics.metrics['layer'].macro_recall():.1%}, " + f"depth_interval_accuracy: {all_metrics.metrics['depth_interval'].macro_precision():.1%}" ) - _metrics = {} - for language, language_metrics in metrics.items(): - for metric_type, value in language_metrics.items(): - if language == "all": - _metrics[metric_type] = value - else: - _metrics[f"{language}_{metric_type}"] = value - return _metrics, document_level_metrics + return all_metrics def create_predictions_objects( predictions: dict, ground_truth_path: Path | None, -) -> tuple[dict[FilePredictions], dict]: +) -> tuple[dict[str, FilePredictions], dict]: """Create predictions objects from the predictions and evaluate them against the ground truth. Args: @@ -355,7 +176,8 @@ def create_predictions_objects( ground_truth_path (Path | None): The path to the ground truth file. Returns: - tuple[dict[FilePredictions], dict]: The predictions objects and the number of ground truth values per file. + tuple[dict[str, FilePredictions], dict]: The predictions objects and the number of ground truth values per + file. """ if ground_truth_path and os.path.exists(ground_truth_path): # for inference no ground truth is available ground_truth = GroundTruth(ground_truth_path) @@ -379,23 +201,54 @@ def create_predictions_objects( return predictions_objects, number_of_truth_values +def evaluate( + predictions, + ground_truth_path: Path, + temp_directory: Path, + input_directory: Path | None, + draw_directory: Path | None, +): + """Computes all the metrics, logs them, and creates corresponding MLFlow artifacts (when enabled).""" + predictions, number_of_truth_values = create_predictions_objects(predictions, ground_truth_path) + + if input_directory and draw_directory: + draw_predictions(predictions, input_directory, draw_directory) + + if number_of_truth_values: # only evaluate if ground truth is available + metrics = evaluate_borehole_extraction(predictions, number_of_truth_values) + + metrics.document_level_metrics_df().to_csv( + temp_directory / "document_level_metrics.csv", index_label="document_name" + ) # mlflow.log_artifact expects a file + + metrics_dict = metrics.metrics_dict() + logger.info("Performance metrics: %s", metrics_dict) + + if mlflow_tracking: + import mlflow + + mlflow.log_metrics(metrics_dict) + mlflow.log_artifact(temp_directory / "document_level_metrics.csv") + else: + logger.warning("Ground truth file not found. Skipping evaluation.") + + if __name__ == "__main__": # setup mlflow tracking; should be started before any other code # such that tracking is enabled in other parts of the code. - # This does not create any scores, but will logg all the created images to mlflow. + # This does not create any scores, but will log all the created images to mlflow. if mlflow_tracking: import mlflow mlflow.set_experiment("Boreholes Stratigraphy") mlflow.start_run() - # instantiate all paths - input_directory = DATAPATH / "Benchmark" - ground_truth_path = input_directory / "ground_truth.json" - out_directory = input_directory / "evaluation" - predictions_path = input_directory / "extract" / "predictions.json" + # TODO: make configurable + ground_truth_path = DATAPATH.parent.parent / "data" / "zurich_ground_truth.json" + predictions_path = DATAPATH / "output" / "predictions.json" + temp_directory = DATAPATH / "_temp" - # evaluate the predictions - metrics, document_level_metrics = evaluate_borehole_extraction( - predictions_path, ground_truth_path, input_directory, out_directory - ) + with open(predictions_path, encoding="utf8") as file: + predictions = json.load(file) + + evaluate(predictions, ground_truth_path, temp_directory, input_directory=None, draw_directory=None) diff --git a/src/stratigraphy/coordinates/coordinate_extraction.py b/src/stratigraphy/coordinates/coordinate_extraction.py index 30c04d75..f04085a0 100644 --- a/src/stratigraphy/coordinates/coordinate_extraction.py +++ b/src/stratigraphy/coordinates/coordinate_extraction.py @@ -266,7 +266,9 @@ def _match_text_with_rect( results.append((match, rect)) return results - def extract_coordinates_from_bbox(self, page: fitz.Page, page_number: int, bbox: fitz.Rect) -> Coordinate | None: + def extract_coordinates_from_bbox( + self, page: fitz.Page, page_number: int, bbox: fitz.Rect | None = None + ) -> Coordinate | None: """Extracts the coordinates from a borehole profile. Processes the borehole profile page by page and tries to find the coordinates in the respective text of the @@ -310,4 +312,4 @@ def extract_coordinates(self) -> Coordinate | None: for page in self.doc: page_number = page.number + 1 # page.number is 0-based - return self.extract_coordinates_from_bbox(page, page_number, page.rect) + return self.extract_coordinates_from_bbox(page, page_number) diff --git a/src/stratigraphy/elevation/elevation_extraction.py b/src/stratigraphy/elevation/elevation_extraction.py index 1dae5a1a..1c51c8fc 100644 --- a/src/stratigraphy/elevation/elevation_extraction.py +++ b/src/stratigraphy/elevation/elevation_extraction.py @@ -166,13 +166,13 @@ def get_elevation_from_lines(self, lines: list[TextLine], page: int) -> Elevatio raise ValueError("Could not extract all required information from the lines provided.") def extract_elevation_from_bbox( - self, pdf_page: fitz.Page, page_number: int, bbox: fitz.Rect + self, pdf_page: fitz.Page, page_number: int, bbox: fitz.Rect | None = None ) -> ElevationInformation | None: """Extract the elevation information from a bounding box. Args: pdf_page (fitz.Page): The PDF page. - bbox (fitz.Rect): The bounding box. + bbox (fitz.Rect | None): The bounding box. page_number (int): The page number. Returns: @@ -197,4 +197,4 @@ def extract_elevation(self) -> ElevationInformation | None: for page in self.doc: page_number = page.number + 1 # page.number is 0-based - return self.extract_elevation_from_bbox(page, page_number, page.rect) + return self.extract_elevation_from_bbox(page, page_number) diff --git a/src/stratigraphy/main.py b/src/stratigraphy/main.py index f41662cd..779d1349 100644 --- a/src/stratigraphy/main.py +++ b/src/stratigraphy/main.py @@ -11,13 +11,12 @@ from tqdm import tqdm from stratigraphy import DATAPATH -from stratigraphy.benchmark.score import create_predictions_objects, evaluate_borehole_extraction +from stratigraphy.benchmark.score import evaluate from stratigraphy.coordinates.coordinate_extraction import CoordinateExtractor from stratigraphy.elevation.elevation_extraction import ElevationExtractor from stratigraphy.extract import process_page from stratigraphy.groundwater.groundwater_extraction import GroundwaterLevelExtractor from stratigraphy.line_detection import extract_lines, line_detection_params -from stratigraphy.util.draw import draw_predictions from stratigraphy.util.duplicate_detection import remove_duplicate_layers from stratigraphy.util.extract_text import extract_text_lines from stratigraphy.util.language_detection import detect_language_of_document @@ -111,7 +110,7 @@ def start_pipeline( predictions_path: Path, skip_draw_predictions: bool = False, draw_lines: bool = False, -) -> list[dict]: +): """Run the boreholes data extraction pipeline. The pipeline will extract material description of all found layers and assign them to the corresponding @@ -128,9 +127,6 @@ def start_pipeline( predictions_path (Path): The path to the predictions file. skip_draw_predictions (bool, optional): Whether to skip drawing predictions on pdf pages. Defaults to False. draw_lines (bool, optional): Whether to draw lines on pdf pages. Defaults to False. - - Returns: - list[dict]: The predictions of the pipeline. """ # noqa: D301 if mlflow_tracking: import mlflow @@ -245,29 +241,9 @@ def start_pipeline( with open(predictions_path, "w", encoding="utf8") as file: json.dump(predictions, file, ensure_ascii=False) - # evaluate the predictions; if file does not exist, the predictions are not changed. - predictions, number_of_truth_values = create_predictions_objects(predictions, ground_truth_path) - - if not skip_draw_predictions: - draw_predictions(predictions, input_directory, draw_directory) - - if number_of_truth_values: # only evaluate if ground truth is available - metrics, document_level_metrics = evaluate_borehole_extraction(predictions, number_of_truth_values) - document_level_metrics.to_csv( - temp_directory / "document_level_metrics.csv" - ) # mlflow.log_artifact expects a file - - # print the metrics - logger.info("Performance metrics:") - logger.info(metrics) - - if mlflow_tracking: - mlflow.log_metrics(metrics) - mlflow.log_artifact(temp_directory / "document_level_metrics.csv") - else: - logger.warning("Ground truth file not found. Skipping evaluation.") - - return predictions + if skip_draw_predictions: + draw_directory = None + evaluate(predictions, ground_truth_path, temp_directory, input_directory, draw_directory) if __name__ == "__main__": diff --git a/src/stratigraphy/util/draw.py b/src/stratigraphy/util/draw.py index da8d56f4..981a5952 100644 --- a/src/stratigraphy/util/draw.py +++ b/src/stratigraphy/util/draw.py @@ -7,6 +7,7 @@ import fitz from dotenv import load_dotenv +from stratigraphy.benchmark.metrics import Metrics from stratigraphy.coordinates.coordinate_extraction import Coordinate from stratigraphy.elevation.elevation_extraction import ElevationInformation from stratigraphy.groundwater.groundwater_extraction import GroundwaterInformationOnPage @@ -23,7 +24,7 @@ logger = logging.getLogger(__name__) -def draw_predictions(predictions: list[FilePredictions], directory: Path, out_directory: Path) -> None: +def draw_predictions(predictions: dict[str, FilePredictions], directory: Path, out_directory: Path) -> None: """Draw predictions on pdf pages. Draws various recognized information on the pdf pages present at directory and saves @@ -106,9 +107,9 @@ def draw_metadata( derotation_matrix: fitz.Matrix, rotation: float, coordinates: Coordinate | None, - coordinates_is_correct: bool, + coordinates_is_correct: Metrics, elevation_info: ElevationInformation | None, - elevation_is_correct: bool, + elevation_is_correct: Metrics, ) -> None: """Draw the extracted metadata on the top of the given PDF page. @@ -120,16 +121,16 @@ def draw_metadata( derotation_matrix (fitz.Matrix): The derotation matrix of the page. rotation (float): The rotation of the page. coordinates (Coordinate | None): The coordinate object to draw. - coordinates_is_correct (bool): Whether the coordinates are correct. + coordinates_is_correct (Metrics): Whether the coordinates are correct. elevation_info (ElevationInformation | None): The elevation information to draw. - elevation_is_correct (bool): Whether the elevation information is correct. + elevation_is_correct (Metrics): Whether the elevation information is correct. """ # TODO associate correctness with the extracted coordinates in a better way - coordinate_correct = coordinates_is_correct is not None and coordinates_is_correct["tp"] > 0 + coordinate_correct = coordinates_is_correct is not None and coordinates_is_correct.tp > 0 coordinate_color = "green" if coordinate_correct else "red" coordinate_rect = fitz.Rect([5, 5, 200, 25]) - elevation_correct = elevation_is_correct is not None and elevation_is_correct["tp"] > 0 + elevation_correct = elevation_is_correct is not None and elevation_is_correct.tp > 0 elevation_color = "green" if elevation_correct else "red" elevation_rect = fitz.Rect([5, 25, 200, 45]) diff --git a/src/stratigraphy/util/extract_text.py b/src/stratigraphy/util/extract_text.py index 52e6dde8..abea091c 100644 --- a/src/stratigraphy/util/extract_text.py +++ b/src/stratigraphy/util/extract_text.py @@ -17,24 +17,21 @@ def extract_text_lines(page: fitz.Page) -> list[TextLine]: Returns: list[TextLine]: A list of text lines. """ - return extract_text_lines_from_bbox(page, fitz.Rect(0, 0, page.rect.width, page.rect.height)) + return extract_text_lines_from_bbox(page, bbox=None) -def extract_text_lines_from_bbox(page: fitz.Page, bbox: fitz.Rect) -> list[TextLine]: +def extract_text_lines_from_bbox(page: fitz.Page, bbox: fitz.Rect | None) -> list[TextLine]: """Extract all text lines from the page. Sometimes, a single lines as identified by PyMuPDF, is still split into separate lines. Args: page (fitz.page): the page to extract text from - bbox (BoundingBox): the bounding box to extract text from + bbox (fitz.Rect | None): the bounding box to extract text from Returns: list[TextLine]: A list of text lines. """ - if not isinstance(bbox, fitz.Rect): - raise ValueError("The bbox parameter must be a fitz.Rect object.") - words = [] words_by_line = {} for x0, y0, x1, y1, word, block_no, line_no, _word_no in page.get_text("words", clip=bbox): diff --git a/src/stratigraphy/util/predictions.py b/src/stratigraphy/util/predictions.py index 59f15f87..6fa2f52c 100644 --- a/src/stratigraphy/util/predictions.py +++ b/src/stratigraphy/util/predictions.py @@ -9,6 +9,7 @@ import fitz import Levenshtein +from stratigraphy.benchmark.metrics import Metrics from stratigraphy.coordinates.coordinate_extraction import Coordinate from stratigraphy.elevation.elevation_extraction import ElevationInformation from stratigraphy.groundwater.groundwater_extraction import GroundwaterInformation, GroundwaterInformationOnPage @@ -230,15 +231,15 @@ def evaluate_metadata(self, metadata_ground_truth: dict): if (math.isclose(int(extracted_coordinates.east.coordinate_value), ground_truth_east, abs_tol=2)) and ( math.isclose(int(extracted_coordinates.north.coordinate_value), ground_truth_north, abs_tol=2) ): - self.metadata_is_correct["coordinates"] = {"tp": 1, "fp": 0, "fn": 0} + self.metadata_is_correct["coordinates"] = Metrics(tp=1, fp=0, fn=0) else: - self.metadata_is_correct["coordinates"] = {"tp": 0, "fp": 1, "fn": 1} + self.metadata_is_correct["coordinates"] = Metrics(tp=0, fp=1, fn=1) else: - self.metadata_is_correct["coordinates"] = { - "tp": 0, - "fp": 1 if extracted_coordinates is not None else 0, - "fn": 1 if ground_truth_coordinates is not None else 0, - } + self.metadata_is_correct["coordinates"] = Metrics( + tp=0, + fp=1 if extracted_coordinates is not None else 0, + fn=1 if ground_truth_coordinates is not None else 0, + ) ############################################################################################################ ### Compute the metadata correctness for the elevation. @@ -248,24 +249,22 @@ def evaluate_metadata(self, metadata_ground_truth: dict): if extracted_elevation is not None and ground_truth_elevation is not None: if math.isclose(extracted_elevation, ground_truth_elevation, abs_tol=0.1): - self.metadata_is_correct["elevation"] = {"tp": 1, "fp": 0, "fn": 0} + self.metadata_is_correct["elevation"] = Metrics(tp=1, fp=0, fn=0) else: - self.metadata_is_correct["elevation"] = {"tp": 0, "fp": 1, "fn": 1} + self.metadata_is_correct["elevation"] = Metrics(tp=0, fp=1, fn=1) else: - self.metadata_is_correct["elevation"] = { - "tp": 0, - "fp": 1 if extracted_elevation is not None else 0, - "fn": 1 if ground_truth_elevation is not None else 0, - } + self.metadata_is_correct["elevation"] = Metrics( + tp=0, fp=1 if extracted_elevation is not None else 0, fn=1 if ground_truth_elevation is not None else 0 + ) @staticmethod - def count_against_ground_truth(values: list, ground_truth: list) -> dict: + def count_against_ground_truth(values: list, ground_truth: list) -> Metrics: # Counter deals with duplicates when doing intersection values_counter = Counter(values) ground_truth_counter = Counter(ground_truth) tp = (values_counter & ground_truth_counter).total() # size of intersection - return {"tp": tp, "fp": len(values) - tp, "fn": len(ground_truth) - tp} + return Metrics(tp=tp, fp=len(values) - tp, fn=len(ground_truth) - tp) def evaluate_groundwater(self, groundwater_ground_truth: list): """Evaluate the groundwater information of the file against the ground truth.