Skip to content

Commit

Permalink
Fix LGConv test cases (#8847)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jan 31, 2024
1 parent c2e2a25 commit f869e3a
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions test/nn/conv/test_lg_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,23 @@ def test_lg_conv():
assert str(conv) == 'LGConv()'
out1 = conv(x, edge_index)
assert out1.size() == (4, 8)
assert torch.allclose(conv(x, adj1.t()), out1)
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)

out2 = conv(x, edge_index, value)
assert out2.size() == (4, 8)
assert torch.allclose(conv(x, adj2.t()), out2)
assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)

if torch_geometric.typing.WITH_TORCH_SPARSE:
adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))
adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4))
assert torch.allclose(conv(x, adj3.t()), out1)
assert torch.allclose(conv(x, adj4.t()), out2)
assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)
assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)

if is_full_test():
jit = torch.jit.script(conv)
assert torch.allclose(jit(x, edge_index), out1)
assert torch.allclose(jit(x, edge_index, value), out2)
assert torch.allclose(jit(x, edge_index), out1, atol=1e-6)
assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6)

if torch_geometric.typing.WITH_TORCH_SPARSE:
assert torch.allclose(jit(x, adj3.t()), out1)
assert torch.allclose(jit(x, adj4.t()), out2)
assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6)
assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6)

0 comments on commit f869e3a

Please sign in to comment.