Skip to content

Commit

Permalink
bugfixing LineQuadTree class + remove inefficient check for duplicate…
Browse files Browse the repository at this point in the history
… lines
  • Loading branch information
stijnvermeeren-swisstopo committed May 14, 2024
1 parent 8b53d0f commit 872a002
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 52 deletions.
2 changes: 0 additions & 2 deletions src/stratigraphy/line_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from stratigraphy.util.dataclasses import Line
from stratigraphy.util.geometric_line_utilities import (
deduplicate_lines,
drop_vertical_lines,
merge_parallel_lines_quadtree,
)
Expand Down Expand Up @@ -72,7 +71,6 @@ def extract_lines(page: fitz.Page, line_detection_params: dict) -> list[Line]:
scale_factor=line_detection_params["pdf_scale_factor"],
)
lines = drop_vertical_lines(lines, threshold=line_detection_params["vertical_lines_threshold"])
lines = deduplicate_lines(lines)
merging_params = line_detection_params["line_merging_params"]

return merge_parallel_lines_quadtree(
Expand Down
55 changes: 18 additions & 37 deletions src/stratigraphy/util/geometric_line_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,6 @@
logger = logging.getLogger(__name__)


def deduplicate_lines(lines: list[Line]) -> list[Line]:
"""Deduplicate lines by merging lines that are close to each other.
Args:
lines (list[Line]): The lines to deduplicate.
Returns:
list[Line]: The deduplicated lines.
"""
deduplicated_lines = []
for line in lines:
if not _check_if_present(deduplicated_lines, line):
deduplicated_lines.append(line)
return deduplicated_lines


def _check_if_present(lines: list[Line], line: Line) -> bool:
return any(
value.start.distance_to(line.start) < 0.1 and value.end.distance_to(line.end) < 0.1 for value in lines
) # we are on a pixel grid and 0.1 is a reasonable threshold


def drop_vertical_lines(lines: list[Line], threshold: float = 0.1) -> ArrayLike:
"""Given a list of lines, remove the lines that are close to vertical.
Expand Down Expand Up @@ -248,25 +226,28 @@ def merge_parallel_lines_quadtree(lines: list[Line], tol: int, angle_threshold:
height = max(max_end_y, max_start_y)
lines_quad_tree = LinesQuadTree(width, height)

keys_queue = queue.Queue()
print("lines", len(lines))
for line in lines:
lines_quad_tree.add(line)
line_key = lines_quad_tree.add(line)
keys_queue.put(line_key)

keys_queue = queue.Queue()
for key in lines_quad_tree.hashmap:
keys_queue.put(key)
while not keys_queue.empty():
line_key = keys_queue.get()

for neighbour_key, neighbour_line in lines_quad_tree.neighbouring_lines(line_key, tol):
if _are_parallel(line, neighbour_line, angle_threshold=angle_threshold) and _are_close(
line, neighbour_line, tol=tol
):
new_line = _merge_lines(line, neighbour_line)
if new_line is not None:
lines_quad_tree.remove(neighbour_key)
lines_quad_tree.remove(line_key)
new_key = lines_quad_tree.add(new_line)
keys_queue.put(new_key)
break
if line_key in lines_quad_tree.hashmap:
line = lines_quad_tree.hashmap[line_key]

for neighbour_key, neighbour_line in lines_quad_tree.neighbouring_lines(line_key, tol).items():
if _are_parallel(line, neighbour_line, angle_threshold=angle_threshold) and _are_close(
line, neighbour_line, tol=tol
):
new_line = _merge_lines(line, neighbour_line)
if new_line is not None:
lines_quad_tree.remove(neighbour_key)
lines_quad_tree.remove(line_key)
new_key = lines_quad_tree.add(new_line)
keys_queue.put(new_key)
break

return list(lines_quad_tree.hashmap.values())
30 changes: 17 additions & 13 deletions src/stratigraphy/util/linesquadtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,21 @@ def __init__(self, width: float, height: float):
def remove(self, line_key: str):
if line_key in self.hashmap:
line = self.hashmap[line_key]
self._qtree_delete((line.start.x, line.start.y), line_key)
self._qtree_delete((line.end.x, line.end.y), line_key)
self._qtree_delete(line.start, line_key)
self._qtree_delete(line.end, line_key)
del self.hashmap[line_key]

def add(self, line: Line) -> str:
line_key = uuid.uuid4().hex
self.hashmap[line_key] = line

self._qtree_insert((line.start.x, line.start.y), line_key)
self._qtree_insert((line.end.x, line.end.y), line_key)
# We round the coordinates, as we don't require infinite precision anyway, and like this we avoid excessive
# recursion within the quad tree in the case of floating point values that are very close to each other.
self._qtree_insert(line.start, line_key)
self._qtree_insert(line.end, line_key)
return line_key

def neighbouring_lines(self, line_key: str, tol: float) -> list[(str, Line)]:
def neighbouring_lines(self, line_key: str, tol: float) -> dict[str, Line]:
"""Efficiently search for all the lines that have a start or end point close to the given line.
Args:
Expand All @@ -43,7 +45,7 @@ def neighbouring_lines(self, line_key: str, tol: float) -> list[(str, Line)]:
from the bounding box formed by the start and end points of the given line.
Returns:
list[(str, Line)]: The lines that are close to the given line, returned as a tuples (line_key, line).
dict[str, Line]: The lines that are close to the given line, returned as a dict of (line_key, line) pairs.
"""
if line_key not in self.hashmap:
return []
Expand All @@ -56,21 +58,23 @@ def neighbouring_lines(self, line_key: str, tol: float) -> list[(str, Line)]:
bb = quads.BoundingBox(min_x - tol, min_y - tol, max_x + tol, max_y + tol)
points = self.qtree.within_bb(bb)

neighbouring_lines = []
neighbouring_lines = {}
for point in points:
for neighbour_key in point.data:
if neighbour_key != line_key and neighbour_key in self.hashmap:
neighbouring_lines.append((neighbour_key, self.hashmap[neighbour_key]))
neighbouring_lines[neighbour_key] = self.hashmap[neighbour_key]
return neighbouring_lines

def _qtree_insert(self, point: Point | tuple, line_key: str):
qtree_point = self.qtree.find(point)
def _qtree_insert(self, point: Point, line_key: str):
coordinates = (round(point.x), round(point.y))
qtree_point = self.qtree.find(coordinates)
if qtree_point:
qtree_point.data.add(line_key)
else:
self.qtree.insert(point, data={line_key})
self.qtree.insert(coordinates, data={line_key})

def _qtree_delete(self, point: Point | tuple, line_key: str):
qtree_point = self.qtree.find(point)
def _qtree_delete(self, point: Point, line_key: str):
coordinates = (round(point.x), round(point.y))
qtree_point = self.qtree.find(coordinates)
if qtree_point:
qtree_point.data.remove(line_key)

0 comments on commit 872a002

Please sign in to comment.