Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ssarkar2 committed Aug 19, 2024
1 parent b0112c3 commit 79aae80
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 3 deletions.
61 changes: 60 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from transformers import GenerationConfig, GenerationMixin

from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.sampler import ApplyToppTopkScalar, Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
Expand Down Expand Up @@ -700,3 +700,62 @@ def test_sampling_params(sampling_params: List[SamplingParams]):

assert tokens1[0] == tokens2[1]
assert tokens1[1] == tokens2[0]


def test_topk_topk_scalar():
obj1 = ApplyToppTopkScalar(2)
assert ApplyToppTopkScalar._padded_k == 0
x = torch.tensor([[9, 9, 8, 8, 8, 8, 7, 7, 7.0],
[10, 10, 9, 9, 9, 8, 5, 5, 5]])

retval1 = obj1(x, p=0.9, k=5)
ninf = -float("inf")
expected1 = torch.tensor([[9., 9., 8., 8., 8., 8., ninf, ninf, ninf],
[10., 10., 9., 9., 9., ninf, ninf, ninf, ninf]])
assert torch.all(retval1 == expected1).item()
assert ApplyToppTopkScalar._padded_k == 9

obj2 = ApplyToppTopkScalar(2)
assert obj2._padded_k == 9

x = torch.tensor([[2, 2, 9, 9, 2, 2, 1, 1, 1.0],
[10, 9, 9, 5, 9, 9, 5, 9, 10]])
retval2 = obj2(x, p=0.9, k=5)
expected2 = torch.tensor(
[[ninf, ninf, 9., 9., ninf, ninf, ninf, ninf, ninf],
[10., ninf, 9., ninf, 9., 9., ninf, 9., 10.]])
assert torch.all(retval2 == expected2).item()
assert obj2._padded_k == 9

retval3 = obj2(x, p=1.0, k=5)
expected3 = torch.tensor([[2., 2., 9., 9., 2., 2., ninf, ninf, ninf],
[10., 9., 9., ninf, 9., 9., ninf, 9., 10.]])

assert torch.all(retval3 == expected3).item()

# this should not be done in general, doing it here for testing purposes
ApplyToppTopkScalar._padded_k = 0
x = torch.tensor([[1, 1, 1, 9, 8, 1, 1, 1, 1.0],
[2, 1, 2, 2, 1, 1, 1, 1, 1]])
obj3 = ApplyToppTopkScalar(2)
retval4 = obj3(x, p=0.9, k=2)
expected4 = torch.tensor(
[[ninf, ninf, ninf, 9., 8., ninf, ninf, ninf, ninf],
[2., ninf, 2., 2., ninf, ninf, ninf, ninf, ninf]])
assert torch.all(retval4 == expected4).item()
assert obj3._padded_k == 4
y = torch.tensor([[8, 8, 8, 9, 8, 1, 1, 1, 1.0],
[2, 1, 2, 2, 1, 1, 1, 1, 1]])
retval5 = obj3(y, p=0.9, k=2)
assert obj3._padded_k == 8
expected5 = torch.tensor([[8., 8., 8., 9., 8., ninf, ninf, ninf, ninf],
[2., ninf, 2., 2., ninf, ninf, ninf, ninf,
ninf]])
assert torch.all(retval5 == expected5).item()
y = torch.tensor([[8, 8, 8, 9, 8, 8, 1, 1, 1.0],
[2, 1, 2, 2, 3, 1, 1, 1, 1]])
retval6 = obj3(y, p=0.9, k=2)
expected6 = torch.tensor([[8., 8., 8., 9., 8., 8., ninf, ninf, ninf],
[2., ninf, 2., 2., 3., ninf, ninf, ninf, ninf]])
assert torch.all(retval6 == expected6).item()
assert obj3._padded_k == 8
116 changes: 114 additions & 2 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A layer that samples the next tokens from the model's outputs."""
import itertools
import math
from math import inf
from typing import Dict, List, Optional, Tuple

Expand Down Expand Up @@ -77,6 +78,13 @@ def _init_sampling_tensors(
self._do_penalties = do_penalties
self._do_top_p_top_k = do_top_p_top_k
self._do_min_p = do_min_p
self._top_p_scalar = sampling_tensors.top_ps[0].item()
self._top_k_scalar = sampling_tensors.top_ks[0].item()
scalar_p = torch.all(sampling_tensors.top_ps == self._top_p_scalar)
scalar_k = torch.all(sampling_tensors.top_ks == self._top_k_scalar)
self._scalar_p_and_k = (scalar_p and scalar_k).item()
if self._scalar_p_and_k and self._do_top_p_top_k:
self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5)

def forward(
self,
Expand Down Expand Up @@ -122,8 +130,13 @@ def forward(
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

if do_top_p_top_k:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
if self._scalar_p_and_k:
logits = self._apply_top_k_top_p_opt(logits,
self._top_p_scalar,
self._top_k_scalar)
else:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)

if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)
Expand Down Expand Up @@ -198,6 +211,105 @@ def _get_bin_counts_and_mask(
return bin_counts, mask


class ApplyToppTopkScalar():
"""
The original implementation of _apply_top_k_top_p is more general
as it uses vector topp, topk
However in a lot of cases, topp and topk is same for all batch elements
For such "scalar" topp, topk cases, we can use this class
The main optimizations in this class is:
Use topk instead of sort, which is much faster especially for small k.
However just using topk might not suffice in cases as shown below
Consider a tensor: 9 9 8 8 8 8 7 7 7
Topk, with k=5, on this yields 9 9 8 8 8
The value "8" is on the boundary, hence the last "8" gets snipped off
However the original implementation accepts all the "8",
so it should output:
9 9 8 8 8 8 (6 values, even though k=5)
To ensure these semantics, we perform topk with __padded_k elements
If we find more boundary elements left over,
then we keep incrementing __padded_k
and in future calls use the expanded value of __padded_k
The increments to __padded_k should be done
with value > 1 to prevent excessive recompilations
due to dynamic shapes (the output shape of the topk)
The main logic of this is in __call__
This is a class instead of a function, just to keep track of
the monotonic non-decreasing state __padded_k
"""
_padded_k = 0

def __init__(self, increment: int):
self._increment = increment

def __call__(self, logits: torch.Tensor, p: float, k: int):
if k > ApplyToppTopkScalar._padded_k:
ApplyToppTopkScalar._padded_k = min(k + self._increment,
logits.shape[1])
#print("Increment padded_k to ", ApplyToppTopkScalar._padded_k)
while (True):
topk_results = torch.topk(logits,
k=ApplyToppTopkScalar._padded_k,
dim=1,
sorted=True)
# TODO, we may not need this flip,
# which might make this slightly faster
idx = torch.fliplr(topk_results.indices)
vals = torch.fliplr(topk_results.values)

if ApplyToppTopkScalar._padded_k == logits.shape[1]:
break

smallest_of_top_k = vals[:, -k]
num_duplicates_of_smallest_of_topk = torch.sum(
logits == smallest_of_top_k.unsqueeze(1), 1)
max_num_duplicates_of_smallest_of_topk = torch.max(
num_duplicates_of_smallest_of_topk).item()

# there are n repeats for a border
# (border meaning the smallest value of the top k).
# we do not know if only 1 or 2 or (n-1)
# of them lie outside the kth border,
# so we choose to conservatively increase by n-1
# when num_duplicates > _padded_k - k
if max_num_duplicates_of_smallest_of_topk - 1 > (
ApplyToppTopkScalar._padded_k - k):
incr = int(
math.ceil((max_num_duplicates_of_smallest_of_topk - 1) /
self._increment) * self._increment)
# this while loop should be traversed at most twice,
# because we dont increment by self._increment and retry
# instead we compute incr in one go
ApplyToppTopkScalar._padded_k = min(
ApplyToppTopkScalar._padded_k + incr, logits.shape[1])
#print("Increment padded_k to ", ApplyToppTopkScalar._padded_k)
else:
break

top_k_mask = vals.size(1) - k

top_k_smallest_val_idx = vals.size(1) - k
top_k_mask = vals[:, top_k_smallest_val_idx].unsqueeze(1)
top_k_mask = vals < top_k_mask
vals.masked_fill_(top_k_mask, -float("inf"))

probs_sort = vals.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= (1 - p)
top_p_mask[:, -1] = False
vals.masked_fill_(top_p_mask, -float("inf"))

new_logits = torch.full(logits.shape,
-float("inf"),
device=logits.device)
new_logits.scatter_(1, idx, vals)

return new_logits


def _apply_min_tokens_penalty(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
Expand Down

0 comments on commit 79aae80

Please sign in to comment.