diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index e2fdc532e..4458705de 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -181,6 +181,18 @@ class HookedTransformerConfig: output_logits_soft_cap (float): An optional softcap for output logits, currently only used in Gemma-2 (see attn_scores_soft_cap for details). Defaults to -1.0, which means not set. + use_NTK_by_parts_rope (bool): Whether to apply the "NTK-by-parts" method when using Rotary + Positional Embedding. This method adjusts the interpolation based on frequency factors + for different parts of the hidden dimensions. See Section 3.2 in + https://arxiv.org/pdf/2309.00071 for details. Defaults to False. + NTK_by_parts_low_freq_factor (float): The threshold applied to low-frequency hidden + dimensions during interpolation when using the "NTK-by-parts" method. Defaults to 1.0. + NTK_by_parts_high_freq_factor (float): The threshold applied to high-frequency hidden + dimensions during interpolation in the "NTK-by-parts" method. Defaults to 4.0. + NTK_by_parts_factor (float): The overall factor used in the "NTK-by-parts" method that + affects the rate of change between low and high-frequency interpolation strategies. + Defaults to 8.0. + """ @@ -246,6 +258,10 @@ class HookedTransformerConfig: use_normalization_before_and_after: bool = False attn_scores_soft_cap: float = -1.0 output_logits_soft_cap: float = -1.0 + use_NTK_by_parts_rope: bool = False + NTK_by_parts_low_freq_factor: float = 1.0 + NTK_by_parts_high_freq_factor: float = 4.0 + NTK_by_parts_factor: float = 8.0 def __post_init__(self): if self.n_heads == -1: diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 3146de0c2..a2a831e9f 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -1,3 +1,4 @@ +import math from abc import ABC from typing import Dict, Optional, Tuple, Union @@ -478,8 +479,33 @@ def calculate_sin_cos_rotary( pos = torch.arange(n_ctx, dtype=high_precision) dim = torch.arange(rotary_dim // 2, dtype=high_precision) - # A set of frequencies evenly spaced in log space - freq = base ** (dim / (rotary_dim / 2)) + # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071 + # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310 + if self.cfg.use_NTK_by_parts_rope: + inv_freq = 1.0 / ( + base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim) + ) + factor = self.cfg.NTK_by_parts_factor + low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor + high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor + old_context_len = n_ctx + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smoothed_inv_freq = ( + 1 - smooth_factor + ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + freq = 1 / inv_freq_llama + else: + freq = base ** (dim / (rotary_dim / 2)) if self.cfg.rotary_adjacent_pairs: freq = einops.repeat(freq, "d -> (d 2)") else: diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 0b8489976..49dffbf04 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -875,6 +875,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "rotary_dim": 128, "final_rms": True, "gated_mlp": True, + "rotary_base": 500000.0, } elif "Meta-Llama-3-70B" in official_model_name: cfg_dict = { @@ -894,6 +895,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "rotary_dim": 128, "final_rms": True, "gated_mlp": True, + "rotary_base": 500000.0, } elif "Llama-3.2-1B" in official_model_name: cfg_dict = { @@ -913,6 +915,11 @@ def convert_hf_model_config(model_name: str, **kwargs): "rotary_dim": 64, "final_rms": True, "gated_mlp": True, + "rotary_base": 500000.0, + "use_NTK_by_parts_rope": True, + "NTK_by_parts_low_freq_factor": 1.0, + "NTK_by_parts_high_freq_factor": 4.0, + "NTK_by_parts_factor": 32.0, } elif "Llama-3.2-3B" in official_model_name: cfg_dict = { @@ -932,6 +939,11 @@ def convert_hf_model_config(model_name: str, **kwargs): "rotary_dim": 128, "final_rms": True, "gated_mlp": True, + "rotary_base": 500000.0, + "use_NTK_by_parts_rope": True, + "NTK_by_parts_low_freq_factor": 1.0, + "NTK_by_parts_high_freq_factor": 4.0, + "NTK_by_parts_factor": 32.0, } elif "Llama-3.1-8B" in official_model_name: cfg_dict = { @@ -951,6 +963,11 @@ def convert_hf_model_config(model_name: str, **kwargs): "rotary_dim": 128, "final_rms": True, "gated_mlp": True, + "rotary_base": 500000.0, + "use_NTK_by_parts_rope": True, + "NTK_by_parts_low_freq_factor": 1.0, + "NTK_by_parts_high_freq_factor": 4.0, + "NTK_by_parts_factor": 8.0, } elif "Llama-3.1-70B" in official_model_name: cfg_dict = { @@ -970,6 +987,11 @@ def convert_hf_model_config(model_name: str, **kwargs): "rotary_dim": 128, "final_rms": True, "gated_mlp": True, + "rotary_base": 500000.0, + "use_NTK_by_parts_rope": True, + "NTK_by_parts_low_freq_factor": 1.0, + "NTK_by_parts_high_freq_factor": 4.0, + "NTK_by_parts_factor": 8.0, } elif architecture == "GPTNeoForCausalLM": cfg_dict = {