diff --git a/src/stratigraphy/line_detection.py b/src/stratigraphy/line_detection.py index bc19a15c..411c4bd2 100644 --- a/src/stratigraphy/line_detection.py +++ b/src/stratigraphy/line_detection.py @@ -11,6 +11,7 @@ from stratigraphy.util.dataclasses import Line from stratigraphy.util.geometric_line_utilities import ( + deduplicate_lines, drop_vertical_lines, merge_parallel_lines_approximately, merge_parallel_lines_quadtree, @@ -53,7 +54,9 @@ def detect_lines_lsd(page: fitz.Page, scale_factor=2, lsd_params=None) -> ArrayL # Detect lines in the image lines = lsd.detect(gray)[0] - return [line_from_array(line, scale_factor) for line in lines] + converted_lines = [line_from_array(line, scale_factor) for line in lines] + deduplicated_lines = deduplicate_lines(converted_lines) + return deduplicated_lines def extract_lines(page: fitz.Page, line_detection_params: dict) -> list[Line]: diff --git a/src/stratigraphy/util/dataclasses.py b/src/stratigraphy/util/dataclasses.py index a9b1eb26..c8f0c7ef 100644 --- a/src/stratigraphy/util/dataclasses.py +++ b/src/stratigraphy/util/dataclasses.py @@ -71,13 +71,6 @@ def remove(self, line_index: str): del self.hashmap[line_index] def add(self, line: Line) -> str: - if not self._check_if_present(line): - key = uuid.uuid4().hex - self.hashmap[key] = line - return key - else: - logger.warning("Line already present in IndexedLines.") - return None - - def _check_if_present(self, line: Line) -> bool: - return any(value.start == line.start and value.end == line.end for _key, value in self.hashmap.items()) + key = uuid.uuid4().hex + self.hashmap[key] = line + return key diff --git a/src/stratigraphy/util/geometric_line_utilities.py b/src/stratigraphy/util/geometric_line_utilities.py index 4e5cdb58..4c9bd877 100644 --- a/src/stratigraphy/util/geometric_line_utilities.py +++ b/src/stratigraphy/util/geometric_line_utilities.py @@ -21,6 +21,26 @@ sys.setrecursionlimit(10000) # required for the quadtree +def deduplicate_lines(lines: list[Line]) -> list[Line]: + """Deduplicate lines by merging lines that are close to each other. + + Args: + lines (list[Line]): The lines to deduplicate. + + Returns: + list[Line]: The deduplicated lines. + """ + deduplicated_lines = [] + for line in lines: + if not _check_if_present(deduplicated_lines, line): + deduplicated_lines.append(line) + return deduplicated_lines + + +def _check_if_present(lines, line: Line) -> bool: + return any(value.start == line.start and value.end == line.end for value in lines) + + def drop_vertical_lines(lines: list[Line], threshold: float = 0.1) -> ArrayLike: """Given a list of lines, remove the lines that are close to vertical. @@ -456,7 +476,6 @@ def merge_parallel_lines_quadtree(lines: list[Line], tol: int, angle_threshold: merged_any = True continue if merged_any: - print("Starting recursion.") return merge_parallel_lines_quadtree( list(indexed_lines.hashmap.values()), tol=tol, angle_threshold=angle_threshold )