-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replace AttnPagedCache with BasePagedAttentionCache (#565)
Creates space for #593 (prefix-sharing) Coming next: #607 , which should be the last thing I do before I can check in my blocktrie implementation. Summary of changes: - copied over stella's cache.py and renamed it to page_pool.py - each inference request now notifies the cache when its pages are done written to
- Loading branch information
Showing
10 changed files
with
387 additions
and
229 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
80 changes: 80 additions & 0 deletions
80
shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
""" | ||
Base class for kv caches. | ||
""" | ||
|
||
from typing import List | ||
from .page_pool import PageInfo | ||
import math | ||
|
||
|
||
class BasePagedAttentionCache: | ||
""" | ||
Manages lifecycle of pages (using PageInfo as handles). | ||
Page States: | ||
Caching - Page can be read by multiple threads | ||
- Also maintains a reference count | ||
Writing - Page is being modified by a single owner thread | ||
Transitions: | ||
Caching -> Writing: When acquiring an unreferenced LRU leaf page for writing | ||
Writing -> Caching: When writing is complete and page is released | ||
Thread Safety: | ||
- Multiple readers allowed in ReadableCaching state | ||
- Single writer exclusive access in Writing state | ||
- Reference counting prevents eviction of in-use pages | ||
""" | ||
|
||
def __init__(self, page_pool, tokens_per_page): | ||
self.page_pool = page_pool | ||
self.tokens_per_page = tokens_per_page | ||
|
||
def acquire_pages_for_tokens( | ||
self, tokens: List[int], extra_token_slots: int = 1 | ||
) -> tuple[list[PageInfo], int]: | ||
""" | ||
Given a list of tokens, return a list of pages and a start position to continue generation from. | ||
Parameters: | ||
- tokens: all the known tokens for this generation request | ||
- extra_token_slots: number of kvcache slots needed in addition to the ones needed to hold the given tokens. | ||
In the base implementation, this will just allocate all new pages, but in shared-kv implementations, we will fetch cached pages if applicable. | ||
The pages are returned in order. | ||
No token at idx < n_cached_token should be written to. TODO: consider enforcing this. | ||
""" | ||
token_count = len(tokens) | ||
pages_needed = math.ceil(token_count / self.tokens_per_page) | ||
pages = self.page_pool.acquire_free_pages(pages_needed) | ||
|
||
n_cached_tokens = 0 | ||
|
||
return pages, n_cached_tokens | ||
|
||
def publish_pages(self, tokens, pages) -> None: | ||
""" | ||
Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests. | ||
Associates the tokens with the pages, and mark them as done writing. | ||
It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)]. | ||
""" | ||
|
||
pass # the base implementation doesn't cache unfinished requests. | ||
|
||
def release_pages(self, tokens, pages): | ||
""" | ||
Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction. | ||
""" | ||
# in the base implementation, the pages can be owned by 1 request max, so they can be instantly release | ||
self.page_pool.release_pages(pages) |
159 changes: 159 additions & 0 deletions
159
shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from __future__ import annotations | ||
from typing import List, Tuple, Optional, Sequence | ||
import threading | ||
import logging | ||
import shortfin as sf | ||
import shortfin.array as sfnp | ||
from dataclasses import dataclass | ||
|
||
from ..config_struct import human_size | ||
import math | ||
|
||
import time | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class PageInfo: | ||
""" | ||
Page index with some metadata about its contents. | ||
""" | ||
|
||
index: int | ||
pool: PagePool | ||
token_offset: int # Offset within the page | ||
token_count: int # Number of tokens stored in this page | ||
writing: bool = False | ||
read_ref_count: int = 0 # Number of threads that still need to read this page. When this reaches 0, page is eligible for release | ||
|
||
|
||
@dataclass | ||
class PagePoolConfig: | ||
""" | ||
Hyperparameters for the page pool. | ||
""" | ||
|
||
dtype: sf.dtype | ||
alloc_page_count: int | ||
|
||
paged_kv_block_size_elements: int # size of a single page as # of elements | ||
# (e.g. one configuration for llama3.1 8b hax 32x2x16x8x128=1048576 elements where: | ||
# 32: number of transformer blocks | ||
# 2: one for k + one for v | ||
# 16: tokens per page | ||
# 8: head count (32 heads, but every 4 heads share the same kv buffer) | ||
# 128: hidden dimension | ||
|
||
|
||
class PagePool: | ||
"""Page table based attention cache. | ||
While internal to a model, the cache is organized with additional structure | ||
per page, outside of the model, it is just a list of pages of a certain | ||
element type and number of elements (all inner dims are flattened). | ||
One page table is allocated per device in a fiber. Currently, this is a | ||
dense allocation with committed memory but in the future, we may just | ||
allocate the address space and lazily populate it with committed memory. | ||
The cache is unique because usage of it can span fibers and concurrency | ||
is implicitly managed at the block level (i.e. freshly acquired blocks | ||
are assumed to be uninitialized and available immediately for use). | ||
It is initialized with a discrete list of fiberd devices from a fiber but | ||
cache usage can be done from any fiber which includes those devices. | ||
In addition to supporting paged attention standalone, this also serves | ||
as the array / buffer allocation layer for radix attention described in | ||
`radix_tree.py`. | ||
""" | ||
|
||
def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig): | ||
self._lock = threading.Lock() | ||
self.devices = list(devices) | ||
self.config = config | ||
self.page_tables: list[sf.array.device_array] = [] | ||
|
||
# Setup accounting structs. | ||
self.attn_page_entries = [ | ||
PageInfo( | ||
index=i, | ||
pool=self, | ||
token_offset=0, | ||
token_count=0, | ||
) | ||
for i in range(self.config.alloc_page_count) | ||
] | ||
|
||
self.attn_page_free = list(self.attn_page_entries) | ||
|
||
# Initialize a page table on each device. | ||
page_table_shape = [ | ||
self.config.alloc_page_count, | ||
self.config.paged_kv_block_size_elements, | ||
] | ||
for device in devices: | ||
logging.info( | ||
"Allocating page table (shape=%r, dtype=%r, size=%s) on %r", | ||
page_table_shape, | ||
self.config.dtype, | ||
human_size(config.dtype.compute_dense_nd_size(page_table_shape)), | ||
device, | ||
) | ||
page_table = sf.array.device_array.for_device( | ||
device, page_table_shape, self.config.dtype | ||
) | ||
self.page_tables.append(page_table) | ||
|
||
def acquire_free_pages(self, count: int) -> list[PageInfo] | None: | ||
with self._lock: | ||
available = len(self.attn_page_free) | ||
if count > available: | ||
return None | ||
return [self.attn_page_free.pop() for _ in range(count)] | ||
|
||
def release_pages(self, pages: list[PageInfo]): | ||
with self._lock: | ||
self.attn_page_free.extend(pages) | ||
|
||
def copy_page(self, src_page: PageInfo) -> PageInfo: | ||
""" | ||
Copy a page's contents to a new page. | ||
Args: | ||
src_page: Source page to copy from | ||
token_count: Optional number of tokens to copy. If None, copies all tokens. | ||
Returns: | ||
New PageInfo containing the copied data | ||
""" | ||
# Allocate new page | ||
(dst_page,) = self.acquire_free_pages(1) | ||
|
||
# fill src page with data | ||
|
||
# Copy the data on each device | ||
for page_table in self.page_tables: | ||
# View of source and destination pages | ||
src_view = page_table.view(src_page.index) | ||
dst_view = page_table.view(dst_page.index) | ||
# Copy the data | ||
dst_view.copy_from(src_view) | ||
|
||
# Setup destination page metadata | ||
dst_page.token_offset = 0 # Always start at beginning of new page | ||
|
||
return dst_page | ||
|
||
def __repr__(self): | ||
# No need to lock for repr (list is internally synchronized). | ||
free_pages = len(self.attn_page_free) | ||
total_pages = len(self.attn_page_entries) | ||
return ( | ||
f"PagePool({total_pages - free_pages}/{total_pages} pages in use: " | ||
f"{100.0 * free_pages / total_pages}% free)" | ||
) | ||
|
||
|
||
############################## begin radix attention |
Oops, something went wrong.