diff --git a/equinox/nn/_embedding.py b/equinox/nn/_embedding.py index 3ac78b14..10001326 100644 --- a/equinox/nn/_embedding.py +++ b/equinox/nn/_embedding.py @@ -185,16 +185,13 @@ def precompute_freqs_cis( ) -> tuple[Float[Array, "end half_emb_size"], Float[Array, "end half_emb_size"]]: freqs = 1.0 / ( theta - ** ( - jnp.arange(0.0, embedding_size, 2, dtype=dtype)[jnp.newaxis, :] - / embedding_size - ) + ** (jnp.arange(0.0, embedding_size, 2)[jnp.newaxis, :] / embedding_size) ) - t = jnp.arange(float(end), dtype=dtype) + t = jnp.arange(float(end)) freqs_outer = jnp.outer(t, freqs) - return jnp.cos(freqs_outer), jnp.sin(freqs_outer) + return jnp.cos(freqs_outer).astype(dtype), jnp.sin(freqs_outer).astype(dtype) @jax.named_scope("eqx.nn.RotaryPositionalEmbedding") def __call__( @@ -223,32 +220,28 @@ def __call__( ) with jax.ensure_compile_time_eval(): - if (embedding_size, self.dtype) in internal_rope_embedding_cache: - freqs_pair = internal_rope_embedding_cache[(embedding_size, self.dtype)] - freqs_cis_seq_len, _ = freqs_pair[0].shape - if seq_len > freqs_cis_seq_len: - freqs_pair = self.precompute_freqs_cis( - embedding_size, seq_len, self.theta, self.dtype - ) - internal_rope_embedding_cache[(embedding_size, self.dtype)] = ( - freqs_pair - ) - else: - freqs_pair = ( - freqs_pair[0][:seq_len], - freqs_pair[1][:seq_len], - ) - else: - freqs_pair = self.precompute_freqs_cis( + cache_key = (embedding_size, self.dtype) + if cache_key not in internal_rope_embedding_cache: + internal_rope_embedding_cache[cache_key] = self.precompute_freqs_cis( embedding_size, seq_len, self.theta, self.dtype ) - internal_rope_embedding_cache[(embedding_size, self.dtype)] = freqs_pair - freqs_real = jnp.tile(freqs_pair[0], (1, 2)) - freqs_imag = jnp.tile(freqs_pair[1], (1, 2)) + freqs_cos, freqs_sin = internal_rope_embedding_cache[cache_key] + freqs_seq_len, _ = freqs_cos.shape + if seq_len > freqs_seq_len: + internal_rope_embedding_cache[cache_key] = self.precompute_freqs_cis( + embedding_size, seq_len, self.theta, self.dtype + ) + freqs_cos, freqs_sin = internal_rope_embedding_cache[cache_key] + + freqs_cos = freqs_cos[:seq_len] + freqs_sin = freqs_sin[:seq_len] + + freqs_cos = jnp.tile(freqs_cos, (1, 2)) + freqs_sin = jnp.tile(freqs_sin, (1, 2)) rotate_x = self.rotate_half(x) - x_rope = (x * freqs_real) + (rotate_x * freqs_imag) + x_rope = (x * freqs_cos) + (rotate_x * freqs_sin) return x_rope diff --git a/tests/test_nn.py b/tests/test_nn.py index a9a8de1b..0e6f5b30 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -1408,9 +1408,9 @@ def test_rope_embeddings_freqs_cis(): embedding_size, seq_length, theta, jnp.float16 ) assert jnp.allclose( - freqs_cis[0], expected_freqs_cis.real.astype(jnp.float16), rtol=1e-2 + freqs_cis[0].astype(jnp.float32), expected_freqs_cis.real, rtol=1e-2 ) and jnp.allclose( - freqs_cis[1], expected_freqs_cis.imag.astype(jnp.float16), rtol=1e-2 + freqs_cis[1].astype(jnp.float32), expected_freqs_cis.imag, rtol=1e-2 ) @@ -1446,17 +1446,29 @@ def test_rope_embeddings_values(): seq_length, embedding_size ) - rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size) + rope_embeddings = eqx.nn.RotaryPositionalEmbedding( + embedding_size, dtype=jnp.float32 + ) res = rope_embeddings(x) assert jnp.allclose(res, expected_values, atol=1e-6) + with jax.numpy_dtype_promotion("standard"): + # Test that high precision rope on low precision input is more + # accurate than low precision rope on low precision input + res = rope_embeddings(x.astype(jnp.float16)) + assert jnp.allclose( + res.astype(jnp.float16), + expected_values.astype(jnp.float16), + rtol=1e-3, + ) + rope_embeddings = eqx.nn.RotaryPositionalEmbedding( embedding_size, dtype=jnp.float16 ) res = rope_embeddings(x.astype(jnp.float16)) assert ( - jnp.allclose(res, expected_values.astype(jnp.float16), rtol=1e-2) + jnp.allclose(res.astype(jnp.float32), expected_values, rtol=1e-2) and res.dtype == jnp.float16 )