Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Add KV cache group dimension to block table #12086

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
127 changes: 65 additions & 62 deletions tests/v1/core/test_prefix_caching.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
max_num_blocks_per_req=10,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024)
vocab_size=1024,
num_kv_cache_groups=1)
reqs: List[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def unified_attention(
attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
self._k_scale, self._v_scale)

Expand Down Expand Up @@ -276,6 +278,8 @@ def unified_attention_with_output(
attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self.impl.forward(query,
key,
value,
Expand Down
21 changes: 16 additions & 5 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import torch

Expand All @@ -24,11 +24,22 @@

@dataclass
class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context
"""
Map from layer_name to all attention modules
copy from vllm_config.compilation_config.static_forward_context
"""
attn_layers: Dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
"""
Type AttentionMetadata for v0,
Type Dict[str, AttentionMetadata] for v1, mapping from layer_name to
AttentionMetadata of that layer
set dynamically for each forward pass
"""
attn_metadata: Union["AttentionMetadata", Dict[str, "AttentionMetadata"]]
"""
The virtual_engine for v0 pipeline parallelism
set dynamically for each forward pass
"""
virtual_engine: int # set dynamically for each forward pass


Expand Down
71 changes: 45 additions & 26 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock,
KVCacheBlock, KVCacheBlocks,
ReqKVCacheBlocks,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens)
Expand All @@ -14,6 +15,11 @@


class KVCacheManager:
"""
The KVCacheManager for models with one KV cache type (e.g., Llama) and
thus one kv cache group (Refer to class `KVCacheConfig` for the meaning of
kv cache group).
"""

def __init__(
self,
Expand Down Expand Up @@ -67,10 +73,13 @@ def __init__(
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
# is finished.
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {}
# KVCacheManager only supports models with one kv cache group, so we
# save KVCachedBlocks of that group instead of ReqKVCacheBlocks for
# simplicity.
self.req_to_blocks: Dict[str, KVCacheBlocks] = {}

def get_computed_blocks(
self, request: Request) -> Tuple[List[KVCacheBlock], int]:
def get_computed_blocks(self,
request: Request) -> Tuple[ReqKVCacheBlocks, int]:
"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.

Expand All @@ -79,21 +88,21 @@ def get_computed_blocks(

Returns:
A tuple containing:
- A list of blocks that are computed for the request.
- The blocks that are computed for the request
- The number of computed tokens.
"""
if not self.enable_caching:
# Prefix caching is disabled.
return [], 0
return [[]], 0

computed_blocks = []

# The block hashes for the request may already be computed
# if the request was preempted and resumed.
if not request.kv_block_hashes:
request.set_kv_block_hashes(
hash_request_tokens(self.block_size, request))
block_hashes = request.kv_block_hashes
[hash_request_tokens(self.block_size, request)])
block_hashes = request.kv_block_hashes[0]

for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
Expand All @@ -108,13 +117,13 @@ def get_computed_blocks(
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens
return [computed_blocks], num_computed_tokens

def append_slots(
self,
request: Request,
num_tokens: int,
) -> Optional[List[KVCacheBlock]]:
) -> Optional[ReqKVCacheBlocks]:
"""Append slots to the block table of the request.
We first append slots to already allocated blocks. If the allocated
blocks are not enough, we allocate new blocks.
Expand All @@ -124,8 +133,8 @@ def append_slots(
num_tokens: The number of tokens to append.

Returns:
A list of new blocks if new blocks are allocated, or None
if new blocks are required but cannot be allocated.
The new blocks if new blocks are allocated, or None if new blocks
are required but cannot be allocated.
"""
num_required_blocks = cdiv(request.num_computed_tokens + num_tokens,
self.block_size)
Expand Down Expand Up @@ -159,7 +168,7 @@ def append_slots(
req_blocks.extend(new_blocks)

if not self.enable_caching:
return new_blocks
return [new_blocks]

num_computed_full_blocks = (request.num_computed_tokens //
self.block_size)
Expand All @@ -182,31 +191,36 @@ def append_slots(
full_blocks=new_full_blocks,
prev_block=req_blocks[num_computed_full_blocks - 1]
if num_computed_full_blocks >= 1 else None,
kv_cache_group_id=0,
)

return new_blocks
return [new_blocks]

def allocate_slots(
self,
request: Request,
num_tokens: int,
computed_blocks: List[KVCacheBlock],
) -> Optional[List[KVCacheBlock]]:
computed_blocks_all_groups: ReqKVCacheBlocks,
) -> Optional[ReqKVCacheBlocks]:
"""Allocate slots for a new request.

Args:
request: The request to allocate slots.
num_tokens: The number of tokens to allocate. Note that this does
not include the tokens that have already been computed.
computed_blocks: A list of computed blocks.
computed_blocks_all_groups: The computed blocks. Should contain
only one KV cache group.

Returns:
A list of new allocated blocks.
The new blocks if new blocks are allocated, or None if new blocks
are required but cannot be allocated.
"""
if num_tokens == 0:
raise ValueError(
f"num_tokens must be greater than 0, got {num_tokens}")

assert len(computed_blocks_all_groups) == 1
computed_blocks = computed_blocks_all_groups[0]
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
Expand Down Expand Up @@ -246,7 +260,7 @@ def allocate_slots(
self.req_to_blocks[request.request_id] = computed_blocks + new_blocks

if not self.enable_caching:
return new_blocks
return [new_blocks]

num_computed_tokens = len(computed_blocks) * self.block_size
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
Expand All @@ -260,9 +274,10 @@ def allocate_slots(
# The new full blocks are the full blocks that are not computed.
full_blocks=new_full_blocks,
prev_block=computed_blocks[-1] if computed_blocks else None,
kv_cache_group_id=0,
)

return new_blocks
return [new_blocks]

def free(self, request: Request) -> None:
"""Free the blocks allocated for the request.
Expand All @@ -289,7 +304,7 @@ def get_num_common_prefix_blocks(
self,
request: Request,
num_running_requests: int,
) -> int:
) -> List[int]:
"""Calculate the number of common prefix blocks shared by all requests
in the RUNNING state.

Expand Down Expand Up @@ -323,7 +338,7 @@ def get_num_common_prefix_blocks(
requests in the current step.

Returns:
int: The number of common prefix blocks.
List[int]: The number of common prefix blocks per KV cache group.
"""
assert request.status == RequestStatus.RUNNING
blocks = self.req_to_blocks[request.request_id]
Expand All @@ -333,7 +348,7 @@ def get_num_common_prefix_blocks(
num_common_blocks += 1
else:
break
return num_common_blocks
return [num_common_blocks]

def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
"""Get new blocks from the free block pool.
Expand Down Expand Up @@ -421,6 +436,7 @@ def _cache_full_blocks(
blk_start_idx: int,
full_blocks: List[KVCacheBlock],
prev_block: Optional[KVCacheBlock],
kv_cache_group_id: int,
) -> None:
"""Cache a list of full blocks for prefix caching.

Expand All @@ -436,8 +452,10 @@ def _cache_full_blocks(
to cache.
full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain.
kv_cache_group_id: The KV cache group ID that the blocks belong to
"""
num_cached_block_hashes = len(request.kv_block_hashes)
num_cached_block_hashes = len(
request.kv_block_hashes[kv_cache_group_id])

# Update the new blocks with the block hashes through the chain.
prev_block_hash_value = None
Expand All @@ -456,7 +474,8 @@ def _cache_full_blocks(
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
block_hash = request.kv_block_hashes[blk_idx]
block_hash = request.kv_block_hashes[kv_cache_group_id][
blk_idx]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
Expand All @@ -478,7 +497,7 @@ def _cache_full_blocks(
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
block_tokens, extra_keys)
request.append_kv_block_hashes(block_hash)
request.append_kv_block_hashes(kv_cache_group_id, block_hash)

# Update and added the full block to the cache.
blk.block_hash = block_hash
Expand Down
Loading
Loading