From d99233aec9ae440836dbb40100090f8d3cafb8b3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 Nov 2023 01:02:15 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 138 ++++++++++--------- pyg_lib/csrc/sampler/neighbor.cpp | 21 +-- pyg_lib/sampler/__init__.py | 5 +- test/csrc/sampler/test_neighbor.cpp | 16 +-- 4 files changed, 92 insertions(+), 88 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index d37fec9e7..1481b3aa1 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -75,24 +75,25 @@ class NeighborSampler { } void edge_temporal_sample(const node_t global_src_node, - const scalar_t local_src_node, - const temporal_t* edge_times, - const int64_t count, - const temporal_t seed_time, - pyg::sampler::Mapper& dst_mapper, - pyg::random::RandintEngine& generator, - std::vector& out_global_dst_nodes) { + const scalar_t local_src_node, + const temporal_t* edge_times, + const int64_t count, + const temporal_t seed_time, + pyg::sampler::Mapper& dst_mapper, + pyg::random::RandintEngine& generator, + std::vector& out_global_dst_nodes) { auto row_start = rowptr_[to_scalar_t(global_src_node)]; auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; if ((row_end - row_start == 0) || (count == 0)) return; // Find new `row_end` such that all neighbors fulfill temporal constraints: - std::vector v(row_end-row_start); - std::iota (std::begin(v), std::end(v), row_start); - auto it = std::upper_bound( - v.begin(), v.end(), seed_time, - [&](const scalar_t& a, const scalar_t& b) { return a < edge_times[b]; }); + std::vector v(row_end - row_start); + std::iota(std::begin(v), std::end(v), row_start); + auto it = std::upper_bound(v.begin(), v.end(), seed_time, + [&](const scalar_t& a, const scalar_t& b) { + return a < edge_times[b]; + }); row_end = it - v.begin() + row_start; if (temporal_strategy_ == "last" && count >= 0) { @@ -103,20 +104,18 @@ class NeighborSampler { "Found invalid non-sorted temporal neighborhood"); } - _sample(global_src_node, local_src_node, row_start, row_end, count, dst_mapper, generator, out_global_dst_nodes); } - - + void node_temporal_sample(const node_t global_src_node, - const scalar_t local_src_node, - const int64_t count, - const temporal_t seed_time, - const temporal_t* time, - pyg::sampler::Mapper& dst_mapper, - pyg::random::RandintEngine& generator, - std::vector& out_global_dst_nodes) { + const scalar_t local_src_node, + const int64_t count, + const temporal_t seed_time, + const temporal_t* time, + pyg::sampler::Mapper& dst_mapper, + pyg::random::RandintEngine& generator, + std::vector& out_global_dst_nodes) { auto row_start = rowptr_[to_scalar_t(global_src_node)]; auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; @@ -360,10 +359,12 @@ sample(const at::Tensor& rowptr, TORCH_CHECK(col.is_contiguous(), "Non-contiguous 'col'"); TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed'"); if (node_time.has_value()) { - TORCH_CHECK(node_time.value().is_contiguous(), "Non-contiguous 'node_time'"); + TORCH_CHECK(node_time.value().is_contiguous(), + "Non-contiguous 'node_time'"); } if (edge_time.has_value()) { - TORCH_CHECK(edge_time.value().is_contiguous(), "Non-contiguous 'edge_time'"); + TORCH_CHECK(edge_time.value().is_contiguous(), + "Non-contiguous 'edge_time'"); } if (seed_time.has_value()) { TORCH_CHECK(seed_time.value().is_contiguous(), @@ -461,7 +462,7 @@ sample(const at::Tensor& rowptr, const auto batch_idx = sampled_nodes[i].first; sampler.edge_temporal_sample( /*global_src_node=*/sampled_nodes[i], - /*local_src_node=*/i, + /*local_src_node=*/i, /*edge_time=*/edge_time_data, /*count=*/count, /*seed_time=*/seed_times[batch_idx], @@ -471,8 +472,7 @@ sample(const at::Tensor& rowptr, if constexpr (distributed) cumsum_neighbors_per_node.push_back(sampled_nodes.size()); } - } - else { + } else { const auto time_data = node_time.value().data_ptr(); for (size_t i = begin; i < end; ++i) { const auto batch_idx = sampled_nodes[i].first; @@ -690,7 +690,6 @@ sample(const std::vector& node_types, // Not supported exit(0); } - } num_sampled_nodes_per_hop_map.at(kv.key())[0] = @@ -738,8 +737,10 @@ sample(const std::vector& node_types, /*generator=*/generator, /*out_global_dst_nodes=*/dst_sampled_nodes); } - } else if ((!node_time_dict.has_value() || !node_time_dict.value().contains(dst)) && - (!edge_time_dict.has_value() || !edge_time_dict.value().contains(to_rel_type(k)))) { + } else if ((!node_time_dict.has_value() || + !node_time_dict.value().contains(dst)) && + (!edge_time_dict.has_value() || + !edge_time_dict.value().contains(to_rel_type(k)))) { for (size_t i = begin; i < end; ++i) { sampler.uniform_sample( /*global_src_node=*/src_sampled_nodes[i], @@ -749,41 +750,44 @@ sample(const std::vector& node_types, /*generator=*/generator, /*out_global_dst_nodes=*/dst_sampled_nodes); } - } else if constexpr (!std::is_scalar::value){ - if (edge_time_dict.has_value() && edge_time_dict.value().contains(to_rel_type(k))) { - //Edge temporal sampling - const at::Tensor& edge_times = edge_time_dict.value().at(to_rel_type(k)); - const auto edge_times_data = edge_times.data_ptr(); - for (size_t i = begin; i < end; ++i) { - const auto batch_idx = src_sampled_nodes[i].first; - sampler.edge_temporal_sample( - /*global_src_node=*/src_sampled_nodes[i], - /*local_src_node=*/i, - /*edge_times=*/edge_times_data, - /*count=*/count, - /*seed_time=*/seed_times[batch_idx], - /*dst_mapper=*/dst_mapper, - /*generator=*/generator, - /*out_global_dst_nodes=*/dst_sampled_nodes); - } - } else { - // Temporal sampling: - const at::Tensor& dst_time = node_time_dict.value().at(dst); - const auto dst_time_data = dst_time.data_ptr(); - for (size_t i = begin; i < end; ++i) { - const auto batch_idx = src_sampled_nodes[i].first; - sampler.node_temporal_sample( - /*global_src_node=*/src_sampled_nodes[i], - /*local_src_node=*/i, - /*count=*/count, - /*seed_time=*/seed_times[batch_idx], - /*time=*/dst_time_data, - /*dst_mapper=*/dst_mapper, - /*generator=*/generator, - /*out_global_dst_nodes=*/dst_sampled_nodes); + } else if constexpr (!std::is_scalar::value) { + if (edge_time_dict.has_value() && + edge_time_dict.value().contains(to_rel_type(k))) { + // Edge temporal sampling + const at::Tensor& edge_times = + edge_time_dict.value().at(to_rel_type(k)); + const auto edge_times_data = + edge_times.data_ptr(); + for (size_t i = begin; i < end; ++i) { + const auto batch_idx = src_sampled_nodes[i].first; + sampler.edge_temporal_sample( + /*global_src_node=*/src_sampled_nodes[i], + /*local_src_node=*/i, + /*edge_times=*/edge_times_data, + /*count=*/count, + /*seed_time=*/seed_times[batch_idx], + /*dst_mapper=*/dst_mapper, + /*generator=*/generator, + /*out_global_dst_nodes=*/dst_sampled_nodes); + } + } else { + // Temporal sampling: + const at::Tensor& dst_time = node_time_dict.value().at(dst); + const auto dst_time_data = dst_time.data_ptr(); + for (size_t i = begin; i < end; ++i) { + const auto batch_idx = src_sampled_nodes[i].first; + sampler.node_temporal_sample( + /*global_src_node=*/src_sampled_nodes[i], + /*local_src_node=*/i, + /*count=*/count, + /*seed_time=*/seed_times[batch_idx], + /*time=*/dst_time_data, + /*dst_mapper=*/dst_mapper, + /*generator=*/generator, + /*out_global_dst_nodes=*/dst_sampled_nodes); + } } } - } } } }); @@ -908,8 +912,8 @@ neighbor_sample_kernel(const at::Tensor& rowptr, bool return_edge_id) { const auto out = [&] { DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col, - seed, num_neighbors, node_time, seed_time, edge_time, edge_weight, csc, - temporal_strategy); + seed, num_neighbors, node_time, seed_time, edge_time, + edge_weight, csc, temporal_strategy); }(); return std::make_tuple(std::get<0>(out), std::get<1>(out), std::get<2>(out), std::get<3>(out), std::get<4>(out), std::get<5>(out)); @@ -960,8 +964,8 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr, std::string temporal_strategy) { const auto out = [&] { DISPATCH_DIST_SAMPLE(replace, directed, disjoint, rowptr, col, seed, - {num_neighbors}, node_time, seed_time, edge_time, edge_weight, csc, - temporal_strategy); + {num_neighbors}, node_time, seed_time, edge_time, + edge_weight, csc, temporal_strategy); }(); return std::make_tuple(std::get<2>(out), std::get<3>(out).value(), std::get<6>(out)); diff --git a/pyg_lib/csrc/sampler/neighbor.cpp b/pyg_lib/csrc/sampler/neighbor.cpp index d8fa94355..4ebc693de 100644 --- a/pyg_lib/csrc/sampler/neighbor.cpp +++ b/pyg_lib/csrc/sampler/neighbor.cpp @@ -39,9 +39,9 @@ neighbor_sample(const at::Tensor& rowptr, static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::neighbor_sample", "") .typed(); - return op.call(rowptr, col, seed, num_neighbors, node_time, seed_time, edge_time, edge_weight, - csc, replace, directed, disjoint, temporal_strategy, - return_edge_id); + return op.call(rowptr, col, seed, num_neighbors, node_time, seed_time, + edge_time, edge_weight, csc, replace, directed, disjoint, + temporal_strategy, return_edge_id); } std::tuple, @@ -92,8 +92,8 @@ hetero_neighbor_sample( .typed(); return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict, num_neighbors_dict, node_time_dict, seed_time_dict, - edge_time_dict, edge_weight_dict, csc, replace, directed, disjoint, - temporal_strategy, return_edge_id); + edge_time_dict, edge_weight_dict, csc, replace, directed, + disjoint, temporal_strategy, return_edge_id); } std::tuple> dist_neighbor_sample( @@ -121,14 +121,16 @@ std::tuple> dist_neighbor_sample( static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::dist_neighbor_sample", "") .typed(); - return op.call(rowptr, col, seed, num_neighbors, node_time, seed_time, edge_time, edge_weight, - csc, replace, directed, disjoint, temporal_strategy); + return op.call(rowptr, col, seed, num_neighbors, node_time, seed_time, + edge_time, edge_weight, csc, replace, directed, disjoint, + temporal_strategy); } TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( "pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] " - "num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? edge_time = None, Tensor? " + "num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? " + "edge_time = None, Tensor? " "edge_weight = None, bool csc = False, bool replace = False, bool " "directed = True, bool disjoint = False, str temporal_strategy = " "'uniform', bool return_edge_id = True) -> " @@ -146,7 +148,8 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { "Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))")); m.def(TORCH_SELECTIVE_SCHEMA( "pyg::dist_neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int " - "num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? edge_time = None, Tensor? " + "num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? " + "edge_time = None, Tensor? " "edge_weight = None, bool csc = False, bool replace = False, bool " "directed = True, bool disjoint = False, str temporal_strategy = " "'uniform') -> (Tensor, Tensor, int[])")); diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index ba521bd4b..b171954e1 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -138,10 +138,7 @@ def hetero_neighbor_sample( for k, v in edge_weight_dict.items() } if edge_time_dict is not None: - edge_time_dict = { - TO_REL_TYPE[k]: v - for k, v in edge_time_dict.items() - } + edge_time_dict = {TO_REL_TYPE[k]: v for k, v in edge_time_dict.items()} out = torch.ops.pyg.hetero_neighbor_sample( node_types, diff --git a/test/csrc/sampler/test_neighbor.cpp b/test/csrc/sampler/test_neighbor.cpp index 49baaab58..1603eee9d 100644 --- a/test/csrc/sampler/test_neighbor.cpp +++ b/test/csrc/sampler/test_neighbor.cpp @@ -236,14 +236,14 @@ TEST(BiasedNeighborTest, BasicAssertions) { /*edge_time=*/edge_weight); /*edge_weight=*/edge_weight); - auto expected_row = at::tensor({0, 1}, options); - EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); - auto expected_col = at::tensor({2, 0}, options); - EXPECT_TRUE(at::equal(std::get<1>(out), expected_col)); - auto expected_nodes = at::tensor({0, 1, 5}, options); - EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes)); - auto expected_edges = at::tensor({0, 2}, options); - EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); + auto expected_row = at::tensor({0, 1}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); + auto expected_col = at::tensor({2, 0}, options); + EXPECT_TRUE(at::equal(std::get<1>(out), expected_col)); + auto expected_nodes = at::tensor({0, 1, 5}, options); + EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes)); + auto expected_edges = at::tensor({0, 2}, options); + EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); } TEST(HeteroBiasedNeighborTest, BasicAssertions) {