Skip to content

Commit

Permalink
Close #80: Refactoring of the groundwater evaluation and classes
Browse files Browse the repository at this point in the history
  • Loading branch information
dcleres committed Sep 26, 2024
1 parent e056302 commit c2864d0
Show file tree
Hide file tree
Showing 9 changed files with 505 additions and 375 deletions.
13 changes: 7 additions & 6 deletions src/stratigraphy/annotations/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dotenv import load_dotenv
from stratigraphy.depthcolumn.depthcolumn import DepthColumn
from stratigraphy.depths_materials_column_pairs.depths_materials_column_pairs import DepthsMaterialsColumnPairs
from stratigraphy.groundwater.groundwater_extraction import GroundwaterInformationOnPage
from stratigraphy.groundwater.groundwater_extraction import GroundwaterOnPage
from stratigraphy.layer.layer import LayerPrediction
from stratigraphy.metadata.coordinate_extraction import Coordinate
from stratigraphy.metadata.elevation_extraction import Elevation
Expand Down Expand Up @@ -82,9 +82,10 @@ def draw_predictions(
draw_coordinates(shape, coordinates)
if elevation is not None and page_number == elevation.page:
draw_elevation(shape, elevation)
for groundwater_entry in file_prediction.groundwater_entries:
if page_number == groundwater_entry.page:
draw_groundwater(shape, groundwater_entry)
for groundwater_on_page in file_prediction.groundwater.groundwater:
# TODO: Adapt this to the structures above -> List the groundwater in the function
if page_number == groundwater_on_page.page:
draw_groundwater(shape, groundwater_on_page)
draw_depth_columns_and_material_rect(
shape,
page.derotation_matrix,
Expand Down Expand Up @@ -185,12 +186,12 @@ def draw_coordinates(shape: fitz.Shape, coordinates: Coordinate) -> None:
shape.finish(color=fitz.utils.getColor("purple"))


def draw_groundwater(shape: fitz.Shape, groundwater_entry: GroundwaterInformationOnPage) -> None:
def draw_groundwater(shape: fitz.Shape, groundwater_entry: GroundwaterOnPage) -> None:
"""Draw a bounding box around the area of the page where the coordinates were extracted from.
Args:
shape (fitz.Shape): The shape object for drawing.
groundwater_entry (GroundwaterInformationOnPage): The groundwater information to draw.
groundwater_entry (GroundwaterOnPage): The groundwater information to draw.
"""
shape.draw_rect(groundwater_entry.rect)
shape.finish(color=fitz.utils.getColor("pink"))
Expand Down
238 changes: 22 additions & 216 deletions src/stratigraphy/benchmark/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
import os
from pathlib import Path

import pandas as pd
from dotenv import load_dotenv
from stratigraphy import DATAPATH
from stratigraphy.annotations.draw import draw_predictions
from stratigraphy.benchmark.ground_truth import GroundTruth
from stratigraphy.benchmark.metrics import DatasetMetrics, DatasetMetricsCatalog, Metrics
from stratigraphy.evaluation.evaluation_dataclasses import BoreholeMetadataMetrics
from stratigraphy.util.predictions import OverallFilePredictions
from stratigraphy.util.util import parse_text

load_dotenv()

Expand All @@ -20,192 +19,6 @@
logger = logging.getLogger(__name__)


def get_layer_metrics(predictions: OverallFilePredictions, number_of_truth_values: dict) -> DatasetMetrics:
"""Calculate F1, precision and recall for the layer predictions.
Calculate F1, precision and recall for the individual documents as well as overall.
Args:
predictions (dict): The predictions.
number_of_truth_values (dict): The number of ground truth values per file.
Returns:
DatasetMetrics: the metrics for the layers
"""
layer_metrics = DatasetMetrics()

for file_prediction in predictions.file_predictions_list:
hits = 0
for layer in file_prediction.layers:
if layer.material_is_correct:
hits += 1
if parse_text(layer.material_description.text) == "":
logger.warning("Empty string found in predictions")
layer_metrics.metrics[file_prediction.file_name] = Metrics(
tp=hits, fp=len(file_prediction.layers) - hits, fn=number_of_truth_values[file_prediction.file_name] - hits
)

return layer_metrics


def get_depth_interval_metrics(predictions: OverallFilePredictions) -> DatasetMetrics:
"""Calculate F1, precision and recall for the depth interval predictions.
Calculate F1, precision and recall for the individual documents as well as overall.
Depth interval accuracy is not calculated for layers with incorrect material predictions.
Args:
predictions (dict): The predictions.
Returns:
DatasetMetrics: the metrics for the depth intervals
"""
depth_interval_metrics = DatasetMetrics()

for file_prediction in predictions.file_predictions_list:
depth_interval_hits = 0
depth_interval_occurrences = 0
for layer in file_prediction.layers:
if layer.material_is_correct:
if layer.depth_interval_is_correct is not None:
depth_interval_occurrences += 1
if layer.depth_interval_is_correct:
depth_interval_hits += 1

if depth_interval_occurrences > 0:
depth_interval_metrics.metrics[file_prediction.file_name] = Metrics(
tp=depth_interval_hits, fp=depth_interval_occurrences - depth_interval_hits, fn=0
)

return depth_interval_metrics


def evaluate_borehole_extraction(
predictions: OverallFilePredictions, number_of_truth_values: dict
) -> DatasetMetricsCatalog:
"""Evaluate the borehole extraction predictions.
Args:
predictions (dict): The FilePredictions objects.
number_of_truth_values (dict): The number of layer ground truth values per file.
Returns:
DatasetMetricsCatalogue: A DatasetMetricsCatalogue that maps a metrics name to the corresponding DatasetMetrics
object
"""
all_metrics = evaluate_layer_extraction(predictions, number_of_truth_values)
all_metrics.metrics["groundwater"] = get_metrics(predictions, "groundwater_is_correct", "groundwater")
all_metrics.metrics["groundwater_depth"] = get_metrics(predictions, "groundwater_is_correct", "groundwater_depth")
return all_metrics


def get_metrics(predictions: OverallFilePredictions, field_key: str, field_name: str) -> DatasetMetrics:
"""Get the metrics for a specific field in the predictions.
Args:
predictions (dict): The FilePredictions objects.
field_key (str): The key to access the specific field in the prediction objects.
field_name (str): The name of the field being evaluated.
Returns:
DatasetMetrics: The requested DatasetMetrics object.
"""
dataset_metrics = DatasetMetrics()

for file_prediction in predictions.file_predictions_list:
dataset_metrics.metrics[file_prediction.file_name] = getattr(file_prediction, field_key)[field_name]

return dataset_metrics


def evaluate_layer_extraction(
predictions: OverallFilePredictions, number_of_truth_values: dict
) -> DatasetMetricsCatalog:
"""Calculate F1, precision and recall for the predictions.
Calculate F1, precision and recall for the individual documents as well as overall.
The individual document metrics are returned as a DataFrame.
Args:
predictions (dict): The FilePredictions objects.
number_of_truth_values (dict): The number of layer ground truth values per file.
Returns:
DatasetMetricsCatalogue: A dictionary that maps a metrics name to the corresponding DatasetMetrics object
"""
all_metrics = DatasetMetricsCatalog()
all_metrics.metrics["layer"] = get_layer_metrics(predictions, number_of_truth_values)
all_metrics.metrics["depth_interval"] = get_depth_interval_metrics(predictions)

# create predictions by language
predictions_by_language = {
"de": OverallFilePredictions(),
"fr": OverallFilePredictions(),
} # TODO: make this dynamic and why is this hardcoded?
for file_predictions in predictions.file_predictions_list:
language = file_predictions.metadata.language
if language in predictions_by_language:
predictions_by_language[language].add_file_predictions(file_predictions)

for language, language_predictions in predictions_by_language.items():
language_number_of_truth_values = {
prediction.file_name: number_of_truth_values[prediction.file_name]
for prediction in language_predictions.file_predictions_list
}
all_metrics.metrics[f"{language}_layer"] = get_layer_metrics(
language_predictions, language_number_of_truth_values
)
all_metrics.metrics[f"{language}_depth_interval"] = get_depth_interval_metrics(language_predictions)

logging.info("Macro avg:")
logging.info(
"F1: %.1f%%, precision: %.1f%%, recall: %.1f%%, depth_interval_accuracy: %.1f%%",
all_metrics.metrics["layer"].macro_f1() * 100,
all_metrics.metrics["layer"].macro_precision() * 100,
all_metrics.metrics["layer"].macro_recall() * 100,
all_metrics.metrics["depth_interval"].macro_precision() * 100,
)

return all_metrics


def create_predictions_objects(
predictions: OverallFilePredictions,
ground_truth_path: Path | None,
) -> tuple[OverallFilePredictions, dict]:
"""Create predictions objects from the predictions and evaluate them against the ground truth.
Args:
predictions (dict): The predictions from the predictions.json file.
ground_truth_path (Path | None): The path to the ground truth file.
metadata_per_file (BoreholeMetadataList): The metadata for the files.
Returns:
tuple[dict[str, FilePredictions], dict]: The predictions objects and the number of ground truth values per
file.
"""
if ground_truth_path and os.path.exists(ground_truth_path): # for inference no ground truth is available
ground_truth = GroundTruth(ground_truth_path)
ground_truth_is_present = True
else:
logging.warning("Ground truth file not found.")
ground_truth_is_present = False

number_of_truth_values = {}
for file_predictions in predictions.file_predictions_list:
# prediction_object = FilePredictions.create_from_json(file_predictions, file_predictions.file_name)

# predictions_objects[file_name] = prediction_object
if ground_truth_is_present:
ground_truth_for_file = ground_truth.for_file(file_predictions.file_name)
if ground_truth_for_file:
file_predictions.evaluate(ground_truth_for_file)
number_of_truth_values[file_predictions.file_name] = len(ground_truth_for_file["layers"])

return predictions, number_of_truth_values


def evaluate(
predictions: OverallFilePredictions,
ground_truth_path: Path,
Expand All @@ -218,52 +31,45 @@ def evaluate(
# Evaluate the borehole extraction metadata
#############################
metadata_metrics_list = predictions.evaluate_metadata_extraction(ground_truth_path)
metadata_metrics = metadata_metrics_list.get_cumulated_metrics()
document_level_metadata_metrics = metadata_metrics_list.get_document_level_metrics()
metadata_metrics: BoreholeMetadataMetrics = metadata_metrics_list.get_cumulated_metrics()
document_level_metadata_metrics: pd.DataFrame = metadata_metrics_list.get_document_level_metrics()
document_level_metadata_metrics.to_csv(
temp_directory / "document_level_metadata_metrics.csv", index_label="document_name"
) # mlflow.log_artifact expects a file

# print the metrics
logger.info("Metadata Performance metrics:")
logger.info(metadata_metrics)
logger.info(metadata_metrics.to_json())

if mlflow_tracking:
import mlflow

mlflow.log_metrics(metadata_metrics)
mlflow.log_metrics(metadata_metrics.to_json())
mlflow.log_artifact(temp_directory / "document_level_metadata_metrics.csv")

#############################
# Evaluate the borehole extraction
#############################
if predictions:
predictions, number_of_truth_values = create_predictions_objects(predictions, ground_truth_path)
metrics = predictions.evaluate_borehole_extraction(ground_truth_path)

if number_of_truth_values: # only evaluate if ground truth is available
metrics = evaluate_borehole_extraction(predictions, number_of_truth_values)

metrics.document_level_metrics_df().to_csv(
temp_directory / "document_level_metrics.csv", index_label="document_name"
) # mlflow.log_artifact expects a file
metrics_dict = metrics.metrics_dict()

# Format the metrics dictionary to limit to three digits
formatted_metrics = {k: f"{v:.3f}" for k, v in metrics_dict.items()}
logger.info("Performance metrics: %s", formatted_metrics)
metrics.document_level_metrics_df().to_csv(
temp_directory / "document_level_metrics.csv", index_label="document_name"
) # mlflow.log_artifact expects a file
metrics_dict = metrics.metrics_dict()

if mlflow_tracking:
mlflow.log_metrics(metrics_dict)
mlflow.log_artifact(temp_directory / "document_level_metrics.csv")
# Format the metrics dictionary to limit to three digits
formatted_metrics = {k: f"{v:.3f}" for k, v in metrics_dict.items()}
logger.info("Performance metrics: %s", formatted_metrics)

else:
logger.warning("Ground truth file not found. Skipping evaluation.")
if mlflow_tracking:
mlflow.log_metrics(metrics_dict)
mlflow.log_artifact(temp_directory / "document_level_metrics.csv")

#############################
# Draw the prediction
#############################
if input_directory and draw_directory:
draw_predictions(predictions, input_directory, draw_directory, document_level_metadata_metrics)
#############################
# Draw the prediction
#############################
if input_directory and draw_directory:
draw_predictions(predictions, input_directory, draw_directory, document_level_metadata_metrics)


if __name__ == "__main__":
Expand Down
5 changes: 2 additions & 3 deletions src/stratigraphy/evaluation/evaluation_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,10 @@ def get_cumulated_metrics(self) -> dict:
coordinates_metrics = Metrics.micro_average(
[metadata.coordinates_metrics for metadata in self.borehole_metadata_metrics]
)
return BoreholeMetadataMetrics(
elevation_metrics=elevation_metrics, coordinates_metrics=coordinates_metrics
).to_json()
return BoreholeMetadataMetrics(elevation_metrics=elevation_metrics, coordinates_metrics=coordinates_metrics)

def get_document_level_metrics(self) -> pd.DataFrame:
"""Get the document level metrics."""
# Get a dataframe per document, concatenate, and sort by index (document name)
return pd.concat(
[metadata.get_document_level_metrics() for metadata in self.borehole_metadata_metrics]
Expand Down
Loading

0 comments on commit c2864d0

Please sign in to comment.