Skip to content

Commit

Permalink
Fix g-retriever example (#9686)
Browse files Browse the repository at this point in the history
Co-authored-by: Alfred Clemedtson <[email protected]>
  • Loading branch information
brs96 and AlfredClemedtson authored Oct 4, 2024
1 parent 5034fef commit 56d53d0
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions examples/llm/g_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def compute_metrics(eval_output):

label = label.split('|')
matches = set(pred).intersection(set(label))
precision = len(matches) / len(set(label))
recall = len(matches) / len(set(pred))
precision = len(matches) / len(set(pred))
recall = len(matches) / len(set(label))
if recall + precision == 0:
f1 = 0
else:
Expand Down Expand Up @@ -159,7 +159,10 @@ def adjust_learning_rate(param_group, LR, epoch):
llm = LLM(model_name='meta-llama/Llama-2-7b-chat-hf', num_params=7)
model = GRetriever(llm=llm, gnn=gnn)

model_save_name = 'gnn_llm' if num_gnn_layers is not None else 'llm'
model_save_name = 'gnn_llm' if num_gnn_layers != 0 else 'llm'
if model_save_name == 'llm':
model = llm

params = [p for _, p in model.named_parameters() if p.requires_grad]
optimizer = torch.optim.AdamW([
{
Expand Down

0 comments on commit 56d53d0

Please sign in to comment.