diff --git a/equinox/nn/_embedding.py b/equinox/nn/_embedding.py index 6a753876..3ac78b14 100644 --- a/equinox/nn/_embedding.py +++ b/equinox/nn/_embedding.py @@ -185,11 +185,14 @@ def precompute_freqs_cis( ) -> tuple[Float[Array, "end half_emb_size"], Float[Array, "end half_emb_size"]]: freqs = 1.0 / ( theta - ** (jnp.arange(0.0, embedding_size, 2)[jnp.newaxis, :] / embedding_size) + ** ( + jnp.arange(0.0, embedding_size, 2, dtype=dtype)[jnp.newaxis, :] + / embedding_size + ) ) - t = jnp.arange(float(end)) - freqs_outer = jnp.outer(t, freqs).astype(dtype) + t = jnp.arange(float(end), dtype=dtype) + freqs_outer = jnp.outer(t, freqs) return jnp.cos(freqs_outer), jnp.sin(freqs_outer)