diff --git a/src/deepsparse/transformers/eval_downstream.py b/src/deepsparse/transformers/eval_downstream.py index 8f928c33be..f9835aa58e 100644 --- a/src/deepsparse/transformers/eval_downstream.py +++ b/src/deepsparse/transformers/eval_downstream.py @@ -62,49 +62,112 @@ import argparse import json +import logging from cProfile import Profile from pstats import Stats import numpy from tqdm.auto import tqdm +from datasets import load_dataset, load_metric from deepsparse import DEEPSPARSE_ENGINE, ORT_ENGINE, Pipeline from deepsparse.transformers.metrics import Perplexity, PrecisionRecallF1 +from deepsparse.transformers.utils.eval_helpers import process_concatenated_datasets -from datasets import load_dataset, load_metric # isort: skip +_LOGGER = logging.getLogger(__name__) -def perplexity_eval(args, batch_size=16, dataset_name="openai_humaneval"): - if args.max_samples: - batch_size = min(batch_size, args.max_samples) +PPL_DATASETS = ["wikitext2", "c4", "openai_humaneval"] - dataset = load_dataset(dataset_name)["test"] +def perplexity_eval(args, dataset_name="openai_humaneval"): + if dataset_name in ["wikitext2", "c4"]: + if args.kwargs is None: + kwargs = {} + else: + kwargs = json.loads(args.kwargs) + dataset = process_concatenated_datasets( + dataset_name, + args.model_path, + args.max_sequence_length, + kwargs, + ) + # Set perplexity computation to accumulate negative log-likelihood across + # sections + accumulate = True + else: + dataset = load_dataset(dataset_name, split="test") + accumulate = False + + # We'll use the text generation pipeline to generate a single token. + # Along with the token, it returns the logits for input sequence text_generation = Pipeline.create( task="text-generation", model_path=args.model_path, engine_type=args.engine, num_cores=args.num_cores, sequence_length=args.max_sequence_length, - max_generated_tokens=1, + trust_remote_code=args.trust_remote_code, ) - perplexity_metrics = Perplexity(pipeline=text_generation, batch_size=batch_size) - active_engines = [ - engine - for engine in [text_generation.engine, text_generation.multitoken_engine] - if engine - ] - print("Engine info: ") - [print(f"{engine}\n") for engine in active_engines] - predictions = [] + + # Instantiate perplexity metric + perplexity_metrics = Perplexity(accumulate=accumulate) + + # Loop through samples + batch_samples = [] + run_inference = False + end_evaluation = False + dataset_length = len(dataset) for idx, sample in _enumerate_progress(dataset, args.max_samples): - predictions.append(sample["prompt"] + sample["canonical_solution"]) - if len(predictions) == batch_size: - perplexity_metrics.add_batch(predictions) - predictions = [] - if args.max_samples and idx >= args.max_samples: + # Collect input sequence + if dataset_name == "openai_humaneval": + sample = sample["prompt"] + sample["canonical_solution"] + batch_samples.append(sample) + + if args.max_samples and idx == args.max_samples - 1: + run_inference = True + end_evaluation = True + + if (idx + 1) % args.batch_size == 0 or idx == dataset_length - 1: + run_inference = True + + if run_inference: + # Perform single token generation + prediction = text_generation( + sequences=batch_samples, + output_scores=True, + return_input_tokens=True, + fixed_sequences_length=True, + include_prompt_logits=True, + max_length=1, + ) + + # Handle one sample at a time to make it simpler for masking + for s in range(len(batch_samples)): + # Need to remove tokens that were masked + input_ids = prediction.input_tokens["input_ids"][s].flatten() + logits = prediction.generations[s].score + attention_mask = prediction.input_tokens["attention_mask"][s].flatten() + + effective_sequence_length = logits.shape[0] + + input_ids = input_ids[-effective_sequence_length:] + attention_mask = attention_mask[-effective_sequence_length:] + + logits = numpy.compress(attention_mask, logits, axis=0)[:-1, :] + input_ids = numpy.compress(attention_mask, input_ids)[1:] + + # Add predictions (logits) and targets (input_ids) to metric + perplexity_metrics.add_batch(logits, input_ids) + + # Reset batch + batch_samples.clear() + run_inference = False + + if end_evaluation: break + return perplexity_metrics @@ -473,7 +536,18 @@ def _split_train_val(train_dataset, val_ratio, seed=42): "imdb": imdb_eval, "conll2003": conll2003_eval, "go_emotions": go_emotions_eval, - "openai_humaneval": perplexity_eval, + "openai_humaneval": lambda args: perplexity_eval( + args, + dataset_name="openai_humaneval", + ), + "wikitext2": lambda args: perplexity_eval( + args, + dataset_name="wikitext2", + ), + "c4": lambda args: perplexity_eval( + args, + dataset_name="c4", + ), } @@ -604,7 +678,24 @@ def parse_args(): type=bool, default=False, ) - + parser.add_argument( + "--batch-size", + help="Batch size with which to evaluate model. Default is 1", + type=int, + default=1, + ) + parser.add_argument( + "--trust-remote-code", + help="Whether to allow for remote code execution in transformers.", + type=bool, + default=False, + ) + parser.add_argument( + "--kwargs", + help="Additional arguments specific to each dataset", + type=str, + default=None, + ) return parser.parse_args() @@ -619,6 +710,12 @@ def _main(args): f"available datasets are {list(SUPPORTED_DATASETS.keys())}" ) + if dataset not in PPL_DATASETS: + _LOGGER.warning( + "Batch-size argument is not supported for this dataset." + "Will use default value of 1." + ) + if dataset == "mnli": mnli_metrics_matched, mnli_metrics_mismatched = mnli_eval(args) mnli_metrics_matched = mnli_metrics_matched.compute() diff --git a/src/deepsparse/transformers/metrics.py b/src/deepsparse/transformers/metrics.py index f2e717a08f..1952ec2155 100644 --- a/src/deepsparse/transformers/metrics.py +++ b/src/deepsparse/transformers/metrics.py @@ -16,16 +16,11 @@ Utilities for evaluation metric computation """ - -from itertools import compress -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import numpy -from tqdm import tqdm -from deepsparse import Pipeline -from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline -from deepsparse.transformers.utils.helpers import pad_to_fixed_length +from scipy.special import log_softmax from sklearn.metrics import precision_recall_fscore_support @@ -36,139 +31,107 @@ class Perplexity: - def __init__(self, pipeline: Pipeline, batch_size: int = 16): + def __init__(self, accumulate: bool = False): """ - Given the pipeline, compute the perplexity of the model - on the given text input. + Class for computing perplexity. - Code adapted from: - https://huggingface.co/spaces/evaluate-metric/perplexity/blob/main/perplexity.py # noqa: E501 + Each batch is processed via the "add_batches" method. + At the end the data is reduced to a single perplexity + metric via the "compute" method. - :param pipeline: The pipeline to use for text generation - :param batch_size: The batch size to split the input text into - non-overlapping batches - """ - torch = _import_torch() - if not isinstance(pipeline, TextGenerationPipeline): - raise ValueError( - "Perplexity can only be computed for text generation pipelines" - ) - self._pipeline = pipeline - self._batch_size = batch_size - self._sequence_length = pipeline.sequence_length - self._loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - - self.perplexities = [] - - def add_batch(self, predictions: List[str]): + Example: + metric = Perplexity() + for prediction, target in samples: + metric.add_batch(prediction, target) + perplexity_value = metric.compute() + + :param accumulate: If True, accumulate negative log-likelihood + over samples. If False, perplexity is computed separately + for each sampled and then averaged in the end. """ - Run the model on the given input sequences and compute the perplexity. - The resulting perplexity is appended to the list of perplexities. + self._predictions = None + self._targets = None + self._accumulate = accumulate + if accumulate: + self._neg_log_likelihood = 0.0 + self._number_tokens = 0 + else: + self._perplexities = None - :param predictions: The predictions to compute perplexity on + def add_batch(self, predictions: numpy.ndarray, targets: numpy.ndarray): + """ + Computes perplexity or negative log-likelihood for each batch + (depending on accumulate argument) + and track results. + + Tracks perplexity or negative log-likelihood since storing + predictions may require a lot of memory. + + :param predictions: predicted scores. + Accepted shapes: + - [batch_size, sequence_length, vocab_size] + - [sequence_length, vocab_size] (batch size = 1) + Note: sequence length has to be uniform within a batch, but not all + batches require the same sequence length + :param targets: target values - index of correct vocabulary entry """ - torch = _import_torch() - # tokenize the input text - encodings = self._pipeline.tokenizer( - predictions, - return_attention_mask=True, - max_length=self._sequence_length, - truncation=True, - padding="max_length", - ) - encoded_texts = encodings["input_ids"] - attention_masks = encodings["attention_mask"] - - for start_index in tqdm(range(0, len(encoded_texts), self._batch_size)): - end_index = min(start_index + self._batch_size, len(encoded_texts)) - encoded_batch = encoded_texts[start_index:end_index] - attention_mask = attention_masks[start_index:end_index] - - # Computing the ground truth labels - - # `encoded_batch` contains sequences of tokens padded - # with tokens from the left side. We need to remove - # them and zero-pad from the right side up to the length - # of the longest sequence in the batch - - encoded_batch = [ - list(compress(sequence, attn_mask)) - for (sequence, attn_mask) in zip(encoded_batch, attention_mask) - ] - max_sequence_len = max(len(sequence) for sequence in encoded_batch) - - encoded_batch = [ - pad_to_fixed_length(numpy.array(sequence), max_sequence_len) - for sequence in encoded_batch - ] - encoded_batch = numpy.stack(encoded_batch) - - # We need to apply the analogous transformation to the attention mask - attention_mask = numpy.array(attention_mask) - attention_mask = [ - list(filter(lambda num: num != 0, mask)) for mask in attention_mask - ] - attention_mask = [ - pad_to_fixed_length(numpy.array(mask), max_sequence_len) - for mask in attention_mask - ] - attention_mask = numpy.stack(attention_mask) - - labels = encoded_batch - - out = self._pipeline( - sequences=predictions, - return_logits=True, - fixed_sequences_length=True, - include_prompt_logits=True, - ) - - logits = out.logits - - if not self._pipeline.cache_support_enabled: - # when running inference without cache, we need to apply - # analogous transformations to the logits as we did to the labels - # and attention mask - - # remove "nonsensical" logits for tokens - logits = [ - logit[-attn_mask.sum() :, :] - for (logit, attn_mask) in zip(logits, attention_mask) - ] - # pad logits to max length - logits = [ - pad_to_fixed_length(logit, max_sequence_len) for logit in logits - ] - logits = numpy.stack(logits) - - # shift logits and labels create the input and target for the loss function - shift_logits = logits[:, :-1, :] - shift_labels = labels[:, 1:] - shift_attention_mask_batch = attention_mask[:, 1:] - - # compute perplexity for this batch - perplexity_batch = torch.exp( - ( - self._loss_fct( - torch.tensor(shift_logits.transpose(0, 2, 1)), - torch.tensor(shift_labels), - ) - * torch.tensor(shift_attention_mask_batch) - ).sum(1) - / torch.tensor(shift_attention_mask_batch).sum(1) - ) - self.perplexities.extend(perplexity_batch.numpy().tolist()) + if self._accumulate: + # If accumulate is True, every token from the batch contributes + # equally to the negative log-likelihood. + # Thus, merge batch and sequence length dimensions and compute negative + # log-likelihood for all tokens, and accumulate to total + predictions = numpy.reshape(predictions, (-1, predictions.shape[-1])) + targets = targets.flatten() + + # Compute negative log-likelihood and accumulate + self._neg_log_likelihood += _cross_entropy( + predictions, targets, reduction="sum" + ).sum() + + # Track number of tokens processed + self._number_tokens += predictions.shape[0] + else: + # If accumulate is False, compute perplexity for + # each sample individually. + # We assume that sequence length is uniform within a batch, + # but may vary from batch to batch. + + # Create batch dimension if it doesn't exist + if targets.ndim == 1: + predictions = numpy.expand_dims(predictions, axis=0) + targets = numpy.expand_dims(targets, axis=0) + + # Compute negative log-likelihoods for batch + neg_log_likelihoods = _cross_entropy(predictions, targets) + + # Compute perplexities for batch + perplexities = numpy.exp(neg_log_likelihoods) + + # Store perplexities + if self._perplexities is None: + self._perplexities = perplexities + else: + self._perplexities = numpy.concatenate( + (self._perplexities, perplexities) + ) def compute(self) -> Dict[str, Any]: """ - :return: A dictionary containing the mean perplexity - and the list of perplexities + :return: A dictionary containing the final results. + If accumulate is True, return single perplexity. + Else, return a list of perplexities (one for each sample) + and mean perplexity. """ - return { - "mean_perplexity": numpy.mean(self.perplexities), - "perplexities": self.perplexities, - } + + if self._accumulate: + perplexity = numpy.exp(self._neg_log_likelihood / self._number_tokens) + return {"perplexity": perplexity} + else: + return { + "perplexities": self._perplexities, + "mean_perplexity": numpy.mean(self._perplexities), + } class PrecisionRecallF1: @@ -231,19 +194,33 @@ def compute(self) -> Dict[str, float]: return results -def _import_torch(): +def _cross_entropy( + predictions: numpy.ndarray, + targets: numpy.ndarray, + reduction: str = "mean", +) -> float: """ - Import and return the required torch module. Raises an ImportError if torch is not - installed. + Calculate the cross-entropy loss between predicted probabilities and target labels. + + Args: + predictions (numpy.ndarray): Predicted logits. + targets (nnumpy.ndarray): Target class labels. + reduction (str, optional): Specifies the reduction method for the loss. + - "mean" (default): Computes the mean loss over all samples. + - "sum": Computes the sum of losses over all samples. - :raises ImportError: if torch is not installed - :return: torch module + Returns: + float: The computed cross-entropy loss. """ - try: - import torch - - return torch - except ImportError as import_error: - raise ImportError( - "Please install `deepsparse[torch]` to use this pathway" - ) from import_error + + logp = log_softmax(predictions, axis=-1) + neg_log_likelihoods = -1.0 * numpy.take_along_axis( + logp, numpy.expand_dims(targets, axis=-1), axis=-1 + ) + neg_log_likelihoods = numpy.squeeze(neg_log_likelihoods, axis=-1) + if reduction == "mean": + neg_log_likelihoods = neg_log_likelihoods.mean(axis=-1) + elif reduction == "sum": + neg_log_likelihoods = neg_log_likelihoods.sum(axis=-1) + + return neg_log_likelihoods diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index bb87dbe18d..4a8fe335ea 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -96,6 +96,10 @@ class Config: alias="prompt", description="The input sequences to generate the text from.", ) + return_input_tokens: bool = Field( + default=False, + description="A flag that indicates whether to return " "the input_tokens. ", + ) include_prompt_logits: bool = Field( default=False, description="A flag that indicates whether to return " @@ -187,6 +191,15 @@ class TextGenerationOutput(BaseModel): "prompt provided. If streamng is enabled, the next generated token is returned." "Otherwise, the full generated sequence is returned." ) + input_tokens: Optional[ + Any + ] = Field( # dictionary mapping "token_ids" and "attention_mask" to numpy arrays + default=None, + description="The output of the tokenizer." + "Dictionary containing token_ids and attention_mask, " + "both mapping to arrays of size " + "[batch_size, sequence_length]", + ) class Config: arbitrary_types_allowed = True @@ -528,6 +541,8 @@ def process_inputs( context = dict( prompts=original_inputs, streaming=inputs.streaming, + return_input_tokens=inputs.return_input_tokens, + input_tokens=input_tokens, generation_config=generation_config, include_prompt_logits=inputs.include_prompt_logits, callback=inputs.callback, @@ -649,8 +664,15 @@ def process_engine_outputs( ] generations = grouped_generations + input_tokens = ( + kwargs.get("input_tokens") if kwargs.get("return_input_tokens") else None + ) + outputs = dict( - created=datetime.datetime.now(), prompts=prompts, generations=generations + created=datetime.datetime.now(), + prompts=prompts, + generations=generations, + input_tokens=input_tokens, ) if "session_ids" in kwargs: diff --git a/src/deepsparse/transformers/utils/eval_helpers.py b/src/deepsparse/transformers/utils/eval_helpers.py new file mode 100644 index 0000000000..4c0e68b9de --- /dev/null +++ b/src/deepsparse/transformers/utils/eval_helpers.py @@ -0,0 +1,183 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Mapping, Union + +import numpy +from transformers import AutoTokenizer, PreTrainedTokenizerFast + +from datasets import load_dataset + + +CONCATENATED_DATSETS = ["wikitext2", "c4"] + + +def process_concatenated_datasets( + dataset_name: str, + model_path: str, + max_sequence_length: int, + kwargs: Mapping, +) -> list: + """ + Concatenate text datasets and split them into chunks text that, after + tokenization, have size "max_sequence_length" tokens. + + Args: + dataset_name (str): The name of the dataset to process. + Options: "wikitext2" or "c4". + model_path (str): The path to a pretrained transformer model for tokenization. + max_sequence_length (int): The maximum number of tokens in each sequence. + kwargs (mapping): Additional keyword arguments. + - eos (str, optional): The end-of-sentence token. + Default is "\n\n" for wikitext2 and "" for c4. + - bos (str, optional): The beginning-of-sentence token. + Default is "". + - raw_samples (int, optional): The number of raw samples to use. + Default is None. + - data_file (int, optional): The index of the data file to use for dataset. + Not used in wikitext2. Default is 0 for c4. + - max_text_length (int, optional): The maximum length of text to consider. + Returns: + list: A list of text sequences. + + Raises: + ValueError: If an invalid dataset_name is provided. + """ + + if dataset_name not in CONCATENATED_DATSETS: + raise KeyError( + f"dataset {dataset_name} not supported for concatenated processing, " + f"available datasets are {list(CONCATENATED_DATSETS.keys())}" + ) + + if dataset_name == "wikitext2": + eos = kwargs.get("eos", "\n\n") + bos = kwargs.get("bos", "") + + raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + raw_text = raw_dataset["text"] + elif dataset_name == "c4": + eos = kwargs.get("eos", "<|endoftext|>") + bos = kwargs.get("bos", "") + raw_samples = kwargs.get("raw_samples", None) + data_file = kwargs.get("data_file", 0) + if data_file is not None: + raw_dataset = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={ + "validation": f"en/c4-validation.{data_file:05d}-of-00008.json.gz" + }, + split="validation", + ) + else: + raw_dataset = load_dataset( + "allenai/c4", + "allenai--c4", + split="validation", + ) + if raw_samples is not None: + raw_dataset = raw_dataset[:raw_samples] + raw_text = raw_dataset["text"] + + # Dataset is split into sections that contain "max_sequence_length" tokens. + # To split the dataset, first tokenize text + tokenizer = AutoTokenizer.from_pretrained(model_path) + return _split_text_by_tokens( + raw_text, + eos, + bos, + tokenizer, + max_sequence_length, + kwargs.get("max_text_length", None), + ) + + +def _split_text_by_tokens( + text: List[str], + eos: str, + bos: str, + tokenizer: PreTrainedTokenizerFast, + sequence_length: int, + max_text_length: Union[None, int], +) -> List[str]: + """ + Tokenizes and splits a list of concatenated text samples into + sections of specified maximum token length. + + Args: + text (List[str]): List of concatenated text samples to be tokenized and split. + eos (str): The end-of-sentence token. + bos (str): The beginning-of-sentence token. + tokenizer (PreTrainedTokenizerFast): Tokenizer for tokenizing the text. + sequence_length (int): The maximum number of tokens in each section. + max_text_length (Union[None, int]): The maximum length of text to consider. + - If None, the entire text is tokenized and split. + - If -1, each sample is tokenized separately. + - If a positive integer, the text is split into sections of this + length before tokenization. + + Returns: + List[str]: A list of sections where each section contains a + maximum of "sequence_length" tokens. + """ + + text = [bos + sample + eos for sample in text] + + if max_text_length is None: + text = "".join(text) + input_tokens = tokenizer(text, return_tensors="np")["input_ids"][0] + elif max_text_length == -1: # per sample tokenization + input_tokens = [] + for slice in text: + input_tokens.append(tokenizer(slice, return_tensors="np")["input_ids"][0]) + input_tokens = numpy.concatenate(input_tokens) + else: + text = "".join(text) + text_slices = len(text) // max_text_length + sliced_text = [ + text[i * max_text_length : (i + 1) * max_text_length] + for i in range(text_slices) + ] + sliced_text.append(text[text_slices * max_text_length :]) + input_tokens = [] + for slice in sliced_text: + input_tokens.append(tokenizer(slice, return_tensors="np")["input_ids"][0]) + input_tokens = numpy.concatenate(input_tokens) + + # Then split the tokenized text into sections of size "max_sequence_length" and + # decode each section back into text format + split_text = [] + for i in range(len(input_tokens) // sequence_length): + start = i * sequence_length + end = (i + 1) * sequence_length + split_text.append( + tokenizer.decode( + input_tokens[start:end], + clean_up_tokenization_spaces=False, + ) + ) + + # Handle any leftover tokens + if (i + 1) * sequence_length < len(input_tokens): + start = (i + 1) * sequence_length + end = len(input_tokens) + split_text.append( + tokenizer.decode( + input_tokens[start:end], + clean_up_tokenization_spaces=False, + ) + ) + + return split_text diff --git a/src/deepsparse/transformers/utils/helpers.py b/src/deepsparse/transformers/utils/helpers.py index 38e3ec4a4c..23b8244b71 100644 --- a/src/deepsparse/transformers/utils/helpers.py +++ b/src/deepsparse/transformers/utils/helpers.py @@ -322,7 +322,7 @@ def pad_to_fixed_length( ) -> numpy.ndarray: """ Pads the array to a fixed length along the given axis. - The padding is done on the right side of the array. + The padding is done on the left side of the array. :param array: array to pad :param max_len: maximum length to pad to @@ -334,7 +334,7 @@ def pad_to_fixed_length( padding = [(0, 0)] * len(array.shape) # for the specified axis, pad to the max length # (from the right side of the array) - padding[axis] = (0, max_len - array.shape[axis]) + padding[axis] = (max_len - array.shape[axis], 0) return numpy.pad(array, padding, mode="constant", constant_values=value) diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index c70c50a5ef..fb25a33883 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -104,6 +104,22 @@ def test_token_generation_non_deterministic(pipeline, prompt): assert len(set(text_outputs)) == 3 +def test_pipeline_for_ppl_eval(pipeline, prompt): + predictions = pipeline( + prompt, + output_scores=True, + return_input_tokens=True, + fixed_sequences_length=True, + include_prompt_logits=True, + max_length=1, + ) + assert hasattr(predictions, "generations") + assert hasattr(predictions, "input_tokens") + assert hasattr(predictions.generations[0], "score") + assert "input_ids" in predictions.input_tokens + assert "attention_mask" in predictions.input_tokens + + def test_streaming_mode_returns_generator(pipeline, prompt): response_generator = pipeline(prompt, streaming=True) assert inspect.isgenerator(