diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 0ce1140cfa07..fcba0518620b 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -107,28 +107,46 @@ def c(val: int | float, ty): raise NotImplementedError(ty) return arith.constant(ty, attr) +def _debug_scalar_ty_format(arg): + ty_format = None + if ir.IndexType.isinstance(arg.type): + return "%llu" + if ir.IntegerType.isinstance(arg.type): + width = ir.IntegerType(arg.type).width + ty_format = "%llu" + if width < 64: + arg = arith.extui(ir.IntegerType.get_signless(64), arg) + if ir.F32Type.isinstance(arg.type): + ty_format = "%f" + if ir.F16Type.isinstance(arg.type): + ty_format = "%f" + arg = arith.extf(ir.F32Type.get(), arg) + + return ty_format, arg def debug_print(fmt, *args, uniform=True): type_formats = [] new_args = [] for arg in args: - ty_format = None - if ir.IndexType.isinstance(arg.type): - ty_format = "%llu" - if ir.IntegerType.isinstance(arg.type): - width = ir.IntegerType(arg.type).width - ty_format = "%llu" - if width < 64: - arg = arith.extui(ir.IntegerType.get_signless(64), arg) - if ir.F32Type.isinstance(arg.type): - ty_format = "%f" - if ir.F16Type.isinstance(arg.type): - ty_format = "%f" - arg = arith.extf(ir.F32Type.get(), arg) + if ir.VectorType.isinstance(arg.type): + index = ir.IndexType.get() + vec_ty = ir.VectorType(arg.type) + if len(vec_ty.shape) > 1: + raise NotImplementedError(vec_ty) + vec_args = [ + vector.extractelement(arg, position=c(i, index)) + for i in range(vec_ty.shape[0]) + ] + ty_formats, args = zip(*map(_debug_scalar_ty_format,vec_args)) + ty_format = f"[{','.join(ty_formats)}]" + new_args += args + else: + ty_format, arg = _debug_scalar_ty_format(arg) + new_args.append(arg) + if ty_format is None: raise NotImplementedError(arg.type) type_formats.append(ty_format) - new_args.append(arg) ctx = ( functools.partial(single_thread, per_block=False) if uniform