Skip to content

Commit

Permalink
Add group_cat functionality (#9029)
Browse files Browse the repository at this point in the history
Concatenates the given sequence of `tensors` in the given dimension
`dim`. Differ from `torch.cat`, values along the concatenating dimension
are grouped according to the indicies defined in the `index` tensors.

All tensors must either have the same shape (except in the concatenating
dimension) or be empty and all index tensors must in the same sequence
as in `tensors`.

Examples:
```
>>> x1 = torch.tensor([[0.2716, 0.4233, 0.2658, 0.8284],
...                    [0.3166, 0.0142, 0.1700, 0.2944],
...                    [0.2371, 0.3839, 0.7193, 0.2954],
...                    [0.4100, 0.0012, 0.5114, 0.3353]])
>>> x2 = torch.tensor([[0.3752, 0.5782, 0.7105, 0.4002],
...                    [0.7757, 0.5999, 0.7898, 0.0753]])

>>> x1_index = torch.LongTensor([0,0,1,2])
>>> x2_index = torch.LongTensor([0,2])

>>> scatter_concat([x1,x2], [x1_index,x2_index], 0)
tensor([[0.2716, 0.4233, 0.2658, 0.8284],
              [0.3166, 0.0142, 0.1700, 0.2944],
              [0.3752, 0.5782, 0.7105, 0.4002],
              [0.2371, 0.3839, 0.7193, 0.2954],
              [0.4100, 0.0012, 0.5114, 0.3353],
              [0.7757, 0.5999, 0.7898, 0.0753]])
```

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
3 people authored Mar 12, 2024
1 parent e0d6b66 commit f8ebf6a
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `group_cat` functionality ([#9029](https://github.com/pyg-team/pytorch_geometric/pull/9029))
- Added support for `EdgeIndex` in `spmm` ([#9026](https://github.com/pyg-team/pytorch_geometric/pull/9026))
- Added option to pre-allocate memory in GPU-based `ApproxKNN` ([#9046](https://github.com/pyg-team/pytorch_geometric/pull/9046))
- Added support for `EdgeIndex` in `MessagePassing` ([#9007](https://github.com/pyg-team/pytorch_geometric/pull/9007))
Expand Down
21 changes: 20 additions & 1 deletion test/utils/test_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch_geometric.typing
from torch_geometric.profile import benchmark
from torch_geometric.testing import withCUDA, withPackage
from torch_geometric.utils import group_argsort, scatter
from torch_geometric.utils import group_argsort, group_cat, scatter
from torch_geometric.utils._scatter import scatter_argmax


Expand Down Expand Up @@ -111,6 +111,25 @@ def test_scatter_argmax(device):
assert argmax.tolist() == [3, 5, 1, 4, 5, 5]


@withCUDA
def test_group_cat(device):
x1 = torch.randn(4, 4, device=device)
x2 = torch.randn(2, 4, device=device)
index1 = torch.tensor([0, 0, 1, 2], device=device)
index2 = torch.tensor([0, 2], device=device)

expected = torch.cat([x1[:2], x2[:1], x1[2:4], x2[1:]], dim=0)

out, index = group_cat(
[x1, x2],
[index1, index2],
dim=0,
return_index=True,
)
assert torch.equal(out, expected)
assert index.tolist() == [0, 0, 0, 1, 2, 2]


if __name__ == '__main__':
# Insights on GPU:
# ================
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import copy

from ._scatter import scatter, group_argsort
from ._scatter import scatter, group_argsort, group_cat
from ._segment import segment
from ._index_sort import index_sort
from .functions import cumsum
Expand Down Expand Up @@ -60,6 +60,7 @@
__all__ = [
'scatter',
'group_argsort',
'group_cat',
'segment',
'index_sort',
'cumsum',
Expand Down
46 changes: 45 additions & 1 deletion torch_geometric/utils/_scatter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Sequence, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -283,3 +283,47 @@ def group_argsort(
ptr = cumsum(count)

return out - ptr[index]


def group_cat(
tensors: Sequence[Tensor],
index: Sequence[Tensor],
dim: int = 0,
return_index: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
r"""Concatenates the given sequence of tensors :obj:`tensors` in the given
dimension :obj:`dim`.
Different from :meth:`torch.cat`, values along the concatenating dimension
are grouped according to the indicies defined in the :obj:`index` tensors.
All tensors must have the same shape (except in the concatenating
dimension).
Args:
tensors ([Tensor]): Sequence of tensors.
index ([Tensor]): Sequence of index tensors.
dim (int, optional): The dimension along which the tensors are
concatenated. (default: :obj:`0`)
return_index (bool, optional): If set to :obj:`True`, will return the
new index tensor. (default: :obj:`False`)
Example:
>>> x1 = torch.tensor([[0.2716, 0.4233],
... [0.3166, 0.0142],
... [0.2371, 0.3839],
... [0.4100, 0.0012]])
>>> x2 = torch.tensor([[0.3752, 0.5782],
... [0.7757, 0.5999]])
>>> index1 = torch.tensor([0, 0, 1, 2])
>>> index2 = torch.tensor([0, 2])
>>> scatter_concat([x1,x2], [index1, index2], dim=0)
tensor([[0.2716, 0.4233],
[0.3166, 0.0142],
[0.3752, 0.5782],
[0.2371, 0.3839],
[0.4100, 0.0012],
[0.7757, 0.5999]])
"""
assert len(tensors) == len(index)
index, perm = torch.cat(index).sort(stable=True)
out = torch.cat(tensors, dim=0)[perm]
return (out, index) if return_index else out

0 comments on commit f8ebf6a

Please sign in to comment.