From 82d669736e961d29d60d5f99c9e4ec7f326ca1f8 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Fri, 11 Oct 2024 20:31:09 +0400 Subject: [PATCH] disable assert for nonequal output for cb (#949) Co-authored-by: Andrei Kochin --- llm_bench/python/benchmark.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/llm_bench/python/benchmark.py b/llm_bench/python/benchmark.py index 9dcfe74f66..790e1fa919 100644 --- a/llm_bench/python/benchmark.py +++ b/llm_bench/python/benchmark.py @@ -202,13 +202,14 @@ def run_text_generation(input_text, num, model, tokenizer, args, iter_data_list, log.warning(f"[{num}] Prompt[{prompt_index}]'s md5 {result_md5_list} " f"is different from md5 of the {num - 1} iteration {prev_md5}") llm_bench_utils.metrics_print.print_generated(num, warm_up=(num == 0), generated=generated_text[0]) - if num == 1: - # if the device is CPU, throw exception - if args['devices'].lower().startswith('cpu') is True: + if not args.get("use_cb", False): + if num == 1: + # if the device is CPU, throw exception + if args['devices'].lower().startswith('cpu') is True: + assert (result_md5_list == prev_md5) + else: + # throw exception assert (result_md5_list == prev_md5) - else: - # throw exception - assert (result_md5_list == prev_md5) else: llm_bench_utils.metrics_print.print_generated(num, warm_up=(num == 0), generated=generated_text[0]) if bench_hook is not None: @@ -412,13 +413,14 @@ def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, arg log.warning(f"[{num}] Prompt[{prompt_index}]'s md5 {result_md5_list} " f"is different from md5 of the {num - 1} iteration {prev_md5}") llm_bench_utils.metrics_print.print_generated(num, warm_up=(num == 0), generated=generated_text[0]) - if num == 1: - # if the device is CPU, throw exception - if args['devices'].lower().startswith('cpu') is True: + if not args.get("use_cb", False): + if num == 1: + # if the device is CPU, throw exception + if args['devices'].lower().startswith('cpu') is True: + assert (result_md5_list == prev_md5) + else: + # throw exception assert (result_md5_list == prev_md5) - else: - # throw exception - assert (result_md5_list == prev_md5) else: llm_bench_utils.metrics_print.print_generated(num, warm_up=(num == 0), generated=generated_text[0]) streamer.reset()