Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA Support for Ultravox model #11253

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,14 +734,16 @@ def generate(
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
sampling_params=sampling_params,
**kwargs)

outputs: List[Tuple[List[List[int]], List[str]]] = []
for req_output in req_outputs:
Expand Down Expand Up @@ -779,6 +781,7 @@ def generate_w_logprobs(
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
**kwargs: Any,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
inputs = self.get_inputs(prompts,
Expand All @@ -787,7 +790,8 @@ def generate_w_logprobs(
audios=audios)

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
sampling_params=sampling_params,
**kwargs)

toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
Expand Down Expand Up @@ -823,13 +827,15 @@ def generate_greedy(
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts,
greedy_params,
images=images,
videos=videos,
audios=audios)
audios=audios,
**kwargs)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]

Expand All @@ -844,6 +850,7 @@ def generate_greedy_logprobs(
videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
Expand All @@ -858,7 +865,8 @@ def generate_greedy_logprobs(
greedy_logprobs_params,
images=images,
audios=audios,
videos=videos)
videos=videos,
**kwargs)

def generate_encoder_decoder_greedy_logprobs(
self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to this file are not related to this PR, please revert.

Expand Down
133 changes: 133 additions & 0 deletions tests/lora/test_ultravox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import shutil
from os import path
from tempfile import TemporaryDirectory
from typing import List, Tuple

from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file
from transformers import AutoTokenizer

from vllm.lora.request import LoRARequest

from ..models.utils import check_outputs_equal

ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3"
LLMA_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"

VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"

PROMPT = "Tell me about a Fool's mate move in 20 words. Provide the moves!"


def llama3_1_8b_chess_lora_path():
return snapshot_download(
repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b")


# can't use llama lora adapter without module name transformation
# because ultravox nest language model
def transform_module_names_for_ultravox(state_dict):
transformed_state_dict = {}
for key, value in state_dict.items():
new_key = key.replace("base_model.model",
"base_model.model.language_model")
transformed_state_dict[new_key] = value
return transformed_state_dict


def mk_llama3_1_8b_ultravox_chess_lora(source_repo, target_path):
tensor_file = "adapter_model.safetensors"
state_dict = load_file(path.join(source_repo, tensor_file))
transformed_state_dict = transform_module_names_for_ultravox(state_dict)

save_file(transformed_state_dict, path.join(target_path, tensor_file))

config_file = "adapter_config.json"
shutil.copyfile(path.join(source_repo, config_file),
path.join(target_path, config_file))
return target_path


def _get_prompt(audio_count, question, placeholder, model_name) -> str:
tokenizer = AutoTokenizer.from_pretrained(model_name)
placeholder = f"{placeholder}\n" * audio_count

return tokenizer.apply_chat_template([{
'role': 'user',
'content': f"{placeholder}{question}"
}],
tokenize=False,
add_generation_prompt=True)


def test_ultravox_lora(vllm_runner):
llama3_1_8b_chess_lora = llama3_1_8b_chess_lora_path()
with TemporaryDirectory() as temp_ultravox_lora_dir:
llama3_1_8b_ultravox_chess_lora = mk_llama3_1_8b_ultravox_chess_lora(
llama3_1_8b_chess_lora, temp_ultravox_lora_dir)
with vllm_runner(
ULTRAVOX_MODEL_NAME,
enforce_eager=True,
max_num_seqs=128,
enable_lora=True,
max_loras=4,
max_lora_rank=128,
dtype="bfloat16",
max_model_len=4096,
) as vllm_model:
ultravox_outputs: List[Tuple[
List[int], str]] = vllm_model.generate_greedy(
[
_get_prompt(0, PROMPT, VLLM_PLACEHOLDER,
ULTRAVOX_MODEL_NAME)
],
256,
lora_request=LoRARequest(str(1), 1,
llama3_1_8b_ultravox_chess_lora),
)

# run llama with and without lora to compare outputs with above
with vllm_runner(
LLMA_MODEL_NAME,
enforce_eager=True,
max_num_seqs=128,
enable_lora=True,
max_loras=4,
max_lora_rank=128,
dtype="bfloat16",
max_model_len=4096,
) as vllm_model:
llama_outputs_no_lora: List[Tuple[List[int],
str]] = vllm_model.generate_greedy(
[
_get_prompt(
0, PROMPT,
VLLM_PLACEHOLDER,
LLMA_MODEL_NAME)
],
256,
)
llama_outputs: List[Tuple[List[int],
str]] = vllm_model.generate_greedy(
[
_get_prompt(0, PROMPT,
VLLM_PLACEHOLDER,
LLMA_MODEL_NAME)
],
256,
lora_request=LoRARequest(
str(1), 1, llama3_1_8b_chess_lora),
)

check_outputs_equal(
thedebugger marked this conversation as resolved.
Show resolved Hide resolved
outputs_0_lst=ultravox_outputs,
outputs_1_lst=llama_outputs,
name_0="ultravox",
name_1="llama",
)

_, llama_no_lora_str = llama_outputs_no_lora[0]
_, ultravox_str = ultravox_outputs[0]

# verify that text don't match with no lora
assert llama_no_lora_str != ultravox_str
thedebugger marked this conversation as resolved.
Show resolved Hide resolved
32 changes: 30 additions & 2 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
Expand All @@ -31,7 +32,7 @@
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings,
Expand Down Expand Up @@ -349,7 +350,21 @@ def forward(
info=UltravoxProcessingInfo,
dummy_inputs=UltravoxDummyInputsBuilder
)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):

packed_modules_mapping = {
thedebugger marked this conversation as resolved.
Show resolved Hide resolved
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
thedebugger marked this conversation as resolved.
Show resolved Hide resolved

# LoRA specific attributes
# which is missing from ultravox
# TODO : Add LoRA to the audio tower and projector.
supported_lora_modules = [
thedebugger marked this conversation as resolved.
Show resolved Hide resolved
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "lm_head"
]
thedebugger marked this conversation as resolved.
Show resolved Hide resolved
embedding_modules = {}
thedebugger marked this conversation as resolved.
Show resolved Hide resolved
embedding_padding_modules = []

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
Expand Down Expand Up @@ -379,6 +394,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
# logits_processor is added here to support 'lm_head' LoRA module
# for language model
self.logits_processor = self.language_model.logits_processor
Copy link
Author

@thedebugger thedebugger Jan 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeejeelee i had to add this because vllm fails here. Do you know why lm_head is special? Can you think of any negative impact with this change on ultravox?

cc @Yard1 which has a related comment to make this code more robust

if config.text_model_id is not None:
# this prefix is not for initialization, but for loading weights
# note the trailing dot
Expand All @@ -397,6 +415,16 @@ def sampler(self):

return get_sampler()

def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model.",
connector="multi_modal_projector.",
tower_model="audio_tower.",
)

def _audio_features_to_embeddings(
self, input_features: torch.Tensor) -> torch.Tensor:
audio_input = input_features.to(self.audio_tower.dtype)
Expand Down
Loading