From df36c29803f0a030b2115f45996fdf5eef1c3bb5 Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Thu, 2 Jan 2025 11:24:04 -0800 Subject: [PATCH] Compute cost-analysis on only one HLO module. There was historically a goal to support multiple HLOs in an executable, but this work was never finished and is no longer planned so we don't need this support. This will soon enable us to return only a dict, instead of a list of dicts with only one item. PiperOrigin-RevId: 711477481 --- jax/_src/stages.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 10963330ff92..2720ca196012 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -249,8 +249,8 @@ def as_text(self) -> str: else: raise - # TODO(skyewm): this should return a single dict (I think returning a list - # was to support MPMD executables, which never fully landed) + # TODO(b/384741132): this should return a single dict (I think returning a list + # was to support MPMD executables, which never fully landed). def cost_analysis(self) -> list[dict[str, float]]: xla_ext_exe = self.xla_extension_executable() @@ -266,9 +266,19 @@ def cost_analysis(self) -> list[dict[str, float]]: # Try client method if executable cost_analysis method is unimplemented if hasattr(xla_ext_exe, "client"): try: + # TODO(b/384741132): We expect that the executable has only one + # HloModule. We should be able to remove this check once we update the + # Executable class to have only a single HloModule (see bug). + hlo_modules = xla_ext_exe.hlo_modules() + assert len(hlo_modules) == 1, ( + f"Exectuable should have only one HloModule ({len(hlo_modules)})" + " were found)." + ) + return [ - xla_extension.hlo_module_cost_analysis(xla_ext_exe.client, m) - for m in xla_ext_exe.hlo_modules() + xla_extension.hlo_module_cost_analysis( + xla_ext_exe.client, hlo_modules[0] + ) ] except xla_extension.XlaRuntimeError as e: msg, *_ = e.args