From 457d5ae0227732338b70049e2f4ea7645f371182 Mon Sep 17 00:00:00 2001 From: eplatero Date: Mon, 18 Nov 2024 15:55:43 -0600 Subject: [PATCH 01/30] rebasing with main. previous local gen_spd_models was broken since it was not picking up latest changes from main. as such, I found common ancestor, picked up latest changes from main, and made new commit to contain all unique changes for pr 119 Signed-off-by: eplatero --- QEfficient/base/modeling_qeff.py | 13 +- QEfficient/compile/compile_helper.py | 40 ++-- .../exporter/export_hf_to_cloud_ai_100.py | 27 ++- QEfficient/transformers/modeling_spd_utils.py | 42 +++++ .../transformers/models/modeling_auto.py | 37 +++- .../transformers/models/spd/modeling_tlm.py | 96 ++++++++++ QEfficient/transformers/pytorch_transforms.py | 33 ++++ QEfficient/utils/_utils.py | 3 +- QEfficient/utils/generate_inputs.py | 51 ++++-- docs/source/quick_start.md | 15 ++ tests/spd/test_tlm_dlm_export_and_compile.py | 172 ++++++++++++++++++ .../models/test_causal_lm_models.py | 31 +--- 12 files changed, 496 insertions(+), 64 deletions(-) create mode 100644 QEfficient/transformers/modeling_spd_utils.py create mode 100644 QEfficient/transformers/models/spd/modeling_tlm.py create mode 100644 tests/spd/test_tlm_dlm_export_and_compile.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 88c2c155b..e4327dd41 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -132,7 +132,12 @@ def _export( """ export_dir = Path(export_dir or (QEFF_HOME / self.model_name)) export_dir = export_dir.with_name(export_dir.name + "-" + self.model_hash) - onnx_path = export_dir / f"{self.model_name}.onnx" + if self.num_speculative_tokens: + model_name = f"{self.model_name}_{self.num_speculative_tokens+1}nltk.onnx" + else: + model_name = f"{self.model_name}.onnx" + onnx_path = export_dir / model_name + # TODO: need to add hash to onnx if onnx_path.is_file(): self.onnx_path = onnx_path return onnx_path @@ -244,6 +249,12 @@ def _compile( if mdp_ts_num_devices > 1: compile_hash.update(to_hashable({"mdp_ts_num_devices": mdp_ts_num_devices})) + if self.num_speculative_tokens: + compile_hash.update(to_hashable({"num_speculative_tokens": self.num_speculative_tokens})) + + if self.is_dlm: + compile_hash.update(to_hashable({"is_dlm": self.is_dlm})) + # Check if already compiled compile_hash = compile_hash.hexdigest()[:16] qpc_path = qpc_path.with_name(qpc_path.name + "-" + compile_hash) diff --git a/QEfficient/compile/compile_helper.py b/QEfficient/compile/compile_helper.py index a94c88d23..a6b0b19f3 100644 --- a/QEfficient/compile/compile_helper.py +++ b/QEfficient/compile/compile_helper.py @@ -16,24 +16,38 @@ def create_and_dump_specializations( - batch_size: int, prompt_len: int, ctx_len: int, path: str, full_batch_size: Optional[int] = None + batch_size: int, + prompt_len: int, + ctx_len: int, + path: str, + is_dlm: bool, + full_batch_size: Optional[int] = None, + num_speculative_tokens: Optional[int] = None, ): - # Create specialization file. - specializations = { - "specializations": [ - { - "batch_size": str(batch_size), - "seq_len": str(prompt_len), - "ctx_len": str(ctx_len), - }, - {"batch_size": str(batch_size), "seq_len": "1", "ctx_len": str(ctx_len)}, - ] - } + # Create specialization cfgs + decode_seq_len = 1 if num_speculative_tokens is None else num_speculative_tokens+1 + specialization_cfgs = [ + dict(batch_size=str(batch_size), seq_len=str(prompt_len), ctx_len=str(ctx_len)), # prefill + dict(batch_size=str(batch_size), seq_len=str(decode_seq_len), ctx_len=str(ctx_len)) # decode + ] + if num_logits_to_keep is not None: + specialization_cfgs[0]["num_logits_to_keep"] = "1" # return last logit + specialization_cfgs[1]["num_logits_to_keep"] = str(num_logits_to_keep+1) # return all SpD decode logits + elif is_dlm: + specialization_cfgs.append( + dict(batch_size=str(batch_size), seq_len="2", ctx_len=str(ctx_len)) + ) + + specializations = dict(specializations=specialization_cfgs) + # If continuous batching is enabled by proving full_batch_size we need to add FBS to the specialization file and update the batch size of decoder part to FBS if full_batch_size is not None: specializations["specializations"][0]["full_batch_size"] = str(full_batch_size) specializations["specializations"][1]["full_batch_size"] = str(full_batch_size) specializations["specializations"][1]["batch_size"] = str(full_batch_size) + if len(specializations["specializations"]) == 3: + specializations["specializations"][2]["batch_size"] = str(full_batch_size) + specializations["specializations"][2]["full_batch_size"] = str(full_batch_size) # To handle repetative input in specializations when prompt_len is 1 if prompt_len == 1 and full_batch_size is None: @@ -168,6 +182,8 @@ def compile( ctx_len=ctx_len, path=specialization_json_path, full_batch_size=full_batch_size, + is_dlm=kwargs.get("is_dlm", False), + num_speculative_tokens=kwargs.get("num_speculative_tokens", None), ) # Select the customIO config based on the mx flag. diff --git a/QEfficient/exporter/export_hf_to_cloud_ai_100.py b/QEfficient/exporter/export_hf_to_cloud_ai_100.py index c13bb9536..bc5a039f2 100644 --- a/QEfficient/exporter/export_hf_to_cloud_ai_100.py +++ b/QEfficient/exporter/export_hf_to_cloud_ai_100.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import os +import math import shutil import warnings from typing import Optional, Tuple, Union @@ -189,6 +190,7 @@ def export_kvstyle_transformed_model_to_onnx( onnx_dir_path: str, seq_len: int, full_batch_size: Optional[int] = None, + num_speculative_tokens: Optional[int] = None, ) -> str: # Disabling requires_grad on all parameters for _, p in enumerate(transformed_model.parameters()): @@ -197,6 +199,19 @@ def export_kvstyle_transformed_model_to_onnx( if seq_len <= 0: raise ValueError(f"Need seq_len to be greater than zero, got seq_len={seq_len}") + # Implicitly pass "num_speculative_tokens" if defined and \ + # assert prompt_len >= num_speculative_tokens + prompt_len = Constants.PROMPT_LEN + num_logits_to_keep = None + if num_speculative_tokens is not None: + num_logits_to_keep = num_speculative_tokens+1 + setattr(transformed_model, "num_logits_to_keep", num_logits_to_keep) + if prompt_len < num_logits_to_keep: + prompt_len *= math.ceil((num_logits_to_keep) / prompt_len) + if prompt_len >= seq_len: + seq_len = prompt_len*2 + + # Preprocess inputs # Build inputs for prefill input_handler = InputHandler( @@ -204,9 +219,10 @@ def export_kvstyle_transformed_model_to_onnx( tokenizer=tokenizer, config=transformed_model.config, prompt=Constants.INPUT_STR, - prompt_len=Constants.PROMPT_LEN, + prompt_len=prompt_len, ctx_len=seq_len, full_batch_size=full_batch_size, + num_logits_to_keep=num_logits_to_keep, ) inputs = input_handler.prepare_pytorch_inputs() @@ -223,7 +239,9 @@ def export_kvstyle_transformed_model_to_onnx( # Build inputs for decode inputs = input_handler.update_pytorch_inputs(inputs, pt_outputs) # To avoid issues in onnx export - inputs["position_ids"] = torch.full((full_batch_size if full_batch_size else 1, 1), seq_len - 1) + bsz = full_batch_size if full_batch_size else 1 + pos_len = inputs["position_ids"].size(1) + inputs["position_ids"] = torch.full((bsz, pos_len), seq_len - 1) # Run PyTorch inference with past pt_outputs = transformed_model(**inputs) @@ -314,6 +332,7 @@ def export_for_cloud( onnx_dir_path: str, seq_length: int = Constants.SEQ_LEN, full_batch_size: Optional[int] = None, + num_speculative_tokens: Optional[int] = None, ) -> str: # FIXME: move all this to class instead of here, and just call qeff_model.export here. if AUTO_MODEL_MAP_TO_MODEL_TYPE_MAP.get(qeff_model.__class__, None) == QEFF_MODEL_TYPE.CAUSALLM: # type: ignore @@ -324,6 +343,7 @@ def export_for_cloud( onnx_dir_path=onnx_dir_path, seq_length=seq_length, full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens ) else: raise NotImplementedError( @@ -338,6 +358,7 @@ def export_lm_model_for_cloud( onnx_dir_path: str, seq_length: int, full_batch_size: Optional[int] = None, + num_speculative_tokens: Optional[int] = None, ) -> str: if os.path.exists(onnx_dir_path): logger.warning(f"Overriding {onnx_dir_path}") @@ -366,6 +387,7 @@ def qualcomm_efficient_converter( kv: bool = True, form_factor: str = "cloud", full_batch_size: Optional[int] = None, + num_speculative_tokens: Optional[int] = None, ) -> Tuple[str, str]: """ This method is an alias for ``QEfficient.export``. @@ -441,6 +463,7 @@ def qualcomm_efficient_converter( onnx_dir_path=onnx_dir_path, seq_length=seq_length, full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, ) return onnx_dir_path, generated_onnx_model_path else: diff --git a/QEfficient/transformers/modeling_spd_utils.py b/QEfficient/transformers/modeling_spd_utils.py new file mode 100644 index 000000000..863adff79 --- /dev/null +++ b/QEfficient/transformers/modeling_spd_utils.py @@ -0,0 +1,42 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import Optional + +import torch + +def filter_hidden_states( + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + num_logits_to_keep: Optional[int] = None, +) -> torch.Tensor: + """ + Filter hidden states based on whether this is a TLM SpD model + + ``Mandatory`` Args: + :hidden_states (torch.Tensor): Hidden states tensor. + :position_ids (torch.Tensor): Position ids tensor. + ``Optional`` Args: + :num_logits_to_keep (int, optional): Number of speculative tokens, specified only for TLM SpD model + + Returns: + :torch.Tensor: Filtered hidden states. + """ + batch_size = position_ids.size(0) + batch_indices = torch.arange(batch_size) + # Cast to INT32 to avoid issue while running in ONNXRT + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + if num_logits_to_keep is None: + # return the last logit + return hidden_states[batch_indices.view(-1, 1), logit_index] + # gather approach + lower_idx = torch.where(logit_index <= num_logits_to_keep, 0, logit_index - num_logits_to_keep).view(-1,1) # shape: [bsz, 1] + spec_idx = torch.arange(num_logits_to_keep).view(1,-1) # shape: [1, k] + indices = torch.add(lower_idx, spec_idx).unsqueeze(2) # shape: [bsz, k, 1] + indices = indices.repeat(1, 1, hidden_states.size(-1)) # shape: [bsz, ,k, d_model] + hidden_states = torch.gather(hidden_states, dim=1, index=indices) # shape: [bsz, k, d_model] + return hidden_states diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 9e887a673..fbe9aacf3 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -5,6 +5,7 @@ # # ---------------------------------------------------------------------------- +import math import hashlib import logging import warnings @@ -18,7 +19,7 @@ import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform -from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform +from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform from QEfficient.utils import constants, get_padding_shape_from_config @@ -110,7 +111,7 @@ class QEFFAutoModelForCausalLM(QEFFTransformersBase): _pytorch_transforms = [AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, CustomOpsTransform, KVCacheTransform] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model: nn.Module, continuous_batching: bool = False, **kwargs): + def __init__(self, model: nn.Module, continuous_batching: bool = False, num_speculative_tokens: Optional[int] = None, is_dlm: bool = False, **kwargs): if kwargs.pop("full_batch_size", None): continuous_batching = True warnings.warn( @@ -123,6 +124,8 @@ def __init__(self, model: nn.Module, continuous_batching: bool = False, **kwargs self.model.config.use_cache = True self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching + self.num_speculative_tokens = num_speculative_tokens + self.is_dlm = is_dlm @classmethod def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: bool = False, *args, **kwargs): @@ -149,6 +152,14 @@ def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: boo model.generate(prompts=["Hi there!!"]) """ + num_speculative_tokens = kwargs.pop("num_speculative_tokens", None) + is_dlm = kwargs.pop("is_dlm", False) + if num_speculative_tokens is not None: + if not isinstance(num_speculative_tokens, int) or num_speculative_tokens<2: + ValueError("`num_speculative_tokens` arg should be an integer greater than 1.") + if is_dlm: + raise ValueError("`num_speculative_tokens` arg and `is_dlm` flag are mutually exclusive.") + cls._pytorch_transforms.append(SpDTransform) if kwargs.pop("full_batch_size", None): continuous_batching = True warnings.warn( @@ -157,6 +168,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: boo self = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) self.continuous_batching = continuous_batching + self.num_speculative_tokens = num_speculative_tokens + self.is_dlm = is_dlm return self @property @@ -182,13 +195,18 @@ def export(self, export_dir: Optional[str] = None) -> str: """ bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + if self.num_speculative_tokens: + num_logits_to_keep = self.num_speculative_tokens+1 + setattr(self.model, "num_logits_to_keep", num_logits_to_keep) + if seq_len < num_logits_to_keep: + seq_len *= math.ceil((num_logits_to_keep) / seq_len) fbs = constants.ONNX_EXPORT_EXAMPLE_FBS kv_cache_shape = get_padding_shape_from_config( self.model.config, fbs if self.continuous_batching else bs, seq_len ) example_inputs = { "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), - "position_ids": torch.arange(seq_len, dtype=torch.int64).view(bs, seq_len), + "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs,1), "past_key_values": [[] for _ in range(self.num_layers)], } dynamic_axes = { @@ -261,20 +279,29 @@ def compile( :str: Path of the compiled ``qpc`` package. """ # Specializations + decode_seq_len = self.num_speculative_tokens+1 if self.num_speculative_tokens else 1 if self.continuous_batching: if full_batch_size is None: raise TypeError("missing required argument: 'full_batch_size'") specializations = [ {"full_batch_size": full_batch_size, "batch_size": 1, "seq_len": prefill_seq_len, "ctx_len": ctx_len}, - {"full_batch_size": full_batch_size, "batch_size": full_batch_size, "seq_len": 1, "ctx_len": ctx_len}, + {"full_batch_size": full_batch_size, "batch_size": full_batch_size, "seq_len": decode_seq_len, "ctx_len": ctx_len}, ] + if self.is_dlm: + specializations.append( + {"full_batch_size": full_batch_size, "batch_size": full_batch_size, "seq_len": 2, "ctx_len": ctx_len}, + ) else: specializations = [ {"batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len}, ] if prefill_seq_len != 1: specializations.append({"batch_size": batch_size, "seq_len": 1, "ctx_len": ctx_len}) + if self.is_dlm: + specializations.append( + {"batch_size": batch_size, "seq_len": 2, "ctx_len": ctx_len}, + ) # Custom IO custom_io = {} @@ -341,4 +368,4 @@ def export(self): raise NotImplementedError("Reached too far!!") def compile(self, *args, **kwargs) -> Any: - raise NotImplementedError("Reached too far!!") + raise NotImplementedError("Reached too far!!") \ No newline at end of file diff --git a/QEfficient/transformers/models/spd/modeling_tlm.py b/QEfficient/transformers/models/spd/modeling_tlm.py new file mode 100644 index 000000000..23b867cf9 --- /dev/null +++ b/QEfficient/transformers/models/spd/modeling_tlm.py @@ -0,0 +1,96 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast + +from QEfficient.transformers.modeling_spd_utils import filter_hidden_states + +def tlm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + #num_logits_to_keep: Optional[torch.LongTensor] = None, # explicit passing is not currently supported +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + # Cast to INT32 to avoid issue while running in ONNXRT + num_logits_to_keep = getattr(self, "num_logits_to_keep", None) + hidden_states = filter_hidden_states(outputs[0], position_ids, num_logits_to_keep) + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/QEfficient/transformers/pytorch_transforms.py b/QEfficient/transformers/pytorch_transforms.py index 9c58bf030..d46e7f817 100644 --- a/QEfficient/transformers/pytorch_transforms.py +++ b/QEfficient/transformers/pytorch_transforms.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- from typing import Tuple +from types import MethodType import transformers from torch import nn @@ -199,6 +200,7 @@ QEffStarcoder2ForCausalLM, QEffStarcoder2Model, ) +from QEfficient.transformers.models.spd.modeling_tlm import tlm_forward class CustomOpsTransform(ModuleMappingTransform): @@ -307,3 +309,34 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: # FIXME: see if we can merge into _module_mapping dict transformers.cache_utils.DynamicCache.update = QEffDynamicCache.update return model, transformed + +class SpDTransform: + """ + Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits. + This is only needed if user is exporting Target Language Model (TLM) for Speculative Decoding to validate output logits + against the speculated tokens from a smaller model. + Other than the computed logits, there should be no difference between the SpD Transformed model and its corresponding cunterpart. + + ``Mandatory`` Args: + :model (nn.Module): PyTorch model. + + Returns: + :model (nn.Module): PyTorch model. + :transformed (bool): whether transformation was applied successfully. + """ + # supported architectures + _module_mapping = { + # Llama + QEffLlamaForCausalLM, + } + + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + transformed = False + if (model_class:=model.__class__) in cls._module_mapping: + model.forward = MethodType(tlm_forward, model) + transformed = True + else: + raise NotImplementedError(f"model class {model_class} does not yet support returning multiple logits to keep.") + + return model, transformed \ No newline at end of file diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 510e7ab8c..6ef6d63cf 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -185,11 +185,12 @@ def load_hf_tokenizer( def get_qpc_dir_path( - model_card_name, num_cores, mos, batch_size, prompt_len, ctx_len, mxfp6, mxint8, device_group, full_batch_size + model_card_name, num_cores, mos, batch_size, prompt_len, ctx_len, mxfp6, mxint8, device_group, full_batch_size, num_speculative_tokens: Optional[int] = None ): # Create a unique directory name for the QPC model based on all parameters qpc_base_dir_name = ( f"qpc_{num_cores}cores_{batch_size}bs_{prompt_len}pl_{ctx_len}cl_{mos}mos" + + f"_{num_speculative_tokens}nst" if num_speculative_tokens else '' + f"{f'_{full_batch_size}fbs_' if full_batch_size is not None else '_'}" + f"{len(device_group) if device_group is not None else 1}" + "devices" diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index c45cfec41..ed6e17be1 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +from typing import Optional + import numpy as np import torch @@ -12,7 +14,9 @@ class InputHandler: - def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size): + def __init__( + self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size, num_logits_to_keep: Optional[int] + ): """ Initialization @@ -24,6 +28,10 @@ def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, f :prompt_len (int): Prompt length for the model to compile. :ctx_len (int): Maximum context length to compile the model. :full_batch_size (int): Continuous batching batch size + :num_logits_to_keep (Optional[int]): + Calculate logits for the last valid `num_logits_to_keep` tokens. + Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. """ # check and fix tokenizer viability padding_check_and_fix(tokenizer) @@ -32,6 +40,7 @@ def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, f self.prompt_len = prompt_len self.ctx_len = ctx_len self.full_batch_size = full_batch_size + self.num_logits_to_keep = num_logits_to_keep self.n_layer = get_num_layers_from_config(config) self.padding_shape = get_padding_shape_from_config( config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len @@ -72,9 +81,15 @@ def prepare_pytorch_inputs(self): ) if self.full_batch_size: - inputs["input_ids"] = input_ids - inputs["position_ids"] = torch.arange(input_len).view(1, input_len) + # Feed input without padding (CB pt forward pass fails if padding exists in position_ids) inputs["batch_index"] = torch.arange(1).view(-1, 1) + if self.num_logits_to_keep is not None: + # preserve length after padding to assert `num_logits_to_keep<=padded_length` + length = inputs["position_ids"].size(1) + inputs["position_ids"] = torch.arange(length).view(1, -1) + else: + inputs["input_ids"] = input_ids + inputs["position_ids"] = position_ids past_key_values = [] for i in range(self.n_layer): @@ -97,23 +112,31 @@ def update_pytorch_inputs(self, inputs, pt_outputs): Return: :Dict: Updated input_ids, position_ids and past_key_values """ + decode_len = 1 if self.num_logits_to_keep is None else self.num_logits_to_keep updated_inputs = {} if self.full_batch_size: + # Create CB inputs (make 1 batch index have proper inputs for decode pass) batch_index = torch.arange(1).view(-1, 1) - - input_ids = pt_outputs.logits.detach().argmax(2) - updated_inputs["input_ids"] = torch.full((self.full_batch_size, 1), self.tokenizer.pad_token_id) - updated_inputs["input_ids"][batch_index.view(-1)] = input_ids - - position_ids = inputs["position_ids"].max(1, keepdim=True).values + 1 - updated_inputs["position_ids"] = torch.full((self.full_batch_size, 1), 0) - updated_inputs["position_ids"][batch_index.view(-1)] = position_ids - + batch_idx_input_ids = pt_outputs.logits.detach().argmax(2) + input_ids = torch.full((self.full_batch_size, decode_len), self.tokenizer.pad_token_id) + input_ids[batch_index.view(-1)] = batch_idx_input_ids + position_ids = torch.full((self.full_batch_size, decode_len), 0) + batch_idx_position_ids = torch.arange(decode_len).view(1,-1) + (inputs["position_ids"].max(1, keepdim=True).values + 1) + position_ids[batch_index.view(-1)] = batch_idx_position_ids + updated_inputs["input_ids"] = input_ids + updated_inputs["position_ids"] = position_ids updated_inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1) else: - updated_inputs["input_ids"] = pt_outputs["logits"].argmax(-1).reshape(-1, 1) - updated_inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1 + if self.num_logits_to_keep is not None: + input_ids = pt_outputs["logits"].argmax(-1) # shape: [batch_size, num_logits_to_keep] + batch_size = input_ids.size(0) + position_ids = torch.arange(self.num_logits_to_keep).view(1, self.num_logits_to_keep).repeat(batch_size, 1) + else: + input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1) + position_ids = inputs["position_ids"].max(1, keepdim=True).values + 1 + updated_inputs["input_ids"] = input_ids + updated_inputs["position_ids"] = position_ids updated_inputs["past_key_values"] = tuple( [(key.detach(), value.detach()) for key, value in pt_outputs["past_key_values"]] diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index 1ece48368..c9885cf13 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -149,3 +149,18 @@ Benchmark the model on Cloud AI 100, run the infer API to print tokens and tok/s qeff_model.generate(prompts=["My name is"]) ``` End to End demo examples for various models are available in **notebooks** directory. Please check them out. + +### Draft-Based Speculative Decoding +Draft-based speculative decoding is the approach where a small Draft Language Model (DLM) makes `num_logits_to_keep` autoregressive speculations ahead of the Target Language Model (TLM). The objective is to predict what the TLM would have predicted if it would have been used instead of the DLM. This approach is beneficial when the autoregressive decode phase of the TLM is memory bound and thus, we can leverage the extra computing resources of our hardware by batching the speculations of the DLM. + +To export both DLM/TLM, add below flags to `from_pretrained`: + +```Python + +tlm_name = "meta-llama/Llama-3.1-405B" +dlm_name = "meta-llama/Llama-3.1-8B" +k = 3 # DLM will make `k` speculations +tlm = AutoModelForCausalLM.from_pretrained(tlm_name, num_speculative_tokens=k) +dlm = AutoModelForCausalLM.from_pretrained(dlm_name, is_dlm=True) +``` +Once done, the same high-level python APIs of `export` and `compile` can be used to generate QPC. \ No newline at end of file diff --git a/tests/spd/test_tlm_dlm_export_and_compile.py b/tests/spd/test_tlm_dlm_export_and_compile.py new file mode 100644 index 000000000..cec269ffa --- /dev/null +++ b/tests/spd/test_tlm_dlm_export_and_compile.py @@ -0,0 +1,172 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import List, Optional + +import numpy as np +import pytest +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession + +configs = [ + pytest.param( + [0], # device_group + 2, # num_speculative_tokens + 32, # prefill_seq_len + 128, # ctx_len + 1, # prefill_bsz + 8, # full_batch_size + "JackFram/llama-68m", # model_name + True, # continuous_batching + id="CB llama", + ), + pytest.param( + [0], # device_group + 2, # num_speculative_tokens + 32, # prefill_seq_len + 128, # ctx_len + 1, # prefill_bsz + None, # full_batch_size + "JackFram/llama-68m", # model_name + False, # continuous_batching + id="non-CB llama", + ), +] + + +@pytest.mark.parametrize( + "device_group,num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name,continuous_batching", configs +) +def test_llama_tlm_logit_dims( + device_group: List[int], + num_speculative_tokens: int, + prefill_seq_len: int, + ctx_len: int, + prefill_bsz: int, + full_batch_size: Optional[int], + model_name: str, + continuous_batching: bool, +): + # get vocab size + tokenizer = AutoTokenizer.from_pretrained(model_name) + vocab_size = len(tokenizer) + + # export and compile tlm model + qeff_model = AutoModelForCausalLM.from_pretrained(model_name, continuous_batching=continuous_batching, num_speculative_tokens=num_speculative_tokens) + qpc_path: str = qeff_model.compile( + num_devices=len(device_group), + num_cores=16, + batch_size=prefill_bsz, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + mxfp6_matmul=True, + full_batch_size=full_batch_size, + ) + + # init qaic session + session = QAICInferenceSession(qpc_path, device_ids=device_group) + # skip inputs/outputs buffers + session.skip_buffers(set([x for x in session.input_names if x.startswith("past_")])) + session.skip_buffers(set([x for x in session.output_names if x.endswith("_RetainedState")])) + # prefill dummy inputs + prefill_inputs = dict( + input_ids=np.zeros((prefill_bsz, prefill_seq_len), dtype=np.int64), + position_ids=np.arange(prefill_seq_len, dtype=np.int64).reshape(-1, 1).repeat(prefill_bsz, 1).transpose(), + ) + # decode dummy inputs + num_logits_to_keep = num_speculative_tokens + 1 + decode_bsz = full_batch_size if full_batch_size is not None else prefill_bsz + decode_inputs = dict( + input_ids=np.zeros((decode_bsz, num_logits_to_keep), dtype=np.int64), + position_ids=np.full((decode_bsz, num_logits_to_keep), -1, dtype=np.int64), + ) + if full_batch_size is not None: + prefill_inputs["batch_index"] = np.arange(prefill_bsz, dtype=np.int64).reshape(prefill_bsz, 1) + decode_inputs["batch_index"] = np.arange(decode_bsz, dtype=np.int64).reshape(-1, 1) + # create dummy logits + prefill_logits = dict(logits=np.random.randn(prefill_bsz, num_logits_to_keep, vocab_size).astype(np.float32)) + decode_logits = dict(logits=np.random.randn(decode_bsz, num_logits_to_keep, vocab_size).astype(np.float32)) + # get prefill/decode logits + session.set_buffers(prefill_logits) + prefill_outputs = session.run(prefill_inputs) + session.set_buffers(decode_logits) + decode_outputs = session.run(decode_inputs) + + # assert expected logit dims + assert prefill_logits["logits"].shape == prefill_outputs["logits"].shape + assert decode_logits["logits"].shape == decode_outputs["logits"].shape + + +@pytest.mark.parametrize( + "device_group,num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name,continuous_batching", configs +) +def test_llama_dlm_logit_dims( + device_group: List[int], + num_speculative_tokens: int, + prefill_seq_len: int, + ctx_len: int, + prefill_bsz: int, + full_batch_size: Optional[int], + model_name: str, + continuous_batching: bool, +): + # get vocab size + tokenizer = AutoTokenizer.from_pretrained(model_name) + vocab_size = len(tokenizer) + + # export and compile tlm model + qeff_model = AutoModelForCausalLM.from_pretrained(model_name, continuous_batching=continuous_batching, is_dlm=True) + qpc_path: str = qeff_model.compile( + num_devices=len(device_group), + num_cores=16, + batch_size=prefill_bsz, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + mxfp6_matmul=True, + full_batch_size=full_batch_size, + ) + + # init qaic session + session = QAICInferenceSession(qpc_path, device_ids=device_group) + # skip inputs/outputs buffers + session.skip_buffers(set([x for x in session.input_names if x.startswith("past_")])) + session.skip_buffers(set([x for x in session.output_names if x.endswith("_RetainedState")])) + # prefill dummy inputs + prefill_inputs = dict( + input_ids=np.zeros((prefill_bsz, prefill_seq_len), dtype=np.int64), + position_ids=np.arange(prefill_seq_len, dtype=np.int64).reshape(-1, 1).repeat(prefill_bsz, 1).transpose(), + batch_index=np.arange(prefill_bsz, dtype=np.int64).reshape(-1, 1), + ) + # decode-1 dummy inputs + decode_bsz = full_batch_size if full_batch_size is not None else prefill_bsz + decode1_inputs = dict( + input_ids=np.zeros((decode_bsz, 1), dtype=np.int64), + position_ids=np.full((decode_bsz, 1), -1, dtype=np.int64), + batch_index=np.arange(decode_bsz, dtype=np.int64).reshape(-1, 1), + ) + # decode-2 dummy inputs + decode2_inputs = dict( + input_ids=np.zeros((decode_bsz, 2), dtype=np.int64), + position_ids=np.full((decode_bsz, 2), -1, dtype=np.int64), + batch_index=np.arange(decode_bsz, dtype=np.int64).reshape(-1, 1), + ) + # create dummy logits + prefill_logits = dict(logits=np.random.randn(prefill_bsz, 1, vocab_size).astype(np.float32)) + decode_logits = dict(logits=np.random.randn(decode_bsz, 1, vocab_size).astype(np.float32)) + # get prefill/decode logits + session.set_buffers(prefill_logits) + prefill_outputs = session.run(prefill_inputs) + session.set_buffers(decode_logits) + decode1_outputs = session.run(decode1_inputs) + decode2_outputs = session.run(decode2_inputs) + + # assert expected logit dims + assert prefill_logits["logits"].shape == prefill_outputs["logits"].shape + assert decode_logits["logits"].shape == decode1_outputs["logits"].shape + assert decode_logits["logits"].shape == decode2_outputs["logits"].shape diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 6f0402c1b..63d8bf836 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -152,8 +152,8 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytest.skip("No available devices to run model on Cloud AI 100") _ = qeff_model.compile( - prefill_seq_len=prompt_len, - ctx_len=ctx_len, + prefill_seq_len=8, + ctx_len=32, num_cores=14, mxfp6=False, aic_enable_depth_first=False, @@ -197,30 +197,3 @@ def test_causal_lm_export_with_deprecated_api(model_name): assert ( new_api_ort_tokens == old_api_ort_tokens ).all(), "New API output does not match old API output for ONNX export function" - - -@pytest.mark.on_qaic -@pytest.mark.parametrize("model_name", test_models) -def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): - """ - Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - """ - if model_name == "microsoft/Phi-3-mini-4k-instruct": - n_layer = 2 # test only 2 layer models - else: - n_layer = 1 - - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) - - -@pytest.mark.on_qaic -def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(): - """ - Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model for a prompt length of 1, both with and without continuous batching. - """ - model_name = "gpt2" - prompt_len = 1 - - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, prompt_len=prompt_len) From 6aae2872e243baad2487bfbebfd1fa45bc10cd3e Mon Sep 17 00:00:00 2001 From: eplatero Date: Mon, 18 Nov 2024 16:02:13 -0600 Subject: [PATCH 02/30] add decode_seq_len to non-continuous batching case Signed-off-by: eplatero --- QEfficient/transformers/models/modeling_auto.py | 2 +- docs/source/quick_start.md | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index fbe9aacf3..b69f9835d 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -297,7 +297,7 @@ def compile( {"batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len}, ] if prefill_seq_len != 1: - specializations.append({"batch_size": batch_size, "seq_len": 1, "ctx_len": ctx_len}) + specializations.append({"batch_size": batch_size, "seq_len": decode_seq_len, "ctx_len": ctx_len}) if self.is_dlm: specializations.append( {"batch_size": batch_size, "seq_len": 2, "ctx_len": ctx_len}, diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index c9885cf13..34e2a6bc7 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -156,7 +156,6 @@ Draft-based speculative decoding is the approach where a small Draft Language Mo To export both DLM/TLM, add below flags to `from_pretrained`: ```Python - tlm_name = "meta-llama/Llama-3.1-405B" dlm_name = "meta-llama/Llama-3.1-8B" k = 3 # DLM will make `k` speculations From 7472df86b91e0a65461165d200266b0ae77a5dce Mon Sep 17 00:00:00 2001 From: eplatero Date: Mon, 18 Nov 2024 16:16:00 -0600 Subject: [PATCH 03/30] mirror test_causal_lm_models.py from main Signed-off-by: eplatero --- .../models/test_causal_lm_models.py | 31 +++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 63d8bf836..9b926c8c7 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -152,8 +152,8 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytest.skip("No available devices to run model on Cloud AI 100") _ = qeff_model.compile( - prefill_seq_len=8, - ctx_len=32, + prefill_seq_len=prompt_len, + ctx_len=ctx_len, num_cores=14, mxfp6=False, aic_enable_depth_first=False, @@ -197,3 +197,30 @@ def test_causal_lm_export_with_deprecated_api(model_name): assert ( new_api_ort_tokens == old_api_ort_tokens ).all(), "New API output does not match old API output for ONNX export function" + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_name", test_models) +def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + if model_name == "microsoft/Phi-3-mini-4k-instruct": + n_layer = 2 # test only 2 layer models + else: + n_layer = 1 + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) + + +@pytest.mark.on_qaic +def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model for a prompt length of 1, both with and without continuous batching. + """ + model_name = "gpt2" + prompt_len = 1 + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, prompt_len=prompt_len) \ No newline at end of file From bc702b052cb55582aacfb9947669eee0206a96b2 Mon Sep 17 00:00:00 2001 From: eplatero Date: Mon, 18 Nov 2024 16:28:46 -0600 Subject: [PATCH 04/30] add more to the explanation of the model changes Signed-off-by: eplatero --- docs/source/quick_start.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index 34e2a6bc7..dd097e6dd 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -162,4 +162,7 @@ k = 3 # DLM will make `k` speculations tlm = AutoModelForCausalLM.from_pretrained(tlm_name, num_speculative_tokens=k) dlm = AutoModelForCausalLM.from_pretrained(dlm_name, is_dlm=True) ``` -Once done, the same high-level python APIs of `export` and `compile` can be used to generate QPC. \ No newline at end of file +Once done, the same high-level python APIs of `export` and `compile` can be used to generate QPC. + +When `num_speculative_tokens` is specified, QEfficient transforms the TLM to always output `num_speculative_tokens+1` logits per batch for both prefill and decode. While only the last logit corresponding to the last autoregressive token is needed in prefill, for decode phase, we take in as batch input the speculations from the DLM. As for the DLM, the only addition of adding the `is_dlm=True` flag is that an extra specialization file with `seq_len=2` is created to account for the "bonus" token that happens when all speculations are correct. +> NOTE: due to some compiler limitations, it is currently not possible to create an onnx-graph that parametrizes `num_speculative_tokens`. Because of this, a unique onnx-graph will be created per unique-specified `num_speculative_tokens`. This is also why `num_speculative_tokens+1` will be returned for both prefill and decode. \ No newline at end of file From e630b8f3acb52cbd992378cca0abf44c405c852e Mon Sep 17 00:00:00 2001 From: eplatero Date: Tue, 19 Nov 2024 05:29:12 -0600 Subject: [PATCH 05/30] lint fixing Signed-off-by: eplatero --- QEfficient/compile/compile_helper.py | 5 +---- QEfficient/exporter/export_hf_to_cloud_ai_100.py | 2 +- QEfficient/transformers/models/modeling_auto.py | 2 +- QEfficient/transformers/models/spd/modeling_tlm.py | 1 + QEfficient/transformers/pytorch_transforms.py | 2 +- 5 files changed, 5 insertions(+), 7 deletions(-) diff --git a/QEfficient/compile/compile_helper.py b/QEfficient/compile/compile_helper.py index a6b0b19f3..afc1f5b75 100644 --- a/QEfficient/compile/compile_helper.py +++ b/QEfficient/compile/compile_helper.py @@ -30,10 +30,7 @@ def create_and_dump_specializations( dict(batch_size=str(batch_size), seq_len=str(prompt_len), ctx_len=str(ctx_len)), # prefill dict(batch_size=str(batch_size), seq_len=str(decode_seq_len), ctx_len=str(ctx_len)) # decode ] - if num_logits_to_keep is not None: - specialization_cfgs[0]["num_logits_to_keep"] = "1" # return last logit - specialization_cfgs[1]["num_logits_to_keep"] = str(num_logits_to_keep+1) # return all SpD decode logits - elif is_dlm: + if is_dlm: specialization_cfgs.append( dict(batch_size=str(batch_size), seq_len="2", ctx_len=str(ctx_len)) ) diff --git a/QEfficient/exporter/export_hf_to_cloud_ai_100.py b/QEfficient/exporter/export_hf_to_cloud_ai_100.py index bc5a039f2..d962a85b4 100644 --- a/QEfficient/exporter/export_hf_to_cloud_ai_100.py +++ b/QEfficient/exporter/export_hf_to_cloud_ai_100.py @@ -5,8 +5,8 @@ # # ----------------------------------------------------------------------------- -import os import math +import os import shutil import warnings from typing import Optional, Tuple, Union diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index b69f9835d..b39438e62 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -5,9 +5,9 @@ # # ---------------------------------------------------------------------------- -import math import hashlib import logging +import math import warnings from pathlib import Path from typing import Any, List, Optional, Union diff --git a/QEfficient/transformers/models/spd/modeling_tlm.py b/QEfficient/transformers/models/spd/modeling_tlm.py index 23b867cf9..f59f92640 100644 --- a/QEfficient/transformers/models/spd/modeling_tlm.py +++ b/QEfficient/transformers/models/spd/modeling_tlm.py @@ -14,6 +14,7 @@ from QEfficient.transformers.modeling_spd_utils import filter_hidden_states + def tlm_forward( self, input_ids: torch.LongTensor = None, diff --git a/QEfficient/transformers/pytorch_transforms.py b/QEfficient/transformers/pytorch_transforms.py index d46e7f817..cab44a3f3 100644 --- a/QEfficient/transformers/pytorch_transforms.py +++ b/QEfficient/transformers/pytorch_transforms.py @@ -194,13 +194,13 @@ QEffQwen2ForCausalLM, QEffQwen2Model, ) +from QEfficient.transformers.models.spd.modeling_tlm import tlm_forward from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import ( QEffStarcoder2Attention, QEFFStarcoder2DecoderLayer, QEffStarcoder2ForCausalLM, QEffStarcoder2Model, ) -from QEfficient.transformers.models.spd.modeling_tlm import tlm_forward class CustomOpsTransform(ModuleMappingTransform): From 840cb9fcec86cf2d639f2cba322dca6a35c461b0 Mon Sep 17 00:00:00 2001 From: eplatero Date: Tue, 19 Nov 2024 05:34:54 -0600 Subject: [PATCH 06/30] alphabetical order imports on pytorch_transforms.py Signed-off-by: eplatero --- QEfficient/transformers/pytorch_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/transformers/pytorch_transforms.py b/QEfficient/transformers/pytorch_transforms.py index cab44a3f3..74a779b67 100644 --- a/QEfficient/transformers/pytorch_transforms.py +++ b/QEfficient/transformers/pytorch_transforms.py @@ -5,8 +5,8 @@ # # ----------------------------------------------------------------------------- -from typing import Tuple from types import MethodType +from typing import Tuple import transformers from torch import nn From ccdcfb77b6ca95492c8d93330120d0c9ffd6522a Mon Sep 17 00:00:00 2001 From: eplatero Date: Tue, 19 Nov 2024 22:42:56 -0600 Subject: [PATCH 07/30] add init to spd directory Signed-off-by: eplatero --- QEfficient/transformers/models/spd/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 QEfficient/transformers/models/spd/__init__.py diff --git a/QEfficient/transformers/models/spd/__init__.py b/QEfficient/transformers/models/spd/__init__.py new file mode 100644 index 000000000..da26921c5 --- /dev/null +++ b/QEfficient/transformers/models/spd/__init__.py @@ -0,0 +1,7 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + From ed57de7a38f26c6560bf125f99f8988d810191cf Mon Sep 17 00:00:00 2001 From: eplatero Date: Tue, 19 Nov 2024 23:23:46 -0600 Subject: [PATCH 08/30] replace modifying seq_len by letting user define proper config Signed-off-by: eplatero --- QEfficient/base/modeling_qeff.py | 6 +++++- QEfficient/transformers/models/modeling_auto.py | 16 ++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index e4327dd41..a7aac49c7 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -222,7 +222,11 @@ def _compile( - convert_to_fp16=True -> -convert-to-fp16 """ if onnx_path is None and self.onnx_path is None: - self.export() + if self.num_speculative_tokens is not None: + prefill_seq_len = specializations[0]["seq_len"] + self.export(seq_len=prefill_seq_len) + else: + self.export() onnx_path = Path(onnx_path or self.onnx_path) compile_dir = Path(compile_dir or onnx_path.parent) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index b39438e62..083a41e4c 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -182,24 +182,28 @@ def model_hash(self) -> str: mhash = mhash.hexdigest()[:16] return mhash - def export(self, export_dir: Optional[str] = None) -> str: + def export( + self, + export_dir: Optional[str] = None, + seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) -> str: """ Exports the model to ``ONNX`` format using ``torch.onnx.export``. We currently don't support exporting non-transformed models. Please refer to the ``convert_to_cloud_bertstyle`` function in the **Low-Level API** for a legacy function that supports this." ``Optional`` Args: - does not any arguments. + :export_dir (str, optional): The directory path to store ONNX-graph. + :seq_len (int, optional): The length of the pytorch prompt inputs.. ``Defaults to 32``. Returns: :str: Path of the generated ``ONNX`` graph. """ bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - if self.num_speculative_tokens: + if self.num_speculative_tokens is not None: num_logits_to_keep = self.num_speculative_tokens+1 - setattr(self.model, "num_logits_to_keep", num_logits_to_keep) if seq_len < num_logits_to_keep: - seq_len *= math.ceil((num_logits_to_keep) / seq_len) + raise ValueError(f"sequence length ({seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})") + setattr(self.model, "num_logits_to_keep", num_logits_to_keep) fbs = constants.ONNX_EXPORT_EXAMPLE_FBS kv_cache_shape = get_padding_shape_from_config( self.model.config, fbs if self.continuous_batching else bs, seq_len From 0a0683d6bb3da3a0be2d5548670388f83d1657ff Mon Sep 17 00:00:00 2001 From: eplatero Date: Wed, 20 Nov 2024 22:49:29 -0600 Subject: [PATCH 09/30] resolving 1st round comments from Onkar and made fix on gather implementation Signed-off-by: eplatero --- .../transformers/models/modeling_auto.py | 1 - QEfficient/transformers/pytorch_transforms.py | 2 +- .../transformers/{models => }/spd/__init__.py | 0 .../{ => spd}/modeling_spd_utils.py | 3 +- .../{models => }/spd/modeling_tlm.py | 2 +- QEfficient/utils/generate_inputs.py | 2 +- docs/source/quick_start.md | 4 +- .../spd/test_tlm_dlm_export_and_compile.py | 172 ++++++++++++++++++ 8 files changed, 179 insertions(+), 7 deletions(-) rename QEfficient/transformers/{models => }/spd/__init__.py (100%) rename QEfficient/transformers/{ => spd}/modeling_spd_utils.py (92%) rename QEfficient/transformers/{models => }/spd/modeling_tlm.py (98%) create mode 100644 tests/transformers/spd/test_tlm_dlm_export_and_compile.py diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 083a41e4c..4bcc6291b 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -7,7 +7,6 @@ import hashlib import logging -import math import warnings from pathlib import Path from typing import Any, List, Optional, Union diff --git a/QEfficient/transformers/pytorch_transforms.py b/QEfficient/transformers/pytorch_transforms.py index 74a779b67..a68284150 100644 --- a/QEfficient/transformers/pytorch_transforms.py +++ b/QEfficient/transformers/pytorch_transforms.py @@ -194,13 +194,13 @@ QEffQwen2ForCausalLM, QEffQwen2Model, ) -from QEfficient.transformers.models.spd.modeling_tlm import tlm_forward from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import ( QEffStarcoder2Attention, QEFFStarcoder2DecoderLayer, QEffStarcoder2ForCausalLM, QEffStarcoder2Model, ) +from QEfficient.transformers.spd.modeling_tlm import tlm_forward class CustomOpsTransform(ModuleMappingTransform): diff --git a/QEfficient/transformers/models/spd/__init__.py b/QEfficient/transformers/spd/__init__.py similarity index 100% rename from QEfficient/transformers/models/spd/__init__.py rename to QEfficient/transformers/spd/__init__.py diff --git a/QEfficient/transformers/modeling_spd_utils.py b/QEfficient/transformers/spd/modeling_spd_utils.py similarity index 92% rename from QEfficient/transformers/modeling_spd_utils.py rename to QEfficient/transformers/spd/modeling_spd_utils.py index 863adff79..eb72388b9 100644 --- a/QEfficient/transformers/modeling_spd_utils.py +++ b/QEfficient/transformers/spd/modeling_spd_utils.py @@ -9,6 +9,7 @@ import torch + def filter_hidden_states( hidden_states: torch.Tensor, position_ids: torch.Tensor, @@ -34,7 +35,7 @@ def filter_hidden_states( # return the last logit return hidden_states[batch_indices.view(-1, 1), logit_index] # gather approach - lower_idx = torch.where(logit_index <= num_logits_to_keep, 0, logit_index - num_logits_to_keep).view(-1,1) # shape: [bsz, 1] + lower_idx = torch.where(logit_index < num_logits_to_keep, 0, logit_index+1 - num_logits_to_keep).view(-1,1) # shape: [bsz, 1] spec_idx = torch.arange(num_logits_to_keep).view(1,-1) # shape: [1, k] indices = torch.add(lower_idx, spec_idx).unsqueeze(2) # shape: [bsz, k, 1] indices = indices.repeat(1, 1, hidden_states.size(-1)) # shape: [bsz, ,k, d_model] diff --git a/QEfficient/transformers/models/spd/modeling_tlm.py b/QEfficient/transformers/spd/modeling_tlm.py similarity index 98% rename from QEfficient/transformers/models/spd/modeling_tlm.py rename to QEfficient/transformers/spd/modeling_tlm.py index f59f92640..67c839b47 100644 --- a/QEfficient/transformers/models/spd/modeling_tlm.py +++ b/QEfficient/transformers/spd/modeling_tlm.py @@ -12,7 +12,7 @@ from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast -from QEfficient.transformers.modeling_spd_utils import filter_hidden_states +from QEfficient.transformers.spd.modeling_spd_utils import filter_hidden_states def tlm_forward( diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index ed6e17be1..a826a3a00 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -15,7 +15,7 @@ class InputHandler: def __init__( - self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size, num_logits_to_keep: Optional[int] + self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size, num_logits_to_keep: Optional[int] = None ): """ Initialization diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index dd097e6dd..dbc6ac8ee 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -156,8 +156,8 @@ Draft-based speculative decoding is the approach where a small Draft Language Mo To export both DLM/TLM, add below flags to `from_pretrained`: ```Python -tlm_name = "meta-llama/Llama-3.1-405B" -dlm_name = "meta-llama/Llama-3.1-8B" +tlm_name = "meta-llama/Llama-2-70b-chat-hf" +dlm_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" k = 3 # DLM will make `k` speculations tlm = AutoModelForCausalLM.from_pretrained(tlm_name, num_speculative_tokens=k) dlm = AutoModelForCausalLM.from_pretrained(dlm_name, is_dlm=True) diff --git a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py new file mode 100644 index 000000000..cec269ffa --- /dev/null +++ b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py @@ -0,0 +1,172 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import List, Optional + +import numpy as np +import pytest +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession + +configs = [ + pytest.param( + [0], # device_group + 2, # num_speculative_tokens + 32, # prefill_seq_len + 128, # ctx_len + 1, # prefill_bsz + 8, # full_batch_size + "JackFram/llama-68m", # model_name + True, # continuous_batching + id="CB llama", + ), + pytest.param( + [0], # device_group + 2, # num_speculative_tokens + 32, # prefill_seq_len + 128, # ctx_len + 1, # prefill_bsz + None, # full_batch_size + "JackFram/llama-68m", # model_name + False, # continuous_batching + id="non-CB llama", + ), +] + + +@pytest.mark.parametrize( + "device_group,num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name,continuous_batching", configs +) +def test_llama_tlm_logit_dims( + device_group: List[int], + num_speculative_tokens: int, + prefill_seq_len: int, + ctx_len: int, + prefill_bsz: int, + full_batch_size: Optional[int], + model_name: str, + continuous_batching: bool, +): + # get vocab size + tokenizer = AutoTokenizer.from_pretrained(model_name) + vocab_size = len(tokenizer) + + # export and compile tlm model + qeff_model = AutoModelForCausalLM.from_pretrained(model_name, continuous_batching=continuous_batching, num_speculative_tokens=num_speculative_tokens) + qpc_path: str = qeff_model.compile( + num_devices=len(device_group), + num_cores=16, + batch_size=prefill_bsz, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + mxfp6_matmul=True, + full_batch_size=full_batch_size, + ) + + # init qaic session + session = QAICInferenceSession(qpc_path, device_ids=device_group) + # skip inputs/outputs buffers + session.skip_buffers(set([x for x in session.input_names if x.startswith("past_")])) + session.skip_buffers(set([x for x in session.output_names if x.endswith("_RetainedState")])) + # prefill dummy inputs + prefill_inputs = dict( + input_ids=np.zeros((prefill_bsz, prefill_seq_len), dtype=np.int64), + position_ids=np.arange(prefill_seq_len, dtype=np.int64).reshape(-1, 1).repeat(prefill_bsz, 1).transpose(), + ) + # decode dummy inputs + num_logits_to_keep = num_speculative_tokens + 1 + decode_bsz = full_batch_size if full_batch_size is not None else prefill_bsz + decode_inputs = dict( + input_ids=np.zeros((decode_bsz, num_logits_to_keep), dtype=np.int64), + position_ids=np.full((decode_bsz, num_logits_to_keep), -1, dtype=np.int64), + ) + if full_batch_size is not None: + prefill_inputs["batch_index"] = np.arange(prefill_bsz, dtype=np.int64).reshape(prefill_bsz, 1) + decode_inputs["batch_index"] = np.arange(decode_bsz, dtype=np.int64).reshape(-1, 1) + # create dummy logits + prefill_logits = dict(logits=np.random.randn(prefill_bsz, num_logits_to_keep, vocab_size).astype(np.float32)) + decode_logits = dict(logits=np.random.randn(decode_bsz, num_logits_to_keep, vocab_size).astype(np.float32)) + # get prefill/decode logits + session.set_buffers(prefill_logits) + prefill_outputs = session.run(prefill_inputs) + session.set_buffers(decode_logits) + decode_outputs = session.run(decode_inputs) + + # assert expected logit dims + assert prefill_logits["logits"].shape == prefill_outputs["logits"].shape + assert decode_logits["logits"].shape == decode_outputs["logits"].shape + + +@pytest.mark.parametrize( + "device_group,num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name,continuous_batching", configs +) +def test_llama_dlm_logit_dims( + device_group: List[int], + num_speculative_tokens: int, + prefill_seq_len: int, + ctx_len: int, + prefill_bsz: int, + full_batch_size: Optional[int], + model_name: str, + continuous_batching: bool, +): + # get vocab size + tokenizer = AutoTokenizer.from_pretrained(model_name) + vocab_size = len(tokenizer) + + # export and compile tlm model + qeff_model = AutoModelForCausalLM.from_pretrained(model_name, continuous_batching=continuous_batching, is_dlm=True) + qpc_path: str = qeff_model.compile( + num_devices=len(device_group), + num_cores=16, + batch_size=prefill_bsz, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + mxfp6_matmul=True, + full_batch_size=full_batch_size, + ) + + # init qaic session + session = QAICInferenceSession(qpc_path, device_ids=device_group) + # skip inputs/outputs buffers + session.skip_buffers(set([x for x in session.input_names if x.startswith("past_")])) + session.skip_buffers(set([x for x in session.output_names if x.endswith("_RetainedState")])) + # prefill dummy inputs + prefill_inputs = dict( + input_ids=np.zeros((prefill_bsz, prefill_seq_len), dtype=np.int64), + position_ids=np.arange(prefill_seq_len, dtype=np.int64).reshape(-1, 1).repeat(prefill_bsz, 1).transpose(), + batch_index=np.arange(prefill_bsz, dtype=np.int64).reshape(-1, 1), + ) + # decode-1 dummy inputs + decode_bsz = full_batch_size if full_batch_size is not None else prefill_bsz + decode1_inputs = dict( + input_ids=np.zeros((decode_bsz, 1), dtype=np.int64), + position_ids=np.full((decode_bsz, 1), -1, dtype=np.int64), + batch_index=np.arange(decode_bsz, dtype=np.int64).reshape(-1, 1), + ) + # decode-2 dummy inputs + decode2_inputs = dict( + input_ids=np.zeros((decode_bsz, 2), dtype=np.int64), + position_ids=np.full((decode_bsz, 2), -1, dtype=np.int64), + batch_index=np.arange(decode_bsz, dtype=np.int64).reshape(-1, 1), + ) + # create dummy logits + prefill_logits = dict(logits=np.random.randn(prefill_bsz, 1, vocab_size).astype(np.float32)) + decode_logits = dict(logits=np.random.randn(decode_bsz, 1, vocab_size).astype(np.float32)) + # get prefill/decode logits + session.set_buffers(prefill_logits) + prefill_outputs = session.run(prefill_inputs) + session.set_buffers(decode_logits) + decode1_outputs = session.run(decode1_inputs) + decode2_outputs = session.run(decode2_inputs) + + # assert expected logit dims + assert prefill_logits["logits"].shape == prefill_outputs["logits"].shape + assert decode_logits["logits"].shape == decode1_outputs["logits"].shape + assert decode_logits["logits"].shape == decode2_outputs["logits"].shape From 37b7b71df1cf044ac4bb2c884d1a0e39fd6f0f27 Mon Sep 17 00:00:00 2001 From: eplatero Date: Wed, 20 Nov 2024 23:01:30 -0600 Subject: [PATCH 10/30] removing old unit tests Signed-off-by: eplatero --- tests/spd/test_tlm_dlm_export_and_compile.py | 172 ------------------- 1 file changed, 172 deletions(-) delete mode 100644 tests/spd/test_tlm_dlm_export_and_compile.py diff --git a/tests/spd/test_tlm_dlm_export_and_compile.py b/tests/spd/test_tlm_dlm_export_and_compile.py deleted file mode 100644 index cec269ffa..000000000 --- a/tests/spd/test_tlm_dlm_export_and_compile.py +++ /dev/null @@ -1,172 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -from typing import List, Optional - -import numpy as np -import pytest -from transformers import AutoTokenizer - -from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM -from QEfficient.generation.cloud_infer import QAICInferenceSession - -configs = [ - pytest.param( - [0], # device_group - 2, # num_speculative_tokens - 32, # prefill_seq_len - 128, # ctx_len - 1, # prefill_bsz - 8, # full_batch_size - "JackFram/llama-68m", # model_name - True, # continuous_batching - id="CB llama", - ), - pytest.param( - [0], # device_group - 2, # num_speculative_tokens - 32, # prefill_seq_len - 128, # ctx_len - 1, # prefill_bsz - None, # full_batch_size - "JackFram/llama-68m", # model_name - False, # continuous_batching - id="non-CB llama", - ), -] - - -@pytest.mark.parametrize( - "device_group,num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name,continuous_batching", configs -) -def test_llama_tlm_logit_dims( - device_group: List[int], - num_speculative_tokens: int, - prefill_seq_len: int, - ctx_len: int, - prefill_bsz: int, - full_batch_size: Optional[int], - model_name: str, - continuous_batching: bool, -): - # get vocab size - tokenizer = AutoTokenizer.from_pretrained(model_name) - vocab_size = len(tokenizer) - - # export and compile tlm model - qeff_model = AutoModelForCausalLM.from_pretrained(model_name, continuous_batching=continuous_batching, num_speculative_tokens=num_speculative_tokens) - qpc_path: str = qeff_model.compile( - num_devices=len(device_group), - num_cores=16, - batch_size=prefill_bsz, - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - mxfp6_matmul=True, - full_batch_size=full_batch_size, - ) - - # init qaic session - session = QAICInferenceSession(qpc_path, device_ids=device_group) - # skip inputs/outputs buffers - session.skip_buffers(set([x for x in session.input_names if x.startswith("past_")])) - session.skip_buffers(set([x for x in session.output_names if x.endswith("_RetainedState")])) - # prefill dummy inputs - prefill_inputs = dict( - input_ids=np.zeros((prefill_bsz, prefill_seq_len), dtype=np.int64), - position_ids=np.arange(prefill_seq_len, dtype=np.int64).reshape(-1, 1).repeat(prefill_bsz, 1).transpose(), - ) - # decode dummy inputs - num_logits_to_keep = num_speculative_tokens + 1 - decode_bsz = full_batch_size if full_batch_size is not None else prefill_bsz - decode_inputs = dict( - input_ids=np.zeros((decode_bsz, num_logits_to_keep), dtype=np.int64), - position_ids=np.full((decode_bsz, num_logits_to_keep), -1, dtype=np.int64), - ) - if full_batch_size is not None: - prefill_inputs["batch_index"] = np.arange(prefill_bsz, dtype=np.int64).reshape(prefill_bsz, 1) - decode_inputs["batch_index"] = np.arange(decode_bsz, dtype=np.int64).reshape(-1, 1) - # create dummy logits - prefill_logits = dict(logits=np.random.randn(prefill_bsz, num_logits_to_keep, vocab_size).astype(np.float32)) - decode_logits = dict(logits=np.random.randn(decode_bsz, num_logits_to_keep, vocab_size).astype(np.float32)) - # get prefill/decode logits - session.set_buffers(prefill_logits) - prefill_outputs = session.run(prefill_inputs) - session.set_buffers(decode_logits) - decode_outputs = session.run(decode_inputs) - - # assert expected logit dims - assert prefill_logits["logits"].shape == prefill_outputs["logits"].shape - assert decode_logits["logits"].shape == decode_outputs["logits"].shape - - -@pytest.mark.parametrize( - "device_group,num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name,continuous_batching", configs -) -def test_llama_dlm_logit_dims( - device_group: List[int], - num_speculative_tokens: int, - prefill_seq_len: int, - ctx_len: int, - prefill_bsz: int, - full_batch_size: Optional[int], - model_name: str, - continuous_batching: bool, -): - # get vocab size - tokenizer = AutoTokenizer.from_pretrained(model_name) - vocab_size = len(tokenizer) - - # export and compile tlm model - qeff_model = AutoModelForCausalLM.from_pretrained(model_name, continuous_batching=continuous_batching, is_dlm=True) - qpc_path: str = qeff_model.compile( - num_devices=len(device_group), - num_cores=16, - batch_size=prefill_bsz, - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - mxfp6_matmul=True, - full_batch_size=full_batch_size, - ) - - # init qaic session - session = QAICInferenceSession(qpc_path, device_ids=device_group) - # skip inputs/outputs buffers - session.skip_buffers(set([x for x in session.input_names if x.startswith("past_")])) - session.skip_buffers(set([x for x in session.output_names if x.endswith("_RetainedState")])) - # prefill dummy inputs - prefill_inputs = dict( - input_ids=np.zeros((prefill_bsz, prefill_seq_len), dtype=np.int64), - position_ids=np.arange(prefill_seq_len, dtype=np.int64).reshape(-1, 1).repeat(prefill_bsz, 1).transpose(), - batch_index=np.arange(prefill_bsz, dtype=np.int64).reshape(-1, 1), - ) - # decode-1 dummy inputs - decode_bsz = full_batch_size if full_batch_size is not None else prefill_bsz - decode1_inputs = dict( - input_ids=np.zeros((decode_bsz, 1), dtype=np.int64), - position_ids=np.full((decode_bsz, 1), -1, dtype=np.int64), - batch_index=np.arange(decode_bsz, dtype=np.int64).reshape(-1, 1), - ) - # decode-2 dummy inputs - decode2_inputs = dict( - input_ids=np.zeros((decode_bsz, 2), dtype=np.int64), - position_ids=np.full((decode_bsz, 2), -1, dtype=np.int64), - batch_index=np.arange(decode_bsz, dtype=np.int64).reshape(-1, 1), - ) - # create dummy logits - prefill_logits = dict(logits=np.random.randn(prefill_bsz, 1, vocab_size).astype(np.float32)) - decode_logits = dict(logits=np.random.randn(decode_bsz, 1, vocab_size).astype(np.float32)) - # get prefill/decode logits - session.set_buffers(prefill_logits) - prefill_outputs = session.run(prefill_inputs) - session.set_buffers(decode_logits) - decode1_outputs = session.run(decode1_inputs) - decode2_outputs = session.run(decode2_inputs) - - # assert expected logit dims - assert prefill_logits["logits"].shape == prefill_outputs["logits"].shape - assert decode_logits["logits"].shape == decode1_outputs["logits"].shape - assert decode_logits["logits"].shape == decode2_outputs["logits"].shape From 15f95b3d0213f3acb24d55784c846b42c243048b Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 21 Nov 2024 18:59:26 +0530 Subject: [PATCH 11/30] * Added way to make num_logits_to_keep dynamic in ONNX and removed need to regenerate ONNX for different values of num_logits_to_keep only qpc is recompiled, * ran formatter , * reorganized pytorch transforms Signed-off-by: Onkar Chougule --- QEfficient/compile/compile_helper.py | 10 ++-- .../exporter/export_hf_to_cloud_ai_100.py | 7 +-- QEfficient/peft/auto.py | 2 +- .../transformers/models/modeling_auto.py | 56 +++++++++++++++---- .../{ => models}/pytorch_transforms.py | 14 +++-- .../{modeling_tlm.py => causal_lm_forward.py} | 45 +++++++++++++-- .../transformers/spd/modeling_spd_utils.py | 43 -------------- QEfficient/utils/_utils.py | 17 +++++- QEfficient/utils/generate_inputs.py | 24 ++++++-- .../models/test_causal_lm_models.py | 2 +- .../spd/test_tlm_dlm_export_and_compile.py | 46 ++++++++------- .../test_transformer_pytorch_transforms.py | 2 +- 12 files changed, 160 insertions(+), 108 deletions(-) rename QEfficient/transformers/{ => models}/pytorch_transforms.py (96%) rename QEfficient/transformers/spd/{modeling_tlm.py => causal_lm_forward.py} (71%) delete mode 100644 QEfficient/transformers/spd/modeling_spd_utils.py diff --git a/QEfficient/compile/compile_helper.py b/QEfficient/compile/compile_helper.py index afc1f5b75..f6d8b0228 100644 --- a/QEfficient/compile/compile_helper.py +++ b/QEfficient/compile/compile_helper.py @@ -25,15 +25,13 @@ def create_and_dump_specializations( num_speculative_tokens: Optional[int] = None, ): # Create specialization cfgs - decode_seq_len = 1 if num_speculative_tokens is None else num_speculative_tokens+1 + decode_seq_len = 1 if num_speculative_tokens is None else num_speculative_tokens + 1 specialization_cfgs = [ - dict(batch_size=str(batch_size), seq_len=str(prompt_len), ctx_len=str(ctx_len)), # prefill - dict(batch_size=str(batch_size), seq_len=str(decode_seq_len), ctx_len=str(ctx_len)) # decode + dict(batch_size=str(batch_size), seq_len=str(prompt_len), ctx_len=str(ctx_len)), # prefill + dict(batch_size=str(batch_size), seq_len=str(decode_seq_len), ctx_len=str(ctx_len)), # decode ] if is_dlm: - specialization_cfgs.append( - dict(batch_size=str(batch_size), seq_len="2", ctx_len=str(ctx_len)) - ) + specialization_cfgs.append(dict(batch_size=str(batch_size), seq_len="2", ctx_len=str(ctx_len))) specializations = dict(specializations=specialization_cfgs) diff --git a/QEfficient/exporter/export_hf_to_cloud_ai_100.py b/QEfficient/exporter/export_hf_to_cloud_ai_100.py index d962a85b4..bbfded9f9 100644 --- a/QEfficient/exporter/export_hf_to_cloud_ai_100.py +++ b/QEfficient/exporter/export_hf_to_cloud_ai_100.py @@ -204,13 +204,12 @@ def export_kvstyle_transformed_model_to_onnx( prompt_len = Constants.PROMPT_LEN num_logits_to_keep = None if num_speculative_tokens is not None: - num_logits_to_keep = num_speculative_tokens+1 + num_logits_to_keep = num_speculative_tokens + 1 setattr(transformed_model, "num_logits_to_keep", num_logits_to_keep) if prompt_len < num_logits_to_keep: prompt_len *= math.ceil((num_logits_to_keep) / prompt_len) if prompt_len >= seq_len: - seq_len = prompt_len*2 - + seq_len = prompt_len * 2 # Preprocess inputs # Build inputs for prefill @@ -343,7 +342,7 @@ def export_for_cloud( onnx_dir_path=onnx_dir_path, seq_length=seq_length, full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens + num_speculative_tokens=num_speculative_tokens, ) else: raise NotImplementedError( diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py index 76c227862..377caa3e7 100644 --- a/QEfficient/peft/auto.py +++ b/QEfficient/peft/auto.py @@ -24,7 +24,7 @@ from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform -from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform +from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform from QEfficient.utils import constants from QEfficient.utils._utils import get_padding_shape_from_config from QEfficient.utils.cache import to_hashable diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 4bcc6291b..16ca33dd4 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -18,7 +18,7 @@ import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform -from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform +from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform from QEfficient.utils import constants, get_padding_shape_from_config @@ -110,7 +110,14 @@ class QEFFAutoModelForCausalLM(QEFFTransformersBase): _pytorch_transforms = [AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, CustomOpsTransform, KVCacheTransform] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model: nn.Module, continuous_batching: bool = False, num_speculative_tokens: Optional[int] = None, is_dlm: bool = False, **kwargs): + def __init__( + self, + model: nn.Module, + continuous_batching: bool = False, + num_speculative_tokens: Optional[int] = None, + is_dlm: bool = False, + **kwargs, + ): if kwargs.pop("full_batch_size", None): continuous_batching = True warnings.warn( @@ -154,7 +161,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: boo num_speculative_tokens = kwargs.pop("num_speculative_tokens", None) is_dlm = kwargs.pop("is_dlm", False) if num_speculative_tokens is not None: - if not isinstance(num_speculative_tokens, int) or num_speculative_tokens<2: + if not isinstance(num_speculative_tokens, int) or num_speculative_tokens < 2: ValueError("`num_speculative_tokens` arg should be an integer greater than 1.") if is_dlm: raise ValueError("`num_speculative_tokens` arg and `is_dlm` flag are mutually exclusive.") @@ -182,7 +189,7 @@ def model_hash(self) -> str: return mhash def export( - self, + self, export_dir: Optional[str] = None, seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) -> str: @@ -199,17 +206,19 @@ def export( """ bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE if self.num_speculative_tokens is not None: - num_logits_to_keep = self.num_speculative_tokens+1 + num_logits_to_keep = self.num_speculative_tokens + 1 if seq_len < num_logits_to_keep: - raise ValueError(f"sequence length ({seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})") - setattr(self.model, "num_logits_to_keep", num_logits_to_keep) + raise ValueError( + f"sequence length ({seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})" + ) + fbs = constants.ONNX_EXPORT_EXAMPLE_FBS kv_cache_shape = get_padding_shape_from_config( self.model.config, fbs if self.continuous_batching else bs, seq_len ) example_inputs = { "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), - "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs,1), + "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), "past_key_values": [[] for _ in range(self.num_layers)], } dynamic_axes = { @@ -237,6 +246,12 @@ def export( example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) dynamic_axes["batch_index"] = {0: "batch_size"} + if self.num_speculative_tokens is not None: + example_inputs["num_logits_to_keep"] = torch.arange(self.num_speculative_tokens + 1).view( + self.num_speculative_tokens + 1, 1 + ) + dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"} + return self._export( example_inputs, output_names, @@ -282,18 +297,28 @@ def compile( :str: Path of the compiled ``qpc`` package. """ # Specializations - decode_seq_len = self.num_speculative_tokens+1 if self.num_speculative_tokens else 1 + decode_seq_len = self.num_speculative_tokens + 1 if self.num_speculative_tokens else 1 if self.continuous_batching: if full_batch_size is None: raise TypeError("missing required argument: 'full_batch_size'") specializations = [ {"full_batch_size": full_batch_size, "batch_size": 1, "seq_len": prefill_seq_len, "ctx_len": ctx_len}, - {"full_batch_size": full_batch_size, "batch_size": full_batch_size, "seq_len": decode_seq_len, "ctx_len": ctx_len}, + { + "full_batch_size": full_batch_size, + "batch_size": full_batch_size, + "seq_len": decode_seq_len, + "ctx_len": ctx_len, + }, ] if self.is_dlm: specializations.append( - {"full_batch_size": full_batch_size, "batch_size": full_batch_size, "seq_len": 2, "ctx_len": ctx_len}, + { + "full_batch_size": full_batch_size, + "batch_size": full_batch_size, + "seq_len": 2, + "ctx_len": ctx_len, + }, ) else: specializations = [ @@ -306,6 +331,13 @@ def compile( {"batch_size": batch_size, "seq_len": 2, "ctx_len": ctx_len}, ) + if self.num_speculative_tokens: + for specialization in specializations: + specialization.update({"num_logits_to_keep": self.num_speculative_tokens + 1}) + + import ipdb + + ipdb.set_trace() # Custom IO custom_io = {} kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" @@ -371,4 +403,4 @@ def export(self): raise NotImplementedError("Reached too far!!") def compile(self, *args, **kwargs) -> Any: - raise NotImplementedError("Reached too far!!") \ No newline at end of file + raise NotImplementedError("Reached too far!!") diff --git a/QEfficient/transformers/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py similarity index 96% rename from QEfficient/transformers/pytorch_transforms.py rename to QEfficient/transformers/models/pytorch_transforms.py index a68284150..5a5344e6c 100644 --- a/QEfficient/transformers/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -200,7 +200,7 @@ QEffStarcoder2ForCausalLM, QEffStarcoder2Model, ) -from QEfficient.transformers.spd.modeling_tlm import tlm_forward +from QEfficient.transformers.spd.causal_lm_forward import tlm_forward class CustomOpsTransform(ModuleMappingTransform): @@ -310,11 +310,12 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: transformers.cache_utils.DynamicCache.update = QEffDynamicCache.update return model, transformed + class SpDTransform: """ Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits. This is only needed if user is exporting Target Language Model (TLM) for Speculative Decoding to validate output logits - against the speculated tokens from a smaller model. + against the speculated tokens from a smaller model. Other than the computed logits, there should be no difference between the SpD Transformed model and its corresponding cunterpart. ``Mandatory`` Args: @@ -324,6 +325,7 @@ class SpDTransform: :model (nn.Module): PyTorch model. :transformed (bool): whether transformation was applied successfully. """ + # supported architectures _module_mapping = { # Llama @@ -333,10 +335,12 @@ class SpDTransform: @classmethod def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: transformed = False - if (model_class:=model.__class__) in cls._module_mapping: + if (model_class := model.__class__) in cls._module_mapping: model.forward = MethodType(tlm_forward, model) transformed = True else: - raise NotImplementedError(f"model class {model_class} does not yet support returning multiple logits to keep.") + raise NotImplementedError( + f"model class {model_class} does not yet support returning multiple logits to keep." + ) - return model, transformed \ No newline at end of file + return model, transformed diff --git a/QEfficient/transformers/spd/modeling_tlm.py b/QEfficient/transformers/spd/causal_lm_forward.py similarity index 71% rename from QEfficient/transformers/spd/modeling_tlm.py rename to QEfficient/transformers/spd/causal_lm_forward.py index 67c839b47..46601c0c9 100644 --- a/QEfficient/transformers/spd/modeling_tlm.py +++ b/QEfficient/transformers/spd/causal_lm_forward.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- - from typing import List, Optional, Tuple, Union import torch @@ -12,7 +11,43 @@ from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast -from QEfficient.transformers.spd.modeling_spd_utils import filter_hidden_states + +def filter_hidden_states( + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + num_logits_to_keep: Optional[int] = None, +) -> torch.Tensor: + """ + Filter hidden states based on whether this is a TLM SpD model + + ``Mandatory`` Args: + :hidden_states (torch.Tensor): Hidden states tensor. + :position_ids (torch.Tensor): Position ids tensor. + ``Optional`` Args: + :num_logits_to_keep (int, optional): Number of speculative tokens, specified only for TLM SpD model + + Returns: + :torch.Tensor: Filtered hidden states. + """ + batch_size = position_ids.size(0) + batch_indices = torch.arange(batch_size) + # Cast to INT32 to avoid issue while running in ONNXRT + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + + if num_logits_to_keep is None: + # return the last logit + return hidden_states[batch_indices.view(-1, 1), logit_index] + + # gather approach + num_logits_to_keep = num_logits_to_keep.shape[0] + lower_idx = torch.where(logit_index < num_logits_to_keep, 0, logit_index + 1 - num_logits_to_keep).view( + -1, 1 + ) # shape: [bsz, 1] + spec_idx = torch.arange(num_logits_to_keep).view(1, -1) # shape: [1, k] + indices = torch.add(lower_idx, spec_idx).unsqueeze(2) # shape: [bsz, k, 1] + indices = indices.repeat(1, 1, hidden_states.size(-1)) # shape: [bsz, ,k, d_model] + hidden_states = torch.gather(hidden_states, dim=1, index=indices) # shape: [bsz, k, d_model] + return hidden_states def tlm_forward( @@ -29,7 +64,7 @@ def tlm_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - #num_logits_to_keep: Optional[torch.LongTensor] = None, # explicit passing is not currently supported + num_logits_to_keep: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -77,8 +112,6 @@ def tlm_forward( cache_position=cache_position, ) - # Cast to INT32 to avoid issue while running in ONNXRT - num_logits_to_keep = getattr(self, "num_logits_to_keep", None) hidden_states = filter_hidden_states(outputs[0], position_ids, num_logits_to_keep) if self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) @@ -94,4 +127,4 @@ def tlm_forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - ) \ No newline at end of file + ) diff --git a/QEfficient/transformers/spd/modeling_spd_utils.py b/QEfficient/transformers/spd/modeling_spd_utils.py deleted file mode 100644 index eb72388b9..000000000 --- a/QEfficient/transformers/spd/modeling_spd_utils.py +++ /dev/null @@ -1,43 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -from typing import Optional - -import torch - - -def filter_hidden_states( - hidden_states: torch.Tensor, - position_ids: torch.Tensor, - num_logits_to_keep: Optional[int] = None, -) -> torch.Tensor: - """ - Filter hidden states based on whether this is a TLM SpD model - - ``Mandatory`` Args: - :hidden_states (torch.Tensor): Hidden states tensor. - :position_ids (torch.Tensor): Position ids tensor. - ``Optional`` Args: - :num_logits_to_keep (int, optional): Number of speculative tokens, specified only for TLM SpD model - - Returns: - :torch.Tensor: Filtered hidden states. - """ - batch_size = position_ids.size(0) - batch_indices = torch.arange(batch_size) - # Cast to INT32 to avoid issue while running in ONNXRT - logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) - if num_logits_to_keep is None: - # return the last logit - return hidden_states[batch_indices.view(-1, 1), logit_index] - # gather approach - lower_idx = torch.where(logit_index < num_logits_to_keep, 0, logit_index+1 - num_logits_to_keep).view(-1,1) # shape: [bsz, 1] - spec_idx = torch.arange(num_logits_to_keep).view(1,-1) # shape: [1, k] - indices = torch.add(lower_idx, spec_idx).unsqueeze(2) # shape: [bsz, k, 1] - indices = indices.repeat(1, 1, hidden_states.size(-1)) # shape: [bsz, ,k, d_model] - hidden_states = torch.gather(hidden_states, dim=1, index=indices) # shape: [bsz, k, d_model] - return hidden_states diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 6ef6d63cf..1d8f5b4d4 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -185,12 +185,23 @@ def load_hf_tokenizer( def get_qpc_dir_path( - model_card_name, num_cores, mos, batch_size, prompt_len, ctx_len, mxfp6, mxint8, device_group, full_batch_size, num_speculative_tokens: Optional[int] = None + model_card_name, + num_cores, + mos, + batch_size, + prompt_len, + ctx_len, + mxfp6, + mxint8, + device_group, + full_batch_size, + num_speculative_tokens: Optional[int] = None, ): # Create a unique directory name for the QPC model based on all parameters qpc_base_dir_name = ( - f"qpc_{num_cores}cores_{batch_size}bs_{prompt_len}pl_{ctx_len}cl_{mos}mos" - + f"_{num_speculative_tokens}nst" if num_speculative_tokens else '' + f"qpc_{num_cores}cores_{batch_size}bs_{prompt_len}pl_{ctx_len}cl_{mos}mos" + f"_{num_speculative_tokens}nst" + if num_speculative_tokens + else "" + f"{f'_{full_batch_size}fbs_' if full_batch_size is not None else '_'}" + f"{len(device_group) if device_group is not None else 1}" + "devices" diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index a826a3a00..ca4ef258e 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -15,7 +15,15 @@ class InputHandler: def __init__( - self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size, num_logits_to_keep: Optional[int] = None + self, + batch_size, + tokenizer, + config, + prompt, + prompt_len, + ctx_len, + full_batch_size, + num_logits_to_keep: Optional[int] = None, ): """ Initialization @@ -28,8 +36,8 @@ def __init__( :prompt_len (int): Prompt length for the model to compile. :ctx_len (int): Maximum context length to compile the model. :full_batch_size (int): Continuous batching batch size - :num_logits_to_keep (Optional[int]): - Calculate logits for the last valid `num_logits_to_keep` tokens. + :num_logits_to_keep (Optional[int]): + Calculate logits for the last valid `num_logits_to_keep` tokens. Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. """ @@ -121,7 +129,9 @@ def update_pytorch_inputs(self, inputs, pt_outputs): input_ids = torch.full((self.full_batch_size, decode_len), self.tokenizer.pad_token_id) input_ids[batch_index.view(-1)] = batch_idx_input_ids position_ids = torch.full((self.full_batch_size, decode_len), 0) - batch_idx_position_ids = torch.arange(decode_len).view(1,-1) + (inputs["position_ids"].max(1, keepdim=True).values + 1) + batch_idx_position_ids = torch.arange(decode_len).view(1, -1) + ( + inputs["position_ids"].max(1, keepdim=True).values + 1 + ) position_ids[batch_index.view(-1)] = batch_idx_position_ids updated_inputs["input_ids"] = input_ids updated_inputs["position_ids"] = position_ids @@ -129,9 +139,11 @@ def update_pytorch_inputs(self, inputs, pt_outputs): else: if self.num_logits_to_keep is not None: - input_ids = pt_outputs["logits"].argmax(-1) # shape: [batch_size, num_logits_to_keep] + input_ids = pt_outputs["logits"].argmax(-1) # shape: [batch_size, num_logits_to_keep] batch_size = input_ids.size(0) - position_ids = torch.arange(self.num_logits_to_keep).view(1, self.num_logits_to_keep).repeat(batch_size, 1) + position_ids = ( + torch.arange(self.num_logits_to_keep).view(1, self.num_logits_to_keep).repeat(batch_size, 1) + ) else: input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1) position_ids = inputs["position_ids"].max(1, keepdim=True).values + 1 diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 9b926c8c7..6f0402c1b 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -223,4 +223,4 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(): model_name = "gpt2" prompt_len = 1 - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, prompt_len=prompt_len) \ No newline at end of file + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, prompt_len=prompt_len) diff --git a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py index cec269ffa..8f377b826 100644 --- a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py +++ b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py @@ -9,6 +9,7 @@ import numpy as np import pytest +import torch from transformers import AutoTokenizer from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM @@ -16,32 +17,33 @@ configs = [ pytest.param( - [0], # device_group - 2, # num_speculative_tokens - 32, # prefill_seq_len - 128, # ctx_len - 1, # prefill_bsz - 8, # full_batch_size - "JackFram/llama-68m", # model_name - True, # continuous_batching + [0], # device_group + 2, # num_speculative_tokens + 32, # prefill_seq_len + 128, # ctx_len + 1, # prefill_bsz + 8, # full_batch_size + "JackFram/llama-68m", # model_name + True, # continuous_batching id="CB llama", ), pytest.param( - [0], # device_group - 2, # num_speculative_tokens - 32, # prefill_seq_len - 128, # ctx_len - 1, # prefill_bsz - None, # full_batch_size - "JackFram/llama-68m", # model_name - False, # continuous_batching + [0], # device_group + 2, # num_speculative_tokens + 32, # prefill_seq_len + 128, # ctx_len + 1, # prefill_bsz + None, # full_batch_size + "JackFram/llama-68m", # model_name + False, # continuous_batching id="non-CB llama", ), ] @pytest.mark.parametrize( - "device_group,num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name,continuous_batching", configs + "device_group,num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name,continuous_batching", + configs, ) def test_llama_tlm_logit_dims( device_group: List[int], @@ -58,10 +60,12 @@ def test_llama_tlm_logit_dims( vocab_size = len(tokenizer) # export and compile tlm model - qeff_model = AutoModelForCausalLM.from_pretrained(model_name, continuous_batching=continuous_batching, num_speculative_tokens=num_speculative_tokens) + qeff_model = AutoModelForCausalLM.from_pretrained( + model_name, continuous_batching=continuous_batching, num_speculative_tokens=num_speculative_tokens + ) qpc_path: str = qeff_model.compile( num_devices=len(device_group), - num_cores=16, + num_cores=14, batch_size=prefill_bsz, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, @@ -78,6 +82,7 @@ def test_llama_tlm_logit_dims( prefill_inputs = dict( input_ids=np.zeros((prefill_bsz, prefill_seq_len), dtype=np.int64), position_ids=np.arange(prefill_seq_len, dtype=np.int64).reshape(-1, 1).repeat(prefill_bsz, 1).transpose(), + num_logits_to_keep=torch.arange(num_speculative_tokens + 1).view(num_speculative_tokens + 1, 1).numpy(), ) # decode dummy inputs num_logits_to_keep = num_speculative_tokens + 1 @@ -104,7 +109,8 @@ def test_llama_tlm_logit_dims( @pytest.mark.parametrize( - "device_group,num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name,continuous_batching", configs + "device_group,num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name,continuous_batching", + configs, ) def test_llama_dlm_logit_dims( device_group: List[int], diff --git a/tests/transformers/test_transformer_pytorch_transforms.py b/tests/transformers/test_transformer_pytorch_transforms.py index d57a72e6c..60d57fe92 100644 --- a/tests/transformers/test_transformer_pytorch_transforms.py +++ b/tests/transformers/test_transformer_pytorch_transforms.py @@ -13,7 +13,7 @@ from transformers.cache_utils import HybridCache from QEfficient.customop.matmulnbits import QuantLinearORT -from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform +from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform from QEfficient.transformers.quantizers.awq import WQLinear_GEMM from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform From e9309a3899c7b2a601dd146326dd9392ee0d20cf Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 21 Nov 2024 22:11:28 +0530 Subject: [PATCH 12/30] changed interface to be similar to CB Signed-off-by: Onkar Chougule --- QEfficient/base/modeling_qeff.py | 5 +-- .../transformers/models/modeling_auto.py | 36 ++++++++++--------- .../spd/test_tlm_dlm_export_and_compile.py | 2 +- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index a7aac49c7..25a97f808 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -132,10 +132,7 @@ def _export( """ export_dir = Path(export_dir or (QEFF_HOME / self.model_name)) export_dir = export_dir.with_name(export_dir.name + "-" + self.model_hash) - if self.num_speculative_tokens: - model_name = f"{self.model_name}_{self.num_speculative_tokens+1}nltk.onnx" - else: - model_name = f"{self.model_name}.onnx" + model_name = f"{self.model_name}.onnx" onnx_path = export_dir / model_name # TODO: need to add hash to onnx if onnx_path.is_file(): diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 16ca33dd4..83e3f26bf 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -114,10 +114,11 @@ def __init__( self, model: nn.Module, continuous_batching: bool = False, - num_speculative_tokens: Optional[int] = None, + is_tlm: bool = False, is_dlm: bool = False, **kwargs, ): + # TODO: remove from version 1.20 if kwargs.pop("full_batch_size", None): continuous_batching = True warnings.warn( @@ -130,11 +131,15 @@ def __init__( self.model.config.use_cache = True self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching - self.num_speculative_tokens = num_speculative_tokens self.is_dlm = is_dlm + if is_tlm: + # TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch + self.model, transformed = SpDTransform.apply(self.model) + self.is_tlm = is_tlm + @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: bool = False, *args, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, is_dlm: bool= False, *args, **kwargs): """ This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM. Once the model is initialized, you can use other methods such as export, compile, and generate on the same object. @@ -157,15 +162,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: boo # You can now execute the model model.generate(prompts=["Hi there!!"]) """ - - num_speculative_tokens = kwargs.pop("num_speculative_tokens", None) - is_dlm = kwargs.pop("is_dlm", False) - if num_speculative_tokens is not None: - if not isinstance(num_speculative_tokens, int) or num_speculative_tokens < 2: - ValueError("`num_speculative_tokens` arg should be an integer greater than 1.") - if is_dlm: - raise ValueError("`num_speculative_tokens` arg and `is_dlm` flag are mutually exclusive.") - cls._pytorch_transforms.append(SpDTransform) + if is_tlm and is_dlm: + raise ValueError("`num_speculative_tokens` arg and `is_dlm` flag are mutually exclusive.") + if kwargs.pop("full_batch_size", None): continuous_batching = True warnings.warn( @@ -174,7 +173,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: boo self = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) self.continuous_batching = continuous_batching - self.num_speculative_tokens = num_speculative_tokens self.is_dlm = is_dlm return self @@ -184,6 +182,7 @@ def model_hash(self) -> str: mhash = hashlib.sha256() mhash.update(to_hashable(self.model.config.to_diff_dict())) mhash.update(to_hashable({"continuous_batching": self.continuous_batching})) + mhash.update(to_hashable({"is_tlm": self.is_tlm})) mhash.update(to_hashable(self._transform_names())) mhash = mhash.hexdigest()[:16] return mhash @@ -272,6 +271,7 @@ def compile( num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, + num_speculative_tokens: Optional[int] = None, **compiler_options, ) -> str: """ @@ -331,13 +331,15 @@ def compile( {"batch_size": batch_size, "seq_len": 2, "ctx_len": ctx_len}, ) - if self.num_speculative_tokens: + if self.is_tlm: + if num_speculative_tokens is None: + raise AttributeError("Please pass valid integer as input to num_speculative_tokens parameter") + if num_speculative_tokens is not None: + if not isinstance(num_speculative_tokens, int) or num_speculative_tokens < 2: + ValueError("`num_speculative_tokens` arg should be an integer greater than 1.") for specialization in specializations: specialization.update({"num_logits_to_keep": self.num_speculative_tokens + 1}) - import ipdb - - ipdb.set_trace() # Custom IO custom_io = {} kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" diff --git a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py index 8f377b826..797af8946 100644 --- a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py +++ b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py @@ -65,7 +65,7 @@ def test_llama_tlm_logit_dims( ) qpc_path: str = qeff_model.compile( num_devices=len(device_group), - num_cores=14, + num_cores=16, batch_size=prefill_bsz, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, From f7917d67f51e693bfea443de9c09ce0784c0428b Mon Sep 17 00:00:00 2001 From: eplatero Date: Thu, 21 Nov 2024 13:30:29 -0600 Subject: [PATCH 13/30] made unit tests work with array approach Signed-off-by: eplatero --- QEfficient/base/modeling_qeff.py | 11 ++-- .../transformers/models/modeling_auto.py | 50 +++++++++---------- QEfficient/utils/constants.py | 1 + .../spd/test_tlm_dlm_export_and_compile.py | 10 ++-- 4 files changed, 35 insertions(+), 37 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 25a97f808..571bbc5cf 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -203,6 +203,7 @@ def _compile( specializations: Optional[List[Dict[str, int]]] = None, custom_io: Optional[Dict[str, str]] = None, mdp_ts_num_devices: int = 1, + num_speculative_tokens: Optional[int] = None, **compiler_options, ) -> str: """ @@ -219,11 +220,7 @@ def _compile( - convert_to_fp16=True -> -convert-to-fp16 """ if onnx_path is None and self.onnx_path is None: - if self.num_speculative_tokens is not None: - prefill_seq_len = specializations[0]["seq_len"] - self.export(seq_len=prefill_seq_len) - else: - self.export() + self.export() onnx_path = Path(onnx_path or self.onnx_path) compile_dir = Path(compile_dir or onnx_path.parent) @@ -250,8 +247,8 @@ def _compile( if mdp_ts_num_devices > 1: compile_hash.update(to_hashable({"mdp_ts_num_devices": mdp_ts_num_devices})) - if self.num_speculative_tokens: - compile_hash.update(to_hashable({"num_speculative_tokens": self.num_speculative_tokens})) + if num_speculative_tokens: + compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens})) if self.is_dlm: compile_hash.update(to_hashable({"is_dlm": self.is_dlm})) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 83e3f26bf..8b67f5445 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -51,7 +51,7 @@ def __repr__(self) -> str: @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, *args, **kwargs): if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') @@ -61,7 +61,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) - return cls(model) + return cls(model, is_tlm=is_tlm) @property def model_name(self) -> str: @@ -82,6 +82,7 @@ def model_hash(self) -> str: mhash = hashlib.sha256() mhash.update(to_hashable(self.model.config.to_diff_dict())) mhash.update(to_hashable(self._transform_names())) + mhash.update(to_hashable({"is_tlm": self.is_tlm})) mhash = mhash.hexdigest()[:16] return mhash @@ -171,7 +172,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: boo "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 ) - self = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + self = super().from_pretrained(pretrained_model_name_or_path, is_tlm=is_tlm, *args, **kwargs) self.continuous_batching = continuous_batching self.is_dlm = is_dlm return self @@ -187,11 +188,7 @@ def model_hash(self) -> str: mhash = mhash.hexdigest()[:16] return mhash - def export( - self, - export_dir: Optional[str] = None, - seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, - ) -> str: + def export(self, export_dir: Optional[str] = None) -> str: """ Exports the model to ``ONNX`` format using ``torch.onnx.export``. We currently don't support exporting non-transformed models. Please refer to the ``convert_to_cloud_bertstyle`` function in the **Low-Level API** for a legacy function that supports this." @@ -203,14 +200,8 @@ def export( Returns: :str: Path of the generated ``ONNX`` graph. """ - bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - if self.num_speculative_tokens is not None: - num_logits_to_keep = self.num_speculative_tokens + 1 - if seq_len < num_logits_to_keep: - raise ValueError( - f"sequence length ({seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})" - ) - + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN fbs = constants.ONNX_EXPORT_EXAMPLE_FBS kv_cache_shape = get_padding_shape_from_config( self.model.config, fbs if self.continuous_batching else bs, seq_len @@ -245,9 +236,10 @@ def export( example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) dynamic_axes["batch_index"] = {0: "batch_size"} - if self.num_speculative_tokens is not None: - example_inputs["num_logits_to_keep"] = torch.arange(self.num_speculative_tokens + 1).view( - self.num_speculative_tokens + 1, 1 + if self.is_tlm: + nlk = constants.ONNX_EXPORT_EXAMPLE_NLK # Number of Logits to Keep + example_inputs["num_logits_to_keep"] = torch.arange(nlk).view( + nlk, 1 ) dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"} @@ -296,8 +288,18 @@ def compile( Returns: :str: Path of the compiled ``qpc`` package. """ + # assert num_speculative_tokens cfg is acceptable if defined + if num_speculative_tokens is not None or self.is_tlm: + assert num_speculative_tokens is not None and self.is_tlm, f"if `num_speculative_tokens` is specified or `is_tlm` is True, they must both be defined." + if not isinstance(num_speculative_tokens, int) or num_speculative_tokens < 2: + ValueError("`num_speculative_tokens` arg should be an integer greater than 1.") + num_logits_to_keep = num_speculative_tokens + 1 + if prefill_seq_len < num_logits_to_keep: + raise ValueError( + f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})" + ) # Specializations - decode_seq_len = self.num_speculative_tokens + 1 if self.num_speculative_tokens else 1 + decode_seq_len = num_speculative_tokens + 1 if num_speculative_tokens else 1 if self.continuous_batching: if full_batch_size is None: raise TypeError("missing required argument: 'full_batch_size'") @@ -332,13 +334,8 @@ def compile( ) if self.is_tlm: - if num_speculative_tokens is None: - raise AttributeError("Please pass valid integer as input to num_speculative_tokens parameter") - if num_speculative_tokens is not None: - if not isinstance(num_speculative_tokens, int) or num_speculative_tokens < 2: - ValueError("`num_speculative_tokens` arg should be an integer greater than 1.") for specialization in specializations: - specialization.update({"num_logits_to_keep": self.num_speculative_tokens + 1}) + specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) # Custom IO custom_io = {} @@ -358,6 +355,7 @@ def compile( mxfp6_matmul=mxfp6_matmul, custom_io=custom_io, mdp_ts_num_devices=num_devices, + num_speculative_tokens=num_speculative_tokens, aic_num_cores=num_cores, **compiler_options, ) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index c660b1897..151a2e19c 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -46,6 +46,7 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1 ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 ONNX_EXPORT_EXAMPLE_FBS = 4 +ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep ONNX_EXPORT_OPSET = 13 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] diff --git a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py index 797af8946..45e28a35d 100644 --- a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py +++ b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py @@ -9,7 +9,6 @@ import numpy as np import pytest -import torch from transformers import AutoTokenizer from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM @@ -60,8 +59,9 @@ def test_llama_tlm_logit_dims( vocab_size = len(tokenizer) # export and compile tlm model + is_tlm = num_speculative_tokens is not None qeff_model = AutoModelForCausalLM.from_pretrained( - model_name, continuous_batching=continuous_batching, num_speculative_tokens=num_speculative_tokens + model_name, continuous_batching=continuous_batching, is_tlm=is_tlm ) qpc_path: str = qeff_model.compile( num_devices=len(device_group), @@ -71,6 +71,7 @@ def test_llama_tlm_logit_dims( ctx_len=ctx_len, mxfp6_matmul=True, full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, ) # init qaic session @@ -79,17 +80,18 @@ def test_llama_tlm_logit_dims( session.skip_buffers(set([x for x in session.input_names if x.startswith("past_")])) session.skip_buffers(set([x for x in session.output_names if x.endswith("_RetainedState")])) # prefill dummy inputs + num_logits_to_keep = num_speculative_tokens + 1 prefill_inputs = dict( input_ids=np.zeros((prefill_bsz, prefill_seq_len), dtype=np.int64), position_ids=np.arange(prefill_seq_len, dtype=np.int64).reshape(-1, 1).repeat(prefill_bsz, 1).transpose(), - num_logits_to_keep=torch.arange(num_speculative_tokens + 1).view(num_speculative_tokens + 1, 1).numpy(), + num_logits_to_keep=np.arange(num_logits_to_keep).reshape(num_logits_to_keep, 1), ) # decode dummy inputs - num_logits_to_keep = num_speculative_tokens + 1 decode_bsz = full_batch_size if full_batch_size is not None else prefill_bsz decode_inputs = dict( input_ids=np.zeros((decode_bsz, num_logits_to_keep), dtype=np.int64), position_ids=np.full((decode_bsz, num_logits_to_keep), -1, dtype=np.int64), + num_logits_to_keep=np.arange(num_logits_to_keep).reshape(num_logits_to_keep, 1), ) if full_batch_size is not None: prefill_inputs["batch_index"] = np.arange(prefill_bsz, dtype=np.int64).reshape(prefill_bsz, 1) From 806ef1a131ce3be5478004df8681819c601b0b8f Mon Sep 17 00:00:00 2001 From: eplatero Date: Thu, 21 Nov 2024 16:47:01 -0600 Subject: [PATCH 14/30] for TLM, made specialization return 1 logit for prefill and for decode Signed-off-by: eplatero --- QEfficient/transformers/models/modeling_auto.py | 7 +++++-- tests/transformers/spd/test_tlm_dlm_export_and_compile.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 8b67f5445..290ab6eee 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -334,8 +334,11 @@ def compile( ) if self.is_tlm: - for specialization in specializations: - specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) + for i,specialization in enumerate(specializations): + if i: + specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) + else: + specialization.update({"num_logits_to_keep": 1}) # Custom IO custom_io = {} diff --git a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py index 45e28a35d..d5211b5f2 100644 --- a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py +++ b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py @@ -84,7 +84,7 @@ def test_llama_tlm_logit_dims( prefill_inputs = dict( input_ids=np.zeros((prefill_bsz, prefill_seq_len), dtype=np.int64), position_ids=np.arange(prefill_seq_len, dtype=np.int64).reshape(-1, 1).repeat(prefill_bsz, 1).transpose(), - num_logits_to_keep=np.arange(num_logits_to_keep).reshape(num_logits_to_keep, 1), + num_logits_to_keep=np.arange(1).reshape(1, 1), ) # decode dummy inputs decode_bsz = full_batch_size if full_batch_size is not None else prefill_bsz @@ -97,7 +97,7 @@ def test_llama_tlm_logit_dims( prefill_inputs["batch_index"] = np.arange(prefill_bsz, dtype=np.int64).reshape(prefill_bsz, 1) decode_inputs["batch_index"] = np.arange(decode_bsz, dtype=np.int64).reshape(-1, 1) # create dummy logits - prefill_logits = dict(logits=np.random.randn(prefill_bsz, num_logits_to_keep, vocab_size).astype(np.float32)) + prefill_logits = dict(logits=np.random.randn(prefill_bsz, 1, vocab_size).astype(np.float32)) decode_logits = dict(logits=np.random.randn(decode_bsz, num_logits_to_keep, vocab_size).astype(np.float32)) # get prefill/decode logits session.set_buffers(prefill_logits) From 7dbb5834f7698795329c3364d66d6c3e7fac155c Mon Sep 17 00:00:00 2001 From: eplatero Date: Thu, 21 Nov 2024 16:56:08 -0600 Subject: [PATCH 15/30] moved from to method because this flag only has implications for compile stage, not export Signed-off-by: eplatero --- QEfficient/base/modeling_qeff.py | 5 +++-- QEfficient/transformers/models/modeling_auto.py | 16 ++++++++-------- .../spd/test_tlm_dlm_export_and_compile.py | 3 ++- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 571bbc5cf..0aa36e33b 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -204,6 +204,7 @@ def _compile( custom_io: Optional[Dict[str, str]] = None, mdp_ts_num_devices: int = 1, num_speculative_tokens: Optional[int] = None, + is_dlm: bool = False, **compiler_options, ) -> str: """ @@ -250,8 +251,8 @@ def _compile( if num_speculative_tokens: compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens})) - if self.is_dlm: - compile_hash.update(to_hashable({"is_dlm": self.is_dlm})) + if is_dlm: + compile_hash.update(to_hashable({"is_dlm": is_dlm})) # Check if already compiled compile_hash = compile_hash.hexdigest()[:16] diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 290ab6eee..16264c2cd 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -116,7 +116,6 @@ def __init__( model: nn.Module, continuous_batching: bool = False, is_tlm: bool = False, - is_dlm: bool = False, **kwargs, ): # TODO: remove from version 1.20 @@ -132,7 +131,6 @@ def __init__( self.model.config.use_cache = True self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching - self.is_dlm = is_dlm if is_tlm: # TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch @@ -140,7 +138,7 @@ def __init__( self.is_tlm = is_tlm @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, is_dlm: bool= False, *args, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, *args, **kwargs): """ This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM. Once the model is initialized, you can use other methods such as export, compile, and generate on the same object. @@ -163,8 +161,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: boo # You can now execute the model model.generate(prompts=["Hi there!!"]) """ - if is_tlm and is_dlm: - raise ValueError("`num_speculative_tokens` arg and `is_dlm` flag are mutually exclusive.") if kwargs.pop("full_batch_size", None): continuous_batching = True @@ -174,7 +170,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: boo self = super().from_pretrained(pretrained_model_name_or_path, is_tlm=is_tlm, *args, **kwargs) self.continuous_batching = continuous_batching - self.is_dlm = is_dlm return self @property @@ -264,6 +259,7 @@ def compile( mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, + is_dlm: bool = False, **compiler_options, ) -> str: """ @@ -298,6 +294,9 @@ def compile( raise ValueError( f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})" ) + # assert is_tlm and is_dlm are mutex + if self.is_tlm and is_dlm: + raise ValueError("`num_speculative_tokens` arg and `is_dlm` flag are mutually exclusive.") # Specializations decode_seq_len = num_speculative_tokens + 1 if num_speculative_tokens else 1 if self.continuous_batching: @@ -313,7 +312,7 @@ def compile( "ctx_len": ctx_len, }, ] - if self.is_dlm: + if is_dlm: specializations.append( { "full_batch_size": full_batch_size, @@ -328,7 +327,7 @@ def compile( ] if prefill_seq_len != 1: specializations.append({"batch_size": batch_size, "seq_len": decode_seq_len, "ctx_len": ctx_len}) - if self.is_dlm: + if is_dlm: specializations.append( {"batch_size": batch_size, "seq_len": 2, "ctx_len": ctx_len}, ) @@ -359,6 +358,7 @@ def compile( custom_io=custom_io, mdp_ts_num_devices=num_devices, num_speculative_tokens=num_speculative_tokens, + is_dlm=is_dlm, aic_num_cores=num_cores, **compiler_options, ) diff --git a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py index d5211b5f2..b283ddf8b 100644 --- a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py +++ b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py @@ -129,7 +129,7 @@ def test_llama_dlm_logit_dims( vocab_size = len(tokenizer) # export and compile tlm model - qeff_model = AutoModelForCausalLM.from_pretrained(model_name, continuous_batching=continuous_batching, is_dlm=True) + qeff_model = AutoModelForCausalLM.from_pretrained(model_name, continuous_batching=continuous_batching) qpc_path: str = qeff_model.compile( num_devices=len(device_group), num_cores=16, @@ -138,6 +138,7 @@ def test_llama_dlm_logit_dims( ctx_len=ctx_len, mxfp6_matmul=True, full_batch_size=full_batch_size, + is_dlm=True, ) # init qaic session From a713a2c58538526fb8e78176cbde2cf944ac3f3d Mon Sep 17 00:00:00 2001 From: eplatero Date: Thu, 21 Nov 2024 22:48:19 -0600 Subject: [PATCH 16/30] fixing qpc directory naming to be backwards compatible Signed-off-by: eplatero --- QEfficient/utils/_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 1d8f5b4d4..8d1cf3cd7 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -199,10 +199,9 @@ def get_qpc_dir_path( ): # Create a unique directory name for the QPC model based on all parameters qpc_base_dir_name = ( - f"qpc_{num_cores}cores_{batch_size}bs_{prompt_len}pl_{ctx_len}cl_{mos}mos" + f"_{num_speculative_tokens}nst" - if num_speculative_tokens - else "" + f"qpc_{num_cores}cores_{batch_size}bs_{prompt_len}pl_{ctx_len}cl_{mos}mos" + f"{f'_{full_batch_size}fbs_' if full_batch_size is not None else '_'}" + + f"{f'_{num_speculative_tokens}nst_' if num_speculative_tokens is not None else ''}" + f"{len(device_group) if device_group is not None else 1}" + "devices" + ("_mxfp6_mxint8" if (mxfp6 and mxint8) else "_mxfp6" if mxfp6 else "_fp16_mxint8" if mxint8 else "_fp16") From e0c150fb68e23c681eb03bcc7a8eabb6ce531883 Mon Sep 17 00:00:00 2001 From: eplatero Date: Thu, 21 Nov 2024 23:32:14 -0600 Subject: [PATCH 17/30] updating docstrings and documentation Signed-off-by: eplatero --- QEfficient/base/modeling_qeff.py | 2 ++ QEfficient/transformers/models/modeling_auto.py | 7 +++++-- .../transformers/models/pytorch_transforms.py | 2 +- docs/source/quick_start.md | 14 +++++++------- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 0aa36e33b..2ef11ab9d 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -216,6 +216,8 @@ def _compile( :specializations (list): List of specializations to compile for :custom_io (dict): Custom IO to specify the input and outputs in different formats than default :mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing. + :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model. + :is_dlm (bool, optional): Whether this is a Speculative Decoding draft-model. ``Defaults to False``. :compiler_options: Pass any compiler option as input. Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below: - aic_num_cores=16 -> -aic-num-cores=16 - convert_to_fp16=True -> -convert-to-fp16 diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 16264c2cd..d96582c70 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -95,6 +95,7 @@ class QEFFAutoModelForCausalLM(QEFFTransformersBase): ``Mandatory`` Args: :model (nn.Module): PyTorch model :continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. + :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode. .. code-block:: python @@ -145,7 +146,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: boo Args: :pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory. - :continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. + :continuous_batching (bool): Whether this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. + :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode. :args, kwargs: Additional arguments to pass to transformers.AutoModelForCausalLM. .. code-block:: python @@ -190,7 +192,6 @@ def export(self, export_dir: Optional[str] = None) -> str: ``Optional`` Args: :export_dir (str, optional): The directory path to store ONNX-graph. - :seq_len (int, optional): The length of the pytorch prompt inputs.. ``Defaults to 32``. Returns: :str: Path of the generated ``ONNX`` graph. @@ -278,6 +279,8 @@ def compile( :full_batch_size (int, optional): Continuous batching batch size. :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to True``. :mxint8_kv_cache (bool, optional): Whether to use ``mxint8`` compression for KV cache. ``Defaults to False``. + :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model. + :is_dlm (bool, optional): Whether this is a Speculative Decoding draft-model. ``Defaults to False``. :mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. ``Defaults to -1``. :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``. diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 5a5344e6c..6b8d00689 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -313,7 +313,7 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: class SpDTransform: """ - Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits. + Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. This is only needed if user is exporting Target Language Model (TLM) for Speculative Decoding to validate output logits against the speculated tokens from a smaller model. Other than the computed logits, there should be no difference between the SpD Transformed model and its corresponding cunterpart. diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index dbc6ac8ee..7fd0ff338 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -151,18 +151,18 @@ qeff_model.generate(prompts=["My name is"]) End to End demo examples for various models are available in **notebooks** directory. Please check them out. ### Draft-Based Speculative Decoding -Draft-based speculative decoding is the approach where a small Draft Language Model (DLM) makes `num_logits_to_keep` autoregressive speculations ahead of the Target Language Model (TLM). The objective is to predict what the TLM would have predicted if it would have been used instead of the DLM. This approach is beneficial when the autoregressive decode phase of the TLM is memory bound and thus, we can leverage the extra computing resources of our hardware by batching the speculations of the DLM. +Draft-based speculative decoding is the approach where a small Draft Language Model (DLM) makes `num_speculative_tokens` autoregressive speculations ahead of the Target Language Model (TLM). The objective is to predict what the TLM would have predicted if it would have been used instead of the DLM. This approach is beneficial when the autoregressive decode phase of the TLM is memory bound and thus, we can leverage the extra computing resources of our hardware by batching the speculations of the DLM as an input to TLM to validate the speculations. -To export both DLM/TLM, add below flags to `from_pretrained`: +To export and compile both DLM/TLM, add corresponding `is_tlm`, `num_speculative_tokens`, and `is_dlm` arguments: ```Python tlm_name = "meta-llama/Llama-2-70b-chat-hf" dlm_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" k = 3 # DLM will make `k` speculations -tlm = AutoModelForCausalLM.from_pretrained(tlm_name, num_speculative_tokens=k) -dlm = AutoModelForCausalLM.from_pretrained(dlm_name, is_dlm=True) +tlm = AutoModelForCausalLM.from_pretrained(tlm_name, is_tlm=True) +dlm = AutoModelForCausalLM.from_pretrained(dlm_name) +tlm.compile(num_speculative_tokens=k) +dlm.compile(is_dlm=True) ``` -Once done, the same high-level python APIs of `export` and `compile` can be used to generate QPC. -When `num_speculative_tokens` is specified, QEfficient transforms the TLM to always output `num_speculative_tokens+1` logits per batch for both prefill and decode. While only the last logit corresponding to the last autoregressive token is needed in prefill, for decode phase, we take in as batch input the speculations from the DLM. As for the DLM, the only addition of adding the `is_dlm=True` flag is that an extra specialization file with `seq_len=2` is created to account for the "bonus" token that happens when all speculations are correct. -> NOTE: due to some compiler limitations, it is currently not possible to create an onnx-graph that parametrizes `num_speculative_tokens`. Because of this, a unique onnx-graph will be created per unique-specified `num_speculative_tokens`. This is also why `num_speculative_tokens+1` will be returned for both prefill and decode. \ No newline at end of file +The `is_tlm` flag is fed during the instantiation of the model because making it a TLM requires slight changes to the ONNX graph. Once complete, the user can specify `num_speculative_tokens` to define the actual number of speculations that the TLM will take as input during the decode phase. As for the DLM, no new changes are required at the ONNX level. The only change is that the model now receives an additional specialization during the compilation step to account for feeding the "bonus" token in the case where all speculations are accepted. \ No newline at end of file From 12d274936f5c8b09b24c1b0bbb421b3ec44451de Mon Sep 17 00:00:00 2001 From: eplatero Date: Fri, 22 Nov 2024 10:32:52 -0600 Subject: [PATCH 18/30] revert changes to CLI exportation of onnx and specialization to reflect state in main branch Signed-off-by: eplatero --- QEfficient/compile/compile_helper.py | 37 ++++------- .../exporter/export_hf_to_cloud_ai_100.py | 28 +------- QEfficient/utils/generate_inputs.py | 65 +++++-------------- 3 files changed, 31 insertions(+), 99 deletions(-) diff --git a/QEfficient/compile/compile_helper.py b/QEfficient/compile/compile_helper.py index f6d8b0228..2bf33699d 100644 --- a/QEfficient/compile/compile_helper.py +++ b/QEfficient/compile/compile_helper.py @@ -16,33 +16,24 @@ def create_and_dump_specializations( - batch_size: int, - prompt_len: int, - ctx_len: int, - path: str, - is_dlm: bool, - full_batch_size: Optional[int] = None, - num_speculative_tokens: Optional[int] = None, + batch_size: int, prompt_len: int, ctx_len: int, path: str, full_batch_size: Optional[int] = None ): - # Create specialization cfgs - decode_seq_len = 1 if num_speculative_tokens is None else num_speculative_tokens + 1 - specialization_cfgs = [ - dict(batch_size=str(batch_size), seq_len=str(prompt_len), ctx_len=str(ctx_len)), # prefill - dict(batch_size=str(batch_size), seq_len=str(decode_seq_len), ctx_len=str(ctx_len)), # decode - ] - if is_dlm: - specialization_cfgs.append(dict(batch_size=str(batch_size), seq_len="2", ctx_len=str(ctx_len))) - - specializations = dict(specializations=specialization_cfgs) - + # Create specialization file. + specializations = { + "specializations": [ + { + "batch_size": str(batch_size), + "seq_len": str(prompt_len), + "ctx_len": str(ctx_len), + }, + {"batch_size": str(batch_size), "seq_len": "1", "ctx_len": str(ctx_len)}, + ] + } # If continuous batching is enabled by proving full_batch_size we need to add FBS to the specialization file and update the batch size of decoder part to FBS if full_batch_size is not None: specializations["specializations"][0]["full_batch_size"] = str(full_batch_size) specializations["specializations"][1]["full_batch_size"] = str(full_batch_size) specializations["specializations"][1]["batch_size"] = str(full_batch_size) - if len(specializations["specializations"]) == 3: - specializations["specializations"][2]["batch_size"] = str(full_batch_size) - specializations["specializations"][2]["full_batch_size"] = str(full_batch_size) # To handle repetative input in specializations when prompt_len is 1 if prompt_len == 1 and full_batch_size is None: @@ -177,8 +168,6 @@ def compile( ctx_len=ctx_len, path=specialization_json_path, full_batch_size=full_batch_size, - is_dlm=kwargs.get("is_dlm", False), - num_speculative_tokens=kwargs.get("num_speculative_tokens", None), ) # Select the customIO config based on the mx flag. @@ -205,4 +194,4 @@ def compile( ) logger.info(f"Compiled QPC files can be found here: {qpc_path}") - return qpc_path + return qpc_path \ No newline at end of file diff --git a/QEfficient/exporter/export_hf_to_cloud_ai_100.py b/QEfficient/exporter/export_hf_to_cloud_ai_100.py index bbfded9f9..d203058c6 100644 --- a/QEfficient/exporter/export_hf_to_cloud_ai_100.py +++ b/QEfficient/exporter/export_hf_to_cloud_ai_100.py @@ -5,7 +5,6 @@ # # ----------------------------------------------------------------------------- -import math import os import shutil import warnings @@ -190,7 +189,6 @@ def export_kvstyle_transformed_model_to_onnx( onnx_dir_path: str, seq_len: int, full_batch_size: Optional[int] = None, - num_speculative_tokens: Optional[int] = None, ) -> str: # Disabling requires_grad on all parameters for _, p in enumerate(transformed_model.parameters()): @@ -199,18 +197,6 @@ def export_kvstyle_transformed_model_to_onnx( if seq_len <= 0: raise ValueError(f"Need seq_len to be greater than zero, got seq_len={seq_len}") - # Implicitly pass "num_speculative_tokens" if defined and \ - # assert prompt_len >= num_speculative_tokens - prompt_len = Constants.PROMPT_LEN - num_logits_to_keep = None - if num_speculative_tokens is not None: - num_logits_to_keep = num_speculative_tokens + 1 - setattr(transformed_model, "num_logits_to_keep", num_logits_to_keep) - if prompt_len < num_logits_to_keep: - prompt_len *= math.ceil((num_logits_to_keep) / prompt_len) - if prompt_len >= seq_len: - seq_len = prompt_len * 2 - # Preprocess inputs # Build inputs for prefill input_handler = InputHandler( @@ -218,10 +204,9 @@ def export_kvstyle_transformed_model_to_onnx( tokenizer=tokenizer, config=transformed_model.config, prompt=Constants.INPUT_STR, - prompt_len=prompt_len, + prompt_len=Constants.PROMPT_LEN, ctx_len=seq_len, full_batch_size=full_batch_size, - num_logits_to_keep=num_logits_to_keep, ) inputs = input_handler.prepare_pytorch_inputs() @@ -238,9 +223,7 @@ def export_kvstyle_transformed_model_to_onnx( # Build inputs for decode inputs = input_handler.update_pytorch_inputs(inputs, pt_outputs) # To avoid issues in onnx export - bsz = full_batch_size if full_batch_size else 1 - pos_len = inputs["position_ids"].size(1) - inputs["position_ids"] = torch.full((bsz, pos_len), seq_len - 1) + inputs["position_ids"] = torch.full((full_batch_size if full_batch_size else 1, 1), seq_len - 1) # Run PyTorch inference with past pt_outputs = transformed_model(**inputs) @@ -331,7 +314,6 @@ def export_for_cloud( onnx_dir_path: str, seq_length: int = Constants.SEQ_LEN, full_batch_size: Optional[int] = None, - num_speculative_tokens: Optional[int] = None, ) -> str: # FIXME: move all this to class instead of here, and just call qeff_model.export here. if AUTO_MODEL_MAP_TO_MODEL_TYPE_MAP.get(qeff_model.__class__, None) == QEFF_MODEL_TYPE.CAUSALLM: # type: ignore @@ -342,7 +324,6 @@ def export_for_cloud( onnx_dir_path=onnx_dir_path, seq_length=seq_length, full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens, ) else: raise NotImplementedError( @@ -357,7 +338,6 @@ def export_lm_model_for_cloud( onnx_dir_path: str, seq_length: int, full_batch_size: Optional[int] = None, - num_speculative_tokens: Optional[int] = None, ) -> str: if os.path.exists(onnx_dir_path): logger.warning(f"Overriding {onnx_dir_path}") @@ -386,7 +366,6 @@ def qualcomm_efficient_converter( kv: bool = True, form_factor: str = "cloud", full_batch_size: Optional[int] = None, - num_speculative_tokens: Optional[int] = None, ) -> Tuple[str, str]: """ This method is an alias for ``QEfficient.export``. @@ -462,11 +441,10 @@ def qualcomm_efficient_converter( onnx_dir_path=onnx_dir_path, seq_length=seq_length, full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens, ) return onnx_dir_path, generated_onnx_model_path else: # [TODO]: Apply the class transformation to make changes for the KV models in edge use cases # model = QEfficient.transform(model_hf, type="Transformers", form_factor="edge") # model.eval() - raise NotImplementedError("Oops! Reached too far!!") + raise NotImplementedError("Oops! Reached too far!!") \ No newline at end of file diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index ca4ef258e..4f1bc7986 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -5,8 +5,6 @@ # # ----------------------------------------------------------------------------- -from typing import Optional - import numpy as np import torch @@ -14,17 +12,7 @@ class InputHandler: - def __init__( - self, - batch_size, - tokenizer, - config, - prompt, - prompt_len, - ctx_len, - full_batch_size, - num_logits_to_keep: Optional[int] = None, - ): + def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size): """ Initialization @@ -36,10 +24,6 @@ def __init__( :prompt_len (int): Prompt length for the model to compile. :ctx_len (int): Maximum context length to compile the model. :full_batch_size (int): Continuous batching batch size - :num_logits_to_keep (Optional[int]): - Calculate logits for the last valid `num_logits_to_keep` tokens. - Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. """ # check and fix tokenizer viability padding_check_and_fix(tokenizer) @@ -48,7 +32,6 @@ def __init__( self.prompt_len = prompt_len self.ctx_len = ctx_len self.full_batch_size = full_batch_size - self.num_logits_to_keep = num_logits_to_keep self.n_layer = get_num_layers_from_config(config) self.padding_shape = get_padding_shape_from_config( config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len @@ -89,15 +72,9 @@ def prepare_pytorch_inputs(self): ) if self.full_batch_size: - # Feed input without padding (CB pt forward pass fails if padding exists in position_ids) + inputs["input_ids"] = input_ids + inputs["position_ids"] = torch.arange(input_len).view(1, input_len) inputs["batch_index"] = torch.arange(1).view(-1, 1) - if self.num_logits_to_keep is not None: - # preserve length after padding to assert `num_logits_to_keep<=padded_length` - length = inputs["position_ids"].size(1) - inputs["position_ids"] = torch.arange(length).view(1, -1) - else: - inputs["input_ids"] = input_ids - inputs["position_ids"] = position_ids past_key_values = [] for i in range(self.n_layer): @@ -120,35 +97,23 @@ def update_pytorch_inputs(self, inputs, pt_outputs): Return: :Dict: Updated input_ids, position_ids and past_key_values """ - decode_len = 1 if self.num_logits_to_keep is None else self.num_logits_to_keep updated_inputs = {} if self.full_batch_size: - # Create CB inputs (make 1 batch index have proper inputs for decode pass) batch_index = torch.arange(1).view(-1, 1) - batch_idx_input_ids = pt_outputs.logits.detach().argmax(2) - input_ids = torch.full((self.full_batch_size, decode_len), self.tokenizer.pad_token_id) - input_ids[batch_index.view(-1)] = batch_idx_input_ids - position_ids = torch.full((self.full_batch_size, decode_len), 0) - batch_idx_position_ids = torch.arange(decode_len).view(1, -1) + ( - inputs["position_ids"].max(1, keepdim=True).values + 1 - ) - position_ids[batch_index.view(-1)] = batch_idx_position_ids - updated_inputs["input_ids"] = input_ids - updated_inputs["position_ids"] = position_ids + + input_ids = pt_outputs.logits.detach().argmax(2) + updated_inputs["input_ids"] = torch.full((self.full_batch_size, 1), self.tokenizer.pad_token_id) + updated_inputs["input_ids"][batch_index.view(-1)] = input_ids + + position_ids = inputs["position_ids"].max(1, keepdim=True).values + 1 + updated_inputs["position_ids"] = torch.full((self.full_batch_size, 1), 0) + updated_inputs["position_ids"][batch_index.view(-1)] = position_ids + updated_inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1) else: - if self.num_logits_to_keep is not None: - input_ids = pt_outputs["logits"].argmax(-1) # shape: [batch_size, num_logits_to_keep] - batch_size = input_ids.size(0) - position_ids = ( - torch.arange(self.num_logits_to_keep).view(1, self.num_logits_to_keep).repeat(batch_size, 1) - ) - else: - input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1) - position_ids = inputs["position_ids"].max(1, keepdim=True).values + 1 - updated_inputs["input_ids"] = input_ids - updated_inputs["position_ids"] = position_ids + updated_inputs["input_ids"] = pt_outputs["logits"].argmax(-1).reshape(-1, 1) + updated_inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1 updated_inputs["past_key_values"] = tuple( [(key.detach(), value.detach()) for key, value in pt_outputs["past_key_values"]] @@ -232,4 +197,4 @@ def update_ort_outputs(self, ort_outputs): outputs["past_key_values"] = present_key_values outputs["logits"] = ort_outputs["logits"] - return outputs + return outputs \ No newline at end of file From 2bba06bf32260bd329574edcf7553021ff24f3c1 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Mon, 25 Nov 2024 21:50:20 +0530 Subject: [PATCH 19/30] fixed specializations creation and ran formatter Signed-off-by: Onkar Chougule --- QEfficient/base/modeling_qeff.py | 4 +- QEfficient/compile/compile_helper.py | 2 +- .../exporter/export_hf_to_cloud_ai_100.py | 2 +- .../transformers/models/modeling_auto.py | 120 +++++++++--------- QEfficient/utils/_utils.py | 2 +- QEfficient/utils/constants.py | 2 +- QEfficient/utils/generate_inputs.py | 2 +- 7 files changed, 69 insertions(+), 65 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 2ef11ab9d..8a1c6eb7e 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -132,9 +132,7 @@ def _export( """ export_dir = Path(export_dir or (QEFF_HOME / self.model_name)) export_dir = export_dir.with_name(export_dir.name + "-" + self.model_hash) - model_name = f"{self.model_name}.onnx" - onnx_path = export_dir / model_name - # TODO: need to add hash to onnx + onnx_path = export_dir / f"{self.model_name}.onnx" if onnx_path.is_file(): self.onnx_path = onnx_path return onnx_path diff --git a/QEfficient/compile/compile_helper.py b/QEfficient/compile/compile_helper.py index 2bf33699d..a94c88d23 100644 --- a/QEfficient/compile/compile_helper.py +++ b/QEfficient/compile/compile_helper.py @@ -194,4 +194,4 @@ def compile( ) logger.info(f"Compiled QPC files can be found here: {qpc_path}") - return qpc_path \ No newline at end of file + return qpc_path diff --git a/QEfficient/exporter/export_hf_to_cloud_ai_100.py b/QEfficient/exporter/export_hf_to_cloud_ai_100.py index d203058c6..c13bb9536 100644 --- a/QEfficient/exporter/export_hf_to_cloud_ai_100.py +++ b/QEfficient/exporter/export_hf_to_cloud_ai_100.py @@ -447,4 +447,4 @@ def qualcomm_efficient_converter( # [TODO]: Apply the class transformation to make changes for the KV models in edge use cases # model = QEfficient.transform(model_hf, type="Transformers", form_factor="edge") # model.eval() - raise NotImplementedError("Oops! Reached too far!!") \ No newline at end of file + raise NotImplementedError("Oops! Reached too far!!") diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index d96582c70..3df8d0028 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -51,7 +51,7 @@ def __repr__(self) -> str: @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, *args, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') @@ -61,7 +61,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = Fals kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) - return cls(model, is_tlm=is_tlm) + return cls(model) @property def model_name(self) -> str: @@ -139,7 +139,9 @@ def __init__( self.is_tlm = is_tlm @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, *args, **kwargs): + def from_pretrained( + cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, *args, **kwargs + ): """ This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM. Once the model is initialized, you can use other methods such as export, compile, and generate on the same object. @@ -163,15 +165,16 @@ def from_pretrained(cls, pretrained_model_name_or_path, continuous_batching: boo # You can now execute the model model.generate(prompts=["Hi there!!"]) """ - + if kwargs.pop("full_batch_size", None): continuous_batching = True warnings.warn( "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 ) - self = super().from_pretrained(pretrained_model_name_or_path, is_tlm=is_tlm, *args, **kwargs) + self = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) self.continuous_batching = continuous_batching + self.is_tlm = is_tlm return self @property @@ -233,10 +236,8 @@ def export(self, export_dir: Optional[str] = None) -> str: dynamic_axes["batch_index"] = {0: "batch_size"} if self.is_tlm: - nlk = constants.ONNX_EXPORT_EXAMPLE_NLK # Number of Logits to Keep - example_inputs["num_logits_to_keep"] = torch.arange(nlk).view( - nlk, 1 - ) + nlk = constants.ONNX_EXPORT_EXAMPLE_NLK # Number of Logits to Keep + example_inputs["num_logits_to_keep"] = torch.arange(nlk).view(nlk, 1) dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"} return self._export( @@ -287,60 +288,61 @@ def compile( Returns: :str: Path of the compiled ``qpc`` package. """ - # assert num_speculative_tokens cfg is acceptable if defined - if num_speculative_tokens is not None or self.is_tlm: - assert num_speculative_tokens is not None and self.is_tlm, f"if `num_speculative_tokens` is specified or `is_tlm` is True, they must both be defined." - if not isinstance(num_speculative_tokens, int) or num_speculative_tokens < 2: - ValueError("`num_speculative_tokens` arg should be an integer greater than 1.") + # raise error if is_tlm and is_dlm aren't mutex + if self.is_tlm and is_dlm: + raise ValueError("`num_speculative_tokens` arg and `is_dlm` flag are mutually exclusive.") + + if self.is_tlm: + # assert num_speculative_tokens cfg is acceptable if defined + if num_speculative_tokens is None: + raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` is True.") + if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2: + ValueError( + f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}" + ) num_logits_to_keep = num_speculative_tokens + 1 if prefill_seq_len < num_logits_to_keep: raise ValueError( f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})" ) - # assert is_tlm and is_dlm are mutex - if self.is_tlm and is_dlm: - raise ValueError("`num_speculative_tokens` arg and `is_dlm` flag are mutually exclusive.") - # Specializations - decode_seq_len = num_speculative_tokens + 1 if num_speculative_tokens else 1 - if self.continuous_batching: - if full_batch_size is None: - raise TypeError("missing required argument: 'full_batch_size'") - - specializations = [ - {"full_batch_size": full_batch_size, "batch_size": 1, "seq_len": prefill_seq_len, "ctx_len": ctx_len}, - { - "full_batch_size": full_batch_size, - "batch_size": full_batch_size, - "seq_len": decode_seq_len, - "ctx_len": ctx_len, - }, - ] - if is_dlm: - specializations.append( - { - "full_batch_size": full_batch_size, - "batch_size": full_batch_size, - "seq_len": 2, - "ctx_len": ctx_len, - }, - ) - else: - specializations = [ - {"batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len}, - ] - if prefill_seq_len != 1: - specializations.append({"batch_size": batch_size, "seq_len": decode_seq_len, "ctx_len": ctx_len}) - if is_dlm: - specializations.append( - {"batch_size": batch_size, "seq_len": 2, "ctx_len": ctx_len}, - ) - if self.is_tlm: - for i,specialization in enumerate(specializations): - if i: - specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) - else: - specialization.update({"num_logits_to_keep": 1}) + if self.continuous_batching and full_batch_size is None: + raise TypeError("missing required argument: 'full_batch_size'") + + # Define prefill specialization + prefill_specialization = { + # Prefill is always run with single BS for continuous batching. + "batch_size": 1 if self.continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + } + prefill_specialization.update({"full_batch_size": full_batch_size}) if self.continuous_batching else None + prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else None + specializations = [ + prefill_specialization, + ] + + # Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization + if prefill_seq_len != 1 or self.continuous_batching: + decode_specialization = { + "batch_size": full_batch_size if self.continuous_batching else batch_size, + "seq_len": num_speculative_tokens + 1 if self.is_tlm else 1, + "ctx_len": ctx_len, + } + decode_specialization.update({"full_batch_size": full_batch_size}) if self.continuous_batching else None + decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else None + specializations.append(decode_specialization) + + # Extra Specializations based on case + if is_dlm: + dlm_specialization = { + "batch_size": full_batch_size if self.continuous_batching else batch_size, + "seq_len": 2, + "ctx_len": ctx_len, + } + dlm_specialization.update({"full_batch_size": full_batch_size}) if self.continuous_batching else None + specializations.append(dlm_specialization) + self.is_dlm = True # Custom IO custom_io = {} @@ -390,6 +392,10 @@ def generate( raise ValueError("Only AI_100 runtime is supported right now via generate API") if not isinstance(self.qpc_path, Path): raise TypeError("Please run compile API first!") + if self.is_tlm or getattr(self, "is_dlm", False): + raise NotImplementedError( + "generate method is not yet supported for tlm or dlm models used in Speculative Decoding" + ) generation_len = kwargs.pop("generation_len", None) return QEfficient.cloud_ai_100_exec_kv( tokenizer, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 8d1cf3cd7..29384d008 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -199,7 +199,7 @@ def get_qpc_dir_path( ): # Create a unique directory name for the QPC model based on all parameters qpc_base_dir_name = ( - f"qpc_{num_cores}cores_{batch_size}bs_{prompt_len}pl_{ctx_len}cl_{mos}mos" + f"qpc_{num_cores}cores_{batch_size}bs_{prompt_len}pl_{ctx_len}cl_{mos}mos" + f"{f'_{full_batch_size}fbs_' if full_batch_size is not None else '_'}" + f"{f'_{num_speculative_tokens}nst_' if num_speculative_tokens is not None else ''}" + f"{len(device_group) if device_group is not None else 1}" diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 151a2e19c..a3783af03 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -46,7 +46,7 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1 ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 ONNX_EXPORT_EXAMPLE_FBS = 4 -ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep +ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep ONNX_EXPORT_OPSET = 13 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index 4f1bc7986..c45cfec41 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -197,4 +197,4 @@ def update_ort_outputs(self, ort_outputs): outputs["past_key_values"] = present_key_values outputs["logits"] = ort_outputs["logits"] - return outputs \ No newline at end of file + return outputs From 547ee4199511815948f004e8a7a58ede9e1ec252 Mon Sep 17 00:00:00 2001 From: eplatero Date: Mon, 2 Dec 2024 06:09:35 -0600 Subject: [PATCH 20/30] add pytorch-level unit test Signed-off-by: eplatero --- .../test_transformer_pytorch_transforms.py | 113 ++++++++++-------- 1 file changed, 61 insertions(+), 52 deletions(-) diff --git a/tests/transformers/test_transformer_pytorch_transforms.py b/tests/transformers/test_transformer_pytorch_transforms.py index 60d57fe92..76357370f 100644 --- a/tests/transformers/test_transformer_pytorch_transforms.py +++ b/tests/transformers/test_transformer_pytorch_transforms.py @@ -13,7 +13,7 @@ from transformers.cache_utils import HybridCache from QEfficient.customop.matmulnbits import QuantLinearORT -from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform +from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform from QEfficient.transformers.quantizers.awq import WQLinear_GEMM from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform @@ -21,48 +21,49 @@ from QEfficient.utils.logging_utils import logger KVCacheTransformTestConfigs = [ - ("llama", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), - ("llama", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), - ("llama", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), - ("llama", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), - ("gpt2", 3, 12, 192, {"n_inner": 512}, 0.8), - ("gpt2", 1, 12, 192, {"n_inner": 512}, 0.8), - ("codegen", 1, 16, 1024, {"n_inner": 2048}, 0.8), - ("codegen", 3, 16, 1024, {"n_inner": 2048}, 0.8), - ("falcon", 1, 71, 4544, {"multi_query": True}, 1.5), - ("falcon", 3, 71, 4544, {"multi_query": False}, 1.5), - ("falcon", 1, 71, 4544, {"multi_query": False}, 1.5), - ("falcon", 3, 71, 4544, {"multi_query": True}, 1.5), - ("gptj", 3, 16, 4096, {"n_inner": 512}, 1), - ("gptj", 1, 16, 4096, {"n_inner": 512}, 1.2), - ("mistral", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), - ("mistral", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), - ("mistral", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), - ("mistral", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), - ("mixtral", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), - ("mixtral", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), - ("mixtral", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), - ("mixtral", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), - ("mpt", 1, 16, 2048, {}, 0.8), - ("mpt", 3, 16, 2048, {}, 0.8), - ("phi", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), - ("phi", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), - ("phi", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), - ("phi", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), - ("phi3", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), - ("phi3", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), - ("phi3", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), - ("phi3", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), - ("qwen2", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), - ("qwen2", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), - ("qwen2", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), - ("qwen2", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), - ("starcoder2", 3, 24, 192, {"num_key_value_heads": 2, "intermediate_size": 512}, 0.8), - ("starcoder2", 1, 24, 192, {"num_key_value_heads": 2, "intermediate_size": 512}, 0.8), - ("starcoder2", 3, 24, 192, {"num_key_value_heads": 24, "intermediate_size": 512}, 0.8), - ("starcoder2", 1, 24, 192, {"num_key_value_heads": 24, "intermediate_size": 512}, 0.8), - ("gemma", 3, 8, 2048, {"num_key_value_heads": 1, "intermediate_size": 512}, 0.8), - ("gemma", 1, 8, 2048, {"num_key_value_heads": 1, "intermediate_size": 512}, 0.8), + ("llama", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), + ("llama", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, True), + ("llama", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), + ("llama", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), + ("llama", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), +# ("gpt2", 3, 12, 192, {"n_inner": 512}, 0.8), +# ("gpt2", 1, 12, 192, {"n_inner": 512}, 0.8), +# ("codegen", 1, 16, 1024, {"n_inner": 2048}, 0.8), +# ("codegen", 3, 16, 1024, {"n_inner": 2048}, 0.8), +# ("falcon", 1, 71, 4544, {"multi_query": True}, 1.5), +# ("falcon", 3, 71, 4544, {"multi_query": False}, 1.5), +# ("falcon", 1, 71, 4544, {"multi_query": False}, 1.5), +# ("falcon", 3, 71, 4544, {"multi_query": True}, 1.5), +# ("gptj", 3, 16, 4096, {"n_inner": 512}, 1), +# ("gptj", 1, 16, 4096, {"n_inner": 512}, 1.2), +# ("mistral", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), +# ("mistral", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), +# ("mistral", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), +# ("mistral", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), +# ("mixtral", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), +# ("mixtral", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), +# ("mixtral", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), +# ("mixtral", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), +# ("mpt", 1, 16, 2048, {}, 0.8), +# ("mpt", 3, 16, 2048, {}, 0.8), +# ("phi", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), +# ("phi", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), +# ("phi", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), +# ("phi", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), +# ("phi3", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), +# ("phi3", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), +# ("phi3", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), +# ("phi3", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), +# ("qwen2", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), +# ("qwen2", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), +# ("qwen2", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), +# ("qwen2", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), +# ("starcoder2", 3, 24, 192, {"num_key_value_heads": 2, "intermediate_size": 512}, 0.8), +# ("starcoder2", 1, 24, 192, {"num_key_value_heads": 2, "intermediate_size": 512}, 0.8), +# ("starcoder2", 3, 24, 192, {"num_key_value_heads": 24, "intermediate_size": 512}, 0.8), +# ("starcoder2", 1, 24, 192, {"num_key_value_heads": 24, "intermediate_size": 512}, 0.8), +# ("gemma", 3, 8, 2048, {"num_key_value_heads": 1, "intermediate_size": 512}, 0.8), +# ("gemma", 1, 8, 2048, {"num_key_value_heads": 1, "intermediate_size": 512}, 0.8), ] @@ -90,7 +91,7 @@ def compare_original_vs_kv_model_pt_outputs(original_val, kv_val, tolerance=1e-6 def run_kv_cache_transform_and_test( - hf_model, num_hidden_layers, padding_shape, vocab_size, input_len, logits_tolerance=0.8, kv_cache=None + hf_model, num_hidden_layers, padding_shape, vocab_size, input_len, logits_tolerance=0.8, kv_cache=None, is_tlm=False, ): hf_model.eval() # Run original model @@ -118,6 +119,10 @@ def run_kv_cache_transform_and_test( # Apply transform hf_model, transformed = KVCacheTransform.apply(hf_model) assert transformed + if is_tlm: + hf_model, transformed = SpDTransform.apply(hf_model) + assert transformed + # Prepare KV model inputs past_key_values = [] @@ -126,15 +131,18 @@ def run_kv_cache_transform_and_test( past_value = torch.zeros((padding_shape), dtype=torch.float32) pkv = (past_key, past_value) past_key_values.append(pkv) + inputs = dict( + input_ids=input_ids, + position_ids=torch.Tensor([range(input_ids.shape[1])]).long(), + past_key_values=tuple(past_key_values), + output_hidden_states=True, + ) + if is_tlm: + inputs["num_logits_to_keep"] = torch.zeros((input_len, 1)) # Run KV model with torch.inference_mode(): - transformed_model_outputs = hf_model( - input_ids=input_ids, - position_ids=torch.Tensor([range(input_ids.shape[1])]).long(), - past_key_values=tuple(past_key_values), - output_hidden_states=True, - ) + transformed_model_outputs = hf_model(**inputs) assert original_model_outputs.keys() == transformed_model_outputs.keys(), "Model output keys do not match!" @@ -184,11 +192,11 @@ def test_rms_norm_ops_transform(module: torch.nn.Module, hidden_size: int, input @pytest.mark.parametrize( - "config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance", + "config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance, is_tlm", KVCacheTransformTestConfigs, ) def test_kv_cache_transform( - config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance + config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance, is_tlm, ): config = AutoConfig.for_model( config_class, @@ -219,6 +227,7 @@ def test_kv_cache_transform( input_len=8, logits_tolerance=logits_tolerance, kv_cache=kv_cache, + is_tlm=is_tlm, ) From f9826c715abfc4f13066a1a2665278ff0b4755f5 Mon Sep 17 00:00:00 2001 From: eplatero Date: Mon, 2 Dec 2024 06:12:34 -0600 Subject: [PATCH 21/30] uncommented non-llama pytorch-level unit test Signed-off-by: eplatero --- .../test_transformer_pytorch_transforms.py | 76 +++++++++---------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/tests/transformers/test_transformer_pytorch_transforms.py b/tests/transformers/test_transformer_pytorch_transforms.py index 76357370f..f0758f2f3 100644 --- a/tests/transformers/test_transformer_pytorch_transforms.py +++ b/tests/transformers/test_transformer_pytorch_transforms.py @@ -26,44 +26,44 @@ ("llama", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), ("llama", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), ("llama", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), -# ("gpt2", 3, 12, 192, {"n_inner": 512}, 0.8), -# ("gpt2", 1, 12, 192, {"n_inner": 512}, 0.8), -# ("codegen", 1, 16, 1024, {"n_inner": 2048}, 0.8), -# ("codegen", 3, 16, 1024, {"n_inner": 2048}, 0.8), -# ("falcon", 1, 71, 4544, {"multi_query": True}, 1.5), -# ("falcon", 3, 71, 4544, {"multi_query": False}, 1.5), -# ("falcon", 1, 71, 4544, {"multi_query": False}, 1.5), -# ("falcon", 3, 71, 4544, {"multi_query": True}, 1.5), -# ("gptj", 3, 16, 4096, {"n_inner": 512}, 1), -# ("gptj", 1, 16, 4096, {"n_inner": 512}, 1.2), -# ("mistral", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), -# ("mistral", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), -# ("mistral", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), -# ("mistral", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), -# ("mixtral", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), -# ("mixtral", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), -# ("mixtral", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), -# ("mixtral", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), -# ("mpt", 1, 16, 2048, {}, 0.8), -# ("mpt", 3, 16, 2048, {}, 0.8), -# ("phi", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), -# ("phi", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), -# ("phi", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), -# ("phi", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), -# ("phi3", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), -# ("phi3", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), -# ("phi3", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), -# ("phi3", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), -# ("qwen2", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), -# ("qwen2", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), -# ("qwen2", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), -# ("qwen2", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), -# ("starcoder2", 3, 24, 192, {"num_key_value_heads": 2, "intermediate_size": 512}, 0.8), -# ("starcoder2", 1, 24, 192, {"num_key_value_heads": 2, "intermediate_size": 512}, 0.8), -# ("starcoder2", 3, 24, 192, {"num_key_value_heads": 24, "intermediate_size": 512}, 0.8), -# ("starcoder2", 1, 24, 192, {"num_key_value_heads": 24, "intermediate_size": 512}, 0.8), -# ("gemma", 3, 8, 2048, {"num_key_value_heads": 1, "intermediate_size": 512}, 0.8), -# ("gemma", 1, 8, 2048, {"num_key_value_heads": 1, "intermediate_size": 512}, 0.8), + ("gpt2", 3, 12, 192, {"n_inner": 512}, 0.8, False), + ("gpt2", 1, 12, 192, {"n_inner": 512}, 0.8, False), + ("codegen", 1, 16, 1024, {"n_inner": 2048}, 0.8, False), + ("codegen", 3, 16, 1024, {"n_inner": 2048}, 0.8, False), + ("falcon", 1, 71, 4544, {"multi_query": True}, 1.5, False), + ("falcon", 3, 71, 4544, {"multi_query": False}, 1.5, False), + ("falcon", 1, 71, 4544, {"multi_query": False}, 1.5, False), + ("falcon", 3, 71, 4544, {"multi_query": True}, 1.5, False), + ("gptj", 3, 16, 4096, {"n_inner": 512}, 1, False), + ("gptj", 1, 16, 4096, {"n_inner": 512}, 1.2, False), + ("mistral", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), + ("mistral", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), + ("mistral", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), + ("mistral", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), + ("mixtral", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), + ("mixtral", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), + ("mixtral", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), + ("mixtral", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), + ("mpt", 1, 16, 2048, {}, 0.8, False), + ("mpt", 3, 16, 2048, {}, 0.8, False), + ("phi", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), + ("phi", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), + ("phi", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), + ("phi", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), + ("phi3", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), + ("phi3", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), + ("phi3", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), + ("phi3", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), + ("qwen2", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), + ("qwen2", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), + ("qwen2", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), + ("qwen2", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), + ("starcoder2", 3, 24, 192, {"num_key_value_heads": 2, "intermediate_size": 512}, 0.8, False), + ("starcoder2", 1, 24, 192, {"num_key_value_heads": 2, "intermediate_size": 512}, 0.8, False), + ("starcoder2", 3, 24, 192, {"num_key_value_heads": 24, "intermediate_size": 512}, 0.8, False), + ("starcoder2", 1, 24, 192, {"num_key_value_heads": 24, "intermediate_size": 512}, 0.8, False), + ("gemma", 3, 8, 2048, {"num_key_value_heads": 1, "intermediate_size": 512}, 0.8, False), + ("gemma", 1, 8, 2048, {"num_key_value_heads": 1, "intermediate_size": 512}, 0.8, False), ] From 91f2fd71c2caef5e2a2b36da7555d26653b09bed Mon Sep 17 00:00:00 2001 From: eplatero Date: Tue, 3 Dec 2024 16:23:20 -0600 Subject: [PATCH 22/30] modified pytorch level unit test and added hf vs ort vs qaic unit test Signed-off-by: eplatero --- .../generation/text_generation_inference.py | 62 ++++++-- .../transformers/models/modeling_auto.py | 5 +- QEfficient/utils/constants.py | 1 + QEfficient/utils/run_utils.py | 7 +- .../models/test_causal_lm_models.py | 31 +++- .../test_transformer_pytorch_transforms.py | 136 ++++++++++++------ 6 files changed, 180 insertions(+), 62 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index cc9880a2e..4d1b9da3d 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -274,6 +274,7 @@ def cloud_ai_100_exec_kv( write_io_dir: Optional[str] = None, automation=False, prompt_to_lora_id_mapping: Optional[List[int]] = None, + is_tlm: bool = False, ): """ This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. @@ -319,6 +320,7 @@ def cloud_ai_100_exec_kv( enable_debug_logs=enable_debug_logs, write_io_dir=write_io_dir, full_batch_size=full_batch_size, + is_tlm=is_tlm, ) if full_batch_size is None: exec_info = [ @@ -355,9 +357,11 @@ def __init__( device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, + is_tlm: Optional[int] = None, ) -> None: self._ctx_len = ctx_len self._write_io_dir = write_io_dir + self.is_tlm = is_tlm # Load QPC self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs) @@ -365,6 +369,7 @@ def __init__( # Fetch the variables from the QPC self._vocab_size = self._fetch_vocab_size() # Fetch Vocab size self.batch_size, self._prefill_seq_len = self._fetch_batch_size_prefill_seq_len() + self._decode_seq_len = self._fetch_decode_seq_len() self.full_batch_size = ( full_batch_size if full_batch_size else self._fetch_full_batch_size() ) # Check and fetch full batch size if CB is enabled @@ -441,6 +446,22 @@ def _fetch_batch_size_prefill_seq_len( batch_size, prefill_seq_len = self._session.bindings[self._session.binding_index_map["input_ids"]].dims return batch_size, prefill_seq_len + def _fetch_decode_seq_len( + self, + ): + """ + Fetches the decode sequence length from the session's bindings or allowed shapes. + + Returns: + decode_seq_len: The decode sequence length fetched from the session's bindings or allowed shapes. + """ + decode_seq_len = None + if self._session.allowed_shapes: + decode_seq_len = min( + [x[self._session.binding_index_map["input_ids"]][1][1] for x in self._session.allowed_shapes] + ) + return decode_seq_len + def _fetch_vocab_size( self, ): @@ -485,9 +506,19 @@ def prepare_decode_inputs(self): Returns: dict: The decode inputs. """ + batch_size = self.full_batch_size if self.full_batch_size is not None else self.batch_size decode_inputs = {} - decode_inputs["input_ids"] = self.decode_input_ids - decode_inputs["position_ids"] = self.decode_pos_ids + if self.is_tlm: + position_ids = np.full((batch_size, self._decode_seq_len), -1, dtype=np.int64) + position_ids[:, -1] = self.decode_pos_ids.flatten() + input_ids = np.zeros((batch_size, self._decode_seq_len), dtype=np.int64) + input_ids[:,-1] = self.decode_input_ids.flatten() + decode_inputs["input_ids"] = input_ids + decode_inputs["position_ids"] = position_ids + decode_inputs["num_logits_to_keep"] = np.zeros((self._decode_seq_len,1)) + else: + decode_inputs["input_ids"] = self.decode_input_ids + decode_inputs["position_ids"] = self.decode_pos_ids if self.batch_index is not None: decode_inputs["batch_index"] = self.batch_index @@ -628,6 +659,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if decode_batch_id is not None: inputs["batch_index"] = decode_batch_id + if self.is_tlm: + inputs["num_logits_to_keep"] = np.zeros((1,1)) if self._prompt_to_lora_id_mapping_prefill: if self.full_batch_size: @@ -668,7 +701,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): """ # Set logits placeholder for decode - logits_out_placeholder = np.zeros((self.full_batch_size, 1, self._vocab_size), dtype=np.float32) + logits_out_placeholder = np.zeros((self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32) self._session.set_buffers({"logits": logits_out_placeholder}) # Generate flag for tracking progress for each batch ID current_decode_ongoing = np.full((self.full_batch_size, 1), True) @@ -694,7 +727,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): for decode_batch_id in range(self.full_batch_size): if ( - next_token_id[decode_batch_id] == self.tokenizer.eos_token_id + next_token_id[decode_batch_id,-1] == self.tokenizer.eos_token_id or generated_id_current_index[decode_batch_id] >= self.generation_len[decode_batch_id] ): if prompt_queue: @@ -724,10 +757,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): current_decode_ongoing[decode_batch_id] = False else: # If the generated sequence is valid and within generation len prepare for next decode - decode_inputs["input_ids"][decode_batch_id] = next_token_id[decode_batch_id] - decode_inputs["position_ids"][decode_batch_id] += 1 + decode_inputs["input_ids"][decode_batch_id, -1] = next_token_id[decode_batch_id, -1] + decode_inputs["position_ids"][decode_batch_id, -1] += 1 self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = ( - next_token_id[decode_batch_id] + next_token_id[decode_batch_id, -1] ) generated_id_current_index[decode_batch_id] += 1 @@ -747,6 +780,9 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform Returns: num_token (int): The number of tokens processed in the decoding process. """ + if self.is_tlm: + logits_out_placeholder = np.zeros((self.batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32) + self._session.set_buffers({"logits": logits_out_placeholder}) finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id num_token = 0 for num_token in range(1, generation_len): @@ -760,8 +796,8 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform # Prepare inputs for next iteration decode_inputs["input_ids"] = outputs["logits"].argmax(2) - decode_inputs["position_ids"] += 1 - self.generated_ids[:, num_token] = decode_inputs["input_ids"].squeeze(1) + decode_inputs["position_ids"][:,-1] += 1 + self.generated_ids[:, num_token] = decode_inputs["input_ids"][:,-1] finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id if finished_sequences.all(): @@ -811,9 +847,10 @@ def __init__( device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, + is_tlm: bool = False, ) -> None: self._qaic_model = QEffTextGenerationBase( - tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir + tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm ) self._full_batch_size = self._qaic_model.full_batch_size self._tokenizer = self._qaic_model.tokenizer @@ -1029,3 +1066,8 @@ def generate( perf_metrics=perf_metrics, ) return latency_stats + + def validate_tlm_gen_tokens(self): + gen_len = (self.generated_ids) + self.prefill_seq_len + diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 3df8d0028..4a20df414 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -392,9 +392,9 @@ def generate( raise ValueError("Only AI_100 runtime is supported right now via generate API") if not isinstance(self.qpc_path, Path): raise TypeError("Please run compile API first!") - if self.is_tlm or getattr(self, "is_dlm", False): + if getattr(self, "is_dlm", False): raise NotImplementedError( - "generate method is not yet supported for tlm or dlm models used in Speculative Decoding" + "generate method is not yet supported for dlm models used in Speculative Decoding" ) generation_len = kwargs.pop("generation_len", None) return QEfficient.cloud_ai_100_exec_kv( @@ -403,6 +403,7 @@ def generate( prompt=prompts, device_id=device_id, generation_len=generation_len, + is_tlm=self.is_tlm, ) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index a3783af03..4a3ba3ff3 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -61,3 +61,4 @@ class Constants: GB = 2**30 MAX_QPC_LIMIT = 30 MAX_RETRIES = 5 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download + NUM_SPECULATIVE_TOKENS = 2 diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index ba1c1fa48..1d4d4516c 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -167,7 +167,7 @@ def run_ort_session(self, inputs, session) -> dict: ort_outputs = dict(zip(output_names, outputs_data)) return ort_outputs - def run_kv_model_on_ort(self, model_path): + def run_kv_model_on_ort(self, model_path, is_tlm=False): """ Function responsible for running ``ONNX`` model on onnxruntime and return the output tokens @@ -197,12 +197,17 @@ def run_kv_model_on_ort(self, model_path): generated_ids = [] inputs = self.input_handler.prepare_ort_inputs() + if is_tlm: + nltk = np.zeros((1,1), dtype=np.int64) + inputs["num_logits_to_keep"] = nltk ort_outputs = self.run_ort_session(inputs, session) ort_outputs = self.input_handler.update_ort_outputs(ort_outputs) for _ in range(1, self.gen_len): generated_ids.append(ort_outputs["logits"].argmax(-1).reshape(-1, 1)) inputs = self.input_handler.update_ort_inputs(inputs, ort_outputs) + if is_tlm: + inputs["num_logits_to_keep"] = nltk ort_outputs = self.run_ort_session(inputs, session) ort_outputs = self.input_handler.update_ort_outputs(ort_outputs) diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 6f0402c1b..a0ba7d64b 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +from typing import Optional + import numpy as np import pytest from transformers import AutoModelForCausalLM @@ -38,6 +40,9 @@ "ibm-granite/granite-20b-code-base", ] +spd_test_models = [ + "JackFram/llama-68m", +] def load_causal_lm_model(model_config): """ @@ -69,6 +74,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( prompt_len: int = Constants.PROMPT_LEN, ctx_len: int = Constants.CTX_LEN, n_layer: int = 1, + num_speculative_tokens: Optional[int] = None, ): """ Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. @@ -98,7 +104,8 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) - qeff_model = QEFFAutoModelForCausalLM(model_hf) + is_tlm = False if num_speculative_tokens is None else True + qeff_model = QEFFAutoModelForCausalLM(model_hf, is_tlm=is_tlm) pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) @@ -107,7 +114,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( ).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output" onnx_model_path = qeff_model.export() - ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path) + ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm) assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output." @@ -120,6 +127,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( num_cores=14, mxfp6=False, aic_enable_depth_first=False, + num_speculative_tokens=num_speculative_tokens, ) exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size @@ -145,7 +153,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf) pytorch_hf_tokens = np.vstack(pytorch_hf_tokens) - qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True) + qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True, is_tlm=is_tlm) onnx_model_path = qeff_model.export() if not get_available_device_id(): @@ -158,6 +166,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( mxfp6=False, aic_enable_depth_first=False, full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens ) exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts) @@ -215,6 +224,22 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_name", spd_test_models) +def test_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + if model_name == "microsoft/Phi-3-mini-4k-instruct": + n_layer = 2 # test only 2 layer models + else: + n_layer = 1 + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, num_speculative_tokens=Constants.NUM_SPECULATIVE_TOKENS) + + @pytest.mark.on_qaic def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(): """ diff --git a/tests/transformers/test_transformer_pytorch_transforms.py b/tests/transformers/test_transformer_pytorch_transforms.py index f0758f2f3..bcc1d4129 100644 --- a/tests/transformers/test_transformer_pytorch_transforms.py +++ b/tests/transformers/test_transformer_pytorch_transforms.py @@ -21,51 +21,56 @@ from QEfficient.utils.logging_utils import logger KVCacheTransformTestConfigs = [ - ("llama", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), - ("llama", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, True), - ("llama", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), - ("llama", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), - ("llama", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), - ("gpt2", 3, 12, 192, {"n_inner": 512}, 0.8, False), - ("gpt2", 1, 12, 192, {"n_inner": 512}, 0.8, False), - ("codegen", 1, 16, 1024, {"n_inner": 2048}, 0.8, False), - ("codegen", 3, 16, 1024, {"n_inner": 2048}, 0.8, False), - ("falcon", 1, 71, 4544, {"multi_query": True}, 1.5, False), - ("falcon", 3, 71, 4544, {"multi_query": False}, 1.5, False), - ("falcon", 1, 71, 4544, {"multi_query": False}, 1.5, False), - ("falcon", 3, 71, 4544, {"multi_query": True}, 1.5, False), - ("gptj", 3, 16, 4096, {"n_inner": 512}, 1, False), - ("gptj", 1, 16, 4096, {"n_inner": 512}, 1.2, False), - ("mistral", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), - ("mistral", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), - ("mistral", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), - ("mistral", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), - ("mixtral", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), - ("mixtral", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), - ("mixtral", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), - ("mixtral", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), - ("mpt", 1, 16, 2048, {}, 0.8, False), - ("mpt", 3, 16, 2048, {}, 0.8, False), - ("phi", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), - ("phi", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), - ("phi", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), - ("phi", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), - ("phi3", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), - ("phi3", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), - ("phi3", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), - ("phi3", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), - ("qwen2", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), - ("qwen2", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), - ("qwen2", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8, False), - ("qwen2", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8, False), - ("starcoder2", 3, 24, 192, {"num_key_value_heads": 2, "intermediate_size": 512}, 0.8, False), - ("starcoder2", 1, 24, 192, {"num_key_value_heads": 2, "intermediate_size": 512}, 0.8, False), - ("starcoder2", 3, 24, 192, {"num_key_value_heads": 24, "intermediate_size": 512}, 0.8, False), - ("starcoder2", 1, 24, 192, {"num_key_value_heads": 24, "intermediate_size": 512}, 0.8, False), - ("gemma", 3, 8, 2048, {"num_key_value_heads": 1, "intermediate_size": 512}, 0.8, False), - ("gemma", 1, 8, 2048, {"num_key_value_heads": 1, "intermediate_size": 512}, 0.8, False), + ("llama", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), + ("llama", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), + ("llama", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), + ("llama", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), + ("gpt2", 3, 12, 192, {"n_inner": 512}, 0.8), + ("gpt2", 1, 12, 192, {"n_inner": 512}, 0.8), + ("codegen", 1, 16, 1024, {"n_inner": 2048}, 0.8), + ("codegen", 3, 16, 1024, {"n_inner": 2048}, 0.8), + ("falcon", 1, 71, 4544, {"multi_query": True}, 1.5), + ("falcon", 3, 71, 4544, {"multi_query": False}, 1.5), + ("falcon", 1, 71, 4544, {"multi_query": False}, 1.5), + ("falcon", 3, 71, 4544, {"multi_query": True}, 1.5), + ("gptj", 3, 16, 4096, {"n_inner": 512}, 1), + ("gptj", 1, 16, 4096, {"n_inner": 512}, 1.2), + ("mistral", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), + ("mistral", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), + ("mistral", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), + ("mistral", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), + ("mixtral", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), + ("mixtral", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), + ("mixtral", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), + ("mixtral", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), + ("mpt", 1, 16, 2048, {}, 0.8), + ("mpt", 3, 16, 2048, {}, 0.8), + ("phi", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), + ("phi", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), + ("phi", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), + ("phi", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), + ("phi3", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), + ("phi3", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), + ("phi3", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), + ("phi3", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), + ("qwen2", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), + ("qwen2", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), + ("qwen2", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), + ("qwen2", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), + ("starcoder2", 3, 24, 192, {"num_key_value_heads": 2, "intermediate_size": 512}, 0.8), + ("starcoder2", 1, 24, 192, {"num_key_value_heads": 2, "intermediate_size": 512}, 0.8), + ("starcoder2", 3, 24, 192, {"num_key_value_heads": 24, "intermediate_size": 512}, 0.8), + ("starcoder2", 1, 24, 192, {"num_key_value_heads": 24, "intermediate_size": 512}, 0.8), + ("gemma", 3, 8, 2048, {"num_key_value_heads": 1, "intermediate_size": 512}, 0.8), + ("gemma", 1, 8, 2048, {"num_key_value_heads": 1, "intermediate_size": 512}, 0.8), ] +SpDTransformTestConfigs = [ + ("llama", 3, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), + ("llama", 1, 32, 128, {"num_key_value_heads": 8, "intermediate_size": 512}, 0.8), + ("llama", 3, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), + ("llama", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), +] def compare_original_vs_kv_model_pt_outputs(original_val, kv_val, tolerance=1e-6) -> bool: # Base case @@ -192,11 +197,50 @@ def test_rms_norm_ops_transform(module: torch.nn.Module, hidden_size: int, input @pytest.mark.parametrize( - "config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance, is_tlm", + "config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance", KVCacheTransformTestConfigs, ) def test_kv_cache_transform( - config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance, is_tlm, + config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance +): + config = AutoConfig.for_model( + config_class, + **kwargs, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + use_cache=True, + cache_position=None, + position_embeddings=None, + ) + hf_model = AutoModelForCausalLM.from_config(config=config, attn_implementation="eager") + + kv_cache = None + if hasattr(config, "cache_implementation") and config.cache_implementation == "hybrid": + # Create a KV Cache from HybridCache class to pass as an object for models which use Hybrid KV Cache + # Refer https://github.com/huggingface/transformers/issues/32896 for more info + # This requires torch._dynamo present in torch>=2.3.0 + kv_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=32) + + padding_shape = get_padding_shape_from_config(config=config, batch_size=1, seq_len=32) + + run_kv_cache_transform_and_test( + hf_model, + num_hidden_layers=num_hidden_layers, + padding_shape=padding_shape, + vocab_size=config.vocab_size, + input_len=8, + logits_tolerance=logits_tolerance, + kv_cache=kv_cache, + ) + + +@pytest.mark.parametrize( + "config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance", + SpDTransformTestConfigs, +) +def test_spd_transform( + config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance ): config = AutoConfig.for_model( config_class, @@ -227,7 +271,7 @@ def test_kv_cache_transform( input_len=8, logits_tolerance=logits_tolerance, kv_cache=kv_cache, - is_tlm=is_tlm, + is_tlm=True, ) From 062b4d564d491a9299622426eda556a476fcee29 Mon Sep 17 00:00:00 2001 From: eplatero Date: Tue, 3 Dec 2024 19:24:49 -0600 Subject: [PATCH 23/30] change llama test model from jackfram to tinyllama to match other tests Signed-off-by: eplatero --- tests/transformers/models/test_causal_lm_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index a0ba7d64b..9162f7660 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -41,7 +41,7 @@ ] spd_test_models = [ - "JackFram/llama-68m", + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", ] def load_causal_lm_model(model_config): From 6ad5a6914d1f6eac74993cc7346efd26de919035 Mon Sep 17 00:00:00 2001 From: eplatero Date: Wed, 4 Dec 2024 06:44:54 -0600 Subject: [PATCH 24/30] fix failing tlm_dlm tests by passing is_tlm correctly in modeling_auto Signed-off-by: eplatero --- .../transformers/models/modeling_auto.py | 7 +++--- .../spd/test_tlm_dlm_export_and_compile.py | 24 +++++++++---------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 4a20df414..94a9a611a 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -51,7 +51,7 @@ def __repr__(self) -> str: @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, *args, **kwargs): if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') @@ -61,7 +61,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) - return cls(model) + return cls(model, is_tlm=is_tlm) @property def model_name(self) -> str: @@ -172,9 +172,8 @@ def from_pretrained( "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 ) - self = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + self = super().from_pretrained(pretrained_model_name_or_path, is_tlm=is_tlm, *args, **kwargs) self.continuous_batching = continuous_batching - self.is_tlm = is_tlm return self @property diff --git a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py index b283ddf8b..32628951c 100644 --- a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py +++ b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py @@ -13,55 +13,53 @@ from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils.device_utils import get_available_device_id configs = [ pytest.param( - [0], # device_group 2, # num_speculative_tokens 32, # prefill_seq_len 128, # ctx_len 1, # prefill_bsz 8, # full_batch_size "JackFram/llama-68m", # model_name - True, # continuous_batching id="CB llama", ), pytest.param( - [0], # device_group 2, # num_speculative_tokens 32, # prefill_seq_len 128, # ctx_len 1, # prefill_bsz None, # full_batch_size "JackFram/llama-68m", # model_name - False, # continuous_batching id="non-CB llama", ), ] @pytest.mark.parametrize( - "device_group,num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name,continuous_batching", + "num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name", configs, ) def test_llama_tlm_logit_dims( - device_group: List[int], num_speculative_tokens: int, prefill_seq_len: int, ctx_len: int, prefill_bsz: int, full_batch_size: Optional[int], model_name: str, - continuous_batching: bool, ): + device_group = get_available_device_id() + if not device_group: + pytest.skip("No available devices to run model on Cloud AI 100") # get vocab size tokenizer = AutoTokenizer.from_pretrained(model_name) vocab_size = len(tokenizer) # export and compile tlm model - is_tlm = num_speculative_tokens is not None + continuous_batching = full_batch_size is not None qeff_model = AutoModelForCausalLM.from_pretrained( - model_name, continuous_batching=continuous_batching, is_tlm=is_tlm + model_name, continuous_batching=continuous_batching, is_tlm=True ) qpc_path: str = qeff_model.compile( num_devices=len(device_group), @@ -111,24 +109,26 @@ def test_llama_tlm_logit_dims( @pytest.mark.parametrize( - "device_group,num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name,continuous_batching", + "num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name", configs, ) def test_llama_dlm_logit_dims( - device_group: List[int], num_speculative_tokens: int, prefill_seq_len: int, ctx_len: int, prefill_bsz: int, full_batch_size: Optional[int], model_name: str, - continuous_batching: bool, ): + device_group = get_available_device_id() + if not device_group: + pytest.skip("No available devices to run model on Cloud AI 100") # get vocab size tokenizer = AutoTokenizer.from_pretrained(model_name) vocab_size = len(tokenizer) # export and compile tlm model + continuous_batching = full_batch_size is not None qeff_model = AutoModelForCausalLM.from_pretrained(model_name, continuous_batching=continuous_batching) qpc_path: str = qeff_model.compile( num_devices=len(device_group), From 58458c0ec1cfda658a638180e580c56a9693ddd4 Mon Sep 17 00:00:00 2001 From: eplatero Date: Wed, 4 Dec 2024 11:34:00 -0600 Subject: [PATCH 25/30] rm dlm specialization Signed-off-by: eplatero --- QEfficient/base/modeling_qeff.py | 5 -- .../transformers/models/modeling_auto.py | 22 ------ docs/source/quick_start.md | 6 +- .../spd/test_tlm_dlm_export_and_compile.py | 73 ------------------- 4 files changed, 3 insertions(+), 103 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 8a1c6eb7e..064d7e6f0 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -202,7 +202,6 @@ def _compile( custom_io: Optional[Dict[str, str]] = None, mdp_ts_num_devices: int = 1, num_speculative_tokens: Optional[int] = None, - is_dlm: bool = False, **compiler_options, ) -> str: """ @@ -215,7 +214,6 @@ def _compile( :custom_io (dict): Custom IO to specify the input and outputs in different formats than default :mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing. :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model. - :is_dlm (bool, optional): Whether this is a Speculative Decoding draft-model. ``Defaults to False``. :compiler_options: Pass any compiler option as input. Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below: - aic_num_cores=16 -> -aic-num-cores=16 - convert_to_fp16=True -> -convert-to-fp16 @@ -251,9 +249,6 @@ def _compile( if num_speculative_tokens: compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens})) - if is_dlm: - compile_hash.update(to_hashable({"is_dlm": is_dlm})) - # Check if already compiled compile_hash = compile_hash.hexdigest()[:16] qpc_path = qpc_path.with_name(qpc_path.name + "-" + compile_hash) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 94a9a611a..d0b58ffb7 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -260,7 +260,6 @@ def compile( mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, - is_dlm: bool = False, **compiler_options, ) -> str: """ @@ -280,17 +279,12 @@ def compile( :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to True``. :mxint8_kv_cache (bool, optional): Whether to use ``mxint8`` compression for KV cache. ``Defaults to False``. :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model. - :is_dlm (bool, optional): Whether this is a Speculative Decoding draft-model. ``Defaults to False``. :mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. ``Defaults to -1``. :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``. Returns: :str: Path of the compiled ``qpc`` package. """ - # raise error if is_tlm and is_dlm aren't mutex - if self.is_tlm and is_dlm: - raise ValueError("`num_speculative_tokens` arg and `is_dlm` flag are mutually exclusive.") - if self.is_tlm: # assert num_speculative_tokens cfg is acceptable if defined if num_speculative_tokens is None: @@ -332,17 +326,6 @@ def compile( decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else None specializations.append(decode_specialization) - # Extra Specializations based on case - if is_dlm: - dlm_specialization = { - "batch_size": full_batch_size if self.continuous_batching else batch_size, - "seq_len": 2, - "ctx_len": ctx_len, - } - dlm_specialization.update({"full_batch_size": full_batch_size}) if self.continuous_batching else None - specializations.append(dlm_specialization) - self.is_dlm = True - # Custom IO custom_io = {} kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" @@ -362,7 +345,6 @@ def compile( custom_io=custom_io, mdp_ts_num_devices=num_devices, num_speculative_tokens=num_speculative_tokens, - is_dlm=is_dlm, aic_num_cores=num_cores, **compiler_options, ) @@ -391,10 +373,6 @@ def generate( raise ValueError("Only AI_100 runtime is supported right now via generate API") if not isinstance(self.qpc_path, Path): raise TypeError("Please run compile API first!") - if getattr(self, "is_dlm", False): - raise NotImplementedError( - "generate method is not yet supported for dlm models used in Speculative Decoding" - ) generation_len = kwargs.pop("generation_len", None) return QEfficient.cloud_ai_100_exec_kv( tokenizer, diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index 7fd0ff338..fd689c8c5 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -153,7 +153,7 @@ End to End demo examples for various models are available in **notebooks** direc ### Draft-Based Speculative Decoding Draft-based speculative decoding is the approach where a small Draft Language Model (DLM) makes `num_speculative_tokens` autoregressive speculations ahead of the Target Language Model (TLM). The objective is to predict what the TLM would have predicted if it would have been used instead of the DLM. This approach is beneficial when the autoregressive decode phase of the TLM is memory bound and thus, we can leverage the extra computing resources of our hardware by batching the speculations of the DLM as an input to TLM to validate the speculations. -To export and compile both DLM/TLM, add corresponding `is_tlm`, `num_speculative_tokens`, and `is_dlm` arguments: +To export and compile both DLM/TLM, add corresponding `is_tlm` and `num_speculative_tokens` for TLM and export DLM as you would any other QEfficient LLM model: ```Python tlm_name = "meta-llama/Llama-2-70b-chat-hf" @@ -162,7 +162,7 @@ k = 3 # DLM will make `k` speculations tlm = AutoModelForCausalLM.from_pretrained(tlm_name, is_tlm=True) dlm = AutoModelForCausalLM.from_pretrained(dlm_name) tlm.compile(num_speculative_tokens=k) -dlm.compile(is_dlm=True) +dlm.compile() ``` -The `is_tlm` flag is fed during the instantiation of the model because making it a TLM requires slight changes to the ONNX graph. Once complete, the user can specify `num_speculative_tokens` to define the actual number of speculations that the TLM will take as input during the decode phase. As for the DLM, no new changes are required at the ONNX level. The only change is that the model now receives an additional specialization during the compilation step to account for feeding the "bonus" token in the case where all speculations are accepted. \ No newline at end of file +The `is_tlm` flag is fed during the instantiation of the model because slight changes to the ONNX graph are required. Once complete, the user can specify `num_speculative_tokens` to define the actual number of speculations that the TLM will take as input during the decode phase. As for the DLM, no new changes are required at the ONNX or compile level. \ No newline at end of file diff --git a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py index 32628951c..fcb69d2db 100644 --- a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py +++ b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py @@ -106,76 +106,3 @@ def test_llama_tlm_logit_dims( # assert expected logit dims assert prefill_logits["logits"].shape == prefill_outputs["logits"].shape assert decode_logits["logits"].shape == decode_outputs["logits"].shape - - -@pytest.mark.parametrize( - "num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name", - configs, -) -def test_llama_dlm_logit_dims( - num_speculative_tokens: int, - prefill_seq_len: int, - ctx_len: int, - prefill_bsz: int, - full_batch_size: Optional[int], - model_name: str, -): - device_group = get_available_device_id() - if not device_group: - pytest.skip("No available devices to run model on Cloud AI 100") - # get vocab size - tokenizer = AutoTokenizer.from_pretrained(model_name) - vocab_size = len(tokenizer) - - # export and compile tlm model - continuous_batching = full_batch_size is not None - qeff_model = AutoModelForCausalLM.from_pretrained(model_name, continuous_batching=continuous_batching) - qpc_path: str = qeff_model.compile( - num_devices=len(device_group), - num_cores=16, - batch_size=prefill_bsz, - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - mxfp6_matmul=True, - full_batch_size=full_batch_size, - is_dlm=True, - ) - - # init qaic session - session = QAICInferenceSession(qpc_path, device_ids=device_group) - # skip inputs/outputs buffers - session.skip_buffers(set([x for x in session.input_names if x.startswith("past_")])) - session.skip_buffers(set([x for x in session.output_names if x.endswith("_RetainedState")])) - # prefill dummy inputs - prefill_inputs = dict( - input_ids=np.zeros((prefill_bsz, prefill_seq_len), dtype=np.int64), - position_ids=np.arange(prefill_seq_len, dtype=np.int64).reshape(-1, 1).repeat(prefill_bsz, 1).transpose(), - batch_index=np.arange(prefill_bsz, dtype=np.int64).reshape(-1, 1), - ) - # decode-1 dummy inputs - decode_bsz = full_batch_size if full_batch_size is not None else prefill_bsz - decode1_inputs = dict( - input_ids=np.zeros((decode_bsz, 1), dtype=np.int64), - position_ids=np.full((decode_bsz, 1), -1, dtype=np.int64), - batch_index=np.arange(decode_bsz, dtype=np.int64).reshape(-1, 1), - ) - # decode-2 dummy inputs - decode2_inputs = dict( - input_ids=np.zeros((decode_bsz, 2), dtype=np.int64), - position_ids=np.full((decode_bsz, 2), -1, dtype=np.int64), - batch_index=np.arange(decode_bsz, dtype=np.int64).reshape(-1, 1), - ) - # create dummy logits - prefill_logits = dict(logits=np.random.randn(prefill_bsz, 1, vocab_size).astype(np.float32)) - decode_logits = dict(logits=np.random.randn(decode_bsz, 1, vocab_size).astype(np.float32)) - # get prefill/decode logits - session.set_buffers(prefill_logits) - prefill_outputs = session.run(prefill_inputs) - session.set_buffers(decode_logits) - decode1_outputs = session.run(decode1_inputs) - decode2_outputs = session.run(decode2_inputs) - - # assert expected logit dims - assert prefill_logits["logits"].shape == prefill_outputs["logits"].shape - assert decode_logits["logits"].shape == decode1_outputs["logits"].shape - assert decode_logits["logits"].shape == decode2_outputs["logits"].shape From 39acd5f0c46abd39c9294113fe6a7a112b7cf2f9 Mon Sep 17 00:00:00 2001 From: eplatero Date: Wed, 4 Dec 2024 13:47:13 -0600 Subject: [PATCH 26/30] updated quick_docs Signed-off-by: eplatero --- docs/source/quick_start.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index fd689c8c5..470446a98 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -151,7 +151,7 @@ qeff_model.generate(prompts=["My name is"]) End to End demo examples for various models are available in **notebooks** directory. Please check them out. ### Draft-Based Speculative Decoding -Draft-based speculative decoding is the approach where a small Draft Language Model (DLM) makes `num_speculative_tokens` autoregressive speculations ahead of the Target Language Model (TLM). The objective is to predict what the TLM would have predicted if it would have been used instead of the DLM. This approach is beneficial when the autoregressive decode phase of the TLM is memory bound and thus, we can leverage the extra computing resources of our hardware by batching the speculations of the DLM as an input to TLM to validate the speculations. +Draft-based speculative decoding is a technique where a small Draft Language Model (DLM) makes `num_speculative_tokens` autoregressive speculations ahead of the Target Language Model (TLM). The objective is to predict what the TLM would have predicted if it would have been used instead of the DLM. This approach is beneficial when the autoregressive decode phase of the TLM is memory bound and thus, we can leverage the extra computing resources of our hardware by batching the speculations of the DLM as an input to TLM to validate the speculations. To export and compile both DLM/TLM, add corresponding `is_tlm` and `num_speculative_tokens` for TLM and export DLM as you would any other QEfficient LLM model: From 6202874c124f3086aeda8fbf047014a808cd3545 Mon Sep 17 00:00:00 2001 From: eplatero Date: Thu, 5 Dec 2024 12:13:09 -0600 Subject: [PATCH 27/30] rm tlm dims test since that's already tested and generalize common code in pytorch_transforms Signed-off-by: eplatero --- .../spd/test_tlm_dlm_export_and_compile.py | 108 ------------------ .../test_transformer_pytorch_transforms.py | 85 ++++++++------ 2 files changed, 52 insertions(+), 141 deletions(-) delete mode 100644 tests/transformers/spd/test_tlm_dlm_export_and_compile.py diff --git a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py b/tests/transformers/spd/test_tlm_dlm_export_and_compile.py deleted file mode 100644 index fcb69d2db..000000000 --- a/tests/transformers/spd/test_tlm_dlm_export_and_compile.py +++ /dev/null @@ -1,108 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -from typing import List, Optional - -import numpy as np -import pytest -from transformers import AutoTokenizer - -from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM -from QEfficient.generation.cloud_infer import QAICInferenceSession -from QEfficient.utils.device_utils import get_available_device_id - -configs = [ - pytest.param( - 2, # num_speculative_tokens - 32, # prefill_seq_len - 128, # ctx_len - 1, # prefill_bsz - 8, # full_batch_size - "JackFram/llama-68m", # model_name - id="CB llama", - ), - pytest.param( - 2, # num_speculative_tokens - 32, # prefill_seq_len - 128, # ctx_len - 1, # prefill_bsz - None, # full_batch_size - "JackFram/llama-68m", # model_name - id="non-CB llama", - ), -] - - -@pytest.mark.parametrize( - "num_speculative_tokens,prefill_seq_len,ctx_len,prefill_bsz,full_batch_size,model_name", - configs, -) -def test_llama_tlm_logit_dims( - num_speculative_tokens: int, - prefill_seq_len: int, - ctx_len: int, - prefill_bsz: int, - full_batch_size: Optional[int], - model_name: str, -): - device_group = get_available_device_id() - if not device_group: - pytest.skip("No available devices to run model on Cloud AI 100") - # get vocab size - tokenizer = AutoTokenizer.from_pretrained(model_name) - vocab_size = len(tokenizer) - - # export and compile tlm model - continuous_batching = full_batch_size is not None - qeff_model = AutoModelForCausalLM.from_pretrained( - model_name, continuous_batching=continuous_batching, is_tlm=True - ) - qpc_path: str = qeff_model.compile( - num_devices=len(device_group), - num_cores=16, - batch_size=prefill_bsz, - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - mxfp6_matmul=True, - full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens, - ) - - # init qaic session - session = QAICInferenceSession(qpc_path, device_ids=device_group) - # skip inputs/outputs buffers - session.skip_buffers(set([x for x in session.input_names if x.startswith("past_")])) - session.skip_buffers(set([x for x in session.output_names if x.endswith("_RetainedState")])) - # prefill dummy inputs - num_logits_to_keep = num_speculative_tokens + 1 - prefill_inputs = dict( - input_ids=np.zeros((prefill_bsz, prefill_seq_len), dtype=np.int64), - position_ids=np.arange(prefill_seq_len, dtype=np.int64).reshape(-1, 1).repeat(prefill_bsz, 1).transpose(), - num_logits_to_keep=np.arange(1).reshape(1, 1), - ) - # decode dummy inputs - decode_bsz = full_batch_size if full_batch_size is not None else prefill_bsz - decode_inputs = dict( - input_ids=np.zeros((decode_bsz, num_logits_to_keep), dtype=np.int64), - position_ids=np.full((decode_bsz, num_logits_to_keep), -1, dtype=np.int64), - num_logits_to_keep=np.arange(num_logits_to_keep).reshape(num_logits_to_keep, 1), - ) - if full_batch_size is not None: - prefill_inputs["batch_index"] = np.arange(prefill_bsz, dtype=np.int64).reshape(prefill_bsz, 1) - decode_inputs["batch_index"] = np.arange(decode_bsz, dtype=np.int64).reshape(-1, 1) - # create dummy logits - prefill_logits = dict(logits=np.random.randn(prefill_bsz, 1, vocab_size).astype(np.float32)) - decode_logits = dict(logits=np.random.randn(decode_bsz, num_logits_to_keep, vocab_size).astype(np.float32)) - # get prefill/decode logits - session.set_buffers(prefill_logits) - prefill_outputs = session.run(prefill_inputs) - session.set_buffers(decode_logits) - decode_outputs = session.run(decode_inputs) - - # assert expected logit dims - assert prefill_logits["logits"].shape == prefill_outputs["logits"].shape - assert decode_logits["logits"].shape == decode_outputs["logits"].shape diff --git a/tests/transformers/test_transformer_pytorch_transforms.py b/tests/transformers/test_transformer_pytorch_transforms.py index bcc1d4129..e1499919f 100644 --- a/tests/transformers/test_transformer_pytorch_transforms.py +++ b/tests/transformers/test_transformer_pytorch_transforms.py @@ -13,6 +13,7 @@ from transformers.cache_utils import HybridCache from QEfficient.customop.matmulnbits import QuantLinearORT +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform from QEfficient.transformers.quantizers.awq import WQLinear_GEMM from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ @@ -72,6 +73,43 @@ ("llama", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), ] +def create_qaic_model_inputs( + input_len: int, + vocab_size: int, + padding_shape: tuple, + num_hidden_layers: int, + is_tlm: bool = False +) -> dict: + """create pytorch QEff model inputs + + ``Mandatory`` Args: + :input_len (int): input length. + :vocab_size (int): vocab size. + :padding_shape (tuple): padding shape of KV$. + :num_hidden_layers (int): number of hidden layers. + ``Optional`` Args: + :is_tlm (bool, optional): whether this is an SpD TLM model. Defaults to False. + + Returns: + :dict: pytorch QEff model inputs + """ + input_ids = torch.randint(0, vocab_size, size=(1, input_len)) + past_key_values = [] + for _ in range(num_hidden_layers): + past_key = torch.zeros((padding_shape), dtype=torch.float32) + past_value = torch.zeros((padding_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) + inputs = dict( + input_ids=input_ids, + position_ids=torch.Tensor([range(input_ids.shape[1])]).long(), + past_key_values=tuple(past_key_values), + output_hidden_states=True, + ) + if is_tlm: + inputs["num_logits_to_keep"] = torch.zeros((input_len, 1)) + return inputs + def compare_original_vs_kv_model_pt_outputs(original_val, kv_val, tolerance=1e-6) -> bool: # Base case if original_val is None: @@ -96,11 +134,12 @@ def compare_original_vs_kv_model_pt_outputs(original_val, kv_val, tolerance=1e-6 def run_kv_cache_transform_and_test( - hf_model, num_hidden_layers, padding_shape, vocab_size, input_len, logits_tolerance=0.8, kv_cache=None, is_tlm=False, + hf_model, qaic_model_inputs, logits_tolerance=0.8, kv_cache=None, is_tlm=False, ): hf_model.eval() # Run original model - input_ids = torch.randint(0, vocab_size, size=(1, input_len)) + input_ids = qaic_model_inputs["input_ids"] + input_len = input_ids.shape[1] with torch.inference_mode(): if isinstance(kv_cache, type(None)): original_model_outputs = hf_model( @@ -121,33 +160,13 @@ def run_kv_cache_transform_and_test( else: original_model_outputs = hf_model(input_ids=input_ids, output_hidden_states=True) - # Apply transform - hf_model, transformed = KVCacheTransform.apply(hf_model) - assert transformed - if is_tlm: - hf_model, transformed = SpDTransform.apply(hf_model) - assert transformed - + # Apply transforms + hf_model = QEFFAutoModelForCausalLM(hf_model, is_tlm=is_tlm).model - # Prepare KV model inputs - past_key_values = [] - for _ in range(num_hidden_layers): - past_key = torch.zeros((padding_shape), dtype=torch.float32) - past_value = torch.zeros((padding_shape), dtype=torch.float32) - pkv = (past_key, past_value) - past_key_values.append(pkv) - inputs = dict( - input_ids=input_ids, - position_ids=torch.Tensor([range(input_ids.shape[1])]).long(), - past_key_values=tuple(past_key_values), - output_hidden_states=True, - ) - if is_tlm: - inputs["num_logits_to_keep"] = torch.zeros((input_len, 1)) # Run KV model with torch.inference_mode(): - transformed_model_outputs = hf_model(**inputs) + transformed_model_outputs = hf_model(**qaic_model_inputs) assert original_model_outputs.keys() == transformed_model_outputs.keys(), "Model output keys do not match!" @@ -224,12 +243,12 @@ def test_kv_cache_transform( padding_shape = get_padding_shape_from_config(config=config, batch_size=1, seq_len=32) + # Prepare KV model inputs + qaic_model_inputs = create_qaic_model_inputs(input_len=8, vocab_size=config.vocab_size, padding_shape=padding_shape, num_hidden_layers=num_hidden_layers) + run_kv_cache_transform_and_test( hf_model, - num_hidden_layers=num_hidden_layers, - padding_shape=padding_shape, - vocab_size=config.vocab_size, - input_len=8, + qaic_model_inputs=qaic_model_inputs, logits_tolerance=logits_tolerance, kv_cache=kv_cache, ) @@ -263,12 +282,12 @@ def test_spd_transform( padding_shape = get_padding_shape_from_config(config=config, batch_size=1, seq_len=32) + # Prepare KV model inputs + qaic_model_inputs = create_qaic_model_inputs(input_len=8, vocab_size=config.vocab_size, padding_shape=padding_shape, num_hidden_layers=num_hidden_layers, is_tlm=True) + run_kv_cache_transform_and_test( hf_model, - num_hidden_layers=num_hidden_layers, - padding_shape=padding_shape, - vocab_size=config.vocab_size, - input_len=8, + qaic_model_inputs=qaic_model_inputs, logits_tolerance=logits_tolerance, kv_cache=kv_cache, is_tlm=True, From 0b8520930a2194297451c03bebba392d1a84a3c9 Mon Sep 17 00:00:00 2001 From: eplatero Date: Thu, 5 Dec 2024 13:10:15 -0600 Subject: [PATCH 28/30] rm flag from non-test definition Signed-off-by: eplatero --- tests/transformers/test_transformer_pytorch_transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/transformers/test_transformer_pytorch_transforms.py b/tests/transformers/test_transformer_pytorch_transforms.py index e1499919f..4acc510f3 100644 --- a/tests/transformers/test_transformer_pytorch_transforms.py +++ b/tests/transformers/test_transformer_pytorch_transforms.py @@ -134,7 +134,7 @@ def compare_original_vs_kv_model_pt_outputs(original_val, kv_val, tolerance=1e-6 def run_kv_cache_transform_and_test( - hf_model, qaic_model_inputs, logits_tolerance=0.8, kv_cache=None, is_tlm=False, + hf_model, qaic_model_inputs, logits_tolerance=0.8, kv_cache=None, ): hf_model.eval() # Run original model @@ -161,6 +161,7 @@ def run_kv_cache_transform_and_test( original_model_outputs = hf_model(input_ids=input_ids, output_hidden_states=True) # Apply transforms + is_tlm = "num_logits_to_keep" in qaic_model_inputs hf_model = QEFFAutoModelForCausalLM(hf_model, is_tlm=is_tlm).model @@ -290,7 +291,6 @@ def test_spd_transform( qaic_model_inputs=qaic_model_inputs, logits_tolerance=logits_tolerance, kv_cache=kv_cache, - is_tlm=True, ) From 5f52d989b1a81b843857482973fe6da37d163c16 Mon Sep 17 00:00:00 2001 From: eplatero Date: Fri, 6 Dec 2024 05:21:21 -0600 Subject: [PATCH 29/30] rm unnecessary function that is not used Signed-off-by: eplatero --- QEfficient/generation/text_generation_inference.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 4d1b9da3d..f1fe5565d 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -1066,8 +1066,3 @@ def generate( perf_metrics=perf_metrics, ) return latency_stats - - def validate_tlm_gen_tokens(self): - gen_len = (self.generated_ids) - self.prefill_seq_len - From 7b967e7cb4b38967d916d459d550c2a65de223d4 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Wed, 11 Dec 2024 00:47:12 +0530 Subject: [PATCH 30/30] ran formatter and linter Signed-off-by: Onkar Chougule --- .../generation/text_generation_inference.py | 20 +++++++----- QEfficient/utils/run_utils.py | 2 +- .../models/test_causal_lm_models.py | 7 ++-- .../test_transformer_pytorch_transforms.py | 32 +++++++++++-------- 4 files changed, 37 insertions(+), 24 deletions(-) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index f1fe5565d..4ddd57ada 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -512,10 +512,10 @@ def prepare_decode_inputs(self): position_ids = np.full((batch_size, self._decode_seq_len), -1, dtype=np.int64) position_ids[:, -1] = self.decode_pos_ids.flatten() input_ids = np.zeros((batch_size, self._decode_seq_len), dtype=np.int64) - input_ids[:,-1] = self.decode_input_ids.flatten() + input_ids[:, -1] = self.decode_input_ids.flatten() decode_inputs["input_ids"] = input_ids decode_inputs["position_ids"] = position_ids - decode_inputs["num_logits_to_keep"] = np.zeros((self._decode_seq_len,1)) + decode_inputs["num_logits_to_keep"] = np.zeros((self._decode_seq_len, 1)) else: decode_inputs["input_ids"] = self.decode_input_ids decode_inputs["position_ids"] = self.decode_pos_ids @@ -660,7 +660,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if decode_batch_id is not None: inputs["batch_index"] = decode_batch_id if self.is_tlm: - inputs["num_logits_to_keep"] = np.zeros((1,1)) + inputs["num_logits_to_keep"] = np.zeros((1, 1)) if self._prompt_to_lora_id_mapping_prefill: if self.full_batch_size: @@ -701,7 +701,9 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): """ # Set logits placeholder for decode - logits_out_placeholder = np.zeros((self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32) + logits_out_placeholder = np.zeros( + (self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32 + ) self._session.set_buffers({"logits": logits_out_placeholder}) # Generate flag for tracking progress for each batch ID current_decode_ongoing = np.full((self.full_batch_size, 1), True) @@ -727,7 +729,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): for decode_batch_id in range(self.full_batch_size): if ( - next_token_id[decode_batch_id,-1] == self.tokenizer.eos_token_id + next_token_id[decode_batch_id, -1] == self.tokenizer.eos_token_id or generated_id_current_index[decode_batch_id] >= self.generation_len[decode_batch_id] ): if prompt_queue: @@ -781,7 +783,9 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform num_token (int): The number of tokens processed in the decoding process. """ if self.is_tlm: - logits_out_placeholder = np.zeros((self.batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32) + logits_out_placeholder = np.zeros( + (self.batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32 + ) self._session.set_buffers({"logits": logits_out_placeholder}) finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id num_token = 0 @@ -796,8 +800,8 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform # Prepare inputs for next iteration decode_inputs["input_ids"] = outputs["logits"].argmax(2) - decode_inputs["position_ids"][:,-1] += 1 - self.generated_ids[:, num_token] = decode_inputs["input_ids"][:,-1] + decode_inputs["position_ids"][:, -1] += 1 + self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1] finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id if finished_sequences.all(): diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index 1d4d4516c..267b2bb9e 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -198,7 +198,7 @@ def run_kv_model_on_ort(self, model_path, is_tlm=False): generated_ids = [] inputs = self.input_handler.prepare_ort_inputs() if is_tlm: - nltk = np.zeros((1,1), dtype=np.int64) + nltk = np.zeros((1, 1), dtype=np.int64) inputs["num_logits_to_keep"] = nltk ort_outputs = self.run_ort_session(inputs, session) ort_outputs = self.input_handler.update_ort_outputs(ort_outputs) diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 9162f7660..6e91711e0 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -44,6 +44,7 @@ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", ] + def load_causal_lm_model(model_config): """ Function to load model from huggingface and transform to KV model @@ -166,7 +167,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( mxfp6=False, aic_enable_depth_first=False, full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens + num_speculative_tokens=num_speculative_tokens, ) exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts) @@ -237,7 +238,9 @@ def test_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): else: n_layer = 1 - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, num_speculative_tokens=Constants.NUM_SPECULATIVE_TOKENS) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, n_layer=n_layer, num_speculative_tokens=Constants.NUM_SPECULATIVE_TOKENS + ) @pytest.mark.on_qaic diff --git a/tests/transformers/test_transformer_pytorch_transforms.py b/tests/transformers/test_transformer_pytorch_transforms.py index 4acc510f3..e6a7d4588 100644 --- a/tests/transformers/test_transformer_pytorch_transforms.py +++ b/tests/transformers/test_transformer_pytorch_transforms.py @@ -14,7 +14,7 @@ from QEfficient.customop.matmulnbits import QuantLinearORT from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM -from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform +from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform from QEfficient.transformers.quantizers.awq import WQLinear_GEMM from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform @@ -73,12 +73,9 @@ ("llama", 1, 32, 128, {"num_key_value_heads": 32, "intermediate_size": 512}, 0.8), ] + def create_qaic_model_inputs( - input_len: int, - vocab_size: int, - padding_shape: tuple, - num_hidden_layers: int, - is_tlm: bool = False + input_len: int, vocab_size: int, padding_shape: tuple, num_hidden_layers: int, is_tlm: bool = False ) -> dict: """create pytorch QEff model inputs @@ -110,6 +107,7 @@ def create_qaic_model_inputs( inputs["num_logits_to_keep"] = torch.zeros((input_len, 1)) return inputs + def compare_original_vs_kv_model_pt_outputs(original_val, kv_val, tolerance=1e-6) -> bool: # Base case if original_val is None: @@ -134,7 +132,10 @@ def compare_original_vs_kv_model_pt_outputs(original_val, kv_val, tolerance=1e-6 def run_kv_cache_transform_and_test( - hf_model, qaic_model_inputs, logits_tolerance=0.8, kv_cache=None, + hf_model, + qaic_model_inputs, + logits_tolerance=0.8, + kv_cache=None, ): hf_model.eval() # Run original model @@ -164,7 +165,6 @@ def run_kv_cache_transform_and_test( is_tlm = "num_logits_to_keep" in qaic_model_inputs hf_model = QEFFAutoModelForCausalLM(hf_model, is_tlm=is_tlm).model - # Run KV model with torch.inference_mode(): transformed_model_outputs = hf_model(**qaic_model_inputs) @@ -245,7 +245,9 @@ def test_kv_cache_transform( padding_shape = get_padding_shape_from_config(config=config, batch_size=1, seq_len=32) # Prepare KV model inputs - qaic_model_inputs = create_qaic_model_inputs(input_len=8, vocab_size=config.vocab_size, padding_shape=padding_shape, num_hidden_layers=num_hidden_layers) + qaic_model_inputs = create_qaic_model_inputs( + input_len=8, vocab_size=config.vocab_size, padding_shape=padding_shape, num_hidden_layers=num_hidden_layers + ) run_kv_cache_transform_and_test( hf_model, @@ -259,9 +261,7 @@ def test_kv_cache_transform( "config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance", SpDTransformTestConfigs, ) -def test_spd_transform( - config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance -): +def test_spd_transform(config_class, num_hidden_layers, num_attention_heads, hidden_size, kwargs, logits_tolerance): config = AutoConfig.for_model( config_class, **kwargs, @@ -284,7 +284,13 @@ def test_spd_transform( padding_shape = get_padding_shape_from_config(config=config, batch_size=1, seq_len=32) # Prepare KV model inputs - qaic_model_inputs = create_qaic_model_inputs(input_len=8, vocab_size=config.vocab_size, padding_shape=padding_shape, num_hidden_layers=num_hidden_layers, is_tlm=True) + qaic_model_inputs = create_qaic_model_inputs( + input_len=8, + vocab_size=config.vocab_size, + padding_shape=padding_shape, + num_hidden_layers=num_hidden_layers, + is_tlm=True, + ) run_kv_cache_transform_and_test( hf_model,