diff --git a/torch_geometric/utils/rag/backend_utils.py b/torch_geometric/utils/rag/backend_utils.py index bcb15fcd4b34..be5f41d998a7 100644 --- a/torch_geometric/utils/rag/backend_utils.py +++ b/torch_geometric/utils/rag/backend_utils.py @@ -244,13 +244,12 @@ def make_pcst_filter(triples: List[Tuple[str, str, str]], model: SentenceTransformer): if DataFrame is None: raise Exception("PCST requires `pip install pandas`") - all_nodes = [] - for triple in triples: - if triple[0] in all_nodes: - all_nodes.append(triple[0]) - if triple[2] in all_nodes: - all_nodes.append(triple[2]) - full_textual_nodes = all_nodes + nodes = [] + for h, r, t in triples: + for node in (h, t): + if node not in nodes: + nodes.append(node) + full_textual_nodes = nodes def apply_retrieval_via_pcst( graph: Data,