Skip to content

Commit

Permalink
[Mosaic GPU] Avoid double-predication when async_copy predicate is sp…
Browse files Browse the repository at this point in the history
…ecified

PiperOrigin-RevId: 700999181
  • Loading branch information
apaszke authored and Google-ML-Automation committed Nov 28, 2024
1 parent b09b077 commit 14ddb81
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/experimental/mosaic/gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def async_copy(
arrive: bool | None = None,
uniform: bool = True,
collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None,
predicate: ir.Value | None = None,
predicate: ir.Value | None = None, # Should select 0 or 1 threads from the WG.
):
index = ir.IndexType.get()
i16 = ir.IntegerType.get_signless(16)
Expand Down Expand Up @@ -504,7 +504,7 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int):

uniform_ctx = (
functools.partial(utils.single_thread, per_block=False)
if uniform
if uniform and predicate is None
else contextlib.nullcontext
)

Expand Down

0 comments on commit 14ddb81

Please sign in to comment.