From 7f50de1b8450d59c29b5e8f41c5d132f18169fd6 Mon Sep 17 00:00:00 2001 From: Rishi Puri Date: Fri, 17 Nov 2023 00:05:22 -0800 Subject: [PATCH] default to None for input_ids, fixes cugraph failures (#8394) ``` 713E File "/opt/rapids/cugraph/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py", line 511, in __iter__ 714E self.current_loader = EXPERIMENTAL__BulkSampleLoader( 715E File "/opt/rapids/cugraph/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py", line 151, in __init__ 716E input_type, input_nodes = torch_geometric.loader.utils.get_input_nodes( 717E TypeError: get_input_nodes() missing 1 required positional argument: 'input_id' ``` --- torch_geometric/loader/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/loader/utils.py b/torch_geometric/loader/utils.py index 3eaa31f3e13c..2c6d91e028cc 100644 --- a/torch_geometric/loader/utils.py +++ b/torch_geometric/loader/utils.py @@ -248,7 +248,7 @@ def filter_custom_store( def get_input_nodes( data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], input_nodes: Union[InputNodes, TensorAttr], - input_id: Optional[Tensor], + input_id: Optional[Tensor] = None, ) -> Tuple[Optional[str], Tensor, Optional[Tensor]]: def to_index(nodes, input_id) -> Tuple[Tensor, Optional[Tensor]]: if isinstance(nodes, Tensor) and nodes.dtype == torch.bool: