diff --git a/test/nn/conv/test_hetero_conv.py b/test/nn/conv/test_hetero_conv.py index 358fe8163f56..b0c03656aa87 100644 --- a/test/nn/conv/test_hetero_conv.py +++ b/test/nn/conv/test_hetero_conv.py @@ -204,4 +204,4 @@ def test_compile_hetero_conv_graph_breaks(device): out = compiled_conv(data.x_dict, data.edge_index_dict) assert len(out) == len(expected) for key in expected.keys(): - assert torch.allclose(out[key], expected[key]) + assert torch.allclose(out[key], expected[key], atol=1e-6) diff --git a/test/utils/test_trim_to_layer.py b/test/utils/test_trim_to_layer.py index c8296008451c..884381dd8435 100644 --- a/test/utils/test_trim_to_layer.py +++ b/test/utils/test_trim_to_layer.py @@ -196,7 +196,7 @@ def test_trim_to_layer_with_neighbor_loader(): batch.num_sampled_nodes, batch.num_sampled_edges)[:2] assert out2.size() == (2, 16) - assert torch.allclose(out1, out2) + assert torch.allclose(out1, out2, atol=1e-6) def test_trim_to_layer_filtering():