diff --git a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml index 0e0e1db2a..550366e1b 100644 --- a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml @@ -98,6 +98,6 @@ jobs: - name: Run shortfin Python tests (full) working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | - pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/llm/components/cache_test.py --ignore=tests/apps/sd + pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/sd # TODO: Enable further tests and switch to # pytest -s diff --git a/shortfin/python/shortfin_apps/llm/_deps.py b/shortfin/python/shortfin_apps/llm/_deps.py index 7123d011e..fb8ca8176 100644 --- a/shortfin/python/shortfin_apps/llm/_deps.py +++ b/shortfin/python/shortfin_apps/llm/_deps.py @@ -5,13 +5,23 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from shortfin.support.deps import ShortfinDepNotFoundError +import sys -try: - import tokenizers -except ModuleNotFoundError as e: - raise ShortfinDepNotFoundError(__name__, "tokenizers") from e +deps = [ + "tokenizers", + "dataclasses_json", +] -try: - import dataclasses_json -except ModuleNotFoundError as e: - raise ShortfinDepNotFoundError(__name__, "dataclasses-json") from e +for dep in deps: + try: + __import__(dep) + except ModuleNotFoundError as e: + if "pytest" in sys.modules: + import pytest + + pytest.skip( + f"A test imports shortfin_apps.llm; skipping due to unavailable Shortfin LLM dependency: {dep}", + allow_module_level=True, + ) + else: + raise ShortfinDepNotFoundError(__name__, dep) from e diff --git a/shortfin/python/shortfin_apps/llm/components/cache.py b/shortfin/python/shortfin_apps/llm/components/cache.py deleted file mode 100644 index 12794498f..000000000 --- a/shortfin/python/shortfin_apps/llm/components/cache.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Sequence - -import logging -import math -import threading - -import shortfin as sf - -from .config_struct import ModelParams, human_size - -logger = logging.getLogger(__name__) - - -class AttnPageEntry: - __slots__ = [ - "cache", - "index", - "in_use", - ] - - def __init__(self, cache: "AttnPageCache", index: int): - self.cache = cache - self.index = index - self.in_use = False - - def __repr__(self): - return f"Block({self.index}, {'FREE' if not self.in_use else 'BUSY'})" - - -class AttnPageCache: - """Page table based attention cache. - - While internal to a model, the cache is organized with additional structure - per page, outside of the model, it is just a list of pages of a certain - element type and number of elements (all inner dims are flattened). - - One page table is allocated per device in a fiber. Currently, this is a - dense allocation with committed memory but in the future, we may just - allocate the address space and lazily populate it with committed memory. - - The cache is unique because usage of it can span fibers and concurrency - is implicitly managed at the block level (i.e. freshly acquired blocks - are assumed to be uninitialized and available immediately for use). - - It is initialized with a discrete list of fiberd devices from a fiber but - cache usage can be done from any fiber which includes those devices. - """ - - def __init__( - self, *, devices: Sequence[sf.ScopedDevice], model_params: ModelParams - ): - self._lock = threading.Lock() - self.devices = list(devices) - self.model_params = model_params - self.page_tables: list[sf.array.device_array] = [] - cache_params = model_params.paged_kv_cache - alloc_page_count = cache_params.device_block_count - - # Setup accounting structs. - self.attn_page_entries = [ - AttnPageEntry(self, i) for i in range(alloc_page_count) - ] - self.attn_page_free = list(self.attn_page_entries) - - # Initialize a page table on each device. - assert cache_params is not None, "Model does not have a paged kv cache" - page_table_shape = [ - alloc_page_count, - model_params.paged_kv_block_size_elements, - ] - for device in devices: - logging.info( - "Allocating page table (shape=%r, dtype=%r, size=%s) on %r", - page_table_shape, - model_params.attn_dtype, - human_size( - math.prod(page_table_shape) - * model_params.attn_dtype.dense_byte_count - ), - device, - ) - page_table = sf.array.device_array.for_device( - device, page_table_shape, model_params.attn_dtype - ) - self.page_tables.append(page_table) - - def acquire_free_pages(self, count: int) -> list[AttnPageEntry] | None: - with self._lock: - available = len(self.attn_page_free) - if count > available: - return None - return [self.attn_page_free.pop() for _ in range(count)] - - def release_pages(self, pages: list[AttnPageEntry]): - with self._lock: - self.attn_page_free.extend(pages) - - def __repr__(self): - # No need to lock for repr (list is internally synchronized). - free_pages = len(self.attn_page_free) - total_pages = len(self.attn_page_entries) - return ( - f"AttnPageCache({total_pages - free_pages}/{total_pages} pages in use: " - f"{100.0 * free_pages / total_pages}% free)" - ) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py new file mode 100644 index 000000000..0007000bc --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -0,0 +1,80 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Base class for kv caches. +""" + +from typing import List +from .page_pool import PageInfo +import math + + +class BasePagedAttentionCache: + """ + Manages lifecycle of pages (using PageInfo as handles). + + + Page States: + Caching - Page can be read by multiple threads + - Also maintains a reference count + Writing - Page is being modified by a single owner thread + + Transitions: + Caching -> Writing: When acquiring an unreferenced LRU leaf page for writing + Writing -> Caching: When writing is complete and page is released + + Thread Safety: + - Multiple readers allowed in ReadableCaching state + - Single writer exclusive access in Writing state + - Reference counting prevents eviction of in-use pages + """ + + def __init__(self, page_pool, tokens_per_page): + self.page_pool = page_pool + self.tokens_per_page = tokens_per_page + + def acquire_pages_for_tokens( + self, tokens: List[int], extra_token_slots: int = 1 + ) -> tuple[list[PageInfo], int]: + """ + Given a list of tokens, return a list of pages and a start position to continue generation from. + + Parameters: + - tokens: all the known tokens for this generation request + - extra_token_slots: number of kvcache slots needed in addition to the ones needed to hold the given tokens. + + In the base implementation, this will just allocate all new pages, but in shared-kv implementations, we will fetch cached pages if applicable. + + The pages are returned in order. + + No token at idx < n_cached_token should be written to. TODO: consider enforcing this. + """ + token_count = len(tokens) + pages_needed = math.ceil(token_count / self.tokens_per_page) + pages = self.page_pool.acquire_free_pages(pages_needed) + + n_cached_tokens = 0 + + return pages, n_cached_tokens + + def publish_pages(self, tokens, pages) -> None: + """ + Given a list of tokens and pages containing KV corresponding to these tokens, make these pages available to other requests. + + Associates the tokens with the pages, and mark them as done writing. + + It is assumed that hereafter, the calling request will not modify these pages, at least not the positions [0:len(tokens)]. + """ + + pass # the base implementation doesn't cache unfinished requests. + + def release_pages(self, tokens, pages): + """ + Decrement reference count for these pages. When reference count is zero, they will be elegible for eviction. + """ + # in the base implementation, the pages can be owned by 1 request max, so they can be instantly release + self.page_pool.release_pages(pages) diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py new file mode 100644 index 000000000..1686370c0 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py @@ -0,0 +1,159 @@ +from __future__ import annotations +from typing import List, Tuple, Optional, Sequence +import threading +import logging +import shortfin as sf +import shortfin.array as sfnp +from dataclasses import dataclass + +from ..config_struct import human_size +import math + +import time + +logger = logging.getLogger(__name__) + + +@dataclass +class PageInfo: + """ + Page index with some metadata about its contents. + """ + + index: int + pool: PagePool + token_offset: int # Offset within the page + token_count: int # Number of tokens stored in this page + writing: bool = False + read_ref_count: int = 0 # Number of threads that still need to read this page. When this reaches 0, page is eligible for release + + +@dataclass +class PagePoolConfig: + """ + Hyperparameters for the page pool. + """ + + dtype: sf.dtype + alloc_page_count: int + + paged_kv_block_size_elements: int # size of a single page as # of elements + # (e.g. one configuration for llama3.1 8b hax 32x2x16x8x128=1048576 elements where: + # 32: number of transformer blocks + # 2: one for k + one for v + # 16: tokens per page + # 8: head count (32 heads, but every 4 heads share the same kv buffer) + # 128: hidden dimension + + +class PagePool: + """Page table based attention cache. + + While internal to a model, the cache is organized with additional structure + per page, outside of the model, it is just a list of pages of a certain + element type and number of elements (all inner dims are flattened). + + One page table is allocated per device in a fiber. Currently, this is a + dense allocation with committed memory but in the future, we may just + allocate the address space and lazily populate it with committed memory. + + The cache is unique because usage of it can span fibers and concurrency + is implicitly managed at the block level (i.e. freshly acquired blocks + are assumed to be uninitialized and available immediately for use). + + It is initialized with a discrete list of fiberd devices from a fiber but + cache usage can be done from any fiber which includes those devices. + + In addition to supporting paged attention standalone, this also serves + as the array / buffer allocation layer for radix attention described in + `radix_tree.py`. + """ + + def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig): + self._lock = threading.Lock() + self.devices = list(devices) + self.config = config + self.page_tables: list[sf.array.device_array] = [] + + # Setup accounting structs. + self.attn_page_entries = [ + PageInfo( + index=i, + pool=self, + token_offset=0, + token_count=0, + ) + for i in range(self.config.alloc_page_count) + ] + + self.attn_page_free = list(self.attn_page_entries) + + # Initialize a page table on each device. + page_table_shape = [ + self.config.alloc_page_count, + self.config.paged_kv_block_size_elements, + ] + for device in devices: + logging.info( + "Allocating page table (shape=%r, dtype=%r, size=%s) on %r", + page_table_shape, + self.config.dtype, + human_size(config.dtype.compute_dense_nd_size(page_table_shape)), + device, + ) + page_table = sf.array.device_array.for_device( + device, page_table_shape, self.config.dtype + ) + self.page_tables.append(page_table) + + def acquire_free_pages(self, count: int) -> list[PageInfo] | None: + with self._lock: + available = len(self.attn_page_free) + if count > available: + return None + return [self.attn_page_free.pop() for _ in range(count)] + + def release_pages(self, pages: list[PageInfo]): + with self._lock: + self.attn_page_free.extend(pages) + + def copy_page(self, src_page: PageInfo) -> PageInfo: + """ + Copy a page's contents to a new page. + + Args: + src_page: Source page to copy from + token_count: Optional number of tokens to copy. If None, copies all tokens. + + Returns: + New PageInfo containing the copied data + """ + # Allocate new page + (dst_page,) = self.acquire_free_pages(1) + + # fill src page with data + + # Copy the data on each device + for page_table in self.page_tables: + # View of source and destination pages + src_view = page_table.view(src_page.index) + dst_view = page_table.view(dst_page.index) + # Copy the data + dst_view.copy_from(src_view) + + # Setup destination page metadata + dst_page.token_offset = 0 # Always start at beginning of new page + + return dst_page + + def __repr__(self): + # No need to lock for repr (list is internally synchronized). + free_pages = len(self.attn_page_free) + total_pages = len(self.attn_page_entries) + return ( + f"PagePool({total_pages - free_pages}/{total_pages} pages in use: " + f"{100.0 * free_pages / total_pages}% free)" + ) + + +############################## begin radix attention diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py index fdcbeefc1..c3e6fe34b 100644 --- a/shortfin/python/shortfin_apps/llm/components/messages.py +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -9,7 +9,8 @@ import shortfin as sf import shortfin.array as sfnp -from .cache import AttnPageCache, AttnPageEntry +from .kvcache.base_attention_cache import BasePagedAttentionCache +from .kvcache.page_pool import PageInfo class InferencePhase(Enum): @@ -41,8 +42,8 @@ def __init__(self, phase: InferencePhase, input_token_ids: list[int]): self.result_logits: sfnp.device_array | None = None # Cache pages that have been locked for this request. - self._cache: AttnPageCache | None = None - self.locked_pages: list[AttnPageEntry] | None = None + self._cache: BasePagedAttentionCache | None = None + self.locked_pages: list[PageInfo] | None = None def reset(self, phase: InferencePhase): """Resets all per request state in preparation for an subsequent execution.""" @@ -66,16 +67,18 @@ def free_cache_pages(self): pages = self.locked_pages self._cache = None self.locked_pages = None - cache.release_pages(pages) + cache.release_pages(self.input_token_ids, pages) def lock_initial_cache_pages( - self, cache: AttnPageCache, pages: list[AttnPageEntry] + self, cache: BasePagedAttentionCache, pages: list[PageInfo] ): assert not self._cache self._cache = cache self.locked_pages = pages - def lock_new_cache_pages(self, cache: AttnPageCache, pages: list[AttnPageEntry]): + def lock_new_cache_pages( + self, cache: BasePagedAttentionCache, pages: list[PageInfo] + ): assert self._cache is cache self.locked_pages.extend(pages) diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index bcd08b756..8d3cc1424 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -11,7 +11,8 @@ import shortfin as sf import shortfin.array as sfnp -from .cache import AttnPageCache +from .kvcache.base_attention_cache import BasePagedAttentionCache +from .kvcache.page_pool import PagePoolConfig, PagePool from .config_struct import ModelParams from .manager import SystemManager from .messages import InferenceExecRequest, InferencePhase, StrobeMessage @@ -54,8 +55,17 @@ def __init__( # Scope dependent objects. self.batcher = BatcherProcess(self) - self.page_cache = AttnPageCache( - devices=self.main_fiber.devices_dict.values(), model_params=model_params + page_pool_config = PagePoolConfig( + dtype=model_params.attn_dtype, + alloc_page_count=model_params.paged_kv_cache.device_block_count, + paged_kv_block_size_elements=model_params.paged_kv_block_size_elements, + ) + page_pool = PagePool( + devices=self.main_fiber.devices_dict.values(), config=page_pool_config + ) + self.page_cache = BasePagedAttentionCache( + page_pool=page_pool, + tokens_per_page=model_params.paged_kv_cache.block_seq_stride, ) self.program_isolation = PROG_ISOLATIONS[program_isolation] @@ -200,7 +210,7 @@ def board_flights(self): self.pending_prefills.clear() logger.debug("Post boarding cache state: %r", cache) - def board_prefills(self, cache: AttnPageCache): + def board_prefills(self, cache: BasePagedAttentionCache): # Fill prefill flights. pending_prefills = self.pending_prefills if len(pending_prefills) == 0: @@ -209,7 +219,7 @@ def board_prefills(self, cache: AttnPageCache): self.service, InferencePhase.PREFILL, self.page_seq_stride, - cache.page_tables, + cache.page_pool.page_tables, ) for prefill_request in pending_prefills: assert prefill_request.phase == InferencePhase.PREFILL @@ -218,7 +228,11 @@ def board_prefills(self, cache: AttnPageCache): needed_pages = math.ceil( len(prefill_request.input_token_ids) / self.page_seq_stride ) - pages = cache.acquire_free_pages(needed_pages) + # allocate kv cache pages + pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens( + prefill_request.input_token_ids, + extra_token_slots=0, # prefill needs no extra kvcache slots to write to + ) if pages is None: logger.debug("Cannot fulfill request for %d pages", needed_pages) continue @@ -236,13 +250,16 @@ def board_prefills(self, cache: AttnPageCache): # And takeoff. exec_process.launch() - def board_decodes(self, cache: AttnPageCache): + def board_decodes(self, cache: BasePagedAttentionCache): # Fill decode flights. pending_decodes = self.pending_decodes if len(pending_decodes) == 0: return exec_process = InferenceExecutorProcess( - self.service, InferencePhase.DECODE, self.page_seq_stride, cache.page_tables + self.service, + InferencePhase.DECODE, + self.page_seq_stride, + cache.page_pool.page_tables, ) for decode_request in pending_decodes: assert decode_request.phase == InferencePhase.DECODE @@ -254,7 +271,11 @@ def board_decodes(self, cache: AttnPageCache): / self.page_seq_stride ) if needed_pages > len(decode_request.locked_pages): - pages = cache.acquire_free_pages(needed_pages) + # allocate kv cache pages + pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens( + decode_request.input_token_ids, + extra_token_slots=1, # need 1 extra slot to write result. + ) if pages is None: logger.debug( "Cannot fulfill decode request for %d pages", needed_pages diff --git a/shortfin/tests/apps/llm/components/cache_test.py b/shortfin/tests/apps/llm/components/cache_test.py deleted file mode 100644 index 169d082b1..000000000 --- a/shortfin/tests/apps/llm/components/cache_test.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -""" -Tests for llm kvcache component. -""" - -import pytest -import time -import tempfile -import shortfin as sf -from _shortfin import lib as sfl -from shortfin_apps.llm.components import cache -from shortfin_apps.llm.components import config_struct -import json -from pathlib import Path - - -@pytest.fixture -def lsys(): - sc = sfl.local.host.CPUSystemBuilder() - ls = sc.create_system() - yield ls - ls.shutdown() - - -@pytest.fixture -def fiber(lsys): - # TODO: Should adopt the main thread. - worker = lsys.create_worker("main") - return lsys.create_fiber(worker) - - -@pytest.fixture -def device(fiber): - return fiber.device(0) - - -@pytest.fixture -def model_params(): - model_params = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 2048, - "attn_head_count": 32, - "attn_head_dim": 100, - "prefill_batch_sizes": [4], - "decode_batch_sizes": [4], - "transformer_block_count": 26, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, - } - - # Create a temporary file to store the JSON - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as tmp_file: - json.dump(model_params, tmp_file, indent=4) - tmp_path = Path(tmp_file.name) - - try: - # Load the JSON using config_struct - model_params = config_struct.ModelParams.load_json(tmp_path) - yield model_params - finally: - tmp_path.unlink - - -@pytest.fixture -def cache_fixture(fiber, model_params) -> cache.AttnPageCache: - # Create and return the cache object - return cache.AttnPageCache( - devices=fiber.devices_dict.values(), model_params=model_params - ) - - -@pytest.mark.parametrize("n_allocated", [1, 16, 255]) -def test_alloc( - cache_fixture: cache.AttnPageCache, - n_allocated, - model_params: config_struct.ModelParams, -): - alloc_page_count = cache_fixture.page_tables[0].shape[0] - - assert alloc_page_count == model_params.paged_kv_cache.device_block_count - - pages = cache_fixture.acquire_free_pages(n_allocated) - last_page = alloc_page_count - 1 - expected_indices = range(last_page, last_page - n_allocated, -1) - for p, expected_ix in zip(pages, expected_indices): - assert p.index == expected_ix - assert p.index > 0 diff --git a/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py new file mode 100644 index 000000000..a1ec00c07 --- /dev/null +++ b/shortfin/tests/apps/llm/components/kvcache/page_pool_test.py @@ -0,0 +1,57 @@ +import pytest +import logging +from shortfin_apps.llm.components.kvcache.page_pool import PagePool, PagePoolConfig +import shortfin as sf +import shortfin.host +import shortfin.array as sfnp +import shortfin.amdgpu + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def setup_pool(generic_device): + pool = PagePool( + devices=[generic_device], + config=PagePoolConfig( + alloc_page_count=256, + dtype=sfnp.float16, + paged_kv_block_size_elements=393216, + ), + ) + return pool + + +def test_page_acquisition(setup_pool): + pool = setup_pool + logger.info(f"=== Running page acquisition test on system ===") + page0 = pool.acquire_free_pages(1) + assert page0 is not None, f"Failed to acquire a free page on system" + logger.info(f"Successfully acquired page on system") + + +def test_page_copy(setup_pool): + pool = setup_pool + logger.info(f"=== Running page copy test on system ===") + (page0,) = pool.acquire_free_pages(1) + page1 = pool.copy_page(page0) + assert page1 is not None, f"Failed to copy a page on system" + assert page0 != page1, f"Copied page should be different from original on system" + logger.info(f"Successfully copied page on system") + + +@pytest.fixture(autouse=True) +def setup_logging(): + """Set up logging format to include timestamp and level""" + logging.basicConfig( + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, + force=True, + ) + + +# Add more tests as needed + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/shortfin/tests/conftest.py b/shortfin/tests/conftest.py index 083698968..b16d5a3c9 100644 --- a/shortfin/tests/conftest.py +++ b/shortfin/tests/conftest.py @@ -50,6 +50,17 @@ def pytest_runtest_setup(item): sf.SystemBuilder.default_system_type = system_type +# Dynamic Parameterization for lsys Fixture +def pytest_generate_tests(metafunc): + if "generic_lsys" in metafunc.fixturenames: + system = metafunc.config.getoption("--system") + if system == "amdgpu": + params = ["cpu", "amdgpu"] + else: + params = [system] + metafunc.parametrize("generic_lsys", params, indirect=True) + + # Keys that will be cleaned project wide prior to and after each test run. # Test code can freely modify these. CLEAN_ENV_KEYS = [ @@ -96,6 +107,28 @@ def kill(): kill() +@pytest.fixture(scope="session") +def generic_lsys(request): + system_type = request.param + if system_type == "cpu" or system_type == "hostcpu": + sc = sf.host.CPUSystemBuilder() + elif system_type == "amdgpu": + sc = sf.amdgpu.SystemBuilder() + lsys = sc.create_system() + yield lsys + lsys.shutdown() + + +@pytest.fixture +def generic_fiber(generic_lsys): + return generic_lsys.create_fiber() + + +@pytest.fixture +def generic_device(generic_fiber): + return generic_fiber.device(0) + + @pytest.fixture def cpu_lsys(): sc = sf.host.CPUSystemBuilder()