Skip to content

Commit

Permalink
Fix type hints in group_cat (#9049)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Mar 12, 2024
1 parent f8ebf6a commit cfdb4ce
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions torch_geometric/utils/_scatter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Sequence, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -286,8 +286,8 @@ def group_argsort(


def group_cat(
tensors: Sequence[Tensor],
index: Sequence[Tensor],
tensors: Union[List[Tensor], Tuple[Tensor, ...]],
indices: Union[List[Tensor], Tuple[Tensor, ...]],
dim: int = 0,
return_index: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
Expand All @@ -300,7 +300,7 @@ def group_cat(
Args:
tensors ([Tensor]): Sequence of tensors.
index ([Tensor]): Sequence of index tensors.
indices ([Tensor]): Sequence of index tensors.
dim (int, optional): The dimension along which the tensors are
concatenated. (default: :obj:`0`)
return_index (bool, optional): If set to :obj:`True`, will return the
Expand All @@ -323,7 +323,7 @@ def group_cat(
[0.4100, 0.0012],
[0.7757, 0.5999]])
"""
assert len(tensors) == len(index)
index, perm = torch.cat(index).sort(stable=True)
assert len(tensors) == len(indices)
index, perm = torch.cat(indices).sort(stable=True)
out = torch.cat(tensors, dim=0)[perm]
return (out, index) if return_index else out

0 comments on commit cfdb4ce

Please sign in to comment.