From 794963de33c55e8e543e68876ae6afb6f0d9bb7c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 14 Jan 2025 07:50:47 -0800 Subject: [PATCH] [pallas:triton] The lowering now uses PTX instead of Triton IR This change improves the stability and backward compatibility of Pallas Triton calls, because unlike PTX, the Triton dialect has no stability guarantees and does change in practice. A few notes * The implementation only supports CUDA at the moment. More work is needed to support ROCm. * Pallas Triton no longer delegates compilation to PTX to XLA:GPU. Instead, compilation is done via a new PjRt extension, which uses its own compilation pipeline mirrored after the one in the Triton Python bindings. * The implementation of the old custom call used by Pallas Triton is deprecated and will be removed after 6 months as per [compatibility guarantees] [*] [*]: https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees PiperOrigin-RevId: 715379979 --- jax/_src/lib/triton.py | 58 +++++++++++++++ .../pallas/triton/pallas_call_registration.py | 73 ++++++++++++------- jax_plugins/cuda/__init__.py | 7 ++ jaxlib/BUILD | 3 +- jaxlib/gpu_plugin_extension.cc | 63 +++++++++++++++- 5 files changed, 176 insertions(+), 28 deletions(-) diff --git a/jax/_src/lib/triton.py b/jax/_src/lib/triton.py index c0a5202e9dbc..74ff0843f3dc 100644 --- a/jax/_src/lib/triton.py +++ b/jax/_src/lib/triton.py @@ -12,4 +12,62 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading +from typing import Protocol + from jaxlib.triton import dialect # noqa: F401 # pytype: disable=import-error + + +class CompilationResult(Protocol): + asm: bytes + smem_bytes: int + cluster_dim_x: int + cluster_dim_y: int + cluster_dim_z: int + + +class CompilationHandler(Protocol): + + def __call__( + self, + module: bytes, + arch_name: str, + num_warps: int, + num_ctas: int, + num_stages: int, + ) -> CompilationResult: + ... + + +_compilation_handlers: dict[str, CompilationHandler] = {} +_compilation_handlers_lock = threading.Lock() + + +def register_compilation_handler( + platform: str, handler: CompilationHandler +) -> None: + with _compilation_handlers_lock: + if existing_handler := _compilation_handlers.get(platform): + raise RuntimeError( + f'Platform {platform} already has a Triton compilation handler:' + f' {existing_handler}' + ) + _compilation_handlers[platform] = handler + + +def compile( + platform: str, + module: bytes, + arch_name: str, + *, + num_warps: int, + num_ctas: int, + num_stages: int, +) -> CompilationResult: + with _compilation_handlers_lock: + handler = _compilation_handlers.get(platform) + if handler is None: + raise RuntimeError( + f'Platform {platform} does not have a Triton compilation handler' + ) + return handler(module, arch_name, num_warps, num_ctas, num_stages) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 59b1b86f33fc..9abe6770789f 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -17,11 +17,15 @@ from __future__ import annotations import io +import re from typing import Any +import zlib +import jax import jax._src.core as jax_core from jax._src.interpreters import mlir -from jax._src.lib.mlir import ir +from jax._src.lib import gpu_triton as triton_kernel_call_lib +from jax._src.lib import triton from jax._src.pallas import core as pallas_core from jax._src.pallas.triton import lowering @@ -51,7 +55,7 @@ def pallas_call_lowering( cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], ): - del interpret, out_avals + del interpret, cost_estimate, out_avals if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError( "dynamic grid bounds not supported in the Triton backend" @@ -77,6 +81,11 @@ def pallas_call_lowering( print("The grid mapping for pallas_call {name_and_src_info}:") print(grid_mapping) + # Sanitize the name to conform to NVPTX requirements. We do this here + # to avoid the need to fetch the new name from PTX post compilation. + name_and_src_info = name_and_src_info.replace( + name=re.sub(r"[^a-zA-Z0-9_$]", "_", name_and_src_info.name) + ) lowering_result = lowering.lower_jaxpr_to_triton_module( jaxpr, grid_mapping, name_and_src_info, lowering_platform ) @@ -86,35 +95,47 @@ def pallas_call_lowering( print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True)) grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid) - out_types = [ - ir.RankedTensorType.get(bm.array_shape_dtype.shape, - mlir.dtype_to_ir_type(bm.array_shape_dtype.dtype)) - for bm in grid_mapping.block_mappings_output - ] buf = io.BytesIO() module_op.write_bytecode(buf) - backend_config = dict( - name=ir.StringAttr.get(name_and_src_info.name), - ir=ir.StringAttr.get(buf.getvalue()), - num_stages=mlir.i32_attr(num_stages), - num_warps=mlir.i32_attr(num_warps), - grid_x=mlir.i32_attr(grid_x), - grid_y=mlir.i32_attr(grid_y), - grid_z=mlir.i32_attr(grid_z), - debug=ir.BoolAttr.get(debug), + gpu_device, *_ = jax.local_devices(backend="gpu") + compilation_result = triton.compile( + lowering_platform.upper(), + buf.getvalue(), + str(gpu_device.compute_capability), # e.g. 7.0 + num_warps=num_warps, + num_ctas=1, + num_stages=num_stages, + ) + kernel = triton_kernel_call_lib.TritonKernel( + name_and_src_info.name, + num_warps, + compilation_result.smem_bytes, + compilation_result.asm, + module_op.get_asm(enable_debug_info=True, pretty_debug_info=True), + triton_kernel_call_lib.get_compute_capability(0), + compilation_result.cluster_dim_x, + compilation_result.cluster_dim_y, + compilation_result.cluster_dim_z, + ) + kernel_call = triton_kernel_call_lib.TritonKernelCall( + kernel, + grid_x, + grid_y, + grid_z, + [triton_kernel_call_lib.create_array_parameter(0, 16)] + * (len(ctx.avals_in) + len(ctx.avals_out)), ) - if "serialized_metadata" in (triton_params or {}): - # This field is unstable and may be removed in the future. - if triton_params["serialized_metadata"] is not None: - backend_config["serialized_metadata"] = ir.StringAttr.get( - triton_params["serialized_metadata"] - ) + # TODO(slebedev): Migrate to ``jax.ffi``. return mlir.custom_call( - call_target_name="__gpu$xla.gpu.triton", - result_types=out_types, + call_target_name="triton_kernel_call", + result_types=[*map(mlir.aval_to_ir_type, ctx.avals_out)], # type: ignore[list-item] operands=in_nodes, - backend_config=backend_config, - api_version=4, + backend_config=zlib.compress( + kernel_call.to_proto( + name_and_src_info.name, + triton_params.get("serialized_metadata") or b"", + ) + ), operand_layouts=avals_to_layouts(ctx.avals_in), result_layouts=avals_to_layouts(ctx.avals_out), operand_output_aliases=dict(input_output_aliases), diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index a09e21c6dd77..68281f4f32b3 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -18,6 +18,7 @@ import os import pathlib +from jax._src.lib import triton from jax._src.lib import xla_client import jax._src.xla_bridge as xb @@ -99,5 +100,11 @@ def initialize(): cuda_plugin_extension.register_custom_type_id, c_api ), ) + triton.register_compilation_handler( + "CUDA", + functools.partial( + cuda_plugin_extension.compile_triton_to_asm, c_api + ), + ) else: logger.warning('cuda_plugin_extension is not found.') diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 843ccb112871..a5680920d808 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -234,8 +234,8 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@nanobind", - "@tsl//tsl/platform:statusor", "@xla//xla:util", "@xla//xla/ffi/api:c_api", "@xla//xla/pjrt:status_casters", @@ -243,6 +243,7 @@ cc_library( "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_helpers", + "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", "@xla//xla/python:py_client_gpu", "@xla//xla/tsl/python/lib/core:numpy", ], diff --git a/jaxlib/gpu_plugin_extension.cc b/jaxlib/gpu_plugin_extension.cc index 46263bdcd40c..8863ebf19b39 100644 --- a/jaxlib/gpu_plugin_extension.cc +++ b/jaxlib/gpu_plugin_extension.cc @@ -16,23 +16,28 @@ limitations under the License. #include "jaxlib/gpu_plugin_extension.h" #include +#include +#include #include #include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_ffi_extension.h" #include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/c/pjrt_c_api_triton_extension.h" #include "xla/pjrt/status_casters.h" #include "xla/python/py_client_gpu.h" #include "xla/tsl/python/lib/core/numpy.h" #include "xla/util.h" -#include "tsl/platform/statusor.h" namespace nb = nanobind; @@ -40,6 +45,44 @@ namespace xla { namespace { +struct TritonCompilationResult { + std::string asm_text; + int64_t smem_bytes; + int cluster_dim_x; + int cluster_dim_y; + int cluster_dim_z; +}; + +absl::StatusOr CompileTritonToASM( + const PJRT_Api* c_api, absl::string_view module, + absl::string_view arch_name, int num_warps, int num_ctas, int num_stages) { + const PJRT_Triton_Extension* triton_ext = + pjrt::FindExtension( + c_api, PJRT_Extension_Type::PJRT_Extension_Type_Triton); + if (triton_ext == nullptr) { + return Unimplemented("The plugin does not have a Triton extension."); + } + PJRT_Triton_Compile_Args args; + args.struct_size = PJRT_Triton_Compile_Args_STRUCT_SIZE; + args.module = module.data(); + args.module_size = module.size(); + args.arch_name = arch_name.data(); + args.arch_name_size = arch_name.size(); + args.num_warps = num_warps; + args.num_ctas = num_ctas; + args.num_stages = num_stages; + RETURN_STATUS_IF_PJRT_ERROR(triton_ext->compile(&args), c_api); + auto asm_text = std::string(args.out_asm, args.out_asm_size); + delete[] args.out_asm; + return TritonCompilationResult{ + .asm_text = std::string(args.out_asm, args.out_asm_size), + .smem_bytes = args.out_smem_bytes, + .cluster_dim_x = args.out_cluster_dim_x, + .cluster_dim_y = args.out_cluster_dim_y, + .cluster_dim_z = args.out_cluster_dim_z, + }; +} + absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, const char* fn_name_c_str, size_t fn_name_size, nb::object fn, @@ -170,6 +213,24 @@ nb::dict Registrations() { void BuildGpuPluginExtension(nanobind::module_& m) { tsl::ImportNumpy(); + + nb::class_(m, "TritonCompilationResult") + .def_ro("asm", &TritonCompilationResult::asm_text) + .def_ro("smem_bytes", &TritonCompilationResult::smem_bytes) + .def_ro("cluster_dim_x", &TritonCompilationResult::cluster_dim_x) + .def_ro("cluster_dim_y", &TritonCompilationResult::cluster_dim_y) + .def_ro("cluster_dim_z", &TritonCompilationResult::cluster_dim_z); + + m.def("compile_triton_to_asm", + [](nb::capsule c_api, nb::bytes module, absl::string_view arch_name, + int num_warps, int num_ctas, int num_stages) { + return xla::ValueOrThrow(CompileTritonToASM( + static_cast(c_api.data()), + absl::string_view(static_cast(module.data()), + module.size()), + arch_name, num_warps, num_ctas, num_stages)); + }); + m.def( "register_custom_call_target", [](nb::capsule c_api, nb::object fn_name_py, nb::object fn,