Skip to content

Commit

Permalink
Merge branch 'main' into tms/v1_tp: instance id
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
tlrmchlsmth committed Dec 7, 2024
2 parents 5271ec6 + 1b62745 commit 50a12bc
Show file tree
Hide file tree
Showing 53 changed files with 1,054 additions and 1,023 deletions.
4 changes: 3 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ steps:
source_file_dependencies:
- vllm/lora
- tests/lora
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore lora/test_long_context.py lora/test_chatglm3_tp.py lora/test_llama_tp.py
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py
parallelism: 4

- label: "PyTorch Fullgraph Smoke Test" # 9min
Expand Down Expand Up @@ -362,6 +362,7 @@ steps:
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/vision_language
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
- pytest -v -s models/embedding/vision_language -m core_model
Expand All @@ -377,6 +378,7 @@ steps:
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/vision_language
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
# HACK - run phi3v tests separately to sidestep this transformers bug
# https://github.com/huggingface/transformers/issues/34307
Expand Down
11 changes: 4 additions & 7 deletions csrc/attention/paged_attention_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,10 @@ void paged_attention_v1_launcher(
blocksparse_block_size, blocksparse_head_sliding_step);

#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
case true: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
break; \
case false: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \
if (is_block_sparse) { \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
} else { \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
}

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
Expand Down
11 changes: 4 additions & 7 deletions csrc/attention/paged_attention_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,10 @@ void paged_attention_v2_launcher(
blocksparse_head_sliding_step);

#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
case true: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
break; \
case false: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \
if (is_block_sparse) { \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
} else { \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
}

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
Expand Down
2 changes: 1 addition & 1 deletion csrc/mamba/causal_conv1d/causal_conv1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
// (which occurs when `final_state_position` is a non-positivie index)
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
if (final_state_position < 0 && seqlen > kWidth){
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
input_t vals_load[kNElts] = {0};
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
// chunk = n_chunks - 2, a segment of the final state sits in the last index
Expand Down
10 changes: 7 additions & 3 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -547,15 +547,15 @@ Text Generation
- ✅︎
-
* - :code:`InternVLChatModel`
- InternVL2
- InternVL 2.5, Mono-InternVL, InternVL 2.0
- T + I\ :sup:`E+`
- :code:`OpenGVLab/Mono-InternVL-2B`, :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
- :code:`OpenGVLab/InternVL2_5-4B`, :code:`OpenGVLab/Mono-InternVL-2B`, :code:`OpenGVLab/InternVL2-4B`, etc.
-
- ✅︎
* - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5
- T + I\ :sup:`E+`
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc.
-
- ✅︎
* - :code:`LlavaNextForConditionalGeneration`
Expand Down Expand Up @@ -664,6 +664,10 @@ Text Generation
.. note::
vLLM currently only supports adding LoRA to the language backbone of multimodal models.

.. note::
To use :code:`TIGER-Lab/Mantis-8B-siglip-llama3`, you have to install their GitHub repo (:code:`pip install git+https://github.com/TIGER-AI-Lab/Mantis.git`)
and pass :code:`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.

.. note::
The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
Expand Down
3 changes: 3 additions & 0 deletions docs/source/usage/spec_decode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ Speculative decoding
not usually yield inter-token latency reductions for all prompt datasets or sampling parameters. The work
to optimize it is ongoing and can be followed in `this issue. <https://github.com/vllm-project/vllm/issues/4630>`_

.. warning::
Currently, speculative decoding in vLLM is not compatible with pipeline parallelism.

This document shows how to use `Speculative Decoding <https://x.com/karpathy/status/1697318534555336961>`_ with vLLM.
Speculative decoding is a technique which improves inter-token latency in memory-bound LLM inference.

Expand Down
19 changes: 18 additions & 1 deletion examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def run_internvl(question: str, modality: str):
# Stop tokens for InternVL
# models variants may have different stop tokens
# please refer to the model card for the correct "stop words":
# https://huggingface.co/OpenGVLab/InternVL2-2B#service
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return llm, prompt, stop_token_ids
Expand Down Expand Up @@ -419,6 +419,22 @@ def run_aria(question: str, modality: str):
return llm, prompt, stop_token_ids


# Mantis
def run_mantis(question: str, modality: str):
assert modality == "image"

llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501
prompt = llama3_template.format(f"{question}\n<image>")

llm = LLM(
model="TIGER-Lab/Mantis-8B-siglip-llama3",
max_model_len=4096,
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
)
stop_token_ids = [128009]
return llm, prompt, stop_token_ids


model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
Expand All @@ -441,6 +457,7 @@ def run_aria(question: str, modality: str):
"glm4v": run_glm4v,
"idefics3": run_idefics3,
"aria": run_aria,
"mantis": run_mantis,
}


Expand Down
2 changes: 1 addition & 1 deletion examples/offline_inference_vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
# Stop tokens for InternVL
# models variants may have different stop tokens
# please refer to the model card for the correct "stop words":
# https://huggingface.co/OpenGVLab/InternVL2-2B#service
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]

Expand Down
3 changes: 0 additions & 3 deletions requirements-test.in
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ mistral_common[opencv] >= 1.5.0 # required for pixtral test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.4 # required for model evaluation test

# TODO: Add this after fully implementing llava(mantis)
# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test

# quantization
bitsandbytes>=0.44.0
buildkite-test-collector==0.1.9
Expand Down
16 changes: 13 additions & 3 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,19 @@ def _compare_tp(
*,
method: Literal["generate", "encode"],
):
tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup
multi_node_only, trust_remote_code, tokenizer_mode, \
load_format, hf_overrides = test_options
(
tp_size,
pp_size,
eager_mode,
chunked_prefill,
) = parallel_setup
(
multi_node_only,
trust_remote_code,
tokenizer_mode,
load_format,
hf_overrides,
) = test_options

if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
Expand Down
39 changes: 22 additions & 17 deletions tests/kernels/test_causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,14 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("has_initial_state", [True, False])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize(
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
@pytest.mark.parametrize('dim', [64])
@pytest.mark.parametrize('batch', [1])
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
itype):
has_initial_state, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
Expand All @@ -167,11 +168,18 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,

weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
initial_states = torch.randn(batch,
dim,
width - 1,
device=device,
dtype=itype)
if has_initial_state:
initial_states = torch.randn(batch,
dim,
width - 1,
device=device,
dtype=itype)
has_initial_state_tensor = torch.ones(batch,
dtype=torch.bool,
device=x.device)
else:
initial_states = None
has_initial_state_tensor = None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
Expand All @@ -183,31 +191,28 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=torch.ones(batch,
dtype=torch.bool,
device=x.device))
has_initial_state=has_initial_state_tensor)
out_ref, final_states_ref = causal_conv1d_ref(
x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=True,
activation=activation)
assert initial_states is not None and final_states_ref is not None
assert torch.allclose(initial_states,
final_states_ref,
rtol=rtol,
atol=atol)
if has_initial_state:
assert initial_states is not None and final_states_ref is not None
assert torch.allclose(initial_states,
final_states_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)

causal_conv1d_opcheck_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=torch.ones(batch,
dtype=torch.bool,
device=x.device))
has_initial_state=has_initial_state_tensor)


@pytest.mark.parametrize("itype", [torch.bfloat16])
Expand Down
30 changes: 22 additions & 8 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"dtype": "half",
"max_tokens": 5,
"tensor_parallel_size": 2,
"model_kwargs": {"device_map": "auto"},
"hf_model_kwargs": {"device_map": "auto"},
"image_size_factors": [(.25, 0.5, 1.0)],
"distributed_executor_backend": (
"ray",
Expand Down Expand Up @@ -108,7 +108,7 @@
"cherry_blossom": "What is in the picture?",
}),
auto_cls=AutoModelForVision2Seq,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
Expand Down Expand Up @@ -151,7 +151,7 @@
"cherry_blossom": "<vlm_image>Please infer the season with reason.",
}),
multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
postprocess_inputs=model_utils.get_key_type_post_processor("pixel_values"),
postprocess_inputs=model_utils.cast_dtype_post_processor("pixel_values"),
stop_str=["<|im_end|>"],
image_size_factors=[(0.10, 0.15)],
max_tokens=64,
Expand All @@ -177,7 +177,7 @@
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
max_model_len=4096,
auto_cls=AutoModelForVision2Seq,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
# For chameleon, we only compare the sequences
Expand Down Expand Up @@ -281,7 +281,7 @@
prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
num_video_frames=16,
max_model_len=16384,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values_videos"
),
auto_cls=AutoModelForVision2Seq,
Expand All @@ -306,6 +306,20 @@
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
),
"mantis": VLMTestInfo(
models=["TIGER-Lab/Mantis-8B-siglip-llama3"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
max_model_len=4096,
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
vllm_runner_kwargs={"hf_overrides": {"architectures": ["MantisForConditionalGeneration"]}}, # noqa: E501
get_stop_token_ids=lambda tok: [128009],
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output,
patch_hf_runner=model_utils.mantis_patch_hf_runner,
),
"minicpmv_25": VLMTestInfo(
models=["openbmb/MiniCPM-Llama3-V-2_5"],
test_type=VLMTestType.IMAGE,
Expand Down Expand Up @@ -342,7 +356,7 @@
# max_num_seqs=2,
# task="generate",
# # use eager mode for hf runner since phi3v didn't work with flash_attn
# model_kwargs={"_attn_implementation": "eager"},
# hf_model_kwargs={"_attn_implementation": "eager"},
# use_tokenizer_eos=True,
# vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output,
# num_logprobs=10,
Expand Down Expand Up @@ -373,7 +387,7 @@
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
max_model_len=4096,
auto_cls=AutoModelForVision2Seq,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2],
Expand Down Expand Up @@ -438,7 +452,7 @@
test_type=VLMTestType.CUSTOM_INPUTS,
max_model_len=16384,
max_num_seqs=2,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
auto_cls=AutoModelForVision2Seq,
Expand Down
Loading

0 comments on commit 50a12bc

Please sign in to comment.