diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py index af8397c235f48..a3f0464e79675 100644 --- a/tests/spec_decode/e2e/test_compatibility.py +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -50,3 +50,49 @@ def test_spec_decode_xfail_spec_max_model_len(test_llm_generator): with pytest.raises(ValueError, match="cannot be larger than"): get_output_from_llm_generator(test_llm_generator, prompts, sampling_params) + + +@pytest.mark.parametrize("common_llm_kwargs", + [{ + "model": "meta-llama/Llama-2-7b-chat-hf", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "enable_chunked_prefill": "True", + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "tensor_parallel_size": 2, + "speculative_draft_tensor_parallel_size": 2, + }, + { + "tensor_parallel_size": 4, + "speculative_draft_tensor_parallel_size": 4, + }, + { + "tensor_parallel_size": 8, + "speculative_draft_tensor_parallel_size": 8, + }, +]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_xfail_chunked_prefill_draft_model_tp_not_one( + test_llm_generator): + """Verify that speculative decoding fails if chunked prefill is enabled for + draft model with tensor parallelism of more than 1. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + ] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + with pytest.raises(ValueError, match="with tensor parallel size 1"): + get_output_from_llm_generator(test_llm_generator, prompts, + sampling_params) diff --git a/vllm/config.py b/vllm/config.py index 9721925987cab..bed58fcecb5cb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1388,6 +1388,23 @@ def maybe_create_spec_config( "Chunked prefill and hidden-state based draft models are " "not compatible.") + speculative_draft_tensor_parallel_size = \ + SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size( + target_parallel_config, + speculative_draft_tensor_parallel_size, + draft_hf_config + ) + + if (enable_chunked_prefill and \ + speculative_draft_tensor_parallel_size != 1): + # TODO - Investigate why the error reported in + # https://github.com/vllm-project/vllm/pull/9291#issuecomment-2463266258 + # is happening and re-enable it. + raise ValueError( + "Chunked prefill and speculative decoding can be enabled " + "simultaneously only for draft models with tensor " + "parallel size 1.") + draft_model_config.max_model_len = ( SpeculativeConfig._maybe_override_draft_max_model_len( speculative_max_model_len, @@ -1466,15 +1483,16 @@ def _maybe_override_draft_max_model_len( ) @staticmethod - def create_draft_parallel_config( - target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: Optional[int], - draft_hf_config: PretrainedConfig, - ) -> ParallelConfig: - """Create a parallel config for use by the draft worker. - - This is mostly a copy of the target parallel config, except the tp_size. + def _verify_and_get_draft_model_tensor_parallel_size( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig) -> int: + """ + Verifies and adjusts the tensor parallel size for a draft model + specified using speculative_draft_tensor_parallel_size. """ + # If speculative_draft_tensor_parallel_size is unset then set it + # appropriately else verify that it is set correctly. if speculative_draft_tensor_parallel_size is None: if draft_hf_config.model_type == "mlp_speculator": speculative_draft_tensor_parallel_size = 1 @@ -1490,7 +1508,18 @@ def create_draft_parallel_config( raise ValueError( f"{speculative_draft_tensor_parallel_size=} cannot be " f"other value than 1 or target model tensor_parallel_size") + return speculative_draft_tensor_parallel_size + @staticmethod + def create_draft_parallel_config( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: int, + draft_hf_config: PretrainedConfig, + ) -> ParallelConfig: + """Create a parallel config for use by the draft worker. + + This is mostly a copy of the target parallel config, except the tp_size. + """ draft_parallel_config = ParallelConfig( pipeline_parallel_size=target_parallel_config. pipeline_parallel_size,