diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index c2cd8c21c132..dc2a5f0d891a 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1032,12 +1032,11 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): ) reg_type = self.registers.flat[0].type is_vector_reg = ir.VectorType.isinstance(reg_type) - reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else () - if cur_dtype == i8 and new_dtype == bf16 and reg_shape == (2,): + reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,) + [vector_len] = reg_shape # This is meant to be a 1D assertion. + if cur_dtype == i8 and self.is_signed and new_dtype == bf16 and vector_len in {2, 4}: new_registers = np.empty_like(self.registers) - for idx, reg in np.ndenumerate(self.registers): - reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg) - val_16 = llvm.extractelement(reg_16, c(0, i32)) + def upcast_to_bf16(reg, high): # We first embed the s8 into a bf16 with the exponent equal to # bias + mantissa bits. Then, we zero the msb that didn't fit into the # mantissa, zero out all bits other than msb, and subtract the last @@ -1045,24 +1044,36 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): # lsb of the exponent (msb of the second byte) is zero, which allows us # to losslesly pack the msb there. When 1, it doubles the value of s2, # making the result negative. - new_val_32 = llvm.inline_asm( + return llvm.inline_asm( i32, - [val_16], - """ - { + [reg], + f""" + {{ .reg .b32 s<3>; - prmt.b32 s0, $1, 0x43, 0x4140; + prmt.b32 s0, $1, 0x43, {0x4342 if high else 0x4140}; and.b32 s1, s0, 0xff7fff7f; and.b32 s2, s0, 0xff80ff80; sub.bf16x2 $0, s1, s2; - } + }} """, "=r,r", ) - new_vec = llvm.mlir_undef(ir.VectorType.get((1,), i32)) - new_vec = llvm.insertelement(new_vec, new_val_32, c(0, i32)) + empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((vector_len // 2,), i32)) + for idx, reg in np.ndenumerate(self.registers): + if vector_len == 2: + reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg) + new_reg_32 = upcast_to_bf16(reg_16, high=False) + new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32)) + elif vector_len == 4: + reg_32 = vector.bitcast(ir.VectorType.get((1,), i32), reg) + low = upcast_to_bf16(reg_32, high=False) + high = upcast_to_bf16(reg_32, high=True) + new_vec_32 = llvm.insertelement(empty_vec_32, low, c(0, i32)) + new_vec_32 = llvm.insertelement(new_vec_32, high, c(1, i32)) + else: + raise NotImplementedError(vector_len) new_registers[idx] = vector.bitcast( - ir.VectorType.get((2,), new_dtype), new_vec + ir.VectorType.get((vector_len,), new_dtype), new_vec_32 ) return FragmentedArray( _registers=new_registers, _layout=self.layout, _is_signed=is_signed diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 7dadc71fdcba..71f2d383f809 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1608,19 +1608,19 @@ def kernel(ctx, out, *_): np.testing.assert_array_equal(result, x) - @parameterized.named_parameters( - ("_bf16", jnp.bfloat16) - ) - def test_fast_i8_convert(self, jax_dtype_to): - jax_dtype_to = jnp.dtype(jax_dtype_to) + @parameterized.parameters(2, 4) + def test_fast_i8_convert(self, reg_length): + jax_dtype_to = jnp.dtype(jnp.bfloat16) jax_dtype_from = jnp.dtype(jnp.int8) mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to) def kernel(ctx, inp, out, smem): del ctx, smem arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True) + assert ir.VectorType(arr.registers.flat[0].type).shape == [reg_length] arr.astype(mlir_dtype_to).store_untiled(out) x = jnp.arange(-128, 128, dtype=jax_dtype_from) + x = jnp.tile(x, reg_length // 2) reference = x.astype(jax_dtype_to) result = mgpu.as_gpu_kernel(