diff --git a/src/stratigraphy/util/interval.py b/src/stratigraphy/util/interval.py index d9c070bd..e137aeda 100644 --- a/src/stratigraphy/util/interval.py +++ b/src/stratigraphy/util/interval.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc +from dataclasses import dataclass import fitz @@ -50,6 +51,19 @@ def to_json(self): } +@dataclass +class AnnotatedInterval: + """Class for annotated intervals.""" + + start: float + end: float + background_rect: fitz.Rect + + @property + def line_anchor(self) -> fitz.Point: + return fitz.Point(0, 0) + + class BoundaryInterval(Interval): """Class for boundary intervals. diff --git a/src/stratigraphy/util/predictions.py b/src/stratigraphy/util/predictions.py index 77c9dcdd..84f326c5 100644 --- a/src/stratigraphy/util/predictions.py +++ b/src/stratigraphy/util/predictions.py @@ -1,15 +1,17 @@ """This module contains classes for predictions.""" +import contextlib import uuid +from collections import defaultdict from dataclasses import dataclass, field import fitz import Levenshtein from stratigraphy.util.depthcolumnentry import DepthColumnEntry -from stratigraphy.util.interval import BoundaryInterval +from stratigraphy.util.interval import AnnotatedInterval, BoundaryInterval from stratigraphy.util.line import TextLine, TextWord -from stratigraphy.util.textblock import TextBlock +from stratigraphy.util.textblock import MaterialDescription, TextBlock from stratigraphy.util.util import parse_text @@ -17,8 +19,8 @@ class LayerPrediction: """A class to represent predictions for a single layer.""" - material_description: TextBlock - depth_interval: BoundaryInterval + material_description: TextBlock | MaterialDescription + depth_interval: BoundaryInterval | AnnotatedInterval material_is_correct: bool = None depth_interval_is_correct: bool = None id: uuid.UUID = field(default_factory=uuid.uuid4) @@ -105,6 +107,126 @@ def create_from_json(predictions_for_file: dict, file_name: str): return FilePredictions(pages=page_predictions_class, file_name=file_name, language=file_language) + @staticmethod + def create_from_label_studio(annotation_results: dict): + """Create predictions class for a file given the annotation results from Label Studio. + + NOTE: We may want to adjust this method to return a single instance of the class, + instead of a list of class objects. + + Args: + annotation_results (dict): The annotation results from Label Studio. + The annotation_results can cover multiple files. + + Returns: + list[FilePredictions]: A list of FilePredictions objects, one for each file present in the + annotation_results. + """ + file_pages = defaultdict(list) + for annotation in annotation_results: + # get page level information + file_name, page_index = _get_file_name_and_page_index(annotation) + page_width = annotation["annotations"][0]["result"][0]["original_width"] + page_height = annotation["annotations"][0]["result"][0]["original_height"] + + # extract all material descriptions and depth intervals and link them together + # Note: we need to loop through the annotations twice, because the order of the annotations is + # not guaranteed. In the first iteration we grasp all IDs, in the second iteration we extract the + # information for each id. + material_descriptions = {} + depth_intervals = {} + linking_objects = [] + + # define all the material descriptions and depth intervals with their ids + for annotation_result in annotation["annotations"][0]["result"]: + if annotation_result["type"] == "labels": + if annotation_result["value"]["labels"] == ["Material Description"]: + material_descriptions[annotation_result["id"]] = { + "rect": annotation_result["value"] + } # TODO extract rectangle properly + elif annotation_result["value"]["labels"] == ["Depth Interval"]: + depth_intervals[annotation_result["id"]] = {} + if annotation_result["type"] == "relation": + linking_objects.append( + {"from_id": annotation_result["from_id"], "to_id": annotation_result["to_id"]} + ) + + # check annotation results for material description or depth interval ids + for annotation_result in annotation["annotations"][0]["result"]: + with contextlib.suppress(KeyError): + id = annotation_result["id"] # relation regions do not have an ID. + if annotation_result["type"] == "textarea": + if id in material_descriptions: + material_descriptions[id]["text"] = annotation_result["value"]["text"][ + 0 + ] # There is always only one element. TO CHECK! + if len(annotation_result["value"]["text"]) > 1: + print(f"More than one text in material description: {annotation_result['value']['text']}") + elif id in depth_intervals: + depth_interval_text = annotation_result["value"]["text"][0] + start, end = _get_start_end_from_text(depth_interval_text) + depth_intervals[id]["start"] = start + depth_intervals[id]["end"] = end + depth_intervals[id]["background_rect"] = annotation_result[ + "value" + ] # TODO extract rectangle properly + else: + print(f"Unknown id: {id}") + + # create the layer prediction objects by linking material descriptions with depth intervals + layers = [] + + for link in linking_objects: + from_id = link["from_id"] + to_id = link["to_id"] + material_description_prediction = MaterialDescription(**material_descriptions.pop(from_id)) + depth_interval_prediction = AnnotatedInterval(**depth_intervals.pop(to_id)) + layers.append( + LayerPrediction( + material_description=material_description_prediction, + depth_interval=depth_interval_prediction, + material_is_correct=True, + depth_interval_is_correct=True, + ) + ) + + if material_descriptions or depth_intervals: + # TODO: This should not be acceptable. Raising an error doesnt seem the right way to go either. + # But at least it should be warned. + print("There are material descriptions or depth intervals left over.") + print(material_descriptions) + print(depth_intervals) + + file_pages[file_name].append( + PagePredictions(layers=layers, page_number=page_index, page_width=page_width, page_height=page_height) + ) + + file_predictions = [] + for file_name, page_predictions in file_pages.items(): + file_predictions.append( + FilePredictions(file_name=file_name, pages=page_predictions, language="unknown") + ) # TODO: language should not be required here. + + return file_predictions + + def convert_to_ground_truth(self): + """Convert the predictions to ground truth format. + + Returns: + dict: The predictions in ground truth format. + """ + for page in self.pages: + layers = [] + for layer in page.layers: + material_description = layer.material_description.text + depth_interval = { + "start": layer.depth_interval.start.value if layer.depth_interval.start else None, + "end": layer.depth_interval.end.value if layer.depth_interval.end else None, + } + layers.append({"material_description": material_description, "depth_interval": depth_interval}) + ground_truth = {self.file_name: {"layers": layers}} + return ground_truth + def evaluate(self, ground_truth_layers: list): """Evaluate all layers of the predictions against the ground truth. @@ -169,3 +291,15 @@ def _find_matching_layer(self, layer: LayerPrediction) -> tuple[dict, bool] | tu def _create_textblock_object(lines: dict) -> TextBlock: lines = [TextLine([TextWord(**line)]) for line in lines] return TextBlock(lines) + + +def _get_start_end_from_text(text: str) -> tuple[float]: + start, end = text.split("end: ") + start = start.split("start: ")[1] + return float(start), float(end) + + +def _get_file_name_and_page_index(annotation): + file_name = annotation["data"]["ocr"].split("/")[-1] + file_name = file_name.split(".")[0] + return file_name.split("_") diff --git a/src/stratigraphy/util/textblock.py b/src/stratigraphy/util/textblock.py index 93953b6c..15da4aa9 100644 --- a/src/stratigraphy/util/textblock.py +++ b/src/stratigraphy/util/textblock.py @@ -10,6 +10,26 @@ from stratigraphy.util.line import TextLine +@dataclass +class MaterialDescription: + """Class to represent a material description in a PDF document. + + Note: This class is similar to the TextBlock class. As such it has the attributes text and rect. + But it does not have the attribute lines and is missing class methods. TextBlock is used during the extraction + process where more fine-grained information is required. We lose this "fine-grainedness" when we annotate + the boreholes in label-studio. + """ + + text: str + rect: fitz.Rect + + def to_json(self): + return { + "text": self.text, + "rect": [self.rect.x0, self.rect.y0, self.rect.x1, self.rect.y1], + } + + @dataclass class TextBlock: """Class to represent a block of text in a PDF document."""