Skip to content

Commit

Permalink
Add Swin model (#394)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonatan Kłosko <[email protected]>
  • Loading branch information
bosko and jonatanklosko authored Sep 3, 2024
1 parent 9421eca commit 45a265b
Show file tree
Hide file tree
Showing 5 changed files with 804 additions and 3 deletions.
2 changes: 2 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ defmodule Bumblebee do
"RobertaForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification},
"RobertaForCausalLM" => {Bumblebee.Text.Roberta, :for_causal_language_modeling},
"RobertaModel" => {Bumblebee.Text.Roberta, :base},
"SwinModel" => {Bumblebee.Vision.Swin, :base},
"SwinForImageClassification" => {Bumblebee.Vision.Swin, :for_image_classification},
"T5Model" => {Bumblebee.Text.T5, :base},
"T5ForConditionalGeneration" => {Bumblebee.Text.T5, :for_conditional_generation},
"T5EncoderModel" => {Bumblebee.Text.T5, :encoder},
Expand Down
12 changes: 9 additions & 3 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ defmodule Bumblebee.Layers do
* `query` - `{batch_size, sequence_length, num_heads, head_size}`
* `key` - `{batch_size, kv_sequence_length, num_heads, head_size}`
* `value` - `{batch_size, kv_sequence_length, num_heads, head_size}`
* `key_mask` (optional) - `{batch_size, kv_sequence_length}`
* `key_mask` (optional) - `{batch_size, kv_sequence_length} | {batch_size, num_heads, sequence_length, kv_sequence_length}`
* `head_mask` (optional) - `{num_heads}`
* `bias` (optional) - `{batch_size | 1, num_heads | 1, sequence_length, kv_sequence_length}`
* `offset` (optional) - `{}`
Expand Down Expand Up @@ -273,8 +273,14 @@ defmodule Bumblebee.Layers do

key_mask =
case key_mask do
%Axon.None{} -> Nx.broadcast(1, {1, 1, 1, 1})
key_mask -> key_mask |> Nx.new_axis(1) |> Nx.new_axis(1)
%Axon.None{} ->
Nx.broadcast(1, {1, 1, 1, 1})

key_mask ->
case Nx.rank(key_mask) do
2 -> key_mask |> Nx.new_axis(1) |> Nx.new_axis(1)
4 -> key_mask
end
end

query_sequence_length = Nx.axis_size(query, 2)
Expand Down
76 changes: 76 additions & 0 deletions lib/bumblebee/utils/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,82 @@ defmodule Bumblebee.Utils.Nx do
Nx.take(tensor, flat_idx, axis: opts[:axis])
end

@doc """
Shifts elements along the specified axes.
When an shift is positive, the elements are shifted clockwise.
Negative shifts result in counter-clockwise shift.
## Options
* `:shifts` - the shift size to apply to the corresponding axis
from `:axes`
* `:axes` - the axes to apply shift to, must have the same length
as `:shifts`
## Examples
iex> x = Nx.iota({3, 3})
iex> Bumblebee.Utils.Nx.roll(x, shifts: [1], axes: [0])
#Nx.Tensor<
s64[3][3]
[
[6, 7, 8],
[0, 1, 2],
[3, 4, 5]
]
>
iex> x = Nx.iota({3, 3})
iex> Bumblebee.Utils.Nx.roll(x, shifts: [-1], axes: [0])
#Nx.Tensor<
s64[3][3]
[
[3, 4, 5],
[6, 7, 8],
[0, 1, 2]
]
>
iex> x = Nx.iota({3, 3})
iex> Bumblebee.Utils.Nx.roll(x, shifts: [1, 2], axes: [0, 1])
#Nx.Tensor<
s64[3][3]
[
[7, 8, 6],
[1, 2, 0],
[4, 5, 3]
]
>
"""
deftransform roll(tensor, opts) do
opts = Keyword.validate!(opts, shifts: [], axes: [])

shifts = opts[:shifts]
axes = opts[:axes]

if length(shifts) != length(axes) do
raise ArgumentError,
"expected shifts and axes to have the same number of elements," <>
" got shifts: #{inspect(shifts)}, axes: #{inspect(axes)}"
end

shifts
|> Enum.zip(axes)
|> Enum.reduce(tensor, fn {shift, axis}, tensor ->
shift = rem(shift, Nx.axis_size(tensor, axis))

if shift == 0 do
tensor
else
{left, right} = Nx.split(tensor, -shift, axis: axis)
Nx.concatenate([right, left], axis: axis)
end
end)
end

@doc """
Returns size of the given `Nx.Batch`, including padding.
"""
Expand Down
Loading

0 comments on commit 45a265b

Please sign in to comment.