diff --git a/src/stratigraphy/util/boundarydepthcolumnvalidator.py b/src/stratigraphy/util/boundarydepthcolumnvalidator.py index e2af86b6..e774eead 100644 --- a/src/stratigraphy/util/boundarydepthcolumnvalidator.py +++ b/src/stratigraphy/util/boundarydepthcolumnvalidator.py @@ -1,7 +1,6 @@ """This module contains logic to validate BoundaryDepthColumn instances.""" import dataclasses -import re from stratigraphy.util.depthcolumn import BoundaryDepthColumn from stratigraphy.util.depthcolumnentry import DepthColumnEntry @@ -57,7 +56,7 @@ def is_valid(self, column: BoundaryDepthColumn, corr_coef_threshold: float = 0.9 ): return False # Check if the entries are strictly increasing. - if not all(i.value < j.value for i, j in zip(column.entries, column.entries[1:], strict=False)): + if not column.is_strictly_increasing(): return False corr_coef = column.pearson_correlation_coef() @@ -100,41 +99,43 @@ def correct_OCR_mistakes(self, column: BoundaryDepthColumn) -> BoundaryDepthColu Returns: BoundaryDepthColumn | None: The corrected depth column, or None if no correction was possible. """ - new_columns = [] - for remove_index in range(len(column.entries)): - new_columns.append( - BoundaryDepthColumn( - [ - entry if index != remove_index else _correct_entry(entry) - for index, entry in enumerate(column.entries) - ], - ), - ) - best_column = max(new_columns, key=lambda column: column.pearson_correlation_coef()) - - # We require a higher correlation coefficient when we've already corrected a mistake. - if self.is_valid(best_column, corr_coef_threshold=0.999): - return best_column - else: - return None - - -def _correct_entry(entry: DepthColumnEntry) -> DepthColumnEntry: + new_columns = [BoundaryDepthColumn()] + for entry in column.entries: + new_columns = [ + BoundaryDepthColumn([*column.entries, DepthColumnEntry(entry.rect, new_value)]) + for column in new_columns + for new_value in _value_alternatives(entry.value) + ] + # Immediately require strictly increasing values, to avoid exponential complexity when many implausible + # alternative values are suggested + new_columns = [column for column in new_columns if column.is_strictly_increasing()] + + if len(new_columns): + best_column = max(new_columns, key=lambda column: column.pearson_correlation_coef()) + + # We require a higher correlation coefficient when we've already corrected a mistake. + if self.is_valid(best_column, corr_coef_threshold=0.999): + return best_column + + return None + + +def _value_alternatives(value: float) -> set[float]: """Corrects frequent OCR errors in depth column entries. Args: - entry (DepthColumnEntry): The depth column entry to correct. + value (float): The depth values to find plausible alternatives for Returns: - DepthColumnEntry: The corrected depth column entry. + set(float): all plausible values (including the original one) """ - text_value = str(entry.value) - text_value = text_value.replace("4", "1") # In older documents, OCR sometimes mistakes 1 for 4 + alternatives = {value} + # In older documents, OCR sometimes mistakes 1 for 4 + alternatives.add(float(str(value).replace("4", "1"))) # replace a pattern such as '.80' with '0.80'. These cases are already converted - # to '80.0' when depth entries are recognized. Whe therefore look at patterns such as '80.0' - # that start with a digit, followed by a '0.0'. We then replace it with a pattern such as '0.80'. - if re.match(r"^[0-9]0\.0$", text_value): - text_value = text_value.replace(".", "") - text_value = "0." + text_value - return DepthColumnEntry(entry.rect, float(text_value)) + # to '80.0' when depth entries are recognized. + if value.is_integer(): + alternatives.add(value / 100) + + return alternatives diff --git a/src/stratigraphy/util/depthcolumn.py b/src/stratigraphy/util/depthcolumn.py index 9f5a477b..d3d7a0a5 100644 --- a/src/stratigraphy/util/depthcolumn.py +++ b/src/stratigraphy/util/depthcolumn.py @@ -229,11 +229,14 @@ def valid_initial_segment(self, rect: fitz.Rect) -> BoundaryDepthColumn: return initial_segment return BoundaryDepthColumn() - def strictly_contains(self, other: BoundaryDepthColumn): + def strictly_contains(self, other: BoundaryDepthColumn) -> bool: return len(other.entries) < len(self.entries) and all( other_entry in self.entries for other_entry in other.entries ) + def is_strictly_increasing(self) -> bool: + return all(i.value < j.value for i, j in zip(self.entries, self.entries[1:], strict=False)) + def depth_intervals(self) -> list[BoundaryInterval]: """Creates a list of depth intervals from the depth column entries.