diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index b170822ea873..c333f5e65877 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -59,7 +59,7 @@ import numpy as np from triton._C.libtriton.triton import ir as tl_ir from triton.compiler import code_generator as code_gen -from triton.compiler import compiler as tc +import triton.compiler.backends.cuda as cb import triton.language as tl # TODO(sharadmv): enable type checking @@ -229,21 +229,13 @@ def _process_grid_to_3d_grid(builder, grid_mapping: GridMapping): def lower_jaxpr_to_triton_module( jaxpr: jax_core.Jaxpr, in_shapes, grid_mapping: GridMapping, name: str, - num_warps: int + cuda_options: cb.CUDAOptions ) -> tl_ir.module: jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True) ir_context = tl_ir.context() ir_context.load_triton() builder = tl_ir.builder(ir_context) - # TODO(sharadmv): handle multiple devices, right now we assume device 0 - # which is fine when we have multiple of the same GPU but this won't work in - # general. - device = 0 - builder.target = tc.CudaTargetDescriptor( - capability=triton_kernel_call_lib.get_compute_capability(device), - num_warps=num_warps, - enable_fp_fusion=True - ) + builder.options = cuda_options module = builder.create_module() in_avals = [var.aval for var in jaxpr.invars] triton_types = [get_triton_type(x) for x in in_avals] @@ -1491,17 +1483,31 @@ def compile_jaxpr( num_stages: int, debug: bool, ) -> TritonCompilationResult: + # TODO(sharadmv): handle multiple devices, right now we assume device 0 + # which is fine when we have multiple of the same GPU but this won't work in + # general. + device = 0 + arch = triton_kernel_call_lib.get_compute_capability(device) + target = ("cuda", arch) + cuda_backend = cb.CUDABackend(target) + cuda_options = cuda_backend.parse_compiler_options( + dict( + num_warps=num_warps, + num_stages=num_stages, + debug=debug, + ) + ) + lowering_result = lower_jaxpr_to_triton_module( - jaxpr, in_shapes, grid_mapping, name, num_warps + jaxpr, in_shapes, grid_mapping, name, cuda_options ) - device = 0 + ttir = str(lowering_result.module) ptx, name, shared_mem_bytes, compute_capability = compile_ttir_to_ptx_inplace( lowering_result.module, + cuda_backend, + cuda_options, device=device, - num_warps=num_warps, - num_stages=num_stages, - dump=debug, ) return TritonCompilationResult( name, ttir, ptx, shared_mem_bytes, compute_capability, lowering_result