From 14241ab2dff974181e192010d4d15d4a054cd97f Mon Sep 17 00:00:00 2001 From: Akash Kaothalkar <0052v2@linux.vnet.ibm.com> Date: Wed, 18 Dec 2024 07:21:56 -0500 Subject: [PATCH 1/2] fix: for auto dtype and fp16(half) dtype model casting Signed-off-by: Akash Kaothalkar <0052v2@linux.vnet.ibm.com> --- vllm/config.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 307cf9c8d5b2a..adabf5cfd8601 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, get_quantization_config) from vllm.model_executor.models import ModelRegistry -from vllm.platforms import current_platform +from vllm.platforms import current_platform, interface from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, @@ -2144,7 +2144,14 @@ def _get_and_verify_dtype( torch_dtype = torch.float16 else: torch_dtype = config_dtype - + + if (current_platform.is_cpu() and + current_platform.get_cpu_architecture() == interface.CpuArchEnum.POWERPC): + if config_dtype == torch.float16 or config_dtype == torch.float32: + logger.info( + "For POWERPC, we cast models to bfloat16 instead of" + "using float16 by default. Float16 is not currently supported for POWERPC.") + torch_dtype = torch.bfloat16 if current_platform.is_hpu() and config_dtype == torch.float16: logger.info( "For HPU, we cast models to bfloat16 instead of" From 616e760543a2c517f0f6cc6c199086f25e9ee37b Mon Sep 17 00:00:00 2001 From: Akash Kaothalkar <0052v2@linux.vnet.ibm.com> Date: Thu, 19 Dec 2024 04:28:29 -0500 Subject: [PATCH 2/2] chore: ran format.sh Signed-off-by: Akash Kaothalkar <0052v2@linux.vnet.ibm.com> --- vllm/config.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index adabf5cfd8601..5e533a2a7c964 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2144,14 +2144,18 @@ def _get_and_verify_dtype( torch_dtype = torch.float16 else: torch_dtype = config_dtype - - if (current_platform.is_cpu() and - current_platform.get_cpu_architecture() == interface.CpuArchEnum.POWERPC): - if config_dtype == torch.float16 or config_dtype == torch.float32: - logger.info( - "For POWERPC, we cast models to bfloat16 instead of" - "using float16 by default. Float16 is not currently supported for POWERPC.") - torch_dtype = torch.bfloat16 + + if (current_platform.is_cpu() + and current_platform.get_cpu_architecture() + == interface.CpuArchEnum.POWERPC + and (config_dtype == torch.float16 + or config_dtype == torch.float32)): + logger.info( + "For POWERPC, we cast models to bfloat16 instead of " + "using float16 by default. Float16 is not currently " + "supported for POWERPC.") + torch_dtype = torch.bfloat16 + if current_platform.is_hpu() and config_dtype == torch.float16: logger.info( "For HPU, we cast models to bfloat16 instead of"