Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] onnx runtime for label anything #100

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
30 changes: 30 additions & 0 deletions label_anything/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,33 @@ When finished, we can get the model test visualization. On the left is the annot
With the semi-automated annotation function of Label-Studio, users can complete object segmentation and detection by simply clicking the mouse during the annotation process, greatly improving the efficiency of annotation.

Some of the code was borrowed from Pull Request ID 253 of label-studio-ml-backend. Thank you to the author for their contribution. Also, thanks to fellow community member [ATang0729](https://github.com/ATang0729) for re-labeling the meow dataset for script testing, and [JimmyMa99](https://github.com/JimmyMa99) for the conversion script, config template, and documentation Optimization.

## (beta)🚀 SAM backend inference using onnx runtime🚀 (optional)

We use onnx runtime for SAM back-end inference to improve the speed of SAM inference, tested on a 3090, which takes 4.6s with pytorch and 0.24s with onnx runtime.

First download the converted onnx from huggingface.

```shell
cd path/to/playground/label_anything
wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/encoder.onnx
wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/decoder.onnx
```

Then turn on back-end reasoning.

```shell
cd path/to/playground/label_anything

label-studio-ml start sam --port 8003 --with \
sam_config=vit_b \
sam_checkpoint_file=. /sam_vit_b_01ec64.pth \
out_mask=True \
out_bbox=True \
device=cuda:0 \
onnx=True \
# device=cuda:0 for GPU inference, if cpu inference is used, replace cuda:0 with cpu
# out_poly=True returns the annotation of the external polygon
```

⚠ Currently only sam_vit_b is supported.
30 changes: 30 additions & 0 deletions label_anything/readme_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -384,5 +384,35 @@ python tools/test.py data/my_set/mask-rcnn_r50_fpn.py path/of/your/checkpoint --

到此半自动化标注就完成了, 通过 Label-Studio 的半自动化标注功能,可以让用户在标注过程中,通过点击一下鼠标,就可以完成目标的分割和检测,大大提高了标注效率。部分代码借鉴自 label-studio-ml-backend ID 为 253 的 Pull Request,感谢作者的贡献。同时感谢社区同学 [ATang0729](https://github.com/ATang0729) 为脚本测试重新标注了喵喵数据集,以及 [JimmyMa99](https://github.com/JimmyMa99) 同学提供的转换脚本、 config 模板以及文档优化。

## (测试阶段)🚀使用 onnx runtime 进行 SAM 后端推理🚀(可选)

我们使用 onnx runtime 进行 SAM 后端推理以提升 SAM 的推理速度,在一张 3090 上测试,使用 pytorch 需要 4.6s ,使用 onnx runtime 只要 0.24s。

首先下载 huggingface 上转换好的 onnx。

```shell
cd path/to/playground/label_anything
wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/encoder.onnx
wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/decoder.onnx
```

接着开启后端推理。

```shell
cd path/to/playground/label_anything

label-studio-ml start sam --port 8003 --with \
sam_config=vit_b \
sam_checkpoint_file=./sam_vit_b_01ec64.pth \
out_mask=True \
out_bbox=True \
device=cuda:0 \
onnx=True \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不加这个 onnx=True 就是 PyTorch 推理是吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对啊

# device=cuda:0 为使用 GPU 推理,如果使用 cpu 推理,将 cuda:0 替换为 cpu
# out_poly=True 返回外接多边形的标注
```

⚠目前仅支持 sam_vit_b。



209 changes: 171 additions & 38 deletions label_anything/sam/mmdetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from label_studio_converter import brush
import torch
from torch.nn import functional as F

import cv2

Expand All @@ -19,8 +20,13 @@

# from mmdet.apis import inference_detector, init_detector
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
from segment_anything.utils.transforms import ResizeLongestSide
import random
import string
import time
import onnxruntime


logger = logging.getLogger(__name__)

def load_my_model(device="cuda:0",sam_config="vit_b",sam_checkpoint_file="sam_vit_b_01ec64.pth"):
Expand All @@ -34,6 +40,28 @@ def load_my_model(device="cuda:0",sam_config="vit_b",sam_checkpoint_file="sam_vi
return predictor


def load_my_onnx(onnx_config:dict):
# !wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/encoder.onnx
# !wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/decoder.onnx
encoder_model_abs_path = "./encoder.onnx"
decoder_model_abs_path = "./decoder.onnx"


providers = onnxruntime.get_available_providers()
if providers:
logging.info(
"Available providers for ONNXRuntime: %s", ", ".join(providers)
)
else:
logging.warning("No available providers for ONNXRuntime")
encoder_session = onnxruntime.InferenceSession(
encoder_model_abs_path, providers=providers
)
decoder_session = onnxruntime.InferenceSession(
decoder_model_abs_path, providers=providers
)

return encoder_session,decoder_session

class MMDetection(LabelStudioMLBase):
"""Object detector based on https://github.com/open-mmlab/mmdetection."""
Expand All @@ -50,21 +78,23 @@ def __init__(self,
out_poly=False,
score_threshold=0.5,
device='cpu',
onnx=False,
**kwargs):

super(MMDetection, self).__init__(**kwargs)
self.onnx=onnx
if self.onnx:
PREDICTOR=load_my_onnx(device)
else:
PREDICTOR=load_my_model(device,sam_config,sam_checkpoint_file)

PREDICTOR=load_my_model(device,sam_config,sam_checkpoint_file)

self.PREDICTOR = PREDICTOR

self.out_mask = out_mask
self.out_bbox = out_bbox
self.out_poly = out_poly

# config_file = config_file or os.environ['config_file']
# checkpoint_file = checkpoint_file or os.environ['checkpoint_file']
# self.config_file = config_file
# self.checkpoint_file = checkpoint_file
self.labels_file = labels_file
# default Label Studio image upload folder
upload_dir = os.path.join(get_data_dir(), 'media', 'upload')
Expand All @@ -76,8 +106,6 @@ def __init__(self,
else:
self.label_map = {}

# self.from_name, self.to_name, self.value, self.labels_in_config = get_single_tag_keys( # noqa E501
# self.parsed_label_config, 'RectangleLabels', 'Image')

self.labels_in_config = dict(
label=self.parsed_label_config['KeyPointLabels']
Expand Down Expand Up @@ -132,6 +160,78 @@ def __init__(self,
# self.model = init_detector(config_file, checkpoint_file, device=device)
self.score_thresh = score_threshold


def pre_process(self, image):
image_size = 1024
transform = ResizeLongestSide(image_size)

input_image = transform.apply_image(image)
input_image_torch = torch.as_tensor(input_image, device="cpu")
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
x = (input_image_torch - pixel_mean) / pixel_std
h, w = x.shape[-2:]
padh = image_size - h
padw = image_size - w
x = F.pad(x, (0, padw, 0, padh))
x = x.numpy()

encoder_inputs = {
"x": x,
}
return encoder_inputs, image.shape[:2]

def run_encoder(self, encoder_inputs):
output = self.encoder_session.run(None, encoder_inputs)
image_embedding = output[0]
return image_embedding



def run_decoder(
self, image_embedding, input_prompt,img_size):
(original_height,original_width)=img_size
points=input_prompt['points']
masks=input_prompt['mask']
boxes=input_prompt['boxes']
labels=input_prompt['label']

image_size = 1024
transform = ResizeLongestSide(image_size)
if boxes is not None:
onnx_box_coords = boxes.reshape(2, 2)
input_labels = np.array([2,3])

onnx_coord = np.concatenate([onnx_box_coords, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[None, :].astype(np.float32)
elif points is not None:
input_point=points
input_label = np.array([1])
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)

onnx_coord = transform.apply_coords(onnx_coord, img_size).astype(np.float32)

onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)


decoder_inputs = {
"image_embeddings": image_embedding,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(
img_size, dtype=np.float32
),
}
masks, _, _ = self.decoder_session.run(None, decoder_inputs)
masks = masks > 0.0

return masks

def _get_image_url(self, task):
image_url = task['data'].get(
self.value) or task['data'].get(DATA_UNDEFINED_NAME)
Expand All @@ -156,8 +256,7 @@ def _get_image_url(self, task):

def predict(self, tasks, **kwargs):

predictor = self.PREDICTOR

start = time.time()
results = []
assert len(tasks) == 1
task = tasks[0]
Expand All @@ -167,61 +266,95 @@ def predict(self, tasks, **kwargs):
if kwargs.get('context') is None:
return []

# image = cv2.imread(f"./{split}")
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)

prompt_type = kwargs['context']['result'][0]['type']
original_height = kwargs['context']['result'][0]['original_height']
original_width = kwargs['context']['result'][0]['original_width']

if self.onnx:
self.encoder_session,self.decoder_session=self.PREDICTOR
encoder_inputs,_ = self.pre_process(image)

if prompt_type == 'keypointlabels':
# getting x and y coordinates of the keypoint
x = kwargs['context']['result'][0]['value']['x'] * original_width / 100
y = kwargs['context']['result'][0]['value']['y'] * original_height / 100
output_label = kwargs['context']['result'][0]['value']['labels'][0]
input_prompt={}

input_prompt['boxes']=input_prompt['mask']=input_prompt['points']=input_prompt['label']=None
if prompt_type == 'keypointlabels':
# getting x and y coordinates of the keypoint
x = kwargs['context']['result'][0]['value']['x'] * original_width / 100
y = kwargs['context']['result'][0]['value']['y'] * original_height / 100
output_label = kwargs['context']['result'][0]['value']['labels'][0]

masks, scores, logits = predictor.predict(
point_coords=np.array([[x, y]]),
# box=np.array([x.cpu() for x in bbox[:4]]),
point_labels=np.array([1]),
multimask_output=False,
)
input_prompt['points']=np.array([[x, y]])
input_prompt['label']=np.array([1])


if prompt_type == 'rectanglelabels':

x = kwargs['context']['result'][0]['value']['x'] * original_width / 100
y = kwargs['context']['result'][0]['value']['y'] * original_height / 100
w = kwargs['context']['result'][0]['value']['width'] * original_width / 100
h = kwargs['context']['result'][0]['value']['height'] * original_height / 100

if prompt_type == 'rectanglelabels':
output_label = kwargs['context']['result'][0]['value']['rectanglelabels'][0]

input_prompt['boxes']=np.array([x, y, x+w, y+h])

input_prompt['label'] = np.array([2,3])


#encoder
image_embedding = self.run_encoder(encoder_inputs)
masks = self.run_decoder(image_embedding,input_prompt,\
(original_height,original_width))
masks = masks[0].astype(np.uint8)

x = kwargs['context']['result'][0]['value']['x'] * original_width / 100
y = kwargs['context']['result'][0]['value']['y'] * original_height / 100
w = kwargs['context']['result'][0]['value']['width'] * original_width / 100
h = kwargs['context']['result'][0]['value']['height'] * original_height / 100
else:
predictor = self.PREDICTOR
predictor.set_image(image)

output_label = kwargs['context']['result'][0]['value']['rectanglelabels'][0]
if prompt_type == 'keypointlabels':
# getting x and y coordinates of the keypoint
x = kwargs['context']['result'][0]['value']['x'] * original_width / 100
y = kwargs['context']['result'][0]['value']['y'] * original_height / 100
output_label = kwargs['context']['result'][0]['value']['labels'][0]


masks, scores, logits = predictor.predict(
point_coords=np.array([[x, y]]),
# box=np.array([x.cpu() for x in bbox[:4]]),
point_labels=np.array([1]),
multimask_output=False,
)

masks, scores, logits = predictor.predict(
# point_coords=np.array([[x, y]]),
box=np.array([x, y, x+w, y+h]),
point_labels=np.array([1]),
multimask_output=False,
)

if prompt_type == 'rectanglelabels':

x = kwargs['context']['result'][0]['value']['x'] * original_width / 100
y = kwargs['context']['result'][0]['value']['y'] * original_height / 100
w = kwargs['context']['result'][0]['value']['width'] * original_width / 100
h = kwargs['context']['result'][0]['value']['height'] * original_height / 100

output_label = kwargs['context']['result'][0]['value']['rectanglelabels'][0]

masks, scores, logits = predictor.predict(
# point_coords=np.array([[x, y]]),
box=np.array([x, y, x+w, y+h]),
point_labels=np.array([1]),
multimask_output=False,
)

mask = masks[0].astype(np.uint8) # each mask has shape [H, W]
# converting the mask from the model to RLE format which is usable in Label Studio

# 找到轮廓
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)


end = time.time()
print(end-start)


# 计算外接矩形


if self.out_bbox:
new_contours = []
for contour in contours:
Expand Down