Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Update Message & Filter Double Tool Calls #345

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vision_agent/agent/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class AgentMessage(BaseModel):
Literal["interaction_response"],
Literal["conversation"],
Literal["planner"],
Literal["planner_update"],
Literal["coder"],
]
content: str
Expand Down
19 changes: 18 additions & 1 deletion vision_agent/agent/vision_agent_planner_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,25 @@ def create_hil_response(
except Exception:
continue

# There's a chance that the same tool is called multiple times with the same inputs
# in the interaction. We want to remove duplicates to avoid redundancy by picking
# the last occurrence of the tool.
cleaned_content = []
seen_content = set()
for c in reversed(content):
if "request" in c and "function_name" in c["request"] and "files" in c:
key = (c["request"]["function_name"], hash(c["files"]))
if key in seen_content:
continue

seen_content.add(key)
cleaned_content.append(c)
else:
cleaned_content.append(c)

return AgentMessage(
role="interaction",
content="<interaction>" + json.dumps(content) + "</interaction>",
content="<interaction>" + json.dumps(cleaned_content) + "</interaction>",
media=None,
)

Expand Down Expand Up @@ -513,6 +529,7 @@ def generate_plan(
code = extract_tag(response, "execute_python")
finalize_plan = extract_tag(response, "finalize_plan")
finished = finalize_plan is not None
self.update_callback({"role": "planner_update", "content": response})

if self.verbose:
_CONSOLE.print(
Expand Down
8 changes: 4 additions & 4 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@
TOOLS_INFO,
UTIL_TOOLS,
UTILITIES_DOCSTRING,
agentic_object_detection,
agentic_sam2_instance_segmentation,
agentic_sam2_video_tracking,
claude35_text_extraction,
closest_box_distance,
closest_mask_distance,
countgd_object_detection,
countgd_sam2_instance_segmentation,
countgd_sam2_video_tracking,
countgd_visual_prompt_object_detection,
custom_object_detection,
depth_anything_v2,
detr_segmentation,
document_extraction,
Expand Down Expand Up @@ -63,10 +67,6 @@
video_temporal_localization,
vit_image_classification,
vit_nsfw_classification,
custom_object_detection,
agentic_object_detection,
agentic_sam2_instance_segmentation,
agentic_sam2_video_tracking,
)

__new_tools__ = [
Expand Down
22 changes: 14 additions & 8 deletions vision_agent/tools/planner_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,7 @@
from PIL import Image

import vision_agent.tools as T
from vision_agent.agent.agent_utils import (
DefaultImports,
extract_code,
extract_json,
extract_tag,
)
from vision_agent.agent.agent_utils import DefaultImports, extract_json, extract_tag
from vision_agent.agent.vision_agent_planner_prompts_v2 import (
CATEGORIZE_TOOL_REQUEST,
FINALIZE_PLAN,
Expand All @@ -36,6 +31,9 @@
from vision_agent.utils.sim import get_tool_recommender

TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}
LOAD_TOOLS_DOCSTRING = T.get_tool_documentation(
[T.load_image, T.extract_frames_and_timestamps]
)

CONFIG = Config()
_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -179,6 +177,7 @@ def run_tool_testing(
cleaned_tool_docs.append(tool_doc)
tool_docs = cleaned_tool_docs
tool_docs_str = "\n".join([e["doc"] for e in tool_docs])
tool_docs_str += "\n" + LOAD_TOOLS_DOCSTRING

prompt = TEST_TOOLS.format(
tool_docs=tool_docs_str,
Expand Down Expand Up @@ -217,8 +216,15 @@ def run_tool_testing(
examples=EXAMPLES,
media=str(image_paths),
)
code = extract_code(lmm.generate(prompt, media=image_paths)) # type: ignore
code = process_code(code)
response = cast(str, lmm.generate(prompt, media=image_paths))
code = extract_tag(response, "code")
if code is None:
code = response

try:
code = process_code(code)
except Exception as e:
_LOGGER.error(f"Error processing code: {e}")
tool_output = code_interpreter.exec_isolation(
DefaultImports.prepend_imports(code)
)
Expand Down
47 changes: 29 additions & 18 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def sam2(
ret = _sam2(image, detections, image_size)
_display_tool_trace(
sam2.__name__,
{},
{"detections": detections},
ret["display_data"],
ret["files"],
)
Expand Down Expand Up @@ -314,18 +314,29 @@ def _apply_object_detection( # inner method to avoid circular importing issues.

# Process each segment and collect detections
detections_per_segment: List[Any] = []
for segment_index, segment in enumerate(segments):
segment_detections = process_segment(
segment_frames=segment,
od_model=od_model,
prompt=prompt,
fine_tune_id=fine_tune_id,
chunk_length=chunk_length,
image_size=image_size,
segment_index=segment_index,
object_detection_tool=_apply_object_detection,
)
detections_per_segment.append(segment_detections)
with ThreadPoolExecutor() as executor:
futures = {
executor.submit(
process_segment,
segment_frames=segment,
od_model=od_model,
prompt=prompt,
fine_tune_id=fine_tune_id,
chunk_length=chunk_length,
image_size=image_size,
segment_index=segment_index,
object_detection_tool=_apply_object_detection,
): segment_index
for segment_index, segment in enumerate(segments)
}

for future in as_completed(futures):
segment_index = futures[future]
detections_per_segment.append((segment_index, future.result()))

detections_per_segment = [
x[1] for x in sorted(detections_per_segment, key=lambda x: x[0])
]

merged_detections = merge_segments(detections_per_segment)
post_processed = post_process(merged_detections, image_size)
Expand Down Expand Up @@ -390,15 +401,15 @@ def _owlv2_object_detection(
{
"label": bbox["label"],
"bbox": normalize_bbox(bbox["bounding_box"], image_size),
"score": bbox["score"],
"score": round(bbox["score"], 2),
}
for bbox in bboxes
]
display_data = [
{
"label": bbox["label"],
"bbox": bbox["bounding_box"],
"score": bbox["score"],
"score": round(bbox["score"], 2),
}
for bbox in bboxes
]
Expand Down Expand Up @@ -582,7 +593,7 @@ def owlv2_sam2_video_tracking(
)
_display_tool_trace(
owlv2_sam2_video_tracking.__name__,
{},
{"prompt": prompt, "chunk_length": chunk_length},
ret["display_data"],
ret["files"],
)
Expand Down Expand Up @@ -2150,7 +2161,7 @@ def siglip_classification(image: np.ndarray, labels: List[str]) -> Dict[str, Any
return response


# agentic od tools
# Agentic OD Tools


def _agentic_object_detection(
Expand Down Expand Up @@ -2646,7 +2657,7 @@ def save_image(image: np.ndarray, file_path: str) -> None:


def save_video(
frames: List[np.ndarray], output_video_path: Optional[str] = None, fps: float = 1
frames: List[np.ndarray], output_video_path: Optional[str] = None, fps: float = 5
) -> str:
"""'save_video' is a utility function that saves a list of frames as a mp4 video file on disk.

Expand Down
Loading