From a9db7880c107370c6f2c6a1fded88439f9515312 Mon Sep 17 00:00:00 2001 From: JoaoVital Date: Wed, 20 Mar 2024 18:15:19 +0000 Subject: [PATCH] return low_res logits on SamAutomaticMaskGenerator.generate() --- segment_anything/automatic_mask_generator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/segment_anything/automatic_mask_generator.py b/segment_anything/automatic_mask_generator.py index d5a8c9692..b8f775eac 100644 --- a/segment_anything/automatic_mask_generator.py +++ b/segment_anything/automatic_mask_generator.py @@ -189,6 +189,7 @@ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: "point_coords": [mask_data["points"][idx].tolist()], "stability_score": mask_data["stability_score"][idx].item(), "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + "low_res": mask_data["low_res"][idx], } curr_anns.append(ann) @@ -276,7 +277,7 @@ def _process_batch( transformed_points = self.predictor.transform.apply_coords(points, im_size) in_points = torch.as_tensor(transformed_points, device=self.predictor.device) in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) - masks, iou_preds, _ = self.predictor.predict_torch( + masks, iou_preds, low_res = self.predictor.predict_torch( in_points[:, None, :], in_labels[:, None], multimask_output=True, @@ -288,6 +289,7 @@ def _process_batch( masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1), points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + low_res=low_res.flatten(0, 1), ) del masks