From 1c740dbbfbb45ccb795c3e8628bee2861e0b89d4 Mon Sep 17 00:00:00 2001 From: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Date: Tue, 5 Nov 2024 17:38:20 -0600 Subject: [PATCH] Modifying the sampler to allow FORCED type of sampling. (#265) --- vllm/model_executor/layers/sampler.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 93d47c0d05a00..2b28598975c47 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -753,11 +753,11 @@ def get_pythonized_sample_results( elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): sample_results = _random_sample(seq_groups, multinomial_samples[sampling_type]) + elif sampling_type == SamplingType.FORCED: + sample_results = _forced_sample(seq_groups, forced_samples) elif sampling_type == SamplingType.BEAM: sample_results = _beam_search_sample(seq_groups, beam_search_logprobs) - elif sampling_type == SamplingType.FORCED: - sample_results = _forced_sample(seq_groups, forced_samples) sample_results_dict.update(zip(seq_group_id, sample_results)) return [ @@ -869,9 +869,6 @@ def _sample_with_torch( # Store sampled tokens in output tensor. sampled_token_ids_tensor[long_sample_indices] = \ multinomial_samples[sampling_type].to(torch.long) - - elif sampling_type == SamplingType.BEAM: - beam_search_logprobs = logprobs[sample_indices] elif sampling_type == SamplingType.FORCED: if (seq_groups[0].sampling_params.future_context is not None): forced_samples = torch.tensor([ @@ -884,6 +881,8 @@ def _sample_with_torch( else: forced_samples = torch.argmax(logprobs[long_sample_indices], dim=-1) + elif sampling_type == SamplingType.BEAM: + beam_search_logprobs = logprobs[sample_indices] else: raise ValueError(f"Unsupported sampling type: {sampling_type}")