Skip to content

Commit

Permalink
Restore upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Bauer committed Jan 19, 2025
1 parent 54bd335 commit cb056c6
Show file tree
Hide file tree
Showing 14 changed files with 265 additions and 75 deletions.
24 changes: 17 additions & 7 deletions .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# If .venv is already setup with python3.8, it will use python3.8. To use 3.11 remove it first.

# Use Nvidia Ubuntu 20 base (includes CUDA if a supported GPU is present)
# https://hub.docker.com/r/nvidia/cuda
FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04@sha256:55211df43bf393d3393559d5ab53283d4ebc3943d802b04546a24f3345825bd9
Expand All @@ -17,18 +19,26 @@ RUN groupadd --gid $USER_GID $USERNAME \
&& chmod 0440 /etc/sudoers.d/$USERNAME

# Install dependencies
RUN sudo apt-get update && \
RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \
software-properties-common && \
add-apt-repository -y ppa:deadsnakes/ppa && \
apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \
build-essential \
python3.9 \
python3.9-dev \
python3.9-distutils \
python3.9-venv \
python3.11 \
python3.11-dev \
python3.11-distutils \
python3.11-venv \
curl \
git
git && \
# Update python3 default to point to python3.11
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1 && \
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 2 && \
update-alternatives --set python3 /usr/bin/python3.11

# User the new user
USER $USERNAME

