diff --git a/src/model.py b/src/model.py index 0f09c3d3..7211b10b 100644 --- a/src/model.py +++ b/src/model.py @@ -55,9 +55,9 @@ def auto_complete_config(auto_complete_model_config): inputs = [ {"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]}, { - "name": "multi_modal_data", + "name": "image", "data_type": "TYPE_STRING", - "dims": [1], + "dims": [-1], # can be multiple images as separate elements "optional": True, }, { @@ -394,20 +394,16 @@ async def generate(self, request): if isinstance(prompt, bytes): prompt = prompt.decode("utf-8") - multi_modal_data_input_tensor = pb_utils.get_input_tensor_by_name( - request, "multi_modal_data" + image_input_tensor = pb_utils.get_input_tensor_by_name( + request, "image" ) - if multi_modal_data_input_tensor: - multi_modal_data = multi_modal_data_input_tensor.as_numpy()[0].decode("utf-8") - multi_modal_data = json.loads(multi_modal_data) - if "image" in multi_modal_data: - image_list = [] - for image_base64_string in multi_modal_data["image"]: - if "base64," in image_base64_string: - image_base64_string = image_base64_string.split("base64,")[-1] - image_data = base64.b64decode(image_base64_string) - image = Image.open(BytesIO(image_data)).convert("RGB") - image_list.append(image) + if image_input_tensor: + image_list = [] + for image_raw in image_input_tensor.as_numpy(): + image_data = base64.b64decode(image_raw.decode("utf-8")) + image = Image.open(BytesIO(image_data)).convert("RGB") + image_list.append(image) + if len(image_list) > 0: prompt = { "prompt": prompt, "multi_modal_data": {