diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index fcba0518620b..2279df4f3984 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -108,21 +108,18 @@ def c(val: int | float, ty): return arith.constant(ty, attr) def _debug_scalar_ty_format(arg): - ty_format = None if ir.IndexType.isinstance(arg.type): - return "%llu" + return "%llu", arg if ir.IntegerType.isinstance(arg.type): - width = ir.IntegerType(arg.type).width - ty_format = "%llu" - if width < 64: + if ir.IntegerType(arg.type).width < 64: arg = arith.extui(ir.IntegerType.get_signless(64), arg) + return "%llu", arg if ir.F32Type.isinstance(arg.type): - ty_format = "%f" + return "%f", arg if ir.F16Type.isinstance(arg.type): - ty_format = "%f" arg = arith.extf(ir.F32Type.get(), arg) - - return ty_format, arg + return "%f", arg + raise NotImplementedError(f"Can't print the type {arg.type}") def debug_print(fmt, *args, uniform=True): type_formats = []