From 42d95a3770a02d5d6d1acab0d6ac5a8c772363dd Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Thu, 16 Jan 2025 15:11:56 -0800 Subject: [PATCH] [Mosaic TPU] Emulate converting x16 vector to mask if mask packing is supported. PiperOrigin-RevId: 716395352 --- tests/pallas/tpu_ops_test.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index e9e064b4744f..29fd741814ba 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -271,15 +271,13 @@ def body(x_ref, y_ref): expected = expected.at[...].set(jnp.where(get_mask(x), 0.0, -1.0)) np.testing.assert_array_equal(result, expected) - @parameterized.product(dtype=[jnp.float32, jnp.bfloat16, jnp.int8]) + @parameterized.product(dtype=[jnp.float32, jnp.bfloat16, jnp.int16, jnp.int8]) def test_cast_vector_to_mask(self, dtype): + if not jtu.if_cloud_tpu_at_least(2025, 1, 22): + self.skipTest("Requires libtpu built after 2025-01-22") shape = (128, 128) bitwidth = pallas_utils.dtype_bitwidth(dtype) - if ( - (jtu.get_tpu_version() > 5 and bitwidth < 8) - or (jtu.get_tpu_version() == 5 and bitwidth not in (8, 32)) - or (jtu.get_tpu_version() < 5 and bitwidth < 32) - ): + if jtu.get_tpu_version() < 5 and bitwidth < 32: self.skipTest( f"Not implemented: cast vector to mask with bitwidth == {bitwidth}" )