From 9c2d3f1875de56908d311581f41eb778a76536e8 Mon Sep 17 00:00:00 2001 From: Renato Durrer Date: Wed, 24 Apr 2024 15:52:37 +0200 Subject: [PATCH] Add create_from_label_studio and convert_to_ground_truth methods to FilePredictions. These methods will allow to load annotations from label studio and use the predictions dataclasses to perform operations with them. It will also allow to create the ground_truth files from the label-studio annotations. --- src/stratigraphy/util/interval.py | 14 +++ src/stratigraphy/util/predictions.py | 142 ++++++++++++++++++++++++++- src/stratigraphy/util/textblock.py | 20 ++++ 3 files changed, 172 insertions(+), 4 deletions(-) diff --git a/src/stratigraphy/util/interval.py b/src/stratigraphy/util/interval.py index d9c070bd..e137aeda 100644 --- a/src/stratigraphy/util/interval.py +++ b/src/stratigraphy/util/interval.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc +from dataclasses import dataclass import fitz @@ -50,6 +51,19 @@ def to_json(self): } +@dataclass +class AnnotatedInterval: + """Class for annotated intervals.""" + + start: float + end: float + background_rect: fitz.Rect + + @property + def line_anchor(self) -> fitz.Point: + return fitz.Point(0, 0) + + class BoundaryInterval(Interval): """Class for boundary intervals. diff --git a/src/stratigraphy/util/predictions.py b/src/stratigraphy/util/predictions.py index 77c9dcdd..84f326c5 100644 --- a/src/stratigraphy/util/predictions.py +++ b/src/stratigraphy/util/predictions.py @@ -1,15 +1,17 @@ """This module contains classes for predictions.""" +import contextlib import uuid +from collections import defaultdict from dataclasses import dataclass, field import fitz import Levenshtein from stratigraphy.util.depthcolumnentry import DepthColumnEntry -from stratigraphy.util.interval import BoundaryInterval +from stratigraphy.util.interval import AnnotatedInterval, BoundaryInterval from stratigraphy.util.line import TextLine, TextWord -from stratigraphy.util.textblock import TextBlock +from stratigraphy.util.textblock import MaterialDescription, TextBlock from stratigraphy.util.util import parse_text @@ -17,8 +19,8 @@ class LayerPrediction: """A class to represent predictions for a single layer.""" - material_description: TextBlock - depth_interval: BoundaryInterval + material_description: TextBlock | MaterialDescription + depth_interval: BoundaryInterval | AnnotatedInterval material_is_correct: bool = None depth_interval_is_correct: bool = None id: uuid.UUID = field(default_factory=uuid.uuid4) @@ -105,6 +107,126 @@ def create_from_json(predictions_for_file: dict, file_name: str): return FilePredictions(pages=page_predictions_class, file_name=file_name, language=file_language) + @staticmethod + def create_from_label_studio(annotation_results: dict): + """Create predictions class for a file given the annotation results from Label Studio. + + NOTE: We may want to adjust this method to return a single instance of the class, + instead of a list of class objects. + + Args: + annotation_results (dict): The annotation results from Label Studio. + The annotation_results can cover multiple files. + + Returns: + list[FilePredictions]: A list of FilePredictions objects, one for each file present in the + annotation_results. + """ + file_pages = defaultdict(list) + for annotation in annotation_results: + # get page level information + file_name, page_index = _get_file_name_and_page_index(annotation) + page_width = annotation["annotations"][0]["result"][0]["original_width"] + page_height = annotation["annotations"][0]["result"][0]["original_height"] + + # extract all material descriptions and depth intervals and link them together + # Note: we need to loop through the annotations twice, because the order of the annotations is + # not guaranteed. In the first iteration we grasp all IDs, in the second iteration we extract the + # information for each id. + material_descriptions = {} + depth_intervals = {} + linking_objects = [] + + # define all the material descriptions and depth intervals with their ids + for annotation_result in annotation["annotations"][0]["result"]: + if annotation_result["type"] == "labels": + if annotation_result["value"]["labels"] == ["Material Description"]: + material_descriptions[annotation_result["id"]] = { + "rect": annotation_result["value"] + } # TODO extract rectangle properly + elif annotation_result["value"]["labels"] == ["Depth Interval"]: + depth_intervals[annotation_result["id"]] = {} + if annotation_result["type"] == "relation": + linking_objects.append( + {"from_id": annotation_result["from_id"], "to_id": annotation_result["to_id"]} + ) + + # check annotation results for material description or depth interval ids + for annotation_result in annotation["annotations"][0]["result"]: + with contextlib.suppress(KeyError): + id = annotation_result["id"] # relation regions do not have an ID. + if annotation_result["type"] == "textarea": + if id in material_descriptions: + material_descriptions[id]["text"] = annotation_result["value"]["text"][ + 0 + ] # There is always only one element. TO CHECK! + if len(annotation_result["value"]["text"]) > 1: + print(f"More than one text in material description: {annotation_result['value']['text']}") + elif id in depth_intervals: + depth_interval_text = annotation_result["value"]["text"][0] + start, end = _get_start_end_from_text(depth_interval_text) + depth_intervals[id]["start"] = start + depth_intervals[id]["end"] = end + depth_intervals[id]["background_rect"] = annotation_result[ + "value" + ] # TODO extract rectangle properly + else: + print(f"Unknown id: {id}") + + # create the layer prediction objects by linking material descriptions with depth intervals + layers = [] + + for link in linking_objects: + from_id = link["from_id"] + to_id = link["to_id"] + material_description_prediction = MaterialDescription(**material_descriptions.pop(from_id)) + depth_interval_prediction = AnnotatedInterval(**depth_intervals.pop(to_id)) + layers.append( + LayerPrediction( + material_description=material_description_prediction, + depth_interval=depth_interval_prediction, + material_is_correct=True, + depth_interval_is_correct=True, + ) + ) + + if material_descriptions or depth_intervals: + # TODO: This should not be acceptable. Raising an error doesnt seem the right way to go either. + # But at least it should be warned. + print("There are material descriptions or depth intervals left over.") + print(material_descriptions) + print(depth_intervals) + + file_pages[file_name].append( + PagePredictions(layers=layers, page_number=page_index, page_width=page_width, page_height=page_height) + ) + + file_predictions = [] + for file_name, page_predictions in file_pages.items(): + file_predictions.append( + FilePredictions(file_name=file_name, pages=page_predictions, language="unknown") + ) # TODO: language should not be required here. + + return file_predictions + + def convert_to_ground_truth(self): + """Convert the predictions to ground truth format. + + Returns: + dict: The predictions in ground truth format. + """ + for page in self.pages: + layers = [] + for layer in page.layers: + material_description = layer.material_description.text + depth_interval = { + "start": layer.depth_interval.start.value if layer.depth_interval.start else None, + "end": layer.depth_interval.end.value if layer.depth_interval.end else None, + } + layers.append({"material_description": material_description, "depth_interval": depth_interval}) + ground_truth = {self.file_name: {"layers": layers}} + return ground_truth + def evaluate(self, ground_truth_layers: list): """Evaluate all layers of the predictions against the ground truth. @@ -169,3 +291,15 @@ def _find_matching_layer(self, layer: LayerPrediction) -> tuple[dict, bool] | tu def _create_textblock_object(lines: dict) -> TextBlock: lines = [TextLine([TextWord(**line)]) for line in lines] return TextBlock(lines) + + +def _get_start_end_from_text(text: str) -> tuple[float]: + start, end = text.split("end: ") + start = start.split("start: ")[1] + return float(start), float(end) + + +def _get_file_name_and_page_index(annotation): + file_name = annotation["data"]["ocr"].split("/")[-1] + file_name = file_name.split(".")[0] + return file_name.split("_") diff --git a/src/stratigraphy/util/textblock.py b/src/stratigraphy/util/textblock.py index 93953b6c..15da4aa9 100644 --- a/src/stratigraphy/util/textblock.py +++ b/src/stratigraphy/util/textblock.py @@ -10,6 +10,26 @@ from stratigraphy.util.line import TextLine +@dataclass +class MaterialDescription: + """Class to represent a material description in a PDF document. + + Note: This class is similar to the TextBlock class. As such it has the attributes text and rect. + But it does not have the attribute lines and is missing class methods. TextBlock is used during the extraction + process where more fine-grained information is required. We lose this "fine-grainedness" when we annotate + the boreholes in label-studio. + """ + + text: str + rect: fitz.Rect + + def to_json(self): + return { + "text": self.text, + "rect": [self.rect.x0, self.rect.y0, self.rect.x1, self.rect.y1], + } + + @dataclass class TextBlock: """Class to represent a block of text in a PDF document."""