Skip to content

Commit

Permalink
making autoencoder work with HookedGPT
Browse files Browse the repository at this point in the history
  • Loading branch information
shehper committed Apr 1, 2024
1 parent c2861db commit 96e1847
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 34 deletions.
9 changes: 6 additions & 3 deletions autoencoder/generate_mlp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

## Add the path to the transformer subdirectory as it contains model.py
sys.path.insert(0, '../transformer')
from model import GPTConfig, GPT
from model import GPTConfig
from hooked_model import HookedGPT

## define some parameters; these can be overwritten from command line
device = 'cpu'
Expand Down Expand Up @@ -41,7 +42,7 @@
checkpoint = torch.load(ckpt_path, map_location=device)
print(f'loaded transformer model checkpoint from {ckpt_path}')
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
model = HookedGPT(gptconf)
state_dict = checkpoint['model']
compile = False # TODO: Don't know why I needed to set compile to False before loading the model..
# TODO: I dont know why the next 4 lines are needed. state_dict does not seem to have any keys with unwanted_prefix.
Expand Down Expand Up @@ -83,7 +84,9 @@
contexts = torch.stack([torch.from_numpy((text_data[i:i+block_size]).astype(np.int64)) for i in ix]) # (b, t)

# compute MLP activations from the loaded model
activations = model.get_last_mlp_acts(contexts) # (b, t, n_ffwd)
_, _ = model(contexts)
activations = model.mlp_activation_hooks[0] # (b, t, n_ffwd)
model.clear_mlp_activation_hooks() # free up memory

# pick tokens_per_context (n) tokens from each context; and flatten the first two dimensions
data = torch.stack([activations[i, torch.randint(block_size, (tokens_per_context,)), :] for i in range(contexts_per_batch)]).view(-1, activations.shape[-1]) #(b*n, n_ffwd)
Expand Down
9 changes: 6 additions & 3 deletions autoencoder/top_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

## Add path to the transformer subdirectory as it contains GPT class in model.py
sys.path.insert(0, '../transformer')
from model import GPTConfig, GPT
from model import GPTConfig
from hooked_model import HookedGPT

