Skip to content

Commit

Permalink
[Fix] The benchmark logic when internal kv cache is involved (#1377)
Browse files Browse the repository at this point in the history
* initial commit

* fix indentation of logging
  • Loading branch information
dbogunowicz authored Nov 2, 2023
1 parent 28c666c commit 02348f2
Showing 1 changed file with 40 additions and 25 deletions.
65 changes: 40 additions & 25 deletions src/deepsparse/benchmark/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,37 +364,52 @@ def benchmark_model(
model_path = model_to_path(model_path)

cached_outputs = None
if not disable_kv_cache_overrides and has_model_kv_cache(model_path):
if not sequence_length:
sequence_length = infer_sequence_length(model_path)
if input_ids_length > sequence_length:
if has_model_kv_cache(model_path):
if not disable_kv_cache_overrides:
if not sequence_length:
sequence_length = infer_sequence_length(model_path)
if input_ids_length > sequence_length:
raise ValueError(
f"input_ids_length: {input_ids_length} "
f"must be less than sequence_length: {sequence_length}"
)

_LOGGER.info(
"Found model with KV cache support. "
"Benchmarking the autoregressive model with "
f"input_ids_length: {input_ids_length} and "
f"sequence length: {sequence_length}."
)

(
model_path,
cached_outputs,
_,
) = overwrite_onnx_model_inputs_for_kv_cache_models(
onnx_file_path=model_path,
input_ids_length=input_ids_length,
sequence_length=sequence_length,
batch_size=batch_size,
)

if internal_kv_cache and engine != DEEPSPARSE_ENGINE:
raise ValueError(
f"input_ids_length: {input_ids_length} "
f"must be less than sequence_length: {sequence_length}"
"Attempting to benchmark a model using engine: "
f"{engine} and internal_kv_cache set to True. "
"The use of internal_kv_cache is only "
f"supported for the engine: {DEEPSPARSE_ENGINE}. "
f"To disable the use of the internal_kv_cache, "
f"set the flag: --no-internal-kv-cache"
)

_LOGGER.info(
"Found model with KV cache support. "
"Benchmarking the autoregressive model with "
f"input_ids_length: {input_ids_length} and "
f"sequence length: {sequence_length}."
f"Benchmarking Engine: {engine} with "
f"{'internal' if internal_kv_cache else 'external'} KV cache management"
)

model_path, cached_outs, _ = overwrite_onnx_model_inputs_for_kv_cache_models(
onnx_file_path=model_path,
input_ids_length=input_ids_length,
sequence_length=sequence_length,
batch_size=batch_size,
)

if internal_kv_cache:
_LOGGER.info(
"Benchmarking DeepSparse Engine with internal KV Cache management"
)
cached_outputs = cached_outs
else:
input_ids_length = None
sequence_length = None
internal_kv_cache = False

num_streams = parse_num_streams(num_streams, num_cores, scenario)

Expand All @@ -407,7 +422,7 @@ def benchmark_model(
num_streams=num_streams,
scheduler=scheduler,
input_shapes=input_shapes,
cached_outputs=cached_outputs,
cached_outputs=cached_outputs if internal_kv_cache else None,
)
elif engine == ORT_ENGINE:
model = ORTEngine(
Expand Down Expand Up @@ -450,7 +465,7 @@ def benchmark_model(
seconds_to_run=time,
seconds_to_warmup=warmup_time,
num_streams=num_streams,
internal_kv_cache=cached_outputs,
internal_kv_cache=internal_kv_cache,
)
export_dict = {
"engine": str(model),
Expand Down

0 comments on commit 02348f2

Please sign in to comment.