From 965f569e811b96672789639456d61ee5b6b313f8 Mon Sep 17 00:00:00 2001 From: aspaul20 <87422803+aspaul20@users.noreply.github.com> Date: Fri, 24 May 2024 13:16:37 +0500 Subject: [PATCH] added sliding window for large image inference (#12152) added sliding window for large image inference --- doc/doc_en/slice_en.md | 16 ++++++ paddleocr.py | 4 +- tools/infer/predict_system.py | 35 ++++++++++++- tools/infer/utility.py | 98 +++++++++++++++++++++++++++++++++++ 4 files changed, 150 insertions(+), 3 deletions(-) create mode 100644 doc/doc_en/slice_en.md diff --git a/doc/doc_en/slice_en.md b/doc/doc_en/slice_en.md new file mode 100644 index 0000000000..e48f7e836c --- /dev/null +++ b/doc/doc_en/slice_en.md @@ -0,0 +1,16 @@ +# Slice Operator +If you have a very large image/document that you would like to run PaddleOCR (detection and recognition) on, you can use the slice operation as follows: + +`ocr_inst = PaddleOCR(**ocr_settings)` +`results = ocr_inst.ocr(img, det=True,rec=True, slice=slice, cls=False,bin=False,inv=False,alpha_color=False)` + +where +`slice = {'horizontal_stride': h_stride, 'vertical_stride':v_stride, 'merge_x_thres':x_thres, 'merge_y_thres': y_thres}` + +Here, `h_stride`, `v_stride`, `x_thres`, and `y_thres` are user-configurable values and need to be set manually. The way the `slice` operator works is that it runs a sliding window across the large input image, creating slices of it and runs the OCR algorithms on it. + +The fragmented slice-level results are then merged together to output image-level detection and recognition results. The horizontal and vertical strides cannot be lower than a certain limit (as too low values would create so many slices it would be very computationally expensive to get results for each of them). However, as an example the recommended values for an image with dimensions 6616x14886 would be as follows. + +`slice = {'horizontal_stride': 300, 'vertical_stride':500, 'merge_x_thres':50, 'merge_y_thres': 35}` + +All slice-level detections with bounding boxes as close as `merge_x_thres` and `merge_y_thres` will be merged together. diff --git a/paddleocr.py b/paddleocr.py index 6e8c66ed3c..65b4c850a1 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -679,6 +679,7 @@ def ocr( bin=False, inv=False, alpha_color=(255, 255, 255), + slice={}, ): """ OCR with PaddleOCR @@ -691,6 +692,7 @@ def ocr( bin: binarize image to black and white. Default is False. inv: invert image colors. Default is False. alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white. + slice: use sliding window inference for large images, det and rec must be True. Requires int values for slice["horizontal_stride"], slice["vertical_stride"], slice["merge_x_thres"], slice["merge_y_thres] (See doc/doc_en/slice_en.md). Default is {}. """ assert isinstance(img, (np.ndarray, list, str, bytes)) if isinstance(img, list) and det == True: @@ -723,7 +725,7 @@ def preprocess_image(_image): ocr_res = [] for idx, img in enumerate(imgs): img = preprocess_image(img) - dt_boxes, rec_res, _ = self.__call__(img, cls) + dt_boxes, rec_res, _ = self.__call__(img, cls, slice) if not dt_boxes and not rec_res: ocr_res.append(None) continue diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 95b199b2a0..aaf63922c5 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -38,6 +38,8 @@ draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop, + slice_generator, + merge_fragmented, ) logger = get_logger() @@ -71,7 +73,7 @@ def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res): logger.debug(f"{bno}, {rec_res[bno]}") self.crop_image_res_index += bbox_num - def __call__(self, img, cls=True): + def __call__(self, img, cls=True, slice={}): time_dict = {"det": 0, "rec": 0, "cls": 0, "all": 0} if img is None: @@ -80,7 +82,32 @@ def __call__(self, img, cls=True): start = time.time() ori_im = img.copy() - dt_boxes, elapse = self.text_detector(img) + if slice: + slice_gen = slice_generator( + img, + horizontal_stride=slice["horizontal_stride"], + vertical_stride=slice["vertical_stride"], + ) + elapsed = [] + dt_slice_boxes = [] + for slice_crop, v_start, h_start in slice_gen: + dt_boxes, elapse = self.text_detector(slice_crop) + if dt_boxes.size: + dt_boxes[:, :, 0] += h_start + dt_boxes[:, :, 1] += v_start + dt_slice_boxes.append(dt_boxes) + elapsed.append(elapse) + dt_boxes = np.concatenate(dt_slice_boxes) + + dt_boxes = merge_fragmented( + boxes=dt_boxes, + x_threshold=slice["merge_x_thres"], + y_threshold=slice["merge_y_thres"], + ) + elapse = sum(elapsed) + else: + dt_boxes, elapse = self.text_detector(img) + time_dict["det"] = elapse if dt_boxes is None: @@ -109,6 +136,10 @@ def __call__(self, img, cls=True): logger.debug( "cls num : {}, elapsed : {}".format(len(img_crop_list), elapse) ) + if len(img_crop_list) > 1000: + logger.debug( + f"rec crops num: {len(img_crop_list)}, time and memory cost may be large." + ) rec_res, elapse = self.text_recognizer(img_crop_list) time_dict["rec"] = elapse diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 61f4ffacc9..4a734683de 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -692,6 +692,104 @@ def get_minarea_rect_crop(img, points): return crop_img +def slice_generator(image, horizontal_stride, vertical_stride, maximum_slices=500): + if not isinstance(image, np.ndarray): + image = np.array(image) + + image_h, image_w = image.shape[:2] + vertical_num_slices = (image_h + vertical_stride - 1) // vertical_stride + horizontal_num_slices = (image_w + horizontal_stride - 1) // horizontal_stride + + assert ( + vertical_num_slices > 0 + ), f"Invalid number ({vertical_num_slices}) of vertical slices" + + assert ( + horizontal_num_slices > 0 + ), f"Invalid number ({horizontal_num_slices}) of horizontal slices" + + if vertical_num_slices >= maximum_slices: + recommended_vertical_stride = max(1, image_h // maximum_slices) + 1 + assert ( + False + ), f"Too computationally expensive with {vertical_num_slices} slices, try a higher vertical stride (recommended minimum: {recommended_vertical_stride})" + + if horizontal_num_slices >= maximum_slices: + recommended_horizontal_stride = max(1, image_w // maximum_slices) + 1 + assert ( + False + ), f"Too computationally expensive with {horizontal_num_slices} slices, try a higher horizontal stride (recommended minimum: {recommended_horizontal_stride})" + + for v_slice_idx in range(vertical_num_slices): + v_start = max(0, (v_slice_idx * vertical_stride)) + v_end = min(((v_slice_idx + 1) * vertical_stride), image_h) + vertical_slice = image[v_start:v_end, :] + for h_slice_idx in range(horizontal_num_slices): + h_start = max(0, (h_slice_idx * horizontal_stride)) + h_end = min(((h_slice_idx + 1) * horizontal_stride), image_w) + horizontal_slice = vertical_slice[:, h_start:h_end] + + yield (horizontal_slice, v_start, h_start) + + +def calculate_box_extents(box): + min_x = box[0][0] + max_x = box[1][0] + min_y = box[0][1] + max_y = box[2][1] + return min_x, max_x, min_y, max_y + + +def merge_boxes(box1, box2, x_threshold, y_threshold): + min_x1, max_x1, min_y1, max_y1 = calculate_box_extents(box1) + min_x2, max_x2, min_y2, max_y2 = calculate_box_extents(box2) + + if ( + abs(min_y1 - min_y2) <= y_threshold + and abs(max_y1 - max_y2) <= y_threshold + and abs(max_x1 - min_x2) <= x_threshold + ): + new_xmin = min(min_x1, min_x2) + new_xmax = max(max_x1, max_x2) + new_ymin = min(min_y1, min_y2) + new_ymax = max(max_y1, max_y2) + return [ + [new_xmin, new_ymin], + [new_xmax, new_ymin], + [new_xmax, new_ymax], + [new_xmin, new_ymax], + ] + else: + return None + + +def merge_fragmented(boxes, x_threshold=10, y_threshold=10): + merged_boxes = [] + visited = set() + + for i, box1 in enumerate(boxes): + if i in visited: + continue + + merged_box = [point[:] for point in box1] + + for j, box2 in enumerate(boxes[i + 1 :], start=i + 1): + if j not in visited: + merged_result = merge_boxes( + merged_box, box2, x_threshold=x_threshold, y_threshold=y_threshold + ) + if merged_result: + merged_box = merged_result + visited.add(j) + + merged_boxes.append(merged_box) + + if len(merged_boxes) == len(boxes): + return np.array(merged_boxes) + else: + return merge_fragmented(merged_boxes, x_threshold, y_threshold) + + def check_gpu(use_gpu): if use_gpu and ( not paddle.is_compiled_with_cuda() or paddle.device.get_device() == "cpu"