Skip to content

Commit

Permalink
Merge branch 'master' into tagdataset/add-llm-exp-pred
Browse files Browse the repository at this point in the history
  • Loading branch information
puririshi98 authored Jan 7, 2025
2 parents d09912b + 8a651b3 commit 1b3e679
Showing 1 changed file with 148 additions and 24 deletions.
172 changes: 148 additions & 24 deletions examples/llm/g_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,94 +31,197 @@


def compute_metrics(eval_output):
"""Compute evaluation metrics (hit, precision, recall, F1).
Parameters:
eval_output (list): List of dictionaries containing prediction output.
Returns:
None (prints metrics to console)
"""
# Concatenate prediction output into a single DataFrame
df = pd.concat([pd.DataFrame(d) for d in eval_output])
all_hit = []
all_precision = []
all_recall = []
all_f1 = []

# Initialize lists to store metrics
all_hit = [] # Boolean values indicating whether prediction matches label
all_precision = [] # List of precision values
all_recall = [] # List of recall values
all_f1 = [] # List of F1 values

# Iterate over prediction-label pairs
for pred, label in zip(df.pred.tolist(), df.label.tolist()):
try:
# Preprocess prediction string
pred = pred.split('[/s]')[0].strip().split('|')

# Check if prediction matches label
hit = re.findall(pred[0], label)
all_hit.append(len(hit) > 0)

# Compute precision, recall, and F1
label = label.split('|')
matches = set(pred).intersection(set(label))
precision = len(matches) / len(set(pred))
recall = len(matches) / len(set(label))

# Handle division by zero
if recall + precision == 0:
f1 = 0
else:
f1 = 2 * precision * recall / (precision + recall)

# Store metrics
all_precision.append(precision)
all_recall.append(recall)
all_f1.append(f1)

except Exception as e:
# Handle exceptions by printing error message and skipping
print(f'Label: {label}')
print(f'Pred: {pred}')
print(f'Exception: {e}')
print('------------------')

# Compute average metrics
hit = sum(all_hit) / len(all_hit)
precision = sum(all_precision) / len(all_precision)
recall = sum(all_recall) / len(all_recall)
f1 = sum(all_f1) / len(all_f1)

# Print metrics to console
print(f'Hit: {hit:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1: {f1:.4f}')


def save_params_dict(model, save_path):
"""Saves a model's parameters, excluding non-trainable weights.
Args:
model (torch.nn.Module): The model to save parameters from.
save_path (str): The path to save the parameters to.
"""
# Get the model's state dictionary, which contains all its parameters
state_dict = model.state_dict()

# Create a dictionary mapping parameter names to their requires_grad status
param_grad_dict = {
k: v.requires_grad
for (k, v) in model.named_parameters()
}

# Remove non-trainable parameters from the state dictionary
for k in list(state_dict.keys()):
if k in param_grad_dict.keys() and not param_grad_dict[k]:
del state_dict[k] # Delete parameters that do not require gradient

# Save the filtered state dictionary to the specified path
torch.save(state_dict, save_path)


def load_params_dict(model, save_path):
# Load the saved model parameters from the specified file path
state_dict = torch.load(save_path)

# Update the model's parameters with the loaded state dictionary
model.load_state_dict(state_dict)

# Return the model with updated parameters
return model


def get_loss(model, batch, model_save_name) -> Tensor:
def get_loss(model, batch, model_save_name: str) -> Tensor:
"""Compute the loss for a given model and batch of data.
Args:
model: The model to compute the loss for.
batch: The batch of data to compute the loss for.
model_save_name: The name of the model being used (e.g. 'llm').
Returns:
Tensor: The computed loss.
"""
# Check the type of model being used to determine the input arguments
if model_save_name == 'llm':
# For LLM models
return model(batch.question, batch.label, batch.desc)
else:
return model(batch.question, batch.x, batch.edge_index, batch.batch,
batch.label, batch.edge_attr, batch.desc)
else: # (GNN+LLM)
return model(
batch.question,
batch.x, # node features
batch.edge_index, # edge indices
batch.batch, # batch indices
batch.label, # answers (labels)
batch.edge_attr, # edge attributes
batch.desc # description
)


def inference_step(model, batch, model_save_name):
"""Performs inference on a given batch of data using the provided model.
Args:
model (nn.Module): The model to use for inference.
batch: The batch of data to process.
model_save_name (str): The name of the model (e.g. 'llm').
Returns:
The output of the inference step.
"""
# Check the type of model being used to determine the input arguments
if model_save_name == 'llm':
# Perform inference on the question and textual graph description
return model.inference(batch.question, batch.desc)
else:
return model.inference(batch.question, batch.x, batch.edge_index,
batch.batch, batch.edge_attr, batch.desc)
else: # (GNN+LLM)
return model.inference(
batch.question,
batch.x, # node features
batch.edge_index, # edge indices
batch.batch, # batch indices
batch.edge_attr, # edge attributes
batch.desc # description
)


