Skip to content

Commit

Permalink
fix: address the review
Browse files Browse the repository at this point in the history
  • Loading branch information
knyazer committed Sep 9, 2024
1 parent b614639 commit 54c0bfe
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
21 changes: 15 additions & 6 deletions equinox/nn/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax
import jax.numpy as jnp
import jax.random as jrandom
from jax._src.dtypes import TypePromotionError
from jaxtyping import Array, ArrayLike, Float, Int, PRNGKeyArray

from .._caches import cache_clears
Expand Down Expand Up @@ -162,11 +163,7 @@ def process_heads(

embedding_size: int = field(static=True)
theta: float = field(static=True, default=10_000.0)
dtype: Any = field(static=True, default=None)

def __post_init__(self):
if self.dtype is None:
self.dtype = default_floating_dtype()
dtype: Any = field(static=True, default_factory=default_floating_dtype)

def __check_init__(self):
if self.embedding_size < 0:
Expand All @@ -191,6 +188,7 @@ def precompute_freqs_cis(
t = jnp.arange(float(end))
freqs_outer = jnp.outer(t, freqs)

# we assign the type at the very end to minimize the loss of precision
return jnp.cos(freqs_outer).astype(dtype), jnp.sin(freqs_outer).astype(dtype)

@jax.named_scope("eqx.nn.RotaryPositionalEmbedding")
Expand Down Expand Up @@ -241,7 +239,18 @@ def __call__(
freqs_sin = jnp.tile(freqs_sin, (1, 2))

rotate_x = self.rotate_half(x)
x_rope = (x * freqs_cos) + (rotate_x * freqs_sin)
try:
x_rope = (x * freqs_cos) + (rotate_x * freqs_sin)
except TypePromotionError as e:
inp_dtype = jnp.dtype(x.dtype)
rope_dtype = jnp.dtype(self.dtype)
raise TypePromotionError(
f"The type of the passed value differs from the type "
f"of the rotary embeddings ({inp_dtype} != {rope_dtype}), thus leading "
"to a conflict when numpy_dtype_promotion is set to strict. To avoid "
f"this error, either initialiaze RoPE module with {inp_dtype} "
f"dtype, or explicitly cast the input argument to {rope_dtype}."
) from e
return x_rope.astype(x.dtype)


Expand Down
5 changes: 5 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax.numpy as jnp
import jax.random as jrandom
import pytest
from jax._src.dtypes import TypePromotionError


def test_custom_init():
Expand Down Expand Up @@ -1463,6 +1464,10 @@ def test_rope_embeddings_values():
rtol=1e-3,
)

# check that without dtype promotion we throw an error
with pytest.raises(TypePromotionError):
rope_embeddings(x.astype(jnp.float16))

rope_embeddings = eqx.nn.RotaryPositionalEmbedding(
embedding_size, dtype=jnp.float16
)
Expand Down

0 comments on commit 54c0bfe

Please sign in to comment.