From 6c066f6cb76d94eddc848c17da37cfd1ebe6806d Mon Sep 17 00:00:00 2001 From: xiejibing <33129072+xiejibing@users.noreply.github.com> Date: Tue, 26 Nov 2024 04:06:54 +0800 Subject: [PATCH] Support input for llama3.2 multi-modal model (#69) Co-authored-by: jibxie --- src/model.py | 32 +++++++++++++++++++++++++++++++- src/utils/metrics.py | 28 ++++++++++++++++------------ 2 files changed, 47 insertions(+), 13 deletions(-) diff --git a/src/model.py b/src/model.py index 3f6e23bb..0fdbe0ce 100644 --- a/src/model.py +++ b/src/model.py @@ -31,7 +31,9 @@ import queue import threading from typing import Dict, List - +import base64 +from PIL import Image +from io import BytesIO import numpy as np import torch import triton_python_backend_utils as pb_utils @@ -40,6 +42,7 @@ from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid +from vllm.version import __version__ as _VLLM_VERSION from utils.metrics import VllmStatLogger @@ -71,6 +74,14 @@ def auto_complete_config(auto_complete_model_config): "optional": True, }, ] + if _VLLM_VERSION >= "0.6.3.post1": + inputs.append({ + "name": "image", + "data_type": "TYPE_STRING", + "dims": [-1], # can be multiple images as separate elements + "optional": True, + }) + outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}] # Store the model configuration as a dictionary. @@ -385,6 +396,25 @@ async def generate(self, request): ).as_numpy()[0] if isinstance(prompt, bytes): prompt = prompt.decode("utf-8") + + if _VLLM_VERSION >= "0.6.3.post1": + image_input_tensor = pb_utils.get_input_tensor_by_name( + request, "image" + ) + if image_input_tensor: + image_list = [] + for image_raw in image_input_tensor.as_numpy(): + image_data = base64.b64decode(image_raw.decode("utf-8")) + image = Image.open(BytesIO(image_data)).convert("RGB") + image_list.append(image) + if len(image_list) > 0: + prompt = { + "prompt": prompt, + "multi_modal_data": { + "image": image_list + } + } + stream = pb_utils.get_input_tensor_by_name(request, "stream") if stream: stream = stream.as_numpy()[0] diff --git a/src/utils/metrics.py b/src/utils/metrics.py index 75c097dc..0504eef9 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -32,7 +32,7 @@ from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase from vllm.engine.metrics import Stats as VllmStats from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets - +from vllm.version import __version__ as _VLLM_VERSION class TritonMetrics: def __init__(self, labels: List[str], max_model_len: int): @@ -76,11 +76,14 @@ def __init__(self, labels: List[str], max_model_len: int): description="Number of generation tokens processed.", kind=pb_utils.MetricFamily.HISTOGRAM, ) - self.histogram_best_of_request_family = pb_utils.MetricFamily( - name="vllm:request_params_best_of", - description="Histogram of the best_of request parameter.", - kind=pb_utils.MetricFamily.HISTOGRAM, - ) + # 'best_of' metric has been hidden since vllm 0.6.3 + # https://github.com/vllm-project/vllm/commit/cbc2ef55292b2af6ff742095c030e8425124c005 + if _VLLM_VERSION < "0.6.3": + self.histogram_best_of_request_family = pb_utils.MetricFamily( + name="vllm:request_params_best_of", + description="Histogram of the best_of request parameter.", + kind=pb_utils.MetricFamily.HISTOGRAM, + ) self.histogram_n_request_family = pb_utils.MetricFamily( name="vllm:request_params_n", description="Histogram of the n request parameter.", @@ -159,10 +162,11 @@ def __init__(self, labels: List[str], max_model_len: int): buckets=build_1_2_5_buckets(max_model_len), ) ) - self.histogram_best_of_request = self.histogram_best_of_request_family.Metric( - labels=labels, - buckets=[1, 2, 5, 10, 20], - ) + if _VLLM_VERSION < "0.6.3": + self.histogram_best_of_request = self.histogram_best_of_request_family.Metric( + labels=labels, + buckets=[1, 2, 5, 10, 20], + ) self.histogram_n_request = self.histogram_n_request_family.Metric( labels=labels, buckets=[1, 2, 5, 10, 20], @@ -247,10 +251,10 @@ def log(self, stats: VllmStats) -> None: self.metrics.histogram_num_generation_tokens_request, stats.num_generation_tokens_requests, ), - (self.metrics.histogram_best_of_request, stats.best_of_requests), (self.metrics.histogram_n_request, stats.n_requests), ] - + if _VLLM_VERSION < "0.6.3": + histogram_metrics.append((self.metrics.histogram_best_of_request, stats.best_of_requests)) for metric, data in counter_metrics: self._log_counter(metric, data) for metric, data in histogram_metrics: