Skip to content

Commit

Permalink
EdgeIndex.is_sorted and torch.sparse.mm support (#8497)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Nov 30, 2023
1 parent b0053ce commit 369c955
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 50 deletions.
85 changes: 61 additions & 24 deletions test/data/test_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ def test_basic():
assert str(adj) == ('EdgeIndex([[0, 1, 1, 2],\n'
' [1, 0, 2, 1]])')
assert adj.sparse_size == (3, 3)

assert adj.sort_order is None
assert not adj.is_sorted
assert not adj.is_sorted_by_row
assert not adj.is_sorted_by_col

assert not isinstance(adj.as_tensor(), EdgeIndex)

Expand All @@ -39,22 +43,32 @@ def test_fill_cache_():
assert adj.sparse_size == (3, 3)
assert torch.equal(adj._rowptr, torch.tensor([0, 1, 3, 4]))

assert adj.sort_order == 'row'
assert adj.is_sorted
assert adj.is_sorted_by_row
assert not adj.is_sorted_by_col

adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col')
adj.validate().fill_cache_()
assert adj.sparse_size == (3, 3)
assert torch.equal(adj._colptr, torch.tensor([0, 1, 3, 4]))

assert adj.sort_order == 'col'
assert adj.is_sorted
assert not adj.is_sorted_by_row
assert adj.is_sorted_by_col


def test_clone():
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row')

out = adj.clone()
assert isinstance(out, EdgeIndex)
assert out.sort_order == 'row'
assert out.is_sorted_by_row

out = torch.clone(adj)
assert isinstance(out, EdgeIndex)
assert out.sort_order == 'row'
assert out.is_sorted_by_row


@withCUDA
Expand Down Expand Up @@ -161,7 +175,7 @@ def test_cat():
assert out.size() == (2, 8)
assert isinstance(out, EdgeIndex)
assert out.sparse_size == (4, 4)
assert out.sort_order is None
assert not out.is_sorted

out = torch.cat([adj1, adj2], dim=0)
assert out.size() == (4, 4)
Expand All @@ -176,14 +190,14 @@ def test_flip():
assert isinstance(out, EdgeIndex)
assert torch.equal(out, torch.tensor([[1, 0, 2, 1], [0, 1, 1, 2]]))
assert out.sparse_size == (3, 3)
assert out.sort_order == 'col'
assert out.is_sorted_by_col
assert torch.equal(out._colptr, torch.tensor([0, 1, 3, 4]))

out = adj.flip([0, 1])
assert isinstance(out, EdgeIndex)
assert torch.equal(out, torch.tensor([[1, 2, 0, 1], [2, 1, 1, 0]]))
assert out.sparse_size == (3, 3)
assert out.sort_order is None
assert not out.is_sorted
assert out._colptr is None


Expand All @@ -205,7 +219,7 @@ def test_narrow():
out = adj.narrow(dim=1, start=1, length=2)
assert torch.equal(out, torch.tensor([[1, 1], [0, 2]]))
assert isinstance(out, EdgeIndex)
assert out.sort_order == 'row'
assert out.is_sorted_by_row

out = adj.narrow(dim=0, start=0, length=1)
assert torch.equal(out, torch.tensor([[0, 1, 1, 2]]))
Expand All @@ -218,17 +232,17 @@ def test_getitem():
out = adj[:, torch.tensor([False, True, False, True])]
assert isinstance(out, EdgeIndex)
assert torch.equal(out, torch.tensor([[1, 2], [0, 1]]))
assert out.sort_order == 'row'
assert out.is_sorted_by_row

out = adj[..., torch.tensor([1, 3])]
assert isinstance(out, EdgeIndex)
assert torch.equal(out, torch.tensor([[1, 2], [0, 1]]))
assert out.sort_order is None
assert not out.is_sorted

out = adj[..., 1::2]
assert isinstance(out, EdgeIndex)
assert torch.equal(out, torch.tensor([[1, 2], [0, 1]]))
assert out.sort_order == 'row'
assert out.is_sorted_by_row

out = adj[:, 0]
assert not isinstance(out, EdgeIndex)
Expand Down Expand Up @@ -317,46 +331,69 @@ def test_to_sparse_csc():
def test_matmul_forward():
x = torch.randn(3, 1)
adj1 = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row')
adj1_dense = adj1.to_dense()
adj2 = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col')
adj2_dense = adj2.to_dense()

out = adj1 @ x
assert torch.allclose(out, adj1.to_dense() @ x)
assert torch.allclose(out, adj1_dense @ x)

out = adj1 @ adj1
assert torch.allclose(out.to_dense(), adj1.to_dense() @ adj1.to_dense())
out = adj1.matmul(x)
assert torch.allclose(out, adj1_dense @ x)

out = adj1 @ adj2
assert torch.allclose(out.to_dense(), adj1.to_dense() @ adj2.to_dense())
out = torch.matmul(adj1, x)
assert torch.allclose(out, adj1_dense @ x)

out = adj2 @ adj1
assert torch.allclose(out.to_dense(), adj2.to_dense() @ adj1.to_dense())
if torch_geometric.typing.WITH_PT20:
out = torch.sparse.mm(adj1, x, reduce='sum')
else:
with pytest.raises(TypeError, match="got an unexpected keyword"):
torch.sparse.mm(adj1, x, reduce='sum')
out = torch.sparse.mm(adj1, x)
assert torch.allclose(out, adj1_dense @ x)

out = adj2 @ adj2
assert torch.allclose(out.to_dense(), adj2.to_dense() @ adj2.to_dense())
out, value = adj1 @ adj1
assert isinstance(out, EdgeIndex)
assert out.is_sorted_by_row
assert out._sparse_size == (3, 3)
assert out._rowptr is not None
assert torch.allclose(to_dense(out, value=value), adj1_dense @ adj1_dense)

out, value = adj1 @ adj2
assert isinstance(out, EdgeIndex)
assert torch.allclose(to_dense(out, value=value), adj1_dense @ adj2_dense)

out, value = adj2 @ adj1
assert isinstance(out, EdgeIndex)
assert torch.allclose(to_dense(out, value=value), adj2_dense @ adj1_dense)

out, value = adj2 @ adj2
assert isinstance(out, EdgeIndex)
assert torch.allclose(to_dense(out, value=value), adj2_dense @ adj2_dense)


def test_matmul_input_value():
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row')

x = torch.randn(3, 1)
input_value = torch.randn(4)
value = torch.randn(4)

out = matmul(adj, x, input_value)
assert torch.allclose(out, to_dense(adj, value=input_value) @ x)
out = matmul(adj, x, input_value=value)
assert torch.allclose(out, to_dense(adj, value=value) @ x)


def test_matmul_backward():
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row')

x1 = torch.randn(3, 1, requires_grad=True)
input_value = torch.randn(4)
value = torch.randn(4)

out = matmul(adj, x1, input_value)
out = matmul(adj, x1, input_value=value)
grad_out = torch.randn_like(out)
out.backward(grad_out)

x2 = x1.detach().requires_grad_()
dense_adj = to_dense(adj, value=input_value)
dense_adj = to_dense(adj, value=value)
out = dense_adj @ x2
out.backward(grad_out)

Expand Down
84 changes: 58 additions & 26 deletions torch_geometric/data/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
torch.int64,
}

ReduceType = Literal['sum']


class SortOrder(Enum):
ROW = 'row'
Expand Down Expand Up @@ -233,6 +235,21 @@ def sort_order(self) -> Optional[str]:
"""
return None if self._sort_order is None else self._sort_order.value

@property
def is_sorted(self) -> bool:
r"""Returns whether indices are either sorted by rows or columns."""
return self._sort_order is not None

@property
def is_sorted_by_row(self) -> bool:
r"""Returns whether indices are sorted by rows."""
return self._sort_order == SortOrder.ROW

@property
def is_sorted_by_col(self) -> bool:
r"""Returns whether indices are sorted by columns."""
return self._sort_order == SortOrder.COL

def get_rowptr(self) -> Tensor:
r"""Returns the :obj:`rowptr` vector of :class:`EdgeIndex`, a
compressed representation of row indices in case :class:`EdgeIndex` is
Expand Down Expand Up @@ -271,9 +288,10 @@ def get_colptr(self) -> Tensor:

return self._colptr

def get_value(self, dtype: Optional[torch.dtype] = None) -> Tensor:
def _get_value(self, dtype: Optional[torch.dtype] = None) -> Tensor:
if self._value is not None:
return self._value # TODO Respect `dtype`
if (dtype or torch.get_default_dtype()) == self._value.dtype:
return self._value

