From 7a63d028f5f116a88341d010ef49130ef3657673 Mon Sep 17 00:00:00 2001 From: ZincCat Date: Thu, 24 Oct 2024 21:53:10 -0400 Subject: [PATCH] better safe softmax for pure jax --- flaxattention/core/attention.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/flaxattention/core/attention.py b/flaxattention/core/attention.py index 9d896f9..32156c6 100644 --- a/flaxattention/core/attention.py +++ b/flaxattention/core/attention.py @@ -64,12 +64,6 @@ def _math_attention_inner( return scores, post_mod_scores # type: ignore -def make_safe(x: Array, axis: int) -> Array: - masked = jnp.isnan(x) | jnp.isinf(x) - masked_rows = jnp.all(masked, axis=axis, keepdims=True) - zeros = jnp.zeros_like(x) - return jnp.where(masked_rows, zeros, x) - def math_attention( query: Array, key: Array, @@ -107,8 +101,7 @@ def math_attention( if use_pallas: post_mod_scores = pl_softmax(post_mod_scores, axis=-1) else: - post_mod_scores = jax.nn.softmax(post_mod_scores, axis=-1) - post_mod_scores = make_safe(post_mod_scores, axis=-1) + post_mod_scores = jax.nn.softmax(post_mod_scores, axis=-1, where=(post_mod_scores != -jnp.inf)) output = jnp.matmul(post_mod_scores.astype(query.dtype), value) return output, logsumexp / jnp.log(2)