Skip to content

Commit

Permalink
[Pallas] Deprecate dictionary compiler_params in favor of dataclass.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699057658
  • Loading branch information
justinjfu authored and Google-ML-Automation committed Nov 22, 2024
1 parent 355589f commit 73fa0f4
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
4 changes: 2 additions & 2 deletions jax/experimental/pallas/ops/gpu/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def layer_norm_forward(
]
method = pl.pallas_call(
kernel,
compiler_params=dict(triton=dict(num_warps=num_warps)),
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
grid=(),
out_shape=out_shape,
debug=False,
Expand Down Expand Up @@ -215,7 +215,7 @@ def layer_norm_backward(
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(
kernel,
compiler_params=dict(triton=dict(num_warps=num_warps)),
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
grid=(),
out_shape=out_shape_dx,
debug=False,
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/pallas/ops/gpu/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def rms_norm_backward(
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(
kernel,
compiler_params=dict(triton=dict(num_warps=num_warps)),
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
grid=(),
out_shape=out_shape_dx,
debug=False,
Expand Down
11 changes: 5 additions & 6 deletions tests/pallas/tpu_pallas_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,12 +486,11 @@ def _wait_on_prev_dma():
+ [pltpu.SemaphoreType.DMA] * 4
+ inner_allocs
),
compiler_params=dict(
mosaic=dict(collective_id=0,
# must set scoped vmem flag *larger* than below! e.g.:
# flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072
vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB
)
compiler_params=pltpu.TPUCompilerParams(
collective_id=0,
# must set scoped vmem flag *larger* than below! e.g.:
# flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072
vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB
),
)

Expand Down

0 comments on commit 73fa0f4

Please sign in to comment.