diff --git a/tests/pallas/tpu_gmm_test.py b/tests/pallas/tpu_gmm_test.py index be830a6a4473..9c416dabaeb1 100644 --- a/tests/pallas/tpu_gmm_test.py +++ b/tests/pallas/tpu_gmm_test.py @@ -48,11 +48,9 @@ ) hp.settings.load_profile("deterministic") - def seed_strategy() -> hps.SearchStrategy[int]: return hps.integers(min_value=0, max_value=4) - @hps.composite def group_strategy( draw: hps.DrawFn, @@ -73,7 +71,6 @@ def group_strategy( ) return num_groups, group_stride - @hps.composite def group_sizes_strategy( draw: hps.DrawFn, m: int, num_groups: int @@ -97,19 +94,12 @@ def group_sizes_strategy( starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final]) return jnp.array(ends - starts, dtype=jnp.int32) - GROUPED_MATMUL_TESTS = ( - (128, 128, 128), - (256, 128, 128), - (128, 256, 128), - (128, 128, 256), - (256, 128, 512), - (512, 128, 128), - (512, 2048, 128), + (128, 128, 128), # Small + (512, 2048, 256), # Big (128, 8, 16), # Test partial tiles. ) - def random_dense( shape: tuple[int, ...], key: jax.Array, @@ -121,7 +111,6 @@ def random_dense( x = jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit) # pylint: disable=invalid-unary-operand-type return x.astype(jnp.bfloat16).astype(dtype) - def dot( lhs: jnp.ndarray, rhs: jnp.ndarray, @@ -133,7 +122,6 @@ def dot( rhs = jnp.transpose(rhs) if transpose_rhs else rhs return jax.lax.dot(lhs, rhs, preferred_element_type=preferred_element_type) - def reference_gmm( lhs: jnp.ndarray, rhs: jnp.ndarray, @@ -154,7 +142,6 @@ def reference_gmm( start += group_sizes[i] return jnp.concatenate(out, axis=0) - def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]: dtypes = [jnp.float32, jnp.bfloat16] @@ -164,7 +151,6 @@ def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]: result.append(x + dtypes_tuple) return tuple(result) - def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]: flags = [False, True] result = [] @@ -173,7 +159,6 @@ def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]: result.append(x + (flag,)) return tuple(result) - def tolerances( lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, out_dtype: jnp.dtype ) -> tuple[float, float]: @@ -185,7 +170,6 @@ def tolerances( return 1e-3, 1e-2 # atol, rtol return 1e-3, 1e-5 # atol, rtol - # TODO(tgale): Fix errors with strict dtype promotion. @jtu.with_config(jax_numpy_dtype_promotion="standard") class GroupedMatmulTest(jtu.JaxTestCase): @@ -218,15 +202,16 @@ def gmm_test( m: int, k: int, n: int, - lhs_dtype: jnp.dtype, - rhs_dtype: jnp.dtype, - out_dtype: jnp.dtype, - transpose_rhs: bool, data: hps.SearchStrategy[hps.DataObject], interpret: bool = False, ): seed = data.draw(seed_strategy()) num_groups, _ = data.draw(group_strategy(max_stride=1)) + lhs_dtype, rhs_dtype, out_dtype = [ + data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) + for _ in range(3) + ] + transpose_rhs = data.draw(hps.booleans()) key = jax.random.key(seed) k1, k2 = jax.random.split(key, 2) @@ -270,35 +255,26 @@ def reference_fn(lhs, rhs, group_sizes, preferred_element_type): self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) - @parameterized.parameters( - *with_transpose_argument(with_dtype_arguments(GROUPED_MATMUL_TESTS)) - ) + @parameterized.parameters(*GROUPED_MATMUL_TESTS) @hp.given(hps.data()) def test_gmm( self, m: int, k: int, n: int, - lhs_dtype: jnp.dtype, - rhs_dtype: jnp.dtype, - out_dtype: jnp.dtype, - transpose_rhs: bool, data: hps.SearchStrategy[hps.DataObject], ): - self.gmm_test(m, k, n, lhs_dtype, rhs_dtype, out_dtype, transpose_rhs, data) + self.gmm_test(m, k, n, data) # NOTE: Run fewer tests with interpret mode. We just want to sanity check that # changes do not break running these kernels with interpret=True. - @parameterized.parameters(*with_dtype_arguments(GROUPED_MATMUL_TESTS[0:1])) + @parameterized.parameters(*GROUPED_MATMUL_TESTS[0:1]) @hp.given(hps.data()) def test_gmm_interpret( self, m: int, k: int, n: int, - lhs_dtype: jnp.dtype, - rhs_dtype: jnp.dtype, - out_dtype: jnp.dtype, data: hps.SearchStrategy[hps.DataObject], ): self.skipTest("interpret mode with dynamic grids is unsupported") @@ -306,28 +282,25 @@ def test_gmm_interpret( m, k, n, - lhs_dtype, - rhs_dtype, - out_dtype, - transpose_rhs=False, data=data, interpret=True, ) - @parameterized.parameters(*with_dtype_arguments(GROUPED_MATMUL_TESTS)) + @parameterized.parameters(*GROUPED_MATMUL_TESTS) @hp.given(hps.data()) def test_gmm_sharded_groups( self, m: int, k: int, n: int, - lhs_dtype: jnp.dtype, - rhs_dtype: jnp.dtype, - out_dtype: jnp.dtype, data: hps.SearchStrategy[hps.DataObject], ): seed = data.draw(seed_strategy()) num_groups, group_stride = data.draw(group_strategy()) + lhs_dtype, rhs_dtype, out_dtype = [ + data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) + for _ in range(3) + ] key = jax.random.key(seed) k1, k2 = jax.random.split(key, 2)