Skip to content

Commit

Permalink
[mgpu] Debug print for mlir vectors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700714031
  • Loading branch information
cperivol authored and Google-ML-Automation committed Nov 27, 2024
1 parent d449f12 commit df8ecb9
Showing 1 changed file with 32 additions and 14 deletions.
46 changes: 32 additions & 14 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,28 +107,46 @@ def c(val: int | float, ty):
raise NotImplementedError(ty)
return arith.constant(ty, attr)

def _debug_scalar_ty_format(arg):
ty_format = None
if ir.IndexType.isinstance(arg.type):
return "%llu"
if ir.IntegerType.isinstance(arg.type):
width = ir.IntegerType(arg.type).width
ty_format = "%llu"
if width < 64:
arg = arith.extui(ir.IntegerType.get_signless(64), arg)
if ir.F32Type.isinstance(arg.type):
ty_format = "%f"
if ir.F16Type.isinstance(arg.type):
ty_format = "%f"
arg = arith.extf(ir.F32Type.get(), arg)

return ty_format, arg

def debug_print(fmt, *args, uniform=True):
type_formats = []
new_args = []
for arg in args:
ty_format = None
if ir.IndexType.isinstance(arg.type):
ty_format = "%llu"
if ir.IntegerType.isinstance(arg.type):
width = ir.IntegerType(arg.type).width
ty_format = "%llu"
if width < 64:
arg = arith.extui(ir.IntegerType.get_signless(64), arg)
if ir.F32Type.isinstance(arg.type):
ty_format = "%f"
if ir.F16Type.isinstance(arg.type):
ty_format = "%f"
arg = arith.extf(ir.F32Type.get(), arg)
if ir.VectorType.isinstance(arg.type):
index = ir.IndexType.get()
vec_ty = ir.VectorType(arg.type)
if len(vec_ty.shape) > 1:
raise NotImplementedError(vec_ty)
vec_args = [
vector.extractelement(arg, position=c(i, index))
for i in range(vec_ty.shape[0])
]
ty_formats, args = zip(*map(_debug_scalar_ty_format,vec_args))
ty_format = f"[{','.join(ty_formats)}]"
new_args += args
else:
ty_format, arg = _debug_scalar_ty_format(arg)
new_args.append(arg)

if ty_format is None:
raise NotImplementedError(arg.type)
type_formats.append(ty_format)
new_args.append(arg)
ctx = (
functools.partial(single_thread, per_block=False)
if uniform
Expand Down

0 comments on commit df8ecb9

Please sign in to comment.