Skip to content

Commit

Permalink
Reduce model len
Browse files Browse the repository at this point in the history
Signed-off-by: Sumit Vij <[email protected]>
  • Loading branch information
thedebugger committed Jan 22, 2025
1 parent dd9a03f commit d247036
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tests/lora/test_ultravox.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import shutil
import torch
from os import path
from tempfile import TemporaryDirectory
from typing import List, Tuple

import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file
from transformers import AutoTokenizer
Expand Down Expand Up @@ -63,8 +63,8 @@ def test_ultravox_lora(vllm_runner):
"""
TODO: Train an Ultravox LoRA instead of using a Llama LoRA.
"""
#Check if set_default_device fixes the CI failure. Other lora tests set device to cuda
#which might be causing device mismatch in CI
#Check if set_default_device fixes the CI failure. Other lora tests set
# device to cuda which might be causing device mismatch in CI
torch.set_default_device("cpu")

llama3_1_8b_chess_lora = llama3_1_8b_chess_lora_path()
Expand All @@ -79,6 +79,7 @@ def test_ultravox_lora(vllm_runner):
max_loras=4,
max_lora_rank=128,
dtype="bfloat16",
max_model_len = 256,
) as vllm_model:
ultravox_outputs: List[Tuple[
List[int], str]] = vllm_model.generate_greedy(
Expand All @@ -100,6 +101,7 @@ def test_ultravox_lora(vllm_runner):
max_loras=4,
max_lora_rank=128,
dtype="bfloat16",
max_model_len = 256,
) as vllm_model:
llama_outputs_no_lora: List[Tuple[List[int], str]] = (
vllm_model.generate_greedy(
Expand Down

0 comments on commit d247036

Please sign in to comment.