Skip to content

Commit

Permalink
Add support for CLS token pooling in text embedding (#385)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonatan Kłosko <[email protected]>
  • Loading branch information
nyo16 and jonatanklosko authored Aug 6, 2024
1 parent 88cef27 commit 7db36b8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
9 changes: 8 additions & 1 deletion lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,14 @@ defmodule Bumblebee.Text do
this option is ignored. Defaults to `:pooled_state`
* `:output_pool` - pooling to apply on top of the model output, in case
it is not already a pooled embedding. Supported values: `:mean_pooling`.
it is not already a pooled embedding. Supported values:
* `:mean_pooling` - performs a mean across all tokens
* `cls_token_pooling` - takes the embedding for the special CLS token.
Note that we currently assume that the CLS token is the first token
in the sequence
By default no pooling is applied
* `:embedding_processor` - a post-processing step to apply to the
Expand Down
21 changes: 10 additions & 11 deletions lib/bumblebee/text/text_embedding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,21 @@ defmodule Bumblebee.Text.TextEmbedding do
output
end

if output_pool != nil and Nx.rank(output) != 3 do
raise ArgumentError,
"expected the output tensor to have rank 3 to apply :output_pool, got: #{Nx.rank(output)}." <>
" You should either disable pooling or pick a different output using :output_attribute"
end

output =
case output_pool do
nil ->
output

:mean_pooling ->
case Nx.rank(output) do
3 ->
:ok

rank ->
raise ArgumentError,
"expected the output tensor to have rank 3 to apply :output_pool, got: #{rank}." <>
" You should either disable pooling or pick a different output using :output_attribute"
end
:cls_token_pooling ->
Nx.take(output, 0, axis: 1)

:mean_pooling ->
input_mask_expanded = Nx.new_axis(inputs["attention_mask"], -1)

output
Expand All @@ -81,7 +80,7 @@ defmodule Bumblebee.Text.TextEmbedding do

other ->
raise ArgumentError,
"expected :output_pool to be one of nil or :mean_pooling, got: #{inspect(other)}"
"expected :output_pool to be one of :cls_token_pooling, :mean_pooling or nil, got: #{inspect(other)}"
end

output =
Expand Down

0 comments on commit 7db36b8

Please sign in to comment.