Skip to content

Commit

Permalink
index using block
Browse files Browse the repository at this point in the history
Signed-off-by: Cody Yu <[email protected]>
  • Loading branch information
comaniac committed Nov 7, 2024
1 parent f34c888 commit 2204969
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 149 deletions.
95 changes: 47 additions & 48 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Compare the with and without prefix caching."""
from vllm.inputs import DecoderOnlyInputs
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_manager import (KVCacheManager, Request,
hash_block_tokens)


def make_request(request_id, prompt_token_ids):
Expand Down Expand Up @@ -31,25 +32,25 @@ def test_prefill():
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids)
computed_block_ids = manager.get_computed_blocks(req0)
assert not computed_block_ids
block_ids = manager.allocate_slots(req0, 55, computed_block_ids)
assert block_ids == [0, 1, 2, 3, 4]
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]

# Check full block metadata
prev_block_id = None
parent_block_hash = None
for block_id in (0, 1, 2):
assert manager.block_pool[block_id].parent_block_id == prev_block_id
assert manager.block_pool[block_id].block_hash is not None
block_hash = hash_block_tokens(parent_block_hash,
manager.block_pool[block_id].token_ids)
assert manager.block_pool[block_id].block_hash == block_hash
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool[block_id].num_hashed_tokens == 16 * (
block_id + 1)
assert manager.block_pool[block_id].token_ids == [block_id] * 16
prev_block_id = block_id
parent_block_hash = block_hash

# Check partial/preallocated block metadata
for block_id in (3, 4):
assert manager.block_pool[block_id].parent_block_id == block_id - 1
assert manager.block_pool[block_id].block_hash is None
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool[block_id].num_hashed_tokens == 0
Expand All @@ -62,14 +63,13 @@ def test_prefill():
# Incomplete 1 block (5 tokens)
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_block_ids = manager.get_computed_blocks(req1)
assert computed_block_ids == [0, 1, 2]
computed_blocks = manager.get_computed_blocks(req1)
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
block_ids = manager.allocate_slots(req1, num_new_tokens,
computed_block_ids)
assert block_ids == [5, 6]
for block_id in (0, 1, 2):
assert manager.block_pool[block_id].ref_cnt == 2
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5, 6]
for block in computed_blocks:
assert block.ref_cnt == 2

# At this point, we should have 3 free blocks left.
assert manager.free_block_queue.num_free_blocks == 3
Expand All @@ -92,12 +92,11 @@ def test_prefill():
# Incomplete 1 block (6 tokens)
unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_block_ids = manager.get_computed_blocks(req2)
assert computed_block_ids == [0, 1, 2]
computed_block = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_block] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
block_ids = manager.allocate_slots(req2, num_new_tokens,
computed_block_ids)
assert block_ids == [7, 8]
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [7, 8]

# Although we only have 5 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
Expand All @@ -112,11 +111,11 @@ def test_prefill():

# Cache miss and eviction.
req3 = make_request("3", [99] * (16 * 9))
computed_block_ids = manager.get_computed_blocks(req3)
assert not computed_block_ids
block_ids = manager.allocate_slots(req2, 16 * 9, computed_block_ids)
computed_blocks = manager.get_computed_blocks(req3)
assert not computed_blocks
blocks = manager.allocate_slots(req2, 16 * 9, computed_blocks)
# This block ID order also checks the eviction order.
assert block_ids == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
assert manager.free_block_queue.num_free_blocks == 0
assert manager.free_block_queue.free_list_head is None
assert manager.free_block_queue.free_list_tail is None
Expand All @@ -138,16 +137,16 @@ def test_decode():
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids)
computed_block_ids = manager.get_computed_blocks(req0)
assert not computed_block_ids
block_ids = manager.allocate_slots(req0, 55, computed_block_ids)
assert block_ids == [0, 1, 2, 3, 4]
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]

# Append slots without allocating a new block.
req0.num_computed_tokens = 55
req0.output_token_ids = [8] * 4
new_block_ids = manager.append_slots(req0, 4)
assert new_block_ids is not None and len(new_block_ids) == 0
new_blocks = manager.append_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0
assert len(manager.block_pool[3].token_ids) == 11

# Append slots without allocating a new block, but start using the
Expand All @@ -156,8 +155,8 @@ def test_decode():
# 6 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
req0.output_token_ids += [7] * (5 + 10)
new_block_ids = manager.append_slots(req0, 15)
assert new_block_ids is not None and len(new_block_ids) == 0
new_blocks = manager.append_slots(req0, 15)
assert new_blocks is not None and len(new_blocks) == 0
assert len(manager.block_pool[3].token_ids) == 16
assert len(manager.block_pool[4].token_ids) == 10

Expand All @@ -166,9 +165,9 @@ def test_decode():
# 6 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
req0.output_token_ids += [12] * (6 + 11)
new_block_ids = manager.append_slots(req0, 17)
new_blocks = manager.append_slots(req0, 17)
# Plus one preallocated block.
assert new_block_ids is not None and len(new_block_ids) == 2
assert new_blocks is not None and len(new_blocks) == 2
assert len(manager.block_pool[4].token_ids) == 16
assert len(manager.block_pool[5].token_ids) == 11
assert len(manager.block_pool[6].token_ids) == 0
Expand All @@ -185,18 +184,18 @@ def test_evict():

last_token_id = 5 * 16 + 7
req0 = make_request("0", list(range(last_token_id)))
computed_block_ids = manager.get_computed_blocks(req0)
assert not computed_block_ids
block_ids = manager.allocate_slots(req0, 5 * 16 + 7, computed_block_ids)
assert len(block_ids) == 7 # 5 full + 1 partial + 1 preallocated
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated

# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
last_token_id + 3 * 16)))
computed_block_ids = manager.get_computed_blocks(req1)
assert not computed_block_ids
block_ids = manager.allocate_slots(req1, 3 * 16, computed_block_ids)
assert len(block_ids) == 3 # 3 full blocks
computed_blocks = manager.get_computed_blocks(req1)
assert not computed_blocks
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
assert len(blocks) == 3 # 3 full blocks
last_token_id += 3 * 16

assert manager.free_block_queue.num_free_blocks == 0
Expand All @@ -210,8 +209,8 @@ def test_evict():

# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_block_ids = manager.get_computed_blocks(req2)
assert computed_block_ids == [0, 1]
block_ids = manager.allocate_slots(req2, 3, computed_block_ids)
assert block_ids == [6, 5]
computed_blocks = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_blocks] == [0, 1]
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [6, 5]
assert manager.free_block_queue.num_free_blocks == 6
Loading

0 comments on commit 2204969

Please sign in to comment.