From 7075748e597a55fc631183c5c3a10915dfe12ade Mon Sep 17 00:00:00 2001 From: Stijn Vermeeren Date: Mon, 16 Sep 2024 15:03:01 +0200 Subject: [PATCH] minor refactoring --- src/stratigraphy/extract.py | 25 +-- src/stratigraphy/main.py | 157 +++++++++--------- src/stratigraphy/util/depthcolumn.py | 2 +- src/stratigraphy/util/find_depth_columns.py | 38 +++++ src/stratigraphy/util/find_description.py | 5 +- src/stratigraphy/util/interval.py | 6 +- .../util/layer_identifier_column.py | 39 ----- 7 files changed, 133 insertions(+), 139 deletions(-) diff --git a/src/stratigraphy/extract.py b/src/stratigraphy/extract.py index e6890fe0..0aaa0ad3 100644 --- a/src/stratigraphy/extract.py +++ b/src/stratigraphy/extract.py @@ -8,6 +8,7 @@ from stratigraphy.util import find_depth_columns from stratigraphy.util.dataclasses import Line from stratigraphy.util.depthcolumn import DepthColumn +from stratigraphy.util.find_depth_columns import get_depth_interval_from_textblock from stratigraphy.util.find_description import ( get_description_blocks, get_description_blocks_from_layer_identifier, @@ -19,7 +20,7 @@ find_layer_identifier_column, find_layer_identifier_column_entries, ) -from stratigraphy.util.line import TextLine +from stratigraphy.util.line import TextLine, TextWord from stratigraphy.util.textblock import TextBlock, block_distance from stratigraphy.util.util import ( remove_empty_predictions, @@ -101,10 +102,8 @@ def process_page( to_delete = [] for i, (_depth_column, material_description_rect) in enumerate(pairs): - for _depth_column_2, material_description_rect_2 in pairs[i + 1 :]: - if material_description_rect.intersects(material_description_rect_2): - to_delete.append(i) - continue + if any(material_description_rect.intersects(other_rect) for _, other_rect in pairs[i + 1 :]): + to_delete.append(i) filtered_pairs = [item for index, item in enumerate(pairs) if index not in to_delete] groups = [] # list of matched depth intervals and text blocks @@ -114,7 +113,7 @@ def process_page( description_lines = get_description_lines(lines, material_description_rect) if len(description_lines) > 1: new_groups = match_columns( - depth_column, description_lines, geometric_lines, material_description_rect, page_number, **params + depth_column, description_lines, geometric_lines, material_description_rect, **params ) groups.extend(new_groups) json_filtered_pairs = [ @@ -174,10 +173,7 @@ def process_page( def score_column_match( - depth_column: DepthColumn, - material_description_rect: fitz.Rect, - all_words: list[TextLine] | None = None, - **params: dict, + depth_column: DepthColumn, material_description_rect: fitz.Rect, all_words: list[TextWord] | None = None ) -> float: """Scores the match between a depth column and a material description. @@ -185,7 +181,6 @@ def score_column_match( depth_column (DepthColumn): The depth column. material_description_rect (fitz.Rect): The material description rectangle. all_words (list[TextLine] | None, optional): List of the available textlines. Defaults to None. - **params (dict): Additional parameters for the matching pipeline. Kept for compatibility with the pipeline. Returns: float: The score of the match. @@ -212,7 +207,6 @@ def match_columns( description_lines: list[TextLine], geometric_lines: list[Line], material_description_rect: fitz.Rect, - page_number: int, **params: dict, ) -> list: """Match the depth column entries with the description lines. @@ -226,7 +220,6 @@ def match_columns( description_lines (list[TextLine]): The description lines. geometric_lines (list[Line]): The geometric lines. material_description_rect (fitz.Rect): The material description rectangle. - page_number (int): The page number. **params (dict): Additional parameters for the matching pipeline. Returns: @@ -244,7 +237,7 @@ def match_columns( blocks = get_description_blocks_from_layer_identifier(depth_column.entries, description_lines) groups = [] for block in blocks: - depth_interval = depth_column.get_depth_interval(block) + depth_interval = get_depth_interval_from_textblock(block) if depth_interval: groups.append({"depth_interval": depth_interval, "block": block}) else: @@ -375,13 +368,13 @@ def split_blocks_by_textline_length(blocks: list[TextBlock], target_split_count: def find_material_description_column( - lines: list[TextLine], depth_column: DepthColumn, language: str, **params: dict + lines: list[TextLine], depth_column: DepthColumn | None, language: str, **params: dict ) -> fitz.Rect | None: """Find the material description column given a depth column. Args: lines (list[TextLine]): The text lines of the page. - depth_column (DepthColumn): The depth column. + depth_column (DepthColumn | None): The depth column. language (str): The language of the page. **params (dict): Additional parameters for the matching pipeline. diff --git a/src/stratigraphy/main.py b/src/stratigraphy/main.py index 779d1349..f4f21778 100644 --- a/src/stratigraphy/main.py +++ b/src/stratigraphy/main.py @@ -149,93 +149,92 @@ def start_pipeline( # if a file is specified instead of an input directory, copy the file to a temporary directory and work with that. if input_directory.is_file(): - file_iterator = [(input_directory.parent, None, [input_directory.name])] + root = input_directory.parent + files = [input_directory.name] else: - file_iterator = os.walk(input_directory) + root = input_directory + _, _, files = next(os.walk(input_directory)) # process the individual pdf files predictions = {} - for root, _dirs, files in file_iterator: - for filename in tqdm(files, desc="Processing files", unit="file"): - if filename.endswith(".pdf"): - in_path = os.path.join(root, filename) - logger.info("Processing file: %s", in_path) - predictions[filename] = {} - - with fitz.Document(in_path) as doc: - language = detect_language_of_document( - doc, matching_params["default_language"], matching_params["material_description"].keys() + for filename in tqdm(files, desc="Processing files", unit="file"): + if filename.endswith(".pdf"): + in_path = os.path.join(root, filename) + logger.info("Processing file: %s", in_path) + predictions[filename] = {} + + with fitz.Document(in_path) as doc: + language = detect_language_of_document( + doc, matching_params["default_language"], matching_params["material_description"].keys() + ) + predictions[filename]["language"] = language + + # Extract the coordinates of the borehole + coordinate_extractor = CoordinateExtractor(document=doc) + coordinates = coordinate_extractor.extract_coordinates() + if coordinates: + predictions[filename]["metadata"] = {"coordinates": coordinates.to_json()} + else: + predictions[filename]["metadata"] = {"coordinates": None} + + # Extract the elevation information + elevation_extractor = ElevationExtractor(document=doc) + elevation = elevation_extractor.extract_elevation() + if elevation: + predictions[filename]["metadata"]["elevation"] = elevation.to_dict() + else: + predictions[filename]["metadata"]["elevation"] = None + + # Extract the groundwater levels + groundwater_extractor = GroundwaterLevelExtractor(document=doc) + groundwater = groundwater_extractor.extract_groundwater() + if groundwater: + predictions[filename]["groundwater"] = [ + groundwater_entry.to_dict() for groundwater_entry in groundwater + ] + else: + predictions[filename]["groundwater"] = None + + layer_predictions_list = [] + depths_materials_column_pairs_list = [] + page_dimensions = [] + for page_index, page in enumerate(doc): + page_number = page_index + 1 + logger.info("Processing page %s", page_number) + + text_lines = extract_text_lines(page) + geometric_lines = extract_lines(page, line_detection_params) + layer_predictions, depths_materials_column_pairs = process_page( + text_lines, geometric_lines, language, page_number, **matching_params ) - predictions[filename]["language"] = language - - # Extract the coordinates of the borehole - coordinate_extractor = CoordinateExtractor(document=doc) - coordinates = coordinate_extractor.extract_coordinates() - if coordinates: - predictions[filename]["metadata"] = {"coordinates": coordinates.to_json()} - else: - predictions[filename]["metadata"] = {"coordinates": None} - - # Extract the elevation information - elevation_extractor = ElevationExtractor(document=doc) - elevation = elevation_extractor.extract_elevation() - if elevation: - predictions[filename]["metadata"]["elevation"] = elevation.to_dict() - else: - predictions[filename]["metadata"]["elevation"] = None - - # Extract the groundwater levels - groundwater_extractor = GroundwaterLevelExtractor(document=doc) - groundwater = groundwater_extractor.extract_groundwater() - if groundwater: - predictions[filename]["groundwater"] = [ - groundwater_entry.to_dict() for groundwater_entry in groundwater - ] - else: - predictions[filename]["groundwater"] = None - - layer_predictions_list = [] - depths_materials_column_pairs_list = [] - page_dimensions = [] - for page_index, page in enumerate(doc): - page_number = page_index + 1 - logger.info("Processing page %s", page_number) - - text_lines = extract_text_lines(page) - geometric_lines = extract_lines(page, line_detection_params) - layer_predictions, depths_materials_column_pairs = process_page( - text_lines, geometric_lines, language, page_number, **matching_params + + # TODO: Add remove duplicates here! + if page_index > 0: + layer_predictions = remove_duplicate_layers( + doc[page_index - 1], + page, + layer_predictions_list, + layer_predictions, + matching_params["img_template_probability_threshold"], ) - # TODO: Add remove duplicates here! - if page_index > 0: - layer_predictions = remove_duplicate_layers( - doc[page_index - 1], - page, - layer_predictions_list, - layer_predictions, - matching_params["img_template_probability_threshold"], + layer_predictions_list.extend(layer_predictions) + depths_materials_column_pairs_list.extend(depths_materials_column_pairs) + page_dimensions.append({"height": page.rect.height, "width": page.rect.width}) + + if draw_lines: # could be changed to if draw_lines and mflow_tracking: + if not mlflow_tracking: + logger.warning("MLFlow tracking is not enabled. MLFLow is required to store the images.") + else: + img = plot_lines( + page, geometric_lines, scale_factor=line_detection_params["pdf_scale_factor"] ) + mlflow.log_image(img, f"pages/{filename}_page_{page.number + 1}_lines.png") + + predictions[filename]["layers"] = layer_predictions_list + predictions[filename]["depths_materials_column_pairs"] = depths_materials_column_pairs_list + predictions[filename]["page_dimensions"] = page_dimensions - layer_predictions_list.extend(layer_predictions) - depths_materials_column_pairs_list.extend(depths_materials_column_pairs) - page_dimensions.append({"height": page.rect.height, "width": page.rect.width}) - - if draw_lines: # could be changed to if draw_lines and mflow_tracking: - if not mlflow_tracking: - logger.warning( - "MLFlow tracking is not enabled. MLFLow is required to store the images." - ) - else: - img = plot_lines( - page, geometric_lines, scale_factor=line_detection_params["pdf_scale_factor"] - ) - mlflow.log_image(img, f"pages/{filename}_page_{page.number + 1}_lines.png") - - predictions[filename]["layers"] = layer_predictions_list - predictions[filename]["depths_materials_column_pairs"] = depths_materials_column_pairs_list - predictions[filename]["page_dimensions"] = page_dimensions - - assert len(page_dimensions) == doc.page_count, "Page count mismatch." + assert len(page_dimensions) == doc.page_count, "Page count mismatch." logger.info("Writing predictions to JSON file %s", predictions_path) with open(predictions_path, "w", encoding="utf8") as file: diff --git a/src/stratigraphy/util/depthcolumn.py b/src/stratigraphy/util/depthcolumn.py index 32920c06..4c636fd4 100644 --- a/src/stratigraphy/util/depthcolumn.py +++ b/src/stratigraphy/util/depthcolumn.py @@ -297,7 +297,7 @@ def significant_intersection(other_rect): intersection = fitz.Rect(other_rect).intersect(self.rect()) return intersection.is_valid and intersection.width > 0.25 * self.rect().width - return len([line for line in all_words if significant_intersection(line.rect)]) - len(self.entries) + return len([word for word in all_words if significant_intersection(word.rect)]) - len(self.entries) def pearson_correlation_coef(self) -> float: # We look at the lower y coordinate, because most often the baseline of the depth value text is aligned with diff --git a/src/stratigraphy/util/find_depth_columns.py b/src/stratigraphy/util/find_depth_columns.py index 41d78637..5a73c947 100644 --- a/src/stratigraphy/util/find_depth_columns.py +++ b/src/stratigraphy/util/find_depth_columns.py @@ -8,6 +8,7 @@ from stratigraphy.util.depthcolumn import BoundaryDepthColumn, LayerDepthColumn from stratigraphy.util.depthcolumnentry import DepthColumnEntry, LayerDepthColumnEntry from stratigraphy.util.line import TextWord +from stratigraphy.util.textblock import TextBlock def depth_column_entries(all_words: list[TextWord], include_splits: bool) -> list[DepthColumnEntry]: @@ -206,3 +207,40 @@ def find_depth_columns( [column for column in numeric_columns if column and boundary_depth_column_validator.is_valid(column)], key=lambda column: len(column.entries), ) + + +def get_depth_interval_from_textblock(block: TextBlock) -> LayerDepthColumnEntry | None: + """Extract depth interval from a material description block. + + For borehole profiles in the Deriaz layout, the depth interval is usually found in the text description + of the material. Often, these text descriptions contain a further separation into multiple sub layers. + These sub layers have their own depth intervals. This function extracts the overall depth interval, + spanning across all mentioned sub layers. + + Args: + block (TextBlock): The block to calculate the depth interval for. + + Returns: + LayerDepthColumnEntry | None: The depth interval. + """ + depth_entries = [] + for line in block.lines: + try: + layer_depth_entry = extract_layer_depth_interval( + line.text, line.rect, line.page_number, require_start_of_string=False + ) + # require_start_of_string = False because the depth interval may not always start at the beginning + # of the line e.g. "Remblais Heterogene: 0.00 - 0.5m" + if layer_depth_entry: + depth_entries.append(layer_depth_entry) + except ValueError: + pass + + if depth_entries: + # Merge the sub layers into one depth interval. + start = min([entry.start for entry in depth_entries], key=lambda start_entry: start_entry.value) + end = max([entry.end for entry in depth_entries], key=lambda end_entry: end_entry.value) + + return LayerDepthColumnEntry(start, end) + else: + return None diff --git a/src/stratigraphy/util/find_description.py b/src/stratigraphy/util/find_description.py index f9219e06..ce0f660d 100644 --- a/src/stratigraphy/util/find_description.py +++ b/src/stratigraphy/util/find_description.py @@ -8,6 +8,7 @@ SplitDescriptionBlockByLine, SplitDescriptionBlockByVerticalSpace, ) +from stratigraphy.util.layer_identifier_column import LayerIdentifierEntry from stratigraphy.util.line import TextLine from stratigraphy.util.textblock import TextBlock @@ -36,12 +37,12 @@ def get_description_lines(lines: list[TextLine], material_description_rect: fitz def get_description_blocks_from_layer_identifier( - layer_identifier_entries: list[TextLine], description_lines: list[TextLine] + layer_identifier_entries: list[LayerIdentifierEntry], description_lines: list[TextLine] ) -> list[TextBlock]: """Divide the description lines into blocks based on the layer identifier entries. Args: - layer_identifier_entries (list[TextLine]): The layer identifier entries. + layer_identifier_entries (list[LayerIdentifierEntry]): The layer identifier entries. description_lines (list[TextLine]): All lines constituting the material description. Returns: diff --git a/src/stratigraphy/util/interval.py b/src/stratigraphy/util/interval.py index 16e32389..abc33230 100644 --- a/src/stratigraphy/util/interval.py +++ b/src/stratigraphy/util/interval.py @@ -84,7 +84,9 @@ def background_rect(self) -> fitz.Rect | None: if self.start and self.end: return fitz.Rect(self.start.rect.x0, self.start.rect.y1, self.start.rect.x1, self.end.rect.y0) - def matching_blocks(self, all_blocks: list[TextBlock], block_index: int) -> tuple[list[TextBlock]]: + def matching_blocks( + self, all_blocks: list[TextBlock], block_index: int + ) -> tuple[list[TextBlock], list[TextBlock], list[TextBlock]]: """Calculates pre, exact and post blocks for the boundary interval. Pre contains all the blocks that are supposed to come before the interval. @@ -96,7 +98,7 @@ def matching_blocks(self, all_blocks: list[TextBlock], block_index: int) -> tupl block_index (int): Index of the current block. Returns: - tuple[list[TextBlock]]: Pre, exact and post blocks. + tuple[list[TextBlock], list[TextBlock], list[TextBlock]]: Pre, exact and post blocks. """ pre, exact, post = [], [], [] diff --git a/src/stratigraphy/util/layer_identifier_column.py b/src/stratigraphy/util/layer_identifier_column.py index e50d5242..5dc4ecc3 100644 --- a/src/stratigraphy/util/layer_identifier_column.py +++ b/src/stratigraphy/util/layer_identifier_column.py @@ -4,10 +4,7 @@ import fitz -from stratigraphy.util.depthcolumn import LayerDepthColumnEntry -from stratigraphy.util.find_depth_columns import extract_layer_depth_interval from stratigraphy.util.line import TextLine -from stratigraphy.util.textblock import TextBlock class LayerIdentifierEntry: @@ -115,42 +112,6 @@ def is_contained(self, rect: fitz.Rect) -> bool: and self.rect().y1 <= rect.y1 ) - def get_depth_interval(self, block: TextBlock) -> LayerDepthColumnEntry: - """Extract depth interval from a material description block. - - For borehole profiles in the Deriaz layout, the depth interval is usually found in the text description - of the material. Often, these text descriptions contain a further separation into multiple sub layers. - These sub layers have their own depth intervals. This function extracts the overall depth interval, - spanning across all mentioned sub layers. - - Args: - block (TextBlock): The block to calculate the depth interval for. - - Returns: - LayerDepthColumnEntry: The depth interval. - """ - depth_entries = [] - for line in block.lines: - try: - layer_depth_entry = extract_layer_depth_interval( - line.text, line.rect, line.page_number, require_start_of_string=False - ) - # require_start_of_string = False because the depth interval may not always start at the beginning - # of the line e.g. "Remblais Heterogene: 0.00 - 0.5m" - if layer_depth_entry: - depth_entries.append(layer_depth_entry) - except ValueError: - pass - - if depth_entries: - # Merge the sub layers into one depth interval. - start = min([entry.start for entry in depth_entries], key=lambda start_entry: start_entry.value) - end = max([entry.end for entry in depth_entries], key=lambda end_entry: end_entry.value) - - return LayerDepthColumnEntry(start, end) - else: - return None - def to_json(self): """Convert the layer identifier column to a JSON serializable format.