From 67327cfd3a27c201f5f75d1dbca2ac79424de558 Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Mon, 23 Oct 2023 22:40:27 +0200 Subject: [PATCH] [LLM Tests] Config support for LLM test suite / enable test suite in GHA (#1324) * initial commit * ready for review * Update tests/deepsparse/transformers/pipelines/test_text_generation.py * revamped after design review * ready for rereview * avoid downloading files multiple times --- src/deepsparse/transformers/helpers.py | 2 +- .../transformers/pipelines/text_generation.py | 3 + .../transformers/utils/token_generator.py | 10 +- .../pipelines/integration_tests/__init__.py | 13 + .../integration_tests/configs/codegen.yaml | 8 + .../integration_tests/configs/opt.yaml | 8 + .../{ => integration_tests}/helpers.py | 66 +- .../pipelines/integration_tests/test_llms.py | 364 +++++++++ .../transformers/pipelines/test_chat.py | 130 ++- .../pipelines/test_text_generation.py | 761 +++--------------- 10 files changed, 669 insertions(+), 696 deletions(-) create mode 100644 tests/deepsparse/transformers/pipelines/integration_tests/__init__.py create mode 100644 tests/deepsparse/transformers/pipelines/integration_tests/configs/codegen.yaml create mode 100644 tests/deepsparse/transformers/pipelines/integration_tests/configs/opt.yaml rename tests/deepsparse/transformers/pipelines/{ => integration_tests}/helpers.py (57%) create mode 100644 tests/deepsparse/transformers/pipelines/integration_tests/test_llms.py diff --git a/src/deepsparse/transformers/helpers.py b/src/deepsparse/transformers/helpers.py index e527ab0b22..3f24749ccb 100644 --- a/src/deepsparse/transformers/helpers.py +++ b/src/deepsparse/transformers/helpers.py @@ -73,7 +73,7 @@ def get_deployment_path(model_path: str) -> Tuple[str, str]: elif model_path.startswith("zoo:"): zoo_model = Model(model_path) - deployment_path = zoo_model.deployment_directory_path + deployment_path = zoo_model.deployment.path return deployment_path, os.path.join(deployment_path, _MODEL_DIR_ONNX_NAME) elif model_path.startswith("hf:"): from huggingface_hub import snapshot_download diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 694d11d664..d374047c95 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -648,6 +648,9 @@ def process_engine_outputs( created=datetime.datetime.now(), prompts=prompts, generations=generations ) + if "session_ids" in kwargs: + outputs["session_ids"] = kwargs["session_ids"] + if self._debug: debug_params = dict( kv_cache_state=kv_cache_state, diff --git a/src/deepsparse/transformers/utils/token_generator.py b/src/deepsparse/transformers/utils/token_generator.py index 817ae43552..4d7004f9c5 100644 --- a/src/deepsparse/transformers/utils/token_generator.py +++ b/src/deepsparse/transformers/utils/token_generator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional import numpy @@ -29,7 +29,7 @@ class TokenGenerator: def __init__( self, logits_shape: int, - tokens: List[int] = [], + tokens: Optional[List[int]] = None, deterministic: bool = True, sampling_temperature: float = 1.0, top_k: int = 0, @@ -61,7 +61,7 @@ def __init__( self.top_p = top_p self.frequency_penalty = frequency_penalty self.presence_penalty = presence_penalty - self.tokens = tokens + self.tokens = [] if tokens is None else tokens self._initialize_token_frequencies() @@ -168,5 +168,5 @@ def _update_frequencies(self, token: numpy.ndarray): def _initialize_token_frequencies(self): unique_tokens, frequencies = numpy.unique(self.tokens, return_counts=True) - for token, frequnecies in zip(unique_tokens, frequencies): - self.token_frequencies[token] += frequnecies + for token, freq in zip(unique_tokens, frequencies): + self.token_frequencies[token] += freq diff --git a/tests/deepsparse/transformers/pipelines/integration_tests/__init__.py b/tests/deepsparse/transformers/pipelines/integration_tests/__init__.py new file mode 100644 index 0000000000..0c44f887a4 --- /dev/null +++ b/tests/deepsparse/transformers/pipelines/integration_tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/deepsparse/transformers/pipelines/integration_tests/configs/codegen.yaml b/tests/deepsparse/transformers/pipelines/integration_tests/configs/codegen.yaml new file mode 100644 index 0000000000..62aac94a6b --- /dev/null +++ b/tests/deepsparse/transformers/pipelines/integration_tests/configs/codegen.yaml @@ -0,0 +1,8 @@ +cadence: "nightly" +model_path: "zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none" +torch_model_name: "salesforce/codegen-350m-mono" +task: ["text-generation"]#, "chat"] +prompt: "\ndef Fibonacci(n):\n # Check if input is 0 then it will\n # print incorrect input" +has_bos_token: False +precision: 0.0001 +internal_kv_cache: [True, False] \ No newline at end of file diff --git a/tests/deepsparse/transformers/pipelines/integration_tests/configs/opt.yaml b/tests/deepsparse/transformers/pipelines/integration_tests/configs/opt.yaml new file mode 100644 index 0000000000..2dfed87fd6 --- /dev/null +++ b/tests/deepsparse/transformers/pipelines/integration_tests/configs/opt.yaml @@ -0,0 +1,8 @@ +cadence: "nightly" +model_path: "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/opt_pretrain/base-none" +torch_model_name: "facebook/opt-1.3b" +task: ["text-generation"] +prompt: "Didn't know what time it was, the lights were low\n I leaned back on my radio" +has_bos_token: True +precision: 0.0001 +internal_kv_cache: [True, False] \ No newline at end of file diff --git a/tests/deepsparse/transformers/pipelines/helpers.py b/tests/deepsparse/transformers/pipelines/integration_tests/helpers.py similarity index 57% rename from tests/deepsparse/transformers/pipelines/helpers.py rename to tests/deepsparse/transformers/pipelines/integration_tests/helpers.py index 0bb962a8e3..e51ac7947a 100644 --- a/tests/deepsparse/transformers/pipelines/helpers.py +++ b/tests/deepsparse/transformers/pipelines/integration_tests/helpers.py @@ -12,11 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +import logging +import os +from typing import Any, Dict, List, Tuple, Union import numpy +import yaml from transformers import AutoModelForCausalLM, AutoTokenizer +import pytest + class TorchGroundTruthSource: """ @@ -36,7 +41,6 @@ def __init__(self, num_tokens_to_generate: int, model_name: str): self.tokenizer = self._create_tokenizer(model_name) self.num_tokens_to_generate = num_tokens_to_generate - self.model_name = model_name def tokenize(self, prompt: str): return self.tokenizer(prompt, return_tensors="pt") @@ -82,3 +86,61 @@ def _create_tokenizer(model_name): tokenizer.pad_token = tokenizer.eos_token return tokenizer + + +def parse_params(configs_directory: str) -> List[Dict[str, Any]]: + # parses the config file provided + assert os.path.isdir( + configs_directory + ), f"Config_directory {configs_directory} is not a directory" + + config_dicts = [] + for file in os.listdir(configs_directory): + if file.endswith(".yaml"): + config_path = os.path.join(configs_directory, file) + # reads the yaml file + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + cadence = os.environ.get("CADENCE", "commit") + expected_cadence = config["cadence"] + + if not isinstance(expected_cadence, list): + expected_cadence = [expected_cadence] + if cadence in expected_cadence: + config_dicts.append(config) + else: + logging.info( + f"Skipping testing model: {config['model_path']} " + f"for cadence: {config['cadence']}" + ) + else: + raise FileNotFoundError( + f"Could not find a yaml file in {configs_directory}" + ) + return config_dicts + + +def validate_internal_kv_cache( + internal_kv_cache, available_kv_cache_types: Union[str, List[str]] +) -> bool: + if internal_kv_cache and True not in available_kv_cache_types: + pytest.skip( + "The tests for running the pipeline with " + "internal kv cache management are disabled." + ) + if not internal_kv_cache and False not in available_kv_cache_types: + pytest.skip( + "The tests for running the pipeline with " + "external kv cache management are disabled." + ) + return internal_kv_cache + + +def validate_task(task: str, available_tasks: Union[str, List[str]]) -> bool: + if task not in available_tasks: + pytest.skip( + f"The tests for running the pipeline with task: {task} are disabled. " + f"The available tasks, as specified in the config are: {available_tasks}" + ) + return task diff --git a/tests/deepsparse/transformers/pipelines/integration_tests/test_llms.py b/tests/deepsparse/transformers/pipelines/integration_tests/test_llms.py new file mode 100644 index 0000000000..33dca47bfa --- /dev/null +++ b/tests/deepsparse/transformers/pipelines/integration_tests/test_llms.py @@ -0,0 +1,364 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This test suite consumes config files to test the text generation pipeline +for various scenerios. + +A sample config file is a yaml that requires the following fields: + cadence: The cadence of the tests. The available options are: + "nightly", "weekly" and "commit". By default, only + the tests that have cadence "commit" will be run + in GHA. This parameter can be both a string or a + list of strings. + model_path: The path to the model to be tested + (sparsezoo stub/hf model path/local_path) + torch_model_name: The name of the torch model + (to generate ground truth info) + task: The task to be tested + (e.g. text-generation) + prompt: The prompt to use for testing + has_bos_token: Whether the model has a bos token + precision: The precision for the logits/kv_cache entries + comparison + internal_kv_cache: The type of the internal KV cache + management. Is a list that can contain the following + values: [True], [False] or [True, False] (to test both + external and internal KV cache management) +""" +import os +from typing import List, Tuple + +import numpy + +import pytest +from deepsparse import Pipeline +from deepsparse.transformers.pipelines.text_generation import TextGenerationOutput +from sparsezoo import Model +from tests.deepsparse.transformers.pipelines.integration_tests.helpers import ( + TorchGroundTruthSource, + parse_params, + validate_internal_kv_cache, + validate_task, +) + + +CONFIGS_DIRECTORY = "tests/deepsparse/transformers/pipelines/integration_tests/configs" + + +@pytest.fixture() +def max_new_tokens() -> int: + return 64 + + +@pytest.mark.parametrize("params_dict", parse_params(CONFIGS_DIRECTORY)) +@pytest.mark.parametrize( + "internal_kv_cache", + [True, False], +) +@pytest.mark.parametrize( + "task", + ["text-generation", "chat"], +) +class TestsIntegrationLLMsPipelines: + """ + This test suite is meant to test the main scenarios of + the text generation pipeline. + """ + + def get_pipeline(self, **kwargs) -> Pipeline: + """ + If no kwargs provided, returns the cached "default" + pipeline that is used for most of the tests. + Otherwise, returns a pipeline with the given kwargs + (the default pipeline kwargs are updated with the + user-provided kwargs) + + :param kwargs: the optional kwargs to be used to + create the pipeline (if not provided, the cached + "default" pipeline is returned) + :return: the appropriate pipeline + """ + if not kwargs: + if self.default_pipeline is None: + self.default_pipeline = Pipeline.create(**self.default_pipeline_kwargs) + return self.default_pipeline + + # return a pipeline with the updated default kwargs + updated_kwargs = self.default_pipeline_kwargs.copy() + updated_kwargs.update(kwargs) + return Pipeline.create(**updated_kwargs) + + @pytest.fixture + def setup(self, params_dict, max_new_tokens, internal_kv_cache, task): + # set the params_dict as the class attributes + for key, value in params_dict.items(): + setattr(self, key, value) + # check whether the specified cache management type + # is supported for testing (skip if not supported) + self.internal_kv_cache: bool = validate_internal_kv_cache( + internal_kv_cache, self.internal_kv_cache + ) + self.task: str = validate_task(task, self.task) + # create torch ground source + torch_source = TorchGroundTruthSource( + num_tokens_to_generate=max_new_tokens + 1, + model_name=self.torch_model_name, + ) + # create torch ground truth + self.torch_ground_truth = torch_source(self.prompt) + + # specify the default pipeline kwargs + self.default_pipeline_kwargs = dict( + task=self.task, + model_path=self.model_path, + internal_kv_cache=self.internal_kv_cache, + ) + self.default_pipeline = None + self.max_new_tokens = max_new_tokens + + def test_ort_single_token_prefill(self, setup): + # Test the pipeline that uses ORT engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by single-token engine + # 2. The KV Cache is never filled up + # 3. KV Cache managed externally + + if self.internal_kv_cache: + pytest.skip( + "Cannot run ORT pipeline with the internal deepsparse cache enabled." + ) + + pipeline = self.get_pipeline( + prompt_sequence_length=1, + engine_type="onnxruntime", + ) + pipeline._debug = True + output = pipeline( + self.prompt, + max_new_tokens=self.max_new_tokens, + output_scores=True, + include_prompt_logits=True, + ) + + self._test_output( + output=output, + torch_ground_truth=self.torch_ground_truth, + ) + + def test_ort_multi_token_prefill(self, setup): + # Test the pipeline that uses ORT engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by multi-token engine + # 2. The KV Cache is never filled up + # 3. KV Cache managed externally + + if self.internal_kv_cache: + pytest.skip( + "Cannot run ORT pipeline with the internal deepsparse cache enabled." + ) + pipeline = self.get_pipeline( + engine_type="onnxruntime", + ) + pipeline._debug = True + output = pipeline( + self.prompt, + max_new_tokens=self.max_new_tokens, + output_scores=True, + include_prompt_logits=True, + ) + + self._test_output( + output=output, + torch_ground_truth=self.torch_ground_truth, + ) + + def test_deepsparse_single_token_prefill(self, setup): + # Test the pipeline that uses deepsparse engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by single-token engine + # 2. The KV Cache is never filled up + # 3. KV Cache managed externally or internally + + pipeline = self.get_pipeline( + prompt_sequence_length=1, + ) + pipeline._debug = True + output = pipeline( + self.prompt, + max_new_tokens=self.max_new_tokens, + output_scores=True, + include_prompt_logits=True, + ) + + self._test_output( + output=output, + torch_ground_truth=self.torch_ground_truth, + # disable kv cache validation if using internal kv cache + run_kv_cache_validation=not self.internal_kv_cache, + ) + + def test_deepsparse_multi_token_prefill(self, setup): + # Test the pipeline that uses deepsparse engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by multi-token engine + # 2. The KV Cache is never filled up + # 3. KV Cache managed internally or externally + + pipeline = self.get_pipeline() + pipeline._debug = True + output = pipeline( + self.prompt, + max_new_tokens=self.max_new_tokens, + output_scores=True, + include_prompt_logits=True, + ) + + self._test_output( + output=output, + torch_ground_truth=self.torch_ground_truth, + # disable kv cache validation if using internal kv cache + run_kv_cache_validation=not self.internal_kv_cache, + ) + + def test_inference_no_kv_cache_deepsparse(self, setup): + self._test_inference_no_kv_cache(engine_type="deepsparse") + + def test_inference_no_kv_cache_ort(self, setup): + self._test_inference_no_kv_cache(engine_type="onnxruntime") + + def _test_inference_no_kv_cache(self, engine_type): + model_path_no_cache = self._get_model_path_no_cache() + pipeline = self.get_pipeline( + model_path=model_path_no_cache, engine_type=engine_type + ) + assert not pipeline.cache_support_enabled, ( + "This pipeline test inference using non-kv cache " + "model and thus should not support kv cache" + ) + + output = pipeline( + self.prompt, max_length=1, output_scores=True, include_prompt_logits=True + ) + prompt_length = self.torch_ground_truth[1].shape[1] + # prompt logits + one logit for the new generated token + logits = output.generations[0].score[-(prompt_length + 1) :, :] + # compute ground truth logits analogously + generated_logits, prompt_logits, *_ = self.torch_ground_truth + logits_gt = numpy.concatenate( + [prompt_logits[0], generated_logits[0, :1, :]], axis=0 + ) + assert numpy.allclose(logits, logits_gt, atol=self.precision) + + def _test_output( + self, + output: TextGenerationOutput, + torch_ground_truth: Tuple[numpy.ndarray, ...], + run_kv_cache_validation: bool = True, + ): + + ( + generated_logits, + prompt_logits, + prompt_kv_cache, + generated_text, + ) = torch_ground_truth + + # concatenate target prompt_logits and generated_logits + target_logits = numpy.concatenate([prompt_logits, generated_logits], axis=1) + # get the logits of the generated sequence + score = output.generations[0].score + + # we expect the logits to be exactly the same + # as the target logits; the generated sequence should + # also be the same as the target sequence + assert numpy.allclose(score, target_logits[0], atol=self.precision) + assert self.prompt + output.generations[0].text == generated_text + + if hasattr(output, "kv_cache_state") and run_kv_cache_validation: + # (if applicable) the kv cache should be the same as the + # target kv cache + expected_cache = list(output.kv_cache_state[0].values()) + total_num_processed_tokens = output.total_num_processed_tokens[0] + self._test_kv_cache_state( + expected_cache=expected_cache, + target_cache=prompt_kv_cache, + total_num_processed_tokens=total_num_processed_tokens, + ) + + def _test_kv_cache_state( + self, + expected_cache: List[numpy.ndarray], + target_cache: List[numpy.ndarray], + total_num_processed_tokens: int, + ): + for x, y in zip(expected_cache, target_cache): + start_index = total_num_processed_tokens + end_index = total_num_processed_tokens - y.shape[2] + # x is (in general) composed of three arrays: + # - padding cache entries (from 0 to -start_index) + # - prompt cache entries (from -start_index to -end_index) + # - generated cache entries (from -end_index to -1) + # as target_cache only pertains to prompt cache entries, we need to + # compare only the prompt cache entries in x with y + assert numpy.allclose( + x[:, :, -start_index:-end_index, :], y, atol=self.precision + ) + + def _get_model_path_no_cache(self): + if not self.model_path.startswith("zoo:"): + pytest.skip("For this test, for now only the zoo model is supported") + model = Model(self.model_path) + # fetch the necessary file names for pipeline creation + required_file_names = [ + os.path.basename(file.name) for file in model.deployment.files + ] + training_directory = model.training + onnx_model_name_no_cache = [ + os.path.basename(file.name) + for file in model.training.files + if file.name.endswith(".onnx") + ][0] + + # check if 'training' exists, + # if not, download the files + if "training" not in os.listdir(model._path): + for filename in required_file_names: + # download the files to a training directory + if filename.endswith(".data"): + # data files are typically stored in a deployment directory + # download them to training + file = model.deployment.get_file(filename) + assert ( + file is not None + ), f"Unable to find file {filename} in model {model}" + file.name = file.name.replace("deployment", "training") + file.download() + continue + + if filename.endswith(".onnx"): + # instead of `model.onnx` the onnx_model_name_no_cache + # should be downloaded + filename = filename.replace("model.onnx", onnx_model_name_no_cache) + + file = training_directory.get_file(filename) + assert ( + file is not None + ), f"Unable to find file {filename} in model {model}" + file.download() + # rename the model file to `model.onnx` + os.rename( + os.path.join(training_directory.path, onnx_model_name_no_cache), + os.path.join(training_directory.path, "model.onnx"), + ) + return training_directory._path diff --git a/tests/deepsparse/transformers/pipelines/test_chat.py b/tests/deepsparse/transformers/pipelines/test_chat.py index 2a6b5d1ebf..a7dd69d290 100644 --- a/tests/deepsparse/transformers/pipelines/test_chat.py +++ b/tests/deepsparse/transformers/pipelines/test_chat.py @@ -12,34 +12,120 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + import pytest from deepsparse import Pipeline -@pytest.mark.parametrize( - "pipeline_kwargs", - [ - dict( - model_path="zoo:nlg/text_generation/codegen_mono-350m/pytorch/" - "huggingface/bigpython_bigquery_thepile/base-none", - engine_type="onnxruntime", - ), - ], -) -@pytest.mark.skip(reason="too heavy for now to run in gha") -def test_chat_pipeline_session_manager(pipeline_kwargs): - chat_pipeline = Pipeline.create(task="chat", **pipeline_kwargs) - - with chat_pipeline.session(): - output_1 = chat_pipeline( - prompt="first", generation_config=dict(max_new_tokens=1) - ) - output_2 = chat_pipeline( - prompt="second", generation_config=dict(max_new_tokens=1) - ) +@pytest.fixture +def pipeline(): + return Pipeline.create( + task="chat", + model_path="hf:mgoin/TinyStories-1M-deepsparse", + engine_type="onnxruntime", + ) + + +@pytest.fixture +def prompt(): + return "Never gonna give you up, never gonna let you down" + + +def test_chat_pipeline_session_manager(pipeline, prompt): + with pipeline.session(): + output_1 = pipeline(prompt) + output_2 = pipeline(prompt) # assert inferences in the same context share a session id assert output_1.session_ids == output_2.session_ids # test that follow-up inference has a different session id - output_3 = chat_pipeline(prompt="third", generation_config=dict(max_new_tokens=1)) + output_3 = pipeline(prompt) assert output_3.session_ids != output_1.session_ids + + +def test_run_with_same_session_ids(pipeline): + # Test the scenario where the same session ids are used for multiple + # inference runs. There are two conditions that must be fulfilled: + # 1. The information regarding the prompt does not leak between sessions + # 2. Running two prompts one after another is identical to running + # a composition of those prompts i.e. + # generated_text = pipeline(prompt_1) + # generated_text_2 = pipeline(prompt_2) + # generated_text_2 == pipeline(prompt_1 + generated_text + prompt_2) + + prompt_1 = "This prompt is used for testing purposes. To this to make sure that" + prompt_2 = "still this prompt should not" + + # make sure information does not leak between sessions + _test_composition_same_session_ids( + prompt_1=prompt_1, + prompt_2=prompt_2, + num_generated_tokens=32, + pipeline=pipeline, + session_id_1="test_1", + session_id_2="test_2", + ) + + _test_composition_same_session_ids( + prompt_1=prompt_1, + prompt_2=prompt_2, + num_generated_tokens=32, + pipeline=pipeline, + session_id_1="test_3", + session_id_2="test_4", + ) + + +def _test_composition_same_session_ids( + prompt_1, + prompt_2, + num_generated_tokens, + pipeline, + session_id_1, + session_id_2, +): + + tokenizer = pipeline.tokenizer + + # make sure that running two prompts one after another + # is identical to running a composition of those prompts + out_1_ = pipeline( + sequences=prompt_1, + force_max_tokens=True, + session_ids=session_id_1, + max_new_tokens=num_generated_tokens, + include_prompt_logits=True, + ) + prompt_1_ = out_1_.generations[0].text + out_1 = pipeline( + sequences=prompt_2, + force_max_tokens=True, + session_ids=session_id_1, + max_new_tokens=num_generated_tokens, + include_prompt_logits=True, + ) + cache_state_1 = pipeline.storage_kv_cache.get(session_id_1).cached_inputs[ + "past_key_values.0.key" + ] + + prompt_composition = tokenizer.decode( + tokenizer(prompt_1).input_ids + + tokenizer(prompt_1_).input_ids + + tokenizer(prompt_2).input_ids, + skip_special_tokens=True, + ) + out_2 = pipeline( + sequences=prompt_composition, + max_new_tokens=num_generated_tokens, + session_ids=session_id_2, + include_prompt_logits=True, + ) + cache_state_2 = pipeline.storage_kv_cache.get(session_id_2).cached_inputs[ + "past_key_values.0.key" + ] + if cache_state_1.shape[0]: + # if cache state is not empty, i.e. we are managing kv cache + # externally, make sure that the cache state is the same + np.allclose(cache_state_1, cache_state_2, atol=0.001) + assert out_1.generations[0].text == out_2.generations[0].text diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index 5298c2f1dd..c70c50a5ef 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -13,675 +13,104 @@ # limitations under the License. import inspect -from typing import List, Optional, Tuple import numpy -from transformers import GenerationConfig import pytest from deepsparse import Pipeline from deepsparse.transformers.utils.helpers import prepends_bos_token -from tests.deepsparse.transformers.pipelines.helpers import TorchGroundTruthSource - - -_PRECISION = 1e-3 - -NATURAL_LANGUAGE_PROMPT = """ -Didn't know what time it was, the lights were low -I leaned back on my radio -Some cat was layin' down some rock 'n' roll -"Lotta soul," he said -Then the loud sound did seem to fade -Came back like a slow voice on a wave of phase -That weren't no DJ, that was hazy cosmic jive -""" - -CODE_LANGUAGE_PROMPT = """ -def Fibonacci(n): - # Check if input is 0 then it will - # print incorrect input - if n < 0: - print("Incorrect input") - # Check if n is 0 - # then it will return 0 - elif n == 0: - return 0 -""" - - -@pytest.mark.parametrize( - "internal_kv_cache", - [ - True, - False, - ], -) -@pytest.mark.parametrize( - "pipeline_type", - ["text_generation", "chat"], -) -@pytest.mark.parametrize( - "model_stub, " - "model_name, " - "uses_bos_token, " - "prompt, " - "logits_max_diff_kv_cache_has_been_filled", - [ - ( - "zoo:nlg/text_generation/codegen_mono-350m/pytorch/" - "huggingface/bigpython_bigquery_thepile/base-none", - "salesforce/codegen-350m-mono", - False, - CODE_LANGUAGE_PROMPT, - 13, - ), - ( - "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/" - "opt_pretrain/base-none", - "facebook/opt-1.3b", - True, - NATURAL_LANGUAGE_PROMPT, - 3.9, - ), - ], - scope="class", -) -@pytest.mark.skip(reason="Those tests are too heavy to run as a normal part of the CI.") -class TestTextGenerationPipeline: - """ - This test suite is meant to test the main scenarios of - the text generation pipeline. - """ - - def get_pipeline(self, **kwargs): - if not kwargs: - # return the default pipeline - if self.default_pipeline: - return self.default_pipeline - else: - self.default_pipeline = Pipeline.create( - task=self.pipeline_type, - model_path=self.model_stub, - internal_kv_cache=self.internal_kv_cache, - prompt_sequence_length=self.prompt_sequence_length, - sequence_length=self.sequence_length, - ) - return self.default_pipeline - # return a pipeline with the given kwargs - return Pipeline.create(**kwargs) - - @pytest.fixture - def setup( - self, - model_stub, - model_name, - uses_bos_token, - prompt, - logits_max_diff_kv_cache_has_been_filled, - internal_kv_cache, - pipeline_type, - ): - self.num_tokens_generate = 216 - self.model_stub = model_stub - self.prompt = prompt - self.pipeline_type = pipeline_type - # create torch ground source - torch_source = TorchGroundTruthSource( - num_tokens_to_generate=self.num_tokens_generate, model_name=model_name - ) - torch_ground_truth = torch_source(self.prompt) - - # prompt length is expressed in number of prompt tokens - prompt_length = torch_ground_truth[1].shape[1] - - # sequence_length that assures that the KV cache will not be filled up - self.sequence_length = 2 * prompt_length + self.num_tokens_generate - # sequence_length that assures that the KV cache will be filled up - self.sequence_length_short = self.num_tokens_generate - - # prompt_sequence_length used for the multitoken prefill scenario - self.prompt_sequence_length = prompt_length // 2 - - # the maximum threshold for the difference between the logits - # when running a scenario where KV Cache buffer has been filled - self.logits_max_diff_kv_cache_has_been_filled = ( - logits_max_diff_kv_cache_has_been_filled - ) - self.internal_kv_cache = internal_kv_cache - - self.default_pipeline = None - - assert self.prompt_sequence_length < prompt_length, ( - "The prompt processing sequence length " - "must be smaller than the prompt length" - ) - - yield model_name, uses_bos_token, torch_ground_truth - - def test_freeze_first_position(self, setup): - # Test whether we should be "freezing" the first token after - # the kv cache is full - _, uses_bos_token, _ = setup - pipeline = self.get_pipeline() - assert prepends_bos_token(pipeline.tokenizer) == uses_bos_token - - def test_ort_single_token_prefill(self, setup): - # Test the pipeline that uses ORT engine. The test covers the - # following scenario: - # 1. Prompt preprocessing is performed by single-token engine - # 2. The KV Cache is never filled up - # 3. KV Cache managed externally - - if self.internal_kv_cache: - pytest.skip( - "Cannot run ORT pipeline with the internal deepsparse cache enabled." - ) - _, _, torch_ground_truth = setup - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length, - prompt_sequence_length=1, - engine_type="onnxruntime", - ) - pipeline._debug = True - - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - assert output.total_num_processed_tokens[0] < self.sequence_length - self._test_output( - output=output, - torch_ground_truth=torch_ground_truth, - ) - - def test_ort_multi_token_prefill(self, setup): - # Test the pipeline that uses ORT engine. The test covers the - # following scenario: - # 1. Prompt preprocessing is performed by multi-token engine - # 2. The KV Cache is never filled up - # 3. KV Cache managed externally - - if self.internal_kv_cache: - pytest.skip( - "Cannot run ORT pipeline with the internal deepsparse cache enabled." - ) - _, _, torch_ground_truth = setup - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length, - prompt_sequence_length=self.prompt_sequence_length, - engine_type="onnxruntime", - ) - pipeline._debug = True - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - - assert output.total_num_processed_tokens[0] < self.sequence_length - self._test_output( - output=output, - torch_ground_truth=torch_ground_truth, - ) - - def test_ort_generation_after_kv_cache_has_been_filled(self, setup): - # Test the pipeline that uses ORT engine. The test covers the - # following scenario: - # 1. Prompt preprocessing is performed by multi-token engine - # 2. The KV Cache is filled up (old entries are removed) - # 3. KV Cache managed externally - - if self.internal_kv_cache: - pytest.skip( - "Cannot run ORT pipeline with the internal deepsparse cache enabled." - ) - _, _, torch_ground_truth = setup - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length_short, - prompt_sequence_length=self.prompt_sequence_length, - engine_type="onnxruntime", - ) - pipeline._debug = True - - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - - assert output.total_num_processed_tokens[0] > self.sequence_length_short, ( - "for this scenario, the kv cache should be full: " - "the total number of processed tokens should be " - "greater than the sequence length" - ) - - self._test_output( - output=output, - torch_ground_truth=torch_ground_truth, - max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 - ) - - def test_deepsparse_single_token_prefill(self, setup): - # Test the pipeline that uses deepsparse engine. The test covers the - # following scenario: - # 1. Prompt preprocessing is performed by single-token engine - # 2. The KV Cache is never filled up - # 3. KV Cache managed externally or internally - - _, _, torch_ground_truth = setup - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length, - prompt_sequence_length=1, - internal_kv_cache=self.internal_kv_cache, - ) - pipeline._debug = True - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - - assert output.total_num_processed_tokens[0] < self.sequence_length - self._test_output( - output=output, - torch_ground_truth=torch_ground_truth, - run_cache_validation=not self.internal_kv_cache, - ) - - def test_deepsparse_multi_token_prefill(self, setup): - # Test the pipeline that uses deepsparse engine. The test covers the - # following scenario: - # 1. Prompt preprocessing is performed by multi-token engine - # 2. The KV Cache is never filled up - # 3. KV Cache managed externally or internally - - _, _, torch_ground_truth = setup - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length, - prompt_sequence_length=self.prompt_sequence_length, - internal_kv_cache=self.internal_kv_cache, - ) - pipeline._debug = True - - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - - assert output.total_num_processed_tokens[0] < self.sequence_length - self._test_output( - output=output, - torch_ground_truth=torch_ground_truth, - run_cache_validation=not self.internal_kv_cache, - ) - - def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): - # Test the pipeline that uses deepsparse engine. The test covers the - # following scenario: - # 1. Prompt preprocessing is performed by multi-token engine - # 2. The KV Cache is filled up (old entries are removed) - # 3. KV Cache managed externally or internally - - _, _, torch_ground_truth = setup - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length_short, - prompt_sequence_length=self.prompt_sequence_length, - internal_kv_cache=self.internal_kv_cache, - ) - pipeline._debug = True - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - - assert output.total_num_processed_tokens[0] > self.sequence_length_short, ( - "for this scenario, the kv cache should be full: " - "the total number of processed tokens should be " - "greater than the sequence length" - ) - - self._test_output( - output=output, - torch_ground_truth=torch_ground_truth, - run_cache_validation=not self.internal_kv_cache, - max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 - ) - - def test_run_same_prompt_multiple_times(self, setup): - # Test the scenario, where the same prompt is run multiple times - # Every run should produce the same output - pipeline = self.get_pipeline() - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - - output_1 = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - - output_2 = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - - assert output_1.generations[0].text == output_2.generations[0].text - assert numpy.allclose( - output_1.generations[0].score, - output_2.generations[0].score, - atol=_PRECISION, - ) - - def test_run_multiple_prompts_in_parallel(self, setup): - # Test the scenario, where multiple prompts are run in parallel - # Same two prompts should produce the same output - pipeline = self.get_pipeline() - - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=[self.prompt, self.prompt], - generation_config=config, - include_prompt_logits=True, - ) - - logits_0 = output.generations[0].score - sequence_0 = output.generations[0].text - - logits_1 = output.generations[1].score - sequence_1 = output.generations[1].text - - assert numpy.allclose(logits_0, logits_1, atol=_PRECISION) - assert sequence_0 == sequence_1 - - def test_num_generated_predictions(self, setup): - # Test the scenario, where multiple predictions are generated - # from the same prompt - pipeline = self.get_pipeline() - - config = GenerationConfig( - num_return_sequences=2, - max_length=self.num_tokens_generate, - top_k=0, - top_p=0.0, - ) - - output_sequences = pipeline(sequences=[self.prompt], generation_config=config) - assert len(output_sequences.generations) == 1 - assert len(output_sequences.generations[0]) == 2 - - output_sequences = pipeline( - sequences=[self.prompt, self.prompt], generation_config=config - ) - assert len(output_sequences.generations) == 2 - - for generation in output_sequences.generations: - assert len(generation) == 2 - - def test_token_generation_deterministic(self, setup): - pipeline_kwargs = { - "task": "text_generation", - "model_path": self.model_stub, - } - config = GenerationConfig( - output_scores=True, - max_length=self.num_tokens_generate, - top_k=0, - top_p=0.0, - num_return_sequences=3, - do_sample=False, - ) - pipeline = self.get_pipeline(**pipeline_kwargs) - inference = pipeline(sequences=["hello?"], generation_config=config) - generations = inference.generations - text_outputs = [x.text for x in generations[0]] - assert len(set(text_outputs)) == 1 - - def test_token_generation_non_deterministic(self, setup): - pipeline_kwargs = { - "task": "text_generation", - "model_path": self.model_stub, - } - pipeline = self.get_pipeline(**pipeline_kwargs) - config = GenerationConfig( - output_scores=True, - max_length=self.num_tokens_generate, - top_k=0, - top_p=0.0, - num_return_sequences=3, - do_sample=True, - ) - inference = pipeline(sequences=["hello?"], generation_config=config) - generations = inference.generations - # Output should be the same from one another - text_outputs = [x.text for x in generations[0]] - assert len(set(text_outputs)) == 3 - - def test_run_with_same_session_ids(self, setup): - # Test the scenario where the same session ids are used for multiple - # inference runs. There are two conditions that must be fulfilled: - # 1. The information regarding the prompt does not leak between sessions - # 2. Running two prompts one after another is identical to running - # a composition of those prompts i.e. - # generated_text = pipeline(prompt_1) - # generated_text_2 = pipeline(prompt_2) - # generated_text_2 == pipeline(prompt_1 + generated_text + prompt_2) - - if self.pipeline_type not in ["chatbot", "chat"]: - pytest.skip("This test is only applicable to chatbot pipeline") - - prompt_1 = "This prompt is used for testing purposes. To this to make sure that" - prompt_2 = "still this prompt should not" - num_generated_tokens = 32 - - self._test_run_with_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - multi_token_prefill=False, - ) - self._test_run_with_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - multi_token_prefill=True, - ) - - def _test_run_with_same_session_ids( - self, - prompt_1, - prompt_2, - num_generated_tokens, - multi_token_prefill, - ): - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - prompt_sequence_length=self.prompt_sequence_length - if multi_token_prefill - else 1, - force_max_tokens=True, - internal_kv_cache=self.internal_kv_cache, - ) - - # make sure information does not leak between sessions - - self._test_composition_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - pipeline, - session_id_1="test_1", - session_id_2="test_2", - ) - - self._test_composition_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - pipeline, - session_id_1="test_3", - session_id_2="test_4", - ) - - @staticmethod - def _test_composition_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - pipeline, - session_id_1, - session_id_2, - ): - - tokenizer = pipeline.tokenizer - config = GenerationConfig( - output_scores=True, max_length=num_generated_tokens, top_k=0, top_p=0.0 - ) - - # make sure that running two prompts one after another - # is identical to running a composition of those prompts - out_1_ = pipeline( - sequences=prompt_1, - session_ids=session_id_1, - generation_config=config, - include_prompt_logits=True, - ) - prompt_1_ = out_1_.generations[0].text - out_1 = pipeline( - sequences=prompt_2, - session_ids=session_id_1, - generation_config=config, - include_prompt_logits=True, - ) - cache_state_1 = pipeline.storage_kv_cache.get(session_id_1).cached_inputs[ - "past_key_values.0.key" - ] - - prompt_composition = tokenizer.decode( - tokenizer(prompt_1).input_ids - + tokenizer(prompt_1_).input_ids - + tokenizer(prompt_2).input_ids, - skip_special_tokens=True, - ) - out_2 = pipeline( - sequences=prompt_composition, - session_ids=session_id_2, - generation_config=config, - include_prompt_logits=True, - ) - cache_state_2 = pipeline.storage_kv_cache.get(session_id_2).cached_inputs[ - "past_key_values.0.key" - ] - if cache_state_1.shape[0]: - # if cache state is not empty, i.e. we are managing kv cache - # externally, make sure that the cache state is the same - numpy.allclose(cache_state_1, cache_state_2, atol=_PRECISION) - assert out_1.generations[0].text == out_2.generations[0].text - - def _test_output( - self, - output: "TextGenerationOutput", # noqa F821 - torch_ground_truth: Tuple[numpy.ndarray, ...], - max_logits_difference_threshold: Optional[float] = None, - run_cache_validation: bool = True, - ): - - ( - generated_logits, - prompt_logits, - prompt_kv_cache, - generated_text, - ) = torch_ground_truth - - # concatenate target prompt_logits and generated_logits and check - target_logits = numpy.concatenate([prompt_logits, generated_logits], axis=1) - score = output.generations[0].score - - if max_logits_difference_threshold: - # if comparing the output from the model where - # the kv cache has been filled, we expect the - # maximum absolute difference between the logits - # to be less than the threshold - # (the threshold is established by running the - # ONNX model in ONNXRuntime) - assert abs(score - target_logits[0]).max() < max_logits_difference_threshold - else: - # otherwise, we expect the logits to be exactly the same - # as the target logits; the generated sequence should - # also be the same as the target sequence, and finally - # (if applicable) the kv cache should be the same as the - # target kv cache - - assert numpy.allclose(score, target_logits[0], atol=_PRECISION) - assert self.prompt + output.generations[0].text == generated_text - - if run_cache_validation: - # extract numpy arrays from cached_inputs - kv_cache_array = list(output.kv_cache_state[0].values()) - total_num_processed_tokens = output.total_num_processed_tokens[0] - self._test_kv_cache_state( - expected_cache=kv_cache_array, - target_cache=torch_ground_truth[2], - total_num_processed_tokens=total_num_processed_tokens, - ) - - @staticmethod - def _test_kv_cache_state( - expected_cache: List[numpy.ndarray], - target_cache: List[numpy.ndarray], - total_num_processed_tokens: int, - ): - for x, y in zip(expected_cache, target_cache): - start_index = total_num_processed_tokens - end_index = total_num_processed_tokens - y.shape[2] - # x is (in general) composed of three arrays: - # - padding cache entries (from 0 to -start_index) - # - prompt cache entries (from -start_index to -end_index) - # - generated cache entries (from -end_index to -1) - # as target_cache only pertains to prompt cache entries, we need to - # compare only the prompt cache entries in x with y - assert numpy.allclose( - x[:, :, -start_index:-end_index, :], y, atol=_PRECISION - ) - - def test_streaming_mode_returns_generator(self, setup): - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length, - prompt_sequence_length=1, - ) - inputs = dict(prompt=self.prompt, streaming=True) - response_generator = pipeline(**inputs) - - assert inspect.isgenerator( - response_generator - ), "Pipeline should return a generator in streaming mode" - - assert all( - isinstance(response, pipeline.output_schema) - for response in response_generator - ), "Pipeline should return a generator of output_schema \ - objects in streaming mode" + + +@pytest.fixture +def pipeline(): + return Pipeline.create( + task="text_generation", + model_path="hf:mgoin/TinyStories-1M-deepsparse", + engine_type="onnxruntime", + ) + + +@pytest.fixture +def prompt(): + return "Never gonna give you up, never gonna let you down" + + +def test_freeze_first_position(pipeline): + # Test whether we should be "freezing" the first token after + # the kv cache is full + assert not prepends_bos_token(pipeline.tokenizer) + + +def test_run_same_prompt_multiple_times(pipeline, prompt): + # Test the scenario, where the same prompt is run multiple times + # Every run should produce the same output + output_1 = pipeline(prompt, output_scores=True) + output_2 = pipeline(prompt, output_scores=True) + + assert output_1.generations[0].text == output_2.generations[0].text + assert numpy.allclose( + output_1.generations[0].score, + output_2.generations[0].score, + atol=1e-3, + ) + + +def test_run_multiple_prompts_in_parallel(pipeline, prompt): + # Test the scenario, where multiple prompts are run in parallel + # Same two prompts should produce the same output + + output = pipeline([prompt, prompt], output_scores=True) + + logits_0 = output.generations[0].score + sequence_0 = output.generations[0].text + + logits_1 = output.generations[1].score + sequence_1 = output.generations[1].text + + assert numpy.allclose(logits_0, logits_1, atol=1e-3) + assert sequence_0 == sequence_1 + + +def test_num_generated_predictions(pipeline, prompt): + # Test the scenario, where multiple predictions are generated + # from the same prompt + + output_sequences = pipeline(prompt, num_return_sequences=2) + + assert len(output_sequences.generations) == 1 + assert len(output_sequences.generations[0]) == 2 + + output_sequences = pipeline([prompt, prompt], num_return_sequences=2) + assert len(output_sequences.generations) == 2 + + for generation in output_sequences.generations: + assert len(generation) == 2 + + +def test_token_generation_deterministic(pipeline, prompt): + inference = pipeline(prompt, num_return_sequences=3, do_sample=False) + generations = inference.generations + # Output should be the same from one another + text_outputs = [x.text for x in generations[0]] + assert len(set(text_outputs)) == 1 + + +def test_token_generation_non_deterministic(pipeline, prompt): + + inference = pipeline(prompt, num_return_sequences=3, do_sample=True) + generations = inference.generations + # Output should be different from one another + text_outputs = [x.text for x in generations[0]] + assert len(set(text_outputs)) == 3 + + +def test_streaming_mode_returns_generator(pipeline, prompt): + response_generator = pipeline(prompt, streaming=True) + assert inspect.isgenerator( + response_generator + ), "Pipeline should return a generator in streaming mode" + + assert all( + isinstance(response, pipeline.output_schema) for response in response_generator + ), "Pipeline should return a generator of output_schema \ + objects in streaming mode"