Skip to content

Commit

Permalink
Added support for weighted sparse_cross_entropy (#8340)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Nov 7, 2023
1 parent c2137ad commit 1725f14
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for weighted `sparse_cross_entropy` ([#8340](https://github.com/pyg-team/pytorch_geometric/pull/8340))
- Added a multi GPU training benchmarks for XPU device ([#8288](https://github.com/pyg-team/pytorch_geometric/pull/8288))
- Support MRR computation in `KGEModel.test()` ([#8298](https://github.com/pyg-team/pytorch_geometric/pull/8298))
- Added an example for model parallelism (`examples/multi_gpu/model_parallel.py`) ([#8309](https://github.com/pyg-team/pytorch_geometric/pull/8309))
Expand Down
21 changes: 17 additions & 4 deletions test/utils/test_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,50 @@
import pytest
import torch
import torch.nn.functional as F

from torch_geometric.utils.cross_entropy import sparse_cross_entropy


def test_sparse_cross_entropy_multiclass():
@pytest.mark.parametrize('with_edge_label_weight', [False, True])
def test_sparse_cross_entropy_multiclass(with_edge_label_weight):
x = torch.randn(5, 5, requires_grad=True)
y = torch.eye(5)

edge_label_index = y.nonzero().t()
edge_label_weight = None
if with_edge_label_weight:
edge_label_weight = torch.rand(edge_label_index.size(1))
y[y == 1.0] = edge_label_weight

expected = F.cross_entropy(x, y)
expected.backward()
expected_grad = x.grad

x.grad = None
out = sparse_cross_entropy(x, edge_label_index)
out = sparse_cross_entropy(x, edge_label_index, edge_label_weight)
out.backward()

assert torch.allclose(expected, out)
assert torch.allclose(expected_grad, x.grad)


def test_sparse_cross_entropy_multilabel():
@pytest.mark.parametrize('with_edge_label_weight', [False, True])
def test_sparse_cross_entropy_multilabel(with_edge_label_weight):
x = torch.randn(4, 4, requires_grad=True)
y = torch.randint_like(x, 0, 2)

edge_label_index = y.nonzero().t()
edge_label_weight = None
if with_edge_label_weight:
edge_label_weight = torch.rand(edge_label_index.size(1))
y[y == 1.0] = edge_label_weight

expected = F.cross_entropy(x, y)
expected.backward()
expected_grad = x.grad

x.grad = None
out = sparse_cross_entropy(x, edge_label_index)
out = sparse_cross_entropy(x, edge_label_index, edge_label_weight)
out.backward()

assert torch.allclose(expected, out)
Expand Down
55 changes: 41 additions & 14 deletions torch_geometric/utils/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Optional, Tuple

import torch
from torch import Tensor
Expand All @@ -10,26 +10,38 @@ class SparseCrossEntropy(torch.autograd.Function):
# We implement our own custom autograd function for this to avoid the
# double gradient computation to `inputs`.
@staticmethod
def forward(ctx, inputs: Tensor, edge_label_index: Tensor) -> Tensor:
def forward(
ctx,
inputs: Tensor,
edge_label_index: Tensor,
edge_label_weight: Optional[Tensor],
) -> Tensor:
assert inputs.dim() == 2

logsumexp = inputs.logsumexp(dim=-1)
ctx.save_for_backward(inputs, edge_label_index, logsumexp)
ctx.save_for_backward(inputs, edge_label_index, edge_label_weight,
logsumexp)

out = inputs[edge_label_index[0], edge_label_index[1]]
out.neg_().add_(logsumexp[edge_label_index[0]])
if edge_label_weight is not None:
out *= edge_label_weight

return out.sum() / inputs.size(0)

@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_out: Tensor) -> Tuple[Tensor, None]:
inputs, edge_label_index, logsumexp = ctx.saved_tensors
def backward(ctx, grad_out: Tensor) -> Tuple[Tensor, None, None]:
inputs, edge_label_index, edge_label_weight, logsumexp = (
ctx.saved_tensors)

grad_out = grad_out / inputs.size(0)
grad_out = grad_out.expand(edge_label_index.size(1))

grad_logsumexp = scatter(grad_out.expand(edge_label_index.size(1)),
edge_label_index[0], dim=0,
if edge_label_weight is not None:
grad_out = grad_out * edge_label_weight

grad_logsumexp = scatter(grad_out, edge_label_index[0], dim=0,
dim_size=inputs.size(0), reduce='sum')

# Gradient computation of `logsumexp`: `grad * (self - result).exp()`
Expand All @@ -39,27 +51,42 @@ def backward(ctx, grad_out: Tensor) -> Tuple[Tensor, None]:

grad_input[edge_label_index[0], edge_label_index[1]] -= grad_out

return grad_input, None
return grad_input, None, None


def sparse_cross_entropy(inputs: Tensor, edge_label_index: Tensor) -> Tensor:
def sparse_cross_entropy(
inputs: Tensor,
edge_label_index: Tensor,
edge_label_weight: Optional[Tensor] = None,
) -> Tensor:
r"""A sparse-label variant of :func:`torch.nn.functional.cross_entropy`.
In particular, the binary target matrix is solely given by sparse indices
:obj:`edge_label_index`.
Args:
inputs (torch.Tensor): The predicted unnormalized logits of shape
:obj:`[batch_size, num_classes]`.
edge_label_index (torch.Tensor): The sparse ground-truth indices of
edge_label_index (torch.Tensor): The sparse ground-truth indices with
shape :obj:`[2, num_labels]`.
edge_label_weight (torch.Tensor, optional): The weight of ground-truth
indices with shape :obj:`[num_labels]`. (default: :obj:`None`)
:rtype: :class:`torch.Tensor`
Example:
>>> inputs = torch.randn(2, 3)
>>> edge_label_index = torch.tensor([[0, 0, 1],
... [0, 1, 2]])
>>> sparse_cross_entropy(inputs, edge_label_index)
>>> edge_label_index = torch.tensor([
... [0, 0, 1],
... [0, 1, 2],
... ])
>>> loss = sparse_cross_entropy(inputs, edge_label_index)
tensor(1.2919)
"""
return SparseCrossEntropy.apply(inputs, edge_label_index)
if edge_label_weight is not None:
assert not edge_label_weight.requires_grad

return SparseCrossEntropy.apply(
inputs,
edge_label_index,
edge_label_weight,
)

0 comments on commit 1725f14

Please sign in to comment.