Skip to content

Commit

Permalink
Make gmm TPU kernel tests significantly cheaper
Browse files Browse the repository at this point in the history
We were testing lots of very similar cases that did not really help a lot with coverage.

PiperOrigin-RevId: 707115030
  • Loading branch information
apaszke authored and Google-ML-Automation committed Dec 17, 2024
1 parent 0ec902d commit 3fc2371
Showing 1 changed file with 15 additions and 42 deletions.
57 changes: 15 additions & 42 deletions tests/pallas/tpu_gmm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,9 @@
)
hp.settings.load_profile("deterministic")


def seed_strategy() -> hps.SearchStrategy[int]:
return hps.integers(min_value=0, max_value=4)


@hps.composite
def group_strategy(
draw: hps.DrawFn,
Expand All @@ -73,7 +71,6 @@ def group_strategy(
)
return num_groups, group_stride


@hps.composite
def group_sizes_strategy(
draw: hps.DrawFn, m: int, num_groups: int
Expand All @@ -97,19 +94,12 @@ def group_sizes_strategy(
starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final])
return jnp.array(ends - starts, dtype=jnp.int32)


GROUPED_MATMUL_TESTS = (
(128, 128, 128),
(256, 128, 128),
(128, 256, 128),
(128, 128, 256),
(256, 128, 512),
(512, 128, 128),
(512, 2048, 128),
(128, 128, 128), # Small
(512, 2048, 256), # Big
(128, 8, 16), # Test partial tiles.
)


def random_dense(
shape: tuple[int, ...],
key: jax.Array,
Expand All @@ -121,7 +111,6 @@ def random_dense(
x = jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit) # pylint: disable=invalid-unary-operand-type
return x.astype(jnp.bfloat16).astype(dtype)


def dot(
lhs: jnp.ndarray,
rhs: jnp.ndarray,
Expand All @@ -133,7 +122,6 @@ def dot(
rhs = jnp.transpose(rhs) if transpose_rhs else rhs
return jax.lax.dot(lhs, rhs, preferred_element_type=preferred_element_type)


def reference_gmm(
lhs: jnp.ndarray,
rhs: jnp.ndarray,
Expand All @@ -154,7 +142,6 @@ def reference_gmm(
start += group_sizes[i]
return jnp.concatenate(out, axis=0)


def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]:
dtypes = [jnp.float32, jnp.bfloat16]

Expand All @@ -164,7 +151,6 @@ def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]:
result.append(x + dtypes_tuple)
return tuple(result)


def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]:
flags = [False, True]
result = []
Expand All @@ -173,7 +159,6 @@ def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]:
result.append(x + (flag,))
return tuple(result)


def tolerances(
lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, out_dtype: jnp.dtype
) -> tuple[float, float]:
Expand All @@ -185,7 +170,6 @@ def tolerances(
return 1e-3, 1e-2 # atol, rtol
return 1e-3, 1e-5 # atol, rtol


# TODO(tgale): Fix errors with strict dtype promotion.
@jtu.with_config(jax_numpy_dtype_promotion="standard")
class GroupedMatmulTest(jtu.JaxTestCase):
Expand Down Expand Up @@ -218,15 +202,16 @@ def gmm_test(
m: int,
k: int,
n: int,
lhs_dtype: jnp.dtype,
rhs_dtype: jnp.dtype,
out_dtype: jnp.dtype,
transpose_rhs: bool,
data: hps.SearchStrategy[hps.DataObject],
interpret: bool = False,
):
seed = data.draw(seed_strategy())
num_groups, _ = data.draw(group_strategy(max_stride=1))
lhs_dtype, rhs_dtype, out_dtype = [
data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16]))
for _ in range(3)
]
transpose_rhs = data.draw(hps.booleans())

key = jax.random.key(seed)
k1, k2 = jax.random.split(key, 2)
Expand Down Expand Up @@ -270,64 +255,52 @@ def reference_fn(lhs, rhs, group_sizes, preferred_element_type):
self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol)
self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol)

@parameterized.parameters(
*with_transpose_argument(with_dtype_arguments(GROUPED_MATMUL_TESTS))
)
@parameterized.parameters(*GROUPED_MATMUL_TESTS)
@hp.given(hps.data())
def test_gmm(
self,
m: int,
k: int,
n: int,
lhs_dtype: jnp.dtype,
rhs_dtype: jnp.dtype,
out_dtype: jnp.dtype,
transpose_rhs: bool,
data: hps.SearchStrategy[hps.DataObject],
):
self.gmm_test(m, k, n, lhs_dtype, rhs_dtype, out_dtype, transpose_rhs, data)
self.gmm_test(m, k, n, data)

# NOTE: Run fewer tests with interpret mode. We just want to sanity check that
# changes do not break running these kernels with interpret=True.
@parameterized.parameters(*with_dtype_arguments(GROUPED_MATMUL_TESTS[0:1]))
@parameterized.parameters(*GROUPED_MATMUL_TESTS[0:1])
@hp.given(hps.data())
def test_gmm_interpret(
self,
m: int,
k: int,
n: int,
lhs_dtype: jnp.dtype,
rhs_dtype: jnp.dtype,
out_dtype: jnp.dtype,
data: hps.SearchStrategy[hps.DataObject],
):
self.skipTest("interpret mode with dynamic grids is unsupported")
self.gmm_test(
m,
k,
n,
lhs_dtype,
rhs_dtype,
out_dtype,
transpose_rhs=False,
data=data,
interpret=True,
)

@parameterized.parameters(*with_dtype_arguments(GROUPED_MATMUL_TESTS))
@parameterized.parameters(*GROUPED_MATMUL_TESTS)
@hp.given(hps.data())
def test_gmm_sharded_groups(
self,
m: int,
k: int,
n: int,
lhs_dtype: jnp.dtype,
rhs_dtype: jnp.dtype,
out_dtype: jnp.dtype,
data: hps.SearchStrategy[hps.DataObject],
):
seed = data.draw(seed_strategy())
num_groups, group_stride = data.draw(group_strategy())
lhs_dtype, rhs_dtype, out_dtype = [
data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16]))
for _ in range(3)
]

key = jax.random.key(seed)
k1, k2 = jax.random.split(key, 2)
Expand Down

0 comments on commit 3fc2371

Please sign in to comment.