# Install poetry
RUN curl -sSL https://install.python-poetry.org | python3 -
RUN curl -sSL https://install.python-poetry.org | python3.11 -
6 changes: 5 additions & 1 deletion .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ jobs:
poetry install --with dev
- name: Unit Test
run: make unit-test
env:
HF_TOKEN: ${{ vars.HF_TOKEN }}
- name: Acceptance Test
run: make acceptance-test
- name: Build check
Expand Down Expand Up @@ -106,6 +108,8 @@ jobs:
run: poetry run mypy .
- name: Test Suite with Coverage Report
run: make coverage-report-test
env:
HF_TOKEN: ${{ vars.HF_TOKEN }}
- name: Build check
run: poetry build
- name: Upload Coverage Report Artifact
Expand Down Expand Up @@ -195,7 +199,7 @@ jobs:
- name: Build Docs
run: poetry run build-docs
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ vars.HF_TOKEN }}
- name: Upload Docs Artifact
uses: actions/upload-artifact@v3
with:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ docs/build
.pylintrc
docs/source/generated
**.orig
.venv
90 changes: 71 additions & 19 deletions demos/Colab_Compatibility.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_48359/2396058561.py:18: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
"/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_57027/2944939757.py:18: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
" ipython.magic(\"load_ext autoreload\")\n",
"/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_48359/2396058561.py:19: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
"/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_57027/2944939757.py:19: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
" ipython.magic(\"autoreload 2\")\n"
]
}
Expand Down Expand Up @@ -51,28 +51,28 @@
" %pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8\n",
" %pip install torch\n",
" %pip install tiktoken\n",
" %pip install transformer_lens\n",
" # %pip install transformer_lens\n",
" %pip install transformers_stream_generator\n",
" # !huggingface-cli login --token NEEL'S TOKEN"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TransformerLens currently supports 190 models out of the box.\n"
"TransformerLens currently supports 206 models out of the box.\n"
]
}
],
"source": [
"import torch\n",
"\n",
"from transformer_lens import HookedTransformer, HookedEncoderDecoder, loading\n",
"from transformer_lens import HookedTransformer, HookedEncoderDecoder, HookedEncoder, loading\n",
"from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer\n",
"from typing import List\n",
"import gc\n",
Expand Down Expand Up @@ -144,11 +144,11 @@
" inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
" input_ids = inputs[\"input_ids\"]\n",
" attention_mask = inputs[\"attention_mask\"]\n",
" decoder_input_ids = torch.tensor([[model.cfg.decoder_start_token_id]]).to(input_ids.device)\n",
" decoder_input_ids = torch.tensor([[tl_model.cfg.decoder_start_token_id]]).to(input_ids.device)\n",
"\n",
"\n",
" while True:\n",
" logits = model.forward(input=input_ids, one_zero_attention_mask=attention_mask, decoder_input=decoder_input_ids)\n",
" logits = tl_model.forward(input=input_ids, one_zero_attention_mask=attention_mask, decoder_input=decoder_input_ids)\n",
" # logits.shape == (batch_size (1), predicted_pos, vocab_size)\n",
"\n",
" token_idx = torch.argmax(logits[0, -1, :]).item()\n",
Expand All @@ -160,7 +160,29 @@
" # break if End-Of-Sequence token generated\n",
" if token_idx == tokenizer.eos_token_id:\n",
" break\n",
" print(tl_model.generate(\"Hello my name is\"))\n",
" del tl_model\n",
" gc.collect()\n",
" if IN_COLAB:\n",
" %rm -rf /root/.cache/huggingface/hub/models*\n",
"\n",
"def run_encoder_only_set(model_set: List[str], device=\"cuda\") -> None:\n",
" for model in model_set:\n",
" print(\"Testing \" + model)\n",
" tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n",
" tl_model = HookedEncoder.from_pretrained(model, device=device)\n",
"\n",
" if GENERATE:\n",
" # Slightly adapted version of the BERT demo\n",
" prompt = \"The capital of France is [MASK].\"\n",
"\n",
" input_ids = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"]\n",
"\n",
" logprobs = tl_model(input_ids)[input_ids == tokenizer.mask_token_id].log_softmax(dim=-1)\n",
" prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())\n",
"\n",
" print(f\"Prompt: {prompt}\")\n",
" print(f'Prediction: \"{prediction}\"')\n",
"\n",
" del tl_model\n",
" gc.collect()\n",
" if IN_COLAB:\n",
Expand All @@ -169,15 +191,14 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# The following models can run in the T4 free environment\n",
"free_compatible = [\n",
" \"ai-forever/mGPT\",\n",
" \"ArthurConmy/redwood_attn_2l\",\n",
" \"bert-base-cased\",\n",
" \"bigcode/santacoder\",\n",
" \"bigscience/bloom-1b1\",\n",
" \"bigscience/bloom-560m\",\n",
Expand Down Expand Up @@ -256,6 +277,10 @@
" \"Qwen/Qwen2-0.5B-Instruct\",\n",
" \"Qwen/Qwen2-1.5B\",\n",
" \"Qwen/Qwen2-1.5B-Instruct\",\n",
" \"Qwen/Qwen2.5-0.5B\",\n",
" \"Qwen/Qwen2.5-0.5B-Instruct\",\n",
" \"Qwen/Qwen2.5-1.5B\",\n",
" \"Qwen/Qwen2.5-1.5B-Instruct\",\n",
" \"roneneldan/TinyStories-1Layer-21M\",\n",
" \"roneneldan/TinyStories-1M\",\n",
" \"roneneldan/TinyStories-28M\",\n",
Expand Down Expand Up @@ -290,7 +315,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -340,6 +365,10 @@
" \"Qwen/Qwen1.5-7B-Chat\",\n",
" \"Qwen/Qwen2-7B\",\n",
" \"Qwen/Qwen2-7B-Instruct\",\n",
" \"Qwen/Qwen2.5-3B\",\n",
" \"Qwen/Qwen2.5-3B-Instruct\",\n",
" \"Qwen/Qwen2.5-7B\",\n",
" \"Qwen/Qwen2.5-7B-Instruct\",\n",
" \"stabilityai/stablelm-base-alpha-3b\",\n",
" \"stabilityai/stablelm-base-alpha-7b\",\n",
" \"stabilityai/stablelm-tuned-alpha-3b\",\n",
Expand All @@ -354,7 +383,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -374,6 +403,8 @@
" \"Qwen/Qwen-14B-Chat\",\n",
" \"Qwen/Qwen1.5-14B\",\n",
" \"Qwen/Qwen1.5-14B-Chat\",\n",
" \"Qwen/Qwen2.5-14B\",\n",
" \"Qwen/Qwen2.5-14B-Instruct\",\n",
"]\n",
"\n",
"if IN_COLAB:\n",
Expand All @@ -384,7 +415,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -398,18 +429,24 @@
" \"meta-llama/Llama-2-70b-chat-hf\",\n",
" \"meta-llama/Llama-3.1-70B\",\n",
" \"meta-llama/Llama-3.1-70B-Instruct\",\n",
" \"meta-llama/Llama-3.3-70B-Instruct\",\n",
" \"meta-llama/Meta-Llama-3-70B\",\n",
" \"meta-llama/Meta-Llama-3-70B-Instruct\",\n",
" \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
" \"mistralai/Mixtral-8x7B-v0.1\",\n",
" \"Qwen/Qwen2.5-32B\",\n",
" \"Qwen/Qwen2.5-32B-Instruct\",\n",
" \"Qwen/Qwen2.5-72B\",\n",
" \"Qwen/Qwen2.5-72B-Instruct\",\n",
" \"Qwen/QwQ-32B-Preview\",\n",
"]\n",
"\n",
"mark_models_as_tested(incompatible_models)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -431,7 +468,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -449,7 +486,22 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# This model works on the free version of Colab\n",
"encoder_only_models = [\"bert-base-cased\"]\n",
"\n",
"if IN_COLAB:\n",
" run_encoder_only_set(encoder_only_models)\n",
"\n",
"mark_models_as_tested(encoder_only_models)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -460,7 +512,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -499,5 +551,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
20 changes: 20 additions & 0 deletions demos/Main_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,26 @@
"cv.attention.attention_patterns(tokens=gpt2_str_tokens, attention=attention_pattern)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this case, we only wanted the layer 0 attention patterns, but we are storing the internal activations from all locations in the model. It's convenient to have access to all activations, but this can be prohibitively expensive for memory use with larger models, batch sizes, or sequence lengths. In addition, we don't need to do the full forward pass through the model to collect layer 0 attention patterns. The following cell will collect only the layer 0 attention patterns and stop the forward pass at layer 1, requiring far less memory and compute."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"attn_hook_name = \"blocks.0.attn.hook_pattern\"\n",
"attn_layer = 0\n",
"_, gpt2_attn_cache = model.run_with_cache(gpt2_tokens, remove_batch_dim=True, stop_at_layer=attn_layer + 1, names_filter=[attn_hook_name])\n",
"gpt2_attn = gpt2_attn_cache[attn_hook_name]\n",
"assert torch.equal(gpt2_attn, attention_pattern)"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down
7 changes: 7 additions & 0 deletions tests/integration/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,10 @@ def set_to_randn(z, hook):
# exactly when the zero hook is attached last XOR it is prepended

