From cc7df8587f85148923150c66bb2d66d73e9a5814 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 6 Aug 2024 21:13:46 +0900 Subject: [PATCH] Support Llama 3 --- lib/bumblebee/layers.ex | 82 +++++++++++++++++++++++++----- lib/bumblebee/text/llama.ex | 32 +++++++++++- test/bumblebee/text/llama_test.exs | 26 ++++++++++ 3 files changed, 126 insertions(+), 14 deletions(-) diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index 59c9ce9c..cde12666 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -1234,9 +1234,9 @@ defmodule Bumblebee.Layers do case scaling_strategy do %{type: :linear, factor: factor} -> - frequency = Nx.pow(base, range) + inv_frequency = inv_frequency(base, range) position = Nx.divide(position, factor) - positions_cos_sin(position, frequency) + positions_cos_sin(position, inv_frequency) %{type: :dynamic, factor: factor} when sequence_length > max_positions -> base = @@ -1244,8 +1244,8 @@ defmodule Bumblebee.Layers do |> Nx.multiply(factor * sequence_length / max_positions - (factor - 1)) |> Nx.pow(size / (size - 2)) - frequency = Nx.pow(base, range) - positions_cos_sin(position, frequency) + inv_frequency = inv_frequency(base, range) + positions_cos_sin(position, inv_frequency) %{ type: type, @@ -1278,22 +1278,80 @@ defmodule Bumblebee.Layers do |> Nx.add(1.0) end - frequency = Nx.multiply(factor, Nx.pow(base, range)) - {cos, sin} = positions_cos_sin(position, frequency) + inv_frequency = inv_frequency(base, range) |> Nx.divide(factor) + {cos, sin} = positions_cos_sin(position, inv_frequency) {Nx.multiply(cos, cos_sin_factor), Nx.multiply(sin, cos_sin_factor)} + %{ + type: :llama3, + factor: factor, + low_frequency_factor: low_frequency_factor, + high_frequency_factor: high_frequency_factor, + original_max_positions: original_max_positions + } -> + inv_frequency = inv_frequency(base, range) + + inv_frequency = + llama3_inv_frequency( + inv_frequency, + factor, + low_frequency_factor, + high_frequency_factor, + original_max_positions + ) + + positions_cos_sin(position, inv_frequency) + _other -> - frequency = Nx.pow(base, range) - positions_cos_sin(position, frequency) + inv_frequency = inv_frequency(base, range) + positions_cos_sin(position, inv_frequency) end end - defnp positions_cos_sin(position, frequency) do - inv_frequency = 1.0 / frequency - angle = Nx.outer(position, inv_frequency) + defnp llama3_inv_frequency( + inv_frequency, + factor, + low_frequency_factor, + high_frequency_factor, + original_max_positions + ) do + low_frequency_wavelength = original_max_positions / low_frequency_factor + high_frequency_wavelength = original_max_positions / high_frequency_factor - angle = Nx.concatenate([angle, angle], axis: -1) + # Vectorize to enable cleaner conditional + inv_frequency = Nx.vectorize(inv_frequency, :range) + + wavelength = 2 * Nx.Constants.pi() / inv_frequency + + inv_frequency = + cond do + wavelength < high_frequency_wavelength -> + inv_frequency + + wavelength > low_frequency_wavelength -> + inv_frequency / factor + true -> + # Interpolation between the two cases above + + smooth_factor = + (original_max_positions / wavelength - low_frequency_factor) / + (high_frequency_factor - low_frequency_factor) + + (1 - smooth_factor) * inv_frequency / factor + smooth_factor * inv_frequency + end + + Nx.devectorize(inv_frequency) + end + + defnp inv_frequency(base, range) do + frequency = Nx.pow(base, range) + 1.0 / frequency + end + + defnp positions_cos_sin(position, inv_frequency) do + angle = Nx.outer(position, inv_frequency) + angle = Nx.concatenate([angle, angle], axis: -1) {Nx.cos(angle), Nx.sin(angle)} end diff --git a/lib/bumblebee/text/llama.ex b/lib/bumblebee/text/llama.ex index 5a909ddc..dac7b2ea 100644 --- a/lib/bumblebee/text/llama.ex +++ b/lib/bumblebee/text/llama.ex @@ -55,6 +55,8 @@ defmodule Bumblebee.Text.Llama do * `%{type: :dynamic, factor: number()}` + * `%{type: :llama3, factor: number(), low_frequency_factor: number(), high_frequency_factor: number(), original_max_positions: pos_integer()}` + For more details see https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases """ ], @@ -383,13 +385,39 @@ defmodule Bumblebee.Text.Llama do import Shared.Converters scaling_strategy_converter = fn name, value -> + # "type" has been renamed to "rope_type" + value = + case Map.pop(value, "type") do + {nil, value} -> value + {type, value} -> Map.put(value, "rope_type", type) + end + case value do - %{"type" => "linear", "factor" => factor} when is_number(factor) -> + %{"rope_type" => "linear", "factor" => factor} when is_number(factor) -> {:ok, %{type: :linear, factor: factor}} - %{"type" => "dynamic", "factor" => factor} when is_number(factor) -> + %{"rope_type" => "dynamic", "factor" => factor} when is_number(factor) -> {:ok, %{type: :dynamic, factor: factor}} + %{ + "rope_type" => "llama3", + "factor" => factor, + "low_freq_factor" => low_frequency_factor, + "high_freq_factor" => high_frequency_factor, + "original_max_position_embeddings" => original_max_positions + } + when is_number(factor) and is_number(low_frequency_factor) and + is_number(high_frequency_factor) and + is_number(original_max_positions) -> + {:ok, + %{ + type: :llama3, + factor: factor, + low_frequency_factor: low_frequency_factor, + high_frequency_factor: high_frequency_factor, + original_max_positions: original_max_positions + }} + _other -> {:error, "invalid format for #{inspect(name)}, got: #{inspect(value)}"} end diff --git a/test/bumblebee/text/llama_test.exs b/test/bumblebee/text/llama_test.exs index 0e64833c..fb143c60 100644 --- a/test/bumblebee/text/llama_test.exs +++ b/test/bumblebee/text/llama_test.exs @@ -28,6 +28,32 @@ defmodule Bumblebee.Text.LlamaTest do ) end + test ":base rotary embedding scaling strategy :llama3" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, + "bumblebee-testing/tiny-random-LlamaModel-rope_scaling-llama3-original_max_position_embeddings-64"} + ) + + assert %Bumblebee.Text.Llama{architecture: :base} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [[1.4802, -2.0331, 0.4759], [2.3749, -0.8367, -0.0205], [0.5762, -0.0517, -1.1795]] + ]) + ) + end + test ":for_sequence_classification" do assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model(