From e699c8f4c13e0cccf775123fbb16486299468e65 Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Wed, 25 Oct 2023 20:26:26 -0400 Subject: [PATCH] [text generation] fix issues with top_p (#1354) * [text generation] fix issues with top_p * * remove reverse [::-1] of logits ordering prior to cumsum * revert change of min values * update unit tests --- .../transformers/utils/token_generator.py | 14 +++++--- .../utils/test_token_generator.py | 33 ++++++++++++------- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/src/deepsparse/transformers/utils/token_generator.py b/src/deepsparse/transformers/utils/token_generator.py index 4d7004f9c5..5fa82b7bc4 100644 --- a/src/deepsparse/transformers/utils/token_generator.py +++ b/src/deepsparse/transformers/utils/token_generator.py @@ -19,6 +19,9 @@ from deepsparse.utils.data import numpy_softmax +_MIN_FLOAT = numpy.finfo(numpy.float32).min + + class TokenGenerator: """ Responsible for generating tokens, and contains functions that @@ -115,7 +118,7 @@ def apply_presence_penalty(self, logits: numpy.ndarray) -> numpy.ndarray: # from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf31 def apply_top_k( - self, logits: numpy.ndarray, filter_value=-float("Inf") + self, logits: numpy.ndarray, filter_value=_MIN_FLOAT ) -> numpy.ndarray: """ Keep top_k logits based on its value. All other values @@ -134,7 +137,7 @@ def apply_top_k( def apply_top_p( self, logits: numpy.ndarray, - filter_value=-float("Inf"), + filter_value=_MIN_FLOAT, min_tokens_to_keep: int = 1, ) -> numpy.ndarray: """ @@ -148,15 +151,16 @@ def apply_top_p( logits_shape = logits.shape logits = logits.reshape(logits.shape[-1]) - sorted_indices = numpy.argsort(logits)[::-1] + sorted_indices = numpy.argsort(logits) sorted_logits = logits[sorted_indices] logit_cumulative_probs = numpy.cumsum(numpy_softmax(sorted_logits)) # Remove tokens with cumulative top_p above the threshold # (token with 0 are kept) - sorted_indices_to_remove = logit_cumulative_probs > self.top_p + sorted_indices_to_remove = logit_cumulative_probs <= (1 - self.top_p) # Keep at least min_tokens_to_keep - sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0 + if min_tokens_to_keep: + sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices[sorted_indices_to_remove] diff --git a/tests/deepsparse/transformers/utils/test_token_generator.py b/tests/deepsparse/transformers/utils/test_token_generator.py index a94a5651df..31b561dc79 100644 --- a/tests/deepsparse/transformers/utils/test_token_generator.py +++ b/tests/deepsparse/transformers/utils/test_token_generator.py @@ -21,6 +21,9 @@ from deepsparse.transformers.utils.token_generator import TokenGenerator +_MIN_FLOAT = numpy.finfo(numpy.float32).min + + @pytest.fixture(scope="function") def logits_fixture() -> numpy.array: def get(shape: Tuple = (1, 1, 51200), token_max_thresh: int = 30, low: int = -30): @@ -128,25 +131,30 @@ def test_apply_topk( assert numpy.all(new_logits == filter_value) @pytest.mark.parametrize( - ("logits", "top_p", "expected_non_inf_counts"), + ("logits", "top_p", "min_tokens_to_keep", "expected_filtered_values"), [ ( 0.1 * numpy.ones(10).reshape((1, 1, 10)), - 0.89, - 9, + 0.79, + 0, + 2, ), ( 0.1 * numpy.ones(10).reshape((1, 1, 10)), - 0.9, - 10, # one token should have cumsum > 0.9 + 0.899, + 0, + 1, # one token should have cumsum > 0.9 ), - (0.1 * numpy.ones(10).reshape((1, 1, 10)), 0, 1), # keep at least one token + (0.1 * numpy.ones(10).reshape((1, 1, 10)), 0, 1, 9), # keep all toks but 1 ( numpy.array([1.0, -3.1, 2.0, 3.1, -1.0, -2.0, 1.2, -1.2]).reshape( 1, 1, -1 ), + # expected distribution: + # [0.0012, 0.0049, 0.0132, 0.023, 0.097, 0.188, 0.3914, 1] 0.9, - 3, + 0, + 5, ), ], ) @@ -154,7 +162,8 @@ def test_apply_top_p( self, logits, top_p, - expected_non_inf_counts, + min_tokens_to_keep, + expected_filtered_values, ): token_generator = TokenGenerator( @@ -162,11 +171,13 @@ def test_apply_top_p( top_p=top_p, ) - filter_value = -float("Inf") + filter_value = _MIN_FLOAT new_logits = token_generator.apply_top_p( - logits.copy(), filter_value=filter_value + logits.copy(), + filter_value=filter_value, + min_tokens_to_keep=min_tokens_to_keep, ) - assert numpy.isfinite(new_logits[-1]).sum(axis=1) == expected_non_inf_counts + assert (new_logits[-1] == filter_value).sum(axis=1) == expected_filtered_values def test_generate_token( self,