From 62354d77d37c14cd633c701e6402df3d92e1e20a Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Sat, 2 Nov 2024 17:00:34 -0700 Subject: [PATCH] dll Signed-off-by: Cody Yu --- tests/v1/core/test_prefix_caching.py | 4 +- vllm/v1/core/kv_cache_manager.py | 304 ++++++++++++++------------- 2 files changed, 165 insertions(+), 143 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 49f939b469eea..06b994866991d 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -39,7 +39,7 @@ def test_prefill(): # Check full block metadata prev_block_id = None for block_id in (0, 1, 2): - assert manager.block_pool[block_id].prev_block_id == prev_block_id + assert manager.block_pool[block_id].parent_block_id == prev_block_id assert manager.block_pool[block_id].block_hash is not None assert manager.block_pool[block_id].ref_cnt == 1 assert manager.block_pool[block_id].num_hashed_tokens == 16 * ( @@ -49,7 +49,7 @@ def test_prefill(): # Check partial/preallocated block metadata for block_id in (3, 4): - assert manager.block_pool[block_id].prev_block_id == block_id - 1 + assert manager.block_pool[block_id].parent_block_id == block_id - 1 assert manager.block_pool[block_id].block_hash is None assert manager.block_pool[block_id].ref_cnt == 1 assert manager.block_pool[block_id].num_hashed_tokens == 0 diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 80c3a7d3634cc..10bd6e5e59970 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,6 +1,4 @@ -import contextlib -from collections import defaultdict, deque -from concurrent.futures import Future, ThreadPoolExecutor +from collections import defaultdict from dataclasses import dataclass, field from functools import lru_cache from typing import Dict, List, Optional, Tuple @@ -17,8 +15,8 @@ class KVCacheBlock: """KV-cache block metadata.""" # Block ID, ranging from 0 to num_gpu_blocks - 1. block_id: int - # Previous block ID. Used to include block chain in the block hash. - prev_block_id: Optional[int] = None + # Parent block ID. Used to include block chain in the block hash. + parent_block_id: Optional[int] = None # Reference count. ref_cnt: int = 0 # Token IDs in the block. @@ -29,14 +27,110 @@ class KVCacheBlock: # is closer to the end of a prompt and more likely to be evicted. num_hashed_tokens: int = 0 + # Used to construct a doubly linked list for free blocks. + # These two attributes should only be manipulated by FreeKVCacheBlockQueue. + prev_free_block: Optional["KVCacheBlock"] = None + next_free_block: Optional["KVCacheBlock"] = None + def reset(self): - self.prev_block_id = None + """Reset the block metadata.""" + self.parent_block_id = None self.ref_cnt = 0 self.token_ids.clear() self.block_hash = None self.num_hashed_tokens = 0 +class FreeKVCacheBlockQueue: + """This class organizes a list of KVCacheBlock objects to a doubly linked + list of free blocks by manipulating the prev_free_block and next_free_block + attributes of the blocks. We implement this class instead of using Python + builtin deque for the following reasons: + 1. Avoid the overhead of queue objects. + 2. Remove a block in the middle of the queue in O(1) time. + + The queue is ordered by block ID in the beginning. When a block is allocated + and then freed, it will be appended back with the eviction order: + 1. The least recent used block is at the front (LRU). + 2. If two blocks have the same last accessed time (allocated by the + same sequence), the one with more hash tokens (the tail of a block + chain) is at the front. + Note that we maintain this order by reversing the block order when free + blocks of a request. This operation is outside of this class. + + Args: + blocks: A list of KVCacheBlock objects. + """ + + def __init__(self, blocks: List[KVCacheBlock]) -> None: + self.num_free_blocks = len(blocks) + + # Initialize the doubly linked list of free blocks. + self.free_list_head = blocks[0] + self.free_list_tail = blocks[-1] + for i in range(self.size): + if i > 0: + blocks[i].prev_free_block = blocks[i - 1] + if i < self.size - 1: + blocks[i].next_free_block = blocks[i + 1] + + def popleft(self) -> KVCacheBlock: + """Pop the first free block and reduce num_free_blocks by 1. + + Returns: + The first free block. + """ + if not self.free_list_head: + raise ValueError("No free blocks available") + + block = self.free_list_head + self.remove(block) + return block + + def remove(self, block: KVCacheBlock) -> None: + """Remove a block in the free list and reduce num_free_blocks by 1. + + Args: + block: The block to remove. + """ + if block.prev_free_block is not None: + # Link the previous block to the next block. + block.prev_free_block.next_free_block = block.next_free_block + if block.next_free_block is not None: + # Link the next block to the previous block. + block.next_free_block.prev_free_block = block.prev_free_block + + if block == self.free_list_head: + # Update the head if the block is the head. + self.free_list_head = block.next_free_block + if block == self.free_list_tail: + # Update the tail if the block is the tail. + self.free_list_tail = block.prev_free_block + + # Remove the block from the linked list. + block.prev_free_block = block.next_free_block = None + self.num_free_blocks -= 1 + + def append(self, block: KVCacheBlock) -> None: + """Put a block back into the free list and increase + num_free_blocks by 1. + + Args: + block: The block to append. + """ + if self.free_list_tail is not None: + # Link the last block to the new block. + self.free_list_tail.next_free_block = block + block.prev_free_block = self.free_list_tail + self.free_list_tail = block + else: + # The free list is empty. + self.free_list_head = self.free_list_tail = block + + block.next_free_block = None + self.num_free_blocks += 1 + + class KVCacheManager: def __init__( @@ -67,33 +161,7 @@ def __init__( self.block_pool: List[KVCacheBlock] = [ KVCacheBlock(idx) for idx in range(num_gpu_blocks) ] - # [Prefix caching] The free block list ordered by block ID in the - # beginning. However, when a block is allocated and then freed, it - # will be added back with the eviction order: - # 1. The least recent used block is at the front (LRU). - # 2. If two blocks have the same last accessed time (allocated by the - # same sequence), the one with more hash tokens (the tail of a block - # chain) is at the front. - # We maintain this order by reversing the block order when free - # blocks of a request. - # - # Note that the block in this list is NOT guaranteed to be free - # due to prefix caching. If a block in free block list is touched - # by a request, we do not remove it immediately from free_block_list - # due to O(n) removal cost. Instead, we remove ref_cnt>0 blocks when - # 1. allocate new blocks in the same batch, or - # 2. finish scheduling a step and call async_remove_touched_blocks. - # That's why we need to maintain lazy_remove_block_ids and - # num_free_blocks counter separately. - # - # [No prefix caching] The free block list is simply in the order - # of last accessed time. - self.free_block_queue = deque(self.block_pool) - self.lazy_remove_block_ids = set() - self.num_free_blocks = num_gpu_blocks - - self._async_executor = ThreadPoolExecutor(max_workers=1) - self._async_touch_task: Optional[Future] = None + self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool) # {block_hash: {block ID: block}}. A cached block is # a full block with a block hash that can be used for prefix caching. @@ -162,14 +230,14 @@ def append_slots( req_block_ids = self.req_to_block_ids[request.request_id] num_new_blocks = num_required_blocks - len(req_block_ids) - if num_new_blocks > self.num_free_blocks: + if num_new_blocks > self.free_block_queue.num_free_blocks: # Need to allocate new blocks due to insufficient pre-allocated # slots, but we cannot allocate new blocks due to the limit. return None # Assign token IDs to already allocated blocks. new_token_ids = None - prev_block_id = None + parent_block_id = None if self.enable_caching: # Figure out the token IDs to add to the blocks. if request.num_computed_tokens < request.num_prompt_tokens: @@ -192,15 +260,15 @@ def append_slots( req_block_ids[last_full_block_idx]].block_hash is None): last_full_block_idx -= 1 - prev_block_id = (last_full_block_idx - if last_full_block_idx >= 0 else None) + parent_block_id = (last_full_block_idx + if last_full_block_idx >= 0 else None) token_id_idx = self._add_token_ids_to_blocks( block_ids=req_block_ids[last_full_block_idx + 1:], token_ids=new_token_ids, - prev_block_id=prev_block_id) + parent_block_id=parent_block_id) new_token_ids = new_token_ids[token_id_idx:] - prev_block_id = req_block_ids[-1] + parent_block_id = req_block_ids[-1] # No new block is needed. When caching is enabled, we make sure # token_id_idx is equal to len(new_token_ids), meaning that all tokens @@ -213,9 +281,9 @@ def append_slots( # Allocate new blocks considering preallocated blocks, and # add token IDs to them if caching is enabled. num_new_blocks = min(num_new_blocks + self.num_preallocate_blocks, - self.num_free_blocks) + self.free_block_queue.num_free_blocks) new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids, - prev_block_id) + parent_block_id) new_block_ids = [blk.block_id for blk in new_blocks] req_block_ids.extend(new_block_ids) return new_block_ids @@ -241,17 +309,17 @@ def allocate_slots( raise ValueError( f"num_tokens must be greater than 0, got {num_tokens}") - # If a computed block is an eviction candidate (in the free queue), - # it cannot be counted as a free block when estimating whether we - # can allocate new blocks for this request. + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it cannot be counted as a free block + # when allocating this request. num_evictable_computed_blocks = len([ bid for bid in computed_block_ids if self.block_pool[bid].ref_cnt == 0 ]) num_required_blocks = cdiv(num_tokens, self.block_size) - if (num_required_blocks > - self.num_free_blocks - num_evictable_computed_blocks): + if (num_required_blocks > self.free_block_queue.num_free_blocks - + num_evictable_computed_blocks): # Cannot allocate new blocks. return None @@ -259,7 +327,8 @@ def allocate_slots( # preallocated blocks. num_new_blocks = min( num_required_blocks + self.num_preallocate_blocks, - self.num_free_blocks - num_evictable_computed_blocks) + self.free_block_queue.num_free_blocks - + num_evictable_computed_blocks) # Get the token IDs for the blocks being allocated for hashing. # Note that we expect this function to be called only once per # request, so we must have all new token IDs in the prompt. @@ -274,16 +343,16 @@ def allocate_slots( f"#computed_tokens={num_computed_tokens}") # Touch the computed blocks to make sure they won't be evicted. - self._lazy_touch(computed_block_ids) + self._touch(computed_block_ids) - # Get the previous block ID to construct the block chain. - prev_block_id = computed_block_ids[ + # Get the parent block ID to construct the block chain. + parent_block_id = computed_block_ids[ -1] if computed_block_ids else None else: new_token_ids = None - prev_block_id = None + parent_block_id = None new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids, - prev_block_id) + parent_block_id) new_block_ids = [blk.block_id for blk in new_blocks] # Concatenate the computed block IDs and the new block IDs. @@ -301,57 +370,20 @@ def free(self, request: Request) -> None: """ block_ids = self.req_to_block_ids.pop(request.request_id) if self.enable_caching: - # Make sure async tasks (remove touched blocks) are finished. - self.wait_for_removing_touched_blocks() - assert not self.lazy_remove_block_ids - # Free blocks in reverse order so that the tail blocks are # freed first. - for block_id in reversed(block_ids): - self.block_pool[block_id].ref_cnt -= 1 - if self.block_pool[block_id].ref_cnt == 0: - self.free_block_queue.append(self.block_pool[block_id]) - self.num_free_blocks += 1 - else: - for block_id in block_ids: - self.block_pool[block_id].ref_cnt -= 1 - if self.block_pool[block_id].ref_cnt == 0: - self.free_block_queue.append(self.block_pool[block_id]) - self.num_free_blocks += 1 - - def async_remove_touched_blocks(self) -> None: - """Asynchronously remove the touched blocks from the free block list. - This function should be called at the end of a scheduling step, so that - costly operations of removing the touched blocks are done in parallel - with the model forward pass. - """ + block_ids = reversed(block_ids) - def _sync_remove_touched_blocks(): - for block_id in self.lazy_remove_block_ids: - # The block may have been removed during allocate_slots - # so suppress the element not found error. - with contextlib.suppress(ValueError): - self.free_block_queue.remove(self.block_pool[block_id]) - self.lazy_remove_block_ids.clear() - - if self.lazy_remove_block_ids: - self.wait_for_removing_touched_blocks() - self._async_touch_task = self._async_executor.submit( - _sync_remove_touched_blocks) - - def wait_for_removing_touched_blocks(self) -> None: - """Wait for the asynchronous task to finish if there are - a task on the fly.""" - if self._async_touch_task is not None: - self._async_touch_task.result() - self._async_touch_task = None - assert not self.lazy_remove_block_ids + for block_id in block_ids: + self.block_pool[block_id].ref_cnt -= 1 + if self.block_pool[block_id].ref_cnt == 0: + self.free_block_queue.append(self.block_pool[block_id]) def _get_new_blocks( self, num_blocks: int, token_ids: Optional[List[int]] = None, - prev_block_id: Optional[int] = None) -> List[KVCacheBlock]: + parent_block_id: Optional[int] = None) -> List[KVCacheBlock]: """Get new blocks from the free block pool, and add token IDs to allocated blocks if caching is enabled. Note that we do not check block cache in this function. @@ -359,28 +391,22 @@ def _get_new_blocks( Args: num_blocks: The number of blocks to allocate. token_ids: The token IDs in the blocks. None if caching is disabled. - prev_block_id: The previous block ID. Used to include block chain + parent_block_id: The parent block ID. Used to include block chain in the block hash. Returns: A list of new block. """ - assert num_blocks <= self.num_free_blocks - if num_blocks > self.num_free_blocks: + if num_blocks > self.free_block_queue.num_free_blocks: raise ValueError( f"Cannot get {num_blocks} free blocks from the pool") # First allocate blocks. - ret = [] + ret: List[KVCacheBlock] = [] idx = 0 while idx < num_blocks: curr_block = self.free_block_queue.popleft() - # The block has been allocated by another request. This happens - # when another request *in the same batch* touches (cache hit) - # the block before calling async_remove_touched_blocks. - # In this case, this block should also in lazy_remove_block_ids. - if curr_block.ref_cnt > 0: - continue + assert curr_block.ref_cnt == 0 # Evict blocks from the cache. if self.enable_caching: @@ -404,27 +430,27 @@ def _get_new_blocks( token_id_idx = self._add_token_ids_to_blocks( block_ids=[blk.block_id for blk in ret], token_ids=token_ids, - prev_block_id=prev_block_id) + parent_block_id=parent_block_id) assert token_id_idx == len(token_ids) - self.num_free_blocks -= num_blocks return ret def _cache_full_block(self, block: KVCacheBlock, - prev_block: Optional[KVCacheBlock] = None) -> None: + parent_block: Optional[KVCacheBlock] = None) -> None: """Cache a full block for prefix caching. Args: block: The block to cache. - prev_block: The previous block. None if this is the first block. + parent_block: The parent block. None if this is the first block. """ - prev_block_hash = (prev_block.block_hash - if prev_block is not None else None) - block_hash = hash_block_tokens(prev_block_hash, tuple(block.token_ids)) + parent_block_hash = (parent_block.block_hash + if parent_block is not None else None) + block_hash = hash_block_tokens(parent_block_hash, + tuple(block.token_ids)) block.block_hash = block_hash block.num_hashed_tokens = self.block_size + ( - prev_block.num_hashed_tokens if prev_block is not None else 0) + parent_block.num_hashed_tokens if parent_block is not None else 0) self.cached_block_hash_to_block[block_hash][block.block_id] = block def _get_cached_block(self, block_hash: int) -> Optional[KVCacheBlock]: @@ -443,31 +469,26 @@ def _get_cached_block(self, block_hash: int) -> Optional[KVCacheBlock]: return self.cached_block_hash_to_block[block_hash][first_block_id] return None - def _lazy_touch(self, block_ids: List[int]) -> None: - """Touch a block manes to remove it from the free block list - so that it will not be evicted. "Lazy" touch means that we do not remove - the block from the free block list immediately to avoid O(n) cost. - The blocks will be removed from the free block list when - 1. allocate new blocks in the same batch, or - 2. finish scheduling a step and call async_remove_touched_blocks. + def _touch(self, block_ids: List[int]) -> None: + """Touch a block increases its reference count by 1, and may remove + the block from the free queue. This is used when a block is hit by + another request with the same prefix. Args: block_id: The ID of the block to touch. """ for block_id in block_ids: curr_block = self.block_pool[block_id] - # The block has no reference yet, meaning that it is in - # the free list, so we reduce the number of free blocks by 1, - # but not remove it from the free list now to avoid O(n) cost. + # ref_cnt=0 means this block is in the free list (i.e. eviction + # candidate), so remove it. if curr_block.ref_cnt == 0: - self.num_free_blocks -= 1 - self.lazy_remove_block_ids.add(block_id) + self.free_block_queue.remove(curr_block) curr_block.ref_cnt += 1 def _add_token_ids_to_blocks(self, block_ids: List[int], token_ids: List[int], - prev_block_id: Optional[int] = None) -> int: + parent_block_id: Optional[int] = None) -> int: """Add token IDs to a list of allocated blocks. If a block becomes full after adding token IDs, cache it. Return the token ID index that has not been added to the blocks @@ -476,22 +497,23 @@ def _add_token_ids_to_blocks(self, Args: block_ids: A list of block IDs to add token IDs. token_ids: A list of token IDs to add. - prev_block_id: The previous block ID. None if this is the + parent_block_id: The parent block ID. None if this is the first block. Returns: The starting token ID index that has not been added to the blocks due to insufficient given blocks. """ - prev_block = self.block_pool[ - prev_block_id] if prev_block_id is not None else None + parent_block = self.block_pool[ + parent_block_id] if parent_block_id is not None else None token_id_start = 0 for block_id in block_ids: curr_block = self.block_pool[block_id] - curr_block.prev_block_id = prev_block_id + curr_block.parent_block_id = parent_block_id - # If all token IDs are added, the rest of the blocks are - # preallocated blocks, so we only need to update the prev_block_id. + # If all token IDs are added, then the rest of the blocks are + # preallocated blocks, so we only need to update the + # parent_block_id. if token_id_start == len(token_ids): continue @@ -501,9 +523,9 @@ def _add_token_ids_to_blocks(self, curr_block.token_ids.extend(token_ids[token_id_start:token_id_end]) # Cache the block if it becomes full. if len(curr_block.token_ids) == self.block_size: - self._cache_full_block(curr_block, prev_block) - prev_block = curr_block - prev_block_id = prev_block.block_id + self._cache_full_block(curr_block, parent_block) + parent_block = curr_block + parent_block_id = parent_block.block_id token_id_start = token_id_end return token_id_start @@ -518,21 +540,21 @@ def hash_prompt_tokens(self, token_ids: List[int]) -> List[int]: The list of computed hash values. """ ret = [] - prev_block_hash = None + parent_block_hash = None for start in range(0, len(token_ids), self.block_size): end = start + self.block_size block_token_ids = tuple(token_ids[start:end]) # Do not hash the block if it is not full. if len(block_token_ids) < self.block_size: break - block_hash = hash_block_tokens(prev_block_hash, block_token_ids) + block_hash = hash_block_tokens(parent_block_hash, block_token_ids) ret.append(block_hash) - prev_block_hash = block_hash + parent_block_hash = block_hash return ret @lru_cache(maxsize=1024) -def hash_block_tokens(prev_block_hash: Optional[int], +def hash_block_tokens(parent_block_hash: Optional[int], cur_block_token_ids: Tuple[int]) -> int: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for @@ -543,7 +565,7 @@ def hash_block_tokens(prev_block_hash: Optional[int], features such as LoRA adapter. Args: - prev_block_hash: The hash of the previous block. None + parent_block_hash: The hash of the parent block. None if this is the first block. cur_block_token_ids: A tuple of token ids in the current block. The current block is assumed to be full. @@ -551,4 +573,4 @@ def hash_block_tokens(prev_block_hash: Optional[int], Returns: The computed hash value for the block. """ - return hash((prev_block_hash, *cur_block_token_ids)) + return hash((parent_block_hash, *cur_block_token_ids))