From 78dc2e3866ae7c365b89f1c2795a8d3c5411d01d Mon Sep 17 00:00:00 2001 From: Nikos Maroulis Date: Fri, 2 Aug 2024 09:03:41 -0400 Subject: [PATCH 1/6] add suport for cls token pooling --- lib/bumblebee/text/text_embedding.ex | 14 +++++++++++++- test/bumblebee/text/text_embedding_test.exs | 15 +++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/lib/bumblebee/text/text_embedding.ex b/lib/bumblebee/text/text_embedding.ex index bf6a7a97..d65d94a6 100644 --- a/lib/bumblebee/text/text_embedding.ex +++ b/lib/bumblebee/text/text_embedding.ex @@ -61,6 +61,18 @@ defmodule Bumblebee.Text.TextEmbedding do nil -> output + :cls_token_pooling -> + case Nx.rank(output) do + 3 -> + # Assuming CLS token is always at the first position + Nx.slice_along_axis(output, 0, 1, axis: 1) |> Nx.squeeze(axes: [1]) + + rank -> + raise ArgumentError, + "expected the output tensor to have rank 3 to apply :cls pooling, got: #{rank}." <> + " You should either disable pooling or pick a different output using :output_attribute" + end + :mean_pooling -> case Nx.rank(output) do 3 -> @@ -81,7 +93,7 @@ defmodule Bumblebee.Text.TextEmbedding do other -> raise ArgumentError, - "expected :output_pool to be one of nil or :mean_pooling, got: #{inspect(other)}" + "expected :output_pool to be one of :cls_token_pooling, :mean_pooling or nil, got: #{inspect(other)}" end output = diff --git a/test/bumblebee/text/text_embedding_test.exs b/test/bumblebee/text/text_embedding_test.exs index cb8ed084..61a8968b 100644 --- a/test/bumblebee/text/text_embedding_test.exs +++ b/test/bumblebee/text/text_embedding_test.exs @@ -100,4 +100,19 @@ defmodule Bumblebee.Text.TextEmbeddingTest do assert_equal(embedding1, embedding2) end + + test "cls token pooling" do + {:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-small-v2"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-small-v2"}) + + serving = + Bumblebee.Text.text_embedding(model_info, tokenizer, + output_attribute: :hidden_state, + output_pool: :cls_token_pooling, + embedding_processor: :l2_norm + ) + + # Nx.Serving.run(serving, "A long text to test the embeddings") + # TBD: tests + end end From 32dfdc300e983f503cd12cce5a475ffe344e9669 Mon Sep 17 00:00:00 2001 From: Nikos Maroulis Date: Fri, 2 Aug 2024 09:06:01 -0400 Subject: [PATCH 2/6] removing test --- test/bumblebee/text/text_embedding_test.exs | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/test/bumblebee/text/text_embedding_test.exs b/test/bumblebee/text/text_embedding_test.exs index 61a8968b..cb8ed084 100644 --- a/test/bumblebee/text/text_embedding_test.exs +++ b/test/bumblebee/text/text_embedding_test.exs @@ -100,19 +100,4 @@ defmodule Bumblebee.Text.TextEmbeddingTest do assert_equal(embedding1, embedding2) end - - test "cls token pooling" do - {:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-small-v2"}) - {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-small-v2"}) - - serving = - Bumblebee.Text.text_embedding(model_info, tokenizer, - output_attribute: :hidden_state, - output_pool: :cls_token_pooling, - embedding_processor: :l2_norm - ) - - # Nx.Serving.run(serving, "A long text to test the embeddings") - # TBD: tests - end end From 35351416b57beb26c0389dacc4a9f6a7779f67ec Mon Sep 17 00:00:00 2001 From: Nikos Maroulis Date: Fri, 2 Aug 2024 09:49:49 -0400 Subject: [PATCH 3/6] simplifying Nx slice --- lib/bumblebee/text/text_embedding.ex | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/bumblebee/text/text_embedding.ex b/lib/bumblebee/text/text_embedding.ex index d65d94a6..668e01ac 100644 --- a/lib/bumblebee/text/text_embedding.ex +++ b/lib/bumblebee/text/text_embedding.ex @@ -65,8 +65,8 @@ defmodule Bumblebee.Text.TextEmbedding do case Nx.rank(output) do 3 -> # Assuming CLS token is always at the first position - Nx.slice_along_axis(output, 0, 1, axis: 1) |> Nx.squeeze(axes: [1]) - + Nx.take(output, 0, axis: 1) + rank -> raise ArgumentError, "expected the output tensor to have rank 3 to apply :cls pooling, got: #{rank}." <> From 574128c801fe220ae0224212c165894f314ad8e2 Mon Sep 17 00:00:00 2001 From: Nikos Maroulis Date: Fri, 2 Aug 2024 09:53:34 -0400 Subject: [PATCH 4/6] fixing formating issues. --- lib/bumblebee/text/text_embedding.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/bumblebee/text/text_embedding.ex b/lib/bumblebee/text/text_embedding.ex index 668e01ac..d51ddc03 100644 --- a/lib/bumblebee/text/text_embedding.ex +++ b/lib/bumblebee/text/text_embedding.ex @@ -66,7 +66,7 @@ defmodule Bumblebee.Text.TextEmbedding do 3 -> # Assuming CLS token is always at the first position Nx.take(output, 0, axis: 1) - + rank -> raise ArgumentError, "expected the output tensor to have rank 3 to apply :cls pooling, got: #{rank}." <> From 61fe5d2e61fe263ef03c5b7603117b5c56edc40d Mon Sep 17 00:00:00 2001 From: Nikos Maroulis Date: Tue, 6 Aug 2024 08:47:11 -0400 Subject: [PATCH 5/6] documentation added to main module. Simplifying logic for check the artiy of args --- lib/bumblebee/text.ex | 5 +++-- lib/bumblebee/text/text_embedding.ex | 27 +++++++-------------------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/lib/bumblebee/text.ex b/lib/bumblebee/text.ex index 2e9233b0..1880208a 100644 --- a/lib/bumblebee/text.ex +++ b/lib/bumblebee/text.ex @@ -313,8 +313,9 @@ defmodule Bumblebee.Text do this option is ignored. Defaults to `:pooled_state` * `:output_pool` - pooling to apply on top of the model output, in case - it is not already a pooled embedding. Supported values: `:mean_pooling`. - By default no pooling is applied + it is not already a pooled embedding. Supported values: `:mean_pooling` or `cls_token_pooling`. + For `cls_token_pooling` we assume that the first token is the CLS token. + By default no pooling is applied. * `:embedding_processor` - a post-processing step to apply to the embedding. Supported values: `:l2_norm`. By default the output is diff --git a/lib/bumblebee/text/text_embedding.ex b/lib/bumblebee/text/text_embedding.ex index d51ddc03..41e284e8 100644 --- a/lib/bumblebee/text/text_embedding.ex +++ b/lib/bumblebee/text/text_embedding.ex @@ -56,34 +56,21 @@ defmodule Bumblebee.Text.TextEmbedding do output end + if output_pool != nil and Nx.rank(output) != 3 do + raise ArgumentError, + "expected the output tensor to have rank 3 to apply :output_pool, got: #{Nx.rank(output)}." <> + " You should either disable pooling or pick a different output using :output_attribute" + end + output = case output_pool do nil -> output :cls_token_pooling -> - case Nx.rank(output) do - 3 -> - # Assuming CLS token is always at the first position - Nx.take(output, 0, axis: 1) - - rank -> - raise ArgumentError, - "expected the output tensor to have rank 3 to apply :cls pooling, got: #{rank}." <> - " You should either disable pooling or pick a different output using :output_attribute" - end + Nx.take(output, 0, axis: 1) :mean_pooling -> - case Nx.rank(output) do - 3 -> - :ok - - rank -> - raise ArgumentError, - "expected the output tensor to have rank 3 to apply :output_pool, got: #{rank}." <> - " You should either disable pooling or pick a different output using :output_attribute" - end - input_mask_expanded = Nx.new_axis(inputs["attention_mask"], -1) output From 5ce9de434f9e69de5b6bdc91b129a5c81be86f17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 6 Aug 2024 15:05:20 +0200 Subject: [PATCH 6/6] Update lib/bumblebee/text.ex --- lib/bumblebee/text.ex | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/lib/bumblebee/text.ex b/lib/bumblebee/text.ex index 1880208a..aa4ca19b 100644 --- a/lib/bumblebee/text.ex +++ b/lib/bumblebee/text.ex @@ -313,9 +313,15 @@ defmodule Bumblebee.Text do this option is ignored. Defaults to `:pooled_state` * `:output_pool` - pooling to apply on top of the model output, in case - it is not already a pooled embedding. Supported values: `:mean_pooling` or `cls_token_pooling`. - For `cls_token_pooling` we assume that the first token is the CLS token. - By default no pooling is applied. + it is not already a pooled embedding. Supported values: + + * `:mean_pooling` - performs a mean across all tokens + + * `cls_token_pooling` - takes the embedding for the special CLS token. + Note that we currently assume that the CLS token is the first token + in the sequence + + By default no pooling is applied * `:embedding_processor` - a post-processing step to apply to the embedding. Supported values: `:l2_norm`. By default the output is