if (torch_geometric.typing.WITH_PT20
and not torch_geometric.typing.WITH_ARM):
Expand Down Expand Up @@ -327,15 +345,11 @@ def sort_by(
return torch.return_types.sort([self, perm])

# If conversion from CSR->CSC or CSC->CSR is known, make use of it:
if (self._sort_order == SortOrder.ROW and sort_order == SortOrder.COL
and self._csr2csc is not None):

if self._sort_order == SortOrder.ROW and self._csr2csc is not None:
edge_index = self.as_tensor()[:, self._csr2csc]
perm = self._csr2csc

elif (self._sort_order == SortOrder.COL and sort_order == SortOrder.ROW
and self._csc2csr is not None):

elif self._sort_order == SortOrder.COL and self._csc2csr is not None:
edge_index = self.as_tensor()[:, self._csc2csr]
perm = self._csc2csr

Expand All @@ -360,9 +374,9 @@ def sort_by(
out._value = self._value

# Fill information for faster future CSR->CSC or CSC->CSR conversion:
if self._sort_order == SortOrder.ROW and sort_order == SortOrder.COL:
if self._sort_order == SortOrder.ROW:
out._csr2csc = self._csr2csc = perm
elif self._sort_order == SortOrder.COL and sort_order == SortOrder.ROW:
elif self._sort_order == SortOrder.COL:
out._csc2csr = self._csc2csr = perm

return torch.return_types.sort([out, perm])
Expand All @@ -379,15 +393,13 @@ def to_sparse_tensor(
value (torch.Tensor, optional): The values of non-zero indices.
(default: :obj:`None`)
"""
is_sorted = self._sort_order == SortOrder.ROW

return SparseTensor(
row=self[0],
col=self[1],
rowptr=self.get_rowptr() if is_sorted else None,
rowptr=self.get_rowptr() if self.is_sorted_by_row else None,
value=value,
sparse_sizes=self.get_sparse_size(),
is_sorted=is_sorted,
is_sorted=self.is_sorted_by_row,
trust_data=True,
)

Expand Down Expand Up @@ -426,7 +438,7 @@ def apply_(
out._sparse_size = tensor.sparse_size
out._sort_order = tensor._sort_order

# Convert cache:
# Convert cache (but do not consider `_value`):
if tensor._rowptr is not None:
out._rowptr = fn(tensor._rowptr, *args, **kwargs)
if tensor._colptr is not None:
Expand Down Expand Up @@ -649,7 +661,7 @@ def to_dense(
def to_sparse_coo(tensor: EdgeIndex, value: Optional[Tensor] = None) -> Tensor:
out = torch.sparse_coo_tensor(
indices=tensor.as_tensor(),
values=tensor.get_value() if value is None else value,
values=tensor._get_value() if value is None else value,
size=tensor.get_sparse_size(),
device=tensor.device,
)
Expand All @@ -665,7 +677,7 @@ def to_sparse_csr(tensor: EdgeIndex, value: Optional[Tensor] = None) -> Tensor:
return torch.sparse_csr_tensor(
crow_indices=tensor.get_rowptr(),
col_indices=tensor[1],
values=tensor.get_value() if value is None else value,
values=tensor._get_value() if value is None else value,
size=tensor.get_sparse_size(),
device=tensor.device,
)
Expand All @@ -681,7 +693,7 @@ def to_sparse_csc(
return torch.sparse_csc_tensor(
ccol_indices=tensor.get_colptr(),
row_indices=tensor[0],
values=tensor.get_value() if value is None else value,
values=tensor._get_value() if value is None else value,
size=tensor.get_sparse_size(),
device=tensor.device,
)
Expand Down Expand Up @@ -722,9 +734,6 @@ def to_sparse(tensor: EdgeIndex, value: Optional[Tensor] = None) -> Tensor:
return to_sparse_coo(tensor, value)


ReduceType = Literal['sum']


class SparseDenseMatmul(torch.autograd.Function):
@staticmethod
def forward(
Expand All @@ -740,7 +749,7 @@ def forward(
"'EdgeIndex' to be sorted by rows")

if reduce not in ReduceType.__args__:
raise NotImplementedError("`reduce='{reduce}'` not yet supported")
raise NotImplementedError(f"`reduce='{reduce}'` not yet supported")

if other.requires_grad:
ctx.save_for_backward(input, input_value)
Expand All @@ -756,7 +765,8 @@ def forward(
return torch.ops.torch_sparse.spmm_sum( #
None, rowptr, col, input_value, None, None, other)

input_value = input.get_value() if input_value is None else input_value
if input_value is None:
input_value = input._get_value()
adj = to_sparse_csr(input, input_value)

return adj @ other
Expand Down Expand Up @@ -787,7 +797,7 @@ def backward(

else:
if input_value is None:
input_value = input.get_value()
input_value = input._get_value()

adj_t = torch.sparse_csr_tensor(
crow_indices=input.get_colptr(),
Expand Down Expand Up @@ -817,7 +827,7 @@ def matmul(
) -> Union[Tensor, Tuple[EdgeIndex, Tensor]]:

if reduce not in ReduceType.__args__:
raise NotImplementedError("`reduce='{reduce}'` not yet supported")
raise NotImplementedError(f"`reduce='{reduce}'` not yet supported")

if not isinstance(other, EdgeIndex):
if other_value is not None:
Expand All @@ -835,4 +845,26 @@ def matmul(
else:
other = to_sparse_csr(other, other_value)

return torch.matmul(input, other)
out = torch.matmul(input, other)
assert out.layout == torch.sparse_csr

rowptr, col = out.crow_indices(), out.col_indices()
edge_index = torch._convert_indices_from_csr_to_coo(
rowptr, col, out_int32=rowptr.dtype != torch.int64)
edge_index = edge_index.to(rowptr.device)

edge_index = edge_index.as_subclass(EdgeIndex)
edge_index._sort_order = SortOrder.ROW
edge_index._sparse_size = (out.size(0), out.size(1))
edge_index._rowptr = rowptr

return edge_index, out.values()


@implements(torch.sparse.mm)
def _matmul(
mat1: EdgeIndex,
mat2: Union[Tensor, EdgeIndex],
reduce: ReduceType = 'sum',
) -> Union[Tensor, Tuple[EdgeIndex, Tensor]]:
return matmul(mat1, mat2, reduce=reduce)

0 comments on commit 369c955

Please sign in to comment.