Skip to content

Commit

Permalink
Rollback unwanted changes and format fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
thedebugger committed Dec 31, 2024
1 parent e2466f2 commit 771484d
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 54 deletions.
2 changes: 1 addition & 1 deletion tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def llama3_1_8b_chess_lora():

@pytest.fixture(scope="session")
def llama3_1_8b_ultravox_chess_lora():
# ultravox chess lora is result of transformation of above chess lora for llama
# ultravox chess lora is result of transformation of above chess llama lora
return snapshot_download(repo_id="thedebugger11/ultravox-chess-lora")


Expand Down
52 changes: 36 additions & 16 deletions tests/lora/test_ultravox.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import List, Tuple

from transformers import AutoTokenizer

from vllm.lora.request import LoRARequest

from ..models.utils import check_outputs_equal

ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3"
Expand All @@ -12,7 +14,7 @@
PROMPT = "Tell me about a silly chess move in 20 words"


def _get_prompt(audio_count, question, placeholder, model_name)->str:
def _get_prompt(audio_count, question, placeholder, model_name) -> str:
tokenizer = AutoTokenizer.from_pretrained(model_name)
placeholder = f"{placeholder}\n" * audio_count

Expand All @@ -36,12 +38,18 @@ def test_ultravox_lora(vllm_runner, llama3_1_8b_chess_lora,
dtype="bfloat16",
max_model_len=4096,
) as vllm_model:
ultravox_outputs: List[Tuple[List[int], str]] = vllm_model.generate_greedy(
[_get_prompt(0, PROMPT, VLLM_PLACEHOLDER, ULTRAVOX_MODEL_NAME)],
256,
lora_request=LoRARequest(str(1), 1,
llama3_1_8b_ultravox_chess_lora),
)
ultravox_outputs: List[Tuple[List[int],
str]] = vllm_model.generate_greedy(
[
_get_prompt(
0, PROMPT, VLLM_PLACEHOLDER,
ULTRAVOX_MODEL_NAME)
],
256,
lora_request=LoRARequest(
str(1), 1,
llama3_1_8b_ultravox_chess_lora),
)

# run llama with and without lora to compare outputs with above
with vllm_runner(
Expand All @@ -54,15 +62,27 @@ def test_ultravox_lora(vllm_runner, llama3_1_8b_chess_lora,
dtype="bfloat16",
max_model_len=4096,
) as vllm_model:
llama_outputs_no_lora: List[Tuple[List[int], str]] = vllm_model.generate_greedy(
[_get_prompt(0, PROMPT, VLLM_PLACEHOLDER, LLMA_MODEL_NAME)],
256,
)
llama_outputs: List[Tuple[List[int], str]] = vllm_model.generate_greedy(
[_get_prompt(0, PROMPT, VLLM_PLACEHOLDER, LLMA_MODEL_NAME)],
256,
lora_request=LoRARequest(str(1), 1, llama3_1_8b_chess_lora),
)
llama_outputs_no_lora: List[Tuple[List[int],
str]] = vllm_model.generate_greedy(
[
_get_prompt(
0, PROMPT,
VLLM_PLACEHOLDER,
LLMA_MODEL_NAME)
],
256,
)
llama_outputs: List[Tuple[List[int],
str]] = vllm_model.generate_greedy(
[
_get_prompt(0, PROMPT,
VLLM_PLACEHOLDER,
LLMA_MODEL_NAME)
],
256,
lora_request=LoRARequest(
str(1), 1, llama3_1_8b_chess_lora),
)

check_outputs_equal(
outputs_0_lst=ultravox_outputs,
Expand Down
12 changes: 4 additions & 8 deletions vllm/assets/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,15 @@
ASSET_DIR = "multimodal_asset"


@dataclass
@dataclass(frozen=True)
class AudioAsset:
name: Literal["winning_call", "mary_had_lamb"]

def __init__(self, audio_path=None):
if audio_path is None:
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR)
self._audio_path = audio_path

@property
def audio_and_sample_rate(self) -> Tuple[np.ndarray, int]:
y, sr = librosa.load(self._audio_path, sr=None)
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR)
y, sr = librosa.load(audio_path, sr=None)
assert isinstance(sr, int)
return y, sr

Expand Down
18 changes: 4 additions & 14 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,6 @@ def activate_adapter(
for module_name, module in self.modules.items():
module_lora = lora_model.get_lora(module_name)
if module_lora:
logger.debug("Setting LoRA. int id: %d, module: %s",
lora_model.id, module_name)
module_lora.optimize()
# Bias is not explicitly enabled with the flag enable_lora_bias.
bias = module_lora.bias
Expand All @@ -409,7 +407,8 @@ def activate_adapter(

if len(missing_modules) > 0:
logger.warning(
"Lora adapter int id %d is activated but is missing base model modules %s",
"Lora adapter int id %d is activated but is missing \
base model modules %s which could impact output",
lora_model.id, missing_modules)
return True

Expand Down Expand Up @@ -467,10 +466,6 @@ def _create_lora_modules(self):
for module_name, module in self.model.named_modules(
remove_duplicate=False):

logger.debug(
"Create lora module if applicable %s",
module_name,
)
if isinstance(module, PPMissingLayer):
continue
if not self._match_target_modules(module_name):
Expand Down Expand Up @@ -517,15 +512,12 @@ def _create_lora_modules(self):
if self.supports_mm and not isinstance(new_module,
BaseLayerWithLoRA):
logger.warning(
"%s module will be ignored because it isn't of type BaseLayerWithLoRA",
"%s module will be ignored because it isn't of type \
BaseLayerWithLoRA",
module_name,
)
continue

logger.debug(
"Going to apply lora on %s module",
module_name,
)
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
Expand All @@ -541,7 +533,6 @@ def create_dummy_lora(
rank: int,
scaling_factor: Optional[float],
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
logger.debug(f"Creating a dummy lora with id: {lora_id}")
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules():
Expand Down Expand Up @@ -654,7 +645,6 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
if replacement_loras[i]:
continue
replacement_loras[i] = None

lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras)

Expand Down
14 changes: 7 additions & 7 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,14 +467,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):

# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
"lm_head"
]
embedding_modules = {}
embedding_padding_modules = []
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings"
}
embedding_padding_modules = ["lm_head"]

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
Expand Down
15 changes: 7 additions & 8 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataDict,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

from .interfaces import SupportsMultiModal, SupportsPP, SupportsLoRA
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings_from_map)
Expand Down Expand Up @@ -303,14 +303,16 @@ def forward(
"audio", get_ultravox_max_audio_tokens)
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
# same as llamaforcasuallm (language model) minus embedding and other modules
# embedding modules haven't been added as a caution since it could impact text
# but not audio
# same as llamaforcasuallm (language model) minus embedding and other
# modules. embedding modules haven't been added as a caution
# since it could affect text but not audio
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}

#lm_head is not added for now since it requires logits_processor
# which is missing from ultravox
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj"
]
Expand All @@ -325,9 +327,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.multi_modal_config = multimodal_config
assert self.multi_modal_config

#TODO: figure out if these prefixes need tweaking to support LoRA and/or
#use LLMWrapper or not like this https://github.com/vllm-project/vllm/pull/7199/files#diff-7b8a4e258637b7c94389c745c449c52137d33cf92957f3e5bcb18a0ee204b21bR807

self.secondary_weights = []
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
if config.audio_model_id is not None:
Expand Down

0 comments on commit 771484d

Please sign in to comment.