diff --git a/segment_anything/automatic_mask_generator.py b/segment_anything/automatic_mask_generator.py index d5a8c9692..b8f775eac 100644 --- a/segment_anything/automatic_mask_generator.py +++ b/segment_anything/automatic_mask_generator.py @@ -189,6 +189,7 @@ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: "point_coords": [mask_data["points"][idx].tolist()], "stability_score": mask_data["stability_score"][idx].item(), "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + "low_res": mask_data["low_res"][idx], } curr_anns.append(ann) @@ -276,7 +277,7 @@ def _process_batch( transformed_points = self.predictor.transform.apply_coords(points, im_size) in_points = torch.as_tensor(transformed_points, device=self.predictor.device) in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) - masks, iou_preds, _ = self.predictor.predict_torch( + masks, iou_preds, low_res = self.predictor.predict_torch( in_points[:, None, :], in_labels[:, None], multimask_output=True, @@ -288,6 +289,7 @@ def _process_batch( masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1), points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + low_res=low_res.flatten(0, 1), ) del masks