Skip to content

Commit

Permalink
[Pallas] Fix GQA triton kernel test.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715576240
  • Loading branch information
justinjfu authored and Google-ML-Automation committed Jan 15, 2025
1 parent d1810b4 commit cc9f6e7
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 3 additions & 0 deletions jax/experimental/pallas/ops/gpu/decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ def decode_attn_unbatched(

# final round of flash
m_next = m.max(axis=0)
# TODO(b/389925439): This barrier is necessary to prevent NaNs/invalid
# values appearing after JIT compilation.
m_next = lax.optimization_barrier(m_next)
correction = jnp.exp(m - m_next[None])
o = o * correction[:, :, None].astype(o.dtype)
l_next = (l * correction).sum(axis=0)
Expand Down
2 changes: 0 additions & 2 deletions tests/pallas/gpu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import os
import sys
import unittest

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -169,7 +168,6 @@ def test_mqa(
for return_residuals in [False, True]
])
@jax.numpy_dtype_promotion("standard")
@unittest.skip("TODO(b/389925439): gqa tests started failing after triton integrate")
def test_gqa(
self,
batch_size,
Expand Down

0 comments on commit cc9f6e7

Please sign in to comment.