From 872a002edf60d9f0ab20b3e89b40fc0c4ed73d9a Mon Sep 17 00:00:00 2001 From: Stijn Vermeeren Date: Tue, 14 May 2024 10:23:46 +0200 Subject: [PATCH] bugfixing LineQuadTree class + remove inefficient check for duplicate lines --- src/stratigraphy/line_detection.py | 2 - .../util/geometric_line_utilities.py | 55 ++++++------------- src/stratigraphy/util/linesquadtree.py | 30 +++++----- 3 files changed, 35 insertions(+), 52 deletions(-) diff --git a/src/stratigraphy/line_detection.py b/src/stratigraphy/line_detection.py index d297e1be..aa49cc06 100644 --- a/src/stratigraphy/line_detection.py +++ b/src/stratigraphy/line_detection.py @@ -11,7 +11,6 @@ from stratigraphy.util.dataclasses import Line from stratigraphy.util.geometric_line_utilities import ( - deduplicate_lines, drop_vertical_lines, merge_parallel_lines_quadtree, ) @@ -72,7 +71,6 @@ def extract_lines(page: fitz.Page, line_detection_params: dict) -> list[Line]: scale_factor=line_detection_params["pdf_scale_factor"], ) lines = drop_vertical_lines(lines, threshold=line_detection_params["vertical_lines_threshold"]) - lines = deduplicate_lines(lines) merging_params = line_detection_params["line_merging_params"] return merge_parallel_lines_quadtree( diff --git a/src/stratigraphy/util/geometric_line_utilities.py b/src/stratigraphy/util/geometric_line_utilities.py index b19a4569..88d927e8 100644 --- a/src/stratigraphy/util/geometric_line_utilities.py +++ b/src/stratigraphy/util/geometric_line_utilities.py @@ -15,28 +15,6 @@ logger = logging.getLogger(__name__) -def deduplicate_lines(lines: list[Line]) -> list[Line]: - """Deduplicate lines by merging lines that are close to each other. - - Args: - lines (list[Line]): The lines to deduplicate. - - Returns: - list[Line]: The deduplicated lines. - """ - deduplicated_lines = [] - for line in lines: - if not _check_if_present(deduplicated_lines, line): - deduplicated_lines.append(line) - return deduplicated_lines - - -def _check_if_present(lines: list[Line], line: Line) -> bool: - return any( - value.start.distance_to(line.start) < 0.1 and value.end.distance_to(line.end) < 0.1 for value in lines - ) # we are on a pixel grid and 0.1 is a reasonable threshold - - def drop_vertical_lines(lines: list[Line], threshold: float = 0.1) -> ArrayLike: """Given a list of lines, remove the lines that are close to vertical. @@ -248,25 +226,28 @@ def merge_parallel_lines_quadtree(lines: list[Line], tol: int, angle_threshold: height = max(max_end_y, max_start_y) lines_quad_tree = LinesQuadTree(width, height) + keys_queue = queue.Queue() + print("lines", len(lines)) for line in lines: - lines_quad_tree.add(line) + line_key = lines_quad_tree.add(line) + keys_queue.put(line_key) - keys_queue = queue.Queue() - for key in lines_quad_tree.hashmap: - keys_queue.put(key) while not keys_queue.empty(): line_key = keys_queue.get() - for neighbour_key, neighbour_line in lines_quad_tree.neighbouring_lines(line_key, tol): - if _are_parallel(line, neighbour_line, angle_threshold=angle_threshold) and _are_close( - line, neighbour_line, tol=tol - ): - new_line = _merge_lines(line, neighbour_line) - if new_line is not None: - lines_quad_tree.remove(neighbour_key) - lines_quad_tree.remove(line_key) - new_key = lines_quad_tree.add(new_line) - keys_queue.put(new_key) - break + if line_key in lines_quad_tree.hashmap: + line = lines_quad_tree.hashmap[line_key] + + for neighbour_key, neighbour_line in lines_quad_tree.neighbouring_lines(line_key, tol).items(): + if _are_parallel(line, neighbour_line, angle_threshold=angle_threshold) and _are_close( + line, neighbour_line, tol=tol + ): + new_line = _merge_lines(line, neighbour_line) + if new_line is not None: + lines_quad_tree.remove(neighbour_key) + lines_quad_tree.remove(line_key) + new_key = lines_quad_tree.add(new_line) + keys_queue.put(new_key) + break return list(lines_quad_tree.hashmap.values()) diff --git a/src/stratigraphy/util/linesquadtree.py b/src/stratigraphy/util/linesquadtree.py index b6aa6b2a..5f5c9bbb 100644 --- a/src/stratigraphy/util/linesquadtree.py +++ b/src/stratigraphy/util/linesquadtree.py @@ -22,19 +22,21 @@ def __init__(self, width: float, height: float): def remove(self, line_key: str): if line_key in self.hashmap: line = self.hashmap[line_key] - self._qtree_delete((line.start.x, line.start.y), line_key) - self._qtree_delete((line.end.x, line.end.y), line_key) + self._qtree_delete(line.start, line_key) + self._qtree_delete(line.end, line_key) del self.hashmap[line_key] def add(self, line: Line) -> str: line_key = uuid.uuid4().hex self.hashmap[line_key] = line - self._qtree_insert((line.start.x, line.start.y), line_key) - self._qtree_insert((line.end.x, line.end.y), line_key) + # We round the coordinates, as we don't require infinite precision anyway, and like this we avoid excessive + # recursion within the quad tree in the case of floating point values that are very close to each other. + self._qtree_insert(line.start, line_key) + self._qtree_insert(line.end, line_key) return line_key - def neighbouring_lines(self, line_key: str, tol: float) -> list[(str, Line)]: + def neighbouring_lines(self, line_key: str, tol: float) -> dict[str, Line]: """Efficiently search for all the lines that have a start or end point close to the given line. Args: @@ -43,7 +45,7 @@ def neighbouring_lines(self, line_key: str, tol: float) -> list[(str, Line)]: from the bounding box formed by the start and end points of the given line. Returns: - list[(str, Line)]: The lines that are close to the given line, returned as a tuples (line_key, line). + dict[str, Line]: The lines that are close to the given line, returned as a dict of (line_key, line) pairs. """ if line_key not in self.hashmap: return [] @@ -56,21 +58,23 @@ def neighbouring_lines(self, line_key: str, tol: float) -> list[(str, Line)]: bb = quads.BoundingBox(min_x - tol, min_y - tol, max_x + tol, max_y + tol) points = self.qtree.within_bb(bb) - neighbouring_lines = [] + neighbouring_lines = {} for point in points: for neighbour_key in point.data: if neighbour_key != line_key and neighbour_key in self.hashmap: - neighbouring_lines.append((neighbour_key, self.hashmap[neighbour_key])) + neighbouring_lines[neighbour_key] = self.hashmap[neighbour_key] return neighbouring_lines - def _qtree_insert(self, point: Point | tuple, line_key: str): - qtree_point = self.qtree.find(point) + def _qtree_insert(self, point: Point, line_key: str): + coordinates = (round(point.x), round(point.y)) + qtree_point = self.qtree.find(coordinates) if qtree_point: qtree_point.data.add(line_key) else: - self.qtree.insert(point, data={line_key}) + self.qtree.insert(coordinates, data={line_key}) - def _qtree_delete(self, point: Point | tuple, line_key: str): - qtree_point = self.qtree.find(point) + def _qtree_delete(self, point: Point, line_key: str): + coordinates = (round(point.x), round(point.y)) + qtree_point = self.qtree.find(coordinates) if qtree_point: qtree_point.data.remove(line_key)