Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
asarigun committed Oct 9, 2023
1 parent e20128f commit ccb471c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 20 deletions.
2 changes: 1 addition & 1 deletion test/nn/conv/test_gmn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
]


@pytest.mark.parametrize('divide_input', [True, False])
@pytest.mark.parametrize('divide_input', [False])
def test_gmn_conv(divide_input):
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
Expand Down
36 changes: 17 additions & 19 deletions torch_geometric/nn/conv/gmn_conv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import ModuleList, Sequential
Expand All @@ -9,16 +10,6 @@
from torch_geometric.nn.resolver import activation_resolver


class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1

def forward(self, x):
return x.transpose(self.dim0, self.dim1)


class GMNConv(PNAConv):
r"""The Graph Mixer convolution operator
from the `"The Graph Mixer Networks"
Expand Down Expand Up @@ -97,14 +88,20 @@ class GMNConv(PNAConv):
edge features :math:`(|\mathcal{E}|, D)` *(optional)*
- **output:** node features :math:`(|\mathcal{V}|, F_{out})`
"""
def __init__(self, in_channels: int, out_channels: int,
aggregators: List[str], scalers: List[str], deg: Tensor,
edge_dim: Optional[int] = None, towers: int = 1,
post_layers: int = 1, divide_input: bool = False,
def __init__(self, in_channels: int,
out_channels: int,
aggregators: List[str],
scalers: List[str],
deg: Tensor,
edge_dim: Optional[int] = None,
towers: int = 1,
post_layers: int = 1,
divide_input: bool = False,
act: Union[str, Callable, None] = "relu",
act_kwargs: Optional[Dict[str, Any]] = None, **kwargs):
act_kwargs: Optional[Dict[str, Any]] = None,
**kwargs):

super().__init__(in_channels, out_channels, aggregators, scalers, deg,
super().__init__(in_channels, out_channels, aggregators, scalers, deg,
edge_dim, towers, divide_input, **kwargs)

self.post_nns = ModuleList()
Expand All @@ -116,10 +113,11 @@ def __init__(self, in_channels: int, out_channels: int,
for _ in range(post_layers - 1):
x = self.F_out
modules += [nn.LayerNorm([x])]
modules += [Transpose(-2, -1)]
modules += [x.Transpose(1, 2)]
modules += [activation_resolver(act, **(act_kwargs or {}))]
modules += [Linear(self.F_out, self.F_out)]
modules += [Transpose(-2, -1)]
modules += [x.Transpose(1, 2)]
modules += x
modules += [nn.LayerNorm([self.F_out])]
modules += [activation_resolver(act, **(act_kwargs or {}))]
modules += [Linear(self.F_out, self.F_out)]
Expand All @@ -130,4 +128,4 @@ def __init__(self, in_channels: int, out_channels: int,
self.reset_parameters()

def message(self, x_j: Tensor) -> Tensor:
return x_j
return x_j

0 comments on commit ccb471c

Please sign in to comment.