Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Platform] Move get_punica_wrapper() function to Platform #11516

Merged
merged 3 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 9 additions & 17 deletions vllm/lora/punica_wrapper/punica_selector.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import resolve_obj_by_qualname

from .punica_base import PunicaWrapperBase

logger = init_logger(__name__)


def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
if current_platform.is_cuda_alike():
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
logger.info_once("Using PunicaWrapperGPU.")
return PunicaWrapperGPU(*args, **kwargs)
elif current_platform.is_cpu():
shen-shanshan marked this conversation as resolved.
Show resolved Hide resolved
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU
logger.info_once("Using PunicaWrapperCPU.")
return PunicaWrapperCPU(*args, **kwargs)
elif current_platform.is_hpu():
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
logger.info_once("Using PunicaWrapperHPU.")
return PunicaWrapperHPU(*args, **kwargs)
else:
raise NotImplementedError
punica_wrapper_qualname = current_platform.get_punica_wrapper()
punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname)
punica_wrapper = punica_wrapper_cls(*args, **kwargs)
assert punica_wrapper is not None, \
"the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong."
logger.info_once("Using " + punica_wrapper_qualname.rsplit(".", 1)[1] +
".")
return punica_wrapper
4 changes: 4 additions & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on CPU.")
return False

@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
4 changes: 4 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
logger.info("Using Flash Attention backend.")
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"

@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on HPU.")
return False

@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU"
7 changes: 7 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,13 @@ def is_pin_memory_available(cls) -> bool:
return False
return True

@classmethod
def get_punica_wrapper(cls) -> str:
"""
Return the punica wrapper for current platform.
"""
raise NotImplementedError


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,7 @@ def verify_quantization(cls, quant: str) -> None:
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True

@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
Loading