# hyperparameters
device = 'cuda' # change it to cpu
Expand Down Expand Up @@ -109,7 +110,7 @@ def sample_tokens(*args, eval_tokens, num_tokens_either_side, fn_seed=0):
gpt_ckpt_path = os.path.join(os.path.dirname(current_dir), 'transformer', gpt_dir, 'ckpt.pt')
gpt_ckpt = torch.load(gpt_ckpt_path, map_location=device)
gptconf = GPTConfig(**gpt_ckpt['model_args'])
gpt = GPT(gptconf)
gpt = HookedGPT(gptconf)
state_dict = gpt_ckpt['model']
compile = False # TODO: why do this?
unwanted_prefix = '_orig_mod.' # TODO: why do this and the next three lines?
Expand Down Expand Up @@ -192,7 +193,9 @@ def sample_tokens(*args, eval_tokens, num_tokens_either_side, fn_seed=0):
else:
X_BT = X_NT[iter * B: (iter + 1) * B].to(device)
# compute MLP activations
mlp_acts_BTF = gpt.get_last_mlp_acts(X_BT) # TODO: Learn to use hooks instead?
_, _ = gpt(X_BT)
mlp_acts_BTF = gpt.mlp_activation_hooks[0]
gpt.clear_mlp_activation_hooks()
# compute feature activations for features in this phase
feature_acts_BTH = autoencoder.get_feature_acts(x=mlp_acts_BTF, s=phase*H, e=(phase+1)*H)
# sample tokens from the context, and save feature activations and tokens for these tokens in data_MW.
Expand Down
55 changes: 27 additions & 28 deletions autoencoder/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Train a Sparse AutoEncoder model
Run on a macbook on a Shakespeare dataset as
python train.py --dataset=shakespeare_char --gpt_dir=out-shakespeare-char --eval_contexts=1000 --eval_batch_size=16 --batch_size=128 --device=cpu --eval_interval=100 --n_features=1024 --resampling_interval=150
python train.py --dataset=shakespeare_char --gpt_dir=out_sc_1_2_32 --eval_contexts=20 --eval_batch_size=16 --batch_size=128 --device=cpu --eval_interval=100 --n_features=1024 --resampling_interval=150 --wandb_log=True
"""
import os
import torch
Expand All @@ -20,7 +20,8 @@

## Add path to the transformer subdirectory as it contains GPT class in model.py
sys.path.insert(0, '../transformer')
from model import GPTConfig, GPT
from model import GPTConfig
from hooked_model import HookedGPT

## hyperparameters
device = 'cuda'
Expand Down Expand Up @@ -125,7 +126,7 @@ def get_histogram_image(data, bins='auto'):
ckpt_path = os.path.join(os.path.dirname(current_dir), 'transformer', gpt_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
gpt = GPT(gptconf)
gpt = HookedGPT(gptconf)
state_dict = checkpoint['model']
compile = False # TODO: Don't know why I needed to set compile to False before loading the model..
# TODO: Also, I dont know why the next 4 lines are needed. state_dict does not seem to have any keys with unwanted_prefix.
Expand Down Expand Up @@ -170,28 +171,20 @@ def get_histogram_image(data, bins='auto'):
# Finally, we pre-select indices of tokens_per_eval_context (=10 by default, as in Anthropic's paper) tokens in each context
# and save it in token_indices. These will be used for the calculation of feature activation counts during evaluation
# TODO: There is probably no need to pre-select these indices. Perhaps remove token_indices and sample tokens during evaluation on the go?
X, Y = get_text_batch(text_data, block_size=block_size, batch_size=eval_contexts) # (eval_contexts, block_size)
X, Y = get_text_batch(text_data, block_size=block_size, batch_size=eval_contexts) # (eval_contexts, block_size)
token_indices = torch.randint(block_size, (eval_contexts, tokens_per_eval_context)) # (eval_contexts, tokens_per_eval_context)
num_eval_batches = eval_contexts // eval_batch_size
mlp_activations_storage = torch.tensor([], dtype=torch.float16)
residual_stream_storage = torch.tensor([], dtype=torch.float16)
transformer_loss, mlp_ablated_loss = 0, 0

ablated_losses = torch.zeros(num_eval_batches,)
for iter in range(num_eval_batches):
print(f'iter = {iter}/{num_eval_batches} in computation of mlp_activations and residual_stream for evaluation data')
# pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
x = slice_fn(X).pin_memory().to(device, non_blocking=True) if device_type == 'cuda' else slice_fn(X).to(device) # select a batch of text data inputs
y = slice_fn(Y).pin_memory().to(device, non_blocking=True) if device_type == 'cuda' else slice_fn(Y).to(device) # select a batch of text data outputs
res_stream, mlp_activations, batch_loss, batch_ablated_loss = gpt.forward_with_and_without_last_mlp(x, y) # Transformer forward pass; compute residual stream, MLP activations and losses
mlp_activations_storage = torch.cat([mlp_activations_storage, mlp_activations.to(dtype=torch.float16, device='cpu')]) # store MLP activations
residual_stream_storage = torch.cat([residual_stream_storage, res_stream.to(dtype=torch.float16, device='cpu')]) # store residual stream
transformer_loss, mlp_ablated_loss = transformer_loss + batch_loss, mlp_ablated_loss + batch_ablated_loss
transformer_loss, mlp_ablated_loss = transformer_loss/num_eval_batches, mlp_ablated_loss/num_eval_batches # divide by num_eval_batches to get mean values
token_indices = torch.randint(block_size, (eval_contexts, tokens_per_eval_context)) # (eval_contexts, tokens_per_eval_context)

ablated_losses[iter] = gpt(x, y, mode="replace")[1].item()
ablated_loss = ablated_losses.mean()

memory = psutil.virtual_memory()
print(f'computed mlp activations and losses on eval data; available memory: {memory.available / (1024**3):.2f} GB; memory usage: {memory.percent}%')
print(f'The full transformer loss and MLP ablated loss on the evaluation data are {transformer_loss:.2f}, {mlp_ablated_loss:.2f}')
del X; gc.collect() # will not need X anymore; instead res_stream_storage and mlp_acts_storage will be used


## INITIATE AUTOENCODER AND OPTIMIZER
Expand All @@ -210,6 +203,9 @@ def get_histogram_image(data, bins='auto'):
start_time = time.time()
num_steps = total_training_examples // batch_size
for step in range(num_steps):

if step == 200:
break

## load a batch of data
batch, current_partition, current_partition_index, offset = load_data(step, batch_size, current_partition_index, current_partition, n_partitions, examples_per_partition, offset)
Expand Down Expand Up @@ -261,16 +257,15 @@ def get_histogram_image(data, bins='auto'):
feature_activation_counts = torch.zeros(n_features, dtype=torch.float32) # number of tokens on which each feature is active
start_log_time = time.time()

transformer_losses = torch.zeros(num_eval_batches,)

for iter in range(num_eval_batches):

x = slice_fn(X).pin_memory().to(device, non_blocking=True) if device_type == 'cuda' else slice_fn(X).to(device) # select a batch of text data inputs
y = slice_fn(Y).pin_memory().to(device, non_blocking=True) if device_type == 'cuda' else slice_fn(Y).to(device) # select a batch of text data outputs

if device_type == 'cuda': # select batch of mlp activations, residual stream and y
batch_mlp_activations = slice_fn(mlp_activations_storage).pin_memory().to(device, non_blocking=True)
batch_res_stream = slice_fn(residual_stream_storage).pin_memory().to(device, non_blocking=True)
batch_targets = slice_fn(Y).pin_memory().to(device, non_blocking=True)
else:
batch_mlp_activations = slice_fn(mlp_activations_storage).to(device)
batch_res_stream = slice_fn(residual_stream_storage).to(device)
batch_targets = slice_fn(Y).to(device)
transformer_losses[iter] = gpt(x, y)[1]
batch_mlp_activations = gpt.mlp_activation_hooks[0]

with torch.no_grad():
output = autoencoder(batch_mlp_activations) # output = {'loss': loss, 'f': f, 'reconst_acts': reconst_acts, 'mseloss': mseloss, 'l1loss': l1loss}
Expand All @@ -290,17 +285,21 @@ def get_histogram_image(data, bins='auto'):
del batch_mlp_activations, f, f_subset; gc.collect(); torch.cuda.empty_cache()

# Compute reconstructed loss from batch_reconstructed_activations
log_dict['losses/reconstructed_nll'] += gpt.get_loss_from_last_mlp_acts(batch_res_stream, output['reconst_acts'], batch_targets).item()
_, reconstructed_nll = gpt(x, y, mode="replace", replacement_tensor=output['reconst_acts'])
log_dict['losses/reconstructed_nll'] += reconstructed_nll
log_dict['losses/autoencoder_loss'] += output['loss'].item()
log_dict['losses/mse_loss'] += output['mse_loss'].item()
log_dict['losses/l1_loss'] += output['l1_loss'].item()
del batch_res_stream, output, batch_targets; gc.collect(); torch.cuda.empty_cache()
# del batch_res_stream, output, batch_targets; gc.collect(); torch.cuda.empty_cache()
del output; gc.collect(); torch.cuda.empty_cache()

transformer_loss = transformer_losses.mean()

# take mean of all loss values by dividing by the number of evaluation batches
log_dict = {key: val/num_eval_batches for key, val in log_dict.items()}

# add nll score to log_dict
log_dict['losses/nll_score'] = (transformer_loss - log_dict['losses/reconstructed_nll'])/(transformer_loss - mlp_ablated_loss).item()
log_dict['losses/nll_score'] = (transformer_loss - log_dict['losses/reconstructed_nll'])/(transformer_loss - ablated_loss).item()

# compute feature densities and plot feature density histogram
log_feature_activation_density = np.log10(feature_activation_counts[feature_activation_counts != 0]/(eval_tokens)) # (n_features,)
Expand Down

0 comments on commit 96e1847

Please sign in to comment.