Skip to content

Commit

Permalink
[Mosaic TPU] Emulate converting x16 vector to mask if mask packing is…
Browse files Browse the repository at this point in the history
… supported.

PiperOrigin-RevId: 716395352
  • Loading branch information
bythew3i authored and Google-ML-Automation committed Jan 21, 2025
1 parent 79bd72e commit 42d95a3
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions tests/pallas/tpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,13 @@ def body(x_ref, y_ref):
expected = expected.at[...].set(jnp.where(get_mask(x), 0.0, -1.0))
np.testing.assert_array_equal(result, expected)

@parameterized.product(dtype=[jnp.float32, jnp.bfloat16, jnp.int8])
@parameterized.product(dtype=[jnp.float32, jnp.bfloat16, jnp.int16, jnp.int8])
def test_cast_vector_to_mask(self, dtype):
if not jtu.if_cloud_tpu_at_least(2025, 1, 22):
self.skipTest("Requires libtpu built after 2025-01-22")
shape = (128, 128)
bitwidth = pallas_utils.dtype_bitwidth(dtype)
if (
(jtu.get_tpu_version() > 5 and bitwidth < 8)
or (jtu.get_tpu_version() == 5 and bitwidth not in (8, 32))
or (jtu.get_tpu_version() < 5 and bitwidth < 32)
):
if jtu.get_tpu_version() < 5 and bitwidth < 32:
self.skipTest(
f"Not implemented: cast vector to mask with bitwidth == {bitwidth}"
)
Expand Down

0 comments on commit 42d95a3

Please sign in to comment.