-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
feat: florencev2 fine-tuning meta tool (#190)
* add florencev2 fine tuning * add pre-commit * add task customization * tools to meta tools
Showing
12 changed files
with
652 additions
and
371 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
repos: | ||
- repo: https://github.com/psf/black | ||
rev: 24.4.2 | ||
hooks: | ||
- id: black | ||
language_version: python3.9 | ||
- repo: https://github.com/pycqa/flake8 | ||
rev: 7.0.0 | ||
hooks: | ||
- id: flake8 |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import json | ||
import logging | ||
from typing import Any, Dict, Optional | ||
|
||
from requests import Session | ||
from requests.adapters import HTTPAdapter | ||
from requests.exceptions import ConnectionError, RequestException, Timeout | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class BaseHTTP: | ||
_TIMEOUT = 30 # seconds | ||
_MAX_RETRIES = 3 | ||
|
||
def __init__( | ||
self, base_endpoint: str, *, headers: Optional[Dict[str, Any]] = None | ||
) -> None: | ||
self._headers = headers | ||
if headers is None: | ||
self._headers = { | ||
"Content-Type": "application/json", | ||
} | ||
self._base_endpoint = base_endpoint | ||
self._session = Session() | ||
self._session.headers.update(self._headers) # type: ignore | ||
self._session.mount( | ||
self._base_endpoint, HTTPAdapter(max_retries=self._MAX_RETRIES) | ||
) | ||
|
||
def post(self, url: str, payload: Dict[str, Any]) -> Dict[str, Any]: | ||
formatted_url = f"{self._base_endpoint}/{url}" | ||
_LOGGER.info(f"Sending data to {formatted_url}") | ||
try: | ||
response = self._session.post( | ||
url=formatted_url, json=payload, timeout=self._TIMEOUT | ||
) | ||
response.raise_for_status() | ||
result: Dict[str, Any] = response.json() | ||
_LOGGER.info(json.dumps(result)) | ||
except (ConnectionError, Timeout, RequestException) as err: | ||
_LOGGER.warning(f"Error: {err}.") | ||
except json.JSONDecodeError: | ||
resp_text = response.text | ||
_LOGGER.warning(f"Response seems incorrect: '{resp_text}'.") | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import os | ||
from uuid import UUID | ||
from typing import List | ||
|
||
from vision_agent.clients.http import BaseHTTP | ||
from vision_agent.utils.type_defs import LandingaiAPIKey | ||
from vision_agent.tools.meta_tools_types import BboxInputBase64, PromptTask | ||
|
||
|
||
class LandingPublicAPI(BaseHTTP): | ||
def __init__(self) -> None: | ||
landing_url = os.environ.get("LANDINGAI_URL", "https://api.dev.landing.ai") | ||
landing_api_key = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key) | ||
headers = {"Content-Type": "application/json", "apikey": landing_api_key} | ||
super().__init__(base_endpoint=landing_url, headers=headers) | ||
|
||
def launch_fine_tuning_job( | ||
self, model_name: str, task: PromptTask, bboxes: List[BboxInputBase64] | ||
) -> UUID: | ||
url = "v1/agent/jobs/fine-tuning" | ||
data = { | ||
"model": {"name": model_name, "task": task.value}, | ||
"bboxes": [bbox.model_dump(by_alias=True) for bbox in bboxes], | ||
} | ||
response = self.post(url, payload=data) | ||
return UUID(response["jobId"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from enum import Enum | ||
from typing import List, Tuple | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
class BboxInput(BaseModel): | ||
image_path: str | ||
labels: List[str] | ||
bboxes: List[Tuple[int, int, int, int]] | ||
|
||
|
||
class BboxInputBase64(BaseModel): | ||
image: str | ||
filename: str | ||
labels: List[str] | ||
bboxes: List[Tuple[int, int, int, int]] | ||
|
||
|
||
class PromptTask(str, Enum): | ||
""" | ||
Valid task prompts options for the Florencev2 model. | ||
""" | ||
|
||
CAPTION = "<CAPTION>" | ||
"""""" | ||
CAPTION_TO_PHRASE_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>" | ||
"""""" | ||
OBJECT_DETECTION = "<OD>" | ||
"""""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters