Skip to content

Commit

Permalink
[LoRA] Change lora_tokenizers capacity (#10796)
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yang <[email protected]>
  • Loading branch information
xyang16 authored Dec 4, 2024
1 parent c92acb9 commit 01d079f
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 10 deletions.
20 changes: 20 additions & 0 deletions tests/lora/test_tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
tokenizer_id="gpt2",
enable_lora=True,
max_num_seqs=1,
max_loras=1,
max_input_length=None,
)
lora_request = LoRARequest("1", 1, sql_lora_files)
Expand Down Expand Up @@ -53,3 +54,22 @@ def test_get_lora_tokenizer(sql_lora_files, tmp_path):
lora_request = LoRARequest("1", 1, str(tmp_path))
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer


@pytest.mark.parametrize("enable_lora", [True, False])
@pytest.mark.parametrize("max_num_seqs", [1, 2])
@pytest.mark.parametrize("max_loras", [1, 2])
def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras):
tokenizer_group = get_tokenizer_group(
get_tokenizer_pool_config(None),
tokenizer_id="gpt2",
enable_lora=enable_lora,
max_num_seqs=max_num_seqs,
max_loras=max_loras,
max_input_length=None,
)
if enable_lora:
assert tokenizer_group.lora_tokenizers.capacity == max(
max_num_seqs, max_loras)
else:
assert tokenizer_group.lora_tokenizers.capacity == 0
2 changes: 1 addition & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def _init_tokenizer(self) -> BaseTokenizerGroup:
model_config=self.model_config,
scheduler_config=self.scheduler_config,
parallel_config=self.parallel_config,
enable_lora=bool(self.lora_config))
lora_config=self.lora_config)

def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
Expand Down
3 changes: 1 addition & 2 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig,
model_config=self.model_config,
scheduler_config=engine_config.scheduler_config,
parallel_config=engine_config.parallel_config,
enable_lora=bool(engine_config.lora_config),
)
lora_config=engine_config.lora_config)
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer)

Expand Down
9 changes: 5 additions & 4 deletions vllm/transformers_utils/tokenizer_group/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Type

from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
TokenizerPoolConfig)
from vllm.config import (LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, TokenizerPoolConfig)
from vllm.executor.ray_utils import ray

from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
Expand All @@ -16,10 +16,11 @@
def init_tokenizer_from_configs(model_config: ModelConfig,
scheduler_config: SchedulerConfig,
parallel_config: ParallelConfig,
enable_lora: bool):
lora_config: LoRAConfig):
init_kwargs = dict(tokenizer_id=model_config.tokenizer,
enable_lora=enable_lora,
enable_lora=bool(lora_config),
max_num_seqs=scheduler_config.max_num_seqs,
max_loras=lora_config.max_loras if lora_config else 0,
max_input_length=None,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
Expand Down
3 changes: 2 additions & 1 deletion vllm/transformers_utils/tokenizer_group/tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
max_loras = tokenizer_config.get("max_loras", 0)
self.lora_tokenizers = LRUCache[AnyTokenizer](
capacity=max_num_seqs if enable_lora else 0)
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)

@classmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
enable_lora=bool(vllm_config.lora_config))
lora_config=vllm_config.lora_config)
self.tokenizer.ping()

# Request streams (map of request_id -> AsyncStream).
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
enable_lora=bool(vllm_config.lora_config))
lora_config=vllm_config.lora_config)
self.tokenizer.ping()

# Processor (convert Inputs --> EngineCoreRequests)
Expand Down

0 comments on commit 01d079f

Please sign in to comment.