diff --git a/src/grits.py b/src/grits.py index c9d4828..eba104c 100644 --- a/src/grits.py +++ b/src/grits.py @@ -227,12 +227,9 @@ def iou(bbox1, bbox2): Compute the intersection-over-union of two bounding boxes. """ intersection = Rect(bbox1).intersect(bbox2) - union = Rect(bbox1).include_rect(bbox2) - - union_area = union.get_area() + union_area = Rect(bbox1).get_area() + Rect(bbox2).get_area() - intersection.get_area() if union_area > 0: - return intersection.get_area() / union.get_area() - + return intersection.get_area() / union_area return 0 diff --git a/src/postprocess.py b/src/postprocess.py index 25feaee..3bfd68c 100644 --- a/src/postprocess.py +++ b/src/postprocess.py @@ -5,6 +5,8 @@ from fitz import Rect +from grits import iou + def apply_threshold(objects, threshold): """ @@ -31,20 +33,6 @@ def apply_class_thresholds(bboxes, labels, scores, class_names, class_thresholds return bboxes, scores, labels -def iou(bbox1, bbox2): - """ - Compute the intersection-over-union of two bounding boxes. - """ - intersection = Rect(bbox1).intersect(bbox2) - union = Rect(bbox1).include_rect(bbox2) - - union_area = union.get_area() - if union_area > 0: - return intersection.get_area() / union.get_area() - - return 0 - - def iob(bbox1, bbox2): """ Compute the intersection area over box area, for bbox1.