diff --git a/lib/bumblebee/text/m2m100.ex b/lib/bumblebee/text/m2m100.ex index 631029ea..98503abf 100644 --- a/lib/bumblebee/text/m2m100.ex +++ b/lib/bumblebee/text/m2m100.ex @@ -395,7 +395,7 @@ defmodule Bumblebee.Text.M2m100 do mask |> Nx.cumulative_sum(axis: 1) |> Nx.multiply(mask) - |> Nx.add(spec.pad_token_id) + |> Nx.subtract(1) end) end @@ -425,46 +425,26 @@ defmodule Bumblebee.Text.M2m100 do defp position_embedding(position_ids, spec, opts) do name = opts[:name] + # For M2M100 we need to offset the embeddings offset = 2 - embedding_dim = spec.hidden_size - num_embeddings = spec.max_positions + offset - padding_idx = spec.pad_token_id - half_dim = div(embedding_dim, 2) - - Axon.nx( - position_ids, - &position_embedding_impl(&1, embedding_dim, half_dim, num_embeddings, padding_idx), - name: join(name, "sinusoidal_position_embedding") + + position_ids = Axon.add(position_ids, Axon.constant(Nx.tensor(offset))) + + Axon.layer(&sinusoidal_position_embedding_impl/2, [position_ids], + size: spec.hidden_size, + name: name ) end - defnp position_embedding_impl( - position_ids, - embedding_dim, - half_dim, - num_embeddings, - padding_idx - ) do - zero_pad_slice = Nx.broadcast(0.0, {1, embedding_dim}) - - Nx.log(10_000) - |> Nx.divide(half_dim - 1) - |> Nx.negate() - |> Nx.multiply(Nx.iota({half_dim})) - |> Nx.exp() - |> Nx.new_axis(0) - |> Nx.multiply(Nx.new_axis(Nx.iota({num_embeddings}), 1)) - |> then(&Nx.concatenate([Nx.sin(&1), Nx.cos(&1)], axis: 1)) - |> Nx.reshape({num_embeddings, :auto}) - |> then(fn emb -> - if rem(embedding_dim, 2) == 1 do - Nx.concatenate([emb, Nx.broadcast(0, {num_embeddings, 1})], axis: 1) - else - emb - end - end) - |> Nx.put_slice([padding_idx, 0], zero_pad_slice) - |> Nx.take(Nx.as_type(position_ids, {:s, 64})) + defnp sinusoidal_position_embedding_impl(position_ids, opts \\ []) do + size = opts[:size] + + half_size = div(size, 2) + base = 10_000 + range = Nx.iota({half_size}) / (half_size - 1) + inv_frequency = 1 / Nx.pow(base, range) + angle = Nx.outer(position_ids, inv_frequency) + Nx.concatenate([Nx.sin(angle), Nx.cos(angle)], axis: -1) end defp decoder(