Skip to content

Commit

Permalink
[Mosaic GPU] Remove expect_wait from Barrier.wait
Browse files Browse the repository at this point in the history
It looks like LLVM already moves the wait loops to the end of the program,
so the whole optimization is no longer necessary and only adds unnecessary operations.

PiperOrigin-RevId: 703052393
  • Loading branch information
apaszke authored and Google-ML-Automation committed Dec 5, 2024
1 parent 7214a3a commit c965ffb
Showing 1 changed file with 5 additions and 21 deletions.
26 changes: 5 additions & 21 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,33 +673,17 @@ def __getitem__(self, offset: ir.Value | int) -> "BarrierRef":
1,
)

def wait_parity(self, parity, expect_wait=False):
i1 = ir.IntegerType.get_signless(1)
def wait_parity(self, parity):
i32 = ir.IntegerType.get_signless(32)
ticks = c(10000000, i32)
address = self.get_ptr()
ticks = arith.constant(i32, 10000000)
parity = arith.extui(i32, parity)
if expect_wait:
nvvm.mbarrier_try_wait_parity_shared(address, parity, ticks)
return
barrier_ready = llvm.inline_asm(
i1,
[address, parity],
"mbarrier.test_wait.parity.shared.b64 $0, [$1], $2;",
"=b,l,r",
has_side_effects=True,
)
should_wait = arith.xori(barrier_ready, c(1, i1))
should_wait = llvm.intr_expect(should_wait, c(0, i1))
with ir.InsertionPoint(scf.IfOp(should_wait).then_block):
nvvm.mbarrier_try_wait_parity_shared(address, parity, ticks)
scf.yield_([])
nvvm.mbarrier_try_wait_parity_shared(self.get_ptr(), parity, ticks)

def wait(self, expect_wait=False):
def wait(self):
parities = memref.load(self.phases, [])
parity, new_parities = self.update_parities(parities)
memref.store(new_parities, self.phases, [])
self.wait_parity(parity, expect_wait=expect_wait)
self.wait_parity(parity)

def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]:
i32 = ir.IntegerType.get_signless(32)
Expand Down

0 comments on commit c965ffb

Please sign in to comment.