Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[eplatero] Add support for exporting and compiling models for SpD #119

Merged
merged 30 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
457d5ae
rebasing with main. previous local gen_spd_models was broken since it…
eplatero97 Nov 18, 2024
6aae287
add decode_seq_len to non-continuous batching case
eplatero97 Nov 18, 2024
7472df8
mirror test_causal_lm_models.py from main
eplatero97 Nov 18, 2024
bc702b0
add more to the explanation of the model changes
eplatero97 Nov 18, 2024
e630b8f
lint fixing
eplatero97 Nov 19, 2024
840cb9f
alphabetical order imports on pytorch_transforms.py
eplatero97 Nov 19, 2024
ccdcfb7
add init to spd directory
eplatero97 Nov 20, 2024
ed57de7
replace modifying seq_len by letting user define proper config
eplatero97 Nov 20, 2024
0a0683d
resolving 1st round comments from Onkar and made fix on gather implem…
eplatero97 Nov 21, 2024
37b7b71
removing old unit tests
eplatero97 Nov 21, 2024
15f95b3
* Added way to make num_logits_to_keep dynamic in ONNX and removed ne…
ochougul Nov 21, 2024
e9309a3
changed interface to be similar to CB
ochougul Nov 21, 2024
f7917d6
made unit tests work with array approach
eplatero97 Nov 21, 2024
806ef1a
for TLM, made specialization return 1 logit for prefill and for decode
eplatero97 Nov 21, 2024
7dbb583
moved from to method because this flag only has implications for c…
eplatero97 Nov 21, 2024
a713a2c
fixing qpc directory naming to be backwards compatible
eplatero97 Nov 22, 2024
e0c150f
updating docstrings and documentation
eplatero97 Nov 22, 2024
12d2749
revert changes to CLI exportation of onnx and specialization to refle…
eplatero97 Nov 22, 2024
2bba06b
fixed specializations creation and ran formatter
ochougul Nov 25, 2024
547ee41
add pytorch-level unit test
eplatero97 Dec 2, 2024
f9826c7
uncommented non-llama pytorch-level unit test
eplatero97 Dec 2, 2024
91f2fd7
modified pytorch level unit test and added hf vs ort vs qaic unit test
eplatero97 Dec 3, 2024
062b4d5
change llama test model from jackfram to tinyllama to match other tests
eplatero97 Dec 4, 2024
6ad5a69
fix failing tlm_dlm tests by passing is_tlm correctly in modeling_auto
eplatero97 Dec 4, 2024
58458c0
rm dlm specialization
eplatero97 Dec 4, 2024
39acd5f
updated quick_docs
eplatero97 Dec 4, 2024
6202874
rm tlm dims test since that's already tested and generalize common co…
eplatero97 Dec 5, 2024
0b85209
rm flag from non-test definition
eplatero97 Dec 5, 2024
5f52d98
rm unnecessary function that is not used
eplatero97 Dec 6, 2024
7b967e7
ran formatter and linter
ochougul Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def _compile(
specializations: Optional[List[Dict[str, int]]] = None,
custom_io: Optional[Dict[str, str]] = None,
mdp_ts_num_devices: int = 1,
num_speculative_tokens: Optional[int] = None,
**compiler_options,
) -> str:
"""
Expand All @@ -212,6 +213,7 @@ def _compile(
:specializations (list): List of specializations to compile for
:custom_io (dict): Custom IO to specify the input and outputs in different formats than default
:mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing.
:num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model.
:compiler_options: Pass any compiler option as input. Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below:
- aic_num_cores=16 -> -aic-num-cores=16
- convert_to_fp16=True -> -convert-to-fp16
Expand Down Expand Up @@ -244,6 +246,9 @@ def _compile(
if mdp_ts_num_devices > 1:
compile_hash.update(to_hashable({"mdp_ts_num_devices": mdp_ts_num_devices}))

if num_speculative_tokens:
compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens}))

# Check if already compiled
compile_hash = compile_hash.hexdigest()[:16]
qpc_path = qpc_path.with_name(qpc_path.name + "-" + compile_hash)
Expand Down
61 changes: 51 additions & 10 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def cloud_ai_100_exec_kv(
write_io_dir: Optional[str] = None,
automation=False,
prompt_to_lora_id_mapping: Optional[List[int]] = None,
is_tlm: bool = False,
):
"""
This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
Expand Down Expand Up @@ -319,6 +320,7 @@ def cloud_ai_100_exec_kv(
enable_debug_logs=enable_debug_logs,
write_io_dir=write_io_dir,
full_batch_size=full_batch_size,
is_tlm=is_tlm,
)
if full_batch_size is None:
exec_info = [
Expand Down Expand Up @@ -355,16 +357,19 @@ def __init__(
device_id: Optional[List[int]] = None,
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
is_tlm: Optional[int] = None,
) -> None:
self._ctx_len = ctx_len
self._write_io_dir = write_io_dir
self.is_tlm = is_tlm

# Load QPC
self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs)

