Skip to content

Commit

Permalink
refactor: naming convention + better control flow
Browse files Browse the repository at this point in the history
  • Loading branch information
knyazer committed Sep 8, 2024
1 parent 419b298 commit d585a36
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 31 deletions.
47 changes: 20 additions & 27 deletions equinox/nn/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,16 +185,13 @@ def precompute_freqs_cis(
) -> tuple[Float[Array, "end half_emb_size"], Float[Array, "end half_emb_size"]]:
freqs = 1.0 / (
theta
** (
jnp.arange(0.0, embedding_size, 2, dtype=dtype)[jnp.newaxis, :]
/ embedding_size
)
** (jnp.arange(0.0, embedding_size, 2)[jnp.newaxis, :] / embedding_size)
)

t = jnp.arange(float(end), dtype=dtype)
t = jnp.arange(float(end))
freqs_outer = jnp.outer(t, freqs)

return jnp.cos(freqs_outer), jnp.sin(freqs_outer)
return jnp.cos(freqs_outer).astype(dtype), jnp.sin(freqs_outer).astype(dtype)

@jax.named_scope("eqx.nn.RotaryPositionalEmbedding")
def __call__(
Expand Down Expand Up @@ -223,32 +220,28 @@ def __call__(
)

with jax.ensure_compile_time_eval():
if (embedding_size, self.dtype) in internal_rope_embedding_cache:
freqs_pair = internal_rope_embedding_cache[(embedding_size, self.dtype)]
freqs_cis_seq_len, _ = freqs_pair[0].shape
if seq_len > freqs_cis_seq_len:
freqs_pair = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta, self.dtype
)
internal_rope_embedding_cache[(embedding_size, self.dtype)] = (
freqs_pair
)
else:
freqs_pair = (
freqs_pair[0][:seq_len],
freqs_pair[1][:seq_len],
)
else:
freqs_pair = self.precompute_freqs_cis(
cache_key = (embedding_size, self.dtype)
if cache_key not in internal_rope_embedding_cache:
internal_rope_embedding_cache[cache_key] = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta, self.dtype
)
internal_rope_embedding_cache[(embedding_size, self.dtype)] = freqs_pair

freqs_real = jnp.tile(freqs_pair[0], (1, 2))
freqs_imag = jnp.tile(freqs_pair[1], (1, 2))
freqs_cos, freqs_sin = internal_rope_embedding_cache[cache_key]
freqs_seq_len, _ = freqs_cos.shape
if seq_len > freqs_seq_len:
internal_rope_embedding_cache[cache_key] = self.precompute_freqs_cis(
embedding_size, seq_len, self.theta, self.dtype
)
freqs_cos, freqs_sin = internal_rope_embedding_cache[cache_key]

freqs_cos = freqs_cos[:seq_len]
freqs_sin = freqs_sin[:seq_len]

freqs_cos = jnp.tile(freqs_cos, (1, 2))
freqs_sin = jnp.tile(freqs_sin, (1, 2))

rotate_x = self.rotate_half(x)
x_rope = (x * freqs_real) + (rotate_x * freqs_imag)
x_rope = (x * freqs_cos) + (rotate_x * freqs_sin)
return x_rope


Expand Down
20 changes: 16 additions & 4 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,9 +1408,9 @@ def test_rope_embeddings_freqs_cis():
embedding_size, seq_length, theta, jnp.float16
)
assert jnp.allclose(
freqs_cis[0], expected_freqs_cis.real.astype(jnp.float16), rtol=1e-2
freqs_cis[0].astype(jnp.float32), expected_freqs_cis.real, rtol=1e-2
) and jnp.allclose(
freqs_cis[1], expected_freqs_cis.imag.astype(jnp.float16), rtol=1e-2
freqs_cis[1].astype(jnp.float32), expected_freqs_cis.imag, rtol=1e-2
)


Expand Down Expand Up @@ -1446,17 +1446,29 @@ def test_rope_embeddings_values():
seq_length, embedding_size
)

rope_embeddings = eqx.nn.RotaryPositionalEmbedding(embedding_size)
rope_embeddings = eqx.nn.RotaryPositionalEmbedding(
embedding_size, dtype=jnp.float32
)
res = rope_embeddings(x)

assert jnp.allclose(res, expected_values, atol=1e-6)

with jax.numpy_dtype_promotion("standard"):
# Test that high precision rope on low precision input is more
# accurate than low precision rope on low precision input
res = rope_embeddings(x.astype(jnp.float16))
assert jnp.allclose(
res.astype(jnp.float16),
expected_values.astype(jnp.float16),
rtol=1e-3,
)

rope_embeddings = eqx.nn.RotaryPositionalEmbedding(
embedding_size, dtype=jnp.float16
)
res = rope_embeddings(x.astype(jnp.float16))

assert (
jnp.allclose(res, expected_values.astype(jnp.float16), rtol=1e-2)
jnp.allclose(res.astype(jnp.float32), expected_values, rtol=1e-2)
and res.dtype == jnp.float16
)

0 comments on commit d585a36

Please sign in to comment.