Skip to content

Commit

Permalink
simplify test heavily and hopefully this reenables windows support
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida committed Oct 30, 2024
1 parent 10f3391 commit 8f77093
Showing 1 changed file with 22 additions and 100 deletions.
122 changes: 22 additions & 100 deletions shortfin/tests/invocation/vmfb_buffer_access_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,15 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import urllib.request
from pathlib import Path
import functools
import pytest
import shortfin as sf
import shortfin.array as sfnp
import array
import random
import struct
import sys


@pytest.fixture(scope="session")
def kvcache_compiled_cpu_path(tmp_path):
def kvcache_compiled_cpu_path(tmp_path_factory):
try:
import iree.compiler.tools as tools
except ModuleNotFoundError:
Expand Down Expand Up @@ -96,65 +91,20 @@ def kvcache_compiled_cpu_path(tmp_path):
}
"""

# Create temporary directory for our files

tmp_dir_path = Path(tmp_path)

# Write MLIR to temporary file
mlir_path = tmp_dir_path / "kvcache.mlir"
# Get a temporary directory using tmp_path_factory
tmp_dir = tmp_path_factory.mktemp("vmfb_buffer_access_test")
mlir_path = tmp_dir / "kvcache.mlir"
mlir_path.write_text(KVCACHE_MODULE_CONTENTS)
vmfb_path = tmp_dir / "kvcache_cpu.vmfb"

# Define output path for compiled binary
vmfb_path = tmp_dir_path / "kvcache_cpu.vmfb"

# Compile the MLIR to VMFB
tools.compile_file(
str(mlir_path),
output_file=str(vmfb_path),
target_backends=["llvm-cpu"],
input_type="AUTO",
)

yield vmfb_path


def float_to_float16(f):
"""Convert a Python float to float16 (stored as uint16)."""
return struct.unpack("H", struct.pack("e", f))[0]


def float16_to_float(h):
"""Convert a float16 (stored as uint16) to Python float."""
return struct.unpack("e", struct.pack("H", h))[0]


def create_random_float16_array(size):
"""Create an array of random float16 values between -1 and 1."""
return array.array(
"H", [float_to_float16(random.uniform(-1, 1)) for _ in range(size)]
)


def calculate_mean_abs(float16_array):
"""Calculate mean of absolute values for a float16 array."""
return sum(abs(float16_to_float(x)) for x in float16_array) / len(float16_array)


def assert_close(a, b, rtol=1e-3, atol=1e-3, err_msg=""):
"""Check if two float16 arrays are close within tolerance."""
if len(a) != len(b):
raise AssertionError(f"{err_msg}: Arrays have different lengths")

for i, (x, y) in enumerate(zip(a, b)):
x_float = float16_to_float(x)
y_float = float16_to_float(y)
abs_diff = abs(x_float - y_float)
tol = atol + rtol * abs(y_float)
if abs_diff > tol:
raise AssertionError(
f"{err_msg}: Arrays differ at index {i}: {x_float} != {y_float} "
f"(diff={abs_diff}, tol={tol})"
)
return vmfb_path


@pytest.fixture
Expand All @@ -175,6 +125,11 @@ def device(fiber):
return fiber.device(0)


def create_random_float16_array(size):
"""Create an array of random uint16 values."""
return array.array("H", [random.randint(0, 65535) for _ in range(size)])


def create_scalar_device_array(device, value, dtype=sfnp.int64):
"""Helper function to create a scalar device array."""
arr = sfnp.device_array.for_device(device, [1], dtype)
Expand All @@ -185,23 +140,8 @@ def create_scalar_device_array(device, value, dtype=sfnp.int64):
return arr


@pytest.mark.parametrize(
"await_before_invoke",
[
True,
False, # Need to potentially xfail this case if using GPU
],
)
@pytest.mark.parametrize("await_before_invoke", [True, False])
def test_kvcache_noreturn(lsys, fiber, kvcache_compiled_cpu_path, await_before_invoke):
"""
This test mimics the kvcache indexing, reading, and writing in open-llama-3b-v2-f16.gguf.
It's a simple test to verify that shortfin can:
- Create device arrays
- Map and fill them
- Invoke a VMFB and pass arguments to it
- Properly retain changes made by the VMFB to device arrays provided as arguments
"""
device = fiber.device(0)
program_module = lsys.load_module(kvcache_compiled_cpu_path)
program = sf.Program([program_module], devices=fiber.raw_devices)
Expand All @@ -211,49 +151,41 @@ def test_kvcache_noreturn(lsys, fiber, kvcache_compiled_cpu_path, await_before_i

# Test parameters
num_pages = 4
num_layers = 26 # Number of transformer layers
num_kv = 2 # K and V states
num_layers = 26
num_kv = 2
batch_size = 16
num_heads = 32
head_dim = 100

# Initialize test data - note we're only writing one layer at a time
test_data_size = batch_size * num_heads * head_dim
test_data = create_random_float16_array(test_data_size)

# The kvcache shape should match the MLIR module's expected shape
# [num_pages, num_layers * num_kv * batch_size * num_heads * head_dim]
total_dim = num_layers * num_kv * batch_size * num_heads * head_dim
assert total_dim == 2662400
kvcache_shape = [num_pages, total_dim]
kvcache_data = array.array("H", [0] * (kvcache_shape[0] * kvcache_shape[1]))

async def main():
# Create device arrays
device_kvcache = sfnp.device_array(device, kvcache_shape, sfnp.float16)
device_new_data = sfnp.device_array(
device, [batch_size, num_heads, head_dim], sfnp.float16
)

# Initialize kvcache on device
staging_kvcache = device_kvcache.for_transfer()
with staging_kvcache.map(discard=True) as m:
m.fill(kvcache_data)
device_kvcache.copy_from(staging_kvcache)

# Initialize new data on device
staging_new_data = device_new_data.for_transfer()
with staging_new_data.map(discard=True) as m:
m.fill(test_data)
device_new_data.copy_from(staging_new_data)

# Test writing and reading for both K and V states
for layer_idx in range(2): # Test first two layers
for kv_idx in range(num_kv): # Test both K and V
page_index = create_scalar_device_array(device, 1) # Write to page 1
for layer_idx in range(2):
for kv_idx in range(num_kv):
page_index = create_scalar_device_array(device, 1)
layer_index = create_scalar_device_array(device, layer_idx)

# Write to kvcache
if await_before_invoke:
await device
ret = await write_function(
Expand All @@ -264,33 +196,23 @@ async def main():
fiber=fiber,
)

# Read from kvcache
if await_before_invoke:
await device
(read_result,) = await read_function(
device_kvcache, page_index, layer_index, fiber=fiber
)

# Transfer results back to host
host_result = read_result.for_transfer()
host_result.copy_from(read_result)
await device

# Get the result array for the specific K/V state
# Simple byte comparison of the arrays
result_array = array.array("H", host_result.items)
offset = kv_idx * test_data_size
result_slice = result_array[offset : offset + test_data_size]

# Verify numerical correctness for the specific K/V state
assert_close(
result_slice,
test_data,
rtol=1e-3,
atol=1e-3,
err_msg=f"KV cache read/write mismatch for layer {layer_idx}, {'key' if kv_idx == 0 else 'value'} state",
assert result_slice.tobytes() == test_data.tobytes(), (
f"KV cache read/write mismatch for layer {layer_idx}, "
f"{'key' if kv_idx == 0 else 'value'} state"
)

# Additional statistical checks
result_mean = calculate_mean_abs(result_slice)
test_mean = calculate_mean_abs(test_data)
assert abs(result_mean - test_mean) < 1e-3
lsys.run(main())

0 comments on commit 8f77093

Please sign in to comment.