From cb6ad8e47e47757bf0141625c95b74c4bd58634d Mon Sep 17 00:00:00 2001 From: Vatsal Rathod Date: Mon, 21 Oct 2024 20:17:06 -0400 Subject: [PATCH] Add configurations for Llama 3.1 models(Llama-3.1-8B and Llama-3.1-70B) (#761) * Add configurations for Llama 3.1 models(Llama-3.1-8B and Llama-3.1-70B) * formatted using black * added models to colab compatability notebook * added configurations for Llama-3.1-8B-Instruct and Llama-3.1-70B-Instruct * cleaned up config a bit * updated compatibility notebook --------- Co-authored-by: Bryce Meyer --- demos/Colab_Compatibility.ipynb | 20 +++++++++------- transformer_lens/loading_from_pretrained.py | 26 ++++++++++++--------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/demos/Colab_Compatibility.ipynb b/demos/Colab_Compatibility.ipynb index d173df4b3..62b5814d4 100644 --- a/demos/Colab_Compatibility.ipynb +++ b/demos/Colab_Compatibility.ipynb @@ -2,25 +2,23 @@ "cells": [ { "cell_type": "code", - "execution_count": 14, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Running as a Jupyter notebook - intended for development only!\n", - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" + "Running as a Jupyter notebook - intended for development only!\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_28124/2396058561.py:18: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_48359/2396058561.py:18: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", " ipython.magic(\"load_ext autoreload\")\n", - "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_28124/2396058561.py:19: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_48359/2396058561.py:19: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", " ipython.magic(\"autoreload 2\")\n" ] } @@ -60,14 +58,14 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "TransformerLens currently supports 186 models out of the box.\n" + "TransformerLens currently supports 190 models out of the box.\n" ] } ], @@ -91,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -323,6 +321,8 @@ " \"google/gemma-7b-it\",\n", " \"meta-llama/Llama-2-7b-chat-hf\",\n", " \"meta-llama/Llama-2-7b-hf\",\n", + " \"meta-llama/Llama-3.1-8B\",\n", + " \"meta-llama/Llama-3.1-8B-Instruct\",\n", " \"meta-llama/Llama-3.2-3B\",\n", " \"meta-llama/Llama-3.2-3B-Instruct\",\n", " \"meta-llama/Meta-Llama-3-8B\",\n", @@ -396,6 +396,8 @@ " \"google/gemma-2-27b\",\n", " \"google/gemma-2-27b-it\",\n", " \"meta-llama/Llama-2-70b-chat-hf\",\n", + " \"meta-llama/Llama-3.1-70B\",\n", + " \"meta-llama/Llama-3.1-70B-Instruct\",\n", " \"meta-llama/Meta-Llama-3-70B\",\n", " \"meta-llama/Meta-Llama-3-70B-Instruct\",\n", " \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n", diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index bf9077db8..0b8489976 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -155,6 +155,10 @@ "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", + "meta-llama/Llama-3.1-70B", + "meta-llama/Llama-3.1-8B", + "meta-llama/Llama-3.1-8B-Instruct", + "meta-llama/Llama-3.1-70B-Instruct", "Baidicoot/Othello-GPT-Transformer-Lens", "bert-base-cased", "roneneldan/TinyStories-1M", @@ -929,13 +933,13 @@ def convert_hf_model_config(model_name: str, **kwargs): "final_rms": True, "gated_mlp": True, } - elif "Llama-3.2-1B-Instruct" in official_model_name: + elif "Llama-3.1-8B" in official_model_name: cfg_dict = { - "d_model": 2048, - "d_head": 64, + "d_model": 4096, + "d_head": 128, "n_heads": 32, - "d_mlp": 8192, - "n_layers": 16, + "d_mlp": 14336, + "n_layers": 32, "n_ctx": 2048, # capped due to memory issues "eps": 1e-5, "d_vocab": 128256, @@ -944,17 +948,17 @@ def convert_hf_model_config(model_name: str, **kwargs): "normalization_type": "RMS", "positional_embedding_type": "rotary", "rotary_adjacent_pairs": False, - "rotary_dim": 64, + "rotary_dim": 128, "final_rms": True, "gated_mlp": True, } - elif "Llama-3.2-3B-Instruct" in official_model_name: + elif "Llama-3.1-70B" in official_model_name: cfg_dict = { - "d_model": 3072, + "d_model": 8192, "d_head": 128, - "n_heads": 24, - "d_mlp": 8192, - "n_layers": 28, + "n_heads": 64, + "d_mlp": 28672, + "n_layers": 80, "n_ctx": 2048, # capped due to memory issues "eps": 1e-5, "d_vocab": 128256,