Skip to content

Commit

Permalink
[Mosaic][Easy] - Wire up kernel names to MLIR dump
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699408419
  • Loading branch information
Google-ML-Automation committed Nov 23, 2024
1 parent b259fde commit e53ff2c
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions jax/_src/tpu_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def _lower_tpu_kernel(
module: ir.Module,
hardware_generation: int,
target_shape: tuple[int, int],
kernel_name: str | None = None,
) -> ir.Module:
"""Runs MLIR passes lowering the given module to an MLIR module.
Expand All @@ -303,8 +304,7 @@ def _lower_tpu_kernel(
tpu.register_dialect(ctx)
mhlo.register_mhlo_dialect(ctx)
mhlo.register_mhlo_passes()

dump_mlir(module, "original")
dump_mlir(module, "original", kernel_name)

if _MOSAIC_ALLOW_HLO.value:
# Run hlo dialect conversion: hlo -> linalg -> vector.
Expand Down Expand Up @@ -406,6 +406,7 @@ def _lower_mosaic_module_to_asm(
*,
backend: str,
device_type: str | None,
kernel_name: str | None,
) -> tuple[ir.Module, tuple[bool, bool, bool, bool]]:
has_communication, has_custom_barrier = tpu.private_has_communication(
module.operation
Expand All @@ -429,7 +430,7 @@ def _lower_mosaic_module_to_asm(
hardware_generation = int(device_kind[len("TPU v")])
target_shape = get_target_shape(hardware_generation)
module = _lower_tpu_kernel(
module, hardware_generation, target_shape=target_shape
module, hardware_generation, target_shape=target_shape, kernel_name=kernel_name,
)
needs_hlo_passes = False
needs_layout_passes = False
Expand Down Expand Up @@ -504,6 +505,7 @@ def _lower_to_custom_call_config(
collective_id: int | None,
serialization_format: int | None,
output_memory_spaces: tuple[MemorySpace | None, ...] | None = None,
kernel_name: str | None = None,
) -> CustomCallBackendConfig:
lowered_module_asm, (
has_communication,
Expand All @@ -514,6 +516,7 @@ def _lower_to_custom_call_config(
module,
backend=backend,
device_type=device_type,
kernel_name=kernel_name,
)
return _lowered_to_custom_call_config(
lowered_module_asm,
Expand Down Expand Up @@ -613,6 +616,7 @@ def lower_module_to_custom_call(
device_type=device_type,
serialization_format=serialization_format,
output_memory_spaces=output_memory_spaces,
kernel_name=kernel_name,
)
return _tpu_custom_call_lowering(
ctx,
Expand Down Expand Up @@ -654,6 +658,7 @@ def as_tpu_kernel(
collective_id=collective_id,
serialization_format=serialization_format,
output_memory_spaces=output_memory_spaces,
kernel_name=kernel_name,
)
return _as_jax_callable(
config,
Expand Down Expand Up @@ -735,7 +740,7 @@ def apply_kernel(*args):
return jax.jit(apply_kernel)


def dump_mlir(module: ir.Module, name: str):
def dump_mlir(module: ir.Module, name: str, kernel_name: str | None = None):
"""A helper function to dump mosaic mlir module"""
try:
should_dump = FLAGS["xla_mosaic_dump_to"].value
Expand All @@ -744,6 +749,8 @@ def dump_mlir(module: ir.Module, name: str):
if should_dump == "sponge":
outdir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", None)
if outdir:
if kernel_name:
name = f"{kernel_name}-{name}"
path = os.path.join(outdir, f"{time.time_ns()}-mosaic-dump-{name}-py.txt")
with open(path, "w") as f:
f.write(str(module))

0 comments on commit e53ff2c

Please sign in to comment.