Skip to content

Commit

Permalink
feat: add finetuned_object_detection tool (#340)
Browse files Browse the repository at this point in the history
  • Loading branch information
CamiloInx authored Jan 14, 2025
1 parent 64a800b commit cc83d52
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
video_temporal_localization,
vit_image_classification,
vit_nsfw_classification,
custom_object_detection,
)

FINE_TUNE_ID = "65ebba4a-88b7-419f-9046-0750e30250da"
Expand Down Expand Up @@ -55,6 +56,7 @@ def test_owlv2_sam2_instance_segmentation():
assert [res["label"] for res in result] == ["coin"] * len(result)
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_owlv2_object_detection_empty():
result = owlv2_object_detection(
prompt="coin",
Expand Down Expand Up @@ -151,6 +153,7 @@ def test_florence2_phrase_grounding_video():
assert 2 <= len([res["label"] for res in result[0]]) <= 26
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]])


def test_template_match():
img = ski.data.coins()
result = template_match(
Expand Down Expand Up @@ -512,3 +515,13 @@ def test_video_tracking_by_given_model():
assert len(result) == 10
assert len([res["label"] for res in result[0]]) == 24
assert len([res["mask"] for res in result[0]]) == 24


def test_finetuned_object_detection_empty():
img = ski.data.coins()

result = custom_object_detection(
deployment_id="5015ec65-b99b-4d62-bef1-fb6acb87bb9c",
image=img,
)
assert len(result) == 0 # no coin objects detected on the finetuned model
1 change: 1 addition & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
video_temporal_localization,
vit_image_classification,
vit_nsfw_classification,
custom_object_detection,
)

__new_tools__ = [
Expand Down
140 changes: 140 additions & 0 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,13 @@ def _apply_object_detection( # inner method to avoid circular importing issues.
)
function_name = "florence2_object_detection"

elif od_model == ODModels.CUSTOM:
segment_results = custom_object_detection(
deployment_id=fine_tune_id,
image=segment_frames[frame_number],
)
function_name = "custom_object_detection"

else:
raise NotImplementedError(
f"Object detection model '{od_model}' is not implemented."
Expand Down Expand Up @@ -1217,6 +1224,139 @@ def countgd_visual_prompt_object_detection(
return bboxes_formatted


def custom_object_detection(
deployment_id: str,
image: np.ndarray,
box_threshold: float = 0.1,
) -> List[Dict[str, Any]]:
"""'custom_object_detection' is a tool that can detect instances of an
object given a deployment_id of a previously finetuned object detection model.
It is particularly useful when trying to detect objects that are not well detected by generalist models.
It returns a list of bounding boxes with normalized
coordinates, label names and associated confidence scores.
Parameters:
deployment_id (str): The id of the finetuned model.
image (np.ndarray): The image that contains instances of the object.
box_threshold (float, optional): The threshold for detection. Defaults
to 0.1.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
bounding box of the detected objects with normalized coordinates between 0
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
top-left and xmax and ymax are the coordinates of the bottom-right of the
bounding box.
Example
-------
>>> custom_object_detection("abcd1234-5678efg", image)
[
{'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]},
{'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5]},
{'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52]},
{'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58]},
]
"""
image_size = image.shape[:2]
if image_size[0] < 1 or image_size[1] < 1:
return []

files = [("image", numpy_to_bytes(image))]
payload = {
"deployment_id": deployment_id,
"confidence": box_threshold,
}
detections: List[List[Dict[str, Any]]] = send_inference_request(
payload, "custom-object-detection", files=files, v2=True
)

bboxes = detections[0]
bboxes_formatted = [
{
"label": bbox["label"],
"bbox": normalize_bbox(bbox["bounding_box"], image_size),
"score": bbox["score"],
}
for bbox in bboxes
]
display_data = [
{
"label": bbox["label"],
"bbox": bbox["bounding_box"],
"score": bbox["score"],
}
for bbox in bboxes
]

_display_tool_trace(
custom_object_detection.__name__,
payload,
display_data,
files,
)
return bboxes_formatted


def custom_od_sam2_video_tracking(
deployment_id: str,
frames: List[np.ndarray],
chunk_length: Optional[int] = 10,
) -> List[List[Dict[str, Any]]]:
"""'custom_od_sam2_video_tracking' is a tool that can segment multiple objects given a
custom model with predefined category names.
It returns a list of bounding boxes, label names,
mask file names and associated probability scores.
Parameters:
deployment_id (str): The id of the deployed custom model.
image (np.ndarray): The image to ground the prompt to.
chunk_length (Optional[int]): The number of frames to re-run florence2 to find
new objects.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label,
bounding box, and mask of the detected objects with normalized coordinates
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left
and xmax and ymax are the coordinates of the bottom-right of the bounding box.
The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
the background.
Example
-------
>>> custom_od_sam2_video_tracking("abcd1234-5678efg", frames)
[
[
{
'label': '0: dinosaur',
'bbox': [0.1, 0.11, 0.35, 0.4],
'mask': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
},
],
...
]
"""

ret = od_sam2_video_tracking(
ODModels.CUSTOM,
prompt="",
frames=frames,
chunk_length=chunk_length,
fine_tune_id=deployment_id,
)
_display_tool_trace(
custom_od_sam2_video_tracking.__name__,
{},
ret["display_data"],
ret["files"],
)
return ret["return_data"] # type: ignore


def qwen2_vl_images_vqa(prompt: str, images: List[np.ndarray]) -> str:
"""'qwen2_vl_images_vqa' is a tool that can answer any questions about arbitrary
images including regular images or images of documents or presentations. It can be
Expand Down
1 change: 1 addition & 0 deletions vision_agent/utils/video_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ODModels(str, Enum):
COUNTGD = "countgd"
FLORENCE2 = "florence2"
OWLV2 = "owlv2"
CUSTOM = "custom"


def split_frames_into_segments(
Expand Down

0 comments on commit cc83d52

Please sign in to comment.