Skip to content

Commit

Permalink
Merge branch 'habana_main' into yuwen/avoid_graph_break
Browse files Browse the repository at this point in the history
  • Loading branch information
yuwenzho committed Sep 3, 2024
2 parents 941963f + 9abadba commit fc48f7e
Showing 1 changed file with 62 additions and 8 deletions.
70 changes: 62 additions & 8 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,46 @@ def warmup_range(config: Tuple[int, int, int]):
return list(ramp_up_tw) + list(stable)


def warmup_buckets(bs_bucket_config, seq_bucket_config):
buckets = itertools.product(warmup_range(bs_bucket_config),
warmup_range(seq_bucket_config))
return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))
def warmup_buckets(bs_bucket_config, seq_bucket_config,
max_num_batched_tokens):
buckets = list(
itertools.product(warmup_range(bs_bucket_config),
warmup_range(seq_bucket_config)))
if len(buckets) == 0:
msg = ("No buckets could be captured with following config "
f"(min, step, max_warmup): "
f"bs:{bs_bucket_config}, "
f"seq:{seq_bucket_config}")
raise ValueError(msg)

# Remove buckets exceeding batch token budget
filtered_buckets = list(
filter(lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens,
buckets))

if len(filtered_buckets) == 0:
# legacy case - we can handle this if we ignore max_num_batched_tokens
min_bucket_bs, min_bucket_seq = min(buckets,
key=lambda b: (b[0] * b[1]))
min_reqd_budget = min_bucket_bs * min_bucket_seq
msg = (
"The current bucketing configuration "
f"(min, step, max_warmup): "
f"bs:{bs_bucket_config}, "
f"seq:{seq_bucket_config} cannot be used with specified "
f"max_num_batched_tokens ({max_num_batched_tokens}), as the "
f"smallest bucket ({min_reqd_budget}) would exceed token budget. "
"Please increase max_num_batched_tokens or decrease bucket minimum "
"Ignoring max_num_batched_tokens at risk of out-of-memory errors.")
logger.error(msg)
return list(sorted(buckets, key=lambda b:
(b[0] * b[1], b[1], b[0]))), []

captured_buckets = list(
sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))
omitted_buckets = list(
sorted([x for x in buckets if x not in filtered_buckets]))
return captured_buckets, omitted_buckets


def next_pow2(value: int):
Expand Down Expand Up @@ -525,8 +561,9 @@ def _setup_buckets(self) -> None:
f"bs:{self.prompt_bs_bucket_cfg}, "
f"seq:{self.prompt_seq_bucket_cfg}")
logger.info(msg)
self.prompt_buckets = warmup_buckets(self.prompt_bs_bucket_cfg,
self.prompt_seq_bucket_cfg)
self.prompt_buckets, prompt_omitted_buckets = warmup_buckets(
self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg,
self.max_num_batched_tokens)

if self.lora_config:
self.prompt_buckets[:] = [
Expand All @@ -538,12 +575,21 @@ def _setup_buckets(self) -> None:
f"prompt buckets: {list(sorted(self.prompt_buckets))}")
logger.info(msg)

msg = (f"Omitted {len(prompt_omitted_buckets)} "
"prompt buckets due to exceeded token budget "
f"(max_num_batched_tokens={self.max_num_batched_tokens})")
logger.info(msg)

msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}"
logger.debug(msg)

msg = ("Decode bucket config (min, step, max_warmup) "
f"bs:{self.decode_bs_bucket_cfg}, "
f"seq:{self.decode_seq_bucket_cfg}")
logger.info(msg)
self.decode_buckets = warmup_buckets(self.decode_bs_bucket_cfg,
self.decode_seq_bucket_cfg)
self.decode_buckets, decode_omitted_buckets = warmup_buckets(
self.decode_bs_bucket_cfg, self.decode_seq_bucket_cfg,
self.max_num_batched_tokens)
if self.lora_config:
self.decode_buckets[:] = [
bucket for bucket in self.decode_buckets
Expand All @@ -553,6 +599,14 @@ def _setup_buckets(self) -> None:
f"{list(sorted(self.decode_buckets))}")
logger.info(msg)

msg = (f"Omitted {len(decode_omitted_buckets)} "
"decode buckets due to exceeded token budget "
f"(max_num_batched_tokens={self.max_num_batched_tokens})")
logger.info(msg)

msg = f"Omitted decode buckets: {list(sorted(decode_omitted_buckets))}"
logger.debug(msg)

def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
Expand Down

0 comments on commit fc48f7e

Please sign in to comment.