Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jan 16, 2024
1 parent 4d451a0 commit d245f87
Show file tree
Hide file tree
Showing 7 changed files with 8 additions and 23 deletions.
10 changes: 3 additions & 7 deletions test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,13 +482,9 @@ def test_my_default_arg_conv():
adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))
assert conv(x, adj2.t()).view(-1).tolist() == [0, 0, 0, 0]


def test_my_default_arg_conv_jit():
conv = MyDefaultArgConv()

# This should not succeed in JIT mode.
with pytest.raises((RuntimeError, AttributeError)):
torch.jit.script(conv.jit)
jit = torch.jit.script(conv)
assert jit(x, edge_index).view(-1).tolist() == [0, 0, 0, 0]
assert jit(x, adj1.t()).view(-1).tolist() == [0, 0, 0, 0]


class MyMultipleOutputConv(MessagePassing):
Expand Down
3 changes: 1 addition & 2 deletions torch_geometric/nn/conv/arma_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def forward(self, x: Tensor, edge_index: Adj,

return out.mean(dim=-3)

def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
assert edge_weight is not None
def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
return edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:
Expand Down
3 changes: 1 addition & 2 deletions torch_geometric/nn/conv/cluster_gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ def forward(self, x: Tensor, edge_index: Adj) -> Tensor:

return out

def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
assert edge_weight is not None
def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
return edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:
Expand Down
8 changes: 1 addition & 7 deletions torch_geometric/nn/conv/dna_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,7 @@ def forward(
# propagate_type: (x: Tensor, edge_weight: OptTensor)
return self.propagate(edge_index, x=x, edge_weight=edge_weight)

def message(
self,
x_i: Tensor,
x_j: Tensor,
edge_weight: OptTensor,
) -> Tensor:
assert edge_weight is not None
def message(self, x_i: Tensor, x_j: Tensor, edge_weight: Tensor) -> Tensor:
x_i = x_i[:, -1:] # [num_edges, 1, channels]
out = self.multi_head(x_i, x_j, x_j) # [num_edges, 1, channels]
return edge_weight.view(-1, 1) * out.squeeze(1)
Expand Down
3 changes: 1 addition & 2 deletions torch_geometric/nn/conv/gin_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ def forward(

return self.nn(out)

def message(self, x_j: Tensor, edge_attr: OptTensor) -> Tensor:
assert edge_attr is not None
def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
if self.lin is None and x_j.size(-1) != edge_attr.size(-1):
raise ValueError("Node and edge feature dimensionalities do not "
"match. Consider setting the 'edge_dim' "
Expand Down
3 changes: 1 addition & 2 deletions torch_geometric/nn/conv/gmm_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,

return out

def message(self, x_j: Tensor, edge_attr: OptTensor) -> Tensor:
assert edge_attr is not None
def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
EPS = 1e-15
F, M = self.rel_in_channels, self.out_channels
(E, D), K = edge_attr.size(), self.kernel_size
Expand Down
1 change: 0 additions & 1 deletion torch_geometric/nn/conv/simple_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,4 @@ def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:
assert isinstance(self.aggr, str)
return spmm(adj_t, x[0], reduce=self.aggr)

0 comments on commit d245f87

Please sign in to comment.