Skip to content

Commit

Permalink
Remove GroundTruthForFile class.
Browse files Browse the repository at this point in the history
  • Loading branch information
redur committed Apr 11, 2024
1 parent 06a0c98 commit bdd7551
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 16 deletions.
14 changes: 3 additions & 11 deletions src/stratigraphy/benchmark/ground_truth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,6 @@
logger = logging.getLogger(__name__)


class GroundTruthForFile:
"""Ground truth data for a single file."""

def __init__(self, ground_truth_layers: list):
self.layers = ground_truth_layers
self.num_layers = len(ground_truth_layers)


class GroundTruth:
"""Ground truth data for the stratigraphy benchmark."""

Expand All @@ -37,9 +29,9 @@ def __init__(self, path: Path):
if parse_text(layer["material_description"]) != ""
]

def for_file(self, file_name: str) -> GroundTruthForFile:
def for_file(self, file_name: str) -> list:
if file_name in self.ground_truth:
return GroundTruthForFile(self.ground_truth[file_name]["layers"])
return self.ground_truth[file_name]["layers"]
else:
logger.warning(f"No ground truth data found for {file_name}.")
return GroundTruthForFile([])
return []
2 changes: 1 addition & 1 deletion src/stratigraphy/benchmark/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def create_predictions_objects(predictions: dict, ground_truth_path: Path) -> tu
predictions_objects[file_name] = prediction_object
if ground_truth_is_present:
predictions_objects[file_name].evaluate(ground_truth.for_file(file_name))
number_of_truth_values[file_name] = ground_truth.for_file(file_name).num_layers
number_of_truth_values[file_name] = len(ground_truth.for_file(file_name))

return predictions_objects, number_of_truth_values

Expand Down
7 changes: 3 additions & 4 deletions src/stratigraphy/util/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import fitz
import Levenshtein

from stratigraphy.benchmark.ground_truth import GroundTruthForFile
from stratigraphy.util.depthcolumnentry import DepthColumnEntry
from stratigraphy.util.interval import BoundaryInterval
from stratigraphy.util.line import TextLine, TextWord
Expand Down Expand Up @@ -96,13 +95,13 @@ def create_from_json(predictions_for_file: dict, file_name: str):

return FilePredictions(pages=page_predictions_class, file_name=file_name)

def evaluate(self, ground_truth: GroundTruthForFile):
def evaluate(self, ground_truth_layers: list):
"""Evaluate all layers of the predictions against the ground truth.
Args:
ground_truth (GroundTruthForFile): The ground truth for the file.
ground_truth_layers (list): The ground truth layers for the file.
"""
self.unmatched_layers = ground_truth.layers.copy()
self.unmatched_layers = ground_truth_layers.copy()
for layer in self.layers:
match, depth_interval_is_correct = self._find_matching_layer(layer)
if match:
Expand Down

0 comments on commit bdd7551

Please sign in to comment.