diff --git a/examples/llm/glem.py b/examples/llm/glem.py index ec76cef4c010..c6bae703fd33 100644 --- a/examples/llm/glem.py +++ b/examples/llm/glem.py @@ -371,9 +371,11 @@ def load_model(em_phase): if gnn_val_acc > lm_val_acc: em_phase = 'gnn' model.gnn = model.gnn.to(device, non_blocking=True) + test_loader = subgraph_loader else: em_phase = 'lm' model.lm = model.lm.to(device, non_blocking=True) + test_loader = text_test_loader test_preds = model.inference(em_phase, test_loader, verbose=verbose) train_acc, val_acc, test_acc = evaluate(test_preds, ['train', 'valid', 'test']) diff --git a/torch_geometric/nn/models/glem.py b/torch_geometric/nn/models/glem.py index afc8b09d77c7..d30d5f8bd062 100644 --- a/torch_geometric/nn/models/glem.py +++ b/torch_geometric/nn/models/glem.py @@ -144,7 +144,8 @@ def train(self, em_phase: str, train_loader: Union[DataLoader, acc (float): training accuracy loss (float): loss value """ - pseudo_labels = pseudo_labels.to(self.device) + if pseudo_labels is not None: + pseudo_labels = pseudo_labels.to(self.device) if em_phase == 'gnn': acc, loss = self.train_gnn(train_loader, optimizer, epoch, pseudo_labels, is_augmented, verbose)