diff --git a/.github/workflows/pipeline_run.yml b/.github/workflows/pipeline_run.yml index c424d046..746d644a 100644 --- a/.github/workflows/pipeline_run.yml +++ b/.github/workflows/pipeline_run.yml @@ -21,4 +21,7 @@ jobs: source env/bin/activate pip install -e . echo "Running pipeline" - boreholes-extract-all -l -i example/example_borehole_profile.pdf -o example/ -p example/predictions.json -m example/metadata.json -g example/example_groundtruth.json -pa all \ No newline at end of file + boreholes-extract-all -l -i example/example_borehole_profile.pdf -o example/ -p example/predictions.json -m example/metadata.json -g example/example_groundtruth.json -pa all + + echo "Running scoring script" + boreholes-score --ground-truth-path example/example_groundtruth.json --predictions-path example/predictions.json --no-mlflow-tracking diff --git a/.vscode/launch.json b/.vscode/launch.json index e3e5da10..11b61972 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -19,6 +19,16 @@ "justMyCode": true, "python": "${workspaceFolder}/swisstopo/bin/python3", }, + { + "name": "Python: Run scoring", + "type": "debugpy", + "request": "launch", + "module": "src.stratigraphy.benchmark.score", + "args": [], + "cwd": "${workspaceFolder}", + "justMyCode": true, + "python": "./swisstopo/bin/python3", + }, { "name": "Python: Run label studio to GT", "type": "debugpy", diff --git a/pyproject.toml b/pyproject.toml index 684fab22..5a2e6fb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ all = ["swissgeol-boreholes-dataextraction[test, lint, experiment-tracking, visu boreholes-extract-all = "stratigraphy.main:click_pipeline" boreholes-extract-metadata = "stratigraphy.main:click_pipeline_metadata" boreholes-download-profiles = "stratigraphy.get_files:download_directory_froms3" +boreholes-score = "stratigraphy.benchmark.score:main" [tool.ruff.lint] select = [ diff --git a/src/stratigraphy/annotations/draw.py b/src/stratigraphy/annotations/draw.py index 889f1ede..70c56228 100644 --- a/src/stratigraphy/annotations/draw.py +++ b/src/stratigraphy/annotations/draw.py @@ -7,13 +7,15 @@ import fitz import pandas as pd from dotenv import load_dotenv +from stratigraphy.depthcolumn.depthcolumn import DepthColumn +from stratigraphy.depths_materials_column_pairs.depths_materials_column_pairs import DepthsMaterialsColumnPairs from stratigraphy.groundwater.groundwater_extraction import GroundwaterInformationOnPage from stratigraphy.layer.layer import LayerPrediction from stratigraphy.metadata.coordinate_extraction import Coordinate from stratigraphy.metadata.elevation_extraction import Elevation from stratigraphy.text.textblock import TextBlock from stratigraphy.util.interval import BoundaryInterval -from stratigraphy.util.predictions import FilePredictions +from stratigraphy.util.predictions import OverallFilePredictions load_dotenv() @@ -25,7 +27,7 @@ def draw_predictions( - predictions: dict[str, FilePredictions], + predictions: OverallFilePredictions, directory: Path, out_directory: Path, document_level_metadata_metrics: pd.DataFrame, @@ -51,66 +53,78 @@ def draw_predictions( """ if directory.is_file(): # deal with the case when we pass a file instead of a directory directory = directory.parent - for file_name, file_prediction in predictions.items(): - logger.info("Drawing predictions for file %s", file_name) + for file_prediction in predictions.file_predictions_list: + logger.info("Drawing predictions for file %s", file_prediction.file_name) depths_materials_column_pairs = file_prediction.depths_materials_columns_pairs coordinates = file_prediction.metadata.coordinates elevation = file_prediction.metadata.elevation # Assess the correctness of the metadata - is_coordinates_correct = document_level_metadata_metrics.loc[file_name].coordinate - is_elevation_correct = document_level_metadata_metrics.loc[file_name].elevation - - with fitz.Document(directory / file_name) as doc: - for page_index, page in enumerate(doc): - page_number = page_index + 1 - shape = page.new_shape() # Create a shape object for drawing - if page_number == 1: - draw_metadata( + if file_prediction.file_name in document_level_metadata_metrics.index: + is_coordinates_correct = document_level_metadata_metrics.loc[file_prediction.file_name].coordinate + is_elevation_correct = document_level_metadata_metrics.loc[file_prediction.file_name].elevation + else: + logger.warning( + "Metrics for file %s not found in document_level_metadata_metrics.", file_prediction.file_name + ) + is_coordinates_correct = False + is_elevation_correct = False + + try: + with fitz.Document(directory / file_prediction.file_name) as doc: + for page_index, page in enumerate(doc): + page_number = page_index + 1 + shape = page.new_shape() # Create a shape object for drawing + if page_number == 1: + draw_metadata( + shape, + page.derotation_matrix, + page.rotation, + coordinates, + is_coordinates_correct, + elevation, + is_elevation_correct, + ) + if coordinates is not None and page_number == coordinates.page: + draw_coordinates(shape, coordinates) + if elevation is not None and page_number == elevation.page: + draw_elevation(shape, elevation) + for groundwater_entry in file_prediction.groundwater_entries: + if page_number == groundwater_entry.page: + draw_groundwater(shape, groundwater_entry) + draw_depth_columns_and_material_rect( shape, page.derotation_matrix, - page.rotation, - coordinates, - is_coordinates_correct, - elevation, - is_elevation_correct, + [pair for pair in depths_materials_column_pairs if pair.page == page_number], ) - if coordinates is not None and page_number == coordinates.page: - draw_coordinates(shape, coordinates) - if elevation is not None and page_number == elevation.page: - draw_elevation(shape, elevation) - for groundwater_entry in file_prediction.groundwater_entries: - if page_number == groundwater_entry.page: - draw_groundwater(shape, groundwater_entry) - draw_depth_columns_and_material_rect( - shape, - page.derotation_matrix, - [pair for pair in depths_materials_column_pairs if pair["page"] == page_number], - ) - draw_material_descriptions( - shape, - page.derotation_matrix, - [ - layer - for layer in file_prediction.layers - if layer.material_description.page_number == page_number - ], - ) - shape.commit() # Commit all the drawing operations to the page + draw_material_descriptions( + shape, + page.derotation_matrix, + [ + layer + for layer in file_prediction.layers + if layer.material_description.page_number == page_number + ], + ) + shape.commit() # Commit all the drawing operations to the page + + tmp_file_path = out_directory / f"{file_prediction.file_name}_page{page_number}.png" + fitz.utils.get_pixmap(page, matrix=fitz.Matrix(2, 2), clip=page.rect).save(tmp_file_path) - tmp_file_path = out_directory / f"{file_name}_page{page_number}.png" - fitz.utils.get_pixmap(page, matrix=fitz.Matrix(2, 2), clip=page.rect).save(tmp_file_path) + if mlflow_tracking: # This is only executed if MLFlow tracking is enabled + try: + import mlflow - if mlflow_tracking: # This is only executed if MLFlow tracking is enabled - try: - import mlflow + mlflow.log_artifact(tmp_file_path, artifact_path="pages") + except NameError: + logger.warning("MLFlow could not be imported. Skipping logging of artifact.") - mlflow.log_artifact(tmp_file_path, artifact_path="pages") - except NameError: - logger.warning("MLFlow could not be imported. Skipping logging of artifact.") + except (FileNotFoundError, fitz.FileDataError) as e: + logger.error("Error opening file %s: %s", file_prediction.file_name, e) + continue - logger.info("Finished drawing predictions for file %s", file_name) + logger.info("Finished drawing predictions for file %s", file_prediction.file_name) def draw_metadata( @@ -238,7 +252,7 @@ def draw_material_descriptions( def draw_depth_columns_and_material_rect( - shape: fitz.Shape, derotation_matrix: fitz.Matrix, depths_materials_column_pairs: list + shape: fitz.Shape, derotation_matrix: fitz.Matrix, depths_materials_column_pairs: list[DepthsMaterialsColumnPairs] ): """Draw depth columns as well as the material rects on a pdf page. @@ -253,17 +267,17 @@ def draw_depth_columns_and_material_rect( depths_materials_column_pairs (list): List of depth column entries. """ for pair in depths_materials_column_pairs: - depth_column = pair["depth_column"] - material_description_rect = pair["material_description_rect"] + depth_column: DepthColumn = pair.depth_column + material_description_rect = pair.material_description_rect - if depth_column is not None: # Draw rectangle for depth columns + if depth_column: # Draw rectangle for depth columns shape.draw_rect( - fitz.Rect(depth_column["rect"]) * derotation_matrix, + fitz.Rect(depth_column.rect()) * derotation_matrix, ) shape.finish(color=fitz.utils.getColor("green")) - for depth_column_entry in depth_column["entries"]: # Draw rectangle for depth column entries + for depth_column_entry in depth_column.entries: # Draw rectangle for depth column entries shape.draw_rect( - fitz.Rect(depth_column_entry["rect"]) * derotation_matrix, + fitz.Rect(depth_column_entry.rect) * derotation_matrix, ) shape.finish(color=fitz.utils.getColor("purple")) diff --git a/src/stratigraphy/benchmark/metrics.py b/src/stratigraphy/benchmark/metrics.py index e6409c67..60888726 100644 --- a/src/stratigraphy/benchmark/metrics.py +++ b/src/stratigraphy/benchmark/metrics.py @@ -1,5 +1,6 @@ """Classes for keeping track of metrics such as the F1-score, precision and recall.""" +from collections import defaultdict from collections.abc import Callable import pandas as pd @@ -52,6 +53,10 @@ def to_dataframe(self, name: str, fn: Callable[[Metrics], float]) -> pd.DataFram series = pd.Series({filename: fn(metric) for filename, metric in self.metrics.items()}) return series.to_frame(name=name) + def get_metrics_list(self) -> list[Metrics]: + """Return a list of all metrics.""" + return list(self.metrics.values()) + class DatasetMetricsCatalog: """Keeps track of all different relevant metrics that are computed for a dataset.""" @@ -78,26 +83,44 @@ def document_level_metrics_df(self) -> pd.DataFrame: def metrics_dict(self) -> dict[str, float]: """Return a dictionary with the overall metrics.""" - groundwater_metrics = Metrics.micro_average(self.metrics["groundwater"].metrics.values()) - groundwater_depth_metrics = Metrics.micro_average(self.metrics["groundwater_depth"].metrics.values()) - - return { - "F1": self.metrics["layer"].pseudo_macro_f1(), - "recall": self.metrics["layer"].macro_recall(), - "precision": self.metrics["layer"].macro_precision(), - "depth_interval_accuracy": self.metrics["depth_interval"].macro_precision(), - "de_F1": self.metrics["de_layer"].pseudo_macro_f1(), - "de_recall": self.metrics["de_layer"].macro_recall(), - "de_precision": self.metrics["de_layer"].macro_precision(), - "de_depth_interval_accuracy": self.metrics["de_depth_interval"].macro_precision(), - "fr_F1": self.metrics["fr_layer"].pseudo_macro_f1(), - "fr_recall": self.metrics["fr_layer"].macro_recall(), - "fr_precision": self.metrics["fr_layer"].macro_precision(), - "fr_depth_interval_accuracy": self.metrics["fr_depth_interval"].macro_precision(), - "groundwater_f1": groundwater_metrics.f1, - "groundwater_recall": groundwater_metrics.recall, - "groundwater_precision": groundwater_metrics.precision, - "groundwater_depth_f1": groundwater_depth_metrics.f1, - "groundwater_depth_recall": groundwater_depth_metrics.recall, - "groundwater_depth_precision": groundwater_depth_metrics.precision, - } + # Initialize a defaultdict to automatically return 0.0 for missing keys + result = defaultdict(lambda: None) + + # Safely compute groundwater metrics using .get() to avoid KeyErrors + groundwater_metrics = Metrics.micro_average( + self.metrics.get("groundwater", DatasetMetrics()).get_metrics_list() + ) + groundwater_depth_metrics = Metrics.micro_average( + self.metrics.get("groundwater_depth", DatasetMetrics()).get_metrics_list() + ) + + # Populate the basic metrics + result.update( + { + "F1": self.metrics.get("layer", DatasetMetrics()).pseudo_macro_f1(), + "recall": self.metrics.get("layer", DatasetMetrics()).macro_recall(), + "precision": self.metrics.get("layer", DatasetMetrics()).macro_precision(), + "depth_interval_accuracy": self.metrics.get("depth_interval", DatasetMetrics()).macro_precision(), + "groundwater_f1": groundwater_metrics.f1, + "groundwater_recall": groundwater_metrics.recall, + "groundwater_precision": groundwater_metrics.precision, + "groundwater_depth_f1": groundwater_depth_metrics.f1, + "groundwater_depth_recall": groundwater_depth_metrics.recall, + "groundwater_depth_precision": groundwater_depth_metrics.precision, + } + ) + + # Add dynamic language-specific metrics only if they exist + for lang in ["de", "fr"]: + layer_key = f"{lang}_layer" + depth_key = f"{lang}_depth_interval" + + if layer_key in self.metrics: + result[f"{lang}_F1"] = self.metrics[layer_key].pseudo_macro_f1() + result[f"{lang}_recall"] = self.metrics[layer_key].macro_recall() + result[f"{lang}_precision"] = self.metrics[layer_key].macro_precision() + + if depth_key in self.metrics: + result[f"{lang}_depth_interval_accuracy"] = self.metrics[depth_key].macro_precision() + + return dict(result) # Convert defaultdict back to a regular dict diff --git a/src/stratigraphy/benchmark/score.py b/src/stratigraphy/benchmark/score.py index 9a6cfd08..2493c073 100644 --- a/src/stratigraphy/benchmark/score.py +++ b/src/stratigraphy/benchmark/score.py @@ -1,5 +1,6 @@ """Evaluate the predictions against the ground truth.""" +import argparse import json import logging import os @@ -10,10 +11,7 @@ from stratigraphy.annotations.draw import draw_predictions from stratigraphy.benchmark.ground_truth import GroundTruth from stratigraphy.benchmark.metrics import DatasetMetrics, DatasetMetricsCatalog, Metrics -from stratigraphy.evaluation.evaluation_dataclasses import OverallBoreholeMetadataMetrics -from stratigraphy.evaluation.metadata_evaluator import MetadataEvaluator -from stratigraphy.metadata.metadata import BoreholeMetadataList -from stratigraphy.util.predictions import FilePredictions +from stratigraphy.util.predictions import OverallFilePredictions from stratigraphy.util.util import parse_text load_dotenv() @@ -23,13 +21,13 @@ logger = logging.getLogger(__name__) -def get_layer_metrics(predictions: dict, number_of_truth_values: dict) -> DatasetMetrics: +def get_layer_metrics(predictions: OverallFilePredictions, number_of_truth_values: dict) -> DatasetMetrics: """Calculate F1, precision and recall for the layer predictions. Calculate F1, precision and recall for the individual documents as well as overall. Args: - predictions (dict): The predictions. + predictions (OverallFilePredictions): The predictions. number_of_truth_values (dict): The number of ground truth values per file. Returns: @@ -37,21 +35,23 @@ def get_layer_metrics(predictions: dict, number_of_truth_values: dict) -> Datase """ layer_metrics = DatasetMetrics() - for filename, file_prediction in predictions.items(): + for file_prediction in predictions.file_predictions_list: hits = 0 for layer in file_prediction.layers: if layer.material_is_correct: hits += 1 if parse_text(layer.material_description.text) == "": logger.warning("Empty string found in predictions") - layer_metrics.metrics[filename] = Metrics( - tp=hits, fp=len(file_prediction.layers) - hits, fn=number_of_truth_values[filename] - hits + layer_metrics.metrics[file_prediction.file_name] = Metrics( + tp=hits, + fp=len(file_prediction.layers) - hits, + fn=number_of_truth_values.get(file_prediction.file_name, 0) - hits, ) return layer_metrics -def get_depth_interval_metrics(predictions: dict) -> DatasetMetrics: +def get_depth_interval_metrics(predictions: OverallFilePredictions) -> DatasetMetrics: """Calculate F1, precision and recall for the depth interval predictions. Calculate F1, precision and recall for the individual documents as well as overall. @@ -59,14 +59,14 @@ def get_depth_interval_metrics(predictions: dict) -> DatasetMetrics: Depth interval accuracy is not calculated for layers with incorrect material predictions. Args: - predictions (dict): The predictions. + predictions (OverallFilePredictions): The predictions. Returns: DatasetMetrics: the metrics for the depth intervals """ depth_interval_metrics = DatasetMetrics() - for filename, file_prediction in predictions.items(): + for file_prediction in predictions.file_predictions_list: depth_interval_hits = 0 depth_interval_occurrences = 0 for layer in file_prediction.layers: @@ -77,27 +77,20 @@ def get_depth_interval_metrics(predictions: dict) -> DatasetMetrics: depth_interval_hits += 1 if depth_interval_occurrences > 0: - depth_interval_metrics.metrics[filename] = Metrics( + depth_interval_metrics.metrics[file_prediction.file_name] = Metrics( tp=depth_interval_hits, fp=depth_interval_occurrences - depth_interval_hits, fn=0 ) return depth_interval_metrics -def evaluate_metadata_extraction( - borehole_metadata: BoreholeMetadataList, ground_truth_path: Path -) -> OverallBoreholeMetadataMetrics: - """Evaluate the metadata extraction.""" - return MetadataEvaluator(borehole_metadata, ground_truth_path).evaluate() - - def evaluate_borehole_extraction( - predictions: dict[str, FilePredictions], number_of_truth_values: dict + predictions: OverallFilePredictions, number_of_truth_values: dict ) -> DatasetMetricsCatalog: """Evaluate the borehole extraction predictions. Args: - predictions (dict): The FilePredictions objects. + predictions (OverallFilePredictions): The FilePredictions objects. number_of_truth_values (dict): The number of layer ground truth values per file. Returns: @@ -110,11 +103,11 @@ def evaluate_borehole_extraction( return all_metrics -def get_metrics(predictions: dict[str, FilePredictions], field_key: str, field_name: str) -> DatasetMetrics: +def get_metrics(predictions: OverallFilePredictions, field_key: str, field_name: str) -> DatasetMetrics: """Get the metrics for a specific field in the predictions. Args: - predictions (dict): The FilePredictions objects. + predictions (OverallFilePredictions): The FilePredictions objects. field_key (str): The key to access the specific field in the prediction objects. field_name (str): The name of the field being evaluated. @@ -123,20 +116,28 @@ def get_metrics(predictions: dict[str, FilePredictions], field_key: str, field_n """ dataset_metrics = DatasetMetrics() - for file_name, file_prediction in predictions.items(): - dataset_metrics.metrics[file_name] = getattr(file_prediction, field_key)[field_name] + for file_prediction in predictions.file_predictions_list: + attribute = getattr(file_prediction, field_key, None) + if attribute and field_name in attribute: + dataset_metrics.metrics[file_prediction.file_name] = attribute[field_name] + else: + logger.warning( + "Missing attribute '%s' or key '%s' in file '%s'", field_key, field_name, file_prediction.file_name + ) return dataset_metrics -def evaluate_layer_extraction(predictions: dict, number_of_truth_values: dict) -> DatasetMetricsCatalog: +def evaluate_layer_extraction( + predictions: OverallFilePredictions, number_of_truth_values: dict +) -> DatasetMetricsCatalog: """Calculate F1, precision and recall for the predictions. Calculate F1, precision and recall for the individual documents as well as overall. The individual document metrics are returned as a DataFrame. Args: - predictions (dict): The FilePredictions objects. + predictions (OverallFilePredictions): The OverallFilePredictions objects. number_of_truth_values (dict): The number of layer ground truth values per file. Returns: @@ -147,15 +148,17 @@ def evaluate_layer_extraction(predictions: dict, number_of_truth_values: dict) - all_metrics.metrics["depth_interval"] = get_depth_interval_metrics(predictions) # create predictions by language - predictions_by_language = {"de": {}, "fr": {}} - for file_name, file_predictions in predictions.items(): - language = file_predictions.language - if language in predictions_by_language: - predictions_by_language[language][file_name] = file_predictions + languages = set(fp.metadata.language for fp in predictions.file_predictions_list) + predictions_by_language = {language: OverallFilePredictions() for language in languages} + + for file_predictions in predictions.file_predictions_list: + language = file_predictions.metadata.language + predictions_by_language[language].add_file_predictions(file_predictions) for language, language_predictions in predictions_by_language.items(): language_number_of_truth_values = { - file_name: number_of_truth_values[file_name] for file_name in language_predictions + prediction.file_name: number_of_truth_values[prediction.file_name] + for prediction in language_predictions.file_predictions_list } all_metrics.metrics[f"{language}_layer"] = get_layer_metrics( language_predictions, language_number_of_truth_values @@ -175,19 +178,17 @@ def evaluate_layer_extraction(predictions: dict, number_of_truth_values: dict) - def create_predictions_objects( - predictions: dict, - metadata_per_file: BoreholeMetadataList, + predictions: OverallFilePredictions, ground_truth_path: Path | None, -) -> tuple[dict[str, FilePredictions], dict]: +) -> tuple[OverallFilePredictions, dict]: """Create predictions objects from the predictions and evaluate them against the ground truth. Args: - predictions (dict): The predictions from the predictions.json file. + predictions (OverallFilePredictions): The predictions objects. ground_truth_path (Path | None): The path to the ground truth file. - metadata_per_file (BoreholeMetadataList): The metadata for the files. Returns: - tuple[dict[str, FilePredictions], dict]: The predictions objects and the number of ground truth values per + tuple[OverallFilePredictions, dict]: The predictions objects and the number of ground truth values per file. """ if ground_truth_path and os.path.exists(ground_truth_path): # for inference no ground truth is available @@ -198,40 +199,41 @@ def create_predictions_objects( ground_truth_is_present = False number_of_truth_values = {} - predictions_objects = {} - for file_name, file_predictions in predictions.items(): - metadata = metadata_per_file.get_metadata(file_name) - if not metadata: - raise ValueError(f"Metadata for file {file_name} not found.") - - prediction_object = FilePredictions.create_from_json(file_predictions, metadata, file_name) - - predictions_objects[file_name] = prediction_object + for file_predictions in predictions.file_predictions_list: if ground_truth_is_present: - ground_truth_for_file = ground_truth.for_file(file_name) + ground_truth_for_file = ground_truth.for_file(file_predictions.file_name) if ground_truth_for_file: - predictions_objects[file_name].evaluate(ground_truth_for_file) - number_of_truth_values[file_name] = len(ground_truth_for_file["layers"]) + file_predictions.evaluate(ground_truth_for_file) + number_of_truth_values[file_predictions.file_name] = len(ground_truth_for_file["layers"]) - return predictions_objects, number_of_truth_values + return predictions, number_of_truth_values def evaluate( - predictions, - metadata_per_file: BoreholeMetadataList, + predictions: OverallFilePredictions, ground_truth_path: Path, temp_directory: Path, input_directory: Path | None, draw_directory: Path | None, -): - """Computes all the metrics, logs them, and creates corresponding MLFlow artifacts (when enabled).""" +) -> None: + """Computes all the metrics, logs them, and creates corresponding MLFlow artifacts (when enabled). + + Args: + predictions (OverallFilePredictions): The predictions objects. + ground_truth_path (Path): The path to the ground truth file. + temp_directory (Path): The path to the temporary directory. + input_directory (Path | None): The path to the input directory. + draw_directory (Path | None): The path to the draw directory. + + Returns: + None + """ ############################# # Evaluate the borehole extraction metadata ############################# - metadata_metrics_list = evaluate_metadata_extraction(metadata_per_file, ground_truth_path) + metadata_metrics_list = predictions.evaluate_metadata_extraction(ground_truth_path) metadata_metrics = metadata_metrics_list.get_cumulated_metrics() document_level_metadata_metrics = metadata_metrics_list.get_document_level_metrics() - document_level_metadata_metrics.to_csv( temp_directory / "document_level_metadata_metrics.csv", index_label="document_name" ) # mlflow.log_artifact expects a file @@ -250,14 +252,8 @@ def evaluate( # Evaluate the borehole extraction ############################# if predictions: - predictions, number_of_truth_values = create_predictions_objects( - predictions, metadata_per_file, ground_truth_path - ) + predictions, number_of_truth_values = create_predictions_objects(predictions, ground_truth_path) - if input_directory and draw_directory: - draw_predictions(predictions, input_directory, draw_directory, document_level_metadata_metrics) - - # evaluate the borehole extraction if number_of_truth_values: # only evaluate if ground truth is available metrics = evaluate_borehole_extraction(predictions, number_of_truth_values) @@ -277,24 +273,77 @@ def evaluate( else: logger.warning("Ground truth file not found. Skipping evaluation.") + ############################# + # Draw the prediction + ############################# + if input_directory and draw_directory: + draw_predictions(predictions, input_directory, draw_directory, document_level_metadata_metrics) + + +def main(): + """Main function to evaluate the predictions against the ground truth.""" + args = parse_cli() -if __name__ == "__main__": # setup mlflow tracking; should be started before any other code # such that tracking is enabled in other parts of the code. # This does not create any scores, but will log all the created images to mlflow. - if mlflow_tracking: + if args.mlflow_tracking: import mlflow mlflow.set_experiment("Boreholes Stratigraphy") mlflow.start_run() - # TODO: make configurable - ground_truth_path = DATAPATH.parent.parent / "data" / "zurich_ground_truth.json" - predictions_path = DATAPATH / "output" / "predictions.json" - temp_directory = DATAPATH / "_temp" + # Load the predictions + try: + with open(args.predictions_path, encoding="utf8") as file: + predictions = json.load(file) + except FileNotFoundError: + logger.error("Predictions file not found: %s", args.predictions_path) + return + except json.JSONDecodeError as e: + logger.error("Error decoding JSON from predictions file: %s", e) + return + + predictions = OverallFilePredictions.from_json(predictions) + + # Customize these as needed + evaluate(predictions, args.ground_truth_path, args.temp_directory, input_directory=None, draw_directory=None) + + +def parse_cli() -> argparse.Namespace: + """Parse the command line arguments and pass them to the main function.""" + parser = argparse.ArgumentParser(description="Borehole Stratigraphy Evaluation Script") + + # Add arguments with defaults + parser.add_argument( + "--ground-truth-path", + type=Path, + default=DATAPATH.parent / "data" / "zurich_ground_truth.json", + help="Path to the ground truth JSON file (default: '../data/zurich_ground_truth.json').", + ) + parser.add_argument( + "--predictions-path", + type=Path, + default=DATAPATH / "output" / "predictions.json", + help="Path to the predictions JSON file (default: './output/predictions.json').", + ) + parser.add_argument( + "--temp-directory", + type=Path, + default=DATAPATH / "_temp", + help="Directory for storing temporary data (default: './_temp').", + ) + parser.add_argument( + "--no-mlflow-tracking", + action="store_false", + dest="mlflow_tracking", + help="Disable MLflow tracking (enabled by default).", + ) + + # Parse arguments and pass to main + return parser.parse_args() - with open(predictions_path, encoding="utf8") as file: - predictions = json.load(file) - # TODO read BoreholeMetadataList from JSON file and pass to the evaluate method - evaluate(predictions, ground_truth_path, temp_directory, input_directory=None, draw_directory=None) +if __name__ == "__main__": + # Parse arguments and pass to main + main() diff --git a/src/stratigraphy/depthcolumn/depthcolumn.py b/src/stratigraphy/depthcolumn/depthcolumn.py index c9169ed2..ce45f148 100644 --- a/src/stratigraphy/depthcolumn/depthcolumn.py +++ b/src/stratigraphy/depthcolumn/depthcolumn.py @@ -53,14 +53,43 @@ def noise_count(self, all_words: list[TextWord]) -> int: def identify_groups( self, description_lines: list[TextLine], geometric_lines: list[Line], material_description_rect: fitz.Rect ) -> list[dict]: + """Identifies groups of description blocks that correspond to depth intervals. + + Args: + description_lines (list[TextLine]): A list of text lines that are part of the description. + geometric_lines (list[Line]): A list of geometric lines that are part of the description. + material_description_rect (fitz.Rect): The bounding box of the material description. + + Returns: + list[dict]: A list of groups, where each group is a dictionary + with the keys "depth_intervals" and "blocks". + """ pass + @abc.abstractmethod def to_json(self): - rect = self.rect() - return { - "rect": [rect.x0, rect.y0, rect.x1, rect.y1], - "entries": [entry.to_json() for entry in self.entries], - } + """Converts the object to a dictionary.""" + pass + + @classmethod + @abc.abstractmethod + def from_json(cls, json_depth_column: dict) -> DepthColumn: + """Converts a dictionary to an object.""" + pass + + +class DepthColumnFactory: + """Factory class for creating DepthColumn objects.""" + + @staticmethod + def create(data: dict) -> DepthColumn: + column_type = data.get("type") + if column_type == "BoundaryDepthColumn": + return BoundaryDepthColumn.from_json(data) + elif column_type == "LayerDepthColumn": + return LayerDepthColumn.from_json(data) + else: + raise ValueError(f"Unknown depth column type: {column_type}") class LayerDepthColumn(DepthColumn): @@ -86,6 +115,33 @@ def __init__(self, entries=None): def __repr__(self): return "LayerDepthColumn({})".format(", ".join([str(entry) for entry in self.entries])) + def to_json(self) -> dict: + """Converts the object to a dictionary. + + Returns: + dict: The object as a dictionary. + """ + rect = self.rect() + return { + "rect": [rect.x0, rect.y0, rect.x1, rect.y1], + "entries": [entry.to_json() for entry in self.entries], + "type": "LayerDepthColumn", + } + + @classmethod + def from_json(cls, json_depth_column: dict) -> LayerDepthColumn: + """Converts a dictionary to an object. + + Args: + json_depth_column (dict): A dictionary representing the depth column. + + Returns: + LayerDepthColumn: The depth column object. + """ + entries_data = json_depth_column.get("entries", []) + entries = [LayerDepthColumnEntry.from_json(entry) for entry in entries_data] + return LayerDepthColumn(entries) + def add_entry(self, entry: LayerDepthColumnEntry) -> LayerDepthColumn: self.entries.append(entry) return self @@ -190,7 +246,35 @@ def rects(self) -> list[fitz.Rect]: def __repr__(self): return "DepthColumn({})".format(", ".join([str(entry) for entry in self.entries])) + def to_json(self) -> dict: + """Converts the object to a dictionary. + + Returns: + dict: The object as a dictionary. + """ + rect = self.rect() + return { + "rect": [rect.x0, rect.y0, rect.x1, rect.y1], + "entries": [entry.to_json() for entry in self.entries], + "type": "BoundaryDepthColumn", + } + + @classmethod + def from_json(cls, json_depth_column: dict) -> BoundaryDepthColumn: + """Converts a dictionary to an object. + + Args: + json_depth_column (dict): A dictionary representing the depth column. + + Returns: + BoundaryDepthColumn: The depth column object. + """ + entries_data = json_depth_column.get("entries", []) + entries = [DepthColumnEntry.from_json(entry) for entry in entries_data] + return BoundaryDepthColumn(entries) + def add_entry(self, entry: DepthColumnEntry) -> BoundaryDepthColumn: + """Adds a depth column entry to the depth column.""" self.entries.append(entry) return self diff --git a/src/stratigraphy/depthcolumn/depthcolumnentry.py b/src/stratigraphy/depthcolumn/depthcolumnentry.py index a0dbb64c..0a9faac4 100644 --- a/src/stratigraphy/depthcolumn/depthcolumnentry.py +++ b/src/stratigraphy/depthcolumn/depthcolumnentry.py @@ -24,6 +24,22 @@ def to_json(self) -> dict[str, Any]: "page": self.page_number, } + @classmethod + def from_json(cls, json_depth_column_entry: dict) -> "DepthColumnEntry": + """Converts a dictionary to an object. + + Args: + json_depth_column_entry (dict): A dictionary representing the depth column entry. + + Returns: + DepthColumnEntry: The depth column entry object. + """ + return cls( + rect=fitz.Rect(json_depth_column_entry["rect"]), + value=json_depth_column_entry["value"], + page_number=json_depth_column_entry["page"], + ) + class AnnotatedDepthColumnEntry(DepthColumnEntry): # noqa: D101 """Class to represent a depth column entry obtained from LabelStudio. @@ -68,3 +84,17 @@ def to_json(self) -> dict[str, Any]: "rect": [self.rect.x0, self.rect.y0, self.rect.x1, self.rect.y1], "page": self.start.page_number, } + + @classmethod + def from_json(cls, json_layer_depth_column_entry: dict) -> "LayerDepthColumnEntry": + """Converts a dictionary to an object. + + Args: + json_layer_depth_column_entry (dict): A dictionary representing the layer depth column entry. + + Returns: + LayerDepthColumnEntry: The layer depth column entry object. + """ + start = DepthColumnEntry.from_json(json_layer_depth_column_entry["start"]) + end = DepthColumnEntry.from_json(json_layer_depth_column_entry["end"]) + return cls(start, end) diff --git a/src/stratigraphy/depths_materials_column_pairs/depths_materials_column_pairs.py b/src/stratigraphy/depths_materials_column_pairs/depths_materials_column_pairs.py new file mode 100644 index 00000000..d35ef9a5 --- /dev/null +++ b/src/stratigraphy/depths_materials_column_pairs/depths_materials_column_pairs.py @@ -0,0 +1,59 @@ +"""Definition of the DepthsMaterialsColumnPairs class.""" + +from dataclasses import dataclass + +import fitz +from stratigraphy.depthcolumn.depthcolumn import DepthColumn, DepthColumnFactory + + +@dataclass +class DepthsMaterialsColumnPairs: + """A class to represent pairs of depth columns and material descriptions.""" + + depth_column: DepthColumn | None + material_description_rect: fitz.Rect + page: int + + def __str__(self) -> str: + """Converts the object to a string. + + Returns: + str: The object as a string. + """ + return ( + f"DepthsMaterialsColumnPairs(depth_column={self.depth_column}," + f"material_description_rect={self.material_description_rect}, page={self.page})" + ) + + def to_json(self) -> dict: + """Converts the object to a dictionary. + + Returns: + dict: The object as a dictionary. + """ + return { + "depth_column": self.depth_column.to_json() if self.depth_column else None, + "material_description_rect": [ + self.material_description_rect.x0, + self.material_description_rect.y0, + self.material_description_rect.x1, + self.material_description_rect.y1, + ], + "page": self.page, + } + + @classmethod + def from_json(cls, json_depths_materials_column_pairs: dict) -> "DepthsMaterialsColumnPairs": + """Converts a dictionary to an object. + + Args: + json_depths_materials_column_pairs (dict): A dictionary representing the depths materials column pairs. + + Returns: + DepthsMaterialsColumnPairs: The depths materials column pairs object. + """ + depth_column = DepthColumnFactory.create(json_depths_materials_column_pairs["depth_column"]) + material_description_rect = fitz.Rect(json_depths_materials_column_pairs["material_description_rect"]) + page = json_depths_materials_column_pairs["page"] + + return cls(depth_column, material_description_rect, page) diff --git a/src/stratigraphy/extract.py b/src/stratigraphy/extract.py index 372beeba..50c5c4d4 100644 --- a/src/stratigraphy/extract.py +++ b/src/stratigraphy/extract.py @@ -2,11 +2,13 @@ import logging import math +from dataclasses import dataclass import fitz from stratigraphy.depthcolumn import find_depth_columns from stratigraphy.depthcolumn.depthcolumn import DepthColumn +from stratigraphy.depths_materials_column_pairs.depths_materials_column_pairs import DepthsMaterialsColumnPairs from stratigraphy.layer.layer_identifier_column import ( LayerIdentifierColumn, find_layer_identifier_column, @@ -30,9 +32,17 @@ logger = logging.getLogger(__name__) +@dataclass +class ProcessPageResult: + """The result of processing a single page of a pdf.""" + + predictions: list[dict] + depth_material_pairs: list[DepthsMaterialsColumnPairs] + + def process_page( lines: list[TextLine], geometric_lines, language: str, page_number: int, **params: dict -) -> list[dict]: +) -> ProcessPageResult: """Process a single page of a pdf. Finds all descriptions and depth intervals on the page and matches them. @@ -115,22 +125,15 @@ def process_page( depth_column, description_lines, geometric_lines, material_description_rect, **params ) groups.extend(new_groups) - json_filtered_pairs = [ - { - "depth_column": depth_column.to_json(), - "material_description_rect": [ - material_description_rect.x0, - material_description_rect.y0, - material_description_rect.x1, - material_description_rect.y1, - ], - "page": page_number, - } + filtered_depth_material_column_pairs = [ + DepthsMaterialsColumnPairs( + depth_column=depth_column, material_description_rect=material_description_rect, page=page_number + ) for depth_column, material_description_rect in filtered_pairs ] else: - json_filtered_pairs = [] + filtered_depth_material_column_pairs = [] # Fallback when no depth column was found material_description_rect = find_material_description_column( lines, depth_column=None, language=language, **params["material_description"] @@ -145,19 +148,10 @@ def process_page( params["left_line_length_threshold"], ) groups.extend([{"block": block} for block in description_blocks]) - json_filtered_pairs.extend( - [ - { - "depth_column": None, - "material_description_rect": [ - material_description_rect.x0, - material_description_rect.y0, - material_description_rect.x1, - material_description_rect.y1, - ], - "page": page_number, - } - ] + filtered_depth_material_column_pairs.append( + DepthsMaterialsColumnPairs( + depth_column=None, material_description_rect=material_description_rect, page=page_number + ) ) predictions = [ ( @@ -168,7 +162,7 @@ def process_page( for group in groups ] predictions = remove_empty_predictions(predictions) - return predictions, json_filtered_pairs + return ProcessPageResult(predictions, filtered_depth_material_column_pairs) def score_column_match( diff --git a/src/stratigraphy/groundwater/groundwater_extraction.py b/src/stratigraphy/groundwater/groundwater_extraction.py index d548c77f..568a346e 100644 --- a/src/stratigraphy/groundwater/groundwater_extraction.py +++ b/src/stratigraphy/groundwater/groundwater_extraction.py @@ -111,24 +111,25 @@ def to_json(self) -> dict: "rect": [self.rect.x0, self.rect.y0, self.rect.x1, self.rect.y1] if self.rect else None, } - @staticmethod - def from_json_values(date: str | None, depth: float | None, elevation: float | None, page: int, rect: list[float]): - """Converts the object from a dictionary. + @classmethod + def from_json(cls, json_groundwater_information_on_page: dict) -> "GroundwaterInformationOnPage": + """Converts a dictionary to an object. Args: - date (str | None): The measurement date of the groundwater. - depth (float | None): The depth of the groundwater. - elevation (float | None): The elevation of the groundwater. - page (int): The page number of the PDF document. - rect (list[float]): The rectangle that contains the extracted information. + json_groundwater_information_on_page (dict): A dictionary representing the groundwater information on a + page. Returns: - GroundwaterInformationOnPage: The object created from the dictionary. + GroundwaterInformationOnPage: The groundwater information on a page object. """ return GroundwaterInformationOnPage( - groundwater=GroundwaterInformation.from_json_values(depth=depth, date=date, elevation=elevation), - page=page, - rect=fitz.Rect(rect), + groundwater=GroundwaterInformation.from_json_values( + depth=json_groundwater_information_on_page["depth"], + date=json_groundwater_information_on_page["date"], + elevation=json_groundwater_information_on_page["elevation"], + ), + page=json_groundwater_information_on_page["page"], + rect=fitz.Rect(json_groundwater_information_on_page["rect"]), ) diff --git a/src/stratigraphy/layer/layer.py b/src/stratigraphy/layer/layer.py index 9a5d3d99..e1bbca9b 100644 --- a/src/stratigraphy/layer/layer.py +++ b/src/stratigraphy/layer/layer.py @@ -3,6 +3,9 @@ import uuid from dataclasses import dataclass, field +import fitz +from stratigraphy.depthcolumn.depthcolumnentry import DepthColumnEntry +from stratigraphy.lines.line import TextLine, TextWord from stratigraphy.text.textblock import MaterialDescription, TextBlock from stratigraphy.util.interval import AnnotatedInterval, BoundaryInterval @@ -26,3 +29,79 @@ def __str__(self) -> str: return ( f"LayerPrediction(material_description={self.material_description}, depth_interval={self.depth_interval})" ) + + def to_json(self) -> dict: + """Converts the object to a dictionary. + + Returns: + dict: The object as a dictionary. + """ + return { + "material_description": self.material_description.to_json() if self.material_description else None, + "depth_interval": self.depth_interval.to_json() if self.depth_interval else None, + "material_is_correct": self.material_is_correct, + "depth_interval_is_correct": self.depth_interval_is_correct, + "id": str(self.id), + } + + @staticmethod + def from_json(json_layer_list: list[dict]) -> list["LayerPrediction"]: + """Converts a dictionary to an object. + + Args: + json_layer_list (list[dict]): A list of dictionaries representing the layers. + + Returns: + list[LayerPrediction]: A list of LayerPrediction objects. + """ + page_layer_predictions_list: list[LayerPrediction] = [] + + # Extract the layer predictions. + for layer in json_layer_list: + material_prediction = _create_textblock_object(layer["material_description"]["lines"]) + if "depth_interval" in layer and layer["depth_interval"] is not None: + depth_interval = layer.get("depth_interval", {}) + start_data = depth_interval.get("start") + end_data = depth_interval.get("end") + start = ( + DepthColumnEntry( + value=start_data["value"], + rect=fitz.Rect(start_data["rect"]), + page_number=start_data["page"], + ) + if start_data is not None + else None + ) + end = ( + DepthColumnEntry( + value=end_data["value"], + rect=fitz.Rect(end_data["rect"]), + page_number=end_data["page"], + ) + if end_data is not None + else None + ) + + depth_interval_prediction = BoundaryInterval(start=start, end=end) + layer_predictions = LayerPrediction( + material_description=material_prediction, depth_interval=depth_interval_prediction + ) + else: + layer_predictions = LayerPrediction(material_description=material_prediction, depth_interval=None) + + page_layer_predictions_list.append(layer_predictions) + + return page_layer_predictions_list + + +def _create_textblock_object(lines: list[dict]) -> TextBlock: + """Creates a TextBlock object from a dictionary. + + Args: + lines (list[dict]): A list of dictionaries representing the lines. + + Returns: + TextBlock: The object. + """ + text_lines = [TextLine([TextWord(**line)]) for line in lines] + return TextBlock(text_lines) diff --git a/src/stratigraphy/main.py b/src/stratigraphy/main.py index b420e875..764ccc08 100644 --- a/src/stratigraphy/main.py +++ b/src/stratigraphy/main.py @@ -16,9 +16,11 @@ from stratigraphy.extract import process_page from stratigraphy.groundwater.groundwater_extraction import GroundwaterLevelExtractor from stratigraphy.layer.duplicate_detection import remove_duplicate_layers +from stratigraphy.layer.layer import LayerPrediction from stratigraphy.lines.line_detection import extract_lines, line_detection_params -from stratigraphy.metadata.metadata import BoreholeMetadata, BoreholeMetadataList +from stratigraphy.metadata.metadata import BoreholeMetadata from stratigraphy.text.extract_text import extract_text_lines +from stratigraphy.util.predictions import FilePredictions, OverallFilePredictions from stratigraphy.util.util import flatten, read_params load_dotenv() @@ -212,10 +214,7 @@ def start_pipeline( _, _, files = next(os.walk(input_directory)) # process the individual pdf files - predictions = {} - - # process the individual pdf files - metadata_per_file = BoreholeMetadataList() + predictions = OverallFilePredictions() for filename in tqdm(files, desc="Processing files", unit="file"): if filename.endswith(".pdf"): @@ -224,23 +223,12 @@ def start_pipeline( with fitz.Document(in_path) as doc: # Extract metadata - metadata = BoreholeMetadata(doc) - - # Add metadata to the metadata list - metadata_per_file.metadata_per_file.append(metadata) + metadata = BoreholeMetadata.from_document(doc) if part == "all": - predictions[filename] = {} - # Extract the groundwater levels groundwater_extractor = GroundwaterLevelExtractor(document=doc) groundwater = groundwater_extractor.extract_groundwater(terrain_elevation=metadata.elevation) - if groundwater: - predictions[filename]["groundwater"] = [ - groundwater_entry.to_json() for groundwater_entry in groundwater - ] - else: - predictions[filename]["groundwater"] = None layer_predictions_list = [] depths_materials_column_pairs_list = [] @@ -250,7 +238,7 @@ def start_pipeline( text_lines = extract_text_lines(page) geometric_lines = extract_lines(page, line_detection_params) - layer_predictions, depths_materials_column_pairs = process_page( + process_page_results = process_page( text_lines, geometric_lines, metadata.language, page_number, **matching_params ) @@ -260,12 +248,14 @@ def start_pipeline( doc[page_index - 1], page, layer_predictions_list, - layer_predictions, + process_page_results.predictions, matching_params["img_template_probability_threshold"], ) + else: + layer_predictions = process_page_results.predictions layer_predictions_list.extend(layer_predictions) - depths_materials_column_pairs_list.extend(depths_materials_column_pairs) + depths_materials_column_pairs_list.extend(process_page_results.depth_material_pairs) if draw_lines: # could be changed to if draw_lines and mflow_tracking: if not mlflow_tracking: @@ -278,25 +268,41 @@ def start_pipeline( ) mlflow.log_image(img, f"pages/{filename}_page_{page.number + 1}_lines.png") + # Save the predictions to the overall predictions object + # Initialize common variables + groundwater_entries = None + layers = None + depths_materials_columns_pairs = None + if part == "all": - predictions[filename]["layers"] = layer_predictions_list - predictions[filename]["depths_materials_column_pairs"] = depths_materials_column_pairs_list - predictions[filename]["page_dimensions"] = ( - metadata.page_dimensions - ) # TODO: Remove this as it is already stored in the metadata + # Convert the layer dicts to LayerPrediction objects + page_layer_predictions_list = LayerPrediction.from_json(layer_predictions_list) + groundwater_entries = groundwater + layers = page_layer_predictions_list + depths_materials_columns_pairs = depths_materials_column_pairs_list + + # Add file predictions + predictions.add_file_predictions( + FilePredictions( + file_name=filename, + metadata=metadata, + groundwater_entries=groundwater_entries, + layers=layers, + depths_materials_columns_pairs=depths_materials_columns_pairs, + ) + ) logger.info("Metadata written to %s", metadata_path) with open(metadata_path, "w", encoding="utf8") as file: - json.dump(metadata_per_file.to_json(), file, ensure_ascii=False) + json.dump(predictions.get_metadata_as_dict(), file, ensure_ascii=False) if part == "all": logger.info("Writing predictions to JSON file %s", predictions_path) with open(predictions_path, "w", encoding="utf8") as file: - json.dump(predictions, file, ensure_ascii=False) + json.dump(predictions.to_json(), file, ensure_ascii=False) evaluate( predictions=predictions, - metadata_per_file=metadata_per_file, ground_truth_path=ground_truth_path, temp_directory=temp_directory, input_directory=input_directory, diff --git a/src/stratigraphy/metadata/coordinate_extraction.py b/src/stratigraphy/metadata/coordinate_extraction.py index 63961ba5..c304980b 100644 --- a/src/stratigraphy/metadata/coordinate_extraction.py +++ b/src/stratigraphy/metadata/coordinate_extraction.py @@ -94,8 +94,8 @@ def from_values(east: float, north: float, rect: fitz.Rect, page: int) -> Coordi logger.warning("Invalid coordinates format. Got E: %s, N: %s", east, north) return None - @staticmethod - def from_json(input: dict): + @classmethod + def from_json(cls, input: dict) -> Coordinate: """Converts a dictionary to a Coordinate object. Args: diff --git a/src/stratigraphy/metadata/elevation_extraction.py b/src/stratigraphy/metadata/elevation_extraction.py index c3a2e333..5b9514ec 100644 --- a/src/stratigraphy/metadata/elevation_extraction.py +++ b/src/stratigraphy/metadata/elevation_extraction.py @@ -53,6 +53,21 @@ def to_json(self) -> dict: "rect": [self.rect.x0, self.rect.y0, self.rect.x1, self.rect.y1] if self.rect else None, } + @classmethod + def from_json(cls, json_elevation: dict) -> "Elevation": + """Converts a dictionary to an object. + + Args: + json_elevation (dict): A dictionary representing the elevation information. + + Returns: + Elevation: The elevation information object. + """ + elevation = json_elevation["elevation"] + page = json_elevation["page"] + rect = json_elevation["rect"] + return cls(elevation=elevation, page=page, rect=rect) + class ElevationExtractor(DataExtractor): """Class for extracting elevation data from text. diff --git a/src/stratigraphy/metadata/metadata.py b/src/stratigraphy/metadata/metadata.py index b980f613..2830a340 100644 --- a/src/stratigraphy/metadata/metadata.py +++ b/src/stratigraphy/metadata/metadata.py @@ -37,37 +37,29 @@ class BoreholeMetadata(metaclass=abc.ABCMeta): filename: Path = None page_dimensions: list[PageDimensions] = None - def __init__(self, document: fitz.Document): + def __init__( + self, + language: str = None, + elevation: Elevation = None, + coordinates: Coordinate = None, + page_dimensions: list[PageDimensions] = None, + filename: Path = None, + ): """Initializes the BoreholeMetadata object. Args: - document (fitz.Document): A PDF document. + Args: + language (str | None): The language of the document. + elevation (Elevation | None): The elevation information. + coordinates (Coordinate | None): The coordinates of the borehole. + page_dimensions (list[PageDimensions] | None): The dimensions of the pages in the document. + filename (Path | None): The name of the file. """ - matching_params = read_params("matching_params.yml") - - # Detect the language of the document - self.language = detect_language_of_document( - document, matching_params["default_language"], matching_params["material_description"].keys() - ) - - # Extract the coordinates of the borehole - coordinate_extractor = CoordinateExtractor(document=document) - self.coordinates = coordinate_extractor.extract_coordinates() - - # Extract the elevation information - elevation_extractor = ElevationExtractor(document=document) - self.elevation = elevation_extractor.extract_elevation() - - # Get the name of the document - self.filename = Path(document.name) - - # Get the dimensions of the document's pages - self.page_dimensions = [] - for page in document: - self.page_dimensions.append(PageDimensions(width=page.rect.width, height=page.rect.height)) - - # Sanity check - assert len(self.page_dimensions) == document.page_count, "Page count mismatch." + self.language = language + self.elevation = elevation + self.coordinates = coordinates + self.page_dimensions = page_dimensions + self.filename = filename def to_json(self) -> dict: """Converts the object to a dictionary. @@ -96,17 +88,101 @@ def __str__(self) -> str: f"page_dimensions={self.page_dimensions})" ) + @classmethod + def from_document(cls, document: fitz.Document) -> "BoreholeMetadata": + """Create a BoreholeMetadata object from a document. + + Args: + document (fitz.Document): The document. + + Returns: + BoreholeMetadata: The metadata object. + """ + matching_params = read_params("matching_params.yml") + + # Detect the language of the document + language = detect_language_of_document( + document, matching_params["default_language"], matching_params["material_description"].keys() + ) + + # Extract the coordinates of the borehole + coordinate_extractor = CoordinateExtractor(document=document) + coordinates = coordinate_extractor.extract_coordinates() + + # Extract the elevation information + elevation_extractor = ElevationExtractor(document=document) + elevation = elevation_extractor.extract_elevation() + + # Get the name of the document + filename = Path(document.name) + + # Get the dimensions of the document's pages + page_dimensions = [] + for page in document: + page_dimensions.append(PageDimensions(width=page.rect.width, height=page.rect.height)) + + # Sanity check + assert len(page_dimensions) == document.page_count, "Page count mismatch." + + return cls( + language=language, + elevation=elevation, + coordinates=coordinates, + filename=filename, + page_dimensions=page_dimensions, + ) + + @classmethod + def from_json(cls, json_metadata: dict, filename: str) -> "BoreholeMetadata": + """Converts a dictionary to an object. + + Args: + json_metadata (dict): A dictionary representing the metadata. + filename (str): The name of the file. + + Returns: + BoreholeMetadata: The metadata object. + """ + elevation = Elevation.from_json(json_metadata["elevation"]) if json_metadata["elevation"] is not None else None + coordinates = ( + Coordinate.from_json(json_metadata["coordinates"]) if json_metadata["coordinates"] is not None else None + ) + language = json_metadata["language"] + page_dimensions = [ + PageDimensions(width=page["width"], height=page["height"]) for page in json_metadata["page_dimensions"] + ] + + return cls( + elevation=elevation, + coordinates=coordinates, + language=language, + page_dimensions=page_dimensions, + filename=Path(filename), + ) + @dataclass class BoreholeMetadataList(metaclass=abc.ABCMeta): - """Metadata for stratigraphy data.""" + """Metadata for stratigraphy data. + + This class is a list of BoreholeMetadata objects. Each object corresponds to a + single file. + """ metadata_per_file: list[BoreholeMetadata] = None def __init__(self): - """Initializes the StratigraphyMetadata object.""" + """Initializes the BoreholeMetadataList object.""" self.metadata_per_file = [] + def add_metadata(self, metadata: BoreholeMetadata) -> None: + """Add metadata to the list. + + Args: + metadata (BoreholeMetadata): The metadata to add. + """ + self.metadata_per_file.append(metadata) + def get_metadata(self, filename: str) -> BoreholeMetadata: """Get the metadata for a specific file. diff --git a/src/stratigraphy/util/predictions.py b/src/stratigraphy/util/predictions.py index 005a65da..c9ae5b21 100644 --- a/src/stratigraphy/util/predictions.py +++ b/src/stratigraphy/util/predictions.py @@ -2,18 +2,16 @@ import logging from collections import Counter +from pathlib import Path -import fitz import Levenshtein -from stratigraphy.depthcolumn.depthcolumnentry import DepthColumnEntry -from stratigraphy.evaluation.evaluation_dataclasses import Metrics +from stratigraphy.depths_materials_column_pairs.depths_materials_column_pairs import DepthsMaterialsColumnPairs +from stratigraphy.evaluation.evaluation_dataclasses import Metrics, OverallBoreholeMetadataMetrics +from stratigraphy.evaluation.metadata_evaluator import MetadataEvaluator from stratigraphy.groundwater.groundwater_extraction import GroundwaterInformation, GroundwaterInformationOnPage from stratigraphy.layer.layer import LayerPrediction -from stratigraphy.lines.line import TextLine, TextWord -from stratigraphy.metadata.metadata import BoreholeMetadata -from stratigraphy.text.textblock import TextBlock -from stratigraphy.util.interval import BoundaryInterval +from stratigraphy.metadata.metadata import BoreholeMetadata, BoreholeMetadataList from stratigraphy.util.util import parse_text logger = logging.getLogger(__name__) @@ -24,91 +22,19 @@ class FilePredictions: def __init__( self, - layers: list[LayerPrediction], + layers: list[LayerPrediction] | None, file_name: str, - language: str, metadata: BoreholeMetadata, - groundwater_entries: list[GroundwaterInformationOnPage], - depths_materials_columns_pairs: list[dict], - page_sizes: list[dict[str, float]], + groundwater_entries: list[GroundwaterInformationOnPage] | None, + depths_materials_columns_pairs: list[DepthsMaterialsColumnPairs] | None, ): - self.layers: list[LayerPrediction] = layers - self.depths_materials_columns_pairs: list[dict] = depths_materials_columns_pairs - self.file_name = file_name - self.language = language - self.metadata = metadata - self.page_sizes: list[dict[str, float]] = page_sizes - self.groundwater_entries = groundwater_entries + self.layers: list[LayerPrediction] | None = layers + self.depths_materials_columns_pairs: list[DepthsMaterialsColumnPairs] | None = depths_materials_columns_pairs + self.file_name: str = file_name + self.metadata: BoreholeMetadata = metadata + self.groundwater_entries: list[GroundwaterInformationOnPage] | None = groundwater_entries self.groundwater_is_correct: dict = {} - @staticmethod - def create_from_json(predictions_for_file: dict, metadata: BoreholeMetadata, file_name: str): - """Create predictions class for a file given the predictions.json file. - - Args: - predictions_for_file (dict): The predictions for the file in json format. - metadata (BoreholeMetadata): The metadata for the file. - file_name (str): The name of the file. - """ - page_layer_predictions_list: list[LayerPrediction] = [] - pages_dimensions_list: list[dict[str, float]] = [] - depths_materials_columns_pairs_list: list[dict] = [] - - # Extract groundwater information if available. - if "groundwater" in predictions_for_file and predictions_for_file["groundwater"] is not None: - groundwater_entries = [ - GroundwaterInformationOnPage.from_json_values(**entry) for entry in predictions_for_file["groundwater"] - ] - else: - groundwater_entries = [] - - # Extract the layer predictions. - for layer in predictions_for_file["layers"]: - material_prediction = _create_textblock_object(layer["material_description"]["lines"]) - if "depth_interval" in layer: - start = ( - DepthColumnEntry( - value=layer["depth_interval"]["start"]["value"], - rect=fitz.Rect(layer["depth_interval"]["start"]["rect"]), - page_number=layer["depth_interval"]["start"]["page"], - ) - if layer["depth_interval"]["start"] is not None - else None - ) - end = ( - DepthColumnEntry( - value=layer["depth_interval"]["end"]["value"], - rect=fitz.Rect(layer["depth_interval"]["end"]["rect"]), - page_number=layer["depth_interval"]["end"]["page"], - ) - if layer["depth_interval"]["end"] is not None - else None - ) - - depth_interval_prediction = BoundaryInterval(start=start, end=end) - layer_predictions = LayerPrediction( - material_description=material_prediction, depth_interval=depth_interval_prediction - ) - else: - layer_predictions = LayerPrediction(material_description=material_prediction, depth_interval=None) - - page_layer_predictions_list.append(layer_predictions) - - if "depths_materials_column_pairs" in predictions_for_file: - depths_materials_columns_pairs_list.extend(predictions_for_file["depths_materials_column_pairs"]) - - pages_dimensions_list.extend(predictions_for_file["page_dimensions"]) - - return FilePredictions( - layers=page_layer_predictions_list, - file_name=file_name, - language=metadata.language, - metadata=metadata, - depths_materials_columns_pairs=depths_materials_columns_pairs_list, - page_sizes=pages_dimensions_list, - groundwater_entries=groundwater_entries, - ) - def convert_to_ground_truth(self): """Convert the predictions to ground truth format. @@ -276,7 +202,94 @@ def _find_matching_layer( unmatched_layers.remove(match) return match, False + def to_json(self) -> dict: + """Converts the object to a dictionary. + + Returns: + dict: The object as a dictionary. + """ + return { + "metadata": self.metadata.to_json(), + "layers": [layer.to_json() for layer in self.layers] if self.layers is not None else [], + "depths_materials_column_pairs": [dmc_pair.to_json() for dmc_pair in self.depths_materials_columns_pairs] + if self.depths_materials_columns_pairs is not None + else [], + "page_dimensions": self.metadata.page_dimensions, # TODO: Remove, already in metadata + "groundwater": [entry.to_json() for entry in self.groundwater_entries] + if self.groundwater_entries is not None + else [], + "file_name": self.file_name, + } + + +class OverallFilePredictions: + """A class to represent predictions for all files.""" + + def __init__(self): + """Initializes the OverallFilePredictions object.""" + self.file_predictions_list: list[FilePredictions] = [] + + def add_file_predictions(self, file_predictions: FilePredictions): + """Add file predictions to the list of file predictions. + + Args: + file_predictions (FilePredictions): The file predictions to add. + """ + self.file_predictions_list.append(file_predictions) + + def get_metadata_as_dict(self): + """Returns the metadata of the predictions as a dictionary.""" + return { + file_prediction.file_name: file_prediction.metadata.to_json() + for file_prediction in self.file_predictions_list + } + + def to_json(self) -> dict: + """Converts the object to a dictionary by merging individual file predictions. + + Returns: + dict: A dictionary representation of the object. + """ + return {fp.file_name: fp.to_json() for fp in self.file_predictions_list} + + @classmethod + def from_json(cls, prediction_from_file: dict) -> "OverallFilePredictions": + """Converts a dictionary to an object. + + Args: + prediction_from_file (dict): A dictionary representing the predictions. + + Returns: + OverallFilePredictions: The object. + """ + overall_file_predictions = OverallFilePredictions() + for file_name, file_data in prediction_from_file.items(): + metadata = BoreholeMetadata.from_json(file_data["metadata"], file_name) + layers = LayerPrediction.from_json(file_data["layers"]) + depths_materials_columns_pairs = [ + DepthsMaterialsColumnPairs.from_json(dmc_pair) + for dmc_pair in file_data["depths_materials_column_pairs"] + ] + groundwater_entries = [GroundwaterInformationOnPage.from_json(entry) for entry in file_data["groundwater"]] + overall_file_predictions.add_file_predictions( + FilePredictions( + layers=layers, + file_name=file_name, + metadata=metadata, + depths_materials_columns_pairs=depths_materials_columns_pairs, + groundwater_entries=groundwater_entries, + ) + ) + return overall_file_predictions + + def evaluate_metadata_extraction(self, ground_truth_path: Path) -> OverallBoreholeMetadataMetrics: + """Evaluate the metadata extraction of the predictions against the ground truth. + + Args: + ground_truth_path (Path): The path to the ground truth file. + """ + metadata_per_file: BoreholeMetadataList = BoreholeMetadataList() -def _create_textblock_object(lines: dict) -> TextBlock: - lines = [TextLine([TextWord(**line)]) for line in lines] - return TextBlock(lines) + for file_prediction in self.file_predictions_list: + metadata_per_file.add_metadata(file_prediction.metadata) + return MetadataEvaluator(metadata_per_file, ground_truth_path).evaluate() diff --git a/tests/test_predictions.py b/tests/test_predictions.py new file mode 100644 index 00000000..3813558e --- /dev/null +++ b/tests/test_predictions.py @@ -0,0 +1,110 @@ +"""Test suite for the prediction module.""" + +from pathlib import Path +from unittest.mock import Mock + +import fitz +import pytest +from stratigraphy.metadata.coordinate_extraction import CoordinateEntry, LV95Coordinate +from stratigraphy.metadata.metadata import BoreholeMetadata +from stratigraphy.util.predictions import FilePredictions, OverallFilePredictions + +# Mock classes used in the FilePredictions constructor +LayerPrediction = Mock() +GroundwaterInformationOnPage = Mock() +DepthsMaterialsColumnPairs = Mock() + + +@pytest.fixture +def sample_file_prediction(): + """Fixture to create a sample FilePredictions object.""" + coord = LV95Coordinate( + east=CoordinateEntry(coordinate_value=2789456), + north=CoordinateEntry(coordinate_value=1123012), + rect=fitz.Rect(), + page=1, + ) + + layer1 = Mock( + material_description=Mock(text="Sand"), depth_interval=Mock(start=Mock(value=10), end=Mock(value=20)) + ) + layer2 = Mock( + material_description=Mock(text="Clay"), depth_interval=Mock(start=Mock(value=30), end=Mock(value=50)) + ) + metadata = BoreholeMetadata(coordinates=coord, page_dimensions=[Mock(width=10, height=20)], language="en") + + return FilePredictions( + layers=[layer1, layer2], + file_name="test_file", + metadata=metadata, + groundwater_entries=None, + depths_materials_columns_pairs=None, + ) + + +def test_convert_to_ground_truth(sample_file_prediction): + """Test the convert_to_ground_truth method.""" + ground_truth = sample_file_prediction.convert_to_ground_truth() + + assert ground_truth["test_file"]["metadata"]["coordinates"]["E"] == 2789456 + assert ground_truth["test_file"]["metadata"]["coordinates"]["N"] == 1123012 + assert len(ground_truth["test_file"]["layers"]) == 2 + assert ground_truth["test_file"]["layers"][0]["material_description"] == "Sand" + + +def test_to_json(sample_file_prediction): + """Test the to_json method.""" + result = sample_file_prediction.to_json() + + assert isinstance(result, dict) + assert result["file_name"] == "test_file" + assert len(result["layers"]) == 2 + assert result["metadata"]["coordinates"]["E"] == 2789456 + + +def test_count_against_ground_truth(): + """Test the count_against_ground_truth static method.""" + values = [1, 2, 2, 3] + ground_truth = [2, 3, 4] + metrics = FilePredictions.count_against_ground_truth(values, ground_truth) + + assert metrics.tp == 2 + assert metrics.fp == 2 + assert metrics.fn == 1 + + +def test_overall_file_predictions(): + """Test OverallFilePredictions class functionality.""" + overall_predictions = OverallFilePredictions() + file_prediction = Mock(to_json=lambda: {"some_data": "test"}, file_name="test_file") + + overall_predictions.add_file_predictions(file_prediction) + result = overall_predictions.to_json() + + assert len(result) == 1 + assert result == {"test_file": {"some_data": "test"}} + + +def test_evaluate_groundwater(sample_file_prediction): + """Test the evaluate_groundwater method.""" + sample_file_prediction.groundwater_entries = [ + Mock(groundwater=Mock(depth=100, format_date=lambda: "2024-10-01", elevation=20)) + ] + groundwater_gt = [{"depth": 100, "date": "2024-10-01", "elevation": 20}] + + sample_file_prediction.evaluate_groundwater(groundwater_gt) + + assert sample_file_prediction.groundwater_is_correct["groundwater"].tp == 1 + assert sample_file_prediction.groundwater_is_correct["groundwater_depth"].tp == 1 + + +def test_evaluate_metadata_extraction(): + """Test evaluate_metadata_extraction method of OverallFilePredictions.""" + overall_predictions = OverallFilePredictions() + file_prediction = Mock(metadata=Mock(to_json=lambda: {"coordinates": "some_coordinates"})) + overall_predictions.add_file_predictions(file_prediction) + + ground_truth_path = Path("example/example_groundtruth.json") + metadata_metrics = overall_predictions.evaluate_metadata_extraction(ground_truth_path) + + assert metadata_metrics is not None # Ensure the evaluation returns a result