Skip to content

Commit

Permalink
Add support for NTK-by-Part Rotary Embedding & set correct rotary bas…
Browse files Browse the repository at this point in the history
…e for Llama-3.1 series (#764)

* Add support for NTK-by-Part Rotary Embedding & set correct rotary base for Llama-3.1-8B

* Add support for NTK-by-Part Rotary Embedding & set correct rotary base for Llama-3.1 series

* Add support for NTK-by-Part Rotary Embedding & set correct rotary base for Llama-3.1 series

* Add support for NTK-by-Part Rotary Embedding & set correct rotary base for Llama-3.1 series

* fix import order

* fix black check

* fix rope settings also for 3.2 models

* fix rope settings also for llama-3 models

---------

Co-authored-by: Bryce Meyer <[email protected]>
  • Loading branch information
Hzfinfdu and bryce13950 authored Oct 26, 2024
1 parent 8029d13 commit c7837fb
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 2 deletions.
16 changes: 16 additions & 0 deletions transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,18 @@ class HookedTransformerConfig:
output_logits_soft_cap (float): An optional softcap for output logits, currently only used
in Gemma-2 (see attn_scores_soft_cap for details). Defaults to -1.0, which means not
set.
use_NTK_by_parts_rope (bool): Whether to apply the "NTK-by-parts" method when using Rotary
Positional Embedding. This method adjusts the interpolation based on frequency factors
for different parts of the hidden dimensions. See Section 3.2 in
https://arxiv.org/pdf/2309.00071 for details. Defaults to False.
NTK_by_parts_low_freq_factor (float): The threshold applied to low-frequency hidden
dimensions during interpolation when using the "NTK-by-parts" method. Defaults to 1.0.
NTK_by_parts_high_freq_factor (float): The threshold applied to high-frequency hidden
dimensions during interpolation in the "NTK-by-parts" method. Defaults to 4.0.
NTK_by_parts_factor (float): The overall factor used in the "NTK-by-parts" method that
affects the rate of change between low and high-frequency interpolation strategies.
Defaults to 8.0.
"""

Expand Down Expand Up @@ -246,6 +258,10 @@ class HookedTransformerConfig:
use_normalization_before_and_after: bool = False
attn_scores_soft_cap: float = -1.0
output_logits_soft_cap: float = -1.0
use_NTK_by_parts_rope: bool = False
NTK_by_parts_low_freq_factor: float = 1.0
NTK_by_parts_high_freq_factor: float = 4.0
NTK_by_parts_factor: float = 8.0

def __post_init__(self):
if self.n_heads == -1:
Expand Down
30 changes: 28 additions & 2 deletions transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from abc import ABC
from typing import Dict, Optional, Tuple, Union

Expand Down Expand Up @@ -478,8 +479,33 @@ def calculate_sin_cos_rotary(
pos = torch.arange(n_ctx, dtype=high_precision)
dim = torch.arange(rotary_dim // 2, dtype=high_precision)

# A set of frequencies evenly spaced in log space
freq = base ** (dim / (rotary_dim / 2))
# Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071
# Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310
if self.cfg.use_NTK_by_parts_rope:
inv_freq = 1.0 / (
base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim)
)
factor = self.cfg.NTK_by_parts_factor
low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor
high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor
old_context_len = n_ctx

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor

wavelen = 2 * math.pi / inv_freq
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
smoothed_inv_freq = (
1 - smooth_factor
) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
freq = 1 / inv_freq_llama
else:
freq = base ** (dim / (rotary_dim / 2))
if self.cfg.rotary_adjacent_pairs:
freq = einops.repeat(freq, "d -> (d 2)")
else:
Expand Down
22 changes: 22 additions & 0 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"rotary_dim": 128,
"final_rms": True,
"gated_mlp": True,
"rotary_base": 500000.0,
}
elif "Meta-Llama-3-70B" in official_model_name:
cfg_dict = {
Expand All @@ -894,6 +895,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"rotary_dim": 128,
"final_rms": True,
"gated_mlp": True,
"rotary_base": 500000.0,
}
elif "Llama-3.2-1B" in official_model_name:
cfg_dict = {
Expand All @@ -913,6 +915,11 @@ def convert_hf_model_config(model_name: str, **kwargs):
"rotary_dim": 64,
"final_rms": True,
"gated_mlp": True,
"rotary_base": 500000.0,
"use_NTK_by_parts_rope": True,
"NTK_by_parts_low_freq_factor": 1.0,
"NTK_by_parts_high_freq_factor": 4.0,
"NTK_by_parts_factor": 32.0,
}
elif "Llama-3.2-3B" in official_model_name:
cfg_dict = {
Expand All @@ -932,6 +939,11 @@ def convert_hf_model_config(model_name: str, **kwargs):
"rotary_dim": 128,
"final_rms": True,
"gated_mlp": True,
"rotary_base": 500000.0,
"use_NTK_by_parts_rope": True,
"NTK_by_parts_low_freq_factor": 1.0,
"NTK_by_parts_high_freq_factor": 4.0,
"NTK_by_parts_factor": 32.0,
}
elif "Llama-3.1-8B" in official_model_name:
cfg_dict = {
Expand All @@ -951,6 +963,11 @@ def convert_hf_model_config(model_name: str, **kwargs):
"rotary_dim": 128,
"final_rms": True,
"gated_mlp": True,
"rotary_base": 500000.0,
"use_NTK_by_parts_rope": True,
"NTK_by_parts_low_freq_factor": 1.0,
"NTK_by_parts_high_freq_factor": 4.0,
"NTK_by_parts_factor": 8.0,
}
elif "Llama-3.1-70B" in official_model_name:
cfg_dict = {
Expand All @@ -970,6 +987,11 @@ def convert_hf_model_config(model_name: str, **kwargs):
"rotary_dim": 128,
"final_rms": True,
"gated_mlp": True,
"rotary_base": 500000.0,
"use_NTK_by_parts_rope": True,
"NTK_by_parts_low_freq_factor": 1.0,
"NTK_by_parts_high_freq_factor": 4.0,
"NTK_by_parts_factor": 8.0,
}
elif architecture == "GPTNeoForCausalLM":
cfg_dict = {
Expand Down

0 comments on commit c7837fb

Please sign in to comment.