Skip to content

Commit

Permalink
Mixed precision for attention (#597)
Browse files Browse the repository at this point in the history
* Add argument "softmax_dtype" to attention

* Relying on JAX type promotion for softmax

* Cast attention logits to Array
  • Loading branch information
mk-0 authored Nov 26, 2023
1 parent 6846a2b commit 588a8c3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
7 changes: 5 additions & 2 deletions equinox/nn/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
import warnings
from functools import partial
from typing import Optional, Union
from typing import cast, Optional, Union

import jax
import jax.numpy as jnp
Expand All @@ -29,8 +29,11 @@ def dot_product_attention_weights(
f"{key.shape[0]}). Got {mask.shape}."
)
logits = jnp.where(mask, logits, jnp.finfo(logits.dtype).min)
logits = cast(Array, logits)

return jax.nn.softmax(logits, axis=-1) # pyright: ignore
dtype = jnp.result_type(logits.dtype, jnp.float32)
weights = jax.nn.softmax(logits.astype(dtype)).astype(logits.dtype)
return weights


def dot_product_attention(
Expand Down
6 changes: 6 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,12 @@ def test_dot_product_attention_weights(getkey):
weights = eqx.nn._attention.dot_product_attention_weights(q, k, mask)
assert jnp.allclose(weights, jnp.array([[1.0, 0.0]]))

q = jnp.array([[1.0]], dtype="float16")
k = jnp.array([[9.0], [1.0]], dtype="float16")
weights = eqx.nn._attention.dot_product_attention_weights(q, k)
assert weights.dtype == q.dtype
assert weights.max() < 1


def test_dot_product_attention(getkey):
q = jnp.array([[0.0, 2**0.5]])
Expand Down

0 comments on commit 588a8c3

Please sign in to comment.