diff --git a/torch_geometric/utils/rag/backend_utils.py b/torch_geometric/utils/rag/backend_utils.py index 19b2da0656d4..7a89e0f922d8 100644 --- a/torch_geometric/utils/rag/backend_utils.py +++ b/torch_geometric/utils/rag/backend_utils.py @@ -246,8 +246,11 @@ def make_pcst_filter(triples: List[Tuple[str, str, str]], raise Exception("PCST requires `pip install pandas`") all_nodes = [] for triple in triples: - all_nodes += [triple[0]] + [triple[2]] - full_textual_nodes = list(set(all_nodes)) + if triple[0] in all_nodes: + all_nodes.append(all_nodes) + if triple[2] in all_nodes: + all_nodes.append(all_nodes) + full_textual_nodes = all_nodes def apply_retrieval_via_pcst( graph: Data,