Skip to content

Commit

Permalink
[text generation] fix issues with top_p (#1354)
Browse files Browse the repository at this point in the history
* [text generation] fix issues with top_p

* * remove reverse [::-1] of logits ordering prior to cumsum
* revert change of min values
* update unit tests
  • Loading branch information
bfineran authored Oct 26, 2023
1 parent a3ae6aa commit e699c8f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
14 changes: 9 additions & 5 deletions src/deepsparse/transformers/utils/token_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from deepsparse.utils.data import numpy_softmax


_MIN_FLOAT = numpy.finfo(numpy.float32).min


class TokenGenerator:
"""
Responsible for generating tokens, and contains functions that
Expand Down Expand Up @@ -115,7 +118,7 @@ def apply_presence_penalty(self, logits: numpy.ndarray) -> numpy.ndarray:

# from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf31
def apply_top_k(
self, logits: numpy.ndarray, filter_value=-float("Inf")
self, logits: numpy.ndarray, filter_value=_MIN_FLOAT
) -> numpy.ndarray:
"""
Keep top_k logits based on its value. All other values
Expand All @@ -134,7 +137,7 @@ def apply_top_k(
def apply_top_p(
self,
logits: numpy.ndarray,
filter_value=-float("Inf"),
filter_value=_MIN_FLOAT,
min_tokens_to_keep: int = 1,
) -> numpy.ndarray:
"""
Expand All @@ -148,15 +151,16 @@ def apply_top_p(
logits_shape = logits.shape
logits = logits.reshape(logits.shape[-1])

sorted_indices = numpy.argsort(logits)[::-1]
sorted_indices = numpy.argsort(logits)
sorted_logits = logits[sorted_indices]
logit_cumulative_probs = numpy.cumsum(numpy_softmax(sorted_logits))

# Remove tokens with cumulative top_p above the threshold
# (token with 0 are kept)
sorted_indices_to_remove = logit_cumulative_probs > self.top_p
sorted_indices_to_remove = logit_cumulative_probs <= (1 - self.top_p)
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
if min_tokens_to_keep:
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0

# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices[sorted_indices_to_remove]
Expand Down
33 changes: 22 additions & 11 deletions tests/deepsparse/transformers/utils/test_token_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from deepsparse.transformers.utils.token_generator import TokenGenerator


_MIN_FLOAT = numpy.finfo(numpy.float32).min


@pytest.fixture(scope="function")
def logits_fixture() -> numpy.array:
def get(shape: Tuple = (1, 1, 51200), token_max_thresh: int = 30, low: int = -30):
Expand Down Expand Up @@ -128,45 +131,53 @@ def test_apply_topk(
assert numpy.all(new_logits == filter_value)

@pytest.mark.parametrize(
("logits", "top_p", "expected_non_inf_counts"),
("logits", "top_p", "min_tokens_to_keep", "expected_filtered_values"),
[
(
0.1 * numpy.ones(10).reshape((1, 1, 10)),
0.89,
9,
0.79,
0,
2,
),
(
0.1 * numpy.ones(10).reshape((1, 1, 10)),
0.9,
10, # one token should have cumsum > 0.9
0.899,
0,
1, # one token should have cumsum > 0.9
),
(0.1 * numpy.ones(10).reshape((1, 1, 10)), 0, 1), # keep at least one token
(0.1 * numpy.ones(10).reshape((1, 1, 10)), 0, 1, 9), # keep all toks but 1
(
numpy.array([1.0, -3.1, 2.0, 3.1, -1.0, -2.0, 1.2, -1.2]).reshape(
1, 1, -1
),
# expected distribution:
# [0.0012, 0.0049, 0.0132, 0.023, 0.097, 0.188, 0.3914, 1]
0.9,
3,
0,
5,
),
],
)
def test_apply_top_p(
self,
logits,
top_p,
expected_non_inf_counts,
min_tokens_to_keep,
expected_filtered_values,
):

token_generator = TokenGenerator(
logits_shape=logits[-1].shape[-1],
top_p=top_p,
)

filter_value = -float("Inf")
filter_value = _MIN_FLOAT
new_logits = token_generator.apply_top_p(
logits.copy(), filter_value=filter_value
logits.copy(),
filter_value=filter_value,
min_tokens_to_keep=min_tokens_to_keep,
)
assert numpy.isfinite(new_logits[-1]).sum(axis=1) == expected_non_inf_counts
assert (new_logits[-1] == filter_value).sum(axis=1) == expected_filtered_values

def test_generate_token(
self,
Expand Down

0 comments on commit e699c8f

Please sign in to comment.