diff --git a/.github/workflows/pipeline_run.yml b/.github/workflows/pipeline_run.yml index 955ce04b..c424d046 100644 --- a/.github/workflows/pipeline_run.yml +++ b/.github/workflows/pipeline_run.yml @@ -21,4 +21,4 @@ jobs: source env/bin/activate pip install -e . echo "Running pipeline" - boreholes-extract-all -l -i example/example_borehole_profile.pdf -o example/ -p example/predictions.json \ No newline at end of file + boreholes-extract-all -l -i example/example_borehole_profile.pdf -o example/ -p example/predictions.json -m example/metadata.json -g example/example_groundtruth.json -pa all \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index a6dcaa6d..8fa809ca 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,6 @@ { "cSpell.words": [ + "dataframe", "DATAPATH", "depthcolumn", "depthcolumnentry", diff --git a/example/example_groundtruth.json b/example/example_groundtruth.json new file mode 100644 index 00000000..53cca8bd --- /dev/null +++ b/example/example_groundtruth.json @@ -0,0 +1,18 @@ +{ + "example_borehole_profile.pdf": { + "groundwater": [], + "layers": [], + "metadata": { + "coordinates": { + "E": 615790, + "N": 157500 + }, + "drilling_date": "1995-09-03", + "drilling_methods": null, + "original_name": "", + "project_name": "", + "reference_elevation": 788.6, + "total_depth": null + } + } +} diff --git a/pyproject.toml b/pyproject.toml index b08c22c5..b979c784 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ all = ["swissgeol-boreholes-dataextraction[test, lint, experiment-tracking, visu [project.scripts] boreholes-extract-all = "stratigraphy.main:click_pipeline" +boreholes-extract-metadata = "stratigraphy.main:click_pipeline_metadata" boreholes-download-profiles = "stratigraphy.get_files:download_directory_froms3" [tool.ruff.lint] diff --git a/src/app/api/v1/endpoints/extract_data.py b/src/app/api/v1/endpoints/extract_data.py index f0e2ec88..910a6643 100644 --- a/src/app/api/v1/endpoints/extract_data.py +++ b/src/app/api/v1/endpoints/extract_data.py @@ -16,8 +16,8 @@ FormatTypes, NotFoundResponse, ) -from stratigraphy.coordinates.coordinate_extraction import CoordinateExtractor, LV03Coordinate, LV95Coordinate -from stratigraphy.util.extract_text import extract_text_lines_from_bbox +from stratigraphy.metadata.coordinate_extraction import CoordinateExtractor, LV03Coordinate, LV95Coordinate +from stratigraphy.text.extract_text import extract_text_lines_from_bbox def extract_data(extract_data_request: ExtractDataRequest) -> ExtractDataResponse: diff --git a/src/app/common/schemas.py b/src/app/common/schemas.py index 76865547..8f124561 100644 --- a/src/app/common/schemas.py +++ b/src/app/common/schemas.py @@ -106,6 +106,7 @@ def to_fitz_rect(self) -> fitz.Rect: """ return fitz.Rect(self.x0, self.y0, self.x1, self.y1) + @staticmethod def load_from_fitz_rect(rect: fitz.Rect) -> "BoundingBox": """Load the bounding box from a PyMuPDF rectangle. diff --git a/src/scripts/label_studio_annotation_to_ground_truth.py b/src/scripts/label_studio_annotation_to_ground_truth.py index 626922c6..81d675f6 100644 --- a/src/scripts/label_studio_annotation_to_ground_truth.py +++ b/src/scripts/label_studio_annotation_to_ground_truth.py @@ -9,10 +9,11 @@ import click import fitz -from stratigraphy.coordinates.coordinate_extraction import Coordinate +from stratigraphy.layer.layer import LayerPrediction +from stratigraphy.metadata.coordinate_extraction import Coordinate +from stratigraphy.text.textblock import MaterialDescription from stratigraphy.util.interval import AnnotatedInterval -from stratigraphy.util.predictions import BoreholeMetaData, FilePredictions, LayerPrediction -from stratigraphy.util.textblock import MaterialDescription +from stratigraphy.util.predictions import BoreholeMetaData, FilePredictions logger = logging.getLogger(__name__) diff --git a/src/stratigraphy/util/draw.py b/src/stratigraphy/annotations/draw.py similarity index 90% rename from src/stratigraphy/util/draw.py rename to src/stratigraphy/annotations/draw.py index 981a5952..57dbf127 100644 --- a/src/stratigraphy/util/draw.py +++ b/src/stratigraphy/annotations/draw.py @@ -5,15 +5,15 @@ from pathlib import Path import fitz +import pandas as pd from dotenv import load_dotenv - -from stratigraphy.benchmark.metrics import Metrics -from stratigraphy.coordinates.coordinate_extraction import Coordinate -from stratigraphy.elevation.elevation_extraction import ElevationInformation from stratigraphy.groundwater.groundwater_extraction import GroundwaterInformationOnPage +from stratigraphy.layer.layer import LayerPrediction +from stratigraphy.metadata.coordinate_extraction import Coordinate +from stratigraphy.metadata.elevation_extraction import Elevation +from stratigraphy.text.textblock import TextBlock from stratigraphy.util.interval import BoundaryInterval -from stratigraphy.util.predictions import FilePredictions, LayerPrediction -from stratigraphy.util.textblock import TextBlock +from stratigraphy.util.predictions import FilePredictions load_dotenv() @@ -24,7 +24,12 @@ logger = logging.getLogger(__name__) -def draw_predictions(predictions: dict[str, FilePredictions], directory: Path, out_directory: Path) -> None: +def draw_predictions( + predictions: dict[str, FilePredictions], + directory: Path, + out_directory: Path, + document_level_metadata_metrics: pd.DataFrame, +) -> None: """Draw predictions on pdf pages. Draws various recognized information on the pdf pages present at directory and saves @@ -42,6 +47,7 @@ def draw_predictions(predictions: dict[str, FilePredictions], directory: Path, o predictions (dict): Content of the predictions.json file. directory (Path): Path to the directory containing the pdf files. out_directory (Path): Path to the output directory where the images are saved. + document_level_metadata_metrics (pd.DataFrame): Document level metadata metrics. """ if directory.is_file(): # deal with the case when we pass a file instead of a directory directory = directory.parent @@ -51,6 +57,11 @@ def draw_predictions(predictions: dict[str, FilePredictions], directory: Path, o depths_materials_column_pairs = file_prediction.depths_materials_columns_pairs coordinates = file_prediction.metadata.coordinates elevation = file_prediction.metadata.elevation + + # Assess the correctness of the metadata + is_coordinates_correct = document_level_metadata_metrics.loc[file_name].coordinate + is_elevation_correct = document_level_metadata_metrics.loc[file_name].elevation + with fitz.Document(directory / file_name) as doc: for page_index, page in enumerate(doc): page_number = page_index + 1 @@ -61,9 +72,9 @@ def draw_predictions(predictions: dict[str, FilePredictions], directory: Path, o page.derotation_matrix, page.rotation, coordinates, - file_prediction.metadata_is_correct.get("coordinates"), + is_coordinates_correct, elevation, - file_prediction.metadata_is_correct.get("elevation"), + is_elevation_correct, ) if coordinates is not None and page_number == coordinates.page: draw_coordinates(shape, coordinates) @@ -107,9 +118,9 @@ def draw_metadata( derotation_matrix: fitz.Matrix, rotation: float, coordinates: Coordinate | None, - coordinates_is_correct: Metrics, - elevation_info: ElevationInformation | None, - elevation_is_correct: Metrics, + is_coordinate_correct: bool, + elevation_info: Elevation | None, + is_elevation_correct: bool, ) -> None: """Draw the extracted metadata on the top of the given PDF page. @@ -121,17 +132,15 @@ def draw_metadata( derotation_matrix (fitz.Matrix): The derotation matrix of the page. rotation (float): The rotation of the page. coordinates (Coordinate | None): The coordinate object to draw. - coordinates_is_correct (Metrics): Whether the coordinates are correct. + is_coordinate_correct (Metrics): Whether the coordinate information is correct. elevation_info (ElevationInformation | None): The elevation information to draw. - elevation_is_correct (Metrics): Whether the elevation information is correct. + is_elevation_correct (Metrics): Whether the elevation information is correct. """ # TODO associate correctness with the extracted coordinates in a better way - coordinate_correct = coordinates_is_correct is not None and coordinates_is_correct.tp > 0 - coordinate_color = "green" if coordinate_correct else "red" + coordinate_color = "green" if is_coordinate_correct else "red" coordinate_rect = fitz.Rect([5, 5, 200, 25]) - elevation_correct = elevation_is_correct is not None and elevation_is_correct.tp > 0 - elevation_color = "green" if elevation_correct else "red" + elevation_color = "green" if is_elevation_correct else "red" elevation_rect = fitz.Rect([5, 25, 200, 45]) shape.draw_rect(coordinate_rect * derotation_matrix) @@ -185,12 +194,12 @@ def draw_groundwater(shape: fitz.Shape, groundwater_entry: GroundwaterInformatio shape.finish(color=fitz.utils.getColor("pink")) -def draw_elevation(shape: fitz.Shape, elevation: ElevationInformation) -> None: +def draw_elevation(shape: fitz.Shape, elevation: Elevation) -> None: """Draw a bounding box around the area of the page where the coordinates were extracted from. Args: shape (fitz.Shape): The shape object for drawing. - elevation (ElevationInformation): The elevation information to draw. + elevation (Elevation): The elevation information to draw. """ shape.draw_rect(elevation.rect) shape.finish(color=fitz.utils.getColor("blue")) diff --git a/src/stratigraphy/util/plot_utils.py b/src/stratigraphy/annotations/plot_utils.py similarity index 98% rename from src/stratigraphy/util/plot_utils.py rename to src/stratigraphy/annotations/plot_utils.py index 5e43bebc..4a4990ea 100644 --- a/src/stratigraphy/util/plot_utils.py +++ b/src/stratigraphy/annotations/plot_utils.py @@ -5,9 +5,8 @@ import cv2 import fitz import numpy as np - +from stratigraphy.text.textblock import TextBlock from stratigraphy.util.dataclasses import Line -from stratigraphy.util.textblock import TextBlock logger = logging.getLogger(__name__) diff --git a/src/stratigraphy/benchmark/ground_truth.py b/src/stratigraphy/benchmark/ground_truth.py index dfe70427..3bf568d8 100644 --- a/src/stratigraphy/benchmark/ground_truth.py +++ b/src/stratigraphy/benchmark/ground_truth.py @@ -16,8 +16,11 @@ class GroundTruth: def __init__(self, path: Path): self.ground_truth = defaultdict(dict) - with open(path) as in_file: + # Load the ground truth data + with open(path, encoding="utf-8") as in_file: ground_truth = json.load(in_file) + + # Parse the ground truth data for borehole_profile, ground_truth_item in ground_truth.items(): layers = ground_truth_item["layers"] self.ground_truth[borehole_profile]["layers"] = [ @@ -42,6 +45,6 @@ def for_file(self, file_name: str) -> dict: """ if file_name in self.ground_truth: return self.ground_truth[file_name] - else: - logger.warning(f"No ground truth data found for {file_name}.") - return {} + + logger.warning("No ground truth data found for %s.", file_name) + return {} diff --git a/src/stratigraphy/benchmark/metrics.py b/src/stratigraphy/benchmark/metrics.py index 45688d95..e6409c67 100644 --- a/src/stratigraphy/benchmark/metrics.py +++ b/src/stratigraphy/benchmark/metrics.py @@ -1,66 +1,21 @@ """Classes for keeping track of metrics such as the F1-score, precision and recall.""" from collections.abc import Callable -from dataclasses import dataclass import pandas as pd - - -@dataclass -class Metrics: - """Computes F-score metrics. - - See also https://en.wikipedia.org/wiki/F-score - - Args: - tp (int): The true positive count - fp (int): The false positive count - fn (int): The false negative count - """ - - tp: int - fp: int - fn: int - - @property - def precision(self) -> float: - """Calculate the precision.""" - if self.tp + self.fp > 0: - return self.tp / (self.tp + self.fp) - else: - return 0 - - @property - def recall(self) -> float: - """Calculate the recall.""" - if self.tp + self.fn > 0: - return self.tp / (self.tp + self.fn) - else: - return 0 - - @property - def f1(self) -> float: - """Calculate the F1 score.""" - if self.precision + self.recall > 0: - return 2 * self.precision * self.recall / (self.precision + self.recall) - else: - return 0 +from stratigraphy.evaluation.evaluation_dataclasses import Metrics class DatasetMetrics: """Keeps track of a particular metrics for all documents in a dataset.""" + # TODO: Currently, some methods for averaging metrics are in the Metrics class. + # (see micro_average(metric_list: list["Metrics"]). On the long run, we should refactor + # this to have a single place where these averaging computations are implemented. + def __init__(self): self.metrics: dict[str, Metrics] = {} - def overall_metrics(self) -> Metrics: - """Can be used to compute micro averages.""" - return Metrics( - tp=sum(metric.tp for metric in self.metrics.values()), - fp=sum(metric.fp for metric in self.metrics.values()), - fn=sum(metric.fn for metric in self.metrics.values()), - ) - def macro_f1(self) -> float: """Compute the macro F1 score.""" if self.metrics: @@ -93,6 +48,7 @@ def pseudo_macro_f1(self) -> float: return 0 def to_dataframe(self, name: str, fn: Callable[[Metrics], float]) -> pd.DataFrame: + """Convert the metrics to a DataFrame.""" series = pd.Series({filename: fn(metric) for filename, metric in self.metrics.items()}) return series.to_frame(name=name) @@ -104,6 +60,7 @@ def __init__(self): self.metrics: dict[str, DatasetMetrics] = {} def document_level_metrics_df(self) -> pd.DataFrame: + """Return a DataFrame with all the document level metrics.""" all_series = [ self.metrics["layer"].to_dataframe("F1", lambda metric: metric.f1), self.metrics["layer"].to_dataframe("precision", lambda metric: metric.precision), @@ -111,8 +68,6 @@ def document_level_metrics_df(self) -> pd.DataFrame: self.metrics["depth_interval"].to_dataframe("Depth_interval_accuracy", lambda metric: metric.precision), self.metrics["layer"].to_dataframe("Number Elements", lambda metric: metric.tp + metric.fn), self.metrics["layer"].to_dataframe("Number wrong elements", lambda metric: metric.fp + metric.fn), - self.metrics["coordinates"].to_dataframe("coordinates", lambda metric: metric.f1), - self.metrics["elevation"].to_dataframe("elevation", lambda metric: metric.f1), self.metrics["groundwater"].to_dataframe("groundwater", lambda metric: metric.f1), self.metrics["groundwater_depth"].to_dataframe("groundwater_depth", lambda metric: metric.f1), ] @@ -122,10 +77,9 @@ def document_level_metrics_df(self) -> pd.DataFrame: return document_level_metrics def metrics_dict(self) -> dict[str, float]: - coordinates_metrics = self.metrics["coordinates"].overall_metrics() - groundwater_metrics = self.metrics["groundwater"].overall_metrics() - groundwater_depth_metrics = self.metrics["groundwater_depth"].overall_metrics() - elevation_metrics = self.metrics["elevation"].overall_metrics() + """Return a dictionary with the overall metrics.""" + groundwater_metrics = Metrics.micro_average(self.metrics["groundwater"].metrics.values()) + groundwater_depth_metrics = Metrics.micro_average(self.metrics["groundwater_depth"].metrics.values()) return { "F1": self.metrics["layer"].pseudo_macro_f1(), @@ -140,16 +94,10 @@ def metrics_dict(self) -> dict[str, float]: "fr_recall": self.metrics["fr_layer"].macro_recall(), "fr_precision": self.metrics["fr_layer"].macro_precision(), "fr_depth_interval_accuracy": self.metrics["fr_depth_interval"].macro_precision(), - "coordinate_f1": coordinates_metrics.f1, - "coordinate_recall": coordinates_metrics.recall, - "coordinate_precision": coordinates_metrics.precision, "groundwater_f1": groundwater_metrics.f1, "groundwater_recall": groundwater_metrics.recall, "groundwater_precision": groundwater_metrics.precision, "groundwater_depth_f1": groundwater_depth_metrics.f1, "groundwater_depth_recall": groundwater_depth_metrics.recall, "groundwater_depth_precision": groundwater_depth_metrics.precision, - "elevation_f1": elevation_metrics.f1, - "elevation_recall": elevation_metrics.recall, - "elevation_precision": elevation_metrics.precision, } diff --git a/src/stratigraphy/benchmark/score.py b/src/stratigraphy/benchmark/score.py index 490f4b56..9a6cfd08 100644 --- a/src/stratigraphy/benchmark/score.py +++ b/src/stratigraphy/benchmark/score.py @@ -7,9 +7,12 @@ from dotenv import load_dotenv from stratigraphy import DATAPATH +from stratigraphy.annotations.draw import draw_predictions from stratigraphy.benchmark.ground_truth import GroundTruth from stratigraphy.benchmark.metrics import DatasetMetrics, DatasetMetricsCatalog, Metrics -from stratigraphy.util.draw import draw_predictions +from stratigraphy.evaluation.evaluation_dataclasses import OverallBoreholeMetadataMetrics +from stratigraphy.evaluation.metadata_evaluator import MetadataEvaluator +from stratigraphy.metadata.metadata import BoreholeMetadataList from stratigraphy.util.predictions import FilePredictions from stratigraphy.util.util import parse_text @@ -65,22 +68,29 @@ def get_depth_interval_metrics(predictions: dict) -> DatasetMetrics: for filename, file_prediction in predictions.items(): depth_interval_hits = 0 - depth_interval_occurences = 0 + depth_interval_occurrences = 0 for layer in file_prediction.layers: if layer.material_is_correct: if layer.depth_interval_is_correct is not None: - depth_interval_occurences += 1 + depth_interval_occurrences += 1 if layer.depth_interval_is_correct: depth_interval_hits += 1 - if depth_interval_occurences > 0: + if depth_interval_occurrences > 0: depth_interval_metrics.metrics[filename] = Metrics( - tp=depth_interval_hits, fp=depth_interval_occurences - depth_interval_hits, fn=0 + tp=depth_interval_hits, fp=depth_interval_occurrences - depth_interval_hits, fn=0 ) return depth_interval_metrics +def evaluate_metadata_extraction( + borehole_metadata: BoreholeMetadataList, ground_truth_path: Path +) -> OverallBoreholeMetadataMetrics: + """Evaluate the metadata extraction.""" + return MetadataEvaluator(borehole_metadata, ground_truth_path).evaluate() + + def evaluate_borehole_extraction( predictions: dict[str, FilePredictions], number_of_truth_values: dict ) -> DatasetMetricsCatalog: @@ -95,8 +105,6 @@ def evaluate_borehole_extraction( object """ all_metrics = evaluate_layer_extraction(predictions, number_of_truth_values) - all_metrics.metrics["coordinates"] = get_metrics(predictions, "metadata_is_correct", "coordinates") - all_metrics.metrics["elevation"] = get_metrics(predictions, "metadata_is_correct", "elevation") all_metrics.metrics["groundwater"] = get_metrics(predictions, "groundwater_is_correct", "groundwater") all_metrics.metrics["groundwater_depth"] = get_metrics(predictions, "groundwater_is_correct", "groundwater_depth") return all_metrics @@ -156,10 +164,11 @@ def evaluate_layer_extraction(predictions: dict, number_of_truth_values: dict) - logging.info("Macro avg:") logging.info( - f"F1: {all_metrics.metrics['layer'].macro_f1():.1%}, " - f"precision: {all_metrics.metrics['layer'].macro_precision():.1%}, " - f"recall: {all_metrics.metrics['layer'].macro_recall():.1%}, " - f"depth_interval_accuracy: {all_metrics.metrics['depth_interval'].macro_precision():.1%}" + "F1: %.1f%%, precision: %.1f%%, recall: %.1f%%, depth_interval_accuracy: %.1f%%", + all_metrics.metrics["layer"].macro_f1() * 100, + all_metrics.metrics["layer"].macro_precision() * 100, + all_metrics.metrics["layer"].macro_recall() * 100, + all_metrics.metrics["depth_interval"].macro_precision() * 100, ) return all_metrics @@ -167,13 +176,15 @@ def evaluate_layer_extraction(predictions: dict, number_of_truth_values: dict) - def create_predictions_objects( predictions: dict, + metadata_per_file: BoreholeMetadataList, ground_truth_path: Path | None, ) -> tuple[dict[str, FilePredictions], dict]: """Create predictions objects from the predictions and evaluate them against the ground truth. Args: predictions (dict): The predictions from the predictions.json file. - ground_truth_path (Path | None): The path to the ground truth file. + ground_truth_path (Path | None): The path to the ground truth file. + metadata_per_file (BoreholeMetadataList): The metadata for the files. Returns: tuple[dict[str, FilePredictions], dict]: The predictions objects and the number of ground truth values per @@ -189,7 +200,11 @@ def create_predictions_objects( number_of_truth_values = {} predictions_objects = {} for file_name, file_predictions in predictions.items(): - prediction_object = FilePredictions.create_from_json(file_predictions, file_name) + metadata = metadata_per_file.get_metadata(file_name) + if not metadata: + raise ValueError(f"Metadata for file {file_name} not found.") + + prediction_object = FilePredictions.create_from_json(file_predictions, metadata, file_name) predictions_objects[file_name] = prediction_object if ground_truth_is_present: @@ -203,34 +218,64 @@ def create_predictions_objects( def evaluate( predictions, + metadata_per_file: BoreholeMetadataList, ground_truth_path: Path, temp_directory: Path, input_directory: Path | None, draw_directory: Path | None, ): """Computes all the metrics, logs them, and creates corresponding MLFlow artifacts (when enabled).""" - predictions, number_of_truth_values = create_predictions_objects(predictions, ground_truth_path) + ############################# + # Evaluate the borehole extraction metadata + ############################# + metadata_metrics_list = evaluate_metadata_extraction(metadata_per_file, ground_truth_path) + metadata_metrics = metadata_metrics_list.get_cumulated_metrics() + document_level_metadata_metrics = metadata_metrics_list.get_document_level_metrics() + + document_level_metadata_metrics.to_csv( + temp_directory / "document_level_metadata_metrics.csv", index_label="document_name" + ) # mlflow.log_artifact expects a file + + # print the metrics + logger.info("Metadata Performance metrics:") + logger.info(metadata_metrics) + + if mlflow_tracking: + import mlflow + + mlflow.log_metrics(metadata_metrics) + mlflow.log_artifact(temp_directory / "document_level_metadata_metrics.csv") - if input_directory and draw_directory: - draw_predictions(predictions, input_directory, draw_directory) + ############################# + # Evaluate the borehole extraction + ############################# + if predictions: + predictions, number_of_truth_values = create_predictions_objects( + predictions, metadata_per_file, ground_truth_path + ) - if number_of_truth_values: # only evaluate if ground truth is available - metrics = evaluate_borehole_extraction(predictions, number_of_truth_values) + if input_directory and draw_directory: + draw_predictions(predictions, input_directory, draw_directory, document_level_metadata_metrics) - metrics.document_level_metrics_df().to_csv( - temp_directory / "document_level_metrics.csv", index_label="document_name" - ) # mlflow.log_artifact expects a file + # evaluate the borehole extraction + if number_of_truth_values: # only evaluate if ground truth is available + metrics = evaluate_borehole_extraction(predictions, number_of_truth_values) - metrics_dict = metrics.metrics_dict() - logger.info("Performance metrics: %s", metrics_dict) + metrics.document_level_metrics_df().to_csv( + temp_directory / "document_level_metrics.csv", index_label="document_name" + ) # mlflow.log_artifact expects a file + metrics_dict = metrics.metrics_dict() - if mlflow_tracking: - import mlflow + # Format the metrics dictionary to limit to three digits + formatted_metrics = {k: f"{v:.3f}" for k, v in metrics_dict.items()} + logger.info("Performance metrics: %s", formatted_metrics) - mlflow.log_metrics(metrics_dict) - mlflow.log_artifact(temp_directory / "document_level_metrics.csv") - else: - logger.warning("Ground truth file not found. Skipping evaluation.") + if mlflow_tracking: + mlflow.log_metrics(metrics_dict) + mlflow.log_artifact(temp_directory / "document_level_metrics.csv") + + else: + logger.warning("Ground truth file not found. Skipping evaluation.") if __name__ == "__main__": @@ -251,4 +296,5 @@ def evaluate( with open(predictions_path, encoding="utf8") as file: predictions = json.load(file) + # TODO read BoreholeMetadataList from JSON file and pass to the evaluate method evaluate(predictions, ground_truth_path, temp_directory, input_directory=None, draw_directory=None) diff --git a/src/stratigraphy/data_extractor/data_extractor.py b/src/stratigraphy/data_extractor/data_extractor.py index 2889df48..b1a11784 100644 --- a/src/stratigraphy/data_extractor/data_extractor.py +++ b/src/stratigraphy/data_extractor/data_extractor.py @@ -9,7 +9,7 @@ import fitz import regex -from stratigraphy.util.line import TextLine +from stratigraphy.lines.line import TextLine from stratigraphy.util.util import read_params logger = logging.getLogger(__name__) diff --git a/src/stratigraphy/util/boundarydepthcolumnvalidator.py b/src/stratigraphy/depthcolumn/boundarydepthcolumnvalidator.py similarity index 97% rename from src/stratigraphy/util/boundarydepthcolumnvalidator.py rename to src/stratigraphy/depthcolumn/boundarydepthcolumnvalidator.py index c1178d03..477007fa 100644 --- a/src/stratigraphy/util/boundarydepthcolumnvalidator.py +++ b/src/stratigraphy/depthcolumn/boundarydepthcolumnvalidator.py @@ -2,9 +2,9 @@ import dataclasses -from stratigraphy.util.depthcolumn import BoundaryDepthColumn -from stratigraphy.util.depthcolumnentry import DepthColumnEntry -from stratigraphy.util.line import TextWord +from stratigraphy.depthcolumn.depthcolumn import BoundaryDepthColumn +from stratigraphy.depthcolumn.depthcolumnentry import DepthColumnEntry +from stratigraphy.lines.line import TextWord @dataclasses.dataclass diff --git a/src/stratigraphy/util/depthcolumn.py b/src/stratigraphy/depthcolumn/depthcolumn.py similarity index 98% rename from src/stratigraphy/util/depthcolumn.py rename to src/stratigraphy/depthcolumn/depthcolumn.py index 4c636fd4..c9169ed2 100644 --- a/src/stratigraphy/util/depthcolumn.py +++ b/src/stratigraphy/depthcolumn/depthcolumn.py @@ -6,12 +6,11 @@ import fitz import numpy as np - +from stratigraphy.depthcolumn.depthcolumnentry import DepthColumnEntry, LayerDepthColumnEntry +from stratigraphy.lines.line import TextLine, TextWord +from stratigraphy.text.find_description import get_description_blocks from stratigraphy.util.dataclasses import Line -from stratigraphy.util.depthcolumnentry import DepthColumnEntry, LayerDepthColumnEntry -from stratigraphy.util.find_description import get_description_blocks from stratigraphy.util.interval import BoundaryInterval, Interval, LayerInterval -from stratigraphy.util.line import TextLine, TextWord class DepthColumn(metaclass=abc.ABCMeta): diff --git a/src/stratigraphy/util/depthcolumnentry.py b/src/stratigraphy/depthcolumn/depthcolumnentry.py similarity index 100% rename from src/stratigraphy/util/depthcolumnentry.py rename to src/stratigraphy/depthcolumn/depthcolumnentry.py diff --git a/src/stratigraphy/util/find_depth_columns.py b/src/stratigraphy/depthcolumn/find_depth_columns.py similarity index 96% rename from src/stratigraphy/util/find_depth_columns.py rename to src/stratigraphy/depthcolumn/find_depth_columns.py index 5a73c947..d7aa8321 100644 --- a/src/stratigraphy/util/find_depth_columns.py +++ b/src/stratigraphy/depthcolumn/find_depth_columns.py @@ -3,12 +3,11 @@ import re import fitz - -from stratigraphy.util.boundarydepthcolumnvalidator import BoundaryDepthColumnValidator -from stratigraphy.util.depthcolumn import BoundaryDepthColumn, LayerDepthColumn -from stratigraphy.util.depthcolumnentry import DepthColumnEntry, LayerDepthColumnEntry -from stratigraphy.util.line import TextWord -from stratigraphy.util.textblock import TextBlock +from stratigraphy.depthcolumn.boundarydepthcolumnvalidator import BoundaryDepthColumnValidator +from stratigraphy.depthcolumn.depthcolumn import BoundaryDepthColumn, LayerDepthColumn +from stratigraphy.depthcolumn.depthcolumnentry import DepthColumnEntry, LayerDepthColumnEntry +from stratigraphy.lines.line import TextWord +from stratigraphy.text.textblock import TextBlock def depth_column_entries(all_words: list[TextWord], include_splits: bool) -> list[DepthColumnEntry]: diff --git a/src/stratigraphy/evaluation/evaluation_dataclasses.py b/src/stratigraphy/evaluation/evaluation_dataclasses.py new file mode 100644 index 00000000..dec4a544 --- /dev/null +++ b/src/stratigraphy/evaluation/evaluation_dataclasses.py @@ -0,0 +1,139 @@ +"""Evaluation utilities.""" + +import abc +from dataclasses import dataclass + +import pandas as pd + + +@dataclass +class Metrics(metaclass=abc.ABCMeta): + """Metrics for the evaluation of extracted features (e.g., Groundwater, Elevation, Coordinates).""" + + tp: int + fp: int + fn: int + + @property + def precision(self) -> float: + """Calculates the precision. + + Returns: + float: The precision. + """ + return self.tp / (self.tp + self.fp) if self.tp + self.fp > 0 else 0 + + @property + def recall(self) -> float: + """Calculates the recall. + + Returns: + float: The recall. + """ + return self.tp / (self.tp + self.fn) if self.tp + self.fn > 0 else 0 + + @property + def f1(self) -> float: + """Calculates the F1 score. + + Returns: + float: The F1 score. + """ + precision = self.precision + recall = self.recall + return 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0 + + def to_json(self, feature_name) -> dict: + """Converts the object to a dictionary. + + Returns: + dict: The object as a dictionary. + """ + return { + f"{feature_name}_precision": self.precision, + f"{feature_name}_recall": self.recall, + f"{feature_name}_f1": self.f1, + } + + # TODO: Currently, some other methods for averaging metrics are in the DatasetMetrics class. + # On the long run, we should refactor this to have a single place where these averaging computations are + # implemented. + @staticmethod + def micro_average(metric_list: list["Metrics"]) -> "Metrics": + """Converts a list of metrics to a metric. + + Args: + metric_list (list): The list of metrics. + + Returns: + Metrics: Combined metrics. + """ + tp = sum([metric.tp for metric in metric_list]) + fp = sum([metric.fp for metric in metric_list]) + fn = sum([metric.fn for metric in metric_list]) + return Metrics(tp=tp, fp=fp, fn=fn) + + +@dataclass +class BoreholeMetadataMetrics(metaclass=abc.ABCMeta): + """Metrics for metadata.""" + + elevation_metrics: Metrics + coordinates_metrics: Metrics + + def to_json(self) -> dict: + """Converts the object to a dictionary. + + Returns: + dict: The object as a dictionary. + """ + return { + **self.elevation_metrics.to_json("elevation"), + **self.coordinates_metrics.to_json("coordinate"), + } + + +@dataclass +class FileBoreholeMetadataMetrics(BoreholeMetadataMetrics): + """Single file Metrics for borehole metadata.""" + + filename: str + + def get_document_level_metrics(self) -> pd.DataFrame: + """Get the document level metrics.""" + return pd.DataFrame( + data={ + "elevation": [self.elevation_metrics.f1], + "coordinate": [self.coordinates_metrics.f1], + }, + index=[self.filename], + ) + + +@dataclass +class OverallBoreholeMetadataMetrics(metaclass=abc.ABCMeta): + """Metrics for borehole metadata.""" + + borehole_metadata_metrics: list[FileBoreholeMetadataMetrics] = None + + def __init__(self): + """Initializes the OverallBoreholeMetadataMetrics object.""" + self.borehole_metadata_metrics = [] + + def get_cumulated_metrics(self) -> dict: + """Evaluate the metadata metrics.""" + elevation_metrics = Metrics.micro_average( + [metadata.elevation_metrics for metadata in self.borehole_metadata_metrics] + ) + coordinates_metrics = Metrics.micro_average( + [metadata.coordinates_metrics for metadata in self.borehole_metadata_metrics] + ) + return BoreholeMetadataMetrics( + elevation_metrics=elevation_metrics, coordinates_metrics=coordinates_metrics + ).to_json() + + def get_document_level_metrics(self) -> pd.DataFrame: + # Get a dataframe per document, concatenate, and sort by index (document name) + return pd.concat( + [metadata.get_document_level_metrics() for metadata in self.borehole_metadata_metrics] + ).sort_index() diff --git a/src/stratigraphy/evaluation/metadata_evaluator.py b/src/stratigraphy/evaluation/metadata_evaluator.py new file mode 100644 index 00000000..1c922d8b --- /dev/null +++ b/src/stratigraphy/evaluation/metadata_evaluator.py @@ -0,0 +1,121 @@ +"""Classes for evaluating the metadata of a borehole.""" + +import math +from typing import Any + +from stratigraphy.benchmark.ground_truth import GroundTruth +from stratigraphy.evaluation.evaluation_dataclasses import ( + FileBoreholeMetadataMetrics, + Metrics, + OverallBoreholeMetadataMetrics, +) +from stratigraphy.metadata.metadata import BoreholeMetadataList + + +class MetadataEvaluator: + """Class for evaluating the metadata of a borehole.""" + + metadata_list: BoreholeMetadataList = None + ground_truth: dict[str, Any] = None + + def __init__(self, metadata_list: BoreholeMetadataList, ground_truth_path: str): + """Initializes the MetadataEvaluator object. + + Args: + metadata_list (BoreholeMetadataList): The metadata to evaluate. + ground_truth_path (str): The path to the ground truth file. + """ + self.metadata_list = metadata_list + + # Load the ground truth data for the metadata + self.metadata_ground_truth = GroundTruth(ground_truth_path) + + def evaluate(self) -> OverallBoreholeMetadataMetrics: + """Evaluate the metadata of the file against the ground truth. + + Args: + ground_truth_path (str): The path to the ground truth file. + """ + # Initialize the metadata correctness metrics + metadata_metrics_list = OverallBoreholeMetadataMetrics() + + for metadata in self.metadata_list.metadata_per_file: + ########################################################################################################### + ### Compute the metadata correctness for the coordinates. + ########################################################################################################### + extracted_coordinates = metadata.coordinates + ground_truth_coordinates = ( + self.metadata_ground_truth.for_file(metadata.filename.name).get("metadata", {}).get("coordinates") + ) + + if extracted_coordinates and ground_truth_coordinates: + if extracted_coordinates.east.coordinate_value > 2e6 and ground_truth_coordinates["E"] < 2e6: + ground_truth_east = int(ground_truth_coordinates["E"]) + 2e6 + ground_truth_north = int(ground_truth_coordinates["N"]) + 1e6 + elif extracted_coordinates.east.coordinate_value < 2e6 and ground_truth_coordinates["E"] > 2e6: + ground_truth_east = int(ground_truth_coordinates["E"]) - 2e6 + ground_truth_north = int(ground_truth_coordinates["N"]) - 1e6 + else: + ground_truth_east = int(ground_truth_coordinates["E"]) + ground_truth_north = int(ground_truth_coordinates["N"]) + + if (math.isclose(int(extracted_coordinates.east.coordinate_value), ground_truth_east, abs_tol=2)) and ( + math.isclose(int(extracted_coordinates.north.coordinate_value), ground_truth_north, abs_tol=2) + ): + coordinate_metrics = Metrics( + tp=1, + fp=0, + fn=0, + ) + else: + coordinate_metrics = Metrics( + tp=0, + fp=1, + fn=1, + ) + else: + coordinate_metrics = Metrics( + tp=0, + fp=1 if extracted_coordinates is not None else 0, + fn=1 if ground_truth_coordinates is not None else 0, + ) + + ############################################################################################################ + ### Compute the metadata correctness for the elevation. + ############################################################################################################ + extracted_elevation = None if metadata.elevation is None else metadata.elevation.elevation + ground_truth_elevation = ( + self.metadata_ground_truth.for_file(metadata.filename.name) + .get("metadata", {}) + .get("reference_elevation") + ) + + if extracted_elevation is not None and ground_truth_elevation is not None: + if math.isclose(extracted_elevation, ground_truth_elevation, abs_tol=0.1): + elevation_metrics = Metrics( + tp=1, + fp=0, + fn=0, + ) + else: + elevation_metrics = Metrics( + tp=0, + fp=1, + fn=1, + ) + else: + elevation_metrics = Metrics( + tp=0, + fp=1 if extracted_elevation is not None else 0, + fn=1 if ground_truth_elevation is not None else 0, + ) + + metadata_metrics_list.borehole_metadata_metrics.append( + FileBoreholeMetadataMetrics( + elevation_metrics=elevation_metrics, + coordinates_metrics=coordinate_metrics, + filename=metadata.filename.name, + ) + ) + + return metadata_metrics_list diff --git a/src/stratigraphy/extract.py b/src/stratigraphy/extract.py index 0aaa0ad3..372beeba 100644 --- a/src/stratigraphy/extract.py +++ b/src/stratigraphy/extract.py @@ -5,23 +5,22 @@ import fitz -from stratigraphy.util import find_depth_columns -from stratigraphy.util.dataclasses import Line -from stratigraphy.util.depthcolumn import DepthColumn -from stratigraphy.util.find_depth_columns import get_depth_interval_from_textblock -from stratigraphy.util.find_description import ( +from stratigraphy.depthcolumn import find_depth_columns +from stratigraphy.depthcolumn.depthcolumn import DepthColumn +from stratigraphy.layer.layer_identifier_column import ( + LayerIdentifierColumn, + find_layer_identifier_column, + find_layer_identifier_column_entries, +) +from stratigraphy.lines.line import TextLine, TextWord +from stratigraphy.text.find_description import ( get_description_blocks, get_description_blocks_from_layer_identifier, get_description_lines, ) +from stratigraphy.text.textblock import TextBlock, block_distance +from stratigraphy.util.dataclasses import Line from stratigraphy.util.interval import BoundaryInterval, Interval -from stratigraphy.util.layer_identifier_column import ( - LayerIdentifierColumn, - find_layer_identifier_column, - find_layer_identifier_column_entries, -) -from stratigraphy.util.line import TextLine, TextWord -from stratigraphy.util.textblock import TextBlock, block_distance from stratigraphy.util.util import ( remove_empty_predictions, x_overlap, @@ -237,7 +236,7 @@ def match_columns( blocks = get_description_blocks_from_layer_identifier(depth_column.entries, description_lines) groups = [] for block in blocks: - depth_interval = get_depth_interval_from_textblock(block) + depth_interval = find_depth_columns.get_depth_interval_from_textblock(block) if depth_interval: groups.append({"depth_interval": depth_interval, "block": block}) else: diff --git a/src/stratigraphy/groundwater/groundwater_extraction.py b/src/stratigraphy/groundwater/groundwater_extraction.py index 4e7006e9..9b533878 100644 --- a/src/stratigraphy/groundwater/groundwater_extraction.py +++ b/src/stratigraphy/groundwater/groundwater_extraction.py @@ -9,10 +9,10 @@ import fitz import numpy as np from stratigraphy.data_extractor.data_extractor import DataExtractor, ExtractedFeature -from stratigraphy.elevation.elevation_extraction import ElevationInformation from stratigraphy.groundwater.utility import extract_date, extract_depth, extract_elevation -from stratigraphy.util.extract_text import extract_text_lines -from stratigraphy.util.line import TextLine +from stratigraphy.lines.line import TextLine +from stratigraphy.metadata.elevation_extraction import Elevation +from stratigraphy.text.extract_text import extract_text_lines logger = logging.getLogger(__name__) @@ -96,7 +96,7 @@ def is_valid(self) -> bool: """ return self.groundwater > 0 - def to_dict(self) -> dict: + def to_json(self) -> dict: """Converts the object to a dictionary. Returns: @@ -269,9 +269,7 @@ def get_groundwater_info_from_lines(self, lines: list[TextLine], page: int) -> G else: raise ValueError("Could not extract all required information from the lines provided.") - def extract_groundwater( - self, terrain_elevation: ElevationInformation | None - ) -> list[GroundwaterInformationOnPage]: + def extract_groundwater(self, terrain_elevation: Elevation | None) -> list[GroundwaterInformationOnPage]: """Extracts the groundwater information from a borehole profile. Processes the borehole profile page by page and tries to find the coordinates in the respective text of the diff --git a/src/stratigraphy/util/duplicate_detection.py b/src/stratigraphy/layer/duplicate_detection.py similarity index 99% rename from src/stratigraphy/util/duplicate_detection.py rename to src/stratigraphy/layer/duplicate_detection.py index 2cc5f900..9260895c 100644 --- a/src/stratigraphy/util/duplicate_detection.py +++ b/src/stratigraphy/layer/duplicate_detection.py @@ -6,8 +6,7 @@ import fitz import Levenshtein import numpy as np - -from stratigraphy.util.plot_utils import convert_page_to_opencv_img +from stratigraphy.annotations.plot_utils import convert_page_to_opencv_img logger = logging.getLogger(__name__) diff --git a/src/stratigraphy/layer/layer.py b/src/stratigraphy/layer/layer.py new file mode 100644 index 00000000..9a5d3d99 --- /dev/null +++ b/src/stratigraphy/layer/layer.py @@ -0,0 +1,28 @@ +"""Layer class definition.""" + +import uuid +from dataclasses import dataclass, field + +from stratigraphy.text.textblock import MaterialDescription, TextBlock +from stratigraphy.util.interval import AnnotatedInterval, BoundaryInterval + + +@dataclass +class LayerPrediction: + """A class to represent predictions for a single layer.""" + + material_description: TextBlock | MaterialDescription + depth_interval: BoundaryInterval | AnnotatedInterval | None + material_is_correct: bool = None + depth_interval_is_correct: bool = None + id: uuid.UUID = field(default_factory=uuid.uuid4) + + def __str__(self) -> str: + """Converts the object to a string. + + Returns: + str: The object as a string. + """ + return ( + f"LayerPrediction(material_description={self.material_description}, depth_interval={self.depth_interval})" + ) diff --git a/src/stratigraphy/util/layer_identifier_column.py b/src/stratigraphy/layer/layer_identifier_column.py similarity index 97% rename from src/stratigraphy/util/layer_identifier_column.py rename to src/stratigraphy/layer/layer_identifier_column.py index 5dc4ecc3..cc5d030d 100644 --- a/src/stratigraphy/util/layer_identifier_column.py +++ b/src/stratigraphy/layer/layer_identifier_column.py @@ -3,8 +3,7 @@ import re import fitz - -from stratigraphy.util.line import TextLine +from stratigraphy.lines.line import TextLine class LayerIdentifierEntry: @@ -21,6 +20,11 @@ def __repr__(self): return str(self.text) def to_json(self): + """Convert the layer identifier entry to a JSON serializable format. + + Returns: + dict: The JSON serializable format of the layer identifier entry. + """ return { "text": self.text, "rect": [self.rect.x0, self.rect.y0, self.rect.x1, self.rect.y1], diff --git a/src/stratigraphy/util/geometric_line_utilities.py b/src/stratigraphy/lines/geometric_line_utilities.py similarity index 99% rename from src/stratigraphy/util/geometric_line_utilities.py rename to src/stratigraphy/lines/geometric_line_utilities.py index 627a1598..b52aae47 100644 --- a/src/stratigraphy/util/geometric_line_utilities.py +++ b/src/stratigraphy/lines/geometric_line_utilities.py @@ -7,9 +7,8 @@ import numpy as np from numpy.typing import ArrayLike - +from stratigraphy.lines.linesquadtree import LinesQuadTree from stratigraphy.util.dataclasses import Line, Point -from stratigraphy.util.linesquadtree import LinesQuadTree logger = logging.getLogger(__name__) diff --git a/src/stratigraphy/util/line.py b/src/stratigraphy/lines/line.py similarity index 99% rename from src/stratigraphy/util/line.py rename to src/stratigraphy/lines/line.py index 05526413..8523828b 100644 --- a/src/stratigraphy/util/line.py +++ b/src/stratigraphy/lines/line.py @@ -3,7 +3,6 @@ from __future__ import annotations import fitz - from stratigraphy.util.util import read_params, x_overlap_significant_largest material_description = read_params("matching_params.yml")["material_description"] diff --git a/src/stratigraphy/line_detection.py b/src/stratigraphy/lines/line_detection.py similarity index 98% rename from src/stratigraphy/line_detection.py rename to src/stratigraphy/lines/line_detection.py index aa49cc06..0d392a05 100644 --- a/src/stratigraphy/line_detection.py +++ b/src/stratigraphy/lines/line_detection.py @@ -8,12 +8,11 @@ import numpy as np from dotenv import load_dotenv from numpy.typing import ArrayLike - -from stratigraphy.util.dataclasses import Line -from stratigraphy.util.geometric_line_utilities import ( +from stratigraphy.lines.geometric_line_utilities import ( drop_vertical_lines, merge_parallel_lines_quadtree, ) +from stratigraphy.util.dataclasses import Line from stratigraphy.util.util import line_from_array, read_params load_dotenv() diff --git a/src/stratigraphy/util/linesquadtree.py b/src/stratigraphy/lines/linesquadtree.py similarity index 99% rename from src/stratigraphy/util/linesquadtree.py rename to src/stratigraphy/lines/linesquadtree.py index 545b4220..3561beea 100644 --- a/src/stratigraphy/util/linesquadtree.py +++ b/src/stratigraphy/lines/linesquadtree.py @@ -3,7 +3,6 @@ import uuid import quads - from stratigraphy.util.dataclasses import Line, Point diff --git a/src/stratigraphy/main.py b/src/stratigraphy/main.py index fd35a9da..fe1870e8 100644 --- a/src/stratigraphy/main.py +++ b/src/stratigraphy/main.py @@ -11,105 +11,165 @@ from tqdm import tqdm from stratigraphy import DATAPATH +from stratigraphy.annotations.plot_utils import plot_lines from stratigraphy.benchmark.score import evaluate -from stratigraphy.coordinates.coordinate_extraction import CoordinateExtractor -from stratigraphy.elevation.elevation_extraction import ElevationExtractor from stratigraphy.extract import process_page from stratigraphy.groundwater.groundwater_extraction import GroundwaterLevelExtractor -from stratigraphy.line_detection import extract_lines, line_detection_params -from stratigraphy.util.duplicate_detection import remove_duplicate_layers -from stratigraphy.util.extract_text import extract_text_lines -from stratigraphy.util.language_detection import detect_language_of_document -from stratigraphy.util.plot_utils import plot_lines +from stratigraphy.layer.duplicate_detection import remove_duplicate_layers +from stratigraphy.lines.line_detection import extract_lines, line_detection_params +from stratigraphy.metadata.metadata import BoreholeMetadata, BoreholeMetadataList +from stratigraphy.text.extract_text import extract_text_lines from stratigraphy.util.util import flatten, read_params load_dotenv() mlflow_tracking = os.getenv("MLFLOW_TRACKING") == "True" # Checks whether MLFlow tracking is enabled +if mlflow_tracking: + import mlflow + logging.basicConfig(format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S") logger = logging.getLogger(__name__) matching_params = read_params("matching_params.yml") +def common_options(f): + """Decorator to add common options to both commands.""" + f = click.option( + "-i", + "--input-directory", + required=True, + type=click.Path(exists=True, path_type=Path), + help="Path to the input directory, or path to a single pdf file.", + )(f) + f = click.option( + "-g", + "--ground-truth-path", + type=click.Path(exists=True, path_type=Path), + help="Path to the ground truth file (optional).", + )(f) + f = click.option( + "-o", + "--out-directory", + type=click.Path(path_type=Path), + default=DATAPATH / "output", + help="Path to the output directory.", + )(f) + f = click.option( + "-p", + "--predictions-path", + type=click.Path(path_type=Path), + default=DATAPATH / "output" / "predictions.json", + help="Path to the predictions file.", + )(f) + f = click.option( + "-m", + "--metadata-path", + type=click.Path(path_type=Path), + default=DATAPATH / "output" / "metadata.json", + help="Path to the metadata file.", + )(f) + f = click.option( + "-s", + "--skip-draw-predictions", + is_flag=True, + default=False, + help="Whether to skip drawing the predictions on pdf pages. Defaults to False.", + )(f) + f = click.option( + "-l", + "--draw-lines", + is_flag=True, + default=False, + help="Whether to draw lines on pdf pages. Defaults to False.", + )(f) + return f + + @click.command() +@common_options @click.option( - "-i", - "--input-directory", - required=True, - type=click.Path(exists=True, path_type=Path), - help="Path to the input directory, or path to a single pdf file.", -) -@click.option( - "-g", - "--ground-truth-path", - type=click.Path(exists=True, path_type=Path), - help="Path to the ground truth file (optional).", -) -@click.option( - "-o", - "--out-directory", - type=click.Path(path_type=Path), - default=DATAPATH / "output", - help="Path to the output directory.", -) -@click.option( - "-p", - "--predictions-path", - type=click.Path(path_type=Path), - default=DATAPATH / "output" / "predictions.json", - help="Path to the predictions file.", -) -@click.option( - "-s", - "--skip-draw-predictions", - is_flag=True, - default=False, - help="Whether to skip drawing the predictions on pdf pages. Defaults to False.", -) -@click.option( - "-l", "--draw-lines", is_flag=True, default=False, help="Whether to draw lines on pdf pages. Defaults to False." + "-pa", "--part", type=click.Choice(["all", "metadata"]), default="all", help="The part of the pipeline to run." ) def click_pipeline( input_directory: Path, ground_truth_path: Path | None, out_directory: Path, predictions_path: Path, + metadata_path: Path, skip_draw_predictions: bool = False, draw_lines: bool = False, + part: str = "all", ): - """Run the boreholes data extraction pipeline. + """Run the boreholes data extraction pipeline.""" + start_pipeline( + input_directory=input_directory, + ground_truth_path=ground_truth_path, + out_directory=out_directory, + predictions_path=predictions_path, + metadata_path=metadata_path, + skip_draw_predictions=skip_draw_predictions, + draw_lines=draw_lines, + part=part, + ) - The pipeline will extract material description of all found layers and assign them to the corresponding - depth intervals. The input directory should contain pdf files with boreholes data. The algorithm can deal - with borehole profiles of multiple pages. - \f - Args: - input_directory (Path): The directory containing the pdf files. Can also be the path to a single pdf file. - ground_truth_path (Path | None): The path to the ground truth file json file. - out_directory (Path): The directory to store the evaluation results. - predictions_path (Path): The path to the predictions file. - skip_draw_predictions (bool, optional): Whether to skip drawing predictions on pdf pages. Defaults to False. - draw_lines (bool, optional): Whether to draw lines on pdf pages. Defaults to False. - """ # noqa: D301 +@click.command() +@common_options +def click_pipeline_metadata( + input_directory: Path, + ground_truth_path: Path | None, + out_directory: Path, + predictions_path: Path, + metadata_path: Path, + skip_draw_predictions: bool = False, + draw_lines: bool = False, +): + """Run only the metadata part of the pipeline.""" start_pipeline( input_directory=input_directory, ground_truth_path=ground_truth_path, out_directory=out_directory, predictions_path=predictions_path, + metadata_path=metadata_path, skip_draw_predictions=skip_draw_predictions, draw_lines=draw_lines, + part="metadata", ) +def setup_mlflow_tracking( + input_directory: Path, + ground_truth_path: Path, + out_directory: Path = None, + predictions_path: Path = None, + metadata_path: Path = None, + experiment_name: str = "Boreholes Stratigraphy", +): + """Set up MLFlow tracking.""" + mlflow.set_experiment(experiment_name) + mlflow.start_run() + mlflow.set_tag("input_directory", str(input_directory)) + mlflow.set_tag("ground_truth_path", str(ground_truth_path)) + if out_directory: + mlflow.set_tag("out_directory", str(out_directory)) + if predictions_path: + mlflow.set_tag("predictions_path", str(predictions_path)) + if metadata_path: + mlflow.set_tag("metadata_path", str(metadata_path)) + mlflow.log_params(flatten(line_detection_params)) + mlflow.log_params(flatten(matching_params)) + + def start_pipeline( input_directory: Path, ground_truth_path: Path, out_directory: Path, predictions_path: Path, + metadata_path: Path, skip_draw_predictions: bool = False, draw_lines: bool = False, + part: str = "all", ): """Run the boreholes data extraction pipeline. @@ -127,26 +187,22 @@ def start_pipeline( predictions_path (Path): The path to the predictions file. skip_draw_predictions (bool, optional): Whether to skip drawing predictions on pdf pages. Defaults to False. draw_lines (bool, optional): Whether to draw lines on pdf pages. Defaults to False. + metadata_path (Path): The path to the metadata file. + part (str, optional): The part of the pipeline to run. Defaults to "all". """ # noqa: D301 if mlflow_tracking: - import mlflow - - mlflow.set_experiment("Boreholes Stratigraphy") - mlflow.start_run() - mlflow.set_tag("input_directory", str(input_directory)) - mlflow.set_tag("ground_truth_path", str(ground_truth_path)) - mlflow.set_tag("out_directory", str(out_directory)) - mlflow.set_tag("predictions_path", str(predictions_path)) - mlflow.log_params(flatten(line_detection_params)) - mlflow.log_params(flatten(matching_params)) + setup_mlflow_tracking(input_directory, ground_truth_path, out_directory, predictions_path, metadata_path) temp_directory = DATAPATH / "_temp" # temporary directory to dump files for mlflow artifact logging - - # check if directories exist and create them when necessary - draw_directory = out_directory / "draw" - draw_directory.mkdir(parents=True, exist_ok=True) temp_directory.mkdir(parents=True, exist_ok=True) + if skip_draw_predictions: + draw_directory = None + else: + # check if directories exist and create them when necessary + draw_directory = out_directory / "draw" + draw_directory.mkdir(parents=True, exist_ok=True) + # if a file is specified instead of an input directory, copy the file to a temporary directory and work with that. if input_directory.is_file(): root = input_directory.parent @@ -154,95 +210,100 @@ def start_pipeline( else: root = input_directory _, _, files = next(os.walk(input_directory)) + # process the individual pdf files predictions = {} + + # process the individual pdf files + metadata_per_file = BoreholeMetadataList() + for filename in tqdm(files, desc="Processing files", unit="file"): if filename.endswith(".pdf"): in_path = os.path.join(root, filename) logger.info("Processing file: %s", in_path) - predictions[filename] = {} with fitz.Document(in_path) as doc: - language = detect_language_of_document( - doc, matching_params["default_language"], matching_params["material_description"].keys() - ) - predictions[filename]["language"] = language - - # Extract the coordinates of the borehole - coordinate_extractor = CoordinateExtractor(document=doc) - coordinates = coordinate_extractor.extract_coordinates() - if coordinates: - predictions[filename]["metadata"] = {"coordinates": coordinates.to_json()} - else: - predictions[filename]["metadata"] = {"coordinates": None} - - # Extract the elevation information - elevation_extractor = ElevationExtractor(document=doc) - elevation = elevation_extractor.extract_elevation() - if elevation: - predictions[filename]["metadata"]["elevation"] = elevation.to_dict() - else: - predictions[filename]["metadata"]["elevation"] = None - - # Extract the groundwater levels - groundwater_extractor = GroundwaterLevelExtractor(document=doc) - groundwater = groundwater_extractor.extract_groundwater(terrain_elevation=elevation) - if groundwater: - predictions[filename]["groundwater"] = [ - groundwater_entry.to_dict() for groundwater_entry in groundwater - ] - else: - predictions[filename]["groundwater"] = None - - layer_predictions_list = [] - depths_materials_column_pairs_list = [] - page_dimensions = [] - for page_index, page in enumerate(doc): - page_number = page_index + 1 - logger.info("Processing page %s", page_number) - - text_lines = extract_text_lines(page) - geometric_lines = extract_lines(page, line_detection_params) - layer_predictions, depths_materials_column_pairs = process_page( - text_lines, geometric_lines, language, page_number, **matching_params - ) - - # TODO: Add remove duplicates here! - if page_index > 0: - layer_predictions = remove_duplicate_layers( - doc[page_index - 1], - page, - layer_predictions_list, - layer_predictions, - matching_params["img_template_probability_threshold"], + # Extract metadata + metadata = BoreholeMetadata(doc) + + # Add metadata to the metadata list + metadata_per_file.metadata_per_file.append(metadata) + + if part == "all": + predictions[filename] = {} + + # Extract the groundwater levels + groundwater_extractor = GroundwaterLevelExtractor(document=doc) + groundwater = groundwater_extractor.extract_groundwater(terrain_elevation=metadata.elevation) + if groundwater: + predictions[filename]["groundwater"] = [ + groundwater_entry.to_json() for groundwater_entry in groundwater + ] + else: + predictions[filename]["groundwater"] = None + + layer_predictions_list = [] + depths_materials_column_pairs_list = [] + page_dimensions = [] + for page_index, page in enumerate(doc): + page_number = page_index + 1 + logger.info("Processing page %s", page_number) + + text_lines = extract_text_lines(page) + geometric_lines = extract_lines(page, line_detection_params) + layer_predictions, depths_materials_column_pairs = process_page( + text_lines, geometric_lines, metadata.language, page_number, **matching_params ) - layer_predictions_list.extend(layer_predictions) - depths_materials_column_pairs_list.extend(depths_materials_column_pairs) - page_dimensions.append({"height": page.rect.height, "width": page.rect.width}) - - if draw_lines: # could be changed to if draw_lines and mflow_tracking: - if not mlflow_tracking: - logger.warning("MLFlow tracking is not enabled. MLFLow is required to store the images.") - else: - img = plot_lines( - page, geometric_lines, scale_factor=line_detection_params["pdf_scale_factor"] + # TODO: Add remove duplicates here! + if page_index > 0: + layer_predictions = remove_duplicate_layers( + doc[page_index - 1], + page, + layer_predictions_list, + layer_predictions, + matching_params["img_template_probability_threshold"], ) - mlflow.log_image(img, f"pages/{filename}_page_{page.number + 1}_lines.png") - - predictions[filename]["layers"] = layer_predictions_list - predictions[filename]["depths_materials_column_pairs"] = depths_materials_column_pairs_list - predictions[filename]["page_dimensions"] = page_dimensions - assert len(page_dimensions) == doc.page_count, "Page count mismatch." - - logger.info("Writing predictions to JSON file %s", predictions_path) - with open(predictions_path, "w", encoding="utf8") as file: - json.dump(predictions, file, ensure_ascii=False) - - if skip_draw_predictions: - draw_directory = None - evaluate(predictions, ground_truth_path, temp_directory, input_directory, draw_directory) + layer_predictions_list.extend(layer_predictions) + depths_materials_column_pairs_list.extend(depths_materials_column_pairs) + page_dimensions.append({"height": page.rect.height, "width": page.rect.width}) + + if draw_lines: # could be changed to if draw_lines and mflow_tracking: + if not mlflow_tracking: + logger.warning( + "MLFlow tracking is not enabled. MLFLow is required to store the images." + ) + else: + img = plot_lines( + page, geometric_lines, scale_factor=line_detection_params["pdf_scale_factor"] + ) + mlflow.log_image(img, f"pages/{filename}_page_{page.number + 1}_lines.png") + + if part == "all": + predictions[filename]["layers"] = layer_predictions_list + predictions[filename]["depths_materials_column_pairs"] = depths_materials_column_pairs_list + predictions[filename]["page_dimensions"] = ( + metadata.page_dimensions + ) # TODO: Remove this as it is already stored in the metadata + + logger.info("Metadata written to %s", metadata_path) + with open(metadata_path, "w", encoding="utf8") as file: + json.dump(metadata_per_file.to_json(), file, ensure_ascii=False) + + if part == "all": + logger.info("Writing predictions to JSON file %s", predictions_path) + with open(predictions_path, "w", encoding="utf8") as file: + json.dump(predictions, file, ensure_ascii=False) + + evaluate( + predictions=predictions, + metadata_per_file=metadata_per_file, + ground_truth_path=ground_truth_path, + temp_directory=temp_directory, + input_directory=input_directory, + draw_directory=draw_directory, + ) if __name__ == "__main__": diff --git a/src/stratigraphy/coordinates/coordinate_extraction.py b/src/stratigraphy/metadata/coordinate_extraction.py similarity index 99% rename from src/stratigraphy/coordinates/coordinate_extraction.py rename to src/stratigraphy/metadata/coordinate_extraction.py index f04085a0..f39e134e 100644 --- a/src/stratigraphy/coordinates/coordinate_extraction.py +++ b/src/stratigraphy/metadata/coordinate_extraction.py @@ -9,8 +9,8 @@ import fitz import regex from stratigraphy.data_extractor.data_extractor import DataExtractor, ExtractedFeature -from stratigraphy.util.extract_text import extract_text_lines_from_bbox -from stratigraphy.util.line import TextLine +from stratigraphy.lines.line import TextLine +from stratigraphy.text.extract_text import extract_text_lines_from_bbox logger = logging.getLogger(__name__) diff --git a/src/stratigraphy/elevation/elevation_extraction.py b/src/stratigraphy/metadata/elevation_extraction.py similarity index 86% rename from src/stratigraphy/elevation/elevation_extraction.py rename to src/stratigraphy/metadata/elevation_extraction.py index ffbb5868..f0f1821b 100644 --- a/src/stratigraphy/elevation/elevation_extraction.py +++ b/src/stratigraphy/metadata/elevation_extraction.py @@ -13,14 +13,14 @@ import numpy as np from stratigraphy.data_extractor.data_extractor import DataExtractor, ExtractedFeature from stratigraphy.groundwater.utility import extract_elevation -from stratigraphy.util.extract_text import extract_text_lines_from_bbox -from stratigraphy.util.line import TextLine +from stratigraphy.lines.line import TextLine +from stratigraphy.text.extract_text import extract_text_lines_from_bbox logger = logging.getLogger(__name__) @dataclass -class ElevationInformation(ExtractedFeature): +class Elevation(ExtractedFeature): """Abstract class for Elevation Information.""" elevation: float | None = None # Elevation relative to the mean sea level @@ -39,9 +39,9 @@ def __str__(self) -> str: Returns: str: The object as a string. """ - return f"ElevationInformation(" f"elevation={self.elevation}, " f"page={self.page})" + return f"Elevation(" f"elevation={self.elevation}, " f"page={self.page})" - def to_dict(self) -> dict: + def to_json(self) -> dict: """Converts the object to a dictionary. Returns: @@ -74,7 +74,7 @@ class ElevationExtractor(DataExtractor): preprocess_replacements = {",": ".", "'": ".", "o": "0", "\n": " ", "ate": "ote"} - def get_elevation_near_key(self, lines: list[TextLine], page: int) -> ElevationInformation | None: + def get_elevation_near_key(self, lines: list[TextLine], page: int) -> Elevation | None: """Find elevation from text lines that are close to an explicit "elevation" label. Also apply some preprocessing to the text of those text lines, to deal with some common (OCR) errors. @@ -85,7 +85,7 @@ def get_elevation_near_key(self, lines: list[TextLine], page: int) -> ElevationI page_width (float): the width of the current page (in points / PyMuPDF coordinates) Returns: - ElevationInformation | None: the found elevation + Elevation | None: the found elevation """ # find the key that indicates the elevation information elevation_key_lines = self.find_feature_key(lines) @@ -104,16 +104,14 @@ def get_elevation_near_key(self, lines: list[TextLine], page: int) -> ElevationI return self.select_best_elevation_information(extracted_elevation_informations) - def select_best_elevation_information( - self, extracted_elevation_informations: list[ElevationInformation] - ) -> ElevationInformation | None: + def select_best_elevation_information(self, extracted_elevation_informations: list[Elevation]) -> Elevation | None: """Select the best elevation information from a list of extracted elevation information. Args: - extracted_elevation_informations (list[ElevationInformation]): A list of extracted elevation information. + extracted_elevation_informations (list[Elevation]): A list of extracted elevation information. Returns: - ElevationInformation | None: The best extracted elevation information. + Elevation | None: The best extracted elevation information. """ # Sort the extracted elevation information by elevation with the highest elevation first extracted_elevation_informations.sort(key=lambda x: x.elevation, reverse=True) @@ -121,7 +119,7 @@ def select_best_elevation_information( # Return the first element of the sorted list return extracted_elevation_informations[0] if extracted_elevation_informations else None - def get_elevation_from_lines(self, lines: list[TextLine], page: int) -> ElevationInformation: + def get_elevation_from_lines(self, lines: list[TextLine], page: int) -> Elevation: r"""Matches the elevation in a string of text. Args: @@ -129,7 +127,7 @@ def get_elevation_from_lines(self, lines: list[TextLine], page: int) -> Elevatio page (int): the page number (1-based) of the PDF document Returns: - ElevationInformation: A list of potential elevation + Elevation: A list of potential elevation """ matched_lines_rect = [] @@ -161,13 +159,13 @@ def get_elevation_from_lines(self, lines: list[TextLine], page: int) -> Elevatio rect_union = None if elevation: - return ElevationInformation(elevation=elevation, rect=rect_union, page=page) + return Elevation(elevation=elevation, rect=rect_union, page=page) else: raise ValueError("Could not extract all required information from the lines provided.") def extract_elevation_from_bbox( self, pdf_page: fitz.Page, page_number: int, bbox: fitz.Rect | None = None - ) -> ElevationInformation | None: + ) -> Elevation | None: """Extract the elevation information from a bounding box. Args: @@ -176,7 +174,7 @@ def extract_elevation_from_bbox( page_number (int): The page number. Returns: - ElevationInformation | None: The extracted elevation information. + Elevation | None: The extracted elevation information. """ lines = extract_text_lines_from_bbox(pdf_page, bbox) @@ -188,7 +186,7 @@ def extract_elevation_from_bbox( logger.info("No elevation found in the bounding box.") - def extract_elevation(self) -> ElevationInformation | None: + def extract_elevation(self) -> Elevation | None: """Extracts the elevation information from a borehole profile. Processes the borehole profile page by page and tries to find the feature key in the respective text of the @@ -197,4 +195,5 @@ def extract_elevation(self) -> ElevationInformation | None: for page in self.doc: page_number = page.number + 1 # page.number is 0-based + # TODO: This return the first found elevation, but we might want to check all pages. return self.extract_elevation_from_bbox(page, page_number) diff --git a/src/stratigraphy/util/language_detection.py b/src/stratigraphy/metadata/language_detection.py similarity index 100% rename from src/stratigraphy/util/language_detection.py rename to src/stratigraphy/metadata/language_detection.py diff --git a/src/stratigraphy/metadata/metadata.py b/src/stratigraphy/metadata/metadata.py new file mode 100644 index 00000000..b980f613 --- /dev/null +++ b/src/stratigraphy/metadata/metadata.py @@ -0,0 +1,130 @@ +"""Metadata for stratigraphy data.""" + +import abc +from dataclasses import dataclass +from pathlib import Path +from typing import NamedTuple + +import fitz +from stratigraphy.metadata.coordinate_extraction import Coordinate, CoordinateExtractor +from stratigraphy.metadata.elevation_extraction import Elevation, ElevationExtractor +from stratigraphy.metadata.language_detection import detect_language_of_document +from stratigraphy.util.util import read_params + + +class PageDimensions(NamedTuple): + """Class for page dimensions.""" + + width: float + height: float + + def to_json(self) -> dict: + """Converts the object to a dictionary. + + Returns: + dict: The object as a dictionary. + """ + return {"width": self.width, "height": self.height} + + +@dataclass +class BoreholeMetadata(metaclass=abc.ABCMeta): + """Metadata for stratigraphy data.""" + + elevation: Elevation | None = None + coordinates: Coordinate | None = None + language: str | None = None # TODO: Change to Enum for the supported languages + filename: Path = None + page_dimensions: list[PageDimensions] = None + + def __init__(self, document: fitz.Document): + """Initializes the BoreholeMetadata object. + + Args: + document (fitz.Document): A PDF document. + """ + matching_params = read_params("matching_params.yml") + + # Detect the language of the document + self.language = detect_language_of_document( + document, matching_params["default_language"], matching_params["material_description"].keys() + ) + + # Extract the coordinates of the borehole + coordinate_extractor = CoordinateExtractor(document=document) + self.coordinates = coordinate_extractor.extract_coordinates() + + # Extract the elevation information + elevation_extractor = ElevationExtractor(document=document) + self.elevation = elevation_extractor.extract_elevation() + + # Get the name of the document + self.filename = Path(document.name) + + # Get the dimensions of the document's pages + self.page_dimensions = [] + for page in document: + self.page_dimensions.append(PageDimensions(width=page.rect.width, height=page.rect.height)) + + # Sanity check + assert len(self.page_dimensions) == document.page_count, "Page count mismatch." + + def to_json(self) -> dict: + """Converts the object to a dictionary. + + Returns: + dict: The object as a dictionary. + """ + return { + "elevation": self.elevation.to_json() if self.elevation else None, + "coordinates": self.coordinates.to_json() if self.coordinates else None, + "language": self.language, + "page_dimensions": [page_dimensions.to_json() for page_dimensions in self.page_dimensions], + } + + def __str__(self) -> str: + """Converts the object to a string. + + Returns: + str: The object as a string. + """ + return ( + f"StratigraphyMetadata(" + f"elevation={self.elevation}, " + f"coordinates={self.coordinates} " + f"language={self.language}, " + f"page_dimensions={self.page_dimensions})" + ) + + +@dataclass +class BoreholeMetadataList(metaclass=abc.ABCMeta): + """Metadata for stratigraphy data.""" + + metadata_per_file: list[BoreholeMetadata] = None + + def __init__(self): + """Initializes the StratigraphyMetadata object.""" + self.metadata_per_file = [] + + def get_metadata(self, filename: str) -> BoreholeMetadata: + """Get the metadata for a specific file. + + Args: + filename (str): The name of the file. + + Returns: + BoreholeMetadata: The metadata for the file. + """ + for metadata in self.metadata_per_file: + if metadata.filename.name == filename: + return metadata + return None + + def to_json(self) -> dict: + """Converts the object to a dictionary. + + Returns: + dict: The object as a dictionary. + """ + return {metadata.filename.name: metadata.to_json() for metadata in self.metadata_per_file} diff --git a/src/stratigraphy/util/description_block_splitter.py b/src/stratigraphy/text/description_block_splitter.py similarity index 98% rename from src/stratigraphy/util/description_block_splitter.py rename to src/stratigraphy/text/description_block_splitter.py index 16cf4586..0f712e11 100644 --- a/src/stratigraphy/util/description_block_splitter.py +++ b/src/stratigraphy/text/description_block_splitter.py @@ -4,10 +4,9 @@ import fitz import numpy as np - +from stratigraphy.lines.line import TextLine +from stratigraphy.text.textblock import TextBlock from stratigraphy.util.dataclasses import Line -from stratigraphy.util.line import TextLine -from stratigraphy.util.textblock import TextBlock class DescriptionBlockSplitter(metaclass=abc.ABCMeta): diff --git a/src/stratigraphy/util/extract_text.py b/src/stratigraphy/text/extract_text.py similarity index 97% rename from src/stratigraphy/util/extract_text.py rename to src/stratigraphy/text/extract_text.py index abea091c..bb58d41b 100644 --- a/src/stratigraphy/util/extract_text.py +++ b/src/stratigraphy/text/extract_text.py @@ -1,8 +1,7 @@ """Methods for extracting plain text from a PDF document.""" import fitz - -from stratigraphy.util.line import TextLine, TextWord +from stratigraphy.lines.line import TextLine, TextWord def extract_text_lines(page: fitz.Page) -> list[TextLine]: diff --git a/src/stratigraphy/util/find_description.py b/src/stratigraphy/text/find_description.py similarity index 96% rename from src/stratigraphy/util/find_description.py rename to src/stratigraphy/text/find_description.py index ce0f660d..48902ab4 100644 --- a/src/stratigraphy/util/find_description.py +++ b/src/stratigraphy/text/find_description.py @@ -1,16 +1,15 @@ """This module contains functions to find the description (blocks) of a material in a pdf page.""" import fitz - -from stratigraphy.util.dataclasses import Line -from stratigraphy.util.description_block_splitter import ( +from stratigraphy.layer.layer_identifier_column import LayerIdentifierEntry +from stratigraphy.lines.line import TextLine +from stratigraphy.text.description_block_splitter import ( SplitDescriptionBlockByLeftHandSideSeparator, SplitDescriptionBlockByLine, SplitDescriptionBlockByVerticalSpace, ) -from stratigraphy.util.layer_identifier_column import LayerIdentifierEntry -from stratigraphy.util.line import TextLine -from stratigraphy.util.textblock import TextBlock +from stratigraphy.text.textblock import TextBlock +from stratigraphy.util.dataclasses import Line def get_description_lines(lines: list[TextLine], material_description_rect: fitz.Rect) -> list[TextLine]: diff --git a/src/stratigraphy/util/textblock.py b/src/stratigraphy/text/textblock.py similarity index 99% rename from src/stratigraphy/util/textblock.py rename to src/stratigraphy/text/textblock.py index 98b82b62..4aa00eb7 100644 --- a/src/stratigraphy/util/textblock.py +++ b/src/stratigraphy/text/textblock.py @@ -7,8 +7,7 @@ import fitz import numpy as np - -from stratigraphy.util.line import TextLine +from stratigraphy.lines.line import TextLine @dataclass diff --git a/src/stratigraphy/util/dataclasses.py b/src/stratigraphy/util/dataclasses.py index 7b5e1b47..c9a777ec 100644 --- a/src/stratigraphy/util/dataclasses.py +++ b/src/stratigraphy/util/dataclasses.py @@ -38,11 +38,17 @@ def __post_init__(self): self.start = self.end self.end = end - self.slope = self.slope() - self.intercept = self.intercept() self.length = self.start.distance_to(self.end) def distance_to(self, point: Point) -> float: + """Calculate the distance of a point to the line. + + Args: + point (Point): The point to calculate the distance to. + + Returns: + float: The distance of the point to the line. + """ # Calculate the distance of the point to the line: # Taken from https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line#Line_defined_by_two_points return np.abs( @@ -50,8 +56,12 @@ def distance_to(self, point: Point) -> float: - (self.start.x - point.x) * (self.end.y - self.start.y) ) / np.sqrt((self.end.x - self.start.x) ** 2 + (self.end.y - self.start.y) ** 2) + @property def slope(self) -> float: + """Calculate the slope of the line.""" return (self.end.y - self.start.y) / (self.end.x - self.start.x) if self.end.x - self.start.x != 0 else np.inf + @property def intercept(self) -> float: + """Calculate the y-intercept of the line.""" return self.start.y - self.slope * self.start.x diff --git a/src/stratigraphy/util/interval.py b/src/stratigraphy/util/interval.py index abc33230..793193c5 100644 --- a/src/stratigraphy/util/interval.py +++ b/src/stratigraphy/util/interval.py @@ -6,9 +6,13 @@ import fitz -from stratigraphy.util.depthcolumnentry import AnnotatedDepthColumnEntry, DepthColumnEntry, LayerDepthColumnEntry -from stratigraphy.util.line import TextLine -from stratigraphy.util.textblock import TextBlock +from stratigraphy.depthcolumn.depthcolumnentry import ( + AnnotatedDepthColumnEntry, + DepthColumnEntry, + LayerDepthColumnEntry, +) +from stratigraphy.lines.line import TextLine +from stratigraphy.text.textblock import TextBlock class Interval(metaclass=abc.ABCMeta): diff --git a/src/stratigraphy/util/predictions.py b/src/stratigraphy/util/predictions.py index ad5dea0a..005a65da 100644 --- a/src/stratigraphy/util/predictions.py +++ b/src/stratigraphy/util/predictions.py @@ -1,46 +1,24 @@ """This module contains classes for predictions.""" import logging -import math -import uuid from collections import Counter -from dataclasses import dataclass, field import fitz import Levenshtein -from stratigraphy.benchmark.metrics import Metrics -from stratigraphy.coordinates.coordinate_extraction import Coordinate -from stratigraphy.elevation.elevation_extraction import ElevationInformation +from stratigraphy.depthcolumn.depthcolumnentry import DepthColumnEntry +from stratigraphy.evaluation.evaluation_dataclasses import Metrics from stratigraphy.groundwater.groundwater_extraction import GroundwaterInformation, GroundwaterInformationOnPage -from stratigraphy.util.depthcolumnentry import DepthColumnEntry -from stratigraphy.util.interval import AnnotatedInterval, BoundaryInterval -from stratigraphy.util.line import TextLine, TextWord -from stratigraphy.util.textblock import MaterialDescription, TextBlock +from stratigraphy.layer.layer import LayerPrediction +from stratigraphy.lines.line import TextLine, TextWord +from stratigraphy.metadata.metadata import BoreholeMetadata +from stratigraphy.text.textblock import TextBlock +from stratigraphy.util.interval import BoundaryInterval from stratigraphy.util.util import parse_text logger = logging.getLogger(__name__) -@dataclass -class BoreholeMetaData: - """Class to represent metadata of a borehole profile.""" - - coordinates: Coordinate | None - elevation: ElevationInformation | None - - -@dataclass -class LayerPrediction: - """A class to represent predictions for a single layer.""" - - material_description: TextBlock | MaterialDescription - depth_interval: BoundaryInterval | AnnotatedInterval | None - material_is_correct: bool = None - depth_interval_is_correct: bool = None - id: uuid.UUID = field(default_factory=uuid.uuid4) - - class FilePredictions: """A class to represent predictions for a single file.""" @@ -49,7 +27,7 @@ def __init__( layers: list[LayerPrediction], file_name: str, language: str, - metadata: BoreholeMetaData, + metadata: BoreholeMetadata, groundwater_entries: list[GroundwaterInformationOnPage], depths_materials_columns_pairs: list[dict], page_sizes: list[dict[str, float]], @@ -59,36 +37,23 @@ def __init__( self.file_name = file_name self.language = language self.metadata = metadata - self.metadata_is_correct: dict = {} self.page_sizes: list[dict[str, float]] = page_sizes self.groundwater_entries = groundwater_entries self.groundwater_is_correct: dict = {} @staticmethod - def create_from_json(predictions_for_file: dict, file_name: str): + def create_from_json(predictions_for_file: dict, metadata: BoreholeMetadata, file_name: str): """Create predictions class for a file given the predictions.json file. Args: predictions_for_file (dict): The predictions for the file in json format. + metadata (BoreholeMetadata): The metadata for the file. file_name (str): The name of the file. """ page_layer_predictions_list: list[LayerPrediction] = [] pages_dimensions_list: list[dict[str, float]] = [] depths_materials_columns_pairs_list: list[dict] = [] - file_language = predictions_for_file["language"] - - # Extract metadata. - metadata = predictions_for_file["metadata"] - coordinates = None - elevation = None - if "coordinates" in metadata and metadata["coordinates"] is not None: - coordinates = Coordinate.from_json(metadata["coordinates"]) - if "elevation" in metadata and metadata["elevation"] is not None: - elevation = ElevationInformation(**metadata["elevation"]) if metadata["elevation"] is not None else None - file_metadata = BoreholeMetaData(coordinates=coordinates, elevation=elevation) - # TODO: Add additional metadata here. - # Extract groundwater information if available. if "groundwater" in predictions_for_file and predictions_for_file["groundwater"] is not None: groundwater_entries = [ @@ -137,8 +102,8 @@ def create_from_json(predictions_for_file: dict, file_name: str): return FilePredictions( layers=page_layer_predictions_list, file_name=file_name, - language=file_language, - metadata=file_metadata, + language=metadata.language, + metadata=metadata, depths_materials_columns_pairs=depths_materials_columns_pairs_list, page_sizes=pages_dimensions_list, groundwater_entries=groundwater_entries, @@ -181,7 +146,6 @@ def evaluate(self, ground_truth: dict): ground_truth (dict): The ground truth for the file. """ self.evaluate_layers(ground_truth["layers"]) - self.evaluate_metadata(ground_truth.get("metadata", {})) groundwater_ground_truth = ground_truth.get("groundwater", []) if groundwater_ground_truth is None: groundwater_ground_truth = [] @@ -203,60 +167,6 @@ def evaluate_layers(self, ground_truth_layers: list): layer.material_is_correct = False layer.depth_interval_is_correct = None - def evaluate_metadata(self, metadata_ground_truth: dict): - """Evaluate the metadata of the file against the ground truth. - - Note: For now coordinates is the only metadata extracted and evaluated for. - - Args: - metadata_ground_truth (dict): The ground truth for the file. - """ - ############################################################################################################ - ### Compute the metadata correctness for the coordinates. - ############################################################################################################ - extracted_coordinates = self.metadata.coordinates - ground_truth_coordinates = metadata_ground_truth.get("coordinates") - - if extracted_coordinates is not None and ground_truth_coordinates is not None: - if extracted_coordinates.east.coordinate_value > 2e6 and ground_truth_coordinates["E"] < 2e6: - ground_truth_east = int(ground_truth_coordinates["E"]) + 2e6 - ground_truth_north = int(ground_truth_coordinates["N"]) + 1e6 - elif extracted_coordinates.east.coordinate_value < 2e6 and ground_truth_coordinates["E"] > 2e6: - ground_truth_east = int(ground_truth_coordinates["E"]) - 2e6 - ground_truth_north = int(ground_truth_coordinates["N"]) - 1e6 - else: - ground_truth_east = int(ground_truth_coordinates["E"]) - ground_truth_north = int(ground_truth_coordinates["N"]) - - if (math.isclose(int(extracted_coordinates.east.coordinate_value), ground_truth_east, abs_tol=2)) and ( - math.isclose(int(extracted_coordinates.north.coordinate_value), ground_truth_north, abs_tol=2) - ): - self.metadata_is_correct["coordinates"] = Metrics(tp=1, fp=0, fn=0) - else: - self.metadata_is_correct["coordinates"] = Metrics(tp=0, fp=1, fn=1) - else: - self.metadata_is_correct["coordinates"] = Metrics( - tp=0, - fp=1 if extracted_coordinates is not None else 0, - fn=1 if ground_truth_coordinates is not None else 0, - ) - - ############################################################################################################ - ### Compute the metadata correctness for the elevation. - ############################################################################################################ - extracted_elevation = None if self.metadata.elevation is None else self.metadata.elevation.elevation - ground_truth_elevation = metadata_ground_truth.get("reference_elevation") - - if extracted_elevation is not None and ground_truth_elevation is not None: - if math.isclose(extracted_elevation, ground_truth_elevation, abs_tol=0.1): - self.metadata_is_correct["elevation"] = Metrics(tp=1, fp=0, fn=0) - else: - self.metadata_is_correct["elevation"] = Metrics(tp=0, fp=1, fn=1) - else: - self.metadata_is_correct["elevation"] = Metrics( - tp=0, fp=1 if extracted_elevation is not None else 0, fn=1 if ground_truth_elevation is not None else 0 - ) - @staticmethod def count_against_ground_truth(values: list, ground_truth: list) -> Metrics: """Count the number of true positives, false positives and false negatives. @@ -266,7 +176,7 @@ def count_against_ground_truth(values: list, ground_truth: list) -> Metrics: ground_truth (list): The ground truth values. Returns: - dict: The number of true positives, false positives and false negatives. + Metrics: The metrics for the values. """ # Counter deals with duplicates when doing intersection values_counter = Counter(values) diff --git a/src/stratigraphy/util/util.py b/src/stratigraphy/util/util.py index 6c5d8bcb..f591b1ed 100644 --- a/src/stratigraphy/util/util.py +++ b/src/stratigraphy/util/util.py @@ -12,6 +12,15 @@ def x_overlap(rect1: fitz.Rect, rect2: fitz.Rect) -> float: # noqa: D103 + """Calculate the x overlap between two rectangles. + + Args: + rect1 (fitz.Rect): First rectangle. + rect2 (fitz.Rect): Second rectangle. + + Returns: + float: The x overlap between the two rectangles. + """ if (rect1.x0 < rect2.x1) and (rect2.x0 < rect1.x1): return min(rect1.x1, rect2.x1) - max(rect1.x0, rect2.x0) else: @@ -19,10 +28,30 @@ def x_overlap(rect1: fitz.Rect, rect2: fitz.Rect) -> float: # noqa: D103 def x_overlap_significant_smallest(rect1: fitz.Rect, rect2: fitz.Rect, level: float) -> bool: # noqa: D103 + """Check if the x overlap between two rectangles is significant relative to the width of the narrowest one. + + Args: + rect1 (fitz.Rect): First rectangle. + rect2 (fitz.Rect): Second rectangle. + level (float): Level of significance. + + Returns: + bool: True if the x overlap is significant, otherwise False. + """ return x_overlap(rect1, rect2) > level * min(rect1.width, rect2.width) def x_overlap_significant_largest(rect1: fitz.Rect, rect2: fitz.Rect, level: float) -> bool: # noqa: D103 + """Check if the x overlap between two rectangles is significant relative to the width of the widest one. + + Args: + rect1 (fitz.Rect): First rectangle. + rect2 (fitz.Rect): Second rectangle. + level (float): Level of significance. + + Returns: + bool: True if the x overlap is significant, otherwise False. + """ return x_overlap(rect1, rect2) > level * max(rect1.width, rect2.width) diff --git a/tests/test_coordinate_extraction.py b/tests/test_coordinate_extraction.py index 96cc0f71..613d3a80 100644 --- a/tests/test_coordinate_extraction.py +++ b/tests/test_coordinate_extraction.py @@ -3,14 +3,14 @@ import fitz import pytest from stratigraphy import DATAPATH -from stratigraphy.coordinates.coordinate_extraction import ( +from stratigraphy.lines.line import TextLine, TextWord +from stratigraphy.metadata.coordinate_extraction import ( Coordinate, CoordinateEntry, CoordinateExtractor, LV03Coordinate, LV95Coordinate, ) -from stratigraphy.util.line import TextLine, TextWord def test_strLV95(): # noqa: D103 diff --git a/tests/test_depthcolumn.py b/tests/test_depthcolumn.py index 66d53fb8..484a8e7a 100644 --- a/tests/test_depthcolumn.py +++ b/tests/test_depthcolumn.py @@ -1,8 +1,8 @@ """Test suite for the find_depth_columns module.""" import fitz -from stratigraphy.util.depthcolumn import BoundaryDepthColumn -from stratigraphy.util.depthcolumnentry import DepthColumnEntry +from stratigraphy.depthcolumn.depthcolumn import BoundaryDepthColumn +from stratigraphy.depthcolumn.depthcolumnentry import DepthColumnEntry def test_boundarydepthcolumn_isarithmeticprogression(): # noqa: D103 diff --git a/tests/test_find_depth_columns.py b/tests/test_find_depth_columns.py index 049ca591..d8fd9294 100644 --- a/tests/test_find_depth_columns.py +++ b/tests/test_find_depth_columns.py @@ -2,9 +2,13 @@ import fitz import pytest -from stratigraphy.util.depthcolumnentry import DepthColumnEntry -from stratigraphy.util.find_depth_columns import depth_column_entries, find_depth_columns, find_layer_depth_columns -from stratigraphy.util.line import TextLine, TextWord +from stratigraphy.depthcolumn.depthcolumnentry import DepthColumnEntry +from stratigraphy.depthcolumn.find_depth_columns import ( + depth_column_entries, + find_depth_columns, + find_layer_depth_columns, +) +from stratigraphy.lines.line import TextLine, TextWord PAGE_NUMBER = 1 ALL_WORDS_FIND_DEPTH_COLUMN = [ diff --git a/tests/test_find_descripton.py b/tests/test_find_descripton.py index 721c10a9..31be076f 100644 --- a/tests/test_find_descripton.py +++ b/tests/test_find_descripton.py @@ -1,9 +1,9 @@ """Test suite for the find_description module.""" import fitz +from stratigraphy.lines.line import TextLine, TextWord +from stratigraphy.text.find_description import get_description_blocks from stratigraphy.util.dataclasses import Line, Point -from stratigraphy.util.find_description import get_description_blocks -from stratigraphy.util.line import TextLine, TextWord page_number = 1 textline1 = TextLine([TextWord(fitz.Rect([0, 0, 10, 10]), "Hello", page_number)]) diff --git a/tests/test_geometric_line_utilities.py b/tests/test_geometric_line_utilities.py index 688b0eea..bfbfb06a 100644 --- a/tests/test_geometric_line_utilities.py +++ b/tests/test_geometric_line_utilities.py @@ -2,8 +2,7 @@ import numpy as np import pytest -from stratigraphy.util.dataclasses import Line, Point -from stratigraphy.util.geometric_line_utilities import ( +from stratigraphy.lines.geometric_line_utilities import ( _get_orthogonal_projection_to_line, _merge_lines, _odr_regression, @@ -11,6 +10,7 @@ is_point_on_line, merge_parallel_lines_quadtree, ) +from stratigraphy.util.dataclasses import Line, Point # Remember, phi is orthogonal to the line we are to parameterize diff --git a/tests/test_interval.py b/tests/test_interval.py index d88bbee9..0d90cded 100644 --- a/tests/test_interval.py +++ b/tests/test_interval.py @@ -1,7 +1,7 @@ """Test suite for the interval module.""" import fitz -from stratigraphy.util.depthcolumnentry import DepthColumnEntry, LayerDepthColumnEntry +from stratigraphy.depthcolumn.depthcolumnentry import DepthColumnEntry, LayerDepthColumnEntry from stratigraphy.util.interval import BoundaryInterval, LayerInterval diff --git a/tests/test_textblock.py b/tests/test_textblock.py index 69536ee8..f32cb93d 100644 --- a/tests/test_textblock.py +++ b/tests/test_textblock.py @@ -1,8 +1,8 @@ """Test suite for the textblock module.""" import fitz -from stratigraphy.util.line import TextLine, TextWord -from stratigraphy.util.textblock import TextBlock, block_distance +from stratigraphy.lines.line import TextLine, TextWord +from stratigraphy.text.textblock import TextBlock, block_distance def test_concatenate(): # noqa: D103