Skip to content

Commit

Permalink
Prefix Cache Aware Scheduling [1/n] (#10128)
Browse files Browse the repository at this point in the history
Signed-off-by: rickyx <[email protected]>
  • Loading branch information
rickyyx authored Nov 23, 2024
1 parent 7c25fe4 commit 4634a89
Show file tree
Hide file tree
Showing 13 changed files with 967 additions and 241 deletions.
181 changes: 175 additions & 6 deletions tests/core/block/test_prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@

import pytest

from tests.core.utils import create_dummy_sequence
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.prefix_caching_block import (PrefixCachingBlock,
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
PrefixCachingBlock,
PrefixCachingBlockAllocator)
from vllm.sequence import Logprob
from vllm.utils import Device


class TestPrefixCachingBlock:
Expand Down Expand Up @@ -726,18 +731,71 @@ def test_touch_block():
token_ids=common_token_ids,
allocator=allocator,
)
block_ids = [block.block_id for block in blocks]
block_hashes = [block.content_hash for block in blocks]
# The allocated blocks should be marked as touched
# but not computed.
computed_block_ids = allocator.get_computed_block_ids(
[], block_ids, skip_last_block_id=False)
computed_block_ids = allocator.find_cached_blocks_prefix(
block_hashes)
assert len(computed_block_ids) == 0

allocator.mark_blocks_as_computed([])
computed_block_ids = allocator.get_computed_block_ids(
[], block_ids, skip_last_block_id=False)
computed_block_ids = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes)
assert len(computed_block_ids) == common_blocks

@staticmethod
def test_find_cached_blocks_prefix():
"""
This test verifies the behavior of find_cached_blocks_prefix.
"""
block_size = 4
num_blocks = 8
total_test_blocks = 12
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)

token_ids = list(range(total_test_blocks * block_size))
block_tokens_seq1 = token_ids[:num_blocks * block_size]
blocks_seq1 = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=block_tokens_seq1,
allocator=allocator,
)
block_hashes_seq1 = [block.content_hash for block in blocks_seq1]
allocator.mark_blocks_as_computed([])

# All blocks should be cached.
cached_blocks_seq1 = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes_seq1)
assert len(cached_blocks_seq1) == num_blocks

# Free the first sequence.
for block in blocks_seq1:
allocator.free(block)

# All blocks should be still be cached if not required to be allocated.
cached_blocks = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes_seq1)
assert len(cached_blocks) == num_blocks

block_tokens_seq2 = token_ids[num_blocks * block_size:]
blocks_seq2 = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=block_tokens_seq2,
allocator=allocator,
)
block_hashes_seq2 = [block.content_hash for block in blocks_seq2]
allocator.mark_blocks_as_computed([])
cached_blocks = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes_seq2)
assert len(cached_blocks) == len(blocks_seq2)

# Half of the blocks from seq1 should still be cached.
num_evicted_blocks = len(blocks_seq2)
cached_blocks = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes_seq1)
assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks

@staticmethod
def create_immutable_chain(
block_size: int,
Expand All @@ -762,3 +820,114 @@ def create_immutable_chain(
blocks.append(prev_block)

return blocks


class TestComputedBlocksTracker:

@staticmethod
def _get_mock_allocator():
return MagicMock(spec=PrefixCachingBlockAllocator)

@staticmethod
def test_get_num_cached_tokens():
"""
Test it correctly computes the number of cached tokens for a given
sequence:
- The cache token count is derived from the number of cached blocks.
- The cache token count is updated when the allocator is updated.
- When a sequence is removed, the cache token count should be updated
accordingly.
# TODO(rickyx): This behaviour for prefill sequence is a hack until
we fix the computed blocks tracking.
- The cache token count for prefill sequence doesn't change while
the sequence is in continuous prefill (chunked prefill).
"""
block_size = 4
mock_allocator = TestComputedBlocksTracker._get_mock_allocator()
tracker = ComputedBlocksTracker(
allocator=mock_allocator,
block_size=block_size,
enable_caching=True,
)

# Not yet allocated.
tokens = [0, 1, 2, 3, 4, 5]
seq1 = create_dummy_sequence(request_id=0,
token_ids=tokens,
block_size=block_size)
mock_allocator.find_cached_blocks_prefix.return_value = []
assert tracker.get_num_cached_tokens(seq1) == 0

mock_allocator.find_cached_blocks_prefix.return_value = [
None
] # 1 block cached.
# Result is cached for prefill sequence.
assert tracker.get_num_cached_tokens(seq1) == 0

# Mark the sequence as non-prefill.
seq1.data.update_num_computed_tokens(len(tokens)) # 6 tokens computed.
assert not seq1.is_prefill()

# Recomputes for decoding sequence.
assert tracker.get_num_cached_tokens(seq1) == 4

# Append new tokens to the sequence.
num_new_tokens = 3
for i in range(num_new_tokens):
seq1.append_token_id(i, {i: Logprob(logprob=0.0)})

assert tracker.get_num_cached_tokens(seq1) == 4

# Update the allocator.
mock_allocator.find_cached_blocks_prefix.return_value = [
None
] * 2 # 2 blocks cached.
assert tracker.get_num_cached_tokens(seq1) == 8

# Remove the sequence.
tracker.remove_seq(seq1.seq_id)

# Re-create the sequence with the same request id to simulate recompute.
seq1 = create_dummy_sequence(request_id=0,
token_ids=tokens,
block_size=block_size)
mock_allocator.find_cached_blocks_prefix.return_value = [
] # no cached block
assert tracker.get_num_cached_tokens(seq1) == 0

@staticmethod
def test_correct_block_hash():
"""
Test that the block hash is correctly computed for a sequence (should
match the underlying block allocator's block hash). So the number of
cached tokens is correctly retrieved.
"""
block_size = 4
allocator = CpuGpuBlockAllocator.create(
allocator_type="prefix_caching",
num_gpu_blocks=16,
num_cpu_blocks=16,
block_size=block_size,
)
gpu_allocator = allocator._allocators[Device.GPU]

tracker = ComputedBlocksTracker(
allocator=allocator,
block_size=block_size,
enable_caching=True,
)

tokens = list(range(block_size * 4)) # 4 blocks.
seq = create_dummy_sequence(request_id=0,
token_ids=tokens,
block_size=block_size)
_ = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=tokens,
allocator=gpu_allocator,
)
allocator.mark_blocks_as_computed([])

assert tracker.get_num_cached_tokens(seq) == len(tokens)
Loading

0 comments on commit 4634a89

Please sign in to comment.