From e4497778dcb09261b7cea75e5a1b2f7577c3e124 Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Wed, 4 Dec 2024 09:40:16 -0800 Subject: [PATCH] [LoRA] Change lora_tokenizers capacity (#10796) Signed-off-by: Xin Yang --- tests/lora/test_tokenizer_group.py | 20 +++++++++++++++++++ vllm/engine/llm_engine.py | 2 +- vllm/engine/multiprocessing/client.py | 3 +-- .../tokenizer_group/__init__.py | 9 +++++---- .../tokenizer_group/tokenizer_group.py | 3 ++- vllm/v1/engine/async_llm.py | 2 +- vllm/v1/engine/llm_engine.py | 2 +- 7 files changed, 31 insertions(+), 10 deletions(-) diff --git a/tests/lora/test_tokenizer_group.py b/tests/lora/test_tokenizer_group.py index daa39b2a3dba1..d225a3f7d6c06 100644 --- a/tests/lora/test_tokenizer_group.py +++ b/tests/lora/test_tokenizer_group.py @@ -17,6 +17,7 @@ async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type): tokenizer_id="gpt2", enable_lora=True, max_num_seqs=1, + max_loras=1, max_input_length=None, ) lora_request = LoRARequest("1", 1, sql_lora_files) @@ -53,3 +54,22 @@ def test_get_lora_tokenizer(sql_lora_files, tmp_path): lora_request = LoRARequest("1", 1, str(tmp_path)) tokenizer = get_lora_tokenizer(lora_request) assert not tokenizer + + +@pytest.mark.parametrize("enable_lora", [True, False]) +@pytest.mark.parametrize("max_num_seqs", [1, 2]) +@pytest.mark.parametrize("max_loras", [1, 2]) +def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras): + tokenizer_group = get_tokenizer_group( + get_tokenizer_pool_config(None), + tokenizer_id="gpt2", + enable_lora=enable_lora, + max_num_seqs=max_num_seqs, + max_loras=max_loras, + max_input_length=None, + ) + if enable_lora: + assert tokenizer_group.lora_tokenizers.capacity == max( + max_num_seqs, max_loras) + else: + assert tokenizer_group.lora_tokenizers.capacity == 0 diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index af66b307028cf..1f3c6197ba1a8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -620,7 +620,7 @@ def _init_tokenizer(self) -> BaseTokenizerGroup: model_config=self.model_config, scheduler_config=self.scheduler_config, parallel_config=self.parallel_config, - enable_lora=bool(self.lora_config)) + lora_config=self.lora_config) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index d21136c03d7d2..7e4f81b2cf8e2 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -94,8 +94,7 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig, model_config=self.model_config, scheduler_config=engine_config.scheduler_config, parallel_config=engine_config.parallel_config, - enable_lora=bool(engine_config.lora_config), - ) + lora_config=engine_config.lora_config) self.input_preprocessor = InputPreprocessor(self.model_config, self.tokenizer) diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 6a114b513f382..c0b3d2585a962 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -1,7 +1,7 @@ from typing import Optional, Type -from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, - TokenizerPoolConfig) +from vllm.config import (LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, TokenizerPoolConfig) from vllm.executor.ray_utils import ray from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup @@ -16,10 +16,11 @@ def init_tokenizer_from_configs(model_config: ModelConfig, scheduler_config: SchedulerConfig, parallel_config: ParallelConfig, - enable_lora: bool): + lora_config: LoRAConfig): init_kwargs = dict(tokenizer_id=model_config.tokenizer, - enable_lora=enable_lora, + enable_lora=bool(lora_config), max_num_seqs=scheduler_config.max_num_seqs, + max_loras=lora_config.max_loras if lora_config else 0, max_input_length=None, tokenizer_mode=model_config.tokenizer_mode, trust_remote_code=model_config.trust_remote_code, diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index e516eeabaadef..761b07f34d2f9 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -21,8 +21,9 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, self.enable_lora = enable_lora self.max_input_length = max_input_length self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) + max_loras = tokenizer_config.get("max_loras", 0) self.lora_tokenizers = LRUCache[AnyTokenizer]( - capacity=max_num_seqs if enable_lora else 0) + capacity=max(max_loras, max_num_seqs) if enable_lora else 0) @classmethod def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 7335c637f0f79..4ef372fd8464b 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -51,7 +51,7 @@ def __init__( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, parallel_config=vllm_config.parallel_config, - enable_lora=bool(vllm_config.lora_config)) + lora_config=vllm_config.lora_config) self.tokenizer.ping() # Request streams (map of request_id -> AsyncStream). diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index bd19d998a4adb..312c0242a45dd 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -46,7 +46,7 @@ def __init__( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, parallel_config=vllm_config.parallel_config, - enable_lora=bool(vllm_config.lora_config)) + lora_config=vllm_config.lora_config) self.tokenizer.ping() # Processor (convert Inputs --> EngineCoreRequests)