Skip to content

Commit

Permalink
Merge branch 'main' into jk-binwrite
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Aug 8, 2024
2 parents dc647bb + 957ee59 commit 998182f
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 19 deletions.
19 changes: 19 additions & 0 deletions lib/bumblebee/application.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
defmodule Bumblebee.Application do
@moduledoc false

use Application

@impl true
def start(_type, _args) do
Bumblebee.Utils.HTTP.start_inets_profile()

children = []
opts = [strategy: :one_for_one, name: Bumblebee.Supervisor]
Supervisor.start_link(children, opts)
end

@impl true
def stop(_state) do
Bumblebee.Utils.HTTP.stop_inets_profile()
end
end
9 changes: 8 additions & 1 deletion lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,14 @@ defmodule Bumblebee.Text do
this option is ignored. Defaults to `:pooled_state`
* `:output_pool` - pooling to apply on top of the model output, in case
it is not already a pooled embedding. Supported values: `:mean_pooling`.
it is not already a pooled embedding. Supported values:
* `:mean_pooling` - performs a mean across all tokens
* `cls_token_pooling` - takes the embedding for the special CLS token.
Note that we currently assume that the CLS token is the first token
in the sequence
By default no pooling is applied
* `:embedding_processor` - a post-processing step to apply to the
Expand Down
21 changes: 10 additions & 11 deletions lib/bumblebee/text/text_embedding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,21 @@ defmodule Bumblebee.Text.TextEmbedding do
output
end

if output_pool != nil and Nx.rank(output) != 3 do
raise ArgumentError,
"expected the output tensor to have rank 3 to apply :output_pool, got: #{Nx.rank(output)}." <>
" You should either disable pooling or pick a different output using :output_attribute"
end

output =
case output_pool do
nil ->
output

:mean_pooling ->
case Nx.rank(output) do
3 ->
:ok

rank ->
raise ArgumentError,
"expected the output tensor to have rank 3 to apply :output_pool, got: #{rank}." <>
" You should either disable pooling or pick a different output using :output_attribute"
end
:cls_token_pooling ->
Nx.take(output, 0, axis: 1)

:mean_pooling ->
input_mask_expanded = Nx.new_axis(inputs["attention_mask"], -1)

output
Expand All @@ -81,7 +80,7 @@ defmodule Bumblebee.Text.TextEmbedding do

other ->
raise ArgumentError,
"expected :output_pool to be one of nil or :mean_pooling, got: #{inspect(other)}"
"expected :output_pool to be one of :cls_token_pooling, :mean_pooling or nil, got: #{inspect(other)}"
end

output =
Expand Down
56 changes: 51 additions & 5 deletions lib/bumblebee/utils/http.ex
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ defmodule Bumblebee.Utils.HTTP do
if Process.alive?(caller) do
send(caller, {:http, reply_info})
else
:httpc.cancel_request(request_id)
:httpc.cancel_request(request_id, :bumblebee)
end
end

opts = [stream: :self, sync: false, receiver: receiver]

{:ok, request_id} = :httpc.request(:get, request, http_opts, opts)
{:ok, request_id} = :httpc.request(:get, request, http_opts, opts, :bumblebee)
download_loop(%{request_id: request_id, file: file, total_size: nil, size: nil})
after
File.close(file)
Expand Down Expand Up @@ -97,7 +97,7 @@ defmodule Bumblebee.Utils.HTTP do
download_loop(state)

{:error, error} ->
:httpc.cancel_request(state.request_id)
:httpc.cancel_request(state.request_id, :bumblebee)
{:error, error}
end
end
Expand Down Expand Up @@ -161,7 +161,7 @@ defmodule Bumblebee.Utils.HTTP do
body_format: :binary
]

case :httpc.request(method, request, http_opts, opts) do
case :httpc.request(method, request, http_opts, opts, :bumblebee) do
{:ok, {{_, status, _}, headers, body}} ->
{:ok, %{status: status, headers: parse_headers(headers), body: body}}

Expand All @@ -186,13 +186,59 @@ defmodule Bumblebee.Utils.HTTP do
end

