From 85362f028c0324d8d00b0438f29c3d9f64737b9a Mon Sep 17 00:00:00 2001 From: Jiaxin Shan Date: Thu, 12 Dec 2024 01:25:16 -0800 Subject: [PATCH] [Misc][LoRA] Ensure Lora Adapter requests return adapter name (#11094) Signed-off-by: Jiaxin Shan Signed-off-by: Jee Jee Li Co-authored-by: Jee Jee Li --- tests/entrypoints/openai/test_serving_engine.py | 11 +++++++++++ vllm/entrypoints/openai/serving_chat.py | 14 ++++++++------ vllm/entrypoints/openai/serving_completion.py | 2 +- vllm/entrypoints/openai/serving_engine.py | 13 +++++++++++++ 4 files changed, 33 insertions(+), 7 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py index 6199a75b5b4f8..096ab6fa0ac09 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -9,6 +9,7 @@ LoadLoraAdapterRequest, UnloadLoraAdapterRequest) from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.lora.request import LoRARequest MODEL_NAME = "meta-llama/Llama-2-7b" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] @@ -33,6 +34,16 @@ async def _async_serving_engine_init(): return serving_engine +@pytest.mark.asyncio +async def test_serving_model_name(): + serving_engine = await _async_serving_engine_init() + assert serving_engine._get_model_name(None) == MODEL_NAME + request = LoRARequest(lora_name="adapter", + lora_path="/path/to/adapter2", + lora_int_id=1) + assert serving_engine._get_model_name(request) == request.lora_name + + @pytest.mark.asyncio async def test_load_lora_adapter_success(): serving_engine = await _async_serving_engine_init() diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0738210e27cb6..a5e7b4ac3bb30 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -123,6 +123,8 @@ async def create_chat_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) + model_name = self._get_model_name(lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) tool_parser = self.tool_parser @@ -238,13 +240,13 @@ async def create_chat_completion( # Streaming response if request.stream: return self.chat_completion_stream_generator( - request, result_generator, request_id, conversation, tokenizer, - request_metadata) + request, result_generator, request_id, model_name, + conversation, tokenizer, request_metadata) try: return await self.chat_completion_full_generator( - request, result_generator, request_id, conversation, tokenizer, - request_metadata) + request, result_generator, request_id, model_name, + conversation, tokenizer, request_metadata) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -259,11 +261,11 @@ async def chat_completion_stream_generator( request: ChatCompletionRequest, result_generator: AsyncIterator[RequestOutput], request_id: str, + model_name: str, conversation: List[ConversationMessage], tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, ) -> AsyncGenerator[str, None]: - model_name = self.base_model_paths[0].name created_time = int(time.time()) chunk_object_type: Final = "chat.completion.chunk" first_iteration = True @@ -604,12 +606,12 @@ async def chat_completion_full_generator( request: ChatCompletionRequest, result_generator: AsyncIterator[RequestOutput], request_id: str, + model_name: str, conversation: List[ConversationMessage], tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, ) -> Union[ErrorResponse, ChatCompletionResponse]: - model_name = self.base_model_paths[0].name created_time = int(time.time()) final_res: Optional[RequestOutput] = None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index ee97d35f2b087..b3436773062f3 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -85,7 +85,6 @@ async def create_completion( return self.create_error_response( "suffix is not currently supported") - model_name = self.base_model_paths[0].name request_id = f"cmpl-{self._base_request_id(raw_request)}" created_time = int(time.time()) @@ -162,6 +161,7 @@ async def create_completion( result_generator = merge_async_iterators( *generators, is_cancelled=raw_request.is_disconnected) + model_name = self._get_model_name(lora_request) num_prompts = len(engine_prompts) # Similar to the OpenAI API, when n != best_of, we do not stream the diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 63f27b955461e..d5ad4354c78be 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -661,3 +661,16 @@ async def unload_lora_adapter( def _is_model_supported(self, model_name): return any(model.name == model_name for model in self.base_model_paths) + + def _get_model_name(self, lora: Optional[LoRARequest]): + """ + Returns the appropriate model name depending on the availability + and support of the LoRA or base model. + Parameters: + - lora: LoRARequest that contain a base_model_name. + Returns: + - str: The name of the base model or the first available model path. + """ + if lora is not None: + return lora.lora_name + return self.base_model_paths[0].name