diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index e9e064b4744f..29fd741814ba 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -271,15 +271,13 @@ def body(x_ref, y_ref): expected = expected.at[...].set(jnp.where(get_mask(x), 0.0, -1.0)) np.testing.assert_array_equal(result, expected) - @parameterized.product(dtype=[jnp.float32, jnp.bfloat16, jnp.int8]) + @parameterized.product(dtype=[jnp.float32, jnp.bfloat16, jnp.int16, jnp.int8]) def test_cast_vector_to_mask(self, dtype): + if not jtu.if_cloud_tpu_at_least(2025, 1, 22): + self.skipTest("Requires libtpu built after 2025-01-22") shape = (128, 128) bitwidth = pallas_utils.dtype_bitwidth(dtype) - if ( - (jtu.get_tpu_version() > 5 and bitwidth < 8) - or (jtu.get_tpu_version() == 5 and bitwidth not in (8, 32)) - or (jtu.get_tpu_version() < 5 and bitwidth < 32) - ): + if jtu.get_tpu_version() < 5 and bitwidth < 32: self.skipTest( f"Not implemented: cast vector to mask with bitwidth == {bitwidth}" )