defp http_ssl_opts() do
# Allow a user-specified CA certs to support, for example, HTTPS proxies
cacert_opt =
case System.get_env("BUMBLEBEE_CACERTS_PATH") do
nil -> {:cacerts, :public_key.cacerts_get()}
file -> {:cacertfile, file}
end

# Use secure options, see https://gist.github.com/jonatanklosko/5e20ca84127f6b31bbe3906498e1a1d7
[
cacert_opt,
verify: :verify_peer,
cacertfile: CAStore.file_path(),
customize_hostname_check: [
match_fun: :public_key.pkix_verify_hostname_match_fun(:https)
]
]
end

@doc false
def start_inets_profile() do
# Starting an HTTP client profile allows us to scope the httpc
# configuration options, such as proxy options
{:ok, _pid} = :inets.start(:httpc, profile: :bumblebee)
set_proxy_options()
end

@doc false
def stop_inets_profile() do
:inets.stop(:httpc, :bumblebee)
end

defp set_proxy_options() do
http_proxy = System.get_env("HTTP_PROXY") || System.get_env("http_proxy")
https_proxy = System.get_env("HTTPS_PROXY") || System.get_env("https_proxy")

no_proxy =
if no_proxy = System.get_env("NO_PROXY") || System.get_env("no_proxy") do
no_proxy
|> String.split(",")
|> Enum.map(&String.to_charlist/1)
else
[]
end

set_proxy_option(:proxy, http_proxy, no_proxy)
set_proxy_option(:https_proxy, https_proxy, no_proxy)
end

defp set_proxy_option(proxy_scheme, proxy, no_proxy) do
uri = URI.parse(proxy || "")

if uri.host && uri.port do
host = String.to_charlist(uri.host)
:httpc.set_options([{proxy_scheme, {{host, uri.port}, no_proxy}}], :bumblebee)
end
end
end
2 changes: 1 addition & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ defmodule Bumblebee.MixProject do

def application do
[
mod: {Bumblebee.Application, []},
extra_applications: [:logger, :inets, :ssl]
]
end
Expand All @@ -42,7 +43,6 @@ defmodule Bumblebee.MixProject do
{:nx_image, "~> 0.1.0"},
{:unpickler, "~> 0.1.0"},
{:safetensors, "~> 0.1.3"},
{:castore, "~> 0.1 or ~> 1.0"},
{:jason, "~> 1.4.0"},
{:unzip, "~> 0.10.0"},
{:progress_bar, "~> 3.0"},
Expand Down
2 changes: 1 addition & 1 deletion mix.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
%{
"axon": {:git, "https://github.com/elixir-nx/axon.git", "054eb4c1c224582528e8d1603ad08e7c4088f21c", []},
"bypass": {:hex, :bypass, "2.1.0", "909782781bf8e20ee86a9cabde36b259d44af8b9f38756173e8f5e2e1fabb9b1", [:mix], [{:plug, "~> 1.7", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.0", [hex: :plug_cowboy, repo: "hexpm", optional: false]}, {:ranch, "~> 1.3", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "d9b5df8fa5b7a6efa08384e9bbecfe4ce61c77d28a4282f79e02f1ef78d96b80"},
"castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"},
"castore": {:hex, :castore, "1.0.8", "dedcf20ea746694647f883590b82d9e96014057aff1d44d03ec90f36a5c0dc6e", [:mix], [], "hexpm", "0b2b66d2ee742cb1d9cb8c8be3b43c3a70ee8651f37b75a8b982e036752983f1"},
"cc_precompiler": {:hex, :cc_precompiler, "0.1.8", "933a5f4da3b19ee56539a076076ce4d7716d64efc8db46fd066996a7e46e2bfd", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "176bdf4366956e456bf761b54ad70bc4103d0269ca9558fd7cee93d1b3f116db"},
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"cowboy": {:hex, :cowboy, "2.9.0", "865dd8b6607e14cf03282e10e934023a1bd8be6f6bacf921a7e2a96d800cd452", [:make, :rebar3], [{:cowlib, "2.11.0", [hex: :cowlib, repo: "hexpm", optional: false]}, {:ranch, "1.8.0", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "2c729f934b4e1aa149aff882f57c6372c15399a20d54f65c8d67bef583021bde"},
Expand Down

0 comments on commit 998182f

Please sign in to comment.