def train(
num_epochs,
hidden_channels,
num_gnn_layers,
batch_size,
eval_batch_size,
lr,
checkpointing=False,
tiny_llama=False,
num_epochs, # Total number of training epochs
hidden_channels, # Number of hidden channels in GNN
num_gnn_layers, # Number of GNN layers
batch_size, # Training batch size
eval_batch_size, # Evaluation batch size
lr, # Initial learning rate
checkpointing=False, # Whether to checkpoint model
tiny_llama=False, # Whether to use tiny LLaMA model
):
"""Train a GNN+LLM model on WebQSP dataset.
Args:
num_epochs (int): Total number of training epochs.
hidden_channels (int): Number of hidden channels in GNN.
num_gnn_layers (int): Number of GNN layers.
batch_size (int): Training batch size.
eval_batch_size (int): Evaluation batch size.
lr (float): Initial learning rate.
checkpointing (bool, optional): Whether to checkpoint model.
Defaults to False.
tiny_llama (bool, optional): Whether to use tiny LLaMA model.
Defaults to False.
Returns:
None
"""
def adjust_learning_rate(param_group, LR, epoch):
# Decay the learning rate with half-cycle cosine after warmup
"""Decay learning rate with half-cycle cosine after warmup.
Args:
param_group (dict): Parameter group.
LR (float): Learning rate.
epoch (int): Current epoch.
Returns:
float: Adjusted learning rate.
"""
min_lr = 5e-6
warmup_epochs = 1
if epoch < warmup_epochs:
Expand All @@ -130,7 +233,10 @@ def adjust_learning_rate(param_group, LR, epoch):
param_group['lr'] = lr
return lr

# Start training time
start_time = time.time()

# Load dataset and create data loaders
path = osp.dirname(osp.realpath(__file__))
path = osp.join(path, '..', '..', 'data', 'WebQSPDataset')
train_dataset = WebQSPDataset(path, split='train')
Expand All @@ -146,16 +252,20 @@ def adjust_learning_rate(param_group, LR, epoch):
test_loader = DataLoader(test_dataset, batch_size=eval_batch_size,
drop_last=False, pin_memory=True, shuffle=False)

# To clean up after Data Preproc
# Clean up memory
gc.collect()
torch.cuda.empty_cache()

# Create GNN model
gnn = GAT(
in_channels=1024,
hidden_channels=hidden_channels,
out_channels=1024,
num_layers=num_gnn_layers,
heads=4,
)

# Create LLaMA model
if tiny_llama:
llm = LLM(
model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
Expand All @@ -166,10 +276,12 @@ def adjust_learning_rate(param_group, LR, epoch):
llm = LLM(model_name='meta-llama/Llama-2-7b-chat-hf', num_params=7)
model = GRetriever(llm=llm, gnn=gnn)

# Set model save name
model_save_name = 'gnn_llm' if num_gnn_layers != 0 else 'llm'
if model_save_name == 'llm':
model = llm

# Create optimizer
params = [p for _, p in model.named_parameters() if p.requires_grad]
optimizer = torch.optim.AdamW([
{
Expand All @@ -178,10 +290,12 @@ def adjust_learning_rate(param_group, LR, epoch):
'weight_decay': 0.05
},
], betas=(0.9, 0.95))
grad_steps = 2

# Initialize best epoch and best validation loss
best_epoch = 0
best_val_loss = float('inf')

# Train model
for epoch in range(num_epochs):
model.train()
epoch_loss = 0
Expand All @@ -198,18 +312,19 @@ def adjust_learning_rate(param_group, LR, epoch):

clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)

if (step + 1) % grad_steps == 0:
if (step + 1) % 2 == 0:
adjust_learning_rate(optimizer.param_groups[0], lr,
step / len(train_loader) + epoch)

optimizer.step()
epoch_loss = epoch_loss + float(loss)

if (step + 1) % grad_steps == 0:
if (step + 1) % 2 == 0:
lr = optimizer.param_groups[0]['lr']
train_loss = epoch_loss / len(train_loader)
print(epoch_str + f', Train Loss: {train_loss:4f}')

# Evaluate model
val_loss = 0
eval_output = []
model.eval()
Expand All @@ -224,16 +339,20 @@ def adjust_learning_rate(param_group, LR, epoch):
best_val_loss = val_loss
best_epoch = epoch
save_params_dict(model, f'{model_save_name}_best_val_loss_ckpt.pt')

# Clean up memory
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

# Load best checkpoint if necessary
if checkpointing and best_epoch != num_epochs - 1:
print("Loading best checkpoint...")
model = load_params_dict(
model,
f'{model_save_name}_best_val_loss_ckpt.pt',
)

# Evaluate model on test set
model.eval()
eval_output = []
print("Final evaluation...")
Expand All @@ -250,8 +369,13 @@ def adjust_learning_rate(param_group, LR, epoch):
eval_output.append(eval_data)
progress_bar_test.update(1)

# Compute metrics
compute_metrics(eval_output)

# Print final training time
print(f"Total Training Time: {time.time() - start_time:2f}s")

# Save model and evaluation output
save_params_dict(model, f'{model_save_name}.pt')
torch.save(eval_output, f'{model_save_name}_eval_outs.pt')

Expand Down

0 comments on commit 1b3e679

Please sign in to comment.