-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for CLS token pooling in text embedding #385
Conversation
@jonatanklosko let me know if this looks good. For testing i didn't saw anything regarding :mean_pooling so I didnt add one and i was thinking actually what is testable in that case. |
lib/bumblebee/text/text_embedding.ex
Outdated
raise ArgumentError, | ||
"expected the output tensor to have rank 3 to apply :cls pooling, got: #{rank}." <> | ||
" You should either disable pooling or pick a different output using :output_attribute" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe for any pooling we expect 3, because we reduce n tokens into 1. Now that we have more types, we can move the check before the case, like this:
if output_pool != nil and Nx.rank(output) != 3 do
raise ...
end
We can use the message from the other clause!
@jonatanklosko thank you for the comments. I changed the code to reflect them. Let me know if this is good! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
This PR adding support for token pooling for models like BGE-M3.
Bellow are the results from the python implementation (i am using only the "dense" output and not the sparse one)
I believe the small differences are because of different implementation of floating point between python <> elixir.