diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py index 629074188a6c1..af8397c235f48 100644 --- a/tests/spec_decode/e2e/test_compatibility.py +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -5,40 +5,6 @@ from .conftest import get_output_from_llm_generator -@pytest.mark.parametrize("common_llm_kwargs", [{ - "model": "JackFram/llama-68m", - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, -}]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [ - { - "enable_chunked_prefill": True, - }, -]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_xfail_chunked_prefill(test_llm_generator): - """Verify that speculative decoding with chunked prefill fails. - """ - output_len = 128 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - ] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - with pytest.raises(ValueError, - match="Speculative decoding and chunked prefill"): - get_output_from_llm_generator(test_llm_generator, prompts, - sampling_params) - - @pytest.mark.parametrize("common_llm_kwargs", [{ "model": "meta-llama/Llama-2-7b-chat-hf", "speculative_model": "JackFram/llama-68m", diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 5f240d42d9e09..a13cca41f99e5 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -62,6 +62,16 @@ { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + "enable_chunked_prefill": False, + }, + { + # Chunked prefill enabled with small value + # to make sure we get mixed batches. + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4 }, { # Verify the detokenizer assertions in the test work when spec @@ -141,6 +151,14 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + "enable_chunked_prefill": False, + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4, }, ]) @pytest.mark.parametrize( @@ -204,6 +222,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + "enable_chunked_prefill": False, + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4 }, ]) @pytest.mark.parametrize( @@ -255,6 +281,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + "enable_chunked_prefill": False, + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4 }, ]) @pytest.mark.parametrize("max_output_len", [ @@ -300,6 +334,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + "enable_chunked_prefill": False, + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4 }, ]) @pytest.mark.parametrize("batch_size", [1]) @@ -347,6 +389,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + "enable_chunked_prefill": False, + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4 }, ]) @pytest.mark.parametrize("batch_size", [32]) @@ -397,6 +447,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + "enable_chunked_prefill": False, + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4 }, ]) @pytest.mark.parametrize( @@ -454,6 +512,14 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + "enable_chunked_prefill": False, + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4 }, ]) @pytest.mark.parametrize("batch_size", [2]) @@ -503,6 +569,15 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs, # Artificially limit the draft model max model len; this forces vLLM # to skip speculation once the sequences grow beyond 32-k tokens. "speculative_max_model_len": 32, + "enable_chunked_prefill": False, + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4, + "speculative_max_model_len": 32, }, ]) @pytest.mark.parametrize("batch_size", [8]) @@ -551,6 +626,15 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs, "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, "speculative_disable_by_batch_size": 2, + "enable_chunked_prefill": False, + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "speculative_disable_by_batch_size": 2, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4, }, ]) @pytest.mark.parametrize("batch_size", [8]) @@ -590,10 +674,17 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs, { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": k, + "enable_chunked_prefill": False, } # Try a range of common k, as well as large speculation. for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63] - ]) + ] + [{ + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": k, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4, + } for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize( "output_len", @@ -636,11 +727,19 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": k, - "spec_decoding_acceptance_method": "typical_acceptance_sampler" + "spec_decoding_acceptance_method": "typical_acceptance_sampler", + "enable_chunked_prefill": False } # Try a range of common k. for k in [1, 2, 3] - ]) + ] + [{ + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": k, + "spec_decoding_acceptance_method": "typical_acceptance_sampler", + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4 + } for k in [1, 2, 3]]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize( "output_len", diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 31bedad480283..e53d169a8fcc3 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -50,18 +50,33 @@ "num_speculative_tokens": 5, "ngram_prompt_lookup_max": 3, }, + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }, ]) @pytest.mark.parametrize("output_len", [ 256, ]) @pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) @pytest.mark.parametrize("seed", [1]) def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): + prefill_chunk_size: int, seed: int): """Verify greedy equality on a tiny model with different batch size.""" + if prefill_chunk_size > 0: + common_llm_kwargs.update( + **{ + "enable_chunked_prefill": True, + "max_num_batched_tokens": prefill_chunk_size, + "max_num_seqs": prefill_chunk_size + }) + else: + common_llm_kwargs["enable_chunked_prefill"] = False run_equality_correctness_test(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, @@ -151,6 +166,16 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, "speculative_model": "[ngram]", "num_speculative_tokens": 5, "ngram_prompt_lookup_max": 3, + "enable_chunked_prefill": False, + }, + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "enable_chunked_prefill": True, + "speculative_disable_mqa_scorer": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4 }, ]) @pytest.mark.parametrize( @@ -251,6 +276,15 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs, "num_speculative_tokens": 5, "ngram_prompt_lookup_max": 3, "speculative_disable_by_batch_size": 4 + }, { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "speculative_disable_by_batch_size": 4, + "enable_chunked_prefill": True, + "speculative_disable_mqa_scorer": True, + "max_num_batched_tokens": 4, + "max_num_seqs": 4 }]) @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize( diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py index 3995f87898afb..f66e957186604 100644 --- a/tests/spec_decode/test_ngram_worker.py +++ b/tests/spec_decode/test_ngram_worker.py @@ -118,7 +118,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): num_gpu_blocks, block_size, final_prompt_lens=final_prompt_lens) - + for sg in seq_group_metadata_list: + sg.is_prompt = False proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, @@ -147,7 +148,7 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): def test_ngram_algo_correctness_for_batches_match_all(): """Verify our ngram algo find the right candidate in the prompt - For the scenario find candidate in all batchs + For the scenario find candidate in all batches """ block_size = 32 @@ -192,6 +193,10 @@ def test_ngram_algo_correctness_for_batches_match_all(): block_size, final_prompt_lens=final_prompt_lens) + # Normally drafter is run on decode requests only; here we check the output + # of the ngram worker as it is the sole proposer that has no forward. + for sg in seq_group_metadata_list: + sg.is_prompt = False proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, diff --git a/tests/spec_decode/test_scorer.py b/tests/spec_decode/test_scorer.py index e579c8b38db91..0b1509d8b7785 100644 --- a/tests/spec_decode/test_scorer.py +++ b/tests/spec_decode/test_scorer.py @@ -46,12 +46,14 @@ def assert_score_equal(score1: SpeculativeScores, @pytest.mark.parametrize('max_propose_len', [1, 3, 5]) @pytest.mark.parametrize('mixed_propose_len', [True]) @pytest.mark.parametrize('device', ['cuda']) +@pytest.mark.parametrize('prefill_chunking', [False, True]) def test_scorer(model_name: str, batch_size: int, max_propose_len: int, - mixed_propose_len: bool, device: str) -> None: + mixed_propose_len: bool, device: str, + prefill_chunking: bool) -> None: """ Compare the batch expansion scorer and mqa scorer return the same score. We test for both queries with the same propose length and different - propose length. + propose length, as well as mixed prefill-decode batches. """ seed = 0 block_size = 32 @@ -67,16 +69,37 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int, if not mixed_propose_len: propose_lens = [max_propose_len] * batch_size else: - non_zero_cnt = random.randint(0, batch_size) + # There must be at least 1 decode request, otherwise + # we have nothing to score (`_run_no_spec`). + non_zero_cnt = random.randint(1, batch_size) propose_lens = [max_propose_len ] * non_zero_cnt + [0] * (batch_size - non_zero_cnt) random.shuffle(propose_lens) - proposals = create_proposal(propose_lens, vocab_size, device) seq_group_metadatalist, _, _ = create_batch(batch_size, max_propose_len, block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + if mixed_propose_len and prefill_chunking and (n_prefills := + batch_size - non_zero_cnt): + prefill, _, _ = create_batch(n_prefills, + None, + prefill_chunk_size=4, + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + seq_ids=list( + range(batch_size, + batch_size + n_prefills))) + # re-order to guarantee prefill|decode order + target_group_metadatalist = [ + seq_group_metadatalist[i] for i, p in enumerate(propose_lens) + if p > 0 + ] + seq_group_metadatalist = prefill + target_group_metadatalist + propose_lens = [0] * n_prefills + [p for p in propose_lens if p > 0] + + proposals = create_proposal(propose_lens, vocab_size, device) requests = ExecuteModelRequest(seq_group_metadatalist, num_lookahead_slots=max_propose_len) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index e0b7b7d47f1f1..8df143104c279 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -10,6 +10,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.utils import set_random_seed from vllm.sequence import ExecuteModelRequest, SequenceOutput +from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.metrics import (AsyncMetricsCollector, SpecDecodeWorkerMetrics) @@ -819,3 +820,84 @@ def test_handle_finished_requests(): # and 'request-3' are removed from seq_with_bonus_token_in_last_step. assert worker._seq_with_bonus_token_in_last_step == \ {4,5,10} + + +@pytest.mark.parametrize('k', [3]) +@pytest.mark.parametrize('batch_size', [2, 32]) +@pytest.mark.parametrize("batch_composition", + ["prefill_only", "decode_only", "mixed"]) +@torch.inference_mode() +def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str): + """ + Verify SpecDecodeWorker calls match the expected flow. + """ + vocab_size = 32_000 + draft_worker = mock_worker(cls=MultiStepWorker) + target_worker = mock_worker() + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + worker = SpecDecodeWorker(draft_worker, + target_worker, + mock_spec_decode_sampler("rejection_sampler"), + disable_logprobs=False, + metrics_collector=metrics_collector) + exception_secret = 'artificial stop' + worker.scorer = mock_worker(BatchExpansionTop1Scorer) + worker.scorer.score_proposals.side_effect = ValueError(exception_secret) + + # Create batch with combination of terminal/non-terminal prefill chunks + # and decodes (different seq_ids). + decodes, _, _ = create_batch(batch_size, k) + # Pre-chunking here, get 'batch_size' chunks. + prefill, _, _ = create_batch(batch_size, + k, + prefill_chunk_size=4, + seq_ids=list(range(batch_size, + batch_size * 2))) + + if batch_composition == "prefill_only": + n_prefills = batch_size + elif batch_composition == "decode_only": + n_prefills = 0 + else: + n_prefills = random.randint(1, batch_size - 1) + n_decodes = batch_size - n_prefills + + prefill = random.sample(prefill, n_prefills) + decodes = random.sample(decodes, n_decodes) + target_group_metadata_list = prefill + decodes + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=target_group_metadata_list, + num_lookahead_slots=k) + + target_token_ids = torch.randint(low=0, + high=vocab_size, + size=(1, batch_size * (k + 1)), + dtype=torch.int64, + device='cuda') + target_token_probs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') + target_token_logprobs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') + target_output = create_sampler_output_list(target_token_ids, + target_token_probs, + target_token_logprobs) + + target_worker.execute_model.return_value = [target_output[0]] + + if not len(decodes): + worker.execute_model(execute_model_req=execute_model_req) + # no spec run (prefill only) + draft_worker.execute_model.assert_called_once_with(execute_model_req) + target_worker.execute_model.assert_called_once_with(execute_model_req) + else: + # Decode-only run OR mixed batch, scorer call fails (it's mocked) + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_req=execute_model_req) + # but first draft still counted + assert draft_worker.get_spec_proposals.call_count == 1 diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index e5cb0530f9961..a4bfa6b2f384b 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -146,6 +146,41 @@ def create_seq_group_metadata_from_prompts( return seq_grou_metadata_list +def create_chunked_seq_group_metadata_from_prompt( + prompt: List[int], + num_gpu_blocks: int, + chunk_size: int, + block_size: int, + seq_id: Optional[int] = None) -> List[SequenceGroupMetadata]: + + if seq_id is None: + seq_id = 0 + + free_gpu_blocks = list(range(num_gpu_blocks)) + + block_allocations = [ + free_gpu_blocks.pop() + for _ in range(round_up_to_next_block(len(prompt), block_size)) + ] + + seq_group_metadata_list = [] + for i, idx in enumerate(range(0, len(prompt), chunk_size)): + chunk_ids = prompt[idx:idx + chunk_size] + data = SequenceData.from_seqs(prompt) + data.update_num_computed_tokens(idx) + seq_data = {i: data} + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=str(seq_id), + is_prompt=True, + do_sample=idx + chunk_size >= len(prompt), # terminal chunk + seq_data=seq_data, + sampling_params=SamplingParams(temperature=0.0), + block_tables={i: block_allocations}, + token_chunk_size=len(chunk_ids))) + return seq_group_metadata_list + + def assert_logprobs_dict_allclose( actual_logprobs: List[Dict[int, Logprob]], expected_logprobs: List[Dict[int, Logprob]]) -> None: @@ -198,7 +233,8 @@ def create_batch(batch_size, prev_output_token_len: int = 10, seq_ids: Optional[List[int]] = None, num_gpu_blocks: Optional[int] = None, - block_size: Optional[int] = None): + block_size: Optional[int] = None, + prefill_chunk_size: Optional[int] = None): if block_size is None: block_size = 8 @@ -213,15 +249,28 @@ def create_batch(batch_size, prompt_lens = prompt_len prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens] - prev_output_tokens = [[ - next(iterator) for _ in range(prev_output_token_len) - ] for _ in range(batch_size)] - final_prompt_lens = [ - len(prompt) + len(prev_output_token) + k + 1 - for prompt, prev_output_token in zip(prompts, prev_output_tokens) - ] - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, final_prompt_lens, - prev_output_tokens, seq_ids) + if prefill_chunk_size: + # Create a batch of chunked prompts. + if not seq_ids: + seq_ids = list(range(len(prompts))) + seq_group_metadata_list = [] + for p, sid in zip(prompts, seq_ids): + seq_group_metadata_list += \ + create_chunked_seq_group_metadata_from_prompt( + p, num_gpu_blocks, prefill_chunk_size, block_size, sid) + seq_group_metadata_list = seq_group_metadata_list[:batch_size] + prev_output_tokens = [] + else: + prev_output_tokens = [[ + next(iterator) for _ in range(prev_output_token_len) + ] for _ in range(batch_size)] + final_prompt_lens = [ + len(prompt) + len(prev_output_token) + k + 1 + for prompt, prev_output_token in zip(prompts, prev_output_tokens) + ] + + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, num_gpu_blocks, block_size, final_prompt_lens, + prev_output_tokens, seq_ids) return seq_group_metadata_list, prompts, prev_output_tokens diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 26da0d89def29..314822b695722 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -276,7 +276,11 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: max_query_len=self.max_query_len, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=self.query_start_loc[self.num_prefills:] + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) if self.query_start_loc is not None else None, seq_start_loc=self.seq_start_loc[self.num_prefills:] if self.seq_start_loc is not None else None, @@ -903,7 +907,9 @@ def unified_flash_attention( # Decoding run. # Use flash_attn_varlen_func kernel for speculative decoding # because different queries might have different lengths. + assert decode_meta.max_decode_query_len is not None + # use only for actual varlen decoding if decode_meta.max_decode_query_len > 1: assert attn_type == AttentionType.DECODER, ( "Only decoder-only models support max_decode_query_len > 1") @@ -949,8 +955,6 @@ def unified_flash_attention( assert prefill_output is not None return prefill_output.view(num_prefill_query_tokens, hidden_size) - # Chunked prefill does not work with speculative decoding. - # Therefore, the query length for decode should be 1 in chunked prefill. assert decode_meta is not None decode_output = decode_output.squeeze(1) output = torch.cat([prefill_output, decode_output], dim=0) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b129d0d992f2f..2bae370eaa90f 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -192,6 +192,12 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, ) + # Batch may be composed of prefill|decodes, adjust query start indices + # to refer to the start of decodes when the two are split apart. + # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + if self._cached_decode_metadata.query_start_loc is not None: + qs = self._cached_decode_metadata.query_start_loc + self._cached_decode_metadata.query_start_loc = qs - qs[0] return self._cached_decode_metadata def advance_step(self, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 4725413baade7..83d03606524dc 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -272,6 +272,13 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables) + + # Batch may be composed of prefill|decodes, adjust query start indices + # to refer to the start of decodes when the two are split apart. + # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + if self._cached_decode_metadata.query_start_loc is not None: + qs = self._cached_decode_metadata.query_start_loc + self._cached_decode_metadata.query_start_loc = qs - qs[0] return self._cached_decode_metadata diff --git a/vllm/config.py b/vllm/config.py index e844a46bf06e6..9721925987cab 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -192,7 +192,6 @@ def __init__( self.max_logprobs = max_logprobs self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init - self.hf_config = get_config(self.model, trust_remote_code, revision, code_revision, rope_scaling, rope_theta, config_format) @@ -1317,13 +1316,6 @@ def maybe_create_spec_config( "speculative decoding is > 1, but got " f"{speculative_disable_by_batch_size=}") - # Reminder: Please update docs/source/serving/compatibility_matrix.rst - # If the feature combo become valid - if enable_chunked_prefill: - raise ValueError( - "Speculative decoding and chunked prefill are " - f"currently mutually exclusive ({enable_chunked_prefill=}).") - # TODO: The user should be able to specify revision/max model len # for the draft model. It is not currently supported. draft_revision = None @@ -1390,6 +1382,12 @@ def maybe_create_spec_config( f"num_speculative_tokens={n_predict}, but " f"{num_speculative_tokens=} was provided.") + if enable_chunked_prefill and draft_hf_config.model_type in ( + "medusa", "mlp_speculator", "eagle"): + raise ValueError( + "Chunked prefill and hidden-state based draft models are " + "not compatible.") + draft_model_config.max_model_len = ( SpeculativeConfig._maybe_override_draft_max_model_len( speculative_max_model_len, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index e56d5cddce424..af4671ec29be9 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1147,6 +1147,7 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) + # Put prefills first due to Attention backend ordering assumption. return SchedulerOutputs( scheduled_seq_groups=(prefills.seq_groups + running_scheduled.prefill_seq_groups + diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 223790806ab18..7a6ebb430541f 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -134,10 +134,12 @@ def process_outputs(self, sample for sample in samples if sample.output_token != VLLM_INVALID_TOKEN_ID ] - assert valid_samples - self._process_seq_outputs(seq, valid_samples, - sequence_group.sampling_params) + # When both spec-decode and pre-fill chunking are enabled, we + # don't have guaranteed samples here (e.g. all -1s). + if valid_samples: + self._process_seq_outputs(seq, valid_samples, + sequence_group.sampling_params) def _process_decode_and_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None: diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 59e71cc8deb48..6a7929d9d8f9c 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -90,7 +90,7 @@ def score_proposals( else: # Batch has a mix of spec decode enabled and disabled seq groups contracted = self._contract_batch( - contracted_bs=len(execute_model_req.seq_group_metadata_list), + execute_model_req.seq_group_metadata_list, target_sampler_output=target_sampler_output, proposals=proposals, num_scoring_tokens=num_scoring_tokens, @@ -126,7 +126,7 @@ def _expand_batch( split_batch_by_proposal_len( seq_group_metadata_list, proposal_lens_list) - target_seq_group_metadata_list = self._create_scoring_model_input( + spec_expanded_seqs = self._create_scoring_model_input( seq_group_metadata_list=spec_seqs, proposal_token_ids=proposal_token_ids_list, # NOTE: We determine the seq ids in the expanded batch using the @@ -135,16 +135,19 @@ def _expand_batch( seq_ids=get_all_seq_ids(seq_group_metadata_list)), ) - num_scoring_tokens = len(target_seq_group_metadata_list) - target_seq_group_metadata_list.extend(non_spec_seqs) + num_scoring_tokens = len(spec_expanded_seqs) + # Batch speculative and non-speculative (e.g. chunked prefill) requests + # but make sure order is prefill|decode due to backend requirement. + target_seq_group_metadata_list = non_spec_seqs + spec_expanded_seqs return (spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) def _contract_batch( - self, contracted_bs: int, target_sampler_output: SamplerOutput, - proposals: SpeculativeProposals, num_scoring_tokens: int, - non_spec_indices: List[int], spec_indices: List[int], k: int + self, contracted_seq_group_metadata_list: List[SequenceGroupMetadata], + target_sampler_output: SamplerOutput, proposals: SpeculativeProposals, + num_scoring_tokens: int, non_spec_indices: List[int], + spec_indices: List[int], k: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Contract the expanded batch back into its original size. @@ -154,6 +157,7 @@ def _contract_batch( contracted_bs is the original batch size, and the batch size that the target_sampler_output will be contracted to. """ + contracted_bs = len(contracted_seq_group_metadata_list) (target_token_ids, target_probs, target_logprobs, target_hidden_states, non_spec_target_token_ids, non_spec_target_probs, non_spec_target_logprobs, @@ -166,8 +170,8 @@ def _contract_batch( # The number of tokens in the expanded batch used for speculation is # equal to the total expanded batch size minus the number of samples for - # non-speculative sequences. - non_spec_expanded_bs = len(non_spec_target_token_ids) + # non-speculative sequences, prefill chunks with no out tokens included + non_spec_expanded_bs = len(non_spec_indices) spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1) @@ -191,7 +195,12 @@ def _contract_batch( else: all_hidden_states = None - if non_spec_indices: + # Rule out prefills that produce no tokens. + non_spec_indices = [ + idx for idx in non_spec_indices + if contracted_seq_group_metadata_list[idx].do_sample + ] + if len(non_spec_indices): all_tokens[non_spec_indices, :1] = \ non_spec_target_token_ids.unsqueeze(1) all_probs[non_spec_indices, :1, :] = \ @@ -290,9 +299,6 @@ def _create_target_seq_group_metadata( This function creates K+1 target SequenceGroupMetadata to take advantage of the bonus token. """ - assert not input_seq_group_metadata.is_prompt, ( - "Speculating on " - "prompts not yet supported") assert len(input_seq_group_metadata.seq_data) == 1, ( "Beam search " "not supported in speculative decoding") @@ -390,27 +396,22 @@ def _split_scoring_output( # and non spec sequences) and should be removed in the future. It can be # done by supporting per-sequence proposal lens. # - # First samples are from speculative scoring, latter samples are non- - # speculative samples. - split_sizes = (num_scoring_tokens, - sampler_output.sampled_token_ids.numel() - - num_scoring_tokens) - (spec_probs, non_spec_probs - ) = sampler_output.sampled_token_probs.split(split_sizes) - (spec_sampled_tokens, non_spec_sampled_tokens + # First samples are non-speculative, latter samples are from speculative + # scoring (prefill|decode order). + split_sizes = (sampler_output.sampled_token_ids.numel() - + num_scoring_tokens, num_scoring_tokens) + (non_spec_probs, + spec_probs) = sampler_output.sampled_token_probs.split(split_sizes) + (non_spec_sampled_tokens, spec_sampled_tokens ) = sampler_output.sampled_token_ids.flatten().split(split_sizes) - ( - spec_logprobs, - non_spec_logprobs, - ) = sampler_output.logprobs.split(split_sizes) + (non_spec_logprobs, + spec_logprobs) = sampler_output.logprobs.split(split_sizes) if sampler_output.hidden_states is not None: - ( - spec_hidden_states, - non_spec_hidden_states, - ) = sampler_output.hidden_states.split(split_sizes) + (non_spec_hidden_states, spec_hidden_states + ) = sampler_output.hidden_states.split(split_sizes) else: - spec_hidden_states, non_spec_hidden_states = None, None + non_spec_hidden_states, spec_hidden_states = None, None return (spec_sampled_tokens, spec_probs, spec_logprobs, spec_hidden_states, non_spec_sampled_tokens, non_spec_probs, diff --git a/vllm/spec_decode/mqa_scorer.py b/vllm/spec_decode/mqa_scorer.py index f35a8a0ab8be3..cbf793e2043e3 100644 --- a/vllm/spec_decode/mqa_scorer.py +++ b/vllm/spec_decode/mqa_scorer.py @@ -21,6 +21,11 @@ def score_proposals( all_proposal_lengths = proposals.proposal_lens.tolist() for i, seq_group_metadata in enumerate( execute_model_req.seq_group_metadata_list): + if all_proposal_lengths[i] == 0: + # Keep prompt seqs untouched (keep computed_tokens for chunks). + target_seq_group_metadata_list.append(seq_group_metadata) + continue + seq_data_dict = seq_group_metadata.seq_data assert len(seq_data_dict) == 1 seq_id = next(iter(seq_data_dict.keys())) @@ -40,8 +45,7 @@ def score_proposals( new_seq_data.update_num_computed_tokens( len(prompt_token_ids) + len(output_token_ids) - 1) - # Ensure that the new sequence has at least one token - # because we only use mqa scorer in the decoding stage. + # Ensure that the new decode sequence has at least one token. assert len(output_token_ids) >= 1 new_seq_data_dict = {target_seq_id: new_seq_data} @@ -54,7 +58,6 @@ def score_proposals( target_seq_id: seq_group_metadata.block_tables[seq_id], }, lora_request=None, - token_chunk_size=1, ) target_seq_group_metadata_list.append(new_seq_group_metadata) @@ -77,6 +80,7 @@ def score_proposals( all_probs = target_probs.reshape(bs, k + 1, self._vocab_size) all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size) else: + # We either have decodes with different lens or prefill+decodes. all_tokens = target_token_ids.new_full(size=(bs, k + 1), fill_value=-1) all_probs = target_probs.new_zeros(*all_tokens.shape, @@ -85,15 +89,18 @@ def score_proposals( fill_value=-float("inf")) target_token_ids = target_token_ids.flatten() start_loc = 0 - for i, proposed_len in enumerate(all_proposal_lengths): - output_len = proposed_len + 1 - end_loc = start_loc + output_len - all_tokens[ - i, :output_len] = target_token_ids[start_loc:end_loc] - all_probs[i, :output_len] = target_probs[start_loc:end_loc] - all_logprobs[ - i, :output_len] = target_logprobs[start_loc:end_loc] - start_loc = end_loc + for i, (proposed_len, seq_meta) in enumerate( + zip(all_proposal_lengths, target_seq_group_metadata_list)): + # Skip chunks with no output tokens. + if seq_meta.do_sample: + output_len = proposed_len + 1 + end_loc = start_loc + output_len + all_tokens[ + i, :output_len] = target_token_ids[start_loc:end_loc] + all_probs[i, :output_len] = target_probs[start_loc:end_loc] + all_logprobs[ + i, :output_len] = target_logprobs[start_loc:end_loc] + start_loc = end_loc hidden_states = None if target_sampler_output.hidden_states is not None: diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index eb3c2e88e668c..b57742c2ebfdd 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -418,7 +418,10 @@ def execute_model( # none of the requests in the batch have spec decoding enabled. # In any of these cases, the proposer and scorer workers # are called normally. - no_spec = num_lookahead_slots == 0 or disable_all_speculation or all( + # We expect `num_speculative_tokens` to be None for prefills. + no_spec = all( + sgm.is_prompt for sgm in execute_model_req.seq_group_metadata_list + ) or num_lookahead_slots == 0 or disable_all_speculation or all( sgm.num_speculative_tokens == 0 for sgm in execute_model_req.seq_group_metadata_list) @@ -484,7 +487,7 @@ def _maybe_disable_speculative_tokens( def _serialize_sampler_output_no_logprobs( self, execute_model_req: ExecuteModelRequest, - sampler_output: SamplerOutput) -> SamplerOutput: + sampler_output: SamplerOutput) -> List[SamplerOutput]: """ Creates and returns a `SamplerOutput` with only the token IDs being serialized to CPU and populated in `CompletionSequenceGroupOutput`. @@ -514,41 +517,56 @@ def _serialize_sampler_output_no_logprobs( if any(seq_output_prompt_logprobs) else \ sampler_output.sampled_token_ids).tolist() - seq_data_entries = ( + seq_data_entries = [ (seq_id, seq_data) for sg in \ execute_model_req.seq_group_metadata_list \ for seq_id, seq_data in sg.seq_data.items() - ) + if sg.do_sample # ignore empty token sequences + ] completion_seq_group_output_list: List[ CompletionSequenceGroupOutput] = [] - for index, ((seq_id, seq_data), needs_prompt_logprobs) in \ - enumerate(zip(seq_data_entries, seq_output_prompt_logprobs)): - if needs_prompt_logprobs: - prompt_token_ids = seq_data.get_prompt_token_ids() - prompt_logprobs = [ - create_logprobs_output( - token_id=p_token_id, + output_index = 0 + # Make sure the non-terminal prefill chunks are still aligned with + # their own empty output. + for seq_group_meta in execute_model_req.seq_group_metadata_list: + # Since we can get chunks here, we dont always have a sampled token + # (only on last chunk) but we still have to provide an output. + if not seq_group_meta.do_sample: + completion_seq_group_output_list.append( + CompletionSequenceGroupOutput(samples=[], + prompt_logprobs=None)) + else: + # Sequence with output. + seq_id, seq_data = seq_data_entries[output_index] + needs_prompt_logprobs = seq_output_prompt_logprobs[ + output_index] + if needs_prompt_logprobs: + prompt_token_ids = seq_data.get_prompt_token_ids() + prompt_logprobs = [ + create_logprobs_output( + token_id=p_token_id, + token_id_logprob_rank=-1, + token_id_logprob=0.0, + topk_token_ids=[], + topk_logprobs=[], + ) + # no prompt logprobs for the first token + for p_token_id in prompt_token_ids[1:] + ] + else: + prompt_logprobs = None + completion_seq_group_output_list.append( + create_sequence_group_output( + token_id=sampled_token_ids_list[output_index][0], token_id_logprob_rank=-1, token_id_logprob=0.0, + seq_id=seq_id, topk_token_ids=[], topk_logprobs=[], - ) - # no prompt logprobs for the first token - for p_token_id in prompt_token_ids[1:] - ] - else: - prompt_logprobs = None - - completion_seq_group_output_list.append( - create_sequence_group_output( - token_id=sampled_token_ids_list[index][0], - token_id_logprob_rank=-1, - token_id_logprob=0.0, - seq_id=seq_id, - topk_token_ids=[], - topk_logprobs=[], - prompt_logprobs=prompt_logprobs)) - return SamplerOutput(outputs=completion_seq_group_output_list) + prompt_logprobs=prompt_logprobs)) + output_index += 1 + + return [SamplerOutput(outputs=completion_seq_group_output_list)] @nvtx_range("spec_decode_worker._run_no_spec") def _run_no_spec(self, execute_model_req: ExecuteModelRequest, @@ -568,6 +586,9 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, hidden_states = sampler_output.hidden_states if hidden_states is not None: # remove hidden_states for prompt tokens + # TODO Enable `return_hidden_states`: prefill chunks hidden states + # are pruned by the logits processor. Also, they should be arranged + # back into full-prefill latent. Address it to enable MLPSpeculator. if any(seq.is_prompt for seq in execute_model_req.seq_group_metadata_list): hidden_states = hidden_states[ @@ -593,14 +614,14 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( execute_model_req=execute_model_req, sampler_output=sampler_output) if self._disable_logprobs else - sampler_output) + [sampler_output]) # Clear device tensors from sampler output. This reduces communication # overhead when the engine runs in a different process than the workers. sampler_output.sampled_token_probs = None sampler_output.sampled_token_ids = None sampler_output.logprobs = None - return [sampler_output_to_return] + return sampler_output_to_return def _run_non_driver_rank(self) -> bool: """Run proposer and verifier model in non-driver workers. This is used @@ -644,9 +665,15 @@ def _run_speculative_decoding_step( This invokes the proposer worker to get k speculative tokens for each sequence, then scores each speculative token using the scoring worker. + When `enable_chunked_prefill` is set, scorer will batch decodes and + prefills, while proposer will sync its KV-cache by running an extra + forward on prefills. + Returns a list of SamplerOutput, each containing a single token per sequence. """ + # With prefill chunking, expect requests to have prompts first + # so that backend gets prefill|decode. assert num_lookahead_slots == execute_model_req.num_lookahead_slots # Pass last hidden states from target model to proposer @@ -671,6 +698,25 @@ def _run_speculative_decoding_step( proposals, ) + _, (non_spec_seqs, non_spec_indices) = split_batch_by_proposal_len( + execute_model_req.seq_group_metadata_list, proposals.proposal_lens) + # With prefill chunking enabled, `non_spec_seqs` contains prefills too: + # discard decodes that have already been processed by proposer. + non_spec_indices = [ + idx for idx in non_spec_indices + if execute_model_req.seq_group_metadata_list[idx].is_prompt + ] + if len(non_spec_indices): + all_hidden_states = proposal_scores.hidden_states + # TODO fix `return_hidden_states`, same as in `_run_no_spec` + if all_hidden_states is not None: + prefill_hidden_states = all_hidden_states[non_spec_indices] + execute_model_req.previous_hidden_states = \ + prepare_prefill_hidden_states(prefill_hidden_states) + # Sync proposer KV cache for prefills. + prefill_req = execute_model_req.clone(non_spec_seqs) + self.proposer_worker.execute_model(prefill_req) + with Timer() as verification_timer: accepted_token_ids, target_logprobs = self._verify_tokens( execute_model_req.seq_group_metadata_list, proposal_scores, @@ -769,7 +815,6 @@ def _verify_tokens( self.previous_hidden_states = HiddenStates( hidden_states, seq_group_metadata_list, second_last_token_hidden_states) - return accepted_token_ids, logprobs def _create_output_sampler_list( @@ -819,6 +864,8 @@ def _create_output_sampler_list( accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() # Construct the output on a per-step, per-sequence basis. + # Non-terminal prefill chunks will end up here as rows with just -1s + # i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] sampler_output_list: List[SamplerOutput] = [] for step_index in range(num_steps): if all(token_id == -1 @@ -861,7 +908,6 @@ def _create_output_sampler_list( # This is periodic because the rejection sampler emits metrics # periodically. self._maybe_log_stage_times(*stage_times) - return sampler_output_list def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float, diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index f6a52a516075d..5a7999a258b2d 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -109,7 +109,6 @@ def get_spec_proposals( proposal_probs=proposal_probs, proposal_lens=proposal_lens, no_proposals=maybe_sampler_output is None) - return proposals def _split_by_proposal_len( @@ -127,9 +126,10 @@ def _split_by_proposal_len( nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] nonzero_proposal_len_indices: List[int] = [] for i, seq_group_metadata in enumerate(seq_group_metadata_list): - # The speculative decoding for this request has been disabled - # (e.g. due to high traffic). - if seq_group_metadata.num_speculative_tokens == 0: + # The speculative decoding for this request has either been disabled + # (e.g. due to high traffic) or this is a prompt request. + if (seq_group_metadata.is_prompt + or seq_group_metadata.num_speculative_tokens == 0): proposal_lens.append(0) continue