diff --git a/torch_geometric/loader/utils.py b/torch_geometric/loader/utils.py index 3eaa31f3e13c..2c6d91e028cc 100644 --- a/torch_geometric/loader/utils.py +++ b/torch_geometric/loader/utils.py @@ -248,7 +248,7 @@ def filter_custom_store( def get_input_nodes( data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], input_nodes: Union[InputNodes, TensorAttr], - input_id: Optional[Tensor], + input_id: Optional[Tensor] = None, ) -> Tuple[Optional[str], Tensor, Optional[Tensor]]: def to_index(nodes, input_id) -> Tuple[Tensor, Optional[Tensor]]: if isinstance(nodes, Tensor) and nodes.dtype == torch.bool: