Skip to content

Commit

Permalink
LGVISIUM-79: create LayerEvaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
stijnvermeeren-swisstopo committed Nov 5, 2024
1 parent d802102 commit a3429a8
Show file tree
Hide file tree
Showing 11 changed files with 260 additions and 301 deletions.
68 changes: 36 additions & 32 deletions src/stratigraphy/annotations/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def draw_predictions(
predictions: OverallFilePredictions,
directory: Path,
out_directory: Path,
document_level_metadata_metrics: pd.DataFrame,
document_level_metadata_metrics: None | pd.DataFrame,
) -> None:
"""Draw predictions on pdf pages.
Expand All @@ -48,7 +48,7 @@ def draw_predictions(
predictions (dict): Content of the predictions.json file.
directory (Path): Path to the directory containing the pdf files.
out_directory (Path): Path to the output directory where the images are saved.
document_level_metadata_metrics (pd.DataFrame): Document level metadata metrics.
document_level_metadata_metrics (None | pd.DataFrame): Document level metadata metrics.
"""
if directory.is_file(): # deal with the case when we pass a file instead of a directory
directory = directory.parent
Expand All @@ -60,15 +60,18 @@ def draw_predictions(
elevation = file_prediction.metadata.elevation

# Assess the correctness of the metadata
if file_prediction.file_name in document_level_metadata_metrics.index:
if (
document_level_metadata_metrics is not None
and file_prediction.file_name in document_level_metadata_metrics.index
):
is_coordinates_correct = document_level_metadata_metrics.loc[file_prediction.file_name].coordinate
is_elevation_correct = document_level_metadata_metrics.loc[file_prediction.file_name].elevation
else:
logger.warning(
"Metrics for file %s not found in document_level_metadata_metrics.", file_prediction.file_name
)
is_coordinates_correct = False
is_elevation_correct = False
is_coordinates_correct = None
is_elevation_correct = None

try:
with fitz.Document(directory / file_prediction.file_name) as doc:
Expand Down Expand Up @@ -131,9 +134,9 @@ def draw_metadata(
derotation_matrix: fitz.Matrix,
rotation: float,
coordinates: Coordinate | None,
is_coordinate_correct: bool,
is_coordinate_correct: bool | None,
elevation_info: Elevation | None,
is_elevation_correct: bool,
is_elevation_correct: bool | None,
) -> None:
"""Draw the extracted metadata on the top of the given PDF page.
Expand All @@ -145,44 +148,45 @@ def draw_metadata(
derotation_matrix (fitz.Matrix): The derotation matrix of the page.
rotation (float): The rotation of the page.
coordinates (Coordinate | None): The coordinate object to draw.
is_coordinate_correct (Metrics): Whether the coordinate information is correct.
elevation_info (ElevationInformation | None): The elevation information to draw.
is_elevation_correct (Metrics): Whether the elevation information is correct.
is_coordinate_correct (bool | None): Whether the coordinate information is correct.
elevation_info (Elevation | None): The elevation information to draw.
is_elevation_correct (bool | None): Whether the elevation information is correct.
"""
# TODO associate correctness with the extracted coordinates in a better way
coordinate_color = "green" if is_coordinate_correct else "red"
coordinate_rect = fitz.Rect([5, 5, 250, 30])

elevation_color = "green" if is_elevation_correct else "red"
elevation_rect = fitz.Rect([5, 30, 250, 55])

shape.draw_rect(coordinate_rect * derotation_matrix)
shape.finish(fill=fitz.utils.getColor("gray"), fill_opacity=0.5)
shape.insert_textbox(coordinate_rect * derotation_matrix, f"Coordinates: {coordinates}", rotate=rotation)
shape.draw_line(
coordinate_rect.top_left * derotation_matrix,
coordinate_rect.bottom_left * derotation_matrix,
)
shape.finish(
color=fitz.utils.getColor(coordinate_color),
width=6,
stroke_opacity=0.5,
)
if is_coordinate_correct is not None:
# TODO associate correctness with the extracted coordinates in a better way
coordinate_color = "green" if is_coordinate_correct else "red"
shape.draw_line(
coordinate_rect.top_left * derotation_matrix,
coordinate_rect.bottom_left * derotation_matrix,
)
shape.finish(
color=fitz.utils.getColor(coordinate_color),
width=6,
stroke_opacity=0.5,
)

# Draw the bounding box around the elevation information
elevation_txt = f"Elevation: {elevation_info.elevation} m" if elevation_info is not None else "Elevation: N/A"
shape.draw_rect(elevation_rect * derotation_matrix)
shape.finish(fill=fitz.utils.getColor("gray"), fill_opacity=0.5)
shape.insert_textbox(elevation_rect * derotation_matrix, elevation_txt, rotate=rotation)
shape.draw_line(
elevation_rect.top_left * derotation_matrix,
elevation_rect.bottom_left * derotation_matrix,
)
shape.finish(
color=fitz.utils.getColor(elevation_color),
width=6,
stroke_opacity=0.5,
)
if is_elevation_correct is not None:
elevation_color = "green" if is_elevation_correct else "red"
shape.draw_line(
elevation_rect.top_left * derotation_matrix,
elevation_rect.bottom_left * derotation_matrix,
)
shape.finish(
color=fitz.utils.getColor(elevation_color),
width=6,
stroke_opacity=0.5,
)


def draw_coordinates(shape: fitz.Shape, coordinates: Coordinate) -> None:
Expand Down
37 changes: 15 additions & 22 deletions src/stratigraphy/benchmark/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
import pandas as pd
from dotenv import load_dotenv
from stratigraphy import DATAPATH
from stratigraphy.annotations.draw import draw_predictions
from stratigraphy.benchmark.ground_truth import GroundTruth
from stratigraphy.evaluation.evaluation_dataclasses import BoreholeMetadataMetrics
from stratigraphy.util.predictions import OverallFilePredictions

load_dotenv()
Expand Down Expand Up @@ -52,29 +50,29 @@ def create_predictions_objects(


def evaluate(
predictions: OverallFilePredictions,
ground_truth_path: Path,
temp_directory: Path,
input_directory: Path | None,
draw_directory: Path | None,
) -> None:
predictions: OverallFilePredictions, ground_truth_path: Path, temp_directory: Path
) -> None | pd.DataFrame:
"""Computes all the metrics, logs them, and creates corresponding MLFlow artifacts (when enabled).
Args:
predictions (OverallFilePredictions): The predictions objects.
ground_truth_path (Path): The path to the ground truth file.
ground_truth_path (Path | None): The path to the ground truth file.
temp_directory (Path): The path to the temporary directory.
input_directory (Path | None): The path to the input directory.
draw_directory (Path | None): The path to the draw directory.
Returns:
None
None | pd.DataFrame: the document level metadata metrics
"""
if not (ground_truth_path and ground_truth_path.exists()): # for inference no ground truth is available
logger.warning("Ground truth file not found. Skipping evaluation.")
return None

ground_truth = GroundTruth(ground_truth_path)

#############################
# Evaluate the borehole extraction metadata
#############################
metadata_metrics_list = predictions.evaluate_metadata_extraction(ground_truth_path)
metadata_metrics: BoreholeMetadataMetrics = metadata_metrics_list.get_cumulated_metrics()
metadata_metrics_list = predictions.evaluate_metadata_extraction(ground_truth)
metadata_metrics = metadata_metrics_list.get_cumulated_metrics()
document_level_metadata_metrics: pd.DataFrame = metadata_metrics_list.get_document_level_metrics()
document_level_metadata_metrics.to_csv(
temp_directory / "document_level_metadata_metrics.csv", index_label="document_name"
Expand All @@ -93,7 +91,7 @@ def evaluate(
#############################
# Evaluate the borehole extraction
#############################
metrics = predictions.evaluate_borehole_extraction(ground_truth_path)
metrics = predictions.evaluate_geology(ground_truth)

metrics.document_level_metrics_df().to_csv(
temp_directory / "document_level_metrics.csv", index_label="document_name"
Expand All @@ -108,11 +106,7 @@ def evaluate(
mlflow.log_metrics(metrics_dict)
mlflow.log_artifact(temp_directory / "document_level_metrics.csv")

#############################
# Draw the prediction
#############################
if input_directory and draw_directory:
draw_predictions(predictions, input_directory, draw_directory, document_level_metadata_metrics)
return document_level_metadata_metrics


def main():
Expand Down Expand Up @@ -141,8 +135,7 @@ def main():

predictions = OverallFilePredictions.from_json(predictions)

# Customize these as needed
evaluate(predictions, args.ground_truth_path, args.temp_directory, input_directory=None, draw_directory=None)
evaluate(predictions, args.ground_truth_path, args.temp_directory)


def parse_cli() -> argparse.Namespace:
Expand Down
4 changes: 2 additions & 2 deletions src/stratigraphy/evaluation/evaluation_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ def get_document_level_metrics(self) -> pd.DataFrame:
class OverallBoreholeMetadataMetrics(metaclass=abc.ABCMeta):
"""Metrics for borehole metadata."""

borehole_metadata_metrics: list[FileBoreholeMetadataMetrics] = None
borehole_metadata_metrics: list[FileBoreholeMetadataMetrics]

def __init__(self):
"""Initializes the OverallBoreholeMetadataMetrics object."""
self.borehole_metadata_metrics = []

def get_cumulated_metrics(self) -> dict:
def get_cumulated_metrics(self) -> BoreholeMetadataMetrics:
"""Evaluate the metadata metrics."""
elevation_metrics = Metrics.micro_average(
[metadata.elevation_metrics for metadata in self.borehole_metadata_metrics]
Expand Down
9 changes: 4 additions & 5 deletions src/stratigraphy/evaluation/groundwater_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,14 @@ def groundwater_depth_metrics_to_overall_metrics(self):
class GroundwaterEvaluator:
"""Class for evaluating the extracted groundwater information of a borehole."""

def __init__(self, groundwater_entries: list[GroundwaterInDocument], ground_truth_path: str):
def __init__(self, groundwater_entries: list[GroundwaterInDocument], ground_truth: GroundTruth):
"""Initializes the GroundwaterEvaluator object.
Args:
groundwater_entries (list[GroundwaterInDocument]): The metadata to evaluate.
ground_truth_path (str): The path to the ground truth file.
ground_truth (GroundTruth): The ground truth.
"""
# Load the ground truth data for the metadata
self.groundwater_ground_truth = GroundTruth(ground_truth_path)
self.ground_truth = ground_truth
self.groundwater_entries: list[GroundwaterInDocument] = groundwater_entries

def evaluate(self) -> OverallGroundwaterMetrics:
Expand All @@ -73,7 +72,7 @@ def evaluate(self) -> OverallGroundwaterMetrics:

for groundwater_in_doc in self.groundwater_entries:
filename = groundwater_in_doc.filename
ground_truth_data = self.groundwater_ground_truth.for_file(filename)
ground_truth_data = self.ground_truth.for_file(filename)
if ground_truth_data is None or ground_truth_data.get("groundwater") is None:
ground_truth = [] # If no ground truth is available, set it to an empty list
else:
Expand Down
Loading

0 comments on commit a3429a8

Please sign in to comment.