Skip to content

Commit

Permalink
Merge branch 'master' into tagdataset/add-llm-exp-pred
Browse files Browse the repository at this point in the history
  • Loading branch information
puririshi98 authored Jan 7, 2025
2 parents 502659a + cb424a6 commit d09912b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
2 changes: 2 additions & 0 deletions examples/llm/glem.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,11 @@ def load_model(em_phase):
if gnn_val_acc > lm_val_acc:
em_phase = 'gnn'
model.gnn = model.gnn.to(device, non_blocking=True)
test_loader = subgraph_loader
else:
em_phase = 'lm'
model.lm = model.lm.to(device, non_blocking=True)
test_loader = text_test_loader
test_preds = model.inference(em_phase, test_loader, verbose=verbose)
train_acc, val_acc, test_acc = evaluate(test_preds,
['train', 'valid', 'test'])
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/models/glem.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def train(self, em_phase: str, train_loader: Union[DataLoader,
acc (float): training accuracy
loss (float): loss value
"""
pseudo_labels = pseudo_labels.to(self.device)
if pseudo_labels is not None:
pseudo_labels = pseudo_labels.to(self.device)
if em_phase == 'gnn':
acc, loss = self.train_gnn(train_loader, optimizer, epoch,
pseudo_labels, is_augmented, verbose)
Expand Down

0 comments on commit d09912b

Please sign in to comment.