Skip to content

Commit

Permalink
Simplify sinusoidal position embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Aug 19, 2024
1 parent 1229a9d commit 85950c6
Showing 1 changed file with 17 additions and 37 deletions.
54 changes: 17 additions & 37 deletions lib/bumblebee/text/m2m100.ex
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ defmodule Bumblebee.Text.M2m100 do
mask
|> Nx.cumulative_sum(axis: 1)
|> Nx.multiply(mask)
|> Nx.add(spec.pad_token_id)
|> Nx.subtract(1)
end)
end

Expand Down Expand Up @@ -425,46 +425,26 @@ defmodule Bumblebee.Text.M2m100 do
defp position_embedding(position_ids, spec, opts) do
name = opts[:name]

# For M2M100 we need to offset the embeddings
offset = 2
embedding_dim = spec.hidden_size
num_embeddings = spec.max_positions + offset
padding_idx = spec.pad_token_id
half_dim = div(embedding_dim, 2)

Axon.nx(
position_ids,
&position_embedding_impl(&1, embedding_dim, half_dim, num_embeddings, padding_idx),
name: join(name, "sinusoidal_position_embedding")

position_ids = Axon.add(position_ids, Axon.constant(Nx.tensor(offset)))

Axon.layer(&sinusoidal_position_embedding_impl/2, [position_ids],
size: spec.hidden_size,
name: name
)
end

defnp position_embedding_impl(
position_ids,
embedding_dim,
half_dim,
num_embeddings,
padding_idx
) do
zero_pad_slice = Nx.broadcast(0.0, {1, embedding_dim})

Nx.log(10_000)
|> Nx.divide(half_dim - 1)
|> Nx.negate()
|> Nx.multiply(Nx.iota({half_dim}))
|> Nx.exp()
|> Nx.new_axis(0)
|> Nx.multiply(Nx.new_axis(Nx.iota({num_embeddings}), 1))
|> then(&Nx.concatenate([Nx.sin(&1), Nx.cos(&1)], axis: 1))
|> Nx.reshape({num_embeddings, :auto})
|> then(fn emb ->
if rem(embedding_dim, 2) == 1 do
Nx.concatenate([emb, Nx.broadcast(0, {num_embeddings, 1})], axis: 1)
else
emb
end
end)
|> Nx.put_slice([padding_idx, 0], zero_pad_slice)
|> Nx.take(Nx.as_type(position_ids, {:s, 64}))
defnp sinusoidal_position_embedding_impl(position_ids, opts \\ []) do
size = opts[:size]

half_size = div(size, 2)
base = 10_000
range = Nx.iota({half_size}) / (half_size - 1)
inv_frequency = 1 / Nx.pow(base, range)
angle = Nx.outer(position_ids, inv_frequency)
Nx.concatenate([Nx.sin(angle), Nx.cos(angle)], axis: -1)
end

defp decoder(
Expand Down

0 comments on commit 85950c6

Please sign in to comment.