Skip to content

Commit

Permalink
[mgpu] Added a missed case for debug_print types and raise a proper e…
Browse files Browse the repository at this point in the history
…rror if a type is unexpected.

PiperOrigin-RevId: 701003002
  • Loading branch information
cperivol authored and Google-ML-Automation committed Nov 28, 2024
1 parent 14ddb81 commit d5bfafb
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,18 @@ def c(val: int | float, ty):
return arith.constant(ty, attr)

def _debug_scalar_ty_format(arg):
ty_format = None
if ir.IndexType.isinstance(arg.type):
return "%llu"
return "%llu", arg
if ir.IntegerType.isinstance(arg.type):
width = ir.IntegerType(arg.type).width
ty_format = "%llu"
if width < 64:
if ir.IntegerType(arg.type).width < 64:
arg = arith.extui(ir.IntegerType.get_signless(64), arg)
return "%llu", arg
if ir.F32Type.isinstance(arg.type):
ty_format = "%f"
return "%f", arg
if ir.F16Type.isinstance(arg.type):
ty_format = "%f"
arg = arith.extf(ir.F32Type.get(), arg)

return ty_format, arg
return "%f", arg
raise NotImplementedError(f"Can't print the type {arg.type}")

def debug_print(fmt, *args, uniform=True):
type_formats = []
Expand Down

0 comments on commit d5bfafb

Please sign in to comment.