Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Llama 3 #387

Merged
merged 1 commit into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 70 additions & 12 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1234,18 +1234,18 @@ defmodule Bumblebee.Layers do

case scaling_strategy do
%{type: :linear, factor: factor} ->
frequency = Nx.pow(base, range)
inv_frequency = inv_frequency(base, range)
position = Nx.divide(position, factor)
positions_cos_sin(position, frequency)
positions_cos_sin(position, inv_frequency)

%{type: :dynamic, factor: factor} when sequence_length > max_positions ->
base =
base
|> Nx.multiply(factor * sequence_length / max_positions - (factor - 1))
|> Nx.pow(size / (size - 2))

frequency = Nx.pow(base, range)
positions_cos_sin(position, frequency)
inv_frequency = inv_frequency(base, range)
positions_cos_sin(position, inv_frequency)

%{
type: type,
Expand Down Expand Up @@ -1278,22 +1278,80 @@ defmodule Bumblebee.Layers do
|> Nx.add(1.0)
end

frequency = Nx.multiply(factor, Nx.pow(base, range))
{cos, sin} = positions_cos_sin(position, frequency)
inv_frequency = inv_frequency(base, range) |> Nx.divide(factor)
{cos, sin} = positions_cos_sin(position, inv_frequency)
{Nx.multiply(cos, cos_sin_factor), Nx.multiply(sin, cos_sin_factor)}

%{
type: :llama3,
factor: factor,
low_frequency_factor: low_frequency_factor,
high_frequency_factor: high_frequency_factor,
original_max_positions: original_max_positions
} ->
inv_frequency = inv_frequency(base, range)

inv_frequency =
llama3_inv_frequency(
inv_frequency,
factor,
low_frequency_factor,
high_frequency_factor,
original_max_positions
)

positions_cos_sin(position, inv_frequency)

_other ->
frequency = Nx.pow(base, range)
positions_cos_sin(position, frequency)
inv_frequency = inv_frequency(base, range)
positions_cos_sin(position, inv_frequency)
end
end

defnp positions_cos_sin(position, frequency) do
inv_frequency = 1.0 / frequency
angle = Nx.outer(position, inv_frequency)
defnp llama3_inv_frequency(
inv_frequency,
factor,
low_frequency_factor,
high_frequency_factor,
original_max_positions
) do
low_frequency_wavelength = original_max_positions / low_frequency_factor
high_frequency_wavelength = original_max_positions / high_frequency_factor

angle = Nx.concatenate([angle, angle], axis: -1)
# Vectorize to enable cleaner conditional
inv_frequency = Nx.vectorize(inv_frequency, :range)

wavelength = 2 * Nx.Constants.pi() / inv_frequency

inv_frequency =
cond do
wavelength < high_frequency_wavelength ->
inv_frequency

wavelength > low_frequency_wavelength ->
inv_frequency / factor

true ->
# Interpolation between the two cases above

smooth_factor =
(original_max_positions / wavelength - low_frequency_factor) /
(high_frequency_factor - low_frequency_factor)

(1 - smooth_factor) * inv_frequency / factor + smooth_factor * inv_frequency
end

Nx.devectorize(inv_frequency)
end

defnp inv_frequency(base, range) do
frequency = Nx.pow(base, range)
1.0 / frequency
end

defnp positions_cos_sin(position, inv_frequency) do
angle = Nx.outer(position, inv_frequency)
angle = Nx.concatenate([angle, angle], axis: -1)
{Nx.cos(angle), Nx.sin(angle)}
end

Expand Down
32 changes: 30 additions & 2 deletions lib/bumblebee/text/llama.ex
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ defmodule Bumblebee.Text.Llama do

* `%{type: :dynamic, factor: number()}`

* `%{type: :llama3, factor: number(), low_frequency_factor: number(), high_frequency_factor: number(), original_max_positions: pos_integer()}`

For more details see https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases
"""
],
Expand Down Expand Up @@ -383,13 +385,39 @@ defmodule Bumblebee.Text.Llama do
import Shared.Converters

scaling_strategy_converter = fn name, value ->
# "type" has been renamed to "rope_type"
value =
case Map.pop(value, "type") do
{nil, value} -> value
{type, value} -> Map.put(value, "rope_type", type)
end

case value do
%{"type" => "linear", "factor" => factor} when is_number(factor) ->
%{"rope_type" => "linear", "factor" => factor} when is_number(factor) ->
{:ok, %{type: :linear, factor: factor}}

%{"type" => "dynamic", "factor" => factor} when is_number(factor) ->
%{"rope_type" => "dynamic", "factor" => factor} when is_number(factor) ->
{:ok, %{type: :dynamic, factor: factor}}

%{
"rope_type" => "llama3",
"factor" => factor,
"low_freq_factor" => low_frequency_factor,
"high_freq_factor" => high_frequency_factor,
"original_max_position_embeddings" => original_max_positions
}
when is_number(factor) and is_number(low_frequency_factor) and
is_number(high_frequency_factor) and
is_number(original_max_positions) ->
{:ok,
%{
type: :llama3,
factor: factor,
low_frequency_factor: low_frequency_factor,
high_frequency_factor: high_frequency_factor,
original_max_positions: original_max_positions
}}

_other ->
{:error, "invalid format for #{inspect(name)}, got: #{inspect(value)}"}
end
Expand Down
26 changes: 26 additions & 0 deletions test/bumblebee/text/llama_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,32 @@ defmodule Bumblebee.Text.LlamaTest do
)
end

test ":base rotary embedding scaling strategy :llama3" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
{:hf,
"bumblebee-testing/tiny-random-LlamaModel-rope_scaling-llama3-original_max_position_embeddings-64"}
)

assert %Bumblebee.Text.Llama{architecture: :base} = spec

inputs = %{
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
}

outputs = Axon.predict(model, params, inputs)

assert Nx.shape(outputs.hidden_state) == {1, 10, 32}

assert_all_close(
outputs.hidden_state[[.., 1..3, 1..3]],
Nx.tensor([
[[1.4802, -2.0331, 0.4759], [2.3749, -0.8367, -0.0205], [0.5762, -0.0517, -1.1795]]
])
)
end

test ":for_sequence_classification" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
Expand Down
Loading