diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 10963330ff92..2720ca196012 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -249,8 +249,8 @@ def as_text(self) -> str: else: raise - # TODO(skyewm): this should return a single dict (I think returning a list - # was to support MPMD executables, which never fully landed) + # TODO(b/384741132): this should return a single dict (I think returning a list + # was to support MPMD executables, which never fully landed). def cost_analysis(self) -> list[dict[str, float]]: xla_ext_exe = self.xla_extension_executable() @@ -266,9 +266,19 @@ def cost_analysis(self) -> list[dict[str, float]]: # Try client method if executable cost_analysis method is unimplemented if hasattr(xla_ext_exe, "client"): try: + # TODO(b/384741132): We expect that the executable has only one + # HloModule. We should be able to remove this check once we update the + # Executable class to have only a single HloModule (see bug). + hlo_modules = xla_ext_exe.hlo_modules() + assert len(hlo_modules) == 1, ( + f"Exectuable should have only one HloModule ({len(hlo_modules)})" + " were found)." + ) + return [ - xla_extension.hlo_module_cost_analysis(xla_ext_exe.client, m) - for m in xla_ext_exe.hlo_modules() + xla_extension.hlo_module_cost_analysis( + xla_ext_exe.client, hlo_modules[0] + ) ] except xla_extension.XlaRuntimeError as e: msg, *_ = e.args