Skip to content

Commit

Permalink
Update cugraph conv layers for pylibcugraphops=23.04 (#7023)
Browse files Browse the repository at this point in the history
This PR updates cugraph models to reflect breaking changes in
`pylibcugraphops=23.04`.
~~Right now, it is **blocked** by RAPIDS 23.04 release.~~

CC: @MatthiasKohl @stadlmax

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and rusty1s committed Apr 27, 2023
1 parent b3cba5e commit ff9fb3d
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 26 deletions.
63 changes: 43 additions & 20 deletions torch_geometric/nn/conv/cugraph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,26 @@
from torch_geometric.utils.sparse import index2ptr

try: # pragma: no cover
from pylibcugraphops import (
make_fg_csr,
make_fg_csr_hg,
make_mfg_csr,
make_mfg_csr_hg,
LEGACY_MODE = False
from pylibcugraphops.pytorch import (
SampledCSC,
SampledHeteroCSC,
StaticCSC,
StaticHeteroCSC,
)
HAS_PYLIBCUGRAPHOPS = True
except ImportError:
HAS_PYLIBCUGRAPHOPS = False
try: # pragma: no cover
from pylibcugraphops import (
make_fg_csr,
make_fg_csr_hg,
make_mfg_csr,
make_mfg_csr_hg,
)
LEGACY_MODE = True
except ImportError:
pass


class CuGraphModule(torch.nn.Module): # pragma: no cover
Expand All @@ -25,9 +36,9 @@ class CuGraphModule(torch.nn.Module): # pragma: no cover
def __init__(self):
super().__init__()

if HAS_PYLIBCUGRAPHOPS is False:
if not HAS_PYLIBCUGRAPHOPS and not LEGACY_MODE:
raise ModuleNotFoundError(f"'{self.__class__.__name__}' requires "
f"'pylibcugraphops'")
f"'pylibcugraphops>=23.02'")

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
Expand Down Expand Up @@ -99,12 +110,17 @@ def get_cugraph(
if max_num_neighbors is None:
max_num_neighbors = int((colptr[1:] - colptr[:-1]).max())

dst_nodes = torch.arange(colptr.numel() - 1, device=row.device)
if LEGACY_MODE:
dst_nodes = torch.arange(colptr.numel() - 1, device=row.device)
return make_mfg_csr(dst_nodes, colptr, row, max_num_neighbors,
num_src_nodes)

return SampledCSC(colptr, row, max_num_neighbors, num_src_nodes)

return make_mfg_csr(dst_nodes, colptr, row, max_num_neighbors,
num_src_nodes)
if LEGACY_MODE:
return make_fg_csr(colptr, row)

return make_fg_csr(colptr, row)
return StaticCSC(colptr, row)

def get_typed_cugraph(
self,
Expand Down Expand Up @@ -142,17 +158,24 @@ def get_typed_cugraph(
if max_num_neighbors is None:
max_num_neighbors = int((colptr[1:] - colptr[:-1]).max())

dst_nodes = torch.arange(colptr.numel() - 1, device=row.device)
if LEGACY_MODE:
dst_nodes = torch.arange(colptr.numel() - 1, device=row.device)
return make_mfg_csr_hg(dst_nodes, colptr, row,
max_num_neighbors, num_src_nodes,
n_node_types=0,
n_edge_types=num_edge_types,
out_node_types=None, in_node_types=None,
edge_types=edge_type)

return SampledHeteroCSC(colptr, row, edge_type, max_num_neighbors,
num_src_nodes, num_edge_types)

return make_mfg_csr_hg(dst_nodes, colptr, row, max_num_neighbors,
num_src_nodes, n_node_types=0,
n_edge_types=num_edge_types,
out_node_types=None, in_node_types=None,
edge_types=edge_type)
if LEGACY_MODE:
return make_fg_csr_hg(colptr, row, n_node_types=0,
n_edge_types=num_edge_types, node_types=None,
edge_types=edge_type)

return make_fg_csr_hg(colptr, row, n_node_types=0,
n_edge_types=num_edge_types, node_types=None,
edge_types=edge_type)
return StaticHeteroCSC(colptr, row, edge_type, num_edge_types)

def forward(
self,
Expand Down
15 changes: 12 additions & 3 deletions torch_geometric/nn/conv/cugraph/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
from torch.nn import Linear, Parameter

from torch_geometric.nn.conv.cugraph import CuGraphModule
from torch_geometric.nn.conv.cugraph.base import LEGACY_MODE
from torch_geometric.nn.inits import zeros

try:
from pylibcugraphops.torch.autograd import mha_gat_n2n as GATConvAgg
if LEGACY_MODE:
from pylibcugraphops.torch.autograd import mha_gat_n2n as GATConvAgg
else:
from pylibcugraphops.pytorch.operators import mha_gat_n2n as GATConvAgg
except ImportError:
pass

Expand Down Expand Up @@ -67,8 +71,13 @@ def forward(
graph = self.get_cugraph(csc, max_num_neighbors)

x = self.lin(x)
out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',
self.negative_slope, False, self.concat)

if LEGACY_MODE:
out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',
self.negative_slope, False, self.concat)
else:
out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',
self.negative_slope, self.concat)

if self.bias is not None:
out = out + self.bias
Expand Down
9 changes: 7 additions & 2 deletions torch_geometric/nn/conv/cugraph/rgcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
from torch.nn import Parameter

from torch_geometric.nn.conv.cugraph import CuGraphModule
from torch_geometric.nn.conv.cugraph.base import LEGACY_MODE
from torch_geometric.nn.inits import glorot, zeros

try:
from pylibcugraphops.torch.autograd import \
agg_hg_basis_n2n_post as RGCNConvAgg
if LEGACY_MODE:
from pylibcugraphops.torch.autograd import \
agg_hg_basis_n2n_post as RGCNConvAgg
else:
from pylibcugraphops.pytorch.operators import \
agg_hg_basis_n2n_post as RGCNConvAgg
except ImportError:
pass

Expand Down
8 changes: 7 additions & 1 deletion torch_geometric/nn/conv/cugraph/sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
from torch.nn import Linear

from torch_geometric.nn.conv.cugraph import CuGraphModule
from torch_geometric.nn.conv.cugraph.base import LEGACY_MODE

try:
from pylibcugraphops.torch.autograd import agg_concat_n2n as SAGEConvAgg
if LEGACY_MODE:
from pylibcugraphops.torch.autograd import \
agg_concat_n2n as SAGEConvAgg
else:
from pylibcugraphops.pytorch.operators import \
agg_concat_n2n as SAGEConvAgg
except ImportError:
pass

Expand Down

0 comments on commit ff9fb3d

Please sign in to comment.