Skip to content

Commit

Permalink
Add create_from_label_studio and convert_to_ground_truth methods to F…
Browse files Browse the repository at this point in the history
…ilePredictions.

These methods will allow to load annotations from label studio and use the predictions dataclasses to perform operations with them. It will also allow to create the ground_truth files from the label-studio annotations.
  • Loading branch information
redur committed Apr 24, 2024
1 parent 9a9d898 commit 9c2d3f1
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 4 deletions.
14 changes: 14 additions & 0 deletions src/stratigraphy/util/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import abc
from dataclasses import dataclass

import fitz

Expand Down Expand Up @@ -50,6 +51,19 @@ def to_json(self):
}


@dataclass
class AnnotatedInterval:
"""Class for annotated intervals."""

start: float
end: float
background_rect: fitz.Rect

@property
def line_anchor(self) -> fitz.Point:
return fitz.Point(0, 0)


class BoundaryInterval(Interval):
"""Class for boundary intervals.
Expand Down
142 changes: 138 additions & 4 deletions src/stratigraphy/util/predictions.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
"""This module contains classes for predictions."""

import contextlib
import uuid
from collections import defaultdict
from dataclasses import dataclass, field

import fitz
import Levenshtein

from stratigraphy.util.depthcolumnentry import DepthColumnEntry
from stratigraphy.util.interval import BoundaryInterval
from stratigraphy.util.interval import AnnotatedInterval, BoundaryInterval
from stratigraphy.util.line import TextLine, TextWord
from stratigraphy.util.textblock import TextBlock
from stratigraphy.util.textblock import MaterialDescription, TextBlock
from stratigraphy.util.util import parse_text


@dataclass
class LayerPrediction:
"""A class to represent predictions for a single layer."""

material_description: TextBlock
depth_interval: BoundaryInterval
material_description: TextBlock | MaterialDescription
depth_interval: BoundaryInterval | AnnotatedInterval
material_is_correct: bool = None
depth_interval_is_correct: bool = None
id: uuid.UUID = field(default_factory=uuid.uuid4)
Expand Down Expand Up @@ -105,6 +107,126 @@ def create_from_json(predictions_for_file: dict, file_name: str):

return FilePredictions(pages=page_predictions_class, file_name=file_name, language=file_language)

@staticmethod
def create_from_label_studio(annotation_results: dict):
"""Create predictions class for a file given the annotation results from Label Studio.
NOTE: We may want to adjust this method to return a single instance of the class,
instead of a list of class objects.
Args:
annotation_results (dict): The annotation results from Label Studio.
The annotation_results can cover multiple files.
Returns:
list[FilePredictions]: A list of FilePredictions objects, one for each file present in the
annotation_results.
"""
file_pages = defaultdict(list)
for annotation in annotation_results:
# get page level information
file_name, page_index = _get_file_name_and_page_index(annotation)
page_width = annotation["annotations"][0]["result"][0]["original_width"]
page_height = annotation["annotations"][0]["result"][0]["original_height"]

# extract all material descriptions and depth intervals and link them together
# Note: we need to loop through the annotations twice, because the order of the annotations is
# not guaranteed. In the first iteration we grasp all IDs, in the second iteration we extract the
# information for each id.
material_descriptions = {}
depth_intervals = {}
linking_objects = []

# define all the material descriptions and depth intervals with their ids
for annotation_result in annotation["annotations"][0]["result"]:
if annotation_result["type"] == "labels":
if annotation_result["value"]["labels"] == ["Material Description"]:
material_descriptions[annotation_result["id"]] = {
"rect": annotation_result["value"]
} # TODO extract rectangle properly
elif annotation_result["value"]["labels"] == ["Depth Interval"]:
depth_intervals[annotation_result["id"]] = {}
if annotation_result["type"] == "relation":
linking_objects.append(
{"from_id": annotation_result["from_id"], "to_id": annotation_result["to_id"]}
)

# check annotation results for material description or depth interval ids
for annotation_result in annotation["annotations"][0]["result"]:
with contextlib.suppress(KeyError):
id = annotation_result["id"] # relation regions do not have an ID.
if annotation_result["type"] == "textarea":
if id in material_descriptions:
material_descriptions[id]["text"] = annotation_result["value"]["text"][
0
] # There is always only one element. TO CHECK!
if len(annotation_result["value"]["text"]) > 1:
print(f"More than one text in material description: {annotation_result['value']['text']}")
elif id in depth_intervals:
depth_interval_text = annotation_result["value"]["text"][0]
start, end = _get_start_end_from_text(depth_interval_text)
depth_intervals[id]["start"] = start
depth_intervals[id]["end"] = end
depth_intervals[id]["background_rect"] = annotation_result[
"value"
] # TODO extract rectangle properly
else:
print(f"Unknown id: {id}")

# create the layer prediction objects by linking material descriptions with depth intervals
layers = []

for link in linking_objects:
from_id = link["from_id"]
to_id = link["to_id"]
material_description_prediction = MaterialDescription(**material_descriptions.pop(from_id))
depth_interval_prediction = AnnotatedInterval(**depth_intervals.pop(to_id))
layers.append(
LayerPrediction(
material_description=material_description_prediction,
depth_interval=depth_interval_prediction,
material_is_correct=True,
depth_interval_is_correct=True,
)
)

if material_descriptions or depth_intervals:
# TODO: This should not be acceptable. Raising an error doesnt seem the right way to go either.
# But at least it should be warned.
print("There are material descriptions or depth intervals left over.")
print(material_descriptions)
print(depth_intervals)

file_pages[file_name].append(
PagePredictions(layers=layers, page_number=page_index, page_width=page_width, page_height=page_height)
)

file_predictions = []
for file_name, page_predictions in file_pages.items():
file_predictions.append(
FilePredictions(file_name=file_name, pages=page_predictions, language="unknown")
) # TODO: language should not be required here.

return file_predictions

def convert_to_ground_truth(self):
"""Convert the predictions to ground truth format.
Returns:
dict: The predictions in ground truth format.
"""
for page in self.pages:
layers = []
for layer in page.layers:
material_description = layer.material_description.text
depth_interval = {
"start": layer.depth_interval.start.value if layer.depth_interval.start else None,
"end": layer.depth_interval.end.value if layer.depth_interval.end else None,
}
layers.append({"material_description": material_description, "depth_interval": depth_interval})
ground_truth = {self.file_name: {"layers": layers}}
return ground_truth

def evaluate(self, ground_truth_layers: list):
"""Evaluate all layers of the predictions against the ground truth.
Expand Down Expand Up @@ -169,3 +291,15 @@ def _find_matching_layer(self, layer: LayerPrediction) -> tuple[dict, bool] | tu
def _create_textblock_object(lines: dict) -> TextBlock:
lines = [TextLine([TextWord(**line)]) for line in lines]
return TextBlock(lines)


def _get_start_end_from_text(text: str) -> tuple[float]:
start, end = text.split("end: ")
start = start.split("start: ")[1]
return float(start), float(end)


def _get_file_name_and_page_index(annotation):
file_name = annotation["data"]["ocr"].split("/")[-1]
file_name = file_name.split(".")[0]
return file_name.split("_")
20 changes: 20 additions & 0 deletions src/stratigraphy/util/textblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,26 @@
from stratigraphy.util.line import TextLine


@dataclass
class MaterialDescription:
"""Class to represent a material description in a PDF document.
Note: This class is similar to the TextBlock class. As such it has the attributes text and rect.
But it does not have the attribute lines and is missing class methods. TextBlock is used during the extraction
process where more fine-grained information is required. We lose this "fine-grainedness" when we annotate
the boreholes in label-studio.
"""

text: str
rect: fitz.Rect

def to_json(self):
return {
"text": self.text,
"rect": [self.rect.x0, self.rect.y0, self.rect.x1, self.rect.y1],
}


@dataclass
class TextBlock:
"""Class to represent a block of text in a PDF document."""
Expand Down

0 comments on commit 9c2d3f1

Please sign in to comment.