From cb424a6e40fb56477129d893cbd949cd5ab14ec0 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Wed, 8 Jan 2025 03:19:56 +0800 Subject: [PATCH] Fix `glem` example (#9903) Closes #9899. Co-authored-by: Rishi Puri --- examples/llm/glem.py | 2 ++ torch_geometric/nn/models/glem.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/llm/glem.py b/examples/llm/glem.py index ec76cef4c010..c6bae703fd33 100644 --- a/examples/llm/glem.py +++ b/examples/llm/glem.py @@ -371,9 +371,11 @@ def load_model(em_phase): if gnn_val_acc > lm_val_acc: em_phase = 'gnn' model.gnn = model.gnn.to(device, non_blocking=True) + test_loader = subgraph_loader else: em_phase = 'lm' model.lm = model.lm.to(device, non_blocking=True) + test_loader = text_test_loader test_preds = model.inference(em_phase, test_loader, verbose=verbose) train_acc, val_acc, test_acc = evaluate(test_preds, ['train', 'valid', 'test']) diff --git a/torch_geometric/nn/models/glem.py b/torch_geometric/nn/models/glem.py index afc8b09d77c7..d30d5f8bd062 100644 --- a/torch_geometric/nn/models/glem.py +++ b/torch_geometric/nn/models/glem.py @@ -144,7 +144,8 @@ def train(self, em_phase: str, train_loader: Union[DataLoader, acc (float): training accuracy loss (float): loss value """ - pseudo_labels = pseudo_labels.to(self.device) + if pseudo_labels is not None: + pseudo_labels = pseudo_labels.to(self.device) if em_phase == 'gnn': acc, loss = self.train_gnn(train_loader, optimizer, epoch, pseudo_labels, is_augmented, verbose)