diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 3c2e3db0eed43..8544fd86a7010 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -140,9 +140,8 @@ def get_logprobs( topk_indices = topk_indices.to(torch.int32) # Concatenate with the token_ids - sampled_logprobs = logprobs[torch.arange(logprobs.size(0)), - token_ids].unsqueeze(-1) token_ids = token_ids.unsqueeze(-1) + sampled_logprobs = logprobs.gather(-1, token_ids) topk_indices = torch.cat([token_ids, topk_indices], dim=1) topk_logprobs = torch.cat([sampled_logprobs, topk_logprobs], dim=1)