Skip to content

Commit

Permalink
[Mosaic TPU] Add optimized casts for bf16->s4 in TPUv6
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 717829264
  • Loading branch information
apaszke authored and Google-ML-Automation committed Jan 21, 2025
1 parent 0a89760 commit 0727f3a
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 32 deletions.
4 changes: 3 additions & 1 deletion jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -561,9 +561,11 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
if (dst_bitwidth > 32) {
return op.emitOpError("Target bitwidth too large");
}
// We have low-level optimized code for bf16->s8 and bf16->s4 casts on v6.
if (ctx.hardware_generation >= 6 && is_vector &&
src_vty.getElementType().isBF16() &&
dst_vty.getElementType().isSignlessInteger(8)) {
(dst_vty.getElementType().isSignlessInteger(8) ||
dst_vty.getElementType().isSignlessInteger(4))) {
auto new_op = builder.create<tpu::FPToSIOp>(
op.getType(), op.getIn(), tpu::RoundingMode::kTowardsZero);
op.replaceAllUsesWith(new_op.getResult());
Expand Down
126 changes: 95 additions & 31 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,23 @@ def _random_value(key: jax.Array, shape_dtype: jax.ShapeDtypeStruct
raise NotImplementedError(shape_dtype)


# TODO(apaszke): Add 8-bit floats.
# TODO(apaszke): Add int4.
_DTYPES = (
_DTYPES_32BIT = (
"float32",
"bfloat16",
"int32",
"uint32",
)
# TODO(apaszke): Add 8-bit floats.
_DTYPES_SUB_32BIT = (
"bfloat16",
"int16",
"int8",
"uint32",
"int4",
"uint16",
"uint8",
"uint4",
"bool",
)
_DTYPES = (*_DTYPES_32BIT, *_DTYPES_SUB_32BIT)


@hps.composite
Expand Down Expand Up @@ -543,16 +547,60 @@ def kernel(x_ref, y_ref):
out = self.pallas_call(kernel, out_shape=x_shape_dtype)(x)
self.assertAllClose(out, func(x), atol=tol, rtol=tol)

@parameterized.product(from_dtype=_DTYPES, to_dtype=_DTYPES)
@parameterized.product(from_dtype=_DTYPES_32BIT, to_dtype=_DTYPES)
@hp.given(hps.data())
def test_cast(self, from_dtype, to_dtype, data):
def test_cast_from_32bit(self, from_dtype, to_dtype, data):
if from_dtype == to_dtype:
self.skipTest("Unnecessary test")
if jtu.is_device_tpu(version=4):
if to_dtype in {"int8", "uint8", "int4", "uint4"}:
self.skipTest("Not supported on this TPU generation")
if to_dtype in {"int16", "uint16"} and not jtu.if_cloud_tpu_at_least(2025, 1, 18):
self.skipTest("Test requires libtpu from 2025/1/18 or later")
if jtu.test_device_matches(["tpu"]) and jtu.get_tpu_version() < 4:
# Currently only casts between 32-bit types and to bf16 are supported.
if to_dtype not in {"int32", "uint32", "float32", "bfloat16"}:
self.skipTest("Not supported on this TPU generation")
if jtu.test_device_matches(["gpu"]) and to_dtype in {"int4", "uint4"}:
self.skipTest("int4/uint4 casts are buggy on GPU") # b/391292861

# XLA does not specify the float->int conversion result for NaNs.
elements = dict(allow_nan=not jnp.issubdtype(to_dtype, jnp.integer))
x = data.draw(hnp.arrays(from_dtype, (8, 128), elements=elements))
x = jnp.asarray(x)
def kernel(x_ref, y_ref):
x = x_ref[...]
y = x.astype(to_dtype)
if to_dtype == jnp.dtype("bool"):
y = y.astype(jnp.int32)
y_ref[...] = y
y_dtype = jnp.int32 if to_dtype == jnp.dtype("bool") else to_dtype
try:
y = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct(x.shape, y_dtype))(x)
except Exception as e:
if "Unsupported cast" in e.args[0]:
self.skipTest("Unsupported cast")
raise
if to_dtype == jnp.dtype("bool"):
y = y.astype(jnp.dtype("bool"))
y_ref = x.astype(to_dtype)
if to_dtype == jnp.bfloat16:
y, y_ref = y.astype(np.float32), y_ref.astype(np.float32)
np.testing.assert_array_equal(y, y_ref)

# Types narrower than 32-bit have few values so we test them exhaustively.
# We also take one more pass with random data just to ensure that we don't
# miss bugs that would be hidden due to exhaustive enumeration being in order.
@parameterized.product(from_dtype=_DTYPES_SUB_32BIT, to_dtype=_DTYPES, randomize=(False, True))
def test_cast_from_sub_32bit(self, from_dtype, to_dtype, randomize):
if from_dtype == to_dtype:
self.skipTest("Unnecessary test")
if jtu.is_device_tpu(version=4):
allowed_v4_cats = {("int16", "int32"): (2025, 1, 18)}
if (
from_dtype in {"int16", "int8", "uint16", "uint8"}
or to_dtype in {"int8", "uint8"}
from_dtype in {"int16", "int8", "uint16", "uint8", "int4", "uint4"}
or to_dtype in {"int8", "uint8", "int4", "uint4"}
) and (from_dtype, to_dtype) not in allowed_v4_cats:
self.skipTest("Not supported on this TPU generation")
if minimum_libtpu_date := allowed_v4_cats.get((from_dtype, to_dtype), None):
Expand All @@ -561,10 +609,9 @@ def test_cast(self, from_dtype, to_dtype, data):
if to_dtype in {"int16", "uint16"} and not jtu.if_cloud_tpu_at_least(2025, 1, 18):
self.skipTest("Test requires libtpu from 2025/1/18 or later")
if jtu.test_device_matches(["tpu"]) and jtu.get_tpu_version() < 4:
# Currently only casts between 32-bit types and to bf16 are supported.
if (from_dtype not in {"int32", "uint32", "float32"} or
to_dtype not in {"int32", "uint32", "float32", "bfloat16"}):
self.skipTest("Not supported on this TPU generation")
self.skipTest("Not supported on this TPU generation")
if jtu.test_device_matches(["gpu"]) and to_dtype in {"int4", "uint4"}:
self.skipTest("int4/uint4 casts are buggy on GPU") # b/391292861

from_int = np.issubdtype(np.dtype(from_dtype), np.integer)
to_int = np.issubdtype(np.dtype(to_dtype), np.integer)
Expand All @@ -575,24 +622,42 @@ def test_cast(self, from_dtype, to_dtype, data):
self.skipTest("trunc from non-32 bit only implemented recently")

# TODO(sharadmv,apaszke): add support for the following casts
if from_dtype == "bool" and to_dtype in {"int16", "int8", "uint16", "uint8"}:
if (from_dtype == "bool" and
to_dtype in {"int16", "int8", "int4", "uint16", "uint8", "uint4"}):
self.skipTest("Not supported: cannot extend to sub-32 bit types")

if from_dtype == "bfloat16":
from_dtype = jnp.bfloat16
if to_dtype == "bfloat16":
to_dtype = jnp.bfloat16

# XLA does not specify the float->int conversion result for NaNs.
elements = dict(allow_nan=not jnp.issubdtype(to_dtype, jnp.integer))
if from_dtype == jnp.bfloat16:
x = jnp.asarray(
data.draw(hnp.arrays(jnp.float32, (8, 128), elements=elements))
)
x = x.astype(jnp.bfloat16)
def bitwidth(dtype):
if jnp.issubdtype(dtype, jnp.integer):
return jnp.iinfo(dtype).bits
elif jnp.issubdtype(dtype, jnp.floating):
return jnp.finfo(dtype).bits
else:
raise ValueError(f"Unsupported dtype: {dtype}")

if from_dtype != "bool":
from_bitwidth = bitwidth(from_dtype)
from_int_dtype = getattr(jnp, "uint" + str(from_bitwidth))
if randomize:
# randint has no support for 4 bit integers.
shape = (128, 128)
rand_int_dtype = getattr(jnp, "uint" + str(max(8, from_bitwidth)))
x = random.randint(
random.key(1234), shape, 0, 1 << from_bitwidth, rand_int_dtype
).astype(from_int_dtype)
x = lax.bitcast_convert_type(x, from_dtype)
else:
x = jax.lax.bitcast_convert_type(
jnp.arange(1 << from_bitwidth, dtype=from_int_dtype), from_dtype
).reshape(8, -1)
else:
x = data.draw(hnp.arrays(from_dtype, (8, 128), elements=elements))
x = jnp.asarray(x)
if randomize:
x = random.randint(random.key(234), (16, 16), 0, 1, jnp.int32) != 0
else:
x = jnp.asarray([[False, True], [True, False]], dtype="bool")
assert x.dtype == jnp.dtype(from_dtype)
# XLA does not specify the float->int conversion result for NaNs.
if jnp.issubdtype(from_dtype, jnp.floating):
x = x.at[jnp.isnan(x)].set(0)
if from_dtype == jnp.dtype("bool"):
x = x.astype(jnp.int32)
def kernel(x_ref, y_ref):
Expand All @@ -603,8 +668,7 @@ def kernel(x_ref, y_ref):
if to_dtype == jnp.dtype("bool"):
y = y.astype(jnp.int32)
y_ref[...] = y
if (y_dtype := to_dtype) == jnp.dtype("bool"):
y_dtype = jnp.int32
y_dtype = jnp.int32 if to_dtype == jnp.dtype("bool") else to_dtype
try:
y = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct(x.shape, y_dtype))(x)
Expand All @@ -617,7 +681,7 @@ def kernel(x_ref, y_ref):
y_ref = x.astype(to_dtype)
if to_dtype == jnp.bfloat16:
y, y_ref = y.astype(np.float32), y_ref.astype(np.float32)
np.testing.assert_allclose(y, y_ref, atol=0., rtol=0.)
np.testing.assert_array_equal(y, y_ref)

@parameterized.parameters(
jnp.bfloat16,
Expand Down

0 comments on commit 0727f3a

Please sign in to comment.