diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 2279df4f3984..f6cab5654e64 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -673,33 +673,17 @@ def __getitem__(self, offset: ir.Value | int) -> "BarrierRef": 1, ) - def wait_parity(self, parity, expect_wait=False): - i1 = ir.IntegerType.get_signless(1) + def wait_parity(self, parity): i32 = ir.IntegerType.get_signless(32) - ticks = c(10000000, i32) - address = self.get_ptr() + ticks = arith.constant(i32, 10000000) parity = arith.extui(i32, parity) - if expect_wait: - nvvm.mbarrier_try_wait_parity_shared(address, parity, ticks) - return - barrier_ready = llvm.inline_asm( - i1, - [address, parity], - "mbarrier.test_wait.parity.shared.b64 $0, [$1], $2;", - "=b,l,r", - has_side_effects=True, - ) - should_wait = arith.xori(barrier_ready, c(1, i1)) - should_wait = llvm.intr_expect(should_wait, c(0, i1)) - with ir.InsertionPoint(scf.IfOp(should_wait).then_block): - nvvm.mbarrier_try_wait_parity_shared(address, parity, ticks) - scf.yield_([]) + nvvm.mbarrier_try_wait_parity_shared(self.get_ptr(), parity, ticks) - def wait(self, expect_wait=False): + def wait(self): parities = memref.load(self.phases, []) parity, new_parities = self.update_parities(parities) memref.store(new_parities, self.phases, []) - self.wait_parity(parity, expect_wait=expect_wait) + self.wait_parity(parity) def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]: i32 = ir.IntegerType.get_signless(32)