diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 88c2c155b..064d7e6f0 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -201,6 +201,7 @@ def _compile( specializations: Optional[List[Dict[str, int]]] = None, custom_io: Optional[Dict[str, str]] = None, mdp_ts_num_devices: int = 1, + num_speculative_tokens: Optional[int] = None, **compiler_options, ) -> str: """ @@ -212,6 +213,7 @@ def _compile( :specializations (list): List of specializations to compile for :custom_io (dict): Custom IO to specify the input and outputs in different formats than default :mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing. + :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model. :compiler_options: Pass any compiler option as input. Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below: - aic_num_cores=16 -> -aic-num-cores=16 - convert_to_fp16=True -> -convert-to-fp16 @@ -244,6 +246,9 @@ def _compile( if mdp_ts_num_devices > 1: compile_hash.update(to_hashable({"mdp_ts_num_devices": mdp_ts_num_devices})) + if num_speculative_tokens: + compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens})) + # Check if already compiled compile_hash = compile_hash.hexdigest()[:16] qpc_path = qpc_path.with_name(qpc_path.name + "-" + compile_hash) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index cc9880a2e..4ddd57ada 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -274,6 +274,7 @@ def cloud_ai_100_exec_kv( write_io_dir: Optional[str] = None, automation=False, prompt_to_lora_id_mapping: Optional[List[int]] = None, + is_tlm: bool = False, ): """ This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. @@ -319,6 +320,7 @@ def cloud_ai_100_exec_kv( enable_debug_logs=enable_debug_logs, write_io_dir=write_io_dir, full_batch_size=full_batch_size, + is_tlm=is_tlm, ) if full_batch_size is None: exec_info = [ @@ -355,9 +357,11 @@ def __init__( device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, + is_tlm: Optional[int] = None, ) -> None: self._ctx_len = ctx_len self._write_io_dir = write_io_dir + self.is_tlm = is_tlm # Load QPC self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs) @@ -365,6 +369,7 @@ def __init__( # Fetch the variables from the QPC self._vocab_size = self._fetch_vocab_size() # Fetch Vocab size self.batch_size, self._prefill_seq_len = self._fetch_batch_size_prefill_seq_len() + self._decode_seq_len = self._fetch_decode_seq_len() self.full_batch_size = ( full_batch_size if full_batch_size else self._fetch_full_batch_size() ) # Check and fetch full batch size if CB is enabled @@ -441,6 +446,22 @@ def _fetch_batch_size_prefill_seq_len( batch_size, prefill_seq_len = self._session.bindings[self._session.binding_index_map["input_ids"]].dims return batch_size, prefill_seq_len + def _fetch_decode_seq_len( + self, + ): + """ + Fetches the decode sequence length from the session's bindings or allowed shapes. + + Returns: + decode_seq_len: The decode sequence length fetched from the session's bindings or allowed shapes. + """ + decode_seq_len = None + if self._session.allowed_shapes: + decode_seq_len = min( + [x[self._session.binding_index_map["input_ids"]][1][1] for x in self._session.allowed_shapes] + ) + return decode_seq_len + def _fetch_vocab_size( self, ): @@ -485,9 +506,19 @@ def prepare_decode_inputs(self): Returns: dict: The decode inputs. """ + batch_size = self.full_batch_size if self.full_batch_size is not None else self.batch_size decode_inputs = {} - decode_inputs["input_ids"] = self.decode_input_ids - decode_inputs["position_ids"] = self.decode_pos_ids + if self.is_tlm: + position_ids = np.full((batch_size, self._decode_seq_len), -1, dtype=np.int64) + position_ids[:, -1] = self.decode_pos_ids.flatten() + input_ids = np.zeros((batch_size, self._decode_seq_len), dtype=np.int64) + input_ids[:, -1] = self.decode_input_ids.flatten() + decode_inputs["input_ids"] = input_ids + decode_inputs["position_ids"] = position_ids + decode_inputs["num_logits_to_keep"] = np.zeros((self._decode_seq_len, 1)) + else: + decode_inputs["input_ids"] = self.decode_input_ids + decode_inputs["position_ids"] = self.decode_pos_ids if self.batch_index is not None: decode_inputs["batch_index"] = self.batch_index @@ -628,6 +659,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if decode_batch_id is not None: inputs["batch_index"] = decode_batch_id + if self.is_tlm: + inputs["num_logits_to_keep"] = np.zeros((1, 1)) if self._prompt_to_lora_id_mapping_prefill: if self.full_batch_size: @@ -668,7 +701,9 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): """ # Set logits placeholder for decode - logits_out_placeholder = np.zeros((self.full_batch_size, 1, self._vocab_size), dtype=np.float32) + logits_out_placeholder = np.zeros( + (self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32 + ) self._session.set_buffers({"logits": logits_out_placeholder}) # Generate flag for tracking progress for each batch ID current_decode_ongoing = np.full((self.full_batch_size, 1), True) @@ -694,7 +729,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): for decode_batch_id in range(self.full_batch_size): if ( - next_token_id[decode_batch_id] == self.tokenizer.eos_token_id + next_token_id[decode_batch_id, -1] == self.tokenizer.eos_token_id or generated_id_current_index[decode_batch_id] >= self.generation_len[decode_batch_id] ): if prompt_queue: @@ -724,10 +759,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): current_decode_ongoing[decode_batch_id] = False else: # If the generated sequence is valid and within generation len prepare for next decode - decode_inputs["input_ids"][decode_batch_id] = next_token_id[decode_batch_id] - decode_inputs["position_ids"][decode_batch_id] += 1 + decode_inputs["input_ids"][decode_batch_id, -1] = next_token_id[decode_batch_id, -1] + decode_inputs["position_ids"][decode_batch_id, -1] += 1 self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = ( - next_token_id[decode_batch_id] + next_token_id[decode_batch_id, -1] ) generated_id_current_index[decode_batch_id] += 1 @@ -747,6 +782,11 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform Returns: num_token (int): The number of tokens processed in the decoding process. """ + if self.is_tlm: + logits_out_placeholder = np.zeros( + (self.batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32 + ) + self._session.set_buffers({"logits": logits_out_placeholder}) finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id num_token = 0 for num_token in range(1, generation_len): @@ -760,8 +800,8 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform # Prepare inputs for next iteration decode_inputs["input_ids"] = outputs["logits"].argmax(2) - decode_inputs["position_ids"] += 1 - self.generated_ids[:, num_token] = decode_inputs["input_ids"].squeeze(1) + decode_inputs["position_ids"][:, -1] += 1 + self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1] finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id if finished_sequences.all(): @@ -811,9 +851,10 @@ def __init__( device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, + is_tlm: bool = False, ) -> None: self._qaic_model = QEffTextGenerationBase( - tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir + tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm ) self._full_batch_size = self._qaic_model.full_batch_size self._tokenizer = self._qaic_model.tokenizer diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py index 76c227862..377caa3e7 100644 --- a/QEfficient/peft/auto.py +++ b/QEfficient/peft/auto.py @@ -24,7 +24,7 @@ from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform -from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform +from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform from QEfficient.utils import constants from QEfficient.utils._utils import get_padding_shape_from_config from QEfficient.utils.cache import to_hashable diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 9e887a673..d0b58ffb7 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -18,7 +18,7 @@ import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform -from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform +from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform from QEfficient.utils import constants, get_padding_shape_from_config @@ -51,7 +51,7 @@ def __repr__(self) -> str: @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, *args, **kwargs): if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') @@ -61,7 +61,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) - return cls(model) + return cls(model, is_tlm=is_tlm) @property def model_name(self) -> str: @@ -82,6 +82,7 @@ def model_hash(self) -> str: mhash = hashlib.sha256() mhash.update(to_hashable(self.model.config.to_diff_dict())) mhash.update(to_hashable(self._transform_names())) + mhash.update(to_hashable({"is_tlm": self.is_tlm})) mhash = mhash.hexdigest()[:16] return mhash @@ -94,6 +95,7 @@ class QEFFAutoModelForCausalLM(QEFFTransformersBase): ``Mandatory`` Args: :model (nn.Module): PyTorch model :continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. + :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode. .. code-block:: python @@ -110,7 +112,14 @@ class QEFFAutoModelForCausalLM(QEFFTransformersBase): _pytorch_transforms = [AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, CustomOpsTransform, KVCacheTransform] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model: nn.Module, continuous_batching: bool = False, **kwargs): + def __init__( + self, + model: nn.Module, + continuous_batching: bool = False, + is_tlm: bool = False, + **kwargs, + ): + # TODO: remove from version 1.20 if kwargs.pop("full_batch_size", None): continuous_batching = True warnings.warn( @@ -124,15 +133,23 @@ def __init__(self, model: nn.Module, continuous_batching: bool = False, **kwargs self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching + if is_tlm: + # TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch + self.model, transformed = SpDTransform.apply(self.model) + self.is_tlm = is_tlm + @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: bool = False, *args, **kwargs): + def from_pretrained( + cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, *args, **kwargs + ): """ This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM. Once the model is initialized, you can use other methods such as export, compile, and generate on the same object. Args: :pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory. - :continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. + :continuous_batching (bool): Whether this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. + :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode. :args, kwargs: Additional arguments to pass to transformers.AutoModelForCausalLM. .. code-block:: python @@ -155,7 +172,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: boo "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 ) - self = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + self = super().from_pretrained(pretrained_model_name_or_path, is_tlm=is_tlm, *args, **kwargs) self.continuous_batching = continuous_batching return self @@ -165,6 +182,7 @@ def model_hash(self) -> str: mhash = hashlib.sha256() mhash.update(to_hashable(self.model.config.to_diff_dict())) mhash.update(to_hashable({"continuous_batching": self.continuous_batching})) + mhash.update(to_hashable({"is_tlm": self.is_tlm})) mhash.update(to_hashable(self._transform_names())) mhash = mhash.hexdigest()[:16] return mhash @@ -175,20 +193,20 @@ def export(self, export_dir: Optional[str] = None) -> str: We currently don't support exporting non-transformed models. Please refer to the ``convert_to_cloud_bertstyle`` function in the **Low-Level API** for a legacy function that supports this." ``Optional`` Args: - does not any arguments. + :export_dir (str, optional): The directory path to store ONNX-graph. Returns: :str: Path of the generated ``ONNX`` graph. """ - bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN fbs = constants.ONNX_EXPORT_EXAMPLE_FBS kv_cache_shape = get_padding_shape_from_config( self.model.config, fbs if self.continuous_batching else bs, seq_len ) example_inputs = { "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), - "position_ids": torch.arange(seq_len, dtype=torch.int64).view(bs, seq_len), + "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), "past_key_values": [[] for _ in range(self.num_layers)], } dynamic_axes = { @@ -216,6 +234,11 @@ def export(self, export_dir: Optional[str] = None) -> str: example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) dynamic_axes["batch_index"] = {0: "batch_size"} + if self.is_tlm: + nlk = constants.ONNX_EXPORT_EXAMPLE_NLK # Number of Logits to Keep + example_inputs["num_logits_to_keep"] = torch.arange(nlk).view(nlk, 1) + dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"} + return self._export( example_inputs, output_names, @@ -236,6 +259,7 @@ def compile( num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, + num_speculative_tokens: Optional[int] = None, **compiler_options, ) -> str: """ @@ -254,27 +278,53 @@ def compile( :full_batch_size (int, optional): Continuous batching batch size. :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to True``. :mxint8_kv_cache (bool, optional): Whether to use ``mxint8`` compression for KV cache. ``Defaults to False``. + :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model. :mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. ``Defaults to -1``. :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``. Returns: :str: Path of the compiled ``qpc`` package. """ - # Specializations - if self.continuous_batching: - if full_batch_size is None: - raise TypeError("missing required argument: 'full_batch_size'") - - specializations = [ - {"full_batch_size": full_batch_size, "batch_size": 1, "seq_len": prefill_seq_len, "ctx_len": ctx_len}, - {"full_batch_size": full_batch_size, "batch_size": full_batch_size, "seq_len": 1, "ctx_len": ctx_len}, - ] - else: - specializations = [ - {"batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len}, - ] - if prefill_seq_len != 1: - specializations.append({"batch_size": batch_size, "seq_len": 1, "ctx_len": ctx_len}) + if self.is_tlm: + # assert num_speculative_tokens cfg is acceptable if defined + if num_speculative_tokens is None: + raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` is True.") + if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2: + ValueError( + f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}" + ) + num_logits_to_keep = num_speculative_tokens + 1 + if prefill_seq_len < num_logits_to_keep: + raise ValueError( + f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})" + ) + + if self.continuous_batching and full_batch_size is None: + raise TypeError("missing required argument: 'full_batch_size'") + + # Define prefill specialization + prefill_specialization = { + # Prefill is always run with single BS for continuous batching. + "batch_size": 1 if self.continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + } + prefill_specialization.update({"full_batch_size": full_batch_size}) if self.continuous_batching else None + prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else None + specializations = [ + prefill_specialization, + ] + + # Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization + if prefill_seq_len != 1 or self.continuous_batching: + decode_specialization = { + "batch_size": full_batch_size if self.continuous_batching else batch_size, + "seq_len": num_speculative_tokens + 1 if self.is_tlm else 1, + "ctx_len": ctx_len, + } + decode_specialization.update({"full_batch_size": full_batch_size}) if self.continuous_batching else None + decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else None + specializations.append(decode_specialization) # Custom IO custom_io = {} @@ -294,6 +344,7 @@ def compile( mxfp6_matmul=mxfp6_matmul, custom_io=custom_io, mdp_ts_num_devices=num_devices, + num_speculative_tokens=num_speculative_tokens, aic_num_cores=num_cores, **compiler_options, ) @@ -329,6 +380,7 @@ def generate( prompt=prompts, device_id=device_id, generation_len=generation_len, + is_tlm=self.is_tlm, ) diff --git a/QEfficient/transformers/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py similarity index 87% rename from QEfficient/transformers/pytorch_transforms.py rename to QEfficient/transformers/models/pytorch_transforms.py index 9c58bf030..6b8d00689 100644 --- a/QEfficient/transformers/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +from types import MethodType from typing import Tuple import transformers @@ -199,6 +200,7 @@ QEffStarcoder2ForCausalLM, QEffStarcoder2Model, ) +from QEfficient.transformers.spd.causal_lm_forward import tlm_forward class CustomOpsTransform(ModuleMappingTransform): @@ -307,3 +309,38 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: # FIXME: see if we can merge into _module_mapping dict transformers.cache_utils.DynamicCache.update = QEffDynamicCache.update return model, transformed + + +class SpDTransform: + """ + Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. + This is only needed if user is exporting Target Language Model (TLM) for Speculative Decoding to validate output logits + against the speculated tokens from a smaller model. + Other than the computed logits, there should be no difference between the SpD Transformed model and its corresponding cunterpart. + + ``Mandatory`` Args: + :model (nn.Module): PyTorch model. + + Returns: + :model (nn.Module): PyTorch model. + :transformed (bool): whether transformation was applied successfully. + """ + + # supported architectures + _module_mapping = { + # Llama + QEffLlamaForCausalLM, + } + + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + transformed = False + if (model_class := model.__class__) in cls._module_mapping: + model.forward = MethodType(tlm_forward, model) + transformed = True + else: + raise NotImplementedError( + f"model class {model_class} does not yet support returning multiple logits to keep." + ) + + return model, transformed diff --git a/QEfficient/transformers/spd/__init__.py b/QEfficient/transformers/spd/__init__.py new file mode 100644 index 000000000..da26921c5 --- /dev/null +++ b/QEfficient/transformers/spd/__init__.py @@ -0,0 +1,7 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + diff --git a/QEfficient/transformers/spd/causal_lm_forward.py b/QEfficient/transformers/spd/causal_lm_forward.py new file mode 100644 index 000000000..46601c0c9 --- /dev/null +++ b/QEfficient/transformers/spd/causal_lm_forward.py @@ -0,0 +1,130 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast + + +def filter_hidden_states( + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + num_logits_to_keep: Optional[int] = None, +) -> torch.Tensor: + """ + Filter hidden states based on whether this is a TLM SpD model + + ``Mandatory`` Args: + :hidden_states (torch.Tensor): Hidden states tensor. + :position_ids (torch.Tensor): Position ids tensor. + ``Optional`` Args: + :num_logits_to_keep (int, optional): Number of speculative tokens, specified only for TLM SpD model + + Returns: + :torch.Tensor: Filtered hidden states. + """ + batch_size = position_ids.size(0) + batch_indices = torch.arange(batch_size) + # Cast to INT32 to avoid issue while running in ONNXRT + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + + if num_logits_to_keep is None: + # return the last logit + return hidden_states[batch_indices.view(-1, 1), logit_index] + + # gather approach + num_logits_to_keep = num_logits_to_keep.shape[0] + lower_idx = torch.where(logit_index < num_logits_to_keep, 0, logit_index + 1 - num_logits_to_keep).view( + -1, 1 + ) # shape: [bsz, 1] + spec_idx = torch.arange(num_logits_to_keep).view(1, -1) # shape: [1, k] + indices = torch.add(lower_idx, spec_idx).unsqueeze(2) # shape: [bsz, k, 1] + indices = indices.repeat(1, 1, hidden_states.size(-1)) # shape: [bsz, ,k, d_model] + hidden_states = torch.gather(hidden_states, dim=1, index=indices) # shape: [bsz, k, d_model] + return hidden_states + + +def tlm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[torch.LongTensor] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = filter_hidden_states(outputs[0], position_ids, num_logits_to_keep) + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 510e7ab8c..29384d008 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -185,12 +185,23 @@ def load_hf_tokenizer( def get_qpc_dir_path( - model_card_name, num_cores, mos, batch_size, prompt_len, ctx_len, mxfp6, mxint8, device_group, full_batch_size + model_card_name, + num_cores, + mos, + batch_size, + prompt_len, + ctx_len, + mxfp6, + mxint8, + device_group, + full_batch_size, + num_speculative_tokens: Optional[int] = None, ): # Create a unique directory name for the QPC model based on all parameters qpc_base_dir_name = ( f"qpc_{num_cores}cores_{batch_size}bs_{prompt_len}pl_{ctx_len}cl_{mos}mos" + f"{f'_{full_batch_size}fbs_' if full_batch_size is not None else '_'}" + + f"{f'_{num_speculative_tokens}nst_' if num_speculative_tokens is not None else ''}" + f"{len(device_group) if device_group is not None else 1}" + "devices" + ("_mxfp6_mxint8" if (mxfp6 and mxint8) else "_mxfp6" if mxfp6 else "_fp16_mxint8" if mxint8 else "_fp16") diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index c660b1897..4a3ba3ff3 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -46,6 +46,7 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1 ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 ONNX_EXPORT_EXAMPLE_FBS = 4 +ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep ONNX_EXPORT_OPSET = 13 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] @@ -60,3 +61,4 @@ class Constants: GB = 2**30 MAX_QPC_LIMIT = 30 MAX_RETRIES = 5 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download + NUM_SPECULATIVE_TOKENS = 2 diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index ba1c1fa48..267b2bb9e 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -167,7 +167,7 @@ def run_ort_session(self, inputs, session) -> dict: ort_outputs = dict(zip(output_names, outputs_data)) return ort_outputs - def run_kv_model_on_ort(self, model_path): + def run_kv_model_on_ort(self, model_path, is_tlm=False): """ Function responsible for running ``ONNX`` model on onnxruntime and return the output tokens @@ -197,12 +197,17 @@ def run_kv_model_on_ort(self, model_path): generated_ids = [] inputs = self.input_handler.prepare_ort_inputs() + if is_tlm: + nltk = np.zeros((1, 1), dtype=np.int64) + inputs["num_logits_to_keep"] = nltk ort_outputs = self.run_ort_session(inputs, session) ort_outputs = self.input_handler.update_ort_outputs(ort_outputs) for _ in range(1, self.gen_len): generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1)) inputs = self.input_handler.update_ort_inputs(inputs, ort_outputs) + if is_tlm: + inputs["num_logits_to_keep"] = nltk ort_outputs = self.run_ort_session(inputs, session) ort_outputs = self.input_handler.update_ort_outputs(ort_outputs) diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index 1ece48368..470446a98 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -149,3 +149,20 @@ Benchmark the model on Cloud AI 100, run the infer API to print tokens and tok/s qeff_model.generate(prompts=["My name is"]) ``` End to End demo examples for various models are available in **notebooks** directory. Please check them out. + +### Draft-Based Speculative Decoding +Draft-based speculative decoding is a technique where a small Draft Language Model (DLM) makes `num_speculative_tokens` autoregressive speculations ahead of the Target Language Model (TLM). The objective is to predict what the TLM would have predicted if it would have been used instead of the DLM. This approach is beneficial when the autoregressive decode phase of the TLM is memory bound and thus, we can leverage the extra computing resources of our hardware by batching the speculations of the DLM as an input to TLM to validate the speculations. + +To export and compile both DLM/TLM, add corresponding `is_tlm` and `num_speculative_tokens` for TLM and export DLM as you would any other QEfficient LLM model: + +```Python +tlm_name = "meta-llama/Llama-2-70b-chat-hf" +dlm_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +k = 3 # DLM will make `k` speculations +tlm = AutoModelForCausalLM.from_pretrained(tlm_name, is_tlm=True) +dlm = AutoModelForCausalLM.from_pretrained(dlm_name) +tlm.compile(num_speculative_tokens=k) +dlm.compile() +``` + +The `is_tlm` flag is fed during the instantiation of the model because slight changes to the ONNX graph are required. Once complete, the user can specify `num_speculative_tokens` to define the actual number of speculations that the TLM will take as input during the decode phase. As for the DLM, no new changes are required at the ONNX or compile level. \ No newline at end of file diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 6f0402c1b..6e91711e0 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +from typing import Optional + import numpy as np import pytest from transformers import AutoModelForCausalLM @@ -38,6 +40,10 @@ "ibm-granite/granite-20b-code-base", ] +spd_test_models = [ + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", +] + def load_causal_lm_model(model_config): """ @@ -69,6 +75,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( prompt_len: int = Constants.PROMPT_LEN, ctx_len: int = Constants.CTX_LEN, n_layer: int = 1, + num_speculative_tokens: Optional[int] = None, ): """ Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. @@ -98,7 +105,8 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) - qeff_model = QEFFAutoModelForCausalLM(model_hf) + is_tlm = False if num_speculative_tokens is None else True + qeff_model = QEFFAutoModelForCausalLM(model_hf, is_tlm=is_tlm) pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) @@ -107,7 +115,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( ).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output" onnx_model_path = qeff_model.export() - ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path) + ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm) assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output." @@ -120,6 +128,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( num_cores=14, mxfp6=False, aic_enable_depth_first=False, + num_speculative_tokens=num_speculative_tokens, ) exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size @@ -145,7 +154,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf) pytorch_hf_tokens = np.vstack(pytorch_hf_tokens) - qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True) + qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True, is_tlm=is_tlm) onnx_model_path = qeff_model.export() if not get_available_device_id(): @@ -158,6 +167,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( mxfp6=False, aic_enable_depth_first=False, full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, ) exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts) @@ -215,6 +225,24 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_name", spd_test_models) +def test_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + if model_name == "microsoft/Phi-3-mini-4k-instruct": + n_layer = 2 # test only 2 layer models + else: + n_layer = 1 + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, n_layer=n_layer, num_speculative_tokens=Constants.NUM_SPECULATIVE_TOKENS + ) + + @pytest.mark.on_qaic def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(): """ diff --git a/tests/transformers/test_transformer_pytorch_transforms.py b/tests/transformers/test_transformer_pytorch_transforms.py index d57a72e6c..e6a7d4588 100644 --- a/tests/transformers/test_transformer_pytorch_transforms.py +++ b/tests/transformers/test_transformer_pytorch_transforms.py @@ -13,7 +13,8 @@ from transformers.cache_utils import HybridCache from QEfficient.customop.matmulnbits import QuantLinearORT -from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM +from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform from QEfficient.transformers.quantizers.awq import WQLinear_GEMM from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform @@ -65,6 +66,47 @@ ("gemma", 1, 8, 2048, {"num_key_value_heads": 1, "intermediate_size": 512}, 0.8), ] +SpDTransformTestConfigs = [ + ("llama", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), + ("llama", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), + ("llama", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), + ("llama", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), +] + + +def create_qaic_model_inputs( + input_len: int, vocab_size: int, padding_shape: tuple, num_hidden_layers: int, is_tlm: bool = False +) -> dict: + """create pytorch QEff model inputs + + ``Mandatory`` Args: + :input_len (int): input length. + :vocab_size (int): vocab size. + :padding_shape (tuple): padding shape of KV$. + :num_hidden_layers (int): number of hidden layers. + ``Optional`` Args: + :is_tlm (bool, optional): whether this is an SpD TLM model. Defaults to False. + + Returns: + :dict: pytorch QEff model inputs + """ + input_ids = torch.randint(0, vocab_size, size=(1, input_len)) + past_key_values = [] + for _ in range(num_hidden_layers): + past_key = torch.zeros((padding_shape), dtype=torch.float32) + past_value = torch.zeros((padding_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) + inputs = dict( + input_ids=input_ids, + position_ids=torch.Tensor([range(input_ids.shape[1])]).long(), + past_key_values=tuple(past_key_values), + output_hidden_states=True, + ) + if is_tlm: + inputs["num_logits_to_keep"] = torch.zeros((input_len, 1)) + return inputs + def compare_original_vs_kv_model_pt_outputs(original_val, kv_val, tolerance=1e-6) -> bool: # Base case @@ -90,11 +132,15 @@ def compare_original_vs_kv_model_pt_outputs(original_val, kv_val, tolerance=1e-6 def run_kv_cache_transform_and_test( - hf_model, num_hidden_layers, padding_shape, vocab_size, input_len, logits_tolerance=0.8, kv_cache=None + hf_model, + qaic_model_inputs, + logits_tolerance=0.8, + kv_cache=None, ): hf_model.eval() # Run original model - input_ids = torch.randint(0, vocab_size, size=(1, input_len)) + input_ids = qaic_model_inputs["input_ids"] + input_len = input_ids.shape[1] with torch.inference_mode(): if isinstance(kv_cache, type(None)): original_model_outputs = hf_model( @@ -115,26 +161,13 @@ def run_kv_cache_transform_and_test( else: original_model_outputs = hf_model(input_ids=input_ids, output_hidden_states=True) - # Apply transform - hf_model, transformed = KVCacheTransform.apply(hf_model) - assert transformed - - # Prepare KV model inputs - past_key_values = [] - for _ in range(num_hidden_layers): - past_key = torch.zeros((padding_shape), dtype=torch.float32) - past_value = torch.zeros((padding_shape), dtype=torch.float32) - pkv = (past_key, past_value) - past_key_values.append(pkv) + # Apply transforms + is_tlm = "num_logits_to_keep" in qaic_model_inputs + hf_model = QEFFAutoModelForCausalLM(hf_model, is_tlm=is_tlm).model # Run KV model with torch.inference_mode(): - transformed_model_outputs = hf_model( - input_ids=input_ids, - position_ids=torch.Tensor([range(input_ids.shape[1])]).long(), - past_key_values=tuple(past_key_values), - output_hidden_states=True, - ) + transformed_model_outputs = hf_model(**qaic_model_inputs) assert original_model_outputs.keys() == transformed_model_outputs.keys(), "Model output keys do not match!" @@ -211,12 +244,57 @@ def test_kv_cache_transform( padding_shape = get_padding_shape_from_config(config=config, batch_size=1, seq_len=32) + # Prepare KV model inputs + qaic_model_inputs = create_qaic_model_inputs( + input_len=8, vocab_size=config.vocab_size, padding_shape=padding_shape, num_hidden_layers=num_hidden_layers + ) + run_kv_cache_transform_and_test( hf_model, + qaic_model_inputs=qaic_model_inputs, + logits_tolerance=logits_tolerance, + kv_cache=kv_cache, + ) + + +@pytest.mark.parametrize( + "config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance", + SpDTransformTestConfigs, +) +def test_spd_transform(config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance): + config = AutoConfig.for_model( + config_class, + **kwargs, num_hidden_layers=num_hidden_layers, - padding_shape=padding_shape, - vocab_size=config.vocab_size, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + use_cache=True, + cache_position=None, + position_embeddings=None, + ) + hf_model = AutoModelForCausalLM.from_config(config=config, attn_implementation="eager") + + kv_cache = None + if hasattr(config, "cache_implementation") and config.cache_implementation == "hybrid": + # Create a KV Cache from HybridCache class to pass as an object for models which use Hybrid KV Cache + # Refer https://github.com/huggingface/transformers/issues/32896 for more info + # This requires torch._dynamo present in torch>=2.3.0 + kv_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=32) + + padding_shape = get_padding_shape_from_config(config=config, batch_size=1, seq_len=32) + + # Prepare KV model inputs + qaic_model_inputs = create_qaic_model_inputs( input_len=8, + vocab_size=config.vocab_size, + padding_shape=padding_shape, + num_hidden_layers=num_hidden_layers, + is_tlm=True, + ) + + run_kv_cache_transform_and_test( + hf_model, + qaic_model_inputs=qaic_model_inputs, logits_tolerance=logits_tolerance, kv_cache=kv_cache, )