Skip to content

Commit

Permalink
Addressed the comments from the PR
Browse files Browse the repository at this point in the history
  • Loading branch information
David Cleres committed Nov 7, 2024
1 parent dd559ff commit 4e3d50e
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 46 deletions.
2 changes: 1 addition & 1 deletion src/stratigraphy/benchmark/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_metrics_list(self) -> list[Metrics]:
class OverallMetricsCatalog:
"""Keeps track of all different relevant metrics that are computed for a dataset."""

def __init__(self, languages: list[str]):
def __init__(self, languages: set[str]):
self.layer_metrics = OverallMetrics()
self.depth_interval_metrics = OverallMetrics()
self.groundwater_metrics = OverallMetrics()
Expand Down
34 changes: 0 additions & 34 deletions src/stratigraphy/benchmark/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from dotenv import load_dotenv
from stratigraphy import DATAPATH
from stratigraphy.annotations.draw import draw_predictions
from stratigraphy.benchmark.ground_truth import GroundTruth
from stratigraphy.evaluation.evaluation_dataclasses import BoreholeMetadataMetrics
from stratigraphy.util.predictions import OverallFilePredictions

Expand All @@ -21,39 +20,6 @@
logger = logging.getLogger(__name__)


def create_predictions_objects(
predictions: OverallFilePredictions,
ground_truth_path: Path | None,
) -> tuple[OverallFilePredictions, dict]:
"""Create predictions objects from the predictions and evaluate them against the ground truth.
Args:
predictions (OverallFilePredictions): The predictions objects.
ground_truth_path (Path | None): The path to the ground truth file.
Returns:
tuple[OverallFilePredictions, dict]: The predictions objects and the number of ground truth values per
file.
"""
if ground_truth_path and ground_truth_path.exists(): # for inference no ground truth is available
ground_truth = GroundTruth(ground_truth_path)
ground_truth_is_present = True
else:
logging.warning("Ground truth file not found.")
ground_truth_is_present = False
return predictions, {}

number_of_truth_values = {}
for file_predictions in predictions.file_predictions_list:
if ground_truth_is_present:
ground_truth_for_file = ground_truth.for_file(file_predictions.file_name)
if ground_truth_for_file:
file_predictions.evaluate(ground_truth_for_file)
number_of_truth_values[file_predictions.file_name] = len(ground_truth_for_file["layers"])

return predictions, number_of_truth_values


def evaluate(
predictions: OverallFilePredictions,
ground_truth_path: Path,
Expand Down
2 changes: 1 addition & 1 deletion src/stratigraphy/evaluation/groundwater_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def evaluate(self) -> OverallGroundwaterMetrics:
groundwater_elevation_metrics=groundwater_elevation_metrics,
groundwater_date_metrics=groundwater_date_metrics,
filename=filename,
) # TODO: This clashes with the OverallMetrics object
)

overall_groundwater_metrics.add_groundwater_metrics(file_groundwater_metrics)

Expand Down
16 changes: 6 additions & 10 deletions src/stratigraphy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,11 @@ def start_pipeline(

if part == "all":
# Extract the groundwater levels
groundwater_in_document = GroundwaterInDocument.from_document(doc, metadata.elevation)
groundwater_entries = GroundwaterInDocument.from_document(doc, metadata.elevation)

# Extract the layers
layer_predictions_list = LayersInDocument([], filename)
depths_materials_column_pairs_list = []
layers = LayersInDocument([], filename)
depths_materials_columns_pairs = []
for page_index, page in enumerate(doc):
page_number = page_index + 1
logger.info("Processing page %s", page_number)
Expand All @@ -253,7 +253,7 @@ def start_pipeline(
layer_predictions = remove_duplicate_layers(
previous_page=doc[page_index - 1],
current_page=page,
previous_layers=layer_predictions_list,
previous_layers=layers,
current_layers=process_page_results.predictions,
img_template_probability_threshold=matching_params[
"img_template_probability_threshold"
Expand All @@ -262,8 +262,8 @@ def start_pipeline(
else:
layer_predictions = process_page_results.predictions

layer_predictions_list.add_layers_on_page(layer_predictions)
depths_materials_column_pairs_list.extend(process_page_results.depth_material_pairs)
layers.add_layers_on_page(layer_predictions)
depths_materials_columns_pairs.extend(process_page_results.depth_material_pairs)

if draw_lines: # could be changed to if draw_lines and mflow_tracking:
if not mlflow_tracking:
Expand All @@ -276,10 +276,6 @@ def start_pipeline(
)
mlflow.log_image(img, f"pages/{filename}_page_{page.number + 1}_lines.png")

groundwater_entries = groundwater_in_document
layers = layer_predictions_list
depths_materials_columns_pairs = depths_materials_column_pairs_list

# Add file predictions
predictions.add_file_predictions(
FilePredictions(
Expand Down

0 comments on commit 4e3d50e

Please sign in to comment.