Skip to content

Commit

Permalink
Import new version of Triton
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590496513
  • Loading branch information
jax authors committed Dec 13, 2023
1 parent 4183f29 commit 4991aac
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
import numpy as np
from triton._C.libtriton.triton import ir as tl_ir
from triton.compiler import code_generator as code_gen
from triton.compiler import compiler as tc
import triton.compiler.backends.cuda as cb
import triton.language as tl

# TODO(sharadmv): enable type checking
Expand Down Expand Up @@ -229,21 +229,13 @@ def _process_grid_to_3d_grid(builder, grid_mapping: GridMapping):

def lower_jaxpr_to_triton_module(
jaxpr: jax_core.Jaxpr, in_shapes, grid_mapping: GridMapping, name: str,
num_warps: int
cuda_options: cb.CUDAOptions
) -> tl_ir.module:
jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True)
ir_context = tl_ir.context()
ir_context.load_triton()
builder = tl_ir.builder(ir_context)
# TODO(sharadmv): handle multiple devices, right now we assume device 0
# which is fine when we have multiple of the same GPU but this won't work in
# general.
device = 0
builder.target = tc.CudaTargetDescriptor(
capability=triton_kernel_call_lib.get_compute_capability(device),
num_warps=num_warps,
enable_fp_fusion=True
)
builder.options = cuda_options
module = builder.create_module()
in_avals = [var.aval for var in jaxpr.invars]
triton_types = [get_triton_type(x) for x in in_avals]
Expand Down Expand Up @@ -1491,17 +1483,31 @@ def compile_jaxpr(
num_stages: int,
debug: bool,
) -> TritonCompilationResult:
# TODO(sharadmv): handle multiple devices, right now we assume device 0
# which is fine when we have multiple of the same GPU but this won't work in
# general.
device = 0
arch = triton_kernel_call_lib.get_compute_capability(device)
target = ("cuda", arch)
cuda_backend = cb.CUDABackend(target)
cuda_options = cuda_backend.parse_compiler_options(
dict(
num_warps=num_warps,
num_stages=num_stages,
debug=debug,
)
)

lowering_result = lower_jaxpr_to_triton_module(
jaxpr, in_shapes, grid_mapping, name, num_warps
jaxpr, in_shapes, grid_mapping, name, cuda_options
)
device = 0

ttir = str(lowering_result.module)
ptx, name, shared_mem_bytes, compute_capability = compile_ttir_to_ptx_inplace(
lowering_result.module,
cuda_backend,
cuda_options,
device=device,
num_warps=num_warps,
num_stages=num_stages,
dump=debug,
)
return TritonCompilationResult(
name, ttir, ptx, shared_mem_bytes, compute_capability, lowering_result
Expand Down

0 comments on commit 4991aac

Please sign in to comment.