Skip to content

Commit

Permalink
[Feature] [Spec decode]: Combine chunked prefill with speculative dec…
Browse files Browse the repository at this point in the history
…oding (vllm-project#9291)

Signed-off-by: NickLucche <[email protected]>
  • Loading branch information
NickLucche authored Nov 7, 2024
1 parent 0eef957 commit aa35e73
Show file tree
Hide file tree
Showing 17 changed files with 477 additions and 147 deletions.
34 changes: 0 additions & 34 deletions tests/spec_decode/e2e/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,6 @@
from .conftest import get_output_from_llm_generator


@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"enable_chunked_prefill": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
"""Verify that speculative decoding with chunked prefill fails.
"""
output_len = 128
temperature = 0.0

prompts = [
"Hello, my name is",
]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

with pytest.raises(ValueError,
match="Speculative decoding and chunked prefill"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)


@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",
Expand Down
105 changes: 102 additions & 3 deletions tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
# Chunked prefill enabled with small value
# to make sure we get mixed batches.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
{
# Verify the detokenizer assertions in the test work when spec
Expand Down Expand Up @@ -141,6 +151,14 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
},
])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -204,6 +222,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -255,6 +281,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("max_output_len", [
Expand Down Expand Up @@ -300,6 +334,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("batch_size", [1])
Expand Down Expand Up @@ -347,6 +389,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("batch_size", [32])
Expand Down Expand Up @@ -397,6 +447,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -454,6 +512,14 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("batch_size", [2])
Expand Down Expand Up @@ -503,6 +569,15 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
"speculative_max_model_len": 32,
},
])
@pytest.mark.parametrize("batch_size", [8])
Expand Down Expand Up @@ -551,6 +626,15 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
},
])
@pytest.mark.parametrize("batch_size", [8])
Expand Down Expand Up @@ -590,10 +674,17 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"enable_chunked_prefill": False,
}
# Try a range of common k, as well as large speculation.
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
])
] + [{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
} for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
Expand Down Expand Up @@ -636,11 +727,19 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
"enable_chunked_prefill": False
}
# Try a range of common k.
for k in [1, 2, 3]
])
] + [{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
} for k in [1, 2, 3]])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize(
"output_len",
Expand Down
36 changes: 35 additions & 1 deletion tests/spec_decode/e2e/test_ngram_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,33 @@
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize("output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
prefill_chunk_size: int, seed: int):
"""Verify greedy equality on a tiny model with different batch size."""
if prefill_chunk_size > 0:
common_llm_kwargs.update(
**{
"enable_chunked_prefill": True,
"max_num_batched_tokens": prefill_chunk_size,
"max_num_seqs": prefill_chunk_size
})
else:
common_llm_kwargs["enable_chunked_prefill"] = False
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
Expand Down Expand Up @@ -151,6 +166,16 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"enable_chunked_prefill": False,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -251,6 +276,15 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4
}, {
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4,
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
Expand Down
9 changes: 7 additions & 2 deletions tests/spec_decode/test_ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)

for sg in seq_group_metadata_list:
sg.is_prompt = False
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
Expand Down Expand Up @@ -147,7 +148,7 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
def test_ngram_algo_correctness_for_batches_match_all():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find candidate in all batchs
For the scenario find candidate in all batches
"""

block_size = 32
Expand Down Expand Up @@ -192,6 +193,10 @@ def test_ngram_algo_correctness_for_batches_match_all():
block_size,
final_prompt_lens=final_prompt_lens)

# Normally drafter is run on decode requests only; here we check the output
# of the ngram worker as it is the sole proposer that has no forward.
for sg in seq_group_metadata_list:
sg.is_prompt = False
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
Expand Down
31 changes: 27 additions & 4 deletions tests/spec_decode/test_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ def assert_score_equal(score1: SpeculativeScores,
@pytest.mark.parametrize('max_propose_len', [1, 3, 5])
@pytest.mark.parametrize('mixed_propose_len', [True])
@pytest.mark.parametrize('device', ['cuda'])
@pytest.mark.parametrize('prefill_chunking', [False, True])
def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
mixed_propose_len: bool, device: str) -> None:
mixed_propose_len: bool, device: str,
prefill_chunking: bool) -> None:
"""
Compare the batch expansion scorer and mqa scorer return the same score.
We test for both queries with the same propose length and different
propose length.
propose length, as well as mixed prefill-decode batches.
"""
seed = 0
block_size = 32
Expand All @@ -67,16 +69,37 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
if not mixed_propose_len:
propose_lens = [max_propose_len] * batch_size
else:
non_zero_cnt = random.randint(0, batch_size)
# There must be at least 1 decode request, otherwise
# we have nothing to score (`_run_no_spec`).
non_zero_cnt = random.randint(1, batch_size)
propose_lens = [max_propose_len
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
random.shuffle(propose_lens)

proposals = create_proposal(propose_lens, vocab_size, device)
seq_group_metadatalist, _, _ = create_batch(batch_size,
max_propose_len,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks)

if mixed_propose_len and prefill_chunking and (n_prefills :=
batch_size - non_zero_cnt):
prefill, _, _ = create_batch(n_prefills,
None,
prefill_chunk_size=4,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
seq_ids=list(
range(batch_size,
batch_size + n_prefills)))
# re-order to guarantee prefill|decode order
target_group_metadatalist = [
seq_group_metadatalist[i] for i, p in enumerate(propose_lens)
if p > 0
]
seq_group_metadatalist = prefill + target_group_metadatalist
propose_lens = [0] * n_prefills + [p for p in propose_lens if p > 0]

proposals = create_proposal(propose_lens, vocab_size, device)
requests = ExecuteModelRequest(seq_group_metadatalist,
num_lookahead_slots=max_propose_len)

Expand Down
Loading

0 comments on commit aa35e73

Please sign in to comment.