assert torch.allclose(logits, model.unembed.b_U[None, :]) == logits_are_unembed_bias


def test_use_attn_in_with_gqa_raises_error():
# Create model that uses GroupedQueryAttention
model = HookedTransformer.from_pretrained("Qwen/Qwen2-0.5B")
with pytest.raises(AssertionError):
model.set_use_attn_in(True)
30 changes: 30 additions & 0 deletions tests/unit/components/test_attention.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import einops
import pytest
import torch
import torch.nn as nn
from transformers.utils import is_bitsandbytes_available

from transformer_lens.components import Attention
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.utilities.attention import complex_attn_linear

if is_bitsandbytes_available():
from bitsandbytes.nn.modules import Params4bit
Expand Down Expand Up @@ -98,3 +100,31 @@ def test_attention_config_dict():
assert attn.cfg.load_in_4bit == False
assert attn.cfg.dtype == torch.float32
assert attn.cfg.act_fn == "relu"


def test_remove_einsum_from_complex_attn_linear():
batch = 64
pos = 128
head_index = 8
d_model = 512
d_head = 64
input = torch.randn(batch, pos, head_index, d_model)
w = torch.randn(head_index, d_model, d_head)
b = torch.randn(head_index, d_head)
result_new = complex_attn_linear(input, w, b)

# Check if new implementation without einsum produces correct shape
assert result_new.shape == (batch, pos, head_index, d_head)

# Old implementation used einsum
result_old = (
einops.einsum(
input,
w,
"batch pos head_index d_model, head_index d_model d_head -> batch pos head_index d_head",
)
+ b
)

# Check if the results are the same
assert torch.allclose(result_new, result_old, atol=1e-4)
Loading

0 comments on commit cb056c6

Please sign in to comment.