Skip to content

Commit

Permalink
feat: florencev2 fine-tuning meta tool (#190)
Browse files Browse the repository at this point in the history
* add florencev2 fine tuning

* add pre-commit

* add task customization

* tools to meta tools
Dayof authored Aug 7, 2024

Verified

This commit was signed with the committer’s verified signature.
srtfisher Sean Fisher
1 parent 87a7b65 commit ae97907
Showing 12 changed files with 652 additions and 371 deletions.
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
repos:
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
language_version: python3.9
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
844 changes: 484 additions & 360 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@ packages = [{include = "vision_agent"}]

[tool.poetry.dependencies] # main dependency group
python = ">=3.9,<4.0"

numpy = ">=1.21.0,<2.0.0"
pillow = "10.*"
requests = "2.*"
@@ -60,6 +61,7 @@ mkdocstrings = {extras = ["python"], version = "^0.23.0"}
mkdocs-material = "^9.4.2"
types-tabulate = "^0.9.0.20240106"
scikit-image = "<0.23.1"
pre-commit = "^3.8.0"

[tool.pytest.ini_options]
log_cli = true
@@ -90,7 +92,6 @@ warn_unused_configs = true
warn_unused_ignores = true
warn_return_any = true
show_error_codes = true
disallow_any_unimported = true

[[tool.mypy.overrides]]
ignore_missing_imports = true
@@ -101,5 +102,5 @@ module = [
"sentence_transformers.*",
"moviepy.*",
"e2b_code_interpreter.*",
"e2b.*",
"e2b.*"
]
2 changes: 1 addition & 1 deletion vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ class DefaultImports:
code = [
"from typing import *",
"from vision_agent.utils.execute import CodeInterpreter",
"from vision_agent.tools.meta_tools import generate_vision_code, edit_vision_code, open_file, create_file, scroll_up, scroll_down, edit_file, get_tool_descriptions",
"from vision_agent.tools.meta_tools import generate_vision_code, edit_vision_code, open_file, create_file, scroll_up, scroll_down, edit_file, get_tool_descriptions, florencev2_fine_tuning",
]

@staticmethod
Empty file.
46 changes: 46 additions & 0 deletions vision_agent/clients/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json
import logging
from typing import Any, Dict, Optional

from requests import Session
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, RequestException, Timeout

_LOGGER = logging.getLogger(__name__)


class BaseHTTP:
_TIMEOUT = 30 # seconds
_MAX_RETRIES = 3

def __init__(
self, base_endpoint: str, *, headers: Optional[Dict[str, Any]] = None
) -> None:
self._headers = headers
if headers is None:
self._headers = {
"Content-Type": "application/json",
}
self._base_endpoint = base_endpoint
self._session = Session()
self._session.headers.update(self._headers) # type: ignore
self._session.mount(
self._base_endpoint, HTTPAdapter(max_retries=self._MAX_RETRIES)
)

def post(self, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
formatted_url = f"{self._base_endpoint}/{url}"
_LOGGER.info(f"Sending data to {formatted_url}")
try:
response = self._session.post(
url=formatted_url, json=payload, timeout=self._TIMEOUT
)
response.raise_for_status()
result: Dict[str, Any] = response.json()
_LOGGER.info(json.dumps(result))
except (ConnectionError, Timeout, RequestException) as err:
_LOGGER.warning(f"Error: {err}.")
except json.JSONDecodeError:
resp_text = response.text
_LOGGER.warning(f"Response seems incorrect: '{resp_text}'.")
return result
26 changes: 26 additions & 0 deletions vision_agent/clients/landing_public_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
from uuid import UUID
from typing import List

from vision_agent.clients.http import BaseHTTP
from vision_agent.utils.type_defs import LandingaiAPIKey
from vision_agent.tools.meta_tools_types import BboxInputBase64, PromptTask


class LandingPublicAPI(BaseHTTP):
def __init__(self) -> None:
landing_url = os.environ.get("LANDINGAI_URL", "https://api.dev.landing.ai")
landing_api_key = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key)
headers = {"Content-Type": "application/json", "apikey": landing_api_key}
super().__init__(base_endpoint=landing_url, headers=headers)

def launch_fine_tuning_job(
self, model_name: str, task: PromptTask, bboxes: List[BboxInputBase64]
) -> UUID:
url = "v1/agent/jobs/fine-tuning"
data = {
"model": {"name": model_name, "task": task.value},
"bboxes": [bbox.model_dump(by_alias=True) for bbox in bboxes],
}
response = self.post(url, payload=data)
return UUID(response["jobId"])
2 changes: 1 addition & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable, List, Optional

from .meta_tools import META_TOOL_DOCSTRING
from .meta_tools import META_TOOL_DOCSTRING, florencev2_fine_tuning
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
from .tools import (
TOOL_DESCRIPTIONS,
45 changes: 45 additions & 0 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
import subprocess
from uuid import UUID
from pathlib import Path
from typing import Any, Dict, List, Union

import vision_agent as va
from vision_agent.lmm.types import Message
from vision_agent.tools.tool_utils import get_tool_documentation
from vision_agent.tools.tools import TOOL_DESCRIPTIONS
from vision_agent.utils.image_utils import convert_to_b64
from vision_agent.clients.landing_public_api import LandingPublicAPI
from vision_agent.tools.meta_tools_types import BboxInput, BboxInputBase64, PromptTask

# These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent

@@ -385,6 +389,46 @@ def get_tool_descriptions() -> str:
return TOOL_DESCRIPTIONS


def florencev2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID:
"""'florencev2_fine_tuning' is a tool that fine-tune florencev2 to be able
to detect objects in an image based on a given dataset. It returns the fine
tuning job id.
Parameters:
bboxes (List[BboxInput]): A list of BboxInput containing the
image path, labels and bounding boxes.
task (PromptTask): The florencev2 fine-tuning task. The options are
CAPTION, CAPTION_TO_PHRASE_GROUNDING and OBJECT_DETECTION.
Returns:
UUID: The fine tuning job id, this id will used to retrieve the fine
tuned model.
Example
-------
>>> fine_tuning_job_id = florencev2_fine_tuning(
[{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[370, 30, 560, 290]]},
{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}],
"OBJECT_DETECTION"
)
"""
bboxes_input = [BboxInput.model_validate(bbox) for bbox in bboxes]
task_input = PromptTask[task]
fine_tuning_request = [
BboxInputBase64(
image=convert_to_b64(bbox_input.image_path),
filename=bbox_input.image_path.split("/")[-1],
labels=bbox_input.labels,
bboxes=bbox_input.bboxes,
)
for bbox_input in bboxes_input
]
landing_api = LandingPublicAPI()
return landing_api.launch_fine_tuning_job(
"florencev2", task_input, fine_tuning_request
)


META_TOOL_DOCSTRING = get_tool_documentation(
[
get_tool_descriptions,
@@ -398,5 +442,6 @@ def get_tool_descriptions() -> str:
search_dir,
search_file,
find_file,
florencev2_fine_tuning,
]
)
30 changes: 30 additions & 0 deletions vision_agent/tools/meta_tools_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from enum import Enum
from typing import List, Tuple

from pydantic import BaseModel


class BboxInput(BaseModel):
image_path: str
labels: List[str]
bboxes: List[Tuple[int, int, int, int]]


class BboxInputBase64(BaseModel):
image: str
filename: str
labels: List[str]
bboxes: List[Tuple[int, int, int, int]]


class PromptTask(str, Enum):
"""
Valid task prompts options for the Florencev2 model.
"""

CAPTION = "<CAPTION>"
""""""
CAPTION_TO_PHRASE_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>"
""""""
OBJECT_DETECTION = "<OD>"
""""""
9 changes: 4 additions & 5 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
@@ -2,23 +2,23 @@
import json
import logging
import tempfile
from importlib import resources
from pathlib import Path
from importlib import resources
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import cv2
import numpy as np
import requests
import numpy as np
from pytube import YouTube # type: ignore
from moviepy.editor import ImageSequenceClip
from PIL import Image, ImageDraw, ImageFont
from pillow_heif import register_heif_opener # type: ignore
from pytube import YouTube # type: ignore

from vision_agent.tools.tool_utils import (
send_inference_request,
get_tool_descriptions,
get_tool_documentation,
get_tools_df,
send_inference_request,
)
from vision_agent.utils import extract_frames_from_video
from vision_agent.utils.execute import FileSerializer, MimeType
@@ -1063,7 +1063,6 @@ def save_video(
if fps <= 0:
_LOGGER.warning(f"Invalid fps value: {fps}. Setting fps to 4 (default value).")
fps = 4

with ImageSequenceClip(frames, fps=fps) as video:
if output_video_path:
f = open(output_video_path, "wb")
4 changes: 2 additions & 2 deletions vision_agent/utils/execute.py
Original file line number Diff line number Diff line change
@@ -209,7 +209,7 @@ def formats(self) -> Iterable[str]:
return formats

@staticmethod
def from_e2b_result(result: E2BResult) -> "Result": # type: ignore
def from_e2b_result(result: E2BResult) -> "Result":
"""
Creates a Result object from an E2BResult object.
"""
@@ -361,7 +361,7 @@ def from_exception(exec: Exception, traceback_raw: List[str]) -> "Execution":
)

@staticmethod
def from_e2b_execution(exec: E2BExecution) -> "Execution": # type: ignore
def from_e2b_execution(exec: E2BExecution) -> "Execution":
"""Creates an Execution object from an E2BResult object."""
return Execution(
results=[Result.from_e2b_result(res) for res in exec.results],

0 comments on commit ae97907

Please sign in to comment.