Skip to content

Commit

Permalink
[pallas:mosaic_gpu] Add test for FragmentedArray.bitcast.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699919048
  • Loading branch information
petebu authored and Google-ML-Automation committed Nov 25, 2024
1 parent b372ce4 commit 69e3f0d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
10 changes: 8 additions & 2 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,8 @@ def __init__(

if (_is_signed is not None) != ir.IntegerType.isinstance(self.mlir_dtype):
raise TypeError(
"is_signed must only be non-None if the MLIR type is an integer"
f" type, got {_is_signed=} for {self.mlir_dtype}"
"is_signed must be non-None if and only if the MLIR type is an"
f" integer type, got {_is_signed=} for {self.mlir_dtype}"
)

match self.layout:
Expand Down Expand Up @@ -962,6 +962,12 @@ def fast_instr(x):
return fast_instr

def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None):
if (output_is_signed is not None) != ir.IntegerType.isinstance(elt):
raise TypeError(
"output_is_signed must be non-None if and only if the MLIR type is an"
f" integer type, got {output_is_signed=} for {elt}"
)

if elt == self.mlir_dtype:
return self
reg_type = self.registers.flat[0].type
Expand Down
34 changes: 34 additions & 0 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,6 +1577,40 @@ def kernel(ctx, _):

_ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), (), None)()

@parameterized.parameters(
(jnp.float16, jnp.float16), # Noop
(jnp.int16, jnp.bfloat16),
(jnp.int16, jnp.float16),
(jnp.uint16, jnp.float16),
(jnp.float32, jnp.int32),
(jnp.float32, jnp.uint32),
(jnp.uint32, jnp.int32),
(jnp.int32, jnp.uint32),
)
def test_bitcast(self, in_dtype, out_dtype):
out_ir_type = utils.dtype_to_ir_type(out_dtype)
in_is_signed = utils.is_signed(in_dtype)
out_is_signed = utils.is_signed(out_dtype)

def kernel(ctx, inp, out, smem):
del ctx, smem
arr = mgpu.FragmentedArray.load_strided(inp, is_signed=in_is_signed)
arr = arr.bitcast(out_ir_type, output_is_signed=out_is_signed)
arr.store_untiled(out)

x = jnp.arange(256, dtype=in_dtype)
reference = jax.lax.bitcast_convert_type(x, out_dtype)

result = mgpu.as_gpu_kernel(
kernel,
(1, 1, 1),
(128, 1, 1),
x,
reference,
None,
)(x)
np.testing.assert_array_equal(result, reference)


class ProfilerTest(TestCase):

Expand Down

0 comments on commit 69e3f0d

Please sign in to comment.