# Fetch the variables from the QPC
self._vocab_size = self._fetch_vocab_size() # Fetch Vocab size
self.batch_size, self._prefill_seq_len = self._fetch_batch_size_prefill_seq_len()
self._decode_seq_len = self._fetch_decode_seq_len()
self.full_batch_size = (
full_batch_size if full_batch_size else self._fetch_full_batch_size()
) # Check and fetch full batch size if CB is enabled
Expand Down Expand Up @@ -441,6 +446,22 @@ def _fetch_batch_size_prefill_seq_len(
batch_size, prefill_seq_len = self._session.bindings[self._session.binding_index_map["input_ids"]].dims
return batch_size, prefill_seq_len

def _fetch_decode_seq_len(
self,
):
"""
Fetches the decode sequence length from the session's bindings or allowed shapes.

Returns:
decode_seq_len: The decode sequence length fetched from the session's bindings or allowed shapes.
"""
decode_seq_len = None
if self._session.allowed_shapes:
decode_seq_len = min(
[x[self._session.binding_index_map["input_ids"]][1][1] for x in self._session.allowed_shapes]
)
return decode_seq_len

def _fetch_vocab_size(
self,
):
Expand Down Expand Up @@ -485,9 +506,19 @@ def prepare_decode_inputs(self):
Returns:
dict: The decode inputs.
"""
batch_size = self.full_batch_size if self.full_batch_size is not None else self.batch_size
decode_inputs = {}
decode_inputs["input_ids"] = self.decode_input_ids
decode_inputs["position_ids"] = self.decode_pos_ids
if self.is_tlm:
position_ids = np.full((batch_size, self._decode_seq_len), -1, dtype=np.int64)
position_ids[:, -1] = self.decode_pos_ids.flatten()
input_ids = np.zeros((batch_size, self._decode_seq_len), dtype=np.int64)
input_ids[:, -1] = self.decode_input_ids.flatten()
decode_inputs["input_ids"] = input_ids
decode_inputs["position_ids"] = position_ids
decode_inputs["num_logits_to_keep"] = np.zeros((self._decode_seq_len, 1))
else:
decode_inputs["input_ids"] = self.decode_input_ids
decode_inputs["position_ids"] = self.decode_pos_ids
if self.batch_index is not None:
decode_inputs["batch_index"] = self.batch_index

Expand Down Expand Up @@ -628,6 +659,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i

if decode_batch_id is not None:
inputs["batch_index"] = decode_batch_id
if self.is_tlm:
inputs["num_logits_to_keep"] = np.zeros((1, 1))

if self._prompt_to_lora_id_mapping_prefill:
if self.full_batch_size:
Expand Down Expand Up @@ -668,7 +701,9 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
"""

# Set logits placeholder for decode
logits_out_placeholder = np.zeros((self.full_batch_size, 1, self._vocab_size), dtype=np.float32)
logits_out_placeholder = np.zeros(
(self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32
)
self._session.set_buffers({"logits": logits_out_placeholder})
# Generate flag for tracking progress for each batch ID
current_decode_ongoing = np.full((self.full_batch_size, 1), True)
Expand All @@ -694,7 +729,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):

for decode_batch_id in range(self.full_batch_size):
if (
next_token_id[decode_batch_id] == self.tokenizer.eos_token_id
next_token_id[decode_batch_id, -1] == self.tokenizer.eos_token_id
or generated_id_current_index[decode_batch_id] >= self.generation_len[decode_batch_id]
):
if prompt_queue:
Expand Down Expand Up @@ -724,10 +759,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
current_decode_ongoing[decode_batch_id] = False
else:
# If the generated sequence is valid and within generation len prepare for next decode
decode_inputs["input_ids"][decode_batch_id] = next_token_id[decode_batch_id]
decode_inputs["position_ids"][decode_batch_id] += 1
decode_inputs["input_ids"][decode_batch_id, -1] = next_token_id[decode_batch_id, -1]
decode_inputs["position_ids"][decode_batch_id, -1] += 1
self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = (
next_token_id[decode_batch_id]
next_token_id[decode_batch_id, -1]
)

generated_id_current_index[decode_batch_id] += 1
Expand All @@ -747,6 +782,11 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
Returns:
num_token (int): The number of tokens processed in the decoding process.
"""
if self.is_tlm:
logits_out_placeholder = np.zeros(
(self.batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32
)
self._session.set_buffers({"logits": logits_out_placeholder})
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
num_token = 0
for num_token in range(1, generation_len):
Expand All @@ -760,8 +800,8 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform

# Prepare inputs for next iteration
decode_inputs["input_ids"] = outputs["logits"].argmax(2)
decode_inputs["position_ids"] += 1
self.generated_ids[:, num_token] = decode_inputs["input_ids"].squeeze(1)
decode_inputs["position_ids"][:, -1] += 1
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id

if finished_sequences.all():
Expand Down Expand Up @@ -811,9 +851,10 @@ def __init__(
device_id: Optional[List[int]] = None,
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
is_tlm: bool = False,
) -> None:
self._qaic_model = QEffTextGenerationBase(
tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir
tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm
)
self._full_batch_size = self._qaic_model.full_batch_size
self._tokenizer = self._qaic_model.tokenizer
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/peft/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM
from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform
from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform
from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform
from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform
from QEfficient.utils import constants
from QEfficient.utils._utils import get_padding_shape_from_config
from QEfficient.utils.cache import to_hashable
Expand Down
Loading
Loading