From cc83d526179c905cbee1dbe6dc6f7b55271833a9 Mon Sep 17 00:00:00 2001 From: Camilo Iral Date: Tue, 14 Jan 2025 13:57:49 -0500 Subject: [PATCH] feat: add finetuned_object_detection tool (#340) --- tests/integ/test_tools.py | 13 +++ vision_agent/tools/__init__.py | 1 + vision_agent/tools/tools.py | 140 +++++++++++++++++++++++++++ vision_agent/utils/video_tracking.py | 1 + 4 files changed, 155 insertions(+) diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index f6b1511c..1d1fbb6e 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -28,6 +28,7 @@ video_temporal_localization, vit_image_classification, vit_nsfw_classification, + custom_object_detection, ) FINE_TUNE_ID = "65ebba4a-88b7-419f-9046-0750e30250da" @@ -55,6 +56,7 @@ def test_owlv2_sam2_instance_segmentation(): assert [res["label"] for res in result] == ["coin"] * len(result) assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result]) + def test_owlv2_object_detection_empty(): result = owlv2_object_detection( prompt="coin", @@ -151,6 +153,7 @@ def test_florence2_phrase_grounding_video(): assert 2 <= len([res["label"] for res in result[0]]) <= 26 assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]]) + def test_template_match(): img = ski.data.coins() result = template_match( @@ -512,3 +515,13 @@ def test_video_tracking_by_given_model(): assert len(result) == 10 assert len([res["label"] for res in result[0]]) == 24 assert len([res["mask"] for res in result[0]]) == 24 + + +def test_finetuned_object_detection_empty(): + img = ski.data.coins() + + result = custom_object_detection( + deployment_id="5015ec65-b99b-4d62-bef1-fb6acb87bb9c", + image=img, + ) + assert len(result) == 0 # no coin objects detected on the finetuned model diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index a7974151..4b6ddffd 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -63,6 +63,7 @@ video_temporal_localization, vit_image_classification, vit_nsfw_classification, + custom_object_detection, ) __new_tools__ = [ diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 50893fca..b90d0bd4 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -290,6 +290,13 @@ def _apply_object_detection( # inner method to avoid circular importing issues. ) function_name = "florence2_object_detection" + elif od_model == ODModels.CUSTOM: + segment_results = custom_object_detection( + deployment_id=fine_tune_id, + image=segment_frames[frame_number], + ) + function_name = "custom_object_detection" + else: raise NotImplementedError( f"Object detection model '{od_model}' is not implemented." @@ -1217,6 +1224,139 @@ def countgd_visual_prompt_object_detection( return bboxes_formatted +def custom_object_detection( + deployment_id: str, + image: np.ndarray, + box_threshold: float = 0.1, +) -> List[Dict[str, Any]]: + """'custom_object_detection' is a tool that can detect instances of an + object given a deployment_id of a previously finetuned object detection model. + It is particularly useful when trying to detect objects that are not well detected by generalist models. + It returns a list of bounding boxes with normalized + coordinates, label names and associated confidence scores. + + Parameters: + deployment_id (str): The id of the finetuned model. + image (np.ndarray): The image that contains instances of the object. + box_threshold (float, optional): The threshold for detection. Defaults + to 0.1. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the score, label, and + bounding box of the detected objects with normalized coordinates between 0 + and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the + top-left and xmax and ymax are the coordinates of the bottom-right of the + bounding box. + + Example + ------- + >>> custom_object_detection("abcd1234-5678efg", image) + [ + {'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]}, + {'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5]}, + {'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52]}, + {'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58]}, + ] + """ + image_size = image.shape[:2] + if image_size[0] < 1 or image_size[1] < 1: + return [] + + files = [("image", numpy_to_bytes(image))] + payload = { + "deployment_id": deployment_id, + "confidence": box_threshold, + } + detections: List[List[Dict[str, Any]]] = send_inference_request( + payload, "custom-object-detection", files=files, v2=True + ) + + bboxes = detections[0] + bboxes_formatted = [ + { + "label": bbox["label"], + "bbox": normalize_bbox(bbox["bounding_box"], image_size), + "score": bbox["score"], + } + for bbox in bboxes + ] + display_data = [ + { + "label": bbox["label"], + "bbox": bbox["bounding_box"], + "score": bbox["score"], + } + for bbox in bboxes + ] + + _display_tool_trace( + custom_object_detection.__name__, + payload, + display_data, + files, + ) + return bboxes_formatted + + +def custom_od_sam2_video_tracking( + deployment_id: str, + frames: List[np.ndarray], + chunk_length: Optional[int] = 10, +) -> List[List[Dict[str, Any]]]: + """'custom_od_sam2_video_tracking' is a tool that can segment multiple objects given a + custom model with predefined category names. + It returns a list of bounding boxes, label names, + mask file names and associated probability scores. + + Parameters: + deployment_id (str): The id of the deployed custom model. + image (np.ndarray): The image to ground the prompt to. + chunk_length (Optional[int]): The number of frames to re-run florence2 to find + new objects. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the score, label, + bounding box, and mask of the detected objects with normalized coordinates + (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left + and xmax and ymax are the coordinates of the bottom-right of the bounding box. + The mask is binary 2D numpy array where 1 indicates the object and 0 indicates + the background. + + Example + ------- + >>> custom_od_sam2_video_tracking("abcd1234-5678efg", frames) + [ + [ + { + 'label': '0: dinosaur', + 'bbox': [0.1, 0.11, 0.35, 0.4], + 'mask': array([[0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0], + ..., + [0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0]], dtype=uint8), + }, + ], + ... + ] + """ + + ret = od_sam2_video_tracking( + ODModels.CUSTOM, + prompt="", + frames=frames, + chunk_length=chunk_length, + fine_tune_id=deployment_id, + ) + _display_tool_trace( + custom_od_sam2_video_tracking.__name__, + {}, + ret["display_data"], + ret["files"], + ) + return ret["return_data"] # type: ignore + + def qwen2_vl_images_vqa(prompt: str, images: List[np.ndarray]) -> str: """'qwen2_vl_images_vqa' is a tool that can answer any questions about arbitrary images including regular images or images of documents or presentations. It can be diff --git a/vision_agent/utils/video_tracking.py b/vision_agent/utils/video_tracking.py index 3f636f5e..6fdec5ee 100644 --- a/vision_agent/utils/video_tracking.py +++ b/vision_agent/utils/video_tracking.py @@ -17,6 +17,7 @@ class ODModels(str, Enum): COUNTGD = "countgd" FLORENCE2 = "florence2" OWLV2 = "owlv2" + CUSTOM = "custom" def split_frames_into_segments(