Skip to content

Commit

Permalink
Support Llama 3 (#387)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Aug 6, 2024
1 parent 5245a7e commit 88cef27
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 14 deletions.
82 changes: 70 additions & 12 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1234,18 +1234,18 @@ defmodule Bumblebee.Layers do

case scaling_strategy do
%{type: :linear, factor: factor} ->
frequency = Nx.pow(base, range)
inv_frequency = inv_frequency(base, range)
position = Nx.divide(position, factor)
positions_cos_sin(position, frequency)
positions_cos_sin(position, inv_frequency)

%{type: :dynamic, factor: factor} when sequence_length > max_positions ->
base =
base
|> Nx.multiply(factor * sequence_length / max_positions - (factor - 1))
|> Nx.pow(size / (size - 2))

frequency = Nx.pow(base, range)
positions_cos_sin(position, frequency)
inv_frequency = inv_frequency(base, range)
positions_cos_sin(position, inv_frequency)

%{
type: type,
Expand Down Expand Up @@ -1278,22 +1278,80 @@ defmodule Bumblebee.Layers do
|> Nx.add(1.0)
end

frequency = Nx.multiply(factor, Nx.pow(base, range))
{cos, sin} = positions_cos_sin(position, frequency)
inv_frequency = inv_frequency(base, range) |> Nx.divide(factor)
{cos, sin} = positions_cos_sin(position, inv_frequency)
{Nx.multiply(cos, cos_sin_factor), Nx.multiply(sin, cos_sin_factor)}

%{
type: :llama3,
factor: factor,
low_frequency_factor: low_frequency_factor,
high_frequency_factor: high_frequency_factor,
original_max_positions: original_max_positions
} ->
inv_frequency = inv_frequency(base, range)

inv_frequency =
llama3_inv_frequency(
inv_frequency,
factor,
low_frequency_factor,
high_frequency_factor,
original_max_positions
)

positions_cos_sin(position, inv_frequency)

_other ->
frequency = Nx.pow(base, range)
positions_cos_sin(position, frequency)
inv_frequency = inv_frequency(base, range)
positions_cos_sin(position, inv_frequency)
end
end

defnp positions_cos_sin(position, frequency) do
inv_frequency = 1.0 / frequency
angle = Nx.outer(position, inv_frequency)
defnp llama3_inv_frequency(
inv_frequency,
factor,
low_frequency_factor,
high_frequency_factor,
original_max_positions
) do
low_frequency_wavelength = original_max_positions / low_frequency_factor
high_frequency_wavelength = original_max_positions / high_frequency_factor

angle = Nx.concatenate([angle, angle], axis: -1)
# Vectorize to enable cleaner conditional
inv_frequency = Nx.vectorize(inv_frequency, :range)

wavelength = 2 * Nx.Constants.pi() / inv_frequency

inv_frequency =
cond do
wavelength < high_frequency_wavelength ->
inv_frequency

wavelength > low_frequency_wavelength ->
inv_frequency / factor

true ->
# Interpolation between the two cases above

smooth_factor =
(original_max_positions / wavelength - low_frequency_factor) /
(high_frequency_factor - low_frequency_factor)

(1 - smooth_factor) * inv_frequency / factor + smooth_factor * inv_frequency
end

Nx.devectorize(inv_frequency)
end

defnp inv_frequency(base, range) do
frequency = Nx.pow(base, range)
1.0 / frequency
end

defnp positions_cos_sin(position, inv_frequency) do
angle = Nx.outer(position, inv_frequency)
angle = Nx.concatenate([angle, angle], axis: -1)
{Nx.cos(angle), Nx.sin(angle)}
end

Expand Down
32 changes: 30 additions & 2 deletions lib/bumblebee/text/llama.ex
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ defmodule Bumblebee.Text.Llama do
* `%{type: :dynamic, factor: number()}`
* `%{type: :llama3, factor: number(), low_frequency_factor: number(), high_frequency_factor: number(), original_max_positions: pos_integer()}`
For more details see https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases
"""
],
Expand Down Expand Up @@ -383,13 +385,39 @@ defmodule Bumblebee.Text.Llama do
import Shared.Converters

scaling_strategy_converter = fn name, value ->
# "type" has been renamed to "rope_type"
value =
case Map.pop(value, "type") do
{nil, value} -> value
{type, value} -> Map.put(value, "rope_type", type)
end

case value do
%{"type" => "linear", "factor" => factor} when is_number(factor) ->
%{"rope_type" => "linear", "factor" => factor} when is_number(factor) ->
{:ok, %{type: :linear, factor: factor}}

%{"type" => "dynamic", "factor" => factor} when is_number(factor) ->
%{"rope_type" => "dynamic", "factor" => factor} when is_number(factor) ->
{:ok, %{type: :dynamic, factor: factor}}

%{
"rope_type" => "llama3",
"factor" => factor,
"low_freq_factor" => low_frequency_factor,
"high_freq_factor" => high_frequency_factor,
"original_max_position_embeddings" => original_max_positions
}
when is_number(factor) and is_number(low_frequency_factor) and
is_number(high_frequency_factor) and
is_number(original_max_positions) ->
{:ok,
%{
type: :llama3,
factor: factor,
low_frequency_factor: low_frequency_factor,
high_frequency_factor: high_frequency_factor,
original_max_positions: original_max_positions
}}

_other ->
{:error, "invalid format for #{inspect(name)}, got: #{inspect(value)}"}
end
Expand Down
26 changes: 26 additions & 0 deletions test/bumblebee/text/llama_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,32 @@ defmodule Bumblebee.Text.LlamaTest do
)
end

test ":base rotary embedding scaling strategy :llama3" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
{:hf,
"bumblebee-testing/tiny-random-LlamaModel-rope_scaling-llama3-original_max_position_embeddings-64"}
)

assert %Bumblebee.Text.Llama{architecture: :base} = spec

inputs = %{
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
}

outputs = Axon.predict(model, params, inputs)

assert Nx.shape(outputs.hidden_state) == {1, 10, 32}

assert_all_close(
outputs.hidden_state[[.., 1..3, 1..3]],
Nx.tensor([
[[1.4802, -2.0331, 0.4759], [2.3749, -0.8367, -0.0205], [0.5762, -0.0517, -1.1795]]
])
)
end

test ":for_sequence_classification" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
Expand Down

0 comments on commit 88cef27

Please sign in to comment.