Skip to content

Commit

Permalink
[Mosaic GPU] Add support for fast upcasts of s8 to bf16 for vectors o…
Browse files Browse the repository at this point in the history
…f 4 elements

To complement the current path that only handles 2 elements.

PiperOrigin-RevId: 700998965
  • Loading branch information
apaszke authored and Google-ML-Automation committed Nov 28, 2024
1 parent a158e02 commit b09b077
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
39 changes: 25 additions & 14 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,37 +1032,48 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None):
)
reg_type = self.registers.flat[0].type
is_vector_reg = ir.VectorType.isinstance(reg_type)
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else ()
if cur_dtype == i8 and new_dtype == bf16 and reg_shape == (2,):
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,)
[vector_len] = reg_shape # This is meant to be a 1D assertion.
if cur_dtype == i8 and self.is_signed and new_dtype == bf16 and vector_len in {2, 4}:
new_registers = np.empty_like(self.registers)
for idx, reg in np.ndenumerate(self.registers):
reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg)
val_16 = llvm.extractelement(reg_16, c(0, i32))
def upcast_to_bf16(reg, high):
# We first embed the s8 into a bf16 with the exponent equal to
# bias + mantissa bits. Then, we zero the msb that didn't fit into the
# mantissa, zero out all bits other than msb, and subtract the last
# two values from each other. This takes advantage of the fact that the
# lsb of the exponent (msb of the second byte) is zero, which allows us
# to losslesly pack the msb there. When 1, it doubles the value of s2,
# making the result negative.
new_val_32 = llvm.inline_asm(
return llvm.inline_asm(
i32,
[val_16],
"""
{
[reg],
f"""
{{
.reg .b32 s<3>;
prmt.b32 s0, $1, 0x43, 0x4140;
prmt.b32 s0, $1, 0x43, {0x4342 if high else 0x4140};
and.b32 s1, s0, 0xff7fff7f;
and.b32 s2, s0, 0xff80ff80;
sub.bf16x2 $0, s1, s2;
}
}}
""",
"=r,r",
)
new_vec = llvm.mlir_undef(ir.VectorType.get((1,), i32))
new_vec = llvm.insertelement(new_vec, new_val_32, c(0, i32))
empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((vector_len // 2,), i32))
for idx, reg in np.ndenumerate(self.registers):
if vector_len == 2:
reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg)
new_reg_32 = upcast_to_bf16(reg_16, high=False)
new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32))
elif vector_len == 4:
reg_32 = vector.bitcast(ir.VectorType.get((1,), i32), reg)
low = upcast_to_bf16(reg_32, high=False)
high = upcast_to_bf16(reg_32, high=True)
new_vec_32 = llvm.insertelement(empty_vec_32, low, c(0, i32))
new_vec_32 = llvm.insertelement(new_vec_32, high, c(1, i32))
else:
raise NotImplementedError(vector_len)
new_registers[idx] = vector.bitcast(
ir.VectorType.get((2,), new_dtype), new_vec
ir.VectorType.get((vector_len,), new_dtype), new_vec_32
)
return FragmentedArray(
_registers=new_registers, _layout=self.layout, _is_signed=is_signed
Expand Down
10 changes: 5 additions & 5 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,19 +1608,19 @@ def kernel(ctx, out, *_):

np.testing.assert_array_equal(result, x)

@parameterized.named_parameters(
("_bf16", jnp.bfloat16)
)
def test_fast_i8_convert(self, jax_dtype_to):
jax_dtype_to = jnp.dtype(jax_dtype_to)
@parameterized.parameters(2, 4)
def test_fast_i8_convert(self, reg_length):
jax_dtype_to = jnp.dtype(jnp.bfloat16)
jax_dtype_from = jnp.dtype(jnp.int8)
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
def kernel(ctx, inp, out, smem):
del ctx, smem
arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True)
assert ir.VectorType(arr.registers.flat[0].type).shape == [reg_length]
arr.astype(mlir_dtype_to).store_untiled(out)

x = jnp.arange(-128, 128, dtype=jax_dtype_from)
x = jnp.tile(x, reg_length // 2)
reference = x.astype(jax_dtype_to)

result = mgpu.as_gpu_kernel(
Expand Down

0 comments on commit b09b077

Please sign in to comment.