diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 7f7f15bea86fa..6627ba1ea5643 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -94,10 +94,46 @@ def warmup_range(config: Tuple[int, int, int]): return list(ramp_up_tw) + list(stable) -def warmup_buckets(bs_bucket_config, seq_bucket_config): - buckets = itertools.product(warmup_range(bs_bucket_config), - warmup_range(seq_bucket_config)) - return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) +def warmup_buckets(bs_bucket_config, seq_bucket_config, + max_num_batched_tokens): + buckets = list( + itertools.product(warmup_range(bs_bucket_config), + warmup_range(seq_bucket_config))) + if len(buckets) == 0: + msg = ("No buckets could be captured with following config " + f"(min, step, max_warmup): " + f"bs:{bs_bucket_config}, " + f"seq:{seq_bucket_config}") + raise ValueError(msg) + + # Remove buckets exceeding batch token budget + filtered_buckets = list( + filter(lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, + buckets)) + + if len(filtered_buckets) == 0: + # legacy case - we can handle this if we ignore max_num_batched_tokens + min_bucket_bs, min_bucket_seq = min(buckets, + key=lambda b: (b[0] * b[1])) + min_reqd_budget = min_bucket_bs * min_bucket_seq + msg = ( + "The current bucketing configuration " + f"(min, step, max_warmup): " + f"bs:{bs_bucket_config}, " + f"seq:{seq_bucket_config} cannot be used with specified " + f"max_num_batched_tokens ({max_num_batched_tokens}), as the " + f"smallest bucket ({min_reqd_budget}) would exceed token budget. " + "Please increase max_num_batched_tokens or decrease bucket minimum " + "Ignoring max_num_batched_tokens at risk of out-of-memory errors.") + logger.error(msg) + return list(sorted(buckets, key=lambda b: + (b[0] * b[1], b[1], b[0]))), [] + + captured_buckets = list( + sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) + omitted_buckets = list( + sorted([x for x in buckets if x not in filtered_buckets])) + return captured_buckets, omitted_buckets def next_pow2(value: int): @@ -525,8 +561,9 @@ def _setup_buckets(self) -> None: f"bs:{self.prompt_bs_bucket_cfg}, " f"seq:{self.prompt_seq_bucket_cfg}") logger.info(msg) - self.prompt_buckets = warmup_buckets(self.prompt_bs_bucket_cfg, - self.prompt_seq_bucket_cfg) + self.prompt_buckets, prompt_omitted_buckets = warmup_buckets( + self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, + self.max_num_batched_tokens) if self.lora_config: self.prompt_buckets[:] = [ @@ -538,12 +575,21 @@ def _setup_buckets(self) -> None: f"prompt buckets: {list(sorted(self.prompt_buckets))}") logger.info(msg) + msg = (f"Omitted {len(prompt_omitted_buckets)} " + "prompt buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") + logger.info(msg) + + msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" + logger.debug(msg) + msg = ("Decode bucket config (min, step, max_warmup) " f"bs:{self.decode_bs_bucket_cfg}, " f"seq:{self.decode_seq_bucket_cfg}") logger.info(msg) - self.decode_buckets = warmup_buckets(self.decode_bs_bucket_cfg, - self.decode_seq_bucket_cfg) + self.decode_buckets, decode_omitted_buckets = warmup_buckets( + self.decode_bs_bucket_cfg, self.decode_seq_bucket_cfg, + self.max_num_batched_tokens) if self.lora_config: self.decode_buckets[:] = [ bucket for bucket in self.decode_buckets @@ -553,6 +599,14 @@ def _setup_buckets(self) -> None: f"{list(sorted(self.decode_buckets))}") logger.info(msg) + msg = (f"Omitted {len(decode_omitted_buckets)} " + "decode buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") + logger.info(msg) + + msg = f"Omitted decode buckets: {list(sorted(decode_omitted_buckets))}" + logger.debug(msg) + def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata],