diff --git a/src/scripts/label_studio_annotation_to_ground_truth.py b/src/scripts/label_studio_annotation_to_ground_truth.py deleted file mode 100644 index 81d675f6..00000000 --- a/src/scripts/label_studio_annotation_to_ground_truth.py +++ /dev/null @@ -1,217 +0,0 @@ -"""Script to convert annotations from label studio to a ground truth file.""" - -import contextlib -import json -import logging -from collections import defaultdict -from pathlib import Path -from typing import Any - -import click -import fitz -from stratigraphy.layer.layer import LayerPrediction -from stratigraphy.metadata.coordinate_extraction import Coordinate -from stratigraphy.text.textblock import MaterialDescription -from stratigraphy.util.interval import AnnotatedInterval -from stratigraphy.util.predictions import BoreholeMetaData, FilePredictions - -logger = logging.getLogger(__name__) - - -@click.command() -@click.option("-a", "--annotation-file-path", type=click.Path(path_type=Path), help="The path to the annotation file.") -@click.option("-o", "--output-path", type=click.Path(path_type=Path), help="The output path of the ground truth file.") -def convert_annotations_to_ground_truth(annotation_file_path: Path, output_path: Path): - """Convert the annotation file to the ground truth format. - - Args: - annotation_file_path (Path): The path to the annotation file. - output_path (Path): The output path of the ground truth file. - """ - with open(annotation_file_path) as f: - annotations = json.load(f) - - file_predictions = create_from_label_studio(annotations) - - ground_truth = {} - for prediction in file_predictions: - ground_truth = {**ground_truth, **prediction.convert_to_ground_truth()} - - # check if the output path exists - if not output_path.parent.exists(): - output_path.parent.mkdir(parents=True) - - with open(output_path, "w") as f: - json.dump(ground_truth, f, indent=4) - - -def create_from_label_studio(annotation_results: dict) -> list[FilePredictions]: - """Create predictions class for a file given the annotation results from Label Studio. - - This method is meant to import annotations from label studio. The primary use case is to - use the annotated data for evaluation. For that purpose, there is the convert_to_ground_truth - method, which then converts the predictions to ground truth format. - - NOTE: We may want to adjust this method to return a single instance of the class, - instead of a list of class objects. - - Args: - annotation_results (dict): The annotation results from Label Studio. - The annotation_results can cover multiple files. - - Returns: - list[FilePredictions]: A list of FilePredictions objects, one for each file present in the - annotation_results. - """ - file_predictions = defaultdict(list) - metadata = {} - for annotation in annotation_results: - # get page level information - file_name, _ = _get_file_name_and_page_index(annotation) - page_width = annotation["annotations"][0]["result"][0]["original_width"] - page_height = annotation["annotations"][0]["result"][0]["original_height"] - - # extract all material descriptions and depth intervals and link them together - # Note: we need to loop through the annotations twice, because the order of the annotations is - # not guaranteed. In the first iteration we grasp all IDs, in the second iteration we extract the - # information for each id. - material_descriptions = {} - depth_intervals = {} - coordinates = {} - linking_objects = [] - - # define all the material descriptions and depth intervals with their ids - for annotation_result in annotation["annotations"][0]["result"]: - if annotation_result["type"] == "labels": - if annotation_result["value"]["labels"] == ["Material Description"]: - material_descriptions[annotation_result["id"]] = { - "rect": annotation_result["value"] - } # TODO extract rectangle properly; does not impact the ground truth though. - elif annotation_result["value"]["labels"] == ["Depth Interval"]: - depth_intervals[annotation_result["id"]] = {} - elif annotation_result["value"]["labels"] == ["Coordinates"]: - coordinates[annotation_result["id"]] = {} - if annotation_result["type"] == "relation": - linking_objects.append({"from_id": annotation_result["from_id"], "to_id": annotation_result["to_id"]}) - - # check annotation results for material description or depth interval ids - for annotation_result in annotation["annotations"][0]["result"]: - with contextlib.suppress(KeyError): - id = annotation_result["id"] # relation regions do not have an ID. - if annotation_result["type"] == "textarea": - if id in material_descriptions: - material_descriptions[id]["text"] = annotation_result["value"]["text"][ - 0 - ] # There is always only one element. TO CHECK! - if len(annotation_result["value"]["text"]) > 1: - print(f"More than one text in material description: {annotation_result['value']['text']}") - elif id in depth_intervals: - depth_interval_text = annotation_result["value"]["text"][0] - start, end = _get_start_end_from_text(depth_interval_text) - depth_intervals[id]["start"] = start - depth_intervals[id]["end"] = end - depth_intervals[id]["background_rect"] = annotation_result[ - "value" - ] # TODO extract rectangle properly; does not impact the ground truth though. - elif id in coordinates: - coordinates[id]["text"] = annotation_result["value"]["text"][0] - else: - print(f"Unknown id: {id}") - - # create the layer prediction objects by linking material descriptions with depth intervals - layers = [] - - for link in linking_objects: - from_id = link["from_id"] - to_id = link["to_id"] - material_description_prediction = MaterialDescription(**material_descriptions.pop(from_id)) - depth_interval_prediction = AnnotatedInterval(**depth_intervals.pop(to_id)) - layers.append( - LayerPrediction( - material_description=material_description_prediction, - depth_interval=depth_interval_prediction, - material_is_correct=True, - depth_interval_is_correct=True, - ) - ) - - if material_descriptions or depth_intervals: - # TODO: This should not be acceptable. Raising an error doesnt seem the right way to go either. - # But at least it should be warned. - print("There are material descriptions or depth intervals left over.") - print(material_descriptions) - print(depth_intervals) - - # instantiate metadata object - if coordinates: - coordinate_text = coordinates.popitem()[1]["text"] - # TODO: we could extract the rectangle as well. For conversion to ground truth this does not matter. - metadata[file_name] = BoreholeMetaData(coordinates=_get_coordinates_from_text(coordinate_text)) - - # create the page prediction object - if file_name in file_predictions: - # append the page predictions to the existing file predictions - file_predictions[file_name].layers.extend(layers) - file_predictions[file_name].page_sizes.append({"width": page_width, "height": page_height}) - else: - # create a new file prediction object if it does not exist yet - file_predictions[file_name] = FilePredictions( - layers=layers, - file_name=f"{file_name}.pdf", - language="unknown", - metadata=metadata.get(file_name), - groundwater_entries=[], - depths_materials_columns_pairs=[], - page_sizes=[{"width": page_width, "height": page_height}], - ) - - file_predictions_list = [] - for _, file_prediction in file_predictions.items(): - file_predictions_list.append(file_prediction) # TODO: language should not be required here. - - return file_predictions_list - - -def _get_coordinates_from_text(text: str) -> Coordinate | None: - """Convert a string to a Coordinate object. - - The string has the format: E: 498'561, N: 114'332 or E: 2'498'561, N: 1'114'332. - - Args: - text (str): The input string to be converted to a Coordinate object. - - Returns: - Coordinate: The Coordinate object. - """ - try: - east_text, north_text = text.split(", ") - east = int(east_text.split(": ")[1].replace("'", "")) - north = int(north_text.split(": ")[1].replace("'", "")) - return Coordinate.from_values(east=east, north=north, page=0, rect=fitz.Rect([0, 0, 0, 0])) - except ValueError: # This is likely due to a wrong format of the text. - logger.warning(f"Could not extract coordinates from text: {text}.") - return None - - -def _get_start_end_from_text(text: str) -> tuple[float, float]: - start, end = text.split("end: ") - start = start.split("start: ")[1] - return float(start), float(end) - - -def _get_file_name_and_page_index(annotation: dict[str, Any]) -> tuple[str, int]: - """Extract the file name and page index from the annotation. - - Args: - annotation (dict): The annotation dictionary. Exported from Label Studio. - - Returns: - tuple[str, int]: The file name and the page index (zero-based). - """ - file_name = annotation["data"]["ocr"].split("/")[-1] - file_name = file_name.split(".")[0] - return file_name.split("_") - - -if __name__ == "__main__": - convert_annotations_to_ground_truth() diff --git a/src/stratigraphy/annotations/draw.py b/src/stratigraphy/annotations/draw.py index 6d1c60b1..e387d1e5 100644 --- a/src/stratigraphy/annotations/draw.py +++ b/src/stratigraphy/annotations/draw.py @@ -14,8 +14,6 @@ from stratigraphy.layer.layer import Layer from stratigraphy.metadata.coordinate_extraction import Coordinate from stratigraphy.metadata.elevation_extraction import Elevation -from stratigraphy.text.textblock import TextBlock -from stratigraphy.util.interval import BoundaryInterval from stratigraphy.util.predictions import OverallFilePredictions load_dotenv() @@ -31,7 +29,7 @@ def draw_predictions( predictions: OverallFilePredictions, directory: Path, out_directory: Path, - document_level_metadata_metrics: pd.DataFrame, + document_level_metadata_metrics: None | pd.DataFrame, ) -> None: """Draw predictions on pdf pages. @@ -50,7 +48,7 @@ def draw_predictions( predictions (dict): Content of the predictions.json file. directory (Path): Path to the directory containing the pdf files. out_directory (Path): Path to the output directory where the images are saved. - document_level_metadata_metrics (pd.DataFrame): Document level metadata metrics. + document_level_metadata_metrics (None | pd.DataFrame): Document level metadata metrics. """ if directory.is_file(): # deal with the case when we pass a file instead of a directory directory = directory.parent @@ -62,15 +60,18 @@ def draw_predictions( elevation = file_prediction.metadata.elevation # Assess the correctness of the metadata - if file_prediction.file_name in document_level_metadata_metrics.index: + if ( + document_level_metadata_metrics is not None + and file_prediction.file_name in document_level_metadata_metrics.index + ): is_coordinates_correct = document_level_metadata_metrics.loc[file_prediction.file_name].coordinate is_elevation_correct = document_level_metadata_metrics.loc[file_prediction.file_name].elevation else: logger.warning( "Metrics for file %s not found in document_level_metadata_metrics.", file_prediction.file_name ) - is_coordinates_correct = False - is_elevation_correct = False + is_coordinates_correct = None + is_elevation_correct = None try: with fitz.Document(directory / file_prediction.file_name) as doc: @@ -104,8 +105,8 @@ def draw_predictions( page.derotation_matrix, [ layer - for layer in file_prediction.layers.get_all_layers() - if layer.material_description.page_number == page_number + for layer in file_prediction.layers_in_document.layers + if layer.material_description.page == page_number ], ) shape.commit() # Commit all the drawing operations to the page @@ -133,9 +134,9 @@ def draw_metadata( derotation_matrix: fitz.Matrix, rotation: float, coordinates: Coordinate | None, - is_coordinate_correct: bool, + is_coordinate_correct: bool | None, elevation_info: Elevation | None, - is_elevation_correct: bool, + is_elevation_correct: bool | None, ) -> None: """Draw the extracted metadata on the top of the given PDF page. @@ -147,44 +148,45 @@ def draw_metadata( derotation_matrix (fitz.Matrix): The derotation matrix of the page. rotation (float): The rotation of the page. coordinates (Coordinate | None): The coordinate object to draw. - is_coordinate_correct (Metrics): Whether the coordinate information is correct. - elevation_info (ElevationInformation | None): The elevation information to draw. - is_elevation_correct (Metrics): Whether the elevation information is correct. + is_coordinate_correct (bool | None): Whether the coordinate information is correct. + elevation_info (Elevation | None): The elevation information to draw. + is_elevation_correct (bool | None): Whether the elevation information is correct. """ - # TODO associate correctness with the extracted coordinates in a better way - coordinate_color = "green" if is_coordinate_correct else "red" coordinate_rect = fitz.Rect([5, 5, 250, 30]) - - elevation_color = "green" if is_elevation_correct else "red" elevation_rect = fitz.Rect([5, 30, 250, 55]) shape.draw_rect(coordinate_rect * derotation_matrix) shape.finish(fill=fitz.utils.getColor("gray"), fill_opacity=0.5) shape.insert_textbox(coordinate_rect * derotation_matrix, f"Coordinates: {coordinates}", rotate=rotation) - shape.draw_line( - coordinate_rect.top_left * derotation_matrix, - coordinate_rect.bottom_left * derotation_matrix, - ) - shape.finish( - color=fitz.utils.getColor(coordinate_color), - width=6, - stroke_opacity=0.5, - ) + if is_coordinate_correct is not None: + # TODO associate correctness with the extracted coordinates in a better way + coordinate_color = "green" if is_coordinate_correct else "red" + shape.draw_line( + coordinate_rect.top_left * derotation_matrix, + coordinate_rect.bottom_left * derotation_matrix, + ) + shape.finish( + color=fitz.utils.getColor(coordinate_color), + width=6, + stroke_opacity=0.5, + ) # Draw the bounding box around the elevation information elevation_txt = f"Elevation: {elevation_info.elevation} m" if elevation_info is not None else "Elevation: N/A" shape.draw_rect(elevation_rect * derotation_matrix) shape.finish(fill=fitz.utils.getColor("gray"), fill_opacity=0.5) shape.insert_textbox(elevation_rect * derotation_matrix, elevation_txt, rotate=rotation) - shape.draw_line( - elevation_rect.top_left * derotation_matrix, - elevation_rect.bottom_left * derotation_matrix, - ) - shape.finish( - color=fitz.utils.getColor(elevation_color), - width=6, - stroke_opacity=0.5, - ) + if is_elevation_correct is not None: + elevation_color = "green" if is_elevation_correct else "red" + shape.draw_line( + elevation_rect.top_left * derotation_matrix, + elevation_rect.bottom_left * derotation_matrix, + ) + shape.finish( + color=fitz.utils.getColor(elevation_color), + width=6, + stroke_opacity=0.5, + ) def draw_coordinates(shape: fitz.Shape, coordinates: Coordinate) -> None: @@ -239,15 +241,7 @@ def draw_material_descriptions(shape: fitz.Shape, derotation_matrix: fitz.Matrix fitz.Rect(layer.material_description.rect) * derotation_matrix, ) shape.finish(color=fitz.utils.getColor("orange")) - draw_layer( - shape=shape, - derotation_matrix=derotation_matrix, - interval=layer.depth_interval, # None if no depth interval - layer=layer.material_description, - index=index, - is_correct=layer.material_is_correct, # None if no ground truth - depth_is_correct=layer.depth_interval_is_correct, # None if no ground truth - ) + draw_layer(shape=shape, derotation_matrix=derotation_matrix, layer=layer, index=index) def draw_depth_columns_and_material_rect( @@ -286,15 +280,7 @@ def draw_depth_columns_and_material_rect( shape.finish(color=fitz.utils.getColor("red")) -def draw_layer( - shape: fitz.Shape, - derotation_matrix: fitz.Matrix, - interval: BoundaryInterval | None, - layer: TextBlock, - index: int, - is_correct: bool, - depth_is_correct: bool, -): +def draw_layer(shape: fitz.Shape, derotation_matrix: fitz.Matrix, layer: Layer, index: int): """Draw layers on a pdf page. In particular, this function: @@ -304,18 +290,15 @@ def draw_layer( Args: shape (fitz.Shape): The shape object for drawing. derotation_matrix (fitz.Matrix): The derotation matrix of the page. - interval (BoundaryInterval | None): Depth interval for the layer. - layer (MaterialDescriptionPrediction): Material description block for the layer. + layer (Layer): The layer (depth interval and material description). index (int): Index of the layer. - is_correct (bool): Whether the text block was correctly identified. - depth_is_correct (bool): Whether the depth interval was correctly identified. """ - if layer.lines: - layer_rect = fitz.Rect(layer.rect) + material_description = layer.material_description.feature + if material_description.lines: color = colors[index % len(colors)] # background color for material description - for line in [line for line in layer.lines]: + for line in [line for line in material_description.lines]: shape.draw_rect(line.rect * derotation_matrix) shape.finish( color=fitz.utils.getColor(color), @@ -323,8 +306,8 @@ def draw_layer( fill=fitz.utils.getColor(color), width=0, ) - if is_correct is not None: - correct_color = "green" if is_correct else "red" + if material_description.is_correct is not None: + correct_color = "green" if material_description.is_correct else "red" shape.draw_line( line.rect.top_left * derotation_matrix, line.rect.bottom_left * derotation_matrix, @@ -335,9 +318,9 @@ def draw_layer( stroke_opacity=0.5, ) - if interval: + if layer.depth_interval: # background color for depth interval - background_rect = interval.background_rect + background_rect = layer.depth_interval.background_rect if background_rect is not None: shape.draw_rect( background_rect * derotation_matrix, @@ -350,8 +333,8 @@ def draw_layer( ) # draw green line if depth interval is correct else red line - if depth_is_correct is not None: - depth_is_correct_color = "green" if depth_is_correct else "red" + if layer.is_correct is not None: + depth_is_correct_color = "green" if layer.is_correct else "red" shape.draw_line( background_rect.top_left * derotation_matrix, background_rect.bottom_left * derotation_matrix, @@ -363,11 +346,12 @@ def draw_layer( ) # line from depth interval to material description - line_anchor = interval.line_anchor + line_anchor = layer.depth_interval.line_anchor if line_anchor: + rect = layer.material_description.rect shape.draw_line( line_anchor * derotation_matrix, - fitz.Point(layer_rect.x0, (layer_rect.y0 + layer_rect.y1) / 2) * derotation_matrix, + fitz.Point(rect.x0, (rect.y0 + rect.y1) / 2) * derotation_matrix, ) shape.finish( color=fitz.utils.getColor(color), diff --git a/src/stratigraphy/benchmark/score.py b/src/stratigraphy/benchmark/score.py index 35fcc7cd..ffee3eec 100644 --- a/src/stratigraphy/benchmark/score.py +++ b/src/stratigraphy/benchmark/score.py @@ -9,8 +9,7 @@ import pandas as pd from dotenv import load_dotenv from stratigraphy import DATAPATH -from stratigraphy.annotations.draw import draw_predictions -from stratigraphy.evaluation.evaluation_dataclasses import BoreholeMetadataMetrics +from stratigraphy.benchmark.ground_truth import GroundTruth from stratigraphy.util.predictions import OverallFilePredictions load_dotenv() @@ -21,29 +20,29 @@ def evaluate( - predictions: OverallFilePredictions, - ground_truth_path: Path, - temp_directory: Path, - input_directory: Path | None, - draw_directory: Path | None, -) -> None: + predictions: OverallFilePredictions, ground_truth_path: Path, temp_directory: Path +) -> None | pd.DataFrame: """Computes all the metrics, logs them, and creates corresponding MLFlow artifacts (when enabled). Args: predictions (OverallFilePredictions): The predictions objects. - ground_truth_path (Path): The path to the ground truth file. + ground_truth_path (Path | None): The path to the ground truth file. temp_directory (Path): The path to the temporary directory. - input_directory (Path | None): The path to the input directory. - draw_directory (Path | None): The path to the draw directory. Returns: - None + None | pd.DataFrame: the document level metadata metrics """ + if not (ground_truth_path and ground_truth_path.exists()): # for inference no ground truth is available + logger.warning("Ground truth file not found. Skipping evaluation.") + return None + + ground_truth = GroundTruth(ground_truth_path) + ############################# # Evaluate the borehole extraction metadata ############################# - metadata_metrics_list = predictions.evaluate_metadata_extraction(ground_truth_path) - metadata_metrics: BoreholeMetadataMetrics = metadata_metrics_list.get_cumulated_metrics() + metadata_metrics_list = predictions.evaluate_metadata_extraction(ground_truth) + metadata_metrics = metadata_metrics_list.get_cumulated_metrics() document_level_metadata_metrics: pd.DataFrame = metadata_metrics_list.get_document_level_metrics() document_level_metadata_metrics.to_csv( temp_directory / "document_level_metadata_metrics.csv", index_label="document_name" @@ -62,7 +61,7 @@ def evaluate( ############################# # Evaluate the borehole extraction ############################# - metrics = predictions.evaluate_borehole_extraction(ground_truth_path) + metrics = predictions.evaluate_geology(ground_truth) metrics.document_level_metrics_df().to_csv( temp_directory / "document_level_metrics.csv", index_label="document_name" @@ -77,11 +76,7 @@ def evaluate( mlflow.log_metrics(metrics_dict) mlflow.log_artifact(temp_directory / "document_level_metrics.csv") - ############################# - # Draw the prediction - ############################# - if input_directory and draw_directory: - draw_predictions(predictions, input_directory, draw_directory, document_level_metadata_metrics) + return document_level_metadata_metrics def main(): @@ -110,8 +105,7 @@ def main(): predictions = OverallFilePredictions.from_json(predictions) - # Customize these as needed - evaluate(predictions, args.ground_truth_path, args.temp_directory, input_directory=None, draw_directory=None) + evaluate(predictions, args.ground_truth_path, args.temp_directory) def parse_cli() -> argparse.Namespace: diff --git a/src/stratigraphy/data_extractor/data_extractor.py b/src/stratigraphy/data_extractor/data_extractor.py index a1da287d..a7387bc6 100644 --- a/src/stratigraphy/data_extractor/data_extractor.py +++ b/src/stratigraphy/data_extractor/data_extractor.py @@ -17,17 +17,11 @@ logger = logging.getLogger(__name__) +@dataclass class ExtractedFeature(metaclass=ABCMeta): """Class for extracted feature information.""" - @abstractmethod - def is_valid(self) -> bool: - """Checks if the information is valid. - - Returns: - bool: True if the information is valid, otherwise False. - """ - pass + is_correct = None @abstractmethod def to_json(self) -> dict: diff --git a/src/stratigraphy/depthcolumn/depthcolumn.py b/src/stratigraphy/depthcolumn/depthcolumn.py index 5d43c69c..6e6eb97f 100644 --- a/src/stratigraphy/depthcolumn/depthcolumn.py +++ b/src/stratigraphy/depthcolumn/depthcolumn.py @@ -3,14 +3,15 @@ from __future__ import annotations import abc +from dataclasses import dataclass import fitz import numpy as np from stratigraphy.depthcolumn.depthcolumnentry import DepthColumnEntry, LayerDepthColumnEntry -from stratigraphy.layer.layer import IntervalBlockGroup from stratigraphy.layer.layer_identifier_column import LayerIdentifierColumn from stratigraphy.lines.line import TextLine, TextWord from stratigraphy.text.find_description import get_description_blocks +from stratigraphy.text.textblock import TextBlock from stratigraphy.util.dataclasses import Line from stratigraphy.util.interval import BoundaryInterval, Interval, LayerInterval @@ -67,7 +68,7 @@ def noise_count(self, all_words: list[TextWord]) -> int: @abc.abstractmethod def identify_groups( self, description_lines: list[TextLine], geometric_lines: list[Line], material_description_rect: fitz.Rect - ) -> list[dict]: + ) -> list[IntervalBlockGroup]: """Identifies groups of description blocks that correspond to depth intervals. Args: @@ -76,8 +77,7 @@ def identify_groups( material_description_rect (fitz.Rect): The bounding box of the material description. Returns: - list[dict]: A list of groups, where each group is a dictionary - with the keys "depth_intervals" and "blocks". + list[IntervalBlockGroup]: A list of groups, where each group is a IntervalBlockGroup. """ pass @@ -265,11 +265,7 @@ def identify_groups( matched_blocks = interval.matching_blocks(description_lines, line_index, next_interval) line_index += sum([len(block.lines) for block in matched_blocks]) - groups.append( - # TODO: This seems to be the only case where a list is passed and most of the time it is a list of one - # element. Seem to need the function: transform_groups(). - IntervalBlockGroup(depth_interval=[interval], block=matched_blocks) - ) + groups.append(IntervalBlockGroup(depth_intervals=[interval], blocks=matched_blocks)) return groups @@ -559,8 +555,8 @@ def identify_groups( current_blocks.extend(pre) if len(exact): if len(current_intervals) > 0 or len(current_blocks) > 0: - groups.append(IntervalBlockGroup(depth_interval=current_intervals, block=current_blocks)) - groups.append(IntervalBlockGroup(depth_interval=[interval], block=exact)) + groups.append(IntervalBlockGroup(depth_intervals=current_intervals, blocks=current_blocks)) + groups.append(IntervalBlockGroup(depth_intervals=[interval], blocks=exact)) current_blocks = post current_intervals = [] else: @@ -570,6 +566,18 @@ def identify_groups( current_intervals.append(interval) if len(current_intervals) > 0 or len(current_blocks) > 0: - groups.append(IntervalBlockGroup(depth_interval=current_intervals, block=current_blocks)) + groups.append(IntervalBlockGroup(depth_intervals=current_intervals, blocks=current_blocks)) return groups + + +@dataclass +class IntervalBlockGroup: + """Helper class to represent a group of depth intervals and an associated group of text blocks. + + The class is used to simplify the code for obtaining an appropriate one-to-one correspondence between depth + intervals and material descriptions. + """ + + depth_intervals: list[Interval] + blocks: list[TextBlock] diff --git a/src/stratigraphy/evaluation/evaluation_dataclasses.py b/src/stratigraphy/evaluation/evaluation_dataclasses.py index 6625dfe5..c61f181e 100644 --- a/src/stratigraphy/evaluation/evaluation_dataclasses.py +++ b/src/stratigraphy/evaluation/evaluation_dataclasses.py @@ -114,13 +114,13 @@ def get_document_level_metrics(self) -> pd.DataFrame: class OverallBoreholeMetadataMetrics(metaclass=abc.ABCMeta): """Metrics for borehole metadata.""" - borehole_metadata_metrics: list[FileBoreholeMetadataMetrics] = None + borehole_metadata_metrics: list[FileBoreholeMetadataMetrics] def __init__(self): """Initializes the OverallBoreholeMetadataMetrics object.""" self.borehole_metadata_metrics = [] - def get_cumulated_metrics(self) -> dict: + def get_cumulated_metrics(self) -> BoreholeMetadataMetrics: """Evaluate the metadata metrics.""" elevation_metrics = Metrics.micro_average( [metadata.elevation_metrics for metadata in self.borehole_metadata_metrics] diff --git a/src/stratigraphy/evaluation/groundwater_evaluator.py b/src/stratigraphy/evaluation/groundwater_evaluator.py index aa83ecaa..dcd0b344 100644 --- a/src/stratigraphy/evaluation/groundwater_evaluator.py +++ b/src/stratigraphy/evaluation/groundwater_evaluator.py @@ -52,15 +52,14 @@ def groundwater_depth_metrics_to_overall_metrics(self): class GroundwaterEvaluator: """Class for evaluating the extracted groundwater information of a borehole.""" - def __init__(self, groundwater_entries: list[GroundwaterInDocument], ground_truth_path: str): + def __init__(self, groundwater_entries: list[GroundwaterInDocument], ground_truth: GroundTruth): """Initializes the GroundwaterEvaluator object. Args: groundwater_entries (list[GroundwaterInDocument]): The metadata to evaluate. - ground_truth_path (str): The path to the ground truth file. + ground_truth (GroundTruth): The ground truth. """ - # Load the ground truth data for the metadata - self.groundwater_ground_truth = GroundTruth(ground_truth_path) + self.ground_truth = ground_truth self.groundwater_entries: list[GroundwaterInDocument] = groundwater_entries def evaluate(self) -> OverallGroundwaterMetrics: @@ -73,7 +72,7 @@ def evaluate(self) -> OverallGroundwaterMetrics: for groundwater_in_doc in self.groundwater_entries: filename = groundwater_in_doc.filename - ground_truth_data = self.groundwater_ground_truth.for_file(filename) + ground_truth_data = self.ground_truth.for_file(filename) if ground_truth_data is None or ground_truth_data.get("groundwater") is None: ground_truth = [] # If no ground truth is available, set it to an empty list else: diff --git a/src/stratigraphy/evaluation/layer_evaluator.py b/src/stratigraphy/evaluation/layer_evaluator.py new file mode 100644 index 00000000..f7e51d87 --- /dev/null +++ b/src/stratigraphy/evaluation/layer_evaluator.py @@ -0,0 +1,149 @@ +"""Classes for evaluating the groundwater levels of a borehole.""" + +import logging +from collections.abc import Callable + +import Levenshtein +from stratigraphy.benchmark.ground_truth import GroundTruth +from stratigraphy.benchmark.metrics import OverallMetrics +from stratigraphy.evaluation.evaluation_dataclasses import Metrics +from stratigraphy.evaluation.utility import _is_valid_depth_interval +from stratigraphy.layer.layer import Layer, LayersInDocument +from stratigraphy.util.util import parse_text + +logger = logging.getLogger(__name__) + +MATERIAL_DESCRIPTION_SIMILARITY_THRESHOLD = 0.9 + + +class LayerEvaluator: + """Class for evaluating the extracted groundwater information of a borehole.""" + + def __init__(self, layers_entries: list[LayersInDocument], ground_truth: GroundTruth): + """Initializes the LayerEvaluator object. + + Args: + layers_entries (list[LayersInDocument]): The layers to evaluate. + ground_truth (GroundTruth): The ground truth. + """ + self.ground_truth = ground_truth + self.layers_entries: list[LayersInDocument] = layers_entries + + def get_layer_metrics(self) -> OverallMetrics: + """Calculate metrics for layer predictions.""" + + def per_layer_action(layer): + if parse_text(layer.material_description.feature.text) == "": + logger.warning("Empty string found in predictions") + + return self.calculate_metrics( + per_layer_filter=lambda layer: True, + per_layer_condition=lambda layer: layer.material_description.feature.is_correct, + per_layer_action=per_layer_action, + ) + + def get_depth_interval_metrics(self) -> OverallMetrics: + """Calculate metrics for depth interval predictions.""" + return self.calculate_metrics( + per_layer_filter=lambda layer: layer.material_description.feature.is_correct + and layer.is_correct is not None, + per_layer_condition=lambda layer: layer.is_correct, + ) + + def calculate_metrics( + self, + per_layer_filter: Callable[[Layer], bool], + per_layer_condition: Callable[[Layer], bool], + per_layer_action: Callable[[Layer], None] | None = None, + ) -> OverallMetrics: + """Calculate metrics based on a condition per layer, after applying a filter. + + Args: + per_layer_filter (Callable[[LayerPrediction], bool]): Function to filter layers to consider. + per_layer_condition (Callable[[LayerPrediction], bool]): Function that returns True if the layer is a hit. + per_layer_action (Optional[Callable[[LayerPrediction], None]]): Optional action to perform per layer. + + Returns: + OverallMetrics: The calculated metrics. + """ + overall_metrics = OverallMetrics() + + for layers_in_document in self.layers_entries: + ground_truth_for_file = self.ground_truth.for_file(layers_in_document.filename) + number_of_truth_values = len(ground_truth_for_file["layers"]) + hits = 0 + total_predictions = 0 + + for layer in layers_in_document.layers: + if per_layer_action: + per_layer_action(layer) + if per_layer_filter(layer): + total_predictions += 1 + if per_layer_condition(layer): + hits += 1 + + fn = 0 + fn = number_of_truth_values - hits + + if total_predictions > 0: + overall_metrics.metrics[layers_in_document.filename] = Metrics( + tp=hits, + fp=total_predictions - hits, + fn=fn, + ) + + return overall_metrics + + @staticmethod + def evaluate_borehole(predicted_layers: list[Layer], ground_truth_layers: list): + """Evaluate all predicted layers for a borehole against the ground truth. + + Args: + predicted_layers (list[Layer]): The predicted layers for the borehole. + ground_truth_layers (list): The ground truth layers for the borehole. + """ + unmatched_layers = ground_truth_layers.copy() + for layer in predicted_layers: + match, depth_interval_is_correct = LayerEvaluator.find_matching_layer(layer, unmatched_layers) + if match: + layer.material_description.feature.is_correct = True + layer.is_correct = depth_interval_is_correct + else: + layer.material_description.feature.is_correct = False + layer.is_correct = None + + @staticmethod + def find_matching_layer(layer: Layer, unmatched_layers: list[dict]) -> tuple[dict, bool] | tuple[None, None]: + """Find the matching layer in the ground truth, if any, and remove it from the list of unmatched layers. + + Args: + layer (Layer): The layer to match. + unmatched_layers (list[dict]): The layers from the ground truth that were not yet matched during the + current evaluation. + + Returns: + tuple[dict, bool] | tuple[None, None]: The matching layer and a boolean indicating if the depth interval + is correct. None if no match was found. + """ + parsed_text = parse_text(layer.material_description.feature.text) + possible_matches = [ + ground_truth_layer + for ground_truth_layer in unmatched_layers + if Levenshtein.ratio(parsed_text, ground_truth_layer["material_description"]) + > MATERIAL_DESCRIPTION_SIMILARITY_THRESHOLD + ] + + if not possible_matches: + return None, None + + for possible_match in possible_matches: + start = possible_match["depth_interval"]["start"] + end = possible_match["depth_interval"]["end"] + + if _is_valid_depth_interval(layer.depth_interval, start, end): + unmatched_layers.remove(possible_match) + return possible_match, True + + match = max(possible_matches, key=lambda x: Levenshtein.ratio(parsed_text, x["material_description"])) + unmatched_layers.remove(match) + return match, False diff --git a/src/stratigraphy/evaluation/metadata_evaluator.py b/src/stratigraphy/evaluation/metadata_evaluator.py index a2d2c180..8cc738b9 100644 --- a/src/stratigraphy/evaluation/metadata_evaluator.py +++ b/src/stratigraphy/evaluation/metadata_evaluator.py @@ -1,8 +1,6 @@ """Classes for evaluating the metadata of a borehole.""" import math -from pathlib import Path -from typing import Any from stratigraphy.benchmark.ground_truth import GroundTruth from stratigraphy.evaluation.evaluation_dataclasses import ( @@ -16,21 +14,19 @@ class MetadataEvaluator: """Class for evaluating the metadata of a borehole.""" - metadata_list: OverallBoreholeMetadata = None - ground_truth: dict[str, Any] = None + metadata_list: OverallBoreholeMetadata + ground_truth: GroundTruth - def __init__(self, metadata_list: OverallBoreholeMetadata, ground_truth_path: Path) -> None: + def __init__(self, metadata_list: OverallBoreholeMetadata, ground_truth: GroundTruth) -> None: """Initializes the MetadataEvaluator object. Args: metadata_list (OverallBoreholeMetadata): Container for multiple borehole metadata objects to evaluate. Contains metadata_per_file for individual boreholes. - ground_truth_path (Path): The path to the ground truth file. + ground_truth (GroundTruth): The ground truth. """ self.metadata_list: OverallBoreholeMetadata = metadata_list - - # Load the ground truth data for the metadata - self.metadata_ground_truth = GroundTruth(ground_truth_path) + self.ground_truth = ground_truth def evaluate(self) -> OverallBoreholeMetadataMetrics: """Evaluate the metadata of the file against the ground truth.""" @@ -43,7 +39,7 @@ def evaluate(self) -> OverallBoreholeMetadataMetrics: ########################################################################################################### extracted_coordinates = metadata.coordinates ground_truth_coordinates = ( - self.metadata_ground_truth.for_file(metadata.filename.name).get("metadata", {}).get("coordinates") + self.ground_truth.for_file(metadata.filename.name).get("metadata", {}).get("coordinates") ) if extracted_coordinates and ground_truth_coordinates: @@ -83,9 +79,7 @@ def evaluate(self) -> OverallBoreholeMetadataMetrics: ############################################################################################################ extracted_elevation = None if metadata.elevation is None else metadata.elevation.elevation ground_truth_elevation = ( - self.metadata_ground_truth.for_file(metadata.filename.name) - .get("metadata", {}) - .get("reference_elevation") + self.ground_truth.for_file(metadata.filename.name).get("metadata", {}).get("reference_elevation") ) if extracted_elevation is not None and ground_truth_elevation is not None: diff --git a/src/stratigraphy/evaluation/utility.py b/src/stratigraphy/evaluation/utility.py index aa29a103..5862df59 100644 --- a/src/stratigraphy/evaluation/utility.py +++ b/src/stratigraphy/evaluation/utility.py @@ -2,13 +2,8 @@ from collections import Counter -import Levenshtein from stratigraphy.evaluation.evaluation_dataclasses import Metrics -from stratigraphy.layer.layer import Layer from stratigraphy.util.interval import Interval -from stratigraphy.util.util import parse_text - -MATERIAL_DESCRIPTION_SIMILARITY_THRESHOLD = 0.9 def count_against_ground_truth(values: list[str], ground_truth: list[str]) -> Metrics: @@ -55,39 +50,3 @@ def _is_valid_depth_interval(depth_interval: Interval, start: float, end: float) return start == depth_interval.start.value and end == depth_interval.end.value return False - - -def find_matching_layer(layer: Layer, unmatched_layers: list[dict]) -> tuple[dict, bool] | tuple[None, None]: - """Find the matching layer in the ground truth. - - Args: - layer (Layer): The layer to match. - unmatched_layers (list[dict]): The layers from the ground truth that were not yet matched during the - current evaluation. - - Returns: - tuple[dict, bool] | tuple[None, None]: The matching layer and a boolean indicating if the depth interval - is correct. None if no match was found. - """ - parsed_text = parse_text(layer.material_description.text) - possible_matches = [ - ground_truth_layer - for ground_truth_layer in unmatched_layers - if Levenshtein.ratio(parsed_text, ground_truth_layer["material_description"]) - > MATERIAL_DESCRIPTION_SIMILARITY_THRESHOLD - ] - - if not possible_matches: - return None, None - - for possible_match in possible_matches: - start = possible_match["depth_interval"]["start"] - end = possible_match["depth_interval"]["end"] - - if _is_valid_depth_interval(layer.depth_interval, start, end): - unmatched_layers.remove(possible_match) - return possible_match, True - - match = max(possible_matches, key=lambda x: Levenshtein.ratio(parsed_text, x["material_description"])) - unmatched_layers.remove(match) - return match, False diff --git a/src/stratigraphy/extract.py b/src/stratigraphy/extract.py index 653200b0..cdc7dcd8 100644 --- a/src/stratigraphy/extract.py +++ b/src/stratigraphy/extract.py @@ -6,10 +6,11 @@ import fitz +from stratigraphy.data_extractor.data_extractor import FeatureOnPage from stratigraphy.depthcolumn import find_depth_columns from stratigraphy.depthcolumn.depthcolumn import DepthColumn from stratigraphy.depths_materials_column_pairs.depths_materials_column_pairs import DepthsMaterialsColumnPairs -from stratigraphy.layer.layer import IntervalBlockGroup, Layer, LayersOnPage +from stratigraphy.layer.layer import IntervalBlockPair, Layer from stratigraphy.layer.layer_identifier_column import ( LayerIdentifierColumn, find_layer_identifier_column, @@ -21,7 +22,7 @@ get_description_blocks_from_layer_identifier, get_description_lines, ) -from stratigraphy.text.textblock import TextBlock, block_distance +from stratigraphy.text.textblock import MaterialDescription, MaterialDescriptionLine, TextBlock, block_distance from stratigraphy.util.dataclasses import Line from stratigraphy.util.interval import BoundaryInterval, Interval from stratigraphy.util.util import ( @@ -36,7 +37,7 @@ class ProcessPageResult: """The result of processing a single page of a pdf.""" - predictions: LayersOnPage + predictions: list[Layer] depth_material_pairs: list[DepthsMaterialsColumnPairs] @@ -118,16 +119,16 @@ def process_page( to_delete.append(i) filtered_pairs = [item for index, item in enumerate(pairs) if index not in to_delete] - groups: list[IntervalBlockGroup] = [] # list of matched depth intervals and text blocks + pairs: list[IntervalBlockPair] = [] # list of matched depth intervals and text blocks # groups is of the form: [{"depth_interval": BoundaryInterval, "block": TextBlock}] if filtered_pairs: # match depth column items with material description for depth_column, material_description_rect in filtered_pairs: description_lines = get_description_lines(lines, material_description_rect) if len(description_lines) > 1: - new_groups = match_columns( + new_pairs = match_columns( depth_column, description_lines, geometric_lines, material_description_rect, **params ) - groups.extend(new_groups) + pairs.extend(new_pairs) filtered_depth_material_column_pairs = [ DepthsMaterialsColumnPairs( depth_column=depth_column, material_description_rect=material_description_rect, page=page_number @@ -149,7 +150,7 @@ def process_page( params["block_line_ratio"], params["left_line_length_threshold"], ) - groups.extend([IntervalBlockGroup(block=block, depth_interval=None) for block in description_blocks]) + pairs.extend([IntervalBlockPair(block=block, depth_interval=None) for block in description_blocks]) filtered_depth_material_column_pairs.extend( [ DepthsMaterialsColumnPairs( @@ -158,18 +159,30 @@ def process_page( ] ) - layer_predictions = LayersOnPage( - [ - Layer( - material_description=group.block, - depth_interval=BoundaryInterval(start=group.depth_interval.start, end=group.depth_interval.end) - if group.depth_interval - else None, - ) - for group in groups - ] - ) - layer_predictions.remove_empty_predictions() + layer_predictions = [ + Layer( + material_description=FeatureOnPage( + feature=MaterialDescription( + text=pair.block.text, + lines=[ + FeatureOnPage( + feature=MaterialDescriptionLine(text_line.text), + rect=text_line.rect, + page=text_line.page_number, + ) + for text_line in pair.block.lines + ], + ), + rect=pair.block.rect, + page=page_number, + ), + depth_interval=BoundaryInterval(start=pair.depth_interval.start, end=pair.depth_interval.end) + if pair.depth_interval + else None, + ) + for pair in pairs + ] + layer_predictions = [layer for layer in layer_predictions if layer.description_nonempty()] return ProcessPageResult(layer_predictions, filtered_depth_material_column_pairs) @@ -209,7 +222,7 @@ def match_columns( geometric_lines: list[Line], material_description_rect: fitz.Rect, **params: dict, -) -> list[IntervalBlockGroup]: +) -> list[IntervalBlockPair]: """Match the depth column entries with the description lines. This function identifies groups of depth intervals and text blocks that are likely to match. @@ -224,7 +237,7 @@ def match_columns( **params (dict): Additional parameters for the matching pipeline. Returns: - list[IntervalBlockGroup]: The matched depth intervals and text blocks. + list[IntervalBlockPair]: The matched depth intervals and text blocks. """ if isinstance(depth_column, DepthColumn): return [ @@ -232,18 +245,18 @@ def match_columns( for group in depth_column.identify_groups( description_lines, geometric_lines, material_description_rect, **params ) - for element in transform_groups(group.depth_interval, group.block, **params) + for element in transform_groups(group.depth_intervals, group.blocks, **params) ] elif isinstance(depth_column, LayerIdentifierColumn): blocks = get_description_blocks_from_layer_identifier(depth_column.entries, description_lines) - groups: list[IntervalBlockGroup] = [] + pairs: list[IntervalBlockPair] = [] for block in blocks: depth_interval = find_depth_columns.get_depth_interval_from_textblock(block) if depth_interval: - groups.append(IntervalBlockGroup(depth_interval=depth_interval, block=block)) + pairs.append(IntervalBlockPair(depth_interval=depth_interval, block=block)) else: - groups.append(IntervalBlockGroup(depth_interval=None, block=block)) - return groups + pairs.append(IntervalBlockPair(depth_interval=None, block=block)) + return pairs else: raise ValueError( f"depth_column must be a DepthColumn or a LayerIdentifierColumn object. Got {type(depth_column)}." @@ -252,7 +265,7 @@ def match_columns( def transform_groups( depth_intervals: list[Interval], blocks: list[TextBlock], **params: dict -) -> list[IntervalBlockGroup]: +) -> list[IntervalBlockPair]: """Transforms the text blocks such that their number equals the number of depth intervals. If there are more depth intervals than text blocks, text blocks are splitted. When there @@ -265,7 +278,7 @@ def transform_groups( **params (dict): Additional parameters for the matching pipeline. Returns: - List[IntervalBlockGroup]: Pairing of text blocks and depth intervals. + List[IntervalBlockPair]: Pairing of text blocks and depth intervals. """ if len(depth_intervals) == 0: return [] @@ -273,7 +286,7 @@ def transform_groups( concatenated_block = TextBlock( [line for block in blocks for line in block.lines] ) # concatenate all text lines within a block; line separation flag does not matter here. - return [IntervalBlockGroup(depth_interval=depth_intervals[0], block=concatenated_block)] + return [IntervalBlockPair(depth_interval=depth_intervals[0], block=concatenated_block)] else: if len(blocks) < len(depth_intervals): blocks = split_blocks_by_textline_length(blocks, target_split_count=len(depth_intervals) - len(blocks)) @@ -283,7 +296,7 @@ def transform_groups( depth_intervals.extend([BoundaryInterval(None, None) for _ in range(len(blocks) - len(depth_intervals))]) return [ - IntervalBlockGroup(depth_interval=depth_interval, block=block) + IntervalBlockPair(depth_interval=depth_interval, block=block) for depth_interval, block in zip(depth_intervals, blocks, strict=False) ] @@ -291,7 +304,7 @@ def transform_groups( def merge_blocks_by_vertical_spacing(blocks: list[TextBlock], target_merge_count: int) -> list[TextBlock]: """Merge textblocks without any geometric lines that separates them. - Note: Deprecated. Currently not in use any more. Kept here until we are sure that it is not needed anymore. + Note: Deprecated. Currently not in use anymore. Kept here until we are sure that it is not needed anymore. The logic looks at the distances between the textblocks and merges them if they are closer than a certain cutoff. diff --git a/src/stratigraphy/layer/duplicate_detection.py b/src/stratigraphy/layer/duplicate_detection.py index 9e484997..9d295608 100644 --- a/src/stratigraphy/layer/duplicate_detection.py +++ b/src/stratigraphy/layer/duplicate_detection.py @@ -7,7 +7,7 @@ import Levenshtein import numpy as np from stratigraphy.annotations.plot_utils import convert_page_to_opencv_img -from stratigraphy.layer.layer import LayersInDocument, LayersOnPage +from stratigraphy.layer.layer import Layer, LayersInDocument logger = logging.getLogger(__name__) @@ -16,9 +16,9 @@ def remove_duplicate_layers( previous_page: fitz.Page, current_page: fitz.Page, previous_layers: LayersInDocument, - current_layers: LayersOnPage, + current_layers: list[Layer], img_template_probability_threshold: float, -) -> LayersOnPage: +) -> list[Layer]: """Remove duplicate layers from the current page based on the layers of the previous page. We check if a layer on the current page is present on the previous page. If we have 3 consecutive layers that are @@ -32,13 +32,13 @@ def remove_duplicate_layers( previous_page (fitz.Page): The previous page. current_page (fitz.Page): The current page containing the layers to check for duplicates. previous_layers (LayersInDocument): The layers of the previous page. - current_layers (LayersOnPage): The layers of the current page. + current_layers (list[Layer]): The layers of the current page. img_template_probability_threshold (float): The threshold for the template matching probability Returns: - list[dict]: The layers of the current page without duplicates. + list[Layer]: The layers of the current page without duplicates. """ - sorted_layers = sorted(current_layers.layers_on_page, key=lambda x: x.material_description.rect.y0) + sorted_layers = sorted(current_layers, key=lambda x: x.material_description.rect.y0) first_non_duplicated_layer_index = 0 count_consecutive_non_duplicate_layers = 0 for layer_index, layer in enumerate(sorted_layers): @@ -57,7 +57,7 @@ def remove_duplicate_layers( else: # in this case we compare the depth interval and material description current_material_description = layer.material_description current_depth_interval = layer.depth_interval - for previous_layer in previous_layers.get_all_layers(): + for previous_layer in previous_layers.layers: if previous_layer.depth_interval is None: # It may happen, that a layer on the previous page does not have depth interval assigned. # In this case we skip the comparison. This should only happen in some edge cases, as we @@ -81,7 +81,10 @@ def remove_duplicate_layers( ) # check if material description is the same text_similarity = ( - Levenshtein.ratio(current_material_description.text, previous_material_description.text) > 0.9 + Levenshtein.ratio( + current_material_description.feature.text, previous_material_description.feature.text + ) + > 0.9 ) same_start_depth = current_depth_interval_start == previous_depth_interval_start @@ -98,11 +101,11 @@ def remove_duplicate_layers( count_consecutive_non_duplicate_layers = 0 else: count_consecutive_non_duplicate_layers += 1 - return LayersOnPage(sorted_layers[first_non_duplicated_layer_index:]) + return sorted_layers[first_non_duplicated_layer_index:] def check_duplicate_layer_by_template_matching( - previous_page: fitz.Page, current_page: fitz.Page, current_layer: dict, img_template_probability_threshold: float + previous_page: fitz.Page, current_page: fitz.Page, current_layer: Layer, img_template_probability_threshold: float ) -> bool: """Check if the current layer is a duplicate of a layer on the previous page by using template matching. @@ -113,7 +116,7 @@ def check_duplicate_layer_by_template_matching( Args: previous_page (fitz.Page): The previous page. current_page (fitz.Page): The current page. - current_layer (dict): The current layer that is checked for a duplicate. + current_layer (Layer): The current layer that is checked for a duplicate. img_template_probability_threshold (float): The threshold for the template matching probability to consider a layer a duplicate. diff --git a/src/stratigraphy/layer/layer.py b/src/stratigraphy/layer/layer.py index 8b815302..ca24797a 100644 --- a/src/stratigraphy/layer/layer.py +++ b/src/stratigraphy/layer/layer.py @@ -4,22 +4,19 @@ from dataclasses import dataclass, field import fitz +from stratigraphy.data_extractor.data_extractor import ExtractedFeature, FeatureOnPage from stratigraphy.depthcolumn.depthcolumnentry import DepthColumnEntry -from stratigraphy.lines.line import TextLine, TextWord from stratigraphy.text.textblock import MaterialDescription, TextBlock from stratigraphy.util.interval import AnnotatedInterval, BoundaryInterval, Interval from stratigraphy.util.util import parse_text -# TODO: make this a subclass of ExtractedFeature (cf. ticket LGVISIUM-79) @dataclass -class Layer: +class Layer(ExtractedFeature): """A class to represent predictions for a single layer.""" - material_description: TextBlock | MaterialDescription + material_description: FeatureOnPage[MaterialDescription] depth_interval: BoundaryInterval | AnnotatedInterval | None - material_is_correct: bool = None - depth_interval_is_correct: bool = None id: uuid.UUID = field(default_factory=uuid.uuid4) def __str__(self) -> str: @@ -28,9 +25,10 @@ def __str__(self) -> str: Returns: str: The object as a string. """ - return ( - f"LayerPrediction(material_description={self.material_description}, depth_interval={self.depth_interval})" - ) + return f"Layer(material_description={self.material_description}, depth_interval={self.depth_interval})" + + def description_nonempty(self) -> bool: + return parse_text(self.material_description.feature.text) != "" def to_json(self) -> dict: """Converts the object to a dictionary. @@ -41,118 +39,64 @@ def to_json(self) -> dict: return { "material_description": self.material_description.to_json() if self.material_description else None, "depth_interval": self.depth_interval.to_json() if self.depth_interval else None, - "material_is_correct": self.material_is_correct, - "depth_interval_is_correct": self.depth_interval_is_correct, "id": str(self.id), } - @staticmethod - def from_json(json_layer_list: list[dict]) -> list["Layer"]: + @classmethod + def from_json(cls, data: dict) -> "Layer": """Converts a dictionary to an object. Args: - json_layer_list (list[dict]): A list of dictionaries representing the layers. + data (dict): A dictionarie representing the layer. Returns: list[LayerPrediction]: A list of LayerPrediction objects. """ - page_layer_predictions_list: list[Layer] = [] - - # Extract the layer predictions. - for layer in json_layer_list: - material_prediction = _create_textblock_object(layer["material_description"]["lines"]) - if "depth_interval" in layer and layer["depth_interval"] is not None: - depth_interval = layer.get("depth_interval", {}) - start_data = depth_interval.get("start") - end_data = depth_interval.get("end") - start = ( - DepthColumnEntry( - value=start_data["value"], - rect=fitz.Rect(start_data["rect"]), - page_number=start_data["page"], - ) - if start_data is not None - else None - ) - end = ( - DepthColumnEntry( - value=end_data["value"], - rect=fitz.Rect(end_data["rect"]), - page_number=end_data["page"], - ) - if end_data is not None - else None + material_prediction = FeatureOnPage.from_json(data["material_description"], MaterialDescription) + if "depth_interval" in data and data["depth_interval"] is not None: + depth_interval = data.get("depth_interval", {}) + start_data = depth_interval.get("start") + end_data = depth_interval.get("end") + start = ( + DepthColumnEntry( + value=start_data["value"], + rect=fitz.Rect(start_data["rect"]), + page_number=start_data["page"], ) - - depth_interval_prediction = BoundaryInterval(start=start, end=end) - layer_predictions = Layer( - material_description=material_prediction, depth_interval=depth_interval_prediction + if start_data is not None + else None + ) + end = ( + DepthColumnEntry( + value=end_data["value"], + rect=fitz.Rect(end_data["rect"]), + page_number=end_data["page"], ) - else: - layer_predictions = Layer(material_description=material_prediction, depth_interval=None) - - page_layer_predictions_list.append(layer_predictions) - - return page_layer_predictions_list - + if end_data is not None + else None + ) -def _create_textblock_object(lines: list[dict]) -> TextBlock: - """Creates a TextBlock object from a dictionary. + depth_interval_prediction = BoundaryInterval(start=start, end=end) + else: + depth_interval_prediction = None - Args: - lines (list[dict]): A list of dictionaries representing the lines. - - Returns: - TextBlock: The object. - """ - lines = [TextLine([TextWord(**line)]) for line in lines] - return TextBlock(lines) - - -# TODO: convert to FeatureOnPage[Layer] (cf. ticket LGVISIUM-79) -@dataclass -class LayersOnPage: - """A class to represent predictions for a single page.""" - - layers_on_page: list[Layer] - - def remove_empty_predictions(self) -> None: - """Remove empty predictions from the layers on the page.""" - self.layers_on_page = [ - layer for layer in self.layers_on_page if parse_text(layer.material_description.text) != "" - ] + return Layer(material_description=material_prediction, depth_interval=depth_interval_prediction) @dataclass class LayersInDocument: """A class to represent predictions for a single document.""" - layers_in_document: list[LayersOnPage] + layers: list[Layer] filename: str - def add_layers_on_page(self, layers_on_page: LayersOnPage): - """Add layers on a page to the layers in the document. - - Args: - layers_on_page (LayersOnPage): The layers on a page to add. - """ - self.layers_in_document.append(layers_on_page) - - def get_all_layers(self) -> list[Layer]: - """Get all layers in the document. - - Returns: - list[Layer]: All layers in the document. - """ - all_layers = [] - for layers_on_page in self.layers_in_document: - all_layers.extend(layers_on_page.layers_on_page) - return all_layers - @dataclass -class IntervalBlockGroup: - """A class to represent a group of depth interval blocks.""" +class IntervalBlockPair: + """Represent the data for a single layer in the borehole profile. + + This consist of a material description (represented as a text block) and a depth interval (if available). + """ - depth_interval: Interval | list[Interval] | None - block: TextBlock | list[TextBlock] + depth_interval: Interval | None + block: TextBlock diff --git a/src/stratigraphy/main.py b/src/stratigraphy/main.py index 84aaed72..35a56c10 100644 --- a/src/stratigraphy/main.py +++ b/src/stratigraphy/main.py @@ -11,6 +11,7 @@ from tqdm import tqdm from stratigraphy import DATAPATH +from stratigraphy.annotations.draw import draw_predictions from stratigraphy.annotations.plot_utils import plot_lines from stratigraphy.benchmark.score import evaluate from stratigraphy.extract import process_page @@ -227,17 +228,15 @@ def start_pipeline( # Save the predictions to the overall predictions object # Initialize common variables - groundwater_entries = None - layers = None - depths_materials_columns_pairs = None + groundwater_entries = GroundwaterInDocument(filename=filename, groundwater=[]) + layers_in_document = LayersInDocument([], filename) + depths_materials_columns_pairs = [] if part == "all": # Extract the groundwater levels groundwater_entries = GroundwaterInDocument.from_document(doc, metadata.elevation) # Extract the layers - layers = LayersInDocument([], filename) - depths_materials_columns_pairs = [] for page_index, page in enumerate(doc): page_number = page_index + 1 logger.info("Processing page %s", page_number) @@ -253,7 +252,7 @@ def start_pipeline( layer_predictions = remove_duplicate_layers( previous_page=doc[page_index - 1], current_page=page, - previous_layers=layers, + previous_layers=layers_in_document, current_layers=process_page_results.predictions, img_template_probability_threshold=matching_params[ "img_template_probability_threshold" @@ -262,7 +261,7 @@ def start_pipeline( else: layer_predictions = process_page_results.predictions - layers.add_layers_on_page(layer_predictions) + layers_in_document.layers.extend(layer_predictions) depths_materials_columns_pairs.extend(process_page_results.depth_material_pairs) if draw_lines: # could be changed to if draw_lines and mflow_tracking: @@ -282,7 +281,7 @@ def start_pipeline( file_name=filename, metadata=metadata, groundwater=groundwater_entries, - layers=layers, + layers_in_document=layers_in_document, depths_materials_columns_pairs=depths_materials_columns_pairs, ) ) @@ -296,14 +295,13 @@ def start_pipeline( with open(predictions_path, "w", encoding="utf8") as file: json.dump(predictions.to_json(), file, ensure_ascii=False) - evaluate( - predictions=predictions, - ground_truth_path=ground_truth_path, - temp_directory=temp_directory, - input_directory=input_directory, - draw_directory=draw_directory, + document_level_metadata_metrics = evaluate( + predictions=predictions, ground_truth_path=ground_truth_path, temp_directory=temp_directory ) + if input_directory and draw_directory: + draw_predictions(predictions, input_directory, draw_directory, document_level_metadata_metrics) + if __name__ == "__main__": click_pipeline() diff --git a/src/stratigraphy/metadata/coordinate_extraction.py b/src/stratigraphy/metadata/coordinate_extraction.py index 16ef3c49..3157c5e8 100644 --- a/src/stratigraphy/metadata/coordinate_extraction.py +++ b/src/stratigraphy/metadata/coordinate_extraction.py @@ -2,7 +2,6 @@ from __future__ import annotations -import abc import logging from dataclasses import dataclass @@ -63,10 +62,6 @@ def to_json(self) -> dict: "page": self.page, } - @abc.abstractmethod - def is_valid(self): - pass - @staticmethod def from_values(east: float, north: float, rect: fitz.Rect, page: int) -> Coordinate | None: """Creates a Coordinate object from the given values. diff --git a/src/stratigraphy/text/textblock.py b/src/stratigraphy/text/textblock.py index 4aa00eb7..b21a7638 100644 --- a/src/stratigraphy/text/textblock.py +++ b/src/stratigraphy/text/textblock.py @@ -3,32 +3,47 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any +from typing import Self import fitz import numpy as np +from stratigraphy.data_extractor.data_extractor import ExtractedFeature, FeatureOnPage from stratigraphy.lines.line import TextLine @dataclass -class MaterialDescription: - """Class to represent a material description in a PDF document. +class MaterialDescriptionLine(ExtractedFeature): + """Class to represent a line of a material description in a PDF document.""" - Note: This class is similar to the TextBlock class. As such it has the attributes text and rect. - But it does not have the attribute lines and is missing class methods. TextBlock is used during the extraction - process where more fine-grained information is required. We lose this "fine-grainedness" when we annotate - the boreholes in label-studio. - """ + text: str + + def to_json(self): + """Convert the MaterialDescriptionLine object to a JSON serializable dictionary.""" + return {"text": self.text} + + @classmethod + def from_json(cls, data: dict) -> Self: + """Converts a dictionary to an object.""" + return cls(text=data["text"]) + + +@dataclass +class MaterialDescription(ExtractedFeature): + """Class to represent a material description in a PDF document.""" text: str - rect: fitz.Rect + lines: list[FeatureOnPage[MaterialDescriptionLine]] def to_json(self): """Convert the MaterialDescription object to a JSON serializable dictionary.""" - return { - "text": self.text, - "rect": [self.rect.x0, self.rect.y0, self.rect.x1, self.rect.y1], - } + return {"text": self.text, "lines": [line.to_json() for line in self.lines]} + + @classmethod + def from_json(cls, data: dict) -> Self: + """Converts a dictionary to an object.""" + return cls( + text=data["text"], lines=[FeatureOnPage.from_json(line, MaterialDescriptionLine) for line in data["lines"]] + ) @dataclass @@ -149,15 +164,6 @@ def _is_legend(self) -> bool: y0_coordinates.append(line.rect.y0) return number_horizontally_close > 1 or number_vertically_close > 2 - def to_json(self) -> dict[str, Any]: - """Convert the TextBlock object to a JSON serializable dictionary.""" - return { - "text": self.text, - "rect": [self.rect.x0, self.rect.y0, self.rect.x1, self.rect.y1], - "lines": [line.to_json() for line in self.lines], - "page": self.page_number, - } - def _is_close(a: float, b: list, tolerance: float) -> bool: return any(abs(a - c) < tolerance for c in b) diff --git a/src/stratigraphy/util/predictions.py b/src/stratigraphy/util/predictions.py index 02507724..f1d187c3 100644 --- a/src/stratigraphy/util/predictions.py +++ b/src/stratigraphy/util/predictions.py @@ -1,22 +1,18 @@ """This module contains classes for predictions.""" import logging -import os -from collections.abc import Callable -from pathlib import Path from stratigraphy.benchmark.ground_truth import GroundTruth -from stratigraphy.benchmark.metrics import OverallMetrics, OverallMetricsCatalog +from stratigraphy.benchmark.metrics import OverallMetricsCatalog from stratigraphy.data_extractor.data_extractor import FeatureOnPage from stratigraphy.depths_materials_column_pairs.depths_materials_column_pairs import DepthsMaterialsColumnPairs -from stratigraphy.evaluation.evaluation_dataclasses import Metrics, OverallBoreholeMetadataMetrics +from stratigraphy.evaluation.evaluation_dataclasses import OverallBoreholeMetadataMetrics from stratigraphy.evaluation.groundwater_evaluator import GroundwaterEvaluator +from stratigraphy.evaluation.layer_evaluator import LayerEvaluator from stratigraphy.evaluation.metadata_evaluator import MetadataEvaluator -from stratigraphy.evaluation.utility import find_matching_layer from stratigraphy.groundwater.groundwater_extraction import Groundwater, GroundwaterInDocument -from stratigraphy.layer.layer import Layer, LayersInDocument, LayersOnPage +from stratigraphy.layer.layer import Layer, LayersInDocument from stratigraphy.metadata.metadata import BoreholeMetadata, OverallBoreholeMetadata -from stratigraphy.util.util import parse_text logger = logging.getLogger(__name__) @@ -26,73 +22,18 @@ class FilePredictions: def __init__( self, - layers: LayersInDocument, + layers_in_document: LayersInDocument, file_name: str, metadata: BoreholeMetadata, groundwater: GroundwaterInDocument, depths_materials_columns_pairs: list[DepthsMaterialsColumnPairs], ): - self.layers: LayersInDocument = layers + self.layers_in_document: LayersInDocument = layers_in_document self.depths_materials_columns_pairs: list[DepthsMaterialsColumnPairs] = depths_materials_columns_pairs self.file_name: str = file_name self.metadata: BoreholeMetadata = metadata self.groundwater: GroundwaterInDocument = groundwater - def convert_to_ground_truth(self): - """Convert the predictions to ground truth format. - - This method is meant to be used in combination with the create_from_label_studio method. - It converts the predictions to ground truth format, which can then be used for evaluation. - - NOTE: This method should be tested before using it to create new ground truth. - - Returns: - dict: The predictions in ground truth format. - """ - ground_truth = {self.file_name: {"metadata": self.metadata}} - layers = [] - for layer in self.layers.get_all_layers(): - material_description = layer.material_description.text - depth_interval = { - "start": layer.depth_interval.start.value if layer.depth_interval.start else None, - "end": layer.depth_interval.end.value if layer.depth_interval.end else None, - } - layers.append({"material_description": material_description, "depth_interval": depth_interval}) - ground_truth[self.file_name]["layers"] = layers - if self.metadata is not None and self.metadata.coordinates is not None: - ground_truth[self.file_name]["metadata"] = { - "coordinates": { - "E": self.metadata.coordinates.east.coordinate_value, - "N": self.metadata.coordinates.north.coordinate_value, - } - } - return ground_truth - - def evaluate(self, ground_truth: dict): - """Evaluate the predictions against the ground truth. - - Args: - ground_truth (dict): The ground truth for the file. - """ - # TODO: Call the evaluator for Layers instead - self.evaluate_layers(ground_truth["layers"]) - - def evaluate_layers(self, ground_truth_layers: list): - """Evaluate all layers of the predictions against the ground truth. - - Args: - ground_truth_layers (list): The ground truth layers for the file. - """ - unmatched_layers = ground_truth_layers.copy() - for layer in self.layers.get_all_layers(): - match, depth_interval_is_correct = find_matching_layer(layer, unmatched_layers) - if match: - layer.material_is_correct = True - layer.depth_interval_is_correct = depth_interval_is_correct - else: - layer.material_is_correct = False - layer.depth_interval_is_correct = None - def to_json(self) -> dict: """Converts the object to a dictionary. @@ -101,7 +42,7 @@ def to_json(self) -> dict: """ return { "metadata": self.metadata.to_json(), - "layers": [layer.to_json() for layer in self.layers.get_all_layers()] if self.layers is not None else [], + "layers": [layer.to_json() for layer in self.layers_in_document.layers], "depths_materials_column_pairs": [dmc_pair.to_json() for dmc_pair in self.depths_materials_columns_pairs] if self.depths_materials_columns_pairs is not None else [], @@ -159,11 +100,8 @@ def from_json(cls, prediction_from_file: dict) -> "OverallFilePredictions": for file_name, file_data in prediction_from_file.items(): metadata = BoreholeMetadata.from_json(file_data["metadata"], file_name) - layers = Layer.from_json(file_data["layers"]) - layers_on_page = LayersOnPage(layers_on_page=layers) - layers_in_doc = LayersInDocument( - layers_in_document=[layers_on_page], filename=file_name - ) # TODO: This is a bit of a hack as we do not seem to save the page of the layer + layers = [Layer.from_json(data) for data in file_data["layers"]] + layers_in_doc = LayersInDocument(layers=layers, filename=file_name) depths_materials_columns_pairs = [ DepthsMaterialsColumnPairs.from_json(dmc_pair) @@ -174,7 +112,7 @@ def from_json(cls, prediction_from_file: dict) -> "OverallFilePredictions": groundwater_in_document = GroundwaterInDocument(groundwater=groundwater_entries, filename=file_name) overall_file_predictions.add_file_predictions( FilePredictions( - layers=layers_in_doc, + layers_in_document=layers_in_doc, file_name=file_name, metadata=metadata, depths_materials_columns_pairs=depths_materials_columns_pairs, @@ -187,101 +125,56 @@ def from_json(cls, prediction_from_file: dict) -> "OverallFilePredictions": ### Evaluation methods ############################################################################################################ - def evaluate_metadata_extraction(self, ground_truth_path: Path) -> OverallBoreholeMetadataMetrics: + def evaluate_metadata_extraction(self, ground_truth: GroundTruth) -> OverallBoreholeMetadataMetrics: """Evaluate the metadata extraction of the predictions against the ground truth. Args: - ground_truth_path (Path): The path to the ground truth file. + ground_truth (GroundTruth): The ground truth. """ metadata_per_file: OverallBoreholeMetadata = OverallBoreholeMetadata() for file_prediction in self.file_predictions_list: metadata_per_file.add_metadata(file_prediction.metadata) - return MetadataEvaluator(metadata_per_file, ground_truth_path).evaluate() + return MetadataEvaluator(metadata_per_file, ground_truth).evaluate() - def evaluate_borehole_extraction(self, ground_truth_path: Path) -> OverallMetricsCatalog | None: + def evaluate_geology(self, ground_truth: GroundTruth) -> OverallMetricsCatalog | None: """Evaluate the borehole extraction predictions. Args: - ground_truth_path (Path): The path to the ground truth file. + ground_truth (GroundTruth): The ground truth. Returns: OverallMetricsCatalog: A OverallMetricsCatalog that maps a metrics name to the corresponding OverallMetrics object. If no ground truth is available, None is returned. """ - ############################################################################################################ - ### Load the ground truth data for the borehole extraction - ############################################################################################################ - ground_truth = None - if ground_truth_path and os.path.exists(ground_truth_path): # for inference no ground truth is available - ground_truth = GroundTruth(ground_truth_path) - else: - logger.warning("Ground truth file not found.") - - ############################################################################################################ - ### Evaluate the borehole extraction - ############################################################################################################ - number_of_truth_values = {} for file_predictions in self.file_predictions_list: - if ground_truth: - ground_truth_for_file = ground_truth.for_file(file_predictions.file_name) - if ground_truth_for_file: - file_predictions.evaluate(ground_truth_for_file) - number_of_truth_values[file_predictions.file_name] = len(ground_truth_for_file["layers"]) - - if number_of_truth_values: - all_metrics = self.evaluate_layer_extraction(number_of_truth_values) - - groundwater_entries = [file_prediction.groundwater for file_prediction in self.file_predictions_list] - overall_groundwater_metrics = GroundwaterEvaluator(groundwater_entries, ground_truth_path).evaluate() - all_metrics.groundwater_metrics = overall_groundwater_metrics.groundwater_metrics_to_overall_metrics() - all_metrics.groundwater_depth_metrics = ( - overall_groundwater_metrics.groundwater_depth_metrics_to_overall_metrics() - ) - return all_metrics - else: - logger.warning("Ground truth file not found. Skipping evaluation.") - return None - - def evaluate_layer_extraction(self, number_of_truth_values: dict) -> OverallMetricsCatalog: - """Calculate F1, precision and recall for the predictions. - - Calculate F1, precision and recall for the individual documents as well as overall. - The individual document metrics are returned as a DataFrame. - - Args: - number_of_truth_values (dict): The number of layer ground truth values per file. + ground_truth_for_file = ground_truth.for_file(file_predictions.file_name) + if ground_truth_for_file: + LayerEvaluator.evaluate_borehole( + file_predictions.layers_in_document.layers, ground_truth_for_file["layers"] + ) - Returns: - OverallMetricsCatalog: A dictionary that maps a metrics name to the corresponding OverallMetrics object - """ - # create predictions by language languages = set(fp.metadata.language for fp in self.file_predictions_list) - predictions_by_language = {language: OverallFilePredictions() for language in languages} - all_metrics = OverallMetricsCatalog(languages=languages) - all_metrics.layer_metrics = get_layer_metrics(self, number_of_truth_values) - all_metrics.depth_interval_metrics = get_depth_interval_metrics(self) - for file_predictions in self.file_predictions_list: - language = file_predictions.metadata.language - if language in predictions_by_language: - predictions_by_language[language].add_file_predictions(file_predictions) + evaluator = LayerEvaluator( + [prediction.layers_in_document for prediction in self.file_predictions_list], ground_truth + ) + all_metrics.layer_metrics = evaluator.get_layer_metrics() + all_metrics.depth_interval_metrics = evaluator.get_depth_interval_metrics() - for language, language_predictions in predictions_by_language.items(): - language_number_of_truth_values = { - prediction.file_name: number_of_truth_values[prediction.file_name] - for prediction in language_predictions.file_predictions_list - } + layers_in_doc_by_language = {language: [] for language in languages} + for file_prediction in self.file_predictions_list: + layers_in_doc_by_language[file_prediction.metadata.language].append(file_prediction.layers_in_document) + for language, layers_in_doc_list in layers_in_doc_by_language.items(): + evaluator = LayerEvaluator(layers_in_doc_list, ground_truth) setattr( all_metrics, f"{language}_layer_metrics", - get_layer_metrics(language_predictions, language_number_of_truth_values), - ) - setattr( - all_metrics, f"{language}_depth_interval_metrics", get_depth_interval_metrics(language_predictions) + evaluator.get_layer_metrics(), ) + setattr(all_metrics, f"{language}_depth_interval_metrics", evaluator.get_depth_interval_metrics()) logger.info("Macro avg:") logger.info( @@ -292,77 +185,10 @@ def evaluate_layer_extraction(self, number_of_truth_values: dict) -> OverallMetr all_metrics.depth_interval_metrics.macro_precision() * 100, ) + groundwater_entries = [file_prediction.groundwater for file_prediction in self.file_predictions_list] + overall_groundwater_metrics = GroundwaterEvaluator(groundwater_entries, ground_truth).evaluate() + all_metrics.groundwater_metrics = overall_groundwater_metrics.groundwater_metrics_to_overall_metrics() + all_metrics.groundwater_depth_metrics = ( + overall_groundwater_metrics.groundwater_depth_metrics_to_overall_metrics() + ) return all_metrics - - -def calculate_metrics( - predictions: OverallFilePredictions, - per_layer_filter: Callable[[Layer], bool], - per_layer_condition: Callable[[Layer], bool], - number_of_truth_values: dict[str, int] | None = None, - per_layer_action: Callable[[Layer], None] | None = None, -) -> OverallMetrics: - """Calculate metrics based on a condition per layer, after applying a filter. - - Args: - predictions (OverallFilePredictions): The predictions. - per_layer_filter (Callable[[LayerPrediction], bool]): Function to filter layers to consider. - per_layer_condition (Callable[[LayerPrediction], bool]): Function that returns True if the layer is a hit. - number_of_truth_values (Optional[Dict[str, int]]): Ground truth counts per file (required for 'fn' - calculation). - per_layer_action (Optional[Callable[[LayerPrediction], None]]): Optional action to perform per layer. - - Returns: - OverallMetrics: The calculated metrics. - """ - overall_metrics = OverallMetrics() - - for file_prediction in predictions.file_predictions_list: - hits = 0 - total_predictions = 0 - - for layer in file_prediction.layers.get_all_layers(): - if per_layer_action: - per_layer_action(layer) - if per_layer_filter(layer): - total_predictions += 1 - if per_layer_condition(layer): - hits += 1 - - fn = 0 - if number_of_truth_values is not None: - fn = number_of_truth_values.get(file_prediction.file_name, 0) - hits - - if total_predictions > 0: - overall_metrics.metrics[file_prediction.file_name] = Metrics( - tp=hits, - fp=total_predictions - hits, - fn=fn, - ) - - return overall_metrics - - -def get_layer_metrics(predictions: OverallFilePredictions, number_of_truth_values: dict) -> OverallMetrics: - """Calculate metrics for layer predictions.""" - - def per_layer_action(layer): - if parse_text(layer.material_description.text) == "": - logger.warning("Empty string found in predictions") - - return calculate_metrics( - predictions=predictions, - per_layer_filter=lambda layer: True, - per_layer_condition=lambda layer: layer.material_is_correct, - number_of_truth_values=number_of_truth_values, - per_layer_action=per_layer_action, - ) - - -def get_depth_interval_metrics(predictions: OverallFilePredictions) -> OverallMetrics: - """Calculate metrics for depth interval predictions.""" - return calculate_metrics( - predictions=predictions, - per_layer_filter=lambda layer: layer.material_is_correct and layer.depth_interval_is_correct is not None, - per_layer_condition=lambda layer: layer.depth_interval_is_correct, - ) diff --git a/tests/test_groundwater.py b/tests/test_groundwater.py index 6d5aad35..1e475410 100644 --- a/tests/test_groundwater.py +++ b/tests/test_groundwater.py @@ -1,6 +1,7 @@ """Tests for the groundwater module.""" import pytest +from stratigraphy.benchmark.ground_truth import GroundTruth from stratigraphy.data_extractor.data_extractor import FeatureOnPage from stratigraphy.evaluation.evaluation_dataclasses import Metrics from stratigraphy.evaluation.groundwater_evaluator import ( @@ -18,9 +19,9 @@ def sample_metrics(): @pytest.fixture -def groundtruth_path(): +def groundtruth(): """Path to the ground truth file.""" - return "example/example_gw_groundtruth.json" + return GroundTruth("example/example_gw_groundtruth.json") def test_add_groundwater_metrics(sample_metrics): @@ -62,7 +63,7 @@ def test_groundwater_depth_metrics_to_overall_metrics(sample_metrics): assert overall.metrics["file_depth"] == gw_metrics.groundwater_depth_metrics -def test_evaluate_with_ground_truth(groundtruth_path): +def test_evaluate_with_ground_truth(groundtruth): """Test the evaluate method with available ground truth data.""" # Sample groundwater entries groundwater_entries = [ @@ -77,7 +78,7 @@ def test_evaluate_with_ground_truth(groundtruth_path): ) ] - evaluator = GroundwaterEvaluator(groundwater_entries, groundtruth_path) + evaluator = GroundwaterEvaluator(groundwater_entries, groundtruth) overall_metrics = evaluator.evaluate() # Assertions @@ -87,7 +88,7 @@ def test_evaluate_with_ground_truth(groundtruth_path): assert overall_metrics.groundwater_metrics[0].groundwater_metrics.precision == 1.0 -def test_evaluate_multiple_entries(groundtruth_path): +def test_evaluate_multiple_entries(groundtruth): """Test the evaluate method with multiple groundwater entries.""" # Sample groundwater entries groundwater_entries = [ @@ -115,7 +116,7 @@ def test_evaluate_multiple_entries(groundtruth_path): ), ] - evaluator = GroundwaterEvaluator(groundwater_entries, groundtruth_path) + evaluator = GroundwaterEvaluator(groundwater_entries, groundtruth) overall_metrics = evaluator.evaluate() # Assertions diff --git a/tests/test_predictions.py b/tests/test_predictions.py index d6cf164d..c692f951 100644 --- a/tests/test_predictions.py +++ b/tests/test_predictions.py @@ -6,10 +6,11 @@ import fitz import pytest +from stratigraphy.benchmark.ground_truth import GroundTruth from stratigraphy.data_extractor.data_extractor import FeatureOnPage from stratigraphy.evaluation.utility import count_against_ground_truth from stratigraphy.groundwater.groundwater_extraction import Groundwater, GroundwaterInDocument -from stratigraphy.layer.layer import LayersInDocument, LayersOnPage +from stratigraphy.layer.layer import LayersInDocument from stratigraphy.metadata.coordinate_extraction import CoordinateEntry, LV95Coordinate from stratigraphy.metadata.metadata import BoreholeMetadata from stratigraphy.util.predictions import FilePredictions, OverallFilePredictions @@ -31,8 +32,7 @@ def sample_file_prediction() -> FilePredictions: layer2 = Mock( material_description=Mock(text="Clay"), depth_interval=Mock(start=Mock(value=30), end=Mock(value=50)) ) - layer_on_page = LayersOnPage(layers_on_page=[layer1, layer2]) - layers_in_document = LayersInDocument(layers_in_document=[layer_on_page], filename="test_file") + layers_in_document = LayersInDocument(layers=[layer1, layer2], filename="test_file") dt_date = datetime(2024, 10, 1) groundwater_on_page = FeatureOnPage( @@ -45,24 +45,14 @@ def sample_file_prediction() -> FilePredictions: metadata = BoreholeMetadata(coordinates=coord, page_dimensions=[Mock(width=10, height=20)], language="en") return FilePredictions( - layers=layers_in_document, + layers_in_document=layers_in_document, file_name="test_file", metadata=metadata, groundwater=groundwater_in_doc, - depths_materials_columns_pairs=None, + depths_materials_columns_pairs=[], ) -def test_convert_to_ground_truth(sample_file_prediction: FilePredictions): - """Test the convert_to_ground_truth method.""" - ground_truth = sample_file_prediction.convert_to_ground_truth() - - assert ground_truth["test_file"]["metadata"]["coordinates"]["E"] == 2789456 - assert ground_truth["test_file"]["metadata"]["coordinates"]["N"] == 1123012 - assert len(ground_truth["test_file"]["layers"]) == 2 - assert ground_truth["test_file"]["layers"][0]["material_description"] == "Sand" - - def test_to_json(sample_file_prediction: FilePredictions): """Test the to_json method.""" result = sample_file_prediction.to_json() @@ -91,8 +81,8 @@ def test_evaluate_metadata_extraction(): file_prediction = Mock(metadata=Mock(to_json=lambda: {"coordinates": "some_coordinates"})) overall_predictions.add_file_predictions(file_prediction) - ground_truth_path = Path("example/example_groundtruth.json") - metadata_metrics = overall_predictions.evaluate_metadata_extraction(ground_truth_path) + ground_truth = GroundTruth(Path("example/example_groundtruth.json")) + metadata_metrics = overall_predictions.evaluate_metadata_extraction(ground_truth) assert metadata_metrics is not None # Ensure the evaluation returns a result