Skip to content

Commit

Permalink
minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
stijnvermeeren-swisstopo committed Sep 16, 2024
1 parent ee065a4 commit 7075748
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 139 deletions.
25 changes: 9 additions & 16 deletions src/stratigraphy/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from stratigraphy.util import find_depth_columns
from stratigraphy.util.dataclasses import Line
from stratigraphy.util.depthcolumn import DepthColumn
from stratigraphy.util.find_depth_columns import get_depth_interval_from_textblock
from stratigraphy.util.find_description import (
get_description_blocks,
get_description_blocks_from_layer_identifier,
Expand All @@ -19,7 +20,7 @@
find_layer_identifier_column,
find_layer_identifier_column_entries,
)
from stratigraphy.util.line import TextLine
from stratigraphy.util.line import TextLine, TextWord
from stratigraphy.util.textblock import TextBlock, block_distance
from stratigraphy.util.util import (
remove_empty_predictions,
Expand Down Expand Up @@ -101,10 +102,8 @@ def process_page(

to_delete = []
for i, (_depth_column, material_description_rect) in enumerate(pairs):
for _depth_column_2, material_description_rect_2 in pairs[i + 1 :]:
if material_description_rect.intersects(material_description_rect_2):
to_delete.append(i)
continue
if any(material_description_rect.intersects(other_rect) for _, other_rect in pairs[i + 1 :]):
to_delete.append(i)
filtered_pairs = [item for index, item in enumerate(pairs) if index not in to_delete]

groups = [] # list of matched depth intervals and text blocks
Expand All @@ -114,7 +113,7 @@ def process_page(
description_lines = get_description_lines(lines, material_description_rect)
if len(description_lines) > 1:
new_groups = match_columns(
depth_column, description_lines, geometric_lines, material_description_rect, page_number, **params
depth_column, description_lines, geometric_lines, material_description_rect, **params
)
groups.extend(new_groups)
json_filtered_pairs = [
Expand Down Expand Up @@ -174,18 +173,14 @@ def process_page(


def score_column_match(
depth_column: DepthColumn,
material_description_rect: fitz.Rect,
all_words: list[TextLine] | None = None,
**params: dict,
depth_column: DepthColumn, material_description_rect: fitz.Rect, all_words: list[TextWord] | None = None
) -> float:
"""Scores the match between a depth column and a material description.
Args:
depth_column (DepthColumn): The depth column.
material_description_rect (fitz.Rect): The material description rectangle.
all_words (list[TextLine] | None, optional): List of the available textlines. Defaults to None.
**params (dict): Additional parameters for the matching pipeline. Kept for compatibility with the pipeline.
Returns:
float: The score of the match.
Expand All @@ -212,7 +207,6 @@ def match_columns(
description_lines: list[TextLine],
geometric_lines: list[Line],
material_description_rect: fitz.Rect,
page_number: int,
**params: dict,
) -> list:
"""Match the depth column entries with the description lines.
Expand All @@ -226,7 +220,6 @@ def match_columns(
description_lines (list[TextLine]): The description lines.
geometric_lines (list[Line]): The geometric lines.
material_description_rect (fitz.Rect): The material description rectangle.
page_number (int): The page number.
**params (dict): Additional parameters for the matching pipeline.
Returns:
Expand All @@ -244,7 +237,7 @@ def match_columns(
blocks = get_description_blocks_from_layer_identifier(depth_column.entries, description_lines)
groups = []
for block in blocks:
depth_interval = depth_column.get_depth_interval(block)
depth_interval = get_depth_interval_from_textblock(block)
if depth_interval:
groups.append({"depth_interval": depth_interval, "block": block})
else:
Expand Down Expand Up @@ -375,13 +368,13 @@ def split_blocks_by_textline_length(blocks: list[TextBlock], target_split_count:


def find_material_description_column(
lines: list[TextLine], depth_column: DepthColumn, language: str, **params: dict
lines: list[TextLine], depth_column: DepthColumn | None, language: str, **params: dict
) -> fitz.Rect | None:
"""Find the material description column given a depth column.
Args:
lines (list[TextLine]): The text lines of the page.
depth_column (DepthColumn): The depth column.
depth_column (DepthColumn | None): The depth column.
language (str): The language of the page.
**params (dict): Additional parameters for the matching pipeline.
Expand Down
157 changes: 78 additions & 79 deletions src/stratigraphy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,93 +149,92 @@ def start_pipeline(

# if a file is specified instead of an input directory, copy the file to a temporary directory and work with that.
if input_directory.is_file():
file_iterator = [(input_directory.parent, None, [input_directory.name])]
root = input_directory.parent
files = [input_directory.name]
else:
file_iterator = os.walk(input_directory)
root = input_directory
_, _, files = next(os.walk(input_directory))
# process the individual pdf files
predictions = {}
for root, _dirs, files in file_iterator:
for filename in tqdm(files, desc="Processing files", unit="file"):
if filename.endswith(".pdf"):
in_path = os.path.join(root, filename)
logger.info("Processing file: %s", in_path)
predictions[filename] = {}

with fitz.Document(in_path) as doc:
language = detect_language_of_document(
doc, matching_params["default_language"], matching_params["material_description"].keys()
for filename in tqdm(files, desc="Processing files", unit="file"):
if filename.endswith(".pdf"):
in_path = os.path.join(root, filename)
logger.info("Processing file: %s", in_path)
predictions[filename] = {}

with fitz.Document(in_path) as doc:
language = detect_language_of_document(
doc, matching_params["default_language"], matching_params["material_description"].keys()
)
predictions[filename]["language"] = language

# Extract the coordinates of the borehole
coordinate_extractor = CoordinateExtractor(document=doc)
coordinates = coordinate_extractor.extract_coordinates()
if coordinates:
predictions[filename]["metadata"] = {"coordinates": coordinates.to_json()}
else:
predictions[filename]["metadata"] = {"coordinates": None}

# Extract the elevation information
elevation_extractor = ElevationExtractor(document=doc)
elevation = elevation_extractor.extract_elevation()
if elevation:
predictions[filename]["metadata"]["elevation"] = elevation.to_dict()
else:
predictions[filename]["metadata"]["elevation"] = None

# Extract the groundwater levels
groundwater_extractor = GroundwaterLevelExtractor(document=doc)
groundwater = groundwater_extractor.extract_groundwater()
if groundwater:
predictions[filename]["groundwater"] = [
groundwater_entry.to_dict() for groundwater_entry in groundwater
]
else:
predictions[filename]["groundwater"] = None

layer_predictions_list = []
depths_materials_column_pairs_list = []
page_dimensions = []
for page_index, page in enumerate(doc):
page_number = page_index + 1
logger.info("Processing page %s", page_number)

text_lines = extract_text_lines(page)
geometric_lines = extract_lines(page, line_detection_params)
layer_predictions, depths_materials_column_pairs = process_page(
text_lines, geometric_lines, language, page_number, **matching_params
)
predictions[filename]["language"] = language

# Extract the coordinates of the borehole
coordinate_extractor = CoordinateExtractor(document=doc)
coordinates = coordinate_extractor.extract_coordinates()
if coordinates:
predictions[filename]["metadata"] = {"coordinates": coordinates.to_json()}
else:
predictions[filename]["metadata"] = {"coordinates": None}

# Extract the elevation information
elevation_extractor = ElevationExtractor(document=doc)
elevation = elevation_extractor.extract_elevation()
if elevation:
predictions[filename]["metadata"]["elevation"] = elevation.to_dict()
else:
predictions[filename]["metadata"]["elevation"] = None

# Extract the groundwater levels
groundwater_extractor = GroundwaterLevelExtractor(document=doc)
groundwater = groundwater_extractor.extract_groundwater()
if groundwater:
predictions[filename]["groundwater"] = [
groundwater_entry.to_dict() for groundwater_entry in groundwater
]
else:
predictions[filename]["groundwater"] = None

layer_predictions_list = []
depths_materials_column_pairs_list = []
page_dimensions = []
for page_index, page in enumerate(doc):
page_number = page_index + 1
logger.info("Processing page %s", page_number)

text_lines = extract_text_lines(page)
geometric_lines = extract_lines(page, line_detection_params)
layer_predictions, depths_materials_column_pairs = process_page(
text_lines, geometric_lines, language, page_number, **matching_params

# TODO: Add remove duplicates here!
if page_index > 0:
layer_predictions = remove_duplicate_layers(
doc[page_index - 1],
page,
layer_predictions_list,
layer_predictions,
matching_params["img_template_probability_threshold"],
)

# TODO: Add remove duplicates here!
if page_index > 0:
layer_predictions = remove_duplicate_layers(
doc[page_index - 1],
page,
layer_predictions_list,
layer_predictions,
matching_params["img_template_probability_threshold"],
layer_predictions_list.extend(layer_predictions)
depths_materials_column_pairs_list.extend(depths_materials_column_pairs)
page_dimensions.append({"height": page.rect.height, "width": page.rect.width})

if draw_lines: # could be changed to if draw_lines and mflow_tracking:
if not mlflow_tracking:
logger.warning("MLFlow tracking is not enabled. MLFLow is required to store the images.")
else:
img = plot_lines(
page, geometric_lines, scale_factor=line_detection_params["pdf_scale_factor"]
)
mlflow.log_image(img, f"pages/{filename}_page_{page.number + 1}_lines.png")

predictions[filename]["layers"] = layer_predictions_list
predictions[filename]["depths_materials_column_pairs"] = depths_materials_column_pairs_list
predictions[filename]["page_dimensions"] = page_dimensions

layer_predictions_list.extend(layer_predictions)
depths_materials_column_pairs_list.extend(depths_materials_column_pairs)
page_dimensions.append({"height": page.rect.height, "width": page.rect.width})

if draw_lines: # could be changed to if draw_lines and mflow_tracking:
if not mlflow_tracking:
logger.warning(
"MLFlow tracking is not enabled. MLFLow is required to store the images."
)
else:
img = plot_lines(
page, geometric_lines, scale_factor=line_detection_params["pdf_scale_factor"]
)
mlflow.log_image(img, f"pages/{filename}_page_{page.number + 1}_lines.png")

predictions[filename]["layers"] = layer_predictions_list
predictions[filename]["depths_materials_column_pairs"] = depths_materials_column_pairs_list
predictions[filename]["page_dimensions"] = page_dimensions

assert len(page_dimensions) == doc.page_count, "Page count mismatch."
assert len(page_dimensions) == doc.page_count, "Page count mismatch."

logger.info("Writing predictions to JSON file %s", predictions_path)
with open(predictions_path, "w", encoding="utf8") as file:
Expand Down
2 changes: 1 addition & 1 deletion src/stratigraphy/util/depthcolumn.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def significant_intersection(other_rect):
intersection = fitz.Rect(other_rect).intersect(self.rect())
return intersection.is_valid and intersection.width > 0.25 * self.rect().width

return len([line for line in all_words if significant_intersection(line.rect)]) - len(self.entries)
return len([word for word in all_words if significant_intersection(word.rect)]) - len(self.entries)

def pearson_correlation_coef(self) -> float:
# We look at the lower y coordinate, because most often the baseline of the depth value text is aligned with
Expand Down
38 changes: 38 additions & 0 deletions src/stratigraphy/util/find_depth_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from stratigraphy.util.depthcolumn import BoundaryDepthColumn, LayerDepthColumn
from stratigraphy.util.depthcolumnentry import DepthColumnEntry, LayerDepthColumnEntry
from stratigraphy.util.line import TextWord
from stratigraphy.util.textblock import TextBlock


def depth_column_entries(all_words: list[TextWord], include_splits: bool) -> list[DepthColumnEntry]:
Expand Down Expand Up @@ -206,3 +207,40 @@ def find_depth_columns(
[column for column in numeric_columns if column and boundary_depth_column_validator.is_valid(column)],
key=lambda column: len(column.entries),
)


def get_depth_interval_from_textblock(block: TextBlock) -> LayerDepthColumnEntry | None:
"""Extract depth interval from a material description block.
For borehole profiles in the Deriaz layout, the depth interval is usually found in the text description
of the material. Often, these text descriptions contain a further separation into multiple sub layers.
These sub layers have their own depth intervals. This function extracts the overall depth interval,
spanning across all mentioned sub layers.
Args:
block (TextBlock): The block to calculate the depth interval for.
Returns:
LayerDepthColumnEntry | None: The depth interval.
"""
depth_entries = []
for line in block.lines:
try:
layer_depth_entry = extract_layer_depth_interval(
line.text, line.rect, line.page_number, require_start_of_string=False
)
# require_start_of_string = False because the depth interval may not always start at the beginning
# of the line e.g. "Remblais Heterogene: 0.00 - 0.5m"
if layer_depth_entry:
depth_entries.append(layer_depth_entry)
except ValueError:
pass

if depth_entries:
# Merge the sub layers into one depth interval.
start = min([entry.start for entry in depth_entries], key=lambda start_entry: start_entry.value)
end = max([entry.end for entry in depth_entries], key=lambda end_entry: end_entry.value)

return LayerDepthColumnEntry(start, end)
else:
return None
5 changes: 3 additions & 2 deletions src/stratigraphy/util/find_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
SplitDescriptionBlockByLine,
SplitDescriptionBlockByVerticalSpace,
)
from stratigraphy.util.layer_identifier_column import LayerIdentifierEntry
from stratigraphy.util.line import TextLine
from stratigraphy.util.textblock import TextBlock

Expand Down Expand Up @@ -36,12 +37,12 @@ def get_description_lines(lines: list[TextLine], material_description_rect: fitz


def get_description_blocks_from_layer_identifier(
layer_identifier_entries: list[TextLine], description_lines: list[TextLine]
layer_identifier_entries: list[LayerIdentifierEntry], description_lines: list[TextLine]
) -> list[TextBlock]:
"""Divide the description lines into blocks based on the layer identifier entries.
Args:
layer_identifier_entries (list[TextLine]): The layer identifier entries.
layer_identifier_entries (list[LayerIdentifierEntry]): The layer identifier entries.
description_lines (list[TextLine]): All lines constituting the material description.
Returns:
Expand Down
6 changes: 4 additions & 2 deletions src/stratigraphy/util/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def background_rect(self) -> fitz.Rect | None:
if self.start and self.end:
return fitz.Rect(self.start.rect.x0, self.start.rect.y1, self.start.rect.x1, self.end.rect.y0)

def matching_blocks(self, all_blocks: list[TextBlock], block_index: int) -> tuple[list[TextBlock]]:
def matching_blocks(
self, all_blocks: list[TextBlock], block_index: int
) -> tuple[list[TextBlock], list[TextBlock], list[TextBlock]]:
"""Calculates pre, exact and post blocks for the boundary interval.
Pre contains all the blocks that are supposed to come before the interval.
Expand All @@ -96,7 +98,7 @@ def matching_blocks(self, all_blocks: list[TextBlock], block_index: int) -> tupl
block_index (int): Index of the current block.
Returns:
tuple[list[TextBlock]]: Pre, exact and post blocks.
tuple[list[TextBlock], list[TextBlock], list[TextBlock]]: Pre, exact and post blocks.
"""
pre, exact, post = [], [], []

Expand Down
Loading

0 comments on commit 7075748

Please sign in to comment.