diff --git a/src/model.py b/src/model.py index d37ab689..d7b550c6 100644 --- a/src/model.py +++ b/src/model.py @@ -25,21 +25,25 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import asyncio +import base64 import gc import json import os import queue import threading +from io import BytesIO from typing import Dict, List import numpy as np import torch import triton_python_backend_utils as pb_utils +from PIL import Image from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid +from vllm.version import __version__ as _VLLM_VERSION from utils.metrics import VllmStatLogger @@ -67,7 +71,7 @@ def auto_complete_config(cls, auto_complete_model_config): @staticmethod def _auto_complete_inputs_and_outputs(auto_complete_model_config): - # Inputs/Outputs expected by the backend. + # Inputs expected by the backend. inputs = [ {"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]}, { @@ -107,6 +111,16 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config): "optional": True, }, ] + if _VLLM_VERSION >= "0.6.3.post1": + inputs.append( + { + "name": "image", + "data_type": "TYPE_STRING", + "dims": [-1], # can be multiple images as separate elements + "optional": True, + } + ) + # Outputs expected by the backend. outputs = [ {"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}, {"name": "finish_reason", "data_type": "TYPE_STRING", "dims": [-1]}, @@ -313,6 +327,21 @@ def _get_input_tensors(self, request): if isinstance(prompt, bytes): prompt = prompt.decode("utf-8") + # image + if _VLLM_VERSION >= "0.6.3.post1": + images = pb_utils.get_input_tensor_by_name(request, "image") + if images: + images_vllm = [] + for image_np in images.as_numpy(): + image_b = base64.b64decode(image_np.decode("utf-8")) + image_rgb = Image.open(BytesIO(image_b)).convert("RGB") + images_vllm.append(image_rgb) + if len(images_vllm) > 0: + prompt = { + "prompt": prompt, + "multi_modal_data": {"image": images_vllm}, + } + # stream stream = pb_utils.get_input_tensor_by_name(request, "stream") if stream: diff --git a/src/utils/metrics.py b/src/utils/metrics.py index 75c097dc..0504eef9 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -32,7 +32,7 @@ from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase from vllm.engine.metrics import Stats as VllmStats from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets - +from vllm.version import __version__ as _VLLM_VERSION class TritonMetrics: def __init__(self, labels: List[str], max_model_len: int): @@ -76,11 +76,14 @@ def __init__(self, labels: List[str], max_model_len: int): description="Number of generation tokens processed.", kind=pb_utils.MetricFamily.HISTOGRAM, ) - self.histogram_best_of_request_family = pb_utils.MetricFamily( - name="vllm:request_params_best_of", - description="Histogram of the best_of request parameter.", - kind=pb_utils.MetricFamily.HISTOGRAM, - ) + # 'best_of' metric has been hidden since vllm 0.6.3 + # https://github.com/vllm-project/vllm/commit/cbc2ef55292b2af6ff742095c030e8425124c005 + if _VLLM_VERSION < "0.6.3": + self.histogram_best_of_request_family = pb_utils.MetricFamily( + name="vllm:request_params_best_of", + description="Histogram of the best_of request parameter.", + kind=pb_utils.MetricFamily.HISTOGRAM, + ) self.histogram_n_request_family = pb_utils.MetricFamily( name="vllm:request_params_n", description="Histogram of the n request parameter.", @@ -159,10 +162,11 @@ def __init__(self, labels: List[str], max_model_len: int): buckets=build_1_2_5_buckets(max_model_len), ) ) - self.histogram_best_of_request = self.histogram_best_of_request_family.Metric( - labels=labels, - buckets=[1, 2, 5, 10, 20], - ) + if _VLLM_VERSION < "0.6.3": + self.histogram_best_of_request = self.histogram_best_of_request_family.Metric( + labels=labels, + buckets=[1, 2, 5, 10, 20], + ) self.histogram_n_request = self.histogram_n_request_family.Metric( labels=labels, buckets=[1, 2, 5, 10, 20], @@ -247,10 +251,10 @@ def log(self, stats: VllmStats) -> None: self.metrics.histogram_num_generation_tokens_request, stats.num_generation_tokens_requests, ), - (self.metrics.histogram_best_of_request, stats.best_of_requests), (self.metrics.histogram_n_request, stats.n_requests), ] - + if _VLLM_VERSION < "0.6.3": + histogram_metrics.append((self.metrics.histogram_best_of_request, stats.best_of_requests)) for metric, data in counter_metrics: self._log_counter(metric, data) for metric, data in histogram_metrics: